Skip to content

Gemma3n get_placeholder_mask issue #39991

@Znerual

Description

@Znerual

System Info

  • transformers version: 4.55.0
  • Platform: Linux-6.8.0-71-generic-x86_64-with-glibc2.39
  • Python version: 3.12.2
  • Huggingface_hub version: 0.34.3
  • Safetensors version: 0.6.1
  • Accelerate version: 1.9.0
  • Accelerate config: not found
  • DeepSpeed version: not installed
  • PyTorch version (accelerator?): 2.7.1+cu126 (NA)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: no

Who can help?

@qubvel @eustlb

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Running the Gemma3n model on audio data using the description from huggingface:

import io
import librosa
import numpy as np
import requests
import torch
from transformers import AutoProcessor, Gemma3nForConditionalGeneration

processor = AutoProcessor.from_pretrained("google/gemma-3n-E4B-it")
model = Gemma3nForConditionalGeneration.from_pretrained(
        "google/gemma-3n-E4B-it",
        torch_dtype="auto", 
        device_map="auto"
    ).eval()


# 2. Download and Process the Problematic Audio File
audio_url = "https://styletts.github.io/wavs/styletts/abstract.mp3"
print(f"Downloading audio from: {audio_url}")
response = requests.get(audio_url)
response.raise_for_status()
audio_bytes = response.content

print("Processing audio to 16kHz mono float32...")
try:
    audio_data, _ = librosa.load(
        io.BytesIO(audio_bytes),
        sr=16000,
        mono=True,
        res_type='scipy'
    )
    processed_audio = audio_data.astype(np.float32)
    print("Audio processing complete.")
except Exception as e:
    print(f"Failed to process audio: {e}")
    exit()

content = [
   {"type": "text", "text": "Some text prompt"},
   {"type": "audio", "audio": audio_data} 
]
        
messages = [
    {"role": "system", "content": [{"type": "text", "text" : "You are a helpful assistant."}]},
    {"role": "user", "content": content}
]

inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(model.device, dtype=model.dtype)
input_len = inputs["input_ids"].shape[-1]

with torch.inference_mode():
    outputs = model.generate(**inputs, max_new_tokens=512, do_sample=False)
    outputs = outputs[0][input_len:] 

response_text = processor.decode(outputs, skip_special_tokens=True)
final_response = response_text.split("model\n")[-1].strip()

results in the error:

Traceback (most recent call last):
  File "***/gemma-api/bug.py", line 51, in <module>
    outputs = model.generate(**inputs, max_new_tokens=512, do_sample=False)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "***/gemma-api-r8Np3HKg-py3.12/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "***/gemma-api-r8Np3HKg-py3.12/lib/python3.12/site-packages/transformers/generation/utils.py", line 2634, in generate
    result = self._sample(
             ^^^^^^^^^^^^^
  File "/***gemma-api-r8Np3HKg-py3.12/lib/python3.12/site-packages/transformers/generation/utils.py", line 3615, in _sample
    outputs = self(**model_inputs, return_dict=True)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/***/gemma-api-r8Np3HKg-py3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/***/gemma-api-r8Np3HKg-py3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/***gemma-api-r8Np3HKg-py3.12/lib/python3.12/site-packages/transformers/utils/generic.py", line 959, in wrapper
    output = func(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/***/gemma-api-r8Np3HKg-py3.12/lib/python3.12/site-packages/transformers/models/gemma3n/modeling_gemma3n.py", line 2283, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/***/gemma-api-r8Np3HKg-py3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/***/gemma-api-r8Np3HKg-py3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "***/gemma-api-r8Np3HKg-py3.12/lib/python3.12/site-packages/transformers/utils/generic.py", line 959, in wrapper
    output = func(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "***/gemma-api-r8Np3HKg-py3.12/lib/python3.12/site-packages/transformers/models/gemma3n/modeling_gemma3n.py", line 2117, in forward
    _, special_audio_mask = self.get_placeholder_mask(
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Gemma3nModel.get_placeholder_mask() missing 1 required positional argument: 'image_features'

and it can be fixed by modifying the modeling_gemma3n.py file in lines 1966 to 1969, by making the arguments optional:

OLD

    def get_placeholder_mask(
        self,
        input_ids: torch.LongTensor,
        inputs_embeds: torch.FloatTensor,
        image_features: torch.FloatTensor,
        audio_features: torch.FloatTensor,
    ):

FIXED

    def get_placeholder_mask(
        self,
        input_ids: torch.LongTensor | None = None,
        inputs_embeds: torch.FloatTensor | None = None,
        image_features: torch.FloatTensor | None = None,
        audio_features: torch.FloatTensor | None = None,
    ):

Expected behavior

Running without an error

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions