Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/transformers/models/mistral4/modeling_mistral4.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ def compute_default_rope_parameters(
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
"""
base = config.rope_parameters["rope_theta"]
dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
dim = int(head_dim * config.rope_parameters.get("partial_rotary_factor", 1.0))

attention_factor = 1.0 # Unused in this type of RoPE

Expand Down
10 changes: 4 additions & 6 deletions src/transformers/models/olmoe/modeling_olmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,8 @@ def __init__(self, config: OlmoeConfig, layer_idx: int | None = None):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
self.q_norm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.k_norm = OlmoeRMSNorm(
(config.hidden_size // config.num_attention_heads) * config.num_key_value_heads, eps=config.rms_norm_eps
)
self.q_norm = OlmoeRMSNorm(config.num_attention_heads * self.head_dim, eps=config.rms_norm_eps)
self.k_norm = OlmoeRMSNorm(config.num_key_value_heads * self.head_dim, eps=config.rms_norm_eps)

def forward(
self,
Expand Down Expand Up @@ -350,8 +348,8 @@ def __init__(self, config):
def forward(self, hidden_states):
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts)
router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1)
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
router_probs = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1)
router_top_value, router_indices = torch.topk(router_probs, self.top_k, dim=-1) # (seq_len, top_k)
if self.norm_topk_prob:
router_top_value /= router_top_value.sum(dim=-1, keepdim=True)
router_top_value = router_top_value.to(router_logits.dtype)
Expand Down
6 changes: 2 additions & 4 deletions src/transformers/models/olmoe/modular_olmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,8 @@ class OlmoeMLP(GemmaMLP):
class OlmoeAttention(LlamaAttention):
def __init__(self, config: OlmoeConfig, layer_idx: int | None = None):
super().__init__(config, layer_idx)
self.q_norm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.k_norm = OlmoeRMSNorm(
(config.hidden_size // config.num_attention_heads) * config.num_key_value_heads, eps=config.rms_norm_eps
)
self.q_norm = OlmoeRMSNorm(config.num_attention_heads * self.head_dim, eps=config.rms_norm_eps)
self.k_norm = OlmoeRMSNorm(config.num_key_value_heads * self.head_dim, eps=config.rms_norm_eps)

def forward(
self,
Expand Down
14 changes: 14 additions & 0 deletions tests/models/mistral4/test_modeling_mistral4.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
Mistral4Model,
)

from transformers.models.mistral4.configuration_mistral4 import Mistral4Config
from transformers.models.mistral4.modeling_mistral4 import Mistral4RotaryEmbedding


from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester

Expand Down Expand Up @@ -81,6 +84,17 @@ class Mistral4IntegrationTest(unittest.TestCase):
def tearDown(self):
cleanup(torch_device, gc_collect=True)

def test_default_rope_uses_rotary_head_dim(self):
config = Mistral4Config(
rope_parameters={"type": "default", "rope_theta": 10000.0},
qk_nope_head_dim=96,
qk_rope_head_dim=32,
)

rotary_embedding = Mistral4RotaryEmbedding(config)

self.assertEqual(rotary_embedding.inv_freq.shape[0] * 2, config.qk_rope_head_dim)

@slow
def test_mistral_small_4_logits(self):
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
Expand Down
42 changes: 42 additions & 0 deletions tests/models/olmoe/test_modeling_olmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
OlmoeForCausalLM,
OlmoeModel,
)
from transformers.models.olmoe.modeling_olmoe import OlmoeAttention
from transformers.models.olmoe.modeling_olmoe import OlmoeTopKRouter


class OlmoeModelTester:
Expand Down Expand Up @@ -193,6 +195,46 @@ class OlmoeModelTest(
# This is because we are hitting edge cases with the causal_mask buffer
model_split_percents = [0.5, 0.7, 0.8]

def test_router_returns_raw_logits(self):
config = OlmoeConfig(hidden_size=4, num_experts=3, num_experts_per_tok=2, norm_topk_prob=False)
router = OlmoeTopKRouter(config)

with torch.no_grad():
router.weight.copy_(
torch.tensor([[1.0, 0.0, -1.0, 2.0], [0.5, 1.0, 0.0, -0.5], [-1.0, 0.5, 1.5, 0.0]])
)

hidden_states = torch.tensor([[[2.0, -1.0, 0.5, 3.0]]])
router_logits, router_scores, router_indices = router(hidden_states)

expected_logits = torch.nn.functional.linear(hidden_states.reshape(-1, 4), router.weight)
expected_probs = torch.nn.functional.softmax(expected_logits, dim=-1)
expected_top_values, expected_top_indices = torch.topk(expected_probs, 2, dim=-1)

self.assertTrue(torch.equal(router_logits, expected_logits))
self.assertTrue(torch.equal(router_indices, expected_top_indices))
self.assertTrue(torch.allclose(router_scores, expected_top_values.to(router_scores.dtype)))

def test_attention_uses_head_dim_in_norms(self):
config = OlmoeConfig(
hidden_size=8,
num_attention_heads=2,
num_key_value_heads=1,
num_hidden_layers=1,
intermediate_size=16,
vocab_size=32,
)
config.head_dim = 2
attention = OlmoeAttention(config, layer_idx=0)

hidden_states = torch.randn(1, 1, 8)
position_embeddings = (torch.ones(1, 1, 2), torch.zeros(1, 1, 2))

output, attn_weights = attention(hidden_states, position_embeddings, attention_mask=None)

self.assertEqual(output.shape, (1, 1, 8))
self.assertIsNone(attn_weights)

def setUp(self):
self.model_tester = OlmoeModelTester(self)
self.config_tester = ConfigTester(self, config_class=OlmoeConfig, hidden_size=32)
Expand Down
Loading