from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
model_name="google/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype=torch.bfloat16,
)
messages = [
{"role": "user", "content": "Write me a poem about Machine Learning."},
]
template = tokenizer.apply_chat_template(messages, tokenize=False)
input_ids = tokenizer(template, return_tensors="pt").to("cuda")
embedding_layer = model.get_input_embeddings()
inputs_embeds = embedding_layer(input_ids["input_ids"])
outputs = model.generate(
inputs_embeds=inputs_embeds,
max_new_tokens=32,
)
I'm encountering an error when attempting to use custom input embeddings with the Gemma2ForCausalLM model's .generate() method. Specifically, I'm modifying the input embeddings for the model and then trying to generate output from these custom embeddings.
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[5], line 21
18 embedding_layer = model.get_input_embeddings()
19 inputs_embeds = embedding_layer(input_ids["input_ids"])
---> 21 outputs = model.generate(
22 inputs_embeds=inputs_embeds,
23 max_new_tokens=32,
24 )
File <full_path>/miniforge3/envs/myenv/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
112 @functools.wraps(func)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)
File <full_path>/miniforge3/envs/myenv/lib/python3.10/site-packages/transformers/generation/utils.py:2024, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
2016 input_ids, model_kwargs = self._expand_inputs_for_generation(
2017 input_ids=input_ids,
2018 expand_size=generation_config.num_return_sequences,
2019 is_encoder_decoder=self.config.is_encoder_decoder,
2020 **model_kwargs,
2021 )
2023 # 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
-> 2024 result = self._sample(
2025 input_ids,
2026 logits_processor=prepared_logits_processor,
2027 logits_warper=prepared_logits_warper,
2028 stopping_criteria=prepared_stopping_criteria,
2029 generation_config=generation_config,
2030 synced_gpus=synced_gpus,
2031 streamer=streamer,
2032 **model_kwargs,
2033 )
2035 elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
2036 # 11. prepare logits warper
2037 prepared_logits_warper = (
2038 self._get_logits_warper(generation_config, device=input_ids.device)
2039 if generation_config.do_sample
2040 else None
2041 )
File <full_path>/miniforge3/envs/myenv/lib/python3.10/site-packages/transformers/generation/utils.py:2975, in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, logits_warper, **model_kwargs)
2969 model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
2971 while self._has_unfinished_sequences(
2972 this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length
2973 ):
2974 # prepare model inputs
-> 2975 model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
2977 # prepare variable output controls (note: some models won't accept all output controls)
2978 model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
File <full_path>/miniforge3/envs/myenv/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py:1223, in Gemma2ForCausalLM.prepare_inputs_for_generation(self, input_ids, past_key_values, attention_mask, inputs_embeds, cache_position, position_ids, use_cache, **kwargs)
1221 if isinstance(past_key_values, HybridCache) and attention_mask.ndim == 2:
1222 if inputs_embeds is not None:
-> 1223 batch_size, sequence_length = inputs_embeds.shape
1224 device = inputs_embeds.device
1225 else:
ValueError: too many values to unpack (expected 2)
System Info
transformersversion: 4.44.0Who can help?
@ArthurZucker
@gante
Information
Tasks
examplesfolder (such as GLUE/SQuAD, ...)Reproduction
Expected behavior
I'm encountering an error when attempting to use custom input embeddings with the Gemma2ForCausalLM model's
.generate()method. Specifically, I'm modifying the input embeddings for the model and then trying to generate output from these custom embeddings.Expected behavior:
When using the
inputs_embedsargument to supply custom input embeddings, the.generate()function should be able to process these embeddings and produce a generated output. This is what happens in other models such as "meta-llama/Meta-Llama-3-8B-Instruct".Actual behavior:
The
.generate()method raises aValueErrorin theprepare_inputs_for_generationfunction. The error suggests that theinputs_embedstensor doesn't have the expected shape, resulting in a "too many values to unpack (expected 2)" error.The error trace looks like the following: