Skip to content
2 changes: 2 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2771,6 +2771,7 @@ def _get_resized_embeddings(
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()

if old_num_tokens == new_num_tokens and not is_deepspeed_zero3_enabled():
old_embeddings.num_embeddings = new_num_tokens # maybe weights are tied which doesn't update attr
return old_embeddings

if not isinstance(old_embeddings, nn.Embedding):
Expand Down Expand Up @@ -2910,6 +2911,7 @@ def _get_resized_lm_head(
)

if old_num_tokens == new_num_tokens and not is_deepspeed_zero3_enabled():
old_lm_head.out_features = new_num_tokens # maybe weights are tied which doesn't update attr
return old_lm_head

if not isinstance(old_lm_head, nn.Linear):
Expand Down
6 changes: 0 additions & 6 deletions src/transformers/models/gemma3/modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,12 +960,6 @@ def __init__(self, config: Gemma3Config):
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()

def get_input_embeddings(self):
return self.model.get_input_embeddings()

def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)

Comment on lines -963 to -968
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as base class, so no need to override

@auto_docstring
def get_image_features(self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]):
return self.model.get_image_features(pixel_values, **kwargs)
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/gemma3n/configuration_gemma3n.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
if is_timm_available():
from timm.data import ImageNetInfo, infer_imagenet_subset


logger = logging.get_logger(__name__)


Expand Down
61 changes: 55 additions & 6 deletions src/transformers/models/gemma3n/modeling_gemma3n.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
TransformersKwargs,
auto_docstring,
can_return_tuple,
is_accelerate_available,
torch_compilable_check,
)
from ...utils.generic import maybe_autocast, merge_with_config_defaults
Expand All @@ -51,6 +52,10 @@
from .configuration_gemma3n import Gemma3nAudioConfig, Gemma3nConfig, Gemma3nTextConfig, Gemma3nVisionConfig


if is_accelerate_available():
from accelerate.hooks import add_hook_to_module


@dataclass
@auto_docstring
class Gemma3nAudioEncoderModelOutput(BaseModelOutputWithPooling):
Expand Down Expand Up @@ -1406,6 +1411,44 @@ def _init_weights(self, module):
if hasattr(module, "gradient_clipping"):
init.constant_(module.gradient_clipping, self.config.gradient_clipping)

def get_per_layer_input_embeddings(self):
return self.base_model.embed_tokens_per_layer

def set_per_layer_input_embeddings(self, value):
self.base_model.embed_tokens_per_layer = value

def resize_token_embeddings(
self,
new_num_tokens: int | None = None,
pad_to_multiple_of: int | None = None,
mean_resizing: bool = True,
) -> nn.Embedding:
inputs_embeds = super().resize_token_embeddings(
new_num_tokens=new_num_tokens,
pad_to_multiple_of=pad_to_multiple_of,
mean_resizing=mean_resizing,
)
self._resize_per_layer_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
return inputs_embeds

def _resize_per_layer_embeddings(
self,
new_num_tokens: int | None = None,
pad_to_multiple_of: int | None = None,
mean_resizing: bool = True,
):
self.config.get_text_config().vocab_size_per_layer_input = self.vocab_size
if self.config.get_text_config().hidden_size_per_layer_input:
embed_tokens_per_layer = self.get_per_layer_input_embeddings()
new_embeddings_per_layer = self._get_resized_embeddings(
embed_tokens_per_layer, new_num_tokens, pad_to_multiple_of, mean_resizing
)
if hasattr(embed_tokens_per_layer, "_hf_hook"):
hook = embed_tokens_per_layer._hf_hook
add_hook_to_module(new_embeddings_per_layer, hook)
new_embeddings_per_layer.requires_grad_(embed_tokens_per_layer.weight.requires_grad)
self.set_per_layer_input_embeddings(new_embeddings_per_layer)


