From de9d2f47edb14491b882a071552826d7eb1661f4 Mon Sep 17 00:00:00 2001 From: owwll Date: Sat, 11 Apr 2026 02:33:20 +0530 Subject: [PATCH] Fix OLMoE routing and Mistral4 RoPE dimensions --- .../models/mistral4/modeling_mistral4.py | 3 +- .../models/olmoe/modeling_olmoe.py | 10 ++--- .../models/olmoe/modular_olmoe.py | 6 +-- .../models/mistral4/test_modeling_mistral4.py | 14 +++++++ tests/models/olmoe/test_modeling_olmoe.py | 42 +++++++++++++++++++ 5 files changed, 64 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/mistral4/modeling_mistral4.py b/src/transformers/models/mistral4/modeling_mistral4.py index 006ddad187bf..5884288491eb 100644 --- a/src/transformers/models/mistral4/modeling_mistral4.py +++ b/src/transformers/models/mistral4/modeling_mistral4.py @@ -106,7 +106,8 @@ def compute_default_rope_parameters( post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). """ base = config.rope_parameters["rope_theta"] - dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + dim = int(head_dim * config.rope_parameters.get("partial_rotary_factor", 1.0)) attention_factor = 1.0 # Unused in this type of RoPE diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index 8a83315a5820..540558c8b8b7 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -243,10 +243,8 @@ def __init__(self, config: OlmoeConfig, layer_idx: int | None = None): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) - self.q_norm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.k_norm = OlmoeRMSNorm( - (config.hidden_size // config.num_attention_heads) * config.num_key_value_heads, eps=config.rms_norm_eps - ) + self.q_norm = OlmoeRMSNorm(config.num_attention_heads * self.head_dim, eps=config.rms_norm_eps) + self.k_norm = OlmoeRMSNorm(config.num_key_value_heads * self.head_dim, eps=config.rms_norm_eps) def forward( self, @@ -350,8 +348,8 @@ def __init__(self, config): def forward(self, hidden_states): hidden_states = hidden_states.reshape(-1, self.hidden_dim) router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts) - router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1) - router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) + router_probs = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1) + router_top_value, router_indices = torch.topk(router_probs, self.top_k, dim=-1) # (seq_len, top_k) if self.norm_topk_prob: router_top_value /= router_top_value.sum(dim=-1, keepdim=True) router_top_value = router_top_value.to(router_logits.dtype) diff --git a/src/transformers/models/olmoe/modular_olmoe.py b/src/transformers/models/olmoe/modular_olmoe.py index 9fee40493496..4970fd3ca0f3 100644 --- a/src/transformers/models/olmoe/modular_olmoe.py +++ b/src/transformers/models/olmoe/modular_olmoe.py @@ -58,10 +58,8 @@ class OlmoeMLP(GemmaMLP): class OlmoeAttention(LlamaAttention): def __init__(self, config: OlmoeConfig, layer_idx: int | None = None): super().__init__(config, layer_idx) - self.q_norm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.k_norm = OlmoeRMSNorm( - (config.hidden_size // config.num_attention_heads) * config.num_key_value_heads, eps=config.rms_norm_eps - ) + self.q_norm = OlmoeRMSNorm(config.num_attention_heads * self.head_dim, eps=config.rms_norm_eps) + self.k_norm = OlmoeRMSNorm(config.num_key_value_heads * self.head_dim, eps=config.rms_norm_eps) def forward( self, diff --git a/tests/models/mistral4/test_modeling_mistral4.py b/tests/models/mistral4/test_modeling_mistral4.py index 449e13461264..a90ab23b1557 100644 --- a/tests/models/mistral4/test_modeling_mistral4.py +++ b/tests/models/mistral4/test_modeling_mistral4.py @@ -39,6 +39,9 @@ Mistral4Model, ) + from transformers.models.mistral4.configuration_mistral4 import Mistral4Config + from transformers.models.mistral4.modeling_mistral4 import Mistral4RotaryEmbedding + from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester @@ -81,6 +84,17 @@ class Mistral4IntegrationTest(unittest.TestCase): def tearDown(self): cleanup(torch_device, gc_collect=True) + def test_default_rope_uses_rotary_head_dim(self): + config = Mistral4Config( + rope_parameters={"type": "default", "rope_theta": 10000.0}, + qk_nope_head_dim=96, + qk_rope_head_dim=32, + ) + + rotary_embedding = Mistral4RotaryEmbedding(config) + + self.assertEqual(rotary_embedding.inv_freq.shape[0] * 2, config.qk_rope_head_dim) + @slow def test_mistral_small_4_logits(self): input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] diff --git a/tests/models/olmoe/test_modeling_olmoe.py b/tests/models/olmoe/test_modeling_olmoe.py index 1f3aa0c76b10..db62e961047f 100644 --- a/tests/models/olmoe/test_modeling_olmoe.py +++ b/tests/models/olmoe/test_modeling_olmoe.py @@ -39,6 +39,8 @@ OlmoeForCausalLM, OlmoeModel, ) + from transformers.models.olmoe.modeling_olmoe import OlmoeAttention + from transformers.models.olmoe.modeling_olmoe import OlmoeTopKRouter class OlmoeModelTester: @@ -193,6 +195,46 @@ class OlmoeModelTest( # This is because we are hitting edge cases with the causal_mask buffer model_split_percents = [0.5, 0.7, 0.8] + def test_router_returns_raw_logits(self): + config = OlmoeConfig(hidden_size=4, num_experts=3, num_experts_per_tok=2, norm_topk_prob=False) + router = OlmoeTopKRouter(config) + + with torch.no_grad(): + router.weight.copy_( + torch.tensor([[1.0, 0.0, -1.0, 2.0], [0.5, 1.0, 0.0, -0.5], [-1.0, 0.5, 1.5, 0.0]]) + ) + + hidden_states = torch.tensor([[[2.0, -1.0, 0.5, 3.0]]]) + router_logits, router_scores, router_indices = router(hidden_states) + + expected_logits = torch.nn.functional.linear(hidden_states.reshape(-1, 4), router.weight) + expected_probs = torch.nn.functional.softmax(expected_logits, dim=-1) + expected_top_values, expected_top_indices = torch.topk(expected_probs, 2, dim=-1) + + self.assertTrue(torch.equal(router_logits, expected_logits)) + self.assertTrue(torch.equal(router_indices, expected_top_indices)) + self.assertTrue(torch.allclose(router_scores, expected_top_values.to(router_scores.dtype))) + + def test_attention_uses_head_dim_in_norms(self): + config = OlmoeConfig( + hidden_size=8, + num_attention_heads=2, + num_key_value_heads=1, + num_hidden_layers=1, + intermediate_size=16, + vocab_size=32, + ) + config.head_dim = 2 + attention = OlmoeAttention(config, layer_idx=0) + + hidden_states = torch.randn(1, 1, 8) + position_embeddings = (torch.ones(1, 1, 2), torch.zeros(1, 1, 2)) + + output, attn_weights = attention(hidden_states, position_embeddings, attention_mask=None) + + self.assertEqual(output.shape, (1, 1, 8)) + self.assertIsNone(attn_weights) + def setUp(self): self.model_tester = OlmoeModelTester(self) self.config_tester = ConfigTester(self, config_class=OlmoeConfig, hidden_size=32)