From 535353cd97d08ada0c833ddff94c68a6c2550f5d Mon Sep 17 00:00:00 2001 From: omar zoloev Date: Thu, 23 Apr 2026 18:06:00 +0300 Subject: [PATCH 1/4] Update test_modeling_gemma4.py --- tests/models/gemma4/test_modeling_gemma4.py | 28 +++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/models/gemma4/test_modeling_gemma4.py b/tests/models/gemma4/test_modeling_gemma4.py index e694fae48362..2b5d046d8941 100644 --- a/tests/models/gemma4/test_modeling_gemma4.py +++ b/tests/models/gemma4/test_modeling_gemma4.py @@ -270,6 +270,34 @@ def test_num_layers_is_small(self): def test_generate_from_random_inputs_embeds(self): pass + def test_audio_rel_pos_encoding_uses_context_size_from_config(self): + from transformers.models.gemma4.configuration_gemma4 import Gemma4AudioConfig + from transformers.models.gemma4.modeling_gemma4 import Gemma4AudioRelPositionalEncoding + + config = Gemma4AudioConfig( + hidden_size=32, + attention_chunk_size=6, + attention_context_left=5, + attention_context_right=1, + use_clipped_linears=False, + ) + + module = Gemma4AudioRelPositionalEncoding(config) + hidden_states = torch.zeros(1, 3, config.hidden_size) + + pos = module(hidden_states) + + context_size = config.attention_chunk_size + config.attention_context_left - 1 + config.attention_context_right + expected_len = context_size // 2 + 1 + + self.assertEqual(pos.shape, (1, expected_len, config.hidden_size)) + + position_ids = torch.arange(context_size // 2, -1, -1, device=hidden_states.device)[..., None] + scaled_time = position_ids * module.inv_timescales.to(device=hidden_states.device) + expected = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1).to(hidden_states.dtype) + + torch.testing.assert_close(pos, expected) + class Gemma4Vision2TextModelTester: def __init__( From 8aa98afa55478502100025f94b3adf7034d0fddc Mon Sep 17 00:00:00 2001 From: omar zoloev Date: Thu, 23 Apr 2026 19:38:45 +0300 Subject: [PATCH 2/4] Update modeling_gemma4.py --- src/transformers/models/gemma4/modeling_gemma4.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/gemma4/modeling_gemma4.py b/src/transformers/models/gemma4/modeling_gemma4.py index 978a0bda8cff..5e9f720fe800 100644 --- a/src/transformers/models/gemma4/modeling_gemma4.py +++ b/src/transformers/models/gemma4/modeling_gemma4.py @@ -189,7 +189,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class Gemma4AudioRelPositionalEncoding(nn.Module): """Sinusoidal relative positional encoding for the audio encoder. - Produces position embeddings of shape [1, 2*context_size - 1, hidden_size] with + Produces position embeddings of shape [1, context_size // 2 + 1, hidden_size] with concatenated [sin..., cos...] layout matching the original Gemma4 convention. """ @@ -210,7 +210,7 @@ def __init__(self, config: Gemma4AudioConfig): @torch.no_grad() def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - position_ids = torch.arange(12, -1, -1, device=hidden_states.device) + position_ids = torch.arange(self.context_size // 2, -1, -1, device=hidden_states.device) position_ids = position_ids[..., None] scaled_time = position_ids * self.inv_timescales.to(device=hidden_states.device) pos_embed = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1) @@ -1133,6 +1133,7 @@ def forward(self, x, position_ids, layer_type=None): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) +@use_kernelized_func(apply_rotary_pos_emb) class Gemma4TextAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -1235,9 +1236,9 @@ def forward( if self.store_full_length_kv: shared_kv_states[self.layer_idx] = key_states, value_states - attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( - self.config._attn_implementation, eager_attention_forward - ) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, From 3a6ec3948b218476f9c5159da806b736cb1ca6d6 Mon Sep 17 00:00:00 2001 From: omar zoloev Date: Thu, 23 Apr 2026 19:42:24 +0300 Subject: [PATCH 3/4] Update modeling_gemma4.py From 80f2d35737785c7e694fb8a20046edebbfc7a690 Mon Sep 17 00:00:00 2001 From: omar zoloev Date: Thu, 23 Apr 2026 19:43:11 +0300 Subject: [PATCH 4/4] Update modeling_gemma4.py --- src/transformers/models/gemma4/modeling_gemma4.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/gemma4/modeling_gemma4.py b/src/transformers/models/gemma4/modeling_gemma4.py index 5e9f720fe800..978a0bda8cff 100644 --- a/src/transformers/models/gemma4/modeling_gemma4.py +++ b/src/transformers/models/gemma4/modeling_gemma4.py @@ -189,7 +189,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class Gemma4AudioRelPositionalEncoding(nn.Module): """Sinusoidal relative positional encoding for the audio encoder. - Produces position embeddings of shape [1, context_size // 2 + 1, hidden_size] with + Produces position embeddings of shape [1, 2*context_size - 1, hidden_size] with concatenated [sin..., cos...] layout matching the original Gemma4 convention. """ @@ -210,7 +210,7 @@ def __init__(self, config: Gemma4AudioConfig): @torch.no_grad() def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - position_ids = torch.arange(self.context_size // 2, -1, -1, device=hidden_states.device) + position_ids = torch.arange(12, -1, -1, device=hidden_states.device) position_ids = position_ids[..., None] scaled_time = position_ids * self.inv_timescales.to(device=hidden_states.device) pos_embed = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1) @@ -1133,7 +1133,6 @@ def forward(self, x, position_ids, layer_type=None): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -@use_kernelized_func(apply_rotary_pos_emb) class Gemma4TextAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -1236,9 +1235,9 @@ def forward( if self.store_full_length_kv: shared_kv_states[self.layer_idx] = key_states, value_states - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) attn_output, attn_weights = attention_interface( self,