class Gemma3nAudioEncoder(Gemma3nPreTrainedModel):
"""
Expand Down Expand Up @@ -2128,6 +2171,12 @@ def forward(
audio_hidden_states=audio_features if input_features is not None else None,
)

def get_per_layer_input_embeddings(self):
return self.language_model.embed_tokens_per_layer

def set_per_layer_input_embeddings(self, value):
self.language_model.embed_tokens_per_layer = value

@can_return_tuple
@auto_docstring(custom_intro="Projects the last hidden state from the audio encoder into language model space.")
def get_audio_features(
Expand Down Expand Up @@ -2167,12 +2216,6 @@ def __init__(self, config: Gemma3nConfig):
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()

def get_input_embeddings(self):
return self.model.get_input_embeddings()

def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)

@auto_docstring
def get_image_features(self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]):
return self.model.get_image_features(pixel_values, **kwargs)
Expand Down Expand Up @@ -2323,6 +2366,12 @@ def prepare_inputs_for_generation(

return model_inputs

def get_per_layer_input_embeddings(self):
return self.model.get_per_layer_input_embeddings()

def set_per_layer_input_embeddings(self, value):
self.model.set_per_layer_input_embeddings(value)


__all__ = [
"Gemma3nAudioEncoder",
Expand Down
62 changes: 61 additions & 1 deletion src/transformers/models/gemma3n/modular_gemma3n.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,14 @@
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check
from ...utils import (
TransformersKwargs,
auto_docstring,
can_return_tuple,
is_accelerate_available,
logging,
torch_compilable_check,
)
from ...utils.generic import merge_with_config_defaults
from ...utils.output_capturing import capture_outputs
from ..auto import AutoModel
Expand Down Expand Up @@ -59,6 +66,9 @@
from ..timm_wrapper.configuration_timm_wrapper import TimmWrapperConfig


if is_accelerate_available():
from accelerate.hooks import add_hook_to_module

logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -1653,6 +1663,44 @@ def _init_weights(self, module):
if hasattr(module, "gradient_clipping"):
init.constant_(module.gradient_clipping, self.config.gradient_clipping)

def get_per_layer_input_embeddings(self):
return self.base_model.embed_tokens_per_layer

def set_per_layer_input_embeddings(self, value):
self.base_model.embed_tokens_per_layer = value

def resize_token_embeddings(
self,
new_num_tokens: int | None = None,
pad_to_multiple_of: int | None = None,
mean_resizing: bool = True,
) -> nn.Embedding:
inputs_embeds = super().resize_token_embeddings(
new_num_tokens=new_num_tokens,
pad_to_multiple_of=pad_to_multiple_of,
mean_resizing=mean_resizing,
)
self._resize_per_layer_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
return inputs_embeds

def _resize_per_layer_embeddings(
self,
new_num_tokens: int | None = None,
pad_to_multiple_of: int | None = None,
mean_resizing: bool = True,
):
self.config.get_text_config().vocab_size_per_layer_input = self.vocab_size
if self.config.get_text_config().hidden_size_per_layer_input:
embed_tokens_per_layer = self.get_per_layer_input_embeddings()
new_embeddings_per_layer = self._get_resized_embeddings(
embed_tokens_per_layer, new_num_tokens, pad_to_multiple_of, mean_resizing
)
if hasattr(embed_tokens_per_layer, "_hf_hook"):
hook = embed_tokens_per_layer._hf_hook
add_hook_to_module(new_embeddings_per_layer, hook)
new_embeddings_per_layer.requires_grad_(embed_tokens_per_layer.weight.requires_grad)
self.set_per_layer_input_embeddings(new_embeddings_per_layer)


class Gemma3nAudioEncoder(Gemma3nPreTrainedModel):
"""
Expand Down Expand Up @@ -1995,6 +2043,12 @@ def __init__(self, config: Gemma3nConfig):
self.embed_vision = Gemma3nMultimodalEmbedder(config.vision_config, config.text_config)
self.embed_audio = Gemma3nMultimodalEmbedder(config.audio_config, config.text_config)

def get_per_layer_input_embeddings(self):
return self.language_model.embed_tokens_per_layer

def set_per_layer_input_embeddings(self, value):
self.language_model.embed_tokens_per_layer = value

