diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 5da4878513eb..2225b033aa0a 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -387,13 +387,14 @@ def prepare_inputs_for_generation( input_ids = input_ids[:, cache_position] # 3. Prepare base model inputs + input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and cache_position[0] == 0: - model_inputs["input_ids"] = None + if inputs_embeds is not None and not self.config.is_encoder_decoder and cache_position[0] == 0: + model_inputs[input_ids_key] = None model_inputs["inputs_embeds"] = inputs_embeds else: # `clone` calls in this function ensure a consistent stride. See #32227 - model_inputs["input_ids"] = input_ids.clone(memory_format=torch.contiguous_format) + model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format) model_inputs["inputs_embeds"] = None # 4. Create missing `position_ids` on the fly @@ -421,8 +422,8 @@ def prepare_inputs_for_generation( batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape device = model_inputs["inputs_embeds"].device else: - batch_size, sequence_length = model_inputs["input_ids"].shape - device = model_inputs["input_ids"].device + batch_size, sequence_length = model_inputs[input_ids_key].shape + device = model_inputs[input_ids_key].device # Create the causal mask with fixed shape in advance, to reduce recompilations. If the function to create # the 4D causal mask exists, it should be present in the base model (XXXModel class). @@ -455,6 +456,8 @@ def prepare_inputs_for_generation( if key not in model_inputs: model_inputs[key] = value + # 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples) + model_inputs.pop("labels", None) return model_inputs def _prepare_model_inputs( diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 822be354fb9d..07c1fa622ea3 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -1682,45 +1682,6 @@ def forward( encoder_attentions=outputs.encoder_attentions, ) - def prepare_inputs_for_generation( - self, - decoder_input_ids, - past_key_values=None, - attention_mask=None, - decoder_attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past_key_values is used - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if decoder_input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = decoder_input_ids.shape[1] - 1 - - decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] - - return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "decoder_attention_mask": decoder_attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) - } - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index e26dce1edfc2..19540a7498f5 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -2561,45 +2561,6 @@ def forward( encoder_attentions=outputs.encoder_attentions, ) - def prepare_inputs_for_generation( - self, - decoder_input_ids, - past_key_values=None, - attention_mask=None, - decoder_attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past_key_values is used - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if decoder_input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = decoder_input_ids.shape[1] - 1 - - decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] - - return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "decoder_attention_mask": decoder_attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) - } - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index ae37f546e510..5c4fdfb472c3 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -1333,43 +1333,6 @@ def forward( encoder_attentions=outputs.encoder_attentions, ) - def prepare_inputs_for_generation( - self, - decoder_input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past is used - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if decoder_input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = decoder_input_ids.shape[1] - 1 - - decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] - - return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) - } - @staticmethod def _reorder_cache(past_key_values, beam_idx): reordered_past = () diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 93298c4e80e5..6f79d2a7d005 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -1285,43 +1285,6 @@ def forward( encoder_attentions=outputs.encoder_attentions, ) - def prepare_inputs_for_generation( - self, - decoder_input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past is used - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if decoder_input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = decoder_input_ids.shape[1] - 1 - - decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] - - return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) - } - @staticmethod def _reorder_cache(past_key_values, beam_idx): reordered_past = () diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index 78384e6ce2f7..5ee7ae21f9d5 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -915,6 +915,8 @@ def forward( ) def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs): + # Overwrite -- hardcoded key return (`is_decoder=True`) + input_shape = input_ids.shape # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly if attention_mask is None: diff --git a/src/transformers/models/fsmt/modeling_fsmt.py b/src/transformers/models/fsmt/modeling_fsmt.py index 4d50f9bb5925..3f865c037c01 100644 --- a/src/transformers/models/fsmt/modeling_fsmt.py +++ b/src/transformers/models/fsmt/modeling_fsmt.py @@ -1191,7 +1191,7 @@ def __init__(self, config: FSMTConfig): @add_end_docstrings(FSMT_GENERATION_EXAMPLE) def forward( self, - input_ids: torch.LongTensor, + input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.BoolTensor] = None, @@ -1263,30 +1263,6 @@ def forward( encoder_attentions=outputs.encoder_attentions, ) - def prepare_inputs_for_generation( - self, - decoder_input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) - } - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id) diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index f96bfd82b526..ee1ad90bfcea 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -2437,36 +2437,6 @@ def forward( encoder_global_attentions=outputs.encoder_global_attentions, ) - def prepare_inputs_for_generation( - self, - decoder_input_ids, - past_key_values=None, - attention_mask=None, - global_attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past is used - if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] - - return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "global_attention_mask": global_attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) - } - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index 8f9385c0fe76..d351e798ac7f 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -2085,42 +2085,6 @@ def forward( encoder_attentions=encoder_outputs.attentions, ) - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past_key_values is used - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - - input_ids = input_ids[:, remove_prefix_length:] - - return { - "decoder_input_ids": input_ids, - "past_key_values": past_key_values, - "encoder_outputs": encoder_outputs, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, - } - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return self._shift_right(labels) diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 1588aa28aa2d..cc35a3504255 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -1621,43 +1621,6 @@ def forward( encoder_attentions=outputs.encoder_attentions, ) - def prepare_inputs_for_generation( - self, - decoder_input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past is used - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if decoder_input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = decoder_input_ids.shape[1] - 1 - - decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] - - return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) - } - @staticmethod def _reorder_cache(past_key_values, beam_idx): reordered_past = () diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index bbb3381bd973..2d7c7d85daed 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -16,7 +16,7 @@ import copy import math -from typing import Dict, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -1438,43 +1438,6 @@ def forward( encoder_attentions=outputs.encoder_attentions, ) - def prepare_inputs_for_generation( - self, - decoder_input_ids: torch.LongTensor, - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - decoder_head_mask: Optional[torch.Tensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - encoder_outputs: Optional[Union[Tuple[torch.Tensor], BaseModelOutput]] = None, - **kwargs, - ) -> Dict: - # cut decoder_input_ids if past is used - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if decoder_input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = decoder_input_ids.shape[1] - 1 - - decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] - - return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) - } - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index a10d62d6dcc3..95cd7c65ef32 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -1647,43 +1647,6 @@ def forward( encoder_attentions=outputs.encoder_attentions, ) - def prepare_inputs_for_generation( - self, - decoder_input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past is used - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if decoder_input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = decoder_input_ids.shape[1] - 1 - - decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] - - return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) - } - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id) diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 6a7406f11b5b..9051414d7414 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -1820,45 +1820,6 @@ def forward( encoder_attentions=encoder_outputs.attentions, ) - # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.prepare_inputs_for_generation - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - decoder_attention_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past_key_values is used - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - - input_ids = input_ids[:, remove_prefix_length:] - - return { - "decoder_input_ids": input_ids, - "past_key_values": past_key_values, - "encoder_outputs": encoder_outputs, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "decoder_attention_mask": decoder_attention_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, - } - # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.prepare_decoder_input_ids_from_labels def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return self._shift_right(labels) diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index 5a466c0cec01..f68a4bb76b3e 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -1475,43 +1475,6 @@ def forward( encoder_attentions=outputs.encoder_attentions, ) - def prepare_inputs_for_generation( - self, - decoder_input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past is used - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if decoder_input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = decoder_input_ids.shape[1] - 1 - - decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] - - return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) - } - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index cedefc4f4642..9c095be16506 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -1762,44 +1762,6 @@ def _unpack_router_logits(self, router_outputs): total_expert_indexes = torch.stack(total_expert_indexes, dim=1) if len(total_expert_indexes) > 0 else None return total_router_logits, total_expert_indexes - # Copied from transfomers.models.switch_transformers.SwitchTransformersForConditionalGeneration.prepare_inputs_for_generation - def prepare_inputs_for_generation( - self, - decoder_input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past is used - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if decoder_input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = decoder_input_ids.shape[1] - 1 - - decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] - - return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) - } - @staticmethod def _reorder_cache(past_key_values, beam_idx): reordered_past = () diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 35f91ca73566..a737ef14d647 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -1390,43 +1390,6 @@ def forward( encoder_attentions=outputs.encoder_attentions, ) - def prepare_inputs_for_generation( - self, - decoder_input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past is used - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if decoder_input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = decoder_input_ids.shape[1] - 1 - - decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] - - return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) - } - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 77c0b32e6433..f90a8d2deb26 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -1588,37 +1588,6 @@ def forward( encoder_attentions=outputs.encoder_attentions, ) - def prepare_inputs_for_generation( - self, - decoder_input_ids, - past_key_values=None, - attention_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past is used - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if decoder_input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = decoder_input_ids.shape[1] - 1 - - decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] - - return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) - } - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 4f6984a7bef6..490fefc686a5 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -16,7 +16,7 @@ import copy import math -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -1372,43 +1372,6 @@ def forward( encoder_attentions=outputs.encoder_attentions, ) - def prepare_inputs_for_generation( - self, - decoder_input_ids: torch.LongTensor, - past_key_values: Optional[List[torch.FloatTensor]] = None, - attention_mask: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.Tensor] = None, - decoder_head_mask: Optional[torch.Tensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - encoder_outputs: Optional[List[torch.FloatTensor]] = None, - **kwargs, # TODO: Check if this is needed. It is unused? - ) -> Dict[str, Any]: - # cut decoder_input_ids if past is used - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if decoder_input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = decoder_input_ids.shape[1] - 1 - - decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] - - return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) - } - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id) diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 003e4f15d2d9..137bd5ad828d 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -2018,35 +2018,6 @@ def _compute_loss(self, logits, labels, ignore_index=-100): return loss - def prepare_inputs_for_generation( - self, - decoder_input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - assert encoder_outputs is not None, "`encoder_outputs` have to be passed for generation." - - if past_key_values: - decoder_input_ids = decoder_input_ids[:, -1:] - # first step, decoder_cached_states are empty - return { - "input_ids": None, # encoder_outputs is defined. input_ids not needed - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, - } - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return self._shift_right(labels) diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index bc375b68e947..5e6f13ca478f 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -1172,6 +1172,8 @@ def prepare_inputs_for_generation( n_docs=None, **kwargs, ): + # Overwritten -- `do_marginalize` is explicitly set in the output + if past_key_values is not None: # if past is defined use only last decoder_input_ids decoder_input_ids = decoder_input_ids[:, -1:] diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index f1495ddc8c00..c39e85bacdd3 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -1702,45 +1702,6 @@ def _unpack_router_logits(self, router_outputs): total_expert_indexes.append(expert_indexes) return torch.cat(total_router_logits, dim=1), torch.cat(total_expert_indexes, dim=1) - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past_key_values is used - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - - input_ids = input_ids[:, remove_prefix_length:] - - output_router_logits = kwargs.get("output_router_logits", True) - - return { - "decoder_input_ids": input_ids, - "past_key_values": past_key_values, - "encoder_outputs": encoder_outputs, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, - "output_router_logits": output_router_logits, - } - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return self._shift_right(labels) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 43e3f3afa4a8..91596f013ab4 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1791,44 +1791,6 @@ def forward( encoder_attentions=encoder_outputs.attentions, ) - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - decoder_attention_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past_key_values is used - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - - input_ids = input_ids[:, remove_prefix_length:] - - return { - "decoder_input_ids": input_ids, - "past_key_values": past_key_values, - "encoder_outputs": encoder_outputs, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "decoder_attention_mask": decoder_attention_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, - } - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return self._shift_right(labels) diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index a7d1e5bacc65..bd621fc2fb3a 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -1302,45 +1302,6 @@ def forward( encoder_attentions=encoder_outputs.attentions, ) - # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.prepare_inputs_for_generation - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - decoder_attention_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past_key_values is used - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - - input_ids = input_ids[:, remove_prefix_length:] - - return { - "decoder_input_ids": input_ids, - "past_key_values": past_key_values, - "encoder_outputs": encoder_outputs, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "decoder_attention_mask": decoder_attention_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, - } - # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.prepare_decoder_input_ids_from_labels def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return self._shift_right(labels) diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index 2363ed04959d..81326c07d6cc 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -1519,6 +1519,8 @@ def prepare_inputs_for_generation( use_cache=True, **kwargs, ): + # Overwitten -- has a unique cache type, `HybridMambaAttentionDynamicCache` + empty_past_kv = past_key_values is None # Omit tokens covered by past_key_values diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 1727aed1117b..02f4f1b6127a 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -3841,6 +3841,38 @@ def test_prepare_inputs_for_generation_decoder_llm(self): self.assertTrue(model_inputs["input_ids"] is not None) self.assertTrue(model_inputs["inputs_embeds"] is None) + def test_prepare_inputs_for_generation_encoder_decoder_llm(self): + """ + Same as `test_prepare_inputs_for_generation_decoder_llm` but for encoder-decoder models. Main difference: we + should look for `decoder_input_ids`, instead of `input_ids`. + """ + model = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5") + model = model.to(torch_device) + + # 1. Sanity check: the model's `prepare_inputs_for_generation` comes from `GenerationMixin` + self.assertTrue("GenerationMixin" in str(model.prepare_inputs_for_generation)) + + # 2. If we pass input ids by themselves, we should get back the same input ids -- with the encoder-decoder key + decoder_input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]]).to(torch_device) + model_inputs = model.prepare_inputs_for_generation(decoder_input_ids) + self.assertTrue(torch.all(model_inputs["decoder_input_ids"] == decoder_input_ids)) + + # 3. If we pass the attention mask too, we will get back the attention mask. Encoder-decoder models usually + # don't use `position_ids` + decoder_attention_mask = torch.tensor([[1, 1, 1], [1, 1, 1]]).to(torch_device) + model_inputs = model.prepare_inputs_for_generation( + decoder_input_ids, decoder_attention_mask=decoder_attention_mask + ) + self.assertTrue(torch.all(model_inputs["decoder_attention_mask"] == decoder_attention_mask)) + self.assertTrue("position_ids" not in model_inputs) + + # 4. `use_cache` (and other kwargs, like the encoder outputs) are forwarded + self.assertFalse("use_cache" in model_inputs) # From the previous input, there is no `use_cache` + model_inputs = model.prepare_inputs_for_generation(decoder_input_ids, use_cache=True, encoder_outputs="foo") + self.assertTrue(model_inputs["use_cache"] is True) + self.assertTrue(model_inputs["encoder_outputs"] == "foo") + # See the decoder-only test for more corner cases. The code is the same, so we don't repeat it here. + def test_generate_compile_fullgraph_tiny(self): """ Tests that we can call end-to-end generation with a tiny model (i.e. doesn't crash)