diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 7bed417c0109..0a718ba6d512 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -1182,6 +1182,8 @@ title: Gemma3 - local: model_doc/gemma3n title: Gemma3n + - local: model_doc/gemma4 + title: Gemma4 - local: model_doc/git title: GIT - local: model_doc/glm46v diff --git a/docs/source/en/model_doc/gemma4.md b/docs/source/en/model_doc/gemma4.md new file mode 100644 index 000000000000..cc6b0bc8471a --- /dev/null +++ b/docs/source/en/model_doc/gemma4.md @@ -0,0 +1,242 @@ + +*This model was released on {release_date} and added to Hugging Face Transformers on 2026-04-01.* + + +# Gemma4 + +## Overview + +[Gemma 4](INSET_PAPER_LINK) is a multimodal model with pretrained and instruction-tuned variants, available in 1B, 13B, and 27B parameters. The architecture is mostly the same as the previous Gemma versions. The key differences are a vision processor that can output images of fixed token budget and a spatial 2D RoPE to encode vision-specific information across height and width axis. + +You can find all the original Gemma 4 checkpoints under the [Gemma 4](https://huggingface.co/collections/google/gemma-4-release-67c6c6f89c4f76621268bb6d) release. + +### Gemma4 Vision Model + +The key difference from previous Gemma releases is the new design to process **images of different sizes** using a **fixed-budget number of tokens**. Unlike many models that squash every image into a fixed square (like 224×224), Gemma 4 keeps the image's natural aspect ratio while making it the right size. There a a couple constraints to follow: +- The total number of pixels must fit within a patch budget +- Both height and width must be divisible by **48** (= patch size 16 × pooling kernel 3) + +> [!IMPORTANT] +> Gemma 4 does **not** apply the standard ImageNet mean/std normalization that many other vision models use. The model's own patch embedding layer handles the final scaling internally (shifting values to the [-1, 1] range). + +The number of "soft tokens" (aka vision tokens) an image processor can produce is configurable. The supported options are outlined below and the default is **280 soft tokens** per image. + + +| Soft Tokens | Patches (before pooling) | Approx. Image Area | +|:-----------:|:------------------------:|:-------------------:| +| 70 | 630 | ~161K pixels | +| 140 | 1,260 | ~323K pixels | +| **280** | **2,520** | **~645K pixels** | +| 560 | 5,040 | ~1.3M pixels | +| 1,120 | 10,080 | ~2.6M pixels | + + +To encode positional information for each patch in the image, Gemma 4 uses a learned 2D position embedding table. The position table stores up to 10,240 positions per axis, which allows the model to handle very large images. Each position is a learned vector of the same dimensions as the patch embedding. The 2D RoPE which Gemma 4 uses independently rotate half the attention head dimensions for the x-axis and the other half for the y-axis. This allows the model to understand spatial relationships like "above," "below," "left of," and "right of." + + + +## Usage examples + +The example below demonstrates how to generate text based on an image with [`Pipeline`] or the [`AutoModel`] class. + + + + +```py +import torch +from transformers import pipeline + +pipeline = pipeline( + task="image-text-to-text", + model="google/gemma-4-2b-pt", + device=0, + dtype=torch.bfloat16 +) +pipeline( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg", + text=" What is shown in this image?" +) +``` + + + + +```py +import torch +from transformers import AutoProcessor, AutoModelForImageTextToText + +model = AutoModelForImageTextToText.from_pretrained( + "google/gemma-4-2b-it", + dtype=torch.bfloat16, + device_map="auto", + attn_implementation="sdpa" +) +processor = AutoProcessor.from_pretrained( + "google/gemma-4-2b-it", + padding_side="left" +) + +messages = [ + { + "role": "user", "content": [ + {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"}, + {"type": "text", "text": "What is shown in this image?"}, + ] + }, +] +inputs = processor.apply_chat_template( + messages, + tokenize=True, + return_dict=True, + return_tensors="pt", + add_generation_prompt=True, +).to(model.device) + +output = model.generate(**inputs, max_new_tokens=50, cache_implementation="static") +print(processor.decode(output[0], skip_special_tokens=True)) +``` + +### Function callin + +TODO: add decent examples, I am no good with tools and agents + +### Quantization + +Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends. + +The example below uses [torchao](../quantization/torchao) to only quantize the weights to int4. + +```py +# pip install torchao +import torch +from transformers import TorchAoConfig, Gemma4ForConditionalGeneration, AutoProcessor + +quantization_config = TorchAoConfig("int4_weight_only", group_size=128) +model = Gemma4ForConditionalGeneration.from_pretrained( + "google/gemma-4-2b-it", + dtype=torch.bfloat16, + device_map="auto", + quantization_config=quantization_config +) +processor = AutoProcessor.from_pretrained( + "google/gemma-2-2b-it", + padding_side="left" +) + +messages = [ + { + "role": "system", + "content": [ + {"type": "text", "text": "You are a helpful assistant."} + ] + }, + { + "role": "user", "content": [ + {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"}, + {"type": "text", "text": "What is shown in this image?"}, + ] + }, +] +inputs = processor.apply_chat_template( + messages, + tokenize=True, + return_dict=True, + return_tensors="pt", + add_generation_prompt=True, +).to(model.device) + +output = model.generate(**inputs, max_new_tokens=50, cache_implementation="static") +print(processor.decode(output[0], skip_special_tokens=True)) +``` + +## Gemma4AudioConfig + +[[autodoc]] Gemma4AudioConfig + +## Gemma4VisionConfig + +[[autodoc]] Gemma4VisionConfig + +## Gemma4TextConfig + +[[autodoc]] Gemma4TextConfig + +## Gemma4Config + +[[autodoc]] Gemma4Config + +## Gemma4AudioFeatureExtractor + +[[autodoc]] Gemma4AudioFeatureExtractor + - __call__ + +## Gemma4ImageProcessorPil + +[[autodoc]] Gemma4ImageProcessorPil + - preprocess + +## Gemma4ImageProcessor + +[[autodoc]] Gemma4ImageProcessor + - preprocess + +## Gemma4VideoProcessor + +[[autodoc]] Gemma4VideoProcessor + - preprocess + +## Gemma4Processor + +[[autodoc]] Gemma4Processor + - __call__ + +## Gemma4PreTrainedModel + +[[autodoc]] Gemma4PreTrainedModel + - forward + +## Gemma4AudioModel + +[[autodoc]] Gemma4AudioModel + - forward + +## Gemma4VisionModel + +[[autodoc]] Gemma4VisionModel + - forward + +## Gemma4TextModel + +[[autodoc]] Gemma4TextModel + - forward + +## Gemma4ForCausalLM + +[[autodoc]] Gemma4ForCausalLM + +## Gemma4Model + +[[autodoc]] Gemma4Model + - forward + +## Gemma4ForConditionalGeneration + +[[autodoc]] Gemma4ForConditionalGeneration + - forward \ No newline at end of file diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 51612dddb2cc..be0ec0497c87 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -1285,15 +1285,16 @@ def _prepare_position_ids(model_kwargs: dict[str, Any], new_length: int, is_enco def _prepare_token_type_ids(model_kwargs: dict[str, Any], new_length: int) -> dict[str, Any]: """Expands or crops the model's token_type_ids for decoding purposes, to the defined length""" - if "token_type_ids" not in model_kwargs or model_kwargs["token_type_ids"] is None: + if model_kwargs.get("token_type_ids") is None: return model_kwargs + # Multimodal models call this arg `mm_token_type_ids` token_type_ids = model_kwargs["token_type_ids"] final_token_type = token_type_ids[:, -1].unsqueeze(-1) type_length_diff = new_length - token_type_ids.shape[1] if type_length_diff < 0: - token_type_ids = token_type_ids[:, :type_length_diff] + model_kwargs["token_type_ids"] = token_type_ids[:, :type_length_diff] elif type_length_diff > 0: token_type_copies = final_token_type.repeat(1, type_length_diff) model_kwargs["token_type_ids"] = torch.cat([model_kwargs["token_type_ids"], token_type_copies], dim=-1) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 0ee06a3dd389..ffb7266a5b2f 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -538,7 +538,7 @@ def prepare_inputs_for_generation( model_inputs["token_type_ids"] = token_type_ids # 3. Slice model inputs if it's an input that should have the same length as `input_ids` - for model_input_name in [position_ids_key, "token_type_ids"]: + for model_input_name in [position_ids_key, "token_type_ids", "mm_token_type_ids"]: model_input = model_inputs.get(model_input_name) if model_input is not None and model_input.shape[-1] != sequence_length: # Input can be 2D or 3D, and we always slice on `seq-length` (last dim) @@ -567,7 +567,9 @@ def prepare_inputs_for_generation( attention_mask=attention_mask, past_key_values=model_inputs.get("past_key_values"), position_ids=model_inputs.get(position_ids_key), + # The following kwargs are not used in the main function - only on a few models with overloaded `create_masks_for_generate` token_type_ids=model_inputs.get("token_type_ids"), + mm_token_type_ids=model_inputs.get("mm_token_type_ids"), is_first_iteration=is_first_iteration, ) @@ -919,6 +921,12 @@ def _update_model_kwargs_for_generation( if (token_type_ids := model_kwargs.get("token_type_ids")) is not None: model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -num_new_tokens:]], dim=-1) + # update mm_token_type_ids with zeros (only-text) + if (mm_token_type_ids := model_kwargs.get("mm_token_type_ids")) is not None: + model_kwargs["mm_token_type_ids"] = torch.cat( + [mm_token_type_ids, mm_token_type_ids.new_zeros((mm_token_type_ids.shape[0], num_new_tokens))], dim=-1 + ) + # Position ids (2D or 3D sometimes) position_ids_key = "position_ids" if not is_encoder_decoder else "decoder_position_ids" if (position_ids := model_kwargs.get(position_ids_key)) is not None: diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index d8f308d3d0bd..213b91e3a115 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -596,10 +596,9 @@ def __init__( self.block_size = block_size self.hidden_dim = config.hidden_size self.activation_scheme = activation_scheme - self.num_experts = config.num_local_experts if hasattr(config, "num_local_experts") else config.num_experts - self.intermediate_dim = ( - config.moe_intermediate_size if hasattr(config, "moe_intermediate_size") else config.intermediate_size - ) + self.num_experts = getattr(config, "num_local_experts", config.num_experts) + self.intermediate_dim = getattr(config, "moe_intermediate_size", config.intermediate_size) + self.act_fn = ACT2FN[getattr(config, "hidden_activation", config.hidden_act)] if self.has_gate: gu_proj_out, gu_proj_in = 2 * self.intermediate_dim, self.hidden_dim @@ -633,8 +632,6 @@ def __init__( self.gate_up_proj_activation_scale = nn.Parameter(torch.ones(self.num_experts, dtype=torch.float32)) self.down_proj_activation_scale = nn.Parameter(torch.ones(self.num_experts, dtype=torch.float32)) - self.act_fn = ACT2FN[config.hidden_act] - def _apply_gate(self, gate_up: torch.Tensor) -> torch.Tensor: gate, up = gate_up.chunk(2, dim=-1) return self.act_fn(gate) * up diff --git a/src/transformers/modeling_rope_utils.py b/src/transformers/modeling_rope_utils.py index 15403eea0519..7a2bb8c8ea24 100644 --- a/src/transformers/modeling_rope_utils.py +++ b/src/transformers/modeling_rope_utils.py @@ -14,6 +14,7 @@ import math import warnings +from collections.abc import Callable from functools import wraps from typing import TYPE_CHECKING, Optional, TypedDict @@ -183,6 +184,76 @@ def _compute_linear_scaling_rope_parameters( return inv_freq, attention_factor +def _compute_proportional_rope_parameters( + config: Optional["PreTrainedConfig"] = None, + device: Optional["torch.device"] = None, + seq_len: int | None = None, + layer_type: str | None = None, + head_dim_key: str = "head_dim", +) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies with proportional RoPE. + + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. This function assumes that the config will provide at least the following + properties: + + * rope_theta (`float`): The base wavelength from which the inverse frequencies will be derived. + * hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly. + * num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly. + + Additionally, this function will make use of the following properties if they are found in the config: + + * head_dim (`int`, *optional*): The size of the key-value heads in the model. If None, this value will be + derived as hidden_size // num_attention_heads. + * partial_rotary_factor (`float`, *optional*, defaults to 1.0): The proportion of the embedding dimension + to apply rotary positional encoding, e.g., [0.0, 0.25, 0.5, 0.75, 1.0]. Unlike other RoPE functions + that use this parameter, proportional RoPE will always return an encoding that is the size of + `head_dim`. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + # For backward compatibility standardize the `rope_parameters_dict` if it uses old format + config.standardize_rope_params() + rope_parameters_dict = config.rope_parameters[layer_type] if layer_type is not None else config.rope_parameters + + head_dim = getattr(config, head_dim_key, None) or config.hidden_size // config.num_attention_heads + base = rope_parameters_dict["rope_theta"] + factor = rope_parameters_dict.get("factor", 1.0) + rope_proportion = rope_parameters_dict.get("partial_rotary_factor", 1.0) + + attention_factor = 1.0 # Unused in this type of RoPE + + rope_angles = int(rope_proportion * head_dim // 2) + + inv_freq_rotated = 1.0 / ( + base + ** (torch.arange(0, 2 * rope_angles, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / head_dim) + ) + + nope_angles = head_dim // 2 - rope_angles + if nope_angles > 0: + inv_freq = torch.cat( + ( + inv_freq_rotated, + torch.zeros(nope_angles, dtype=torch.float32, device=device), + ), + dim=0, + ) + else: + inv_freq = inv_freq_rotated + + inv_freq /= factor + return inv_freq, attention_factor + + def _compute_dynamic_ntk_parameters( config: Optional["PreTrainedConfig"] = None, device: Optional["torch.device"] = None, @@ -558,16 +629,17 @@ def _compute_llama3_parameters( # This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters # from the model config. You can append new {'rope_type': callable} pairs to this rope_parameters to enable custom RoPE # parameterizations, as long as the callable has the same signature. -ROPE_INIT_FUNCTIONS = { +ROPE_INIT_FUNCTIONS: dict[str, Callable[..., tuple["torch.Tensor", float]]] = { "linear": _compute_linear_scaling_rope_parameters, "dynamic": _compute_dynamic_ntk_parameters, "yarn": _compute_yarn_parameters, "longrope": _compute_longrope_parameters, "llama3": _compute_llama3_parameters, + "proportional": _compute_proportional_rope_parameters, } -class RopeParameters(TypedDict, total=False): +class RopeParameters(TypedDict): """ Args: rope_theta (`float`): @@ -896,6 +968,20 @@ def _validate_llama3_rope_parameters(self, rope_parameters: dict, ignore_keys: s f"{original_max_position_embeddings} and max_position_embeddings={self.max_position_embeddings}" ) + def _validate_proportional_rope_parameters(self, rope_parameters: dict, ignore_keys: set | None = None): + required_keys = {"rope_type", "rope_theta"} + rope_type = rope_parameters["rope_type"] + received_keys = set(rope_parameters.keys()) + self._check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) + + partial_rotary_factor = rope_parameters.get("partial_rotary_factor") + if partial_rotary_factor is None: + logger.warning( + "`rope_parameters`'s partial_rotary_factor is None. This will default to 1.0 in the computation, " + "making this equivalent to the linear_scaling RoPE type. Provide a value in the range [0.0, 1.0) to " + "make use of the proportional RoPE funcitonality." + ) + @staticmethod def _check_received_keys( rope_type: str, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 550197fd62f6..989be9eb114e 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -152,6 +152,7 @@ from .gemma2 import * from .gemma3 import * from .gemma3n import * + from .gemma4 import * from .git import * from .glm import * from .glm4 import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 13ae48545a20..2c0fe88d0e74 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -174,6 +174,10 @@ ("gemma3n_audio", "Gemma3nAudioConfig"), ("gemma3n_text", "Gemma3nTextConfig"), ("gemma3n_vision", "Gemma3nVisionConfig"), + ("gemma4", "Gemma4Config"), + ("gemma4_audio", "Gemma4AudioConfig"), + ("gemma4_text", "Gemma4TextConfig"), + ("gemma4_vision", "Gemma4VisionConfig"), ("git", "GitConfig"), ("glm", "GlmConfig"), ("glm4", "Glm4Config"), @@ -689,6 +693,10 @@ ("gemma3n_audio", "Gemma3nAudioEncoder"), ("gemma3n_text", "Gemma3nForCausalLM"), ("gemma3n_vision", "TimmWrapperModel"), + ("gemma4", "Gemma4ForConditionalGeneration"), + ("gemma4_audio", "Gemma4AudioModel"), + ("gemma4_text", "Gemma4ForCausalLM"), + ("gemma4_vision", "Gemma4VisionModel"), ("git", "GIT"), ("glm", "GLM"), ("glm4", "GLM4"), @@ -1098,6 +1106,9 @@ ("gemma3n_audio", "gemma3n"), ("gemma3n_text", "gemma3n"), ("gemma3n_vision", "gemma3n"), + ("gemma4_audio", "gemma4"), + ("gemma4_text", "gemma4"), + ("gemma4_vision", "gemma4"), ("glm4v_vision", "glm4v"), ("glm4v_moe_vision", "glm4v_moe"), ("glm4v_text", "glm4v"), diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index 69745c5847be..111c56efb436 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -46,6 +46,7 @@ ("dia", "DiaFeatureExtractor"), ("encodec", "EncodecFeatureExtractor"), ("gemma3n", "Gemma3nAudioFeatureExtractor"), + ("gemma4", "Gemma4AudioFeatureExtractor"), ("glmasr", "WhisperFeatureExtractor"), ("granite_speech", "GraniteSpeechFeatureExtractor"), ("higgs_audio_v2_tokenizer", "DacFeatureExtractor"), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 1e868161c9f2..7ae16550008f 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -125,6 +125,7 @@ ("fuyu", {"torchvision": "FuyuImageProcessor", "pil": "FuyuImageProcessorPil"}), ("gemma3", {"torchvision": "Gemma3ImageProcessor", "pil": "Gemma3ImageProcessorPil"}), ("gemma3n", {"torchvision": "SiglipImageProcessor", "pil": "SiglipImageProcessorPil"}), + ("gemma4", {"torchvision": "Gemma4ImageProcessor", "pil": "Gemma4ImageProcessorPil"}), ("git", {"torchvision": "CLIPImageProcessor", "pil": "CLIPImageProcessorPil"}), ("glm46v", {"torchvision": "Glm46VImageProcessor", "pil": "Glm46VImageProcessorPil"}), ("glm4v", {"torchvision": "Glm4vImageProcessor", "pil": "Glm4vImageProcessorPil"}), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 0958e1cba2d8..7edff302436b 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -173,6 +173,10 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("gemma3n_audio", "Gemma3nAudioEncoder"), ("gemma3n_text", "Gemma3nTextModel"), ("gemma3n_vision", "TimmWrapperModel"), + ("gemma4", "Gemma4Model"), + ("gemma4_audio", "Gemma4AudioModel"), + ("gemma4_text", "Gemma4TextModel"), + ("gemma4_vision", "Gemma4VisionModel"), ("git", "GitModel"), ("glm", "GlmModel"), ("glm4", "Glm4Model"), @@ -528,6 +532,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("fsmt", "FSMTForConditionalGeneration"), ("funnel", "FunnelForPreTraining"), ("gemma3", "Gemma3ForConditionalGeneration"), + ("gemma4", "Gemma4ForConditionalGeneration"), ("glmasr", "GlmAsrForConditionalGeneration"), ("gpt-sw3", "GPT2LMHeadModel"), ("gpt2", "GPT2LMHeadModel"), @@ -647,6 +652,8 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("gemma3_text", "Gemma3ForCausalLM"), ("gemma3n", "Gemma3nForConditionalGeneration"), ("gemma3n_text", "Gemma3nForCausalLM"), + ("gemma4", "Gemma4ForConditionalGeneration"), + ("gemma4_text", "Gemma4ForCausalLM"), ("git", "GitForCausalLM"), ("glm", "GlmForCausalLM"), ("glm4", "Glm4ForCausalLM"), @@ -970,6 +977,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("fuyu", "FuyuForCausalLM"), ("gemma3", "Gemma3ForConditionalGeneration"), ("gemma3n", "Gemma3nForConditionalGeneration"), + ("gemma4", "Gemma4ForConditionalGeneration"), ("git", "GitForCausalLM"), ("glm46v", "Glm46VForConditionalGeneration"), ("glm4v", "Glm4vForConditionalGeneration"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index cb4c9c61d601..262480b71485 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -81,6 +81,7 @@ ("fuyu", "FuyuProcessor"), ("gemma3", "Gemma3Processor"), ("gemma3n", "Gemma3nProcessor"), + ("gemma4", "Gemma4Processor"), ("git", "GitProcessor"), ("glm46v", "Glm46VProcessor"), ("glm4v", "Glm4vProcessor"), diff --git a/src/transformers/models/auto/video_processing_auto.py b/src/transformers/models/auto/video_processing_auto.py index 3b7d36a9aeb2..ca6faa440ac8 100644 --- a/src/transformers/models/auto/video_processing_auto.py +++ b/src/transformers/models/auto/video_processing_auto.py @@ -53,6 +53,7 @@ VIDEO_PROCESSOR_MAPPING_NAMES = OrderedDict( [ ("ernie4_5_vl_moe", "Ernie4_5_VLMoeVideoProcessor"), + ("gemma4", "Gemma4VideoProcessor"), ("glm46v", "Glm46VVideoProcessor"), ("glm4v", "Glm4vVideoProcessor"), ("instructblip", "InstructBlipVideoVideoProcessor"), diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py index d37df841ca17..edca10b4f48e 100644 --- a/src/transformers/models/gemma3n/modeling_gemma3n.py +++ b/src/transformers/models/gemma3n/modeling_gemma3n.py @@ -130,21 +130,18 @@ def __init__(self, dim: int, eps: float = 1e-6, with_scale: bool = True): self.with_scale = with_scale if self.with_scale: - self.weight = nn.Parameter(torch.ones(dim)) - else: - self.register_buffer("weight", torch.tensor(1.0), persistent=False) - - def _norm(self, x): - return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + self.weight = nn.Parameter(torch.ones(dim), requires_grad=True) - def forward(self, x: torch.Tensor) -> torch.Tensor: - # Llama does x.to(float16) * w whilst Gemma2 is (x * w).to(float16) - # See https://github.com/huggingface/transformers/pull/29402 - output = self._norm(x.float()) * self.weight.float() - return output.type_as(x) + def _norm(self, hidden_states: torch.Tensor): + mean_squared = hidden_states.pow(2).mean(-1, keepdim=True) + self.eps + # Use torch.pow() (over torch.sqrt() or torch.rsqrt()) to addess compiler differences between Torch and JAX + return hidden_states * torch.pow(mean_squared, -0.5) - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.eps}" + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + normed_output = self._norm(hidden_states.float()) + if self.with_scale: + normed_output = normed_output * self.weight.float() + return normed_output.type_as(hidden_states) # ==== Audio Encoder ==== diff --git a/src/transformers/models/gemma3n/modular_gemma3n.py b/src/transformers/models/gemma3n/modular_gemma3n.py index c52110fbb9e5..d5633a689687 100644 --- a/src/transformers/models/gemma3n/modular_gemma3n.py +++ b/src/transformers/models/gemma3n/modular_gemma3n.py @@ -46,7 +46,6 @@ Gemma3Attention, Gemma3DecoderLayer, Gemma3ForCausalLM, - Gemma3RMSNorm, Gemma3RotaryEmbedding, Gemma3TextModel, Gemma3TextScaledWordEmbedding, @@ -480,25 +479,25 @@ class Gemma3nCausalLMOutputWithPast(PaliGemmaCausalLMOutputWithPast): audio_hidden_states: torch.FloatTensor | None = None -class Gemma3nRMSNorm(Gemma3RMSNorm): +class Gemma3nRMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6, with_scale: bool = True): - super().__init__(dim, eps=eps) - del self.weight + super().__init__() + self.eps = eps self.with_scale = with_scale if self.with_scale: - self.weight = nn.Parameter(torch.ones(dim)) - else: - self.register_buffer("weight", torch.tensor(1.0), persistent=False) + self.weight = nn.Parameter(torch.ones(dim), requires_grad=True) - def _norm(self, x): - return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + def _norm(self, hidden_states: torch.Tensor): + mean_squared = hidden_states.pow(2).mean(-1, keepdim=True) + self.eps + # Use torch.pow() (over torch.sqrt() or torch.rsqrt()) to addess compiler differences between Torch and JAX + return hidden_states * torch.pow(mean_squared, -0.5) - def forward(self, x: torch.Tensor) -> torch.Tensor: - # Llama does x.to(float16) * w whilst Gemma2 is (x * w).to(float16) - # See https://github.com/huggingface/transformers/pull/29402 - output = self._norm(x.float()) * self.weight.float() - return output.type_as(x) + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + normed_output = self._norm(hidden_states.float()) + if self.with_scale: + normed_output = normed_output * self.weight.float() + return normed_output.type_as(hidden_states) # ==== Audio Encoder ==== diff --git a/src/transformers/models/gemma4/__init__.py b/src/transformers/models/gemma4/__init__.py new file mode 100644 index 000000000000..d108443c16cb --- /dev/null +++ b/src/transformers/models/gemma4/__init__.py @@ -0,0 +1,33 @@ +# Copyright 2026 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_gemma4 import * + from .feature_extraction_gemma4 import * + from .image_processing_gemma4 import * + from .image_processing_pil_gemma4 import * + from .modeling_gemma4 import * + from .processing_gemma4 import * + from .video_processing_gemma4 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/gemma4/configuration_gemma4.py b/src/transformers/models/gemma4/configuration_gemma4.py new file mode 100644 index 000000000000..a605d9a862ed --- /dev/null +++ b/src/transformers/models/gemma4/configuration_gemma4.py @@ -0,0 +1,347 @@ +# Copyright 2026 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Literal + +from huggingface_hub.dataclasses import strict + +from ...configuration_utils import PreTrainedConfig +from ...utils import auto_docstring, logging +from ...utils.type_validators import interval + + +logger = logging.get_logger(__name__) + + +@auto_docstring(checkpoint="google/gemma-4-e2b-it") +@strict +class Gemma4AudioConfig(PreTrainedConfig): + r""" + subsampling_conv_channels (`list[int]`, defaults to `[128, 32]`): + Channel sizes for the convolutional layers in the Sub-sample Convolution Projection. + residual_weight (`float`, defaults to `0.5`): + Scaling applied to hidden_states prior to combining with the residual in the feedforward. + attention_chunk_size (`int`, defaults to `12`): + The sub-sequence size for attention processing. + attention_context_left (`int`, defaults to `13`): + The leftward context size for the attention chunk. + attention_context_right (`int`, defaults to `0`): + The rightward context size for the attention chunk. + attention_logit_cap (`float`, defaults to `50.0`): + Cap applied to attention weights. + attention_invalid_logits_value (`float`, defaults to `1e-9`): + Value to use for invalid logits in attention. + use_clipped_linears (`bool`, defaults to `True`): + If true, apply clipping to the Linear layers, drawing bounds from the model checkpoint. + gradient_clipping (`float`, defaults to `1e10`): + Clipping value used to stabilize extremely large gradient values. + output_proj_dims (`int`, defaults to `1536`): + Dimension of the final linear projection from `hidden_size` to the model's output. + """ + + model_type = "gemma4_audio" + + hidden_size: int = 1024 + num_hidden_layers: int = 12 + num_attention_heads: int = 8 + hidden_act: str = "silu" + + # subsampling parameters + subsampling_conv_channels: list[int] | tuple[int, int] = (128, 32) + + # conformer parameters + conv_kernel_size: int = 5 + residual_weight: float = 0.5 + attention_chunk_size: int = 12 + attention_context_left: int = 13 + attention_context_right: int = 0 + attention_logit_cap: float = 50.0 + attention_invalid_logits_value: float = -1.0e9 + + use_clipped_linears: bool = True + rms_norm_eps: float = 1e-6 + gradient_clipping: float = 1e10 + output_proj_dims: int = 1536 + initializer_range: float = interval(min=0.0, max=1.0)(default=0.02) + + def __post_init__(self, **kwargs): + # JSON serialization converts tuples to lists, convert back + if isinstance(self.subsampling_conv_channels, tuple): + self.subsampling_conv_channels = list(self.subsampling_conv_channels) + super().__post_init__(**kwargs) + + +@auto_docstring(checkpoint="google/gemma-4-e2b-it") +@strict +class Gemma4TextConfig(PreTrainedConfig): + r""" + use_bidirectional_attention (`str`, *optional*): + Controls bidirectional attention behavior. When set to `"vision"`, vision tokens + attend bidirectionally while text tokens use causal attention. When set to `"all"`, + all tokens use bidirectional attention. + vocab_size_per_layer_input (`int`, defaults to 262144): + Vocabulary size for the per-layer input embeddings. Used by models with per-layer + residual streams where a smaller embedding is added at each decoder layer. + hidden_size_per_layer_input (`int`, defaults to 256): + Hidden dimension for the per-layer input embeddings. Controls the width of the + per-layer residual embedding vectors. + num_global_key_value_heads (`int`, *optional*): + Number of key-value heads for global (full) attention layers. If `None`, defaults + to `num_key_value_heads`. + global_head_dim (`int`, defaults to 512): + Dimension of each attention head in global (full) attention layers. + attention_k_eq_v (`bool`, defaults to `False`): + Whether keys and values share the same projection weights. When `True`, the key + projection output is reused as the value projection. + num_kv_shared_layers (`int`, defaults to 0): + Number of consecutive decoder layers that share the same key-value projections. + A value of 0 means no sharing (each layer has independent KV projections). + enable_moe_block (`bool`, defaults to `False`): + Whether to enable Mixture-of-Experts (MoE) blocks in the decoder layers. When + `True`, eligible layers will use a sparse MoE feed-forward network. + use_double_wide_mlp (`bool`, defaults to `False`): + Whether to use a double-width MLP with fused gate and up projections. + top_k_experts (`int`, *optional*): + Number of experts activated per token in MoE layers. Only used when + `enable_moe_block=True`. + moe_intermediate_size (`int`, *optional*): + Intermediate (hidden) size of each expert's feed-forward network in MoE layers. + Only used when `enable_moe_block=True`. + """ + + model_type = "gemma4_text" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.q_norm": "replicated_with_grad_allreduce", + "layers.*.self_attn.k_norm": "replicated_with_grad_allreduce", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + vocab_size: int = 262_144 + hidden_size: int = 2304 + intermediate_size: int = 9216 + num_hidden_layers: int = 30 + num_attention_heads: int = 8 + num_key_value_heads: int = 4 + head_dim: int = 256 + hidden_activation: str = "gelu_pytorch_tanh" + max_position_embeddings: int = 131_072 + initializer_range: float = 0.02 + rms_norm_eps: float = 1e-6 + use_cache: bool = True + pad_token_id: int | None = 0 + eos_token_id: int | list[int] | None = 1 + bos_token_id: int | None = 2 + tie_word_embeddings: bool = True + rope_parameters: dict | None = None + attention_bias: bool = False + attention_dropout: int | float | None = 0.0 + sliding_window: int = 512 + layer_types: list[str] | None = None + final_logit_softcapping: float | None = None + use_bidirectional_attention: Literal["all", "vision"] | None = None + vocab_size_per_layer_input: int = 262_144 + hidden_size_per_layer_input: int = 256 + num_global_key_value_heads: int | None = None + global_head_dim: int = 512 + attention_k_eq_v: bool = False + num_kv_shared_layers: int = 0 + enable_moe_block: bool = False + use_double_wide_mlp: bool = False + num_experts: int | None = None + top_k_experts: int | None = None + moe_intermediate_size: int | None = None + + def __post_init__(self, **kwargs): + if self.use_bidirectional_attention == "all": + self.sliding_window = (self.sliding_window // 2) + 1 # due to fa we set exclusive bounds + + if self.layer_types is None: + sliding_window_pattern = 6 # by default 5:1 + self.layer_types = [ + "sliding_attention" if bool((i + 1) % sliding_window_pattern) else "full_attention" + for i in range(self.num_hidden_layers) + ] + + if self.layer_types and (last_layer_type := self.layer_types[-1]) != "full_attention": + logger.warning( + f"Last layer must use `full_attention`, but got `{last_layer_type}`. Forcing last layer to `full_attention`." + ) + self.layer_types[-1] = "full_attention" + + default_rope_params: dict[Literal["full_attention", "sliding_attention"] : dict[str, Any]] = { + "sliding_attention": {"rope_type": "default", "rope_theta": 10_000.0}, + "full_attention": {"rope_type": "proportional", "partial_rotary_factor": 0.25, "rope_theta": 1_000_000.0}, + } + if self.rope_parameters is None: + self.rope_parameters = default_rope_params + + super().__post_init__(**kwargs) + + def convert_rope_params_to_dict(self, **kwargs): + # No need to handle BC for new models, because they have no old-format `rope_scaling` + return kwargs + + +@auto_docstring(checkpoint="google/gemma-4-e2b-it") +@strict +class Gemma4VisionConfig(PreTrainedConfig): + r""" + pooling_kernel_size (`int`, *optional*): + Spatial pooling kernel size applied after patchification. + position_embedding_size (`int`, defaults to 10240): + Maximum number of position embeddings for the vision encoder. Controls the size of + the learned 2D position embedding table used by the patch embedder. + use_clipped_linears (`bool`, defaults to `False`): + Whether to use weight-clipped linear layers. When enabled, linear layer weights are + clamped to a fixed range during the forward pass to improve numerical stability. + standardize (`bool`, defaults to `False`): + If true, applies a bias and scale to the soft tokens returned from the pooler. + """ + + model_type = "gemma4_vision" + base_model_tp_plan = { + "encoder.layers.*.self_attn.q_proj": "colwise", + "encoder.layers.*.self_attn.k_proj": "colwise", + "encoder.layers.*.self_attn.v_proj": "colwise", + "encoder.layers.*.self_attn.q_norm": "replicated_with_grad_allreduce", + "encoder.layers.*.self_attn.k_norm": "replicated_with_grad_allreduce", + "encoder.layers.*.self_attn.o_proj": "rowwise", + "encoder.layers.*.mlp.gate_proj": "colwise", + "encoder.layers.*.mlp.up_proj": "colwise", + "encoder.layers.*.mlp.down_proj": "rowwise", + } + default_theta = 100.0 + + hidden_size: int = 768 + intermediate_size: int = 3072 + num_hidden_layers: int = 16 + num_attention_heads: int = 12 + num_key_value_heads: int = 12 + head_dim: int = 64 + hidden_activation: str = "gelu_pytorch_tanh" + rms_norm_eps: float = 1e-6 + max_position_embeddings: int = 131_072 + attention_bias: bool | None = False + attention_dropout: float | None = 0.0 + rope_parameters: dict | None = None + pooling_kernel_size: int = 3 + patch_size: int = 16 + position_embedding_size: int = 10 * 1024 + use_clipped_linears: bool = False + standardize: bool = False + initializer_range: float = 0.02 + + def __post_init__(self, **kwargs): + if self.rope_parameters is None: + self.rope_parameters = {"rope_type": "default", "rope_theta": 100.0} + + super().__post_init__(**kwargs) + + +@auto_docstring(checkpoint="google/gemma-4-e2b-it") +@strict +class Gemma4Config(PreTrainedConfig): + r""" + boi_token_id (`int`, *optional*, defaults to 255999): + The begin-of-image token index to wrap the image prompt. + eoi_token_id (`int`, *optional*, defaults to 258882): + The end-of-image token index to wrap the image prompt. + boa_token_id (`int`, *optional*, defaults to 256000): + The begin-of-audio token index to wrap the audio prompt. + eoa_token_index (`int`, *optional*, defaults to 258883): + The end-of-audio token index to wrap the audio prompt. + + Example: + + ```python + >>> from transformers import ( + >>> Gemma4AudioConfig, + >>> Gemma4Config, + >>> Gemma4ForConditionalGeneration, + >>> Gemma4TextConfig, + >>> Gemma4VisionConfig, + >>> ) + + >>> # Initializing a Gemma 4 Audio config. + >>> audio_config = Gemma4AudioConfig() + + >>> # Initializing a Gemma 4 Text config. + >>> text_config = Gemma4TextConfig() + + >>> # Initializing a Gemma 4 vision config. + >>> vision_config = Gemma4VisionConfig() + + >>> # Initializing a Gemma 4 config similar to google/gemma-4-e2b-it + >>> configuration = Gemma4Config(text_config, vision_config, audio_config) + + >>> # Initializing a model from the google/gemma-4-e2b-it configuration + >>> model = Gemma4ForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "gemma4" + sub_configs = { + "text_config": Gemma4TextConfig, + "vision_config": Gemma4VisionConfig, + "audio_config": Gemma4AudioConfig, + } + + text_config: Gemma4TextConfig | dict[str, Any] | None = None + vision_config: Gemma4VisionConfig | dict[str, Any] | None = None + audio_config: Gemma4AudioConfig | dict[str, Any] | None = None + boi_token_id: int | None = 255_999 + eoi_token_id: int | None = 258_882 + image_token_id: int | None = 258_880 + video_token_id: int | None = 258_884 + boa_token_id: int | None = 256_000 + eoa_token_index: int | None = 258_883 + audio_token_id: int | None = 258_881 + initializer_range: float | None = 0.02 + tie_word_embeddings: bool = True + + def __post_init__(self, **kwargs): + if self.text_config is None: + self.text_config = Gemma4TextConfig() + logger.info("text_config is None. Using default Gemma4TextConfig.") + elif isinstance(self.text_config, dict): + self.text_config = Gemma4TextConfig(**self.text_config) + + if self.vision_config is None: + logger.info("vision_config is None. Gemma4Model.vision_tower will not be initialized.") + if isinstance(self.vision_config, dict): + self.vision_config = Gemma4VisionConfig(**self.vision_config) + + if self.audio_config is None: + logger.info("audio_config is None. Gemma4Model.audio_tower will not be initialized.") + if isinstance(self.audio_config, dict): + self.audio_config = Gemma4AudioConfig(**self.audio_config) + + super().__post_init__(**kwargs) + + +__all__ = ["Gemma4AudioConfig", "Gemma4Config", "Gemma4TextConfig", "Gemma4VisionConfig"] diff --git a/src/transformers/models/gemma4/convert_gemma4_weights.py b/src/transformers/models/gemma4/convert_gemma4_weights.py new file mode 100644 index 000000000000..cc9005afc8f8 --- /dev/null +++ b/src/transformers/models/gemma4/convert_gemma4_weights.py @@ -0,0 +1,1231 @@ +# Copyright 2026 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Utility to convert Gemma models from Orbax to HF Transformers checkpoint. + +python src/transformers/models/gemma4/convert_gemma4_weights.py \ + --variant='gemma-4-e2b' \ + --include_chat_template \ + --include_response_schema \ + --tokenizer_path="$HOME/tokenizers/gemma4/gemma4_cleaned_262144.model" \ + --checkpoint_path="$HOME/gemma4/checkpoints/gemma_e2b_it_orbax" \ + --output_path="$HOME/gemma4/checkpoints/gemma_e2b_it_safetensors" +""" + +import ast +import json +import os +import pathlib +from collections.abc import Iterable, Mapping +from typing import Any + +import accelerate +import jax +import numpy as np +import torch +import tree +from absl import app, flags, logging +from jax.sharding import SingleDeviceSharding +from orbax import checkpoint as obc +from orbax.checkpoint import args as obc_args +from orbax.checkpoint import type_handlers + +from transformers import ( + Gemma4AudioConfig, + Gemma4AudioFeatureExtractor, + Gemma4Config, + Gemma4ForCausalLM, + Gemma4ForConditionalGeneration, + Gemma4ImageProcessor, + Gemma4Processor, + Gemma4TextConfig, + Gemma4VideoProcessor, + Gemma4VisionConfig, + GemmaTokenizer, + GenerationConfig, + RopeParameters, +) +from transformers.tokenization_utils_sentencepiece import SentencePieceExtractor +from transformers.utils.hub import cached_file + + +# ==== Internal Constants and Classes ==== + +# The correct chat templates were already uploaded to those 2 repos, so download from there +_CHAT_TEMPLATE = pathlib.Path(cached_file("gg-hf-gg/gemma-4-E4B-it", "chat_template.jinja")).read_text() +_CHAT_TEMPLATE_LARGE = pathlib.Path(cached_file("gg-hf-gg/gemma-4-31B-it", "chat_template.jinja")).read_text() + +_RESPONSE_SCHEMA = { + "type": "object", + "properties": { + "role": {"const": "assistant"}, + "thinking": { + "type": "string", + "x-regex": r"<\|channel\>(?:thought\n)?(.+?)", + }, + "tool_calls": { + "x-regex-iterator": r"<\|tool_call\>(.*?)", + "type": "array", + "items": { + "type": "object", + "properties": { + "type": {"const": "function"}, + "function": { + "type": "object", + "properties": { + "name": {"type": "string", "x-regex": r"call:([^{]+)"}, + "arguments": { + "type": "string", + "x-regex": r"call:[^{]+(\{.*\})", + "x-mapping-regex": {r"<\|\"\|>": '"', r"(\{|,)\s*([a-zA-Z_]\w+):": r'\1"\2":'}, + }, + }, + }, + }, + }, + }, + }, +} + +_DTYPES = {"float32", "bfloat16", "float16"} + +_SLIDING_WINDOW_PATTERN = 6 + +_AUDIO_ENCODER_PARAMETER = "AudioEncoder/encoder" +_AUDIO_ENCODER_CONFORMER = f"{_AUDIO_ENCODER_PARAMETER}/conformer/stacked_layers" +_AUDIO_ENCODER_SSCP = f"{_AUDIO_ENCODER_PARAMETER}/feature" + +_TRANSFORMER_PARAMETER = "transformer" +_TRANSFORMER_DECODER_BLOCK = f"{_TRANSFORMER_PARAMETER}/stacked_layers/attention_type_" +_TRANSFORMER_DECODER_BLOCK_LEN = len(_TRANSFORMER_DECODER_BLOCK) +_TRANSFORMER_EMBEDDER = f"{_TRANSFORMER_PARAMETER}/embedder" +_TRANSFORMER_FINAL_NORM = "transformer/final_norm" +_TRANSFORMER_POST_TRAINING_PREFIX = "rlx_networks/policy_network/" +_TRANSFORMER_POST_TRAINING_PREFIX_LEN = len(_TRANSFORMER_POST_TRAINING_PREFIX) + +_VISION_ENCODER_PARAMETER = "PatchInputVariablePoolingEncoder_0" +_VISION_ENCODER_VIT_PARAMETER = f"{_VISION_ENCODER_PARAMETER}/_model/vit" +_VISION_ENCODER_ENTRY = f"{_VISION_ENCODER_VIT_PARAMETER}/entry" +_VISION_ENCODER_EXIT = f"{_VISION_ENCODER_VIT_PARAMETER}/exit" +_VISION_ENCODER_STANDARDIZE = f"{_VISION_ENCODER_PARAMETER}/standardize" +_VISION_ENCODER_TRANSFORMER = f"{_VISION_ENCODER_VIT_PARAMETER}/transformer/stacked_layers/block" + +_VARIANT_GEMMA_4_E2B = "gemma-4-e2b" +_VARIANT_GEMMA_4_E4B = "gemma-4-e4b" +_VARIANT_GEMMA_4_26B_A4B = "gemma-4-26b-a4b" +_VARIANT_GEMMA_4_31B = "gemma-4-31b" + +_LARGE_MODEL_VARIANTS = { + _VARIANT_GEMMA_4_31B, + _VARIANT_GEMMA_4_26B_A4B, +} + +_ON_DEVICE_VISION_CONFIG = Gemma4VisionConfig( + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=16, + num_attention_heads=12, + num_key_value_heads=12, + head_dim=64, + global_head_dim=64, + default_output_length=280, + pooling_kernel_size=3, + use_clipped_linears=True, +) + +_LARGE_MODEL_VISION_CONFIG = Gemma4VisionConfig( + hidden_size=1152, + intermediate_size=4304, + num_hidden_layers=27, + num_attention_heads=16, + num_key_value_heads=16, + head_dim=72, + global_head_dim=72, + default_output_length=280, + pooling_kernel_size=3, + use_clipped_linears=False, + standardize=True, +) + +_DEFAULT_AUDIO_CONFIG = Gemma4AudioConfig() + +_ROPE_PARAMS: dict[str, RopeParameters] = { + "full_attention": RopeParameters( + rope_theta=1_000_000.0, + rope_type="proportional", + partial_rotary_factor=0.25, + ), + "sliding_attention": RopeParameters( + rope_theta=10000.0, + rope_type="default", + ), +} + +_DEFAULT_LAYER_TYPES = ["sliding_attention"] * 5 + ["full_attention"] +_GEMMA_4_E2B_LAYER_TYPES = ["sliding_attention"] * 4 + ["full_attention"] + + +_VARIANTS: Mapping[str, Gemma4Config] = { + _VARIANT_GEMMA_4_E2B: Gemma4Config( + text_config=Gemma4TextConfig( + hidden_size=1536, + hidden_size_per_layer_input=256, + intermediate_size=4 * 1536, + num_hidden_layers=35, + layer_types=_GEMMA_4_E2B_LAYER_TYPES * 7, + num_attention_heads=8, + num_key_value_heads=1, + num_global_key_value_heads=None, + attention_k_eq_v=False, + use_bidirectional_attention=None, + num_kv_shared_layers=20, + use_double_wide_mlp=True, + final_logit_softcapping=30.0, + rope_parameters=_ROPE_PARAMS, + ), + vision_config=_ON_DEVICE_VISION_CONFIG, + audio_config=_DEFAULT_AUDIO_CONFIG, + vision_soft_tokens_per_image=280, + ), + _VARIANT_GEMMA_4_E4B: Gemma4Config( + text_config=Gemma4TextConfig( + hidden_size=2560, + hidden_size_per_layer_input=256, + intermediate_size=4 * 2560, + num_hidden_layers=42, + layer_types=_DEFAULT_LAYER_TYPES * 7, + num_attention_heads=8, + num_key_value_heads=2, + num_global_key_value_heads=None, + global_head_dim=512, # Global attention layers use 512-dim heads + attention_k_eq_v=False, + use_bidirectional_attention=None, + num_kv_shared_layers=18, + final_logit_softcapping=30.0, + rope_parameters=_ROPE_PARAMS, + ), + vision_config=_ON_DEVICE_VISION_CONFIG, + vision_soft_tokens_per_image=280, + audio_config=_DEFAULT_AUDIO_CONFIG, + ), + _VARIANT_GEMMA_4_31B: Gemma4Config( + text_config=Gemma4TextConfig( + hidden_size=5376, + hidden_size_per_layer_input=0, + intermediate_size=4 * 5376, + num_hidden_layers=60, + layer_types=_DEFAULT_LAYER_TYPES * 10, + num_attention_heads=32, + num_key_value_heads=16, + num_global_key_value_heads=4, + attention_k_eq_v=True, + use_bidirectional_attention="vision", + num_kv_shared_layers=0, + sliding_window=1024, + final_logit_softcapping=30.0, + rope_parameters=_ROPE_PARAMS, + max_position_embeddings=262_144, + ), + vision_config=_LARGE_MODEL_VISION_CONFIG, + vision_soft_tokens_per_image=280, + ), + _VARIANT_GEMMA_4_26B_A4B: Gemma4Config( + text_config=Gemma4TextConfig( + hidden_size=2816, + hidden_size_per_layer_input=0, + intermediate_size=2112, # Shared expert FFW + num_hidden_layers=30, + layer_types=_DEFAULT_LAYER_TYPES * 5, + num_attention_heads=16, + num_key_value_heads=8, + num_global_key_value_heads=2, + attention_k_eq_v=True, + use_bidirectional_attention="vision", + num_kv_shared_layers=0, + enable_moe_block=True, + num_experts=128, + moe_intermediate_size=704, + top_k_experts=8, + sliding_window=1024, + final_logit_softcapping=30.0, + rope_parameters=_ROPE_PARAMS, + max_position_embeddings=262_144, + ), + vision_config=_LARGE_MODEL_VISION_CONFIG, + vision_soft_tokens_per_image=280, + ), +} + + +# ==== Flags ==== + +_AUDIO_DTYPE = flags.DEFINE_enum( + name="audio_dtype", + default="bfloat16", + help="The floating point precision (aka dtype) of the model.", + enum_values=_DTYPES, +) + +_CHECKPOINT_PATH = flags.DEFINE_string( + name="checkpoint_path", + default=None, + help="Path to the Orbax checkpoint.", + required=True, +) + +_INCLUDE_CHAT_TEMPLATE = flags.DEFINE_bool( + name="include_chat_template", default=False, help="If true, will save the default chat template with the tokenizer" +) + +_INCLUDE_RESPONSE_SCHEMA = flags.DEFINE_bool( + name="include_response_schema", + default=False, + help="If true, will save the default response schema with the tokenizer", +) + +_OUTPUT_PATH = flags.DEFINE_string( + name="output_path", + default=None, + help="Path to store the HF checkpoint.", + required=True, +) + +_TEXT_DTYPE = flags.DEFINE_enum( + name="text_dtype", + default="bfloat16", + help="The floating point precision (aka dtype) of the model.", + enum_values=_DTYPES, +) + +_TEXT_ONLY = flags.DEFINE_bool( + name="text_only", + default=False, + help="If True, saves a Gemma4ForCasualLM model instead of a Gemma4ForConditionalGeneration model.", +) + +_TOKENIZER_PATH = flags.DEFINE_string( + name="tokenizer_path", + default=None, + help="Path to the SentencePiece model file.", + required=True, +) + +_VARIANT = flags.DEFINE_enum( + name="variant", + default=None, + help="The model variant to convert.", + enum_values=set(_VARIANTS.keys()), + required=True, +) + +_VERBOSE = flags.DEFINE_bool( + name="verbose", + default=False, + help="If true, log the path, shape, and dtype of every converted layer.", +) + +_VISION_DTYPE = flags.DEFINE_enum( + name="vision_dtype", + default="bfloat16", + help="The floating point precision (aka dtype) of the model.", + enum_values=_DTYPES, +) + + +def convert_audio_encoder_weights( + config, # Gemma4AudioConfig + path: str, + param: str, + weights: np.ndarray, +) -> Iterable[tuple[str, np.ndarray]]: + converted_paths: list[str] = [] + converted_weights: list[Any] = [] + + # The conformer uses its own internal dimension (1024 by default via conf_hidden_size). + # Since we now use the default hidden_size=1024 (same as conf_hidden_size), + # we use config.conf_hidden_size for reshaping conformer weights. + + if path.startswith(_AUDIO_ENCODER_CONFORMER): + assert weights.shape[0] == config.num_hidden_layers + + for i, matrix in enumerate(weights): + if "fflayer_end" in path: + base = f"layers.{i}.feed_forward2" + + if path.endswith("ffn_layer1/ClippedEinsum_0"): + converted_paths.append(f"{base}.ffw_layer_1.{param.removeprefix('clip_')}") + converted_weights.append(matrix) + elif path.endswith("ffn_layer2/ClippedEinsum_0"): + converted_paths.append(f"{base}.ffw_layer_2.{param.removeprefix('clip_')}") + converted_weights.append(matrix) + elif path.endswith("ffn_layer1"): + converted_paths.append(f"{base}.ffw_layer_1.linear.weight") + converted_weights.append(matrix.transpose()) + elif path.endswith("ffn_layer2"): + converted_paths.append(f"{base}.ffw_layer_2.linear.weight") + converted_weights.append(matrix.transpose()) + elif path.endswith("post_layer_norm"): + converted_paths.append(f"{base}.post_layer_norm.weight") + converted_weights.append(matrix) + elif path.endswith("pre_layer_norm"): + converted_paths.append(f"{base}.pre_layer_norm.weight") + converted_weights.append(matrix) + elif "fflayer_start" in path: + base = f"layers.{i}.feed_forward1" + + if path.endswith("ffn_layer1/ClippedEinsum_0"): + converted_paths.append(f"{base}.ffw_layer_1.{param.removeprefix('clip_')}") + converted_weights.append(matrix) + elif path.endswith("ffn_layer2/ClippedEinsum_0"): + converted_paths.append(f"{base}.ffw_layer_2.{param.removeprefix('clip_')}") + converted_weights.append(matrix) + elif path.endswith("ffn_layer1"): + converted_paths.append(f"{base}.ffw_layer_1.linear.weight") + converted_weights.append(matrix.transpose()) + elif path.endswith("ffn_layer2"): + converted_paths.append(f"{base}.ffw_layer_2.linear.weight") + converted_weights.append(matrix.transpose()) + elif path.endswith("post_layer_norm"): + converted_paths.append(f"{base}.post_layer_norm.weight") + converted_weights.append(matrix) + elif path.endswith("pre_layer_norm"): + converted_paths.append(f"{base}.pre_layer_norm.weight") + converted_weights.append(matrix) + elif path.endswith("final_ln"): + converted_paths.append(f"layers.{i}.norm_out.weight") + converted_weights.append(matrix) + elif "lconv" in path: + base = f"layers.{i}.lconv1d" + + if path.endswith("linear_start/ClippedEinsum_0"): + converted_paths.append(f"{base}.linear_start.{param.removeprefix('clip_')}") + converted_weights.append(matrix) + elif path.endswith("linear_end/ClippedEinsum_0"): + converted_paths.append(f"{base}.linear_end.{param.removeprefix('clip_')}") + converted_weights.append(matrix) + elif path.endswith("conv_norm"): + converted_paths.append(f"{base}.conv_norm.weight") + converted_weights.append(matrix) + elif path.endswith("depthwise_conv1d"): + converted_paths.append(f"{base}.depthwise_conv1d.weight") + converted_weights.append(matrix.transpose()) + elif path.endswith("linear_end"): + converted_paths.append(f"{base}.linear_end.linear.weight") + converted_weights.append(matrix.transpose()) + elif path.endswith("linear_start"): + converted_paths.append(f"{base}.linear_start.linear.weight") + converted_weights.append(matrix.transpose()) + elif path.endswith("ln"): + converted_paths.append(f"{base}.pre_layer_norm.weight") + converted_weights.append(matrix) + elif "trans_atten" in path: + base = f"layers.{i}" + + if param == "per_dim_scale": + converted_paths.append(f"{base}.self_attn.per_dim_scale") + converted_weights.append(matrix) + + if path.endswith("query_key_value_projection/ClippedEinsum_0"): + converted_paths.append(f"{base}.self_attn.q_proj.{param.removeprefix('clip_')}") + converted_weights.append(matrix) + converted_paths.append(f"{base}.self_attn.k_proj.{param.removeprefix('clip_')}") + converted_weights.append(matrix) + converted_paths.append(f"{base}.self_attn.v_proj.{param.removeprefix('clip_')}") + converted_weights.append(matrix) + elif path.endswith("post/ClippedEinsum_0"): + converted_paths.append(f"{base}.self_attn.post.{param.removeprefix('clip_')}") + converted_weights.append(matrix) + + if path.endswith("query_key_value_projection"): + converted_paths.extend( + [ + f"{base}.self_attn.q_proj.linear.weight", + f"{base}.self_attn.k_proj.linear.weight", + f"{base}.self_attn.v_proj.linear.weight", + ] + ) + converted_weights.extend( + [ + m.reshape(config.hidden_size, config.hidden_size).transpose() + for m in matrix.transpose(1, 0, 2, 3) + ] + ) + elif path.endswith("pos_proj"): + converted_paths.append(f"{base}.self_attn.relative_k_proj.weight") + converted_weights.append(matrix.reshape(config.hidden_size, config.hidden_size).transpose()) + elif path.endswith("post"): + converted_paths.append(f"{base}.self_attn.post.linear.weight") + converted_weights.append(matrix.transpose(2, 0, 1).reshape(config.hidden_size, config.hidden_size)) + elif path.endswith("post_norm"): + converted_paths.append(f"{base}.norm_post_attn.weight") + converted_weights.append(matrix) + elif path.endswith("pre_norm"): + converted_paths.append(f"{base}.norm_pre_attn.weight") + converted_weights.append(matrix) + elif path.startswith(_AUDIO_ENCODER_SSCP): + if path.endswith("input_proj"): + converted_paths.append("subsample_conv_projection.input_proj_linear.weight") + converted_weights.append( + weights.transpose(2, 0, 1).reshape(config.hidden_size, config.subsampling_conv_channels[1] ** 2) + ) + elif "norm_" in path: + index = int(path[-1]) + converted_paths.append(f"subsample_conv_projection.layer{index}.norm.weight") + converted_weights.append(weights) + elif "subsampling_" in path: + index = int(path[-1]) + converted_paths.append(f"subsample_conv_projection.layer{index}.conv.weight") + converted_weights.append(weights.transpose(3, 2, 0, 1)) + + elif path.endswith("output_projection"): + if param == "kernel": + converted_paths.append("output_proj.weight") + converted_weights.append(weights.transpose()) + elif param == "bias": + converted_paths.append("output_proj.bias") + converted_weights.append(weights) + + if (cpl := len(converted_paths)) != (cwl := len(converted_weights)): + raise ValueError( + "The `converted_paths` and `converted_weights` should be the same " + f"length. Got {cpl} and {cwl}, respectively, for {path}." + ) + + return zip(converted_paths, converted_weights) + + +def convert_vision_encoder_weights( + config, # Gemma4VisionConfig + path: str, + param: str, + weights: np.ndarray, +) -> Iterable[tuple[str, np.ndarray]]: + """Convert vision encoder weights from JAX checkpoint to HuggingFace format. + + Args: + config: Vision config with num_hidden_layers, hidden_size, etc. + path: Path in the JAX checkpoint (e.g., "VisionEncoder_0/entry/input_projection") + param: Parameter type (e.g., "w", "scale", "pos_emb") + weights: NumPy array of weights + + Returns: + Iterable of (hf_path, converted_weights) tuples + """ + converted_paths: list[str] = [] + converted_weights: list[Any] = [] + + # Patch Embedder - Entry + # TODO(philculliton): These do not appear to be used currently - they should be loaded by Gemma4VisionPatchEmbedder, by all appearances, but are not currently. + if path == f"{_VISION_ENCODER_ENTRY}/input_projection": + if param == "w": + converted_paths.append("patch_embedder.input_proj.weight") + # Shape: (768, 768) -> transpose to (768, 768) for nn.Linear + converted_weights.append(weights.transpose()) + elif path == _VISION_ENCODER_ENTRY: + if param == "pos_emb": + converted_paths.append("patch_embedder.position_embedding_table") + # Shape: (10240, 2, 768) -> transpose to (2, 10240, 768) + converted_weights.append(weights.transpose(1, 0, 2)) + + # Pooler - Exit: convert the learnable scale parameter for vision output scaling + elif path == _VISION_ENCODER_EXIT: + if param == "scale": + converted_paths.append("pooler.scale") + # JAX shape is (1, 1, d_model), keep as-is for nn.Parameter + converted_weights.append(weights) + + elif path == _VISION_ENCODER_STANDARDIZE: + if param == "bias": + converted_paths.append("std_bias") + converted_weights.append(weights) + else: + converted_paths.append("std_scale") + converted_weights.append(weights) + + # Transformer Layers (stacked format) + elif path.startswith(_VISION_ENCODER_TRANSFORMER): + # All vision transformer layers are stacked in dimension 0 + num_layers = weights.shape[0] + assert num_layers == config.num_hidden_layers, f"Expected {config.num_hidden_layers} layers, got {num_layers}" + + for i, matrix in enumerate(weights): + base_path = f"encoder.layers.{i}" + + # Handle clipped einsum states (`ClippedEinsum_0` target paths). + if path.endswith("attn_vec_einsum/ClippedEinsum_0"): + converted_paths.append(f"{base_path}.self_attn.o_proj.{param.removeprefix('clip_')}") + converted_weights.append(matrix) + if path.endswith("kv_einsum/ClippedEinsum_0"): + # NOTE: In JAX reference implementations of Gemma, k_proj and v_proj are performed with a single einsum + # operation. We split this into two operations in Transformers, but they are passed the same input and + # share the same activation bounds for clipping, thus we re-use the same matrix for both. + converted_paths.append(f"{base_path}.self_attn.k_proj.{param.removeprefix('clip_')}") + converted_weights.append(matrix) + converted_paths.append(f"{base_path}.self_attn.v_proj.{param.removeprefix('clip_')}") + converted_weights.append(matrix) + if path.endswith("q_einsum/ClippedEinsum_0"): + converted_paths.append(f"{base_path}.self_attn.q_proj.{param.removeprefix('clip_')}") + converted_weights.append(matrix) + if path.endswith("gating_einsum/ClippedEinsum_0"): + # NOTE: In JAX reference implementations of Gemma, gate_proj and up_proj are performed with a single + # einsum operation. We split this into two operations in Transformers, but they are passed the same + # input and share the same activation bounds for clipping, thus we re-use the same matrix for both. + converted_paths.append(f"{base_path}.mlp.gate_proj.{param.removeprefix('clip_')}") + converted_weights.append(matrix) + converted_paths.append(f"{base_path}.mlp.up_proj.{param.removeprefix('clip_')}") + converted_weights.append(matrix) + if path.endswith("linear/ClippedEinsum_0"): + converted_paths.append(f"{base_path}.mlp.down_proj.{param.removeprefix('clip_')}") + converted_weights.append(matrix) + + # Handle clipped einsum states (`compression_einsum` target paths). + # The target path specifies the activation direction (`input` or `output`), + # and the parameter holds `clip_min` or `clip_max`. + if "/compression_einsum/" in path: + direction = path.split("/")[-1].split("_")[0] # Extracts "input" or "output" + hf_suffix = f"{direction}_{param.removeprefix('clip_')}" + einsum_type = path.split("/compression_einsum/")[0].split("/")[-1] + + if einsum_type == "attn_vec_einsum": + converted_paths.append(f"{base_path}.self_attn.o_proj.{hf_suffix}") + converted_weights.append(matrix) + elif einsum_type == "kv_einsum": + converted_paths.append(f"{base_path}.self_attn.k_proj.{hf_suffix}") + converted_weights.append(matrix) + converted_paths.append(f"{base_path}.self_attn.v_proj.{hf_suffix}") + converted_weights.append(matrix) + elif einsum_type == "q_einsum": + converted_paths.append(f"{base_path}.self_attn.q_proj.{hf_suffix}") + converted_weights.append(matrix) + elif einsum_type == "gating_einsum": + converted_paths.append(f"{base_path}.mlp.gate_proj.{hf_suffix}") + converted_weights.append(matrix) + converted_paths.append(f"{base_path}.mlp.up_proj.{hf_suffix}") + converted_weights.append(matrix) + elif einsum_type == "linear": + converted_paths.append(f"{base_path}.mlp.down_proj.{hf_suffix}") + converted_weights.append(matrix) + + if path.endswith("attn/attn_vec_einsum"): + # Shape: (12, 64, 768) -> reshape to (768, 768) for o_proj + converted_paths.append(f"{base_path}.self_attn.o_proj.linear.weight") + converted_weights.append( + matrix.transpose(2, 0, 1).reshape(config.hidden_size, config.num_attention_heads * config.head_dim) + ) + elif path.endswith("attn/kv_einsum"): + # Shape: (2, 12, 768, 64) -> split into k_proj and v_proj + converted_paths.extend( + [ + f"{base_path}.self_attn.k_proj.linear.weight", + f"{base_path}.self_attn.v_proj.linear.weight", + ] + ) + k_proj_weights, v_proj_weights = matrix.transpose(0, 2, 1, 3) + kv_proj_shape = (config.hidden_size, config.num_key_value_heads * config.head_dim) + converted_weights.extend( + [ + k_proj_weights.reshape(kv_proj_shape).transpose(), + v_proj_weights.reshape(kv_proj_shape).transpose(), + ] + ) + elif path.endswith("attn/q_einsum"): + # Shape: (12, 768, 64) -> reshape to (768, 768) for q_proj + converted_paths.append(f"{base_path}.self_attn.q_proj.linear.weight") + converted_weights.append( + matrix.transpose(1, 0, 2) + .reshape(config.hidden_size, config.num_attention_heads * config.head_dim) + .transpose() + ) + elif path.endswith("mlp/gating_einsum"): + # Shape: (2, 3072, 768) -> split into gate_proj and up_proj + converted_paths.extend( + [ + f"{base_path}.mlp.gate_proj.linear.weight", + f"{base_path}.mlp.up_proj.linear.weight", + ] + ) + gate_proj_weight, up_proj_weight = matrix + converted_weights.extend([gate_proj_weight, up_proj_weight]) + elif path.endswith("mlp/linear"): + # Shape: (3072, 768) -> transpose for down_proj + converted_paths.append(f"{base_path}.mlp.down_proj.linear.weight") + converted_weights.append(matrix.transpose()) + elif path.endswith("post_attention_norm"): + converted_paths.append(f"{base_path}.post_attention_layernorm.weight") + converted_weights.append(matrix) + elif path.endswith("post_ffw_norm"): + converted_paths.append(f"{base_path}.post_feedforward_layernorm.weight") + converted_weights.append(matrix) + elif path.endswith("pre_attention_norm"): + converted_paths.append(f"{base_path}.input_layernorm.weight") + converted_weights.append(matrix) + elif path.endswith("pre_ffw_norm"): + converted_paths.append(f"{base_path}.pre_feedforward_layernorm.weight") + converted_weights.append(matrix) + elif path.endswith("attn/query_norm/scale") or path.endswith("attn/query_norm"): + # Vision Q/K norms: JAX trained scale values (~-0.6) are not directly + # usable because the OSS modules expect different shapes and the HF + # RMSNorm uses scale_shift=1.0 (formula: weight + 1.0). + # We use zeros to get identity: (0 + 1.0) = 1.0, matching the blaze + # reference which also uses zeros(head_dim) -> (1+0) = 1.0 identity. + converted_paths.append(f"{base_path}.self_attn.q_norm.weight") + converted_weights.append(matrix) + elif path.endswith("attn/key_norm/scale") or path.endswith("attn/key_norm"): + converted_paths.append(f"{base_path}.self_attn.k_norm.weight") + converted_weights.append(matrix) + + if (cpl := len(converted_paths)) != (cwl := len(converted_weights)): + raise ValueError( + "The `converted_paths` and `converted_weights` should be the same " + f"length. Got {cpl} and {cwl}, respectively, for {path}." + ) + + return zip(converted_paths, converted_weights) + + +def convert_transformer_weights( + config: Gemma4TextConfig, + path: str, + param: str, + weights: np.ndarray, +) -> Iterable[tuple[str, np.ndarray]]: + if path.startswith(_TRANSFORMER_POST_TRAINING_PREFIX): + path = path[_TRANSFORMER_POST_TRAINING_PREFIX_LEN:] + + converted_paths: list[str] = [] + converted_weights: list[Any] = [] + + # Handle new checkpoint format: transformer/layer_N/... + # TODO(philculliton):Direct handling for unstacked checkpoint type, needs to be merged to allow for unified tensor handling + if path.startswith(f"{_TRANSFORMER_PARAMETER}/layer_"): + # Extract layer number from path like "transformer/layer_0/attn/q_einsum" + layer_str = path.split("/")[1] # "layer_0" + layer_idx = int(layer_str.replace("layer_", "")) # 0 + base_path = f"layers.{layer_idx}" + + # Determine head_dim from actual checkpoint weight dimensions + # For q_einsum/key_norm, the last dimension tells us the head_dim + # Otherwise fall back to config + if path.endswith("attn/key_norm") or path.endswith("attn/query_norm"): + head_dim = weights.shape[0] # The norm dimension IS the head_dim + elif path.endswith("attn/q_einsum"): + head_dim = weights.shape[-1] # Last dimension is head_dim + else: + # Fall back to config-based determination + head_dim = ( + config.global_head_dim + if config.layer_types[layer_idx] == "full_attention" and config.global_head_dim + else config.head_dim + ) + + # Note: In new format, weights are per-layer (not batched), so no enumerate loop needed + matrix = weights + + if path.endswith("attn/attn_vec_einsum"): + converted_paths.append(f"{base_path}.self_attn.o_proj.weight") + converted_weights.append( + matrix.transpose(2, 0, 1).reshape(config.hidden_size, config.num_attention_heads * head_dim) + ) + elif path.endswith("attn/kv_einsum"): + converted_paths.extend( + [ + f"{base_path}.self_attn.k_proj.weight", + f"{base_path}.self_attn.v_proj.weight", + ] + ) + k_proj_weights, v_proj_weights = matrix.transpose(0, 2, 1, 3) + kv_proj_shape = (config.hidden_size, config.num_key_value_heads * head_dim) + converted_weights.extend( + [ + k_proj_weights.reshape(kv_proj_shape).transpose(), + v_proj_weights.reshape(kv_proj_shape).transpose(), + ] + ) + elif path.endswith("attn/k_einsum"): + converted_paths.append(f"{base_path}.self_attn.k_proj.weight") + converted_weights.append( + matrix.transpose(1, 0, 2) + .reshape(config.hidden_size, config.num_global_key_value_heads * head_dim) + .transpose() + ) + elif path.endswith("attn/q_einsum"): + converted_paths.append(f"{base_path}.self_attn.q_proj.weight") + converted_weights.append( + matrix.transpose(1, 0, 2) + .reshape(config.hidden_size, config.num_attention_heads * head_dim) + .transpose() + ) + elif path.endswith("attn/query_norm"): + converted_paths.append(f"{base_path}.self_attn.q_norm.weight") + converted_weights.append(matrix) + elif path.endswith("attn/key_norm"): + converted_paths.append(f"{base_path}.self_attn.k_norm.weight") + converted_weights.append(matrix) + elif path.endswith("mlp/gating_einsum"): + converted_paths.extend([f"{base_path}.mlp.gate_proj.weight", f"{base_path}.mlp.up_proj.weight"]) + gate_proj_weight, up_proj_weight = matrix + converted_weights.extend([gate_proj_weight, up_proj_weight]) + elif path.endswith("mlp/linear"): + converted_paths.append(f"{base_path}.mlp.down_proj.weight") + converted_weights.append(matrix.transpose()) + elif path.endswith("per_layer_input_gate"): + converted_paths.append(f"{base_path}.per_layer_input_gate.weight") + converted_weights.append(matrix.transpose()) + elif path.endswith("per_layer_projection"): + converted_paths.append(f"{base_path}.per_layer_projection.weight") + converted_weights.append(matrix.transpose()) + elif path.endswith("post_attention_norm"): + converted_paths.append(f"{base_path}.post_attention_layernorm.weight") + converted_weights.append(matrix) + elif path.endswith("post_ffw_norm"): + converted_paths.append(f"{base_path}.post_feedforward_layernorm.weight") + converted_weights.append(matrix) + elif path.endswith("post_per_layer_input_norm"): + converted_paths.append(f"{base_path}.post_per_layer_input_norm.weight") + converted_weights.append(matrix) + elif path.endswith("pre_attention_norm"): + converted_paths.append(f"{base_path}.input_layernorm.weight") + converted_weights.append(matrix) + elif path.endswith("pre_ffw_norm"): + converted_paths.append(f"{base_path}.pre_feedforward_layernorm.weight") + converted_weights.append(matrix) + elif path.endswith(layer_str) and param == "skip_scale": + converted_paths.append(f"{base_path}.layer_scalar") + converted_weights.append(matrix) + + # Handle old checkpoint format: transformer/stacked_layers/attention_type_N/... + elif path.startswith(_TRANSFORMER_DECODER_BLOCK): + attention_type_index = int(path[_TRANSFORMER_DECODER_BLOCK_LEN]) + expected_layers_per_group = config.num_hidden_layers / _SLIDING_WINDOW_PATTERN + observed_layers_per_group = weights.shape[0] + assert observed_layers_per_group == expected_layers_per_group, ( + f"Expected {observed_layers_per_group=} to be {expected_layers_per_group=}" + ) + + for i, matrix in enumerate(weights): + layer_idx = _SLIDING_WINDOW_PATTERN * i + attention_type_index + base_path = f"layers.{layer_idx}" + head_dim = ( + config.global_head_dim + if config.layer_types[layer_idx] == "full_attention" and config.global_head_dim + else config.head_dim + ) + + if param == "skip_scale": + converted_paths.append(f"{base_path}.layer_scalar") + converted_weights.append(matrix) + elif path.endswith("attn/attn_vec_einsum"): + converted_paths.append(f"{base_path}.self_attn.o_proj.weight") + converted_weights.append( + matrix.transpose(2, 0, 1).reshape(config.hidden_size, config.num_attention_heads * head_dim) + ) + elif path.endswith("attn/kv_einsum"): + converted_paths.extend( + [ + f"{base_path}.self_attn.k_proj.weight", + f"{base_path}.self_attn.v_proj.weight", + ] + ) + k_proj_weights, v_proj_weights = matrix.transpose(0, 2, 1, 3) + kv_proj_shape = (config.hidden_size, config.num_key_value_heads * head_dim) + converted_weights.extend( + [ + k_proj_weights.reshape(kv_proj_shape).transpose(), + v_proj_weights.reshape(kv_proj_shape).transpose(), + ] + ) + elif path.endswith("attn/k_einsum"): + converted_paths.append(f"{base_path}.self_attn.k_proj.weight") + converted_weights.append( + matrix.transpose(1, 0, 2) + .reshape(config.hidden_size, config.num_global_key_value_heads * head_dim) + .transpose() + ) + elif path.endswith("attn/q_einsum"): + converted_paths.append(f"{base_path}.self_attn.q_proj.weight") + converted_weights.append( + matrix.transpose(1, 0, 2) + .reshape(config.hidden_size, config.num_attention_heads * head_dim) + .transpose() + ) + elif path.endswith("attn/query_norm"): + converted_paths.append(f"{base_path}.self_attn.q_norm.weight") + converted_weights.append(matrix) + elif path.endswith("attn/key_norm"): + converted_paths.append(f"{base_path}.self_attn.k_norm.weight") + converted_weights.append(matrix) + elif path.endswith("mlp/gating_einsum"): + # NOTE: The JAX implementations changes the type of the primary `mlp` for MOE models and adds a new + # `mlp2` that operates _before_ `mlp`. In Hugging Face Transformers we keep the type of `mlp` constant + # and add an `experts` that operates after `mlp`, so we need to invert this assignment when using MOE arch. + if config.enable_moe_block: + # MoE expert weights: matrix shape [num_experts, 2, moe_intermediate_size, hidden_size] + # -> experts.gate_up_proj (nn.Parameter, shape [E, 2*moe_inter, hidden]) + num_experts, _, expert_inter, hidden_size = matrix.shape + gate_up_proj_weight = np.asarray(matrix).reshape(num_experts, 2 * expert_inter, hidden_size) + converted_paths.append(f"{base_path}.experts.gate_up_proj") + converted_weights.append(gate_up_proj_weight) + else: + # Dense MLP: matrix shape [2, intermediate_size, hidden_size] + gate_proj_weight, up_proj_weight = matrix + converted_paths.extend([f"{base_path}.mlp.gate_proj.weight", f"{base_path}.mlp.up_proj.weight"]) + converted_weights.extend([gate_proj_weight, up_proj_weight]) + elif path.endswith("mlp/linear"): + # NOTE: The JAX implementations changes the type of the primary `mlp` for MOE models and adds a new + # `mlp2` that operates _before_ `mlp`. In Hugging Face Transformers we keep the type of `mlp` constant + # and add an `experts` that operates after `mlp`, so we need to invert this assignment when using MOE arch. + if config.enable_moe_block: + # MoE expert down_proj: matrix shape [num_experts, moe_inter, hidden] + # -> experts.down_proj (nn.Parameter, shape [E, hidden, moe_inter]) + converted_paths.append(f"{base_path}.experts.down_proj") + converted_weights.append(matrix.transpose(0, 2, 1)) + else: + # Dense MLP down_proj + converted_paths.append(f"{base_path}.mlp.down_proj.weight") + converted_weights.append(matrix.transpose()) + elif path.endswith("mlp/router_logits"): + # MoE router: matrix shape [hidden_size, num_experts] + # -> router.proj.weight (nn.Linear, shape [num_experts, hidden_size]) + converted_paths.append(f"{base_path}.router.proj.weight") + converted_weights.append(matrix.transpose()) + elif param == "router_scale" and path.endswith("mlp"): + # MoE router scale: shape [hidden_size] + converted_paths.append(f"{base_path}.router.scale") + converted_weights.append(matrix) + elif param == "per_expert_scale" and path.endswith("mlp"): + # MoE per-expert scale: shape [num_experts] + converted_paths.append(f"{base_path}.router.per_expert_scale") + converted_weights.append(matrix) + elif path.endswith("mlp2/gating_einsum"): + # Shared expert: matrix shape [2, intermediate_size, hidden_size] + # -> mlp.gate_proj.weight + mlp.up_proj.weight (nn.Linear) + converted_paths.extend([f"{base_path}.mlp.gate_proj.weight", f"{base_path}.mlp.up_proj.weight"]) + gate_proj_weight, up_proj_weight = matrix + converted_weights.extend([gate_proj_weight, up_proj_weight]) + elif path.endswith("mlp2/linear"): + # Shared expert down_proj: matrix shape [intermediate_size, hidden_size] + # -> mlp.down_proj.weight (nn.Linear, needs transpose) + converted_paths.append(f"{base_path}.mlp.down_proj.weight") + converted_weights.append(matrix.transpose()) + elif path.endswith("per_layer_input_gate"): + converted_paths.append(f"{base_path}.per_layer_input_gate.weight") + converted_weights.append(matrix.transpose()) + elif path.endswith("per_layer_projection"): + converted_paths.append(f"{base_path}.per_layer_projection.weight") + converted_weights.append(matrix.transpose()) + elif path.endswith("post_attention_norm"): + converted_paths.append(f"{base_path}.post_attention_layernorm.weight") + converted_weights.append(matrix) + elif path.endswith("post_ffw_norm"): + converted_paths.append(f"{base_path}.post_feedforward_layernorm.weight") + converted_weights.append(matrix) + elif path.endswith("post_ffw1_norm"): + converted_paths.append(f"{base_path}.post_feedforward_layernorm_2.weight") + converted_weights.append(matrix) + elif path.endswith("post_ffw2_norm"): + converted_paths.append(f"{base_path}.post_feedforward_layernorm_1.weight") + converted_weights.append(matrix) + elif path.endswith("pre_ffw2_norm"): + converted_paths.append(f"{base_path}.pre_feedforward_layernorm.weight") + converted_weights.append(matrix) + elif path.endswith("post_per_layer_input_norm"): + converted_paths.append(f"{base_path}.post_per_layer_input_norm.weight") + converted_weights.append(matrix) + elif path.endswith("pre_attention_norm"): + converted_paths.append(f"{base_path}.input_layernorm.weight") + converted_weights.append(matrix) + elif path.endswith("pre_ffw_norm"): + # NOTE: The JAX implementations changes the type of the primary `mlp` for MOE models and adds a new + # `mlp2` that operates _before_ `mlp`. In Hugging Face Transformer we keep the type of `mlp` constant + # and add an `mlp2` that operates after `mlp`, so we need to invert this assignment when using MOE arch. + if config.enable_moe_block: + # pre_ffw_norm is the pre-norm for ffw1 (MoE); in HF, MoE is mlp_2 + converted_paths.append(f"{base_path}.pre_feedforward_layernorm_2.weight") + else: + converted_paths.append(f"{base_path}.pre_feedforward_layernorm.weight") + converted_weights.append(matrix) + elif path == _TRANSFORMER_EMBEDDER: + if param == "input_embedding": + converted_paths.append("embed_tokens.weight") + converted_weights.append(weights) + elif param == "per_layer_embeddings": + converted_paths.append("embed_tokens_per_layer.weight") + # JAX uses an einsum, but Transformers uses a Linear, so reshapes are required here and in modeling file. + vocab_size, num_layers, hidden_dim = weights.shape + converted_weights.append(weights.reshape(vocab_size, num_layers * hidden_dim)) + elif path.startswith(_TRANSFORMER_EMBEDDER): + # TODO: ryanmullins - support multimodal norms and projections + if path.endswith("per_layer_model_projection"): + converted_paths.append("per_layer_model_projection.weight") + converted_weights.append( + weights.reshape( + config.hidden_size, config.num_hidden_layers * config.hidden_size_per_layer_input + ).transpose() + ) + elif path.endswith("per_layer_projection_norm"): + converted_paths.append("per_layer_projection_norm.weight") + converted_weights.append(weights) + elif path == _TRANSFORMER_FINAL_NORM: + converted_paths = ["norm.weight"] + converted_weights = [weights] + + if (cpl := len(converted_paths)) != (cwl := len(converted_weights)): + raise ValueError( + "The `converted_paths` and `converted_weights` should be the same " + f"length. Got {cpl} and {cwl}, respectively, for {path}." + ) + + return zip(converted_paths, converted_weights) + + +def _restore_checkpoint(checkpoint_path: str) -> dict: + """Restores an Orbax checkpoint, handling multi-device sharded checkpoints. + + Reads the checkpoint metadata to build a target tree structure and uses + SingleDeviceSharding to consolidate all shards onto a single CPU device. + """ + metadata_path = os.path.join(checkpoint_path, "_METADATA") + with open(metadata_path, "rb") as f: + metadata = json.loads(f.read()) + + tree_metadata = metadata["tree_metadata"] + + # Build a nested dict matching the checkpoint's tree structure + target = {} + for key_str in tree_metadata: + keys = ast.literal_eval(key_str) + d = target + for k in keys[:-1]: + if k not in d: + d[k] = {} + d = d[k] + d[keys[-1]] = np.zeros(1) # placeholder leaf + + device = jax.devices("cpu")[0] + sharding = SingleDeviceSharding(device) + + restore_args_tree = tree.map_structure(lambda _: type_handlers.ArrayRestoreArgs(sharding=sharding), target) + restore = obc_args.PyTreeRestore(item=target, restore_args=restore_args_tree) + + checkpointer = obc.PyTreeCheckpointer() + return checkpointer.restore(checkpoint_path, args=restore) + + +def convert(checkpoint_path: str, config: Gemma4Config) -> dict[str, torch.Tensor]: + """Loads Orbax checkpoint from `input_path` and converts it to HF tree.""" + ckpt = _restore_checkpoint(checkpoint_path) + hf_tree: dict[str, torch.Tensor] = {} + + text_path_prefix = "model" + if not _TEXT_ONLY.value: + text_path_prefix += ".language_model" + + def update_tree(path: str, weights: np.ndarray, target_dtype: torch.dtype) -> None: + # Convert directly to float32 in a single step to avoid an extra intermediate copy. + # The old code did np.asarray(weights) then .astype("float32"), keeping two full copies alive. + weights_f32 = np.asarray(weights, dtype=np.float32) + del weights # allow GC of the input (JAX array or numpy view) + t = torch.from_numpy(weights_f32) # shares memory with weights_f32 + if t.dtype != target_dtype: + hf_tree[path] = t.to(target_dtype) + del t, weights_f32 # free the float32 intermediate + else: + hf_tree[path] = t + if _VERBOSE.value: + logging.info( + "%s converted shape=%s with dtype=%s", + path, + hf_tree[path].shape, + target_dtype, + ) + + for path_tuple, value in tree.flatten_with_path(ckpt): + param = path_tuple[-1] + if "params" in path_tuple: + path_tuple = path_tuple[2:] + path_tuple = path_tuple[:-1] + path = "/".join(path_tuple) if len(path_tuple) > 1 else path_tuple[0] + + if path.endswith("audio_input_projection") and not _TEXT_ONLY.value: + update_tree("model.embed_audio.embedding_projection.weight", value.transpose(), config.audio_config.dtype) + elif path.endswith("mm_input_projection") and not _TEXT_ONLY.value: + update_tree( + "model.embed_vision.embedding_projection.weight", value.transpose(), config.vision_config.dtype + ) + elif path.startswith(_TRANSFORMER_PARAMETER): + for hf_path, weights in convert_transformer_weights(config.text_config, path, param, value): + update_tree(f"{text_path_prefix}.{hf_path}", weights, config.text_config.dtype) + elif path.startswith(_VISION_ENCODER_PARAMETER) and not _TEXT_ONLY.value: + for hf_path, weights in convert_vision_encoder_weights(config.vision_config, path, param, value): + update_tree(f"model.vision_tower.{hf_path}", weights, config.vision_config.dtype) + elif path.startswith(_AUDIO_ENCODER_PARAMETER) and not _TEXT_ONLY.value: + for hf_path, weights in convert_audio_encoder_weights(config.audio_config, path, param, value): + update_tree(f"model.audio_tower.{hf_path}", weights, config.audio_config.dtype) + + hf_tree["lm_head.weight"] = hf_tree[f"{text_path_prefix}.embed_tokens.weight"] + + return hf_tree + + +def main(*args): + del args + + output_path = _OUTPUT_PATH.value + variant = _VARIANT.value + + config = _VARIANTS[variant] + config.text_config.dtype = getattr(torch, _TEXT_DTYPE.value) + config.vision_config.dtype = getattr(torch, _VISION_DTYPE.value) + if (audio_config := config.audio_config) is not None: + audio_config.dtype = getattr(torch, _AUDIO_DTYPE.value) + + if _INCLUDE_CHAT_TEMPLATE.value: + # Chat template is included for instruction tuned models, which treat + # both "" and "" as generation stoppers. + config.eos_token_id = [1, 106] + + logging.info( + "Converting Gemma 4 (%s) @ %s (language) and %s (vision)", + variant, + _TEXT_DTYPE.value, + _VISION_DTYPE.value, + ) + state_tree = convert(_CHECKPOINT_PATH.value, config) + logging.info("Converted Gemma 4 (%s) state tree from Orbax to Hugging Face.", variant) + + with accelerate.init_empty_weights(): + if _TEXT_ONLY.value: + config = config.text_config + model = Gemma4ForCausalLM(config=config) + else: + model = Gemma4ForConditionalGeneration(config=config) + + model.load_state_dict(state_tree, assign=True) + logging.info( + "Loaded Gemma 4 (%s) in Hugging Face Transformers as a %s instance.", + variant, + type(model).__name__, + ) + model.save_pretrained(output_path, state_dict=state_tree, safe_serialization=True) + logging.info( + "Saved Gemma 4 (%s) to SafeTensors in %s using %s", + variant, + output_path, + type(model).__name__, + ) + del model + del state_tree + + chat_template = _CHAT_TEMPLATE_LARGE if variant in _LARGE_MODEL_VARIANTS else _CHAT_TEMPLATE + chat_template_kwargs = {"chat_template": chat_template} if _INCLUDE_CHAT_TEMPLATE.value else {} + response_schema_kwargs = {"response_schema": _RESPONSE_SCHEMA} if _INCLUDE_RESPONSE_SCHEMA.value else {} + + sentencepiece_extractor = SentencePieceExtractor(_TOKENIZER_PATH.value) + vocab, _, merges = sentencepiece_extractor.extract() + tokenizer = GemmaTokenizer( + vocab=vocab, + merges=merges, + add_bos_token=False, + padding_side="left", + extra_special_tokens={ + "image_token": "<|image|>", + "boi_token": "<|image>", + "eoi_token": "", + "audio_token": "<|audio|>", + "boa_token": "<|audio>", + "eoa_token": "", + "sot_token": "<|turn>", + "eot_token": "", + "soc_token": "<|channel>", + "eoc_token": "", + "think_token": "<|think|>", + "escape_token": '<|"|>', + "str_token": "<|tool_response>", + "etr_token": "", + "stc_token": "<|tool_call>", + "etc_token": "", + "std_token": "<|tool>", + "etd_token": "", + }, + **chat_template_kwargs, + **response_schema_kwargs, + ) + + # Update config multimodal token IDs from the tokenizer. + # The Gemma4 SPM (262144 vocab) has native <|image> (255999) and <|audio> (256000) + # tokens, plus (258882) and (258883) for delimiters. + # Only and are added as new tokens (IDs >= 262144). + config.image_token_id = tokenizer.image_token_id + config.boi_token_id = tokenizer.convert_tokens_to_ids(tokenizer.boi_token) + config.eoi_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eoi_token) + config.audio_token_id = tokenizer.audio_token_id + config.boa_token_id = tokenizer.convert_tokens_to_ids(tokenizer.boa_token) + config.eoa_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eoa_token) + logging.info( + "Set multimodal token IDs from tokenizer: image=%d, boi=%d, eoi=%d, audio=%d, boa=%d, eoa=%d", + config.image_token_id, + config.boi_token_id, + config.eoi_token_id, + config.audio_token_id, + config.boa_token_id, + config.eoa_token_id, + ) + # Re-save the config with correct token IDs + config.save_pretrained(output_path) + + if _TEXT_ONLY.value: + tokenizer.save_pretrained(output_path) + logging.info("Saved GemmaTokenizer for %s to %s", variant, output_path) + else: + vision_config = config.vision_config + feature_extractor = Gemma4AudioFeatureExtractor() + image_processor = Gemma4ImageProcessor( + image_seq_length=vision_config.default_output_length, + do_normalize=False, + max_soft_tokens=vision_config.default_output_length, + pooling_kernel_size=3, + ) + video_processor = Gemma4VideoProcessor() + processor = Gemma4Processor( + image_processor=image_processor, + feature_extractor=feature_extractor, + video_processor=video_processor, + tokenizer=tokenizer, + image_seq_length=vision_config.default_output_length, + **chat_template_kwargs, + ) + processor.save_pretrained(output_path) + + logging.info("Saved Gemma4Processor for %s to %s", variant, output_path) + del feature_extractor, image_processor, processor + + generation_config = GenerationConfig( + pad_token_id=config.get_text_config().pad_token_id, + bos_token_id=config.get_text_config().bos_token_id, + eos_token_id=( + tokenizer.convert_tokens_to_ids([tokenizer.eos_token, tokenizer.eot_token, tokenizer.str_token]) + if _INCLUDE_CHAT_TEMPLATE.value + else config.get_text_config().eos_token_id + ), + temperature=1.0, + do_sample=True, + top_k=64, + top_p=0.95, + ) + generation_config.save_pretrained(output_path) + + +if __name__ == "__main__": + app.run(main) diff --git a/src/transformers/models/gemma4/feature_extraction_gemma4.py b/src/transformers/models/gemma4/feature_extraction_gemma4.py new file mode 100644 index 000000000000..38382e8a85cb --- /dev/null +++ b/src/transformers/models/gemma4/feature_extraction_gemma4.py @@ -0,0 +1,298 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import warnings +from collections.abc import Sequence + +import numpy as np + +from ...audio_utils import mel_filter_bank, window_function +from ...feature_extraction_sequence_utils import SequenceFeatureExtractor +from ...feature_extraction_utils import BatchFeature +from ...utils import PaddingStrategy, TensorType, logging + + +logger = logging.get_logger(__name__) + + +def _unfold(array: np.ndarray, dimension: int, size: int, step: int) -> np.ndarray: + """A basic NumPy equivalent of PyTorch's unfold for 2D arrays along the last dim.""" + if array.ndim != 2: + raise ValueError("This unfold implementation currently supports 2D arrays (batch, time).") + if dimension != -1 and dimension != array.ndim - 1: + raise ValueError("This unfold implementation only supports unfolding the last dimension.") + + batch_size, original_length = array.shape + num_frames = (original_length - size) // step + 1 + + if num_frames <= 0: + return np.zeros((batch_size, 0, size), dtype=array.dtype) + + output_shape = (batch_size, num_frames, size) + output_strides = (array.strides[0], array.strides[1] * step, array.strides[1]) + + return np.lib.stride_tricks.as_strided(array, shape=output_shape, strides=output_strides) + + +class Gemma4AudioFeatureExtractor(SequenceFeatureExtractor): + """An audio feature extractor Universal Speech Models https://huggingface.co/papers/2303.01037. + + Args: + feature_size (`int`, *optional*, defaults to 128): + The feature dimension of the extracted features. + sampling_rate (`int`, *optional*, defaults to 16000): + The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). + padding_value (`float`, *optional*, defaults to 0.0): + Padding value used to pad the audio. Should correspond to silences. + return_attention_mask (`bool`, *optional*, defaults to `True`): + Whether to return the attention mask for the generated MEL spectrograms. + frame_length_ms (`float`, *optional*, defaults to 20.0): + The length of a frame in milliseconds. + hop_length_ms (`float`, *optional*, defaults to 10.0): + Length of the overlapping windows for the STFT used to obtain the Mel Frequency coefficients. + min_frequency (`float`, *optional*, defaults to 0.0): + The minimum frequency (in Hz) for the Mel filterbank. + max_frequency (`float`, *optional*, defaults to 8000.0): + The maximum frequency (in Hz) for the Mel filterbank. + preemphasis (`float`, *optional*, defaults to 0.0): + The preemphasis coefficient. + preemphasis_htk_flavor (`bool`, *optional*, defaults to `True`): + Whether to use HTK-style preemphasis. + fft_overdrive (`bool`, *optional*, defaults to `False`): + Whether to use FFT overdrive. + dither (`float`, *optional*, defaults to 0.0): + Adds dithering. In other words, adds a small Gaussian noise to each frame. + E.g. use 0.0001 to add dithering with a normal distribution centered + around 0.0 with standard deviation 0.0001 (assuming [-1,+1] range of raw_speech). + The value 0.0 means no dithering. + Dithering has similar effect as `spectrogram(mel_floor=...)`. It reduces + the high log_mel_fbank values for signals with hard-zero sections, + when VAD cutoff is present in the signal. + input_scale_factor (`float`, *optional*, defaults to 1.0): + Scaling factor applied to the input waveform. + mel_floor (`float`, *optional*, defaults to 0.001): + Minimum value for Mel spectrograms to avoid log(0). + per_bin_mean (`Optional[Sequence[float]]`, *optional*): + Mean values for per-bin normalization. + per_bin_stddev (`Optional[Sequence[float]]`, *optional*): + Standard deviation values for per-bin normalization. + """ + + model_input_names = ["input_features", "input_features_mask"] + + def __init__( + self, + feature_size: int = 128, + sampling_rate: int = 16_000, + padding_value: float = 0.0, + return_attention_mask: bool = True, + frame_length_ms: float = 20.0, + hop_length_ms: float = 10.0, + min_frequency: float = 0.0, + max_frequency: float = 8000.0, + preemphasis: float = 0.0, + preemphasis_htk_flavor: bool = True, + fft_overdrive: bool = False, + dither: float = 0.0, + input_scale_factor: float = 1.0, + mel_floor: float = 1e-3, + per_bin_mean: Sequence[float] | None = None, + per_bin_stddev: Sequence[float] | None = None, + **kwargs, + ): + super().__init__( + feature_size=feature_size, + sampling_rate=sampling_rate, + padding_value=padding_value, + return_attention_mask=return_attention_mask, + **kwargs, + ) + + self.min_frequency = min_frequency + self.max_frequency = max_frequency + self.preemphasis = preemphasis + self.preemphasis_htk_flavor = preemphasis_htk_flavor + self.fft_overdrive = fft_overdrive + self.dither = dither + self.input_scale_factor = input_scale_factor + self.frame_length = int(round(sampling_rate * frame_length_ms / 1000.0)) + self.hop_length = int(round(sampling_rate * hop_length_ms / 1000.0)) + self.mel_floor = np.array(mel_floor, dtype=np.float64) + + fft_length = 2 ** math.ceil(math.log2(self.frame_length)) + if self.fft_overdrive: + fft_length *= 2 + self.fft_length = fft_length + + # Use periodic Hann window, matching sl.STFT default (signal.hann_window) + # For even frame_length: window[n] = 0.5 - 0.5 * cos(2*pi*n / frame_length) + self.window = window_function(self.frame_length).astype(np.float32) + + # Use HuggingFace's mel_filter_bank for compatibility. + # Suppress the expected warning about all-zero upper mel filters; + # with fft_length=512 (257 bins) and 128 mel filters the uppermost + # triangular filter falls between frequency bins, which is harmless. + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + self.mel_filters = mel_filter_bank( + num_frequency_bins=self.fft_length // 2 + 1, + num_mel_filters=feature_size, + min_frequency=min_frequency, + max_frequency=max_frequency, + sampling_rate=self.sampling_rate, + norm=None, + mel_scale="htk", + ) + + if per_bin_mean is not None: + self.per_bin_mean = np.array(per_bin_mean).reshape(1, 1, feature_size) + else: + self.per_bin_mean = None + + if per_bin_stddev is not None: + self.per_bin_stddev = np.array(per_bin_stddev).reshape(1, 1, feature_size) + else: + self.per_bin_stddev = None + + def _extract_spectrogram(self, waveform: np.ndarray, attention_mask: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """""" + if waveform.ndim == 1: # If single waveform, add batch dimension + waveform = np.expand_dims(waveform, axis=0) + + if self.dither > 0.0: + waveform = waveform + self.dither * np.random.randn(*waveform.shape).astype(waveform.dtype) + + if self.input_scale_factor != 1.0: + waveform = waveform * self.input_scale_factor + + # Semicausal time padding: prepend frame_length // 2 zeros so that the + # first STFT frame is centered at t=0, matching sl.STFT(time_padding='semicausal'). + pad_left = self.frame_length // 2 + waveform = np.pad(waveform, ((0, 0), (pad_left, 0)), mode="constant") + attention_mask = np.pad(attention_mask, (pad_left, 0), mode="constant", constant_values=0) + + frame_size_for_unfold = self.frame_length + 1 + + # NumPy equivalent of unfold for [B, NumFrames, frame_size_for_unfold] + frames_to_process = _unfold(waveform, dimension=-1, size=frame_size_for_unfold, step=self.hop_length) + + if self.preemphasis > 0.0: + if self.preemphasis_htk_flavor: + first_in_frame = frames_to_process[..., :1] * (1.0 - self.preemphasis) + rest_in_frame = frames_to_process[..., 1:-1] - self.preemphasis * frames_to_process[..., :-2] + frames = np.concatenate([first_in_frame, rest_in_frame], axis=-1) + else: + frames = frames_to_process[..., 1:] - self.preemphasis * frames_to_process[..., :-1] + else: + frames = frames_to_process[..., :-1] + + # Apply window, then RFFT. np.fft.rfft with n=fft_length implicitly + # right-pads frames to fft_length. + frames = frames * self.window # Broadcasting window + stft = np.fft.rfft(frames, n=self.fft_length, axis=-1) + + magnitude_spec = np.abs(stft) + + mel_spec = np.matmul(magnitude_spec, self.mel_filters) + log_mel_spec = np.log(mel_spec + self.mel_floor) + + if self.per_bin_mean is not None: + log_mel_spec = log_mel_spec - self.per_bin_mean # Broadcasting + + if self.per_bin_stddev is not None: + log_mel_spec = log_mel_spec / self.per_bin_stddev # Broadcasting + + mel_spectrogram = log_mel_spec.squeeze(0) + num_mel_frames = mel_spectrogram.shape[0] + + # Build a frame-aware mask: a mel frame is valid only when every sample + # in its analysis window [i*hop, i*hop + frame_size - 1] is real audio. + # We check this by looking at the last sample of each frame's window. + frame_end_indices = np.arange(num_mel_frames) * self.hop_length + frame_size_for_unfold - 1 + mask = attention_mask[frame_end_indices].astype(bool) + return mel_spectrogram, mask + + def __call__( + self, + raw_speech: np.ndarray | list[float] | list[np.ndarray] | list[list[float]], + padding: bool | str | PaddingStrategy = "longest", + max_length: int | None = 480_000, + truncation: bool = True, + pad_to_multiple_of: int | None = 128, + return_tensors: str | TensorType | None = None, + return_attention_mask: bool | None = True, + **kwargs, + ) -> BatchFeature: + """Creates a batch of MEL spectrograms from the provided raw speech. + + This implementation uses a different algorithm for windowing and preemphasis compared to the built-in + `transformers.audio_utils.spectrogram()` function that _will_ result in different outputs. Consider this + carefully when selecting an audio feature extractor, especially with pre-trained models. + + Args: + raw_speech: + The audio for which MEL spectrograms are created. + padding (`Union[bool, str, PaddingStrategy]`, *optional*, defaults to `"longest"`): + The padding strategy to use for batches of audio with different lengths. + max_length (`int`, *optional*, defaults to 480000): + If provided, defines the maximum length of the audio to allow. Audio longer than this will be + truncated if `truncation=True`. + truncation (`bool`, *optional*, defaults to `True`): + Whether or not to truncate audio above `max_length`. + pad_to_multiple_of (`int`, *optional*, defaults to 128): + When padding, pad to a multiple of this value. The default value is defined for optimal TPU support. + return_tensors (`Union[str, TensorType]`, *optional*, defaults to `None`): + The type of tensors to return (e.g., NumPy, or Torch). + return_attention_mask (`bool`, *optional*, defaults to `True`): + Whether to return the attention mask for the generated MEL spectrograms. + """ + + is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 + is_batched_sequence = isinstance(raw_speech, Sequence) and isinstance(raw_speech[0], (np.ndarray, Sequence)) + is_batched = is_batched_numpy or is_batched_sequence + + if is_batched: + raw_speech = [np.asarray([rs]).T for rs in raw_speech] + elif not is_batched and not isinstance(raw_speech, np.ndarray): + raw_speech = np.asarray(raw_speech) + + if not is_batched: # always return a batch + raw_speech = [np.asarray([raw_speech])] + + batched_speech = self.pad( + BatchFeature({"input_features": raw_speech}), + padding=padding, + max_length=max_length, + truncation=truncation, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + ) + + prepared_speech = [] + prepared_speech_mask = [] + for speech, mask in zip(batched_speech.input_features, batched_speech.attention_mask): + speech, mask = self._extract_spectrogram(speech.T, mask) + prepared_speech.append(speech.astype(np.float32)) + prepared_speech_mask.append(mask) + + prepared_speech = [speech * mask[..., None] for speech, mask in zip(prepared_speech, prepared_speech_mask)] + + return BatchFeature( + {"input_features": prepared_speech, "input_features_mask": prepared_speech_mask}, + tensor_type=return_tensors, + ) + + +__all__ = ["Gemma4AudioFeatureExtractor"] diff --git a/src/transformers/models/gemma4/image_processing_gemma4.py b/src/transformers/models/gemma4/image_processing_gemma4.py new file mode 100644 index 000000000000..88510d052daf --- /dev/null +++ b/src/transformers/models/gemma4/image_processing_gemma4.py @@ -0,0 +1,220 @@ +# Copyright 2026 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +from torchvision.transforms.v2 import functional as F + +from ...image_processing_backends import TorchvisionBackend +from ...image_processing_utils import BatchFeature +from ...image_utils import ImageInput, PILImageResampling +from ...processing_utils import ImagesKwargs, Unpack +from ...utils import TensorType, auto_docstring, logging +from .image_processing_pil_gemma4 import _SUPPORTED_SOFT_TOKENS, get_aspect_ratio_preserving_size + + +logger = logging.get_logger(__name__) + + +# Copied from transformers.models.siglip2.image_processing_siglip2.convert_image_to_patches +def convert_image_to_patches(image: "torch.Tensor", patch_size: int) -> "torch.Tensor": + """ + Convert 3D tensor image of shape (num_channels, image_height, image_width) into 2D tensor of patches of shape + (num_patches_height * num_patches_width, patch_size * patch_size * num_channels). + """ + num_channels, image_height, image_width = image.shape + num_patches_height = image_height // patch_size + num_patches_width = image_width // patch_size + patched_image = image.reshape(num_channels, num_patches_height, patch_size, num_patches_width, patch_size) + patched_image = patched_image.permute(1, 3, 2, 4, 0) + patched_image = patched_image.reshape(num_patches_height * num_patches_width, -1) + return patched_image + + +# Adopted from Siglip2 (mask -> position ids) +def pad_along_first_dim( + image: "torch.Tensor", positions: "torch.Tensor", target_length: int +) -> tuple["torch.Tensor", "torch.Tensor"]: + """ + Pad the tensor along the first dimension. + """ + current_length = image.shape[0] + padding_length = target_length - current_length + if padding_length > 0: + padding = [0, 0] * (image.ndim - 1) + [0, padding_length] + pos_padding = (0, 0, 0, padding_length) + image = torch.nn.functional.pad(image, padding, mode="constant", value=0) + positions = torch.nn.functional.pad(positions, pos_padding, mode="constant", value=-1) + return image, positions + + +class Gemma4ImageProcessorKwargs(ImagesKwargs, total=False): + """ + patch_size (`int`, *optional*): + Size of each image patch in pixels. + max_soft_tokens (`int`, *optional*): + Maximum number of soft (vision) tokens per image. + Must be one of {70, 140, 280, 560, 1120}. + pooling_kernel_size (`int`, *optional*): + Spatial pooling kernel size applied after patchification. + """ + + patch_size: int + max_soft_tokens: int + pooling_kernel_size: int + + +@auto_docstring(custom_intro="Constructs a Gemma4 image processor.") +class Gemma4ImageProcessor(TorchvisionBackend): + resample = PILImageResampling.BICUBIC + image_mean = [0.0, 0.0, 0.0] + image_std = [1.0, 1.0, 1.0] + size = None + default_to_square = True + do_convert_rgb = True + do_resize = True + do_rescale = True + do_normalize = False + patch_size = 16 + max_soft_tokens = 280 + pooling_kernel_size = 3 + valid_kwargs = Gemma4ImageProcessorKwargs + model_input_names = ["pixel_values", "image_position_ids", "num_soft_tokens_per_image"] + + def __init__(self, **kwargs: Unpack[Gemma4ImageProcessorKwargs]): + super().__init__(**kwargs) + + if self.max_soft_tokens not in _SUPPORTED_SOFT_TOKENS: + raise ValueError(f"`max_soft_tokens` must be one of {_SUPPORTED_SOFT_TOKENS}, got {self.max_soft_tokens}.") + + def _validate_preprocess_kwargs(self, **kwargs): + # Gemma4 uses aspect_ratio_preserving_resize driven by patch_size, + # max_soft_tokens, and pooling_kernel_size — not the standard `size` + # parameter. Temporarily disable do_resize so the base validation + # doesn't require `size` to be set. + kwargs["do_resize"] = False + super()._validate_preprocess_kwargs(**kwargs) + + def aspect_ratio_preserving_resize( + self, + image: torch.Tensor, + patch_size: int, + max_patches: int, + pooling_kernel_size: int, + resample: F.InterpolationMode, + ) -> torch.Tensor: + height, width = image.shape[-2], image.shape[-1] + target_height, target_width = get_aspect_ratio_preserving_size( + height=height, + width=width, + patch_size=patch_size, + max_patches=max_patches, + pooling_kernel_size=pooling_kernel_size, + ) + + if target_height == height and target_width == width: + return image + + return F.resize( + image, + size=[target_height, target_width], + interpolation=resample, + antialias=True, + ) + + def preprocess( + self, + images: ImageInput, + **kwargs: Unpack[Gemma4ImageProcessorKwargs], + ) -> BatchFeature: + return super().preprocess(images, **kwargs) + + def _preprocess( + self, + images: list["torch.Tensor"], + do_resize: bool, + resample: "PILImageResampling | F.InterpolationMode | int | None", + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: float | list[float] | None, + image_std: float | list[float] | None, + return_tensors: str | TensorType | None, + patch_size: int | None = None, + max_soft_tokens: int | None = None, + pooling_kernel_size: int | None = None, + **kwargs, + ) -> BatchFeature: + if max_soft_tokens not in _SUPPORTED_SOFT_TOKENS: + raise ValueError(f"`max_soft_tokens` must be one of {_SUPPORTED_SOFT_TOKENS}, got {max_soft_tokens}.") + + # Compute max_patches from max_soft_tokens and pooling_kernel_size + max_patches = max_soft_tokens * pooling_kernel_size**2 + + # Process each image individually: resize, rescale/normalize, patchify, pad. + # Images have different aspect ratios and thus different resized dimensions, + # so patchification and padding must happen per-image before stacking. + pixel_values = [] + position_ids = [] + num_soft_tokens_per_image = [] + + for image in images: + # Step 1: Aspect-ratio-preserving resize + if do_resize: + image = self.aspect_ratio_preserving_resize( + image=image, + patch_size=patch_size, + max_patches=max_patches, + pooling_kernel_size=pooling_kernel_size, + resample=resample, + ) + + # Step 2: Rescale pixel values (typically to [0, 1]) and optionally identity normalize + image = self.rescale_and_normalize(image, do_rescale, rescale_factor, do_normalize, image_mean, image_std) + + # Step 3: Patchify the image + # (num_channels, height, width) -> (num_patches, patch_size * patch_size * num_channels) + patch_height = image.shape[-2] // patch_size + patch_width = image.shape[-1] // patch_size + patches = convert_image_to_patches(image, patch_size) + num_soft_tokens_per_image.append(patches.shape[0] // pooling_kernel_size**2) + + # Step 5: Compute position IDs + device = image.device + patch_grid = torch.meshgrid( + torch.arange(patch_width, device=device), + torch.arange(patch_height, device=device), + indexing="xy", + ) + stacked_grid = torch.stack(patch_grid, dim=-1) + real_positions = stacked_grid.reshape(patches.shape[0], 2) + + # Step 6. Pad pacthes and positions to `max_patches` + patches, positions = pad_along_first_dim(patches, real_positions, max_patches) + pixel_values.append(patches) + position_ids.append(positions) + + # Stack into batch tensors + pixel_values = torch.stack(pixel_values, dim=0) # (batch, max_patches, patch_pixels) + position_ids = torch.stack(position_ids, dim=0) # (batch, max_patches, 2) + + data = { + "pixel_values": pixel_values, + "image_position_ids": position_ids, + "num_soft_tokens_per_image": num_soft_tokens_per_image, + } + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["Gemma4ImageProcessor"] diff --git a/src/transformers/models/gemma4/image_processing_pil_gemma4.py b/src/transformers/models/gemma4/image_processing_pil_gemma4.py new file mode 100644 index 000000000000..d58f6a4c2ec6 --- /dev/null +++ b/src/transformers/models/gemma4/image_processing_pil_gemma4.py @@ -0,0 +1,278 @@ +# Copyright 2026 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import numpy as np + +from ...image_processing_backends import PilBackend +from ...image_processing_utils import BatchFeature +from ...image_transforms import resize +from ...image_utils import ImageInput +from ...processing_utils import ImagesKwargs, Unpack +from ...utils import TensorType, auto_docstring, is_vision_available, logging + + +if is_vision_available(): + from ...image_utils import PILImageResampling + + +logger = logging.get_logger(__name__) + +_SUPPORTED_SOFT_TOKENS = (70, 140, 280, 560, 1120) + + +def get_aspect_ratio_preserving_size( + height: int, + width: int, + patch_size: int, + max_patches: int, + pooling_kernel_size: int, +) -> tuple[int, int]: + """ + Image is resized to preserve aspect ratio so it fits within the patch budget. + Target dimensions are the largest that: + 1) Produce at most `max_patches` patches when patchified with `patch_size` + 2) Have height and width divisible by `pooling_kernel_size * patch_size` + """ + total_px = height * width + target_px = max_patches * (patch_size**2) + factor = math.sqrt(target_px / total_px) + ideal_height = factor * height + ideal_width = factor * width + side_mult = pooling_kernel_size * patch_size + + # Round down to nearest multiple of side_mult + target_height = int(math.floor(ideal_height / side_mult)) * side_mult + target_width = int(math.floor(ideal_width / side_mult)) * side_mult + + # Handle edge cases where one or both dimensions round to 0 + if target_height == 0 and target_width == 0: + raise ValueError( + "Attempting to resize to a 0 x 0 image. Resized height should be divisble by " + f"`pooling_kernel_size * patch_size`={pooling_kernel_size * patch_size}." + ) + + max_side_length = (max_patches // pooling_kernel_size**2) * side_mult + if target_height == 0: + target_height = side_mult + target_width = min( + int(math.floor(width / height)) * side_mult, + max_side_length, + ) + elif target_width == 0: + target_width = side_mult + target_height = min( + int(math.floor(height / width)) * side_mult, + max_side_length, + ) + + if target_height * target_width > target_px: + raise ValueError( + f"Resizing [{height}x{width}] to [{target_height}x{target_width}] " + f"but this exceeds {max_patches} patches with patch_size {patch_size}" + ) + + return target_height, target_width + + +# Copied from transformers.models.siglip2.image_processing_pil_siglip2.convert_image_to_patches +def convert_image_to_patches(image: np.ndarray, patch_size: int) -> np.ndarray: + """ + Convert 3D array image of shape (num_channels, image_height, image_width) into 2D array of patches of shape + (num_patches_height * num_patches_width, patch_size * patch_size * num_channels). + """ + num_channels, image_height, image_width = image.shape + num_patches_height = image_height // patch_size + num_patches_width = image_width // patch_size + patched_image = image.reshape(num_channels, num_patches_height, patch_size, num_patches_width, patch_size) + patched_image = patched_image.transpose(1, 3, 2, 4, 0) + patched_image = patched_image.reshape(num_patches_height * num_patches_width, -1) + return patched_image + + +# Adopted from Siglip2 (mask -> position ids) +def pad_along_first_dim(image: np.ndarray, positions: np.ndarray, target_length: int) -> tuple[np.ndarray, np.ndarray]: + """ + Pad the image along the first dimension. + """ + current_length = image.shape[0] + padding_length = target_length - current_length + if padding_length > 0: + paddings = [(0, padding_length)] + [(0, 0)] * (image.ndim - 1) + pos_paddings = [(0, padding_length), (0, 0)] + image = np.pad(image, paddings, mode="constant", constant_values=0) + positions = np.pad(positions, pos_paddings, mode="constant", constant_values=-1) + return image, positions + + +class Gemma4ImageProcessorKwargs(ImagesKwargs, total=False): + """ + patch_size (`int`, *optional*): + Size of each image patch in pixels. + max_soft_tokens (`int`, *optional*): + Maximum number of soft (vision) tokens per image. + Must be one of {70, 140, 280, 560, 1120}. + pooling_kernel_size (`int`, *optional*): + Spatial pooling kernel size applied after patchification. + """ + + patch_size: int + max_soft_tokens: int + pooling_kernel_size: int + + +@auto_docstring(custom_intro="Constructs a Gemma4 image processor.") +class Gemma4ImageProcessorPil(PilBackend): + valid_kwargs = Gemma4ImageProcessorKwargs + model_input_names = ["pixel_values", "image_position_ids", "num_soft_tokens_per_image"] + + do_resize = True + resample = PILImageResampling.BICUBIC + do_rescale = True + rescale_factor = 1 / 255 + do_normalize = False + image_mean = [0.0, 0.0, 0.0] + image_std = [1.0, 1.0, 1.0] + do_convert_rgb = True + patch_size = 16 + max_soft_tokens = 280 + pooling_kernel_size = 3 + + def __init__(self, **kwargs: Unpack[Gemma4ImageProcessorKwargs]) -> None: + super().__init__(**kwargs) + + if self.max_soft_tokens not in _SUPPORTED_SOFT_TOKENS: + raise ValueError(f"`max_soft_tokens` must be one of {_SUPPORTED_SOFT_TOKENS}, got {self.max_soft_tokens}.") + + def _validate_preprocess_kwargs(self, **kwargs): + # Gemma4 uses aspect_ratio_preserving_resize driven by patch_size, + # max_soft_tokens, and pooling_kernel_size — not the standard `size` + # parameter. Temporarily disable do_resize so the base validation + # doesn't require `size` to be set. + kwargs["do_resize"] = False + super()._validate_preprocess_kwargs(**kwargs) + + @auto_docstring + def preprocess( + self, + images: ImageInput, + **kwargs: Unpack[Gemma4ImageProcessorKwargs], + ) -> BatchFeature: + return super().preprocess(images, **kwargs) + + def aspect_ratio_preserving_resize( + self, + image: np.ndarray, + patch_size: int, + max_patches: int, + pooling_kernel_size: int, + resample: PILImageResampling, + ) -> np.ndarray: + height, width = image.shape[-2], image.shape[-1] + target_height, target_width = get_aspect_ratio_preserving_size( + height=height, + width=width, + patch_size=patch_size, + max_patches=max_patches, + pooling_kernel_size=pooling_kernel_size, + ) + + if target_height == height and target_width == width: + return image + + return resize( + image, + size=(target_height, target_width), + resample=resample, + ) + + def _preprocess( + self, + images: list[np.ndarray], + do_resize: bool, + resample: "PILImageResampling | int | None", + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: float | list[float] | None, + image_std: float | list[float] | None, + return_tensors: str | TensorType | None, + max_soft_tokens: int | None = None, + patch_size: int | None = None, + pooling_kernel_size: int | None = None, + **kwargs, + ) -> BatchFeature: + if max_soft_tokens not in _SUPPORTED_SOFT_TOKENS: + raise ValueError(f"`max_soft_tokens` must be one of {_SUPPORTED_SOFT_TOKENS}, got {max_soft_tokens}.") + + # Compute max_patches from max_soft_tokens and pooling_kernel_size + max_patches = max_soft_tokens * pooling_kernel_size**2 + + # Process each image individually: resize, rescale/normalize, patchify, pad. + # Images have different aspect ratios and thus different resized dimensions, + # so patchification and padding must happen per-image before stacking. + pixel_values = [] + position_ids = [] + num_soft_tokens_per_image = [] + + for image in images: + # Step 1: Aspect-ratio-preserving resize + if do_resize: + image = self.aspect_ratio_preserving_resize( + image=image, + patch_size=patch_size, + max_patches=max_patches, + pooling_kernel_size=pooling_kernel_size, + resample=resample, + ) + + # Step 2: Rescale pixel values from [0, 255] to [0, 1] + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor) + + # Step 3: Identity normalization because Gemma4 was trained with pixels in [0, 1] + if do_normalize: + image = self.normalize(image=image, mean=image_mean, std=image_std) + + # Step 4: Patchify the image + # image is (C, H, W) numpy array; add batch dimension for reshape + # (num_channels, height, width) -> (num_patches, patch_size * patch_size * num_channels) + patches = convert_image_to_patches(image, patch_size) + num_soft_tokens_per_image.append(patches.shape[0] // pooling_kernel_size**2) + + # Step 5: Compute position IDs + patch_height = image.shape[-2] // patch_size + patch_width = image.shape[-1] // patch_size + grid_x, grid_y = np.meshgrid(np.arange(patch_width), np.arange(patch_height), indexing="xy") + real_positions = np.stack([grid_x, grid_y], axis=-1).reshape(patches.shape[0], 2) + + patches, positions = pad_along_first_dim(patches, real_positions, max_patches) + + pixel_values.append(patches) + position_ids.append(positions) + + # Stack into batch arrays and convert to tensors + pixel_values = np.stack(pixel_values, axis=0) # (batch, max_patches, patch_pixels) + position_ids = np.stack(position_ids, axis=0) # (batch, max_patches, 2) + + data = { + "pixel_values": pixel_values, + "image_position_ids": position_ids, + "num_soft_tokens_per_image": num_soft_tokens_per_image, + } + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["Gemma4ImageProcessorPil"] diff --git a/src/transformers/models/gemma4/modeling_gemma4.py b/src/transformers/models/gemma4/modeling_gemma4.py new file mode 100644 index 000000000000..f690c0425c8c --- /dev/null +++ b/src/transformers/models/gemma4/modeling_gemma4.py @@ -0,0 +1,2564 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/gemma4/modular_gemma4.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_gemma4.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2026 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from collections.abc import Callable +from dataclasses import dataclass +from functools import cached_property +from typing import Optional + +import torch +from torch import nn +from torch.nn import functional as F + +from ... import initialization as init +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...configuration_utils import PreTrainedConfig +from ...generation import GenerationMixin +from ...integrations import use_experts_implementation, use_kernelized_func +from ...masking_utils import ( + create_bidirectional_mask, + create_causal_mask, + create_masks_for_generate, + create_sliding_window_causal_mask, +) +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check +from ...utils.generic import maybe_autocast, merge_with_config_defaults +from ...utils.output_capturing import OutputRecorder, capture_outputs +from ..auto.modeling_auto import AutoModel +from .configuration_gemma4 import Gemma4AudioConfig, Gemma4Config, Gemma4TextConfig, Gemma4VisionConfig + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Gemma4 outputs, with hidden states and attentions. + """ +) +class Gemma4ModelOutputWithPast(BaseModelOutputWithPast): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + audio_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state. + """ + + image_hidden_states: torch.FloatTensor | None = None + + audio_hidden_states: torch.FloatTensor | None = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Gemma4 causal language model (or autoregressive) outputs. + """ +) +class Gemma4CausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder after projecting last hidden state. + audio_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state. + """ + + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + past_key_values: Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + image_hidden_states: torch.FloatTensor | None = None + + audio_hidden_states: torch.FloatTensor | None = None + + +@dataclass +@auto_docstring +class Gemma4AudioModelOutput(BaseModelOutputWithPooling): + r""" + attention_mask (`torch.BoolTensor`, *optional*): + A torch.BoolTensor of shape `(batch_size, num_frames)`. True for valid positions, False for padding. + """ + + attention_mask: torch.BoolTensor | None = None + + +class Gemma4ClippableLinear(nn.Module): + def __init__( + self, + config: Gemma4VisionConfig | Gemma4AudioConfig, + in_features: int, + out_features: int, + ) -> None: + super().__init__() + self.use_clipped_linears = config.use_clipped_linears + self.linear = nn.Linear(in_features, out_features, bias=False) + + if self.use_clipped_linears: + self.register_buffer("input_min", torch.tensor(-float("inf"))) + self.register_buffer("input_max", torch.tensor(float("inf"))) + self.register_buffer("output_min", torch.tensor(-float("inf"))) + self.register_buffer("output_max", torch.tensor(float("inf"))) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if self.use_clipped_linears: + hidden_states = torch.clamp(hidden_states, self.input_min, self.input_max) + + hidden_states = self.linear(hidden_states) + + if self.use_clipped_linears: + hidden_states = torch.clamp(hidden_states, self.output_min, self.output_max) + + return hidden_states + + +class Gemma4RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6, with_scale: bool = True): + super().__init__() + self.eps = eps + self.with_scale = with_scale + + if self.with_scale: + self.weight = nn.Parameter(torch.ones(dim), requires_grad=True) + + def _norm(self, hidden_states: torch.Tensor): + mean_squared = hidden_states.pow(2).mean(-1, keepdim=True) + self.eps + # Use torch.pow() (over torch.sqrt() or torch.rsqrt()) to addess compiler differences between Torch and JAX + return hidden_states * torch.pow(mean_squared, -0.5) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + normed_output = self._norm(hidden_states.float()) + if self.with_scale: + normed_output = normed_output * self.weight.float() + return normed_output.type_as(hidden_states) + + +class Gemma4AudioRelPositionalEncoding(nn.Module): + """Sinusoidal relative positional encoding for the audio encoder. + + Produces position embeddings of shape [1, 2*context_size - 1, hidden_size] with + concatenated [sin..., cos...] layout matching the original Gemma4 convention. + """ + + inv_timescales: torch.Tensor + + def __init__(self, config: Gemma4AudioConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.context_size = ( + config.attention_chunk_size + config.attention_context_left - 1 + config.attention_context_right + ) + min_timescale = 1.0 + max_timescale = 10000.0 + num_timescales = self.hidden_size // 2 + log_timescale_increment = math.log(max_timescale / min_timescale) / max(num_timescales - 1, 1) + inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales) * -log_timescale_increment) + self.register_buffer("inv_timescales", inv_timescales.unsqueeze(0).unsqueeze(0), persistent=False) + + @torch.no_grad() + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + position_ids = torch.arange(12, -1, -1, device=hidden_states.device) + position_ids = position_ids[..., None] + scaled_time = position_ids * self.inv_timescales.to(device=hidden_states.device) + pos_embed = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1) + return pos_embed.to(dtype=hidden_states.dtype) + + +class Gemma4AudioAttention(nn.Module): + """Chunked local attention with relative position bias""" + + def __init__(self, config: Gemma4AudioConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.attention_logits_soft_cap = config.attention_logit_cap + self.head_dim = config.hidden_size // config.num_attention_heads + self.num_heads = config.num_attention_heads + + self.q_scale = (self.head_dim**-0.5) / math.log(2) + self.k_scale = math.log(1 + math.e) / math.log(2) + + self.chunk_size = config.attention_chunk_size + self.max_past_horizon = config.attention_context_left - 1 + self.max_future_horizon = config.attention_context_right + self.context_size = self.chunk_size + self.max_past_horizon + self.max_future_horizon + + self.q_proj = Gemma4ClippableLinear(config, config.hidden_size, self.num_heads * self.head_dim) + self.k_proj = Gemma4ClippableLinear(config, config.hidden_size, self.num_heads * self.head_dim) + self.v_proj = Gemma4ClippableLinear(config, config.hidden_size, self.num_heads * self.head_dim) + self.post = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size) + + self.relative_k_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) + self.per_dim_scale = nn.Parameter(torch.zeros(self.head_dim)) + + self.register_buffer("softcap", torch.tensor(self.attention_logits_soft_cap), persistent=False) + + def _convert_to_block(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Splits a `(batch_size, seq_len, num_heads, head_dim)` tensor into non-overlapping blocks of `chunk_size` along the sequence dim.""" + batch_size, seq_len, num_heads, head_dim = hidden_states.shape + num_blocks = (seq_len + self.chunk_size - 1) // self.chunk_size + pad = num_blocks * self.chunk_size - seq_len + hidden_states = F.pad(hidden_states, (0, 0, 0, 0, 0, pad)) + return hidden_states.reshape(batch_size, num_blocks, self.chunk_size, num_heads, head_dim).contiguous() + + def _extract_block_context(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Extracts overlapping context windows of `context_size` for every block, strided by `chunk_size`.""" + batch_size, seq_len, num_heads, head_dim = hidden_states.shape + hidden_states = F.pad( + hidden_states, (0, 0, 0, 0, self.max_past_horizon, self.max_future_horizon + self.chunk_size - 1) + ) + hidden_states = hidden_states.unfold(1, self.context_size, self.chunk_size) + hidden_states = torch.movedim(hidden_states, -1, 2) + return hidden_states.contiguous() + + def _rel_shift(self, x: torch.Tensor) -> torch.Tensor: + """Relative position shift for blocked attention. See appendix B of https://huggingface.co/papers/1901.02860.""" + batch_size, num_heads, num_blocks, block_size, position_length = x.shape + context_size = self.context_size + x = F.pad(x, (0, context_size + 1 - position_length)) + x = x.view(batch_size, num_heads, num_blocks, block_size * (context_size + 1)) + x = x[..., : block_size * context_size] + return x.view(batch_size, num_heads, num_blocks, block_size, context_size) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: torch.Tensor, + attention_mask: torch.BoolTensor | None = None, + ) -> tuple[torch.Tensor, None]: + batch_size, seq_length, _ = hidden_states.shape + hidden_shape = (batch_size, seq_length, self.num_heads, self.head_dim) + + query_states = self.q_proj(hidden_states).float().view(hidden_shape) + key_states = self.k_proj(hidden_states).float().view(hidden_shape) + value_states = self.v_proj(hidden_states).float().view(hidden_shape) + + query_states = query_states * self.q_scale * F.softplus(self.per_dim_scale) + key_states = key_states * self.k_scale + + query_states = self._convert_to_block(query_states) + key_states = self._extract_block_context(key_states) + value_states = self._extract_block_context(value_states) + num_blocks = query_states.shape[1] + + relative_key_states = self.relative_k_proj(position_embeddings) + relative_key_states = relative_key_states.view(-1, self.num_heads, self.head_dim) + relative_key_states = relative_key_states.to(dtype=query_states.dtype) + + queries = query_states.permute(0, 3, 1, 2, 4) + matrix_ac = queries @ key_states.permute(0, 3, 1, 4, 2) + + queries_flat = queries.reshape(batch_size, self.num_heads, -1, self.head_dim) + matrix_bd = queries_flat @ relative_key_states.permute(1, 2, 0) + matrix_bd = matrix_bd.reshape(batch_size, self.num_heads, num_blocks, self.chunk_size, -1) + matrix_bd = self._rel_shift(matrix_bd) + + attn_weights = matrix_ac + matrix_bd + attn_weights = attn_weights / self.softcap + attn_weights = torch.tanh(attn_weights) + attn_weights = attn_weights * self.softcap + + if attention_mask is not None: + attn_weights = attn_weights.masked_fill( + attention_mask.logical_not(), self.config.attention_invalid_logits_value + ) + + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype) + attn_output = attn_weights @ value_states.permute(0, 3, 1, 2, 4) + attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, num_blocks * self.chunk_size, -1) + attn_output = attn_output[:, :seq_length].contiguous() + attn_output = self.post(attn_output.to(dtype=self.post.linear.weight.dtype)) + + return attn_output, attn_weights + + +class Gemma4AudioSubSampleConvProjectionLayer(nn.Module): + def __init__(self, in_channels, out_channels, norm_eps): + super().__init__() + self.conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=(2, 2), + padding=1, + bias=False, + ) + self.norm = nn.LayerNorm(out_channels, eps=norm_eps, elementwise_affine=True, bias=False) + self.act = nn.ReLU() + + def forward(self, hidden_states: torch.Tensor, mask: torch.Tensor | None = None): + if mask is not None: + mask = mask.to(device=hidden_states.device) + hidden_states = hidden_states * mask[:, None, :, None] + + hidden_states = self.conv(hidden_states.to(self.conv.weight.dtype)) + hidden_states = self.act(self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2).contiguous()) + + if mask is not None: + mask = mask[:, ::2] + + return hidden_states, mask + + +class Gemma4AudioSubSampleConvProjection(nn.Module): + def __init__(self, config: Gemma4AudioConfig): + super().__init__() + self.layer0 = Gemma4AudioSubSampleConvProjectionLayer( + in_channels=1, + out_channels=config.subsampling_conv_channels[0], + norm_eps=config.rms_norm_eps, + ) + self.layer1 = Gemma4AudioSubSampleConvProjectionLayer( + in_channels=config.subsampling_conv_channels[0], + out_channels=config.subsampling_conv_channels[1], + norm_eps=config.rms_norm_eps, + ) + proj_input_dim = (config.subsampling_conv_channels[0] // 4) * config.subsampling_conv_channels[1] + self.input_proj_linear = nn.Linear(proj_input_dim, config.hidden_size, bias=False) + + def forward( + self, + input_features: torch.Tensor, + input_features_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + hidden_states = input_features.unsqueeze(1) + hidden_states, mask = self.layer0(hidden_states, input_features_mask) + hidden_states, mask = self.layer1(hidden_states, mask) + + batch_size, _, seq_len, _ = hidden_states.shape + hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous().reshape(batch_size, seq_len, -1) + return self.input_proj_linear(hidden_states), mask + + +class Gemma4AudioFeedForward(nn.Module): + def __init__(self, config: Gemma4AudioConfig): + super().__init__() + self.config = config + + self.ffw_layer_1 = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size * 4) + self.ffw_layer_2 = Gemma4ClippableLinear(config, config.hidden_size * 4, config.hidden_size) + + self.pre_layer_norm = Gemma4RMSNorm(config.hidden_size) + self.post_layer_norm = Gemma4RMSNorm(config.hidden_size) + self.act_fn = ACT2FN[config.hidden_act] + + self.gradient_clipping = config.gradient_clipping + self.post_layer_scale = config.residual_weight + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # This is needed to avoid any underflow/overflow issues when clipping + gradient_clipping = min(self.gradient_clipping, torch.finfo(self.ffw_layer_1.linear.weight.dtype).max) + + residual = hidden_states + hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping) + hidden_states = self.pre_layer_norm(hidden_states) + + hidden_states = self.ffw_layer_1(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states = self.ffw_layer_2(hidden_states) + + hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping) + hidden_states = self.post_layer_norm(hidden_states) + hidden_states *= self.post_layer_scale + hidden_states += residual + + return hidden_states + + +# TODO: this could be imported from Voxtral realtime +class Gemma4AudioCausalConv1d(nn.Conv1d): + # def __init__( + # self, + # in_channels: int, + # out_channels: int, + # kernel_size: int, + # # cache_key: str, + # stride: int = 1, + # dilation: int = 1, + # bias: bool = True, + # ): + # super().__init__(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, bias=bias) + # self.cache_key = cache_key + + @cached_property + def left_pad(self): + effective_kernel_size = (self.kernel_size[0] - 1) * self.dilation[0] + 1 + return effective_kernel_size - self.stride[0] + + def forward( + self, + x: torch.Tensor, + # padding_cache: VoxtralRealtimeConv1dPaddingCache | None = None, # TODO: we might want to add a cache? + ) -> torch.Tensor: + # if padding_cache is not None: + # x = padding_cache.update(x, self.cache_key, self) + # else: + # x = nn.functional.pad(x, (self.left_pad, 0)) + x = nn.functional.pad(x, (self.left_pad, 0)) + + return super().forward(x) + + +class Gemma4AudioLightConv1d(nn.Module): + def __init__(self, config: Gemma4AudioConfig): + super().__init__() + self.config = config + + self.linear_start = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size * 2) + self.linear_end = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size) + self.depthwise_conv1d = Gemma4AudioCausalConv1d( + in_channels=config.hidden_size, + out_channels=config.hidden_size, + kernel_size=config.conv_kernel_size, + groups=config.hidden_size, + bias=False, + ) + + self.pre_layer_norm = Gemma4RMSNorm(config.hidden_size, eps=config.rms_norm_eps, with_scale=True) + self.conv_norm = Gemma4RMSNorm(config.hidden_size, eps=config.rms_norm_eps, with_scale=True) + self.act_fn = ACT2FN[config.hidden_act] + + self.gradient_clipping = config.gradient_clipping + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + residual = hidden_states + + hidden_states = self.pre_layer_norm(hidden_states) + hidden_states = self.linear_start(hidden_states) + hidden_states = nn.functional.glu(hidden_states, dim=-1) + + hidden_states = self.depthwise_conv1d(hidden_states.transpose(1, 2)).transpose(1, 2) + + # This is needed to avoid any underflow/overflow issues when clipping + gradient_clipping = min(self.gradient_clipping, torch.finfo(self.linear_start.linear.weight.dtype).max) + hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping) + hidden_states = self.conv_norm(hidden_states) + + hidden_states = self.act_fn(hidden_states) + hidden_states = self.linear_end(hidden_states) + hidden_states += residual + return hidden_states + + +class Gemma4AudioLayer(nn.Module): + def __init__(self, config: Gemma4AudioConfig, layer_idx: int): + super().__init__() + self.config = config + + self.feed_forward1 = Gemma4AudioFeedForward(config) + self.feed_forward2 = Gemma4AudioFeedForward(config) + self.self_attn = Gemma4AudioAttention(config, layer_idx) + self.lconv1d = Gemma4AudioLightConv1d(config) + + self.norm_pre_attn = Gemma4RMSNorm(config.hidden_size) + self.norm_post_attn = Gemma4RMSNorm(config.hidden_size) + self.norm_out = Gemma4RMSNorm(config.hidden_size) + + self.gradient_clipping = config.gradient_clipping + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.BoolTensor | None, + position_embeddings: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + # This is needed to avoid any underflow/overflow issues when clipping + gradient_clipping = min(self.gradient_clipping, torch.finfo(self.norm_pre_attn.weight.dtype).max) + + hidden_states = self.feed_forward1(hidden_states) + residual = hidden_states + + hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping) + hidden_states = self.norm_pre_attn(hidden_states) + + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + ) + + hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping) + hidden_states = self.norm_post_attn(hidden_states) + hidden_states += residual + + hidden_states = self.lconv1d(hidden_states) + hidden_states = self.feed_forward2(hidden_states) + + hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping) + hidden_states = self.norm_out(hidden_states) + + return hidden_states + + +# ---- Vision Encoder Layers ---- + + +class Gemma4VisionPatchEmbedder(nn.Module): + def __init__(self, config: Gemma4VisionConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.patch_size = config.patch_size + self.position_embedding_size = config.position_embedding_size + + self.input_proj = nn.Linear(3 * self.patch_size**2, self.hidden_size, bias=False) + self.position_embedding_table = nn.Parameter(torch.ones(2, self.position_embedding_size, self.hidden_size)) + + def _position_embeddings(self, pixel_position_ids: torch.Tensor, padding_positions: torch.Tensor) -> torch.Tensor: + """Prepare patch positions map for matmul with positon embedding table.""" + # Expanding and permute patch positions to (batch_size, num_patches, 2, position_embedding_size) for matmul. + clamped_positions = pixel_position_ids.clamp(min=0) + one_hot = F.one_hot(clamped_positions, num_classes=self.position_embedding_size) + one_hot = one_hot.permute(0, 2, 1, 3).to(self.position_embedding_table) + # Compute positional embeddings and sum across x and y. + position_embeddings = one_hot @ self.position_embedding_table + position_embeddings = position_embeddings.sum(dim=1) + # Zero out embeddings for any padding patches. + position_embeddings = torch.where(padding_positions.unsqueeze(-1), 0.0, position_embeddings) + return position_embeddings + + def forward( + self, pixel_values: torch.Tensor, pixel_position_ids: torch.Tensor, padding_positions: torch.Tensor + ) -> torch.Tensor: + # Gemma4 applies no normalization and instead scales in model code + pixel_values = 2 * (pixel_values - 0.5) + hidden_states = self.input_proj(pixel_values.to(self.input_proj.weight.dtype)) + position_embeddings = self._position_embeddings(pixel_position_ids, padding_positions) + return hidden_states + position_embeddings + + +class Gemma4VisionPooler(nn.Module): + """Scaling and optional spatial pooling for vision encodings""" + + def __init__(self, config: Gemma4VisionConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.root_hidden_size = self.hidden_size**0.5 + + def _avg_pool_by_positions( + self, hidden_states: torch.Tensor, pixel_position_ids: torch.Tensor, length: int + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + 2D spatial pooling according to patch positions. + Pools the input tokens by averaging patches within a `k^2` grid, where `k` is determined by the ratio between + input and output lengths + """ + input_seq_len = hidden_states.shape[1] + k = int((input_seq_len // length) ** 0.5) + k_squared = k**2 + if k_squared * length != input_seq_len: + raise ValueError( + f"Cannot pool {hidden_states.shape} to {length}: {k=}^2 times {length=} must be {input_seq_len}." + ) + + # Clamp padding positions (which are -1) to 0 so they don't break one_hot. + # Padding patches have zero hidden states so they contribute nothing to the average. + clamped_positions = pixel_position_ids.clamp(min=0) + max_x = clamped_positions[..., 0].max(dim=-1, keepdim=True)[0] + 1 + kernel_idxs = torch.div(clamped_positions, k, rounding_mode="floor") + kernel_idxs = kernel_idxs[..., 0] + (max_x // k) * kernel_idxs[..., 1] + weights = F.one_hot(kernel_idxs.long(), length).float() / k_squared + output = weights.transpose(1, 2) @ hidden_states.float() + mask = torch.logical_not((weights == 0).all(dim=1)) + return output.to(hidden_states.dtype), mask + + def forward( + self, + hidden_states: torch.Tensor, + pixel_position_ids: torch.Tensor, + padding_positions: torch.Tensor, + output_length: int | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if output_length > hidden_states.shape[1]: + raise ValueError( + f"Cannot output more soft tokens (requested {output_length}) than there are patches" + f" ({hidden_states.shape[1]}). Change the value of `num_soft_tokens` when processing." + ) + + hidden_states = hidden_states.masked_fill(padding_positions.unsqueeze(-1), 0.0) + + if hidden_states.shape[1] != output_length: + hidden_states, padding_positions = self._avg_pool_by_positions( + hidden_states, pixel_position_ids, output_length + ) + + hidden_states *= self.root_hidden_size + return hidden_states, padding_positions + + +class Gemma4VisionMLP(nn.Module): + def __init__(self, config: Gemma4VisionConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = Gemma4ClippableLinear(config, self.hidden_size, self.intermediate_size) + self.up_proj = Gemma4ClippableLinear(config, self.hidden_size, self.intermediate_size) + self.down_proj = Gemma4ClippableLinear(config, self.intermediate_size, self.hidden_size) + self.act_fn = ACT2FN[config.hidden_activation] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class Gemma4VisionRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: Gemma4VisionConfig, device=None): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + + self.rope_type = self.config.rope_parameters["rope_type"] + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) + + @staticmethod + def compute_default_rope_parameters( + config: Gemma4VisionConfig | None = None, + device: torch.device | None = None, + seq_len: int | None = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters["rope_theta"] + dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + + # The reference implementation computes RoPE frequencies INDEPENDENTLY + # for each spatial dimension using the partitioned head_dim (head_dim // ndim), + # so both x and y dimensions get identical frequency ranges. + # This is different from splitting the global inv_freq between dimensions. + spatial_dim = dim // 2 + + attention_factor = 1.0 # Unused in this type of RoPE + inv_freq = 1.0 / ( + base + ** (torch.arange(0, spatial_dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / spatial_dim) + ) + return inv_freq, attention_factor + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + + # Multidimensional positions: [batch, num_patches, ndim]. Apply rotations to each spatial dim separately + all_cos, all_sin = [], [] + for i in range(2): + dim_position_ids = position_ids[:, :, i] + dim_position_ids_expanded = dim_position_ids[:, None, :].float() + + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ dim_position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + all_cos.append(cos) + all_sin.append(sin) + + cos = torch.cat(all_cos, dim=-1).to(dtype=x.dtype) + sin = torch.cat(all_sin, dim=-1).to(dtype=x.dtype) + return cos, sin + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + x (`torch.Tensor`): The tensor to embed. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + return (x * cos) + (rotate_half(x) * sin) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + dropout: float | int = 0.0, + scaling: float | None = None, + softcap: float | None = None, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor]: + if scaling is None: + scaling = module.head_dim**-0.5 + + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + + if softcap is not None: + attn_weights = attn_weights / softcap + attn_weights = torch.tanh(attn_weights) + attn_weights = attn_weights * softcap + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +def apply_multidimensional_rope( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + position_ids: torch.Tensor, + unsqueeze_dim: int = 2, +) -> torch.Tensor: + """Applies multidimensional RoPE to inputs. + + Args: + x (`torch.Tensor`): The tensor to embed. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + If position_ids.ndim + 2 == x.ndim, then this function passes through to `apply_rotary_pos_emb()`. + Otherwise, position_ids is used to split the inputs, x, into multiple pieces, where each piece is fed to + `apply_rotary_pos_emb()`, and then concatenated back together. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + + Returns: + Tensor of shape [B, L, N, H] with RoPE applied. + """ + ndim = position_ids.shape[-1] + num_input_channels = x.shape[-1] + num_rotated_channels_per_dim = 2 * (num_input_channels // (2 * ndim)) + + if num_rotated_channels_per_dim <= 0: + raise ValueError( + "Invalid configuration: num_rotated_channels_per_dim must be > 0, got" + f" {num_rotated_channels_per_dim} (num_input_channels={num_input_channels}," + f" ndim={ndim})" + ) + + # Correctly split the input tensor into ndim parts + split_sizes = [num_rotated_channels_per_dim] * ndim + x_parts = torch.split(x, split_sizes, dim=-1) + cos_parts = torch.split(cos, split_sizes, dim=-1) + sin_parts = torch.split(sin, split_sizes, dim=-1) + y_parts = [ + apply_rotary_pos_emb( + x=x_parts[k], + cos=cos_parts[k], + sin=sin_parts[k], + unsqueeze_dim=unsqueeze_dim, + ) + for k in range(ndim) + ] + return torch.cat(y_parts, dim=-1) + + +@use_kernelized_func(apply_rotary_pos_emb) +class Gemma4VisionAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Gemma4VisionConfig, layer_idx: int): + super().__init__() + self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = 1.0 + self.attention_dropout = self.config.attention_dropout + self.is_causal = False + self.q_proj = Gemma4ClippableLinear(config, config.hidden_size, config.num_attention_heads * self.head_dim) + self.k_proj = Gemma4ClippableLinear(config, config.hidden_size, config.num_key_value_heads * self.head_dim) + self.v_proj = Gemma4ClippableLinear(config, config.hidden_size, config.num_key_value_heads * self.head_dim) + self.o_proj = Gemma4ClippableLinear(config, config.num_attention_heads * self.head_dim, config.hidden_size) + + self.q_norm = Gemma4RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) + self.k_norm = Gemma4RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) + self.v_norm = Gemma4RMSNorm(self.head_dim, eps=config.rms_norm_eps, with_scale=False) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: torch.Tensor = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + cos, sin = position_embeddings + + query_states = self.q_proj(hidden_states).view(hidden_shape) + query_states = self.q_norm(query_states) + query_states = apply_multidimensional_rope(query_states, cos, sin, position_ids) + query_states = query_states.transpose(1, 2) + + key_states = self.k_proj(hidden_states).view(hidden_shape) + key_states = self.k_norm(key_states) + key_states = apply_multidimensional_rope(key_states, cos, sin, position_ids) + key_states = key_states.transpose(1, 2) + + value_states = self.v_proj(hidden_states).view(hidden_shape) + value_states = self.v_norm(value_states) + value_states = value_states.transpose(1, 2) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=self.attention_dropout if self.training else 0.0, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Gemma4VisionEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Gemma4VisionConfig, layer_idx: int): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + self.self_attn = Gemma4VisionAttention(config=config, layer_idx=layer_idx) + self.mlp = Gemma4VisionMLP(config) + self.input_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.pre_feedforward_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.post_feedforward_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: torch.Tensor = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + **kwargs, + ) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class Gemma4VisionEncoder(nn.Module): + def __init__(self, config: Gemma4VisionConfig): + super().__init__() + self.config = config + self.num_layers = config.num_hidden_layers + self.rotary_emb = Gemma4VisionRotaryEmbedding(config) + self.layers = nn.ModuleList( + [Gemma4VisionEncoderLayer(config=config, layer_idx=i) for i in range(self.num_layers)] + ) + + def forward( + self, + inputs_embeds: torch.Tensor, + attention_mask: torch.Tensor, + pixel_position_ids: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + r""" + pixel_position_ids (torch.Tensor): + Patch positions as (x, y) coordinates in the image as [batch, num_patches, 2]. + """ + attention_mask = create_bidirectional_mask( + config=self.config, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + ) + + # embed positions + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, pixel_position_ids) + + # decoder layers + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + position_ids=pixel_position_ids, + **kwargs, + ) + + return BaseModelOutputWithPast(last_hidden_state=hidden_states) + + +class Gemma4TextMLP(nn.Module): + def __init__(self, config: Gemma4TextConfig, layer_idx: int): + super().__init__() + first_kv_shared_layer_idx = config.num_hidden_layers - config.num_kv_shared_layers + is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0 + use_double_wide_mlp = config.use_double_wide_mlp and is_kv_shared_layer + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size * (2 if use_double_wide_mlp else 1) + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_activation] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class Gemma4TextRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: Gemma4TextConfig, device=None, layer_type=None): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.layer_types = set(config.layer_types) + self.rope_init_fns: dict[str, Callable[..., tuple[torch.Tensor, float]]] = {} + self.rope_type: dict[str, str] = {} + + for layer_type in self.layer_types: + rope_params = self.config.rope_parameters[layer_type] + if rope_params is None: + continue + + if (rope_type := rope_params["rope_type"]) != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type] + else: + rope_init_fn = self.compute_default_rope_parameters + + self.rope_init_fns[layer_type] = rope_init_fn + self.rope_type[layer_type] = rope_type + + rope_init_fn_kwargs = {"device": device, "layer_type": layer_type} + if layer_type == "full_attention" and rope_type == "proportional": + rope_init_fn_kwargs["head_dim_key"] = "global_head_dim" + + curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, **rope_init_fn_kwargs) + self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False) + self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False) + setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling) + + @staticmethod + def compute_default_rope_parameters( + config: Gemma4TextConfig | None = None, + device: Optional["torch.device"] = None, + seq_len: int | None = None, + layer_type: str | None = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + layer_type (`str`, *optional*): + The current layer type if the model has different RoPE parameters per type. + Should not be used unless `config.layer_types is not None` + + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + # For backward compatibility standardize the `rope_parameters_dict` if it uses old format + base = config.rope_parameters[layer_type]["rope_theta"] + dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids, layer_type=None): + inv_freq = getattr(self, f"{layer_type}_inv_freq") + attention_scaling = getattr(self, f"{layer_type}_attention_scaling") + + inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * attention_scaling + sin = emb.sin() * attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +@use_kernelized_func(apply_rotary_pos_emb) +class Gemma4TextAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Gemma4TextConfig, layer_idx: int): + super().__init__() + self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None + self.config = config + self.layer_idx = layer_idx + self.is_sliding = self.layer_type == "sliding_attention" + self.sliding_window = config.sliding_window if self.is_sliding else None + + self.head_dim = config.global_head_dim if not self.is_sliding and config.global_head_dim else config.head_dim + self.use_alternative_attention = config.attention_k_eq_v and not self.is_sliding + num_key_value_heads = ( + config.num_global_key_value_heads if self.use_alternative_attention else config.num_key_value_heads + ) + self.num_key_value_groups = config.num_attention_heads // num_key_value_heads + self.scaling = 1.0 + self.attention_dropout = self.config.attention_dropout + self.is_causal = config.use_bidirectional_attention != "all" + + # Shared kv cache + first_kv_shared_layer_idx = self.config.num_hidden_layers - getattr(self.config, "num_kv_shared_layers", 0) + self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0 + prev_layers = config.layer_types[:first_kv_shared_layer_idx] + if self.is_kv_shared_layer: + # For shared layers, find the last non-shared layer of the same type before sharing starts + self.kv_shared_layer_index = len(prev_layers) - 1 - prev_layers[::-1].index(config.layer_types[layer_idx]) + self.store_full_length_kv = False + else: + self.kv_shared_layer_index = None + # For non-shared layers, store full-length kv if this is the last non-shared layer of its type + self.store_full_length_kv = layer_idx == len(prev_layers) - 1 - prev_layers[::-1].index( + config.layer_types[layer_idx] + ) + + self.q_norm = Gemma4RMSNorm(dim=self.head_dim, eps=config.rms_norm_eps) + self.k_norm = Gemma4RMSNorm(dim=self.head_dim, eps=config.rms_norm_eps) + self.v_norm = Gemma4RMSNorm(self.head_dim, eps=config.rms_norm_eps, with_scale=False) + + self.k_proj = nn.Linear(config.hidden_size, num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = ( + nn.Linear(config.hidden_size, num_key_value_heads * self.head_dim, bias=config.attention_bias) + if not self.use_alternative_attention + else None + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: torch.Tensor, + attention_mask: torch.Tensor | None, + past_key_values: Cache | None = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + cos, sin = position_embeddings + + query_states = self.q_proj(hidden_states).view(hidden_shape) + query_states = self.q_norm(query_states) + query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2) + query_states = query_states.transpose(1, 2) + + # For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer + if self.is_kv_shared_layer and past_key_values is not None: + key_states, value_states = past_key_values.shared_layers[self.kv_shared_layer_index] + # Device of past layer may be different from current one + key_states = key_states.to(query_states.device) + value_states = value_states.to(query_states.device) + else: + key_states = self.k_proj(hidden_states).view(hidden_shape) + value_states = self.v_proj(hidden_states).view(hidden_shape) if self.v_proj is not None else key_states + + key_states = self.k_norm(key_states) + key_states = apply_rotary_pos_emb(key_states, cos, sin, unsqueeze_dim=2) + key_states = key_states.transpose(1, 2) + + value_states = self.v_norm(value_states) + value_states = value_states.transpose(1, 2) + + if past_key_values is not None: + if not self.is_kv_shared_layer: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) + if self.store_full_length_kv: + if not hasattr(past_key_values, "shared_layers"): + past_key_values.shared_layers = {} + past_key_values.shared_layers[self.layer_idx] = key_states, value_states + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=self.attention_dropout if self.training else 0.0, + scaling=self.scaling, + sliding_window=self.sliding_window, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +@use_experts_implementation +class Gemma4TextExperts(nn.Module): + """Collection of expert weights stored as 3D tensors.""" + + def __init__(self, config: Gemma4TextConfig): + super().__init__() + self.num_experts = config.num_experts + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.moe_intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) + self.act_fn = ACT2FN[config.hidden_activation] + + def forward( + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + ) -> torch.Tensor: + final_hidden_states = torch.zeros_like(hidden_states) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + + for expert_idx in expert_hit: + expert_idx = expert_idx[0] + if expert_idx == self.num_experts: + continue + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) + + return final_hidden_states + + +class Gemma4TextRouter(nn.Module): + def __init__(self, config: Gemma4TextConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.scalar_root_size = self.hidden_size**-0.5 + self.eps = config.rms_norm_eps + + self.norm = Gemma4RMSNorm(self.hidden_size, eps=self.eps, with_scale=False) + self.proj = nn.Linear(config.hidden_size, config.num_experts, bias=False) + self.scale = nn.Parameter(torch.ones(self.hidden_size)) + self.per_expert_scale = nn.Parameter(torch.ones(config.num_experts)) + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + hidden_states = self.norm(hidden_states) + hidden_states = hidden_states * self.scale * self.scalar_root_size + + expert_scores = self.proj(hidden_states) # [B*S, E] + router_probabilities = nn.functional.softmax(expert_scores, dim=-1) + + # topk returns both values (probabilities) and indices directly + top_k_weights, top_k_index = torch.topk( + router_probabilities, + k=self.config.top_k_experts, + dim=-1, + ) # both [B*S, K] + + # Normalize the top-k weights so they sum to 1 per token + top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True) + + # Apply per-expert scale directly to the weights + top_k_weights = top_k_weights * self.per_expert_scale[top_k_index] + + return router_probabilities, top_k_weights, top_k_index + + +class Gemma4TextDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Gemma4TextConfig | Gemma4VisionConfig, layer_idx: int): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + self.self_attn = Gemma4TextAttention(config=config, layer_idx=layer_idx) + self.mlp = Gemma4TextMLP(config, layer_idx) + self.input_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.pre_feedforward_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.post_feedforward_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.register_buffer("layer_scalar", torch.ones(1)) + + self.hidden_size_per_layer_input = config.hidden_size_per_layer_input + if self.hidden_size_per_layer_input: + self.act_fn = ACT2FN[config.hidden_activation] + self.per_layer_input_gate = nn.Linear(self.hidden_size, self.hidden_size_per_layer_input, bias=False) + self.per_layer_projection = nn.Linear(self.hidden_size_per_layer_input, self.hidden_size, bias=False) + self.post_per_layer_input_norm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + + self.enable_moe_block = config.enable_moe_block + if self.enable_moe_block: + self.router = Gemma4TextRouter(config) + self.experts = Gemma4TextExperts(config) + self.post_feedforward_layernorm_1 = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.post_feedforward_layernorm_2 = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.pre_feedforward_layernorm_2 = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + per_layer_input: torch.Tensor = None, + position_embeddings: torch.Tensor = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + **kwargs, + ) -> torch.Tensor: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + **kwargs, + ) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + + if self.enable_moe_block: + hidden_states_1 = self.post_feedforward_layernorm_1(hidden_states) + + # Take hidden states before MLP here + hidden_states_flat = residual.reshape(-1, residual.shape[-1]) + _, top_k_weights, top_k_index = self.router(hidden_states_flat) + hidden_states_2 = self.pre_feedforward_layernorm_2(hidden_states_flat) + hidden_states_2 = self.experts(hidden_states_2, top_k_index, top_k_weights) + hidden_states_2 = hidden_states_2.reshape(residual.shape) + hidden_states_2 = self.post_feedforward_layernorm_2(hidden_states_2) + + # Combine mlp and moe outputs + hidden_states = hidden_states_1 + hidden_states_2 + + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + + if self.hidden_size_per_layer_input: + residual = hidden_states + hidden_states = self.per_layer_input_gate(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states = hidden_states * per_layer_input + hidden_states = self.per_layer_projection(hidden_states) + hidden_states = self.post_per_layer_input_norm(hidden_states) + hidden_states = residual + hidden_states + + hidden_states *= self.layer_scalar + return hidden_states + + +class Gemma4TextScaledWordEmbedding(nn.Embedding): + """ + This module overrides nn.Embeddings' forward by multiplying with embeddings scale. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.scalar_embed_scale = embed_scale + self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False) + + def forward(self, input_ids: torch.Tensor): + return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype) + + +# ---- Model Classes ---- + + +class Gemma4PreTrainedModel(PreTrainedModel): + config: Gemma4Config + supports_gradient_checkpointing = True + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + _can_compile_fullgraph = True + _supports_attention_backend = True + _no_split_modules = ["Gemma4TextDecoderLayer", "Gemma4VisionEncoderLayer", "Gemma4AudioLayer"] + _skip_keys_device_placement = ["past_key_values"] + input_modalities = ("image", "text", "video", "audio") + + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, Gemma4VisionPatchEmbedder): + init.ones_(module.position_embedding_table) + elif isinstance(module, Gemma4AudioRelPositionalEncoding): + min_timescale = 1.0 + max_timescale = 10000.0 + num_timescales = module.hidden_size // 2 + log_timescale_increment = math.log(max_timescale / min_timescale) / max(num_timescales - 1, 1) + inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales) * -log_timescale_increment) + init.copy_(module.inv_timescales, inv_timescales.unsqueeze(0).unsqueeze(0)) + elif isinstance(module, Gemma4AudioAttention): + init.constant_(module.softcap, module.attention_logits_soft_cap) + init.zeros_(module.per_dim_scale) + elif isinstance(module, Gemma4TextRotaryEmbedding): + for layer_type, rope_init_fn in module.rope_init_fns.items(): + rope_init_fn_kwargs = {"layer_type": layer_type} + if layer_type == "full_attention" and module.rope_type[layer_type] == "proportional": + rope_init_fn_kwargs["head_dim_key"] = "global_head_dim" + + curr_inv_freq, _ = rope_init_fn(module.config, **rope_init_fn_kwargs) + init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq) + init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq) + elif isinstance(module, Gemma4VisionRotaryEmbedding): + rope_fn = ( + ROPE_INIT_FUNCTIONS[module.rope_type] + if module.rope_type != "default" + else module.compute_default_rope_parameters + ) + buffer_value, _ = rope_fn(module.config) + init.copy_(module.inv_freq, buffer_value) + init.copy_(module.original_inv_freq, buffer_value) + elif isinstance(module, Gemma4TextScaledWordEmbedding): + init.constant_(module.embed_scale, module.scalar_embed_scale) + elif isinstance(module, Gemma4TextRouter): + init.ones_(module.scale) + init.ones_(module.per_expert_scale) + elif isinstance(module, Gemma4TextExperts): + std = self.config.initializer_range + init.normal_(module.gate_up_proj, mean=0.0, std=std) + init.normal_(module.down_proj, mean=0.0, std=std) + elif isinstance(module, Gemma4TextDecoderLayer): + init.ones_(module.layer_scalar) + elif isinstance(module, Gemma4ClippableLinear) and module.use_clipped_linears: + init.constant_(module.input_min, -float("inf")) + init.constant_(module.input_max, float("inf")) + init.constant_(module.output_min, -float("inf")) + init.constant_(module.output_max, float("inf")) + elif isinstance(module, Gemma4VisionModel) and module.config.standardize: + init.zeros_(module.std_bias) + init.ones_(module.std_scale) + + +@auto_docstring(custom_intro="The base Gemma 4 language model without a language modeling head.") +class Gemma4TextModel(Gemma4PreTrainedModel): + config: Gemma4TextConfig + input_modalities = ("text",) + _can_record_outputs = { + "router_logits": OutputRecorder(Gemma4TextRouter, index=0), + "hidden_states": Gemma4TextDecoderLayer, + "attentions": Gemma4TextAttention, + } + + def __init__(self, config: Gemma4TextConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + # Gemma4 downcasts the below to bfloat16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402 + self.embed_tokens = Gemma4TextScaledWordEmbedding( + config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=self.config.hidden_size**0.5 + ) + self.layers = nn.ModuleList( + [Gemma4TextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Gemma4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Gemma4TextRotaryEmbedding(config) + self.gradient_checkpointing = False + self.unique_layer_types = set(self.config.layer_types) + + self.hidden_size_per_layer_input = config.hidden_size_per_layer_input + if self.hidden_size_per_layer_input: + self.embed_tokens_per_layer = Gemma4TextScaledWordEmbedding( + config.vocab_size_per_layer_input, + config.num_hidden_layers * config.hidden_size_per_layer_input, + self.padding_idx, + embed_scale=config.hidden_size_per_layer_input**0.5, + ) + self.per_layer_input_scale = 2.0**-0.5 + self.per_layer_model_projection = nn.Linear( + config.hidden_size, + config.num_hidden_layers * config.hidden_size_per_layer_input, + bias=False, + ) + self.per_layer_model_projection_scale = config.hidden_size**-0.5 + self.per_layer_projection_norm = Gemma4RMSNorm(config.hidden_size_per_layer_input, eps=config.rms_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + @merge_with_config_defaults + @capture_outputs + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + per_layer_inputs: torch.Tensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + r""" + per_layer_inputs (`torch.Tensor` of shape `(batch_size, sequence_length, num_hidden_layers, hidden_size_per_layer_input)`, *optional*): + Pre-computed per-layer input embeddings. When provided, these are used directly instead of being + computed from `input_ids` via `get_per_layer_inputs()`. This is primarily used by the multimodal + model (`Gemma4Model`) which pre-computes per-layer inputs from the original `input_ids` *before* + merging multimodal soft tokens into `inputs_embeds` — at which point the original token ids are + no longer recoverable. + """ + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if input_ids is not None: + inputs_embeds = self.embed_tokens(input_ids) + + if self.hidden_size_per_layer_input: + if per_layer_inputs is None: + per_layer_inputs = self.get_per_layer_inputs(input_ids, inputs_embeds) + per_layer_inputs = self.project_per_layer_inputs(inputs_embeds, per_layer_inputs) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if position_ids is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens + position_ids = position_ids.unsqueeze(0) + + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "inputs_embeds": inputs_embeds, + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), + } + + # embed positions + hidden_states = inputs_embeds + position_embeddings = {} + for layer_type in self.unique_layer_types: + position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type) + + # decoder layers + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + per_layer_input = per_layer_inputs[:, :, i, :] if per_layer_inputs is not None else None + + hidden_states = decoder_layer( + hidden_states, + per_layer_input, + position_embeddings=position_embeddings[self.config.layer_types[i]], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], + position_ids=position_ids, + past_key_values=past_key_values, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + def get_per_layer_inputs(self, input_ids: torch.Tensor | None, inputs_embeds: torch.Tensor | None) -> torch.Tensor: + if not self.hidden_size_per_layer_input: + raise RuntimeError( + "Attempting to call get_per_layer_inputs() from a model initialized with a config that does not support" + f" per-layer embeddings. {self.config}" + ) + + # If only inputs_embeds are provided, reverse main embedding to find the input_ids - this allows to `generate` + # from `inputs_embeds` only as other models (otherwise it would need the value from both embeddings) + if input_ids is None: + with torch.no_grad(): + input_ids = ( + ( + inputs_embeds[:, :, None, :] + == self.embed_tokens.weight[None, None, :, :] * self.config.hidden_size**0.5 + ) + .all(dim=3) + .nonzero()[:, 2] + ) + try: + input_ids = input_ids.view(inputs_embeds.shape[:2]) + except RuntimeError: + raise RuntimeError( + "It seems like you tried to call `forward` from `inputs_embeds` without providing `input_ids`, and that " + "the `inputs_embeds` you provided do not exactly match the embedding weights. Since Gemma4 needs to reverse " + "the embedding to compute another embedding, make sure you provide exact `inputs_embeds`" + ) + + return self.embed_tokens_per_layer(input_ids).reshape( + *input_ids.shape, + self.config.num_hidden_layers, + self.hidden_size_per_layer_input, + ) + + def project_per_layer_inputs( + self, + inputs_embeds: torch.Tensor, + per_layer_inputs: torch.Tensor | None = None, + ) -> torch.Tensor: + if not self.hidden_size_per_layer_input: + raise RuntimeError( + "Attempting to call project_per_layer_inputs() from a model initialized with a config that does not" + f" support per-layer embeddings. {self.config}" + ) + + per_layer_projection = self.per_layer_model_projection(inputs_embeds) * self.per_layer_model_projection_scale + per_layer_projection = per_layer_projection.reshape( + *inputs_embeds.shape[:-1], + self.config.num_hidden_layers, + self.hidden_size_per_layer_input, + ) + per_layer_projection = self.per_layer_projection_norm(per_layer_projection) + + if per_layer_inputs is None: + return per_layer_projection + + return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale + + +@auto_docstring(custom_intro="The base Gemma 4 language model with a language modeling head.") +class Gemma4ForCausalLM(Gemma4PreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + _tp_plan = {"lm_head": "colwise_gather_output"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + config: Gemma4TextConfig + base_model_prefix = "model" + + def __init__(self, config: Gemma4TextConfig): + super().__init__(config) + self.model = Gemma4TextModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, Gemma4ForCausalLM + + >>> model = Gemma4ForCausalLM.from_pretrained("google/gemma-2-9b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") + + >>> prompt = "What is your favorite condiment?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "What is your favorite condiment?" + ```""" + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + if self.config.final_logit_softcapping is not None: + logits = logits / self.config.final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * self.config.final_logit_softcapping + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +def sliding_window_mask_function(sliding_window: tuple[int, int]) -> Callable: + """ + This creates uni/bidirectional attention mask with sliding window. + """ + + def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: + left_window_size, right_window_size = sliding_window + + dist = q_idx - kv_idx + left_mask = (dist >= 0) & (dist < left_window_size) + right_mask = (dist < 0) & (-dist < right_window_size) + return left_mask | right_mask + + return inner_mask + + +class Gemma4AudioModel(Gemma4PreTrainedModel): + """An audio encoder based on the [Universal Speech Model](https://huggingface.co/papers/2303.01037) architecture.""" + + config: Gemma4AudioConfig + main_input_name = "input_features" + base_model_prefix = "model.audio_tower" # prefix for Gemma4ForConditionalGeneration saved checkpoints, required for Gemma4AudioModel.from_pretrained() + _can_record_outputs = { + "hidden_states": Gemma4AudioLayer, + "attentions": Gemma4AudioAttention, + } + + def __init__(self, config: Gemma4AudioConfig): + super().__init__(config) + self.config = config + + self.subsample_conv_projection = Gemma4AudioSubSampleConvProjection(config) + self.rel_pos_enc = Gemma4AudioRelPositionalEncoding(config) + self.layers = nn.ModuleList( + [Gemma4AudioLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.output_proj = nn.Linear(config.hidden_size, config.output_proj_dims, bias=True) + + self.post_init() + + def _convert_4d_mask_to_blocked_5d(self, mask_4d: torch.Tensor) -> torch.Tensor: + """ + Convert a standard 4D attention mask `[batch_size, 1, seq_len, seq_len]` to the 5D blocked format + `[batch_size, 1, num_blocks, chunk_size, context_size]` expected by the chunked local attention, + """ + batch_size, _, seq_len, _ = mask_4d.shape + device = mask_4d.device + + chunk_size = self.config.attention_chunk_size + max_past_horizon = self.config.attention_context_left - 1 + max_future_horizon = self.config.attention_context_right + + num_blocks = (seq_len + chunk_size - 1) // chunk_size + padded_seq_len = num_blocks * chunk_size + pad_amount = padded_seq_len - seq_len + + mask_4d = F.pad(mask_4d, (0, pad_amount, 0, pad_amount), value=False) + mask_5d = mask_4d.reshape(batch_size, 1, num_blocks, chunk_size, padded_seq_len) + mask_5d = F.pad(mask_5d, (max_past_horizon, max_future_horizon), value=False) + + block_starts = torch.arange(num_blocks, device=device) * chunk_size + offsets = torch.arange(chunk_size + max_past_horizon + max_future_horizon, device=device) + kv_indices = block_starts[:, None] + offsets[None, :] + kv_indices = kv_indices[None, None, :, None, :].expand(batch_size, 1, -1, chunk_size, -1) + + return mask_5d.gather(-1, kv_indices) + + @merge_with_config_defaults + @capture_outputs + @auto_docstring(custom_intro="Encodes audio features to soft tokens.") + def forward( + self, + input_features: torch.Tensor, + attention_mask: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.BoolTensor]: + hidden_states, output_mask = self.subsample_conv_projection(input_features, attention_mask) + position_embeddings = self.rel_pos_enc(hidden_states) + + attention_mask = create_bidirectional_mask( + config=self.config, + inputs_embeds=hidden_states, + attention_mask=output_mask, + and_mask_function=sliding_window_mask_function( + (self.config.attention_context_left - 1, self.config.attention_context_right) + ), + ) + attention_mask = self._convert_4d_mask_to_blocked_5d(attention_mask) + + for encoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = encoder_layer( + hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.output_proj(hidden_states) + return Gemma4AudioModelOutput(last_hidden_state=hidden_states, attention_mask=output_mask) + + +class Gemma4VisionModel(Gemma4PreTrainedModel): + """The Gemma 4 Vision Encoder.""" + + config = Gemma4VisionConfig + _can_record_outputs = { + "hidden_states": Gemma4VisionEncoderLayer, + "attentions": Gemma4VisionAttention, + } + + def __init__(self, config: Gemma4VisionConfig): + super().__init__(config) + self.patch_embedder = Gemma4VisionPatchEmbedder(config) + self.encoder = Gemma4VisionEncoder(config) + self.pooler = Gemma4VisionPooler(config) + + if self.config.standardize: + self.register_buffer("std_bias", torch.empty(self.config.hidden_size)) + self.register_buffer("std_scale", torch.empty(self.config.hidden_size)) + + self.post_init() + + @merge_with_config_defaults + @capture_outputs + @auto_docstring(custom_intro="Encodes image pixels to soft tokens from patches.") + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_position_ids: torch.LongTensor, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + r""" + pixel_values (`torch.FloatTensor` or `list[torch.FloatTensor]`): + The images to encode. Either a single `[batch, channels, height, width]` tensor + (all images same size) or a list of `[1, channels, height, width]` tensors (different sizes). + pixel_position_ids (`torch.LongTensor` of shape `(batch_size, max_patches, 2)`): + The patch positions as (x, y) coordinates in the image. Padding patches are indicated by (-1, -1). + """ + pooling_kernel_size = self.config.pooling_kernel_size + output_length = pixel_values.shape[-2] // (pooling_kernel_size * pooling_kernel_size) + + padding_positions = (pixel_position_ids == -1).all(dim=-1) + inputs_embeds = self.patch_embedder(pixel_values, pixel_position_ids, padding_positions) + output = self.encoder( + inputs_embeds=inputs_embeds, + attention_mask=~padding_positions, # encoder expects True=valid, padding_positions is True=padding + pixel_position_ids=pixel_position_ids, + **kwargs, + ) + + hidden_states, pooler_mask = self.pooler( + hidden_states=output.last_hidden_state, + pixel_position_ids=pixel_position_ids, + padding_positions=padding_positions, + output_length=output_length, + ) + + # Strip padding tokens. pooler_mask is True = valid, False = padding. + hidden_states = hidden_states[pooler_mask] + + if self.config.standardize: + hidden_states = (hidden_states - self.std_bias) * self.std_scale + + return BaseModelOutputWithPast(last_hidden_state=hidden_states) + + +class Gemma4MultimodalEmbedder(nn.Module): + """Embeds token ids or soft tokens for multimodal content into language model space.""" + + def __init__( + self, + multimodal_config: Gemma4AudioConfig | Gemma4VisionConfig, + text_config: Gemma4TextConfig, + ): + super().__init__() + + self.multimodal_hidden_size = getattr(multimodal_config, "output_proj_dims", multimodal_config.hidden_size) + self.eps = multimodal_config.rms_norm_eps + self.text_hidden_size = text_config.hidden_size + self.embedding_projection = nn.Linear(self.multimodal_hidden_size, self.text_hidden_size, bias=False) + self.embedding_pre_projection_norm = Gemma4RMSNorm(self.multimodal_hidden_size, eps=self.eps, with_scale=False) + + def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor: + """Embeds token ids or soft tokens for multimodal content into language model space. + Args: + inputs_embeds: A torch.Tensor containing the soft tokens to embed. + Returns: + A torch.Tensor of embeddings with shape `[batch_size, seq_len, self.config.text_config.hidden_size]`. + """ + embs_normed = self.embedding_pre_projection_norm(inputs_embeds) + return self.embedding_projection(embs_normed) + + +# Identical as Gemma3 but modular can't resolve if we simply import. FIXME: @cyril +def token_type_ids_mask_function( + token_type_ids: torch.Tensor | None, + image_group_ids: torch.Tensor | None, +) -> Callable | None: + """ + This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths, + not start and end indices. + """ + # Do not return an additional mask in this case + if token_type_ids is None: + return None + + def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: + seq_length = image_group_ids.shape[-1] + + # clamp indices because with static cache they can go beyond `image_group_ids.shape[-1]` + q_idx_clamped = q_idx.clamp(max=seq_length - 1) + kv_idx_clamped = kv_idx.clamp(max=seq_length - 1) + + # Unmask if the q and kv come from same group which is not -1 (i.e. non-text) + q_group = image_group_ids[batch_idx, q_idx_clamped] + kv_group = image_group_ids[batch_idx, kv_idx_clamped] + q_group = torch.where(q_idx < seq_length, q_group, -1) + kv_group = torch.where(kv_idx < seq_length, kv_group, -1) + return (q_group == kv_group) & (q_group >= 0) + + return inner_mask + + +# Similar to Gemma3 but `sliding_mask_kwargs` and `mask_kwargs` are different and `token_type_ids->mm_token_type_ids` +def create_causal_mask_mapping( + config: PreTrainedConfig, + inputs_embeds: torch.Tensor, + attention_mask: torch.Tensor | None, + past_key_values: Cache | None, + position_ids: torch.Tensor | None, + mm_token_type_ids: torch.Tensor | None = None, + pixel_values: torch.FloatTensor | None = None, + is_training: bool = False, + is_first_iteration: bool | None = None, + **kwargs, +) -> dict: + """ + Overwrites the base `create_masks_for_generate` with `token_type_ids` masking to create the causal mask mapping + for all kinds of forward passes. Gemma4 uses a bidirectional mask for images. + + Uses `pixel_values` as an optional input to disambiguate edge cases. + """ + if is_training and mm_token_type_ids is None: + raise ValueError("`mm_token_type_ids` is required as a model input when training") + + mask_kwargs = { + "config": config.get_text_config(), + "inputs_embeds": inputs_embeds, + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + sliding_mask_kwargs = mask_kwargs.copy() + + # NOTE: this `may_have_image_input` logic is not flawless, it fails when we're using a cache eagerly initialized + # (e.g. compiled prefill) AND `pixel_values` are not provided (i.e. the image data is provided through other + # means). Determining prefill in that case requires checking data values, which is not compile-compatible. + is_first_iteration = ( + is_first_iteration + if is_first_iteration is not None + else (past_key_values is None or not past_key_values.is_initialized or pixel_values is not None) + ) + if mm_token_type_ids is not None and is_first_iteration: + # We need to pass an additional mask function to account for token type ids, and it needs to be an `or` (to + # undo the causal masking) + + # First find where a new vision block starts. Vision tokens cannot attend to + # future vision tokens, but can attend to all prev tokens and to itself bidirectionally + is_vision = (mm_token_type_ids == 1) | (mm_token_type_ids == 2) + is_prev_vision = torch.roll(is_vision, shifts=1, dims=-1) + is_prev_vision[..., 0] = False + new_vision_starts = is_vision & ~is_prev_vision + vision_group_ids = torch.cumsum(new_vision_starts.int(), dim=1) - 1 + vision_group_ids = torch.where(is_vision, vision_group_ids, -1) + sliding_mask_kwargs["or_mask_function"] = token_type_ids_mask_function( + mm_token_type_ids.to(inputs_embeds.device), vision_group_ids + ) + + return { + "full_attention": create_causal_mask(**mask_kwargs), + "sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs), + } + + +@auto_docstring( + custom_intro=""" + The base Gemma 4 model comprising a vision backbone, an audio backbone, and a language model without a + language modeling head. + """ +) +class Gemma4Model(Gemma4PreTrainedModel): + # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch + accepts_loss_kwargs = False + + def __init__(self, config: Gemma4Config): + super().__init__(config) + self.vocab_size = config.text_config.vocab_size + + language_model = AutoModel.from_config(config=config.text_config) + self.language_model = language_model + self.vocab_size_per_layer_input = config.text_config.vocab_size_per_layer_input + self.vision_tower = AutoModel.from_config(config.vision_config) if config.vision_config is not None else None + self.embed_vision = ( + Gemma4MultimodalEmbedder(config.vision_config, config.text_config) + if config.vision_config is not None + else None + ) + self.audio_tower = AutoModel.from_config(config.audio_config) if config.audio_config is not None else None + self.embed_audio = ( + Gemma4MultimodalEmbedder(config.audio_config, config.text_config) + if config.audio_config is not None + else None + ) + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + @can_return_tuple + @auto_docstring(custom_intro="Projects the last hidden state from the vision model into language model space.") + def get_image_features( + self, + pixel_values: torch.FloatTensor, + image_position_ids: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPooling: + r""" + image_position_ids (`torch.LongTensor` of shape `(batch_size, max_patches, 2)`, *optional*): + The patch positions as (x, y) coordinates in the image. Padding patches are indicated by (-1, -1). + """ + vision_outputs = self.vision_tower( + pixel_values=pixel_values, + pixel_position_ids=image_position_ids, + **kwargs, + ) + last_hidden_state = vision_outputs.last_hidden_state + vision_outputs.pooler_output = self.embed_vision(inputs_embeds=last_hidden_state) + return vision_outputs + + def get_placeholder_mask( + self, + input_ids: torch.LongTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + ) -> tuple[torch.BoolTensor, torch.BoolTensor, torch.BoolTensor]: + """ + Obtains mask for multimodal placeholders (replaced by soft tokens) and hard text tokens. + + Masks will be obtained from `mm_token_type_ids`, `input_ids`, or `inputs_embeds` as available and in that + precedence order. If passing `input_ids` or `inputs_embeds`, the image mask will be derived using + `config.image_token_id`. Same goes for audio and video masks + + Args: + input_ids: A tensor containing the hard token IDs from the text tokenizer. + inputs_embeds: A tensor containing the embeddings for all hard text tokens. + + Returns: + image_mask, video_mask, audio_mask + """ + if input_ids is not None: + special_image_mask = input_ids == self.config.image_token_id + special_video_mask = input_ids == self.config.video_token_id + special_audio_mask = input_ids == self.config.audio_token_id + else: + special_image_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + ).all(-1) + special_video_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + ).all(-1) + special_audio_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + ).all(-1) + + return special_image_mask, special_video_mask, special_audio_mask + + @merge_with_config_defaults + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + pixel_values: torch.FloatTensor | None = None, + pixel_values_videos: torch.FloatTensor | None = None, + input_features: torch.FloatTensor | None = None, + attention_mask: torch.Tensor | None = None, + input_features_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + mm_token_type_ids: torch.LongTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + image_position_ids: torch.LongTensor | None = None, + video_position_ids: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Gemma4ModelOutputWithPast: + r""" + input_features_mask (`torch.FloatTensor]` of shape `(num_images, seq_length)`): + The attention mask for the input audio. + image_position_ids (`torch.LongTensor` of shape `(batch_size, max_patches, 2)`, *optional*): + 2D patch position coordinates from the image processor, with `(-1, -1)` indicating padding. + Passed through to the vision encoder for positional embedding computation. + video_position_ids (`torch.LongTensor` of shape `(num_videos, num_frames, max_patches, 2)`, *optional*): + 2D patch position coordinates from the video processor, with `(-1, -1)` indicating padding. + Passed through to the vision encoder for positional embedding computation. + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + image_mask, video_mask, audio_mask = self.get_placeholder_mask(input_ids, inputs_embeds) + multimodal_mask = image_mask | video_mask | audio_mask + + # Replace image id with PAD if the image token if OOV, to avoid index-errors + llm_input_ids = None + if inputs_embeds is None: + llm_input_ids = input_ids.clone() + llm_input_ids[multimodal_mask] = self.config.text_config.pad_token_id + inputs_embeds = self.get_input_embeddings()(llm_input_ids) + + if self.config.get_text_config().hidden_size_per_layer_input: + pad_embedding = self.language_model.embed_tokens.weight[self.config.text_config.pad_token_id, :] + llm_inputs_embeds = torch.where(multimodal_mask[..., None], pad_embedding.view(1, 1, -1), inputs_embeds) + per_layer_inputs = self.language_model.get_per_layer_inputs(llm_input_ids, llm_inputs_embeds) + else: + per_layer_inputs = None + + # Merge text and images + if pixel_values is not None: + image_features = self.get_image_features(pixel_values, image_position_ids, return_dict=True).pooler_output + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + + # Confirm the number of soft tokens from the vision tower matches the number of slots in the embeddings. + n_image_tokens = image_mask.sum() + image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + torch_compilable_check( + inputs_embeds[image_mask].numel() == image_features.numel(), + f"Image features and image tokens do not match, tokens: {n_image_tokens}, features:" + f" {image_features.shape[0]}", + ) + + inputs_embeds = inputs_embeds.masked_scatter( + image_mask.to(inputs_embeds.device), image_features.to(inputs_embeds.device) + ) + + if pixel_values_videos is not None: + video_features = self.get_video_features( + pixel_values_videos, video_position_ids, return_dict=True + ).pooler_output + video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) + + # Confirm the number of soft tokens from the vision tower matches the number of slots in the embeddings. + n_video_tokens = video_mask.sum() + video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + torch_compilable_check( + inputs_embeds[video_mask].numel() == video_features.numel(), + f"Video features and video tokens do not match, tokens: {n_video_tokens}, features:" + f" {video_features.shape[0]}", + ) + + inputs_embeds = inputs_embeds.masked_scatter( + video_mask.to(inputs_embeds.device), video_features.to(inputs_embeds.device) + ) + + # Merge text and audio + if input_features is not None and input_features_mask is not None: + audio_output = self.get_audio_features(input_features, input_features_mask, return_dict=True) + audio_features = audio_output.pooler_output + audio_mask_from_encoder = audio_output.attention_mask # True = valid + + # Strip padding tokens: only keep real (non-padding) audio soft tokens. + # audio_mask_from_encoder is True for valid positions, False for padding tokens. + # This mirrors the vision encoder's padding stripping (see Gemma4VisionEncoder.forward). + audio_features = audio_features[audio_mask_from_encoder] + + n_audio_tokens = audio_mask.sum() + audio_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + torch_compilable_check( + inputs_embeds[audio_mask].numel() == audio_features.numel(), + f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features:" + f" {audio_features.shape[0] * audio_features.shape[1]}", + ) + + inputs_embeds = inputs_embeds.masked_scatter( + audio_mask.to(inputs_embeds.device), audio_features.to(inputs_embeds.device) + ) + + # It may already have been prepared by, e.g., `generate` + if position_ids is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens + position_ids = position_ids.unsqueeze(0) + + if not isinstance(causal_mask_mapping := attention_mask, dict): + if self.config.get_text_config().use_bidirectional_attention == "vision": + # Larger Gemma 4 models use Gemma 3's bidirectional attention mask for vision inputs + causal_mask_mapping = create_causal_mask_mapping( + self.config, + inputs_embeds, + attention_mask, + past_key_values, + position_ids, + mm_token_type_ids, + pixel_values, + is_training=self.training, + ) + else: + # Smaller Gemma models use a conventional casual attention mask + causal_mask_mapping = create_masks_for_generate( + self.config, + inputs_embeds, + attention_mask, + past_key_values, + position_ids, + ) + + outputs = self.language_model( + per_layer_inputs=per_layer_inputs, + attention_mask=causal_mask_mapping, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + return_dict=True, + **kwargs, + ) + + return Gemma4ModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + audio_hidden_states=audio_features if input_features is not None else None, + ) + + @can_return_tuple + @auto_docstring(custom_intro="Projects the last hidden state from the audio encoder into language model space.") + def get_audio_features( + self, + input_features: torch.Tensor, + input_features_mask: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | Gemma4AudioModelOutput: + r""" + input_features (`torch.FloatTensor]` of shape `(num_images, seq_length, num_features)`): + The tensors corresponding to the input audio. + input_features_mask (`torch.FloatTensor]` of shape `(num_images, seq_length)`): + The attention mask for the input audio. + """ + if self.audio_tower is None: + raise ValueError( + "Audio features were requested, but the model was initialized without an audio_config. " + "Cannot process audio without an audio tower and audio embedder." + ) + + audio_outputs = self.audio_tower(input_features, input_features_mask, return_dict=True, **kwargs) + audio_outputs.pooler_output = self.embed_audio(inputs_embeds=audio_outputs.last_hidden_state) + + return audio_outputs + + @can_return_tuple + @auto_docstring(custom_intro="Projects the last hidden state from the vision encoder into language model space.") + def get_video_features( + self, + pixel_values_videos: torch.FloatTensor, + video_position_ids: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPooling: + r""" + video_position_ids (`torch.LongTensor` of shape `(num_videos, num_frames, max_patches, 2)`, *optional*): + 2D patch position coordinates from the video processor, with `(-1, -1)` indicating padding. + Passed through to the vision encoder for positional embedding computation. + """ + pixel_values_videos = pixel_values_videos.flatten(0, 1) + video_position_ids = video_position_ids.flatten(0, 1) + vision_outputs = self.vision_tower( + pixel_values=pixel_values_videos, + pixel_position_ids=video_position_ids, + **kwargs, + ) + last_hidden_state = vision_outputs.last_hidden_state + vision_outputs.pooler_output = self.embed_vision(inputs_embeds=last_hidden_state) + return vision_outputs + + +@auto_docstring( + custom_intro=""" + The base Gemma 4 model comprising a vision backbone, an audio backbone, a language model, and a language modeling + head. + """ +) +class Gemma4ForConditionalGeneration(Gemma4PreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + base_model_prefix = "model" + + def __init__(self, config: Gemma4Config): + super().__init__(config) + self.model = Gemma4Model(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + @auto_docstring + def get_image_features( + self, + pixel_values: torch.FloatTensor, + image_position_ids: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ): + r""" + image_position_ids (`torch.LongTensor` of shape `(batch_size, max_patches, 2)`, *optional*): + 2D patch position coordinates from the image processor, with `(-1, -1)` indicating padding. + Passed through to the vision encoder for positional embedding computation. + """ + return self.model.get_image_features(pixel_values, image_position_ids, **kwargs) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + pixel_values: torch.FloatTensor | None = None, + pixel_values_videos: torch.FloatTensor | None = None, + input_features: torch.FloatTensor | None = None, + attention_mask: torch.Tensor | None = None, + input_features_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + image_position_ids: torch.LongTensor | None = None, + video_position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + mm_token_type_ids: torch.LongTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> Gemma4CausalLMOutputWithPast: + r""" + input_features_mask (`torch.FloatTensor]` of shape `(num_images, seq_length)`): + The attention mask for the input audio. + image_position_ids (`torch.LongTensor` of shape `(batch_size, max_patches, 2)`, *optional*): + 2D patch position coordinates from the image processor, with `(-1, -1)` indicating padding. + Passed through to the vision encoder for positional embedding computation. + video_position_ids (`torch.LongTensor` of shape `(num_videos, num_frames, max_patches, 2)`, *optional*): + 2D patch position coordinates from the video processor, with `(-1, -1)` indicating padding. + Passed through to the vision encoder for positional embedding computation. + """ + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + input_features=input_features, + attention_mask=attention_mask, + input_features_mask=input_features_mask, + position_ids=position_ids, + past_key_values=past_key_values, + mm_token_type_ids=mm_token_type_ids, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + image_position_ids=image_position_ids, + video_position_ids=video_position_ids, + return_dict=True, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + if (final_logit_softcapping := self.config.get_text_config().final_logit_softcapping) is not None: + logits = logits / final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * final_logit_softcapping + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + shift_logits = logits[..., :-1, :] + shift_labels = labels[..., 1:] + if attention_mask is not None: + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device) + shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous() + shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous() + else: + shift_logits = shift_logits.contiguous() + shift_labels = shift_labels.contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + + flat_logits = shift_logits.view(-1, self.config.get_text_config().vocab_size) + flat_labels = shift_labels.view(-1).to(shift_logits.device) + loss = loss_fct(flat_logits, flat_labels) + + return Gemma4CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + audio_hidden_states=outputs.audio_hidden_states, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + position_ids=None, + pixel_values=None, + pixel_values_videos=None, + input_features=None, + attention_mask=None, + input_features_mask=None, + token_type_ids=None, + use_cache=True, + logits_to_keep=None, + labels=None, + is_first_iteration=False, + **kwargs, + ): + # Overwritten -- custom `position_ids` and `pixel_values` handling + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=use_cache, + logits_to_keep=logits_to_keep, + token_type_ids=token_type_ids, + is_first_iteration=is_first_iteration, + **kwargs, + ) + + # If we're in cached decoding stage, multimodal inputs are already cached and can be dropped + if is_first_iteration or not use_cache: + model_inputs["pixel_values"] = pixel_values + model_inputs["pixel_values_videos"] = pixel_values_videos + model_inputs["input_features"] = input_features + model_inputs["input_features_mask"] = input_features_mask + + return model_inputs + + @staticmethod + def create_masks_for_generate( + config: PreTrainedConfig, + inputs_embeds: torch.Tensor, + attention_mask: torch.Tensor | None, + past_key_values: Cache | None, + position_ids: torch.Tensor | None, + mm_token_type_ids: torch.Tensor | None = None, + is_first_iteration: bool | None = False, + **kwargs, + ) -> dict: + if getattr(config.get_text_config(), "use_bidirectional_attention", None) == "vision": + # Larger Gemma 4 models use Gemma 3's bidirectional attention mask for vision inputs + return create_causal_mask_mapping( + config, + inputs_embeds, + attention_mask, + past_key_values, + position_ids, + mm_token_type_ids, + is_first_iteration=is_first_iteration, + **{k: v for k, v in kwargs.items() if k != "pixel_values"}, + ) + else: + # Smaller Gemma models use a conventional casual attention mask + return create_masks_for_generate( + config, inputs_embeds, attention_mask, past_key_values, position_ids, **kwargs + ) + + +__all__ = [ + "Gemma4AudioModel", + "Gemma4ForCausalLM", + "Gemma4ForConditionalGeneration", + "Gemma4Model", + "Gemma4PreTrainedModel", + "Gemma4TextModel", + "Gemma4VisionModel", +] diff --git a/src/transformers/models/gemma4/modular_gemma4.py b/src/transformers/models/gemma4/modular_gemma4.py new file mode 100644 index 000000000000..a97273802213 --- /dev/null +++ b/src/transformers/models/gemma4/modular_gemma4.py @@ -0,0 +1,2160 @@ +# Copyright 2026 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from collections.abc import Callable +from dataclasses import dataclass +from functools import cached_property + +import torch +from torch import nn +from torch.nn import functional as F + +from ... import initialization as init +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...configuration_utils import PreTrainedConfig +from ...integrations import use_kernelized_func +from ...masking_utils import ( + create_bidirectional_mask, + create_causal_mask, + create_masks_for_generate, + create_sliding_window_causal_mask, +) +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + TransformersKwargs, + auto_docstring, + can_return_tuple, + logging, + torch_compilable_check, +) +from ...utils.generic import maybe_autocast, merge_with_config_defaults +from ...utils.output_capturing import OutputRecorder, capture_outputs +from ..auto.modeling_auto import AutoModel +from ..gemma3.modeling_gemma3 import ( + Gemma3Attention, + Gemma3DecoderLayer, + Gemma3ForCausalLM, + Gemma3MLP, + Gemma3RotaryEmbedding, + Gemma3TextModel, + Gemma3TextScaledWordEmbedding, +) +from ..gemma3n.modeling_gemma3n import ( + Gemma3nCausalLMOutputWithPast, + Gemma3nForConditionalGeneration, + Gemma3nModel, + Gemma3nModelOutputWithPast, + Gemma3nMultimodalEmbedder, + Gemma3nRMSNorm, + apply_rotary_pos_emb, + eager_attention_forward, +) +from ..llama.modeling_llama import LlamaRotaryEmbedding +from ..mixtral.modeling_mixtral import MixtralExperts +from ..moonshine_streaming.modeling_moonshine_streaming import sliding_window_mask_function +from .configuration_gemma4 import Gemma4AudioConfig, Gemma4Config, Gemma4TextConfig, Gemma4VisionConfig + + +logger = logging.get_logger(__name__) + + +class Gemma4ModelOutputWithPast(Gemma3nModelOutputWithPast): + pass + + +class Gemma4CausalLMOutputWithPast(Gemma3nCausalLMOutputWithPast): + pass + + +@dataclass +@auto_docstring +class Gemma4AudioModelOutput(BaseModelOutputWithPooling): + r""" + attention_mask (`torch.BoolTensor`, *optional*): + A torch.BoolTensor of shape `(batch_size, num_frames)`. True for valid positions, False for padding. + """ + + attention_mask: torch.BoolTensor | None = None + + +class Gemma4ClippableLinear(nn.Module): + def __init__( + self, + config: Gemma4VisionConfig | Gemma4AudioConfig, + in_features: int, + out_features: int, + ) -> None: + super().__init__() + self.use_clipped_linears = config.use_clipped_linears + self.linear = nn.Linear(in_features, out_features, bias=False) + + if self.use_clipped_linears: + self.register_buffer("input_min", torch.tensor(-float("inf"))) + self.register_buffer("input_max", torch.tensor(float("inf"))) + self.register_buffer("output_min", torch.tensor(-float("inf"))) + self.register_buffer("output_max", torch.tensor(float("inf"))) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if self.use_clipped_linears: + hidden_states = torch.clamp(hidden_states, self.input_min, self.input_max) + + hidden_states = self.linear(hidden_states) + + if self.use_clipped_linears: + hidden_states = torch.clamp(hidden_states, self.output_min, self.output_max) + + return hidden_states + + +class Gemma4RMSNorm(Gemma3nRMSNorm): + pass + + +class Gemma4AudioRelPositionalEncoding(nn.Module): + """Sinusoidal relative positional encoding for the audio encoder. + + Produces position embeddings of shape [1, 2*context_size - 1, hidden_size] with + concatenated [sin..., cos...] layout matching the original Gemma4 convention. + """ + + inv_timescales: torch.Tensor + + def __init__(self, config: Gemma4AudioConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.context_size = ( + config.attention_chunk_size + config.attention_context_left - 1 + config.attention_context_right + ) + min_timescale = 1.0 + max_timescale = 10000.0 + num_timescales = self.hidden_size // 2 + log_timescale_increment = math.log(max_timescale / min_timescale) / max(num_timescales - 1, 1) + inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales) * -log_timescale_increment) + self.register_buffer("inv_timescales", inv_timescales.unsqueeze(0).unsqueeze(0), persistent=False) + + @torch.no_grad() + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + position_ids = torch.arange(12, -1, -1, device=hidden_states.device) + position_ids = position_ids[..., None] + scaled_time = position_ids * self.inv_timescales.to(device=hidden_states.device) + pos_embed = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1) + return pos_embed.to(dtype=hidden_states.dtype) + + +class Gemma4AudioAttention(nn.Module): + """Chunked local attention with relative position bias""" + + def __init__(self, config: Gemma4AudioConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.attention_logits_soft_cap = config.attention_logit_cap + self.head_dim = config.hidden_size // config.num_attention_heads + self.num_heads = config.num_attention_heads + + self.q_scale = (self.head_dim**-0.5) / math.log(2) + self.k_scale = math.log(1 + math.e) / math.log(2) + + self.chunk_size = config.attention_chunk_size + self.max_past_horizon = config.attention_context_left - 1 + self.max_future_horizon = config.attention_context_right + self.context_size = self.chunk_size + self.max_past_horizon + self.max_future_horizon + + self.q_proj = Gemma4ClippableLinear(config, config.hidden_size, self.num_heads * self.head_dim) + self.k_proj = Gemma4ClippableLinear(config, config.hidden_size, self.num_heads * self.head_dim) + self.v_proj = Gemma4ClippableLinear(config, config.hidden_size, self.num_heads * self.head_dim) + self.post = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size) + + self.relative_k_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) + self.per_dim_scale = nn.Parameter(torch.zeros(self.head_dim)) + + self.register_buffer("softcap", torch.tensor(self.attention_logits_soft_cap), persistent=False) + + def _convert_to_block(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Splits a `(batch_size, seq_len, num_heads, head_dim)` tensor into non-overlapping blocks of `chunk_size` along the sequence dim.""" + batch_size, seq_len, num_heads, head_dim = hidden_states.shape + num_blocks = (seq_len + self.chunk_size - 1) // self.chunk_size + pad = num_blocks * self.chunk_size - seq_len + hidden_states = F.pad(hidden_states, (0, 0, 0, 0, 0, pad)) + return hidden_states.reshape(batch_size, num_blocks, self.chunk_size, num_heads, head_dim).contiguous() + + def _extract_block_context(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Extracts overlapping context windows of `context_size` for every block, strided by `chunk_size`.""" + batch_size, seq_len, num_heads, head_dim = hidden_states.shape + hidden_states = F.pad( + hidden_states, (0, 0, 0, 0, self.max_past_horizon, self.max_future_horizon + self.chunk_size - 1) + ) + hidden_states = hidden_states.unfold(1, self.context_size, self.chunk_size) + hidden_states = torch.movedim(hidden_states, -1, 2) + return hidden_states.contiguous() + + def _rel_shift(self, x: torch.Tensor) -> torch.Tensor: + """Relative position shift for blocked attention. See appendix B of https://huggingface.co/papers/1901.02860.""" + batch_size, num_heads, num_blocks, block_size, position_length = x.shape + context_size = self.context_size + x = F.pad(x, (0, context_size + 1 - position_length)) + x = x.view(batch_size, num_heads, num_blocks, block_size * (context_size + 1)) + x = x[..., : block_size * context_size] + return x.view(batch_size, num_heads, num_blocks, block_size, context_size) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: torch.Tensor, + attention_mask: torch.BoolTensor | None = None, + ) -> tuple[torch.Tensor, None]: + batch_size, seq_length, _ = hidden_states.shape + hidden_shape = (batch_size, seq_length, self.num_heads, self.head_dim) + + query_states = self.q_proj(hidden_states).float().view(hidden_shape) + key_states = self.k_proj(hidden_states).float().view(hidden_shape) + value_states = self.v_proj(hidden_states).float().view(hidden_shape) + + query_states = query_states * self.q_scale * F.softplus(self.per_dim_scale) + key_states = key_states * self.k_scale + + query_states = self._convert_to_block(query_states) + key_states = self._extract_block_context(key_states) + value_states = self._extract_block_context(value_states) + num_blocks = query_states.shape[1] + + relative_key_states = self.relative_k_proj(position_embeddings) + relative_key_states = relative_key_states.view(-1, self.num_heads, self.head_dim) + relative_key_states = relative_key_states.to(dtype=query_states.dtype) + + queries = query_states.permute(0, 3, 1, 2, 4) + matrix_ac = queries @ key_states.permute(0, 3, 1, 4, 2) + + queries_flat = queries.reshape(batch_size, self.num_heads, -1, self.head_dim) + matrix_bd = queries_flat @ relative_key_states.permute(1, 2, 0) + matrix_bd = matrix_bd.reshape(batch_size, self.num_heads, num_blocks, self.chunk_size, -1) + matrix_bd = self._rel_shift(matrix_bd) + + attn_weights = matrix_ac + matrix_bd + attn_weights = attn_weights / self.softcap + attn_weights = torch.tanh(attn_weights) + attn_weights = attn_weights * self.softcap + + if attention_mask is not None: + attn_weights = attn_weights.masked_fill( + attention_mask.logical_not(), self.config.attention_invalid_logits_value + ) + + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype) + attn_output = attn_weights @ value_states.permute(0, 3, 1, 2, 4) + attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, num_blocks * self.chunk_size, -1) + attn_output = attn_output[:, :seq_length].contiguous() + attn_output = self.post(attn_output.to(dtype=self.post.linear.weight.dtype)) + + return attn_output, attn_weights + + +class Gemma4AudioSubSampleConvProjectionLayer(nn.Module): + def __init__(self, in_channels, out_channels, norm_eps): + super().__init__() + self.conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=(2, 2), + padding=1, + bias=False, + ) + self.norm = nn.LayerNorm(out_channels, eps=norm_eps, elementwise_affine=True, bias=False) + self.act = nn.ReLU() + + def forward(self, hidden_states: torch.Tensor, mask: torch.Tensor | None = None): + if mask is not None: + mask = mask.to(device=hidden_states.device) + hidden_states = hidden_states * mask[:, None, :, None] + + hidden_states = self.conv(hidden_states.to(self.conv.weight.dtype)) + hidden_states = self.act(self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2).contiguous()) + + if mask is not None: + mask = mask[:, ::2] + + return hidden_states, mask + + +class Gemma4AudioSubSampleConvProjection(nn.Module): + def __init__(self, config: Gemma4AudioConfig): + super().__init__() + self.layer0 = Gemma4AudioSubSampleConvProjectionLayer( + in_channels=1, + out_channels=config.subsampling_conv_channels[0], + norm_eps=config.rms_norm_eps, + ) + self.layer1 = Gemma4AudioSubSampleConvProjectionLayer( + in_channels=config.subsampling_conv_channels[0], + out_channels=config.subsampling_conv_channels[1], + norm_eps=config.rms_norm_eps, + ) + proj_input_dim = (config.subsampling_conv_channels[0] // 4) * config.subsampling_conv_channels[1] + self.input_proj_linear = nn.Linear(proj_input_dim, config.hidden_size, bias=False) + + def forward( + self, + input_features: torch.Tensor, + input_features_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + hidden_states = input_features.unsqueeze(1) + hidden_states, mask = self.layer0(hidden_states, input_features_mask) + hidden_states, mask = self.layer1(hidden_states, mask) + + batch_size, _, seq_len, _ = hidden_states.shape + hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous().reshape(batch_size, seq_len, -1) + return self.input_proj_linear(hidden_states), mask + + +class Gemma4AudioFeedForward(nn.Module): + def __init__(self, config: Gemma4AudioConfig): + super().__init__() + self.config = config + + self.ffw_layer_1 = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size * 4) + self.ffw_layer_2 = Gemma4ClippableLinear(config, config.hidden_size * 4, config.hidden_size) + + self.pre_layer_norm = Gemma4RMSNorm(config.hidden_size) + self.post_layer_norm = Gemma4RMSNorm(config.hidden_size) + self.act_fn = ACT2FN[config.hidden_act] + + self.gradient_clipping = config.gradient_clipping + self.post_layer_scale = config.residual_weight + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # This is needed to avoid any underflow/overflow issues when clipping + gradient_clipping = min(self.gradient_clipping, torch.finfo(self.ffw_layer_1.linear.weight.dtype).max) + + residual = hidden_states + hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping) + hidden_states = self.pre_layer_norm(hidden_states) + + hidden_states = self.ffw_layer_1(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states = self.ffw_layer_2(hidden_states) + + hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping) + hidden_states = self.post_layer_norm(hidden_states) + hidden_states *= self.post_layer_scale + hidden_states += residual + + return hidden_states + + +# TODO: this could be imported from Voxtral realtime +class Gemma4AudioCausalConv1d(nn.Conv1d): + # def __init__( + # self, + # in_channels: int, + # out_channels: int, + # kernel_size: int, + # # cache_key: str, + # stride: int = 1, + # dilation: int = 1, + # bias: bool = True, + # ): + # super().__init__(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, bias=bias) + # self.cache_key = cache_key + + @cached_property + def left_pad(self): + effective_kernel_size = (self.kernel_size[0] - 1) * self.dilation[0] + 1 + return effective_kernel_size - self.stride[0] + + def forward( + self, + x: torch.Tensor, + # padding_cache: VoxtralRealtimeConv1dPaddingCache | None = None, # TODO: we might want to add a cache? + ) -> torch.Tensor: + # if padding_cache is not None: + # x = padding_cache.update(x, self.cache_key, self) + # else: + # x = nn.functional.pad(x, (self.left_pad, 0)) + x = nn.functional.pad(x, (self.left_pad, 0)) + + return super().forward(x) + + +class Gemma4AudioLightConv1d(nn.Module): + def __init__(self, config: Gemma4AudioConfig): + super().__init__() + self.config = config + + self.linear_start = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size * 2) + self.linear_end = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size) + self.depthwise_conv1d = Gemma4AudioCausalConv1d( + in_channels=config.hidden_size, + out_channels=config.hidden_size, + kernel_size=config.conv_kernel_size, + groups=config.hidden_size, + bias=False, + ) + + self.pre_layer_norm = Gemma4RMSNorm(config.hidden_size, eps=config.rms_norm_eps, with_scale=True) + self.conv_norm = Gemma4RMSNorm(config.hidden_size, eps=config.rms_norm_eps, with_scale=True) + self.act_fn = ACT2FN[config.hidden_act] + + self.gradient_clipping = config.gradient_clipping + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + residual = hidden_states + + hidden_states = self.pre_layer_norm(hidden_states) + hidden_states = self.linear_start(hidden_states) + hidden_states = nn.functional.glu(hidden_states, dim=-1) + + hidden_states = self.depthwise_conv1d(hidden_states.transpose(1, 2)).transpose(1, 2) + + # This is needed to avoid any underflow/overflow issues when clipping + gradient_clipping = min(self.gradient_clipping, torch.finfo(self.linear_start.linear.weight.dtype).max) + hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping) + hidden_states = self.conv_norm(hidden_states) + + hidden_states = self.act_fn(hidden_states) + hidden_states = self.linear_end(hidden_states) + hidden_states += residual + return hidden_states + + +class Gemma4AudioLayer(nn.Module): + def __init__(self, config: Gemma4AudioConfig, layer_idx: int): + super().__init__() + self.config = config + + self.feed_forward1 = Gemma4AudioFeedForward(config) + self.feed_forward2 = Gemma4AudioFeedForward(config) + self.self_attn = Gemma4AudioAttention(config, layer_idx) + self.lconv1d = Gemma4AudioLightConv1d(config) + + self.norm_pre_attn = Gemma4RMSNorm(config.hidden_size) + self.norm_post_attn = Gemma4RMSNorm(config.hidden_size) + self.norm_out = Gemma4RMSNorm(config.hidden_size) + + self.gradient_clipping = config.gradient_clipping + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.BoolTensor | None, + position_embeddings: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + # This is needed to avoid any underflow/overflow issues when clipping + gradient_clipping = min(self.gradient_clipping, torch.finfo(self.norm_pre_attn.weight.dtype).max) + + hidden_states = self.feed_forward1(hidden_states) + residual = hidden_states + + hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping) + hidden_states = self.norm_pre_attn(hidden_states) + + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + ) + + hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping) + hidden_states = self.norm_post_attn(hidden_states) + hidden_states += residual + + hidden_states = self.lconv1d(hidden_states) + hidden_states = self.feed_forward2(hidden_states) + + hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping) + hidden_states = self.norm_out(hidden_states) + + return hidden_states + + +# ---- Vision Encoder Layers ---- + + +class Gemma4VisionPatchEmbedder(nn.Module): + def __init__(self, config: Gemma4VisionConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.patch_size = config.patch_size + self.position_embedding_size = config.position_embedding_size + + self.input_proj = nn.Linear(3 * self.patch_size**2, self.hidden_size, bias=False) + self.position_embedding_table = nn.Parameter(torch.ones(2, self.position_embedding_size, self.hidden_size)) + + def _position_embeddings(self, pixel_position_ids: torch.Tensor, padding_positions: torch.Tensor) -> torch.Tensor: + """Prepare patch positions map for matmul with positon embedding table.""" + # Expanding and permute patch positions to (batch_size, num_patches, 2, position_embedding_size) for matmul. + clamped_positions = pixel_position_ids.clamp(min=0) + one_hot = F.one_hot(clamped_positions, num_classes=self.position_embedding_size) + one_hot = one_hot.permute(0, 2, 1, 3).to(self.position_embedding_table) + # Compute positional embeddings and sum across x and y. + position_embeddings = one_hot @ self.position_embedding_table + position_embeddings = position_embeddings.sum(dim=1) + # Zero out embeddings for any padding patches. + position_embeddings = torch.where(padding_positions.unsqueeze(-1), 0.0, position_embeddings) + return position_embeddings + + def forward( + self, pixel_values: torch.Tensor, pixel_position_ids: torch.Tensor, padding_positions: torch.Tensor + ) -> torch.Tensor: + # Gemma4 applies no normalization and instead scales in model code + pixel_values = 2 * (pixel_values - 0.5) + hidden_states = self.input_proj(pixel_values.to(self.input_proj.weight.dtype)) + position_embeddings = self._position_embeddings(pixel_position_ids, padding_positions) + return hidden_states + position_embeddings + + +class Gemma4VisionPooler(nn.Module): + """Scaling and optional spatial pooling for vision encodings""" + + def __init__(self, config: Gemma4VisionConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.root_hidden_size = self.hidden_size**0.5 + + def _avg_pool_by_positions( + self, hidden_states: torch.Tensor, pixel_position_ids: torch.Tensor, length: int + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + 2D spatial pooling according to patch positions. + Pools the input tokens by averaging patches within a `k^2` grid, where `k` is determined by the ratio between + input and output lengths + """ + input_seq_len = hidden_states.shape[1] + k = int((input_seq_len // length) ** 0.5) + k_squared = k**2 + if k_squared * length != input_seq_len: + raise ValueError( + f"Cannot pool {hidden_states.shape} to {length}: {k=}^2 times {length=} must be {input_seq_len}." + ) + + # Clamp padding positions (which are -1) to 0 so they don't break one_hot. + # Padding patches have zero hidden states so they contribute nothing to the average. + clamped_positions = pixel_position_ids.clamp(min=0) + max_x = clamped_positions[..., 0].max(dim=-1, keepdim=True)[0] + 1 + kernel_idxs = torch.div(clamped_positions, k, rounding_mode="floor") + kernel_idxs = kernel_idxs[..., 0] + (max_x // k) * kernel_idxs[..., 1] + weights = F.one_hot(kernel_idxs.long(), length).float() / k_squared + output = weights.transpose(1, 2) @ hidden_states.float() + mask = torch.logical_not((weights == 0).all(dim=1)) + return output.to(hidden_states.dtype), mask + + def forward( + self, + hidden_states: torch.Tensor, + pixel_position_ids: torch.Tensor, + padding_positions: torch.Tensor, + output_length: int | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if output_length > hidden_states.shape[1]: + raise ValueError( + f"Cannot output more soft tokens (requested {output_length}) than there are patches" + f" ({hidden_states.shape[1]}). Change the value of `num_soft_tokens` when processing." + ) + + hidden_states = hidden_states.masked_fill(padding_positions.unsqueeze(-1), 0.0) + + if hidden_states.shape[1] != output_length: + hidden_states, padding_positions = self._avg_pool_by_positions( + hidden_states, pixel_position_ids, output_length + ) + + hidden_states *= self.root_hidden_size + return hidden_states, padding_positions + + +class Gemma4VisionMLP(Gemma3MLP): + def __init__(self, config: Gemma4VisionConfig): + super().__init__(self, config) + self.gate_proj = Gemma4ClippableLinear(config, self.hidden_size, self.intermediate_size) + self.up_proj = Gemma4ClippableLinear(config, self.hidden_size, self.intermediate_size) + self.down_proj = Gemma4ClippableLinear(config, self.intermediate_size, self.hidden_size) + + +def apply_multidimensional_rope( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + position_ids: torch.Tensor, + unsqueeze_dim: int = 2, +) -> torch.Tensor: + """Applies multidimensional RoPE to inputs. + + Args: + x (`torch.Tensor`): The tensor to embed. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + If position_ids.ndim + 2 == x.ndim, then this function passes through to `apply_rotary_pos_emb()`. + Otherwise, position_ids is used to split the inputs, x, into multiple pieces, where each piece is fed to + `apply_rotary_pos_emb()`, and then concatenated back together. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + + Returns: + Tensor of shape [B, L, N, H] with RoPE applied. + """ + ndim = position_ids.shape[-1] + num_input_channels = x.shape[-1] + num_rotated_channels_per_dim = 2 * (num_input_channels // (2 * ndim)) + + if num_rotated_channels_per_dim <= 0: + raise ValueError( + "Invalid configuration: num_rotated_channels_per_dim must be > 0, got" + f" {num_rotated_channels_per_dim} (num_input_channels={num_input_channels}," + f" ndim={ndim})" + ) + + # Correctly split the input tensor into ndim parts + split_sizes = [num_rotated_channels_per_dim] * ndim + x_parts = torch.split(x, split_sizes, dim=-1) + cos_parts = torch.split(cos, split_sizes, dim=-1) + sin_parts = torch.split(sin, split_sizes, dim=-1) + y_parts = [ + apply_rotary_pos_emb( + x=x_parts[k], + cos=cos_parts[k], + sin=sin_parts[k], + unsqueeze_dim=unsqueeze_dim, + ) + for k in range(ndim) + ] + return torch.cat(y_parts, dim=-1) + + +class Gemma4VisionRotaryEmbedding(LlamaRotaryEmbedding): + @staticmethod + def compute_default_rope_parameters( + config: Gemma4VisionConfig | None = None, + device: torch.device | None = None, + seq_len: int | None = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters["rope_theta"] + dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + + # The reference implementation computes RoPE frequencies INDEPENDENTLY + # for each spatial dimension using the partitioned head_dim (head_dim // ndim), + # so both x and y dimensions get identical frequency ranges. + # This is different from splitting the global inv_freq between dimensions. + spatial_dim = dim // 2 + + attention_factor = 1.0 # Unused in this type of RoPE + inv_freq = 1.0 / ( + base + ** (torch.arange(0, spatial_dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / spatial_dim) + ) + return inv_freq, attention_factor + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + + # Multidimensional positions: [batch, num_patches, ndim]. Apply rotations to each spatial dim separately + all_cos, all_sin = [], [] + for i in range(2): + dim_position_ids = position_ids[:, :, i] + dim_position_ids_expanded = dim_position_ids[:, None, :].float() + + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ dim_position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + all_cos.append(cos) + all_sin.append(sin) + + cos = torch.cat(all_cos, dim=-1).to(dtype=x.dtype) + sin = torch.cat(all_sin, dim=-1).to(dtype=x.dtype) + return cos, sin + + +class Gemma4VisionAttention(Gemma3Attention): + def __init__(self, config: Gemma4VisionConfig, layer_idx: int): + super().__init__(self, config, layer_idx) + del self.attn_logit_softcapping + del self.sliding_window + del self.is_sliding + self.scaling = 1.0 + self.is_causal = False + self.k_proj = Gemma4ClippableLinear(config, config.hidden_size, config.num_key_value_heads * self.head_dim) + self.q_proj = Gemma4ClippableLinear(config, config.hidden_size, config.num_attention_heads * self.head_dim) + self.v_proj = Gemma4ClippableLinear(config, config.hidden_size, config.num_key_value_heads * self.head_dim) + self.o_proj = Gemma4ClippableLinear(config, config.num_attention_heads * self.head_dim, config.hidden_size) + self.v_norm = Gemma4RMSNorm(self.head_dim, eps=config.rms_norm_eps, with_scale=False) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: torch.Tensor = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + cos, sin = position_embeddings + + query_states = self.q_proj(hidden_states).view(hidden_shape) + query_states = self.q_norm(query_states) + query_states = apply_multidimensional_rope(query_states, cos, sin, position_ids) + query_states = query_states.transpose(1, 2) + + key_states = self.k_proj(hidden_states).view(hidden_shape) + key_states = self.k_norm(key_states) + key_states = apply_multidimensional_rope(key_states, cos, sin, position_ids) + key_states = key_states.transpose(1, 2) + + value_states = self.v_proj(hidden_states).view(hidden_shape) + value_states = self.v_norm(value_states) + value_states = value_states.transpose(1, 2) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=self.attention_dropout if self.training else 0.0, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +# Same forward as Gemma3 but no cache +class Gemma4VisionEncoderLayer(Gemma3DecoderLayer): + def __init__(self, config: Gemma4VisionConfig, layer_idx: int): + super().__init__(self, config, layer_idx) + self.self_attn = Gemma4VisionAttention(config=config, layer_idx=layer_idx) + self.mlp = Gemma4VisionMLP(config) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: torch.Tensor = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + **kwargs, + ) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class Gemma4VisionEncoder(nn.Module): + def __init__(self, config: Gemma4VisionConfig): + super().__init__() + self.config = config + self.num_layers = config.num_hidden_layers + self.rotary_emb = Gemma4VisionRotaryEmbedding(config) + self.layers = nn.ModuleList( + [Gemma4VisionEncoderLayer(config=config, layer_idx=i) for i in range(self.num_layers)] + ) + + def forward( + self, + inputs_embeds: torch.Tensor, + attention_mask: torch.Tensor, + pixel_position_ids: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + r""" + pixel_position_ids (torch.Tensor): + Patch positions as (x, y) coordinates in the image as [batch, num_patches, 2]. + """ + attention_mask = create_bidirectional_mask( + config=self.config, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + ) + + # embed positions + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, pixel_position_ids) + + # decoder layers + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + position_ids=pixel_position_ids, + **kwargs, + ) + + return BaseModelOutputWithPast(last_hidden_state=hidden_states) + + +# ---- Text model ---- + + +class Gemma4TextMLP(Gemma3MLP): + def __init__(self, config: Gemma4TextConfig, layer_idx: int): + first_kv_shared_layer_idx = config.num_hidden_layers - config.num_kv_shared_layers + is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0 + use_double_wide_mlp = config.use_double_wide_mlp and is_kv_shared_layer + super().__init__() + self.intermediate_size = config.intermediate_size * (2 if use_double_wide_mlp else 1) + + +class Gemma4TextRotaryEmbedding(Gemma3RotaryEmbedding): + def __init__(self, config: Gemma4TextConfig, device=None, layer_type=None): + nn.Module.__init__(self) + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.layer_types = set(config.layer_types) + self.rope_init_fns: dict[str, Callable[..., tuple[torch.Tensor, float]]] = {} + self.rope_type: dict[str, str] = {} + + for layer_type in self.layer_types: + rope_params = self.config.rope_parameters[layer_type] + if rope_params is None: + continue + + if (rope_type := rope_params["rope_type"]) != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type] + else: + rope_init_fn = self.compute_default_rope_parameters + + self.rope_init_fns[layer_type] = rope_init_fn + self.rope_type[layer_type] = rope_type + + rope_init_fn_kwargs = {"device": device, "layer_type": layer_type} + if layer_type == "full_attention" and rope_type == "proportional": + rope_init_fn_kwargs["head_dim_key"] = "global_head_dim" + + curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, **rope_init_fn_kwargs) + self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False) + self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False) + setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling) + + +@use_kernelized_func(apply_rotary_pos_emb) +class Gemma4TextAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Gemma4TextConfig, layer_idx: int): + super().__init__() + self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None + self.config = config + self.layer_idx = layer_idx + self.is_sliding = self.layer_type == "sliding_attention" + self.sliding_window = config.sliding_window if self.is_sliding else None + + self.head_dim = config.global_head_dim if not self.is_sliding and config.global_head_dim else config.head_dim + self.use_alternative_attention = config.attention_k_eq_v and not self.is_sliding + num_key_value_heads = ( + config.num_global_key_value_heads if self.use_alternative_attention else config.num_key_value_heads + ) + self.num_key_value_groups = config.num_attention_heads // num_key_value_heads + self.scaling = 1.0 + self.attention_dropout = self.config.attention_dropout + self.is_causal = config.use_bidirectional_attention != "all" + + # Shared kv cache + first_kv_shared_layer_idx = self.config.num_hidden_layers - getattr(self.config, "num_kv_shared_layers", 0) + self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0 + prev_layers = config.layer_types[:first_kv_shared_layer_idx] + if self.is_kv_shared_layer: + # For shared layers, find the last non-shared layer of the same type before sharing starts + self.kv_shared_layer_index = len(prev_layers) - 1 - prev_layers[::-1].index(config.layer_types[layer_idx]) + self.store_full_length_kv = False + else: + self.kv_shared_layer_index = None + # For non-shared layers, store full-length kv if this is the last non-shared layer of its type + self.store_full_length_kv = layer_idx == len(prev_layers) - 1 - prev_layers[::-1].index( + config.layer_types[layer_idx] + ) + + self.q_norm = Gemma4RMSNorm(dim=self.head_dim, eps=config.rms_norm_eps) + self.k_norm = Gemma4RMSNorm(dim=self.head_dim, eps=config.rms_norm_eps) + self.v_norm = Gemma4RMSNorm(self.head_dim, eps=config.rms_norm_eps, with_scale=False) + + self.k_proj = nn.Linear(config.hidden_size, num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = ( + nn.Linear(config.hidden_size, num_key_value_heads * self.head_dim, bias=config.attention_bias) + if not self.use_alternative_attention + else None + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: torch.Tensor, + attention_mask: torch.Tensor | None, + past_key_values: Cache | None = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + cos, sin = position_embeddings + + query_states = self.q_proj(hidden_states).view(hidden_shape) + query_states = self.q_norm(query_states) + query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2) + query_states = query_states.transpose(1, 2) + + # For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer + if self.is_kv_shared_layer and past_key_values is not None: + key_states, value_states = past_key_values.shared_layers[self.kv_shared_layer_index] + # Device of past layer may be different from current one + key_states = key_states.to(query_states.device) + value_states = value_states.to(query_states.device) + else: + key_states = self.k_proj(hidden_states).view(hidden_shape) + value_states = self.v_proj(hidden_states).view(hidden_shape) if self.v_proj is not None else key_states + + key_states = self.k_norm(key_states) + key_states = apply_rotary_pos_emb(key_states, cos, sin, unsqueeze_dim=2) + key_states = key_states.transpose(1, 2) + + value_states = self.v_norm(value_states) + value_states = value_states.transpose(1, 2) + + if past_key_values is not None: + if not self.is_kv_shared_layer: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) + if self.store_full_length_kv: + if not hasattr(past_key_values, "shared_layers"): + past_key_values.shared_layers = {} + past_key_values.shared_layers[self.layer_idx] = key_states, value_states + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=self.attention_dropout if self.training else 0.0, + scaling=self.scaling, + sliding_window=self.sliding_window, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Gemma4TextExperts(MixtralExperts): + def __init__(self, config: Gemma4TextConfig): + super().__init__() + self.num_experts = config.num_experts + self.intermediate_dim = config.moe_intermediate_size + self.act_fn = ACT2FN[config.hidden_activation] + + +class Gemma4TextRouter(nn.Module): + def __init__(self, config: Gemma4TextConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.scalar_root_size = self.hidden_size**-0.5 + self.eps = config.rms_norm_eps + + self.norm = Gemma4RMSNorm(self.hidden_size, eps=self.eps, with_scale=False) + self.proj = nn.Linear(config.hidden_size, config.num_experts, bias=False) + self.scale = nn.Parameter(torch.ones(self.hidden_size)) + self.per_expert_scale = nn.Parameter(torch.ones(config.num_experts)) + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + hidden_states = self.norm(hidden_states) + hidden_states = hidden_states * self.scale * self.scalar_root_size + + expert_scores = self.proj(hidden_states) # [B*S, E] + router_probabilities = nn.functional.softmax(expert_scores, dim=-1) + + # topk returns both values (probabilities) and indices directly + top_k_weights, top_k_index = torch.topk( + router_probabilities, + k=self.config.top_k_experts, + dim=-1, + ) # both [B*S, K] + + # Normalize the top-k weights so they sum to 1 per token + top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True) + + # Apply per-expert scale directly to the weights + top_k_weights = top_k_weights * self.per_expert_scale[top_k_index] + + return router_probabilities, top_k_weights, top_k_index + + +class Gemma4TextDecoderLayer(Gemma3DecoderLayer): + def __init__(self, config: Gemma4TextConfig | Gemma4VisionConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.self_attn = Gemma4TextAttention(config=config, layer_idx=layer_idx) + self.mlp = Gemma4TextMLP(config, layer_idx) + self.register_buffer("layer_scalar", torch.ones(1)) + + self.hidden_size_per_layer_input = config.hidden_size_per_layer_input + if self.hidden_size_per_layer_input: + self.act_fn = ACT2FN[config.hidden_activation] + self.per_layer_input_gate = nn.Linear(self.hidden_size, self.hidden_size_per_layer_input, bias=False) + self.per_layer_projection = nn.Linear(self.hidden_size_per_layer_input, self.hidden_size, bias=False) + self.post_per_layer_input_norm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + + self.enable_moe_block = config.enable_moe_block + if self.enable_moe_block: + self.router = Gemma4TextRouter(config) + self.experts = Gemma4TextExperts(config) + self.post_feedforward_layernorm_1 = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.post_feedforward_layernorm_2 = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.pre_feedforward_layernorm_2 = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + per_layer_input: torch.Tensor = None, + position_embeddings: torch.Tensor = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + **kwargs, + ) -> torch.Tensor: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + **kwargs, + ) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + + if self.enable_moe_block: + hidden_states_1 = self.post_feedforward_layernorm_1(hidden_states) + + # Take hidden states before MLP here + hidden_states_flat = residual.reshape(-1, residual.shape[-1]) + _, top_k_weights, top_k_index = self.router(hidden_states_flat) + hidden_states_2 = self.pre_feedforward_layernorm_2(hidden_states_flat) + hidden_states_2 = self.experts(hidden_states_2, top_k_index, top_k_weights) + hidden_states_2 = hidden_states_2.reshape(residual.shape) + hidden_states_2 = self.post_feedforward_layernorm_2(hidden_states_2) + + # Combine mlp and moe outputs + hidden_states = hidden_states_1 + hidden_states_2 + + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + + if self.hidden_size_per_layer_input: + residual = hidden_states + hidden_states = self.per_layer_input_gate(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states = hidden_states * per_layer_input + hidden_states = self.per_layer_projection(hidden_states) + hidden_states = self.post_per_layer_input_norm(hidden_states) + hidden_states = residual + hidden_states + + hidden_states *= self.layer_scalar + return hidden_states + + +class Gemma4TextScaledWordEmbedding(Gemma3TextScaledWordEmbedding): + pass + + +# ---- Model Classes ---- + + +class Gemma4PreTrainedModel(PreTrainedModel): + config: Gemma4Config + supports_gradient_checkpointing = True + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + _can_compile_fullgraph = True + _supports_attention_backend = True + _no_split_modules = ["Gemma4TextDecoderLayer", "Gemma4VisionEncoderLayer", "Gemma4AudioLayer"] + _skip_keys_device_placement = ["past_key_values"] + input_modalities = ("image", "text", "video", "audio") + + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, Gemma4VisionPatchEmbedder): + init.ones_(module.position_embedding_table) + elif isinstance(module, Gemma4AudioRelPositionalEncoding): + min_timescale = 1.0 + max_timescale = 10000.0 + num_timescales = module.hidden_size // 2 + log_timescale_increment = math.log(max_timescale / min_timescale) / max(num_timescales - 1, 1) + inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales) * -log_timescale_increment) + init.copy_(module.inv_timescales, inv_timescales.unsqueeze(0).unsqueeze(0)) + elif isinstance(module, Gemma4AudioAttention): + init.constant_(module.softcap, module.attention_logits_soft_cap) + init.zeros_(module.per_dim_scale) + elif isinstance(module, Gemma4TextRotaryEmbedding): + for layer_type, rope_init_fn in module.rope_init_fns.items(): + rope_init_fn_kwargs = {"layer_type": layer_type} + if layer_type == "full_attention" and module.rope_type[layer_type] == "proportional": + rope_init_fn_kwargs["head_dim_key"] = "global_head_dim" + + curr_inv_freq, _ = rope_init_fn(module.config, **rope_init_fn_kwargs) + init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq) + init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq) + elif isinstance(module, Gemma4VisionRotaryEmbedding): + rope_fn = ( + ROPE_INIT_FUNCTIONS[module.rope_type] + if module.rope_type != "default" + else module.compute_default_rope_parameters + ) + buffer_value, _ = rope_fn(module.config) + init.copy_(module.inv_freq, buffer_value) + init.copy_(module.original_inv_freq, buffer_value) + elif isinstance(module, Gemma4TextScaledWordEmbedding): + init.constant_(module.embed_scale, module.scalar_embed_scale) + elif isinstance(module, Gemma4TextRouter): + init.ones_(module.scale) + init.ones_(module.per_expert_scale) + elif isinstance(module, Gemma4TextExperts): + std = self.config.initializer_range + init.normal_(module.gate_up_proj, mean=0.0, std=std) + init.normal_(module.down_proj, mean=0.0, std=std) + elif isinstance(module, Gemma4TextDecoderLayer): + init.ones_(module.layer_scalar) + elif isinstance(module, Gemma4ClippableLinear) and module.use_clipped_linears: + init.constant_(module.input_min, -float("inf")) + init.constant_(module.input_max, float("inf")) + init.constant_(module.output_min, -float("inf")) + init.constant_(module.output_max, float("inf")) + elif isinstance(module, Gemma4VisionModel) and module.config.standardize: + init.zeros_(module.std_bias) + init.ones_(module.std_scale) + + +@auto_docstring(custom_intro="The base Gemma 4 language model without a language modeling head.") +class Gemma4TextModel(Gemma3TextModel): + config: Gemma4TextConfig + _can_record_outputs = { + "router_logits": OutputRecorder(Gemma4TextRouter, index=0), + "hidden_states": Gemma4TextDecoderLayer, + "attentions": Gemma4TextAttention, + } + + def __init__(self, config: Gemma4TextConfig): + super().__init__(config) + self.layers = nn.ModuleList( + [Gemma4TextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.rotary_emb = Gemma4TextRotaryEmbedding(config) + self.unique_layer_types = set(self.config.layer_types) + + self.hidden_size_per_layer_input = config.hidden_size_per_layer_input + if self.hidden_size_per_layer_input: + self.embed_tokens_per_layer = Gemma4TextScaledWordEmbedding( + config.vocab_size_per_layer_input, + config.num_hidden_layers * config.hidden_size_per_layer_input, + self.padding_idx, + embed_scale=config.hidden_size_per_layer_input**0.5, + ) + self.per_layer_input_scale = 2.0**-0.5 + self.per_layer_model_projection = nn.Linear( + config.hidden_size, + config.num_hidden_layers * config.hidden_size_per_layer_input, + bias=False, + ) + self.per_layer_model_projection_scale = config.hidden_size**-0.5 + self.per_layer_projection_norm = Gemma4RMSNorm(config.hidden_size_per_layer_input, eps=config.rms_norm_eps) + + def get_per_layer_inputs(self, input_ids: torch.Tensor | None, inputs_embeds: torch.Tensor | None) -> torch.Tensor: + if not self.hidden_size_per_layer_input: + raise RuntimeError( + "Attempting to call get_per_layer_inputs() from a model initialized with a config that does not support" + f" per-layer embeddings. {self.config}" + ) + + # If only inputs_embeds are provided, reverse main embedding to find the input_ids - this allows to `generate` + # from `inputs_embeds` only as other models (otherwise it would need the value from both embeddings) + if input_ids is None: + with torch.no_grad(): + input_ids = ( + ( + inputs_embeds[:, :, None, :] + == self.embed_tokens.weight[None, None, :, :] * self.config.hidden_size**0.5 + ) + .all(dim=3) + .nonzero()[:, 2] + ) + try: + input_ids = input_ids.view(inputs_embeds.shape[:2]) + except RuntimeError: + raise RuntimeError( + "It seems like you tried to call `forward` from `inputs_embeds` without providing `input_ids`, and that " + "the `inputs_embeds` you provided do not exactly match the embedding weights. Since Gemma4 needs to reverse " + "the embedding to compute another embedding, make sure you provide exact `inputs_embeds`" + ) + + return self.embed_tokens_per_layer(input_ids).reshape( + *input_ids.shape, + self.config.num_hidden_layers, + self.hidden_size_per_layer_input, + ) + + def project_per_layer_inputs( + self, + inputs_embeds: torch.Tensor, + per_layer_inputs: torch.Tensor | None = None, + ) -> torch.Tensor: + if not self.hidden_size_per_layer_input: + raise RuntimeError( + "Attempting to call project_per_layer_inputs() from a model initialized with a config that does not" + f" support per-layer embeddings. {self.config}" + ) + + per_layer_projection = self.per_layer_model_projection(inputs_embeds) * self.per_layer_model_projection_scale + per_layer_projection = per_layer_projection.reshape( + *inputs_embeds.shape[:-1], + self.config.num_hidden_layers, + self.hidden_size_per_layer_input, + ) + per_layer_projection = self.per_layer_projection_norm(per_layer_projection) + + if per_layer_inputs is None: + return per_layer_projection + + return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale + + @merge_with_config_defaults + @capture_outputs + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + per_layer_inputs: torch.Tensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + r""" + per_layer_inputs (`torch.Tensor` of shape `(batch_size, sequence_length, num_hidden_layers, hidden_size_per_layer_input)`, *optional*): + Pre-computed per-layer input embeddings. When provided, these are used directly instead of being + computed from `input_ids` via `get_per_layer_inputs()`. This is primarily used by the multimodal + model (`Gemma4Model`) which pre-computes per-layer inputs from the original `input_ids` *before* + merging multimodal soft tokens into `inputs_embeds` — at which point the original token ids are + no longer recoverable. + """ + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if input_ids is not None: + inputs_embeds = self.embed_tokens(input_ids) + + if self.hidden_size_per_layer_input: + if per_layer_inputs is None: + per_layer_inputs = self.get_per_layer_inputs(input_ids, inputs_embeds) + per_layer_inputs = self.project_per_layer_inputs(inputs_embeds, per_layer_inputs) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if position_ids is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens + position_ids = position_ids.unsqueeze(0) + + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "inputs_embeds": inputs_embeds, + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), + } + + # embed positions + hidden_states = inputs_embeds + position_embeddings = {} + for layer_type in self.unique_layer_types: + position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type) + + # decoder layers + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + per_layer_input = per_layer_inputs[:, :, i, :] if per_layer_inputs is not None else None + + hidden_states = decoder_layer( + hidden_states, + per_layer_input, + position_embeddings=position_embeddings[self.config.layer_types[i]], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], + position_ids=position_ids, + past_key_values=past_key_values, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +@auto_docstring(custom_intro="The base Gemma 4 language model with a language modeling head.") +class Gemma4ForCausalLM(Gemma3ForCausalLM): + base_model_prefix = "model" + + +class Gemma4AudioModel(Gemma4PreTrainedModel): + """An audio encoder based on the [Universal Speech Model](https://huggingface.co/papers/2303.01037) architecture.""" + + config: Gemma4AudioConfig + main_input_name = "input_features" + base_model_prefix = "model.audio_tower" # prefix for Gemma4ForConditionalGeneration saved checkpoints, required for Gemma4AudioModel.from_pretrained() + _can_record_outputs = { + "hidden_states": Gemma4AudioLayer, + "attentions": Gemma4AudioAttention, + } + + def __init__(self, config: Gemma4AudioConfig): + super().__init__(config) + self.config = config + + self.subsample_conv_projection = Gemma4AudioSubSampleConvProjection(config) + self.rel_pos_enc = Gemma4AudioRelPositionalEncoding(config) + self.layers = nn.ModuleList( + [Gemma4AudioLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.output_proj = nn.Linear(config.hidden_size, config.output_proj_dims, bias=True) + + self.post_init() + + def _convert_4d_mask_to_blocked_5d(self, mask_4d: torch.Tensor) -> torch.Tensor: + """ + Convert a standard 4D attention mask `[batch_size, 1, seq_len, seq_len]` to the 5D blocked format + `[batch_size, 1, num_blocks, chunk_size, context_size]` expected by the chunked local attention, + """ + batch_size, _, seq_len, _ = mask_4d.shape + device = mask_4d.device + + chunk_size = self.config.attention_chunk_size + max_past_horizon = self.config.attention_context_left - 1 + max_future_horizon = self.config.attention_context_right + + num_blocks = (seq_len + chunk_size - 1) // chunk_size + padded_seq_len = num_blocks * chunk_size + pad_amount = padded_seq_len - seq_len + + mask_4d = F.pad(mask_4d, (0, pad_amount, 0, pad_amount), value=False) + mask_5d = mask_4d.reshape(batch_size, 1, num_blocks, chunk_size, padded_seq_len) + mask_5d = F.pad(mask_5d, (max_past_horizon, max_future_horizon), value=False) + + block_starts = torch.arange(num_blocks, device=device) * chunk_size + offsets = torch.arange(chunk_size + max_past_horizon + max_future_horizon, device=device) + kv_indices = block_starts[:, None] + offsets[None, :] + kv_indices = kv_indices[None, None, :, None, :].expand(batch_size, 1, -1, chunk_size, -1) + + return mask_5d.gather(-1, kv_indices) + + @merge_with_config_defaults + @capture_outputs + @auto_docstring(custom_intro="Encodes audio features to soft tokens.") + def forward( + self, + input_features: torch.Tensor, + attention_mask: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.BoolTensor]: + hidden_states, output_mask = self.subsample_conv_projection(input_features, attention_mask) + position_embeddings = self.rel_pos_enc(hidden_states) + + attention_mask = create_bidirectional_mask( + config=self.config, + inputs_embeds=hidden_states, + attention_mask=output_mask, + and_mask_function=sliding_window_mask_function( + (self.config.attention_context_left - 1, self.config.attention_context_right) + ), + ) + attention_mask = self._convert_4d_mask_to_blocked_5d(attention_mask) + + for encoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = encoder_layer( + hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.output_proj(hidden_states) + return Gemma4AudioModelOutput(last_hidden_state=hidden_states, attention_mask=output_mask) + + +class Gemma4VisionModel(Gemma4PreTrainedModel): + """The Gemma 4 Vision Encoder.""" + + config = Gemma4VisionConfig + _can_record_outputs = { + "hidden_states": Gemma4VisionEncoderLayer, + "attentions": Gemma4VisionAttention, + } + + def __init__(self, config: Gemma4VisionConfig): + super().__init__(config) + self.patch_embedder = Gemma4VisionPatchEmbedder(config) + self.encoder = Gemma4VisionEncoder(config) + self.pooler = Gemma4VisionPooler(config) + + if self.config.standardize: + self.register_buffer("std_bias", torch.empty(self.config.hidden_size)) + self.register_buffer("std_scale", torch.empty(self.config.hidden_size)) + + self.post_init() + + @merge_with_config_defaults + @capture_outputs + @auto_docstring(custom_intro="Encodes image pixels to soft tokens from patches.") + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_position_ids: torch.LongTensor, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + r""" + pixel_values (`torch.FloatTensor` or `list[torch.FloatTensor]`): + The images to encode. Either a single `[batch, channels, height, width]` tensor + (all images same size) or a list of `[1, channels, height, width]` tensors (different sizes). + pixel_position_ids (`torch.LongTensor` of shape `(batch_size, max_patches, 2)`): + The patch positions as (x, y) coordinates in the image. Padding patches are indicated by (-1, -1). + """ + pooling_kernel_size = self.config.pooling_kernel_size + output_length = pixel_values.shape[-2] // (pooling_kernel_size * pooling_kernel_size) + + padding_positions = (pixel_position_ids == -1).all(dim=-1) + inputs_embeds = self.patch_embedder(pixel_values, pixel_position_ids, padding_positions) + output = self.encoder( + inputs_embeds=inputs_embeds, + attention_mask=~padding_positions, # encoder expects True=valid, padding_positions is True=padding + pixel_position_ids=pixel_position_ids, + **kwargs, + ) + + hidden_states, pooler_mask = self.pooler( + hidden_states=output.last_hidden_state, + pixel_position_ids=pixel_position_ids, + padding_positions=padding_positions, + output_length=output_length, + ) + + # Strip padding tokens. pooler_mask is True = valid, False = padding. + hidden_states = hidden_states[pooler_mask] + + if self.config.standardize: + hidden_states = (hidden_states - self.std_bias) * self.std_scale + + return BaseModelOutputWithPast(last_hidden_state=hidden_states) + + +class Gemma4MultimodalEmbedder(Gemma3nMultimodalEmbedder): + def __init__( + self, + multimodal_config: Gemma4AudioConfig | Gemma4VisionConfig, + text_config: Gemma4TextConfig, + ): + # Audio tower may use a different output dimension (output_proj_dims) than the + # internal hidden_size. Use the tower-specific dimension if specified. + super().__init__(multimodal_config, text_config) + del self.embedding + del self.hard_embedding_norm + del self.soft_embedding_norm + del self.vocab_offset + del self.vocab_size + del self.embedding_post_projection_norm + + self.multimodal_hidden_size = getattr(multimodal_config, "output_proj_dims", multimodal_config.hidden_size) + self.embedding_pre_projection_norm = Gemma4RMSNorm(self.multimodal_hidden_size, eps=self.eps, with_scale=False) + + def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor: + """Embeds token ids or soft tokens for multimodal content into language model space. + Args: + inputs_embeds: A torch.Tensor containing the soft tokens to embed. + Returns: + A torch.Tensor of embeddings with shape `[batch_size, seq_len, self.config.text_config.hidden_size]`. + """ + embs_normed = self.embedding_pre_projection_norm(inputs_embeds) + return self.embedding_projection(embs_normed) + + +# Identical as Gemma3 but modular can't resolve if we simply import. FIXME: @cyril +def token_type_ids_mask_function( + token_type_ids: torch.Tensor | None, + image_group_ids: torch.Tensor | None, +) -> Callable | None: + """ + This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths, + not start and end indices. + """ + # Do not return an additional mask in this case + if token_type_ids is None: + return None + + def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: + seq_length = image_group_ids.shape[-1] + + # clamp indices because with static cache they can go beyond `image_group_ids.shape[-1]` + q_idx_clamped = q_idx.clamp(max=seq_length - 1) + kv_idx_clamped = kv_idx.clamp(max=seq_length - 1) + + # Unmask if the q and kv come from same group which is not -1 (i.e. non-text) + q_group = image_group_ids[batch_idx, q_idx_clamped] + kv_group = image_group_ids[batch_idx, kv_idx_clamped] + q_group = torch.where(q_idx < seq_length, q_group, -1) + kv_group = torch.where(kv_idx < seq_length, kv_group, -1) + return (q_group == kv_group) & (q_group >= 0) + + return inner_mask + + +# Similar to Gemma3 but `sliding_mask_kwargs` and `mask_kwargs` are different and `token_type_ids->mm_token_type_ids` +def create_causal_mask_mapping( + config: PreTrainedConfig, + inputs_embeds: torch.Tensor, + attention_mask: torch.Tensor | None, + past_key_values: Cache | None, + position_ids: torch.Tensor | None, + mm_token_type_ids: torch.Tensor | None = None, + pixel_values: torch.FloatTensor | None = None, + is_training: bool = False, + is_first_iteration: bool | None = None, + **kwargs, +) -> dict: + """ + Overwrites the base `create_masks_for_generate` with `token_type_ids` masking to create the causal mask mapping + for all kinds of forward passes. Gemma4 uses a bidirectional mask for images. + + Uses `pixel_values` as an optional input to disambiguate edge cases. + """ + if is_training and mm_token_type_ids is None: + raise ValueError("`mm_token_type_ids` is required as a model input when training") + + mask_kwargs = { + "config": config.get_text_config(), + "inputs_embeds": inputs_embeds, + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + sliding_mask_kwargs = mask_kwargs.copy() + + # NOTE: this `may_have_image_input` logic is not flawless, it fails when we're using a cache eagerly initialized + # (e.g. compiled prefill) AND `pixel_values` are not provided (i.e. the image data is provided through other + # means). Determining prefill in that case requires checking data values, which is not compile-compatible. + is_first_iteration = ( + is_first_iteration + if is_first_iteration is not None + else (past_key_values is None or not past_key_values.is_initialized or pixel_values is not None) + ) + if mm_token_type_ids is not None and is_first_iteration: + # We need to pass an additional mask function to account for token type ids, and it needs to be an `or` (to + # undo the causal masking) + + # First find where a new vision block starts. Vision tokens cannot attend to + # future vision tokens, but can attend to all prev tokens and to itself bidirectionally + is_vision = (mm_token_type_ids == 1) | (mm_token_type_ids == 2) + is_prev_vision = torch.roll(is_vision, shifts=1, dims=-1) + is_prev_vision[..., 0] = False + new_vision_starts = is_vision & ~is_prev_vision + vision_group_ids = torch.cumsum(new_vision_starts.int(), dim=1) - 1 + vision_group_ids = torch.where(is_vision, vision_group_ids, -1) + sliding_mask_kwargs["or_mask_function"] = token_type_ids_mask_function( + mm_token_type_ids.to(inputs_embeds.device), vision_group_ids + ) + + return { + "full_attention": create_causal_mask(**mask_kwargs), + "sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs), + } + + +@auto_docstring( + custom_intro=""" + The base Gemma 4 model comprising a vision backbone, an audio backbone, and a language model without a + language modeling head. + """ +) +class Gemma4Model(Gemma3nModel): + def __init__(self, config: Gemma4Config): + super().__init__(config) + del self.vision_tower + del self.embed_vision + self.vision_tower = AutoModel.from_config(config.vision_config) if config.vision_config is not None else None + self.embed_vision = ( + Gemma4MultimodalEmbedder(config.vision_config, config.text_config) + if config.vision_config is not None + else None + ) + del self.audio_tower + del self.embed_audio + self.audio_tower = AutoModel.from_config(config.audio_config) if config.audio_config is not None else None + self.embed_audio = ( + Gemma4MultimodalEmbedder(config.audio_config, config.text_config) + if config.audio_config is not None + else None + ) + + @can_return_tuple + @auto_docstring(custom_intro="Projects the last hidden state from the vision model into language model space.") + def get_image_features( + self, + pixel_values: torch.FloatTensor, + image_position_ids: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPooling: + r""" + image_position_ids (`torch.LongTensor` of shape `(batch_size, max_patches, 2)`, *optional*): + The patch positions as (x, y) coordinates in the image. Padding patches are indicated by (-1, -1). + """ + vision_outputs = self.vision_tower( + pixel_values=pixel_values, + pixel_position_ids=image_position_ids, + **kwargs, + ) + last_hidden_state = vision_outputs.last_hidden_state + vision_outputs.pooler_output = self.embed_vision(inputs_embeds=last_hidden_state) + return vision_outputs + + @can_return_tuple + @auto_docstring(custom_intro="Projects the last hidden state from the vision encoder into language model space.") + def get_video_features( + self, + pixel_values_videos: torch.FloatTensor, + video_position_ids: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPooling: + r""" + video_position_ids (`torch.LongTensor` of shape `(num_videos, num_frames, max_patches, 2)`, *optional*): + 2D patch position coordinates from the video processor, with `(-1, -1)` indicating padding. + Passed through to the vision encoder for positional embedding computation. + """ + pixel_values_videos = pixel_values_videos.flatten(0, 1) + video_position_ids = video_position_ids.flatten(0, 1) + vision_outputs = self.vision_tower( + pixel_values=pixel_values_videos, + pixel_position_ids=video_position_ids, + **kwargs, + ) + last_hidden_state = vision_outputs.last_hidden_state + vision_outputs.pooler_output = self.embed_vision(inputs_embeds=last_hidden_state) + return vision_outputs + + def get_placeholder_mask( + self, + input_ids: torch.LongTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + ) -> tuple[torch.BoolTensor, torch.BoolTensor, torch.BoolTensor]: + """ + Obtains mask for multimodal placeholders (replaced by soft tokens) and hard text tokens. + + Masks will be obtained from `mm_token_type_ids`, `input_ids`, or `inputs_embeds` as available and in that + precedence order. If passing `input_ids` or `inputs_embeds`, the image mask will be derived using + `config.image_token_id`. Same goes for audio and video masks + + Args: + input_ids: A tensor containing the hard token IDs from the text tokenizer. + inputs_embeds: A tensor containing the embeddings for all hard text tokens. + + Returns: + image_mask, video_mask, audio_mask + """ + if input_ids is not None: + special_image_mask = input_ids == self.config.image_token_id + special_video_mask = input_ids == self.config.video_token_id + special_audio_mask = input_ids == self.config.audio_token_id + else: + special_image_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + ).all(-1) + special_video_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + ).all(-1) + special_audio_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + ).all(-1) + + return special_image_mask, special_video_mask, special_audio_mask + + @merge_with_config_defaults + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + pixel_values: torch.FloatTensor | None = None, + pixel_values_videos: torch.FloatTensor | None = None, + input_features: torch.FloatTensor | None = None, + attention_mask: torch.Tensor | None = None, + input_features_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + mm_token_type_ids: torch.LongTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + image_position_ids: torch.LongTensor | None = None, + video_position_ids: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Gemma4ModelOutputWithPast: + r""" + input_features_mask (`torch.FloatTensor]` of shape `(num_images, seq_length)`): + The attention mask for the input audio. + image_position_ids (`torch.LongTensor` of shape `(batch_size, max_patches, 2)`, *optional*): + 2D patch position coordinates from the image processor, with `(-1, -1)` indicating padding. + Passed through to the vision encoder for positional embedding computation. + video_position_ids (`torch.LongTensor` of shape `(num_videos, num_frames, max_patches, 2)`, *optional*): + 2D patch position coordinates from the video processor, with `(-1, -1)` indicating padding. + Passed through to the vision encoder for positional embedding computation. + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + image_mask, video_mask, audio_mask = self.get_placeholder_mask(input_ids, inputs_embeds) + multimodal_mask = image_mask | video_mask | audio_mask + + # Replace image id with PAD if the image token if OOV, to avoid index-errors + llm_input_ids = None + if inputs_embeds is None: + llm_input_ids = input_ids.clone() + llm_input_ids[multimodal_mask] = self.config.text_config.pad_token_id + inputs_embeds = self.get_input_embeddings()(llm_input_ids) + + if self.config.get_text_config().hidden_size_per_layer_input: + pad_embedding = self.language_model.embed_tokens.weight[self.config.text_config.pad_token_id, :] + llm_inputs_embeds = torch.where(multimodal_mask[..., None], pad_embedding.view(1, 1, -1), inputs_embeds) + per_layer_inputs = self.language_model.get_per_layer_inputs(llm_input_ids, llm_inputs_embeds) + else: + per_layer_inputs = None + + # Merge text and images + if pixel_values is not None: + image_features = self.get_image_features(pixel_values, image_position_ids, return_dict=True).pooler_output + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + + # Confirm the number of soft tokens from the vision tower matches the number of slots in the embeddings. + n_image_tokens = image_mask.sum() + image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + torch_compilable_check( + inputs_embeds[image_mask].numel() == image_features.numel(), + f"Image features and image tokens do not match, tokens: {n_image_tokens}, features:" + f" {image_features.shape[0]}", + ) + + inputs_embeds = inputs_embeds.masked_scatter( + image_mask.to(inputs_embeds.device), image_features.to(inputs_embeds.device) + ) + + if pixel_values_videos is not None: + video_features = self.get_video_features( + pixel_values_videos, video_position_ids, return_dict=True + ).pooler_output + video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) + + # Confirm the number of soft tokens from the vision tower matches the number of slots in the embeddings. + n_video_tokens = video_mask.sum() + video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + torch_compilable_check( + inputs_embeds[video_mask].numel() == video_features.numel(), + f"Video features and video tokens do not match, tokens: {n_video_tokens}, features:" + f" {video_features.shape[0]}", + ) + + inputs_embeds = inputs_embeds.masked_scatter( + video_mask.to(inputs_embeds.device), video_features.to(inputs_embeds.device) + ) + + # Merge text and audio + if input_features is not None and input_features_mask is not None: + audio_output = self.get_audio_features(input_features, input_features_mask, return_dict=True) + audio_features = audio_output.pooler_output + audio_mask_from_encoder = audio_output.attention_mask # True = valid + + # Strip padding tokens: only keep real (non-padding) audio soft tokens. + # audio_mask_from_encoder is True for valid positions, False for padding tokens. + # This mirrors the vision encoder's padding stripping (see Gemma4VisionEncoder.forward). + audio_features = audio_features[audio_mask_from_encoder] + + n_audio_tokens = audio_mask.sum() + audio_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + torch_compilable_check( + inputs_embeds[audio_mask].numel() == audio_features.numel(), + f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features:" + f" {audio_features.shape[0] * audio_features.shape[1]}", + ) + + inputs_embeds = inputs_embeds.masked_scatter( + audio_mask.to(inputs_embeds.device), audio_features.to(inputs_embeds.device) + ) + + # It may already have been prepared by, e.g., `generate` + if position_ids is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens + position_ids = position_ids.unsqueeze(0) + + if not isinstance(causal_mask_mapping := attention_mask, dict): + if self.config.get_text_config().use_bidirectional_attention == "vision": + # Larger Gemma 4 models use Gemma 3's bidirectional attention mask for vision inputs + causal_mask_mapping = create_causal_mask_mapping( + self.config, + inputs_embeds, + attention_mask, + past_key_values, + position_ids, + mm_token_type_ids, + pixel_values, + is_training=self.training, + ) + else: + # Smaller Gemma models use a conventional casual attention mask + causal_mask_mapping = create_masks_for_generate( + self.config, + inputs_embeds, + attention_mask, + past_key_values, + position_ids, + ) + + outputs = self.language_model( + per_layer_inputs=per_layer_inputs, + attention_mask=causal_mask_mapping, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + return_dict=True, + **kwargs, + ) + + return Gemma4ModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + audio_hidden_states=audio_features if input_features is not None else None, + ) + + @can_return_tuple + @auto_docstring(custom_intro="Projects the last hidden state from the audio encoder into language model space.") + def get_audio_features( + self, + input_features: torch.Tensor, + input_features_mask: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | Gemma4AudioModelOutput: + r""" + input_features (`torch.FloatTensor]` of shape `(num_images, seq_length, num_features)`): + The tensors corresponding to the input audio. + input_features_mask (`torch.FloatTensor]` of shape `(num_images, seq_length)`): + The attention mask for the input audio. + """ + if self.audio_tower is None: + raise ValueError( + "Audio features were requested, but the model was initialized without an audio_config. " + "Cannot process audio without an audio tower and audio embedder." + ) + + audio_outputs = self.audio_tower(input_features, input_features_mask, return_dict=True, **kwargs) + audio_outputs.pooler_output = self.embed_audio(inputs_embeds=audio_outputs.last_hidden_state) + + return audio_outputs + + +@auto_docstring( + custom_intro=""" + The base Gemma 4 model comprising a vision backbone, an audio backbone, a language model, and a language modeling + head. + """ +) +class Gemma4ForConditionalGeneration(Gemma3nForConditionalGeneration): + base_model_prefix = "model" + + def forward( + self, + input_ids: torch.LongTensor | None = None, + pixel_values: torch.FloatTensor | None = None, + pixel_values_videos: torch.FloatTensor | None = None, + input_features: torch.FloatTensor | None = None, + attention_mask: torch.Tensor | None = None, + input_features_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + image_position_ids: torch.LongTensor | None = None, + video_position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + mm_token_type_ids: torch.LongTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> Gemma4CausalLMOutputWithPast: + r""" + input_features_mask (`torch.FloatTensor]` of shape `(num_images, seq_length)`): + The attention mask for the input audio. + image_position_ids (`torch.LongTensor` of shape `(batch_size, max_patches, 2)`, *optional*): + 2D patch position coordinates from the image processor, with `(-1, -1)` indicating padding. + Passed through to the vision encoder for positional embedding computation. + video_position_ids (`torch.LongTensor` of shape `(num_videos, num_frames, max_patches, 2)`, *optional*): + 2D patch position coordinates from the video processor, with `(-1, -1)` indicating padding. + Passed through to the vision encoder for positional embedding computation. + """ + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + input_features=input_features, + attention_mask=attention_mask, + input_features_mask=input_features_mask, + position_ids=position_ids, + past_key_values=past_key_values, + mm_token_type_ids=mm_token_type_ids, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + image_position_ids=image_position_ids, + video_position_ids=video_position_ids, + return_dict=True, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + if (final_logit_softcapping := self.config.get_text_config().final_logit_softcapping) is not None: + logits = logits / final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * final_logit_softcapping + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + shift_logits = logits[..., :-1, :] + shift_labels = labels[..., 1:] + if attention_mask is not None: + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device) + shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous() + shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous() + else: + shift_logits = shift_logits.contiguous() + shift_labels = shift_labels.contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + + flat_logits = shift_logits.view(-1, self.config.get_text_config().vocab_size) + flat_labels = shift_labels.view(-1).to(shift_logits.device) + loss = loss_fct(flat_logits, flat_labels) + + return Gemma4CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + audio_hidden_states=outputs.audio_hidden_states, + ) + + @auto_docstring + def get_image_features( + self, + pixel_values: torch.FloatTensor, + image_position_ids: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ): + r""" + image_position_ids (`torch.LongTensor` of shape `(batch_size, max_patches, 2)`, *optional*): + 2D patch position coordinates from the image processor, with `(-1, -1)` indicating padding. + Passed through to the vision encoder for positional embedding computation. + """ + return self.model.get_image_features(pixel_values, image_position_ids, **kwargs) + + @staticmethod + def create_masks_for_generate( + config: PreTrainedConfig, + inputs_embeds: torch.Tensor, + attention_mask: torch.Tensor | None, + past_key_values: Cache | None, + position_ids: torch.Tensor | None, + mm_token_type_ids: torch.Tensor | None = None, + is_first_iteration: bool | None = False, + **kwargs, + ) -> dict: + if getattr(config.get_text_config(), "use_bidirectional_attention", None) == "vision": + # Larger Gemma 4 models use Gemma 3's bidirectional attention mask for vision inputs + return create_causal_mask_mapping( + config, + inputs_embeds, + attention_mask, + past_key_values, + position_ids, + mm_token_type_ids, + is_first_iteration=is_first_iteration, + **{k: v for k, v in kwargs.items() if k != "pixel_values"}, + ) + else: + # Smaller Gemma models use a conventional casual attention mask + return create_masks_for_generate( + config, inputs_embeds, attention_mask, past_key_values, position_ids, **kwargs + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + position_ids=None, + pixel_values=None, + pixel_values_videos=None, + input_features=None, + attention_mask=None, + input_features_mask=None, + token_type_ids=None, + use_cache=True, + logits_to_keep=None, + labels=None, + is_first_iteration=False, + **kwargs, + ): + # Overwritten -- custom `position_ids` and `pixel_values` handling + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=use_cache, + logits_to_keep=logits_to_keep, + token_type_ids=token_type_ids, + is_first_iteration=is_first_iteration, + **kwargs, + ) + + # If we're in cached decoding stage, multimodal inputs are already cached and can be dropped + if is_first_iteration or not use_cache: + model_inputs["pixel_values"] = pixel_values + model_inputs["pixel_values_videos"] = pixel_values_videos + model_inputs["input_features"] = input_features + model_inputs["input_features_mask"] = input_features_mask + + return model_inputs + + +__all__ = [ + "Gemma4AudioModel", + "Gemma4ForCausalLM", + "Gemma4ForConditionalGeneration", + "Gemma4Model", + "Gemma4PreTrainedModel", + "Gemma4TextModel", + "Gemma4VisionModel", +] diff --git a/src/transformers/models/gemma4/processing_gemma4.py b/src/transformers/models/gemma4/processing_gemma4.py new file mode 100644 index 000000000000..d688250d0b36 --- /dev/null +++ b/src/transformers/models/gemma4/processing_gemma4.py @@ -0,0 +1,366 @@ +# Copyright 2026 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +import numpy as np + +from ...audio_utils import AudioInput +from ...image_processing_utils import BatchFeature +from ...image_utils import ImageInput, make_nested_list_of_images +from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import auto_docstring, is_vision_available, logging +from ...utils.import_utils import requires +from ...video_utils import VideoInput + + +if is_vision_available(): + from .image_processing_pil_gemma4 import Gemma4ImageProcessorKwargs, get_aspect_ratio_preserving_size + + +logger = logging.get_logger(__name__) + + +class Gemma4ProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: Gemma4ImageProcessorKwargs + _defaults = { + "text_kwargs": { + "padding": True, + "return_mm_token_type_ids": True, + }, + "images_kwargs": { + "do_convert_rgb": True, + }, + "audio_kwargs": {}, + "videos_kwargs": {"return_metadata": True}, + } + + +@auto_docstring +@requires(backends=("vision",)) +class Gemma4Processor(ProcessorMixin): + def __init__( + self, + feature_extractor, + image_processor, + tokenizer, + video_processor, + chat_template=None, + image_seq_length: int = 280, + audio_seq_length: int = 750, + audio_ms_per_token: int = 40, + **kwargs, + ): + r""" + image_seq_length (`int`, *optional*, defaults to 280): + The number of soft tokens per image used for placeholder expansion. + audio_seq_length (`int`, *optional*, defaults to 750): + The maximum number of audio soft tokens per audio segment. Serves as an + upper-bound cap when dynamic audio token counts are computed. + audio_ms_per_token (`int`, *optional*, defaults to 40): + Milliseconds of audio per output soft token. Used to dynamically compute + the number of audio placeholder tokens as ``ceil(duration_ms / audio_ms_per_token)``. + The default of 40 comes from the SSCP convolution's 4× time reduction on 10ms frames. + """ + self.image_seq_length = image_seq_length + self.image_token_id = tokenizer.image_token_id + self.boi_token = tokenizer.boi_token + self.eoi_token = tokenizer.eoi_token + self.image_token = tokenizer.image_token + + # FIXME: add the token to config and ask Ryan to re-upload + tokenizer.add_special_tokens({"additional_special_tokens": ["<|video|>"]}) + self.video_token = "<|video|>" + self.video_token_id = tokenizer.convert_tokens_to_ids(self.video_token) + + # Audio token handling, mirroring the vision pattern. + # audio_seq_length serves as the maximum cap on the number of audio soft tokens + # any single audio segment can produce. With dynamic audio tokens, the actual + # number of placeholders inserted per audio is computed from the audio duration. + self.audio_seq_length = audio_seq_length + # Milliseconds of audio per output soft token. The default of 40 comes from the + # SSCP convolution's 4× time reduction applied to 10ms mel spectrogram frames. + self.audio_ms_per_token = audio_ms_per_token + self.audio_token_id = getattr(tokenizer, "audio_token_id", None) + self.audio_token = getattr(tokenizer, "audio_token", None) + self.boa_token = getattr(tokenizer, "boa_token", None) + self.eoa_token = getattr(tokenizer, "eoa_token", None) + + super().__init__( + feature_extractor=feature_extractor, + image_processor=image_processor, + tokenizer=tokenizer, + video_processor=video_processor, + chat_template=chat_template, + **kwargs, + ) + + @auto_docstring + def __call__( + self, + images: ImageInput | None = None, + text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None, + audio: AudioInput | None = None, + videos: VideoInput | None = None, + **kwargs: Unpack[Gemma4ProcessorKwargs], + ) -> BatchFeature: + if text is None and images is None and audio is None and videos is None: + raise ValueError("Provide at least one of `text`, `images`, `audio`, or `videos`.") + + output_kwargs = self._merge_kwargs( + Gemma4ProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise TypeError("Invalid input text. Please provide a string, or a list of strings") + + image_inputs = {} + if images is not None: + images = self.image_processor.fetch_images(images) + batched_images = make_nested_list_of_images(images) + image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"]) + + num_soft_tokens = image_inputs.pop("num_soft_tokens_per_image") + + # Create empty text to be replaced with placeholders + if not text: + text = [" ".join([self.image_token] * len(images)) for images in batched_images] + + if len(batched_images) != len(text): + raise ValueError( + f"Received inconsistently sized batches of images ({len(batched_images)}) and text ({len(text)})." + ) + + replacements = [f"{self.boi_token}{self.image_token * n}{self.eoi_token}" for n in num_soft_tokens] + replacements_iter = iter(replacements) + + # Expand image_token placeholders to per-image soft token sequences. + # re.sub never re-scans replaced text, so it is safe + pattern = re.escape(self.image_token) + text = [re.sub(pattern, lambda _: next(replacements_iter), prompt) for prompt in text] + + # Process video inputs in same way + video_inputs = {} + if videos is not None: + video_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"]) + num_video_tokens = video_inputs.pop("num_soft_tokens_per_video") + + # If user has not requested video metadata, pop it so it isn't returned + if not kwargs.get("return_metadata"): + video_metadata = video_inputs.pop("video_metadata") + else: + video_metadata = video_inputs["video_metadata"] + + video_replacements = [] + for metadata, n_tokens in zip(video_metadata, num_video_tokens): + if metadata.fps is None: + logger.warning_once( + "Gemma 4 requires frame timestamps to construct prompts, but the `fps` of the input video " + "could not be inferred. Probably `video_metadata` was missing from inputs and you passed " + "pre-sampled frames. Defaulting to `fps=24`. Please provide `video_metadata` for more " + "accurate results." + ) + metadata.fps = 24 if metadata.fps is None else metadata.fps + # mm:ss format for timestamps + timestamp_str = [ + f"{int(seconds // 60):02d}:{int(seconds % 60):02d}" for seconds in metadata.timestamps + ] + video_replacements.append( + " ".join( + [f"{t} {self.boi_token}{self.video_token * n_tokens}{self.eoi_token}" for t in timestamp_str] + ) + ) + + video_replacements = iter(video_replacements) + pattern = re.escape(self.video_token) + text = [re.sub(pattern, lambda _: next(video_replacements), prompt) for prompt in text] + + # Process audio inputs + audio_inputs = {} + if audio is not None: + if self.audio_token is None or self.boa_token is None or self.eoa_token is None: + raise ValueError( + "Audio inputs were provided, but the tokenizer does not have an `audio_token` defined." + ) + + # Normalize audio input to list of waveforms + if isinstance(audio, np.ndarray) and audio.ndim == 1: + audio = [audio] + + # TODO: Add tests for audio-only processor inputs. + if not text: + text = [self.audio_token] * len(audio) + + # Dynamic audio token expansion wihtout padding: + # * Extract audio features with feature extractor; + # * Compute precise per-audio token counts from the waveform duration; + # * Generate full audio token sequence for each computed audio length; + # * Expand text prompts with full audio token sequences. + audio_kwargs = output_kwargs.get("audio_kwargs", {}) + audio_inputs = self.feature_extractor(audio, **audio_kwargs) + sampling_rate = self.feature_extractor.sampling_rate + num_audio_tokens = [self._compute_audio_num_tokens(a, sampling_rate) for a in audio] + replacements = [f"{self.boa_token}{self.audio_token * n}{self.eoa_token}" for n in num_audio_tokens] + replacements_iter = iter(replacements) + audio_pattern = re.escape(self.audio_token) + text = [re.sub(audio_pattern, lambda _: next(replacements_iter), prompt) for prompt in text] + + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False) + text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"]) + + # Check special tokens for all active modalities + active_modalities = [] + if images is not None: + active_modalities.append("image") + if videos is not None: + active_modalities.append("video") + if audio is not None: + active_modalities.append("audio") + if active_modalities: + self._check_special_mm_tokens(text, text_inputs, modalities=active_modalities) + + if return_mm_token_type_ids: + text_inputs["mm_token_type_ids"] = self.create_mm_token_type_ids(text_inputs["input_ids"]) + + return BatchFeature( + data={**text_inputs, **image_inputs, **audio_inputs, **video_inputs}, + tensor_type=return_tensors, + ) + + def _compute_audio_num_tokens(self, audio_waveform, sampling_rate: int) -> int: + """Compute the number of audio soft tokens for a single waveform. + + Replicates the exact sequence-length arithmetic of the audio encoder + so that the processor inserts the correct number of placeholder tokens. + The computation mirrors: + + 1. Mel framing via ``_unfold`` in ``Gemma4AudioFeatureExtractor`` + 2. Two ``Conv2d`` subsampling layers in ``Gemma4AudioSubSampleConvProjection`` + (each: kernel=3, stride=2, semicausal padding top=1, bottom=1) + + The result is capped at ``self.audio_seq_length`` (the configured maximum). + + Args: + audio_waveform: A 1-D numpy array or list containing the raw audio samples. + sampling_rate: The sampling rate of the audio waveform in Hz. + + Returns: + The number of audio soft tokens to insert as placeholders. + """ + num_samples = len(audio_waveform) + + # Step 1: Mel frames (matches feature_extraction_gemma4.py _unfold) + frame_length = int(round(sampling_rate * 20.0 / 1000.0)) # 320 @ 16kHz + hop_length = int(round(sampling_rate * 10.0 / 1000.0)) # 160 @ 16kHz + frame_size_for_unfold = frame_length + 1 # 321 + + # The feature extractor prepends (frame_length // 2) zero samples as + # semicausal time-padding before the unfold. We must include this to + # match the actual number of mel frames it produces. + pad_left = frame_length // 2 # 160 @ 16kHz + padded_samples = num_samples + pad_left + num_mel_frames = (padded_samples - frame_size_for_unfold) // hop_length + 1 + + if num_mel_frames <= 0: + return 0 + + # Step 2: Two SSCP conv layers (kernel=3, stride=2, semicausal pad top=1, bottom=1) + # Each layer: T_out = (T_in + pad_top + pad_bottom - kernel) // stride + 1 + t = num_mel_frames + for _ in range(2): + t_padded = t + 2 # pad_top=1, pad_bottom=1 + t = (t_padded - 3) // 2 + 1 + + # Cap at the configured maximum + return min(t, self.audio_seq_length) + + def _get_num_multimodal_tokens(self, image_sizes=None, audio_lengths=None, **kwargs): + """ + Computes the number of placeholder tokens needed for multimodal inputs with the given sizes. + + Args: + image_sizes (`list[list[int]]`, *optional*): + The input sizes formatted as (height, width) per each image. + audio_lengths (`list[int]`, *optional*): + The lengths of audio inputs in number of samples. Used to dynamically + compute per-audio token counts. + + Returns: + `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided + input modalities, along with other useful data. + """ + + images_kwargs = Gemma4ProcessorKwargs._defaults.get("images_kwargs", {}) + images_kwargs.update(kwargs) + patch_size = images_kwargs.get("patch_size", None) or self.image_processor.patch_size + pooling_kernel_size = ( + images_kwargs.get("pooling_kernel_size", None) or self.image_processor.pooling_kernel_size + ) + max_soft_tokens = images_kwargs.get("max_soft_tokens", None) or self.image_processor.max_soft_tokens + + max_patches = max_soft_tokens * pooling_kernel_size**2 + + vision_data = {} + if image_sizes is not None: + num_image_tokens = [] + for image_size in image_sizes: + target_h, target_w = get_aspect_ratio_preserving_size( + height=image_size[0], + width=image_size[1], + patch_size=patch_size, + max_patches=max_patches, + pooling_kernel_size=pooling_kernel_size, + ) + patch_height = target_h // patch_size + patch_width = target_w // patch_size + num_image_tokens.append(patch_height * patch_width // pooling_kernel_size**2) + + num_image_patches = [1] * len(image_sizes) + vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches}) + + if audio_lengths is not None: + # Dynamically compute per-audio token counts from sample lengths. + # audio_lengths are in number of samples; assume default sampling rate. + sampling_rate = getattr(self.feature_extractor, "sampling_rate", 16_000) + num_audio_tokens = [ + self._compute_audio_num_tokens(np.zeros(length), sampling_rate) for length in audio_lengths + ] + vision_data.update({"num_audio_tokens": num_audio_tokens}) + + return MultiModalData(**vision_data) + + @property + def model_input_names(self): + model_input_names = super().model_input_names + model_input_names = [ + name + for name in model_input_names + if name not in ["num_soft_tokens_per_image", "num_soft_tokens_per_video"] + ] + + # Include audio feature extractor input names if available + if self.feature_extractor is not None: + feature_extractor_input_names = self.feature_extractor.model_input_names + model_input_names.extend([name for name in feature_extractor_input_names if name not in model_input_names]) + + return model_input_names + ["mm_token_type_ids"] + + +__all__ = ["Gemma4Processor"] diff --git a/src/transformers/models/gemma4/video_processing_gemma4.py b/src/transformers/models/gemma4/video_processing_gemma4.py new file mode 100644 index 000000000000..d867d31fcd7e --- /dev/null +++ b/src/transformers/models/gemma4/video_processing_gemma4.py @@ -0,0 +1,237 @@ +# Copyright 2026 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + +from ...image_processing_utils import BatchFeature +from ...processing_utils import Unpack, VideosKwargs +from ...utils import ( + TensorType, + add_start_docstrings, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, + is_vision_available, + logging, +) +from ...video_processing_utils import BASE_VIDEO_PROCESSOR_DOCSTRING, BaseVideoProcessor +from ...video_utils import VideoInput +from .image_processing_gemma4 import _SUPPORTED_SOFT_TOKENS, get_aspect_ratio_preserving_size + + +if is_vision_available(): + from ...image_utils import PILImageResampling + +if is_torch_available(): + import torch + +if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F +elif is_torchvision_available(): + from torchvision.transforms import functional as F + + +logger = logging.get_logger(__name__) + + +class Gemma4VideoProcessorKwargs(VideosKwargs, total=False): + """ + patch_size (`int`, *optional*): + Size of each image patch in pixels. + max_soft_tokens (`int`, *optional*): + Maximum number of soft (vision) tokens per video frame. + Must be one of {70, 140, 280, 560, 1120}. + pooling_kernel_size (`int`, *optional*): + Spatial pooling kernel size applied after patchification. + """ + + patch_size: int + max_soft_tokens: int + pooling_kernel_size: int + + +def convert_video_to_patches(video: "torch.Tensor", patch_size: int) -> "torch.Tensor": + """ + Convert 4D tensor video of shape (num_frames, num_channels, height, width) into 3D tensor of patches of shape + (num_frames, num_patches_height * num_patches_width, patch_size * patch_size * num_channels). + """ + num_frames, num_channels, height, width = video.shape + num_patches_height = height // patch_size + num_patches_width = width // patch_size + patched_video = video.reshape( + num_frames, num_channels, num_patches_height, patch_size, num_patches_width, patch_size + ) + patched_video = patched_video.permute(0, 2, 4, 3, 5, 1) + patched_video = patched_video.reshape(num_frames, num_patches_height * num_patches_width, -1) + return patched_video + + +def pad_to_max_patches( + video: "torch.Tensor", positions: "torch.Tensor", target_length: int +) -> tuple["torch.Tensor", "torch.Tensor"]: + """ + Pad the video along to max number of patches + """ + current_length = video.shape[1] + padding_length = target_length - current_length + if padding_length > 0: + padding = [0, 0, 0, padding_length, 0, 0] + pos_padding = (0, 0, 0, padding_length, 0, 0) + video = torch.nn.functional.pad(video, padding, mode="constant", value=0) + positions = torch.nn.functional.pad(positions, pos_padding, mode="constant", value=-1) + return video, positions + + +@add_start_docstrings( + "Constructs a Gemma4 video processor that samples frames from videos for use with the Gemma4 model.", + BASE_VIDEO_PROCESSOR_DOCSTRING, +) +class Gemma4VideoProcessor(BaseVideoProcessor): + resample = PILImageResampling.BICUBIC + image_mean = [0.0, 0.0, 0.0] + image_std = [1.0, 1.0, 1.0] + size = None + default_to_square = True + do_convert_rgb = True + do_resize = True + do_rescale = True + do_normalize = True + num_frames = 32 + do_sample_frames = True + patch_size = 16 + max_soft_tokens = 70 + pooling_kernel_size = 3 + valid_kwargs = Gemma4VideoProcessorKwargs + model_input_names = ["pixel_values_videos", "video_position_ids"] + + def __init__(self, **kwargs: Unpack[Gemma4VideoProcessorKwargs]): + super().__init__(**kwargs) + + if self.max_soft_tokens not in _SUPPORTED_SOFT_TOKENS: + raise ValueError(f"`max_soft_tokens` must be one of {_SUPPORTED_SOFT_TOKENS}, got {self.max_soft_tokens}.") + + def _validate_preprocess_kwargs(self, **kwargs): + # Gemma4 uses aspect_ratio_preserving_resize driven by patch_size, + # max_soft_tokens, and pooling_kernel_size — not the standard `size` + # parameter. Temporarily disable do_resize so the base validation + # doesn't require `size` to be set. + kwargs["do_resize"] = False + super()._validate_preprocess_kwargs(**kwargs) + + def aspect_ratio_preserving_resize( + self, + video: torch.Tensor, + patch_size: int, + max_patches: int, + pooling_kernel_size: int, + resample: F.InterpolationMode, + ) -> torch.Tensor: + height, width = video.shape[-2], video.shape[-1] + target_height, target_width = get_aspect_ratio_preserving_size( + height=height, + width=width, + patch_size=patch_size, + max_patches=max_patches, + pooling_kernel_size=pooling_kernel_size, + ) + + if target_height == height and target_width == width: + return video + + return F.resize( + video, + size=[target_height, target_width], + interpolation=resample, + antialias=True, + ) + + def preprocess( + self, + videos: VideoInput, + **kwargs: Unpack[Gemma4VideoProcessorKwargs], + ) -> BatchFeature: + return super().preprocess(videos, **kwargs) + + def _preprocess( + self, + videos: list["torch.Tensor"], + do_resize: bool, + resample: "F.InterpolationMode | int | None", + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: float | list[float] | None, + image_std: float | list[float] | None, + return_tensors: str | TensorType | None, + patch_size: int | None = None, + max_soft_tokens: int | None = None, + pooling_kernel_size: int | None = None, + **kwargs, + ) -> BatchFeature: + if max_soft_tokens not in _SUPPORTED_SOFT_TOKENS: + raise ValueError(f"`max_soft_tokens` must be one of {_SUPPORTED_SOFT_TOKENS}, got {max_soft_tokens}.") + + max_patches = max_soft_tokens * pooling_kernel_size**2 + + pixel_values = [] + position_ids = [] + num_soft_tokens_per_video = [] + num_frames = 1 + + for video in videos: + if do_resize: + video = self.aspect_ratio_preserving_resize( + video=video, + patch_size=patch_size, + max_patches=max_patches, + pooling_kernel_size=pooling_kernel_size, + resample=resample, + ) + + video = self.rescale_and_normalize(video, do_rescale, rescale_factor, do_normalize, image_mean, image_std) + + num_frames = video.shape[0] + patch_height = video.shape[-2] // patch_size + patch_width = video.shape[-1] // patch_size + patches = convert_video_to_patches(video, patch_size) + num_soft_tokens_per_video.append(patches.shape[1] // pooling_kernel_size**2) + + device = video.device + patch_grid = torch.meshgrid( + torch.arange(patch_width, device=device), + torch.arange(patch_height, device=device), + indexing="xy", + ) + stacked_grid = torch.stack(patch_grid, dim=-1) + real_positions = stacked_grid.reshape(patches.shape[1], 2) + real_positions = real_positions[None, ...].repeat(num_frames, 1, 1) + + patches, positions = pad_to_max_patches(patches, real_positions, max_patches) + pixel_values.append(patches) + position_ids.append(positions) + + # Stack into batch tensors + pixel_values = torch.stack(pixel_values, dim=0) # (num_videos, num_frames, max_patches, patch_pixels) + position_ids = torch.stack(position_ids, dim=0) # (num_videos, num_frames, max_patches, 2) + + data = { + "pixel_values_videos": pixel_values, + "video_position_ids": position_ids, + "num_soft_tokens_per_video": num_soft_tokens_per_video, + } + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["Gemma4VideoProcessor"] diff --git a/src/transformers/pipelines/any_to_any.py b/src/transformers/pipelines/any_to_any.py index bed637cbb381..4ae91d5a176f 100644 --- a/src/transformers/pipelines/any_to_any.py +++ b/src/transformers/pipelines/any_to_any.py @@ -13,6 +13,7 @@ # limitations under the License. import enum +import re from typing import Any, Union, overload import numpy as np @@ -159,6 +160,7 @@ def _sanitize_parameters( continue_final_message=None, skip_special_tokens=None, generation_mode=None, + processor_kwargs=None, **kwargs: Unpack[ProcessingKwargs], ): forward_kwargs = {} @@ -171,12 +173,15 @@ def _sanitize_parameters( preprocess_params["timeout"] = timeout if continue_final_message is not None: preprocess_params["continue_final_message"] = continue_final_message + if processor_kwargs is not None: + preprocess_params["processor_kwargs"] = processor_kwargs # Forward kwargs forward_kwargs["generate_kwargs"] = generate_kwargs or {} if generation_mode is not None and generation_mode != "text": forward_kwargs["generate_kwargs"]["generation_mode"] = generation_mode - if kwargs.get("load_audio_from_video"): + # Qwen-Omni models need to know the origin of audio, to align mm position ids + if kwargs.get("load_audio_from_video") and re.search(r"qwen\domni", self.model.__class__.__name__.lower()): forward_kwargs["generate_kwargs"]["use_audio_in_video"] = True if stop_sequence is not None: if isinstance(stop_sequence, str): @@ -359,8 +364,21 @@ def preprocess(self, inputs=None, timeout=None, continue_final_message=None, **p if continue_final_message is None: continue_final_message = inputs.messages[-1]["role"] == "assistant" + # Processor kwargs are passed separately from jinja kwargs to chat template + # but it was added only in https://github.com/huggingface/transformers/pull/44881 + processor_kwargs = processing_kwargs.pop("processor_kwargs", None) or {} + + chat_template_kwargs = { + "continue_final_message": continue_final_message, + "return_tensors": "pt", + "tokenize": True, + "return_dict": True, + "add_generation_prompt": not continue_final_message, + "processor_kwargs": processor_kwargs, + **processing_kwargs, + } + # Handle Mistral tokenizer which does not accept processing kwargs - chat_template_kwargs = {"add_generation_prompt": not continue_final_message, **processing_kwargs} if self.processor.tokenizer.__class__.__name__ == "MistralCommonBackend": chat_template_kwargs = { k: v for k, v in chat_template_kwargs.items() if k in ["padding", "truncation", "max_length"] @@ -368,10 +386,6 @@ def preprocess(self, inputs=None, timeout=None, continue_final_message=None, **p model_inputs = self.processor.apply_chat_template( inputs.messages, - continue_final_message=continue_final_message, - return_tensors="pt", - tokenize=True, - return_dict=True, **chat_template_kwargs, ).to(dtype=self.dtype) model_inputs["text"] = inputs @@ -385,16 +399,15 @@ def preprocess(self, inputs=None, timeout=None, continue_final_message=None, **p inputs = inputs.copy() # avoid in-place changes if users passed dict text = inputs.pop("text") - # Feature extractor do not load audio files and expect a decode array + # Feature extractor do not load audio files and expect a decoded array if inputs.get("audio", None) is not None and hasattr(self.processor, "feature_extractor"): inputs["audio"] = self.processor.feature_extractor.fetch_audio(inputs["audio"]) # If batched text inputs, we set padding to True unless specified otherwise + processor_kwargs = processing_kwargs.pop("processor_kwargs", None) or processing_kwargs if isinstance(text, (list, tuple)) and len(text) > 1: - processing_kwargs.setdefault("padding", True) - - # Multimodal data is loaded in preprocessors so we pass all ipnuts directly to `self.processor` - model_inputs = self.processor(text=text, **inputs, return_tensors="pt", **processing_kwargs).to( + processor_kwargs.setdefault("padding", True) + model_inputs = self.processor(text=text, **inputs, return_tensors="pt", **processor_kwargs).to( dtype=self.dtype ) model_inputs["text"] = text @@ -432,6 +445,8 @@ def postprocess( # Decode inputs and outputs the same way to remove input text from generated text if present skip_special_tokens = skip_special_tokens if skip_special_tokens is not None else True + if getattr(self.tokenizer, "response_schema", False): + skip_special_tokens = False generation_mode = postprocess_kwargs["generation_mode"] or "text" if generation_mode == "image" and hasattr(self.model, "decode_image_tokens"): generated_sequence = self.model.decode_image_tokens(generated_sequence.to(self.model.device)) diff --git a/src/transformers/pipelines/image_text_to_text.py b/src/transformers/pipelines/image_text_to_text.py index 793f23a78a7d..7d28b91ab2ab 100644 --- a/src/transformers/pipelines/image_text_to_text.py +++ b/src/transformers/pipelines/image_text_to_text.py @@ -140,6 +140,7 @@ def _sanitize_parameters( stop_sequence=None, continue_final_message=None, skip_special_tokens=None, + processor_kwargs=None, **kwargs: Unpack[ProcessingKwargs], ): forward_kwargs = {} @@ -152,6 +153,8 @@ def _sanitize_parameters( preprocess_params["timeout"] = timeout if continue_final_message is not None: preprocess_params["continue_final_message"] = continue_final_message + if processor_kwargs is not None: + preprocess_params["processor_kwargs"] = processor_kwargs # Forward kwargs if generate_kwargs is not None: @@ -327,16 +330,34 @@ def preprocess(self, inputs=None, timeout=None, continue_final_message=None, **p # because very few models support multiple separate, consecutive assistant messages if continue_final_message is None: continue_final_message = inputs.messages[-1]["role"] == "assistant" + + # Processor kwargs are passed separately from jinja kwargs to chat template + # but it was added only in https://github.com/huggingface/transformers/pull/44881 + processor_kwargs = processing_kwargs.pop("processor_kwargs", None) or {} + + chat_template_kwargs = { + "continue_final_message": continue_final_message, + "return_tensors": "pt", + "tokenize": True, + "return_dict": True, + "add_generation_prompt": not continue_final_message, + "processor_kwargs": processor_kwargs, + **processing_kwargs, + } + + # Handle Mistral tokenizer which does not accept processing kwargs + if self.processor.tokenizer.__class__.__name__ == "MistralCommonBackend": + chat_template_kwargs = { + k: v for k, v in chat_template_kwargs.items() if k in ["padding", "truncation", "max_length"] + } + model_inputs = self.processor.apply_chat_template( inputs.messages, - add_generation_prompt=not continue_final_message, - continue_final_message=continue_final_message, - return_tensors="pt", - tokenize=True, - return_dict=True, + **chat_template_kwargs, ).to(dtype=self.dtype) model_inputs["text"] = inputs return model_inputs + # In case we only have text inputs if isinstance(inputs, (list, tuple, str)): images = None @@ -348,9 +369,10 @@ def preprocess(self, inputs=None, timeout=None, continue_final_message=None, **p inputs_text = inputs["text"] # if batched text inputs, we set padding to True unless specified otherwise + processor_kwargs = processing_kwargs.pop("processor_kwargs", None) or processing_kwargs if isinstance(text, (list, tuple)) and len(text) > 1: - processing_kwargs.setdefault("padding", True) - model_inputs = self.processor(images=images, text=text, return_tensors="pt", **processing_kwargs).to( + processor_kwargs.setdefault("padding", True) + model_inputs = self.processor(images=images, text=text, return_tensors="pt", **processor_kwargs).to( dtype=self.dtype ) @@ -393,6 +415,8 @@ def postprocess( # Decode inputs and outputs the same way to remove input text from generated text if present skip_special_tokens = skip_special_tokens if skip_special_tokens is not None else True + if getattr(self.tokenizer, "response_schema", False): + skip_special_tokens = False generated_texts = self.processor.post_process_image_text_to_text( generated_sequence, skip_special_tokens=skip_special_tokens, **postprocess_kwargs ) diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index a7a25e234e09..6a0b2966d07c 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -450,6 +450,8 @@ def postprocess( split_keys[k] = v.numpy().tolist() skip_special_tokens = skip_special_tokens if skip_special_tokens is not None else True + if getattr(self.tokenizer, "response_schema", False): + skip_special_tokens = False for idx, sequence in enumerate(generated_sequence): if return_type == ReturnType.TENSORS: record = {"generated_token_ids": sequence} diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index b238b8b17031..8a73a68b0cf5 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -1832,6 +1832,9 @@ def apply_chat_template( batch_audios.append(load_audio(fname, sampling_rate=sampling_rate)) else: for fname in video_fnames: + # This updates the template in-place and adds audio entry + # to ensure `audio` token is added by jinja + message["content"].append({"type": "audio"}) batch_audios.append(load_audio(fname, sampling_rate=sampling_rate)) # Currently all processors can accept nested list of batches, but not flat list of visuals @@ -1839,10 +1842,8 @@ def apply_chat_template( batch_images.append(images) batch_videos.append(videos) - template_kwargs = { - **self.tokenizer.special_tokens_map, - **kwargs, - } # kwargs overwrite special tokens if both are present + # `kwargs` overwrite special tokens if both are present + template_kwargs = {**self.tokenizer.special_tokens_map, **kwargs} prompt, generation_indices = render_jinja_template( conversations=conversations, tools=tools, diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index fdc730df55c8..371e80168820 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -1617,9 +1617,9 @@ def set_config_for_less_flaky_test(config): # norm layers (layer/group norm, etc.) could cause flaky tests when the tensors have very small variance. # (We don't need the original epsilon values to check eager/sdpa matches) - attrs = ["text_config", "vision_config", "text_encoder", "audio_encoder", "decoder"] + attrs = ["text_config", "vision_config", "audio_config", "text_encoder", "audio_encoder", "decoder"] for attr in attrs: - if hasattr(config, attr): + if hasattr(config, attr) and getattr(config, attr) is not None: for target_attr in target_attrs: setattr(getattr(config, attr), target_attr, 1.0) diff --git a/src/transformers/utils/chat_parsing_utils.py b/src/transformers/utils/chat_parsing_utils.py index 1035d65b094a..44983cb33a25 100644 --- a/src/transformers/utils/chat_parsing_utils.py +++ b/src/transformers/utils/chat_parsing_utils.py @@ -26,6 +26,26 @@ jmespath = None +def _gemma4_json_to_json(text: str) -> str: + """Convert Gemma4 tool call format (unquoted keys, ``<|"|>`` string delimiters) to valid JSON.""" + strings = [] + + def _capture(m): + strings.append(m.group(1)) + return f"\x00{len(strings) - 1}\x00" + + # Grab the inside of gemma-quotes and store them for later + text = re.sub(r'<\|"\|>(.*?)<\|"\|>', _capture, text, flags=re.DOTALL) + # Add quotes to the bare keys elsewhere + text = re.sub(r"(?<=[{,])(\w+):", r'"\1":', text) + + # Put the inside of the quotes back afterwards + for i, s in enumerate(strings): + text = text.replace(f"\x00{i}\x00", json.dumps(s)) + + return text + + def _parse_re_match(node_match: re.Match) -> dict | str: # If the regex has named groups, return a dict of those groups if node_match.groupdict(): @@ -126,6 +146,13 @@ def recursive_parse( # a substring to parse, if needed. if "x-parser" in node_schema: parser = node_schema["x-parser"] + if parser == "gemma4-tool-call": + if not isinstance(node_content, str): + raise TypeError( + f"Node has Gemma4 tool call parser but got non-string input: {node_content}\nSchema: {node_schema}" + ) + node_content = _gemma4_json_to_json(node_content) + parser = "json" # fall through to the JSON parser below - don't add an elif! if parser == "json": if not isinstance(node_content, str): raise TypeError( diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 055b852be0b3..15df7036eb35 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -997,32 +997,35 @@ def _prepare_model_kwargs(model_inputs, signature): # Prepare padding on common inputs (pad length 32) input_ids = inputs_dict["input_ids"] attention_mask = inputs_dict["attention_mask"] - token_type_ids = inputs_dict.get("token_type_ids", None) pad_token_id = getattr(config.get_text_config(decoder=True), "pad_token_id", None) or 0 pad_size = (input_ids.shape[0], 32, *input_ids.shape[2:]) padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * pad_token_id - padded_input_ids = torch.cat((padding, input_ids), dim=1) - padded_attention_mask = torch.cat( + + padded_inputs_dict = copy.deepcopy(inputs_dict) + padded_inputs_dict["input_ids"] = torch.cat((padding, input_ids), dim=1) + padded_inputs_dict["attention_mask"] = torch.cat( (torch.zeros(pad_size[:2], dtype=input_ids.dtype, device=torch_device), attention_mask), dim=1 ) - if token_type_ids is not None: - padded_token_type_ids = torch.cat( + if inputs_dict.get("token_type_ids") is not None: + padded_inputs_dict["token_type_ids"] = torch.cat( ( # Assumption: `0` is a good default value for padding token type ids torch.zeros(pad_size[:2], dtype=input_ids.dtype, device=torch_device), - token_type_ids, + inputs_dict["token_type_ids"], + ), + dim=1, + ) + + if inputs_dict.get("mm_token_type_ids") is not None: + padded_inputs_dict["mm_token_type_ids"] = torch.cat( + ( + # `0` is a default value of text-modality type ids + torch.zeros(pad_size[:2], dtype=input_ids.dtype, device=torch_device), + inputs_dict["mm_token_type_ids"], ), dim=1, ) - else: - padded_token_type_ids = None - # Get output logits from inputs with left-padding (pad length 32) - padded_inputs_dict = copy.deepcopy(inputs_dict) - padded_inputs_dict["input_ids"] = padded_input_ids - padded_inputs_dict["attention_mask"] = padded_attention_mask - if padded_token_type_ids is not None: - padded_inputs_dict["token_type_ids"] = padded_token_type_ids padded_inputs_dict.update(padded_custom_inputs) model_kwargs_with_padding = _prepare_model_kwargs(padded_inputs_dict, signature) diff --git a/tests/models/gemma4/__init__.py b/tests/models/gemma4/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/gemma4/test_image_processing_gemma4.py b/tests/models/gemma4/test_image_processing_gemma4.py new file mode 100644 index 000000000000..c114025f0a63 --- /dev/null +++ b/tests/models/gemma4/test_image_processing_gemma4.py @@ -0,0 +1,247 @@ +# Copyright 2026 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from transformers.models.gemma4.image_processing_pil_gemma4 import get_aspect_ratio_preserving_size +from transformers.testing_utils import require_torch, require_vision +from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available + +from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs + + +if is_torch_available(): + import torch + +if is_vision_available(): + from PIL import Image + + if is_torchvision_available(): + pass + + +class Gemma4ImageProcessingTester: + def __init__( + self, + parent, + batch_size=7, + num_channels=3, + min_resolution=30, + max_resolution=400, + do_resize=True, + do_normalize=False, + image_mean=None, + image_std=None, + do_convert_rgb=True, + patch_size=6, + max_soft_tokens=70, + pooling_kernel_size=1, + ): + super().__init__() + image_mean = image_mean if image_mean is not None else [0.0, 0.0, 0.0] + image_std = image_std if image_std is not None else [1.0, 1.0, 1.0] + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.min_resolution = min_resolution + self.max_resolution = max_resolution + self.do_resize = do_resize + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + self.do_convert_rgb = do_convert_rgb + self.patch_size = patch_size + self.max_soft_tokens = max_soft_tokens + self.pooling_kernel_size = pooling_kernel_size + + def prepare_image_processor_dict(self): + return { + "do_resize": self.do_resize, + "do_normalize": self.do_normalize, + "image_mean": self.image_mean, + "image_std": self.image_std, + "do_convert_rgb": self.do_convert_rgb, + "patch_size": self.patch_size, + "max_soft_tokens": self.max_soft_tokens, + "pooling_kernel_size": self.pooling_kernel_size, + } + + # Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTester.prepare_image_inputs + def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False): + return prepare_image_inputs( + batch_size=self.batch_size, + num_channels=self.num_channels, + min_resolution=self.min_resolution, + max_resolution=self.max_resolution, + equal_resolution=equal_resolution, + numpify=numpify, + torchify=torchify, + ) + + def expected_output_image_shape(self, images=None): + """Return the expected per-image output shape: (max_patches, patch_pixels).""" + max_patches = self.max_soft_tokens * self.pooling_kernel_size**2 + # Images are always converted to RGB (3 channels) before patchification + patch_pixels = self.patch_size**2 * 3 + return max_patches, patch_pixels + + +@require_torch +@require_vision +class Gemma4ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): + def setUp(self): + super().setUp() + self.image_processor_tester = Gemma4ImageProcessingTester(self) + + @unittest.skip("Gemma4 patchification requires RGB (3-channel) images; 4-channel inputs are unsupported.") + def test_call_numpy_4_channels(self): + pass + + @property + def image_processor_dict(self): + return self.image_processor_tester.prepare_image_processor_dict() + + def test_image_processor_properties(self): + """Test that all expected attributes are present.""" + for image_processing_class in self.image_processing_classes.values(): + image_processing = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "do_normalize")) + self.assertTrue(hasattr(image_processing, "image_mean")) + self.assertTrue(hasattr(image_processing, "image_std")) + self.assertTrue(hasattr(image_processing, "do_convert_rgb")) + self.assertTrue(hasattr(image_processing, "patch_size")) + self.assertTrue(hasattr(image_processing, "max_soft_tokens")) + self.assertTrue(hasattr(image_processing, "pooling_kernel_size")) + + def test_image_processor_defaults(self): + """Test default parameter values for Gemma4 matching VARASP_SL280_K3.""" + for image_processing_class in self.image_processing_classes.values(): + proc = image_processing_class() + self.assertEqual(proc.patch_size, 16) + self.assertEqual(proc.max_soft_tokens, 280) + self.assertEqual(proc.pooling_kernel_size, 3) + self.assertFalse(proc.do_normalize) + self.assertEqual(list(proc.image_mean), [0.0, 0.0, 0.0]) + self.assertEqual(list(proc.image_std), [1.0, 1.0, 1.0]) + self.assertEqual(proc.resample, 3) + + def test_image_processor_from_dict_with_kwargs(self): + for image_processing_class in self.image_processing_classes.values(): + image_processor = image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.patch_size, 6) + self.assertEqual(image_processor.max_soft_tokens, 70) + + image_processor = image_processing_class.from_dict(self.image_processor_dict, patch_size=18) + self.assertEqual(image_processor.patch_size, 18) + + def test_output_keys(self): + """Test that the output contains pixel_values, image_position_ids, and num_soft_tokens_per_image.""" + for image_processing_class in self.image_processing_classes.values(): + image_processing = image_processing_class(**self.image_processor_dict) + image = Image.fromarray(np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)) + result = image_processing(image, return_tensors="pt") + self.assertIn("pixel_values", result) + self.assertIn("image_position_ids", result) + self.assertIn("num_soft_tokens_per_image", result) + + def test_aspect_ratio_preserving_resize_dimensions(self): + """Test resize dimension calculations match C++ source of truth VisionAspectRatioTests.""" + for patch_size, max_patches, pooling_kernel_size, height, width, expectation in [ + (16, 256, 1, 256, 256, (256, 256)), + (16, 256, 1, 512, 512, (256, 256)), + (10, 200, 1, 50, 10000, (10, 2000)), + (10, 200, 1, 25, 10000, (10, 2000)), + (16, 2304, 6, 2785, 34, (6144, 96)), + (10, 200, 1, 25, 20000, (10, 2000)), + (4, 64, 2, 50, 1000, (8, 128)), + (5, 100, 3, 100, 100, (45, 45)), + (5, 20, 3, 5, 100, (15, 30)), + ]: + target_h, target_w = get_aspect_ratio_preserving_size( + height=height, + width=width, + patch_size=patch_size, + max_patches=max_patches, + pooling_kernel_size=pooling_kernel_size, + ) + side_mult = patch_size * pooling_kernel_size + + self.assertEqual((target_h, target_w), expectation) + self.assertEqual(target_h % side_mult, 0, f"Resized height {target_h} not divisible by {side_mult}") + self.assertEqual(target_w % side_mult, 0, f"Resized width {target_w} not divisible by {side_mult}") + + @parameterized.expand([(70), (140), (280), (560), (1120)]) + def test_max_soft_tokens_values(self, max_soft_tokens): + """Test that the processor produces valid patchified output for each supported max_soft_tokens value.""" + for image_processing_class in self.image_processing_classes.values(): + processor = image_processing_class(patch_size=16, max_soft_tokens=max_soft_tokens, pooling_kernel_size=3) + image = Image.fromarray(np.random.randint(0, 255, (200, 300, 3), dtype=np.uint8)) + result = processor(image, return_tensors="pt") + + max_patches = max_soft_tokens * 3**2 + patch_pixels = 16 * 16 * 3 + self.assertEqual(result.pixel_values.shape, (1, max_patches, patch_pixels)) + self.assertEqual(result.image_position_ids.shape, (1, max_patches, 2)) + + # Verify real patches don't exceed the budget + real_mask = result.image_position_ids[0, :, 0] >= 0 + num_real = real_mask.sum().item() + self.assertLessEqual(num_real, max_patches) + + def test_position_ids_structure(self): + """Test that image_position_ids has correct real and padding structure.""" + for image_processing_class in self.image_processing_classes.values(): + image_processing = image_processing_class(**self.image_processor_dict) + image = Image.fromarray(np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)) + result = image_processing(image, return_tensors="pt") + + position_ids = result.image_position_ids[0] # (max_patches, 2) + max_patches = ( + self.image_processor_tester.max_soft_tokens * self.image_processor_tester.pooling_kernel_size**2 + ) + + # Real positions should be non-negative + real_mask = position_ids[:, 0] >= 0 + num_real = real_mask.sum().item() + self.assertGreater(num_real, 0) + self.assertLessEqual(num_real, max_patches) + + # Padding positions should be (-1, -1) + pad_mask = ~real_mask + if pad_mask.any(): + pad_positions = position_ids[pad_mask] + self.assertTrue((pad_positions == -1).all()) + + # Real positions should come before padding positions + if pad_mask.any(): + last_real_idx = torch.where(real_mask)[0][-1].item() + first_pad_idx = torch.where(pad_mask)[0][0].item() + self.assertEqual(last_real_idx + 1, first_pad_idx) + + def test_padding_patches_are_zero(self): + """Test that padding patches in pixel_values are filled with zeros.""" + for image_processing_class in self.image_processing_classes.values(): + image_processing = image_processing_class(**self.image_processor_dict) + image = Image.fromarray(np.random.randint(1, 255, (100, 100, 3), dtype=np.uint8)) + result = image_processing(image, return_tensors="pt") + + position_ids = result.image_position_ids[0] + pad_mask = position_ids[:, 0] < 0 + if pad_mask.any(): + pad_patches = result.pixel_values[0, pad_mask] + self.assertTrue((pad_patches == 0).all()) diff --git a/tests/models/gemma4/test_modeling_gemma4.py b/tests/models/gemma4/test_modeling_gemma4.py new file mode 100644 index 000000000000..c63e9ba20165 --- /dev/null +++ b/tests/models/gemma4/test_modeling_gemma4.py @@ -0,0 +1,866 @@ +# Copyright 2026 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch Gemma4 model.""" + +import logging +import unittest + +import pytest +from parameterized import parameterized + +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + Gemma4Config, + Gemma4TextConfig, + is_torch_available, +) +from transformers.testing_utils import ( + Expectations, + cleanup, + is_flash_attn_2_available, + require_deterministic_for_xpu, + require_flash_attn, + require_torch, + require_torch_accelerator, + require_torch_large_accelerator, + slow, + torch_device, +) + +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor + + +if is_torch_available(): + import torch + + from transformers import ( + Gemma4ForCausalLM, + Gemma4ForConditionalGeneration, + Gemma4Model, + Gemma4Processor, + Gemma4TextModel, + ) + from transformers.pytorch_utils import is_torch_greater_or_equal + + +class Gemma4TextModelTester(CausalLMModelTester): + if is_torch_available(): + config_class = Gemma4TextConfig + base_model_class = Gemma4TextModel + causal_lm_class = Gemma4ForCausalLM + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.num_hidden_layers = 4 # override to correctly test sharing cache pattern + self.num_kv_shared_layers = 2 # important to override + self.layer_types = [ + "sliding_attention", + "full_attention", + "sliding_attention", + "full_attention", + ] # similarly we want to test sharing on both types + self.global_head_dim = self.head_dim # gemma4 use a different head_dim for full and sliding layers + + # To make model small + self.vocab_size_per_layer_input = 99 + self.hidden_size_per_layer_input = 16 + + # To activate moe blocks + self.enable_moe_block = True + self.moe_intermediate_size = 16 + self.top_k_experts = 2 + + # Test if bidirectional image mask path works + self.use_bidirectional_attention = "vision" + + +@require_torch +class Gemma4TextModelTest(CausalLMModelTest, unittest.TestCase): + model_tester_class = Gemma4TextModelTester + # used in `test_torch_compile_for_training` + _torch_compile_train_cls = Gemma4ForCausalLM if is_torch_available() else None + + @unittest.skip("We need 4 layers to correctly test cache sharing.") + def test_num_layers_is_small(self): + pass + + @unittest.skip("Gemma4 uses different rope per layer type, which is not compatible with this test") + def test_model_rope_scaling_frequencies(self): + pass + + @parameterized.expand([("linear",), ("dynamic",), ("yarn",)]) + @unittest.skip("Gemma4 uses different rope per layer type, which is not compatible with this test") + def test_model_rope_scaling_from_config(self): + pass + + @unittest.skip( + "Gemma4 cannot use random inputs_embeds, as it needs to reverse them when input_ids is not provided" + ) + def test_generate_from_random_inputs_embeds(self): + pass + + @unittest.skip( + "Flaky on CI, but not locally on Mac. If model is set to fp32 instead of bf16, not flaky anymore." + "TODO Cyril: investigate where the loss of precision between bf16 and fp32 comes from." + ) + def test_sdpa_padding_matches_padding_free_with_position_ids(self): + pass + + +class Gemma4Audio2TextModelTester: + def __init__( + self, + parent, + image_token_id=4, + boi_token_id=5, + eoi_token_id=6, + audio_token_id=7, + boa_token_id=8, + eoa_token_index=9, + video_token_id=10, + seq_length=50, + audio_seq_length=96, + audio_num_channels=16, + is_training=True, + audio_config={ + "hidden_size": 32, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "hidden_act": "silu", + "subsampling_conv_channels": [16, 8], + "conv_kernel_size": 3, + "attention_chunk_size": 4, + "attention_context_left": 5, + "attention_context_right": 0, + "output_proj_dims": 32, + # Clipped linears register inf/-inf buffers which cause NaN in test_torch_save_load's + # comparison logic (inf - inf = NaN). Disable for testing. + "use_clipped_linears": False, + }, + ): + self.parent = parent + self.image_token_id = image_token_id + self.boi_token_id = boi_token_id + self.eoi_token_id = eoi_token_id + self.audio_token_id = audio_token_id + self.boa_token_id = boa_token_id + self.eoa_token_index = eoa_token_index + self.video_token_id = video_token_id + self.llm_tester = Gemma4TextModelTester(self.parent) + self.llm_tester.use_bidirectional_attention = None + self.text_config = self.llm_tester.get_config() + self.audio_config = audio_config + self.seq_length = seq_length + self.audio_seq_length = audio_seq_length + self.audio_num_channels = audio_num_channels + self.pad_token_id = self.text_config.pad_token_id + + self.num_hidden_layers = self.text_config.num_hidden_layers + self.vocab_size = self.text_config.vocab_size + self.hidden_size = self.text_config.hidden_size + self.num_attention_heads = self.text_config.num_attention_heads + self.is_training = is_training + + self.batch_size = 3 + self.encoder_seq_length = seq_length + + def get_config(self): + return Gemma4Config( + text_config=self.text_config, + vision_config=None, + audio_config=self.audio_config, + image_token_id=self.image_token_id, + boi_token_id=self.boi_token_id, + eoi_token_id=self.eoi_token_id, + audio_token_id=self.audio_token_id, + boa_token_id=self.boa_token_id, + eoa_token_index=self.eoa_token_index, + video_token_id=self.video_token_id, + ) + + def prepare_config_and_inputs(self): + input_features = floats_tensor([self.batch_size, self.audio_seq_length, self.audio_num_channels]) + input_features_mask = torch.ones(self.batch_size, self.audio_seq_length, dtype=torch.bool) + config = self.get_config() + return config, input_features, input_features_mask + + def prepare_config_and_inputs_for_common(self): + config, input_features, input_features_mask = self.prepare_config_and_inputs() + input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1 + attention_mask = input_ids.ne(self.pad_token_id).to(torch_device) + + # Ensure no tokens accidentally match special token IDs + for token_id in [config.image_token_id, config.video_token_id, config.audio_token_id]: + input_ids[input_ids == token_id] = self.pad_token_id + + # The audio encoder produces audio_seq_length / 4 tokens per audio sample after subsampling. + # We need that many audio placeholder tokens per sequence in input_ids. + num_audio_tokens = self.audio_seq_length // 4 + input_ids[:, :num_audio_tokens] = config.audio_token_id + + inputs_dict = { + "input_features": input_features, + "input_features_mask": input_features_mask, + "input_ids": input_ids, + "attention_mask": attention_mask, + } + return config, inputs_dict + + +@require_torch +class Gemma4Audio2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + all_model_classes = (Gemma4Model, Gemma4ForConditionalGeneration) if is_torch_available() else () + all_generative_model_classes = (Gemma4ForConditionalGeneration,) if is_torch_available() else () + + def setUp(self): + self.model_tester = Gemma4Audio2TextModelTester(self) + self.config_tester = ConfigTester(self, config_class=Gemma4Config, hidden_size=37) + + @unittest.skip("The tester has no image in input dict") + def test_get_image_features_hidden_states(self): + pass + + @unittest.skip("The tester has no image in input dict") + def test_get_image_features_attentions(self): + pass + + @unittest.skip("The tester has no image in input dict") + @parameterized.expand([True, False, None]) + def test_get_image_features_output(self, return_dict: bool | None): + pass + + @unittest.skip("The tester has no videos in input dict") + def test_get_video_features_hidden_states(self): + pass + + @unittest.skip("The tester has no videos in input dict") + def test_get_video_features_attentions(self): + pass + + @unittest.skip("The tester has no videos in input dict") + @parameterized.expand([True, False, None]) + def test_get_video_features_output(self, return_dict: bool | None): + pass + + @unittest.skip("We need 4 layers to correctly test cache sharing.") + def test_num_layers_is_small(self): + pass + + @unittest.skip("Gemma4 needs correct embeddings for per-layer-input computation, random won't work!") + def test_generate_from_random_inputs_embeds(self): + pass + + +class Gemma4Vision2TextModelTester: + def __init__( + self, + parent, + mm_tokens_per_image=2, + image_token_id=4, + video_token_id=7, + audio_token_id=8, + boi_token_id=5, + eoi_token_id=6, + seq_length=25, + is_training=True, + vision_config={ + "use_labels": True, + "image_size": 20, + "patch_size": 5, + "num_channels": 3, + "is_training": True, + "hidden_size": 32, + "num_key_value_heads": 1, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "intermediate_size": 37, + "dropout": 0.1, + "attention_dropout": 0.1, + "initializer_range": 0.02, + }, + ): + self.parent = parent + # `image_token_id` is set to 0 to pass "resize_embeddings" test, do not modify + self.mm_tokens_per_image = mm_tokens_per_image + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.audio_token_id = audio_token_id + self.boi_token_id = boi_token_id + self.eoi_token_id = eoi_token_id + self.llm_tester = Gemma4TextModelTester(self.parent) + self.text_config = self.llm_tester.get_config() + self.vision_config = vision_config + self.seq_length = seq_length + self.pad_token_id = self.text_config.pad_token_id + + self.num_hidden_layers = self.text_config.num_hidden_layers + self.vocab_size = self.text_config.vocab_size + self.hidden_size = self.text_config.hidden_size + self.num_attention_heads = self.text_config.num_attention_heads + self.is_training = is_training + + self.batch_size = 3 + self.num_channels = vision_config["num_channels"] + self.image_size = vision_config["image_size"] + self.encoder_seq_length = seq_length + + def get_config(self): + return Gemma4Config( + text_config=self.text_config, + vision_config=self.vision_config, + image_token_id=self.image_token_id, + video_token_id=self.video_token_id, + audio_token_id=self.audio_token_id, + boi_token_id=self.boi_token_id, + eoi_token_id=self.eoi_token_id, + mm_tokens_per_image=self.mm_tokens_per_image, + ) + + def prepare_config_and_inputs(self): + config = self.get_config() + config.vision_config.pooling_kernel_size = 2 + + # (num_images, max_num_patches, patch_size * patch_size * num_channels) + patch_size = config.vision_config.patch_size + pixel_values = floats_tensor( + [ + self.batch_size, + self.vision_config["image_size"], + patch_size * patch_size * self.vision_config["num_channels"], + ] + ) + # (num_images, max_num_patches, 2) for height/width positions. Let it be all ones for testign + pixel_position_ids = torch.ones(self.vision_config["image_size"], device=torch_device, dtype=torch.long) + pixel_position_ids = pixel_position_ids[None, :, None].repeat(self.batch_size, 1, 2) + + return config, pixel_values, pixel_position_ids + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values, pixel_position_ids = config_and_inputs + input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1 + attention_mask = input_ids.ne(self.pad_token_id).to(torch_device) + + # Ensure no tokens accidentally match special token IDs + for token_id in [config.image_token_id, config.video_token_id, config.audio_token_id]: + input_ids[input_ids == token_id] = self.pad_token_id + input_ids[:, :1] = config.image_token_id + + mm_token_type_ids = torch.zeros_like(input_ids) + mm_token_type_ids[input_ids == config.image_token_id] = 1 + + inputs_dict = { + "pixel_values": pixel_values, + "image_position_ids": pixel_position_ids, + "input_ids": input_ids, + "attention_mask": attention_mask, + "mm_token_type_ids": mm_token_type_ids, + } + return config, inputs_dict + + +@require_torch +class Gemma4Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + all_model_classes = (Gemma4Model, Gemma4ForConditionalGeneration) if is_torch_available() else () + all_generative_model_classes = (Gemma4ForConditionalGeneration,) if is_torch_available() else () + additional_model_inputs = ["mm_token_type_ids"] + + def setUp(self): + self.model_tester = Gemma4Vision2TextModelTester(self) + self.config_tester = ConfigTester(self, config_class=Gemma4Config, hidden_size=37) + + @unittest.skip("The tester has no audios in input dict") + def test_get_audio_features_hidden_states(self): + pass + + @unittest.skip("The tester has no audios in input dict") + def test_get_audio_features_attentions(self): + pass + + @unittest.skip("The tester has no audios in input dict") + @parameterized.expand([True, False, None]) + def test_get_audio_features_output(self, return_dict: bool | None): + pass + + @unittest.skip("The tester has no videos in input dict") + def test_get_video_features_hidden_states(self): + pass + + @unittest.skip("The tester has no videos in input dict") + def test_get_video_features_attentions(self): + pass + + @unittest.skip("The tester has no videos in input dict") + @parameterized.expand([True, False, None]) + def test_get_video_features_output(self, return_dict: bool | None): + pass + + @unittest.skip("We need 4 layers to correctly test cache sharing.") + def test_num_layers_is_small(self): + pass + + @unittest.skip("Gemma4 needs correct embeddings for per-layer-input computation, random won't work!") + def test_generate_from_random_inputs_embeds(self): + pass + + +@unittest.skip("Integration Tests are not up-to-date yet! TODO Cyril: update me pretty pretty please!") +@slow +@require_torch_accelerator +class Gemma4IntegrationTest(unittest.TestCase): + def setUp(self): + self.processor = Gemma4Processor.from_pretrained("google/gemma-4-e2b-it", padding_side="left") + + url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png" + self.messages = [ + {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, + { + "role": "user", + "content": [ + {"type": "image", "url": url}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + + def tearDown(self): + cleanup(torch_device, gc_collect=True) + + @require_deterministic_for_xpu + def test_model_4b_bf16(self): + model_id = "google/gemma-4-e2b-it" + + model = Gemma4ForConditionalGeneration.from_pretrained(model_id, dtype=torch.bfloat16).to(torch_device) + + inputs = self.processor.apply_chat_template( + self.messages, + tokenize=True, + return_dict=True, + return_tensors="pt", + add_generation_prompt=True, + ).to(torch_device) + + # cache_implementation="hybrid" an in the original transformers implementation + output = model.generate(**inputs, max_new_tokens=30, do_sample=False, cache_implementation="hybrid") + output_text = self.processor.batch_decode(output, skip_special_tokens=True) + + EXPECTED_TEXTS = Expectations( + { + ("xpu", 3): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown and white cow standing on a sandy beach with turquoise water in the background. It looks like a lovely,'], + ("cuda", (8, 0)): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear turquoise water and a blue sky in the background. It looks like'], + ("cuda", (8, 6)): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear blue water and a blue sky in the background. It looks like'], + ("rocm", (9, 4)): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with turquoise water and a blue sky in the background. It looks like a'], + ("rocm", (9, 5)): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown and white cow standing on a sandy beach with turquoise water and a distant coastline in the background. It looks'], + } + ) # fmt: skip + EXPECTED_TEXT = EXPECTED_TEXTS.get_expectation() + self.assertEqual(output_text, EXPECTED_TEXT) + + @require_torch_large_accelerator + @require_deterministic_for_xpu + def test_model_4b_batch(self): + model_id = "google/gemma-4-e2b-it" + + model = Gemma4ForConditionalGeneration.from_pretrained(model_id, dtype=torch.bfloat16).to(torch_device) + + messages_2 = [ + {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, + { + "role": "user", + "content": [ + { + "type": "image", + "url": "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png", + }, + {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}, + {"type": "text", "text": "Are these images identical?"}, + ], + }, + ] + + inputs = self.processor.apply_chat_template( + [self.messages, messages_2], + tokenize=True, + return_dict=True, + return_tensors="pt", + padding=True, + add_generation_prompt=True, + ).to(torch_device) + + # cache_implementation="hybrid" an in the original transformers implementation + output = model.generate(**inputs, max_new_tokens=30, do_sample=False, cache_implementation="hybrid") + output_text = self.processor.batch_decode(output, skip_special_tokens=True) + + EXPECTED_TEXTS = Expectations( + { + ("xpu", 3): + [ + 'user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown and white cow standing on a sandy beach next to a turquoise ocean. It looks like a very sunny and', + 'user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, these images are not identical. They depict very different scenes:\n\n* **Image 1** shows a cow standing on a beach.', + ], + ("cuda", (8,0)): + [ + 'user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear turquoise water and a blue sky in the background. It looks like', + "user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, these images are not identical. \n\nHere's a breakdown of the differences:\n\n* **Image 1:** Shows a brown" + ], + ("cuda", (8,6)): + [ + 'user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear blue water and a blue sky in the background. It looks like', + "user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, these images are not identical. \n\nHere's a breakdown of the differences:\n\n* **Image 1:** Shows a brown" + ], + ("rocm", (9, 4)): + [ + 'user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear turquoise water and a blue sky in the background. It looks like', + "user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, these images are not identical. \n\nHere's a breakdown of the differences:\n\n* **Image 1:** Shows a cow" + ], + ("rocm", (9, 5)): + [ + 'user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown and white cow standing on a sandy beach next to a turquoise ocean. There are some clouds in the blue', + 'user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, these images are not identical. They depict very different scenes. \n\n* **Image 1** shows a cow standing on a beach', + ], + } + ) # fmt: skip + EXPECTED_TEXT = EXPECTED_TEXTS.get_expectation() + self.assertEqual(output_text, EXPECTED_TEXT) + + @require_torch_large_accelerator + def test_model_4b_crops(self): + model_id = "google/gemma-4-e2b-it" + + model = Gemma4ForConditionalGeneration.from_pretrained(model_id, dtype=torch.bfloat16).to(torch_device) + + crop_config = { + "images_kwargs": { + "do_pan_and_scan": True, + "pan_and_scan_max_num_crops": 448, + "pan_and_scan_min_crop_size": 32, + "pan_and_scan_min_ratio_to_activate": 0.3, + } + } + + inputs = self.processor.apply_chat_template( + self.messages, + tokenize=True, + return_dict=True, + return_tensors="pt", + add_generation_prompt=True, + **crop_config, + ).to(torch_device) + + # cache_implementation="hybrid" an in the original transformers implementation + output = model.generate(**inputs, max_new_tokens=30, do_sample=False, cache_implementation="hybrid") + output_text = self.processor.batch_decode(output, skip_special_tokens=True) + + EXPECTED_NUM_IMAGES = 3 # one for the origin image and two crops of images + EXPECTED_TEXTS = Expectations( + { + ("xpu", 3): ['user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There are clouds in the blue sky above.'], + ("cuda", 7): [], + ("cuda", (8, 6)): ["user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There's a clear blue sky with some white clouds above."], + ("cuda", (8, 0)): ["user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There's a blue sky with some white clouds in the background"], + ("rocm", (9, 4)): ["user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There's a bright blue sky with some white clouds in the"], + ("rocm", (9, 5)): ["user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There's a blue sky with some white clouds in the background"] + } + ) # fmt: skip + EXPECTED_TEXT = EXPECTED_TEXTS.get_expectation() + self.assertEqual(len(inputs["pixel_values"]), EXPECTED_NUM_IMAGES) + print(f"Generated text: {output_text}") + self.assertEqual(output_text, EXPECTED_TEXT) + + @require_torch_large_accelerator + @require_deterministic_for_xpu + def test_model_4b_batch_crops(self): + model_id = "google/gemma-4-e2b-it" + + model = Gemma4ForConditionalGeneration.from_pretrained(model_id, dtype=torch.bfloat16).to(torch_device) + crop_config = { + "images_kwargs": { + "do_pan_and_scan": True, + "pan_and_scan_max_num_crops": 448, + "pan_and_scan_min_crop_size": 32, + "pan_and_scan_min_ratio_to_activate": 0.3, + } + } + messages_2 = [ + {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, + { + "role": "user", + "content": [ + { + "type": "image", + "url": "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png", + }, + {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}, + {"type": "text", "text": "Are these images identical?"}, + ], + }, + ] + + inputs = self.processor.apply_chat_template( + [self.messages, messages_2], + tokenize=True, + return_dict=True, + return_tensors="pt", + padding=True, + add_generation_prompt=True, + **crop_config, + ).to(torch_device) + + # cache_implementation="hybrid" an in the original transformers implementation + output = model.generate(**inputs, max_new_tokens=30, do_sample=False, cache_implementation="hybrid") + output_text = self.processor.batch_decode(output, skip_special_tokens=True) + EXPECTED_NUM_IMAGES = 9 # 3 * (one for the origin image and two crops of images) = 9 + EXPECTED_TEXTS = Expectations( + { + ("xpu", 3): [ + 'user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There are clouds in the blue sky above.', + 'user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nAre these images identical?\nmodel\nNo, the images are not identical. \n\nThe first image shows a cow on a beach, while the second image shows a street scene with a', + ], + ("cuda", 7): [], + ("cuda", (8,0)): [ + "user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There's a blue sky with some white clouds in the background", + 'user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nAre these images identical?\nmodel\nNo, the images are not identical. \n\nThe first image shows a cow on a beach, while the second image shows a street scene with a' + ], + ("cuda", (8, 6)): [ + "user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There's a bright blue sky with some white clouds in the", + 'user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nAre these images identical?\nmodel\nNo, the images are not identical. \n\nThe first image shows a cow on a beach, while the second image shows a street scene with a' + ], + ("rocm", (9, 4)) : [ + "user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There's a bright blue sky with some white clouds in the", + 'user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nAre these images identical?\nmodel\nNo, the images are not identical. \n\nThe first image shows a cow on a beach, while the second image shows a street scene with a' + ], + ("rocm", (9, 5)) : [ + 'user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There are clouds in the blue sky above.', + 'user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nAre these images identical?\nmodel\nNo, the images are not identical. \n\nThe first image shows a cow on a beach, while the second image shows a street scene with a', + ], + } + ) # fmt: skip + EXPECTED_TEXT = EXPECTED_TEXTS.get_expectation() + self.assertEqual(len(inputs["pixel_values"]), EXPECTED_NUM_IMAGES) + self.assertEqual(output_text, EXPECTED_TEXT) + + @require_torch_large_accelerator + def test_model_4b_multiimage(self): + model_id = "google/gemma-4-e2b-it" + + model = Gemma4ForConditionalGeneration.from_pretrained(model_id, dtype=torch.bfloat16).to(torch_device) + + messages = [ + {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, + { + "role": "user", + "content": [ + {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}, + {"type": "text", "text": "What do you see here?"}, + ], + }, + ] + + inputs = self.processor.apply_chat_template( + messages, + tokenize=True, + return_dict=True, + return_tensors="pt", + padding=True, + add_generation_prompt=True, + ).to(torch_device) + + # cache_implementation="hybrid" an in the original transformers implementation + output = model.generate(**inputs, max_new_tokens=30, do_sample=False, cache_implementation="hybrid") + output_text = self.processor.batch_decode(output, skip_special_tokens=True) + EXPECTED_TEXTS = Expectations( + { + ("xpu", 3): ["user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nOkay, let's break down what I see in this image!\n\nHere's a description of the scene:\n\n* **Chinese Arch"], + ("cuda", 7): [], + ("cuda", (8, 0)): ["user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nOkay, let's break down what I see in this image:\n\n**Overall Scene:**\n\nIt looks like a street scene in a vibrant,"], + ("cuda", (8, 6)): ["user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nOkay, let's break down what I see in this image:\n\n**Overall Scene:**\n\nIt appears to be a street scene in a city"], + ("rocm", (9, 4)): ["user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nOkay, let's break down what I see in this image:\n\n**Overall Scene:**\n\nIt appears to be a street scene in a vibrant"], + ("rocm", (9, 5)): ["user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nOkay, let's break down what I see in this image:\n\n**Main Features:**\n\n* **Chinese Archway:** The most prominent"], + } + ) # fmt: skip + EXPECTED_TEXT = EXPECTED_TEXTS.get_expectation() + self.assertEqual(output_text, EXPECTED_TEXT) + + @require_deterministic_for_xpu + def test_model_1b_text_only(self): + model_id = "google/gemma-3-1b-it" + + model = Gemma4ForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16).to(torch_device) + tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left") + inputs = tokenizer("Write a poem about Machine Learning.", return_tensors="pt").to(torch_device) + + # cache_implementation="hybrid" an in the original transformers implementation + output = model.generate(**inputs, max_new_tokens=30, do_sample=False, cache_implementation="hybrid") + output_text = tokenizer.batch_decode(output, skip_special_tokens=True) + + EXPECTED_TEXTS = Expectations( + { + ("xpu", 3): ['Write a poem about Machine Learning.\n\n---\n\nThe data flows, a river deep,\nWith patterns hidden, secrets sleep.\nA neural net, a watchful eye,\nLearning'], + ("cuda", 7): ['Write a poem about Machine Learning.\n\n---\n\nThe data flows, a silent stream,\nInto the neural net, a waking dream.\nAlgorithms hum, a coded grace,\n'], + ("cuda", 8): ['Write a poem about Machine Learning.\n\n---\n\nThe data flows, a silent stream,\nInto the neural net, a waking dream.\nAlgorithms hum, a coded grace,\n'], + ("rocm", (9, 4)): ['Write a poem about Machine Learning.\n\n---\n\nThe data flows, a silent stream,\nInto the neural net, a waking dream.\nAlgorithms hum, a coded grace,\n'], + ("rocm", (9, 5)): ['Write a poem about Machine Learning.\n\n---\n\nThe data flows, a river deep,\nWith patterns hidden, secrets sleep.\nA neural net, a watchful eye,\nLearning'], + } + ) # fmt: skip + EXPECTED_TEXT = EXPECTED_TEXTS.get_expectation() + self.assertEqual(output_text, EXPECTED_TEXT) + + # TODO: raushan FA2 generates gibberish for no reason, check later + @require_flash_attn + @require_torch_large_accelerator + @pytest.mark.flash_attn_test + def test_model_4b_flash_attn(self): + model_id = "google/gemma-4-e2b-it" + + model = Gemma4ForConditionalGeneration.from_pretrained( + model_id, dtype=torch.bfloat16, attn_implementation="flash_attention_2" + ).to(torch_device) + + inputs = self.processor.apply_chat_template( + self.messages, + tokenize=True, + return_dict=True, + return_tensors="pt", + add_generation_prompt=True, + ).to(torch_device) + + # cache_implementation="hybrid" an in the original transformers implementation + output = model.generate(**inputs, max_new_tokens=30, do_sample=False, cache_implementation="hybrid") + output_text = self.processor.batch_decode(output, skip_special_tokens=True) + + EXPECTED_TEXTS = Expectations( + { + ("xpu", 3): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown and white cow standing on a sandy beach with turquoise water and a distant island in the background. It looks like a sunny day'], + ("cuda", 7): [], + ("cuda", 8): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown and white cow standing on a sandy beach with turquoise water and a distant island in the background. It looks like a sunny day'], + ("rocm", (9, 5)): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown and white cow standing on a sandy beach with a turquoise ocean and a distant island in the background. It looks like a sunny'], + } + ) # fmt: skip + EXPECTED_TEXT = EXPECTED_TEXTS.get_expectation() + self.assertEqual(output_text, EXPECTED_TEXT) + + @parameterized.expand([("flash_attention_2",), ("sdpa",), ("eager",)]) + def test_generation_beyond_sliding_window(self, attn_implementation: str): + """Test that we can correctly generate beyond the sliding window. This is non trivial as + we need to correctly slice the attention mask in all cases (because we use a hybrid cache). + Outputs for every attention functions should be coherent and identical. + """ + model_id = "google/gemma-3-1b-it" + + if attn_implementation == "flash_attention_2" and not is_flash_attn_2_available(): + self.skipTest("FlashAttention2 is required for this test.") + + input_text = [ + "This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens + "A list of colors: red, blue", # This will almost all be padding tokens + ] + tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left") + inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device) + + model = AutoModelForCausalLM.from_pretrained( + model_id, attn_implementation=attn_implementation, dtype=torch.float16 + ).to(torch_device) + + # Make sure prefill is larger than sliding window + input_size = inputs.input_ids.shape[-1] + self.assertTrue(input_size > model.config.sliding_window) + + out = model.generate(**inputs, max_new_tokens=20, do_sample=False, cache_implementation="static")[ + :, input_size: + ] + output_text = tokenizer.batch_decode(out) + + EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip + self.assertEqual(output_text, EXPECTED_COMPLETIONS) + + @pytest.mark.torch_export_test + def test_export_text_only_with_hybrid_cache(self): + if not is_torch_greater_or_equal("2.6.0"): + self.skipTest(reason="This test requires torch >= 2.6 to run.") + + from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM + + model_id = "google/gemma-3-1b-it" + model = AutoModelForCausalLM.from_pretrained(model_id) + self.assertEqual(model.config.cache_implementation, "hybrid") + + # Export + hybrid cache + model.eval() + exportable_module = TorchExportableModuleForDecoderOnlyLM(model, batch_size=1, max_cache_len=1024) + exported_program = exportable_module.export( + input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device), + ) + logging.info(f"\nExported program: {exported_program}") + + # Test generation with the exported model + prompt = "What is the capital of France?" + max_new_tokens_to_generate = 20 + # Generate text with the exported model + tokenizer = AutoTokenizer.from_pretrained(model_id) + export_generated_text = TorchExportableModuleForDecoderOnlyLM.generate( + exported_program, tokenizer, prompt, max_new_tokens=max_new_tokens_to_generate + ) + logging.info(f"\nExport generated texts: '{export_generated_text}'") + + input_text = tokenizer(prompt, return_tensors="pt") + with torch.no_grad(): + eager_outputs = model.generate( + **input_text, + max_new_tokens=max_new_tokens_to_generate, + do_sample=False, # Use greedy decoding to match the exported model + cache_implementation="hybrid", + ) + + eager_generated_text = tokenizer.decode(eager_outputs[0], skip_special_tokens=True) + logging.info(f"\nEager generated texts: '{eager_generated_text}'") + + self.assertEqual(export_generated_text, eager_generated_text) + + def test_dynamic_sliding_window_is_default(self): + """ + Test that the dynamic sliding window cache (added in #40039) is the default cache implementation for Gemma4 + models, despite the fact that Hub checkpoints may have `cache_implementation="hybrid"` (static sliding window). + """ + model_id = "google/gemma-3-1b-it" + model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto") + + # the default cache is static sliding window + self.assertEqual(model.config.cache_implementation, "hybrid") + self.assertEqual(model.generation_config.cache_implementation, "hybrid") + + tokenizer = AutoTokenizer.from_pretrained(model_id) + prompt = "What is the capital of France?" + model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + + foward_outputs = model(**model_inputs) + self.assertIn("DynamicSlidingWindowLayer", str(foward_outputs.past_key_values)) + + generate_outputs = model.generate( + **model_inputs, max_new_tokens=2, do_sample=False, return_dict_in_generate=True + ) + self.assertIn("DynamicSlidingWindowLayer", str(generate_outputs.past_key_values)) + + # If we manually specify the cache implementation = "hybrid", it will use the static sliding window cache + generate_outputs = model.generate( + **model_inputs, + max_new_tokens=2, + do_sample=False, + return_dict_in_generate=True, + cache_implementation="hybrid", + ) + self.assertNotIn("DynamicSlidingWindowLayer", str(generate_outputs.past_key_values)) diff --git a/tests/models/gemma4/test_processing_gemma4.py b/tests/models/gemma4/test_processing_gemma4.py new file mode 100644 index 000000000000..347f7d2bfda0 --- /dev/null +++ b/tests/models/gemma4/test_processing_gemma4.py @@ -0,0 +1,315 @@ +# Copyright 2026 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil +import unittest + +import numpy as np +from parameterized import parameterized + +from transformers import Gemma4Processor +from transformers.testing_utils import get_tests_dir, require_vision +from transformers.utils import is_vision_available + +from ...test_processing_common import ProcessorTesterMixin + + +if is_vision_available(): + pass + +SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model") + + +@require_vision +class Gemma4ProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = Gemma4Processor + video_unstructured_max_length = 570 + video_text_kwargs_max_length = 570 + video_text_kwargs_override_max_length = 570 + + @classmethod + def _setup_test_attributes(cls, processor): + cls.image_token = processor.image_token + cls.video_token = processor.video_token + + @classmethod + def _setup_video_processor(cls): + video_processor_class = cls._get_component_class_from_processor("video_processor") + gemma4_video_processor_kwargs = { + "patch_size": 28, + "max_soft_tokens": 70, + "pooling_kernel_size": 3, + "num_frames": 2, + } + return video_processor_class(**gemma4_video_processor_kwargs) + + @classmethod + def _setup_feature_extractor(cls): + feature_extractor_class = cls._get_component_class_from_processor("feature_extractor") + gemma4_feature_extractor_kwargs = {} + return feature_extractor_class(**gemma4_feature_extractor_kwargs) + + @classmethod + def _setup_image_processor(cls): + image_processor_class = cls._get_component_class_from_processor("image_processor") + gemma4_image_processor_kwargs = { + "patch_size": 28, + "max_soft_tokens": 70, + "pooling_kernel_size": 3, + } + return image_processor_class(**gemma4_image_processor_kwargs) + + @classmethod + def _setup_tokenizer(cls): + tokenizer_class = cls._get_component_class_from_processor("tokenizer") + extra_special_tokens = { + "image_token": "", + "boi_token": "", + "eoi_token": "", + "audio_token": "", + "boa_token": "", + "eoa_token": "", + } + tokenizer = tokenizer_class.from_pretrained( + SAMPLE_VOCAB, keep_accents=True, extra_special_tokens=extra_special_tokens + ) + tokenizer.pad_token_id = tokenizer.eos_token_id + return tokenizer + + # Copied from tests.models.llava.test_processing_llava.LlavaProcessorTest.test_get_num_vision_tokens + def test_get_num_vision_tokens(self): + "Tests general functionality of the helper used internally in vLLM" + + processor = self.get_processor() + + output = processor._get_num_multimodal_tokens(image_sizes=[(100, 100), (300, 100), (500, 30)]) + self.assertTrue("num_image_tokens" in output) + self.assertEqual(len(output["num_image_tokens"]), 3) + + self.assertTrue("num_image_patches" in output) + self.assertEqual(len(output["num_image_patches"]), 3) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tmpdirname, ignore_errors=True) + + # TODO: raushan or arthur: add the real chat template + @staticmethod + def prepare_processor_dict(): + return { + "chat_template": "{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}\n {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n {%- set first_user_prefix = \"\" -%}\n {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}\n {{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif -%}\n {%- if (message['role'] == 'assistant') -%}\n {%- set role = \"model\" -%}\n {%- else -%}\n {%- set role = message['role'] -%}\n {%- endif -%}\n {{ '' + role + '\n' + (first_user_prefix if loop.first else \"\") }}\n {%- if message['content'] is string -%}\n {{ message['content'] | trim }}\n {%- elif message['content'] is iterable -%}\n {%- for item in message['content'] -%}\n {%- if item['type'] == 'image' -%}\n {{ '' }}\n {%- elif item['type'] == 'text' -%}\n {{ item['text'] | trim }}\n {%- endif -%}\n {%- endfor -%}\n {%- else -%}\n {{ raise_exception(\"Invalid content type\") }}\n {%- endif -%}\n {{ '\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{'model\n'}}\n{%- endif -%}\n", "image_seq_length": 3, + } # fmt: skip + + # Override as Gemma4 needs images to be an explicitly nested batch + def prepare_image_inputs(self, batch_size: int | None = None): + """This function prepares a list of PIL images for testing""" + images = super().prepare_image_inputs(batch_size) + if isinstance(images, (list, tuple)): + images = [[image] for image in images] + return images + + def test_text_with_image_tokens(self): + feature_extractor = self.get_component("feature_extractor") + image_processor = self.get_component("image_processor") + video_processor = self.get_component("video_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class( + feature_extractor=feature_extractor, + tokenizer=tokenizer, + image_processor=image_processor, + video_processor=video_processor, + ) + text_multi_images = f"{processor.boi_token}{processor.boi_token}Dummy text!" + text_single_image = f"{processor.boi_token}Dummy text!" + + image = self.prepare_image_inputs() + + # We can't be sure what is users intention: if user wants one image per text OR two images for first text and no image for second text + with self.assertRaises(ValueError): + _ = processor(text=[text_single_image, text_single_image], images=[image, image], return_tensors="np") + + # The users is expected to be explicit about which image belong to which text by nesting the images list + out_multiimages = processor(text=text_multi_images, images=[image, image], return_tensors="np") + out_batch_oneimage = processor( + text=[text_single_image, text_single_image], images=[[image], [image]], return_tensors="np" + ) + self.assertListEqual( + out_batch_oneimage[self.images_input_name].tolist(), out_multiimages[self.images_input_name].tolist() + ) + + def test_special_mm_token_truncation(self): + """Tests that special vision tokens do not get truncated when `truncation=True` is set.""" + + processor = self.get_processor() + + input_str = self.prepare_text_inputs(batch_size=2, modalities="image") + image_input = self.prepare_image_inputs(batch_size=2) + _ = processor( + text=input_str, + images=image_input, + return_tensors="pt", + truncation=None, + padding=True, + ) + + with self.assertRaises(ValueError): + _ = processor( + text=input_str, + images=image_input, + return_tensors="pt", + truncation=True, + padding=True, + max_length=5, + ) + + def test_get_num_multimodal_tokens_matches_processor_call(self): + "Tests that the helper used internally in vLLM works correctly" + + processor = self.get_processor() + if processor.tokenizer.pad_token_id is None: + processor.tokenizer.pad_token_id = processor.tokenizer.eos_token_id + + if not hasattr(processor, "_get_num_multimodal_tokens"): + self.skipTest("Processor doesn't support `_get_num_multimodal_tokens` yet") + + image_sizes = [(100, 100), (300, 100), (500, 30), (213, 167)] + + # Overwritten because Gemma3 needs nested image inputs + image_inputs = [] + for h, w in image_sizes: + image_inputs.append([np.random.randint(255, size=(h, w, 3), dtype=np.uint8)]) + + text = [f"This is an image {getattr(self, 'image_token', '')}"] * len(image_inputs) + inputs = processor( + text=text, images=image_inputs, padding=True, return_mm_token_type_ids=True, return_tensors="pt" + ) + + if "mm_token_type_ids" not in inputs: + self.skipTest("Processor doesn't support `mm_token_type_ids`") + + num_image_tokens_from_call = inputs.mm_token_type_ids.sum(-1).tolist() + num_image_tokens_from_helper = processor._get_num_multimodal_tokens(image_sizes=image_sizes) + self.assertListEqual(num_image_tokens_from_call, num_image_tokens_from_helper["num_image_tokens"]) + + @unittest.skip("This test seems to be loading a different video, check for all models and fix") + def test_apply_chat_template_video_frame_sampling(self): + pass + + +class Gemma4AudioTokenCountTest(unittest.TestCase): + """Regression tests for _compute_audio_num_tokens. + + The original implementation used ceil(duration_ms / 40) which could overshoot + the actual encoder output length by 1 token for ~50% of audio lengths. + The fix replicates the exact mel-framing + conv-subsampling arithmetic. + """ + + @staticmethod + def _encoder_output_length(num_samples: int, sr: int = 16000) -> int: + """Reference implementation of the encoder's actual output length.""" + frame_length = int(round(sr * 20.0 / 1000.0)) + hop_length = int(round(sr * 10.0 / 1000.0)) + frame_size_for_unfold = frame_length + 1 + pad_left = frame_length // 2 + padded_samples = num_samples + pad_left + num_mel_frames = (padded_samples - frame_size_for_unfold) // hop_length + 1 + if num_mel_frames <= 0: + return 0 + t = num_mel_frames + for _ in range(2): + t_padded = t + 2 + t = (t_padded - 3) // 2 + 1 + return t + + @staticmethod + def _compute_tokens(num_samples, sr=16000): + """Call _compute_audio_num_tokens without constructing a full processor.""" + + class _Stub: + audio_seq_length = 1500 + + return Gemma4Processor._compute_audio_num_tokens(_Stub(), np.zeros(num_samples), sr) + + @parameterized.expand( + [ + ("over_1s_boundary", 16001), + ("bug_report_194_vs_193", 123521), + ("over_5s_boundary", 80001), + ("over_10s_boundary", 160001), + ("pad_left_effect_1s", 16161), + ] + ) + def test_audio_token_count_matches_encoder(self, _name, num_samples): + """Verify _compute_audio_num_tokens matches the encoder for edge-case lengths.""" + expected = self._encoder_output_length(num_samples) + actual = self._compute_tokens(num_samples) + self.assertEqual(actual, expected) + + @parameterized.expand( + [ + ("1s", 16000, 25), + ("5s", 80000, 125), + ("10s", 160000, 250), + ("30s", 480000, 750), + ] + ) + def test_audio_token_count_round_boundaries(self, _name, num_samples, expected_tokens): + """Verify exact results at round durations.""" + self.assertEqual(self._compute_tokens(num_samples), expected_tokens) + + def test_audio_token_count_short_audio(self): + """Very short audio that produces zero mel frames should return 0.""" + # With pad_left = 160 and frame_size_for_unfold = 321, anything <= 160 samples => 0 mel frames + self.assertEqual(self._compute_tokens(160), 0) + + @parameterized.expand( + [ + # Lengths where the old naive mask would produce +1 extra token + # after stride-2 conv subsampling. With sr=16000, hop=160, frame_size=321. + ("short_boundary", 641), + ("over_1s", 16001), + ("over_5s", 80001), + ("bug_report_length", 123521), + ("pad_left_effect_1s", 16161), + ] + ) + def test_feature_extractor_mask_matches_processor(self, _name, num_samples): + """Regression: feature extractor mask must agree with processor token count. + + The bug was that ``attention_mask[::hop]`` overcounts real mel frames by +2 + (marks frames as valid even when their window extends into padding). + After two stride-2 conv blocks this becomes +1 extra token ~50% of the time. + """ + from transformers import Gemma4AudioFeatureExtractor + + fe = Gemma4AudioFeatureExtractor() + + # Batch with a longer audio to force padding (the trigger for the bug) + target = np.random.randn(num_samples).astype(np.float32) + padding_partner = np.random.randn(num_samples + 5000).astype(np.float32) + + features = fe([target, padding_partner], return_tensors="np", padding="longest") + mask = features["input_features_mask"][0] # mask for target audio + + # Simulate two stride-2 conv blocks on the mask + T = len(mask) + for _ in range(2): + T_out = (T + 2 - 3) // 2 + 1 + mask = mask[::2][:T_out] + T = len(mask) + + real_tokens = int(mask.sum()) + expected = self._compute_tokens(num_samples) + self.assertEqual(real_tokens, expected) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index f16f5cff01e2..7bb449c5d956 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -1301,7 +1301,7 @@ def test_init_weights_can_init_buffers(self): config.scale = 0 for sub_key in config.sub_configs: subconfig = getattr(config, sub_key) - if hasattr(subconfig, "scale"): + if subconfig is not None and hasattr(subconfig, "scale"): subconfig.scale = 0 for model_class in self.all_model_classes: diff --git a/tests/test_processing_common.py b/tests/test_processing_common.py index 120168734d8b..5c632571b678 100644 --- a/tests/test_processing_common.py +++ b/tests/test_processing_common.py @@ -1721,7 +1721,7 @@ def test_apply_chat_template_video_frame_sampling(self): tokenize=True, return_dict=True, return_tensors="pt", - processor_kwargs={"num_frames": num_frames}, + processor_kwargs={"num_frames": num_frames, "fps": None}, ) self.assertTrue(self.videos_input_name in out_dict_with_video) self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1) @@ -1735,7 +1735,7 @@ def test_apply_chat_template_video_frame_sampling(self): tokenize=True, return_dict=True, return_tensors="pt", - processor_kwargs={"fps": fps}, + processor_kwargs={"fps": fps, "num_frames": None}, ) self.assertTrue(self.videos_input_name in out_dict_with_video) self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1) diff --git a/tests/utils/test_chat_parsing_utils.py b/tests/utils/test_chat_parsing_utils.py index 60bc181d1c88..8a2a6074686c 100644 --- a/tests/utils/test_chat_parsing_utils.py +++ b/tests/utils/test_chat_parsing_utils.py @@ -216,6 +216,40 @@ "x-regex": r"(\<\|channel\>thought\n(?P.*?)\)?(?P(?:(?!\<\|tool_call\>).)+)?(?P\<\|tool_call\>.*\)?", } +gemma4_schema = { + "type": "object", + "properties": { + "role": {"const": "assistant"}, + "thinking": {"type": "string"}, + "content": {"type": "string"}, + "tool_calls": { + "x-regex-iterator": r"<\|tool_call>(.*?)", + "type": "array", + "items": { + "type": "object", + "properties": { + "type": {"const": "function"}, + "function": { + "type": "object", + "x-regex": r"call\:(?P\w+)(?P\{.*\})", + "properties": { + "name": { + "type": "string", + }, + "arguments": { + "type": "object", + "x-parser": "gemma4-tool-call", + "additionalProperties": {}, + }, + }, + }, + }, + }, + }, + }, + "x-regex": r"(\<\|channel\>thought\n(?P.*?)\)?(?P(?:(?!\<\|tool_call\>).)+)?(?P\<\|tool_call\>.*\)?", +} + prefix_items_schema = { # Not intended to be "realistic", just checks that prefixItems can handle a heterogeneous array "x-regex-iterator": r"(.*?)<\/block>", @@ -446,6 +480,58 @@ def test_re_sub_schema(self): }, ) + def test_gemma4_tool_call(self): + model_out = '<|channel>thought\nThe user is asking for the current temperature in Paris. I should check the available tools to see if there\'s a function that can provide this information.<|tool_call>call:get_current_temperature{detail_level:0,location:<|"|>Paris, France<|"|>,unit:<|"|>celsius<|"|>}<|tool_response>' + parsed = recursive_parse(model_out, gemma4_schema) + self.assertEqual( + parsed, + { + "role": "assistant", + "thinking": "The user is asking for the current temperature in Paris. I should check the available tools to see if there's a function that can provide this information.", + "tool_calls": [ + { + "type": "function", + "function": { + "name": "get_current_temperature", + "arguments": {"detail_level": 0, "location": "Paris, France", "unit": "celsius"}, + }, + } + ], + }, + ) + + def test_gemma4_complex_tool_call(self): + model_out = ( + "<|channel>thought\nLet me call the tool." + '<|tool_call>call:foo{bool_value:true,list_value:[<|"|>foo<|"|>,<|"|>bar<|"|>],' + 'null_value:null,number_value:1,string_value:<|"|>foo<|"|>,' + 'struct_value:{foo:<|"|>bar<|"|>}}' + ) + parsed = recursive_parse(model_out, gemma4_schema) + self.assertEqual( + parsed, + { + "role": "assistant", + "thinking": "Let me call the tool.", + "tool_calls": [ + { + "type": "function", + "function": { + "name": "foo", + "arguments": { + "bool_value": True, + "list_value": ["foo", "bar"], + "null_value": None, + "number_value": 1, + "string_value": "foo", + "struct_value": {"foo": "bar"}, + }, + }, + } + ], + }, + ) + def test_required_fields_present(self): """Test that required fields pass validation when present in the output.""" schema = { diff --git a/tests/utils/test_modeling_rope_utils.py b/tests/utils/test_modeling_rope_utils.py index 1699a1ad74ba..45d86bfc2c87 100644 --- a/tests/utils/test_modeling_rope_utils.py +++ b/tests/utils/test_modeling_rope_utils.py @@ -38,6 +38,9 @@ def test_rope_validation(self): # If we explicitly set the other RoPE types, then validation should fail for rope_type in all_rope_types: + # proportional is same as default wrt to expected keys + if rope_type == "proportional": + continue config.rope_parameters = {"rope_type": rope_type, "rope_theta": 10000.0} with self.assertRaises(KeyError): config.validate_rope() @@ -52,6 +55,9 @@ def test_rope_validation(self): "long_factor": ["longrope"], } for rope_type in all_rope_types: + # proportional is same as default wrt to expected keys + if rope_type == "proportional": + continue for param, valid_rope_types in valid_param_mapping.items(): # Set `param` with a dummy value -- we want to test the dict key config.rope_parameters = {"rope_type": rope_type, "rope_theta": 10000.0, param: True} @@ -478,3 +484,136 @@ def test_llama3_rope_numerically(self): } inv_freq, _ = rope_fn(config=config, device=torch_device) torch.testing.assert_close(inv_freq, EXPECTED_INV_FREQ) + + def test_proportional_rope_numerically(self): + # fmt: off + EXPECTED_INV_FREQ = torch.tensor( + [ + 1.0000e+00, 8.6596e-01, 7.4989e-01, 6.4938e-01, 5.6234e-01, 4.8697e-01, + 4.2170e-01, 3.6517e-01, 3.1623e-01, 2.7384e-01, 2.3714e-01, 2.0535e-01, + 1.7783e-01, 1.5399e-01, 1.3335e-01, 1.1548e-01, 0.0000e+00, 0.0000e+00, + 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, + 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, + 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, + 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, + 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, + 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, + 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, + 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00 + ], device=torch_device + ) + # fmt: on + + # input sanity checks: if these change, the output will also change + config = LlamaConfig() + self.assertEqual(config.rope_parameters, {"rope_type": "default", "rope_theta": 10000.0}) + self.assertEqual(config.hidden_size, 4096) + self.assertEqual(config.num_attention_heads, 32) + self.assertFalse(hasattr(config, "partial_rotary_factor")) + + head_dim = config.hidden_size // config.num_attention_heads # 128 + + rope_fn = ROPE_INIT_FUNCTIONS["proportional"] + default_rope_fn = LlamaRotaryEmbedding.compute_default_rope_parameters + + # Check 1: `attention_factor` is always 1.0, regardless of parameters + for partial_rotary_factor in (1.0, 0.5, 0.25): + config.rope_parameters = { + "rope_type": "proportional", + "rope_theta": 10000.0, + "partial_rotary_factor": partial_rotary_factor, + } + _, attention_scale = rope_fn(config=config, device=torch_device) + self.assertEqual(attention_scale, 1.0) + + # Check 2: output shape is always head_dim // 2, regardless of partial_rotary_factor + for partial_rotary_factor in (1.0, 0.5, 0.25): + config.rope_parameters = { + "rope_type": "proportional", + "rope_theta": 10000.0, + "partial_rotary_factor": partial_rotary_factor, + } + inv_freq, _ = rope_fn(config=config, device=torch_device) + self.assertEqual(inv_freq.shape[0], head_dim // 2) + + # Check 3: zero-padding behavior — when partial_rotary_factor < 1.0, the last (head_dim // 2 - rope_angles) + # entries must be exactly zero, and the first rope_angles entries must be non-zero + for partial_rotary_factor, expected_rope_angles in ((0.5, 32), (0.25, 16)): + config.rope_parameters = { + "rope_type": "proportional", + "rope_theta": 10000.0, + "partial_rotary_factor": partial_rotary_factor, + } + inv_freq, _ = rope_fn(config=config, device=torch_device) + + # First rope_angles entries should be non-zero (rotated frequencies) + self.assertTrue(torch.all(inv_freq[:expected_rope_angles] != 0)) + # Remaining entries should be exactly zero (NoPE angles) + expected_nope_angles = head_dim // 2 - expected_rope_angles + torch.testing.assert_close( + inv_freq[expected_rope_angles:], + torch.zeros(expected_nope_angles, device=torch_device), + ) + + # When partial_rotary_factor = 1.0, no entries should be zero + config.rope_parameters = { + "rope_type": "proportional", + "rope_theta": 10000.0, + "partial_rotary_factor": 1.0, + } + inv_freq, _ = rope_fn(config=config, device=torch_device) + self.assertTrue(torch.all(inv_freq != 0)) + + # Check 4: factor scaling equivalences with default and linear RoPE + # 4a: With partial_rotary_factor=1.0 and factor=1.0, proportional RoPE == default RoPE + config.rope_parameters = { + "rope_type": "proportional", + "rope_theta": 10000.0, + "partial_rotary_factor": 1.0, + "factor": 1.0, + } + inv_freq_prop, _ = rope_fn(config=config, device=torch_device) + config.rope_parameters = {"rope_type": "default", "rope_theta": 10000.0} + default_inv_freq, _ = default_rope_fn(config=config, device=torch_device) + torch.testing.assert_close(inv_freq_prop, default_inv_freq) + + # 4b: With partial_rotary_factor=1.0 and factor=2.0, proportional RoPE == linear RoPE + linear_rope_fn = ROPE_INIT_FUNCTIONS["linear"] + for factor in (2.0, 10.0): + config.rope_parameters = { + "rope_type": "proportional", + "rope_theta": 10000.0, + "partial_rotary_factor": 1.0, + "factor": factor, + } + inv_freq_prop, _ = rope_fn(config=config, device=torch_device) + config.rope_parameters = {"rope_type": "linear", "rope_theta": 10000.0, "factor": factor} + inv_freq_linear, _ = linear_rope_fn(config=config, device=torch_device) + torch.testing.assert_close(inv_freq_prop, inv_freq_linear) + + # 4c: With partial_rotary_factor=0.5 and factor=2.0, the non-zero portion should be the rotated subspace + # frequencies divided by factor + config.rope_parameters = { + "rope_type": "proportional", + "rope_theta": 10000.0, + "partial_rotary_factor": 0.5, + "factor": 2.0, + } + inv_freq_scaled, _ = rope_fn(config=config, device=torch_device) + config.rope_parameters = { + "rope_type": "proportional", + "rope_theta": 10000.0, + "partial_rotary_factor": 0.5, + "factor": 1.0, + } + inv_freq_unscaled, _ = rope_fn(config=config, device=torch_device) + torch.testing.assert_close(inv_freq_scaled, inv_freq_unscaled / 2.0) + + # Check 5: numerical snapshot to avoid regressions (partial_rotary_factor=0.25, factor=1.0) + config.rope_parameters = { + "rope_type": "proportional", + "rope_theta": 10000.0, + "partial_rotary_factor": 0.25, + } + inv_freq, _ = rope_fn(config=config, device=torch_device) + torch.testing.assert_close(inv_freq, EXPECTED_INV_FREQ) diff --git a/utils/check_repo.py b/utils/check_repo.py index 1f327cbc7cf0..54901ef4afff 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -278,6 +278,8 @@ "VibeVoiceAcousticTokenizerDecoderModel", # Tested through VibeVoiceAcousticTokenizerModel "PI0Model", # special arch, tested through PI0ForConditionalGeneration "UVDocBridge", # Building part of a bigger model, tested implicitly through UVDocModel + "Gemma4VisionModel", # Building part of a bigger model, tested implicitly + "Gemma4AudioModel", # Building part of a bigger model, tested implicitly ] )