Skip to content

Fix double softmax in MoE router load-balancing loss#45111

Closed
ionut-anghelina wants to merge 1 commit intohuggingface:mainfrom
ionut-anghelina:dev/ionut/FixDoubleSoftmax
Closed

Fix double softmax in MoE router load-balancing loss#45111
ionut-anghelina wants to merge 1 commit intohuggingface:mainfrom
ionut-anghelina:dev/ionut/FixDoubleSoftmax

Conversation

@ionut-anghelina
Copy link
Copy Markdown

Summary

  • Several MoE routers applied softmax inside forward() but returned the result as router_logits. The load_balancing_loss_func then applied softmax again, computing the aux loss on softmax(softmax(logits)) which flattens the distribution toward uniform, rendering the load-balancing loss ineffective.
  • Fix: use a separate router_probs variable for the softmaxed values used in top-k routing, keeping router_logits as raw logits so the loss function's single softmax is correct.

Source modular files fixed (3):

  • mixtral/modular_mixtral.pyMixtralTopKRouter
  • qwen2_moe/modular_qwen2_moe.pyQwen2MoeTopKRouter
  • qwen3_vl_moe/modular_qwen3_vl_moe.pyQwen3VLMoeTextTopKRouter

Downstream models regenerated by make fix-repo (10):

mixtral, minimax, qwen2_moe, olmoe, flex_olmo, qwen3_moe, qwen3_next, qwen3_omni_moe, qwen3_vl_moe, qwen3_5_moe

Test plan

  • Verified with debug script that router_logits now returns raw logits (row sums ≠ 1.0) instead of probabilities
  • Confirmed load_balancing_loss_func applies softmax exactly once
  • Run MoE model tests: pytest tests/models/mixtral/ tests/models/qwen2_moe/ tests/models/qwen3_vl_moe/ -v

🤖 Generated with Claude Code

Several MoE routers applied softmax to raw logits inside forward() but
returned the result as `router_logits`. The load_balancing_loss_func then
applied softmax again, computing the aux loss on softmax(softmax(logits))
which flattens the distribution toward uniform, rendering the load-balancing
loss ineffective.

Fix: use a separate `router_probs` variable for the softmaxed values used
in top-k routing, keeping `router_logits` as raw logits so the loss
function's single softmax is correct.

Source modular files fixed:
- mixtral/modular_mixtral.py (MixtralTopKRouter)
- qwen2_moe/modular_qwen2_moe.py (Qwen2MoeTopKRouter)
- qwen3_vl_moe/modular_qwen3_vl_moe.py (Qwen3VLMoeTextTopKRouter)

Downstream models regenerated by make fix-repo:
mixtral, minimax, qwen2_moe, olmoe, flex_olmo, qwen3_moe, qwen3_next,
qwen3_omni_moe, qwen3_vl_moe, qwen3_5_moe

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: flex_olmo, minimax, mixtral, olmoe, qwen2_moe, qwen3_5_moe, qwen3_moe, qwen3_next, qwen3_omni_moe, qwen3_vl_moe

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants