diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 4d7d9c4b7807..52ef4a9601bc 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -1194,9 +1194,15 @@ def prepare_inputs_for_generation( # TODO: use `next_tokens` directly instead. model_inputs = {"input_ids": input_ids.contiguous()} - input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + if position_ids is not None: + input_length = position_ids.shape[-1] + elif input_ids is not None: + input_length = input_ids.shape[-1] + else: + input_length = inputs_embeds.shape[-2] if cache_position is None: - cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + input_tensor = inputs_embeds if input_ids is None else input_ids + cache_position = torch.arange(past_length, past_length + input_length, device=input_tensor.device) else: cache_position = cache_position[-input_length:] diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index c60c67d46e1b..eba59ca492c8 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -1198,9 +1198,15 @@ def prepare_inputs_for_generation( # TODO: use `next_tokens` directly instead. model_inputs = {"input_ids": input_ids.contiguous()} - input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + if position_ids is not None: + input_length = position_ids.shape[-1] + elif input_ids is not None: + input_length = input_ids.shape[-1] + else: + input_length = inputs_embeds.shape[-2] if cache_position is None: - cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + input_tensor = inputs_embeds if input_ids is None else input_ids + cache_position = torch.arange(past_length, past_length + input_length, device=input_tensor.device) else: cache_position = cache_position[-input_length:] diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index f5f3dc02ee9d..2adb80b7db97 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1295,9 +1295,15 @@ def prepare_inputs_for_generation( # TODO: use `next_tokens` directly instead. model_inputs = {"input_ids": input_ids.contiguous()} - input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + if position_ids is not None: + input_length = position_ids.shape[-1] + elif input_ids is not None: + input_length = input_ids.shape[-1] + else: + input_length = inputs_embeds.shape[-2] if cache_position is None: - cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + input_tensor = inputs_embeds if input_ids is None else input_ids + cache_position = torch.arange(past_length, past_length + input_length, device=input_tensor.device) else: cache_position = cache_position[-input_length:]