From 0bf42a417f5d144730d2de6a13060a0f0e6a183c Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 9 Oct 2024 18:32:08 +0000 Subject: [PATCH 1/6] poof --- src/transformers/generation/utils.py | 10 +++-- src/transformers/models/bart/modeling_bart.py | 39 ------------------- .../modeling_bigbird_pegasus.py | 39 ------------------- .../models/blenderbot/modeling_blenderbot.py | 37 ------------------ .../modeling_blenderbot_small.py | 37 ------------------ .../models/blip/modeling_blip_text.py | 28 ------------- src/transformers/models/fsmt/modeling_fsmt.py | 26 +------------ src/transformers/models/led/modeling_led.py | 30 -------------- .../models/longt5/modeling_longt5.py | 36 ----------------- .../models/m2m_100/modeling_m2m_100.py | 37 ------------------ .../models/marian/modeling_marian.py | 39 +------------------ .../models/mbart/modeling_mbart.py | 37 ------------------ src/transformers/models/mt5/modeling_mt5.py | 39 ------------------- src/transformers/models/mvp/modeling_mvp.py | 37 ------------------ .../models/nllb_moe/modeling_nllb_moe.py | 38 ------------------ .../models/pegasus/modeling_pegasus.py | 37 ------------------ .../models/pegasus_x/modeling_pegasus_x.py | 31 --------------- .../models/plbart/modeling_plbart.py | 39 +------------------ .../models/prophetnet/modeling_prophetnet.py | 29 -------------- src/transformers/models/rag/modeling_rag.py | 2 + .../modeling_switch_transformers.py | 39 ------------------- src/transformers/models/t5/modeling_t5.py | 38 ------------------ src/transformers/models/umt5/modeling_umt5.py | 39 ------------------- .../models/zamba/modeling_zamba.py | 2 + tests/generation/test_utils.py | 32 +++++++++++++++ 25 files changed, 45 insertions(+), 752 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 5da4878513eb..231f69297fc0 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -387,13 +387,15 @@ def prepare_inputs_for_generation( input_ids = input_ids[:, cache_position] # 3. Prepare base model inputs + # Encoder-decoder models have slightly different keys for the decoder `input_ids` + input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: - model_inputs["input_ids"] = None + model_inputs[input_ids_key] = None model_inputs["inputs_embeds"] = inputs_embeds else: # `clone` calls in this function ensure a consistent stride. See #32227 - model_inputs["input_ids"] = input_ids.clone(memory_format=torch.contiguous_format) + model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format) model_inputs["inputs_embeds"] = None # 4. Create missing `position_ids` on the fly @@ -421,8 +423,8 @@ def prepare_inputs_for_generation( batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape device = model_inputs["inputs_embeds"].device else: - batch_size, sequence_length = model_inputs["input_ids"].shape - device = model_inputs["input_ids"].device + batch_size, sequence_length = model_inputs[input_ids_key].shape + device = model_inputs[input_ids_key].device # Create the causal mask with fixed shape in advance, to reduce recompilations. If the function to create # the 4D causal mask exists, it should be present in the base model (XXXModel class). diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 822be354fb9d..07c1fa622ea3 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -1682,45 +1682,6 @@ def forward( encoder_attentions=outputs.encoder_attentions, ) - def prepare_inputs_for_generation( - self, - decoder_input_ids, - past_key_values=None, - attention_mask=None, - decoder_attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past_key_values is used - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if decoder_input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = decoder_input_ids.shape[1] - 1 - - decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] - - return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "decoder_attention_mask": decoder_attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) - } - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index e26dce1edfc2..19540a7498f5 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -2561,45 +2561,6 @@ def forward( encoder_attentions=outputs.encoder_attentions, ) - def prepare_inputs_for_generation( - self, - decoder_input_ids, - past_key_values=None, - attention_mask=None, - decoder_attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past_key_values is used - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if decoder_input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = decoder_input_ids.shape[1] - 1 - - decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] - - return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "decoder_attention_mask": decoder_attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) - } - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index ae37f546e510..5c4fdfb472c3 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -1333,43 +1333,6 @@ def forward( encoder_attentions=outputs.encoder_attentions, ) - def prepare_inputs_for_generation( - self, - decoder_input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past is used - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if decoder_input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = decoder_input_ids.shape[1] - 1 - - decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] - - return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) - } - @staticmethod def _reorder_cache(past_key_values, beam_idx): reordered_past = () diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 93298c4e80e5..6f79d2a7d005 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -1285,43 +1285,6 @@ def forward( encoder_attentions=outputs.encoder_attentions, ) - def prepare_inputs_for_generation( - self, - decoder_input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past is used - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if decoder_input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = decoder_input_ids.shape[1] - 1 - - decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] - - return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) - } - @staticmethod def _reorder_cache(past_key_values, beam_idx): reordered_past = () diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index 78384e6ce2f7..1f1e0d3c1860 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -914,34 +914,6 @@ def forward( cross_attentions=outputs.cross_attentions, ) - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs): - input_shape = input_ids.shape - # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly - if attention_mask is None: - attention_mask = input_ids.new_ones(input_shape) - - # cut decoder_input_ids if past_key_values is used - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - - input_ids = input_ids[:, remove_prefix_length:] - - return { - "input_ids": input_ids, - "attention_mask": attention_mask, - "past_key_values": past_key_values, - "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), - "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), - "is_decoder": True, - } - def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: diff --git a/src/transformers/models/fsmt/modeling_fsmt.py b/src/transformers/models/fsmt/modeling_fsmt.py index 4d50f9bb5925..3f865c037c01 100644 --- a/src/transformers/models/fsmt/modeling_fsmt.py +++ b/src/transformers/models/fsmt/modeling_fsmt.py @@ -1191,7 +1191,7 @@ def __init__(self, config: FSMTConfig): @add_end_docstrings(FSMT_GENERATION_EXAMPLE) def forward( self, - input_ids: torch.LongTensor, + input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.BoolTensor] = None, @@ -1263,30 +1263,6 @@ def forward( encoder_attentions=outputs.encoder_attentions, ) - def prepare_inputs_for_generation( - self, - decoder_input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) - } - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id) diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index f96bfd82b526..ee1ad90bfcea 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -2437,36 +2437,6 @@ def forward( encoder_global_attentions=outputs.encoder_global_attentions, ) - def prepare_inputs_for_generation( - self, - decoder_input_ids, - past_key_values=None, - attention_mask=None, - global_attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past is used - if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] - - return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "global_attention_mask": global_attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) - } - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index 8f9385c0fe76..d351e798ac7f 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -2085,42 +2085,6 @@ def forward( encoder_attentions=encoder_outputs.attentions, ) - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past_key_values is used - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - - input_ids = input_ids[:, remove_prefix_length:] - - return { - "decoder_input_ids": input_ids, - "past_key_values": past_key_values, - "encoder_outputs": encoder_outputs, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, - } - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return self._shift_right(labels) diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 1588aa28aa2d..cc35a3504255 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -1621,43 +1621,6 @@ def forward( encoder_attentions=outputs.encoder_attentions, ) - def prepare_inputs_for_generation( - self, - decoder_input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past is used - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if decoder_input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = decoder_input_ids.shape[1] - 1 - - decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] - - return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) - } - @staticmethod def _reorder_cache(past_key_values, beam_idx): reordered_past = () diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index bbb3381bd973..2d7c7d85daed 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -16,7 +16,7 @@ import copy import math -from typing import Dict, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -1438,43 +1438,6 @@ def forward( encoder_attentions=outputs.encoder_attentions, ) - def prepare_inputs_for_generation( - self, - decoder_input_ids: torch.LongTensor, - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - decoder_head_mask: Optional[torch.Tensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - encoder_outputs: Optional[Union[Tuple[torch.Tensor], BaseModelOutput]] = None, - **kwargs, - ) -> Dict: - # cut decoder_input_ids if past is used - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if decoder_input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = decoder_input_ids.shape[1] - 1 - - decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] - - return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) - } - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index a10d62d6dcc3..95cd7c65ef32 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -1647,43 +1647,6 @@ def forward( encoder_attentions=outputs.encoder_attentions, ) - def prepare_inputs_for_generation( - self, - decoder_input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past is used - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if decoder_input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = decoder_input_ids.shape[1] - 1 - - decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] - - return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) - } - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id) diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 6a7406f11b5b..9051414d7414 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -1820,45 +1820,6 @@ def forward( encoder_attentions=encoder_outputs.attentions, ) - # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.prepare_inputs_for_generation - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - decoder_attention_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past_key_values is used - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - - input_ids = input_ids[:, remove_prefix_length:] - - return { - "decoder_input_ids": input_ids, - "past_key_values": past_key_values, - "encoder_outputs": encoder_outputs, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "decoder_attention_mask": decoder_attention_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, - } - # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.prepare_decoder_input_ids_from_labels def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return self._shift_right(labels) diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index 5a466c0cec01..f68a4bb76b3e 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -1475,43 +1475,6 @@ def forward( encoder_attentions=outputs.encoder_attentions, ) - def prepare_inputs_for_generation( - self, - decoder_input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past is used - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if decoder_input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = decoder_input_ids.shape[1] - 1 - - decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] - - return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) - } - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index cedefc4f4642..9c095be16506 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -1762,44 +1762,6 @@ def _unpack_router_logits(self, router_outputs): total_expert_indexes = torch.stack(total_expert_indexes, dim=1) if len(total_expert_indexes) > 0 else None return total_router_logits, total_expert_indexes - # Copied from transfomers.models.switch_transformers.SwitchTransformersForConditionalGeneration.prepare_inputs_for_generation - def prepare_inputs_for_generation( - self, - decoder_input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past is used - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if decoder_input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = decoder_input_ids.shape[1] - 1 - - decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] - - return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) - } - @staticmethod def _reorder_cache(past_key_values, beam_idx): reordered_past = () diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 35f91ca73566..a737ef14d647 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -1390,43 +1390,6 @@ def forward( encoder_attentions=outputs.encoder_attentions, ) - def prepare_inputs_for_generation( - self, - decoder_input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past is used - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if decoder_input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = decoder_input_ids.shape[1] - 1 - - decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] - - return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) - } - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 77c0b32e6433..f90a8d2deb26 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -1588,37 +1588,6 @@ def forward( encoder_attentions=outputs.encoder_attentions, ) - def prepare_inputs_for_generation( - self, - decoder_input_ids, - past_key_values=None, - attention_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past is used - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if decoder_input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = decoder_input_ids.shape[1] - 1 - - decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] - - return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) - } - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 4f6984a7bef6..490fefc686a5 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -16,7 +16,7 @@ import copy import math -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -1372,43 +1372,6 @@ def forward( encoder_attentions=outputs.encoder_attentions, ) - def prepare_inputs_for_generation( - self, - decoder_input_ids: torch.LongTensor, - past_key_values: Optional[List[torch.FloatTensor]] = None, - attention_mask: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.Tensor] = None, - decoder_head_mask: Optional[torch.Tensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - encoder_outputs: Optional[List[torch.FloatTensor]] = None, - **kwargs, # TODO: Check if this is needed. It is unused? - ) -> Dict[str, Any]: - # cut decoder_input_ids if past is used - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if decoder_input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = decoder_input_ids.shape[1] - 1 - - decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] - - return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) - } - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id) diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 003e4f15d2d9..137bd5ad828d 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -2018,35 +2018,6 @@ def _compute_loss(self, logits, labels, ignore_index=-100): return loss - def prepare_inputs_for_generation( - self, - decoder_input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - assert encoder_outputs is not None, "`encoder_outputs` have to be passed for generation." - - if past_key_values: - decoder_input_ids = decoder_input_ids[:, -1:] - # first step, decoder_cached_states are empty - return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, - } - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return self._shift_right(labels) diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index bc375b68e947..5e6f13ca478f 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -1172,6 +1172,8 @@ def prepare_inputs_for_generation( n_docs=None, **kwargs, ): + # Overwritten -- `do_marginalize` is explicitly set in the output + if past_key_values is not None: # if past is defined use only last decoder_input_ids decoder_input_ids = decoder_input_ids[:, -1:] diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index f1495ddc8c00..c39e85bacdd3 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -1702,45 +1702,6 @@ def _unpack_router_logits(self, router_outputs): total_expert_indexes.append(expert_indexes) return torch.cat(total_router_logits, dim=1), torch.cat(total_expert_indexes, dim=1) - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past_key_values is used - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - - input_ids = input_ids[:, remove_prefix_length:] - - output_router_logits = kwargs.get("output_router_logits", True) - - return { - "decoder_input_ids": input_ids, - "past_key_values": past_key_values, - "encoder_outputs": encoder_outputs, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, - "output_router_logits": output_router_logits, - } - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return self._shift_right(labels) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 43e3f3afa4a8..91596f013ab4 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1791,44 +1791,6 @@ def forward( encoder_attentions=encoder_outputs.attentions, ) - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - decoder_attention_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past_key_values is used - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - - input_ids = input_ids[:, remove_prefix_length:] - - return { - "decoder_input_ids": input_ids, - "past_key_values": past_key_values, - "encoder_outputs": encoder_outputs, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "decoder_attention_mask": decoder_attention_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, - } - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return self._shift_right(labels) diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index a7d1e5bacc65..bd621fc2fb3a 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -1302,45 +1302,6 @@ def forward( encoder_attentions=encoder_outputs.attentions, ) - # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.prepare_inputs_for_generation - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - decoder_attention_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past_key_values is used - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - - input_ids = input_ids[:, remove_prefix_length:] - - return { - "decoder_input_ids": input_ids, - "past_key_values": past_key_values, - "encoder_outputs": encoder_outputs, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "decoder_attention_mask": decoder_attention_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, - } - # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.prepare_decoder_input_ids_from_labels def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return self._shift_right(labels) diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index 2363ed04959d..81326c07d6cc 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -1519,6 +1519,8 @@ def prepare_inputs_for_generation( use_cache=True, **kwargs, ): + # Overwitten -- has a unique cache type, `HybridMambaAttentionDynamicCache` + empty_past_kv = past_key_values is None # Omit tokens covered by past_key_values diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 1727aed1117b..02f4f1b6127a 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -3841,6 +3841,38 @@ def test_prepare_inputs_for_generation_decoder_llm(self): self.assertTrue(model_inputs["input_ids"] is not None) self.assertTrue(model_inputs["inputs_embeds"] is None) + def test_prepare_inputs_for_generation_encoder_decoder_llm(self): + """ + Same as `test_prepare_inputs_for_generation_decoder_llm` but for encoder-decoder models. Main difference: we + should look for `decoder_input_ids`, instead of `input_ids`. + """ + model = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5") + model = model.to(torch_device) + + # 1. Sanity check: the model's `prepare_inputs_for_generation` comes from `GenerationMixin` + self.assertTrue("GenerationMixin" in str(model.prepare_inputs_for_generation)) + + # 2. If we pass input ids by themselves, we should get back the same input ids -- with the encoder-decoder key + decoder_input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]]).to(torch_device) + model_inputs = model.prepare_inputs_for_generation(decoder_input_ids) + self.assertTrue(torch.all(model_inputs["decoder_input_ids"] == decoder_input_ids)) + + # 3. If we pass the attention mask too, we will get back the attention mask. Encoder-decoder models usually + # don't use `position_ids` + decoder_attention_mask = torch.tensor([[1, 1, 1], [1, 1, 1]]).to(torch_device) + model_inputs = model.prepare_inputs_for_generation( + decoder_input_ids, decoder_attention_mask=decoder_attention_mask + ) + self.assertTrue(torch.all(model_inputs["decoder_attention_mask"] == decoder_attention_mask)) + self.assertTrue("position_ids" not in model_inputs) + + # 4. `use_cache` (and other kwargs, like the encoder outputs) are forwarded + self.assertFalse("use_cache" in model_inputs) # From the previous input, there is no `use_cache` + model_inputs = model.prepare_inputs_for_generation(decoder_input_ids, use_cache=True, encoder_outputs="foo") + self.assertTrue(model_inputs["use_cache"] is True) + self.assertTrue(model_inputs["encoder_outputs"] == "foo") + # See the decoder-only test for more corner cases. The code is the same, so we don't repeat it here. + def test_generate_compile_fullgraph_tiny(self): """ Tests that we can call end-to-end generation with a tiny model (i.e. doesn't crash) From a44719722e58c3baf0eda7ae628a1d8547b0dbd8 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 9 Oct 2024 18:36:49 +0000 Subject: [PATCH 2/6] rm comment --- src/transformers/generation/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 231f69297fc0..e67b820c9610 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -387,7 +387,6 @@ def prepare_inputs_for_generation( input_ids = input_ids[:, cache_position] # 3. Prepare base model inputs - # Encoder-decoder models have slightly different keys for the decoder `input_ids` input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: From 369b61455abe07ec7fe3f8852bddf2f37f40f6f8 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 11 Oct 2024 10:55:19 +0000 Subject: [PATCH 3/6] revert blip 2 --- .../models/blip/modeling_blip_text.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index 1f1e0d3c1860..5ee7ae21f9d5 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -914,6 +914,36 @@ def forward( cross_attentions=outputs.cross_attentions, ) + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs): + # Overwrite -- hardcoded key return (`is_decoder=True`) + + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past_key_values is used + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), + "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), + "is_decoder": True, + } + def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: From 6c4ae1b50c328e0e01cbe02e8b05600f9dc52215 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 11 Oct 2024 14:24:12 +0000 Subject: [PATCH 4/6] handle trainer case --- src/transformers/generation/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index e67b820c9610..3b3688c7fc9d 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -456,6 +456,8 @@ def prepare_inputs_for_generation( if key not in model_inputs: model_inputs[key] = value + # 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples) + model_inputs.pop("labels", None) return model_inputs def _prepare_model_inputs( From ed666f3ecf4aa188f65cda47c88245f01760a68c Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 11 Oct 2024 14:48:42 +0000 Subject: [PATCH 5/6] encoder-decoder models don't try to forward inputs_embeds --- src/transformers/generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 3b3688c7fc9d..2225b033aa0a 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -389,7 +389,7 @@ def prepare_inputs_for_generation( # 3. Prepare base model inputs input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and cache_position[0] == 0: + if inputs_embeds is not None and not self.config.is_encoder_decoder and cache_position[0] == 0: model_inputs[input_ids_key] = None model_inputs["inputs_embeds"] = inputs_embeds else: From af479685db31ca30a5dac141167abf3f24584601 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 11 Oct 2024 14:49:07 +0000 Subject: [PATCH 6/6] test all