Skip to content

Fix: Remove double softmax in MoE router load-balancing loss (Mixtral, Qwen2MoE, Qwen3VLMoE)#45132

Closed
akintunero wants to merge 6 commits intohuggingface:mainfrom
akintunero:fix/issue-45120-double-softmax-moe
Closed

Fix: Remove double softmax in MoE router load-balancing loss (Mixtral, Qwen2MoE, Qwen3VLMoE)#45132
akintunero wants to merge 6 commits intohuggingface:mainfrom
akintunero:fix/issue-45120-double-softmax-moe

Conversation

@akintunero
Copy link
Copy Markdown
Contributor

Summary

This PR fixes GitHub issue #45120: "Double softmax in MoE router load-balancing loss". MoE routers in Mixtral, Qwen2MoE, and Qwen3VLMoE were applying softmax inside forward(), then the load_balancing_loss_func applied softmax AGAIN, resulting in softmax(softmax(logits)) which flattened routing probabilities.

Root Cause

Three routers reassigned router_logits with softmaxed values, then returned them to load_balancing_loss_func which expected raw logits.

Solution

Separated concepts: keep raw logits in router_logits, use router_probs for softmaxed values during routing.

Changes

  • MixtralTopKRouter: Renamed softmax reassignment to router_probs
  • Qwen2MoeTopKRouter: Renamed softmax reassignment to router_probs
  • Qwen3VLMoeTextTopKRouter: Renamed softmax reassignment to router_probs
  • Added comprehensive unit tests verifying router_logits are raw logits

Impact

  • Fine-tuning: Load-balancing loss now receives correct raw logits
  • Inference: No changes
  • Backward compatible: Only internal computation affected

Fixes #45120

…, Qwen2MoE, Qwen3VLMoE)

Fixes issue huggingface#45120: MoE routers were applying softmax to raw logits inside
their forward() method, then returning the softmaxed values as 'router_logits'.
The load_balancing_loss_func then applied softmax AGAIN, computing the
auxiliary loss on softmax(softmax(logits)). This flattened the routing
probability distribution toward uniform, making the load-balancing loss
ineffective for fine-tuning.

Solution: Keep raw logits as the return value, use separate variable for
softmaxed probabilities during routing decision.

Changes:
- MixtralTopKRouter: Renamed router_logits reassignment to router_probs
- Qwen2MoeTopKRouter: Renamed router_logits reassignment to router_probs
- Qwen3VLMoeTextTopKRouter: Renamed router_logits reassignment to router_probs
- Added comprehensive unit tests in all three model test files

Impact:
- Fine-tuning: Load-balancing loss now receives correct raw logits
- Inference: No changes (routing logic uses probabilities internally)
- Backward compatibility: Fully maintained

See: huggingface#45120
Moves the set_seed import from inside the test method to the top-level imports
to comply with PEP8 E402 (module-import-not-at-top-of-file) and avoid potential
linting violations flagged by CircleCI.
…delTest

The test was incorrectly placed in MixtralIntegrationTest which doesn't have
the model_tester fixture. Moving it to MixtralModelTest ensures it has access
to self.model_tester.prepare_config_and_inputs_for_common().
Removes disconnected test code that was left over from moving the
test_router_logits method between test classes.
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: mixtral, qwen2_moe, qwen3_vl_moe

@olumayowa-chaoshi
Copy link
Copy Markdown

Closing this PR to refix the issue with a cleaner approach.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Double softmax in MoE router load-balancing loss (mixtral, qwen2_moe, qwen3_vl_moe families)

2 participants