From 180a0ddd070c2c6bd4f0baeeca6d07453d84a470 Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 8 Apr 2026 18:39:40 +0200 Subject: [PATCH 01/10] gemma4 --- .../models/gemma3n/modeling_gemma3n.py | 6 -- .../models/gemma4/modeling_gemma4.py | 65 +++++++++++++++++-- .../models/gemma4/modular_gemma4.py | 55 ++++++++++++++++ .../models/paligemma/modeling_paligemma.py | 6 -- 4 files changed, 115 insertions(+), 17 deletions(-) diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py index edca10b4f48e..ca440fa91ef3 100644 --- a/src/transformers/models/gemma3n/modeling_gemma3n.py +++ b/src/transformers/models/gemma3n/modeling_gemma3n.py @@ -2166,12 +2166,6 @@ def __init__(self, config: Gemma3nConfig): self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) self.post_init() - def get_input_embeddings(self): - return self.model.get_input_embeddings() - - def set_input_embeddings(self, value): - self.model.set_input_embeddings(value) - @auto_docstring def get_image_features(self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]): return self.model.get_image_features(pixel_values, **kwargs) diff --git a/src/transformers/models/gemma4/modeling_gemma4.py b/src/transformers/models/gemma4/modeling_gemma4.py index f690c0425c8c..b4ff306aaa66 100644 --- a/src/transformers/models/gemma4/modeling_gemma4.py +++ b/src/transformers/models/gemma4/modeling_gemma4.py @@ -46,13 +46,24 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check +from ...utils import ( + ModelOutput, + TransformersKwargs, + auto_docstring, + can_return_tuple, + is_accelerate_available, + torch_compilable_check, +) from ...utils.generic import maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import OutputRecorder, capture_outputs from ..auto.modeling_auto import AutoModel from .configuration_gemma4 import Gemma4AudioConfig, Gemma4Config, Gemma4TextConfig, Gemma4VisionConfig +if is_accelerate_available(): + from accelerate.hooks import add_hook_to_module + + @dataclass @auto_docstring( custom_intro=""" @@ -1485,6 +1496,44 @@ def _init_weights(self, module): init.zeros_(module.std_bias) init.ones_(module.std_scale) + def get_per_layer_input_embeddings(self): + return self.base_model.embed_tokens_per_layer + + def set_per_layer_input_embeddings(self, value): + self.base_model.embed_tokens_per_layer = value + + def resize_token_embeddings( + self, + new_num_tokens: int | None = None, + pad_to_multiple_of: int | None = None, + mean_resizing: bool = True, + ) -> nn.Embedding: + inputs_embeds = super().resize_token_embeddings( + new_num_tokens=new_num_tokens, + pad_to_multiple_of=pad_to_multiple_of, + mean_resizing=mean_resizing, + ) + self._resize_per_layer_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing) + return inputs_embeds + + def _resize_per_layer_embeddings( + self, + new_num_tokens: int | None = None, + pad_to_multiple_of: int | None = None, + mean_resizing: bool = True, + ): + self.config.get_text_config().vocab_size_per_layer_input = self.vocab_size + if self.config.get_text_config().hidden_size_per_layer_input: + embed_tokens_per_layer = self.get_per_layer_input_embeddings() + new_embeddings_per_layer = self._get_resized_embeddings( + embed_tokens_per_layer, new_num_tokens, pad_to_multiple_of, mean_resizing + ) + if hasattr(embed_tokens_per_layer, "_hf_hook"): + hook = embed_tokens_per_layer._hf_hook + add_hook_to_module(new_embeddings_per_layer, hook) + new_embeddings_per_layer.requires_grad_(embed_tokens_per_layer.weight.requires_grad) + self.set_per_layer_input_embeddings(new_embeddings_per_layer) + @auto_docstring(custom_intro="The base Gemma 4 language model without a language modeling head.") class Gemma4TextModel(Gemma4PreTrainedModel): @@ -2080,6 +2129,12 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) + def get_per_layer_input_embeddings(self): + return self.language_model.embed_tokens_per_layer + + def set_per_layer_input_embeddings(self, value): + self.language_model.embed_tokens_per_layer = value + @can_return_tuple @auto_docstring(custom_intro="Projects the last hidden state from the vision model into language model space.") def get_image_features( @@ -2371,11 +2426,11 @@ def __init__(self, config: Gemma4Config): self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) self.post_init() - def get_input_embeddings(self): - return self.model.get_input_embeddings() + def get_per_layer_input_embeddings(self): + return self.model.get_per_layer_input_embeddings() - def set_input_embeddings(self, value): - self.model.set_input_embeddings(value) + def set_per_layer_input_embeddings(self, value): + self.model.set_per_layer_input_embeddings(value) @auto_docstring def get_image_features( diff --git a/src/transformers/models/gemma4/modular_gemma4.py b/src/transformers/models/gemma4/modular_gemma4.py index a97273802213..e3e878c3430a 100644 --- a/src/transformers/models/gemma4/modular_gemma4.py +++ b/src/transformers/models/gemma4/modular_gemma4.py @@ -41,6 +41,7 @@ TransformersKwargs, auto_docstring, can_return_tuple, + is_accelerate_available, logging, torch_compilable_check, ) @@ -72,6 +73,10 @@ from .configuration_gemma4 import Gemma4AudioConfig, Gemma4Config, Gemma4TextConfig, Gemma4VisionConfig +if is_accelerate_available(): + from accelerate.hooks import add_hook_to_module + + logger = logging.get_logger(__name__) @@ -1209,6 +1214,44 @@ def _init_weights(self, module): init.zeros_(module.std_bias) init.ones_(module.std_scale) + def get_per_layer_input_embeddings(self): + return self.base_model.embed_tokens_per_layer + + def set_per_layer_input_embeddings(self, value): + self.base_model.embed_tokens_per_layer = value + + def resize_token_embeddings( + self, + new_num_tokens: int | None = None, + pad_to_multiple_of: int | None = None, + mean_resizing: bool = True, + ) -> nn.Embedding: + inputs_embeds = super().resize_token_embeddings( + new_num_tokens=new_num_tokens, + pad_to_multiple_of=pad_to_multiple_of, + mean_resizing=mean_resizing, + ) + self._resize_per_layer_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing) + return inputs_embeds + + def _resize_per_layer_embeddings( + self, + new_num_tokens: int | None = None, + pad_to_multiple_of: int | None = None, + mean_resizing: bool = True, + ): + self.config.get_text_config().vocab_size_per_layer_input = self.vocab_size + if self.config.get_text_config().hidden_size_per_layer_input: + embed_tokens_per_layer = self.get_per_layer_input_embeddings() + new_embeddings_per_layer = self._get_resized_embeddings( + embed_tokens_per_layer, new_num_tokens, pad_to_multiple_of, mean_resizing + ) + if hasattr(embed_tokens_per_layer, "_hf_hook"): + hook = embed_tokens_per_layer._hf_hook + add_hook_to_module(new_embeddings_per_layer, hook) + new_embeddings_per_layer.requires_grad_(embed_tokens_per_layer.weight.requires_grad) + self.set_per_layer_input_embeddings(new_embeddings_per_layer) + @auto_docstring(custom_intro="The base Gemma 4 language model without a language modeling head.") class Gemma4TextModel(Gemma3TextModel): @@ -1691,6 +1734,12 @@ def __init__(self, config: Gemma4Config): else None ) + def get_per_layer_input_embeddings(self): + return self.language_model.embed_tokens_per_layer + + def set_per_layer_input_embeddings(self, value): + self.language_model.embed_tokens_per_layer = value + @can_return_tuple @auto_docstring(custom_intro="Projects the last hidden state from the vision model into language model space.") def get_image_features( @@ -1975,6 +2024,12 @@ def get_audio_features( class Gemma4ForConditionalGeneration(Gemma3nForConditionalGeneration): base_model_prefix = "model" + def get_per_layer_input_embeddings(self): + return self.model.get_per_layer_input_embeddings() + + def set_per_layer_input_embeddings(self, value): + self.model.set_per_layer_input_embeddings(value) + def forward( self, input_ids: torch.LongTensor | None = None, diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index 369514a55f76..6eeeaa6bd681 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -398,12 +398,6 @@ def __init__(self, config: PaliGemmaConfig): self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) self.post_init() - def get_input_embeddings(self): - return self.model.get_input_embeddings() - - def set_input_embeddings(self, value): - self.model.set_input_embeddings(value) - @auto_docstring def get_image_features(self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]): return self.model.get_image_features(pixel_values, **kwargs) From 304acd6a91cad3e505946320ff3e69362b5130ca Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 8 Apr 2026 18:44:51 +0200 Subject: [PATCH 02/10] mask out per-layer input correctly! --- src/transformers/models/gemma3n/modeling_gemma3n.py | 10 +++++----- src/transformers/models/gemma3n/modular_gemma3n.py | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py index ca440fa91ef3..370065c9e1f9 100644 --- a/src/transformers/models/gemma3n/modeling_gemma3n.py +++ b/src/transformers/models/gemma3n/modeling_gemma3n.py @@ -2045,11 +2045,6 @@ def forward( if input_ids is not None: inputs_embeds = self.get_input_embeddings()(input_ids) - # Prepare per-layer inputs from inputs_ids - per_layer_inputs_mask = torch.logical_and(input_ids >= 0, input_ids < self.vocab_size_per_layer_input) - per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids)) - per_layer_inputs = self.language_model.get_per_layer_inputs(per_layer_inputs_tokens) - # Handle vision tokens (>= embed_vision.vocab_offset and < embed_audio.vocab_offset) vision_mask = torch.logical_and( input_ids >= self.embed_vision.vocab_offset, input_ids < self.embed_audio.vocab_offset @@ -2069,6 +2064,11 @@ def forward( audio_embeds = audio_embeds.to(inputs_embeds.device, inputs_embeds.dtype) expanded_audio_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds) inputs_embeds = torch.where(expanded_audio_mask, audio_embeds, inputs_embeds) + + # Prepare per-layer inputs from inputs_ids by masking out multimodal soft tokens + per_layer_inputs_mask = vision_mask | audio_mask + per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, torch.zeros_like(input_ids)), input_ids + per_layer_inputs = self.language_model.get_per_layer_inputs(per_layer_inputs_tokens) else: per_layer_inputs = None diff --git a/src/transformers/models/gemma3n/modular_gemma3n.py b/src/transformers/models/gemma3n/modular_gemma3n.py index d5633a689687..c38f0ebc279d 100644 --- a/src/transformers/models/gemma3n/modular_gemma3n.py +++ b/src/transformers/models/gemma3n/modular_gemma3n.py @@ -2115,11 +2115,6 @@ def forward( if input_ids is not None: inputs_embeds = self.get_input_embeddings()(input_ids) - # Prepare per-layer inputs from inputs_ids - per_layer_inputs_mask = torch.logical_and(input_ids >= 0, input_ids < self.vocab_size_per_layer_input) - per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids)) - per_layer_inputs = self.language_model.get_per_layer_inputs(per_layer_inputs_tokens) - # Handle vision tokens (>= embed_vision.vocab_offset and < embed_audio.vocab_offset) vision_mask = torch.logical_and( input_ids >= self.embed_vision.vocab_offset, input_ids < self.embed_audio.vocab_offset @@ -2139,6 +2134,11 @@ def forward( audio_embeds = audio_embeds.to(inputs_embeds.device, inputs_embeds.dtype) expanded_audio_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds) inputs_embeds = torch.where(expanded_audio_mask, audio_embeds, inputs_embeds) + + # Prepare per-layer inputs from inputs_ids by masking out multimodal soft tokens + per_layer_inputs_mask = vision_mask | audio_mask + per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, torch.zeros_like(input_ids)), input_ids + per_layer_inputs = self.language_model.get_per_layer_inputs(per_layer_inputs_tokens) else: per_layer_inputs = None From 31bb636c41b619cdc7452fc1fa8ad318323b4e7c Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 9 Apr 2026 12:29:19 +0200 Subject: [PATCH 03/10] revert gemma3n and add t5gemma --- src/transformers/models/gemma3n/modeling_gemma3n.py | 10 +++++----- src/transformers/models/gemma3n/modular_gemma3n.py | 10 +++++----- src/transformers/models/t5gemma/modeling_t5gemma.py | 5 +++++ src/transformers/models/t5gemma/modular_t5gemma.py | 5 +++++ 4 files changed, 20 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py index 370065c9e1f9..ca440fa91ef3 100644 --- a/src/transformers/models/gemma3n/modeling_gemma3n.py +++ b/src/transformers/models/gemma3n/modeling_gemma3n.py @@ -2045,6 +2045,11 @@ def forward( if input_ids is not None: inputs_embeds = self.get_input_embeddings()(input_ids) + # Prepare per-layer inputs from inputs_ids + per_layer_inputs_mask = torch.logical_and(input_ids >= 0, input_ids < self.vocab_size_per_layer_input) + per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids)) + per_layer_inputs = self.language_model.get_per_layer_inputs(per_layer_inputs_tokens) + # Handle vision tokens (>= embed_vision.vocab_offset and < embed_audio.vocab_offset) vision_mask = torch.logical_and( input_ids >= self.embed_vision.vocab_offset, input_ids < self.embed_audio.vocab_offset @@ -2064,11 +2069,6 @@ def forward( audio_embeds = audio_embeds.to(inputs_embeds.device, inputs_embeds.dtype) expanded_audio_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds) inputs_embeds = torch.where(expanded_audio_mask, audio_embeds, inputs_embeds) - - # Prepare per-layer inputs from inputs_ids by masking out multimodal soft tokens - per_layer_inputs_mask = vision_mask | audio_mask - per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, torch.zeros_like(input_ids)), input_ids - per_layer_inputs = self.language_model.get_per_layer_inputs(per_layer_inputs_tokens) else: per_layer_inputs = None diff --git a/src/transformers/models/gemma3n/modular_gemma3n.py b/src/transformers/models/gemma3n/modular_gemma3n.py index c38f0ebc279d..d5633a689687 100644 --- a/src/transformers/models/gemma3n/modular_gemma3n.py +++ b/src/transformers/models/gemma3n/modular_gemma3n.py @@ -2115,6 +2115,11 @@ def forward( if input_ids is not None: inputs_embeds = self.get_input_embeddings()(input_ids) + # Prepare per-layer inputs from inputs_ids + per_layer_inputs_mask = torch.logical_and(input_ids >= 0, input_ids < self.vocab_size_per_layer_input) + per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids)) + per_layer_inputs = self.language_model.get_per_layer_inputs(per_layer_inputs_tokens) + # Handle vision tokens (>= embed_vision.vocab_offset and < embed_audio.vocab_offset) vision_mask = torch.logical_and( input_ids >= self.embed_vision.vocab_offset, input_ids < self.embed_audio.vocab_offset @@ -2134,11 +2139,6 @@ def forward( audio_embeds = audio_embeds.to(inputs_embeds.device, inputs_embeds.dtype) expanded_audio_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds) inputs_embeds = torch.where(expanded_audio_mask, audio_embeds, inputs_embeds) - - # Prepare per-layer inputs from inputs_ids by masking out multimodal soft tokens - per_layer_inputs_mask = vision_mask | audio_mask - per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, torch.zeros_like(input_ids)), input_ids - per_layer_inputs = self.language_model.get_per_layer_inputs(per_layer_inputs_tokens) else: per_layer_inputs = None diff --git a/src/transformers/models/t5gemma/modeling_t5gemma.py b/src/transformers/models/t5gemma/modeling_t5gemma.py index a6b9b5392194..2c7f8df57571 100644 --- a/src/transformers/models/t5gemma/modeling_t5gemma.py +++ b/src/transformers/models/t5gemma/modeling_t5gemma.py @@ -961,6 +961,11 @@ def __init__(self, config: T5GemmaConfig): def set_output_embeddings(self, new_embeddings): self.lm_head.out_proj = new_embeddings + # The tying happens from decoder to lm-head, but when resizing + # the resized embed is assigned only to the head. Then tying weights + # again reverts everything back. So we have to update decoder here + if self.config.tie_word_embeddings: + self.model.decoder.embed_tokens = new_embeddings def get_output_embeddings(self): return self.lm_head.out_proj diff --git a/src/transformers/models/t5gemma/modular_t5gemma.py b/src/transformers/models/t5gemma/modular_t5gemma.py index c7d4a4051959..61725cc1e6ac 100644 --- a/src/transformers/models/t5gemma/modular_t5gemma.py +++ b/src/transformers/models/t5gemma/modular_t5gemma.py @@ -800,6 +800,11 @@ def __init__(self, config: T5GemmaConfig): def set_output_embeddings(self, new_embeddings): self.lm_head.out_proj = new_embeddings + # The tying happens from decoder to lm-head, but when resizing + # the resized embed is assigned only to the head. Then tying weights + # again reverts everything back. So we have to update decoder here + if self.config.tie_word_embeddings: + self.model.decoder.embed_tokens = new_embeddings def get_output_embeddings(self): return self.lm_head.out_proj From 02e4218dfee1e0ec3dfe7673ca4aa320e8c733d8 Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 9 Apr 2026 12:36:27 +0200 Subject: [PATCH 04/10] fix repo --- .../models/gemma3/modeling_gemma3.py | 6 ----- .../models/gemma4/modeling_gemma4.py | 24 +++++++++---------- 2 files changed, 12 insertions(+), 18 deletions(-) diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 0dd41d6fd450..f433c54b6b38 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -960,12 +960,6 @@ def __init__(self, config: Gemma3Config): self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) self.post_init() - def get_input_embeddings(self): - return self.model.get_input_embeddings() - - def set_input_embeddings(self, value): - self.model.set_input_embeddings(value) - @auto_docstring def get_image_features(self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]): return self.model.get_image_features(pixel_values, **kwargs) diff --git a/src/transformers/models/gemma4/modeling_gemma4.py b/src/transformers/models/gemma4/modeling_gemma4.py index b4ff306aaa66..63c21e01180a 100644 --- a/src/transformers/models/gemma4/modeling_gemma4.py +++ b/src/transformers/models/gemma4/modeling_gemma4.py @@ -2129,12 +2129,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def get_per_layer_input_embeddings(self): - return self.language_model.embed_tokens_per_layer - - def set_per_layer_input_embeddings(self, value): - self.language_model.embed_tokens_per_layer = value - @can_return_tuple @auto_docstring(custom_intro="Projects the last hidden state from the vision model into language model space.") def get_image_features( @@ -2385,6 +2379,12 @@ def get_audio_features( return audio_outputs + def get_per_layer_input_embeddings(self): + return self.language_model.embed_tokens_per_layer + + def set_per_layer_input_embeddings(self, value): + self.language_model.embed_tokens_per_layer = value + @can_return_tuple @auto_docstring(custom_intro="Projects the last hidden state from the vision encoder into language model space.") def get_video_features( @@ -2426,12 +2426,6 @@ def __init__(self, config: Gemma4Config): self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) self.post_init() - def get_per_layer_input_embeddings(self): - return self.model.get_per_layer_input_embeddings() - - def set_per_layer_input_embeddings(self, value): - self.model.set_per_layer_input_embeddings(value) - @auto_docstring def get_image_features( self, @@ -2578,6 +2572,12 @@ def prepare_inputs_for_generation( return model_inputs + def get_per_layer_input_embeddings(self): + return self.model.get_per_layer_input_embeddings() + + def set_per_layer_input_embeddings(self, value): + self.model.set_per_layer_input_embeddings(value) + @staticmethod def create_masks_for_generate( config: PreTrainedConfig, From bbbd4340eddffcd5313ef7f7daada3cef2ceb978 Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 9 Apr 2026 12:40:20 +0200 Subject: [PATCH 05/10] the test --- tests/test_modeling_common.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index f21c3bcef9e9..2eb6f9490ece 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2155,6 +2155,11 @@ def test_resize_tokens_embeddings(self): # Check that the model can still do a forward pass successfully (every parameter should be resized) if not is_deepspeed_zero3_enabled(): + # Input ids should be expanded to the new maximum size of the vocabulary + inputs_dict["input_ids"][:, -1] = new_model_vocab_size - 1 + if "decoder_input_ids" in inputs_dict: + inputs_dict["decoder_input_ids"][:, -1] = new_model_vocab_size - 1 + # A distriputed launcher is needed for the forward pass when deepspeed is enabled model_inputs = self._prepare_for_class(inputs_dict, model_class) model(**model_inputs) From 431da57fb493cc61d47e47a48f1c5d8af3453967 Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 9 Apr 2026 14:21:58 +0200 Subject: [PATCH 06/10] fix some tests --- tests/models/blip/test_modeling_blip.py | 4 ++-- .../colmodernvbert/test_modeling_colmodernvbert.py | 9 +++++---- tests/models/lfm2_vl/test_modeling_lfm2_vl.py | 2 +- tests/models/qwen3_vl/test_modeling_qwen3_vl.py | 4 ++-- tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py | 4 ++-- tests/test_modeling_common.py | 4 +--- 6 files changed, 13 insertions(+), 14 deletions(-) diff --git a/tests/models/blip/test_modeling_blip.py b/tests/models/blip/test_modeling_blip.py index 47a30b88db83..0a3561d45db2 100644 --- a/tests/models/blip/test_modeling_blip.py +++ b/tests/models/blip/test_modeling_blip.py @@ -652,10 +652,10 @@ def prepare_config_and_inputs_for_common(self): config, input_ids, attention_mask, pixel_values = config_and_inputs inputs_dict = { "input_ids": input_ids, - "decoder_input_ids": input_ids, + "decoder_input_ids": input_ids.clone(), "attention_mask": attention_mask, "pixel_values": pixel_values, - "labels": input_ids, + "labels": input_ids.clone(), } return config, inputs_dict diff --git a/tests/models/colmodernvbert/test_modeling_colmodernvbert.py b/tests/models/colmodernvbert/test_modeling_colmodernvbert.py index 2f5134036d52..170ed352b863 100755 --- a/tests/models/colmodernvbert/test_modeling_colmodernvbert.py +++ b/tests/models/colmodernvbert/test_modeling_colmodernvbert.py @@ -49,6 +49,7 @@ def __init__( parent, batch_size=2, num_images=2, + seq_length=7, ignore_index=-100, text_config=None, is_training=False, @@ -98,10 +99,11 @@ def __init__( self.pixel_shuffle_factor = pixel_shuffle_factor self.image_token_id = self.text_config["vocab_size"] - 1 self.pad_token_id = text_config["pad_token_id"] - self.seq_length = ( + self.image_seq_length = ( int(((vision_config["image_size"] // vision_config["patch_size"]) ** 2) / (pixel_shuffle_factor**2)) * self.num_images ) + self.seq_length = seq_length + self.image_seq_length self.hidden_size = text_config["hidden_size"] self.num_hidden_layers = text_config["num_hidden_layers"] @@ -136,9 +138,8 @@ def prepare_config_and_inputs_for_common(self): input_ids = ids_tensor([self.batch_size, self.seq_length], config.vlm_config.text_config.vocab_size) attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device) - # For simplicity just set the last n tokens to the image token - n_image_tokens_per_batch = self.seq_length - input_ids[:, -n_image_tokens_per_batch:] = self.image_token_id + # For simplicity just set the first n tokens to the image token + input_ids[:, : self.image_seq_length] = self.image_token_id attention_mask = input_ids.ne(1).to(torch_device) inputs_dict = { "pixel_values": pixel_values, diff --git a/tests/models/lfm2_vl/test_modeling_lfm2_vl.py b/tests/models/lfm2_vl/test_modeling_lfm2_vl.py index c14e3933f77b..6e0576efe3d5 100644 --- a/tests/models/lfm2_vl/test_modeling_lfm2_vl.py +++ b/tests/models/lfm2_vl/test_modeling_lfm2_vl.py @@ -135,7 +135,7 @@ def prepare_config_and_inputs_for_common(self): # For simplicity just set the last n tokens to the image token input_ids[input_ids == self.image_token_id] = self.text_config["pad_token_id"] - input_ids[:, -self.image_seq_length :] = self.image_token_id + input_ids[:, : self.image_seq_length] = self.image_token_id attention_mask = input_ids.ne(1).to(torch_device) inputs_dict = { diff --git a/tests/models/qwen3_vl/test_modeling_qwen3_vl.py b/tests/models/qwen3_vl/test_modeling_qwen3_vl.py index b7e0b9053c25..cb28890602a0 100644 --- a/tests/models/qwen3_vl/test_modeling_qwen3_vl.py +++ b/tests/models/qwen3_vl/test_modeling_qwen3_vl.py @@ -103,8 +103,8 @@ def place_image_tokens(self, input_ids, config): input_ids[input_ids == self.image_token_id] = self.pad_token_id input_ids[input_ids == self.vision_start_token_id] = self.pad_token_id # Place image tokens with vision_start_token_id prefix - input_ids[:, -1] = self.image_token_id - input_ids[:, -2] = self.vision_start_token_id + input_ids[:, 1] = self.image_token_id + input_ids[:, 0] = self.vision_start_token_id return input_ids def get_additional_inputs(self, config, input_ids, pixel_values): diff --git a/tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py b/tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py index 2de7b384d075..0b0523de3b71 100644 --- a/tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py +++ b/tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py @@ -102,8 +102,8 @@ def place_image_tokens(self, input_ids, config): input_ids[input_ids == self.image_token_id] = self.pad_token_id input_ids[input_ids == self.vision_start_token_id] = self.pad_token_id # Place image tokens with vision_start_token_id prefix - input_ids[:, -1] = self.image_token_id - input_ids[:, -2] = self.vision_start_token_id + input_ids[:, 1] = self.image_token_id + input_ids[:, 0] = self.vision_start_token_id return input_ids def get_additional_inputs(self, config, input_ids, pixel_values): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 2eb6f9490ece..3b511e3be41c 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2156,9 +2156,7 @@ def test_resize_tokens_embeddings(self): # Check that the model can still do a forward pass successfully (every parameter should be resized) if not is_deepspeed_zero3_enabled(): # Input ids should be expanded to the new maximum size of the vocabulary - inputs_dict["input_ids"][:, -1] = new_model_vocab_size - 1 - if "decoder_input_ids" in inputs_dict: - inputs_dict["decoder_input_ids"][:, -1] = new_model_vocab_size - 1 + inputs_dict["input_ids"][:, -2] = new_model_vocab_size - 1 # A distriputed launcher is needed for the forward pass when deepspeed is enabled model_inputs = self._prepare_for_class(inputs_dict, model_class) From 8cc8f93c98547e84b4c435b6c7e663c5565dd913 Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 9 Apr 2026 15:26:57 +0200 Subject: [PATCH 07/10] gemma3n also has a causal model --- .../models/gemma3n/configuration_gemma3n.py | 1 - .../models/gemma3n/modeling_gemma3n.py | 64 +++++++++++++++++ .../models/gemma3n/modular_gemma3n.py | 71 ++++++++++++++++++- .../models/t5gemma/modeling_t5gemma.py | 3 +- .../models/t5gemma/modular_t5gemma.py | 3 +- .../test_modeling_colmodernvbert.py | 1 + 6 files changed, 139 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/gemma3n/configuration_gemma3n.py b/src/transformers/models/gemma3n/configuration_gemma3n.py index ee3ed0348de4..0c6729ef9345 100644 --- a/src/transformers/models/gemma3n/configuration_gemma3n.py +++ b/src/transformers/models/gemma3n/configuration_gemma3n.py @@ -30,7 +30,6 @@ if is_timm_available(): from timm.data import ImageNetInfo, infer_imagenet_subset - logger = logging.get_logger(__name__) diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py index ca440fa91ef3..4aad6c51b1af 100644 --- a/src/transformers/models/gemma3n/modeling_gemma3n.py +++ b/src/transformers/models/gemma3n/modeling_gemma3n.py @@ -43,6 +43,7 @@ TransformersKwargs, auto_docstring, can_return_tuple, + is_accelerate_available, torch_compilable_check, ) from ...utils.generic import maybe_autocast, merge_with_config_defaults @@ -51,6 +52,10 @@ from .configuration_gemma3n import Gemma3nAudioConfig, Gemma3nConfig, Gemma3nTextConfig, Gemma3nVisionConfig +if is_accelerate_available(): + from accelerate.hooks import add_hook_to_module + + @dataclass @auto_docstring class Gemma3nAudioEncoderModelOutput(BaseModelOutputWithPooling): @@ -1406,6 +1411,37 @@ def _init_weights(self, module): if hasattr(module, "gradient_clipping"): init.constant_(module.gradient_clipping, self.config.gradient_clipping) + def resize_token_embeddings( + self, + new_num_tokens: int | None = None, + pad_to_multiple_of: int | None = None, + mean_resizing: bool = True, + ) -> nn.Embedding: + inputs_embeds = super().resize_token_embeddings( + new_num_tokens=new_num_tokens, + pad_to_multiple_of=pad_to_multiple_of, + mean_resizing=mean_resizing, + ) + self._resize_per_layer_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing) + return inputs_embeds + + def _resize_per_layer_embeddings( + self, + new_num_tokens: int | None = None, + pad_to_multiple_of: int | None = None, + mean_resizing: bool = True, + ): + self.config.vocab_size_per_layer_input = self.vocab_size + embed_tokens_per_layer = self.base_model.embed_tokens_per_layer + new_embeddings_per_layer = self._get_resized_embeddings( + embed_tokens_per_layer, new_num_tokens, pad_to_multiple_of, mean_resizing + ) + if hasattr(embed_tokens_per_layer, "_hf_hook"): + hook = embed_tokens_per_layer._hf_hook + add_hook_to_module(new_embeddings_per_layer, hook) + new_embeddings_per_layer.requires_grad_(embed_tokens_per_layer.weight.requires_grad) + self.base_model.embed_tokens_per_layer = new_embeddings_per_layer + class Gemma3nAudioEncoder(Gemma3nPreTrainedModel): """ @@ -2150,6 +2186,20 @@ def get_audio_features( return audio_outputs + def resize_token_embeddings( + self, + new_num_tokens: int | None = None, + pad_to_multiple_of: int | None = None, + mean_resizing: bool = True, + ) -> nn.Embedding: + inputs_embeds = super().resize_token_embeddings( + new_num_tokens=new_num_tokens, + pad_to_multiple_of=pad_to_multiple_of, + mean_resizing=mean_resizing, + ) + # TODO: fix resizing for embeds per-layer by filtering out mm-soft-tokens + return inputs_embeds + @auto_docstring( custom_intro=""" @@ -2334,6 +2384,20 @@ def prepare_inputs_for_generation( return model_inputs + def resize_token_embeddings( + self, + new_num_tokens: int | None = None, + pad_to_multiple_of: int | None = None, + mean_resizing: bool = True, + ) -> nn.Embedding: + inputs_embeds = super().resize_token_embeddings( + new_num_tokens=new_num_tokens, + pad_to_multiple_of=pad_to_multiple_of, + mean_resizing=mean_resizing, + ) + # TODO: fix resizing for embeds per-layer by filtering out mm-soft-tokens + return inputs_embeds + __all__ = [ "Gemma3nAudioEncoder", diff --git a/src/transformers/models/gemma3n/modular_gemma3n.py b/src/transformers/models/gemma3n/modular_gemma3n.py index d5633a689687..f769084f64a1 100644 --- a/src/transformers/models/gemma3n/modular_gemma3n.py +++ b/src/transformers/models/gemma3n/modular_gemma3n.py @@ -31,7 +31,14 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check +from ...utils import ( + TransformersKwargs, + auto_docstring, + can_return_tuple, + is_accelerate_available, + logging, + torch_compilable_check, +) from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ..auto import AutoModel @@ -59,6 +66,9 @@ from ..timm_wrapper.configuration_timm_wrapper import TimmWrapperConfig +if is_accelerate_available(): + from accelerate.hooks import add_hook_to_module + logger = logging.get_logger(__name__) @@ -1653,6 +1663,37 @@ def _init_weights(self, module): if hasattr(module, "gradient_clipping"): init.constant_(module.gradient_clipping, self.config.gradient_clipping) + def resize_token_embeddings( + self, + new_num_tokens: int | None = None, + pad_to_multiple_of: int | None = None, + mean_resizing: bool = True, + ) -> nn.Embedding: + inputs_embeds = super().resize_token_embeddings( + new_num_tokens=new_num_tokens, + pad_to_multiple_of=pad_to_multiple_of, + mean_resizing=mean_resizing, + ) + self._resize_per_layer_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing) + return inputs_embeds + + def _resize_per_layer_embeddings( + self, + new_num_tokens: int | None = None, + pad_to_multiple_of: int | None = None, + mean_resizing: bool = True, + ): + self.config.vocab_size_per_layer_input = self.vocab_size + embed_tokens_per_layer = self.base_model.embed_tokens_per_layer + new_embeddings_per_layer = self._get_resized_embeddings( + embed_tokens_per_layer, new_num_tokens, pad_to_multiple_of, mean_resizing + ) + if hasattr(embed_tokens_per_layer, "_hf_hook"): + hook = embed_tokens_per_layer._hf_hook + add_hook_to_module(new_embeddings_per_layer, hook) + new_embeddings_per_layer.requires_grad_(embed_tokens_per_layer.weight.requires_grad) + self.base_model.embed_tokens_per_layer = new_embeddings_per_layer + class Gemma3nAudioEncoder(Gemma3nPreTrainedModel): """ @@ -2220,6 +2261,20 @@ def get_audio_features( return audio_outputs + def resize_token_embeddings( + self, + new_num_tokens: int | None = None, + pad_to_multiple_of: int | None = None, + mean_resizing: bool = True, + ) -> nn.Embedding: + inputs_embeds = super().resize_token_embeddings( + new_num_tokens=new_num_tokens, + pad_to_multiple_of=pad_to_multiple_of, + mean_resizing=mean_resizing, + ) + # TODO: fix resizing for embeds per-layer by filtering out mm-soft-tokens + return inputs_embeds + @auto_docstring( custom_intro=""" @@ -2395,6 +2450,20 @@ def prepare_inputs_for_generation( def create_masks_for_generate(self, **super_kwargs): raise AttributeError("Do not inherit create_masks_for_generate from PaliGemma") + def resize_token_embeddings( + self, + new_num_tokens: int | None = None, + pad_to_multiple_of: int | None = None, + mean_resizing: bool = True, + ) -> nn.Embedding: + inputs_embeds = super().resize_token_embeddings( + new_num_tokens=new_num_tokens, + pad_to_multiple_of=pad_to_multiple_of, + mean_resizing=mean_resizing, + ) + # TODO: fix resizing for embeds per-layer by filtering out mm-soft-tokens + return inputs_embeds + __all__ = [ "Gemma3nAudioConfig", diff --git a/src/transformers/models/t5gemma/modeling_t5gemma.py b/src/transformers/models/t5gemma/modeling_t5gemma.py index 2c7f8df57571..1f41875c5def 100644 --- a/src/transformers/models/t5gemma/modeling_t5gemma.py +++ b/src/transformers/models/t5gemma/modeling_t5gemma.py @@ -965,7 +965,8 @@ def set_output_embeddings(self, new_embeddings): # the resized embed is assigned only to the head. Then tying weights # again reverts everything back. So we have to update decoder here if self.config.tie_word_embeddings: - self.model.decoder.embed_tokens = new_embeddings + self.model.decoder.embed_tokens.weight = new_embeddings.weight + self.model.decoder.embed_tokens.num_embeddings = new_embeddings.weight.shape[0] def get_output_embeddings(self): return self.lm_head.out_proj diff --git a/src/transformers/models/t5gemma/modular_t5gemma.py b/src/transformers/models/t5gemma/modular_t5gemma.py index 61725cc1e6ac..1c8846ad74b9 100644 --- a/src/transformers/models/t5gemma/modular_t5gemma.py +++ b/src/transformers/models/t5gemma/modular_t5gemma.py @@ -804,7 +804,8 @@ def set_output_embeddings(self, new_embeddings): # the resized embed is assigned only to the head. Then tying weights # again reverts everything back. So we have to update decoder here if self.config.tie_word_embeddings: - self.model.decoder.embed_tokens = new_embeddings + self.model.decoder.embed_tokens.weight = new_embeddings.weight + self.model.decoder.embed_tokens.num_embeddings = new_embeddings.weight.shape[0] def get_output_embeddings(self): return self.lm_head.out_proj diff --git a/tests/models/colmodernvbert/test_modeling_colmodernvbert.py b/tests/models/colmodernvbert/test_modeling_colmodernvbert.py index 170ed352b863..4b7391407529 100755 --- a/tests/models/colmodernvbert/test_modeling_colmodernvbert.py +++ b/tests/models/colmodernvbert/test_modeling_colmodernvbert.py @@ -139,6 +139,7 @@ def prepare_config_and_inputs_for_common(self): attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device) # For simplicity just set the first n tokens to the image token + input_ids[input_ids == self.image_token_id] = self.pad_token_id input_ids[:, : self.image_seq_length] = self.image_token_id attention_mask = input_ids.ne(1).to(torch_device) inputs_dict = { From a1f1f8799c127742f64932299adae98a63bfeec1 Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 9 Apr 2026 16:30:53 +0200 Subject: [PATCH 08/10] fix repo --- .../models/gemma3n/modeling_gemma3n.py | 65 ++++++++---------- .../models/gemma3n/modular_gemma3n.py | 67 ++++++++----------- .../models/gemma4/modeling_gemma4.py | 11 +-- .../models/gemma4/modular_gemma4.py | 54 ++------------- 4 files changed, 68 insertions(+), 129 deletions(-) diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py index 4aad6c51b1af..f1a09e6ef4d1 100644 --- a/src/transformers/models/gemma3n/modeling_gemma3n.py +++ b/src/transformers/models/gemma3n/modeling_gemma3n.py @@ -1411,6 +1411,12 @@ def _init_weights(self, module): if hasattr(module, "gradient_clipping"): init.constant_(module.gradient_clipping, self.config.gradient_clipping) + def get_per_layer_input_embeddings(self): + return self.base_model.embed_tokens_per_layer + + def set_per_layer_input_embeddings(self, value): + self.base_model.embed_tokens_per_layer = value + def resize_token_embeddings( self, new_num_tokens: int | None = None, @@ -1431,16 +1437,17 @@ def _resize_per_layer_embeddings( pad_to_multiple_of: int | None = None, mean_resizing: bool = True, ): - self.config.vocab_size_per_layer_input = self.vocab_size - embed_tokens_per_layer = self.base_model.embed_tokens_per_layer - new_embeddings_per_layer = self._get_resized_embeddings( - embed_tokens_per_layer, new_num_tokens, pad_to_multiple_of, mean_resizing - ) - if hasattr(embed_tokens_per_layer, "_hf_hook"): - hook = embed_tokens_per_layer._hf_hook - add_hook_to_module(new_embeddings_per_layer, hook) - new_embeddings_per_layer.requires_grad_(embed_tokens_per_layer.weight.requires_grad) - self.base_model.embed_tokens_per_layer = new_embeddings_per_layer + self.config.get_text_config().vocab_size_per_layer_input = self.vocab_size + if self.config.get_text_config().hidden_size_per_layer_input: + embed_tokens_per_layer = self.get_per_layer_input_embeddings() + new_embeddings_per_layer = self._get_resized_embeddings( + embed_tokens_per_layer, new_num_tokens, pad_to_multiple_of, mean_resizing + ) + if hasattr(embed_tokens_per_layer, "_hf_hook"): + hook = embed_tokens_per_layer._hf_hook + add_hook_to_module(new_embeddings_per_layer, hook) + new_embeddings_per_layer.requires_grad_(embed_tokens_per_layer.weight.requires_grad) + self.set_per_layer_input_embeddings(new_embeddings_per_layer) class Gemma3nAudioEncoder(Gemma3nPreTrainedModel): @@ -2164,6 +2171,12 @@ def forward( audio_hidden_states=audio_features if input_features is not None else None, ) + def get_per_layer_input_embeddings(self): + return self.language_model.embed_tokens_per_layer + + def set_per_layer_input_embeddings(self, value): + self.language_model.embed_tokens_per_layer = value + @can_return_tuple @auto_docstring(custom_intro="Projects the last hidden state from the audio encoder into language model space.") def get_audio_features( @@ -2186,20 +2199,6 @@ def get_audio_features( return audio_outputs - def resize_token_embeddings( - self, - new_num_tokens: int | None = None, - pad_to_multiple_of: int | None = None, - mean_resizing: bool = True, - ) -> nn.Embedding: - inputs_embeds = super().resize_token_embeddings( - new_num_tokens=new_num_tokens, - pad_to_multiple_of=pad_to_multiple_of, - mean_resizing=mean_resizing, - ) - # TODO: fix resizing for embeds per-layer by filtering out mm-soft-tokens - return inputs_embeds - @auto_docstring( custom_intro=""" @@ -2384,19 +2383,11 @@ def prepare_inputs_for_generation( return model_inputs - def resize_token_embeddings( - self, - new_num_tokens: int | None = None, - pad_to_multiple_of: int | None = None, - mean_resizing: bool = True, - ) -> nn.Embedding: - inputs_embeds = super().resize_token_embeddings( - new_num_tokens=new_num_tokens, - pad_to_multiple_of=pad_to_multiple_of, - mean_resizing=mean_resizing, - ) - # TODO: fix resizing for embeds per-layer by filtering out mm-soft-tokens - return inputs_embeds + def get_per_layer_input_embeddings(self): + return self.model.get_per_layer_input_embeddings() + + def set_per_layer_input_embeddings(self, value): + self.model.set_per_layer_input_embeddings(value) __all__ = [ diff --git a/src/transformers/models/gemma3n/modular_gemma3n.py b/src/transformers/models/gemma3n/modular_gemma3n.py index f769084f64a1..d48086cc3b4a 100644 --- a/src/transformers/models/gemma3n/modular_gemma3n.py +++ b/src/transformers/models/gemma3n/modular_gemma3n.py @@ -1663,6 +1663,12 @@ def _init_weights(self, module): if hasattr(module, "gradient_clipping"): init.constant_(module.gradient_clipping, self.config.gradient_clipping) + def get_per_layer_input_embeddings(self): + return self.base_model.embed_tokens_per_layer + + def set_per_layer_input_embeddings(self, value): + self.base_model.embed_tokens_per_layer = value + def resize_token_embeddings( self, new_num_tokens: int | None = None, @@ -1683,16 +1689,17 @@ def _resize_per_layer_embeddings( pad_to_multiple_of: int | None = None, mean_resizing: bool = True, ): - self.config.vocab_size_per_layer_input = self.vocab_size - embed_tokens_per_layer = self.base_model.embed_tokens_per_layer - new_embeddings_per_layer = self._get_resized_embeddings( - embed_tokens_per_layer, new_num_tokens, pad_to_multiple_of, mean_resizing - ) - if hasattr(embed_tokens_per_layer, "_hf_hook"): - hook = embed_tokens_per_layer._hf_hook - add_hook_to_module(new_embeddings_per_layer, hook) - new_embeddings_per_layer.requires_grad_(embed_tokens_per_layer.weight.requires_grad) - self.base_model.embed_tokens_per_layer = new_embeddings_per_layer + self.config.get_text_config().vocab_size_per_layer_input = self.vocab_size + if self.config.get_text_config().hidden_size_per_layer_input: + embed_tokens_per_layer = self.get_per_layer_input_embeddings() + new_embeddings_per_layer = self._get_resized_embeddings( + embed_tokens_per_layer, new_num_tokens, pad_to_multiple_of, mean_resizing + ) + if hasattr(embed_tokens_per_layer, "_hf_hook"): + hook = embed_tokens_per_layer._hf_hook + add_hook_to_module(new_embeddings_per_layer, hook) + new_embeddings_per_layer.requires_grad_(embed_tokens_per_layer.weight.requires_grad) + self.set_per_layer_input_embeddings(new_embeddings_per_layer) class Gemma3nAudioEncoder(Gemma3nPreTrainedModel): @@ -2036,6 +2043,12 @@ def __init__(self, config: Gemma3nConfig): self.embed_vision = Gemma3nMultimodalEmbedder(config.vision_config, config.text_config) self.embed_audio = Gemma3nMultimodalEmbedder(config.audio_config, config.text_config) + def get_per_layer_input_embeddings(self): + return self.language_model.embed_tokens_per_layer + + def set_per_layer_input_embeddings(self, value): + self.language_model.embed_tokens_per_layer = value + @can_return_tuple @auto_docstring(custom_intro="Projects the last hidden state from the vision model into language model space.") def get_image_features( @@ -2261,20 +2274,6 @@ def get_audio_features( return audio_outputs - def resize_token_embeddings( - self, - new_num_tokens: int | None = None, - pad_to_multiple_of: int | None = None, - mean_resizing: bool = True, - ) -> nn.Embedding: - inputs_embeds = super().resize_token_embeddings( - new_num_tokens=new_num_tokens, - pad_to_multiple_of=pad_to_multiple_of, - mean_resizing=mean_resizing, - ) - # TODO: fix resizing for embeds per-layer by filtering out mm-soft-tokens - return inputs_embeds - @auto_docstring( custom_intro=""" @@ -2283,6 +2282,12 @@ def resize_token_embeddings( """ ) class Gemma3nForConditionalGeneration(PaliGemmaForConditionalGeneration): + def get_per_layer_input_embeddings(self): + return self.model.get_per_layer_input_embeddings() + + def set_per_layer_input_embeddings(self, value): + self.model.set_per_layer_input_embeddings(value) + @can_return_tuple @auto_docstring def forward( @@ -2450,20 +2455,6 @@ def prepare_inputs_for_generation( def create_masks_for_generate(self, **super_kwargs): raise AttributeError("Do not inherit create_masks_for_generate from PaliGemma") - def resize_token_embeddings( - self, - new_num_tokens: int | None = None, - pad_to_multiple_of: int | None = None, - mean_resizing: bool = True, - ) -> nn.Embedding: - inputs_embeds = super().resize_token_embeddings( - new_num_tokens=new_num_tokens, - pad_to_multiple_of=pad_to_multiple_of, - mean_resizing=mean_resizing, - ) - # TODO: fix resizing for embeds per-layer by filtering out mm-soft-tokens - return inputs_embeds - __all__ = [ "Gemma3nAudioConfig", diff --git a/src/transformers/models/gemma4/modeling_gemma4.py b/src/transformers/models/gemma4/modeling_gemma4.py index 63c21e01180a..686f42297f56 100644 --- a/src/transformers/models/gemma4/modeling_gemma4.py +++ b/src/transformers/models/gemma4/modeling_gemma4.py @@ -1428,19 +1428,20 @@ def forward(self, input_ids: torch.Tensor): return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype) -# ---- Model Classes ---- - - +@auto_docstring class Gemma4PreTrainedModel(PreTrainedModel): config: Gemma4Config + base_model_prefix = "model" supports_gradient_checkpointing = True + _no_split_modules = ["Gemma4TextDecoderLayer", "Gemma4VisionEncoderLayer", "Gemma4AudioLayer"] + _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True + _can_compile_fullgraph = True _supports_attention_backend = True - _no_split_modules = ["Gemma4TextDecoderLayer", "Gemma4VisionEncoderLayer", "Gemma4AudioLayer"] - _skip_keys_device_placement = ["past_key_values"] + _can_record_outputs = None # override input_modalities = ("image", "text", "video", "audio") @torch.no_grad() diff --git a/src/transformers/models/gemma4/modular_gemma4.py b/src/transformers/models/gemma4/modular_gemma4.py index e3e878c3430a..55f3d770125d 100644 --- a/src/transformers/models/gemma4/modular_gemma4.py +++ b/src/transformers/models/gemma4/modular_gemma4.py @@ -63,6 +63,7 @@ Gemma3nModel, Gemma3nModelOutputWithPast, Gemma3nMultimodalEmbedder, + Gemma3nPreTrainedModel, Gemma3nRMSNorm, apply_rotary_pos_emb, eager_attention_forward, @@ -74,7 +75,7 @@ if is_accelerate_available(): - from accelerate.hooks import add_hook_to_module + pass logger = logging.get_logger(__name__) @@ -1149,21 +1150,14 @@ class Gemma4TextScaledWordEmbedding(Gemma3TextScaledWordEmbedding): # ---- Model Classes ---- -class Gemma4PreTrainedModel(PreTrainedModel): - config: Gemma4Config - supports_gradient_checkpointing = True - _supports_flash_attn = True - _supports_sdpa = True - _supports_flex_attn = True - _can_compile_fullgraph = True - _supports_attention_backend = True +class Gemma4PreTrainedModel(Gemma3nPreTrainedModel): _no_split_modules = ["Gemma4TextDecoderLayer", "Gemma4VisionEncoderLayer", "Gemma4AudioLayer"] - _skip_keys_device_placement = ["past_key_values"] input_modalities = ("image", "text", "video", "audio") + _can_record_outputs = None # override @torch.no_grad() def _init_weights(self, module): - super()._init_weights(module) + PreTrainedModel._init_weights(module) if isinstance(module, Gemma4VisionPatchEmbedder): init.ones_(module.position_embedding_table) elif isinstance(module, Gemma4AudioRelPositionalEncoding): @@ -1214,44 +1208,6 @@ def _init_weights(self, module): init.zeros_(module.std_bias) init.ones_(module.std_scale) - def get_per_layer_input_embeddings(self): - return self.base_model.embed_tokens_per_layer - - def set_per_layer_input_embeddings(self, value): - self.base_model.embed_tokens_per_layer = value - - def resize_token_embeddings( - self, - new_num_tokens: int | None = None, - pad_to_multiple_of: int | None = None, - mean_resizing: bool = True, - ) -> nn.Embedding: - inputs_embeds = super().resize_token_embeddings( - new_num_tokens=new_num_tokens, - pad_to_multiple_of=pad_to_multiple_of, - mean_resizing=mean_resizing, - ) - self._resize_per_layer_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing) - return inputs_embeds - - def _resize_per_layer_embeddings( - self, - new_num_tokens: int | None = None, - pad_to_multiple_of: int | None = None, - mean_resizing: bool = True, - ): - self.config.get_text_config().vocab_size_per_layer_input = self.vocab_size - if self.config.get_text_config().hidden_size_per_layer_input: - embed_tokens_per_layer = self.get_per_layer_input_embeddings() - new_embeddings_per_layer = self._get_resized_embeddings( - embed_tokens_per_layer, new_num_tokens, pad_to_multiple_of, mean_resizing - ) - if hasattr(embed_tokens_per_layer, "_hf_hook"): - hook = embed_tokens_per_layer._hf_hook - add_hook_to_module(new_embeddings_per_layer, hook) - new_embeddings_per_layer.requires_grad_(embed_tokens_per_layer.weight.requires_grad) - self.set_per_layer_input_embeddings(new_embeddings_per_layer) - @auto_docstring(custom_intro="The base Gemma 4 language model without a language modeling head.") class Gemma4TextModel(Gemma3TextModel): From 295d36d5c65f3b2b370d7d4b909240ea9ac274fe Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 9 Apr 2026 16:26:40 +0000 Subject: [PATCH 09/10] Apply repo consistency fixes --- src/transformers/models/gemma4/modeling_gemma4.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/gemma4/modeling_gemma4.py b/src/transformers/models/gemma4/modeling_gemma4.py index 51804c523e52..ebc309155af6 100644 --- a/src/transformers/models/gemma4/modeling_gemma4.py +++ b/src/transformers/models/gemma4/modeling_gemma4.py @@ -1442,6 +1442,7 @@ class Gemma4PreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["Gemma4TextDecoderLayer", "Gemma4VisionEncoderLayer", "Gemma4AudioLayer"] + _skip_keys_device_placement = ["past_key_values", "shared_kv_states"] _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True @@ -1449,7 +1450,6 @@ class Gemma4PreTrainedModel(PreTrainedModel): _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = None # override - _skip_keys_device_placement = ["past_key_values", "shared_kv_states"] input_modalities = ("image", "text", "video", "audio") @torch.no_grad() @@ -2381,6 +2381,12 @@ def forward( audio_hidden_states=audio_features if input_features is not None else None, ) + def get_per_layer_input_embeddings(self): + return self.language_model.embed_tokens_per_layer + + def set_per_layer_input_embeddings(self, value): + self.language_model.embed_tokens_per_layer = value + @can_return_tuple @auto_docstring(custom_intro="Projects the last hidden state from the audio encoder into language model space.") def get_audio_features( @@ -2406,12 +2412,6 @@ def get_audio_features( return audio_outputs - def get_per_layer_input_embeddings(self): - return self.language_model.embed_tokens_per_layer - - def set_per_layer_input_embeddings(self, value): - self.language_model.embed_tokens_per_layer = value - @can_return_tuple @auto_docstring(custom_intro="Projects the last hidden state from the vision encoder into language model space.") def get_video_features( From f7efff62ae556ede282348821f7006f90ba807ac Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 14 Apr 2026 18:13:22 +0200 Subject: [PATCH 10/10] update attr even if weights are tied --- src/transformers/modeling_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 2907b2b987cb..c6d694f68218 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2771,6 +2771,7 @@ def _get_resized_embeddings( old_num_tokens, old_embedding_dim = old_embeddings.weight.size() if old_num_tokens == new_num_tokens and not is_deepspeed_zero3_enabled(): + old_embeddings.num_embeddings = new_num_tokens # maybe weights are tied which doesn't update attr return old_embeddings if not isinstance(old_embeddings, nn.Embedding): @@ -2910,6 +2911,7 @@ def _get_resized_lm_head( ) if old_num_tokens == new_num_tokens and not is_deepspeed_zero3_enabled(): + old_lm_head.out_features = new_num_tokens # maybe weights are tied which doesn't update attr return old_lm_head if not isinstance(old_lm_head, nn.Linear):