diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py index d6c1f61dc43f..2c3ac27a28ff 100644 --- a/src/transformers/models/llama4/modeling_llama4.py +++ b/src/transformers/models/llama4/modeling_llama4.py @@ -1015,7 +1015,10 @@ def __init__(self, config): frequencies_x = img_idx % idx # get the coordinates of the 2d matrix along x frequencies_y = img_idx // idx # get the coordinates of the 2d matrix along y freq_dim = config.hidden_size // config.num_attention_heads // 2 - rope_freq = 1.0 / (config.rope_theta ** (torch.arange(0, freq_dim, 2)[: (freq_dim // 2)].float() / freq_dim)) + rope_freq = 1.0 / ( + config.rope_parameters["rope_theta"] + ** (torch.arange(0, freq_dim, 2)[: (freq_dim // 2)].float() / freq_dim) + ) freqs_x = ((frequencies_x + 1)[..., None] * rope_freq[None, None, :]).repeat_interleave(2, dim=-1) freqs_y = ((frequencies_y + 1)[..., None] * rope_freq[None, None, :]).repeat_interleave(2, dim=-1) freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2]