diff --git a/tests/models/gemma4/test_modeling_gemma4.py b/tests/models/gemma4/test_modeling_gemma4.py index e49232b5fed4..355b836b7639 100644 --- a/tests/models/gemma4/test_modeling_gemma4.py +++ b/tests/models/gemma4/test_modeling_gemma4.py @@ -270,6 +270,34 @@ def test_num_layers_is_small(self): def test_generate_from_random_inputs_embeds(self): pass + def test_audio_rel_pos_encoding_uses_context_size_from_config(self): + from transformers.models.gemma4.configuration_gemma4 import Gemma4AudioConfig + from transformers.models.gemma4.modeling_gemma4 import Gemma4AudioRelPositionalEncoding + + config = Gemma4AudioConfig( + hidden_size=32, + attention_chunk_size=6, + attention_context_left=5, + attention_context_right=1, + use_clipped_linears=False, + ) + + module = Gemma4AudioRelPositionalEncoding(config) + hidden_states = torch.zeros(1, 3, config.hidden_size) + + pos = module(hidden_states) + + context_size = config.attention_chunk_size + config.attention_context_left - 1 + config.attention_context_right + expected_len = context_size // 2 + 1 + + self.assertEqual(pos.shape, (1, expected_len, config.hidden_size)) + + position_ids = torch.arange(context_size // 2, -1, -1, device=hidden_states.device)[..., None] + scaled_time = position_ids * module.inv_timescales.to(device=hidden_states.device) + expected = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1).to(hidden_states.dtype) + + torch.testing.assert_close(pos, expected) + class Gemma4Vision2TextModelTester: def __init__(