From 1a59a20e020309e303ef8dabafac5cc5b7be908c Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 22 Mar 2024 19:55:45 -0700 Subject: [PATCH 1/2] Generate: fix generation with inputs_embeds for llama and gemma The changes in https://github.com/huggingface/transformers/pull/29467 break generation with inputs_embeds when input_ids is None since they expect input_ids to be non-None even for the prefill forward without past_key_values. --- src/transformers/models/gemma/modeling_gemma.py | 10 ++++++++-- src/transformers/models/llama/modeling_llama.py | 10 ++++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index c60c67d46e1b..eba59ca492c8 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -1198,9 +1198,15 @@ def prepare_inputs_for_generation( # TODO: use `next_tokens` directly instead. model_inputs = {"input_ids": input_ids.contiguous()} - input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + if position_ids is not None: + input_length = position_ids.shape[-1] + elif input_ids is not None: + input_length = input_ids.shape[-1] + else: + input_length = inputs_embeds.shape[-2] if cache_position is None: - cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + input_tensor = inputs_embeds if input_ids is None else input_ids + cache_position = torch.arange(past_length, past_length + input_length, device=input_tensor.device) else: cache_position = cache_position[-input_length:] diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index f5f3dc02ee9d..2adb80b7db97 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1295,9 +1295,15 @@ def prepare_inputs_for_generation( # TODO: use `next_tokens` directly instead. model_inputs = {"input_ids": input_ids.contiguous()} - input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + if position_ids is not None: + input_length = position_ids.shape[-1] + elif input_ids is not None: + input_length = input_ids.shape[-1] + else: + input_length = inputs_embeds.shape[-2] if cache_position is None: - cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + input_tensor = inputs_embeds if input_ids is None else input_ids + cache_position = torch.arange(past_length, past_length + input_length, device=input_tensor.device) else: cache_position = cache_position[-input_length:] From a55caaf20ab847a45aa13a52a4ebbde011c34b6b Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 22 Mar 2024 21:07:17 -0700 Subject: [PATCH 2/2] Make corresponding change to cohere --- src/transformers/models/cohere/modeling_cohere.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 4d7d9c4b7807..52ef4a9601bc 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -1194,9 +1194,15 @@ def prepare_inputs_for_generation( # TODO: use `next_tokens` directly instead. model_inputs = {"input_ids": input_ids.contiguous()} - input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + if position_ids is not None: + input_length = position_ids.shape[-1] + elif input_ids is not None: + input_length = input_ids.shape[-1] + else: + input_length = inputs_embeds.shape[-2] if cache_position is None: - cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + input_tensor = inputs_embeds if input_ids is None else input_ids + cache_position = torch.arange(past_length, past_length + input_length, device=input_tensor.device) else: cache_position = cache_position[-input_length:]