Fix: Remove double softmax in MoE router load-balancing loss (Mixtral, Qwen2MoE, Qwen3VLMoE)#45132
Closed
akintunero wants to merge 6 commits intohuggingface:mainfrom
Closed
Fix: Remove double softmax in MoE router load-balancing loss (Mixtral, Qwen2MoE, Qwen3VLMoE)#45132akintunero wants to merge 6 commits intohuggingface:mainfrom
akintunero wants to merge 6 commits intohuggingface:mainfrom
Conversation
…, Qwen2MoE, Qwen3VLMoE) Fixes issue huggingface#45120: MoE routers were applying softmax to raw logits inside their forward() method, then returning the softmaxed values as 'router_logits'. The load_balancing_loss_func then applied softmax AGAIN, computing the auxiliary loss on softmax(softmax(logits)). This flattened the routing probability distribution toward uniform, making the load-balancing loss ineffective for fine-tuning. Solution: Keep raw logits as the return value, use separate variable for softmaxed probabilities during routing decision. Changes: - MixtralTopKRouter: Renamed router_logits reassignment to router_probs - Qwen2MoeTopKRouter: Renamed router_logits reassignment to router_probs - Qwen3VLMoeTextTopKRouter: Renamed router_logits reassignment to router_probs - Added comprehensive unit tests in all three model test files Impact: - Fine-tuning: Load-balancing loss now receives correct raw logits - Inference: No changes (routing logic uses probabilities internally) - Backward compatibility: Fully maintained See: huggingface#45120
Moves the set_seed import from inside the test method to the top-level imports to comply with PEP8 E402 (module-import-not-at-top-of-file) and avoid potential linting violations flagged by CircleCI.
…delTest The test was incorrectly placed in MixtralIntegrationTest which doesn't have the model_tester fixture. Moving it to MixtralModelTest ensures it has access to self.model_tester.prepare_config_and_inputs_for_common().
Removes disconnected test code that was left over from moving the test_router_logits method between test classes.
Contributor
|
[For maintainers] Suggested jobs to run (before merge) run-slow: mixtral, qwen2_moe, qwen3_vl_moe |
|
Closing this PR to refix the issue with a cleaner approach. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR fixes GitHub issue #45120: "Double softmax in MoE router load-balancing loss". MoE routers in Mixtral, Qwen2MoE, and Qwen3VLMoE were applying softmax inside forward(), then the load_balancing_loss_func applied softmax AGAIN, resulting in softmax(softmax(logits)) which flattened routing probabilities.
Root Cause
Three routers reassigned
router_logitswith softmaxed values, then returned them to load_balancing_loss_func which expected raw logits.Solution
Separated concepts: keep raw logits in
router_logits, userouter_probsfor softmaxed values during routing.Changes
Impact
Fixes #45120