Skip to content

Gemma2ForCausalLM: ValueError in prepare_inputs_for_generation when using custom input embeddings #32479

@serteal

Description

@serteal

System Info

  • transformers version: 4.44.0
  • Platform: Linux-5.4.0-189-generic-x86_64-with-glibc2.31
  • Python version: 3.10.14
  • Huggingface_hub version: 0.23.4
  • Safetensors version: 0.4.3
  • Accelerate version: 0.32.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.3.0+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: NO
  • Using GPU in script?: YES
  • GPU type: NVIDIA RTX A6000

Who can help?

@ArthurZucker
@gante

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

model_name="google/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

messages = [
    {"role": "user", "content": "Write me a poem about Machine Learning."},
]
template = tokenizer.apply_chat_template(messages, tokenize=False)
input_ids = tokenizer(template, return_tensors="pt").to("cuda")

embedding_layer = model.get_input_embeddings()
inputs_embeds = embedding_layer(input_ids["input_ids"])

outputs = model.generate(
    inputs_embeds=inputs_embeds,
    max_new_tokens=32,
)

Expected behavior

I'm encountering an error when attempting to use custom input embeddings with the Gemma2ForCausalLM model's .generate() method. Specifically, I'm modifying the input embeddings for the model and then trying to generate output from these custom embeddings.

Expected behavior:
When using the inputs_embeds argument to supply custom input embeddings, the .generate() function should be able to process these embeddings and produce a generated output. This is what happens in other models such as "meta-llama/Meta-Llama-3-8B-Instruct".

Actual behavior:
The .generate() method raises a ValueError in the prepare_inputs_for_generation function. The error suggests that the inputs_embeds tensor doesn't have the expected shape, resulting in a "too many values to unpack (expected 2)" error.

The error trace looks like the following:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[5], line 21
     18 embedding_layer = model.get_input_embeddings()
     19 inputs_embeds = embedding_layer(input_ids["input_ids"])
---> 21 outputs = model.generate(
     22     inputs_embeds=inputs_embeds,
     23     max_new_tokens=32,
     24 )

File <full_path>/miniforge3/envs/myenv/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File <full_path>/miniforge3/envs/myenv/lib/python3.10/site-packages/transformers/generation/utils.py:2024, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
   2016     input_ids, model_kwargs = self._expand_inputs_for_generation(
   2017         input_ids=input_ids,
   2018         expand_size=generation_config.num_return_sequences,
   2019         is_encoder_decoder=self.config.is_encoder_decoder,
   2020         **model_kwargs,
   2021     )
   2023     # 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
-> 2024     result = self._sample(
   2025         input_ids,
   2026         logits_processor=prepared_logits_processor,
   2027         logits_warper=prepared_logits_warper,
   2028         stopping_criteria=prepared_stopping_criteria,
   2029         generation_config=generation_config,
   2030         synced_gpus=synced_gpus,
   2031         streamer=streamer,
   2032         **model_kwargs,
   2033     )
   2035 elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
   2036     # 11. prepare logits warper
   2037     prepared_logits_warper = (
   2038         self._get_logits_warper(generation_config, device=input_ids.device)
   2039         if generation_config.do_sample
   2040         else None
   2041     )

File <full_path>/miniforge3/envs/myenv/lib/python3.10/site-packages/transformers/generation/utils.py:2975, in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, logits_warper, **model_kwargs)
   2969 model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
   2971 while self._has_unfinished_sequences(
   2972     this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length
   2973 ):
   2974     # prepare model inputs
-> 2975     model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
   2977     # prepare variable output controls (note: some models won't accept all output controls)
   2978     model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})

File <full_path>/miniforge3/envs/myenv/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py:1223, in Gemma2ForCausalLM.prepare_inputs_for_generation(self, input_ids, past_key_values, attention_mask, inputs_embeds, cache_position, position_ids, use_cache, **kwargs)
   1221 if isinstance(past_key_values, HybridCache) and attention_mask.ndim == 2:
   1222     if inputs_embeds is not None:
-> 1223         batch_size, sequence_length = inputs_embeds.shape
   1224         device = inputs_embeds.device
   1225     else:

ValueError: too many values to unpack (expected 2)

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions