diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 6912a4596370..d22d6d208bc4 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -1118,7 +1118,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=0, + num_logits_to_keep=None, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1169,6 +1169,9 @@ def prepare_inputs_for_generation( batch_size=batch_size, ) + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + model_inputs.update( { "position_ids": position_ids, @@ -1176,7 +1179,6 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, - "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 9684fd174733..92f3c57d8670 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -1382,7 +1382,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=0, + num_logits_to_keep=None, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1433,6 +1433,9 @@ def prepare_inputs_for_generation( batch_size=batch_size, ) + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + model_inputs.update( { "position_ids": position_ids, @@ -1440,7 +1443,6 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, - "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 62917c73f332..23eaee57857a 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -1131,7 +1131,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=0, + num_logits_to_keep=None, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1181,6 +1181,9 @@ def prepare_inputs_for_generation( batch_size=batch_size, ) + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + model_inputs.update( { "position_ids": position_ids, @@ -1188,7 +1191,6 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, - "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index bf6ff76189d4..3035afdf9405 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -1094,7 +1094,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=0, + num_logits_to_keep=None, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1148,6 +1148,10 @@ def prepare_inputs_for_generation( cache_position=cache_position, batch_size=batch_size, ) + + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + model_inputs.update( { "position_ids": position_ids, @@ -1155,7 +1159,6 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, - "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 5d4f8e408eb0..f57bdd27fee6 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -1643,7 +1643,7 @@ def prepare_inputs_for_generation( past_key_values=None, attention_mask=None, inputs_embeds=None, - num_logits_to_keep=0, + num_logits_to_keep=None, **kwargs, ): past_length = 0 @@ -1687,6 +1687,9 @@ def prepare_inputs_for_generation( else: model_inputs = {"input_ids": input_ids} + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + image_hidden_states = kwargs.get("image_hidden_states", None) if image_hidden_states is not None: pixel_values = None @@ -1703,7 +1706,6 @@ def prepare_inputs_for_generation( "pixel_values": pixel_values, "pixel_attention_mask": pixel_attention_mask, "image_hidden_states": image_hidden_states, - "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 162478a7258c..c273b021d736 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -1348,7 +1348,7 @@ def prepare_inputs_for_generation( output_router_logits=False, position_ids=None, use_cache=True, - num_logits_to_keep=0, + num_logits_to_keep=None, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1373,6 +1373,9 @@ def prepare_inputs_for_generation( else: model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + model_inputs.update( { "position_ids": position_ids, @@ -1381,7 +1384,6 @@ def prepare_inputs_for_generation( "use_cache": use_cache, "attention_mask": attention_mask, "output_router_logits": output_router_logits, - "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 022ae5ce74c4..051b4f539cdd 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1244,7 +1244,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=0, + num_logits_to_keep=None, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1295,6 +1295,9 @@ def prepare_inputs_for_generation( batch_size=batch_size, ) + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + model_inputs.update( { "position_ids": position_ids, @@ -1302,7 +1305,6 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, - "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 394c80edb540..ae53156d9ba2 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -377,6 +377,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, LlavaCausalLMOutputWithPast]: r""" Args: @@ -385,6 +386,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: Example: @@ -518,6 +525,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, ) logits = outputs[0] @@ -558,6 +566,7 @@ def prepare_inputs_for_generation( pixel_values=None, attention_mask=None, cache_position=None, + num_logits_to_keep=None, **kwargs, ): # Trigger the new behavior if we have more than image embeddings seq length tokens for images @@ -572,6 +581,7 @@ def prepare_inputs_for_generation( inputs_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, **kwargs, ) diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index 723d54c92dd9..5fe029f13e73 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -721,6 +721,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, LlavaNextCausalLMOutputWithPast]: r""" Args: @@ -729,6 +730,11 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + Returns: Example: @@ -890,6 +896,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, ) logits = outputs[0] @@ -931,6 +938,7 @@ def prepare_inputs_for_generation( image_sizes=None, attention_mask=None, cache_position=None, + num_logits_to_keep=None, **kwargs, ): legacy_processing = ( @@ -944,6 +952,7 @@ def prepare_inputs_for_generation( inputs_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, **kwargs, ) diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index 3430fbe590aa..f616c8df9255 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -767,6 +767,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, LlavaNextVideoCausalLMOutputWithPast]: r""" Args: @@ -778,6 +779,10 @@ def forward( Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. Returns: @@ -973,6 +978,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + num_logits_to_keep=num_logits_to_keep, ) logits = outputs[0] @@ -1014,6 +1020,7 @@ def prepare_inputs_for_generation( pixel_values_videos=None, image_sizes=None, attention_mask=None, + num_logits_to_keep=None, **kwargs, ): if past_key_values is not None: @@ -1057,6 +1064,9 @@ def prepare_inputs_for_generation( else: model_inputs = {"input_ids": input_ids} + if "num_logits_to_keep" != None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + model_inputs.update( { "position_ids": position_ids, diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 240e229e0bb0..862157832b91 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -1090,7 +1090,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=0, + num_logits_to_keep=None, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1118,6 +1118,9 @@ def prepare_inputs_for_generation( else: model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + model_inputs.update( { "position_ids": position_ids, @@ -1125,7 +1128,6 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, - "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 919f32abc7fc..22aa9010692a 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -1348,7 +1348,7 @@ def prepare_inputs_for_generation( output_router_logits=False, position_ids=None, use_cache=True, - num_logits_to_keep=0, + num_logits_to_keep=None, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1373,6 +1373,9 @@ def prepare_inputs_for_generation( else: model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + model_inputs.update( { "position_ids": position_ids, @@ -1381,7 +1384,6 @@ def prepare_inputs_for_generation( "use_cache": use_cache, "attention_mask": attention_mask, "output_router_logits": output_router_logits, - "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 548732b371a5..719f3ff2fd17 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -1123,7 +1123,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=0, + num_logits_to_keep=None, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1174,6 +1174,9 @@ def prepare_inputs_for_generation( batch_size=batch_size, ) + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + model_inputs.update( { "position_ids": position_ids, @@ -1181,7 +1184,6 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, - "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 587ef92e4585..ccc376232a9d 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -1162,7 +1162,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=0, + num_logits_to_keep=None, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1213,6 +1213,9 @@ def prepare_inputs_for_generation( batch_size=batch_size, ) + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + model_inputs.update( { "position_ids": position_ids, @@ -1220,7 +1223,6 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, - "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index 8eff8cce50cc..7c09177e27d7 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -356,6 +356,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, PaliGemmaCausalLMOutputWithPast]: r""" Args: @@ -364,6 +365,11 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + Returns: Example: @@ -458,6 +464,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, ) logits = outputs.logits @@ -503,6 +510,7 @@ def prepare_inputs_for_generation( attention_mask=None, token_type_ids=None, use_cache=True, + num_logits_to_keep=None, **kwargs, ): model_inputs = self.language_model.prepare_inputs_for_generation( @@ -511,6 +519,7 @@ def prepare_inputs_for_generation( inputs_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, **kwargs, ) diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 90a7f355992e..193dc860bd13 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -972,7 +972,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=0, + num_logits_to_keep=None, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1023,6 +1023,9 @@ def prepare_inputs_for_generation( batch_size=batch_size, ) + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + model_inputs.update( { "position_ids": position_ids, @@ -1030,7 +1033,6 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, - "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 3c647a9d8d81..f20a0074702f 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -1264,7 +1264,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=0, + num_logits_to_keep=None, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1315,6 +1315,9 @@ def prepare_inputs_for_generation( batch_size=batch_size, ) + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + model_inputs.update( { "position_ids": position_ids, @@ -1322,7 +1325,6 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, - "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 4652294980fd..c1857b73ec39 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -1304,7 +1304,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=0, + num_logits_to_keep=None, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1355,6 +1355,9 @@ def prepare_inputs_for_generation( batch_size=batch_size, ) + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + model_inputs.update( { "position_ids": position_ids, @@ -1362,7 +1365,6 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, - "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 59413730ad4a..55f7c00bb394 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -1162,7 +1162,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=0, + num_logits_to_keep=None, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1213,6 +1213,9 @@ def prepare_inputs_for_generation( batch_size=batch_size, ) + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + model_inputs.update( { "position_ids": position_ids, @@ -1220,7 +1223,6 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, - "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index c08735f45345..3e3c331e91cb 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1358,7 +1358,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=0, + num_logits_to_keep=None, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1409,6 +1409,9 @@ def prepare_inputs_for_generation( batch_size=batch_size, ) + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + model_inputs.update( { "position_ids": position_ids, @@ -1416,7 +1419,6 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, - "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 1ec4665fcfb7..d86770c408cd 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -1250,7 +1250,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=0, + num_logits_to_keep=None, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1301,6 +1301,9 @@ def prepare_inputs_for_generation( batch_size=batch_size, ) + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + model_inputs.update( { "position_ids": position_ids, @@ -1308,7 +1311,6 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, - "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 90603fd4e51e..b904102f4508 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -1138,7 +1138,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=0, + num_logits_to_keep=None, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1189,6 +1189,9 @@ def prepare_inputs_for_generation( batch_size=batch_size, ) + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + model_inputs.update( { "position_ids": position_ids, @@ -1196,7 +1199,6 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, - "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index 425d46bd7741..b9263ad15cbf 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -417,6 +417,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, VideoLlavaCausalLMOutputWithPast]: r""" Args: @@ -425,6 +426,11 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + Returns: Example: @@ -627,6 +633,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, ) logits = outputs[0] @@ -668,6 +675,7 @@ def prepare_inputs_for_generation( pixel_values_videos=None, attention_mask=None, cache_position=None, + num_logits_to_keep=None, **kwargs, ): # Trigger the new behavior if we have more than image embeddings seq length tokens for images @@ -682,6 +690,7 @@ def prepare_inputs_for_generation( inputs_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, **kwargs, ) diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index b1df10fdb3dc..e036d6fb7667 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -379,6 +379,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, VipLlavaCausalLMOutputWithPast]: r""" Args: @@ -387,6 +388,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: Example: @@ -512,6 +519,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, ) logits = outputs[0] @@ -552,6 +560,7 @@ def prepare_inputs_for_generation( pixel_values=None, attention_mask=None, cache_position=None, + num_logits_to_keep=None, **kwargs, ): # Trigger the new behavior if we have more than image embeddings seq length tokens for images @@ -566,6 +575,7 @@ def prepare_inputs_for_generation( inputs_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, **kwargs, )