System Info
transformers version: 4.55.0
- Platform: Linux-6.8.0-71-generic-x86_64-with-glibc2.39
- Python version: 3.12.2
- Huggingface_hub version: 0.34.3
- Safetensors version: 0.6.1
- Accelerate version: 1.9.0
- Accelerate config: not found
- DeepSpeed version: not installed
- PyTorch version (accelerator?): 2.7.1+cu126 (NA)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: no
Who can help?
@qubvel @eustlb
Information
Tasks
Reproduction
Running the Gemma3n model on audio data using the description from huggingface:
import io
import librosa
import numpy as np
import requests
import torch
from transformers import AutoProcessor, Gemma3nForConditionalGeneration
processor = AutoProcessor.from_pretrained("google/gemma-3n-E4B-it")
model = Gemma3nForConditionalGeneration.from_pretrained(
"google/gemma-3n-E4B-it",
torch_dtype="auto",
device_map="auto"
).eval()
# 2. Download and Process the Problematic Audio File
audio_url = "https://styletts.github.io/wavs/styletts/abstract.mp3"
print(f"Downloading audio from: {audio_url}")
response = requests.get(audio_url)
response.raise_for_status()
audio_bytes = response.content
print("Processing audio to 16kHz mono float32...")
try:
audio_data, _ = librosa.load(
io.BytesIO(audio_bytes),
sr=16000,
mono=True,
res_type='scipy'
)
processed_audio = audio_data.astype(np.float32)
print("Audio processing complete.")
except Exception as e:
print(f"Failed to process audio: {e}")
exit()
content = [
{"type": "text", "text": "Some text prompt"},
{"type": "audio", "audio": audio_data}
]
messages = [
{"role": "system", "content": [{"type": "text", "text" : "You are a helpful assistant."}]},
{"role": "user", "content": content}
]
inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(model.device, dtype=model.dtype)
input_len = inputs["input_ids"].shape[-1]
with torch.inference_mode():
outputs = model.generate(**inputs, max_new_tokens=512, do_sample=False)
outputs = outputs[0][input_len:]
response_text = processor.decode(outputs, skip_special_tokens=True)
final_response = response_text.split("model\n")[-1].strip()
results in the error:
Traceback (most recent call last):
File "***/gemma-api/bug.py", line 51, in <module>
outputs = model.generate(**inputs, max_new_tokens=512, do_sample=False)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "***/gemma-api-r8Np3HKg-py3.12/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "***/gemma-api-r8Np3HKg-py3.12/lib/python3.12/site-packages/transformers/generation/utils.py", line 2634, in generate
result = self._sample(
^^^^^^^^^^^^^
File "/***gemma-api-r8Np3HKg-py3.12/lib/python3.12/site-packages/transformers/generation/utils.py", line 3615, in _sample
outputs = self(**model_inputs, return_dict=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/***/gemma-api-r8Np3HKg-py3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/***/gemma-api-r8Np3HKg-py3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/***gemma-api-r8Np3HKg-py3.12/lib/python3.12/site-packages/transformers/utils/generic.py", line 959, in wrapper
output = func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/***/gemma-api-r8Np3HKg-py3.12/lib/python3.12/site-packages/transformers/models/gemma3n/modeling_gemma3n.py", line 2283, in forward
outputs = self.model(
^^^^^^^^^^^
File "/***/gemma-api-r8Np3HKg-py3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/***/gemma-api-r8Np3HKg-py3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "***/gemma-api-r8Np3HKg-py3.12/lib/python3.12/site-packages/transformers/utils/generic.py", line 959, in wrapper
output = func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "***/gemma-api-r8Np3HKg-py3.12/lib/python3.12/site-packages/transformers/models/gemma3n/modeling_gemma3n.py", line 2117, in forward
_, special_audio_mask = self.get_placeholder_mask(
^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Gemma3nModel.get_placeholder_mask() missing 1 required positional argument: 'image_features'
and it can be fixed by modifying the modeling_gemma3n.py file in lines 1966 to 1969, by making the arguments optional:
OLD
def get_placeholder_mask(
self,
input_ids: torch.LongTensor,
inputs_embeds: torch.FloatTensor,
image_features: torch.FloatTensor,
audio_features: torch.FloatTensor,
):
FIXED
def get_placeholder_mask(
self,
input_ids: torch.LongTensor | None = None,
inputs_embeds: torch.FloatTensor | None = None,
image_features: torch.FloatTensor | None = None,
audio_features: torch.FloatTensor | None = None,
):
Expected behavior
Running without an error
System Info
transformersversion: 4.55.0Who can help?
@qubvel @eustlb
Information
Tasks
examplesfolder (such as GLUE/SQuAD, ...)Reproduction
Running the Gemma3n model on audio data using the description from huggingface:
results in the error:
and it can be fixed by modifying the
modeling_gemma3n.pyfile in lines 1966 to 1969, by making the arguments optional:OLD
FIXED
Expected behavior
Running without an error