@can_return_tuple
@auto_docstring(custom_intro="Projects the last hidden state from the vision model into language model space.")
def get_image_features(
Expand Down Expand Up @@ -2230,6 +2284,12 @@ def get_audio_features(
class Gemma3nForConditionalGeneration(PaliGemmaForConditionalGeneration):
accepts_loss_kwargs = False

def get_per_layer_input_embeddings(self):
return self.model.get_per_layer_input_embeddings()

def set_per_layer_input_embeddings(self, value):
self.model.set_per_layer_input_embeddings(value)

@can_return_tuple
@auto_docstring
def forward(
Expand Down
80 changes: 68 additions & 12 deletions src/transformers/models/gemma4/modeling_gemma4.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,24 @@
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check
from ...utils import (
ModelOutput,
TransformersKwargs,
auto_docstring,
can_return_tuple,
is_accelerate_available,
torch_compilable_check,
)
from ...utils.generic import maybe_autocast, merge_with_config_defaults
from ...utils.output_capturing import OutputRecorder, capture_outputs
from ..auto.modeling_auto import AutoModel
from .configuration_gemma4 import Gemma4AudioConfig, Gemma4Config, Gemma4TextConfig, Gemma4VisionConfig


if is_accelerate_available():
from accelerate.hooks import add_hook_to_module


@dataclass
@auto_docstring(
custom_intro="""
Expand Down Expand Up @@ -1425,19 +1436,20 @@ def forward(self, input_ids: torch.Tensor):
return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)


# ---- Model Classes ----


@auto_docstring
class Gemma4PreTrainedModel(PreTrainedModel):
config: Gemma4Config
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["Gemma4TextDecoderLayer", "Gemma4VisionEncoderLayer", "Gemma4AudioLayer"]
_skip_keys_device_placement = ["past_key_values", "shared_kv_states"]
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True

_can_compile_fullgraph = True
_supports_attention_backend = True
_no_split_modules = ["Gemma4TextDecoderLayer", "Gemma4VisionEncoderLayer", "Gemma4AudioLayer"]
_skip_keys_device_placement = ["past_key_values", "shared_kv_states"]
_can_record_outputs = None # override
input_modalities = ("image", "text", "video", "audio")

@torch.no_grad()
Expand Down Expand Up @@ -1493,6 +1505,44 @@ def _init_weights(self, module):
init.zeros_(module.std_bias)
init.ones_(module.std_scale)

def get_per_layer_input_embeddings(self):
return self.base_model.embed_tokens_per_layer

def set_per_layer_input_embeddings(self, value):
self.base_model.embed_tokens_per_layer = value

def resize_token_embeddings(
self,
new_num_tokens: int | None = None,
pad_to_multiple_of: int | None = None,
mean_resizing: bool = True,
) -> nn.Embedding:
inputs_embeds = super().resize_token_embeddings(
new_num_tokens=new_num_tokens,
pad_to_multiple_of=pad_to_multiple_of,
mean_resizing=mean_resizing,
)
self._resize_per_layer_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
return inputs_embeds

def _resize_per_layer_embeddings(
self,
new_num_tokens: int | None = None,
pad_to_multiple_of: int | None = None,
mean_resizing: bool = True,
):
self.config.get_text_config().vocab_size_per_layer_input = self.vocab_size
if self.config.get_text_config().hidden_size_per_layer_input:
embed_tokens_per_layer = self.get_per_layer_input_embeddings()
new_embeddings_per_layer = self._get_resized_embeddings(
embed_tokens_per_layer, new_num_tokens, pad_to_multiple_of, mean_resizing
)
if hasattr(embed_tokens_per_layer, "_hf_hook"):
hook = embed_tokens_per_layer._hf_hook
add_hook_to_module(new_embeddings_per_layer, hook)
new_embeddings_per_layer.requires_grad_(embed_tokens_per_layer.weight.requires_grad)
self.set_per_layer_input_embeddings(new_embeddings_per_layer)


@auto_docstring(custom_intro="The base Gemma 4 language model without a language modeling head.")
class Gemma4TextModel(Gemma4PreTrainedModel):
Expand Down Expand Up @@ -2331,6 +2381,12 @@ def forward(
audio_hidden_states=audio_features if input_features is not None else None,
)

def get_per_layer_input_embeddings(self):
return self.language_model.embed_tokens_per_layer

def set_per_layer_input_embeddings(self, value):
self.language_model.embed_tokens_per_layer = value

@can_return_tuple
@auto_docstring(custom_intro="Projects the last hidden state from the audio encoder into language model space.")
def get_audio_features(
Expand Down Expand Up @@ -2402,12 +2458,6 @@ def __init__(self, config: Gemma4Config):
]
self.post_init()

def get_input_embeddings(self):
return self.model.get_input_embeddings()

def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)

@auto_docstring
def get_image_features(
self,
Expand Down Expand Up @@ -2536,6 +2586,12 @@ def prepare_inputs_for_generation(

return model_inputs

def get_per_layer_input_embeddings(self):
return self.model.get_per_layer_input_embeddings()

def set_per_layer_input_embeddings(self, value):
self.model.set_per_layer_input_embeddings(value)

@staticmethod
def create_masks_for_generate(
config: PreTrainedConfig,
Expand Down
Loading
Loading