diff --git a/docs/source/en/cache_explanation.md b/docs/source/en/cache_explanation.md index 6c31035234bb..13f310669200 100644 --- a/docs/source/en/cache_explanation.md +++ b/docs/source/en/cache_explanation.md @@ -141,7 +141,7 @@ The legacy format is essentially the same data structure but organized different - The tensors have the same shape `[batch_size, num_heads, seq_len, head_dim]`. - The format is less flexible and doesn't support features like quantization or offloading. -If your project depends on this legacy format, you can convert between [`DynamicCache`] and a tuple of tuples as shown below with the [`~DynamicCache.from_legacy_cache`] and [`DynamicCache.to_legacy_cache`] functions. This is helpful if you have custom logic for manipulating a cache in a specific format. +If your project depends on this legacy format, we recommend to convert to [`DynamicCache`] with [`~DynamicCache.from_legacy_cache`]. Note that legacy cache format is deprecated and not used anymore in `Transformers`. You can convert back to tuple format with [`DynamicCache.to_legacy_cache`] functions, which is helpful if you have custom logic for manipulating a cache in a specific format. ```py import torch diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index b4478bcbac01..9ee097bba21f 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -28,7 +28,6 @@ if is_sklearn_available(): from sklearn.metrics import roc_curve -from ..cache_utils import Cache from ..pytorch_utils import isin_mps_friendly from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor, SuppressTokensLogitsProcessor @@ -295,9 +294,7 @@ def _update_past_and_masks( has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None if has_past_key_values: new_cache_size = input_ids.shape[-1] - 1 - remove_from_pkv - self.assistant_kwargs["past_key_values"] = _crop_past_key_values( - self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - num_added_tokens - ) + self.assistant_kwargs["past_key_values"].crop(new_cache_size - num_added_tokens) self.assistant_kwargs = _prepare_attention_mask( self.assistant_kwargs, input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder ) @@ -1180,47 +1177,6 @@ def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, return candidate_ids, candidate_logits -def _crop_past_key_values(model, past_key_values, max_length): - """Crops the past key values up to a certain maximum length.""" - new_past = [] - if isinstance(past_key_values, Cache): - past_key_values.crop(max_length) - elif model.config.is_encoder_decoder: - for idx in range(len(past_key_values)): - new_past.append( - ( - past_key_values[idx][0][:, :, :max_length, :], - past_key_values[idx][1][:, :, :max_length, :], - past_key_values[idx][2], - past_key_values[idx][3], - ) - ) - past_key_values = tuple(new_past) - # gptbigcode is special and stores kv in shape (batch_size, seq_len, dim), if it's a multi_query model - elif "gptbigcode" in model.__class__.__name__.lower() or ( - model.config.architectures is not None and "gptbigcode" in model.config.architectures[0].lower() - ): - if model.config.multi_query: - for idx in range(len(past_key_values)): - past_key_values[idx] = past_key_values[idx][:, :max_length, :] - else: - for idx in range(len(past_key_values)): - past_key_values[idx] = past_key_values[idx][:, :, :max_length, :] - elif past_key_values is not None: - for idx in range(len(past_key_values)): - if past_key_values[idx] != ([], []): - new_past.append( - ( - past_key_values[idx][0][:, :, :max_length, :], - past_key_values[idx][1][:, :, :max_length, :], - ) - ) - else: - new_past.append((past_key_values[idx][0], past_key_values[idx][1])) - past_key_values = tuple(new_past) - return past_key_values - - def _prepare_attention_mask(model_kwargs: dict[str, Any], new_length: int, is_encoder_decoder: bool) -> dict[str, Any]: """Expands or crops the model's mask for decoding purposes, to the defined length""" diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 1daab346d94d..e9b28eb102dd 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -68,7 +68,6 @@ EarlyExitCandidateGenerator, PromptLookupCandidateGenerator, UniversalSpeculativeDecodingGenerator, - _crop_past_key_values, _prepare_attention_mask, _prepare_token_type_ids, ) @@ -567,15 +566,7 @@ def prepare_inputs_for_generation( # 1. Handle BC: model_inputs = {} - # - some models don't have `Cache` support (which implies they don't expect `cache_position` in `forward`) - if self._supports_cache_class: - model_inputs["cache_position"] = cache_position - # - `cache_position` was not a mandatory input in `prepare_inputs_for_generation` for those models, and this - # function may be called outside of `generate`. Handle most use cases by creating `cache_position` on the fly - # (this alternative is not as robust as calling `generate` and letting it create `cache_position`) - elif cache_position is None: - past_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - cache_position = torch.arange(past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device) + model_inputs["cache_position"] = cache_position # 2. Generic cache-dependent input preparation if past_key_values is not None: @@ -1014,12 +1005,6 @@ def _update_model_kwargs_for_generation( model_kwargs["cache_position"] = torch.cat((past_positions, new_positions)) return model_kwargs - def _reorder_cache(self, past_key_values, beam_idx): - raise NotImplementedError( - f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to" - f" enable beam search for {self.__class__}" - ) - def _get_candidate_generator( self, generation_config: GenerationConfig, @@ -1559,13 +1544,6 @@ def _validate_assistant(self, assistant_model, tokenizer, assistant_tokenizer): def _validate_model_kwargs(self, model_kwargs: dict[str, Any]): """Validates model kwargs for generation. Generate argument typos will also be caught here.""" - # If a `Cache` instance is passed, checks whether the model is compatible with it - if isinstance(model_kwargs.get("past_key_values", None), Cache) and not self._supports_cache_class: - raise ValueError( - f"{self.__class__.__name__} does not support an instance of `Cache` as `past_key_values`. Please " - "check the model documentation for supported cache formats." - ) - # Excludes arguments that are handled before calling any model function if self.config.is_encoder_decoder: for key in ["decoder_input_ids"]: @@ -1975,21 +1953,23 @@ def _get_cache( self._cache.reset() return self._cache - def _supports_default_dynamic_cache(self) -> bool: + @classmethod + def _supports_default_dynamic_cache(cls) -> bool: """ Return `True` if current model can use a `DynamicCache` instance when initializing the `past_key_values`. - This is mostly the same as `_supports_cache_class` attribute, but add exception for `Jamba` model which - uses its own `HybridMambaAttentionDynamicCache` and do not need to initialize the Cache in advance in - order to save memory (because no back and forth `to_legacy_cache` and `from_legacy_cache` will be performed - for `HybridMambaAttentionDynamicCache`). + This adds exception for some models like `Jamba` model which uses its own `HybridMambaAttentionDynamicCache` + and do not need to initialize the Cache in advance in order to save memory (because no back and forth + `to_legacy_cache` and `from_legacy_cache` will be performed for `HybridMambaAttentionDynamicCache`). """ - return ( - self._supports_cache_class - and "jamba" not in self.__class__.__name__.lower() - and "zamba" not in self.__class__.__name__.lower() - and "bamba" not in self.__class__.__name__.lower() - and "minimax" not in self.__class__.__name__.lower() - and "lfm2" not in self.__class__.__name__.lower() + # NOTE: remove xlnet/reformer when the models are deprecated, non-standard model architecture/cache name + return not cls._is_stateful and all( + special_model_name not in cls.__name__.lower() + for special_model_name in [ + "reformer", + "minimax", + "xlnet", + "lfm2", + ] ) def _prepare_cache_for_generation( @@ -2076,7 +2056,7 @@ def _prepare_cache_for_generation( model_kwargs=model_kwargs, ) elif generation_config.cache_implementation == "quantized": - if not self._supports_quantized_cache: + if self.config.is_encoder_decoder or not self._supports_default_dynamic_cache(): raise ValueError( "This model does not support the quantized cache. If you want your model to support quantized " "cache, please open an issue and tag @zucchini-nlp." @@ -3708,33 +3688,6 @@ def _sample( else: return input_ids - # Auxiliary functions for beam search - def _temporary_reorder_cache(self, past_key_values, beam_idx): - """ - Temporary function to handle the different types of cache reordering processes while we roll out `Cache`. - - TODO: standardize cache formats and make all models compatible with `Cache`. It would remove the need - for this function, with `Cache.reorder_cache` being the sole remaining code path - """ - model_class = self.__class__.__name__.lower() - # Exception 1: code path for models using the legacy cache format - if isinstance(past_key_values, (tuple, list)): - past_key_values = self._reorder_cache(past_key_values, beam_idx) - # Exception 2: models with different cache formats. These are limited to `DynamicCache` until their - # cache format is standardized, to avoid adding complexity to the codebase. - elif "gptbigcode" in model_class: - if not isinstance(past_key_values, (DynamicCache, EncoderDecoderCache)): - raise ValueError( - f"Using an unsupported cache format with {model_class}. Currently, it only supports the " - "legacy tuple format or `DynamicCache`" - ) - past_key_values = self._reorder_cache(past_key_values, beam_idx) - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - # Standard code path: use the `Cache.reorder_cache` - else: - past_key_values.reorder_cache(beam_idx) - return past_key_values - @staticmethod def _flatten_beam_dim(tensor: torch.Tensor) -> torch.Tensor: """[batch_size, num_beams, ...] -> [batch_size * num_beams, ...]""" @@ -4230,11 +4183,13 @@ def _beam_search( # beam search as a whole (as opposed to individual beams, i.e. `stopping_criteria`) # pluck the cache from the beam indices that will be used in the next iteration + # NOTE: we need to check if `self._reorder_cache` exists for special models like RAG, RecurrentGemma etc. if model_kwargs.get("past_key_values", None) is not None: - model_kwargs["past_key_values"] = self._temporary_reorder_cache( - past_key_values=model_kwargs["past_key_values"], - beam_idx=self._flatten_beam_dim(running_beam_indices[..., cur_len - decoder_prompt_len]), - ) + beam_idx = self._flatten_beam_dim(running_beam_indices[..., cur_len - decoder_prompt_len]) + if hasattr(self, "_reorder_cache"): + model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx) + else: + model_kwargs["past_key_values"].reorder_cache(beam_idx) cur_len = cur_len + 1 is_early_stop_heuristic_unsatisfied = self._check_early_stop_heuristic( @@ -4537,10 +4492,14 @@ def _group_beam_search( # (that way the memory peak does not include outputs.logits) del outputs + # NOTE: we need to check if `self._reorder_cache` exists for special models like RAG, RecurrentGemma etc. if model_kwargs.get("past_key_values", None) is not None: - model_kwargs["past_key_values"] = self._temporary_reorder_cache( - model_kwargs["past_key_values"], reordering_indices - ) + if hasattr(self, "_reorder_cache"): + model_kwargs["past_key_values"] = self._reorder_cache( + model_kwargs["past_key_values"], reordering_indices + ) + else: + model_kwargs["past_key_values"].reorder_cache(reordering_indices) # increase cur_len cur_len = cur_len + 1 @@ -4774,10 +4733,12 @@ def _constrained_beam_search( # (that way the memory peak does not include outputs.logits) del outputs + # NOTE: we need to check if `self._reorder_cache` exists for special models like RAG, RecurrentGemma etc. if model_kwargs.get("past_key_values", None) is not None: - model_kwargs["past_key_values"] = self._temporary_reorder_cache( - model_kwargs["past_key_values"], beam_idx - ) + if hasattr(self, "_reorder_cache"): + model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx) + else: + model_kwargs["past_key_values"].reorder_cache(beam_idx) if return_dict_in_generate and output_scores: beam_indices = tuple(beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))) @@ -5002,8 +4963,7 @@ def _assisted_decoding( new_cur_len = input_ids.shape[1] # 4.2. Discard past key values relative to unused assistant tokens - new_cache_size = new_cur_len - 1 - outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size) + outputs.past_key_values.crop(new_cur_len - 1) # 5. Update the candidate generation strategy if needed candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 3202ef47c19d..f89827c2e630 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1971,13 +1971,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi # Flex Attention support _supports_flex_attn = False - # Has support for a `Cache` instance as `past_key_values`? Does it support a `StaticCache`? - _supports_cache_class = False + # Has support `torch.compile(fullgraph=True)` _supports_static_cache = False - # Has support for a `QuantoQuantizedCache` instance as `past_key_values` - _supports_quantized_cache = False - # A tensor parallel plan to be applied to the model when TP is enabled. For # top-level models, this attribute is currently defined in respective model # code. For base models, this attribute comes from diff --git a/src/transformers/models/albert/modeling_albert.py b/src/transformers/models/albert/modeling_albert.py index 7285c8ba569a..005665d324b9 100755 --- a/src/transformers/models/albert/modeling_albert.py +++ b/src/transformers/models/albert/modeling_albert.py @@ -271,12 +271,6 @@ def __init__(self, config: AlbertConfig): self.max_position_embeddings = config.max_position_embeddings self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) - # Copied from transformers.models.bert.modeling_bert.BertSelfAttention.transpose_for_scores - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def prune_heads(self, heads: list[int]) -> None: if len(heads) == 0: return @@ -302,13 +296,17 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, output_attentions: bool = False, ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]: - mixed_query_layer = self.query(hidden_states) - mixed_key_layer = self.key(hidden_states) - mixed_value_layer = self.value(hidden_states) - - query_layer = self.transpose_for_scores(mixed_query_layer) - key_layer = self.transpose_for_scores(mixed_key_layer) - value_layer = self.transpose_for_scores(mixed_value_layer) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states) + key_layer = self.key(hidden_states) + value_layer = self.value(hidden_states) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = value_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -378,9 +376,21 @@ def forward( return super().forward(hidden_states, attention_mask, output_attentions=output_attentions) batch_size, seq_len, _ = hidden_states.size() - query_layer = self.transpose_for_scores(self.query(hidden_states)) - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. diff --git a/src/transformers/models/arcee/modeling_arcee.py b/src/transformers/models/arcee/modeling_arcee.py index 8e1b1b168bfa..448ef08632e3 100644 --- a/src/transformers/models/arcee/modeling_arcee.py +++ b/src/transformers/models/arcee/modeling_arcee.py @@ -316,8 +316,7 @@ class ArceePreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_attention_backend = True _can_record_outputs = { diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index a42d04717e00..99c603012556 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -631,7 +631,7 @@ class AriaTextPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn = False _supports_sdpa = True - _supports_cache_class = True + _supports_attention_backend = True _can_record_outputs = { "hidden_states": AriaTextDecoderLayer, @@ -664,8 +664,6 @@ class AriaPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True _supports_static_cache = False # MoE models don't work with torch.compile (dynamic slicing) _supports_attention_backend = True _can_record_outputs = { diff --git a/src/transformers/models/aria/modular_aria.py b/src/transformers/models/aria/modular_aria.py index fa371db78c7a..aa9aea69f8ac 100644 --- a/src/transformers/models/aria/modular_aria.py +++ b/src/transformers/models/aria/modular_aria.py @@ -1286,7 +1286,7 @@ class AriaTextPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn = False _supports_sdpa = True - _supports_cache_class = True + _supports_attention_backend = True _can_record_outputs = { "hidden_states": AriaTextDecoderLayer, diff --git a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py index f37a5151feac..32a2c8bad1ea 100644 --- a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +++ b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py @@ -149,17 +149,28 @@ def __init__(self, config: ASTConfig) -> None: self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( - self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + self, + hidden_states, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(self.query(hidden_states)) + batch_size, seq_length, _ = hidden_states.shape + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/autoformer/modeling_autoformer.py b/src/transformers/models/autoformer/modeling_autoformer.py index 85f67ab6d16f..974d2a5e4d3c 100644 --- a/src/transformers/models/autoformer/modeling_autoformer.py +++ b/src/transformers/models/autoformer/modeling_autoformer.py @@ -26,6 +26,7 @@ from torch import nn from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, @@ -422,10 +423,11 @@ def __init__( self, embed_dim: int, num_heads: int, - dropout: float = 0.0, - is_decoder: bool = False, - bias: bool = True, - autocorrelation_factor: int = 3, + dropout: Optional[float] = 0.0, + is_decoder: Optional[bool] = False, + bias: Optional[bool] = True, + autocorrelation_factor: Optional[int] = 3, + layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim @@ -440,6 +442,7 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.layer_idx = layer_idx self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -448,69 +451,63 @@ def __init__( self.autocorrelation_factor = autocorrelation_factor - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() # get query proj query_states = self.q_proj(hidden_states) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states.reshape(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) # (1) period-based dependencies discovery # Resize (truncation or zero filling) @@ -631,7 +628,7 @@ def forward( attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights_reshaped class AutoformerEncoderLayer(GradientCheckpointingLayer): @@ -673,7 +670,7 @@ def forward( returned tensors for more detail. """ residual = hidden_states - hidden_states, attn_weights, _ = self.self_attn( + hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, @@ -709,7 +706,7 @@ def forward( class AutoformerDecoderLayer(GradientCheckpointingLayer): - def __init__(self, config: AutoformerConfig): + def __init__(self, config: AutoformerConfig, layer_idx=None): super().__init__() self.embed_dim = config.d_model @@ -719,6 +716,7 @@ def __init__(self, config: AutoformerConfig): dropout=config.attention_dropout, is_decoder=True, autocorrelation_factor=config.autocorrelation_factor, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -731,6 +729,7 @@ def __init__(self, config: AutoformerConfig): dropout=config.attention_dropout, is_decoder=True, autocorrelation_factor=config.autocorrelation_factor, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) @@ -760,9 +759,10 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -788,15 +788,13 @@ def forward( residual = hidden_states # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -805,20 +803,18 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -826,9 +822,6 @@ def forward( # added layer norm here as an improvement hidden_states = self.encoder_attn_layer_norm(hidden_states) - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # Fully Connected residual = hidden_states hidden_states = self.activation_fn(self.fc1(hidden_states)) @@ -849,9 +842,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -1047,7 +1037,9 @@ def __init__(self, config: AutoformerConfig): self.embed_positions = AutoformerSinusoidalPositionalEmbedding( config.context_length + config.prediction_length, config.d_model ) - self.layers = nn.ModuleList([AutoformerDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layers = nn.ModuleList( + [AutoformerDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)] + ) self.layernorm_embedding = nn.LayerNorm(config.d_model) # https://github.com/thuml/Autoformer/blob/e6371e24f2ae2dd53e472edefdd5814c5176f864/models/Autoformer.py#L74 @@ -1071,6 +1063,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple, AutoFormerDecoderOutput]: r""" Args: @@ -1149,6 +1142,22 @@ def forward( input_shape = inputs_embeds.size()[:-1] + if self.gradient_checkpointing and use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -1167,7 +1176,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1187,8 +1195,6 @@ def forward( if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - layer_outputs = decoder_layer( hidden_states, attention_mask, @@ -1196,16 +1202,14 @@ def forward( encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) (hidden_states, residual_trend) = layer_outputs[0] trend = trend + residual_trend - if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1219,17 +1223,26 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v - for v in [hidden_states, trend, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [ + hidden_states, + trend, + past_key_values, + all_hidden_states, + all_self_attns, + all_cross_attentions, + ] if v is not None ) return AutoFormerDecoderOutput( last_hidden_state=hidden_states, trend=trend, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, @@ -1431,6 +1444,7 @@ def forward( output_attentions: Optional[bool] = None, use_cache: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[AutoformerModelOutput, tuple]: r""" past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): @@ -1612,6 +1626,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) else: decoder_outputs = AutoFormerDecoderOutput() diff --git a/src/transformers/models/aya_vision/modeling_aya_vision.py b/src/transformers/models/aya_vision/modeling_aya_vision.py index ccd8d3a56acc..da420c82114f 100644 --- a/src/transformers/models/aya_vision/modeling_aya_vision.py +++ b/src/transformers/models/aya_vision/modeling_aya_vision.py @@ -93,10 +93,9 @@ class AyaVisionPreTrainedModel(PreTrainedModel): base_model_prefix = "" supports_gradient_checkpointing = True _skip_keys_device_placement = "past_key_values" - _supports_cache_class = True + _supports_flash_attn = True _supports_sdpa = True - _supports_quantized_cache = False _supports_static_cache = False _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/aya_vision/modular_aya_vision.py b/src/transformers/models/aya_vision/modular_aya_vision.py index a533ca32daa4..58c118d73fad 100644 --- a/src/transformers/models/aya_vision/modular_aya_vision.py +++ b/src/transformers/models/aya_vision/modular_aya_vision.py @@ -90,7 +90,6 @@ def pixel_shuffle(self, image_features): # B, S, D class AyaVisionPreTrainedModel(LlavaPreTrainedModel): - _supports_quantized_cache = False _supports_static_cache = False def _init_weights(self, module): diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index 0113ce6b8c85..3e63239970f6 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -1041,7 +1041,7 @@ class BambaPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn = True _supports_sdpa = True - _supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache + # Note: only supports HybridMambaAttentionDynamicCache _is_stateful = True def _init_weights(self, module): diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index 9c61066479c9..937b41113bd1 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -812,7 +812,7 @@ class BambaPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn = True _supports_sdpa = True - _supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache + # Note: only supports HybridMambaAttentionDynamicCache _is_stateful = True def _init_weights(self, module): diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index a4201faca56c..6f01ccd0d265 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -23,6 +23,7 @@ from torch import nn from torch.nn import functional as F +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...generation.logits_process import ( AlternatingCodebooksLogitsProcessor, @@ -66,7 +67,7 @@ class BarkSelfAttention(nn.Module): # adapted from GPTNeoSelfAttention and Bark code # BarkSelfAttention can have two attention type, i.e full attention or causal attention - def __init__(self, config, is_causal=False): + def __init__(self, config, is_causal=False, layer_idx=None): super().__init__() # regularization @@ -90,6 +91,7 @@ def __init__(self, config, is_causal=False): self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.bias) self.is_causal = is_causal + self.layer_idx = layer_idx if is_causal: block_size = config.block_size bias = torch.tril(torch.ones((block_size, block_size), dtype=bool)).view(1, 1, block_size, block_size) @@ -155,6 +157,7 @@ def forward( head_mask=None, use_cache=False, output_attentions=False, + cache_position=None, ): # calculate query, key, values for all heads in batch and move head forward to be the batch dim query, key, value = self.att_proj(hidden_states).split(self.embed_dim, dim=2) @@ -164,15 +167,7 @@ def forward( value = self._split_heads(value, self.num_heads, self.head_dim) if past_key_values is not None: - past_key = past_key_values[0] - past_value = past_key_values[1] - key = torch.cat((past_key, key), dim=-2) - value = torch.cat((past_value, value), dim=-2) - - if use_cache is True: - present = (key, value) - else: - present = None + key, value = past_key_values.update(key, value, self.layer_idx, {"cache_position": cache_position}) attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) @@ -180,11 +175,7 @@ def forward( attn_output = self.out_proj(attn_output) attn_output = self.resid_dropout(attn_output) - outputs = (attn_output, present) - if output_attentions: - outputs += (attn_weights,) - - return outputs + return attn_output, attn_weights class BarkSelfFlashAttention2(BarkSelfAttention): @@ -229,6 +220,7 @@ def forward( head_mask=None, use_cache=False, output_attentions=False, + cache_position=None, ): batch_size, query_len, _ = hidden_states.size() @@ -240,18 +232,7 @@ def forward( value = self._split_heads(value, self.num_heads, self.head_dim) if past_key_values is not None: - # (batch, head, seq_length, head_features) -> (batch, seq_length, head, head_features) - past_key = past_key_values[0].transpose(1, 2) - past_value = past_key_values[1].transpose(1, 2) - # and merge on seq_length - key = torch.cat((past_key, key), dim=1) - value = torch.cat((past_value, value), dim=1) - - if use_cache is True: - # (batch, head, seq_length, head_features) - present = (key.transpose(1, 2), value.transpose(1, 2)) - else: - present = None + key, value = past_key_values.update(key, value, self.layer_idx, {"cache_position": cache_position}) attn_output = _flash_attention_forward( query, @@ -268,12 +249,7 @@ def forward( attn_output = self.out_proj(attn_output) attn_output = self.resid_dropout(attn_output) - outputs = (attn_output, present) - if output_attentions: - attn_weights = None - outputs += (attn_weights,) - - return outputs + return attn_output, None BARK_ATTENTION_CLASSES = { @@ -299,7 +275,7 @@ def forward(self, hidden_states): class BarkBlock(GradientCheckpointingLayer): - def __init__(self, config, is_causal=False): + def __init__(self, config, is_causal=False, layer_idx=None): super().__init__() if is_causal: @@ -311,7 +287,9 @@ def __init__(self, config, is_causal=False): self.layernorm_1 = nn.LayerNorm(config.hidden_size) self.layernorm_2 = nn.LayerNorm(config.hidden_size) - self.attn = BARK_ATTENTION_CLASSES[config._attn_implementation](config, is_causal=is_causal) + self.attn = BARK_ATTENTION_CLASSES[config._attn_implementation]( + config, is_causal=is_causal, layer_idx=layer_idx + ) self.mlp = BarkMLP(config) @@ -323,6 +301,7 @@ def forward( head_mask=None, use_cache=False, output_attentions=False, + cache_position=None, ): intermediary_hidden_states = self.layernorm_1(hidden_states) @@ -333,6 +312,7 @@ def forward( head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) attn_output = attn_outputs[0] # output_attn: output, present_key_values, (attn_weights) @@ -343,12 +323,7 @@ def forward( self.layernorm_2(intermediary_hidden_states) ) - if use_cache: - outputs = (intermediary_hidden_states,) + outputs - else: - outputs = (intermediary_hidden_states,) + outputs[1:] - - return outputs # hidden_states, ((present), attentions) + return (intermediary_hidden_states,) + outputs @auto_docstring @@ -411,7 +386,7 @@ def __init__(self, config): self.drop = nn.Dropout(config.dropout) - self.layers = nn.ModuleList([BarkBlock(config, is_causal=True) for _ in range(config.num_layers)]) + self.layers = nn.ModuleList([BarkBlock(config, is_causal=True, layer_idx=i) for i in range(config.num_layers)]) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self.layernorm_final = nn.LayerNorm(config.hidden_size, bias=config.bias) @@ -428,17 +403,17 @@ def get_input_embeddings(self): def set_input_embeddings(self, new_embeddings): self.input_embeds_layer = new_embeddings - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, cache_position=None, **kwargs): # Overwritten -- bark has a model-specific hack input_embeds = kwargs.get("input_embeds", None) attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) - if past_key_values is not None: + if cache_position[0] != 0: # Omit tokens covered by past_key_values seq_len = input_ids.shape[1] - past_length = past_key_values[0][0].shape[2] + past_length = past_key_values.get_seq_length() # Some generation methods already pass only the last input ID if input_ids.shape[1] > past_length: @@ -481,6 +456,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg "use_cache": kwargs.get("use_cache"), "position_ids": position_ids, "attention_mask": attention_mask, + "cache_position": cache_position, } return { "input_ids": input_ids, @@ -488,6 +464,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg "use_cache": kwargs.get("use_cache"), "position_ids": position_ids, "attention_mask": attention_mask, + "cache_position": cache_position, } @auto_docstring @@ -504,6 +481,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.Tensor], CausalLMOutputWithPast]: r""" input_embeds (`torch.FloatTensor` of shape `(batch_size, input_sequence_length, hidden_size)`, *optional*): @@ -546,11 +524,24 @@ def forward( device = input_ids.device if input_ids is not None else input_embeds.device - if past_key_values is None: - past_length = 0 - past_key_values = tuple([None] * len(self.layers)) - else: - past_length = past_key_values[0][0].size(-2) + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `DynamicCache` instead, e.g. " + "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + past_length = past_key_values.get_seq_length() if past_key_values is not None else past_key_values if position_ids is None: position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device) @@ -579,37 +570,27 @@ def forward( hidden_states = self.drop(input_embeds + position_embeds) output_shape = input_shape + (hidden_states.size(-1),) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - present_key_values = () if use_cache else None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None - for i, (block, past_layer_key_values) in enumerate(zip(self.layers, past_key_values)): + for i, block in enumerate(self.layers): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) outputs = block( hidden_states, - past_key_values=past_layer_key_values, + past_key_values=past_key_values, attention_mask=attention_mask, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = outputs[0] - if use_cache: - present_key_values = present_key_values + (outputs[1],) - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + all_self_attentions = all_self_attentions + (outputs[1],) hidden_states = self.layernorm_final(hidden_states) @@ -621,34 +602,22 @@ def forward( logits = self.lm_head(hidden_states) + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + if not return_dict: return tuple( - v for v in [None, logits, present_key_values, all_hidden_states, all_self_attentions] if v is not None + v for v in [None, logits, past_key_values, all_hidden_states, all_self_attentions] if v is not None ) return CausalLMOutputWithPast( loss=loss, logits=logits, - past_key_values=present_key_values, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, ) - @staticmethod - def _reorder_cache( - past_key_values: tuple[tuple[torch.Tensor]], beam_idx: torch.Tensor - ) -> tuple[tuple[torch.Tensor]]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - """ - # Necessary for beam_search - return tuple( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) - for layer_past in past_key_values - ) - @auto_docstring( custom_intro=""" @@ -1007,7 +976,9 @@ def __init__(self, config): self.drop = nn.Dropout(config.dropout) - self.layers = nn.ModuleList([BarkBlock(config, is_causal=False) for _ in range(config.num_layers)]) + self.layers = nn.ModuleList( + [BarkBlock(config, is_causal=False, layer_idx=i) for i in range(config.num_layers)] + ) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self.layernorm_final = nn.LayerNorm(config.hidden_size) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index f2729321f059..77665f5313b1 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -268,7 +268,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class BartEncoderLayer(GradientCheckpointingLayer): @@ -310,7 +310,7 @@ def forward( returned tensors for more detail. """ residual = hidden_states - hidden_states, attn_weights, _ = self.self_attn( + hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, @@ -411,7 +411,7 @@ def forward( residual = hidden_states # Self Attention - hidden_states, self_attn_weights, past_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -428,7 +428,7 @@ def forward( if encoder_hidden_states is not None: residual = hidden_states - hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, @@ -455,9 +455,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - if use_cache: - outputs += (past_key_value,) - return outputs @@ -496,7 +493,7 @@ class BartPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): @@ -1109,7 +1106,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1142,10 +1138,6 @@ def forward( cache_position=cache_position, ) hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[3 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1156,19 +1148,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, @@ -1543,17 +1534,6 @@ def forward( def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) - + layer_past[2:], - ) - return reordered_past - @auto_docstring( custom_intro=""" @@ -1981,15 +1961,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = [ "BartForCausalLM", diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index 2b5fb795eaea..9c964467e513 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -260,11 +260,6 @@ def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None) -> N if self.has_relative_position_bias: self.relative_position_bias = BeitRelativePositionBias(config, window_size=window_size) - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states: torch.Tensor, @@ -274,11 +269,22 @@ def forward( interpolate_pos_encoding: bool = False, resolution: Optional[tuple[int]] = None, ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]: - mixed_query_layer = self.query(hidden_states) - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + batch_size, seq_length, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -345,10 +351,22 @@ def forward( resolution=resolution, ) - mixed_query_layer = self.query(hidden_states) - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + batch_size, seq_length, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) attn_bias = None if self.has_relative_position_bias: diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 32c0406026dd..48832a1cf090 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -28,6 +28,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_layers import GradientCheckpointingLayer @@ -45,6 +46,7 @@ from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, get_torch_version, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_bert import BertConfig @@ -189,7 +191,7 @@ def forward( class BertSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -214,12 +216,9 @@ def __init__(self, config, position_embedding_type=None): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, @@ -227,53 +226,65 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if is_cross_attention and encoder_attention_mask is not None: + attention_mask = encoder_attention_mask - if is_cross_attention and past_key_value is not None: + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - - query_layer = self.transpose_for_scores(mixed_query_layer) - - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) + key_layer = self.key(current_states) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + value_layer = self.value(current_states) + value_layer = value_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: + if past_key_value is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -315,20 +326,17 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return context_layer, attention_probs class BertSdpaSelfAttention(BertSelfAttention): - def __init__(self, config, position_embedding_type=None): - super().__init__(config, position_embedding_type=position_embedding_type) + def __init__(self, config, position_embedding_type=None, layer_idx=None): + super().__init__(config, position_embedding_type=position_embedding_type, layer_idx=layer_idx) self.dropout_prob = config.attention_probs_dropout_prob self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") # Adapted from BertSelfAttention + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, @@ -336,8 +344,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None: # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. @@ -356,38 +365,59 @@ def forward( encoder_attention_mask, past_key_value, output_attentions, + cache_position, ) bsz, tgt_len, _ = hidden_states.size() - query_layer = self.transpose_for_scores(self.query(hidden_states)) + query_layer = ( + self.query(hidden_states).view(bsz, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + ) # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention # mask needs to be such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if is_cross_attention and encoder_attention_mask is not None: + attention_mask = encoder_attention_mask current_states = encoder_hidden_states if is_cross_attention else hidden_states - attention_mask = encoder_attention_mask if is_cross_attention else attention_mask + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value - # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning - if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: - key_layer, value_layer = past_key_value + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) - if past_key_value is not None and not is_cross_attention: - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) + key_layer = ( + self.key(current_states) + .view(bsz, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(current_states) + .view(bsz, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. @@ -417,10 +447,7 @@ def forward( attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) - outputs = (attn_output,) - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return attn_output, None class BertSelfOutput(nn.Module): @@ -444,10 +471,12 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class BertAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() self.self = BERT_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, position_embedding_type=position_embedding_type + config, + position_embedding_type=position_embedding_type, + layer_idx=layer_idx, ) self.output = BertSelfOutput(config) self.pruned_heads = set() @@ -470,6 +499,7 @@ def prune_heads(self, heads): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, @@ -477,17 +507,19 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -524,17 +556,17 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class BertLayer(GradientCheckpointingLayer): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = BertAttention(config) + self.attention = BertAttention(config, layer_idx=layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = BertAttention(config, position_embedding_type="absolute") + self.crossattention = BertAttention(config, position_embedding_type="absolute", layer_idx=layer_idx) self.intermediate = BertIntermediate(config) self.output = BertOutput(config) @@ -545,28 +577,21 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -574,33 +599,23 @@ def forward( " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - cross_attn_past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (present_key_value,) - return outputs def feed_forward_chunk(self, attention_output): @@ -610,10 +625,10 @@ def feed_forward_chunk(self, attention_output): class BertEncoder(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.config = config - self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([BertLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( @@ -628,6 +643,7 @@ def forward( output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -640,13 +656,21 @@ def forward( ) use_cache = False - next_decoder_cache = () if use_cache else None + return_legacy_cache = False + if use_cache and self.config.is_decoder and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None layer_outputs = layer_module( hidden_states, @@ -654,13 +678,12 @@ def forward( layer_head_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -669,12 +692,15 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -683,7 +709,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -893,6 +919,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -918,8 +945,13 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if token_type_ids is None: if hasattr(self.embeddings, "token_type_ids"): @@ -1004,6 +1036,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None @@ -1170,7 +1203,8 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - **kwargs, + cache_position: Optional[torch.Tensor] = None, + **loss_kwargs, ) -> Union[tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -1196,6 +1230,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = outputs[0] @@ -1203,7 +1238,7 @@ def forward( lm_loss = None if labels is not None: - lm_loss = self.loss_function(prediction_scores, labels, self.config.vocab_size, **kwargs) + lm_loss = self.loss_function(prediction_scores, labels, self.config.vocab_size, **loss_kwargs) if not return_dict: output = (prediction_scores,) + outputs[2:] @@ -1218,14 +1253,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @auto_docstring class BertForMaskedLM(BertPreTrainedModel): diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index bd65e88ae855..9dd0f3931101 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -22,15 +22,14 @@ from torch import nn from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer -from ...utils import ( - auto_docstring, - logging, -) +from ...utils import auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_bert_generation import BertGenerationConfig @@ -54,7 +53,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->BertGeneration class BertGenerationSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -79,12 +78,9 @@ def __init__(self, config, position_embedding_type=None): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, @@ -92,53 +88,65 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if is_cross_attention and encoder_attention_mask is not None: + attention_mask = encoder_attention_mask + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value - if is_cross_attention and past_key_value is not None: + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - - query_layer = self.transpose_for_scores(mixed_query_layer) - - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) + key_layer = self.key(current_states) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + value_layer = self.value(current_states) + value_layer = value_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: + if past_key_value is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -180,11 +188,7 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return context_layer, attention_probs BERT_GENERATION_SELF_ATTENTION_CLASSES = { @@ -194,10 +198,12 @@ def forward( # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->BertGeneration,BERT->BERT_GENERATION class BertGenerationAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() self.self = BERT_GENERATION_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, position_embedding_type=position_embedding_type + config, + position_embedding_type=position_embedding_type, + layer_idx=layer_idx, ) self.output = BertGenerationSelfOutput(config) self.pruned_heads = set() @@ -220,6 +226,7 @@ def prune_heads(self, heads): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, @@ -227,17 +234,19 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -277,17 +286,19 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->BertGeneration class BertGenerationLayer(GradientCheckpointingLayer): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = BertGenerationAttention(config) + self.attention = BertGenerationAttention(config, layer_idx=layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = BertGenerationAttention(config, position_embedding_type="absolute") + self.crossattention = BertGenerationAttention( + config, position_embedding_type="absolute", layer_idx=layer_idx + ) self.intermediate = BertGenerationIntermediate(config) self.output = BertGenerationOutput(config) @@ -298,28 +309,21 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -327,33 +331,23 @@ def forward( " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - cross_attn_past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (present_key_value,) - return outputs def feed_forward_chunk(self, attention_output): @@ -362,12 +356,11 @@ def feed_forward_chunk(self, attention_output): return layer_output -# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->BertGeneration class BertEncoder(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.config = config - self.layer = nn.ModuleList([BertGenerationLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([BertGenerationLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( @@ -382,6 +375,7 @@ def forward( output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -394,13 +388,21 @@ def forward( ) use_cache = False - next_decoder_cache = () if use_cache else None + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None layer_outputs = layer_module( hidden_states, @@ -408,13 +410,12 @@ def forward( layer_head_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -423,12 +424,15 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -437,7 +441,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -678,8 +682,13 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) @@ -873,14 +882,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = [ "BertGenerationDecoder", diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index 1ac24c1bc238..3058bdc94f6e 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -26,6 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -296,7 +297,7 @@ def forward( class BigBirdSelfAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -314,11 +315,7 @@ def __init__(self, config): self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.is_decoder = config.is_decoder - - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) + self.layer_idx = layer_idx def forward( self, @@ -329,43 +326,41 @@ def forward( encoder_attention_mask=None, past_key_value=None, output_attentions=False, + cache_position=None, ): - mixed_query_layer = self.query(hidden_states) + batch_size, seq_length, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) - # If this is instantiated as a cross-attention module, the keys - # and values come from an encoder; the attention mask needs to be - # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - - if is_cross_attention and past_key_value is not None: + current_states = encoder_hidden_states if is_cross_attention else hidden_states + attention_mask = encoder_attention_mask if is_cross_attention else attention_mask + if is_cross_attention and past_key_value is not None and past_key_value.get_seq_length(self.layer_idx) > 0: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = past_key_value.key_cache[self.layer_idx] + value_layer = past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - - query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = ( + self.key(current_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(current_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + key_layer, value_layer = past_key_value.update( + key_layer, + value_layer, + self.layer_idx, + ) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -392,11 +387,7 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(*new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return context_layer, attention_probs class BigBirdBlockSparseAttention(nn.Module): @@ -423,11 +414,6 @@ def __init__(self, config, seed=None): self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias) - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states, @@ -450,9 +436,21 @@ def forward( if to_seq_length % to_block_size != 0: raise ValueError("Key/Value sided sequence length must be multiple of block size") - query_layer = self.transpose_for_scores(self.query(hidden_states)) - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) context_layer, attention_probs = self.bigbird_block_sparse_attention( query_layer, @@ -478,9 +476,7 @@ def forward( ) context_layer = context_layer.contiguous().view(batch_size, from_seq_length, -1) - - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - return outputs + return context_layer, attention_probs @staticmethod def torch_bmm_nd(inp_1, inp_2, ndim=None): @@ -1310,7 +1306,7 @@ def __init__(self, config, seed=None): self.seed = seed if self.config.attention_type == "original_full": - self.self = BigBirdSelfAttention(config) + self.self = BigBirdSelfAttention(config, layer_idx=seed) elif self.config.attention_type == "block_sparse": self.self = BigBirdBlockSparseAttention(config, seed) else: @@ -1320,7 +1316,7 @@ def __init__(self, config, seed=None): self.output = BigBirdSelfOutput(config) - def set_attention_type(self, value: str): + def set_attention_type(self, value: str, layer_idx=None): if value not in ["original_full", "block_sparse"]: raise ValueError( f"attention_type can only be set to either 'original_full' or 'block_sparse', but is {value}" @@ -1332,7 +1328,7 @@ def set_attention_type(self, value: str): self.attention_type = value if value == "original_full": # copy all weights to new full attention class - attn_weights = BigBirdSelfAttention(self.config) + attn_weights = BigBirdSelfAttention(self.config, layer_idx=layer_idx) else: # copy all weights to new sparse attention class attn_weights = BigBirdBlockSparseAttention(self.config, self.seed) @@ -1359,6 +1355,7 @@ def forward( to_mask=None, from_blocked_mask=None, to_blocked_mask=None, + cache_position=None, ): # fp16 compatibility if band_mask is not None: @@ -1370,12 +1367,13 @@ def forward( if self.attention_type == "original_full": self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, ) else: if encoder_hidden_states is not None: @@ -1433,11 +1431,11 @@ def __init__(self, config, seed=None): if self.add_cross_attention: if not self.is_decoder: raise TypeError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = BigBirdAttention(config) + self.crossattention = BigBirdAttention(config, seed=seed) self.intermediate = BigBirdIntermediate(config) self.output = BigBirdOutput(config) - def set_attention_type(self, value: str): + def set_attention_type(self, value: str, layer_idx=None): if value not in ["original_full", "block_sparse"]: raise ValueError( f"attention_type can only be set to either 'original_full' or 'block_sparse', but is {value}" @@ -1446,10 +1444,10 @@ def set_attention_type(self, value: str): if value == self.attention_type: return self.attention_type = value - self.attention.set_attention_type(value) + self.attention.set_attention_type(value, layer_idx=layer_idx) if self.add_cross_attention: - self.crossattention.set_attention_type(value) + self.crossattention.set_attention_type(value, layer_idx=layer_idx) def forward( self, @@ -1464,33 +1462,27 @@ def forward( blocked_encoder_mask=None, past_key_value=None, output_attentions=False, + cache_position=None, ): # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, band_mask=band_mask, from_mask=from_mask, to_mask=to_mask, from_blocked_mask=blocked_encoder_mask, to_blocked_mask=blocked_encoder_mask, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -1498,35 +1490,23 @@ def forward( " cross-attention layers by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - cross_attn_past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) - outputs = (layer_output,) + outputs - - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (present_key_value,) - - return outputs + return (layer_output,) + outputs def feed_forward_chunk(self, attention_output): intermediate_output = self.intermediate(attention_output) @@ -1554,8 +1534,8 @@ def set_attention_type(self, value: str): if value == self.attention_type: return self.attention_type = value - for layer in self.layer: - layer.set_attention_type(value) + for i, layer in enumerate(self.layer): + layer.set_attention_type(value, layer_idx=i) def forward( self, @@ -1573,6 +1553,7 @@ def forward( to_mask=None, blocked_encoder_mask=None, return_dict=True, + cache_position=None, ) -> Union[BaseModelOutputWithPastAndCrossAttentions, tuple]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -1585,14 +1566,21 @@ def forward( ) use_cache = False - next_decoder_cache = () if use_cache else None + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `DynamicCache` instead, e.g. " + "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None layer_outputs = layer_module( hidden_states, @@ -1604,13 +1592,12 @@ def forward( from_mask, to_mask, blocked_encoder_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -1619,12 +1606,15 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -1633,7 +1623,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -1868,6 +1858,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs, # NOOP kwargs, for now ) -> Union[BaseModelOutputWithPoolingAndCrossAttentions, tuple[torch.FloatTensor]]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -1894,8 +1885,13 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) @@ -2007,6 +2003,7 @@ def forward( to_mask=to_mask, blocked_encoder_mask=blocked_encoder_mask, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = encoder_outputs[0] @@ -2396,6 +2393,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs, ) -> Union[CausalLMOutputWithCrossAttentions, tuple[torch.FloatTensor]]: r""" @@ -2420,6 +2418,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, **kwargs, ) @@ -2448,15 +2447,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) - + layer_past[2:], - ) - return reordered_past - class BigBirdClassificationHead(nn.Module): """Head for sentence-level classification tasks.""" diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 1e126fbcaff8..2466400b82b3 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -108,7 +108,7 @@ def forward(self, input_ids: torch.Tensor): # Copied from transformers.models.big_bird.modeling_big_bird.BigBirdSelfAttention with BigBird->BigBirdPegasus class BigBirdPegasusSelfAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -126,11 +126,7 @@ def __init__(self, config): self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.is_decoder = config.is_decoder - - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) + self.layer_idx = layer_idx def forward( self, @@ -141,43 +137,41 @@ def forward( encoder_attention_mask=None, past_key_value=None, output_attentions=False, + cache_position=None, ): - mixed_query_layer = self.query(hidden_states) + batch_size, seq_length, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) - # If this is instantiated as a cross-attention module, the keys - # and values come from an encoder; the attention mask needs to be - # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - - if is_cross_attention and past_key_value is not None: + current_states = encoder_hidden_states if is_cross_attention else hidden_states + attention_mask = encoder_attention_mask if is_cross_attention else attention_mask + if is_cross_attention and past_key_value is not None and past_key_value.get_seq_length(self.layer_idx) > 0: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = past_key_value.key_cache[self.layer_idx] + value_layer = past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - - query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = ( + self.key(current_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(current_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + key_layer, value_layer = past_key_value.update( + key_layer, + value_layer, + self.layer_idx, + ) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -204,11 +198,7 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(*new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return context_layer, attention_probs # Copied from transformers.models.big_bird.modeling_big_bird.BigBirdBlockSparseAttention with BigBird->BigBirdPegasus @@ -236,11 +226,6 @@ def __init__(self, config, seed=None): self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias) - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states, @@ -263,9 +248,21 @@ def forward( if to_seq_length % to_block_size != 0: raise ValueError("Key/Value sided sequence length must be multiple of block size") - query_layer = self.transpose_for_scores(self.query(hidden_states)) - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) context_layer, attention_probs = self.bigbird_block_sparse_attention( query_layer, @@ -291,9 +288,7 @@ def forward( ) context_layer = context_layer.contiguous().view(batch_size, from_seq_length, -1) - - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - return outputs + return context_layer, attention_probs @staticmethod def torch_bmm_nd(inp_1, inp_2, ndim=None): @@ -1331,7 +1326,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class BigBirdPegasusEncoderLayer(GradientCheckpointingLayer): @@ -1492,7 +1487,7 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - hidden_states, self_attn_weights, past_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -1509,7 +1504,7 @@ def forward( residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, @@ -1534,9 +1529,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - if use_cache: - outputs += (past_key_value,) - return outputs @@ -1573,7 +1565,7 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel): _no_split_modules = ["BigBirdPegasusEncoderLayer", "BigBirdPegasusDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_param_buffer_assignment = False - _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): @@ -2265,7 +2257,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -2298,9 +2289,6 @@ def forward( ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[3 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -2313,19 +2301,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, @@ -2656,17 +2643,6 @@ def forward( def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) - + layer_past[2:], - ) - return reordered_past - @auto_docstring( custom_intro=""" @@ -3064,15 +3040,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = [ "BigBirdPegasusForCausalLM", diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index 390b57a4179d..543c6cba5c42 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -245,7 +245,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class BioGptDecoderLayer(GradientCheckpointingLayer): @@ -307,7 +307,7 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - hidden_states, self_attn_weights, past_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -335,9 +335,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (past_key_value,) - return outputs @@ -349,7 +346,7 @@ class BioGptPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): @@ -635,7 +632,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = None - next_decoder_cache = () if use_cache else None for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) @@ -660,9 +656,6 @@ def forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -672,19 +665,18 @@ def forward( hidden_states = self.layer_norm(hidden_states) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, @@ -779,15 +771,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @auto_docstring class BioGptForTokenClassification(BioGptPreTrainedModel): diff --git a/src/transformers/models/biogpt/modular_biogpt.py b/src/transformers/models/biogpt/modular_biogpt.py index 24d1c77fb688..0994ff64693e 100644 --- a/src/transformers/models/biogpt/modular_biogpt.py +++ b/src/transformers/models/biogpt/modular_biogpt.py @@ -132,7 +132,7 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - hidden_states, self_attn_weights, past_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -160,9 +160,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (past_key_value,) - return outputs @@ -174,7 +171,7 @@ class BioGptPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): @@ -460,7 +457,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = None - next_decoder_cache = () if use_cache else None for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) @@ -485,9 +481,6 @@ def forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -497,19 +490,18 @@ def forward( hidden_states = self.layer_norm(hidden_states) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, @@ -604,15 +596,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @auto_docstring class BioGptForTokenClassification(BioGptPreTrainedModel): diff --git a/src/transformers/models/bitnet/modeling_bitnet.py b/src/transformers/models/bitnet/modeling_bitnet.py index c2e84df25685..a9fb4a30f0c5 100644 --- a/src/transformers/models/bitnet/modeling_bitnet.py +++ b/src/transformers/models/bitnet/modeling_bitnet.py @@ -311,8 +311,7 @@ class BitNetPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_attention_backend = True _can_record_outputs = { diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 2af2d994489d..a94a31a04b2b 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -267,7 +267,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Blenderbot, MBART->BLENDERBOT @@ -310,7 +310,7 @@ def forward( """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states, attn_weights, _ = self.self_attn( + hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, @@ -331,12 +331,7 @@ def forward( clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs + return hidden_states, attn_weights # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Blenderbot, MBART->BLENDERBOT @@ -410,7 +405,7 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - hidden_states, self_attn_weights, past_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -427,7 +422,7 @@ def forward( residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, @@ -452,9 +447,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - if use_cache: - outputs += (past_key_value,) - return outputs @@ -466,7 +458,7 @@ class BlenderbotPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): @@ -1063,7 +1055,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1096,9 +1087,6 @@ def forward( ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[3 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1112,19 +1100,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, @@ -1480,17 +1467,6 @@ def forward( encoder_attentions=outputs.encoder_attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) - + layer_past[2:], - ) - return reordered_past - # Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->Blenderbot class BlenderbotDecoderWrapper(BlenderbotPreTrainedModel): @@ -1631,15 +1607,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = [ "BlenderbotForCausalLM", diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index f81718d68cb1..c6abb963009f 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -251,7 +251,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights # Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL @@ -294,7 +294,7 @@ def forward( returned tensors for more detail. """ residual = hidden_states - hidden_states, attn_weights, _ = self.self_attn( + hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, @@ -396,7 +396,7 @@ def forward( residual = hidden_states # Self Attention - hidden_states, self_attn_weights, past_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -413,7 +413,7 @@ def forward( if encoder_hidden_states is not None: residual = hidden_states - hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, @@ -440,9 +440,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - if use_cache: - outputs += (past_key_value,) - return outputs @@ -454,7 +451,7 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): @@ -1046,7 +1043,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1079,9 +1075,6 @@ def forward( ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[3 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1092,19 +1085,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, @@ -1434,17 +1426,6 @@ def forward( encoder_attentions=outputs.encoder_attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) - + layer_past[2:], - ) - return reordered_past - # Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->BlenderbotSmall class BlenderbotSmallDecoderWrapper(BlenderbotSmallPreTrainedModel): @@ -1585,15 +1566,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = [ "BlenderbotSmallForCausalLM", diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index 2dac6b3493fc..821bd783c67a 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -23,6 +23,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -97,7 +98,7 @@ def forward( # Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L97 class BlipTextSelfAttention(nn.Module): - def __init__(self, config, is_cross_attention): + def __init__(self, config, is_cross_attention, layer_idx=None): super().__init__() self.config = config if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): @@ -109,6 +110,7 @@ def __init__(self, config, is_cross_attention): self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size + self.layer_idx = layer_idx self.query = nn.Linear(config.hidden_size, self.all_head_size) if is_cross_attention: @@ -136,11 +138,6 @@ def save_attention_map(self, attention_map): def get_attention_map(self): return self.attention_map - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states: torch.Tensor, @@ -148,32 +145,60 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + batch_size, seq_length, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + attention_mask = encoder_attention_mask if is_cross_attention else attention_mask + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value - if is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - - query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = ( + self.key(current_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(current_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) - past_key_value = (key_layer, value_layer) + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -216,10 +241,7 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(*new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - outputs = outputs + (past_key_value,) - return outputs + return context_layer, attention_probs # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert -> BlipText @@ -239,9 +261,9 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#242 class BlipTextAttention(nn.Module): - def __init__(self, config, is_cross_attention=False): + def __init__(self, config, is_cross_attention=False, layer_idx=None): super().__init__() - self.self = BlipTextSelfAttention(config, is_cross_attention) + self.self = BlipTextSelfAttention(config, is_cross_attention, layer_idx=layer_idx) self.output = BlipTextSelfOutput(config) self.pruned_heads = set() @@ -269,18 +291,18 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -324,10 +346,12 @@ def __init__(self, config, layer_num): self.config = config self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = BlipTextAttention(config) + self.attention = BlipTextAttention(config, layer_idx=layer_num) self.layer_num = layer_num if self.config.is_decoder: - self.crossattention = BlipTextAttention(config, is_cross_attention=self.config.is_decoder) + self.crossattention = BlipTextAttention( + config, is_cross_attention=self.config.is_decoder, layer_idx=layer_num + ) self.intermediate = BlipTextIntermediate(config) self.output = BlipTextOutput(config) @@ -338,42 +362,37 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] - - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] + outputs = self_attention_outputs[1:] if encoder_hidden_states is not None: cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) - outputs = (layer_output,) + outputs - - outputs = outputs + (present_key_value,) - - return outputs + return (layer_output,) + outputs def feed_forward_chunk(self, attention_output): intermediate_output = self.intermediate(attention_output) @@ -401,6 +420,7 @@ def forward( output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: if self.gradient_checkpointing and self.training: if use_cache: @@ -408,11 +428,25 @@ def forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False + + return_legacy_cache = False + if use_cache: + if not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + # The model acts as encoder decoder but is not an encoder decoder. So we cast all cache objects to + # `EncoderDecoderCache` type assuming that the incoming cache is from `self_attention` + elif isinstance(past_key_values, DynamicCache): + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) + all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.is_decoder else None - - next_decoder_cache = () if use_cache else None + all_cross_attentions = () if output_attentions and encoder_hidden_states is not None else None for i in range(self.config.num_hidden_layers): layer_module = self.layer[i] @@ -420,7 +454,6 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None layer_outputs = layer_module( hidden_states, @@ -428,26 +461,29 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) - all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + if encoder_hidden_states is not None: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -456,7 +492,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -669,6 +705,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, is_decoder: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: r""" encoder_hidden_states (`torch.FloatTensor`, *optional*): @@ -717,8 +754,13 @@ def forward( else: raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds") - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if attention_mask is None: attention_mask = torch.ones((batch_size, seq_length + past_key_values_length)).to(device) @@ -776,6 +818,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None @@ -833,6 +876,7 @@ def forward( return_logits: Optional[bool] = False, is_decoder: Optional[bool] = True, reduction: Optional[str] = "mean", + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: r""" encoder_hidden_states (`torch.FloatTensor`, *optional*): Sequence of @@ -874,6 +918,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, is_decoder=is_decoder, + cache_position=cache_position, ) sequence_output = outputs[0] @@ -908,40 +953,15 @@ def forward( def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs): # Overwrite -- hardcoded key return (`is_decoder=True`) - input_shape = input_ids.shape - # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly - if attention_mask is None: - attention_mask = input_ids.new_ones(input_shape) - - # cut decoder_input_ids if past_key_values is used - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + **model_kwargs, + ) + model_inputs["is_decoder"] = True - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - - input_ids = input_ids[:, remove_prefix_length:] - - return { - "input_ids": input_ids, - "attention_mask": attention_mask, - "past_key_values": past_key_values, - "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), - "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), - "is_decoder": True, - } - - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past + return model_inputs __all__ = ["BlipTextModel", "BlipTextLMHeadModel", "BlipTextPreTrainedModel"] diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 21c533dd4f13..8235767b7e06 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -1830,10 +1830,8 @@ def forward( class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): config_class = Blip2Config main_input_name = "pixel_values" - _supports_cache_class = True - _supports_static_cache = True - _supports_quantized_cache = False # not all LM bacbones support (e.g. T5) + _supports_static_cache = True _keep_in_fp32_modules = ["query_tokens", "qformer"] _supports_flash_attn = False # because self.qformer does not support FA2 diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 0360bdf6f724..242061eb4e40 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -433,9 +433,8 @@ class BloomPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["BloomBlock"] _skip_keys_device_placement = "past_key_values" - _supports_cache_class = True + _supports_static_cache = True - _supports_quantized_cache = True def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) @@ -506,7 +505,7 @@ def forward( ) -> Union[tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]` + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -846,7 +845,7 @@ def forward( ) -> Union[tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]` + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -915,29 +914,6 @@ def forward( attentions=transformer_outputs.attentions, ) - def _reorder_cache( - self, past: tuple[tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor - ) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - - Output shares the same memory storage as `past`. - """ - # Get a copy of `beam_idx` on all the devices where we need those indices. - device_to_beam_idx = { - past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past - } - reordered_past = tuple( - ( - layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]), - layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]), - ) - for layer_past in past - ) - return reordered_past - @auto_docstring( custom_intro=""" @@ -980,7 +956,7 @@ def forward( ) -> Union[tuple[torch.Tensor], SequenceClassifierOutputWithPast]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]` + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -1116,7 +1092,7 @@ def forward( ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]` + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -1207,7 +1183,7 @@ def forward( ) -> Union[tuple, QuestionAnsweringModelOutput]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]` + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index 1635c7b3d454..47f68fe23c0a 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -25,6 +25,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN, QuickGELUActivation +from ...cache_utils import Cache, EncoderDecoderCache from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -36,6 +37,7 @@ from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging, torch_int +from ...utils.deprecation import deprecate_kwarg from .configuration_bridgetower import BridgeTowerConfig, BridgeTowerTextConfig, BridgeTowerVisionConfig @@ -401,7 +403,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->BridgeTower class BridgeTowerSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -426,12 +428,9 @@ def __init__(self, config, position_embedding_type=None): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, @@ -439,53 +438,65 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + if is_cross_attention and encoder_attention_mask is not None: attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) - else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] + else: + key_layer = self.key(current_states) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + value_layer = self.value(current_states) + value_layer = value_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: + if past_key_value is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -527,11 +538,7 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return context_layer, attention_probs BRIDGE_TOWER_SELF_ATTENTION_CLASSES = { @@ -541,10 +548,12 @@ def forward( # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->BridgeTower,BERT->BRIDGE_TOWER class BridgeTowerAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() self.self = BRIDGE_TOWER_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, position_embedding_type=position_embedding_type + config, + position_embedding_type=position_embedding_type, + layer_idx=layer_idx, ) self.output = BridgeTowerSelfOutput(config) self.pruned_heads = set() @@ -567,6 +576,7 @@ def prune_heads(self, heads): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, @@ -574,17 +584,19 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -592,14 +604,14 @@ def forward( class BridgeTowerBertCrossLayer(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = BridgeTowerAttention(config) + self.attention = BridgeTowerAttention(config, layer_idx=layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention - self.crossattention = BridgeTowerAttention(config) + self.crossattention = BridgeTowerAttention(config, layer_idx=layer_idx) self.intermediate = BridgeTowerIntermediate(config) self.output = BridgeTowerOutput(config) @@ -612,6 +624,7 @@ def forward( encoder_attention_mask=None, past_key_value=None, output_attentions=False, + cache_position=None, ): # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 self_attention_outputs = self.attention( @@ -629,16 +642,16 @@ def forward( cross_attention_outputs = self.crossattention( attention_output, - attention_mask=attention_mask, + attention_mask=encoder_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, past_key_value=past_key_value, output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] # add cross attentions if we output attention weights - outputs = outputs + cross_attention_outputs[1:-1] + outputs = outputs + cross_attention_outputs[1:] layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output @@ -654,17 +667,17 @@ def feed_forward_chunk(self, attention_output): class BridgeTowerTextLayer(GradientCheckpointingLayer): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = BridgeTowerAttention(config) + self.attention = BridgeTowerAttention(config, layer_idx=layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = BridgeTowerAttention(config, position_embedding_type="absolute") + self.crossattention = BridgeTowerAttention(config, position_embedding_type="absolute", layer_idx=layer_idx) self.intermediate = BridgeTowerIntermediate(config) self.output = BridgeTowerOutput(config) @@ -675,28 +688,27 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] # if decoder, the last output is tuple of self-attn cache if self.is_decoder: outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] else: outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -704,34 +716,22 @@ def forward( " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - cross_attn_past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) - outputs = (layer_output,) + outputs - - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (present_key_value,) - - return outputs + return (layer_output,) + outputs def feed_forward_chunk(self, attention_output): intermediate_output = self.intermediate(attention_output) @@ -741,10 +741,12 @@ def feed_forward_chunk(self, attention_output): # Copied from transformers.models.roberta.modeling_roberta.RobertaEncoder with Roberta->BridgeTowerText class BridgeTowerTextEncoder(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.config = config - self.layer = nn.ModuleList([BridgeTowerTextLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList( + [BridgeTowerTextLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)] + ) self.gradient_checkpointing = False def forward( @@ -759,6 +761,7 @@ def forward( output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -771,13 +774,21 @@ def forward( ) use_cache = False - next_decoder_cache = () if use_cache else None + return_legacy_cache = False + if use_cache and self.config.is_decoder and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None layer_outputs = layer_module( hidden_states, @@ -785,13 +796,12 @@ def forward( layer_head_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -800,12 +810,15 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -814,7 +827,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -1041,6 +1054,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1066,8 +1080,13 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) @@ -1120,6 +1139,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None @@ -1172,10 +1192,10 @@ def __init__(self, config): ln.bias.data = self.vision_model.visual.ln_post.bias.data self.cross_modal_image_layers = nn.ModuleList( - [BridgeTowerBertCrossLayer(text_config) for _ in range(config.num_hidden_layers)] + [BridgeTowerBertCrossLayer(text_config, layer_idx=i) for i in range(config.num_hidden_layers)] ) self.cross_modal_text_layers = nn.ModuleList( - [BridgeTowerBertCrossLayer(text_config) for _ in range(config.num_hidden_layers)] + [BridgeTowerBertCrossLayer(text_config, layer_idx=i) for i in range(config.num_hidden_layers)] ) # Class token => Linear => Tanh diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index 3dff8f3b2cf0..dcb0d243d031 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_layers import GradientCheckpointingLayer @@ -41,6 +42,7 @@ from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, get_torch_version, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_camembert import CamembertConfig @@ -139,7 +141,7 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds): # Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->Camembert class CamembertSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -164,12 +166,9 @@ def __init__(self, config, position_embedding_type=None): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, @@ -177,53 +176,65 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if is_cross_attention and encoder_attention_mask is not None: + attention_mask = encoder_attention_mask - if is_cross_attention and past_key_value is not None: + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - - query_layer = self.transpose_for_scores(mixed_query_layer) - - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) + key_layer = self.key(current_states) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + value_layer = self.value(current_states) + value_layer = value_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: + if past_key_value is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -265,21 +276,18 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return context_layer, attention_probs # Copied from transformers.models.roberta.modeling_roberta.RobertaSdpaSelfAttention with Roberta->Camembert class CamembertSdpaSelfAttention(CamembertSelfAttention): - def __init__(self, config, position_embedding_type=None): - super().__init__(config, position_embedding_type=position_embedding_type) + def __init__(self, config, position_embedding_type=None, layer_idx=None): + super().__init__(config, position_embedding_type=position_embedding_type, layer_idx=layer_idx) self.dropout_prob = config.attention_probs_dropout_prob self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") # Adapted from CamembertSelfAttention + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, @@ -287,8 +295,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None: # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. @@ -307,38 +316,59 @@ def forward( encoder_attention_mask, past_key_value, output_attentions, + cache_position, ) bsz, tgt_len, _ = hidden_states.size() - query_layer = self.transpose_for_scores(self.query(hidden_states)) + query_layer = ( + self.query(hidden_states).view(bsz, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + ) # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention # mask needs to be such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if is_cross_attention and encoder_attention_mask is not None: + attention_mask = encoder_attention_mask current_states = encoder_hidden_states if is_cross_attention else hidden_states - attention_mask = encoder_attention_mask if is_cross_attention else attention_mask + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value - # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning - if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: - key_layer, value_layer = past_key_value + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) - if past_key_value is not None and not is_cross_attention: - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) + key_layer = ( + self.key(current_states) + .view(bsz, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(current_states) + .view(bsz, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. @@ -368,10 +398,7 @@ def forward( attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) - outputs = (attn_output,) - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return attn_output, None # Copied from transformers.models.roberta.modeling_roberta.RobertaSelfOutput with Roberta->Camembert @@ -397,10 +424,12 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.roberta.modeling_roberta.RobertaAttention with Roberta->Camembert,ROBERTA->CAMEMBERT class CamembertAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() self.self = CAMEMBERT_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, position_embedding_type=position_embedding_type + config, + position_embedding_type=position_embedding_type, + layer_idx=layer_idx, ) self.output = CamembertSelfOutput(config) self.pruned_heads = set() @@ -423,6 +452,7 @@ def prune_heads(self, heads): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, @@ -430,17 +460,19 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -480,17 +512,17 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.roberta.modeling_roberta.RobertaLayer with Roberta->Camembert class CamembertLayer(GradientCheckpointingLayer): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = CamembertAttention(config) + self.attention = CamembertAttention(config, layer_idx=layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = CamembertAttention(config, position_embedding_type="absolute") + self.crossattention = CamembertAttention(config, position_embedding_type="absolute", layer_idx=layer_idx) self.intermediate = CamembertIntermediate(config) self.output = CamembertOutput(config) @@ -501,28 +533,21 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -530,33 +555,23 @@ def forward( " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - cross_attn_past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (present_key_value,) - return outputs def feed_forward_chunk(self, attention_output): @@ -567,10 +582,10 @@ def feed_forward_chunk(self, attention_output): # Copied from transformers.models.roberta.modeling_roberta.RobertaEncoder with Roberta->Camembert class CamembertEncoder(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.config = config - self.layer = nn.ModuleList([CamembertLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([CamembertLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( @@ -585,6 +600,7 @@ def forward( output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -597,13 +613,21 @@ def forward( ) use_cache = False - next_decoder_cache = () if use_cache else None + return_legacy_cache = False + if use_cache and self.config.is_decoder and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None layer_outputs = layer_module( hidden_states, @@ -611,13 +635,12 @@ def forward( layer_head_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -626,12 +649,15 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -640,7 +666,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -815,6 +841,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -840,8 +867,13 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if token_type_ids is None: if hasattr(self.embeddings, "token_type_ids"): @@ -926,6 +958,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None @@ -1541,14 +1574,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - # Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): diff --git a/src/transformers/models/canine/modeling_canine.py b/src/transformers/models/canine/modeling_canine.py index e19fe2534e1c..9866aad87a4e 100644 --- a/src/transformers/models/canine/modeling_canine.py +++ b/src/transformers/models/canine/modeling_canine.py @@ -410,11 +410,6 @@ def __init__(self, config): self.max_position_embeddings = config.max_position_embeddings self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, from_tensor: torch.Tensor, @@ -423,16 +418,27 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - mixed_query_layer = self.query(from_tensor) + batch_size, seq_length, _ = from_tensor.shape # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. - key_layer = self.transpose_for_scores(self.key(to_tensor)) - value_layer = self.transpose_for_scores(self.value(to_tensor)) - - query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = ( + self.key(to_tensor) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(to_tensor) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + query_layer = ( + self.query(from_tensor) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 0100f191d278..fe4899c7e932 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -377,10 +377,7 @@ def forward( attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights # copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Chameleon, LLAMA->CHAMELEON @@ -430,7 +427,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -453,9 +450,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -504,7 +498,7 @@ def forward( residual = hidden_states # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -526,9 +520,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -823,8 +814,7 @@ class ChameleonPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values", "causal_mask"] _supports_flash_attn = True _supports_sdpa = True - _supports_quantized_cache = True - _supports_cache_class = True + _supports_static_cache = True _supports_param_buffer_assignment = False _supports_flex_attn = True @@ -1009,7 +999,6 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: @@ -1028,9 +1017,6 @@ def forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1040,16 +1026,14 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = None - if use_cache: - next_cache = next_decoder_cache - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return tuple( + v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, ) diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index 6ab3ade7c25d..afe7bdb06a3b 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -926,11 +926,8 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - if attention_mask is None: - attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + attention_mask = torch.ones(((batch_size, seq_length)), device=device) if token_type_ids is None: if hasattr(self.embeddings, "token_type_ids"): diff --git a/src/transformers/models/clvp/modeling_clvp.py b/src/transformers/models/clvp/modeling_clvp.py index cd4af11d873d..16a079c3b83c 100644 --- a/src/transformers/models/clvp/modeling_clvp.py +++ b/src/transformers/models/clvp/modeling_clvp.py @@ -26,6 +26,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN, get_activation +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationConfig, GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( @@ -268,7 +269,7 @@ class ClvpSelfAttention(nn.Module): Multi-headed attention to combine Absolute and Rotary Positional Embeddings into a single Attention module. """ - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.config = config self.embed_dim = config.hidden_size @@ -281,6 +282,7 @@ def __init__(self, config): ) self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout + self.layer_idx = layer_idx if hasattr(config, "max_position_embeddings"): max_positions = config.max_position_embeddings @@ -302,10 +304,11 @@ def forward( rotary_pos_emb: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, use_cache: Optional[bool] = False, head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[tuple[torch.FloatTensor]]]: # Raise error when position_ids is None but rotary_pos_emb is provided, because we need that when applying # rotary_pos_emb to query and key states. @@ -320,14 +323,9 @@ def forward( value_states = self._shape(self.v_proj(hidden_states), -1, bsz) if past_key_value is not None: - past_key, past_value = past_key_value - key_states = torch.cat((past_key, key_states), dim=-2) - value_states = torch.cat((past_value, value_states), dim=-2) - - if use_cache is True: - present = (key_states, value_states) - else: - present = None + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) if rotary_pos_emb is not None: rotary_emb_dim = rotary_pos_emb.shape[-1] @@ -385,10 +383,7 @@ def forward( attn_output = self.out_proj(attn_output) - if not output_attentions: - attn_weights = None - - return attn_output, present, attn_weights + return attn_output, attn_weights class ClvpGatedLinearUnit(nn.Module): @@ -464,7 +459,7 @@ def forward( hidden_states = self.input_rmsnorm(hidden_states) - attention_outputs = self.self_attn( + hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, rotary_pos_emb=rotary_pos_emb, attention_mask=attention_mask, @@ -472,8 +467,6 @@ def forward( output_attentions=output_attentions, ) - hidden_states = attention_outputs[0] - hidden_states = residual + hidden_states residual = hidden_states @@ -481,12 +474,7 @@ def forward( hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - outputs = (hidden_states,) - - if output_attentions: - outputs += (attention_outputs[-1],) - - return outputs + return hidden_states, attn_weights # Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->Clvp @@ -608,13 +596,13 @@ def forward(self, hidden_states: Optional[tuple[torch.FloatTensor]]) -> torch.Fl class ClvpDecoderLayer(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() hidden_size = config.hidden_size inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size self.input_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = ClvpSelfAttention(config) + self.attn = ClvpSelfAttention(config, layer_idx=layer_idx) self.post_attention_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.mlp = ClvpDecoderMLP(inner_dim, config) @@ -622,12 +610,13 @@ def __init__(self, config): def forward( self, hidden_states: Optional[tuple[torch.FloatTensor]], - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.Tensor], Optional[tuple[torch.Tensor, tuple[torch.FloatTensor, ...]]]]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -639,9 +628,9 @@ def forward( head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) attn_output = attn_outputs[0] - outputs = attn_outputs[1:] # residual connection hidden_states = attn_output + residual @@ -651,12 +640,7 @@ def forward( # residual connection hidden_states = residual + feed_forward_hidden_states - if use_cache: - outputs = (hidden_states,) + outputs - else: - outputs = (hidden_states,) + outputs[1:] - - return outputs + return (hidden_states,) + attn_outputs[1:] class ClvpConditioningEncoder(nn.Module): @@ -1007,7 +991,9 @@ def __init__(self, config): self.position_embeds_layer = nn.Embedding(self.config.max_position_embeddings, self.config.hidden_size) self.drop = nn.Dropout(self.config.embd_pdrop) - self.layers = nn.ModuleList([ClvpDecoderLayer(self.config) for _ in range(self.config.num_hidden_layers)]) + self.layers = nn.ModuleList( + [ClvpDecoderLayer(self.config, layer_idx=i) for i in range(self.config.num_hidden_layers)] + ) self.layer_norm = nn.LayerNorm(self.config.hidden_size, eps=self.config.layer_norm_epsilon) self.gradient_checkpointing = False @@ -1042,6 +1028,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1068,11 +1055,24 @@ def forward( if token_type_ids is not None: token_type_ids = token_type_ids.view(-1, input_shape[-1]) - if past_key_values is None: - past_key_values_length = 0 - past_key_values = tuple([None] * len(self.layers)) - else: - past_key_values_length = past_key_values[0][0].size(-2) + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `DynamicCache` instead, e.g. " + "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if position_ids is None: position_ids = torch.arange( past_key_values_length, input_shape[-1] + past_key_values_length, dtype=torch.long, device=device @@ -1104,18 +1104,10 @@ def forward( output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - presents = () if use_cache else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_hidden_states = () if output_hidden_states else None - for i, (block, past_key_value) in enumerate(zip(self.layers, past_key_values)): + for i, block in enumerate(self.layers): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -1127,26 +1119,26 @@ def forward( attention_mask, position_ids, head_mask[i], + cache_position, ) else: outputs = block( hidden_states, - past_key_value=past_key_value, + past_key_value=past_key_values, attention_mask=attention_mask, position_ids=position_ids, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + all_self_attentions = all_self_attentions + (outputs[1],) if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + all_cross_attentions = all_cross_attentions + (outputs[2],) hidden_states = self.layer_norm(hidden_states) @@ -1156,16 +1148,19 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v - for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=presents, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -1205,6 +1200,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1226,6 +1222,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1332,56 +1329,28 @@ def _prepare_model_inputs( return inputs, input_name, model_kwargs def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, inputs_embeds=None, conditioning_embeds=None, **kwargs + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + conditioning_embeds=None, + cache_position=None, + **kwargs, ): # Overwritten: has `conditioning_embeds`-related logic input_ids_length = input_ids.shape[-1] - token_type_ids = kwargs.get("token_type_ids", None) - # only last token for inputs_ids if past is defined in kwargs - if past_key_values: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - - input_ids = input_ids[:, remove_prefix_length:] - if token_type_ids is not None: - token_type_ids = token_type_ids[:, -input_ids.shape[1] :] - - attention_mask = kwargs.get("attention_mask", None) - position_ids = kwargs.get("position_ids", None) - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) - else: - position_ids = None - - if conditioning_embeds is not None and past_key_values is not None: - position_ids = torch.tensor([input_ids_length], dtype=torch.long, device=input_ids.device) - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "position_ids": position_ids, - "token_type_ids": token_type_ids, - } + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + **kwargs, ) + if conditioning_embeds is not None and cache_position[0] != 0: + model_inputs["position_ids"] = torch.tensor([input_ids_length], dtype=torch.long, device=input_ids.device) + return model_inputs @auto_docstring @@ -1399,6 +1368,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple, CausalLMOutputWithCrossAttentions]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -1426,6 +1396,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = outputs[0] @@ -1456,20 +1427,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - @staticmethod - def _reorder_cache( - past_key_values: tuple[tuple[torch.Tensor]], beam_idx: torch.Tensor - ) -> tuple[tuple[torch.Tensor]]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - """ - return tuple( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) - for layer_past in past_key_values - ) - @auto_docstring( custom_intro=""" @@ -1708,6 +1665,7 @@ def forward( output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = False, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple, ClvpOutput]: r""" conditioning_encoder_inputs_embeds (`torch.FloatTensor`, *optional*): @@ -1762,6 +1720,7 @@ def forward( inputs_embeds=conditioning_embeds, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) speech_ids = decoder_outputs[0] diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index f39364180171..acc7877ede74 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -286,8 +286,7 @@ class CodeGenPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["CodeGenBlock"] _skip_keys_device_placement = "past_key_values" - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True def __init__(self, *inputs, **kwargs): @@ -676,19 +675,5 @@ def forward( attentions=transformer_outputs.attentions, ) - @staticmethod - def _reorder_cache( - past_key_values: tuple[tuple[torch.Tensor]], beam_idx: torch.Tensor - ) -> tuple[tuple[torch.Tensor]]: - """ - This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or - [`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - """ - return tuple( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) - for layer_past in past_key_values - ) - __all__ = ["CodeGenForCausalLM", "CodeGenModel", "CodeGenPreTrainedModel"] diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 6e43f257a387..9a601591372b 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -344,8 +344,7 @@ class CoherePreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_attention_backend = True _can_record_outputs = { diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index 82ac1b2b61ae..10e65a180278 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -321,8 +321,7 @@ class Cohere2PreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_attention_backend = True _can_record_outputs = { diff --git a/src/transformers/models/colqwen2/modeling_colqwen2.py b/src/transformers/models/colqwen2/modeling_colqwen2.py index b0703f665e9b..7ef0acf080be 100644 --- a/src/transformers/models/colqwen2/modeling_colqwen2.py +++ b/src/transformers/models/colqwen2/modeling_colqwen2.py @@ -43,7 +43,6 @@ class ColQwen2PreTrainedModel(PreTrainedModel): _no_split_modules = [] _supports_flash_attn = True _supports_sdpa = True - _supports_cache_class = True def _init_weights(self, module): std = ( diff --git a/src/transformers/models/colqwen2/modular_colqwen2.py b/src/transformers/models/colqwen2/modular_colqwen2.py index f63e865a7142..3d2a57f1c350 100644 --- a/src/transformers/models/colqwen2/modular_colqwen2.py +++ b/src/transformers/models/colqwen2/modular_colqwen2.py @@ -228,7 +228,6 @@ def __call__( class ColQwen2PreTrainedModel(ColPaliPreTrainedModel): _supports_flash_attn = True _supports_sdpa = True - _supports_cache_class = True @dataclass diff --git a/src/transformers/models/convbert/modeling_convbert.py b/src/transformers/models/convbert/modeling_convbert.py index bdac1fecc1c3..a43ebfa0259d 100755 --- a/src/transformers/models/convbert/modeling_convbert.py +++ b/src/transformers/models/convbert/modeling_convbert.py @@ -325,11 +325,6 @@ def __init__(self, config): self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states: torch.Tensor, @@ -338,8 +333,7 @@ def forward( encoder_hidden_states: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - mixed_query_layer = self.query(hidden_states) - batch_size = hidden_states.size(0) + batch_size, seq_length, _ = hidden_states.shape # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. @@ -353,9 +347,16 @@ def forward( mixed_key_conv_attn_layer = self.key_conv_attn_layer(hidden_states.transpose(1, 2)) mixed_key_conv_attn_layer = mixed_key_conv_attn_layer.transpose(1, 2) - query_layer = self.transpose_for_scores(mixed_query_layer) - key_layer = self.transpose_for_scores(mixed_key_layer) - value_layer = self.transpose_for_scores(mixed_value_layer) + mixed_query_layer = self.query(hidden_states) + query_layer = mixed_query_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) + key_layer = mixed_key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + value_layer = mixed_value_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) conv_attn_layer = torch.multiply(mixed_key_conv_attn_layer, mixed_query_layer) conv_kernel_layer = self.conv_kernel_layer(conv_attn_layer) diff --git a/src/transformers/models/cpmant/modeling_cpmant.py b/src/transformers/models/cpmant/modeling_cpmant.py index f3ecf4930f6e..d68521c6dd52 100755 --- a/src/transformers/models/cpmant/modeling_cpmant.py +++ b/src/transformers/models/cpmant/modeling_cpmant.py @@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import PreTrainedModel @@ -60,11 +61,12 @@ def forward(self, hidden_states: torch.Tensor): class CpmAntAttention(nn.Module): - def __init__(self, config: CpmAntConfig): + def __init__(self, config: CpmAntConfig, layer_idx=None): super().__init__() self.dim_model = config.hidden_size self.num_heads = config.num_attention_heads self.dim_head = config.dim_head + self.layer_idx = layer_idx self.project_q = nn.Linear(self.dim_model, self.num_heads * self.dim_head, bias=False) self.project_k = nn.Linear(self.dim_model, self.num_heads * self.dim_head, bias=False) @@ -86,8 +88,9 @@ def forward( attention_mask: torch.BoolTensor, position_bias: torch.Tensor, output_attentions: Optional[bool] = False, - past_key_values: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ): """ Args: @@ -120,8 +123,7 @@ def forward( value = value.view(batch_size, len_k, self.num_heads, self.dim_head).permute(0, 2, 1, 3) if past_key_values is not None: - key = torch.cat([past_key_values[0], key], dim=-2) - value = torch.cat([past_key_values[1], value], dim=-2) + key, value = past_key_values.update(key, value, self.layer_idx, {"cache_position": cache_position}) len_k = key.size(-2) # (batch_size, num_heads, len_q, dim_head) @ (batch_size, num_heads, dim_head, len_k) -> (batch_size, num_heads, len_q, len_k) @@ -156,18 +158,14 @@ def forward( score = self.attention_out(score) - past_key_values = None - if use_cache: - past_key_values = (key, value) - - return score, attn_weights, past_key_values + return score, attn_weights class CpmAntSelfAttentionBlock(nn.Module): - def __init__(self, config: CpmAntConfig): + def __init__(self, config: CpmAntConfig, layer_idx=None): super().__init__() self.layernorm_before_attention = CpmAntLayerNorm(config) - self.self_attention = CpmAntAttention(config) + self.self_attention = CpmAntAttention(config, layer_idx=layer_idx) if config.dropout_p: self.dropout = torch.nn.Dropout(config.dropout_p) else: @@ -179,8 +177,9 @@ def forward( attention_mask: torch.Tensor, position_bias: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, - past_key_values: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ): """ Args: @@ -199,17 +198,22 @@ def forward( (see `past_key_values`). """ outputs = self.layernorm_before_attention(hidden_states) - outputs = self.self_attention( - outputs, outputs, attention_mask, position_bias, output_attentions, past_key_values, use_cache + outputs, attn_weights = self.self_attention( + outputs, + outputs, + attention_mask, + position_bias, + output_attentions, + past_key_values, + use_cache, + cache_position, ) - outputs, attn_weights, current_key_value = outputs - if self.dropout is not None: outputs = self.dropout(outputs) hidden_states = hidden_states + outputs - return hidden_states, attn_weights, current_key_value + return hidden_states, attn_weights class CpmAntDenseGatedACT(nn.Module): @@ -286,9 +290,9 @@ def forward( class CpmAntTransformerBlock(nn.Module): - def __init__(self, config: CpmAntConfig): + def __init__(self, config: CpmAntConfig, layer_idx=None): super().__init__() - self.self_att = CpmAntSelfAttentionBlock(config) + self.self_att = CpmAntSelfAttentionBlock(config, layer_idx=layer_idx) self.ffn = CpmAntFFNBlock(config) def forward( @@ -297,8 +301,9 @@ def forward( attention_mask: torch.Tensor, position_bias: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, - past_key_values: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ): """ Args: @@ -316,27 +321,25 @@ def forward( If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). """ - hidden_states = self.self_att( + hidden_states, attn_weights = self.self_att( hidden_states, attention_mask=attention_mask, position_bias=position_bias, output_attentions=output_attentions, past_key_values=past_key_values, use_cache=use_cache, + cache_position=cache_position, ) - hidden_states, attn_weights, current_key_value = hidden_states - hidden_states = self.ffn(hidden_states) - - return hidden_states, attn_weights, current_key_value + return hidden_states, attn_weights class CpmAntEncoder(nn.Module): def __init__(self, config: CpmAntConfig): super().__init__() self.num_layers = config.num_hidden_layers - self.layers = nn.ModuleList([CpmAntTransformerBlock(config) for ith in range(self.num_layers)]) + self.layers = nn.ModuleList([CpmAntTransformerBlock(config, layer_idx=i) for i in range(self.num_layers)]) self.output_layernorm = CpmAntLayerNorm(config) @@ -347,8 +350,9 @@ def forward( position_bias: torch.Tensor, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - past_key_values: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = None, + cache_postion: Optional[torch.Tensor] = None, ): """ Args: @@ -370,7 +374,6 @@ def forward( """ all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - current_key_values = () if use_cache else None for i, layer in enumerate(self.layers): if output_hidden_states: @@ -380,21 +383,19 @@ def forward( attention_mask, position_bias, output_attentions=output_attentions, - past_key_values=past_key_values[i] if past_key_values else None, + past_key_values=past_key_values, use_cache=use_cache, ) - hidden_states, attn_weights, current_key_value = layer_outputs + hidden_states, attn_weights = layer_outputs if output_attentions: all_self_attns += (attn_weights,) - if current_key_value is not None: - current_key_values = current_key_values + (current_key_value,) hidden_states = self.output_layernorm(hidden_states) if output_hidden_states: all_hidden_states += (hidden_states,) - return hidden_states, current_key_values, all_hidden_states, all_self_attns + return hidden_states, all_hidden_states, all_self_attns # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->CPMAnt @@ -592,6 +593,7 @@ def forward( past_key_values: Optional[tuple[tuple[torch.Tensor]]] = None, use_cache: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs, ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPast]: r""" @@ -634,17 +636,24 @@ def forward( position = torch.arange(seq_length, dtype=dtype, device=device).repeat(batch, 1) span = torch.full((batch, seq_length), 0, dtype=dtype, device=device) - if past_key_values is None: - past_length = 0 - past_key_values = tuple([None] * self.encoder.num_layers) - input_ids = input_ids.contiguous() - hidden_states = self.input_embedding(input_ids) - segment_states = self.segment_embedding(segment) - hidden_states = hidden_states + segment_states - else: - past_length = past_key_values[0][0].size(-2) - segment_states = self.segment_embedding(segment) - hidden_states = self.input_embedding(input_ids) + segment_states[:, -1:, :] + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `DynamicCache` instead, e.g. " + "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + past_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + input_ids = input_ids.contiguous() + hidden_states = self.input_embedding(input_ids) + segment_states = self.segment_embedding(segment) + if past_length != 0: + segment_states = segment_states[:, -1:, :] + + hidden_states = hidden_states + segment_states attention_mask = self._prepare_attention_mask(input_ids, span, context, length) position_bias = self.position_bias(position, position, segment, segment) @@ -653,7 +662,7 @@ def forward( position_bias = position_bias[:, :, past_length:, :] hidden_states = hidden_states[:, past_length:, :] - hidden_states, present_key_values, all_hidden_states, all_attentions = self.encoder( + hidden_states, all_hidden_states, all_attentions = self.encoder( hidden_states, attention_mask, position_bias, @@ -661,6 +670,7 @@ def forward( output_hidden_states, past_key_values, use_cache, + cache_position, ) if past_length == 0: @@ -677,14 +687,17 @@ def forward( new_hidden_states += (hidden_state[:, self.prompt_length :, :],) all_hidden_states = new_hidden_states + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + if not return_dict: return tuple( - v for v in [hidden_states, present_key_values, all_hidden_states, all_attentions] if v is not None + v for v in [hidden_states, past_key_values, all_hidden_states, all_attentions] if v is not None ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=present_key_values, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_attentions, ) @@ -719,6 +732,7 @@ def forward( labels: Optional[torch.Tensor] = None, return_dict: Optional[bool] = None, attention_mask: Optional[torch.Tensor] = None, # dummy parameter for text-generation pipeline + cache_position: Optional[torch.Tensor] = None, **kwargs, ) -> Union[tuple, CausalLMOutputWithPast]: r""" @@ -751,7 +765,13 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict model_output = self.cpmant( - input_ids, output_attentions, output_hidden_states, past_key_values, use_cache, return_dict + input_ids, + output_attentions, + output_hidden_states, + past_key_values, + use_cache, + return_dict, + cache_position, ) hidden_states = model_output.last_hidden_state if return_dict else model_output[0] @@ -786,12 +806,5 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings - def _reorder_cache(self, past_key_values, beam_idx): - past_key_values = [list(each) if each is not None else each for each in past_key_values] - for key_value_layer in past_key_values: - key_value_layer[0] = key_value_layer[0][beam_idx] - key_value_layer[1] = key_value_layer[1][beam_idx] - return past_key_values - __all__ = ["CpmAntForCausalLM", "CpmAntModel", "CpmAntPreTrainedModel"] diff --git a/src/transformers/models/csm/modeling_csm.py b/src/transformers/models/csm/modeling_csm.py index f0f7eecbad74..d7807065e55f 100644 --- a/src/transformers/models/csm/modeling_csm.py +++ b/src/transformers/models/csm/modeling_csm.py @@ -370,8 +370,7 @@ class CsmPreTrainedModel(PreTrainedModel): _supports_sdpa = True # does not because of Mimi codec model # _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_attention_backend = True _can_record_outputs = { diff --git a/src/transformers/models/csm/modular_csm.py b/src/transformers/models/csm/modular_csm.py index 266327b13e5f..ffde5c82eb54 100644 --- a/src/transformers/models/csm/modular_csm.py +++ b/src/transformers/models/csm/modular_csm.py @@ -133,8 +133,7 @@ class CsmPreTrainedModel(PreTrainedModel): _supports_sdpa = True # does not because of Mimi codec model # _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_attention_backend = True _can_record_outputs = { diff --git a/src/transformers/models/ctrl/modeling_ctrl.py b/src/transformers/models/ctrl/modeling_ctrl.py index 883803741d7e..ba1a737efc5d 100644 --- a/src/transformers/models/ctrl/modeling_ctrl.py +++ b/src/transformers/models/ctrl/modeling_ctrl.py @@ -22,6 +22,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel @@ -83,10 +84,11 @@ def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=N class MultiHeadAttention(nn.Module): - def __init__(self, d_model_size, num_heads): + def __init__(self, d_model_size, num_heads, layer_idx=None): super().__init__() self.num_heads = num_heads self.d_model_size = d_model_size + self.layer_idx = layer_idx self.depth = int(d_model_size / self.num_heads) @@ -129,6 +131,7 @@ def forward( head_mask=None, use_cache=False, output_attentions=False, + cache_position=None, ): batch_size = q.shape[0] @@ -139,26 +142,16 @@ def forward( q = self.split_into_heads(q, batch_size) k = self.split_into_heads(k, batch_size) v = self.split_into_heads(v, batch_size) - if layer_past is not None: - past_key, past_value = layer_past[0], layer_past[1] - k = torch.cat((past_key, k), dim=-2) - v = torch.cat((past_value, v), dim=-2) - if use_cache is True: - present = torch.stack((k, v)) - else: - present = (None,) + if layer_past is not None: + k, v = layer_past.update(k, v, self.layer_idx, {"cache_position": cache_position}) output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask) scaled_attention = output[0].permute([0, 2, 1, 3]) attn = output[1] original_size_attention = scaled_attention.reshape(batch_size, -1, self.d_model_size) output = self.dense(original_size_attention) - - outputs = (output, present) - if output_attentions: - outputs = outputs + (attn,) - return outputs + return output, attn def point_wise_feed_forward_network(d_model_size, dff): @@ -166,10 +159,10 @@ def point_wise_feed_forward_network(d_model_size, dff): class EncoderLayer(nn.Module): - def __init__(self, d_model_size, num_heads, dff, rate=0.1): + def __init__(self, d_model_size, num_heads, dff, rate=0.1, layer_idx=None): super().__init__() - self.multi_head_attention = MultiHeadAttention(d_model_size, num_heads) + self.multi_head_attention = MultiHeadAttention(d_model_size, num_heads, layer_idx=layer_idx) self.ffn = point_wise_feed_forward_network(d_model_size, dff) self.layernorm1 = nn.LayerNorm(d_model_size, eps=1e-6) @@ -179,7 +172,15 @@ def __init__(self, d_model_size, num_heads, dff, rate=0.1): self.dropout2 = nn.Dropout(rate) def forward( - self, x, mask, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, output_attentions=False + self, + x, + mask, + layer_past=None, + attention_mask=None, + head_mask=None, + use_cache=False, + output_attentions=False, + cache_position=None, ): normed = self.layernorm1(x) attn_outputs = self.multi_head_attention( @@ -192,6 +193,7 @@ def forward( head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) attn_output = attn_outputs[0] attn_output = self.dropout1(attn_output) @@ -242,7 +244,10 @@ def __init__(self, config): self.dropout = nn.Dropout(config.embd_pdrop) self.h = nn.ModuleList( - [EncoderLayer(config.n_embd, config.n_head, config.dff, config.resid_pdrop) for _ in range(config.n_layer)] + [ + EncoderLayer(config.n_embd, config.n_head, config.dff, config.resid_pdrop, layer_idx=i) + for i in range(config.n_layer) + ] ) self.layernorm = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) @@ -276,6 +281,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs, # NOOP kwargs, for now ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPast]: r""" @@ -332,11 +338,17 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device - if past_key_values is None: - past_length = 0 - past_key_values = tuple([None] * len(self.h)) - else: - past_length = past_key_values[0][0].size(-2) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `DynamicCache` instead, e.g. " + "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + past_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if position_ids is None: position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0) @@ -387,38 +399,40 @@ def forward( hidden_states = self.dropout(hidden_states) - presents = () if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None - for i, (h, layer_past) in enumerate(zip(self.h, past_key_values)): + for i, h in enumerate(self.h): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) outputs = h( hidden_states, mask, - layer_past=layer_past, + layer_past=past_key_values, attention_mask=attention_mask, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) - hidden_states, present = outputs[:2] - if use_cache is True: - presents = presents + (present,) - + hidden_states = outputs[0] if output_attentions: - all_attentions += (outputs[2],) + all_attentions += (outputs[1],) hidden_states = self.layernorm(hidden_states) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None) + return tuple( + v for v in [hidden_states, past_key_values, all_hidden_states, all_attentions] if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=presents, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_attentions, ) @@ -462,6 +476,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs, ) -> Union[tuple[torch.Tensor], CausalLMOutputWithPast]: r""" @@ -520,6 +535,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = transformer_outputs[0] @@ -552,7 +568,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, use_cac # only last tokens for inputs_ids if past is defined in kwargs if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] + past_length = past_key_values.get_seq_length() # Some generation methods already pass only the last input ID if input_ids.shape[1] > past_length: @@ -565,20 +581,6 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, use_cac return {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": use_cache} - @staticmethod - def _reorder_cache( - past_key_values: tuple[tuple[torch.Tensor]], beam_idx: torch.Tensor - ) -> tuple[tuple[torch.Tensor]]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - """ - return tuple( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) - for layer_past in past_key_values - ) - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index 94dedbfb38db..419bcb7b68b4 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -246,7 +246,6 @@ def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index f447ff6258de..93d217eed128 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -23,6 +23,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -38,6 +39,7 @@ from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_data2vec_text import Data2VecTextConfig @@ -139,7 +141,7 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds): # Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->Data2VecText class Data2VecTextSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -164,12 +166,9 @@ def __init__(self, config, position_embedding_type=None): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, @@ -177,53 +176,65 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if is_cross_attention and encoder_attention_mask is not None: + attention_mask = encoder_attention_mask + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value - if is_cross_attention and past_key_value is not None: + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - - query_layer = self.transpose_for_scores(mixed_query_layer) - - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) + key_layer = self.key(current_states) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + value_layer = self.value(current_states) + value_layer = value_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: + if past_key_value is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -265,11 +276,7 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return context_layer, attention_probs # Copied from transformers.models.bert.modeling_bert.BertSelfOutput @@ -294,10 +301,12 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Data2VecText,BERT->DATA2VEC_TEXT class Data2VecTextAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() self.self = DATA2VEC_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, position_embedding_type=position_embedding_type + config, + position_embedding_type=position_embedding_type, + layer_idx=layer_idx, ) self.output = Data2VecTextSelfOutput(config) self.pruned_heads = set() @@ -320,6 +329,7 @@ def prune_heads(self, heads): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, @@ -327,17 +337,19 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -377,17 +389,19 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Data2VecText class Data2VecTextLayer(GradientCheckpointingLayer): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = Data2VecTextAttention(config) + self.attention = Data2VecTextAttention(config, layer_idx=layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = Data2VecTextAttention(config, position_embedding_type="absolute") + self.crossattention = Data2VecTextAttention( + config, position_embedding_type="absolute", layer_idx=layer_idx + ) self.intermediate = Data2VecTextIntermediate(config) self.output = Data2VecTextOutput(config) @@ -398,28 +412,21 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -427,33 +434,23 @@ def forward( " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - cross_attn_past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (present_key_value,) - return outputs def feed_forward_chunk(self, attention_output): @@ -464,10 +461,10 @@ def feed_forward_chunk(self, attention_output): # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Data2VecText class Data2VecTextEncoder(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.config = config - self.layer = nn.ModuleList([Data2VecTextLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([Data2VecTextLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( @@ -482,6 +479,7 @@ def forward( output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -494,13 +492,21 @@ def forward( ) use_cache = False - next_decoder_cache = () if use_cache else None + return_legacy_cache = False + if use_cache and self.config.is_decoder and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None layer_outputs = layer_module( hidden_states, @@ -508,13 +514,12 @@ def forward( layer_head_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -523,12 +528,15 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -537,7 +545,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -649,6 +657,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -674,8 +683,13 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) @@ -728,6 +742,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None @@ -788,6 +803,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs, ) -> Union[tuple, CausalLMOutputWithCrossAttentions]: r""" @@ -830,6 +846,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = outputs[0] @@ -857,14 +874,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @auto_docstring class Data2VecTextForMaskedLM(Data2VecTextPreTrainedModel): diff --git a/src/transformers/models/data2vec/modeling_data2vec_vision.py b/src/transformers/models/data2vec/modeling_data2vec_vision.py index ede5404571ab..2cf64ac21f81 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_data2vec_vision.py @@ -261,11 +261,6 @@ def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = if self.has_relative_position_bias: self.relative_position_bias = Data2VecVisionRelativePositionBias(config, window_size=window_size) - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states: torch.Tensor, @@ -275,11 +270,22 @@ def forward( interpolate_pos_encoding: bool = False, resolution: Optional[tuple[int]] = None, ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]: - mixed_query_layer = self.query(hidden_states) - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + batch_size, seq_length, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -347,10 +353,22 @@ def forward( resolution=resolution, ) - mixed_query_layer = self.query(hidden_states) - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + batch_size, seq_length, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) attn_bias = None if self.has_relative_position_bias: diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 3bef3e329379..0c5bc0dd9c3e 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -811,8 +811,7 @@ class DbrxPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) def _init_weights(self, module: nn.Module): diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index c2c0a2a64545..689fff29e25c 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -452,7 +452,7 @@ class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel): base_model_prefix = "transformer" is_parallelizable = True supports_gradient_checkpointing = True - _supports_cache_class = True + _supports_static_cache = False def __init__(self, *inputs, **kwargs): diff --git a/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py b/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py index 836b9dc9de2a..ff36b6d43f43 100644 --- a/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py +++ b/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py @@ -458,8 +458,7 @@ class DeepseekV2PreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_attention_backend = True _can_record_outputs = { diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index 6c00b64eef79..708a370171bd 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -497,8 +497,7 @@ class DeepseekV3PreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_attention_backend = True _can_record_outputs = { diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index 52ca6bdac9cf..2966d465f66d 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -214,17 +214,28 @@ def __init__(self, config: DeiTConfig) -> None: self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( - self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + self, + hidden_states, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(self.query(hidden_states)) + batch_size, seq_length, _ = hidden_states.shape + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/deprecated/ernie_m/modeling_ernie_m.py b/src/transformers/models/deprecated/ernie_m/modeling_ernie_m.py index f57754cc585a..cf41dc2e29e8 100755 --- a/src/transformers/models/deprecated/ernie_m/modeling_ernie_m.py +++ b/src/transformers/models/deprecated/ernie_m/modeling_ernie_m.py @@ -539,7 +539,7 @@ def forward( past_key_values_length = 0 if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] + past_key_values_length = past_key_values.get_seq_length() # Adapted from paddlenlp.transformers.ernie_m.ErnieMModel if attention_mask is None: diff --git a/src/transformers/models/deprecated/gptsan_japanese/modeling_gptsan_japanese.py b/src/transformers/models/deprecated/gptsan_japanese/modeling_gptsan_japanese.py index 6974c0024702..25c56354c3fa 100644 --- a/src/transformers/models/deprecated/gptsan_japanese/modeling_gptsan_japanese.py +++ b/src/transformers/models/deprecated/gptsan_japanese/modeling_gptsan_japanese.py @@ -919,7 +919,7 @@ def forward( num_batch = input_ids.shape[0] pasts_or_spout_value = None if past_key_values is not None: - num_pasts_contexts = past_key_values[0][0].shape[2] + num_pasts_contexts = past_key_values.get_seq_length() elif self.config.d_spout and spout is not None: # `spout` is a special input vector specific to GPTSAN # This controls the output by projecting embedded information such as the class of sentences during learning. diff --git a/src/transformers/models/deprecated/nezha/modeling_nezha.py b/src/transformers/models/deprecated/nezha/modeling_nezha.py index 2ef4a560952e..2bead71cadd4 100644 --- a/src/transformers/models/deprecated/nezha/modeling_nezha.py +++ b/src/transformers/models/deprecated/nezha/modeling_nezha.py @@ -925,7 +925,7 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index 848f3f971e05..2d473cd423df 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -592,7 +592,7 @@ def forward( use_cache = False if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] + past_key_values_length = past_key_values.get_seq_length() seq_length_with_past = seq_length_with_past + past_key_values_length if position_ids is None: @@ -794,7 +794,7 @@ def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] + past_length = past_key_values.get_seq_length() # Some generation methods already pass only the last input ID if input_ids.shape[1] > past_length: diff --git a/src/transformers/models/deprecated/qdqbert/modeling_qdqbert.py b/src/transformers/models/deprecated/qdqbert/modeling_qdqbert.py index df3fce3b5205..914428b96a23 100755 --- a/src/transformers/models/deprecated/qdqbert/modeling_qdqbert.py +++ b/src/transformers/models/deprecated/qdqbert/modeling_qdqbert.py @@ -908,7 +908,7 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) @@ -1122,7 +1122,7 @@ def prepare_inputs_for_generation( # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] + past_length = past_key_values.get_seq_length() # Some generation methods already pass only the last input ID if input_ids.shape[1] > past_length: diff --git a/src/transformers/models/deprecated/realm/modeling_realm.py b/src/transformers/models/deprecated/realm/modeling_realm.py index e88a75bd1bf2..68787c60e903 100644 --- a/src/transformers/models/deprecated/realm/modeling_realm.py +++ b/src/transformers/models/deprecated/realm/modeling_realm.py @@ -1047,7 +1047,7 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) diff --git a/src/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py b/src/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py index 0599c3b592f1..ed4c96c89bfb 100755 --- a/src/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py +++ b/src/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py @@ -559,7 +559,7 @@ def forward( raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale @@ -886,7 +886,7 @@ def prepare_inputs_for_generation( attention_mask = input_ids.new_ones(input_ids.shape) if past_key_values: - past_length = past_key_values[0][0].shape[2] + past_length = past_key_values.get_seq_length() # Some generation methods already pass only the last input ID if input_ids.shape[1] > past_length: diff --git a/src/transformers/models/deprecated/xlm_prophetnet/modeling_xlm_prophetnet.py b/src/transformers/models/deprecated/xlm_prophetnet/modeling_xlm_prophetnet.py index 57f2c2610e82..19d7988a691e 100644 --- a/src/transformers/models/deprecated/xlm_prophetnet/modeling_xlm_prophetnet.py +++ b/src/transformers/models/deprecated/xlm_prophetnet/modeling_xlm_prophetnet.py @@ -597,7 +597,7 @@ def forward(self, inputs_shape, device, attention_mask=None, past_key_values=Non if past_key_values is not None: # position_ids is the same for every token when decoding a single step # Without the int() cast, it doesn't work in some cases when exporting to ONNX - prev_num_input_ids = past_key_values[0][0].shape[2] + prev_num_input_ids = past_key_values.get_seq_length() num_input_ids = inputs_shape[1] + prev_num_input_ids position_ids = torch.ones((1, 1), dtype=torch.long, device=device) * ( int(self.padding_idx + num_input_ids) diff --git a/src/transformers/models/dia/modeling_dia.py b/src/transformers/models/dia/modeling_dia.py index f801a7f60372..da0f616eda77 100644 --- a/src/transformers/models/dia/modeling_dia.py +++ b/src/transformers/models/dia/modeling_dia.py @@ -67,7 +67,6 @@ class DiaPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True _supports_static_cache = True main_input_name = "input_ids" _no_split_modules = ["DiaEncoderLayer", "DiaDecoderLayer"] diff --git a/src/transformers/models/dia/modular_dia.py b/src/transformers/models/dia/modular_dia.py index 5dfa78ce3644..7da15d7c10b9 100644 --- a/src/transformers/models/dia/modular_dia.py +++ b/src/transformers/models/dia/modular_dia.py @@ -62,7 +62,6 @@ class DiaPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True _supports_static_cache = True main_input_name = "input_ids" _no_split_modules = ["DiaEncoderLayer", "DiaDecoderLayer"] diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index 26de9466431d..e2b093fd8ec0 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -533,8 +533,7 @@ class DiffLlamaPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = False - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_attention_backend = False _can_record_outputs = { diff --git a/src/transformers/models/dinat/modeling_dinat.py b/src/transformers/models/dinat/modeling_dinat.py index 0618e0d2ded8..140d16bd33b9 100644 --- a/src/transformers/models/dinat/modeling_dinat.py +++ b/src/transformers/models/dinat/modeling_dinat.py @@ -270,19 +270,27 @@ def __init__(self, config, dim, num_heads, kernel_size, dilation): self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 3, 1, 2, 4) - def forward( self, hidden_states: torch.Tensor, output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor]: - query_layer = self.transpose_for_scores(self.query(hidden_states)) - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + batch_size, seq_length, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # Apply the scale factor before computing attention weights. It's usually more efficient because # attention weights are typically a bigger tensor compared to query. diff --git a/src/transformers/models/dinov2/modeling_dinov2.py b/src/transformers/models/dinov2/modeling_dinov2.py index e266727bf9ef..102b15a5fbdd 100644 --- a/src/transformers/models/dinov2/modeling_dinov2.py +++ b/src/transformers/models/dinov2/modeling_dinov2.py @@ -202,17 +202,28 @@ def __init__(self, config: Dinov2Config) -> None: self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( - self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + self, + hidden_states, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(self.query(hidden_states)) + batch_size, seq_length, _ = hidden_states.shape + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py b/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py index 69236ab67feb..0c09e2f75d18 100644 --- a/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +++ b/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py @@ -223,17 +223,28 @@ def __init__(self, config: Dinov2WithRegistersConfig) -> None: self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( - self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + self, + hidden_states, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(self.query(hidden_states)) + batch_size, seq_length, _ = hidden_states.shape + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/doge/modeling_doge.py b/src/transformers/models/doge/modeling_doge.py index 969150c7c7d4..29aa4b196102 100644 --- a/src/transformers/models/doge/modeling_doge.py +++ b/src/transformers/models/doge/modeling_doge.py @@ -494,8 +494,6 @@ class DogePreTrainedModel(PreTrainedModel): _supports_flash_attn = False _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True _supports_static_cache = False _supports_attention_backend = True _can_record_outputs = { diff --git a/src/transformers/models/dots1/modeling_dots1.py b/src/transformers/models/dots1/modeling_dots1.py index 4ac1420e2d30..665aa1b85db0 100644 --- a/src/transformers/models/dots1/modeling_dots1.py +++ b/src/transformers/models/dots1/modeling_dots1.py @@ -417,8 +417,7 @@ class Dots1PreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_attention_backend = True _can_record_outputs = { diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index 8614a4de6e11..fd9fd489938b 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -326,17 +326,28 @@ def __init__(self, config: DPTConfig) -> None: self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( - self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + self, + hidden_states, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(self.query(hidden_states)) + batch_size, seq_length, _ = hidden_states.shape + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index 613dea9473b1..c22b47c55ad4 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, get_activation +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -39,11 +40,8 @@ ) from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer -from ...utils import ( - ModelOutput, - auto_docstring, - logging, -) +from ...utils import ModelOutput, auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_electra import ElectraConfig @@ -200,7 +198,7 @@ def forward( # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Electra class ElectraSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -225,12 +223,9 @@ def __init__(self, config, position_embedding_type=None): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, @@ -238,53 +233,65 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if is_cross_attention and encoder_attention_mask is not None: + attention_mask = encoder_attention_mask + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value - if is_cross_attention and past_key_value is not None: + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - - query_layer = self.transpose_for_scores(mixed_query_layer) - - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) + key_layer = self.key(current_states) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + value_layer = self.value(current_states) + value_layer = value_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: + if past_key_value is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -326,11 +333,7 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return context_layer, attention_probs # Copied from transformers.models.bert.modeling_bert.BertSelfOutput @@ -355,10 +358,12 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Electra,BERT->ELECTRA class ElectraAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() self.self = ELECTRA_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, position_embedding_type=position_embedding_type + config, + position_embedding_type=position_embedding_type, + layer_idx=layer_idx, ) self.output = ElectraSelfOutput(config) self.pruned_heads = set() @@ -381,6 +386,7 @@ def prune_heads(self, heads): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, @@ -388,17 +394,19 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -438,17 +446,17 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Electra class ElectraLayer(GradientCheckpointingLayer): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = ElectraAttention(config) + self.attention = ElectraAttention(config, layer_idx=layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = ElectraAttention(config, position_embedding_type="absolute") + self.crossattention = ElectraAttention(config, position_embedding_type="absolute", layer_idx=layer_idx) self.intermediate = ElectraIntermediate(config) self.output = ElectraOutput(config) @@ -459,28 +467,21 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -488,33 +489,23 @@ def forward( " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - cross_attn_past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (present_key_value,) - return outputs def feed_forward_chunk(self, attention_output): @@ -525,10 +516,10 @@ def feed_forward_chunk(self, attention_output): # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Electra class ElectraEncoder(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.config = config - self.layer = nn.ModuleList([ElectraLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([ElectraLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( @@ -543,6 +534,7 @@ def forward( output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -555,13 +547,21 @@ def forward( ) use_cache = False - next_decoder_cache = () if use_cache else None + return_legacy_cache = False + if use_cache and self.config.is_decoder and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None layer_outputs = layer_module( hidden_states, @@ -569,13 +569,12 @@ def forward( layer_head_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -584,12 +583,15 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -598,7 +600,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -750,8 +752,13 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if attention_mask is None: attention_mask = torch.ones(input_shape, device=device) @@ -1574,15 +1581,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - # Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM._reorder_cache - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = [ "ElectraForCausalLM", diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index cdf5eee993aa..6633abc49461 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -1097,8 +1097,7 @@ class Emu3PreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values", "causal_mask"] _supports_flash_attn = True _supports_sdpa = True - _supports_quantized_cache = True - _supports_cache_class = True + _supports_static_cache = True _supports_param_buffer_assignment = False _supports_flex_attn = True diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index 48ec65a5e21e..38c7e2197bee 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -607,9 +607,5 @@ def resize_token_embeddings(self, *args, **kwargs): " model.decoder.resize_token_embeddings(...))" ) - def _reorder_cache(self, past_key_values, beam_idx): - # apply decoder cache reordering here - return self.decoder._reorder_cache(past_key_values, beam_idx) - __all__ = ["EncoderDecoderModel"] diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index 79898516126d..a5d55bafd754 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -41,6 +42,7 @@ from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_ernie import ErnieConfig @@ -125,7 +127,7 @@ def forward( # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Ernie class ErnieSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -150,12 +152,9 @@ def __init__(self, config, position_embedding_type=None): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, @@ -163,53 +162,65 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if is_cross_attention and encoder_attention_mask is not None: + attention_mask = encoder_attention_mask + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value - if is_cross_attention and past_key_value is not None: + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - - query_layer = self.transpose_for_scores(mixed_query_layer) - - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) + key_layer = self.key(current_states) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + value_layer = self.value(current_states) + value_layer = value_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: + if past_key_value is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -251,11 +262,7 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return context_layer, attention_probs # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->Ernie @@ -280,10 +287,12 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Ernie,BERT->ERNIE class ErnieAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() self.self = ERNIE_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, position_embedding_type=position_embedding_type + config, + position_embedding_type=position_embedding_type, + layer_idx=layer_idx, ) self.output = ErnieSelfOutput(config) self.pruned_heads = set() @@ -306,6 +315,7 @@ def prune_heads(self, heads): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, @@ -313,17 +323,19 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -363,17 +375,17 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Ernie class ErnieLayer(GradientCheckpointingLayer): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = ErnieAttention(config) + self.attention = ErnieAttention(config, layer_idx=layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = ErnieAttention(config, position_embedding_type="absolute") + self.crossattention = ErnieAttention(config, position_embedding_type="absolute", layer_idx=layer_idx) self.intermediate = ErnieIntermediate(config) self.output = ErnieOutput(config) @@ -384,28 +396,21 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -413,33 +418,22 @@ def forward( " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - cross_attn_past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs - - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (present_key_value,) - return outputs def feed_forward_chunk(self, attention_output): @@ -450,10 +444,10 @@ def feed_forward_chunk(self, attention_output): # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Ernie class ErnieEncoder(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.config = config - self.layer = nn.ModuleList([ErnieLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([ErnieLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( @@ -468,6 +462,7 @@ def forward( output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -480,13 +475,21 @@ def forward( ) use_cache = False - next_decoder_cache = () if use_cache else None + return_legacy_cache = False + if use_cache and self.config.is_decoder and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None layer_outputs = layer_module( hidden_states, @@ -494,13 +497,12 @@ def forward( layer_head_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -509,12 +511,15 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -523,7 +528,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -767,8 +772,13 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) @@ -1061,15 +1071,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel._reorder_cache - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @auto_docstring class ErnieForMaskedLM(ErniePreTrainedModel): diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index eb028ec898ba..c9388e588139 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -255,7 +255,7 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds): class EsmSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() self.config = config @@ -285,6 +285,7 @@ def __init__(self, config, position_embedding_type=None): self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx @deprecate_kwarg("past_key_value", version="4.54.0") def forward( @@ -390,8 +391,8 @@ class EsmFlashAttention2(EsmSelfAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - def __init__(self, config, position_embedding_type=None): - super().__init__(config, position_embedding_type=position_embedding_type) + def __init__(self, config, position_embedding_type=None, layer_idx=None): + super().__init__(config, position_embedding_type=position_embedding_type, layer_idx=layer_idx) # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. @@ -504,9 +505,9 @@ def forward( class EsmAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() - self.self = ESM_ATTENTION_CLASSES[config._attn_implementation](config) + self.self = ESM_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx) self.output = EsmSelfOutput(config) self.pruned_heads = set() self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -539,6 +540,7 @@ def forward( encoder_attention_mask=None, past_key_value=None, output_attentions=False, + cache_position=None, ): hidden_states_ln = self.LayerNorm(hidden_states) self_outputs = self.self( @@ -604,6 +606,7 @@ def forward( encoder_attention_mask=None, past_key_value=None, output_attentions=False, + cache_position=None, ): self_attention_outputs = self.attention( hidden_states, @@ -676,6 +679,7 @@ def forward( output_attentions=False, output_hidden_states=False, return_dict=True, + cache_position=None, ): all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 9e31eb1c9020..8e03e28c0d11 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -642,8 +642,7 @@ class FalconPreTrainedModel(PreTrainedModel): _no_split_modules = ["FalconDecoderLayer"] _supports_flash_attn = True _supports_sdpa = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True def __init__(self, *inputs, **kwargs): @@ -727,7 +726,7 @@ def forward( ) -> Union[tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]` + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -1029,7 +1028,7 @@ def forward( ) -> Union[tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]` + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -1086,30 +1085,6 @@ def forward( attentions=transformer_outputs.attentions, ) - def _reorder_cache( - self, past: tuple[tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor - ) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - - Output shares the same memory storage as `past`. - """ - - # Get a copy of `beam_idx` on all the devices where we need those indices. - device_to_beam_idx = { - past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past - } - reordered_past = tuple( - ( - layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]), - layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]), - ) - for layer_past in past - ) - return reordered_past - @auto_docstring( custom_intro=""" @@ -1151,7 +1126,7 @@ def forward( ) -> Union[tuple[torch.Tensor], SequenceClassifierOutputWithPast]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]` + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -1277,7 +1252,7 @@ def forward( ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]` + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -1356,7 +1331,7 @@ def forward( ) -> Union[tuple, QuestionAnsweringModelOutput]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]` + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as diff --git a/src/transformers/models/falcon_h1/modeling_falcon_h1.py b/src/transformers/models/falcon_h1/modeling_falcon_h1.py index 82173de99fae..a4ab2fe8d16d 100644 --- a/src/transformers/models/falcon_h1/modeling_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modeling_falcon_h1.py @@ -1152,7 +1152,6 @@ class FalconH1PreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn = True _supports_sdpa = True - _supports_cache_class = True # Note: only supports FalconHybridMambaAttentionDynamicCache _is_stateful = True def _init_weights(self, module): diff --git a/src/transformers/models/falcon_h1/modular_falcon_h1.py b/src/transformers/models/falcon_h1/modular_falcon_h1.py index 89c6abc411d7..bd4f63375be3 100644 --- a/src/transformers/models/falcon_h1/modular_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modular_falcon_h1.py @@ -930,7 +930,6 @@ class FalconH1PreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn = True _supports_sdpa = True - _supports_cache_class = True # Note: only supports FalconHybridMambaAttentionDynamicCache _is_stateful = True def _init_weights(self, module): diff --git a/src/transformers/models/flaubert/modeling_flaubert.py b/src/transformers/models/flaubert/modeling_flaubert.py index f9a549d20559..7699d3f31fc6 100644 --- a/src/transformers/models/flaubert/modeling_flaubert.py +++ b/src/transformers/models/flaubert/modeling_flaubert.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import gelu, get_activation +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, @@ -87,6 +88,7 @@ def __init__(self, n_heads, dim, config): self.layer_id = next(MultiHeadAttention.NEW_ID) self.dim = dim self.n_heads = n_heads + self.head_dim = dim // n_heads self.dropout = config.attention_dropout assert self.dim % self.n_heads == 0 @@ -111,50 +113,57 @@ def prune_heads(self, heads): self.dim = attention_head_size * self.n_heads self.pruned_heads = self.pruned_heads.union(heads) - def forward(self, input, mask, kv=None, cache=None, head_mask=None, output_attentions=False): + def forward( + self, + input, + mask, + kv=None, + cache=None, + head_mask=None, + output_attentions=False, + cache_position=None, + ): """ Self-attention (if kv is None) or attention over source sentence (provided by kv). """ # Input is (bs, qlen, dim) # Mask is (bs, klen) (non-causal) or (bs, klen, klen) bs, qlen, dim = input.size() - if kv is None: - klen = qlen if cache is None else cache["slen"] + qlen - else: - klen = kv.size(1) - # assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured' - n_heads = self.n_heads - dim_per_head = self.dim // n_heads - mask_reshape = (bs, 1, qlen, klen) if mask.dim() == 3 else (bs, 1, 1, klen) - - def shape(x): - """projection""" - return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2) - - def unshape(x): - """compute context""" - return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head) - - q = shape(self.q_lin(input)) # (bs, n_heads, qlen, dim_per_head) - if kv is None: - k = shape(self.k_lin(input)) # (bs, n_heads, qlen, dim_per_head) - v = shape(self.v_lin(input)) # (bs, n_heads, qlen, dim_per_head) - elif cache is None or self.layer_id not in cache: - k = v = kv - k = shape(self.k_lin(k)) # (bs, n_heads, qlen, dim_per_head) - v = shape(self.v_lin(v)) # (bs, n_heads, qlen, dim_per_head) + is_cross_attention = kv is not None + mask_reshape = (bs, 1, qlen, -1) if mask.dim() == 3 else (bs, 1, 1, -1) + q = self.q_lin(input).view(bs, -1, self.n_heads, self.head_dim).transpose(1, 2) if cache is not None: - if self.layer_id in cache: - if kv is None: - k_, v_ = cache[self.layer_id] - k = torch.cat([k_, k], dim=2) # (bs, n_heads, klen, dim_per_head) - v = torch.cat([v_, v], dim=2) # (bs, n_heads, klen, dim_per_head) + if isinstance(cache, EncoderDecoderCache): + is_updated = cache.is_updated.get(self.layer_id) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = cache.cross_attention_cache else: - k, v = cache[self.layer_id] - cache[self.layer_id] = (k, v) + curr_past_key_value = cache.self_attention_cache + else: + curr_past_key_value = cache - q = q / math.sqrt(dim_per_head) # (bs, n_heads, qlen, dim_per_head) + current_states = kv if is_cross_attention else input + if is_cross_attention and cache is not None and is_updated: + # reuse k,v, cross_attentions + k = curr_past_key_value.key_cache[self.layer_id] + v = curr_past_key_value.value_cache[self.layer_id] + else: + k = self.k_lin(current_states) + v = self.v_lin(current_states) + k = k.view(bs, -1, self.n_heads, self.head_dim).transpose(1, 2) + v = v.view(bs, -1, self.n_heads, self.head_dim).transpose(1, 2) + + if cache is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + k, v = curr_past_key_value.update(k, v, self.layer_id, {"cache_position": cache_position}) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + cache.is_updated[self.layer_id] = True + + q = q / math.sqrt(self.head_dim) # (bs, n_heads, qlen, head_dim) scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, qlen, klen) mask = (mask == 0).view(mask_reshape).expand_as(scores) # (bs, n_heads, qlen, klen) scores.masked_fill_(mask, torch.finfo(scores.dtype).min) # (bs, n_heads, qlen, klen) @@ -166,8 +175,8 @@ def unshape(x): if head_mask is not None: weights = weights * head_mask - context = torch.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head) - context = unshape(context) # (bs, qlen, dim) + context = torch.matmul(weights, v) # (bs, n_heads, qlen, head_dim) + context = context.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * self.head_dim) outputs = (self.out_lin(context),) if output_attentions: @@ -814,6 +823,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple, BaseModelOutput]: r""" langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -848,6 +858,9 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device + if not isinstance(cache, Cache): + cache = EncoderDecoderCache.from_legacy_cache(cache) + if lengths is None: if input_ids is not None: lengths = (input_ids != self.pad_index).sum(dim=1).long() @@ -893,7 +906,7 @@ def forward( # do not recompute cached elements if cache is not None and input_ids is not None: - _slen = slen - cache["slen"] + _slen = slen - cache.get_seq_length() input_ids = input_ids[:, -_slen:] position_ids = position_ids[:, -_slen:] if langs is not None: @@ -935,6 +948,7 @@ def forward( cache=cache, head_mask=head_mask[i], output_attentions=output_attentions, + cache_position=cache_position, ) attn = attn_outputs[0] if output_attentions: @@ -951,13 +965,6 @@ def forward( attn = nn.functional.dropout(attn, p=self.dropout, training=self.training) tensor = tensor + attn - # encoder attention (for decoder only) - # if self.is_decoder and src_enc is not None: - # attn = self.encoder_attn[i](tensor, src_mask, kv=src_enc, cache=cache) - # attn = nn.functional.dropout(attn, p=self.dropout, training=self.training) - # tensor = tensor + attn - # tensor = self.layer_norm15[i](tensor) - # FFN if not self.pre_norm: tensor = tensor + self.ffns[i](tensor) @@ -972,13 +979,6 @@ def forward( if output_hidden_states: hidden_states = hidden_states + (tensor,) - # update cache length - if cache is not None: - cache["slen"] += tensor.size(1) - - # move back sequence length to dimension 0 - # tensor = tensor.transpose(0, 1) - if not return_dict: return tuple(v for v in [tensor, hidden_states, attentions] if v is not None) diff --git a/src/transformers/models/flava/modeling_flava.py b/src/transformers/models/flava/modeling_flava.py index a1a30f369a10..64a61e66b520 100644 --- a/src/transformers/models/flava/modeling_flava.py +++ b/src/transformers/models/flava/modeling_flava.py @@ -442,11 +442,6 @@ def __init__(self, config: FlavaPossibleConfigs) -> None: self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states: torch.Tensor, @@ -454,11 +449,22 @@ def forward( head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: - mixed_query_layer = self.query(hidden_states) - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + batch_size, seq_length, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/fsmt/modeling_fsmt.py b/src/transformers/models/fsmt/modeling_fsmt.py index 9e3000ad1338..52f3d027c80c 100644 --- a/src/transformers/models/fsmt/modeling_fsmt.py +++ b/src/transformers/models/fsmt/modeling_fsmt.py @@ -35,6 +35,7 @@ from torch.nn import CrossEntropyLoss, LayerNorm from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...modeling_outputs import ( @@ -452,7 +453,7 @@ def forward( class DecoderLayer(nn.Module): - def __init__(self, config: FSMTConfig): + def __init__(self, config: FSMTConfig, layer_idx=None): super().__init__() self.embed_dim = config.d_model @@ -460,6 +461,7 @@ def __init__(self, config: FSMTConfig): embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -471,6 +473,7 @@ def __init__(self, config: FSMTConfig): config.decoder_attention_heads, dropout=config.attention_dropout, encoder_decoder_attention=True, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) @@ -488,12 +491,10 @@ def forward( cross_attn_layer_head_mask=None, decoder_padding_mask=None, output_attentions=False, + cache_position=None, ): residual = x - if layer_state is None: - layer_state = {} - # Self Attention x, self_attn_weights = self.self_attn( query=x, @@ -503,6 +504,7 @@ def forward( attn_mask=causal_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) x = nn.functional.dropout(x, p=self.dropout, training=self.training) x = residual + x @@ -518,6 +520,7 @@ def forward( layer_state=layer_state, # mutates layer state layer_head_mask=cross_attn_layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) x = nn.functional.dropout(x, p=self.dropout, training=self.training) x = residual + x @@ -534,9 +537,8 @@ def forward( return ( x, self_attn_weights, - layer_state, cross_attn_weights, - ) # layer_state = cache for decoding + ) class FSMTDecoder(nn.Module): @@ -559,7 +561,7 @@ def __init__(self, config: FSMTConfig, embed_tokens: nn.Embedding): self.embed_positions = SinusoidalPositionalEmbedding( config.max_position_embeddings + self.padding_idx + 1, embed_dim, self.padding_idx ) - self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.decoder_layers)]) # type: list[DecoderLayer] + self.layers = nn.ModuleList([DecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) # type: list[DecoderLayer] if is_deepspeed_zero3_enabled(): import deepspeed @@ -585,10 +587,11 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, past_key_values: Optional[list[torch.FloatTensor]] = None, - use_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ): """ Includes several features from "Jointly Learning to Align and Translate with Transformer Models" (Garg et al., @@ -645,6 +648,17 @@ def forward( else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + # initialize `past_key_values` + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + x += positions x = nn.functional.dropout(x, p=self.dropout, training=self.training) @@ -656,7 +670,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attns = () if output_attentions else None - next_decoder_cache = [] # check if head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -676,23 +689,19 @@ def forward( if dropout_probability < self.layerdrop: continue - layer_state = past_key_values[idx] if past_key_values is not None else None - - x, layer_self_attn, layer_past, layer_cross_attn = decoder_layer( + x, layer_self_attn, layer_cross_attn = decoder_layer( x, encoder_hidden_states, encoder_attn_mask=encoder_padding_mask, decoder_padding_mask=decoder_padding_mask, - layer_state=layer_state, + layer_state=past_key_values, causal_mask=decoder_causal_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), output_attentions=output_attentions, + cache_position=cache_position, ) - if use_cache: - next_decoder_cache.append(layer_past.copy()) - if output_attentions: all_self_attns += (layer_self_attn,) all_cross_attns += (layer_cross_attn,) @@ -709,15 +718,16 @@ def forward( x = self.output_projection(x) - next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( - v for v in [x, next_cache, all_hidden_states, all_self_attns, all_cross_attns] if v is not None + v for v in [x, past_key_values, all_hidden_states, all_self_attns, all_cross_attns] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=x, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attns, @@ -741,6 +751,7 @@ def __init__( dropout=0.0, bias=True, encoder_decoder_attention=False, # otherwise self_attention + layer_idx=None, ): super().__init__() self.embed_dim = embed_dim @@ -749,6 +760,7 @@ def __init__( self.head_dim = embed_dim // num_heads assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" self.scaling = self.head_dim**-0.5 + self.layer_idx = layer_idx self.encoder_decoder_attention = encoder_decoder_attention self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -757,64 +769,65 @@ def __init__( self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.cache_key = "encoder_decoder" if self.encoder_decoder_attention else "self" - def _shape(self, tensor, seq_len, bsz): - return tensor.contiguous().view(seq_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) - def forward( self, query, key: Optional[Tensor], key_padding_mask: Optional[Tensor] = None, - layer_state: Optional[dict[str, Optional[Tensor]]] = None, + layer_state: Optional[Cache] = None, attn_mask: Optional[Tensor] = None, layer_head_mask: Optional[Tensor] = None, - output_attentions=False, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[Tensor, Optional[Tensor]]: """Input shape: Time(SeqLen) x Batch x Channel""" - static_kv: bool = self.encoder_decoder_attention tgt_len, bsz, embed_dim = query.size() assert embed_dim == self.embed_dim assert list(query.size()) == [tgt_len, bsz, embed_dim] - # get here for encoder decoder cause of static_kv - if layer_state is not None: # reuse k,v and encoder_padding_mask - saved_state = layer_state.get(self.cache_key, {}) - if "prev_key" in saved_state and static_kv: - # previous time steps are cached - no need to recompute key and value if they are static - key = None - else: - saved_state = None - layer_state = {} - q = self.q_proj(query) * self.scaling - if static_kv: - if key is None: - k = v = None + if layer_state is not None: + if isinstance(layer_state, EncoderDecoderCache): + is_updated = layer_state.is_updated.get(self.layer_idx) + if self.encoder_decoder_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = layer_state.cross_attention_cache + else: + curr_past_key_value = layer_state.self_attention_cache else: - k = self.k_proj(key) - v = self.v_proj(key) + curr_past_key_value = layer_state + + # NOTE: FSMT has format (seq_len, BS, model_dim) ofr inputs + current_states = key if self.encoder_decoder_attention else query + if self.encoder_decoder_attention and layer_state is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - k = self.k_proj(query) - v = self.v_proj(query) - - q = self._shape(q, tgt_len, bsz) - if k is not None: - k = self._shape(k, -1, bsz) - if v is not None: - v = self._shape(v, -1, bsz) - - if saved_state is not None: - k, v, key_padding_mask = self._use_saved_state(k, v, saved_state, key_padding_mask, static_kv, bsz) - - # Update cache - layer_state[self.cache_key] = { - "prev_key": k.view(bsz, self.num_heads, -1, self.head_dim), - "prev_value": v.view(bsz, self.num_heads, -1, self.head_dim), - "prev_key_padding_mask": key_padding_mask if not static_kv else None, - } + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(-1, bsz, self.num_heads, self.head_dim).permute(1, 2, 0, 3) + value_states = value_states.view(-1, bsz, self.num_heads, self.head_dim).permute(1, 2, 0, 3) + + if layer_state is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not self.encoder_decoder_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if self.encoder_decoder_attention: + layer_state.is_updated[self.layer_idx] = True - assert k is not None - src_len = k.size(1) - attn_weights = torch.bmm(q, k.transpose(1, 2)) + query_states = self.q_proj(query) * self.scaling + + # Reshape back to 3D tensors for `bmm` + query_states = query_states.view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) + key_states = key_states.reshape(bsz * self.num_heads, -1, self.head_dim) + value_states = value_states.reshape(bsz * self.num_heads, -1, self.head_dim) + + assert key_states is not None + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) assert attn_weights.size() == (bsz * self.num_heads, tgt_len, src_len) if attn_mask is not None: @@ -857,45 +870,14 @@ def forward( training=self.training, ) - assert v is not None - attn_output = torch.bmm(attn_probs, v) + assert value_states is not None + attn_output = torch.bmm(attn_probs, value_states) assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim) attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) attn_output = self.out_proj(attn_output) return attn_output, attn_weights_reshaped - def _use_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, bsz): - # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) - if "prev_key" in saved_state: - _prev_key = saved_state["prev_key"] - assert _prev_key is not None - prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) - if static_kv: - k = prev_key - else: - assert k is not None - k = torch.cat([prev_key, k], dim=1) - if "prev_value" in saved_state: - _prev_value = saved_state["prev_value"] - assert _prev_value is not None - prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) - if static_kv: - v = prev_value - else: - assert v is not None - v = torch.cat([prev_value, v], dim=1) - assert k is not None and v is not None - prev_key_padding_mask: Optional[Tensor] = saved_state.get("prev_key_padding_mask", None) - if prev_key_padding_mask is not None: - if static_kv: - new_key_padding_mask = prev_key_padding_mask - else: - new_key_padding_mask = torch.cat([prev_key_padding_mask, key_padding_mask], dim=1) - else: - new_key_padding_mask = key_padding_mask - return k, v, new_key_padding_mask - def fill_with_neg_inf(t): """FP16-compatible function that fills a input_ids with -inf.""" @@ -953,6 +935,7 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.Tensor], Seq2SeqModelOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1033,6 +1016,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1098,6 +1082,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.Tensor], Seq2SeqLMOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1161,6 +1146,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) lm_logits = outputs[0] @@ -1189,17 +1175,6 @@ def forward( def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = [] - for layer_past in past_key_values: - # get the correct batch idx from decoder layer's batch dim for cross and self-attn - layer_past_new = { - attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items() - } - reordered_past.append(layer_past_new) - return reordered_past - def get_encoder(self): return self.model.encoder diff --git a/src/transformers/models/fuyu/modeling_fuyu.py b/src/transformers/models/fuyu/modeling_fuyu.py index d2838fa8e03b..57e256c8fab3 100644 --- a/src/transformers/models/fuyu/modeling_fuyu.py +++ b/src/transformers/models/fuyu/modeling_fuyu.py @@ -41,7 +41,6 @@ class FuyuPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True _no_split_modules = [] _skip_keys_device_placement = "past_key_values" @@ -390,14 +389,5 @@ def prepare_inputs_for_generation( return model_inputs - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = ["FuyuForCausalLM", "FuyuPreTrainedModel", "FuyuModel"] diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index e5b06aaa64f6..2a5c08f1b1d1 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -313,8 +313,7 @@ class GemmaPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_attention_backend = True _can_record_outputs = { diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index fceb1cf9d005..3a45c3275abf 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -343,8 +343,7 @@ class Gemma2PreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_attention_backend = True _can_record_outputs = { diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 3f50d5f17be4..5f939c8248a5 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -433,8 +433,7 @@ class Gemma3PreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_attention_backend = True _can_record_outputs = { diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py index 9cb17a6a5b84..c63f9b31c04e 100644 --- a/src/transformers/models/gemma3n/modeling_gemma3n.py +++ b/src/transformers/models/gemma3n/modeling_gemma3n.py @@ -1489,8 +1489,7 @@ class Gemma3nPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_attention_backend = True _can_record_outputs = { diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index ad52c63d3d38..c6b1aa1f88db 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -151,11 +151,6 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None): self.max_position_embeddings = config.max_position_embeddings self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states: torch.Tensor, @@ -165,11 +160,24 @@ def forward( output_attentions: Optional[bool] = False, pixel_values_present: Optional[bool] = False, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + batch_size, seq_length, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) cutoff = self.image_patch_tokens if pixel_values_present else 0 - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) if past_key_value is not None: # NOTE: like in other caches, we store the text component. In GIT it means we discard the image component. key_layer_past, value_layer_past = past_key_value.update( @@ -178,8 +186,6 @@ def forward( key_layer = torch.cat([key_layer[:, :, :cutoff, :], key_layer_past], dim=2) value_layer = torch.cat([value_layer[:, :, :cutoff, :], value_layer_past], dim=2) - query_layer = self.transpose_for_scores(mixed_query_layer) - # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -450,8 +456,6 @@ class GitPreTrainedModel(PreTrainedModel): config_class = GitConfig base_model_prefix = "git" supports_gradient_checkpointing = True - _supports_cache_class = True - _supports_quantized_cache = True def _init_weights(self, module): """Initialize the weights""" @@ -1095,7 +1099,7 @@ def forward( past_key_values_length = 0 if past_key_values is not None: past_key_values_length = ( - past_key_values[0][0].shape[2] + past_key_values.get_seq_length() if not isinstance(past_key_values, Cache) else past_key_values.get_seq_length() ) @@ -1452,13 +1456,5 @@ def prepare_inputs_for_generation( "use_cache": use_cache, } - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = ["GitForCausalLM", "GitModel", "GitPreTrainedModel", "GitVisionModel"] diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index ccb4cb583ac5..147ccde41adb 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -330,8 +330,7 @@ class GlmPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_attention_backend = True _can_record_outputs = { diff --git a/src/transformers/models/glm4/modeling_glm4.py b/src/transformers/models/glm4/modeling_glm4.py index 1a1b5abe5710..b1c6421fe16e 100644 --- a/src/transformers/models/glm4/modeling_glm4.py +++ b/src/transformers/models/glm4/modeling_glm4.py @@ -334,8 +334,7 @@ class Glm4PreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_attention_backend = True _can_record_outputs = { diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index b607c48cc30b..2e8a4149d76b 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -406,7 +406,7 @@ class Glm4vPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn = True _supports_sdpa = True - _supports_cache_class = True + _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/glpn/modeling_glpn.py b/src/transformers/models/glpn/modeling_glpn.py index 8715a09613a3..b21d2f14d765 100755 --- a/src/transformers/models/glpn/modeling_glpn.py +++ b/src/transformers/models/glpn/modeling_glpn.py @@ -126,11 +126,6 @@ def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_ ) self.layer_norm = nn.LayerNorm(hidden_size) - def transpose_for_scores(self, hidden_states): - new_shape = hidden_states.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - hidden_states = hidden_states.view(new_shape) - return hidden_states.permute(0, 2, 1, 3) - def forward( self, hidden_states, @@ -138,7 +133,12 @@ def forward( width, output_attentions=False, ): - query_layer = self.transpose_for_scores(self.query(hidden_states)) + batch_size, seq_length, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) if self.sr_ratio > 1: batch_size, seq_len, num_channels = hidden_states.shape @@ -150,8 +150,16 @@ def forward( hidden_states = hidden_states.reshape(batch_size, num_channels, -1).permute(0, 2, 1) hidden_states = self.layer_norm(hidden_states) - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/got_ocr2/modeling_got_ocr2.py b/src/transformers/models/got_ocr2/modeling_got_ocr2.py index c85d2c962f65..f11f12cd5409 100644 --- a/src/transformers/models/got_ocr2/modeling_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modeling_got_ocr2.py @@ -280,10 +280,10 @@ class GotOcr2PreTrainedModel(PreTrainedModel): base_model_prefix = "" supports_gradient_checkpointing = True _skip_keys_device_placement = "past_key_values" - _supports_cache_class = True + _supports_flash_attn = True _supports_sdpa = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index a98bf0c235e0..78cc23380972 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -562,7 +562,7 @@ class GPT2PreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_attention_backend = True - _supports_cache_class = True + _supports_static_cache = True def __init__(self, *inputs, **kwargs): @@ -785,7 +785,7 @@ def forward( r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -1169,7 +1169,7 @@ def forward( r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -1323,7 +1323,7 @@ def forward( r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -1424,20 +1424,6 @@ def forward( attentions=transformer_outputs.attentions, ) - @staticmethod - def _reorder_cache( - past_key_values: tuple[tuple[torch.Tensor]], beam_idx: torch.Tensor - ) -> tuple[tuple[torch.Tensor]]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - """ - return tuple( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) - for layer_past in past_key_values - ) - @auto_docstring( custom_intro=""" @@ -1486,7 +1472,7 @@ def forward( r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -1619,7 +1605,7 @@ def forward( r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -1705,7 +1691,7 @@ def forward( r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as diff --git a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py index 46a3dfea4410..127a0eed4732 100644 --- a/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py @@ -134,6 +134,7 @@ def __init__( self.attention_softmax_in_fp32 = attention_softmax_in_fp32 self.scale_attention_softmax_in_fp32 = scale_attention_softmax_in_fp32 self.multi_query = multi_query + self.num_key_value_heads = 1 if multi_query else n_head self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 2d2418d1f700..aae9a3e9b027 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -14,7 +14,7 @@ """PyTorch GPTBigCode model.""" import math -from typing import Optional, Union +from typing import Callable, Optional, Union import torch import torch.utils.checkpoint @@ -22,27 +22,27 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available -from ...modeling_layers import GradientCheckpointingLayer +from ...masking_utils import create_causal_mask +from ...modeling_flash_attention_utils import is_flash_attn_available from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, SequenceClassifierOutputWithPast, TokenClassifierOutput, ) -from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import is_torch_greater_or_equal_than_2_2 +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...utils import ( auto_docstring, + can_return_tuple, logging, ) from .configuration_gpt_bigcode import GPTBigCodeConfig if is_flash_attn_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward + pass logger = logging.get_logger(__name__) @@ -78,6 +78,49 @@ def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor return x +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class GPTBigCodeAttention(nn.Module): def __init__(self, config, is_cross_attention=False, layer_idx=None): super().__init__() @@ -90,6 +133,7 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): self.head_dim = self.embed_dim // self.num_heads self.kv_heads = 1 if self.multi_query else self.num_heads self.kv_dim = self.kv_heads * self.head_dim + self.num_key_value_groups = self.num_heads // self.kv_heads self.split_size = self.embed_dim self.is_causal = True @@ -100,6 +144,7 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): ) self.scale_attn_weights = config.scale_attn_weights + self.scaling = self.head_dim**0.5 if config.scale_attn_weights else 1.0 self.is_cross_attention = is_cross_attention self.layer_idx = layer_idx @@ -120,418 +165,93 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): self.c_proj = nn.Linear(self.embed_dim, self.embed_dim) - self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.attn_dropout = config.attn_pdrop self.resid_dropout = nn.Dropout(config.resid_pdrop) - def _get_mask_value(self, device, dtype): - # torch.where expects a tensor. We use a cache to avoid recreating it every time. - if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device: - self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device) - return self.mask_value - - def _attn(self, query, key, value, attention_mask=None, head_mask=None): - dtype = query.dtype - softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype - upcast = dtype != softmax_dtype - - unscale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1 - scale_factor = unscale**-1 - if self.scale_attn_weights: - scale_factor /= self.head_dim**0.5 - - # MQA models: (batch_size, query_length, num_heads * head_dim) - # MHA models: (batch_size, num_heads, query_length, head_dim) - query_shape = query.shape - batch_size = query_shape[0] - key_length = key.size(-1) - if self.multi_query: - # (batch_size, query_length, num_heads, head_dim) x (batch_size, head_dim, key_length) - # -> (batch_size, query_length, num_heads, key_length) - query_length = query_shape[1] - attn_shape = (batch_size, query_length, self.num_heads, key_length) - attn_view = (batch_size, query_length * self.num_heads, key_length) - # No copy needed for MQA 2, or when layer_past is provided. - query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim) - else: - # (batch_size, num_heads, query_length, head_dim) x (batch_size, num_heads, head_dim, key_length) - # -> (batch_size, num_heads, query_length, key_length) - query_length = query_shape[2] - attn_shape = (batch_size, self.num_heads, query_length, key_length) - attn_view = (batch_size * self.num_heads, query_length, key_length) - # Always copies - query = query.reshape(batch_size * self.num_heads, query_length, self.head_dim) - # No copy when layer_past is provided. - key = key.reshape(batch_size * self.num_heads, self.head_dim, key_length) - - attn_weights = torch.empty(attn_view, device=query.device, dtype=query.dtype) - if query.device.type == "cpu": - # This is needed because of a bug in pytorch https://github.com/pytorch/pytorch/issues/80588. - # The bug was fixed in https://github.com/pytorch/pytorch/pull/96086, - # but the fix has not been released as of pytorch version 2.0.0. - attn_weights = torch.zeros_like(attn_weights) - beta = 1 - else: - beta = 0 - attn_weights = torch.baddbmm(attn_weights, query, key, beta=beta, alpha=scale_factor).view(attn_shape) - - if upcast: - # Use a fused kernel to prevent a large overhead from casting and scaling. - # Sub-optimal when the key length is not a multiple of 8. - if attention_mask is None: - attn_weights = upcast_softmax(attn_weights, unscale, softmax_dtype) - else: - mask_value = self._get_mask_value(attn_weights.device, softmax_dtype) - attn_weights = upcast_masked_softmax(attn_weights, attention_mask, mask_value, unscale, softmax_dtype) - else: - if attention_mask is not None: - mask_value = self._get_mask_value(attn_weights.device, softmax_dtype) - - # The fused kernel is very slow when the key length is not a multiple of 8, so we skip fusion. - attn_weights = torch.where(attention_mask, attn_weights, mask_value) - - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) - - attn_weights = self.attn_dropout(attn_weights) - - # Mask heads if we want to - if head_mask is not None: - if self.multi_query: - head_mask = head_mask.transpose(1, 2) - attn_weights = attn_weights * head_mask - - if self.multi_query: - attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape) - else: - attn_output = torch.matmul(attn_weights, value) - - return attn_output, attn_weights - def forward( self, hidden_states: torch.Tensor, - layer_past: Optional[torch.Tensor] = None, + layer_past: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, + **kwargs, ) -> Union[ tuple[torch.Tensor, Optional[torch.Tensor]], tuple[torch.Tensor, Optional[torch.Tensor], tuple[torch.Tensor, ...]], ]: - if encoder_hidden_states is not None: - if not hasattr(self, "q_attn") or not self.is_cross_attention: - raise ValueError( - "If class is used as cross attention, the weights `q_attn` have to be defined. " - "Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`." - ) - - query = self.q_attn(hidden_states) - key_value = self.c_attn(encoder_hidden_states) - attention_mask = encoder_attention_mask - elif self.multi_query: - query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) - else: - # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim), - # i.e., the memory layout is not the same as GPT2. - # This makes the concatenation with past_key_value more efficient. - query, key_value = ( - self.c_attn(hidden_states) - .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) - .transpose(1, 2) - .split((self.head_dim, 2 * self.head_dim), dim=3) - ) + input_shape = hidden_states.shape[:-1] if layer_past is not None: - key_value = torch.cat((layer_past, key_value), dim=-2) - present = key_value if use_cache else None - - key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) - - attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask) - - if not self.multi_query: - attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape) - attn_output = self.c_proj(attn_output) - attn_output = self.resid_dropout(attn_output) - - outputs = (attn_output, present) - if output_attentions: - if self.multi_query: - # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length) - attn_weights = attn_weights.transpose(1, 2) - outputs += (attn_weights,) - - return outputs # a, present, (attentions) - - -class GPTBigCodeFlashAttention2(GPTBigCodeAttention): - """ - GPTBigCode flash attention module. This module inherits from `GPTBigCodeAttention` as the weights of the module - stays untouched. The only required change would be on the forward pass where it needs to correctly call the public - API of flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() + if isinstance(layer_past, EncoderDecoderCache): + is_updated = layer_past.is_updated.get(self.layer_idx) + if self.is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = layer_past.cross_attention_cache + else: + curr_past_key_value = layer_past.self_attention_cache + else: + curr_past_key_value = layer_past - def forward( - self, - hidden_states: torch.Tensor, - layer_past: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - ) -> Union[ - tuple[torch.Tensor, Optional[torch.Tensor]], - tuple[torch.Tensor, Optional[torch.Tensor], tuple[torch.Tensor, ...]], - ]: - if encoder_hidden_states is not None: + if self.is_cross_attention: if not hasattr(self, "q_attn") or not self.is_cross_attention: raise ValueError( "If class is used as cross attention, the weights `q_attn` have to be defined. " "Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`." ) - - query = self.q_attn(hidden_states) - key_value = self.c_attn(encoder_hidden_states) - attention_mask = encoder_attention_mask - elif self.multi_query: - query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) - else: - # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim), - # i.e., the memory layout is not the same as GPT2. - # This makes the concatenation with past_key_value more efficient. - query, key_value = ( - self.c_attn(hidden_states) - .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) - .transpose(1, 2) - .split((self.head_dim, 2 * self.head_dim), dim=3) - ) - - if layer_past is not None: - key_value = torch.cat((layer_past, key_value), dim=-2) - present = key_value if use_cache else None - - key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - if self.multi_query: - batch_size, query_length, _ = query.shape - query = query.reshape(batch_size, query_length, self.num_heads, self.head_dim) - key = key.unsqueeze(2) - value = value.unsqueeze(2) + if layer_past is not None and is_updated: + # reuse k,v, cross_attentions + key = curr_past_key_value.key_cache[self.layer_idx] + value = curr_past_key_value.value_cache[self.layer_idx] + else: + query = self.q_attn(hidden_states).view(*input_shape, -1, self.head_dim).transpose(1, 2) + key, value = self.c_attn(encoder_hidden_states).split((self.head_dim, self.head_dim), dim=-1) else: - query_length = query.shape[2] - batch_size, _, tgt, _ = key.shape - query = query.transpose(1, 2).reshape(batch_size, query_length, self.num_heads, self.head_dim) - key = key.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim) - value = value.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim) - - attn_dropout = self.attn_pdrop if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query.dtype - device_type = query.device.type if query.device.type != "mps" else "cpu" - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = ( - torch.get_autocast_dtype(device_type) - if hasattr(torch, "get_autocast_dtype") - else torch.get_autocast_gpu_dtype() + if self.multi_query: + query, key, value = ( + self.c_attn(hidden_states).unsqueeze(1).split((self.embed_dim, self.kv_dim, self.kv_dim), dim=3) ) - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype + query = query.view(*input_shape, -1, self.head_dim).transpose(1, 2) else: - target_dtype = self.c_attn.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - query = query.to(target_dtype) - key = key.to(target_dtype) - value = value.to(target_dtype) + query, key, value = ( + self.c_attn(hidden_states) + .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) + .transpose(1, 2) + .split(3 * [self.head_dim], dim=3) + ) - attn_output = _flash_attention_forward( + if layer_past is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not self.is_cross_attention else None + key, value = curr_past_key_value.update(key, value, self.layer_idx, {"cache_position": cache_position}) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if self.is_cross_attention: + layer_past.is_updated[self.layer_idx] = True + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, query, key, value, attention_mask, - query_length, - dropout=attn_dropout, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) - - attn_weights_reshaped = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) - attn_output = self.c_proj(attn_weights_reshaped) - attn_output = self.resid_dropout(attn_output) - - outputs = (attn_output, present) - - if output_attentions: - if self.multi_query: - # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length) - attn_weights_reshaped = attn_weights_reshaped.transpose(1, 2) - else: - attn_weights_reshaped = None - - outputs += (attn_weights_reshaped,) - - return outputs # a, present, (attentions) - - -class GPTBigCodeSdpaAttention(GPTBigCodeAttention): - def _attn(self, query, key, value, attention_mask=None): - scale = None - if not self.scale_attn_weights: - scale = 1 - - # MQA models: (batch_size, query_length, num_heads * head_dim) - # MHA models: (batch_size, num_heads, query_length, head_dim) - query_shape = query.shape - batch_size = query_shape[0] - key.shape[-2] - - if self.multi_query: - query_length = query_shape[1] - - # SDPA requires the dimension [..., sequence_length, head_dim]. - query = query.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2) - - # Without these unsqueeze, SDPA complains as the query and key/value have a different number of dimensions. - key = key.unsqueeze(1) - value = value.unsqueeze(1) - - # Although these expand are not numerically useful, PyTorch can not dispatch to memory-efficient backend - # and flash attention backend (No available kernel. Aborting execution.) from the shapes - # query = [batch_size, num_heads, query_length, head_dim] - # key = [batch_size, 1, past_length, head_dim] - # value = [batch_size, 1, past_length, head_dim] - # - # torch==2.1.2 is bugged with non-contiguous inputs with custom attn_mask (https://github.com/pytorch/pytorch/issues/112577), hence the check. - if is_torch_greater_or_equal_than_2_2: - key = key.expand(-1, self.num_heads, -1, -1) - value = value.expand(-1, self.num_heads, -1, -1) - else: - query_length = query_shape[-1] - - # See the comment above. - if query.device.type == "cuda" and attention_mask is not None: - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not - # create a causal mask in case query_length == 1. - is_causal = True if self.is_causal and attention_mask is None and query_length > 1 else False - - sdpa_result = torch.nn.functional.scaled_dot_product_attention( - query, - key, - value, - attn_mask=attention_mask, - dropout_p=self.attn_pdrop if self.training else 0.0, - is_causal=is_causal, - scale=scale, + dropout=0.0 if not self.training else self.attn_dropout, + scaling=self.scaling, + head_mask=head_mask, + **kwargs, ) - if self.multi_query: - # (batch_size, num_heads, seq_len, head_dim) --> (batch_size, seq_len, num_heads, head_dim) - sdpa_result = sdpa_result.transpose(1, 2) - - # Reshape is kind of expensive here, as it does a memory copy, - # but I did not manage to make away without it (logits do not match when using view) - # (batch_size, seq_len, num_heads, head_dim) --> (batch_size, seq_len, num_heads * head_dim) - sdpa_result = sdpa_result.reshape(query_shape) - - return sdpa_result, None - - def forward( - self, - hidden_states: torch.Tensor, - layer_past: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - ) -> Union[ - tuple[torch.Tensor, Optional[torch.Tensor]], - tuple[torch.Tensor, Optional[torch.Tensor], tuple[torch.Tensor, ...]], - ]: - if encoder_hidden_states is not None: - if not hasattr(self, "q_attn") or not self.is_cross_attention: - raise ValueError( - "If class is used as cross attention, the weights `q_attn` have to be defined. " - "Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`." - ) - - query = self.q_attn(hidden_states) - key_value = self.c_attn(encoder_hidden_states) - attention_mask = encoder_attention_mask - elif self.multi_query: - query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) - else: - # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim), - # i.e., the memory layout is not the same as GPT2. - # This makes the concatenation with past_key_value more efficient. - query, key_value = ( - self.c_attn(hidden_states) - .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) - .transpose(1, 2) - .split((self.head_dim, 2 * self.head_dim), dim=3) - ) - - if layer_past is not None: - key_value = torch.cat((layer_past, key_value), dim=-2) - present = key_value if use_cache else None - - key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) - - if not output_attentions: - # Difference with the original implementation: there is no need to transpose the key here, - # as SDPA expects seq_length to be at index -2 for the key as well - attn_output, attn_weights = self._attn(query, key, value, attention_mask) - else: - # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "GPTBigCodeModel is using GPTBigCodeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`." - ' Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - attn_output, attn_weights = super()._attn(query, key.transpose(-1, -2), value, attention_mask) - - if not self.multi_query: - attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) - - outputs = (attn_output, present) - if output_attentions: - if self.multi_query: - # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length) - attn_weights = attn_weights.transpose(1, 2) - outputs += (attn_weights,) - - return outputs + return attn_output, attn_weights class GPTBigCodeMLP(nn.Module): @@ -552,14 +272,7 @@ def forward(self, hidden_states: Optional[tuple[torch.FloatTensor]]) -> torch.Fl return hidden_states -GPTBIGCODE_ATTENTION_CLASSES = { - "eager": GPTBigCodeAttention, - "flash_attention_2": GPTBigCodeFlashAttention2, - "sdpa": GPTBigCodeSdpaAttention, -} - - -class GPTBigCodeBlock(GradientCheckpointingLayer): +class GPTBigCodeBlock(nn.Module): def __init__(self, config, layer_idx=None): super().__init__() hidden_size = config.hidden_size @@ -567,7 +280,7 @@ def __init__(self, config, layer_idx=None): self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = GPTBIGCODE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx) + self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) @@ -575,9 +288,7 @@ def __init__(self, config, layer_idx=None): if config.multi_query: raise NotImplementedError("Cross-attention not implemented for MQA") - self.crossattention = GPTBIGCODE_ATTENTION_CLASSES[config._attn_implementation]( - config, is_cross_attention=True, layer_idx=layer_idx - ) + self.crossattention = GPTBigCodeAttention(config, is_cross_attention=True, layer_idx=layer_idx) self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) @@ -586,13 +297,14 @@ def __init__(self, config, layer_idx=None): def forward( self, hidden_states: Optional[tuple[torch.Tensor]], - layer_past: Optional[torch.Tensor] = None, + layer_past: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, **kwargs, ) -> Union[ tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor, torch.Tensor] @@ -606,6 +318,8 @@ def forward( head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, + **kwargs, ) attn_output = attn_outputs[0] # output_attn: a, present, (attentions) outputs = attn_outputs[1:] @@ -628,24 +342,19 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, + cache_position=cache_position, + **kwargs, ) attn_output = cross_attn_outputs[0] # residual connection hidden_states = residual + attn_output - outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights + outputs = outputs + cross_attn_outputs[1:] # add cross attentions if we output attention weights residual = hidden_states hidden_states = self.ln_2(hidden_states) feed_forward_hidden_states = self.mlp(hidden_states) - # residual connection hidden_states = residual + feed_forward_hidden_states - - if use_cache: - outputs = (hidden_states,) + outputs - else: - outputs = (hidden_states,) + outputs[1:] - - return outputs # hidden_states, present, (attentions, cross_attentions) + return (hidden_states,) + outputs @auto_docstring @@ -722,6 +431,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, new_embeddings): self.wte = new_embeddings + @can_return_tuple @auto_docstring def forward( self, @@ -738,11 +448,13 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + **kwargs, ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: r""" input_ids (`torch.Tensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -760,16 +472,9 @@ def forward( use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if self.gradient_checkpointing and self.training and use_cache: - logger.warning( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") elif input_ids is not None: - self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) batch_size = input_ids.shape[0] @@ -782,81 +487,44 @@ def forward( if batch_size <= 0: raise ValueError("batch_size has to be defined and > 0") - device = input_ids.device if input_ids is not None else inputs_embeds.device + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) - if token_type_ids is not None: - token_type_ids = token_type_ids.view(-1, input_shape[-1]) + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) - if past_key_values is None: - past_length = 0 - past_key_values = tuple([None] * len(self.h)) - else: - past_length = past_key_values[0].size(-2) - - if attention_mask is not None and len(attention_mask.shape) == 2 and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_length > 0: - position_ids = position_ids[:, past_length : input_shape[-1] + past_length :] - elif position_ids is None: - position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0) - - # Self-attention mask. - query_length = input_shape[-1] - key_length = past_length + query_length - self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length] + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + position_ids=position_ids, + past_key_values=past_key_values, + ) if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask.bool() if (attention_mask is not None and 0 in attention_mask) else None encoder_attention_mask = ( encoder_attention_mask.bool() if (encoder_attention_mask is not None and 0 in encoder_attention_mask) else None ) else: - # 4d mask is passed through the layers - if attention_mask is not None: - self_attention_mask = self_attention_mask * attention_mask.view(batch_size, 1, -1).to( - dtype=torch.bool, device=self_attention_mask.device - ) - - # MQA models: (batch_size, query_length, n_heads, key_length) - # MHA models: (batch_size, n_heads, query_length, key_length) - self_attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1) - - if self._use_sdpa and head_mask is None and not output_attentions: - # SDPA with a custom mask is much faster in fp16/fp32 dtype rather than bool. Cast here to floating point instead of at every layer. - dtype = self.wte.weight.dtype - min_dtype = torch.finfo(dtype).min - self_attention_mask = torch.where( - self_attention_mask, - torch.full([], 0.0, dtype=dtype, device=self_attention_mask.device), - torch.full([], min_dtype, dtype=dtype, device=self_attention_mask.device), - ) - - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - if self.multi_query: - # gpt_bigcode using MQA has the bad taste to use a causal mask with shape - # [batch_size, target_length, 1, source_length], not compatible with SDPA, hence this transpose. - self_attention_mask = self_attention_mask.transpose(1, 2) - - if ( - query_length > 1 - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - ): - # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend - # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 - self_attention_mask = AttentionMaskConverter._unmask_unattended( - self_attention_mask, min_dtype=min_dtype - ) - - attention_mask = self_attention_mask - # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if ( @@ -877,46 +545,42 @@ def forward( # head_mask has shape n_layer x batch x n_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.n_layer) - if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds.to(inputs_embeds.device) if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) token_type_embeds = self.wte(token_type_ids) hidden_states = hidden_states + token_type_embeds hidden_states = self.drop(hidden_states) - output_shape = input_shape + (hidden_states.size(-1),) - presents = [] if use_cache else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_hidden_states = () if output_hidden_states else None - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + for i, block in enumerate(self.h): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) outputs = block( hidden_states, - layer_past, - attention_mask, + past_key_values, + causal_mask, head_mask[i], encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, + **kwargs, ) hidden_states = outputs[0] - if use_cache: - presents.append(outputs[1]) - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + all_self_attentions = all_self_attentions + (outputs[1],) if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + all_cross_attentions = all_cross_attentions + (outputs[2],) hidden_states = self.ln_f(hidden_states) @@ -925,16 +589,12 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if not return_dict: - return tuple( - v - for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] - if v is not None - ) + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=presents, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -964,75 +624,6 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): - # Overwritten -- `past_key_values` with uncommon shape - - token_type_ids = kwargs.get("token_type_ids", None) - # Omit tokens covered by past_key_values - if past_key_values: - if self.config.multi_query: - past_length = past_key_values[0].shape[1] - else: - past_length = past_key_values[0].shape[2] - - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - - input_ids = input_ids[:, remove_prefix_length:] - if token_type_ids is not None: - token_type_ids = token_type_ids[:, -input_ids.shape[1] :] - - attention_mask = kwargs.get("attention_mask", None) - position_ids = kwargs.get("position_ids", None) - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - else: - position_ids = None - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "position_ids": position_ids, - "attention_mask": attention_mask, - "token_type_ids": token_type_ids, - } - ) - return model_inputs - - def _get_initial_cache_position(self, seq_length, device, model_kwargs): - """ - Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length. - Since gpt bigcode is special, the method is overridden here, other models use it from `generation.utils.py`. - """ - past_length = 0 - if "past_key_values" in model_kwargs: - if self.config.multi_query: - past_length = model_kwargs["past_key_values"][0].shape[1] - else: - past_length = model_kwargs["past_key_values"][0].shape[2] - if "inputs_embeds" in model_kwargs: - cur_len = model_kwargs["inputs_embeds"].shape[1] - else: - cur_len = seq_length - model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=device) - return model_kwargs - @auto_docstring def forward( self, @@ -1050,12 +641,13 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs, ) -> Union[tuple, CausalLMOutputWithCrossAttentions]: r""" input_ids (`torch.Tensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -1086,6 +678,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = transformer_outputs[0] @@ -1113,17 +706,6 @@ def forward( cross_attentions=transformer_outputs.cross_attentions, ) - @staticmethod - def _reorder_cache( - past_key_values: tuple[tuple[torch.Tensor]], beam_idx: torch.Tensor - ) -> tuple[tuple[torch.Tensor]]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - """ - return tuple(layer_past.index_select(0, beam_idx.to(layer_past.device)) for layer_past in past_key_values) - @auto_docstring( custom_intro=""" @@ -1164,11 +746,12 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + **kwargs, ) -> Union[tuple, SequenceClassifierOutputWithPast]: r""" input_ids (`torch.Tensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -1197,6 +780,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + **kwargs, ) hidden_states = transformer_outputs[0] logits = self.score(hidden_states) @@ -1299,7 +883,7 @@ def forward( r""" input_ids (`torch.Tensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 896b2123c67b..e6df1a4225e3 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -477,8 +477,6 @@ class GPTNeoPreTrainedModel(PreTrainedModel): _no_split_modules = ["GPTNeoBlock"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn = True - _supports_cache_class = True - _supports_quantized_cache = True _supports_static_cache = False # TODO: needs a HybridCache def __init__(self, *inputs, **kwargs): @@ -542,7 +540,7 @@ def forward( r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -817,7 +815,7 @@ def forward( r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -883,20 +881,6 @@ def forward( attentions=transformer_outputs.attentions, ) - @staticmethod - def _reorder_cache( - past_key_values: tuple[tuple[torch.Tensor]], beam_idx: torch.Tensor - ) -> tuple[tuple[torch.Tensor]]: - """ - This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or - [`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - """ - return tuple( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) - for layer_past in past_key_values - ) - @auto_docstring( custom_intro=""" @@ -941,7 +925,7 @@ def forward( r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -1064,7 +1048,7 @@ def forward( r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -1146,7 +1130,7 @@ def forward( r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 7359d7b46e39..511ac1a29c8e 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -363,8 +363,7 @@ class GPTNeoXPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_attention_backend = True _can_record_outputs = { diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index 30a2ce2fbc5f..6f8674b00bc9 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -47,8 +47,7 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel): base_model_prefix = "gpt_neox_japanese" _no_split_modules = ["GPTNeoXJapaneseLayer"] _skip_keys_device_placement = "past_key_values" - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True def _init_weights(self, module): @@ -750,15 +749,6 @@ def forward( attentions=outputs.attentions, ) - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) - + layer_past[2:], - ) - return reordered_past - __all__ = [ "GPTNeoXJapaneseForCausalLM", diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 5c8e9e81c2e5..7fcc7451ac15 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -472,8 +472,6 @@ class GPTJPreTrainedModel(PreTrainedModel): _no_split_modules = ["GPTJBlock"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn = True - _supports_cache_class = True - _supports_quantized_cache = True _supports_static_cache = True _supports_param_buffer_assignment = False @@ -1017,20 +1015,6 @@ def forward( attentions=transformer_outputs.attentions, ) - @staticmethod - def _reorder_cache( - past_key_values: tuple[tuple[torch.Tensor]], beam_idx: torch.Tensor - ) -> tuple[tuple[torch.Tensor]]: - """ - This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or - [`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - """ - return tuple( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) - for layer_past in past_key_values - ) - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index f2cf41c249cc..87e2dfd793c2 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -308,8 +308,7 @@ class GranitePreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_attention_backend = True _can_record_outputs = { diff --git a/src/transformers/models/granite_speech/modeling_granite_speech.py b/src/transformers/models/granite_speech/modeling_granite_speech.py index d4d8b3767ef5..2cfaa235828d 100644 --- a/src/transformers/models/granite_speech/modeling_granite_speech.py +++ b/src/transformers/models/granite_speech/modeling_granite_speech.py @@ -282,7 +282,7 @@ def forward(self, hidden_states: torch.Tensor): @auto_docstring class GraniteSpeechPreTrainedModel(PreTrainedModel): config_class = GraniteSpeechConfig - _supports_cache_class = True + _supports_flash_attn = True _supports_sdpa = True diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 824c7ccd8b7e..132c243493f1 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -591,8 +591,7 @@ class GraniteMoePreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) def _init_weights(self, module): @@ -1022,14 +1021,5 @@ def forward( router_logits=outputs.router_logits, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = ["GraniteMoeForCausalLM", "GraniteMoeModel", "GraniteMoePreTrainedModel"] diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index fdfdae611221..761bee178f34 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -1164,8 +1164,7 @@ class GraniteMoeHybridPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) _is_stateful = True @@ -1739,15 +1738,6 @@ def forward( router_logits=outputs.router_logits, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - def prepare_inputs_for_generation( self, input_ids, @@ -1805,14 +1795,5 @@ def prepare_inputs_for_generation( ) return model_inputs - def _supports_default_dynamic_cache(self) -> bool: - """ - Function overwritten as this class uses its own `HybridMambaAttentionDynamicCache` - and do not need to initialize the Cache in advance in order to save memory - (because no back and forth `to_legacy_cache` and `from_legacy_cache` will be performed - for `HybridMambaAttentionDynamicCache`). - """ - return False - __all__ = ["GraniteMoeHybridForCausalLM", "GraniteMoeHybridModel", "GraniteMoeHybridPreTrainedModel"] diff --git a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py index f894d1a8f8a7..eea61219ac11 100644 --- a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py @@ -383,14 +383,5 @@ def prepare_inputs_for_generation( ) return model_inputs - def _supports_default_dynamic_cache(self) -> bool: - """ - Function overwritten as this class uses its own `HybridMambaAttentionDynamicCache` - and do not need to initialize the Cache in advance in order to save memory - (because no back and forth `to_legacy_cache` and `from_legacy_cache` will be performed - for `HybridMambaAttentionDynamicCache`). - """ - return False - __all__ = ["GraniteMoeHybridForCausalLM", "GraniteMoeHybridModel", "GraniteMoeHybridPreTrainedModel"] diff --git a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py index 527ee691d493..009ecbf0ddbb 100644 --- a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py +++ b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py @@ -509,8 +509,7 @@ class GraniteMoeSharedPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) def _init_weights(self, module): @@ -1054,14 +1053,5 @@ def forward( router_logits=outputs.router_logits, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = ["GraniteMoeSharedForCausalLM", "GraniteMoeSharedModel", "GraniteMoeSharedPreTrainedModel"] diff --git a/src/transformers/models/grounding_dino/modeling_grounding_dino.py b/src/transformers/models/grounding_dino/modeling_grounding_dino.py index 743f74a1215b..197c99c57be6 100644 --- a/src/transformers/models/grounding_dino/modeling_grounding_dino.py +++ b/src/transformers/models/grounding_dino/modeling_grounding_dino.py @@ -1183,11 +1183,6 @@ def __init__(self, config, num_attention_heads=None): self.dropout = nn.Dropout(config.attention_dropout) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, queries: torch.Tensor, @@ -1196,9 +1191,18 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor]: - query_layer = self.transpose_for_scores(self.query(queries)) - key_layer = self.transpose_for_scores(self.key(keys)) - value_layer = self.transpose_for_scores(self.value(values)) + batch_size, seq_length, _ = queries.shape + query_layer = ( + self.query(queries) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(keys).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + ) + value_layer = ( + self.value(values).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + ) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index c238fa200f95..7140b89b44ea 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -315,8 +315,7 @@ class HeliumPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_attention_backend = True _can_record_outputs = { diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 6f2d9cf3aec7..c4ab2cb2ec35 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -306,7 +306,6 @@ def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, diff --git a/src/transformers/models/ibert/modeling_ibert.py b/src/transformers/models/ibert/modeling_ibert.py index 89fd716f885f..5d9c9b17e496 100644 --- a/src/transformers/models/ibert/modeling_ibert.py +++ b/src/transformers/models/ibert/modeling_ibert.py @@ -226,11 +226,6 @@ def __init__(self, config): self.softmax = IntSoftmax(self.act_bit, quant_mode=self.quant_mode, force_dequant=config.force_dequant) - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states, @@ -254,9 +249,14 @@ def forward( ) # Transpose - query_layer = self.transpose_for_scores(query_layer) - key_layer = self.transpose_for_scores(key_layer) - value_layer = self.transpose_for_scores(value_layer) + batch_size, seq_length, _ = hidden_states.shape + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = value_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -564,7 +564,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = None # `config.add_cross_attention` is not supported - next_decoder_cache = None # `config.use_cache` is not supported for i, layer_module in enumerate(self.layer): if output_hidden_states: @@ -592,7 +591,6 @@ def forward( v for v in [ hidden_states, - next_decoder_cache, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -601,7 +599,6 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -765,7 +762,6 @@ def forward( return BaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=sequence_output, pooler_output=pooled_output, - past_key_values=encoder_outputs.past_key_values, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, cross_attentions=encoder_outputs.cross_attentions, diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index cf5fad7bb5fd..9812204e91c7 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -878,7 +878,7 @@ class IdeficsPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["IdeficsDecoderLayer", "IdeficsGatedCrossAttentionLayer"] _supports_sdpa = True - _supports_cache_class = True + _supports_flash_attn = True _supports_static_cache = False # IDEFICS cannot compile due to dynamic control flow when checking inputs _supports_attention_backend = True @@ -1554,12 +1554,5 @@ def _update_model_kwargs_for_generation( model_kwargs["image_hidden_states"] = outputs.image_hidden_states return model_kwargs - @staticmethod - def _reorder_cache(past, beam_idx): - reordered_past = () - for layer_past in past: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) - return reordered_past - __all__ = ["IdeficsForVisionText2Text", "IdeficsModel", "IdeficsPreTrainedModel"] diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index e596bce9dea4..6c93643a3f4f 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -459,7 +459,7 @@ class Idefics2PreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_attention_backend = True def _init_weights(self, module): diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index 2461f4b95678..021e0d9e709e 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -476,7 +476,7 @@ class Idefics3PreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_attention_backend = True def _init_weights(self, module): diff --git a/src/transformers/models/ijepa/modeling_ijepa.py b/src/transformers/models/ijepa/modeling_ijepa.py index 2c16928f0abb..a0e89f406f31 100644 --- a/src/transformers/models/ijepa/modeling_ijepa.py +++ b/src/transformers/models/ijepa/modeling_ijepa.py @@ -229,17 +229,28 @@ def __init__(self, config: IJepaConfig) -> None: self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( - self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + self, + hidden_states, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(self.query(hidden_states)) + batch_size, seq_length, _ = hidden_states.shape + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index d0372118b958..2041c615dfab 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -333,40 +334,61 @@ def _merge_heads(self, tensor, num_heads, attn_head_size): def forward( self, hidden_states: torch.Tensor, - layer_past: Optional[bool] = None, + layer_past: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple: - if encoder_hidden_states is not None: + is_cross_attention = encoder_hidden_states is not None + bsz, seq_len, _ = hidden_states.shape + + if layer_past is not None: + if isinstance(layer_past, EncoderDecoderCache): + is_updated = layer_past.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = layer_past.cross_attention_cache + else: + curr_past_key_value = layer_past.self_attention_cache + else: + curr_past_key_value = layer_past + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention: if not hasattr(self, "q_attn"): raise ValueError( "If class is used as cross attention, the weights `q_attn` have to be defined. " "Please make sure to instantiate class with `ImageGPTAttention(..., is_cross_attention=True)`." ) - query = self.q_attn(hidden_states) - key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) - attention_mask = encoder_attention_mask + if layer_past is not None and is_updated: + # reuse k,v, cross_attentions, and compute only q + query = query = self.q_attn(hidden_states) + key = curr_past_key_value.key_cache[self.layer_idx] + value = curr_past_key_value.value_cache[self.layer_idx] + else: + query = query = self.q_attn(hidden_states) + key, value = self.c_attn(current_states).split(self.split_size, dim=2) + key = key.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value = value.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) else: - query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) - - query = self._split_heads(query, self.num_heads, self.head_dim) - key = self._split_heads(key, self.num_heads, self.head_dim) - value = self._split_heads(value, self.num_heads, self.head_dim) + query, key, value = self.c_attn(current_states).split(self.split_size, dim=2) + key = key.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value = value.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) if layer_past is not None: - past_key, past_value = layer_past - key = torch.cat((past_key, key), dim=-2) - value = torch.cat((past_value, value), dim=-2) + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key, value = curr_past_key_value.update(key, value, self.layer_idx, {"cache_position": cache_position}) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + layer_past.is_updated[self.layer_idx] = True - if use_cache is True: - present = (key, value) - else: - present = None + query = query.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2) if self.reorder_and_upcast_attn: attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) @@ -377,11 +399,7 @@ def forward( attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) - outputs = (attn_output, present) - if output_attentions: - outputs += (attn_weights,) - - return outputs # a, present, (attentions) + return attn_output, attn_weights class ImageGPTMLP(nn.Module): @@ -420,13 +438,14 @@ def __init__(self, config, layer_idx=None): def forward( self, hidden_states: torch.Tensor, - layer_past: Optional[bool] = None, + layer_past: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple: residual = hidden_states hidden_states = self.ln_1(hidden_states) @@ -437,8 +456,9 @@ def forward( head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) - attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + attn_output = attn_outputs[0] outputs = attn_outputs[1:] # residual connection hidden_states = attn_output + residual @@ -454,16 +474,18 @@ def forward( hidden_states = self.ln_cross_attn(hidden_states) cross_attn_outputs = self.crossattention( hidden_states, + layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, + cache_position=cache_position, ) attn_output = cross_attn_outputs[0] # residual connection hidden_states = residual + attn_output - outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights + outputs = outputs + cross_attn_outputs[1:] # add cross attentions if we output attention weights residual = hidden_states hidden_states = self.ln_2(hidden_states) @@ -471,9 +493,7 @@ def forward( # residual connection hidden_states = residual + feed_forward_hidden_states - outputs = (hidden_states,) + (outputs if use_cache else outputs[1:]) - - return outputs # hidden_states, present, (attentions, cross_attentions) + return (hidden_states,) + outputs @auto_docstring @@ -565,12 +585,13 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs: Any, ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -618,14 +639,28 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + past_length = past_key_values.get_seq_length() if past_key_values is not None else past_key_values + if token_type_ids is not None: token_type_ids = token_type_ids.view(-1, input_shape[-1]) - if past_key_values is None: - past_length = 0 - past_key_values = tuple([None] * len(self.h)) - else: - past_length = past_key_values[0][0].size(-2) if position_ids is None: position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0) @@ -677,27 +712,15 @@ def forward( hidden_states = hidden_states + token_type_embeds hidden_states = self.drop(hidden_states) - output_shape = input_shape + (hidden_states.size(-1),) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - presents = () if use_cache else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_hidden_states = () if output_hidden_states else None - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + for i, block in enumerate(self.h): # Model parallel if self.model_parallel: torch.cuda.set_device(hidden_states.device) - # Ensure layer_past is on same device as hidden_states (might not be correct) - if layer_past is not None: - layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) # Ensure that attention_mask is always on the same device as hidden_states if attention_mask is not None: attention_mask = attention_mask.to(hidden_states.device) @@ -708,23 +731,21 @@ def forward( outputs = block( hidden_states, - layer_past, + past_key_values, attention_mask, head_mask[i], encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + all_self_attentions = all_self_attentions + (outputs[1],) if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + all_cross_attentions = all_cross_attentions + (outputs[2],) # Model Parallel: If it's the last layer for that device, put things on the next device if self.model_parallel: @@ -733,22 +754,25 @@ def forward( hidden_states = hidden_states.to("cuda:" + str(k + 1)) hidden_states = self.ln_f(hidden_states) - hidden_states = hidden_states.view(*output_shape) + # Add last hidden state if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v - for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=presents, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -798,12 +822,13 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs: Any, ) -> Union[tuple, CausalLMOutputWithCrossAttentions]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -867,6 +892,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = transformer_outputs[0] @@ -894,20 +920,6 @@ def forward( cross_attentions=transformer_outputs.cross_attentions, ) - @staticmethod - def _reorder_cache( - past_key_values: tuple[tuple[torch.Tensor]], beam_idx: torch.Tensor - ) -> tuple[tuple[torch.Tensor]]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - """ - return tuple( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) - for layer_past in past_key_values - ) - @auto_docstring( custom_intro=""" @@ -945,7 +957,7 @@ def forward( r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index 9718e8fb736e..abd55a22bfd5 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -522,7 +522,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class InformerProbSparseAttention(nn.Module): @@ -741,7 +741,7 @@ def forward( attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights_reshaped # source: https://github.com/zhouhaoyi/Informer2020/blob/main/models/encoder.py @@ -814,7 +814,7 @@ def forward( returned tensors for more detail. """ residual = hidden_states - hidden_states, attn_weights, _ = self.self_attn( + hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, @@ -924,7 +924,7 @@ def forward( residual = hidden_states # Self Attention - hidden_states, self_attn_weights, past_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -941,7 +941,7 @@ def forward( if encoder_hidden_states is not None: residual = hidden_states - hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, @@ -968,9 +968,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - if use_cache: - outputs += (past_key_value,) - return outputs @@ -1268,7 +1265,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1302,9 +1298,6 @@ def forward( ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[3 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1315,19 +1308,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/informer/modular_informer.py b/src/transformers/models/informer/modular_informer.py index 3d46275bdc81..79d7c661141f 100644 --- a/src/transformers/models/informer/modular_informer.py +++ b/src/transformers/models/informer/modular_informer.py @@ -430,7 +430,7 @@ def forward( attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights_reshaped # source: https://github.com/zhouhaoyi/Informer2020/blob/main/models/encoder.py diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index ad94340b5060..c5af37a0d942 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -339,9 +339,8 @@ class InstructBlipPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_static_cache = True - _supports_quantized_cache = False # not all LM bacbones support (e.g. T5) _no_split_modules = [ "InstructBlipQFormerEmbeddings", @@ -1354,9 +1353,8 @@ def forward( class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, GenerationMixin): config_class = InstructBlipConfig main_input_name = "pixel_values" - _supports_cache_class = True + _supports_static_cache = True - _supports_quantized_cache = False # not all LM bacbones support (e.g. T5) _keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8 def __init__(self, config: InstructBlipConfig): diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index 68ed042a3d09..0f62721d6736 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -826,9 +826,8 @@ class InstructBlipVideoPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_static_cache = True - _supports_quantized_cache = False # not all LM bacbones support (e.g. T5) _no_split_modules = [ "InstructBlipVideoQFormerEmbeddings", @@ -1360,9 +1359,8 @@ def forward( class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel, GenerationMixin): config_class = InstructBlipVideoConfig main_input_name = "pixel_values" - _supports_cache_class = True + _supports_static_cache = True - _supports_quantized_cache = False # not all LM bacbones support (e.g. T5) _keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8 def __init__(self, config: InstructBlipVideoConfig): diff --git a/src/transformers/models/internvl/modeling_internvl.py b/src/transformers/models/internvl/modeling_internvl.py index 687a6dd1e3a2..d5768ef4a8c2 100644 --- a/src/transformers/models/internvl/modeling_internvl.py +++ b/src/transformers/models/internvl/modeling_internvl.py @@ -520,10 +520,10 @@ class InternVLPreTrainedModel(PreTrainedModel): base_model_prefix = "" supports_gradient_checkpointing = True _skip_keys_device_placement = "past_key_values" - _supports_cache_class = True + _supports_flash_attn = True _supports_sdpa = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 4b15f42eadc6..c7145cd82563 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -1071,7 +1071,7 @@ class JambaPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn = True _supports_sdpa = True - _supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache + # Note: only supports HybridMambaAttentionDynamicCache _is_stateful = True def _init_weights(self, module): diff --git a/src/transformers/models/janus/modeling_janus.py b/src/transformers/models/janus/modeling_janus.py index bb8dec3b46ca..e85023a20866 100644 --- a/src/transformers/models/janus/modeling_janus.py +++ b/src/transformers/models/janus/modeling_janus.py @@ -62,8 +62,7 @@ class JanusPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values", "causal_mask"] _supports_flash_attn = True _supports_sdpa = True - _supports_quantized_cache = True - _supports_cache_class = True + _supports_static_cache = True _supports_param_buffer_assignment = False diff --git a/src/transformers/models/janus/modular_janus.py b/src/transformers/models/janus/modular_janus.py index df94df47b938..588500bae200 100644 --- a/src/transformers/models/janus/modular_janus.py +++ b/src/transformers/models/janus/modular_janus.py @@ -389,8 +389,7 @@ class JanusPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values", "causal_mask"] _supports_flash_attn = True _supports_sdpa = True - _supports_quantized_cache = True - _supports_cache_class = True + _supports_static_cache = True _supports_param_buffer_assignment = False diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 7009c2d7d553..7f725462a4de 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -828,7 +828,6 @@ class JetMoePreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True - _supports_cache_class = True def _init_weights(self, module): """Initialize the weights.""" diff --git a/src/transformers/models/kosmos2/modeling_kosmos2.py b/src/transformers/models/kosmos2/modeling_kosmos2.py index da44dae8e343..e3cbd12f7332 100644 --- a/src/transformers/models/kosmos2/modeling_kosmos2.py +++ b/src/transformers/models/kosmos2/modeling_kosmos2.py @@ -23,6 +23,7 @@ from torch import nn from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -677,9 +678,10 @@ def __init__( embed_dim: int, num_heads: int, dropout: float = 0.0, - is_decoder: bool = False, - add_inner_attn_layernorm: bool = False, - bias: bool = True, + is_decoder: Optional[bool] = False, + add_inner_attn_layernorm: Optional[bool] = False, + bias: Optional[bool] = True, + layer_idx: Optional[bool] = None, ): super().__init__() self.config = config @@ -695,6 +697,7 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.layer_idx = layer_idx self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -706,22 +709,17 @@ def __init__( if add_inner_attn_layernorm: self.inner_attn_ln = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) - def _shape(self, projection: torch.Tensor) -> torch.Tensor: - new_projection_shape = projection.size()[:-1] + (self.num_heads, self.head_dim) - # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D) - new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) - return new_projection - def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, **kwargs, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer @@ -729,33 +727,40 @@ def forward( is_cross_attention = encoder_hidden_states is not None batch_size, seq_length = hidden_states.shape[:2] - # use encoder_hidden_states if cross attention - current_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states - # checking that the `sequence_length` of the `past_key_value` is the same as the he provided - # `encoder_hidden_states` to support prefix tuning - if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: + query_states = self.q_proj(hidden_states) + query_states = query_states.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - key_states = self._shape(self.k_proj(current_states)) - value_states = self._shape(self.v_proj(current_states)) - if past_key_value is not None and not is_cross_attention: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - query_states = self._shape(self.q_proj(hidden_states)) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True attention_interface: Callable = eager_attention_forward @@ -785,7 +790,7 @@ def forward( attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class Kosmos2TextFFN(nn.Module): @@ -812,7 +817,7 @@ def forward(self, hidden_states): class Kosmos2TextBlock(GradientCheckpointingLayer): - def __init__(self, config: Kosmos2TextConfig): + def __init__(self, config: Kosmos2TextConfig, layer_idx=None): super().__init__() self.embed_dim = config.embed_dim @@ -823,6 +828,7 @@ def __init__(self, config: Kosmos2TextConfig): dropout=config.attention_dropout, is_decoder=True, add_inner_attn_layernorm=True, + layer_idx=layer_idx, ) self.dropout = config.dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) @@ -835,6 +841,7 @@ def __init__(self, config: Kosmos2TextConfig): dropout=config.attention_dropout, is_decoder=True, add_inner_attn_layernorm=False, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) @@ -849,33 +856,28 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, **kwargs, ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states - - # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - hidden_states = self.self_attn_layer_norm(hidden_states) - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, **kwargs, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: if not hasattr(self, "encoder_attn"): @@ -885,26 +887,21 @@ def forward( ) residual = hidden_states - hidden_states = self.encoder_attn_layer_norm(hidden_states) - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, + cache_position=cache_position, **kwargs, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # Fully Connected residual = hidden_states @@ -918,10 +915,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - - if use_cache: - outputs += (present_key_value,) - return outputs @@ -948,7 +941,7 @@ def __init__(self, config: Kosmos2TextConfig): padding_idx=config.pad_token_id, ) - self.layers = nn.ModuleList([Kosmos2TextBlock(config) for _ in range(config.layers)]) + self.layers = nn.ModuleList([Kosmos2TextBlock(config, layer_idx=i) for i in range(config.layers)]) self.layer_norm = nn.LayerNorm(config.embed_dim, config.layer_norm_eps) self.gradient_checkpointing = False @@ -1027,6 +1020,8 @@ def forward( use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -1045,8 +1040,24 @@ def forward( else: raise ValueError("You have to specify either input_ids or inputs_embeds") - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 # We don't need img info. when `past_key_values_length` > 0 if past_key_values_length > 0: @@ -1073,18 +1084,10 @@ def forward( hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - present_key_value_states = () if use_cache else None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1104,8 +1107,6 @@ def forward( if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - layer_outputs = decoder_layer( hidden_states, attention_mask, @@ -1113,16 +1114,14 @@ def forward( encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, **kwargs, ) hidden_states = layer_outputs[0] - if use_cache: - present_key_value_states += (layer_outputs[3 if output_attentions else 1],) - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1132,13 +1131,16 @@ def forward( # add final layer norm hidden_states = self.layer_norm(hidden_states) + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=present_key_value_states, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, @@ -1304,6 +1306,8 @@ def forward( use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: r""" @@ -1336,6 +1340,8 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, **kwargs, ) @@ -1391,6 +1397,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, CausalLMOutputWithCrossAttentions]: r""" @@ -1435,6 +1442,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, + cache_position=cache_position, **kwargs, ) lm_logits = self.lm_head(outputs[0]) @@ -1466,9 +1474,11 @@ def prepare_inputs_for_generation( ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model - if past_key_values is not None: + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + if cache_position[0] != 0: image_embeds = None image_embeds_position_mask = None + # appending `False` to `image_embeds_position_mask` (because `input_ids` grows during generation) elif image_embeds_position_mask is not None: batch_size, seq_len = inputs_embeds.size()[:-1] if inputs_embeds is not None else input_ids.size() @@ -1497,16 +1507,6 @@ def prepare_inputs_for_generation( return model_inputs - @staticmethod - # Copied from transformers.models.umt5.modeling_umt5.UMT5ForConditionalGeneration._reorder_cache - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - class Kosmos2ImageToTextProjection(nn.Module): """The layer that transforms the image model's output to part of the text model's input (namely, image features)""" @@ -1532,7 +1532,7 @@ def forward(self, features): latent_query = self.latent_query.unsqueeze(0).expand(hidden_states.size(0), -1, -1) key_value_states = torch.cat([hidden_states, latent_query], dim=1) - hidden_states, attn_weights, _ = self.x_attn( + hidden_states, attn_weights = self.x_attn( hidden_states=latent_query, encoder_hidden_states=key_value_states, past_key_value=None, diff --git a/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py b/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py index 41b5800b2392..04b38cd10aa2 100644 --- a/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +++ b/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py @@ -119,7 +119,7 @@ class KyutaiSpeechToTextPreTrainedModel(PreTrainedModel): _no_split_modules = ["KyutaiSpeechToTextDecoderLayer", "MimiTransformerLayer"] _supports_flash_attn = True _supports_sdpa = True - _supports_cache_class = True + main_input_name = "input_ids" def _init_weights(self, module): diff --git a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py index 66637bedd8d2..7f6a861a674d 100755 --- a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py @@ -123,11 +123,6 @@ def __init__(self, config): self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - def compute_qkv(self, hidden_states): if self.fast_qkv: qkv = self.qkv_linear(hidden_states) @@ -154,12 +149,13 @@ def forward( rel_pos=None, rel_2d_pos=None, ): - q, k, v = self.compute_qkv(hidden_states) + batch_size, seq_length, _ = hidden_states.shape + query, key, value = self.compute_qkv(hidden_states) # (B, L, H*D) -> (B, H, L, D) - query_layer = self.transpose_for_scores(q) - key_layer = self.transpose_for_scores(k) - value_layer = self.transpose_for_scores(v) + query_layer = query.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + key_layer = key.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = value.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) query_layer = query_layer / math.sqrt(self.attention_head_size) # [BSZ, NAT, L, L] diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py index 05f662b12a9f..8b5628541092 100644 --- a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py @@ -245,11 +245,6 @@ def __init__(self, config): self.has_relative_attention_bias = config.has_relative_attention_bias self.has_spatial_attention_bias = config.has_spatial_attention_bias - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - def cogview_attention(self, attention_scores, alpha=32): """ https://huggingface.co/papers/2105.13290 Section 2.4 Stabilization of training: Precision Bottleneck Relaxation @@ -271,11 +266,22 @@ def forward( rel_pos=None, rel_2d_pos=None, ): - mixed_query_layer = self.query(hidden_states) - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + batch_size, seq_length, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # Take the dot product between "query" and "key" to get the raw attention scores. # The attention scores QT K/√d could be significantly larger than input elements, and result in overflow. diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index 4055d9148570..b665858862ab 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask from ...modeling_layers import GradientCheckpointingLayer @@ -765,9 +766,10 @@ def __init__( self, embed_dim: int, num_heads: int, - dropout: float = 0.0, - is_decoder: bool = False, - bias: bool = True, + dropout: Optional[float] = 0.0, + is_decoder: Optional[bool] = False, + bias: Optional[bool] = True, + layer_idx: Optional[bool] = None, ): super().__init__() self.embed_dim = embed_dim @@ -781,24 +783,23 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.layer_idx = layer_idx self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer @@ -808,40 +809,44 @@ def forward( # get query proj query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - if is_cross_attention and past_key_value is not None: + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states.reshape(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) @@ -964,7 +969,7 @@ def forward( class LEDDecoderLayer(GradientCheckpointingLayer): - def __init__(self, config: LEDConfig): + def __init__(self, config: LEDConfig, layer_idx=None): super().__init__() self.embed_dim = config.d_model @@ -973,6 +978,7 @@ def __init__(self, config: LEDConfig): num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -984,6 +990,7 @@ def __init__(self, config: LEDConfig): config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) @@ -998,9 +1005,10 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ): """ Args: @@ -1022,15 +1030,13 @@ def forward( residual = hidden_states # Self-Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -1042,23 +1048,19 @@ def forward( if encoder_hidden_states is not None: residual = hidden_states - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # Fully Connected residual = hidden_states hidden_states = self.activation_fn(self.fc1(hidden_states)) @@ -1074,7 +1076,7 @@ def forward( outputs += (self_attn_weights, cross_attn_weights) if use_cache: - outputs += (present_key_value,) + outputs += (past_key_value,) return outputs @@ -1629,7 +1631,7 @@ def __init__(self, config: LEDConfig, embed_tokens: Optional[nn.Embedding] = Non self.max_target_positions, config.d_model, ) - self.layers = nn.ModuleList([LEDDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layers = nn.ModuleList([LEDDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) self.layernorm_embedding = nn.LayerNorm(config.d_model) self.gradient_checkpointing = False @@ -1651,6 +1653,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, + cache_position=None, ): r""" Args: @@ -1744,12 +1747,27 @@ def forward( else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] combined_attention_mask = None @@ -1779,18 +1797,10 @@ def forward( hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if output_attentions else None - next_decoder_cache = () if use_cache else None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1809,8 +1819,6 @@ def forward( if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - layer_outputs = decoder_layer( hidden_states, combined_attention_mask, @@ -1818,16 +1826,13 @@ def forward( encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) - if output_attentions: all_self_attns += (layer_outputs[1],) all_cross_attentions += (layer_outputs[2],) @@ -1836,16 +1841,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, @@ -1901,6 +1908,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.Tensor], LEDSeq2SeqModelOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1986,6 +1994,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -2071,6 +2080,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.Tensor], LEDSeq2SeqLMOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -2190,6 +2200,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias @@ -2218,17 +2229,6 @@ def forward( def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) - + layer_past[2:], - ) - return reordered_past - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/lfm2/modeling_lfm2.py b/src/transformers/models/lfm2/modeling_lfm2.py index 1c9e02662447..dc58e8f2e1eb 100644 --- a/src/transformers/models/lfm2/modeling_lfm2.py +++ b/src/transformers/models/lfm2/modeling_lfm2.py @@ -542,8 +542,6 @@ class Lfm2PreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True _supports_static_cache = False _supports_attention_backend = True _can_record_outputs = { diff --git a/src/transformers/models/lilt/modeling_lilt.py b/src/transformers/models/lilt/modeling_lilt.py index d2dd1c751664..deddd105da9e 100644 --- a/src/transformers/models/lilt/modeling_lilt.py +++ b/src/transformers/models/lilt/modeling_lilt.py @@ -186,7 +186,7 @@ def forward(self, bbox=None, position_ids=None): class LiltSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -221,6 +221,7 @@ def __init__(self, config, position_embedding_type=None): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.channel_shrink_ratio = config.channel_shrink_ratio + self.layer_idx = layer_idx def transpose_for_scores(self, x, r=1): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size // r) @@ -338,9 +339,9 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class LiltAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() - self.self = LiltSelfAttention(config, position_embedding_type=position_embedding_type) + self.self = LiltSelfAttention(config, position_embedding_type=position_embedding_type, layer_idx=layer_idx) self.output = LiltSelfOutput(config) self.pruned_heads = set() @@ -421,11 +422,11 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class LiltLayer(GradientCheckpointingLayer): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = LiltAttention(config) + self.attention = LiltAttention(config, layer_idx=layer_idx) self.intermediate = LiltIntermediate(config) self.output = LiltOutput(config) @@ -482,10 +483,10 @@ def layout_feed_forward_chunk(self, attention_output): class LiltEncoder(nn.Module): # Copied from transformers.models.bert.modeling_bert.BertEncoder.__init__ with Bert->Lilt - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.config = config - self.layer = nn.ModuleList([LiltLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([LiltLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 3a078129b035..c5a57563639f 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -314,8 +314,7 @@ class LlamaPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_attention_backend = True _can_record_outputs = { diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py index b3781f65c08f..24fc14bc96f0 100644 --- a/src/transformers/models/llama4/modeling_llama4.py +++ b/src/transformers/models/llama4/modeling_llama4.py @@ -436,8 +436,7 @@ class Llama4PreTrainedModel(PreTrainedModel): _supports_flash_attn = False _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index b39360bb5863..7cd79de12d3c 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -117,10 +117,10 @@ class LlavaPreTrainedModel(PreTrainedModel): base_model_prefix = "" supports_gradient_checkpointing = True _skip_keys_device_placement = "past_key_values" - _supports_cache_class = True + _supports_flash_attn = True _supports_sdpa = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index 249b85a5598d..c019fb275cf3 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -228,10 +228,10 @@ class LlavaNextPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["LlamaDecoderLayer"] _skip_keys_device_placement = "past_key_values" - _supports_cache_class = True + _supports_flash_attn = True _supports_sdpa = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index e2b2effba949..2d73a51a2856 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -169,10 +169,10 @@ class LlavaNextVideoPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["LlamaDecoderLayer"] _skip_keys_device_placement = "past_key_values" - _supports_cache_class = True + _supports_flash_attn = True _supports_sdpa = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index e82f18b0ed0c..af16955b413c 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -282,10 +282,10 @@ class LlavaOnevisionPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["LlamaDecoderLayer"] _skip_keys_device_placement = "past_key_values" - _supports_cache_class = True + _supports_flash_attn = True _supports_sdpa = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index 537b816b79f8..b0613a2ea543 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -1249,7 +1249,7 @@ class LongT5PreTrainedModel(PreTrainedModel): base_model_prefix = "transformer" supports_gradient_checkpointing = True _no_split_modules = ["LongT5Block"] - _supports_cache_class = True + _supports_static_cache = False # TODO: @raushan more involved due to local/global attn @property @@ -2113,30 +2113,6 @@ def forward( def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return self._shift_right(labels) - def _reorder_cache(self, past_key_values, beam_idx): - # if decoder past is not included in output - # speedy decoding is disabled and no need to reorder - if past_key_values is None: - logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") - return past_key_values - - reordered_decoder_past = () - for layer_past_states in past_key_values: - # get the correct batch idx from layer past batch dim - # batch dim of `past` is at 2nd position - reordered_layer_past_states = () - for layer_past_state in layer_past_states: - # need to set correct `past` for each of the four key / value states - reordered_layer_past_states = reordered_layer_past_states + ( - layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), - ) - - assert reordered_layer_past_states[0].shape == layer_past_states[0].shape - assert len(reordered_layer_past_states) == len(layer_past_states) - - reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) - return reordered_decoder_past - @auto_docstring class LongT5EncoderModel(LongT5PreTrainedModel): diff --git a/src/transformers/models/lxmert/modeling_lxmert.py b/src/transformers/models/lxmert/modeling_lxmert.py index bc8a84d17e01..4138cb0b82a9 100644 --- a/src/transformers/models/lxmert/modeling_lxmert.py +++ b/src/transformers/models/lxmert/modeling_lxmert.py @@ -319,22 +319,21 @@ def __init__(self, config, ctx_dim=None): self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + ( - self.num_attention_heads, - self.attention_head_size, - ) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward(self, hidden_states, context, attention_mask=None, output_attentions=False): - mixed_query_layer = self.query(hidden_states) - mixed_key_layer = self.key(context) - mixed_value_layer = self.value(context) - - query_layer = self.transpose_for_scores(mixed_query_layer) - key_layer = self.transpose_for_scores(mixed_key_layer) - value_layer = self.transpose_for_scores(mixed_value_layer) + batch_size, seq_length, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(context).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + ) + value_layer = ( + self.value(context) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 1f33da304f52..cc663aa7432d 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -332,7 +332,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->M2M100, MBART->M2M100 @@ -375,7 +375,7 @@ def forward( """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states, attn_weights, _ = self.self_attn( + hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, @@ -396,12 +396,7 @@ def forward( clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs + return hidden_states, attn_weights # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->M2M100, MBART->M2M100 @@ -475,7 +470,7 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - hidden_states, self_attn_weights, past_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -492,7 +487,7 @@ def forward( residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, @@ -516,10 +511,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - - if use_cache: - outputs += (past_key_value,) - return outputs @@ -532,7 +523,7 @@ class M2M100PreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + # Doesn't support `compile` (dynamic control flow). Can be fixed but low usage model _supports_static_cache = False @@ -1111,7 +1102,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if output_attentions else None - next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1154,9 +1144,6 @@ def forward( if skip_the_layer: continue - if use_cache: - next_decoder_cache = layer_outputs[3 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) all_cross_attentions += (layer_outputs[2],) @@ -1167,19 +1154,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, @@ -1464,14 +1450,5 @@ def forward( encoder_attentions=outputs.encoder_attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = ["M2M100ForConditionalGeneration", "M2M100Model", "M2M100PreTrainedModel"] diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 3782d4bc3dff..94c913ad7a66 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -267,7 +267,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights # Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->Marian, BART->MARIAN @@ -310,7 +310,7 @@ def forward( returned tensors for more detail. """ residual = hidden_states - hidden_states, attn_weights, _ = self.self_attn( + hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, @@ -412,7 +412,7 @@ def forward( residual = hidden_states # Self Attention - hidden_states, self_attn_weights, past_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -429,7 +429,7 @@ def forward( if encoder_hidden_states is not None: residual = hidden_states - hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, @@ -455,10 +455,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - - if use_cache: - outputs += (past_key_value,) - return outputs @@ -470,7 +466,7 @@ class MarianPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module: Union[nn.Linear, nn.Embedding, MarianSinusoidalPositionalEmbedding]): @@ -1061,7 +1057,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1093,9 +1088,6 @@ def forward( ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[3 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1106,19 +1098,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, @@ -1589,17 +1580,6 @@ def forward( def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) - + layer_past[2:], - ) - return reordered_past - # Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->Marian class MarianDecoderWrapper(MarianPreTrainedModel): @@ -1740,14 +1720,5 @@ def forward( cross_attentions=outputs.cross_attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = ["MarianForCausalLM", "MarianModel", "MarianMTModel", "MarianPreTrainedModel"] diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py index 41dba3a25635..9fb5a7469bcc 100755 --- a/src/transformers/models/markuplm/modeling_markuplm.py +++ b/src/transformers/models/markuplm/modeling_markuplm.py @@ -742,15 +742,6 @@ def forward( attentions=encoder_outputs.attentions, ) - # Copied from transformers.models.bert.modeling_bert.BertModel._reorder_cache - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @auto_docstring class MarkupLMForQuestionAnswering(MarkupLMPreTrainedModel): diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 6f3ca84a2485..773101861114 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -277,7 +277,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class MBartEncoderLayer(GradientCheckpointingLayer): @@ -319,7 +319,7 @@ def forward( """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states, attn_weights, _ = self.self_attn( + hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, @@ -340,12 +340,7 @@ def forward( clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs + return hidden_states, attn_weights class MBartDecoderLayer(GradientCheckpointingLayer): @@ -418,7 +413,7 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - hidden_states, self_attn_weights, past_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -435,7 +430,7 @@ def forward( residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, @@ -460,9 +455,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - if use_cache: - outputs += (past_key_value,) - return outputs @@ -500,7 +492,7 @@ class MBartPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): @@ -1103,7 +1095,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1135,10 +1126,6 @@ def forward( cache_position=cache_position, ) hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[3 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1151,19 +1138,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, @@ -1506,17 +1492,6 @@ def forward( def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) - + layer_past[2:], - ) - return reordered_past - @auto_docstring( custom_intro=""" @@ -1943,15 +1918,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = [ "MBartForCausalLM", diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index 5410f6bf0ee1..7ed94107b753 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -27,6 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -43,6 +44,7 @@ from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_megatron_bert import MegatronBertConfig @@ -178,7 +180,7 @@ def forward( # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->MegatronBert class MegatronBertSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -203,12 +205,9 @@ def __init__(self, config, position_embedding_type=None): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, @@ -216,53 +215,65 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if is_cross_attention and encoder_attention_mask is not None: + attention_mask = encoder_attention_mask + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value - if is_cross_attention and past_key_value is not None: + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - - query_layer = self.transpose_for_scores(mixed_query_layer) - - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) + key_layer = self.key(current_states) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + value_layer = self.value(current_states) + value_layer = value_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: + if past_key_value is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -304,11 +315,7 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return context_layer, attention_probs # Based transformers.models.bert.modeling_bert.BertSelfOutput. Moved LayerNorm to MegatronBertAttention below. @@ -326,10 +333,10 @@ def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch. # Based transformers.models.bert.modeling_bert.BertAttention. Added LayerNorm. class MegatronBertAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.self = MegatronBertSelfAttention(config) + self.self = MegatronBertSelfAttention(config, layer_idx=layer_idx) self.output = MegatronBertSelfOutput(config) self.pruned_heads = set() @@ -357,19 +364,19 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: ln_outputs = self.ln(hidden_states) self_outputs = self.self( ln_outputs, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -407,17 +414,17 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Based on transformers.models.bert.modeling_bert.BertLayer. Added LayerNorm. class MegatronBertLayer(GradientCheckpointingLayer): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = MegatronBertAttention(config) + self.attention = MegatronBertAttention(config, layer_idx=layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise TypeError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = MegatronBertAttention(config) + self.crossattention = MegatronBertAttention(config, layer_idx=layer_idx) self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.intermediate = MegatronBertIntermediate(config) self.output = MegatronBertOutput(config) @@ -429,28 +436,22 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise AttributeError( @@ -458,34 +459,22 @@ def forward( " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - cross_attn_past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) - outputs = (layer_output,) + outputs - - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (present_key_value,) - - return outputs + return (layer_output,) + outputs def feed_forward_chunk(self, attention_output): ln_output = self.ln(attention_output) @@ -498,7 +487,7 @@ class MegatronBertEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config - self.layer = nn.ModuleList([MegatronBertLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([MegatronBertLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) # The final layer norm. We removed the 1st LN, moved LN to each hidden layer and this one # is simply the final LN (Transformer's BERT has it attached to each hidden layer). @@ -517,6 +506,7 @@ def forward( output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: if self.gradient_checkpointing and self.training: if use_cache: @@ -524,17 +514,25 @@ def forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - next_decoder_cache = () if use_cache else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None layer_outputs = layer_module( hidden_states, @@ -542,16 +540,15 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) # Because we moved the layer-norm at the end of the hidden layer, we have non-normali- # zed data here. If that's really needed, we must apply LN to match Transformer's BERT. hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -563,12 +560,15 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -577,7 +577,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -785,6 +785,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple, BaseModelOutputWithPoolingAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -810,8 +811,13 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) @@ -858,6 +864,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None @@ -1027,6 +1034,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs, ) -> Union[tuple, CausalLMOutputWithCrossAttentions]: r""" @@ -1067,6 +1075,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = outputs[0] @@ -1094,14 +1103,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @auto_docstring class MegatronBertForMaskedLM(MegatronBertPreTrainedModel): diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 3ca8fe2b3918..6afc4fdb91e3 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -1375,7 +1375,7 @@ class MimiPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn = True _supports_sdpa = True - _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): diff --git a/src/transformers/models/minimax/modeling_minimax.py b/src/transformers/models/minimax/modeling_minimax.py index a5ad9c68ad02..057f8d809781 100644 --- a/src/transformers/models/minimax/modeling_minimax.py +++ b/src/transformers/models/minimax/modeling_minimax.py @@ -588,8 +588,7 @@ class MiniMaxPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True # Note: only supports MiniMaxCache - _supports_quantized_cache = False + # Note: only supports MiniMaxCache _supports_static_cache = False _supports_attention_backend = True _can_record_outputs = { diff --git a/src/transformers/models/minimax/modular_minimax.py b/src/transformers/models/minimax/modular_minimax.py index 9e2ee0b17b03..423ae27717c4 100644 --- a/src/transformers/models/minimax/modular_minimax.py +++ b/src/transformers/models/minimax/modular_minimax.py @@ -470,9 +470,8 @@ def forward( class MiniMaxPreTrainedModel(MixtralPreTrainedModel): - _supports_cache_class = True # Note: only supports MiniMaxCache + # Note: only supports MiniMaxCache _supports_static_cache = False - _supports_quantized_cache = False _can_record_outputs = { "router_logits": OutputRecorder(MiniMaxSparseMoeBlock, index=1), "hidden_states": MiniMaxDecoderLayer, diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index cc03b14c553d..a6fbfbb4898a 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -259,8 +259,7 @@ class MistralPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_attention_backend = True _can_record_outputs = { diff --git a/src/transformers/models/mistral3/modeling_mistral3.py b/src/transformers/models/mistral3/modeling_mistral3.py index 8abcab6d3647..20e81a8404d1 100644 --- a/src/transformers/models/mistral3/modeling_mistral3.py +++ b/src/transformers/models/mistral3/modeling_mistral3.py @@ -182,10 +182,10 @@ class Mistral3PreTrainedModel(PreTrainedModel): base_model_prefix = "" supports_gradient_checkpointing = True _skip_keys_device_placement = "past_key_values" - _supports_cache_class = True + _supports_flash_attn = True _supports_sdpa = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index ac64ace04683..cc78dfecd5e4 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -388,8 +388,6 @@ class MixtralPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) _supports_attention_backend = True _can_record_outputs = { diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index 0c418e70a910..fcfa4a3db72e 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -849,11 +849,10 @@ class MllamaPreTrainedModel(PreTrainedModel): "MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer", ] - _supports_cache_class = True + _supports_static_cache = False # static cache cannot have different shapes for each layer _supports_sdpa = True _supports_flash_attn = True - _supports_quantized_cache = True _supports_flex_attn = True _supports_attention_backend = True @@ -1603,7 +1602,6 @@ def forward( ) class MllamaModel(MllamaPreTrainedModel): _checkpoint_conversion_mapping = {"language_model.model": "language_model"} - _supports_quantized_cache = False # quant cache not supported in encoder-decoder setting def __init__(self, config: MllamaConfig): super().__init__(config) @@ -1763,7 +1761,6 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin): "^multi_modal_projector": "model.multi_modal_projector", "^language_model.lm_head": "lm_head", } - _supports_quantized_cache = False # quant cache not supported in encoder-decoder setting _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: MllamaConfig): diff --git a/src/transformers/models/mobilebert/modeling_mobilebert.py b/src/transformers/models/mobilebert/modeling_mobilebert.py index b1c267c959bc..91508d099711 100644 --- a/src/transformers/models/mobilebert/modeling_mobilebert.py +++ b/src/transformers/models/mobilebert/modeling_mobilebert.py @@ -231,11 +231,6 @@ def __init__(self, config): ) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, query_tensor: torch.Tensor, @@ -245,13 +240,22 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(query_tensor) - mixed_key_layer = self.key(key_tensor) - mixed_value_layer = self.value(value_tensor) - - query_layer = self.transpose_for_scores(mixed_query_layer) - key_layer = self.transpose_for_scores(mixed_key_layer) - value_layer = self.transpose_for_scores(mixed_value_layer) + batch_size, seq_length, _ = query_tensor.shape + query_layer = ( + self.query(query_tensor) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(key_tensor) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(value_tensor) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/mobilevit/modeling_mobilevit.py b/src/transformers/models/mobilevit/modeling_mobilevit.py index 1b483fe958c0..3f882b9850ff 100755 --- a/src/transformers/models/mobilevit/modeling_mobilevit.py +++ b/src/transformers/models/mobilevit/modeling_mobilevit.py @@ -211,17 +211,23 @@ def __init__(self, config: MobileViTConfig, hidden_size: int) -> None: self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - mixed_query_layer = self.query(hidden_states) - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + batch_size, seq_length, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py index a997f40e916a..011db51daac4 100644 --- a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +++ b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py @@ -224,8 +224,6 @@ class ModernBertDecoderPreTrainedModel(ModernBertPreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = False _supports_gradient_checkpointing = True - _supports_cache_class = True - _supports_quantized_cache = True _supports_static_cache = False _supports_attention_backend = True _can_record_outputs = { @@ -422,11 +420,10 @@ def forward( **kwargs, ) -> Union[tuple, CausalLMOutputWithPast]: r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Returns: [`~modeling_outputs.CausalLMOutputWithPast`] @@ -484,15 +481,6 @@ def forward( attentions=outputs.attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py b/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py index 46a3674af00c..7609c2f1febf 100644 --- a/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py +++ b/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py @@ -401,8 +401,6 @@ class ModernBertDecoderPreTrainedModel(ModernBertPreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = False _supports_gradient_checkpointing = True - _supports_cache_class = True - _supports_quantized_cache = True _supports_static_cache = False _supports_attention_backend = True _can_record_outputs = { @@ -599,11 +597,10 @@ def forward( **kwargs, ) -> Union[tuple, CausalLMOutputWithPast]: r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Returns: [`~modeling_outputs.CausalLMOutputWithPast`] @@ -661,15 +658,6 @@ def forward( attentions=outputs.attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index 33d0a75f3722..9c2641d978a1 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -461,7 +461,7 @@ class MoonshinePreTrainedModel(PreTrainedModel): _no_split_modules = ["MoonshineEncoderLayer", "MoonshineDecoderLayer"] _supports_flash_attn = True _supports_sdpa = True - _supports_cache_class = True + _supports_static_cache = True # TODO arthur, how do we separate when it cross / self coming from different layer? diff --git a/src/transformers/models/moonshine/modular_moonshine.py b/src/transformers/models/moonshine/modular_moonshine.py index 5e56fee5e018..4e2882fb81dc 100644 --- a/src/transformers/models/moonshine/modular_moonshine.py +++ b/src/transformers/models/moonshine/modular_moonshine.py @@ -496,7 +496,7 @@ class MoonshinePreTrainedModel(PreTrainedModel): _no_split_modules = ["MoonshineEncoderLayer", "MoonshineDecoderLayer"] _supports_flash_attn = True _supports_sdpa = True - _supports_cache_class = True + _supports_static_cache = True # TODO arthur, how do we separate when it cross / self coming from different layer? diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index 3b09eba5e0d9..4cde9816bf79 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -807,7 +807,7 @@ class MoshiPreTrainedModel(PreTrainedModel): _no_split_modules = ["MoshiDecoderLayer", "MimiTransformerLayer"] _supports_flash_attn = True _supports_sdpa = True - _supports_cache_class = True + main_input_name = "input_ids" def _init_weights(self, module): @@ -2526,19 +2526,5 @@ def _check_and_maybe_initialize_inputs( return input_ids, user_audio_codes, moshi_audio_codes, concat_unconditional_inputs - @staticmethod - def _reorder_cache( - past_key_values: tuple[tuple[torch.Tensor]], beam_idx: torch.Tensor - ) -> tuple[tuple[torch.Tensor]]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - """ - return tuple( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) - for layer_past in past_key_values - ) - __all__ = ["MoshiForCausalLM", "MoshiForConditionalGeneration", "MoshiModel", "MoshiPreTrainedModel"] diff --git a/src/transformers/models/mpnet/modeling_mpnet.py b/src/transformers/models/mpnet/modeling_mpnet.py index efa74f191f13..82698f8ecfb2 100644 --- a/src/transformers/models/mpnet/modeling_mpnet.py +++ b/src/transformers/models/mpnet/modeling_mpnet.py @@ -144,11 +144,6 @@ def __init__(self, config): self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states, @@ -158,13 +153,22 @@ def forward( output_attentions=False, **kwargs, ): - q = self.q(hidden_states) - k = self.k(hidden_states) - v = self.v(hidden_states) - - q = self.transpose_for_scores(q) - k = self.transpose_for_scores(k) - v = self.transpose_for_scores(v) + batch_size, seq_length, _ = hidden_states.shape + q = ( + self.q(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + k = ( + self.k(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + v = ( + self.v(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(q, k.transpose(-1, -2)) diff --git a/src/transformers/models/mpt/modeling_mpt.py b/src/transformers/models/mpt/modeling_mpt.py index b3005728e7b9..81680bef7950 100644 --- a/src/transformers/models/mpt/modeling_mpt.py +++ b/src/transformers/models/mpt/modeling_mpt.py @@ -23,6 +23,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss from torch.nn import functional as F +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from ...modeling_layers import GradientCheckpointingLayer @@ -69,7 +70,7 @@ class MptAttention(nn.Module): Using torch or triton attention implementation enables user to also use additive bias. """ - def __init__(self, config: MptConfig): + def __init__(self, config: MptConfig, layer_idx: Optional[int] = None): super().__init__() self.hidden_size = config.hidden_size self.n_heads = config.n_heads @@ -83,13 +84,15 @@ def __init__(self, config: MptConfig): self.clip_qkv = config.attn_config.clip_qkv self.Wqkv = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False) self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.layer_idx = layer_idx def forward( self, hidden_states: torch.Tensor, position_bias: torch.Tensor, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, + cache_position: Optional[torch.Tensor] = None, ): batch_size, seq_length = hidden_states.shape[:2] @@ -103,16 +106,11 @@ def forward( value_states = value_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) if past_key_value is not None: - if len(past_key_value) != 0: - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - past_key_value = (key_states, value_states) - else: - past_key_value = (key_states, value_states) + cache_kwargs = {"cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.softmax_scale - - query_length = seq_length if past_key_value is None else seq_length + past_key_value[0].shape[2] + query_length = seq_length if past_key_value is None else seq_length + past_key_value.get_seq_length() if position_bias is not None: if len(position_bias.shape) != 3: @@ -137,7 +135,7 @@ def forward( context_states = context_states.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_length, -1) attn_output = self.out_proj(context_states) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class MptMLP(nn.Module): @@ -162,7 +160,7 @@ def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch. class MptBlock(GradientCheckpointingLayer): - def __init__(self, config: MptConfig): + def __init__(self, config: MptConfig, layer_idx: Optional[int] = None): super().__init__() hidden_size = config.hidden_size @@ -171,7 +169,7 @@ def __init__(self, config: MptConfig): self.norm_1.bias = None self.num_heads = config.n_heads - self.attn = MptAttention(config) + self.attn = MptAttention(config, layer_idx) self.norm_2 = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) # backward compatibility with weights on the Hub @@ -187,9 +185,10 @@ def forward( hidden_states: torch.Tensor, position_bias: torch.Tensor, attention_mask: torch.Tensor, - layer_past: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + layer_past: Optional[Cache] = None, use_cache: bool = False, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, ): # hidden_states: [batch_size, seq_length, hidden_size] # Layer norm at the beginning of the transformer layer. @@ -198,11 +197,12 @@ def forward( residual = hidden_states # Self attention. - attn_outputs, attn_weights, past_key_value = self.attn( + attn_outputs, attn_weights = self.attn( layernorm_output, position_bias=position_bias, attention_mask=attention_mask, past_key_value=layer_past, + cache_position=cache_position, ) hidden_states = self.resid_attn_dropout(attn_outputs) + residual @@ -214,15 +214,7 @@ def forward( # MLP. output = self.ffn(layernorm_output, residual) - outputs = (output,) - - if use_cache: - outputs += (past_key_value,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs # hidden_states, present, attentions + return output, attn_weights @auto_docstring @@ -285,7 +277,7 @@ def __init__(self, config: MptConfig): self.wte = nn.Embedding(config.vocab_size, self.hidden_size) # Transformer blocks - self.blocks = nn.ModuleList([MptBlock(config) for _ in range(config.n_layers)]) + self.blocks = nn.ModuleList([MptBlock(config, layer_idx=i) for i in range(config.n_layers)]) # Final Layer Norm self.norm_f = LayerNorm(self.hidden_size, eps=config.layer_norm_epsilon) @@ -310,18 +302,19 @@ def set_input_embeddings(self, new_embeddings: torch.Tensor): def forward( self, input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[tuple[tuple[torch.Tensor, torch.Tensor], ...]] = None, + past_key_values: Optional[Union[tuple[tuple[torch.Tensor, torch.Tensor], ...], Cache]] = None, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs, # NOOP kwargs, for now ) -> Union[tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]` + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -348,31 +341,34 @@ def forward( else: raise ValueError("You have to specify either input_ids or inputs_embeds") - if past_key_values is None: - past_key_values = tuple([None] * len(self.blocks)) + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False if inputs_embeds is None: inputs_embeds = self.wte(input_ids) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `DynamicCache` instead, e.g. " + "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + hidden_states = inputs_embeds - presents = () if use_cache else None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # Compute alibi tensor: check build_alibi_tensor documentation - seq_length_with_past = seq_length - past_key_values_length = 0 - if past_key_values[0] is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + seq_length_with_past = seq_length + past_key_values_length if attention_mask is None: attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) else: @@ -385,38 +381,41 @@ def forward( ) causal_mask = causal_mask.bool() - for block, layer_past in zip(self.blocks, past_key_values): + for block in self.blocks: if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) outputs = block( hidden_states, - layer_past=layer_past, + layer_past=past_key_values, attention_mask=causal_mask, use_cache=use_cache, output_attentions=output_attentions, position_bias=alibi, + cache_position=cache_position, ) hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + all_self_attentions = all_self_attentions + (outputs[1],) # Add last hidden state hidden_states = self.norm_f(hidden_states) + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + return tuple( + v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions] if v is not None + ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=presents, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, ) @@ -457,11 +456,12 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs, ) -> Union[tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]` + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -487,6 +487,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = transformer_outputs[0] @@ -516,29 +517,6 @@ def forward( attentions=transformer_outputs.attentions, ) - def _reorder_cache( - self, past: tuple[tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor - ) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - - Output shares the same memory storage as `past`. - """ - # Get a copy of `beam_idx` on all the devices where we need those indices. - device_to_beam_idx = { - past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past - } - reordered_past = tuple( - ( - layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]), - layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]), - ) - for layer_past in past - ) - return reordered_past - @auto_docstring( custom_intro=""" @@ -579,7 +557,7 @@ def forward( ) -> Union[tuple[torch.Tensor], SequenceClassifierOutputWithPast]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]` + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -703,7 +681,7 @@ def forward( ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]` + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -781,7 +759,7 @@ def forward( ) -> Union[tuple, QuestionAnsweringModelOutput]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]` + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as diff --git a/src/transformers/models/mra/modeling_mra.py b/src/transformers/models/mra/modeling_mra.py index a7fd783d848d..159299aa3053 100644 --- a/src/transformers/models/mra/modeling_mra.py +++ b/src/transformers/models/mra/modeling_mra.py @@ -555,32 +555,39 @@ def __init__(self, config, position_embedding_type=None): self.initial_prior_first_n_blocks = config.initial_prior_first_n_blocks self.initial_prior_diagonal_n_blocks = config.initial_prior_diagonal_n_blocks - def transpose_for_scores(self, layer): - new_layer_shape = layer.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - layer = layer.view(*new_layer_shape) - return layer.permute(0, 2, 1, 3) - def forward(self, hidden_states, attention_mask=None): - mixed_query_layer = self.query(hidden_states) - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) - - batch_size, num_heads, seq_len, head_dim = query_layer.size() + batch_size, seq_len, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # revert changes made by get_extended_attention_mask attention_mask = 1.0 + attention_mask / 10000.0 attention_mask = ( - attention_mask.squeeze().repeat(1, num_heads, 1).reshape(batch_size * num_heads, seq_len).int() + attention_mask.squeeze() + .repeat(1, self.num_attention_heads, 1) + .reshape(batch_size * self.num_attention_heads, seq_len) + .int() ) # The CUDA kernels are most efficient with inputs whose size is a multiple of a GPU's warp size (32). Inputs # smaller than this are padded with zeros. gpu_warp_size = 32 - if head_dim < gpu_warp_size: - pad_size = batch_size, num_heads, seq_len, gpu_warp_size - head_dim + if self.attention_head_size < gpu_warp_size: + pad_size = batch_size, self.num_attention_heads, seq_len, gpu_warp_size - self.attention_head_size query_layer = torch.cat([query_layer, torch.zeros(pad_size, device=query_layer.device)], dim=-1) key_layer = torch.cat([key_layer, torch.zeros(pad_size, device=key_layer.device)], dim=-1) @@ -597,10 +604,10 @@ def forward(self, hidden_states, attention_mask=None): initial_prior_diagonal_n_blocks=self.initial_prior_diagonal_n_blocks, ) - if head_dim < gpu_warp_size: - context_layer = context_layer[:, :, :, :head_dim] + if self.attention_head_size < gpu_warp_size: + context_layer = context_layer[:, :, :, : self.attention_head_size] - context_layer = context_layer.reshape(batch_size, num_heads, seq_len, head_dim) + context_layer = context_layer.reshape(batch_size, self.num_attention_heads, seq_len, self.attention_head_size) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 0341c3afcfde..87dcd1d41f0e 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -757,9 +757,8 @@ class MT5PreTrainedModel(PreTrainedModel): base_model_prefix = "transformer" is_parallelizable = True supports_gradient_checkpointing = True - _supports_quantized_cache = False # enc-dec models don't support yet _supports_static_cache = True - _supports_cache_class = True + _no_split_modules = ["MT5Block"] _keep_in_fp32_modules = ["wo"] @@ -1857,37 +1856,6 @@ def forward( def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return self._shift_right(labels) - # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration._reorder_cache - def _reorder_cache(self, past_key_values, beam_idx): - # if decoder past is not included in output - # speedy decoding is disabled and no need to reorder - if past_key_values is None: - logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") - return past_key_values - - reordered_decoder_past = () - for layer_past_states in past_key_values: - # get the correct batch idx from layer past batch dim - # batch dim of `past` is at 2nd position - reordered_layer_past_states = () - for layer_past_state in layer_past_states: - # need to set correct `past` for each of the four key / value states - reordered_layer_past_states = reordered_layer_past_states + ( - layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), - ) - - if reordered_layer_past_states[0].shape != layer_past_states[0].shape: - raise ValueError( - f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched" - ) - if len(reordered_layer_past_states) != len(layer_past_states): - raise ValueError( - f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched" - ) - - reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) - return reordered_decoder_past - @auto_docstring class MT5EncoderModel(MT5PreTrainedModel): diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 2062a2333cc0..b7f092ea64a8 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -26,6 +26,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import ( ClassifierFreeGuidanceLogitsProcessor, GenerationConfig, @@ -189,11 +190,12 @@ def __init__( self, embed_dim: int, num_heads: int, - dropout: float = 0.0, - is_decoder: bool = False, - bias: bool = True, - is_causal: bool = False, + dropout: Optional[float] = 0.0, + is_decoder: Optional[bool] = False, + bias: Optional[bool] = True, + is_causal: Optional[bool] = False, config: Optional[MusicgenConfig] = None, + layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim @@ -210,6 +212,7 @@ def __init__( self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal + self.layer_idx = layer_idx self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -220,10 +223,11 @@ def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, # TODO: we need a refactor so that the different attention modules can get their specific kwargs # ATM, we have mixed things encoder, decoder, and encoder-decoder attn **kwargs: Unpack[FlashAttentionKwargs], @@ -244,42 +248,35 @@ def forward( # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -301,11 +298,11 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class MusicgenDecoderLayer(GradientCheckpointingLayer): - def __init__(self, config: MusicgenDecoderConfig): + def __init__(self, config: MusicgenDecoderConfig, layer_idx=None): super().__init__() self.embed_dim = config.hidden_size @@ -317,6 +314,7 @@ def __init__(self, config: MusicgenDecoderConfig): bias=False, is_causal=True, config=config, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -330,6 +328,7 @@ def __init__(self, config: MusicgenDecoderConfig): is_decoder=True, bias=False, config=config, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=False) @@ -346,9 +345,10 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Args: @@ -372,42 +372,35 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # Fully Connected residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) @@ -421,10 +414,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - - if use_cache: - outputs += (present_key_value,) - return outputs @@ -477,7 +466,9 @@ def __init__(self, config: MusicgenDecoderConfig): config.hidden_size, ) - self.layers = nn.ModuleList([MusicgenDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.layers = nn.ModuleList( + [MusicgenDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)] + ) self.layer_norm = nn.LayerNorm(config.hidden_size) self.attn_implementation = config._attn_implementation @@ -506,6 +497,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: r""" input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`): @@ -565,8 +557,24 @@ def forward( else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." + ) + use_cache = False + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if inputs_embeds is None: inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)]) @@ -586,23 +594,13 @@ def forward( # embed positions positions = self.embed_positions(input, past_key_values_length) - hidden_states = inputs_embeds + positions.to(inputs_embeds.device) - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -620,8 +618,6 @@ def forward( if self.training and (dropout_probability < self.layerdrop): continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - layer_outputs = decoder_layer( hidden_states, attention_mask, @@ -629,15 +625,12 @@ def forward( encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -650,16 +643,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, @@ -772,6 +767,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: r""" input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`): @@ -831,6 +827,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -898,6 +895,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs, ) -> Union[tuple, CausalLMOutputWithCrossAttentions]: r""" @@ -960,6 +958,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = outputs[0] @@ -1875,6 +1874,7 @@ def prepare_inputs_for_generation( encoder_outputs=None, decoder_delay_pattern_mask=None, guidance_scale=None, + cache_position=None, **kwargs, ): # Overwritten -- MusicGen has custom processing @@ -1896,16 +1896,15 @@ def prepare_inputs_for_generation( decoder_attention_mask = decoder_attention_mask.repeat((2, 1)) if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if decoder_input_ids.shape[1] > past_length: - remove_prefix_length = past_length + if cache_position[-1] >= decoder_input_ids.shape[1]: + decoder_input_ids = decoder_input_ids[:, -cache_position.shape[0] :] + elif ( + decoder_input_ids.shape[1] != cache_position.shape[0] + ): # Default case (the "else", a no op, is Exception 2) + decoder_input_ids = decoder_input_ids[:, cache_position] else: # Default to old behavior: keep only final ID - remove_prefix_length = decoder_input_ids.shape[1] - 1 - - decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] + decoder_input_ids = decoder_input_ids[:, -1:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 0111ba51d6a8..e415967b0a90 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -26,6 +26,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import ( ClassifierFreeGuidanceLogitsProcessor, GenerationConfig, @@ -197,11 +198,12 @@ def __init__( self, embed_dim: int, num_heads: int, - dropout: float = 0.0, - is_decoder: bool = False, - bias: bool = True, - is_causal: bool = False, + dropout: Optional[float] = 0.0, + is_decoder: Optional[bool] = False, + bias: Optional[bool] = True, + is_causal: Optional[bool] = False, config: Optional[MusicgenMelodyConfig] = None, + layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim @@ -218,6 +220,7 @@ def __init__( self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal + self.layer_idx = layer_idx self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -228,10 +231,11 @@ def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, # TODO: we need a refactor so that the different attention modules can get their specific kwargs # ATM, we have mixed things encoder, decoder, and encoder-decoder attn **kwargs: Unpack[FlashAttentionKwargs], @@ -252,42 +256,35 @@ def forward( # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -309,11 +306,11 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class MusicgenMelodyDecoderLayer(GradientCheckpointingLayer): - def __init__(self, config: MusicgenMelodyDecoderConfig): + def __init__(self, config: MusicgenMelodyDecoderConfig, layer_idx=None): super().__init__() self.embed_dim = config.hidden_size @@ -325,6 +322,7 @@ def __init__(self, config: MusicgenMelodyDecoderConfig): bias=False, is_causal=True, config=config, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -341,9 +339,10 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Args: @@ -360,15 +359,13 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -381,16 +378,7 @@ def forward( hidden_states = self.fc2(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs + return hidden_states, self_attn_weights @auto_docstring @@ -444,7 +432,9 @@ def __init__(self, config: MusicgenMelodyDecoderConfig): config.hidden_size, ) - self.layers = nn.ModuleList([MusicgenMelodyDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.layers = nn.ModuleList( + [MusicgenMelodyDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)] + ) self.layer_norm = nn.LayerNorm(config.hidden_size) self.attn_implementation = config._attn_implementation @@ -473,6 +463,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple, BaseModelOutputWithPast]: r""" input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`): @@ -526,9 +517,24 @@ def forward( else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." + ) + use_cache = False + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if inputs_embeds is None: inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)]) @@ -556,22 +562,12 @@ def forward( # embed positions positions = self.embed_positions(inputs_embeds, past_key_values_length) - hidden_states = inputs_embeds + positions.to(inputs_embeds.device) - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None - next_decoder_cache = () if use_cache else None # check if head_mask has a correct number of layers specified if desired if head_mask is not None: @@ -589,21 +585,16 @@ def forward( if self.training and (dropout_probability < self.layerdrop): continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - if output_attentions: all_attentions += (layer_outputs[1],) @@ -613,12 +604,16 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attentions] if v is not None) + return tuple( + v for v in [hidden_states, past_key_values, all_hidden_states, all_attentions] if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_attentions, ) @@ -708,6 +703,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple, BaseModelOutputWithPast]: r""" input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`): @@ -760,6 +756,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -827,6 +824,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, labels: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple, MusicgenMelodyOutputWithPast]: r""" input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`): @@ -881,6 +879,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = outputs[0] @@ -1793,7 +1792,7 @@ def prepare_inputs_for_generation( decoder_attention_mask = decoder_attention_mask.repeat((2, 1)) if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] + past_length = past_key_values.get_seq_length() # Some generation methods already pass only the last input ID if decoder_input_ids.shape[1] > past_length: diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index 9e5136d27ad0..1223d23fba20 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask, @@ -97,9 +98,10 @@ def __init__( self, embed_dim: int, num_heads: int, - dropout: float = 0.0, - is_decoder: bool = False, - bias: bool = True, + dropout: Optional[float] = 0.0, + is_decoder: Optional[bool] = False, + bias: Optional[bool] = True, + layer_idx: Optional[bool] = None, ): super().__init__() self.embed_dim = embed_dim @@ -114,24 +116,23 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.layer_idx = layer_idx self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, attn_prompt: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -143,35 +144,38 @@ def forward( # get query proj query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - if is_cross_attention and past_key_value is not None: + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True if attn_prompt is not None: key_states = torch.cat([attn_prompt[0].expand(bsz, -1, -1, -1), key_states], dim=2) @@ -181,9 +185,10 @@ def forward( attention_mask = torch.cat([prompt_mask, attention_mask], dim=(-1)) proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states.reshape(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) @@ -242,7 +247,7 @@ def forward( attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights_reshaped class MvpEncoderLayer(GradientCheckpointingLayer): @@ -284,7 +289,7 @@ def forward( returned tensors for more detail. """ residual = hidden_states - hidden_states, attn_weights, _ = self.self_attn( + hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, @@ -309,16 +314,11 @@ def forward( clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs + return hidden_states, attn_weights class MvpDecoderLayer(GradientCheckpointingLayer): - def __init__(self, config: MvpConfig): + def __init__(self, config: MvpConfig, layer_idx=None): super().__init__() self.embed_dim = config.d_model @@ -327,6 +327,7 @@ def __init__(self, config: MvpConfig): num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -338,6 +339,7 @@ def __init__(self, config: MvpConfig): config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) @@ -354,9 +356,10 @@ def forward( cross_attn_layer_head_mask: Optional[torch.Tensor] = None, self_attn_prompt: Optional[torch.Tensor] = None, cross_attn_prompt: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -383,45 +386,37 @@ def forward( residual = hidden_states # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, attn_prompt=self_attn_prompt, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, attn_prompt=cross_attn_prompt, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # Fully Connected residual = hidden_states hidden_states = self.activation_fn(self.fc1(hidden_states)) @@ -436,9 +431,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -735,7 +727,7 @@ def __init__( config.max_position_embeddings, config.d_model, ) - self.layers = nn.ModuleList([MvpDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layers = nn.ModuleList([MvpDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) self.layernorm_embedding = nn.LayerNorm(config.d_model) self.use_prompt = use_prompt @@ -776,6 +768,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: r""" Args: @@ -862,12 +855,27 @@ def forward( else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 attention_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length ) @@ -893,18 +901,10 @@ def forward( self_attn_prompt = self.self_attn_prompt(prompt_ids) cross_attn_prompt = self.cross_attn_prompt(prompt_ids) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -924,8 +924,6 @@ def forward( if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - layer_outputs = decoder_layer( hidden_states, attention_mask, @@ -935,15 +933,12 @@ def forward( cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), self_attn_prompt=(self_attn_prompt[idx] if self.use_prompt else None), cross_attn_prompt=(cross_attn_prompt[idx] if self.use_prompt else None), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -954,16 +949,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, @@ -1028,6 +1025,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple, Seq2SeqModelOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1111,6 +1109,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1196,6 +1195,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple, Seq2SeqLMOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1285,6 +1285,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias @@ -1312,17 +1313,6 @@ def forward( def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) - + layer_past[2:], - ) - return reordered_past - @auto_docstring( custom_intro=""" @@ -1739,6 +1729,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple, CausalLMOutputWithCrossAttentions]: r""" cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): @@ -1809,15 +1800,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = [ "MvpForCausalLM", diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index fe75aa212e56..745f88b88742 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -588,8 +588,7 @@ class NemotronPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True def _init_weights(self, module): diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 1e0dddeac937..ea584137632c 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -22,6 +22,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module @@ -511,11 +512,12 @@ def __init__( self, embed_dim: int, num_heads: int, - dropout: float = 0.0, - is_decoder: bool = False, - bias: bool = True, - is_causal: bool = False, + dropout: Optional[float] = 0.0, + is_decoder: Optional[bool] = False, + bias: Optional[bool] = True, + is_causal: Optional[bool] = False, config: Optional[NllbMoeConfig] = None, + layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim @@ -532,6 +534,7 @@ def __init__( self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal + self.layer_idx = layer_idx self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -542,10 +545,11 @@ def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, # TODO: we need a refactor so that the different attention modules can get their specific kwargs # ATM, we have mixed things encoder, decoder, and encoder-decoder attn **kwargs: Unpack[FlashAttentionKwargs], @@ -566,42 +570,35 @@ def forward( # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - # get key, value proj - # `past_key_value[0].shape[2] == encoder_hidden_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `encoder_hidden_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == encoder_hidden_states.shape[1] - ): + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self.k_proj(encoder_hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(encoder_hidden_states).view(*kv_input_shape).transpose(1, 2) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -623,7 +620,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class NllbMoeEncoderLayer(GradientCheckpointingLayer): @@ -669,7 +666,7 @@ def forward( """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states, attn_weights, _ = self.self_attn( + hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, @@ -709,7 +706,7 @@ def forward( class NllbMoeDecoderLayer(GradientCheckpointingLayer): - def __init__(self, config: NllbMoeConfig, is_sparse: bool = False): + def __init__(self, config: NllbMoeConfig, is_sparse: bool = False, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model self.is_sparse = is_sparse @@ -719,6 +716,7 @@ def __init__(self, config: NllbMoeConfig, is_sparse: bool = False): dropout=config.attention_dropout, is_decoder=True, config=config, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -731,6 +729,7 @@ def __init__(self, config: NllbMoeConfig, is_sparse: bool = False): config.attention_dropout, is_decoder=True, config=config, + layer_idx=layer_idx, ) self.cross_attention_layer_norm = nn.LayerNorm(self.embed_dim) if not self.is_sparse: @@ -748,10 +747,11 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = True, ) -> torch.Tensor: """ Args: @@ -779,42 +779,35 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = self.attn_dropout(hidden_states) hidden_states = residual + hidden_states # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states hidden_states = self.cross_attention_layer_norm(hidden_states) - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.cross_attention( + hidden_states, cross_attn_weights = self.cross_attention( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = self.attn_dropout(hidden_states) hidden_states = residual + hidden_states - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value += cross_attn_present_key_value - # Fully Connected residual = hidden_states @@ -833,7 +826,7 @@ def forward( clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - outputs = (hidden_states, present_key_value) + outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights, cross_attn_weights) @@ -1112,7 +1105,7 @@ def __init__(self, config: NllbMoeConfig, embed_tokens: Optional[nn.Embedding] = self.layers = nn.ModuleList() for i in range(config.decoder_layers): is_sparse = (i + 1) % sparse_step == 0 if sparse_step > 0 else False - self.layers.append(NllbMoeDecoderLayer(config, is_sparse)) + self.layers.append(NllbMoeDecoderLayer(config, is_sparse, layer_idx=i)) self.layer_norm = nn.LayerNorm(config.d_model) @@ -1135,6 +1128,7 @@ def forward( output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = True, ): r""" Args: @@ -1222,12 +1216,28 @@ def forward( else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # initialize `past_key_values` + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 attention_mask = self._update_causal_mask( attention_mask, input_shape, @@ -1249,19 +1259,11 @@ def forward( hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_router_probs = () if output_router_logits else None all_cross_attentions = () if output_attentions else None - present_key_value_states = () if use_cache else None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1285,8 +1287,6 @@ def forward( layer_head_mask = head_mask[idx] if head_mask is not None else None cross_attn_layer_head_mask = cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - past_key_value = past_key_values[idx] if past_key_values is not None else None - # under fsdp or deepspeed zero3 all gpus must run in sync layer_outputs = decoder_layer( hidden_states, @@ -1295,10 +1295,11 @@ def forward( encoder_attention_mask=encoder_attention_mask, layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_router_logits=output_router_logits, + cache_position=cache_position, ) hidden_states = layer_outputs[0] @@ -1306,12 +1307,9 @@ def forward( if skip_the_layer: continue - if use_cache: - present_key_value_states += (layer_outputs[1],) - if output_attentions: - all_self_attns += (layer_outputs[2],) - all_cross_attentions += (layer_outputs[3],) + all_self_attns += (layer_outputs[1],) + all_cross_attentions += (layer_outputs[2],) if output_router_logits: all_router_probs += (layer_outputs[-1],) @@ -1322,12 +1320,15 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - present_key_value_states, + past_key_values, all_hidden_states, all_self_attns, all_cross_attentions, @@ -1337,7 +1338,7 @@ def forward( ) return MoEModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=present_key_value_states, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, @@ -1476,6 +1477,7 @@ def forward( output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = True, ) -> Union[tuple[torch.Tensor], Seq2SeqMoEModelOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1555,6 +1557,7 @@ def forward( output_hidden_states=output_hidden_states, output_router_logits=output_router_logits, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1625,6 +1628,7 @@ def forward( output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.Tensor], Seq2SeqMoEOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1696,6 +1700,7 @@ def forward( output_hidden_states=output_hidden_states, output_router_logits=output_router_logits, return_dict=return_dict, + cache_position=cache_position, ) lm_logits = self.lm_head(outputs[0]) @@ -1767,15 +1772,6 @@ def _unpack_router_logits(self, router_outputs): total_expert_indexes = torch.stack(total_expert_indexes, dim=1) if len(total_expert_indexes) > 0 else None return total_router_logits, total_expert_indexes - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = [ "NllbMoeForConditionalGeneration", diff --git a/src/transformers/models/nystromformer/modeling_nystromformer.py b/src/transformers/models/nystromformer/modeling_nystromformer.py index f5b940157ded..babd8acc09f7 100755 --- a/src/transformers/models/nystromformer/modeling_nystromformer.py +++ b/src/transformers/models/nystromformer/modeling_nystromformer.py @@ -167,17 +167,23 @@ def iterative_inv(self, mat, n_iter=6): ) return value - def transpose_for_scores(self, layer): - new_layer_shape = layer.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - layer = layer.view(*new_layer_shape) - return layer.permute(0, 2, 1, 3) - def forward(self, hidden_states, attention_mask=None, output_attentions=False): - mixed_query_layer = self.query(hidden_states) - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + batch_size, seq_length, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) query_layer = query_layer / math.sqrt(math.sqrt(self.attention_head_size)) key_layer = key_layer / math.sqrt(math.sqrt(self.attention_head_size)) diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 8459ba57a658..77e41e2d6292 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -293,8 +293,7 @@ class OlmoPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_attention_backend = True _can_record_outputs = { diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 8eb966f313ec..5dd56e2eddfb 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -298,8 +298,7 @@ class Olmo2PreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_attention_backend = True _can_record_outputs = { diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index 420d239b2d23..6c491754d448 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -705,8 +705,7 @@ class OlmoePreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) def _init_weights(self, module): diff --git a/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py b/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py index 9bac40553d9f..ab1ae0b9744b 100644 --- a/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py +++ b/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py @@ -494,11 +494,6 @@ def __init__(self, config, hidden_size, num_attention_heads, dropout): self.out_proj = nn.Linear(hidden_size, hidden_size) self.dropout = nn.Dropout(dropout) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, queries: torch.Tensor, @@ -507,9 +502,18 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor]: - query_layer = self.transpose_for_scores(self.query(queries)) - key_layer = self.transpose_for_scores(self.key(keys)) - value_layer = self.transpose_for_scores(self.value(values)) + batch_size, seq_length, _ = queries.shape + query_layer = ( + self.query(queries) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(keys).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + ) + value_layer = ( + self.value(values).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + ) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 2f5a07f79ac6..b6641f4820f5 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -312,8 +312,7 @@ class OPTPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True def _init_weights(self, module): @@ -872,15 +871,6 @@ def forward( attentions=outputs.attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index e3c17bc18ede..5ab2b93e32a4 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -113,8 +113,7 @@ class PaliGemmaPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["PaliGemmaMultiModalProjector"] _skip_keys_device_placement = "past_key_values" - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_flash_attn = True _supports_sdpa = True diff --git a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py index 1ff46faf6efc..bf2dc59d249d 100644 --- a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py +++ b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py @@ -309,7 +309,6 @@ def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, diff --git a/src/transformers/models/patchtst/modeling_patchtst.py b/src/transformers/models/patchtst/modeling_patchtst.py index dfd28ea2b0a9..c613fd8955af 100755 --- a/src/transformers/models/patchtst/modeling_patchtst.py +++ b/src/transformers/models/patchtst/modeling_patchtst.py @@ -106,7 +106,6 @@ def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 9d0f3a9dd880..1eec8b4166dc 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -266,7 +266,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Pegasus, MBART->PEGASUS @@ -309,7 +309,7 @@ def forward( """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states, attn_weights, _ = self.self_attn( + hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, @@ -330,12 +330,7 @@ def forward( clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs + return hidden_states, attn_weights # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Pegasus, MBART->PEGASUS @@ -409,7 +404,7 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - hidden_states, self_attn_weights, past_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -426,7 +421,7 @@ def forward( residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, @@ -451,9 +446,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - if use_cache: - outputs += (past_key_value,) - return outputs @@ -465,7 +457,7 @@ class PegasusPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): @@ -1108,7 +1100,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1141,9 +1132,6 @@ def forward( ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[3 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1156,19 +1144,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, @@ -1537,17 +1524,6 @@ def forward( def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) - + layer_past[2:], - ) - return reordered_past - # Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->Pegasus class PegasusDecoderWrapper(PegasusPreTrainedModel): @@ -1710,14 +1686,5 @@ def forward( cross_attentions=outputs.cross_attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = ["PegasusForCausalLM", "PegasusForConditionalGeneration", "PegasusModel", "PegasusPreTrainedModel"] diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 29dbbdf32832..a371be87de2e 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -287,7 +287,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class PegasusXGlobalLocalAttention(nn.Module): @@ -705,7 +705,7 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - hidden_states, self_attn_weights, past_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -721,7 +721,7 @@ def forward( residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, @@ -744,10 +744,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - - if use_cache: - outputs += (past_key_value,) - return outputs @@ -761,7 +757,7 @@ class PegasusXPreTrainedModel(PreTrainedModel): # Flaky logits _supports_sdpa = False _supports_flex_attn = True - _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): @@ -1364,7 +1360,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = None for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) @@ -1387,9 +1382,6 @@ def forward( ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[3 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1402,19 +1394,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, @@ -1724,17 +1715,6 @@ def forward( def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) - + layer_past[2:], - ) - return reordered_past - # Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->PegasusX class PegasusXDecoderWrapper(PegasusXPreTrainedModel): diff --git a/src/transformers/models/perception_lm/modeling_perception_lm.py b/src/transformers/models/perception_lm/modeling_perception_lm.py index 4a0e01a2415b..da244141c7d5 100644 --- a/src/transformers/models/perception_lm/modeling_perception_lm.py +++ b/src/transformers/models/perception_lm/modeling_perception_lm.py @@ -91,10 +91,10 @@ class PerceptionLMPreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _skip_keys_device_placement = "past_key_values" - _supports_cache_class = True + _supports_flash_attn = True _supports_sdpa = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_flex_attn = True _supports_attention_backend = True @@ -130,10 +130,14 @@ class PerceptionLMModelOutputWithPast(BaseModelOutputWithPast): `past_key_values` input) to speed up sequential decoding. image_hidden_states (`torch.FloatTensor`, *optional*): A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. - image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + Image hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + video_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_videos, sequence_length, hidden_size)`. + Video hidden_states of the model produced by the vision encoder and after projecting the last hidden state. """ image_hidden_states: Optional[torch.FloatTensor] = None + video_hidden_states: Optional[torch.FloatTensor] = None @@ -157,7 +161,10 @@ class PerceptionLMCausalLMOutputWithPast(ModelOutput): `past_key_values` input) to speed up sequential decoding. image_hidden_states (`torch.FloatTensor`, *optional*): A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. - image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + Image hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + video_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_videos, sequence_length, hidden_size)`. + Video hidden_states of the model produced by the vision encoder and after projecting the last hidden state. """ loss: Optional[torch.FloatTensor] = None @@ -166,6 +173,7 @@ class PerceptionLMCausalLMOutputWithPast(ModelOutput): hidden_states: Optional[tuple[torch.FloatTensor]] = None attentions: Optional[tuple[torch.FloatTensor]] = None image_hidden_states: Optional[torch.FloatTensor] = None + video_hidden_states: Optional[torch.FloatTensor] = None diff --git a/src/transformers/models/perception_lm/modular_perception_lm.py b/src/transformers/models/perception_lm/modular_perception_lm.py index 05c9e8e9d043..3258fcd79fa9 100644 --- a/src/transformers/models/perception_lm/modular_perception_lm.py +++ b/src/transformers/models/perception_lm/modular_perception_lm.py @@ -96,10 +96,44 @@ class PerceptionLMPreTrainedModel(LlavaPreTrainedModel): class PerceptionLMModelOutputWithPast(LlavaModelOutputWithPast): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + Image hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + video_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_videos, sequence_length, hidden_size)`. + Video hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + video_hidden_states: Optional[torch.FloatTensor] = None class PerceptionLMCausalLMOutputWithPast(LlavaCausalLMOutputWithPast): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + Image hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + video_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_videos, sequence_length, hidden_size)`. + Video hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + video_hidden_states: Optional[torch.FloatTensor] = None diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index af7d1db9e272..ef69edc1870d 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -388,8 +388,7 @@ class PersimmonPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["PersimmonDecoderLayer"] _skip_keys_device_placement = "past_key_values" - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_sdpa = True _supports_flash_attn = True diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 527574f613bf..d08d73d87adc 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -298,8 +298,7 @@ class PhiPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_attention_backend = True _can_record_outputs = { diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 865a3973adf8..fb92b54105a5 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -290,8 +290,7 @@ class Phi3PreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_attention_backend = True _can_record_outputs = { diff --git a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py index a464e67e68f9..58008f692bd8 100644 --- a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py @@ -1592,8 +1592,7 @@ class Phi4MultimodalPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_attention_backend = True _can_record_outputs = { diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index 328d749cd6be..df2978f123e6 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -889,8 +889,7 @@ class PhimoePreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) def _init_weights(self, module): diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 93c281dc5bab..900660aa5f73 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -350,7 +350,7 @@ def forward( @auto_docstring class Pix2StructPreTrainedModel(PreTrainedModel): config_class = Pix2StructConfig - _supports_cache_class = True + _supports_static_cache = False @property @@ -1037,37 +1037,6 @@ def __init__(self, config): self.post_init() self.gradient_checkpointing = False - # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._reorder_cache - def _reorder_cache(self, past_key_values, beam_idx): - # if decoder past is not included in output - # speedy decoding is disabled and no need to reorder - if past_key_values is None: - logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") - return past_key_values - - reordered_decoder_past = () - for layer_past_states in past_key_values: - # get the correct batch idx from layer past batch dim - # batch dim of `past` is at 2nd position - reordered_layer_past_states = () - for layer_past_state in layer_past_states: - # need to set correct `past` for each of the four key / value states - reordered_layer_past_states = reordered_layer_past_states + ( - layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), - ) - - if reordered_layer_past_states[0].shape != layer_past_states[0].shape: - raise ValueError( - f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched" - ) - if len(reordered_layer_past_states) != len(layer_past_states): - raise ValueError( - f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched" - ) - - reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) - return reordered_decoder_past - def get_input_embeddings(self): return self.embed_tokens diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 95835fd977cb..51c13b87322a 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -463,7 +463,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class PLBartEncoderLayer(GradientCheckpointingLayer): @@ -505,7 +505,7 @@ def forward( returned tensors for more detail. """ residual = hidden_states - hidden_states, attn_weights, _ = self.self_attn( + hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, @@ -775,7 +775,7 @@ def forward( residual = hidden_states # Self Attention - hidden_states, self_attn_weights, past_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -792,7 +792,7 @@ def forward( if encoder_hidden_states is not None: residual = hidden_states - hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, @@ -819,9 +819,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - if use_cache: - outputs += (past_key_value,) - return outputs @@ -1036,7 +1033,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1069,10 +1065,6 @@ def forward( cache_position=cache_position, ) hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[3 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1083,19 +1075,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, @@ -1440,17 +1431,6 @@ def forward( def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) - + layer_past[2:], - ) - return reordered_past - class PLBartClassificationHead(nn.Module): """Head for sentence-level classification tasks.""" @@ -1769,15 +1749,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = [ "PLBartForCausalLM", diff --git a/src/transformers/models/plbart/modular_plbart.py b/src/transformers/models/plbart/modular_plbart.py index 5202e61de846..2aa8568954d1 100644 --- a/src/transformers/models/plbart/modular_plbart.py +++ b/src/transformers/models/plbart/modular_plbart.py @@ -596,17 +596,6 @@ def forward( def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) - + layer_past[2:], - ) - return reordered_past - class PLBartClassificationHead(BartClassificationHead): pass diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index 0cebfaef8c3a..6b64c1fd8fd8 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -576,7 +576,7 @@ class Pop2PianoPreTrainedModel(PreTrainedModel): base_model_prefix = "transformer" is_parallelizable = False supports_gradient_checkpointing = True - _supports_cache_class = True + _supports_static_cache = False _no_split_modules = ["Pop2PianoBlock"] _keep_in_fp32_modules = ["wo"] @@ -1332,35 +1332,5 @@ def generate( def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return self._shift_right(labels) - def _reorder_cache(self, past_key_values, beam_idx): - # if decoder past is not included in output - # speedy decoding is disabled and no need to reorder - if past_key_values is None: - logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") - return past_key_values - - reordered_decoder_past = () - for layer_past_states in past_key_values: - # get the correct batch idx from layer past batch dim - # batch dim of `past` is at 2nd position - reordered_layer_past_states = () - for layer_past_state in layer_past_states: - # need to set correct `past` for each of the four key / value states - reordered_layer_past_states = reordered_layer_past_states + ( - layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), - ) - - if reordered_layer_past_states[0].shape != layer_past_states[0].shape: - raise ValueError( - f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched" - ) - if len(reordered_layer_past_states) != len(layer_past_states): - raise ValueError( - f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched" - ) - - reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) - return reordered_decoder_past - __all__ = ["Pop2PianoForConditionalGeneration", "Pop2PianoPreTrainedModel"] diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 99596434270f..d9c7807d5280 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -26,6 +26,7 @@ from torch.nn import LayerNorm from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput @@ -386,10 +387,10 @@ def forward(self, inputs_shape, device, attention_mask=None, past_key_values=Non ) if position_ids is None: - if past_key_values is not None: + if past_key_values is not None and past_key_values.get_seq_length() != 0: # position_ids is the same for every token when decoding a single step # Without the int() cast, it doesn't work in some cases when exporting to ONNX - prev_num_input_ids = past_key_values[0][0].shape[2] + prev_num_input_ids = past_key_values.get_seq_length() num_input_ids = inputs_shape[1] + prev_num_input_ids position_ids = torch.ones((1, 1), dtype=torch.long, device=device) * ( int(self.padding_idx + num_input_ids) @@ -415,11 +416,7 @@ def _forward(self, position_ids): class ProphetNetAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__( - self, - config: ProphetNetConfig, - num_attn_heads: int, - ): + def __init__(self, config: ProphetNetConfig, num_attn_heads: int, layer_idx: Optional[int] = None): super().__init__() hidden_size = config.hidden_size @@ -427,6 +424,7 @@ def __init__( self.dropout = config.dropout self.num_attn_heads = num_attn_heads self.head_dim = hidden_size // num_attn_heads + self.layer_idx = layer_idx assert self.head_dim * num_attn_heads == hidden_size, ( "`config.hidden_size` must be divisible by `config.num_encoder_attention_heads` and" @@ -439,17 +437,15 @@ def __init__( self.out_proj = nn.Linear(hidden_size, hidden_size) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_attn_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states, key_value_states: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None, layer_head_mask: Optional[Tensor] = None, - past_key_value: Optional[tuple[Tensor]] = None, - output_attentions: bool = False, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[Tensor, Optional[Tensor]]: batch_size, tgt_len, hidden_size = hidden_states.size() @@ -465,32 +461,41 @@ def forward( # previous time steps are cached - no need to recompute key and value if they are static query_states = self.query_proj(hidden_states) / (self.head_dim**0.5) - if is_cross_attention and past_key_value is not None: + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.key_proj(key_value_states), -1, batch_size) - value_states = self._shape(self.value_proj(key_value_states), -1, batch_size) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self._shape(self.key_proj(hidden_states), -1, batch_size) - value_states = self._shape(self.value_proj(hidden_states), -1, batch_size) - - if is_cross_attention: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - # project states into the correct shape - proj_shape = (batch_size, self.num_attn_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, batch_size).view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) + key_states = self.key_proj(current_states) + value_states = self.value_proj(current_states) + key_states = key_states.view(batch_size, -1, self.num_attn_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.num_attn_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True + + query_states = query_states.view(batch_size, tgt_len, self.num_attn_heads, self.head_dim).transpose(1, 2) src_len = key_states.size(2) + attn_weights = torch.einsum("bsij,bsjk->bsik", query_states, key_states.transpose(2, 3)) expected_shape = (batch_size, self.num_attn_heads, tgt_len, src_len) if attn_weights.size() != expected_shape: @@ -538,7 +543,7 @@ def forward( attn_output = self.out_proj(attn_output) attn_output = nn.functional.dropout(attn_output, p=self.dropout, training=self.training) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights_reshaped class ProphetNetFeedForward(nn.Module): @@ -565,7 +570,7 @@ def forward(self, hidden_states): class ProphetNetNgramSelfAttention(nn.Module): - def __init__(self, config: ProphetNetConfig): + def __init__(self, config: ProphetNetConfig, layer_idx=None): super().__init__() self.hidden_size = config.hidden_size @@ -576,6 +581,7 @@ def __init__(self, config: ProphetNetConfig): self.attention_dropout = config.attention_dropout self.head_dim = config.hidden_size // self.num_attn_heads self.ngram = config.ngram + self.layer_idx = layer_idx assert self.head_dim * self.num_attn_heads == config.hidden_size, ( "config.hidden_size must be divisible by num_attn_heads" @@ -610,6 +616,7 @@ def forward( main_relative_position_buckets=None, predict_relative_position_buckets=None, position_ids=None, + cache_position=None, ): batch_size, ngram_sequence_length, hidden_size = hidden_states.size() assert list(hidden_states.size()) == [batch_size, ngram_sequence_length, hidden_size], ( @@ -631,9 +638,9 @@ def forward( value_states = self._shape(value_states, -1, batch_size) proj_shape = (batch_size, self.num_attn_heads, -1, self.head_dim) - query_states = query_states.view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) + query_states = query_states.reshape(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) # chunk into main stream and predict stream hidden_states_list = hidden_states.chunk(1 + self.ngram, dim=1) @@ -646,15 +653,16 @@ def forward( main_key_states, predict_key_states_list = key_states_list[0], key_states_list[1:] main_value_states, predict_value_states_list = value_states_list[0], value_states_list[1:] - # saved states are stored with shape (batch_size, num_attn_heads, seq_len, head_dim) + # ProphetNet has two separate attention layers, one for self and one for cross attention + # We need to obtain the self attention only for this module, if `EncoderDecoderCache` if past_key_value is not None: - prev_main_key_states = past_key_value[0] - main_key_states = torch.cat((prev_main_key_states, main_key_states), dim=2) - prev_main_value_states = past_key_value[1] - main_value_states = torch.cat((prev_main_value_states, main_value_states), dim=2) - - # Update cache - past_key_value = (main_key_states, main_value_states) + if isinstance(past_key_value, EncoderDecoderCache): + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + main_key_states, main_value_states = curr_past_key_value.update( + main_key_states, main_value_states, self.layer_idx, {"cache_position": cache_position} + ) # get seq_length of main stream only sequence_length = ngram_sequence_length // (1 + self.ngram) @@ -776,7 +784,7 @@ def forward( attn_output = nn.functional.dropout(attn_output, p=self.dropout, training=self.training) - return attn_output, main_attn_probs, predict_attn_probs, past_key_value + return attn_output, main_attn_probs, predict_attn_probs def get_main_relative_pos_embeddings( self, hidden_states, attn_weights, position_ids, main_relative_position_buckets @@ -906,7 +914,7 @@ def forward( output_attentions: bool = False, ): # 1st residual block - attention_output, attn_weights, _ = self.self_attn( + attention_output, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, @@ -931,15 +939,15 @@ class ProphetNetDecoderLayer(GradientCheckpointingLayer): Decoder block for Prophetnet """ - def __init__(self, config: ProphetNetConfig): + def __init__(self, config: ProphetNetConfig, layer_idx=None): super().__init__() # 1st residual block - self.self_attn = ProphetNetNgramSelfAttention(config) + self.self_attn = ProphetNetNgramSelfAttention(config, layer_idx=layer_idx) self.self_attn_layer_norm = LayerNorm(config.hidden_size) # 2nd residual block if config.add_cross_attention: - self.cross_attn = ProphetNetAttention(config, config.num_decoder_attention_heads) + self.cross_attn = ProphetNetAttention(config, config.num_decoder_attention_heads, layer_idx=layer_idx) self.cross_attn_layer_norm = LayerNorm(config.hidden_size) # 3rd residual block @@ -959,15 +967,14 @@ def forward( predict_relative_position_buckets=None, position_ids=None, past_key_value=None, - use_cache: bool = True, - output_attentions: bool = False, + use_cache: Optional[bool] = True, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ): # 1st residual block - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - ngram_attention_output, self_attn_weights, self_attn_weights_ngram, present_key_value = self.self_attn( + ngram_attention_output, self_attn_weights, self_attn_weights_ngram = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, extended_predict_attention_mask=extended_predict_attention_mask, @@ -977,24 +984,19 @@ def forward( ) hidden_states = self.self_attn_layer_norm(hidden_states + ngram_attention_output) - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attn_weights = None if encoder_hidden_states is not None: # 2nd residual block - attention_output, cross_attn_weights, cross_attn_present_key_value = self.cross_attn( + attention_output, cross_attn_weights = self.cross_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attn_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, ) hidden_states = self.cross_attn_layer_norm(attention_output + hidden_states) - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # 3rd residual block feed_forward_output = self.feed_forward(hidden_states) hidden_states = self.feed_forward_layer_norm(feed_forward_output + hidden_states) @@ -1004,9 +1006,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, self_attn_weights_ngram, cross_attn_weights) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -1160,7 +1159,9 @@ def __init__(self, config: ProphetNetConfig, word_embeddings: Optional[nn.Embedd self.position_embeddings = ProphetNetPositionalEmbeddings(config) self.ngram_embeddings = nn.Embedding(self.ngram, config.hidden_size, None) - self.layers = nn.ModuleList([ProphetNetDecoderLayer(config) for _ in range(config.num_decoder_layers)]) + self.layers = nn.ModuleList( + [ProphetNetDecoderLayer(config, layer_idx=i) for i in range(config.num_decoder_layers)] + ) self.embeddings_layer_norm = LayerNorm(config.hidden_size) self.gradient_checkpointing = False @@ -1188,6 +1189,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple, ProphetNetDecoderModelOutput]: r""" cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): @@ -1225,13 +1227,32 @@ def forward( batch_size, sequence_length = inputs_embeds.shape[:2] + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + main_stream_pos_embed, position_ids = self.position_embeddings( (batch_size, sequence_length), device=inputs_embeds.device, past_key_values=past_key_values, ) - if past_key_values is not None: + if past_key_values_length != 0: main_relative_position_buckets, predict_relative_position_buckets = None, None else: ( @@ -1246,7 +1267,7 @@ def forward( ngram_embeddings = self.ngram_embeddings.weight # prepare attention mask - if past_key_values is not None: + if past_key_values_length != 0: assert hidden_states.size(1) == 1, ( "At the moment `use_cache` is only supported for `decoder_input_ids` of length 1" ) @@ -1288,15 +1309,6 @@ def forward( all_ngram_stream_attns = () if output_attentions else None all_cross_attns = () if output_attentions and self.config.add_cross_attention else None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - present_key_values = () if use_cache else None - # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): if attn_mask is not None: @@ -1311,8 +1323,6 @@ def forward( if self.config.ngram > 0: all_ngram_stream_hidden_states += (hidden_states[:, sequence_length:],) - past_key_value = past_key_values[idx] if past_key_values is not None else None - layer_outputs = decoder_layer( hidden_states, extended_attention_mask, @@ -1324,16 +1334,13 @@ def forward( main_relative_position_buckets=main_relative_position_buckets, predict_relative_position_buckets=predict_relative_position_buckets, position_ids=position_ids, - past_key_value=past_key_value, + past_key_value=past_key_values, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = layer_outputs[0] - - if use_cache: - present_key_values += (layer_outputs[4 if output_attentions else 1],) - if output_attentions: all_main_stream_attns += (layer_outputs[1],) all_ngram_stream_attns += (layer_outputs[2],) @@ -1346,6 +1353,9 @@ def forward( if self.config.ngram > 0: all_ngram_stream_hidden_states += (hidden_states[:, sequence_length:],) + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + # split last_hidden_state for return last_hidden_state = hidden_states[:, :sequence_length] last_hidden_state_ngram = hidden_states[:, sequence_length:] if self.config.ngram > 0 else None @@ -1356,7 +1366,7 @@ def forward( for v in [ last_hidden_state, last_hidden_state_ngram, - present_key_values, + past_key_values, all_main_stream_hidden_states, all_ngram_stream_hidden_states, all_main_stream_attns, @@ -1368,7 +1378,7 @@ def forward( return ProphetNetDecoderModelOutput( last_hidden_state=last_hidden_state, last_hidden_state_ngram=last_hidden_state_ngram, - past_key_values=present_key_values, + past_key_values=past_key_values, hidden_states=all_main_stream_hidden_states, hidden_states_ngram=all_ngram_stream_hidden_states, attentions=all_main_stream_attns, @@ -1516,6 +1526,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple, ProphetNetSeq2SeqModelOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1587,6 +1598,7 @@ def forward( output_hidden_states=output_hidden_states, use_cache=use_cache, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1657,6 +1669,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple, ProphetNetSeq2SeqLMOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1722,6 +1735,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) batch_size, sequence_length = ( decoder_input_ids.shape if decoder_input_ids is not None else decoder_inputs_embeds.shape[:2] @@ -1791,18 +1805,6 @@ def _compute_loss(self, logits, labels, ignore_index=-100): def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return self._shift_right(labels) - @staticmethod - # Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration._reorder_cache - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) - + layer_past[2:], - ) - return reordered_past - def get_encoder(self): return self.prophetnet.encoder @@ -2025,16 +2027,6 @@ def prepare_inputs_for_generation( "use_cache": use_cache, } - @staticmethod - # Copied from transformers.models.bart.modeling_bart.BartForCausalLM._reorder_cache - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - class ProphetNetDecoderWrapper(ProphetNetPreTrainedModel): """ diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 285789569626..e58c08c223e7 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -262,8 +262,7 @@ class Qwen2PreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_attention_backend = True _can_record_outputs = { diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index b890b81297d1..d576c801a4ce 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -87,7 +87,6 @@ class Qwen2_5OmniPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn = True _supports_sdpa = True - _supports_cache_class = True _supports_static_cache = False _supports_attention_backend = True diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index eaa612dc83d6..cf2d802abb6b 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -325,7 +325,7 @@ class Qwen2_5_VLPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn = True _supports_sdpa = True - _supports_cache_class = True + _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index 569fd4bdc25e..eafcbaf01926 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -188,7 +188,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, None + return attn_output, attn_weights # Copied from transformers.models.whisper.modeling_whisper.WhisperEncoderLayer with Whisper->Qwen2Audio, WHISPER->QWEN2AUDIO @@ -231,7 +231,7 @@ def forward( """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states, attn_weights, _ = self.self_attn( + hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, @@ -252,12 +252,7 @@ def forward( clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs + return hidden_states, attn_weights @auto_docstring diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 7ec40f25814d..9503c92bff6e 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -746,7 +746,6 @@ class Qwen2MoePreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn = True _supports_sdpa = True - _supports_cache_class = True def _init_weights(self, module): std = self.config.initializer_range diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index d6435667a5de..d2f9b535c9b0 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -658,7 +658,7 @@ class Qwen2VLPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn = True _supports_sdpa = True - _supports_cache_class = True + _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index 2ebdbc97565e..b6e537ec1825 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -288,8 +288,7 @@ class Qwen3PreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_attention_backend = True _can_record_outputs = { diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 1d0f3275235f..f7b180558ad6 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -411,8 +411,6 @@ class Qwen3MoePreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) _supports_attention_backend = True _can_record_outputs = { diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index 1af81cf95947..89f9f7d1b93e 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -21,7 +21,7 @@ import torch from torch import nn -from ...cache_utils import EncoderDecoderCache +from ...cache_utils import Cache, EncoderDecoderCache from ...configuration_utils import PretrainedConfig from ...generation import GenerationConfig, GenerationMixin, LogitsProcessorList, StoppingCriteriaList from ...modeling_outputs import ModelOutput @@ -50,7 +50,7 @@ class RetrievAugLMMarginOutput(ModelOutput): doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`): Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and `question_encoder_last_hidden_state`. - past_key_values (`list[torch.FloatTensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): List of `torch.FloatTensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads, sequence_length, embed_size_per_head)`). @@ -115,7 +115,7 @@ class RetrievAugLMMarginOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: Optional[torch.FloatTensor] = None doc_scores: Optional[torch.FloatTensor] = None - past_key_values: Optional[list[torch.FloatTensor]] = None + past_key_values: Optional[Cache] = None retrieved_doc_embeds: Optional[torch.FloatTensor] = None retrieved_doc_ids: Optional[torch.LongTensor] = None context_input_ids: Optional[torch.LongTensor] = None @@ -141,7 +141,7 @@ class RetrievAugLMOutput(ModelOutput): doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`): Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and `question_encoder_last_hidden_state`. - past_key_values (`list[torch.FloatTensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): List of `torch.FloatTensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads, sequence_length, embed_size_per_head)`). @@ -205,7 +205,7 @@ class RetrievAugLMOutput(ModelOutput): logits: Optional[torch.FloatTensor] = None doc_scores: Optional[torch.FloatTensor] = None - past_key_values: Optional[list[torch.FloatTensor]] = None + past_key_values: Optional[Cache] = None retrieved_doc_embeds: Optional[torch.FloatTensor] = None retrieved_doc_ids: Optional[torch.LongTensor] = None context_input_ids: Optional[torch.LongTensor] = None @@ -439,7 +439,7 @@ def forward( encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.BoolTensor] = None, - past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, doc_scores: Optional[torch.FloatTensor] = None, context_input_ids: Optional[torch.LongTensor] = None, context_attention_mask: Optional[torch.LongTensor] = None, @@ -713,7 +713,7 @@ def forward( encoder_outputs: Optional[tuple[tuple[torch.Tensor]]] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.BoolTensor] = None, - past_key_values: Optional[tuple[tuple[torch.Tensor]]] = None, + past_key_values: Optional[Cache] = None, context_input_ids: Optional[torch.LongTensor] = None, context_attention_mask: Optional[torch.LongTensor] = None, doc_scores: Optional[torch.FloatTensor] = None, @@ -1204,6 +1204,8 @@ def _reorder_stacked(hidden_states, new_order): if isinstance(past_key_values, EncoderDecoderCache): reordered_past = EncoderDecoderCache.from_legacy_cache(reordered_past) + if isinstance(past_key_values, EncoderDecoderCache): + reordered_past = EncoderDecoderCache.from_legacy_cache(reordered_past) return reordered_past def marginalize(self, seq_logits, doc_scores, n_docs=None): @@ -1225,7 +1227,7 @@ def forward( encoder_outputs: Optional[tuple[tuple[torch.Tensor]]] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.BoolTensor] = None, - past_key_values: Optional[tuple[tuple[torch.Tensor]]] = None, + past_key_values: Optional[Cache] = None, context_input_ids: Optional[torch.LongTensor] = None, context_attention_mask: Optional[torch.LongTensor] = None, doc_scores: Optional[torch.FloatTensor] = None, @@ -1563,6 +1565,15 @@ def extend_enc_output(tensor, num_beams=None): generation_config=generation_config, stopping_criteria=stopping_criteria ) + self._prepare_cache_for_generation( + generation_config, + model_kwargs, + assistant_model=None, + batch_size=input_ids.shape[0], + max_cache_length=generation_config.max_length - 1, + device=input_ids.device, + ) + if generation_config.num_beams == 1: if generation_config.num_return_sequences > 1: raise ValueError( @@ -1581,6 +1592,14 @@ def extend_enc_output(tensor, num_beams=None): elif generation_config.num_beams > 1: if generation_config.num_return_sequences > generation_config.num_beams: raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") + + # 11. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) return self._beam_search( input_ids, logits_processor=pre_processor, diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index 02af226c5bf5..b6c2325a692e 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -511,8 +511,6 @@ class RecurrentGemmaPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["cache"] _supports_flash_attn = False _supports_sdpa = False # we can't compare with eager for now - _supports_cache_class = True - _supports_quantized_cache = True def _init_weights(self, module): std = math.sqrt(self.config.w_init_variance_scale / self.config.conv1d_width) @@ -807,15 +805,5 @@ def forward( hidden_states=outputs.hidden_states, ) - # Ignore copy - def _reorder_cache(self, past_key_values, beam_idx): - for layer in self.layers: - if hasattr(layer.temporal_block, "key_states"): - k_state = layer.temporal_block.key_states - v_state = layer.temporal_block.value_states - k_state = k_state.index_select(0, beam_idx.to(k_state.device)) - v_state = v_state.index_select(0, beam_idx.to(v_state.device)) - return None - __all__ = ["RecurrentGemmaForCausalLM", "RecurrentGemmaModel", "RecurrentGemmaPreTrainedModel"] diff --git a/src/transformers/models/reformer/modeling_reformer.py b/src/transformers/models/reformer/modeling_reformer.py index b9915efd1e53..d0a80755b7d7 100755 --- a/src/transformers/models/reformer/modeling_reformer.py +++ b/src/transformers/models/reformer/modeling_reformer.py @@ -20,7 +20,7 @@ from dataclasses import dataclass from functools import reduce from operator import mul -from typing import Optional, Union +from typing import Any, Iterable, Optional, Union import numpy as np import torch @@ -29,6 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import DynamicCache from ...generation import GenerationMixin from ...modeling_outputs import CausalLMOutput, MaskedLMOutput, QuestionAnsweringModelOutput, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel @@ -60,6 +61,119 @@ ) +class ReformerDynamicCache(DynamicCache): + """ + A dynamic cache that stores past buckets instead of key/values. + """ + + def __init__(self, _distributed_cache_data: Optional[Iterable] = None) -> None: + super().__init__() + self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen + self.buckets_cache: list[torch.Tensor] = [] + self.states_cache: list[torch.Tensor] = [] + + if _distributed_cache_data is not None: + for buckets, states in _distributed_cache_data: + self.buckets_cache.append(buckets) + self.states_cache.append(states) + + def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: + """ + Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the + sequence length. + """ + if layer_idx < len(self): + return (self.buckets_cache[layer_idx], self.states_cache[layer_idx]) + else: + raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") + + def __iter__(self): + """ + Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over + keys and values + """ + for layer_idx in range(len(self)): + yield (self.buckets_cache[layer_idx], self.states_cache[layer_idx]) + + def __len__(self): + """ + Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds + to the number of layers in the model. + """ + return len(self.states_cache) + + def update( + self, + buckets: torch.Tensor, + states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `ReformerDynamicCache`. + + Return: + A tuple containing the updated key and value states. + """ + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += states.shape[-2] + + # Update the cache + if states is not None: + if len(self.states_cache) <= layer_idx: + self.states_cache.append(states) + else: + self.states_cache[layer_idx] = torch.cat([self.states_cache[layer_idx], states], dim=1) + + if buckets is not None: + if len(self.buckets_cache) <= layer_idx: + self.buckets_cache.append(buckets) + else: + self.buckets_cache[layer_idx] = torch.cat([self.buckets_cache[layer_idx], buckets], dim=-1) + else: + # `ReformerLocalAttn` passes `None` to buckets as the module uses no buckets + self.buckets_cache.append(torch.tensor([], device=self.states_cache[layer_idx].device)) + + return self.buckets_cache[layer_idx], self.states_cache[layer_idx] + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + return None + + def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor]]: + """Converts the `ReformerDynamicCache` instance into the its equivalent in the legacy cache format. Used for + backward compatibility.""" + legacy_cache = () + for layer_idx in range(len(self)): + buckets, states = self.buckets_cache[layer_idx], self.states_cache[layer_idx] + buckets = buckets if buckets.numel() != 0 else None + legacy_cache += ((buckets, states),) + return legacy_cache + + @classmethod + def from_legacy_cache( + cls, past_buckets_states: Optional[tuple[tuple[torch.FloatTensor, torch.FloatTensor]]] = None + ) -> "ReformerDynamicCache": + """Converts a cache in the legacy cache format into an equivalent `ReformerDynamicCache`. Used for + backward compatibility.""" + cache = cls() + if past_buckets_states is not None: + for layer_idx in range(len(past_buckets_states)): + buckets, states = past_buckets_states[layer_idx] + cache.update(buckets, states, layer_idx) + return cache + + def _stable_argsort(vector, dim): # this function scales the vector so that torch.argsort is stable. # torch.argsort is not stable on its own @@ -316,7 +430,7 @@ def _split_seq_length_dim_to(self, vectors, dim_factor_1, dim_factor_2, num_attn class LSHSelfAttention(nn.Module, EfficientAttentionMixin): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.config = config @@ -328,6 +442,7 @@ def __init__(self, config): self.hash_seed = config.hash_seed self.is_decoder = config.is_decoder self.max_position_embeddings = config.max_position_embeddings + self.layer_idx = layer_idx self.dropout = config.lsh_attention_probs_dropout_prob @@ -356,6 +471,7 @@ def forward( past_buckets_states=None, use_cache=False, output_attentions=False, + cache_position=None, **kwargs, ): sequence_length = hidden_states.shape[1] @@ -364,16 +480,13 @@ def forward( # num hashes can optionally be overwritten by user num_hashes = num_hashes if num_hashes is not None else self.num_hashes - do_cached_attention = use_cache and past_buckets_states[1] is not None - # check if cache shall be used and that hidden states are already cached - if do_cached_attention: + exists_cache = past_buckets_states is not None and len(past_buckets_states) > self.layer_idx + if exists_cache: assert sequence_length == 1, ( "At the moment, auto-regressive language generation is only possible one word at a time. Make sure" f" that input sequence length {sequence_length} equals 1, when `past_buckets_states` is passed." ) - past_buckets = past_buckets_states[0] - past_states = past_buckets_states[1] # get query vector query_vectors = self.query_key(hidden_states) @@ -381,7 +494,10 @@ def forward( query_vectors, self.num_attention_heads, self.attention_head_size ) - if past_buckets is not None: + past_buckets = past_buckets_states.buckets_cache[self.layer_idx] + past_states = past_buckets_states.states_cache[self.layer_idx] + + if past_buckets.numel() != 0: key_value_hidden_states, sorted_bucket_idx, buckets = self._get_relevant_hid_states_and_buckets( query_vectors=query_vectors, attention_mask=attention_mask, @@ -425,7 +541,7 @@ def forward( value_vectors = self.value(hidden_states) # if query key is not already split - if not do_cached_attention or past_buckets is None: + if not exists_cache or past_buckets.numel() == 0: query_key_vectors = self._split_hidden_size_dim( query_key_vectors, self.num_attention_heads, self.attention_head_size ) @@ -434,7 +550,7 @@ def forward( ) # cache buckets for next incremental decoding - if do_cached_attention and past_buckets is None and key_value_hidden_states.shape[1] >= self.chunk_length: + if exists_cache and key_value_hidden_states.shape[1] >= self.chunk_length: buckets = self._hash_vectors(query_key_vectors, num_hashes, attention_mask) # free memory @@ -448,7 +564,7 @@ def forward( ) do_standard_self_attention = (sequence_length <= self.chunk_length) or ( - use_cache and past_buckets_states[1] is not None + exists_cache and past_states is not None ) # LSH attention only makes sense if chunked attention should be performed if not do_standard_self_attention: @@ -498,7 +614,7 @@ def forward( "If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and" " `config.num_chunks_before` are set to 0." ) - elif do_cached_attention and past_buckets is not None: + elif exists_cache and past_buckets.numel() != 0: # use max sequence length sorted_bucket_idx_per_hash = sorted_bucket_idx else: @@ -526,7 +642,7 @@ def forward( attention_mask=attention_mask, head_mask=head_mask, do_standard_self_attention=do_standard_self_attention, - do_cached_attention=do_cached_attention, + use_cache=exists_cache, ) # free memory @@ -537,7 +653,7 @@ def forward( # sort clusters back to correct ordering out_vectors, logits = ReverseSort.apply(out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx) - if not do_standard_self_attention or (do_cached_attention and past_buckets is not None): + if not do_standard_self_attention or (exists_cache and past_buckets.numel() != 0): # sum up all hash rounds if num_hashes > 1: out_vectors = self._split_seq_length_dim_to( @@ -721,7 +837,7 @@ def _attend( attention_mask, head_mask, do_standard_self_attention, - do_cached_attention, + use_cache, ): # look at previous and following chunks if chunked attention if not do_standard_self_attention: @@ -741,12 +857,12 @@ def _attend( sorted_bucket_idx_per_hash, -1, self.chunk_length, self.num_attention_heads ) key_value_bucket_idx = self._look_adjacent(query_bucket_idx, self.num_chunks_before, self.num_chunks_after) - elif do_cached_attention and query_key_dots.ndim > 4: + elif use_cache and query_key_dots.ndim > 4: key_value_bucket_idx = sorted_bucket_idx_per_hash query_bucket_idx = ( key_value_bucket_idx.new_ones(key_value_bucket_idx.shape[:-1] + (1,)) * key_value_bucket_idx.max() ) - elif do_cached_attention and query_key_dots.ndim <= 4: + elif use_cache and query_key_dots.ndim <= 4: query_bucket_idx = (query_key_dots.shape[-1] - 1) * torch.ones_like(query_key_dots)[:, :, :, -1] key_value_bucket_idx = torch.arange( query_key_dots.shape[-1], dtype=torch.long, device=query_key_dots.device @@ -762,7 +878,7 @@ def _attend( self_mask_value = self.self_mask_value_float32 mask_value = self.mask_value_float32 - if not do_cached_attention: + if not use_cache: mask = self._compute_attn_mask( query_bucket_idx, key_value_bucket_idx, @@ -1016,7 +1132,7 @@ def backward(ctx, grad_out_vectors, grad_logits): class LocalSelfAttention(nn.Module, EfficientAttentionMixin): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.num_attention_heads = config.num_attention_heads @@ -1029,6 +1145,7 @@ def __init__(self, config): self.attention_head_size = config.attention_head_size self.all_head_size = self.num_attention_heads * self.attention_head_size self.hidden_size = config.hidden_size + self.layer_idx = layer_idx # projection matrices self.query = nn.Linear(self.hidden_size, self.all_head_size, bias=False) @@ -1055,13 +1172,16 @@ def forward( batch_size = hidden_states.shape[0] # check if cache shall be used and that hidden states are already cached - if use_cache and past_buckets_states[1] is not None: - assert past_buckets_states[0] is None, ( + if past_buckets_states is not None and len(past_buckets_states) > self.layer_idx: + past_buckets = past_buckets_states.buckets_cache[self.layer_idx] + past_states = past_buckets_states.states_cache[self.layer_idx] + + assert past_buckets.numel() == 0, ( "LocalSelfAttention should not make use of `buckets`. There seems to be an error when caching" " hidden_states_and_buckets." ) key_value_hidden_states = self._retrieve_relevant_hidden_states( - past_buckets_states[1], self.chunk_length, self.num_chunks_before + past_states, self.chunk_length, self.num_chunks_before ) key_value_hidden_states = torch.cat([key_value_hidden_states, hidden_states], dim=1) @@ -1262,15 +1382,15 @@ def __init__(self, config, layer_id=0): self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if len(set(self.attn_layers)) == 1 and self.attn_layers[0] == "lsh": - self.self_attention = LSHSelfAttention(config) + self.self_attention = LSHSelfAttention(config, layer_idx=layer_id) elif len(set(self.attn_layers)) == 1 and self.attn_layers[0] == "local": - self.self_attention = LocalSelfAttention(config) + self.self_attention = LocalSelfAttention(config, layer_idx=layer_id) elif len(set(self.attn_layers)) == 2 and set(self.attn_layers) == {"lsh", "local"}: # get correct attn layers if self.attn_layers[self.layer_id] == "lsh": - self.self_attention = LSHSelfAttention(config) + self.self_attention = LSHSelfAttention(config, layer_idx=layer_id) else: - self.self_attention = LocalSelfAttention(config) + self.self_attention = LocalSelfAttention(config, layer_idx=layer_id) else: raise NotImplementedError( f"Only attn layer types 'lsh' and 'local' exist, but got `config.attn_layers`: {self.attn_layers}. " @@ -1289,25 +1409,21 @@ def forward( orig_sequence_length=None, output_attentions=False, buckets=None, + cache_position=None, ): hidden_states = self.layer_norm(hidden_states) - # make sure cached hidden states is set to None for backward pass - if past_buckets_states is not None: - past_buckets_states_layer = past_buckets_states[self.layer_id] - else: - past_buckets_states_layer = None - # use cached buckets for backprob if buckets not None for LSHSelfAttention self_attention_outputs = self.self_attention( hidden_states=hidden_states, head_mask=head_mask, attention_mask=attention_mask, num_hashes=num_hashes, - past_buckets_states=past_buckets_states_layer, + past_buckets_states=past_buckets_states, use_cache=use_cache, output_attentions=output_attentions, buckets=buckets, + cache_position=cache_position, ) # add buckets if necessary @@ -1317,24 +1433,26 @@ def forward( buckets = None # cache hidden states for future use - if use_cache: - if past_buckets_states[self.layer_id][0] is None: - # padded input should not be cached - past_buckets = ( - buckets[:, :, :, :orig_sequence_length] - if (buckets is not None and orig_sequence_length > 1) - else buckets + if use_cache and past_buckets_states is not None: + # padded input should not be cached during prefill + states = ( + hidden_states[:, :orig_sequence_length] + if len(past_buckets_states.states_cache) <= self.layer_id + else hidden_states + ) + buckets = ( + buckets[:, :, :, :orig_sequence_length] + if ( + len(past_buckets_states.buckets_cache) <= self.layer_id + and buckets is not None + and orig_sequence_length > 1 ) - else: - past_buckets = torch.cat([past_buckets_states[self.layer_id][0], buckets], dim=-1) - - if past_buckets_states[self.layer_id][1] is None: - # padded input should not be cached - past_states = hidden_states[:, :orig_sequence_length] - else: - past_states = torch.cat([past_buckets_states[self.layer_id][1], hidden_states], dim=1) + else buckets + ) + buckets, hidden_states = past_buckets_states.update( + buckets, states[:, :orig_sequence_length], self.layer_id + ) - past_buckets_states[self.layer_id] = (past_buckets, past_states) # compute attention feed forward output attention_output = self.output(self_attention_outputs.hidden_states) @@ -1708,8 +1826,15 @@ def forward( all_attentions = [] # init cached hidden states if necessary - if past_buckets_states is None: - past_buckets_states = [((None), (None)) for i in range(len(self.layers))] + return_legacy_cache = False + if use_cache or not isinstance(past_buckets_states, ReformerDynamicCache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `ReformerDynamicCache` instead, e.g. " + "`past_key_values=ReformerDynamicCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_buckets_states = ReformerDynamicCache.from_legacy_cache(past_buckets_states) # concat same tensor for reversible ResNet hidden_states = torch.cat([hidden_states, hidden_states], dim=-1) @@ -1734,11 +1859,15 @@ def forward( # Apply dropout hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + next_cache = past_buckets_states if use_cache else None + if return_legacy_cache: + next_cache = past_buckets_states.to_legacy_cache() + return ReformerEncoderOutput( hidden_states=hidden_states, all_hidden_states=all_hidden_states, all_attentions=all_attentions, - past_buckets_states=past_buckets_states, + past_buckets_states=next_cache, ) @@ -2087,7 +2216,7 @@ def _pad_to_mult_of_chunk_length( # Extend `inputs_embeds` with padding to match least common multiple chunk_length if inputs_embeds is not None: - padded_inputs_embeds = self.embeddings(padded_input_ids, position_ids) + padded_inputs_embeds = self.get_input_embeddings()(padded_input_ids) inputs_embeds = torch.cat([inputs_embeds, padded_inputs_embeds], dim=-2) input_shape = inputs_embeds.size() return input_ids, inputs_embeds, attention_mask, position_ids, input_shape @@ -2240,6 +2369,9 @@ def _reorder_cache(self, past_key_values, beam_idx): # hidden states reord_hidden_states = layer_past[1].index_select(0, beam_idx.to(layer_past[1].device)) reord_past_buckets_states.append((reord_buckets, reord_hidden_states)) + + if isinstance(past_key_values, ReformerDynamicCache): + reord_past_buckets_states = ReformerDynamicCache.from_legacy_cache(reord_past_buckets_states) return reord_past_buckets_states diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index 774fc46ef7fc..93dd4a31fc96 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -39,6 +40,7 @@ from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_rembert import RemBertConfig @@ -199,7 +201,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class RemBertSelfAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -218,12 +220,9 @@ def __init__(self, config): self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, @@ -231,45 +230,61 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple: - mixed_query_layer = self.query(hidden_states) + batch_size, seq_length, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if is_cross_attention and encoder_attention_mask is not None: + attention_mask = encoder_attention_mask - if is_cross_attention and past_key_value is not None: + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - - query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = ( + self.key(current_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(current_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -296,11 +311,7 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(*new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return context_layer, attention_probs # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->RemBert @@ -319,9 +330,9 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class RemBertAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() - self.self = RemBertSelfAttention(config) + self.self = RemBertSelfAttention(config, layer_idx=layer_idx) self.output = RemBertSelfOutput(config) self.pruned_heads = set() @@ -344,6 +355,7 @@ def prune_heads(self, heads): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") # Copied from transformers.models.bert.modeling_bert.BertAttention.forward def forward( self, @@ -352,17 +364,19 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -401,17 +415,17 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class RemBertLayer(GradientCheckpointingLayer): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = RemBertAttention(config) + self.attention = RemBertAttention(config, layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = RemBertAttention(config) + self.crossattention = RemBertAttention(config, layer_idx=layer_idx) self.intermediate = RemBertIntermediate(config) self.output = RemBertOutput(config) @@ -423,28 +437,21 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -452,33 +459,23 @@ def forward( " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - cross_attn_past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (present_key_value,) - return outputs # Copied from transformers.models.bert.modeling_bert.BertLayer.feed_forward_chunk @@ -494,7 +491,7 @@ def __init__(self, config): self.config = config self.embedding_hidden_mapping_in = nn.Linear(config.input_embedding_size, config.hidden_size) - self.layer = nn.ModuleList([RemBertLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([RemBertLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( @@ -509,6 +506,7 @@ def forward( output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: if self.gradient_checkpointing and self.training: if use_cache: @@ -516,18 +514,27 @@ def forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + hidden_states = self.embedding_hidden_mapping_in(hidden_states) all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - next_decoder_cache = () if use_cache else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None layer_outputs = layer_module( hidden_states, @@ -535,13 +542,11 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -550,12 +555,15 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -564,7 +572,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -699,6 +707,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple, BaseModelOutputWithPoolingAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -724,8 +733,13 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) @@ -772,6 +786,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None @@ -1004,15 +1019,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) - + layer_past[2:], - ) - return reordered_past - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index ecf3a6cc5314..003b2fa519bb 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_layers import GradientCheckpointingLayer @@ -41,6 +42,7 @@ from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, get_torch_version, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_roberta import RobertaConfig @@ -138,7 +140,7 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds): # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Roberta class RobertaSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -163,12 +165,9 @@ def __init__(self, config, position_embedding_type=None): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, @@ -176,53 +175,65 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if is_cross_attention and encoder_attention_mask is not None: + attention_mask = encoder_attention_mask - if is_cross_attention and past_key_value is not None: + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - - query_layer = self.transpose_for_scores(mixed_query_layer) - - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) + key_layer = self.key(current_states) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + value_layer = self.value(current_states) + value_layer = value_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: + if past_key_value is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -264,21 +275,18 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return context_layer, attention_probs # Copied from transformers.models.bert.modeling_bert.BertSdpaSelfAttention with Bert->Roberta class RobertaSdpaSelfAttention(RobertaSelfAttention): - def __init__(self, config, position_embedding_type=None): - super().__init__(config, position_embedding_type=position_embedding_type) + def __init__(self, config, position_embedding_type=None, layer_idx=None): + super().__init__(config, position_embedding_type=position_embedding_type, layer_idx=layer_idx) self.dropout_prob = config.attention_probs_dropout_prob self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") # Adapted from RobertaSelfAttention + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, @@ -286,8 +294,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None: # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. @@ -306,38 +315,59 @@ def forward( encoder_attention_mask, past_key_value, output_attentions, + cache_position, ) bsz, tgt_len, _ = hidden_states.size() - query_layer = self.transpose_for_scores(self.query(hidden_states)) + query_layer = ( + self.query(hidden_states).view(bsz, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + ) # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention # mask needs to be such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if is_cross_attention and encoder_attention_mask is not None: + attention_mask = encoder_attention_mask current_states = encoder_hidden_states if is_cross_attention else hidden_states - attention_mask = encoder_attention_mask if is_cross_attention else attention_mask + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value - # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning - if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: - key_layer, value_layer = past_key_value + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) - if past_key_value is not None and not is_cross_attention: - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) + key_layer = ( + self.key(current_states) + .view(bsz, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(current_states) + .view(bsz, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. @@ -367,10 +397,7 @@ def forward( attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) - outputs = (attn_output,) - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return attn_output, None # Copied from transformers.models.bert.modeling_bert.BertSelfOutput @@ -396,10 +423,12 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Roberta,BERT->ROBERTA class RobertaAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() self.self = ROBERTA_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, position_embedding_type=position_embedding_type + config, + position_embedding_type=position_embedding_type, + layer_idx=layer_idx, ) self.output = RobertaSelfOutput(config) self.pruned_heads = set() @@ -422,6 +451,7 @@ def prune_heads(self, heads): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, @@ -429,17 +459,19 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -479,17 +511,17 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Roberta class RobertaLayer(GradientCheckpointingLayer): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = RobertaAttention(config) + self.attention = RobertaAttention(config, layer_idx=layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = RobertaAttention(config, position_embedding_type="absolute") + self.crossattention = RobertaAttention(config, position_embedding_type="absolute", layer_idx=layer_idx) self.intermediate = RobertaIntermediate(config) self.output = RobertaOutput(config) @@ -500,28 +532,21 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -529,33 +554,23 @@ def forward( " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - cross_attn_past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (present_key_value,) - return outputs def feed_forward_chunk(self, attention_output): @@ -566,10 +581,10 @@ def feed_forward_chunk(self, attention_output): # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Roberta class RobertaEncoder(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.config = config - self.layer = nn.ModuleList([RobertaLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([RobertaLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( @@ -584,6 +599,7 @@ def forward( output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -596,13 +612,21 @@ def forward( ) use_cache = False - next_decoder_cache = () if use_cache else None + return_legacy_cache = False + if use_cache and self.config.is_decoder and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None layer_outputs = layer_module( hidden_states, @@ -610,13 +634,12 @@ def forward( layer_head_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -625,12 +648,15 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -639,7 +665,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -755,6 +781,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -780,8 +807,13 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if token_type_ids is None: if hasattr(self.embeddings, "token_type_ids"): @@ -866,6 +898,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None @@ -1006,14 +1039,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @auto_docstring class RobertaForMaskedLM(RobertaPreTrainedModel): diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index e8636281e399..e04faa1b6387 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -39,6 +40,7 @@ from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_roberta_prelayernorm import RobertaPreLayerNormConfig @@ -137,7 +139,7 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds): # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->RobertaPreLayerNorm class RobertaPreLayerNormSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -162,12 +164,9 @@ def __init__(self, config, position_embedding_type=None): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, @@ -175,53 +174,65 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if is_cross_attention and encoder_attention_mask is not None: + attention_mask = encoder_attention_mask + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value - if is_cross_attention and past_key_value is not None: + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - - query_layer = self.transpose_for_scores(mixed_query_layer) - - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) + key_layer = self.key(current_states) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + value_layer = self.value(current_states) + value_layer = value_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: + if past_key_value is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -263,11 +274,7 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return context_layer, attention_probs class RobertaPreLayerNormSelfOutput(nn.Module): @@ -284,9 +291,11 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class RobertaPreLayerNormAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() - self.self = RobertaPreLayerNormSelfAttention(config, position_embedding_type=position_embedding_type) + self.self = RobertaPreLayerNormSelfAttention( + config, position_embedding_type=position_embedding_type, layer_idx=layer_idx + ) self.output = RobertaPreLayerNormSelfOutput(config) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.pruned_heads = set() @@ -310,6 +319,7 @@ def prune_heads(self, heads): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, @@ -317,8 +327,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: hidden_states_pre_layer_norm = self.LayerNorm(hidden_states) self_outputs = self.self( @@ -329,6 +340,7 @@ def forward( encoder_attention_mask, past_key_value, output_attentions, + cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -367,17 +379,19 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->RobertaPreLayerNorm class RobertaPreLayerNormLayer(GradientCheckpointingLayer): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = RobertaPreLayerNormAttention(config) + self.attention = RobertaPreLayerNormAttention(config, layer_idx=layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = RobertaPreLayerNormAttention(config, position_embedding_type="absolute") + self.crossattention = RobertaPreLayerNormAttention( + config, position_embedding_type="absolute", layer_idx=layer_idx + ) self.intermediate = RobertaPreLayerNormIntermediate(config) self.output = RobertaPreLayerNormOutput(config) @@ -388,28 +402,21 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -417,33 +424,23 @@ def forward( " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - cross_attn_past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (present_key_value,) - return outputs def feed_forward_chunk(self, attention_output): @@ -454,10 +451,12 @@ def feed_forward_chunk(self, attention_output): # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->RobertaPreLayerNorm class RobertaPreLayerNormEncoder(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.config = config - self.layer = nn.ModuleList([RobertaPreLayerNormLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList( + [RobertaPreLayerNormLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)] + ) self.gradient_checkpointing = False def forward( @@ -472,6 +471,7 @@ def forward( output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -484,13 +484,21 @@ def forward( ) use_cache = False - next_decoder_cache = () if use_cache else None + return_legacy_cache = False + if use_cache and self.config.is_decoder and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None layer_outputs = layer_module( hidden_states, @@ -498,13 +506,12 @@ def forward( layer_head_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -513,12 +520,15 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -527,7 +537,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -676,8 +686,13 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) @@ -874,14 +889,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index 0a2e74e579df..8d98140aff9a 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -39,6 +40,7 @@ from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_roc_bert import RoCBertConfig @@ -252,7 +254,7 @@ def forward( # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->RoCBert class RoCBertSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -277,12 +279,9 @@ def __init__(self, config, position_embedding_type=None): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, @@ -290,53 +289,65 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if is_cross_attention and encoder_attention_mask is not None: + attention_mask = encoder_attention_mask + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value - if is_cross_attention and past_key_value is not None: + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - - query_layer = self.transpose_for_scores(mixed_query_layer) - - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) + key_layer = self.key(current_states) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + value_layer = self.value(current_states) + value_layer = value_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: + if past_key_value is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -378,11 +389,7 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return context_layer, attention_probs # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->RoCBert @@ -407,10 +414,12 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->RoCBert,BERT->ROC_BERT class RoCBertAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() self.self = ROC_BERT_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, position_embedding_type=position_embedding_type + config, + position_embedding_type=position_embedding_type, + layer_idx=layer_idx, ) self.output = RoCBertSelfOutput(config) self.pruned_heads = set() @@ -433,6 +442,7 @@ def prune_heads(self, heads): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, @@ -440,17 +450,19 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -490,17 +502,17 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->RoCBert class RoCBertLayer(GradientCheckpointingLayer): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = RoCBertAttention(config) + self.attention = RoCBertAttention(config, layer_idx=layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = RoCBertAttention(config, position_embedding_type="absolute") + self.crossattention = RoCBertAttention(config, position_embedding_type="absolute", layer_idx=layer_idx) self.intermediate = RoCBertIntermediate(config) self.output = RoCBertOutput(config) @@ -511,28 +523,21 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -540,33 +545,23 @@ def forward( " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - cross_attn_past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (present_key_value,) - return outputs def feed_forward_chunk(self, attention_output): @@ -577,10 +572,10 @@ def feed_forward_chunk(self, attention_output): # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->RoCBert class RoCBertEncoder(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.config = config - self.layer = nn.ModuleList([RoCBertLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([RoCBertLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( @@ -595,6 +590,7 @@ def forward( output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -607,13 +603,21 @@ def forward( ) use_cache = False - next_decoder_cache = () if use_cache else None + return_legacy_cache = False + if use_cache and self.config.is_decoder and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None layer_outputs = layer_module( hidden_states, @@ -621,13 +625,12 @@ def forward( layer_head_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -636,12 +639,15 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -650,7 +656,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -871,8 +877,13 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) @@ -1459,7 +1470,7 @@ def prepare_inputs_for_generation( # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] + past_length = past_key_values.get_seq_length() # Some generation methods already pass only the last input ID if input_ids.shape[1] > past_length: @@ -1482,15 +1493,6 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, } - # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel._reorder_cache - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index 8439fed19cfc..58c2320ecda6 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, get_activation +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -188,7 +189,7 @@ def forward(self, input_ids=None, token_type_ids=None, inputs_embeds=None): class RoFormerSelfAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -208,11 +209,7 @@ def __init__(self, config): self.is_decoder = config.is_decoder self.rotary_value = config.rotary_value - - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) + self.layer_idx = layer_idx def forward( self, @@ -221,50 +218,58 @@ def forward( sinusoidal_pos=None, head_mask=None, encoder_hidden_states=None, - encoder_attention_mask=None, past_key_value=None, output_attentions=False, + cache_position=None, ): - mixed_query_layer = self.query(hidden_states) - query_layer = self.transpose_for_scores(mixed_query_layer) + batch_size, seq_length, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - if is_cross_attention and past_key_value is not None: + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - if sinusoidal_pos is not None: - if self.rotary_value: - query_layer, key_layer, value_layer = self.apply_rotary_position_embeddings( - sinusoidal_pos, query_layer, key_layer, value_layer - ) - else: - query_layer, key_layer = self.apply_rotary_position_embeddings( - sinusoidal_pos, query_layer, key_layer - ) + key_layer = ( + self.key(current_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(current_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + if past_key_value is not None: - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -291,11 +296,7 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(*new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return context_layer, attention_probs @staticmethod def apply_rotary_position_embeddings(sinusoidal_pos, query_layer, key_layer, value_layer=None): @@ -341,9 +342,9 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class RoFormerAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() - self.self = RoFormerSelfAttention(config) + self.self = RoFormerSelfAttention(config, layer_idx=layer_idx) self.output = RoFormerSelfOutput(config) self.pruned_heads = set() @@ -374,9 +375,9 @@ def forward( sinusoidal_pos=None, head_mask=None, encoder_hidden_states=None, - encoder_attention_mask=None, past_key_value=None, output_attentions=False, + cache_position=None, ): self_outputs = self.self( hidden_states, @@ -384,9 +385,9 @@ def forward( sinusoidal_pos, head_mask, encoder_hidden_states, - encoder_attention_mask, past_key_value, output_attentions, + cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -425,17 +426,17 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class RoFormerLayer(GradientCheckpointingLayer): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = RoFormerAttention(config) + self.attention = RoFormerAttention(config, layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = RoFormerAttention(config) + self.crossattention = RoFormerAttention(config, layer_idx) self.intermediate = RoFormerIntermediate(config) self.output = RoFormerOutput(config) @@ -449,27 +450,20 @@ def forward( encoder_attention_mask=None, past_key_value=None, output_attentions=False, + cache_position=None, ): - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, - attention_mask, - sinusoidal_pos, - head_mask, + attention_mask=attention_mask, + sinusoidal_pos=sinusoidal_pos, + head_mask=head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -477,35 +471,23 @@ def forward( "layers by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - sinusoidal_pos, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - cross_attn_past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + sinusoidal_pos=sinusoidal_pos, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) - outputs = (layer_output,) + outputs - - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (present_key_value,) - - return outputs + return (layer_output,) + outputs def feed_forward_chunk(self, attention_output): intermediate_output = self.intermediate(attention_output) @@ -520,7 +502,7 @@ def __init__(self, config): self.embed_positions = RoFormerSinusoidalPositionalEmbedding( config.max_position_embeddings, config.hidden_size // config.num_attention_heads ) - self.layer = nn.ModuleList([RoFormerLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([RoFormerLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( @@ -535,6 +517,7 @@ def forward( output_attentions=False, output_hidden_states=False, return_dict=True, + cache_position=None, ): if self.gradient_checkpointing and self.training: if use_cache: @@ -542,22 +525,31 @@ def forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 # [sequence_length, embed_size_per_head] -> [batch_size, num_heads, sequence_length, embed_size_per_head] sinusoidal_pos = self.embed_positions(hidden_states.shape[:-1], past_key_values_length)[None, None, :, :] - next_decoder_cache = () if use_cache else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None layer_outputs = layer_module( hidden_states, @@ -566,13 +558,12 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -581,12 +572,15 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -595,7 +589,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -837,6 +831,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[BaseModelOutputWithPastAndCrossAttentions, tuple[torch.Tensor]]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -862,8 +857,13 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) @@ -909,6 +909,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = encoder_outputs[0] @@ -1064,6 +1065,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs, ) -> Union[CausalLMOutputWithCrossAttentions, tuple[torch.Tensor]]: r""" @@ -1103,6 +1105,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = outputs[0] @@ -1130,15 +1133,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) - + layer_past[2:], - ) - return reordered_past - class RoFormerClassificationHead(nn.Module): """Head for sentence-level classification tasks.""" diff --git a/src/transformers/models/rwkv/modeling_rwkv.py b/src/transformers/models/rwkv/modeling_rwkv.py index 81b529175cf6..4614eaff3e85 100644 --- a/src/transformers/models/rwkv/modeling_rwkv.py +++ b/src/transformers/models/rwkv/modeling_rwkv.py @@ -526,7 +526,7 @@ def forward( r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as @@ -720,7 +720,7 @@ def forward( r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index 21403244ee68..d461ee6e3dca 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -25,6 +25,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module @@ -1029,16 +1030,14 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -1050,45 +1049,42 @@ def forward( # get query proj query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - # `past_key_value[0].shape[2] == encoder_hidden_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `encoder_hidden_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == encoder_hidden_states.shape[1] - ): + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(encoder_hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(encoder_hidden_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states.reshape(*proj_shape) key_states = key_states.reshape(*proj_shape) value_states = value_states.reshape(*proj_shape) @@ -1140,7 +1136,7 @@ def forward( attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights_reshaped # Copied from transformers.models.nllb_moe.modeling_nllb_moe.NllbMoeDenseActDense with NllbMoe->SeamlessM4T,DenseActDense->FeedForwardNetwork, d_model->hidden_size @@ -1204,7 +1200,7 @@ def forward( """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states, attn_weights, _ = self.self_attn( + hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, @@ -1230,7 +1226,7 @@ def forward( class SeamlessM4TDecoderLayer(GradientCheckpointingLayer): - def __init__(self, config: SeamlessM4TConfig, decoder_ffn_dim=None, decoder_attention_heads=None): + def __init__(self, config: SeamlessM4TConfig, decoder_ffn_dim=None, decoder_attention_heads=None, layer_idx=None): super().__init__() decoder_ffn_dim = config.decoder_ffn_dim if decoder_ffn_dim is None else decoder_ffn_dim decoder_attention_heads = ( @@ -1243,6 +1239,7 @@ def __init__(self, config: SeamlessM4TConfig, decoder_ffn_dim=None, decoder_atte num_heads=decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -1250,7 +1247,11 @@ def __init__(self, config: SeamlessM4TConfig, decoder_ffn_dim=None, decoder_atte self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.cross_attention = SeamlessM4TAttention( - self.embed_dim, decoder_attention_heads, config.attention_dropout, is_decoder=True + self.embed_dim, + decoder_attention_heads, + config.attention_dropout, + is_decoder=True, + layer_idx=layer_idx, ) self.cross_attention_layer_norm = nn.LayerNorm(self.embed_dim) @@ -1265,9 +1266,10 @@ def forward( attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Args: @@ -1291,41 +1293,33 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = self.attn_dropout(hidden_states) hidden_states = residual + hidden_states # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states hidden_states = self.cross_attention_layer_norm(hidden_states) - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.cross_attention( + hidden_states, cross_attn_weights = self.cross_attention( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, attention_mask=encoder_attention_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = self.attn_dropout(hidden_states) hidden_states = residual + hidden_states - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value += cross_attn_present_key_value - # Fully Connected residual = hidden_states @@ -1336,12 +1330,7 @@ def forward( hidden_states = residual + hidden_states - outputs = (hidden_states, present_key_value) - - if output_attentions: - outputs += (self_attn_weights, cross_attn_weights) - - return outputs + return hidden_states, self_attn_weights, cross_attn_weights ############ SUB-MODELS related code ################ @@ -1748,12 +1737,13 @@ def __init__( ) layers = [] - for _ in range(config.decoder_layers): + for i in range(config.decoder_layers): layers.append( SeamlessM4TDecoderLayer( config, decoder_attention_heads=config.decoder_attention_heads, decoder_ffn_dim=config.decoder_ffn_dim, + layer_idx=i, ) ) self.layers = nn.ModuleList(layers) @@ -1782,6 +1772,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1803,12 +1794,28 @@ def forward( else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." + ) + use_cache = False + + # initialize `past_key_values` + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 attention_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length ) @@ -1827,18 +1834,10 @@ def forward( hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) @@ -1849,27 +1848,22 @@ def forward( if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - layer_outputs = decoder_layer( hidden_states, attention_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[1],) - if output_attentions: - all_self_attns += (layer_outputs[2],) + all_self_attns += (layer_outputs[1],) if encoder_hidden_states is not None: - all_cross_attentions += (layer_outputs[3],) + all_cross_attentions += (layer_outputs[2],) hidden_states = self.layer_norm(hidden_states) @@ -1877,16 +1871,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, @@ -1930,6 +1926,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.Tensor], Seq2SeqModelOutput]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1967,6 +1964,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -2055,6 +2053,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Seq2SeqLMOutput, tuple[torch.FloatTensor]]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -2086,6 +2085,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) lm_logits = self.lm_head(outputs[0]) @@ -2114,16 +2114,6 @@ def forward( def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.t2u_pad_token_id, self.config.t2u_decoder_start_token_id) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], - ) - return reordered_past - def _tie_weights(self) -> None: if getattr(self.config, "tie_word_embeddings", True): output_embeddings = self.get_output_embeddings() @@ -2733,16 +2723,6 @@ def generate( **kwargs, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], - ) - return reordered_past - @auto_docstring( custom_intro=""" @@ -3010,16 +2990,6 @@ def generate( **kwargs, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], - ) - return reordered_past - @auto_docstring( custom_intro=""" @@ -3093,6 +3063,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Seq2SeqLMOutput, tuple[torch.FloatTensor]]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -3153,6 +3124,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) lm_logits = self.lm_head(decoder_outputs[0]) @@ -3347,16 +3319,6 @@ def generate( return waveform, waveform_lengths - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], - ) - return reordered_past - @auto_docstring( custom_intro=""" @@ -3694,16 +3656,6 @@ def generate( return waveform, waveform_lengths - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], - ) - return reordered_past - @auto_docstring( custom_intro=""" @@ -4144,16 +4096,6 @@ def generate( return waveform, waveform_lengths - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], - ) - return reordered_past - __all__ = [ "SeamlessM4TForTextToSpeech", diff --git a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py index 2210e1426dd2..3f4595eeee62 100644 --- a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +++ b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py @@ -25,6 +25,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module @@ -901,53 +902,56 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, projection: torch.Tensor) -> torch.Tensor: - new_projection_shape = projection.size()[:-1] + (self.num_heads, self.head_dim) - # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D) - new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) - return new_projection - def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" is_cross_attention = encoder_hidden_states is not None batch_size, seq_length = hidden_states.shape[:2] - # use encoder_hidden_states if cross attention - current_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states - # checking that the `sequence_length` of the `past_key_value` is the same as the he provided - # `encoder_hidden_states` to support prefix tuning - if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - key_states = self._shape(self.k_proj(current_states)) - value_states = self._shape(self.v_proj(current_states)) - if past_key_value is not None and not is_cross_attention: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - query_states = self._shape(self.q_proj(hidden_states) * self.scaling) - attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + query_states = self.q_proj(hidden_states) + query_states = query_states.reshape(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states * self.scaling + attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) if attention_mask is not None: attention_scores = attention_scores + attention_mask @@ -962,10 +966,7 @@ def forward( context_states = context_states.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_length, -1) attn_output = self.out_proj(context_states) - if output_attentions: - return attn_output, attn_weights, past_key_value - else: - return attn_output, None, past_key_value + return attn_output, attn_weights # Copied from transformers.models.nllb_moe.modeling_nllb_moe.NllbMoeDenseActDense with NllbMoe->SeamlessM4Tv2,DenseActDense->FeedForwardNetwork, d_model->hidden_size @@ -1030,7 +1031,7 @@ def forward( """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states, attn_weights, _ = self.self_attn( + hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, @@ -1057,7 +1058,9 @@ def forward( # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TDecoderLayer with SeamlessM4T->SeamlessM4Tv2 class SeamlessM4Tv2DecoderLayer(GradientCheckpointingLayer): - def __init__(self, config: SeamlessM4Tv2Config, decoder_ffn_dim=None, decoder_attention_heads=None): + def __init__( + self, config: SeamlessM4Tv2Config, decoder_ffn_dim=None, decoder_attention_heads=None, layer_idx=None + ): super().__init__() decoder_ffn_dim = config.decoder_ffn_dim if decoder_ffn_dim is None else decoder_ffn_dim decoder_attention_heads = ( @@ -1070,6 +1073,7 @@ def __init__(self, config: SeamlessM4Tv2Config, decoder_ffn_dim=None, decoder_at num_heads=decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -1077,7 +1081,11 @@ def __init__(self, config: SeamlessM4Tv2Config, decoder_ffn_dim=None, decoder_at self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.cross_attention = SeamlessM4Tv2Attention( - self.embed_dim, decoder_attention_heads, config.attention_dropout, is_decoder=True + self.embed_dim, + decoder_attention_heads, + config.attention_dropout, + is_decoder=True, + layer_idx=layer_idx, ) self.cross_attention_layer_norm = nn.LayerNorm(self.embed_dim) @@ -1092,9 +1100,10 @@ def forward( attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Args: @@ -1118,41 +1127,33 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = self.attn_dropout(hidden_states) hidden_states = residual + hidden_states # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states hidden_states = self.cross_attention_layer_norm(hidden_states) - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.cross_attention( + hidden_states, cross_attn_weights = self.cross_attention( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, attention_mask=encoder_attention_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = self.attn_dropout(hidden_states) hidden_states = residual + hidden_states - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value += cross_attn_present_key_value - # Fully Connected residual = hidden_states @@ -1163,12 +1164,7 @@ def forward( hidden_states = residual + hidden_states - outputs = (hidden_states, present_key_value) - - if output_attentions: - outputs += (self_attn_weights, cross_attn_weights) - - return outputs + return hidden_states, self_attn_weights, cross_attn_weights class SeamlessM4Tv2TextToUnitDecoderLayer(GradientCheckpointingLayer): @@ -1220,7 +1216,7 @@ def forward( residual = hidden_states # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, @@ -1246,12 +1242,7 @@ def forward( hidden_states = residual + hidden_states hidden_states = self.conv_layer_norm(hidden_states) - outputs = (hidden_states, present_key_value) - - if output_attentions: - outputs += self_attn_weights - - return outputs + return hidden_states, self_attn_weights ############ SUB-MODELS related code ################ @@ -1786,12 +1777,13 @@ def __init__( ) layers = [] - for _ in range(config.decoder_layers): + for i in range(config.decoder_layers): layers.append( SeamlessM4Tv2DecoderLayer( config, decoder_attention_heads=config.decoder_attention_heads, decoder_ffn_dim=config.decoder_ffn_dim, + layer_idx=i, ) ) self.layers = nn.ModuleList(layers) @@ -1820,6 +1812,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1841,12 +1834,28 @@ def forward( else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." + ) + use_cache = False + + # initialize `past_key_values` + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 attention_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length ) @@ -1865,18 +1874,10 @@ def forward( hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) @@ -1887,27 +1888,23 @@ def forward( if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - layer_outputs = decoder_layer( hidden_states, attention_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[1],) - if output_attentions: - all_self_attns += (layer_outputs[2],) + all_self_attns += (layer_outputs[1],) if encoder_hidden_states is not None: - all_cross_attentions += (layer_outputs[3],) + all_cross_attentions += (layer_outputs[2],) hidden_states = self.layer_norm(hidden_states) @@ -1915,16 +1912,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, @@ -2937,16 +2936,6 @@ def generate( **kwargs, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], - ) - return reordered_past - @auto_docstring( custom_intro=""" @@ -3224,17 +3213,6 @@ def generate( **kwargs, ) - @staticmethod - # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForSpeechToText._reorder_cache - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], - ) - return reordered_past - @auto_docstring( custom_intro=""" @@ -3317,6 +3295,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Seq2SeqLMOutput, tuple[torch.FloatTensor]]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -3377,6 +3356,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) lm_logits = self.lm_head(decoder_outputs[0]) @@ -3602,17 +3582,6 @@ def generate( return waveform, waveform_lengths - @staticmethod - # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForTextToSpeech._reorder_cache - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], - ) - return reordered_past - @auto_docstring( custom_intro=""" @@ -3989,17 +3958,6 @@ def generate( return waveform, waveform_lengths - @staticmethod - # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForSpeechToSpeech._reorder_cache - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], - ) - return reordered_past - @auto_docstring( custom_intro=""" @@ -4484,17 +4442,6 @@ def generate( return waveform, waveform_lengths - @staticmethod - # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TModel._reorder_cache - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], - ) - return reordered_past - __all__ = [ "SeamlessM4Tv2ForTextToSpeech", diff --git a/src/transformers/models/segformer/modeling_segformer.py b/src/transformers/models/segformer/modeling_segformer.py index 81c220446103..b998e0546d8f 100755 --- a/src/transformers/models/segformer/modeling_segformer.py +++ b/src/transformers/models/segformer/modeling_segformer.py @@ -152,11 +152,6 @@ def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_ ) self.layer_norm = nn.LayerNorm(hidden_size) - def transpose_for_scores(self, hidden_states): - new_shape = hidden_states.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - hidden_states = hidden_states.view(new_shape) - return hidden_states.permute(0, 2, 1, 3) - def forward( self, hidden_states, @@ -164,7 +159,12 @@ def forward( width, output_attentions=False, ): - query_layer = self.transpose_for_scores(self.query(hidden_states)) + batch_size, seq_length, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) if self.sr_ratio > 1: batch_size, seq_len, num_channels = hidden_states.shape @@ -176,8 +176,16 @@ def forward( hidden_states = hidden_states.reshape(batch_size, num_channels, -1).permute(0, 2, 1) hidden_states = self.layer_norm(hidden_states) - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index b3f4388cb53c..b21a2a12cccf 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -299,7 +299,6 @@ def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, diff --git a/src/transformers/models/smollm3/modeling_smollm3.py b/src/transformers/models/smollm3/modeling_smollm3.py index da584a63fc47..afa85c915bd3 100644 --- a/src/transformers/models/smollm3/modeling_smollm3.py +++ b/src/transformers/models/smollm3/modeling_smollm3.py @@ -292,8 +292,7 @@ class SmolLM3PreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_attention_backend = True _can_record_outputs = { diff --git a/src/transformers/models/smolvlm/modeling_smolvlm.py b/src/transformers/models/smolvlm/modeling_smolvlm.py index 2fd0776edf71..a0da6da35076 100644 --- a/src/transformers/models/smolvlm/modeling_smolvlm.py +++ b/src/transformers/models/smolvlm/modeling_smolvlm.py @@ -57,7 +57,7 @@ class SmolVLMPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_attention_backend = True def _init_weights(self, module): diff --git a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py index 4c6e91ffeeac..f5451a0d1e25 100644 --- a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py @@ -503,9 +503,5 @@ def resize_token_embeddings(self, *args, **kwargs): " respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))" ) - def _reorder_cache(self, past_key_values, beam_idx): - # apply decoder cache reordering here - return self.decoder._reorder_cache(past_key_values, beam_idx) - __all__ = ["SpeechEncoderDecoderModel"] diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index b304f457beb3..ae8a1595b86c 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -22,6 +22,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask, @@ -213,11 +214,12 @@ def __init__( self, embed_dim: int, num_heads: int, - dropout: float = 0.0, - is_decoder: bool = False, - bias: bool = True, - is_causal: bool = False, + dropout: Optional[float] = 0.0, + is_decoder: Optional[bool] = False, + bias: Optional[bool] = True, + is_causal: Optional[bool] = False, config: Optional[Speech2TextConfig] = None, + layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim @@ -234,6 +236,7 @@ def __init__( self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal + self.layer_idx = layer_idx self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -244,10 +247,11 @@ def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, # TODO: we need a refactor so that the different attention modules can get their specific kwargs # ATM, we have mixed things encoder, decoder, and encoder-decoder attn **kwargs: Unpack[FlashAttentionKwargs], @@ -268,42 +272,35 @@ def forward( # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -325,7 +322,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Speech2Text, MBART->SPEECH_TO_TEXT @@ -368,7 +365,7 @@ def forward( """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states, attn_weights, _ = self.self_attn( + hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, @@ -389,18 +386,13 @@ def forward( clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs + return hidden_states, attn_weights # copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Speech2Text, MBART->SPEECH_TO_TEXT # TODO: change copy when applying cache class class Speech2TextDecoderLayer(GradientCheckpointingLayer): - def __init__(self, config: Speech2TextConfig): + def __init__(self, config: Speech2TextConfig, layer_idx=None): super().__init__() self.embed_dim = config.d_model @@ -411,6 +403,7 @@ def __init__(self, config: Speech2TextConfig): is_decoder=True, is_causal=True, config=config, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -423,6 +416,7 @@ def __init__(self, config: Speech2TextConfig): dropout=config.attention_dropout, is_decoder=True, config=config, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) @@ -438,9 +432,10 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Args: @@ -464,42 +459,35 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # Fully Connected residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) @@ -514,9 +502,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -766,7 +751,9 @@ def __init__(self, config: Speech2TextConfig): self.padding_idx, ) - self.layers = nn.ModuleList([Speech2TextDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layers = nn.ModuleList( + [Speech2TextDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)] + ) self.layer_norm = nn.LayerNorm(config.d_model) @@ -794,6 +781,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, + cache_position=None, ): r""" Args: @@ -878,12 +866,27 @@ def forward( else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..." + ) + use_cache = False + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 attention_mask = self._update_causal_mask( attention_mask, input_shape, @@ -903,18 +906,10 @@ def forward( hidden_states = inputs_embeds + positions hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -932,8 +927,6 @@ def forward( if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - layer_outputs = decoder_layer( hidden_states, attention_mask, @@ -941,15 +934,13 @@ def forward( encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -961,16 +952,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, @@ -1093,6 +1086,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.FloatTensor], Seq2SeqLMOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1185,6 +1179,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1249,6 +1244,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.FloatTensor], Seq2SeqLMOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1326,6 +1322,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) lm_logits = self.lm_head(outputs[0]) @@ -1350,14 +1347,5 @@ def forward( encoder_attentions=outputs.encoder_attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = ["Speech2TextForConditionalGeneration", "Speech2TextModel", "Speech2TextPreTrainedModel"] diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index 8f46fd00c8f8..00655c40608f 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, L1Loss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module @@ -809,7 +810,7 @@ def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, ): if input_ids is not None: input_shape = input_ids.size() @@ -817,7 +818,14 @@ def forward( else: raise ValueError("You have to specify `decoder_input_ids`") - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) + positions = self.embed_positions(input_ids, past_key_values_length) inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale @@ -853,9 +861,10 @@ def __init__( self, embed_dim: int, num_heads: int, - dropout: float = 0.0, - is_decoder: bool = False, - bias: bool = True, + dropout: Optional[float] = 0.0, + is_decoder: Optional[bool] = False, + bias: Optional[bool] = True, + layer_idx: Optional[bool] = None, ): super().__init__() self.embed_dim = embed_dim @@ -870,25 +879,24 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.layer_idx = layer_idx self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, position_bias: Optional[torch.Tensor] = None, output_attentions: bool = False, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer @@ -899,40 +907,44 @@ def forward( # get query proj query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - if is_cross_attention and past_key_value is not None: + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states.reshape(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) @@ -1000,7 +1012,7 @@ def forward( attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights_reshaped class SpeechT5FeedForward(nn.Module): @@ -1065,7 +1077,7 @@ def forward( returned tensors for more detail. """ residual = hidden_states - hidden_states, attn_weights, _ = self.attention( + hidden_states, attn_weights = self.attention( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, @@ -1089,13 +1101,14 @@ def forward( class SpeechT5DecoderLayer(GradientCheckpointingLayer): - def __init__(self, config: SpeechT5Config): + def __init__(self, config: SpeechT5Config, layer_idx=None): super().__init__() self.self_attn = SpeechT5Attention( embed_dim=config.hidden_size, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + layer_idx=layer_idx, ) self.dropout = nn.Dropout(config.hidden_dropout) self.self_attn_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -1105,6 +1118,7 @@ def __init__(self, config: SpeechT5Config): config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -1119,9 +1133,10 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ): """ Args: @@ -1144,43 +1159,36 @@ def forward( residual = hidden_states # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = self.dropout(hidden_states) hidden_states = residual + hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = self.dropout(hidden_states) hidden_states = residual + hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # Fully Connected hidden_states = hidden_states + self.feed_forward(hidden_states) hidden_states = self.final_layer_norm(hidden_states) @@ -1190,9 +1198,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -1481,7 +1486,7 @@ def __init__(self, config: SpeechT5Config): super().__init__(config) self.layerdrop = config.decoder_layerdrop - self.layers = nn.ModuleList([SpeechT5DecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layers = nn.ModuleList([SpeechT5DecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) self.gradient_checkpointing = False @@ -1501,6 +1506,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: r""" Args: @@ -1570,7 +1576,24 @@ def forward( input_shape = hidden_states.size()[:-1] - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 attention_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape, hidden_states, past_key_values_length @@ -1585,18 +1608,10 @@ def forward( synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1619,8 +1634,6 @@ def forward( if skip_the_layer and not synced_gpus: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - layer_outputs = decoder_layer( hidden_states, attention_mask, @@ -1628,15 +1641,12 @@ def forward( encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) - if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) @@ -1646,17 +1656,19 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -1691,6 +1703,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: decoder_hidden_states = self.prenet(input_values, speaker_embeddings) @@ -1706,6 +1719,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) return outputs @@ -1743,6 +1757,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: decoder_hidden_states, attention_mask = self.prenet(input_values, attention_mask, past_key_values) @@ -1758,6 +1773,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) return outputs @@ -1789,6 +1805,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: outputs = self.wrapped_decoder( hidden_states=input_values, @@ -1802,6 +1819,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) return outputs @@ -1996,6 +2014,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.FloatTensor], Seq2SeqModelOutput]: r""" input_values (`torch.Tensor` of shape `(batch_size, sequence_length)`): @@ -2070,6 +2089,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, **decoder_args, ) @@ -2152,6 +2172,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, labels: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple, Seq2SeqLMOutput]: r""" input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): @@ -2247,6 +2268,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, + cache_position=cache_position, ) logits = self.text_decoder_postnet(outputs[0]) @@ -2272,15 +2294,6 @@ def forward( encoder_attentions=outputs.encoder_attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - def _generate_speech( model: SpeechT5PreTrainedModel, @@ -2485,6 +2498,7 @@ def forward( speaker_embeddings: Optional[torch.FloatTensor] = None, labels: Optional[torch.FloatTensor] = None, stop_labels: Optional[torch.Tensor] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple, Seq2SeqSpectrogramOutput]: r""" input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): @@ -2567,6 +2581,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, + cache_position=cache_position, ) outputs_before_postnet, outputs_after_postnet, logits = self.speech_decoder_postnet(outputs[0]) @@ -2838,6 +2853,7 @@ def forward( speaker_embeddings: Optional[torch.FloatTensor] = None, labels: Optional[torch.FloatTensor] = None, stop_labels: Optional[torch.Tensor] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple, Seq2SeqSpectrogramOutput]: r""" input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): @@ -2925,6 +2941,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, + cache_position=cache_position, ) _, spectrogram, logits = self.speech_decoder_postnet(outputs[0]) diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index 3b4e8f560026..d11a1eb60d87 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -24,11 +24,7 @@ from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import ( - BaseModelOutput, - ModelOutput, - QuestionAnsweringModelOutput, -) +from ...modeling_outputs import BaseModelOutput, ModelOutput, QuestionAnsweringModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 66551b53cbdc..9e9511202863 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -619,9 +619,8 @@ class StableLmPreTrainedModel(PreTrainedModel): _no_split_modules = ["StableLmDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn = True - _supports_cache_class = True _supports_sdpa = True - _supports_quantized_cache = True + _supports_static_cache = True def _init_weights(self, module): diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index c3f10cea1171..9e95bd88b8ee 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -296,8 +296,7 @@ class Starcoder2PreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_attention_backend = True _can_record_outputs = { diff --git a/src/transformers/models/superglue/modeling_superglue.py b/src/transformers/models/superglue/modeling_superglue.py index ce92e7b66bb7..56506bc7a235 100644 --- a/src/transformers/models/superglue/modeling_superglue.py +++ b/src/transformers/models/superglue/modeling_superglue.py @@ -233,7 +233,6 @@ def forward( return hidden_state, all_hidden_states -# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->SuperGlue class SuperGlueSelfAttention(nn.Module): def __init__(self, config, position_embedding_type=None): super().__init__() @@ -261,11 +260,6 @@ def __init__(self, config, position_embedding_type=None): self.is_decoder = config.is_decoder - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states: torch.Tensor, @@ -273,58 +267,38 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) - # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) - else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - - query_layer = self.transpose_for_scores(mixed_query_layer) - - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) + current_states = encoder_hidden_states if is_cross_attention else hidden_states + attention_mask = encoder_attention_mask if is_cross_attention else encoder_attention_mask + + batch_size = hidden_states.shape[0] + key_layer = ( + self.key(current_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(current_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: - position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( - -1, 1 - ) - else: - position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) distance = position_ids_l - position_ids_r @@ -364,7 +338,7 @@ def forward( outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) if self.is_decoder: - outputs = outputs + (past_key_value,) + outputs = outputs + (None,) return outputs @@ -383,12 +357,12 @@ def forward(self, hidden_states: torch.Tensor, *args) -> torch.Tensor: } -# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->SuperGlue,BERT->SUPERGLUE class SuperGlueAttention(nn.Module): def __init__(self, config, position_embedding_type=None): super().__init__() self.self = SUPERGLUE_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, position_embedding_type=position_embedding_type + config, + position_embedding_type=position_embedding_type, ) self.output = SuperGlueSelfOutput(config) self.pruned_heads = set() @@ -417,18 +391,16 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them diff --git a/src/transformers/models/swin2sr/modeling_swin2sr.py b/src/transformers/models/swin2sr/modeling_swin2sr.py index 4b16bc954dc4..de61e5b2d259 100644 --- a/src/transformers/models/swin2sr/modeling_swin2sr.py +++ b/src/transformers/models/swin2sr/modeling_swin2sr.py @@ -296,11 +296,6 @@ def __init__(self, config, dim, num_heads, window_size, pretrained_window_size=[ self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states: torch.Tensor, @@ -309,11 +304,21 @@ def forward( output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor]: batch_size, dim, num_channels = hidden_states.shape - mixed_query_layer = self.query(hidden_states) - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # cosine attention attention_scores = nn.functional.normalize(query_layer, dim=-1) @ nn.functional.normalize( diff --git a/src/transformers/models/swinv2/modeling_swinv2.py b/src/transformers/models/swinv2/modeling_swinv2.py index d18c126fe4b2..14ec4791ac8d 100644 --- a/src/transformers/models/swinv2/modeling_swinv2.py +++ b/src/transformers/models/swinv2/modeling_swinv2.py @@ -467,11 +467,6 @@ def __init__(self, config, dim, num_heads, window_size, pretrained_window_size=[ self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states: torch.Tensor, @@ -480,11 +475,21 @@ def forward( output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor]: batch_size, dim, num_channels = hidden_states.shape - mixed_query_layer = self.query(hidden_states) - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # cosine attention attention_scores = nn.functional.normalize(query_layer, dim=-1) @ nn.functional.normalize( diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 63f1bfadeebe..7ede856b416b 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -765,7 +765,7 @@ class SwitchTransformersPreTrainedModel(PreTrainedModel): config_class = SwitchTransformersConfig base_model_prefix = "switch_transformers" supports_gradient_checkpointing = True - _supports_cache_class = True + _supports_static_cache = False _no_split_modules = ["SwitchTransformersBlock"] @@ -1720,38 +1720,6 @@ def _unpack_router_logits(self, router_outputs): def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return self._shift_right(labels) - def _reorder_cache(self, past_key_values, beam_idx): - # if decoder past is not included in output - # speedy decoding is disabled and no need to reorder - if past_key_values is None: - logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") - return past_key_values - - reordered_decoder_past = () - for layer_past_states in past_key_values: - # get the correct batch idx from layer past batch dim - # batch dim of `past` is at 2nd position - reordered_layer_past_states = () - for layer_past_state in layer_past_states: - # need to set correct `past` for each of the four key / value states - reordered_layer_past_states = reordered_layer_past_states + ( - layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), - ) - - if reordered_layer_past_states[0].shape != layer_past_states[0].shape: - raise ValueError( - "expected reordered_layer_past_states to have the same shape than layer_past_states, " - f"but got {reordered_layer_past_states[0].shape} and {layer_past_states[0].shape}" - ) - if len(reordered_layer_past_states) != len(layer_past_states): - raise ValueError( - "expected layer_past_states to have the same length as reordered_layer_past_states, " - f"but got {len(layer_past_states)} and {len(reordered_layer_past_states)}" - ) - - reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) - return reordered_decoder_past - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 216ea793fd87..23a43615b1e0 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -771,9 +771,8 @@ class T5PreTrainedModel(PreTrainedModel): base_model_prefix = "transformer" is_parallelizable = True supports_gradient_checkpointing = True - _supports_quantized_cache = False # enc-dec models don't support yet _supports_static_cache = True - _supports_cache_class = True + _no_split_modules = ["T5Block"] _keep_in_fp32_modules = ["wo"] @@ -1827,36 +1826,6 @@ def forward( def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return self._shift_right(labels) - def _reorder_cache(self, past_key_values, beam_idx): - # if decoder past is not included in output - # speedy decoding is disabled and no need to reorder - if past_key_values is None: - logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") - return past_key_values - - reordered_decoder_past = () - for layer_past_states in past_key_values: - # get the correct batch idx from layer past batch dim - # batch dim of `past` is at 2nd position - reordered_layer_past_states = () - for layer_past_state in layer_past_states: - # need to set correct `past` for each of the four key / value states - reordered_layer_past_states = reordered_layer_past_states + ( - layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), - ) - - if reordered_layer_past_states[0].shape != layer_past_states[0].shape: - raise ValueError( - f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched" - ) - if len(reordered_layer_past_states) != len(layer_past_states): - raise ValueError( - f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched" - ) - - reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) - return reordered_decoder_past - @auto_docstring class T5EncoderModel(T5PreTrainedModel): diff --git a/src/transformers/models/t5gemma/modeling_t5gemma.py b/src/transformers/models/t5gemma/modeling_t5gemma.py index f37995ae32ce..2cd8798883fa 100644 --- a/src/transformers/models/t5gemma/modeling_t5gemma.py +++ b/src/transformers/models/t5gemma/modeling_t5gemma.py @@ -584,8 +584,7 @@ class T5GemmaPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_attention_backend = True _can_record_outputs = { diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index 2629c82d78d1..8545bc1021c6 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -26,11 +26,13 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_tapas import TapasConfig @@ -279,7 +281,7 @@ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs class TapasSelfAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -297,12 +299,9 @@ def __init__(self, config): self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states, @@ -312,36 +311,59 @@ def forward( encoder_attention_mask=None, past_key_value=None, output_attentions=False, + cache_position=None, ): - mixed_query_layer = self.query(hidden_states) + batch_size, seq_length, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if is_cross_attention and encoder_attention_mask is not None: + attention_mask = encoder_attention_mask - if is_cross_attention and past_key_value is not None: + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - - query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = ( + self.key(current_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(current_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) - if self.is_decoder: - past_key_value = (key_layer, value_layer) + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -389,9 +411,9 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class TapasAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() - self.self = TapasSelfAttention(config) + self.self = TapasSelfAttention(config, layer_idx=layer_idx) self.output = TapasSelfOutput(config) self.pruned_heads = set() @@ -414,6 +436,7 @@ def prune_heads(self, heads): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") # Copied from transformers.models.bert.modeling_bert.BertAttention.forward def forward( self, @@ -422,17 +445,19 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -471,17 +496,17 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class TapasLayer(GradientCheckpointingLayer): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = TapasAttention(config) + self.attention = TapasAttention(config, layer_idx=layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = TapasAttention(config) + self.crossattention = TapasAttention(config, layer_idx=layer_idx) self.intermediate = TapasIntermediate(config) self.output = TapasOutput(config) @@ -493,28 +518,21 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -522,33 +540,23 @@ def forward( " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - cross_attn_past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (present_key_value,) - return outputs # Copied from transformers.models.bert.modeling_bert.BertLayer.feed_forward_chunk @@ -562,7 +570,7 @@ class TapasEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config - self.layer = nn.ModuleList([TapasLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([TapasLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( @@ -577,7 +585,16 @@ def forward( output_attentions=False, output_hidden_states=False, return_dict=True, + cache_position=None, ): + if use_cache and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None for i, layer_module in enumerate(self.layer): @@ -594,6 +611,7 @@ def forward( encoder_attention_mask=encoder_attention_mask, past_key_value=past_key_values, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if output_attentions: diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 0732ec7a3c64..3e74c55b8feb 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -432,7 +432,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights # Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->TimeSeriesTransformer, BART->TIME_SERIES_TRANSFORMER @@ -475,7 +475,7 @@ def forward( returned tensors for more detail. """ residual = hidden_states - hidden_states, attn_weights, _ = self.self_attn( + hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, @@ -577,7 +577,7 @@ def forward( residual = hidden_states # Self Attention - hidden_states, self_attn_weights, past_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -594,7 +594,7 @@ def forward( if encoder_hidden_states is not None: residual = hidden_states - hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, @@ -621,9 +621,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - if use_cache: - outputs += (past_key_value,) - return outputs @@ -1038,7 +1035,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1072,9 +1068,6 @@ def forward( ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[3 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1085,19 +1078,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: - next_cache = past_key_values.to_legacy_cache() + past_key_values = past_key_values.to_legacy_cache() if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index fdc0cae068a8..ca10d41d6ce6 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -23,6 +23,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask, @@ -150,10 +151,11 @@ def __init__( num_heads: int, kdim: Optional[int] = None, vdim: Optional[int] = None, - dropout: float = 0.0, - is_decoder: bool = False, - bias: bool = True, - is_cross_attention: bool = False, + dropout: Optional[float] = 0.0, + is_decoder: Optional[bool] = False, + bias: Optional[bool] = True, + is_cross_attention: Optional[bool] = False, + layer_idx: Optional[bool] = None, ): super().__init__() self.embed_dim = embed_dim @@ -169,6 +171,7 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.layer_idx = layer_idx self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias) self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias) @@ -176,17 +179,15 @@ def __init__( self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -197,40 +198,44 @@ def forward( # get query proj query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - if is_cross_attention and past_key_value is not None: + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states.reshape(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) @@ -286,11 +291,11 @@ def forward( attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights_reshaped class TrOCRDecoderLayer(GradientCheckpointingLayer): - def __init__(self, config: TrOCRConfig): + def __init__(self, config: TrOCRConfig, layer_idx=None): super().__init__() self.embed_dim = config.hidden_size @@ -300,6 +305,7 @@ def __init__(self, config: TrOCRConfig): num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -317,6 +323,7 @@ def __init__(self, config: TrOCRConfig): dropout=config.attention_dropout, is_decoder=True, is_cross_attention=True, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) @@ -332,9 +339,10 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ): """ Args: @@ -357,15 +365,13 @@ def forward( residual = hidden_states # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -373,30 +379,24 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None - if encoder_hidden_states is not None: residual = hidden_states - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # Fully Connected residual = hidden_states hidden_states = self.activation_fn(self.fc1(hidden_states)) @@ -412,9 +412,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -470,7 +467,7 @@ def __init__(self, config: TrOCRConfig): else: self.layernorm_embedding = None - self.layers = nn.ModuleList([TrOCRDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layers = nn.ModuleList([TrOCRDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -496,6 +493,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, + cache_position=None, ): r""" Args: @@ -581,8 +579,24 @@ def forward( else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..." + ) + use_cache = False + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -612,18 +626,10 @@ def forward( encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] ) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -642,8 +648,6 @@ def forward( if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - layer_outputs = decoder_layer( hidden_states, attention_mask, @@ -651,15 +655,13 @@ def forward( encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -670,16 +672,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, @@ -756,6 +760,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple, CausalLMOutputWithCrossAttentions]: r""" cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): @@ -836,6 +841,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) logits = self.output_projection(outputs[0]) @@ -858,14 +864,5 @@ def forward( cross_attentions=outputs.cross_attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = ["TrOCRForCausalLM", "TrOCRPreTrainedModel"] diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 23af62e4a1dc..a0848a667b48 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -254,7 +254,7 @@ class UdopPreTrainedModel(PreTrainedModel): config_class = UdopConfig base_model_prefix = "transformer" supports_gradient_checkpointing = True - _supports_cache_class = True + _supports_static_cache = False _keep_in_fp32_modules = ["wo"] @@ -1887,37 +1887,6 @@ def forward( encoder_attentions=encoder_outputs.attentions, ) - # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration._reorder_cache - def _reorder_cache(self, past_key_values, beam_idx): - # if decoder past is not included in output - # speedy decoding is disabled and no need to reorder - if past_key_values is None: - logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") - return past_key_values - - reordered_decoder_past = () - for layer_past_states in past_key_values: - # get the correct batch idx from layer past batch dim - # batch dim of `past` is at 2nd position - reordered_layer_past_states = () - for layer_past_state in layer_past_states: - # need to set correct `past` for each of the four key / value states - reordered_layer_past_states = reordered_layer_past_states + ( - layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), - ) - - if reordered_layer_past_states[0].shape != layer_past_states[0].shape: - raise ValueError( - f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched" - ) - if len(reordered_layer_past_states) != len(layer_past_states): - raise ValueError( - f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched" - ) - - reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) - return reordered_decoder_past - @auto_docstring class UdopEncoderModel(UdopPreTrainedModel): diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index 4b9d96db2f21..62bbd9d7f7c6 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -507,7 +507,7 @@ class UMT5PreTrainedModel(PreTrainedModel): config_class = UMT5Config base_model_prefix = "transformer" supports_gradient_checkpointing = True - _supports_cache_class = True + _supports_static_cache = True _no_split_modules = ["UMT5Block"] _keep_in_fp32_modules = ["wo"] @@ -1388,15 +1388,6 @@ def forward( def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return self._shift_right(labels) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @auto_docstring class UMT5EncoderModel(UMT5PreTrainedModel): diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index d1c2c0e00a81..195923642e57 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -338,7 +338,6 @@ def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 21bd79613da1..62aaff87b709 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -343,7 +343,6 @@ def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index 8ed85c94e53c..b7f55915f1a3 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -131,10 +131,10 @@ class VideoLlavaPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["VideoLlavaVisionAttention"] _skip_keys_device_placement = "past_key_values" - _supports_cache_class = True + _supports_flash_attn = True _supports_sdpa = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_attention_backend = True diff --git a/src/transformers/models/videomae/modeling_videomae.py b/src/transformers/models/videomae/modeling_videomae.py index 39784ca889ce..d7c93bdda0d9 100755 --- a/src/transformers/models/videomae/modeling_videomae.py +++ b/src/transformers/models/videomae/modeling_videomae.py @@ -239,22 +239,18 @@ def __init__(self, config: VideoMAEConfig) -> None: self.q_bias = None self.v_bias = None - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: + batch_size, seq_length, _ = hidden_states.shape k_bias = torch.zeros_like(self.v_bias, requires_grad=False) if self.q_bias is not None else None keys = nn.functional.linear(input=hidden_states, weight=self.key.weight, bias=k_bias) values = nn.functional.linear(input=hidden_states, weight=self.value.weight, bias=self.v_bias) queries = nn.functional.linear(input=hidden_states, weight=self.query.weight, bias=self.q_bias) - key_layer = self.transpose_for_scores(keys) - value_layer = self.transpose_for_scores(values) - query_layer = self.transpose_for_scores(queries) + key_layer = keys.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = values.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + query_layer = queries.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/vilt/modeling_vilt.py b/src/transformers/models/vilt/modeling_vilt.py index f54cc65822d8..2600605fc604 100755 --- a/src/transformers/models/vilt/modeling_vilt.py +++ b/src/transformers/models/vilt/modeling_vilt.py @@ -328,17 +328,23 @@ def __init__(self, config): self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False): - mixed_query_layer = self.query(hidden_states) - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + batch_size, seq_length, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 94ced611b0f0..9a9c986562d0 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -118,10 +118,10 @@ class VipLlavaPreTrainedModel(PreTrainedModel): base_model_prefix = "" supports_gradient_checkpointing = True _skip_keys_device_placement = "past_key_values" - _supports_cache_class = True + _supports_flash_attn = True _supports_sdpa = True - _supports_quantized_cache = True + _supports_static_cache = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py index 6bda9bf1221d..70d3ccedee3e 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py @@ -597,9 +597,5 @@ def forward( def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) - def _reorder_cache(self, past_key_values, beam_idx): - # apply decoder cache reordering here - return self.decoder._reorder_cache(past_key_values, beam_idx) - __all__ = ["VisionEncoderDecoderModel"] diff --git a/src/transformers/models/visual_bert/modeling_visual_bert.py b/src/transformers/models/visual_bert/modeling_visual_bert.py index 305cc68a39ec..255406c6ce2f 100755 --- a/src/transformers/models/visual_bert/modeling_visual_bert.py +++ b/src/transformers/models/visual_bert/modeling_visual_bert.py @@ -193,11 +193,6 @@ def __init__(self, config): self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states, @@ -205,12 +200,22 @@ def forward( head_mask=None, output_attentions=False, ): - mixed_query_layer = self.query(hidden_states) - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - - query_layer = self.transpose_for_scores(mixed_query_layer) + batch_size, seq_length, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -1367,21 +1372,18 @@ def __init__(self, config): self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - def forward(self, query, key, attention_mask): + batch_size, seq_length, _ = query.shape attention_mask = attention_mask.to(query.dtype) attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) attention_mask = (1.0 - attention_mask) * torch.finfo(query.dtype).min - mixed_query_layer = self.query(query) - mixed_key_layer = self.key(key) - - query_layer = self.transpose_for_scores(mixed_query_layer) - key_layer = self.transpose_for_scores(mixed_key_layer) + query_layer = ( + self.query(query).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + ) + key_layer = ( + self.key(key).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + ) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index 8e38f83cacaa..58738c006345 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -218,17 +218,28 @@ def __init__(self, config: ViTConfig) -> None: self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( - self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + self, + hidden_states, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(self.query(hidden_states)) + batch_size, seq_length, _ = hidden_states.shape + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index ddd582eca2a0..31ee70ad3462 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -378,17 +378,28 @@ def __init__(self, config: ViTMAEConfig) -> None: self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( - self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + self, + hidden_states, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(self.query(hidden_states)) + batch_size, seq_length, _ = hidden_states.shape + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/vit_msn/modeling_vit_msn.py b/src/transformers/models/vit_msn/modeling_vit_msn.py index 0c3da4fffa37..c3640dadef02 100644 --- a/src/transformers/models/vit_msn/modeling_vit_msn.py +++ b/src/transformers/models/vit_msn/modeling_vit_msn.py @@ -216,17 +216,28 @@ def __init__(self, config: ViTMSNConfig) -> None: self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( - self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + self, + hidden_states, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(self.query(hidden_states)) + batch_size, seq_length, _ = hidden_states.shape + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py b/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py index af2ca9825e48..86d85bb53bbe 100644 --- a/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +++ b/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py @@ -148,17 +148,28 @@ def __init__(self, config: VitPoseBackboneConfig) -> None: self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( - self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + self, + hidden_states, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(self.query(hidden_states)) + batch_size, seq_length, _ = hidden_states.shape + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/vivit/modeling_vivit.py b/src/transformers/models/vivit/modeling_vivit.py index ca7c3046a52c..54f07e0be0a6 100755 --- a/src/transformers/models/vivit/modeling_vivit.py +++ b/src/transformers/models/vivit/modeling_vivit.py @@ -209,17 +209,28 @@ def __init__(self, config: VivitConfig) -> None: self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( - self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + self, + hidden_states, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(self.query(hidden_states)) + batch_size, seq_length, _ = hidden_states.shape + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/vjepa2/modeling_vjepa2.py b/src/transformers/models/vjepa2/modeling_vjepa2.py index 52cb2c7baa17..0a7326b916fa 100644 --- a/src/transformers/models/vjepa2/modeling_vjepa2.py +++ b/src/transformers/models/vjepa2/modeling_vjepa2.py @@ -243,14 +243,6 @@ def __init__( self.scaling = self.attention_head_size**-0.5 self.is_causal = False - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + ( - self.num_attention_heads, - self.attention_head_size, - ) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def _get_frame_pos(self, ids): tokens_per_frame = int(self.grid_size * self.grid_size) return ids // tokens_per_frame @@ -309,11 +301,22 @@ def forward( output_attentions: bool = False, head_mask: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: - mixed_query_layer = self.query(hidden_states) - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + batch_size, seq_length, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) pos_ids = self.get_position_ids(hidden_states, masks=position_mask) key_layer = self.apply_rotary_embeddings(key_layer, pos_ids) diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 88ca7c254095..a5ee3378c119 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -530,7 +530,6 @@ def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 1bed6ce27b46..f72ce7bd40ba 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -1152,7 +1152,7 @@ def split_by_batch_index(values, key, batch_idx, is_shortform, beam_indices=None for v in [cache_cls.key_cache, cache_cls.value_cache]: layer_past_key_values.append(v[layer_idx][batch_idx][None].cpu()) all_past_key_values.append(tuple(layer_past_key_values)) - return tuple(all_past_key_values) + return EncoderDecoderCache.from_legacy_cache(tuple(all_past_key_values)) else: all_past_key_values = [] for v in range(len(values)): @@ -1199,7 +1199,6 @@ def _stack_split_outputs(self, seek_outputs, model_output_type, device, kwargs): for i in range(len(seek_outputs[0][key])) ) elif key == "past_key_values": - past_key_value_type = kwargs.get("past_key_values") if seek_outputs[0][key] is not None: outputs[key] = tuple( tuple( @@ -1208,8 +1207,8 @@ def _stack_split_outputs(self, seek_outputs, model_output_type, device, kwargs): ) for i in range(len(seek_outputs[0][key])) ) - if past_key_value_type is not None and isinstance(past_key_value_type, EncoderDecoderCache): - outputs[key] = past_key_value_type.from_legacy_cache(outputs[key]) + if isinstance(seek_outputs[0][key], EncoderDecoderCache): + outputs[key] = EncoderDecoderCache.from_legacy_cache(outputs[key]) else: outputs[key] = None diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 1d5c808517cc..a03353c68878 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -364,7 +364,7 @@ def forward( attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper, MBART->WHISPER @@ -407,7 +407,7 @@ def forward( """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states, attn_weights, _ = self.self_attn( + hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, @@ -428,12 +428,7 @@ def forward( clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs + return hidden_states, attn_weights class WhisperDecoderLayer(GradientCheckpointingLayer): @@ -503,7 +498,7 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, @@ -519,7 +514,7 @@ def forward( if encoder_hidden_states is not None: residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, @@ -530,9 +525,6 @@ def forward( hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states - # add cross-attn to positions 1 of present_key_value tuple - present_key_value = (present_key_value, cross_attn_present_key_value) - # Fully Connected residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) @@ -547,9 +539,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -563,7 +552,7 @@ class WhisperPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): @@ -1520,15 +1509,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index 65e3e6284046..15fc5c2178a3 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -22,6 +22,7 @@ from torch import nn from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_layers import GradientCheckpointingLayer @@ -107,9 +108,10 @@ def __init__( self, embed_dim: int, num_heads: int, - dropout: float = 0.0, - is_decoder: bool = False, - bias: bool = True, + dropout: Optional[float] = 0.0, + is_decoder: Optional[bool] = False, + bias: Optional[bool] = True, + layer_idx: Optional[bool] = None, ): super().__init__() self.embed_dim = embed_dim @@ -124,23 +126,22 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.layer_idx = layer_idx self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -149,43 +150,48 @@ def forward( is_cross_attention = key_value_states is not None bsz, tgt_len, _ = hidden_states.size() + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len # get query proj query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - if is_cross_attention and past_key_value is not None: + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, src_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, src_len, -1, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states.reshape(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) @@ -251,11 +257,11 @@ def forward( attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights_reshaped class XGLMDecoderLayer(GradientCheckpointingLayer): - def __init__(self, config: XGLMConfig): + def __init__(self, config: XGLMConfig, layer_idx=None): super().__init__() self.embed_dim = config.d_model @@ -264,6 +270,7 @@ def __init__(self, config: XGLMConfig): num_heads=config.attention_heads, dropout=config.attention_dropout, is_decoder=True, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -275,6 +282,7 @@ def __init__(self, config: XGLMConfig): num_heads=config.attention_heads, dropout=config.attention_dropout, is_decoder=True, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) @@ -292,9 +300,10 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Args: @@ -318,42 +327,35 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # Fully Connected residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) @@ -368,9 +370,6 @@ def forward( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -419,7 +418,7 @@ def __init__(self, config: XGLMConfig, embed_tokens: Optional[nn.Embedding] = No config.d_model, config.pad_token_id, ) - self.layers = nn.ModuleList([XGLMDecoderLayer(config) for _ in range(config.num_layers)]) + self.layers = nn.ModuleList([XGLMDecoderLayer(config, layer_idx=i) for i in range(config.num_layers)]) self.layer_norm = nn.LayerNorm(config.d_model) self.gradient_checkpointing = False @@ -448,6 +447,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: r""" encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): @@ -486,7 +486,31 @@ def forward( else: raise ValueError("You have to specify either input_ids or inputs_embeds") - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache = True` is incompatible with gradient checkpointing`. Setting `use_cache = False`..." + ) + use_cache = False + + # initialize `past_key_values` + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) if position_ids is None: position_ids = torch.arange( @@ -497,13 +521,6 @@ def forward( ) position_ids = position_ids.unsqueeze(0) - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -516,18 +533,10 @@ def forward( ) hidden_states = nn.functional.dropout(hidden_states, p=float(self.dropout), training=self.training) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache = True` is incompatible with gradient checkpointing`. Setting `use_cache = False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -546,8 +555,6 @@ def forward( if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - layer_outputs = decoder_layer( hidden_states, attention_mask, @@ -555,15 +562,13 @@ def forward( encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -576,16 +581,18 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, @@ -639,6 +646,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs, ) -> Union[tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: r""" @@ -685,6 +693,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) logits = self.lm_head(outputs[0]) @@ -712,14 +721,5 @@ def forward( cross_attentions=outputs.cross_attentions, ) - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = ["XGLMForCausalLM", "XGLMModel", "XGLMPreTrainedModel"] diff --git a/src/transformers/models/xlm/modeling_xlm.py b/src/transformers/models/xlm/modeling_xlm.py index b823c67227d1..b92b3d06d2e3 100755 --- a/src/transformers/models/xlm/modeling_xlm.py +++ b/src/transformers/models/xlm/modeling_xlm.py @@ -27,6 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import gelu, get_activation +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, @@ -501,6 +502,7 @@ def __init__(self, n_heads, dim, config): self.layer_id = next(MultiHeadAttention.NEW_ID) self.dim = dim self.n_heads = n_heads + self.head_dim = dim // n_heads self.dropout = config.attention_dropout assert self.dim % self.n_heads == 0 @@ -525,50 +527,57 @@ def prune_heads(self, heads): self.dim = attention_head_size * self.n_heads self.pruned_heads = self.pruned_heads.union(heads) - def forward(self, input, mask, kv=None, cache=None, head_mask=None, output_attentions=False): + def forward( + self, + input, + mask, + kv=None, + cache=None, + head_mask=None, + output_attentions=False, + cache_position=None, + ): """ Self-attention (if kv is None) or attention over source sentence (provided by kv). """ # Input is (bs, qlen, dim) # Mask is (bs, klen) (non-causal) or (bs, klen, klen) bs, qlen, dim = input.size() - if kv is None: - klen = qlen if cache is None else cache["slen"] + qlen - else: - klen = kv.size(1) - # assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured' - n_heads = self.n_heads - dim_per_head = self.dim // n_heads - mask_reshape = (bs, 1, qlen, klen) if mask.dim() == 3 else (bs, 1, 1, klen) - - def shape(x): - """projection""" - return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2) - - def unshape(x): - """compute context""" - return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head) - - q = shape(self.q_lin(input)) # (bs, n_heads, qlen, dim_per_head) - if kv is None: - k = shape(self.k_lin(input)) # (bs, n_heads, qlen, dim_per_head) - v = shape(self.v_lin(input)) # (bs, n_heads, qlen, dim_per_head) - elif cache is None or self.layer_id not in cache: - k = v = kv - k = shape(self.k_lin(k)) # (bs, n_heads, qlen, dim_per_head) - v = shape(self.v_lin(v)) # (bs, n_heads, qlen, dim_per_head) + is_cross_attention = kv is not None + mask_reshape = (bs, 1, qlen, -1) if mask.dim() == 3 else (bs, 1, 1, -1) + q = self.q_lin(input).view(bs, -1, self.n_heads, self.head_dim).transpose(1, 2) if cache is not None: - if self.layer_id in cache: - if kv is None: - k_, v_ = cache[self.layer_id] - k = torch.cat([k_, k], dim=2) # (bs, n_heads, klen, dim_per_head) - v = torch.cat([v_, v], dim=2) # (bs, n_heads, klen, dim_per_head) + if isinstance(cache, EncoderDecoderCache): + is_updated = cache.is_updated.get(self.layer_id) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = cache.cross_attention_cache else: - k, v = cache[self.layer_id] - cache[self.layer_id] = (k, v) + curr_past_key_value = cache.self_attention_cache + else: + curr_past_key_value = cache - q = q / math.sqrt(dim_per_head) # (bs, n_heads, qlen, dim_per_head) + current_states = kv if is_cross_attention else input + if is_cross_attention and cache is not None and is_updated: + # reuse k,v, cross_attentions + k = curr_past_key_value.key_cache[self.layer_id] + v = curr_past_key_value.value_cache[self.layer_id] + else: + k = self.k_lin(current_states) + v = self.v_lin(current_states) + k = k.view(bs, -1, self.n_heads, self.head_dim).transpose(1, 2) + v = v.view(bs, -1, self.n_heads, self.head_dim).transpose(1, 2) + + if cache is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + k, v = curr_past_key_value.update(k, v, self.layer_id, {"cache_position": cache_position}) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + cache.is_updated[self.layer_id] = True + + q = q / math.sqrt(self.head_dim) # (bs, n_heads, qlen, head_dim) scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, qlen, klen) mask = (mask == 0).view(mask_reshape).expand_as(scores) # (bs, n_heads, qlen, klen) scores.masked_fill_(mask, torch.finfo(scores.dtype).min) # (bs, n_heads, qlen, klen) @@ -580,8 +589,8 @@ def unshape(x): if head_mask is not None: weights = weights * head_mask - context = torch.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head) - context = unshape(context) # (bs, qlen, dim) + context = torch.matmul(weights, v) # (bs, n_heads, qlen, head_dim) + context = context.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * self.head_dim) outputs = (self.out_lin(context),) if output_attentions: @@ -785,6 +794,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs, # Dummy kwargs for now ) -> Union[tuple, BaseModelOutput]: r""" @@ -821,45 +831,38 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device + if not isinstance(cache, Cache): + cache = EncoderDecoderCache.from_legacy_cache(cache) + if lengths is None: if input_ids is not None: lengths = (input_ids != self.pad_index).sum(dim=1).long() else: lengths = torch.tensor([slen] * bs, device=device) - # mask = input_ids != self.pad_index # check inputs assert lengths.size(0) == bs assert lengths.max().item() <= slen - # input_ids = input_ids.transpose(0, 1) # batch size as dimension 0 - # assert (src_enc is None) == (src_len is None) - # if src_enc is not None: - # assert self.is_decoder - # assert src_enc.size(0) == bs # generate masks mask, attn_mask = get_masks(slen, lengths, self.causal, padding_mask=attention_mask) - # if self.is_decoder and src_enc is not None: - # src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None] # position_ids if position_ids is None: position_ids = self.position_ids[:, :slen] else: assert position_ids.size() == (bs, slen) # (slen, bs) - # position_ids = position_ids.transpose(0, 1) # langs if langs is not None: assert langs.size() == (bs, slen) # (slen, bs) - # langs = langs.transpose(0, 1) # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.n_layers) # do not recompute cached elements if cache is not None and input_ids is not None: - _slen = slen - cache["slen"] + _slen = slen - cache.get_seq_length() input_ids = input_ids[:, -_slen:] position_ids = position_ids[:, -_slen:] if langs is not None: @@ -894,6 +897,7 @@ def forward( cache=cache, head_mask=head_mask[i], output_attentions=output_attentions, + cache_position=cache_position, ) attn = attn_outputs[0] if output_attentions: @@ -902,13 +906,6 @@ def forward( tensor = tensor + attn tensor = self.layer_norm1[i](tensor) - # encoder attention (for decoder only) - # if self.is_decoder and src_enc is not None: - # attn = self.encoder_attn[i](tensor, src_mask, kv=src_enc, cache=cache) - # attn = nn.functional.dropout(attn, p=self.dropout, training=self.training) - # tensor = tensor + attn - # tensor = self.layer_norm15[i](tensor) - # FFN tensor = tensor + self.ffns[i](tensor) tensor = self.layer_norm2[i](tensor) @@ -918,13 +915,6 @@ def forward( if output_hidden_states: hidden_states = hidden_states + (tensor,) - # update cache length - if cache is not None: - cache["slen"] += tensor.size(1) - - # move back sequence length to dimension 0 - # tensor = tensor.transpose(0, 1) - if not return_dict: return tuple(v for v in [tensor, hidden_states, attentions] if v is not None) return BaseModelOutput(last_hidden_state=tensor, hidden_states=hidden_states, attentions=attentions) @@ -1026,6 +1016,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs, ) -> Union[tuple, MaskedLMOutput]: r""" @@ -1068,6 +1059,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, **kwargs, ) diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index f66396b135c5..43c6680da6a2 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_layers import GradientCheckpointingLayer @@ -41,6 +42,7 @@ from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, get_torch_version, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_xlm_roberta import XLMRobertaConfig @@ -139,7 +141,7 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds): # Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->XLMRoberta class XLMRobertaSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -164,12 +166,9 @@ def __init__(self, config, position_embedding_type=None): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, @@ -177,53 +176,65 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if is_cross_attention and encoder_attention_mask is not None: + attention_mask = encoder_attention_mask - if is_cross_attention and past_key_value is not None: + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - - query_layer = self.transpose_for_scores(mixed_query_layer) - - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) + key_layer = self.key(current_states) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + value_layer = self.value(current_states) + value_layer = value_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: + if past_key_value is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -265,21 +276,18 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return context_layer, attention_probs # Copied from transformers.models.roberta.modeling_roberta.RobertaSdpaSelfAttention with Roberta->XLMRoberta class XLMRobertaSdpaSelfAttention(XLMRobertaSelfAttention): - def __init__(self, config, position_embedding_type=None): - super().__init__(config, position_embedding_type=position_embedding_type) + def __init__(self, config, position_embedding_type=None, layer_idx=None): + super().__init__(config, position_embedding_type=position_embedding_type, layer_idx=layer_idx) self.dropout_prob = config.attention_probs_dropout_prob self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") # Adapted from XLMRobertaSelfAttention + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, @@ -287,8 +295,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None: # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. @@ -307,38 +316,59 @@ def forward( encoder_attention_mask, past_key_value, output_attentions, + cache_position, ) bsz, tgt_len, _ = hidden_states.size() - query_layer = self.transpose_for_scores(self.query(hidden_states)) + query_layer = ( + self.query(hidden_states).view(bsz, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + ) # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention # mask needs to be such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if is_cross_attention and encoder_attention_mask is not None: + attention_mask = encoder_attention_mask current_states = encoder_hidden_states if is_cross_attention else hidden_states - attention_mask = encoder_attention_mask if is_cross_attention else attention_mask + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value - # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning - if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: - key_layer, value_layer = past_key_value + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) - if past_key_value is not None and not is_cross_attention: - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) + key_layer = ( + self.key(current_states) + .view(bsz, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(current_states) + .view(bsz, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. @@ -368,10 +398,7 @@ def forward( attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) - outputs = (attn_output,) - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return attn_output, None # Copied from transformers.models.roberta.modeling_roberta.RobertaSelfOutput with Roberta->XLMRoberta @@ -397,10 +424,12 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.roberta.modeling_roberta.RobertaAttention with Roberta->XLMRoberta,ROBERTA->XLM_ROBERTA class XLMRobertaAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() self.self = XLM_ROBERTA_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, position_embedding_type=position_embedding_type + config, + position_embedding_type=position_embedding_type, + layer_idx=layer_idx, ) self.output = XLMRobertaSelfOutput(config) self.pruned_heads = set() @@ -423,6 +452,7 @@ def prune_heads(self, heads): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, @@ -430,17 +460,19 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -480,17 +512,17 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.roberta.modeling_roberta.RobertaLayer with Roberta->XLMRoberta class XLMRobertaLayer(GradientCheckpointingLayer): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = XLMRobertaAttention(config) + self.attention = XLMRobertaAttention(config, layer_idx=layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = XLMRobertaAttention(config, position_embedding_type="absolute") + self.crossattention = XLMRobertaAttention(config, position_embedding_type="absolute", layer_idx=layer_idx) self.intermediate = XLMRobertaIntermediate(config) self.output = XLMRobertaOutput(config) @@ -501,28 +533,21 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -530,33 +555,23 @@ def forward( " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - cross_attn_past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (present_key_value,) - return outputs def feed_forward_chunk(self, attention_output): @@ -567,10 +582,10 @@ def feed_forward_chunk(self, attention_output): # Copied from transformers.models.roberta.modeling_roberta.RobertaEncoder with Roberta->XLMRoberta class XLMRobertaEncoder(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.config = config - self.layer = nn.ModuleList([XLMRobertaLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([XLMRobertaLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( @@ -585,6 +600,7 @@ def forward( output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -597,13 +613,21 @@ def forward( ) use_cache = False - next_decoder_cache = () if use_cache else None + return_legacy_cache = False + if use_cache and self.config.is_decoder and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None layer_outputs = layer_module( hidden_states, @@ -611,13 +635,12 @@ def forward( layer_head_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -626,12 +649,15 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -640,7 +666,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -746,6 +772,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -771,8 +798,13 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if token_type_ids is None: if hasattr(self.embeddings, "token_type_ids"): @@ -857,6 +889,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None @@ -998,14 +1031,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @auto_docstring # Copied from transformers.models.roberta.modeling_roberta.RobertaForMaskedLM with Roberta->XLMRoberta, ROBERTA->XLM_ROBERTA diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index 7cbeaadb184c..7e3592847bdb 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_layers import GradientCheckpointingLayer @@ -40,6 +41,7 @@ from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, get_torch_version, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_xlm_roberta_xl import XLMRobertaXLConfig @@ -136,7 +138,7 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds): # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->XLMRobertaXL class XLMRobertaXLSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -161,12 +163,9 @@ def __init__(self, config, position_embedding_type=None): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, @@ -174,53 +173,65 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if is_cross_attention and encoder_attention_mask is not None: + attention_mask = encoder_attention_mask - if is_cross_attention and past_key_value is not None: + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - - query_layer = self.transpose_for_scores(mixed_query_layer) - - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) + key_layer = self.key(current_states) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + value_layer = self.value(current_states) + value_layer = value_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: + if past_key_value is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -262,21 +273,18 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return context_layer, attention_probs # Copied from transformers.models.bert.modeling_bert.BertSdpaSelfAttention with Bert->XLMRobertaXL class XLMRobertaXLSdpaSelfAttention(XLMRobertaXLSelfAttention): - def __init__(self, config, position_embedding_type=None): - super().__init__(config, position_embedding_type=position_embedding_type) + def __init__(self, config, position_embedding_type=None, layer_idx=None): + super().__init__(config, position_embedding_type=position_embedding_type, layer_idx=layer_idx) self.dropout_prob = config.attention_probs_dropout_prob self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") # Adapted from XLMRobertaXLSelfAttention + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, @@ -284,8 +292,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None: # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. @@ -304,38 +313,59 @@ def forward( encoder_attention_mask, past_key_value, output_attentions, + cache_position, ) bsz, tgt_len, _ = hidden_states.size() - query_layer = self.transpose_for_scores(self.query(hidden_states)) + query_layer = ( + self.query(hidden_states).view(bsz, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + ) # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention # mask needs to be such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if is_cross_attention and encoder_attention_mask is not None: + attention_mask = encoder_attention_mask current_states = encoder_hidden_states if is_cross_attention else hidden_states - attention_mask = encoder_attention_mask if is_cross_attention else attention_mask + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value - # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning - if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: - key_layer, value_layer = past_key_value + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) - if past_key_value is not None and not is_cross_attention: - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) + key_layer = ( + self.key(current_states) + .view(bsz, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(current_states) + .view(bsz, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. @@ -365,10 +395,7 @@ def forward( attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) - outputs = (attn_output,) - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return attn_output, None class XLMRobertaXLSelfOutput(nn.Module): @@ -391,11 +418,13 @@ def forward(self, hidden_states, input_tensor): class XLMRobertaXLAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() self.self_attn_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.self = XLMROBERTAXL_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, position_embedding_type=position_embedding_type + config, + position_embedding_type=position_embedding_type, + layer_idx=layer_idx, ) self.output = XLMRobertaXLSelfOutput(config) self.pruned_heads = set() @@ -418,6 +447,7 @@ def prune_heads(self, heads): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states, @@ -427,6 +457,7 @@ def forward( encoder_attention_mask=None, past_key_value=None, output_attentions=False, + cache_position=None, ): intermediate = self.self_attn_layer_norm(hidden_states) self_outputs = self.self( @@ -437,6 +468,7 @@ def forward( encoder_attention_mask, past_key_value, output_attentions, + cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -471,17 +503,19 @@ def forward(self, hidden_states, input_tensor): class XLMRobertaXLLayer(GradientCheckpointingLayer): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = XLMRobertaXLAttention(config) + self.attention = XLMRobertaXLAttention(config, layer_idx=layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = XLMRobertaXLAttention(config, position_embedding_type="absolute") + self.crossattention = XLMRobertaXLAttention( + config, position_embedding_type="absolute", layer_idx=layer_idx + ) self.intermediate = XLMRobertaXLIntermediate(config) self.output = XLMRobertaXLOutput(config) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -495,26 +529,19 @@ def forward( encoder_attention_mask=None, past_key_value=None, output_attentions=False, + cache_position=None, ): - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -522,34 +549,22 @@ def forward( " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - cross_attn_past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) - outputs = (layer_output,) + outputs - - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (present_key_value,) - - return outputs + return (layer_output,) + outputs def feed_forward_chunk(self, attention_output): intermediate_output = self.LayerNorm(attention_output) @@ -562,7 +577,7 @@ class XLMRobertaXLEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config - self.layer = nn.ModuleList([XLMRobertaXLLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([XLMRobertaXLLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.gradient_checkpointing = False @@ -578,6 +593,7 @@ def forward( output_attentions=False, output_hidden_states=False, return_dict=True, + cache_position=None, ): if self.gradient_checkpointing and self.training: if use_cache: @@ -585,17 +601,26 @@ def forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - next_decoder_cache = () if use_cache else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None layer_outputs = layer_module( hidden_states, @@ -603,13 +628,12 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -620,12 +644,15 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -634,7 +661,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -738,6 +765,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -763,8 +791,13 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if token_type_ids is None: if hasattr(self.embeddings, "token_type_ids"): @@ -849,6 +882,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None @@ -994,7 +1028,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] + past_length = past_key_values.get_seq_length() # Some generation methods already pass only the last input ID if input_ids.shape[1] > past_length: @@ -1014,14 +1048,6 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti "past_key_values": past_key_values, } - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @auto_docstring class XLMRobertaXLForMaskedLM(XLMRobertaXLPreTrainedModel): diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index 84cf9f6d5349..6266ec88f545 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -23,6 +23,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( @@ -38,6 +39,7 @@ from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_xmod import XmodConfig @@ -136,7 +138,7 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds): # Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->Xmod class XmodSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -161,12 +163,9 @@ def __init__(self, config, position_embedding_type=None): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, @@ -174,53 +173,65 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if is_cross_attention and encoder_attention_mask is not None: + attention_mask = encoder_attention_mask + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value - if is_cross_attention and past_key_value is not None: + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.key_cache[self.layer_idx] + value_layer = curr_past_key_value.value_cache[self.layer_idx] else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - - query_layer = self.transpose_for_scores(mixed_query_layer) - - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) + key_layer = self.key(current_states) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + value_layer = self.value(current_states) + value_layer = value_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: + if past_key_value is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -262,11 +273,7 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return context_layer, attention_probs class XmodSelfOutput(nn.Module): @@ -285,9 +292,9 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class XmodAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() - self.self = XmodSelfAttention(config, position_embedding_type=position_embedding_type) + self.self = XmodSelfAttention(config, position_embedding_type=position_embedding_type, layer_idx=layer_idx) self.output = XmodSelfOutput(config) self.pruned_heads = set() self.pre_norm = config.pre_norm @@ -311,6 +318,7 @@ def prune_heads(self, heads): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("encoder_attention_mask", version="4.55.0") def forward( self, hidden_states: torch.Tensor, @@ -318,8 +326,9 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: residual = hidden_states if self.pre_norm: @@ -332,6 +341,7 @@ def forward( encoder_attention_mask, past_key_value, output_attentions, + cache_position, ) attention_output = self.output(self_outputs[0], residual) if not self.pre_norm: @@ -425,17 +435,17 @@ def lang_adapter(self, lang_ids: torch.Tensor, hidden_states: torch.Tensor): class XmodLayer(GradientCheckpointingLayer): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = XmodAttention(config) + self.attention = XmodAttention(config, layer_idx=layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = XmodAttention(config, position_embedding_type="absolute") + self.crossattention = XmodAttention(config, position_embedding_type="absolute", layer_idx=layer_idx) self.intermediate = XmodIntermediate(config) self.output = XmodOutput(config) self.pre_norm = config.pre_norm @@ -448,28 +458,21 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -477,23 +480,17 @@ def forward( " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - cross_attn_past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights residual = attention_output if self.pre_norm: @@ -507,13 +504,7 @@ def forward( layer_output = self.output(intermediate_output, residual, lang_ids) if not self.pre_norm: layer_output = self.output.LayerNorm(layer_output) - outputs = (layer_output,) + outputs - - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (present_key_value,) - - return outputs + return (layer_output,) + outputs def feed_forward_chunk(self, attention_output): return self.intermediate(attention_output) @@ -523,7 +514,7 @@ class XmodEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config - self.layer = nn.ModuleList([XmodLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([XmodLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.is_pre_norm = config.pre_norm if self.is_pre_norm: self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -542,6 +533,7 @@ def forward( output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: if self.gradient_checkpointing and self.training: if use_cache: @@ -549,17 +541,26 @@ def forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - next_decoder_cache = () if use_cache else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None layer_outputs = layer_module( hidden_states, @@ -568,13 +569,12 @@ def forward( layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -586,12 +586,15 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -600,7 +603,7 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -744,6 +747,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: r""" lang_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -774,8 +778,13 @@ def forward( batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if lang_ids is None: if self.config.default_language is None: @@ -836,6 +845,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None @@ -900,6 +910,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs, ) -> Union[tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: r""" @@ -947,6 +958,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = outputs[0] @@ -974,15 +986,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - # Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM._reorder_cache - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - @auto_docstring class XmodForMaskedLM(XmodPreTrainedModel): diff --git a/src/transformers/models/yolos/modeling_yolos.py b/src/transformers/models/yolos/modeling_yolos.py index 96ae1fe1d9c0..e686f9741b41 100755 --- a/src/transformers/models/yolos/modeling_yolos.py +++ b/src/transformers/models/yolos/modeling_yolos.py @@ -264,17 +264,28 @@ def __init__(self, config: YolosConfig) -> None: self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( - self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + self, + hidden_states, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(self.query(hidden_states)) + batch_size, seq_length, _ = hidden_states.shape + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/yoso/modeling_yoso.py b/src/transformers/models/yoso/modeling_yoso.py index da35490c59ed..1d999ea4ca50 100644 --- a/src/transformers/models/yoso/modeling_yoso.py +++ b/src/transformers/models/yoso/modeling_yoso.py @@ -341,17 +341,23 @@ def __init__(self, config, position_embedding_type=None): groups=config.num_attention_heads, ) - def transpose_for_scores(self, layer): - new_layer_shape = layer.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - layer = layer.view(*new_layer_shape) - return layer.permute(0, 2, 1, 3) - def forward(self, hidden_states, attention_mask=None, output_attentions=False): - mixed_query_layer = self.query(hidden_states) - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + batch_size, seq_length, _ = hidden_states.shape + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) if self.use_conv: conv_value_layer = self.conv(value_layer * attention_mask[:, None, :, None]) diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index 06d689ae6807..317ae28e4490 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -788,7 +788,7 @@ class ZambaPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn = False _supports_sdpa = False - _supports_cache_class = True # Note: only supports ZambaHybridDynamicCache + # Note: only supports ZambaHybridDynamicCache _is_stateful = True def _init_weights(self, module): diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index ce7555058a4e..45e638ba6f8a 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -1179,7 +1179,7 @@ class Zamba2PreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_flex_attn = True _supports_sdpa = True - _supports_cache_class = True # Note: only supports Zamba2HybridDynamicCache + # Note: only supports Zamba2HybridDynamicCache _is_stateful = True def _init_weights(self, module): diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 032a2dd5cbf4..b912a63419a2 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -902,7 +902,7 @@ class Zamba2PreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_flex_attn = True _supports_sdpa = True - _supports_cache_class = True # Note: only supports Zamba2HybridDynamicCache + # Note: only supports Zamba2HybridDynamicCache _is_stateful = True def _init_weights(self, module): diff --git a/src/transformers/models/zoedepth/modeling_zoedepth.py b/src/transformers/models/zoedepth/modeling_zoedepth.py index 48ff8174186e..900425633847 100644 --- a/src/transformers/models/zoedepth/modeling_zoedepth.py +++ b/src/transformers/models/zoedepth/modeling_zoedepth.py @@ -799,11 +799,6 @@ def __init__(self, hidden_size, num_attention_heads, dropout): self.dropout = nn.Dropout(dropout) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, queries: torch.Tensor, @@ -812,9 +807,18 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor]: - query_layer = self.transpose_for_scores(self.query(queries)) - key_layer = self.transpose_for_scores(self.key(keys)) - value_layer = self.transpose_for_scores(self.value(values)) + batch_size, seq_length, _ = queries.shape + query_layer = ( + self.query(queries) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(keys).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + ) + value_layer = ( + self.value(values).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + ) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/utils/args_doc.py b/src/transformers/utils/args_doc.py index a08dd7fff363..3b521a27c565 100644 --- a/src/transformers/utils/args_doc.py +++ b/src/transformers/utils/args_doc.py @@ -353,17 +353,13 @@ class ModelArgs: blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. - Two formats are allowed: - - a [`~cache_utils.Cache`] instance, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); - - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy - cache format. + Only [`~cache_utils.Cache`] instance is allowed as input, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + If no `past_key_values` are passed, [`~cache_utils.DynamicCache`] will be initialized by default. - The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the - legacy cache format will be returned. + The model will output the same cache format that is fed as input. - If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't - have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + If `past_key_values` are used, the user is expected to input only unprocessed `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, unprocessed_length)` instead of all `input_ids` of shape `(batch_size, sequence_length)`. """, "shape": None, @@ -969,12 +965,6 @@ class ClassAttrs: _supports_flex_attn = r""" Whether the model's attention implementation supports FlexAttention. """ - _supports_cache_class = r""" - Whether the model supports a `Cache` instance as `past_key_values`. - """ - _supports_quantized_cache = r""" - Whether the model supports a `QuantoQuantizedCache` instance as `past_key_values`. - """ _supports_static_cache = r""" Whether the model supports a `StaticCache` instance as `past_key_values`. """ diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index ea00871d27ba..531bf70d5eeb 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1001,7 +1001,7 @@ def test_contrastive_generate(self): self.skipTest(reason="Stateful models don't support contrastive search generation") # won't fix: FSMT and Reformer have a different cache variable type (and format). - if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): + if any(model_name in model_class.__name__.lower() for model_name in ["reformer"]): self.skipTest(reason="Won't fix: old model with different cache format") config, inputs_dict = self.prepare_config_and_inputs_for_generate() @@ -1030,7 +1030,7 @@ def test_contrastive_generate_dict_outputs_use_cache(self): self.skipTest(reason="Stateful models don't support contrastive search generation") # won't fix: FSMT and Reformer have a different cache variable type (and format). - if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): + if any(model_name in model_class.__name__.lower() for model_name in ["reformer"]): self.skipTest(reason="Won't fix: old model with different cache format") config, inputs_dict = self.prepare_config_and_inputs_for_generate() @@ -1070,10 +1070,8 @@ def test_contrastive_generate_low_memory(self): if model_class._is_stateful: self.skipTest(reason="Stateful models don't support contrastive search generation") - if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer", "speech2text"]): + if any(model_name in model_class.__name__.lower() for model_name in ["reformer"]): self.skipTest(reason="Won't fix: old model with different cache format") - if any(model_name in model_class.__name__.lower() for model_name in ["gptbigcode"]): - self.skipTest(reason="TODO: fix me") config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1) @@ -1112,22 +1110,16 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type): for model_class in self.all_generative_model_classes: if model_class._is_stateful: self.skipTest(reason="Stateful models don't support assisted generation") - if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): + if any(model_name in model_class.__name__.lower() for model_name in ["reformer"]): self.skipTest(reason="Won't fix: old model with different cache format") if any( model_name in model_class.__name__.lower() for model_name in [ - "bigbirdpegasus", - "led", - "mega", "moshi", - "speech2text", "git", "prophetnet", - "seamlessm4t", - "clvp", "mllama", # special cache sizes - "blip2", # overridden `generate()` + "blip2", # overridden `generate()` all BLIP models "instructblip", "instructblipvideo", ] @@ -1196,23 +1188,16 @@ def test_prompt_lookup_decoding_matches_greedy_search(self): for model_class in self.all_generative_model_classes: if model_class._is_stateful: self.skipTest(reason="Stateful models don't support assisted generation") - if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): + if any(model_name in model_class.__name__.lower() for model_name in ["reformer"]): self.skipTest(reason="Won't fix: old model with different cache format") if any( model_name in model_class.__name__.lower() for model_name in [ - "bigbirdpegasus", - "led", - "mega", "moshi", - "speech2text", "git", "prophetnet", - "seamlessm4t", - "clvp", - "fuyu", "mllama", # special cache sizes - "blip2", # overridden `generate()` + "blip2", # overridden `generate()` for all BLIP models "instructblip", "instructblipvideo", # All models below: shouldn't suggest image tokens. Can be fixed by passing `suppress_ids` to candidate generator: @joaa @raushan @@ -1340,22 +1325,16 @@ def test_assisted_decoding_sample(self): for model_class in self.all_generative_model_classes: if model_class._is_stateful: self.skipTest(reason="Stateful models don't support assisted generation") - if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): + if any(model_name in model_class.__name__.lower() for model_name in ["reformer"]): self.skipTest(reason="Won't fix: old model with different cache format") if any( model_name in model_class.__name__.lower() for model_name in [ - "bigbirdpegasus", - "led", - "mega", "moshi", - "speech2text", "git", "prophetnet", - "seamlessm4t", - "clvp", "mllama", # special cache sizes - "blip2", # overridden `generate()` + "blip2", # overridden `generate()` for all BLIP models "instructblip", "instructblipvideo", ] @@ -2059,12 +2038,15 @@ def test_generate_with_static_cache(self): @pytest.mark.generate def test_generate_with_quant_cache(self): for model_class in self.all_generative_model_classes: - if not model_class._supports_quantized_cache: + config, inputs_dict = self.prepare_config_and_inputs_for_generate() + + if ( + config.get_text_config(decoder=True).is_encoder_decoder + or not model_class._supports_default_dynamic_cache() + ): self.skipTest(reason="This model does not support the quantized cache format") - config, inputs_dict = self.prepare_config_and_inputs_for_generate() config.is_decoder = True - model = model_class(config).to(torch_device).eval() generation_kwargs = { "max_new_tokens": 5, @@ -2509,14 +2491,10 @@ def _check_generate_outputs(self, output, config, use_cache=False, num_return_se # Past Key Value States -- a few notes here: # 1. Its inner sequence length is with respect to the inputs of the latest forward pass, hence the "-1" # 2. We ignore models that have unique cache structures (e.g. mamba) or are in need of refatoring to match the - # standard cache format (e.g.gptbigcode ) + # standard cache format (e.g.mamba architecture ) models_without_standard_cache = ( "bamba", - "ctrl", - "fsmt", "granitemoehybrid", - "gptbigcode", - "mega", "reformer", "jamba", "mamba", diff --git a/tests/models/bert/test_modeling_bert.py b/tests/models/bert/test_modeling_bert.py index a6aac8e3829a..c75175a811fb 100644 --- a/tests/models/bert/test_modeling_bert.py +++ b/tests/models/bert/test_modeling_bert.py @@ -737,9 +737,11 @@ def test_sdpa_ignored_mask(self): torch.allclose(res_eager.last_hidden_state, res_sdpa.last_hidden_state, atol=1e-5, rtol=1e-4) ) - # Case where query length != kv_length. - res_eager = model(**inp, past_key_values=pkv) - res_sdpa = model_sdpa(**inp, past_key_values=pkv) + # Case where query length != kv_length. Note that model needs to be a decoder so we can use cache + model.config.is_decoder = True + model_sdpa.config.is_decoder = True + res_eager = model(**inp, past_key_values=pkv, use_cache=True) + res_sdpa = model_sdpa(**inp, past_key_values=pkv, use_cache=True) self.assertTrue( torch.allclose(res_eager.last_hidden_state, res_sdpa.last_hidden_state, atol=1e-5, rtol=1e-4) ) diff --git a/tests/models/big_bird/test_modeling_big_bird.py b/tests/models/big_bird/test_modeling_big_bird.py index bdab0f73b653..8ec874d0f7a8 100644 --- a/tests/models/big_bird/test_modeling_big_bird.py +++ b/tests/models/big_bird/test_modeling_big_bird.py @@ -284,6 +284,7 @@ def create_and_check_decoder_model_past_large_inputs( attention_mask=next_attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, + use_cache=False, output_hidden_states=True, )["hidden_states"][0] output_from_past = model( @@ -292,6 +293,7 @@ def create_and_check_decoder_model_past_large_inputs( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, past_key_values=past_key_values, + use_cache=True, output_hidden_states=True, )["hidden_states"][0] diff --git a/tests/models/mllama/test_modeling_mllama.py b/tests/models/mllama/test_modeling_mllama.py index 729122951e05..f2b6969c4109 100644 --- a/tests/models/mllama/test_modeling_mllama.py +++ b/tests/models/mllama/test_modeling_mllama.py @@ -34,6 +34,7 @@ Expectations, cleanup, require_bitsandbytes, + require_optimum_quanto, require_read_token, require_torch, require_torch_accelerator, @@ -344,6 +345,12 @@ def _check_attentions_for_generate( self.assertListEqual([layer_attention.shape for layer_attention in iter_attentions], expected_shapes) + @require_optimum_quanto + @pytest.mark.generate + @unittest.skip("Mllama is actually an encoder decoder cache and thus can't supports quant cache") + def test_generate_with_quant_cache(self): + pass + @unittest.skip("For some unknown reasons the tests fails in CrossAttention layer when doing torch.sdpa(). ") def test_sdpa_can_compile_dynamic(self): pass diff --git a/tests/models/mvp/test_modeling_mvp.py b/tests/models/mvp/test_modeling_mvp.py index 9ee1077d7d9b..1dc16992633c 100644 --- a/tests/models/mvp/test_modeling_mvp.py +++ b/tests/models/mvp/test_modeling_mvp.py @@ -770,9 +770,9 @@ def create_and_check_decoder_model_attention_mask_past( # get two different outputs output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"] - output_from_past = model(next_tokens, attention_mask=attn_mask, past_key_values=past_key_values)[ - "last_hidden_state" - ] + output_from_past = model( + next_tokens, attention_mask=attn_mask, past_key_values=past_key_values, use_cache=True + )["last_hidden_state"] # select random slice random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() diff --git a/tests/models/perception_lm/test_processor_perception_lm.py b/tests/models/perception_lm/test_processor_perception_lm.py index 7ae377d14a74..28f5a56e4c3a 100644 --- a/tests/models/perception_lm/test_processor_perception_lm.py +++ b/tests/models/perception_lm/test_processor_perception_lm.py @@ -21,7 +21,7 @@ AutoTokenizer, PerceptionLMProcessor, ) -from transformers.testing_utils import require_vision +from transformers.testing_utils import require_read_token, require_vision from transformers.utils import is_torch_available, is_vision_available from ...test_processing_common import ProcessorTesterMixin @@ -34,11 +34,12 @@ import torch -# TEST_MODEL_PATH = "facebook/Perception-LM-1B" -TEST_MODEL_PATH = "shumingh/plm_1b_hf" # should be replaced by the above once checkpoints are merged +TEST_MODEL_PATH = "facebook/Perception-LM-1B" @require_vision +@require_read_token +@unittest.skip("Fequires read token and we didn't requests access yet. FIXME @ydshieh when you are back :)") class PerceptionLMProcessorTest(ProcessorTesterMixin, unittest.TestCase): processor_class = PerceptionLMProcessor diff --git a/tests/models/prophetnet/test_modeling_prophetnet.py b/tests/models/prophetnet/test_modeling_prophetnet.py index 58b535fd0ae7..2b3fc9aa9e30 100644 --- a/tests/models/prophetnet/test_modeling_prophetnet.py +++ b/tests/models/prophetnet/test_modeling_prophetnet.py @@ -737,7 +737,7 @@ def create_and_check_decoder_model_attention_mask_past( # get two different outputs output_from_no_past = model(next_input_ids)["last_hidden_state"] - output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"] + output_from_past = model(next_tokens, past_key_values=past_key_values, use_cache=True)["last_hidden_state"] # select random slice random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() diff --git a/tests/models/speecht5/test_modeling_speecht5.py b/tests/models/speecht5/test_modeling_speecht5.py index 9c0fa0fa394f..c3a8025e0fa5 100644 --- a/tests/models/speecht5/test_modeling_speecht5.py +++ b/tests/models/speecht5/test_modeling_speecht5.py @@ -354,9 +354,9 @@ def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict): next_attention_mask = torch.cat([attention_mask, next_attn_mask], dim=-1) output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"] - output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[ - "last_hidden_state" - ] + output_from_past = model( + next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values, use_cache=True + )["last_hidden_state"] # select random slice random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 5c8d0d88d4bb..b33c246e7cfe 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -853,10 +853,10 @@ def test_can_init_all_missing_weights(self): addition_year = int(match_object.group(1)) for model_class in self.all_model_classes: - # For now, skip everything older than 2025 and "important models" (too much models to patch otherwise) + # For now, skip everything older than 2024 and "important models" (too much models to patch otherwise) # Use `supports_cache_class` as a proxy to judge "important" models in order to prioritize them # TODO: relax this as we patch more and more models - if addition_year < 2024 and not model_class._supports_cache_class: + if addition_year < 2024: self.skipTest(reason=f"{model_class} is not a priorited model for now.") # Monkey patch the method to add a seed (we do it on PreTrainedModel._initialize_weights, which wraps @@ -1590,18 +1590,7 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa head_dim = model.config.hidden_size // model.config.num_attention_heads cache_shape = (batch_size, num_heads, 0, head_dim) - empty_pkv = tuple( - ( - torch.rand(cache_shape, dtype=torch.float, device=torch_device), - torch.rand(cache_shape, dtype=torch.float, device=torch_device), - ) - for i in range(model.config.num_hidden_layers) - ) - empty_pkv = ( - DynamicCache.from_legacy_cache(empty_pkv) - if model_class._supports_cache_class - else empty_pkv - ) + empty_pkv = DynamicCache() cache_length = 9 cache_shape = (batch_size, num_heads, cache_length, head_dim) @@ -1612,11 +1601,7 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa ) for i in range(model.config.num_hidden_layers) ) - non_empty_pkv = ( - DynamicCache.from_legacy_cache(non_empty_pkv) - if model_class._supports_cache_class - else non_empty_pkv - ) + non_empty_pkv = DynamicCache.from_legacy_cache(non_empty_pkv) inps = copy.deepcopy(inputs_to_test[0]) diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 0c3c90768deb..ed8f7d4da158 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -46,7 +46,6 @@ AutoModelForCausalLM, AutoTokenizer, Cache, - ClvpForCausalLM, DynamicCache, Gemma2Config, GenerationConfig, @@ -122,36 +121,6 @@ def test_dynamic_cache_retrocompatibility(self): torch.allclose(to_legacy[layer_idx][key_value_idx], new_cache[layer_idx][key_value_idx]) ) - def test_reorder_cache_retrocompatibility(self): - """Tests that Cache.reorder_cache is retrocompatible with the legacy code path""" - legacy_reorder_fn = ClvpForCausalLM._reorder_cache # An example of a legacy `_reorder_cache` function - - legacy_cache = () - new_cache = DynamicCache() - - # Creates a new cache with 10 layers in both formats - for layer_idx in range(10): - new_key = torch.rand((4, 4, 8, 16)) - new_value = torch.rand((4, 4, 8, 16)) - new_cache.update(new_key, new_value, layer_idx) - legacy_cache += ((new_key, new_value),) - - # Let's create some dummy beam indices. From the shape above, it is equivalent to the case where num_beams=4 - # and batch_size=1 - beam_idx = torch.randint(low=0, high=4, size=(4,)) - - legacy_cache_reordered = legacy_reorder_fn(legacy_cache, beam_idx) - new_cache.reorder_cache(beam_idx) - - # Let's check that the results are the same - for layer_idx in range(10): - for key_value_idx in range(2): - self.assertTrue( - torch.allclose( - new_cache[layer_idx][key_value_idx], legacy_cache_reordered[layer_idx][key_value_idx] - ) - ) - def test_static_cache_mha_mqa_gqa(self): """ Tests that static cache works with multi-head attention (MHA), grouped query attention (GQA), and multi-query