fix(moe): normalize auxiliary loss by top_k for correct load balancing#43775
Open
Mr-Neutr0n wants to merge 3 commits intohuggingface:mainfrom
Open
fix(moe): normalize auxiliary loss by top_k for correct load balancing#43775Mr-Neutr0n wants to merge 3 commits intohuggingface:mainfrom
Mr-Neutr0n wants to merge 3 commits intohuggingface:mainfrom
Conversation
The auxiliary load balancing loss in MoE models was not correctly normalized when top_k > 1. The tokens_per_expert distribution (f_i) was summing to K instead of 1, while router_prob_per_expert (P_i) sums to 1, making the loss calculation incorrect. According to DeepSeek-MoE and megablocks implementations, f_i should be normalized by K so that both distributions represent the same scale: Before: sum(f_i) = K, sum(P_i) = 1 After: sum(f_i) = 1, sum(P_i) = 1 This ensures the load balancing loss correctly penalizes unbalanced routing when using top-k routing with k > 1. Fixes huggingface#43688 Signed-off-by: Harikrishna KP <harikp2002@gmail.com>
Apply the same top_k normalization fix to the generated modeling file so it matches the modular source file and passes CI consistency check. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
The top_k normalization fix in modular_mixtral.py propagates to all MoE models that inherit load_balancing_loss_func from mixtral. Regenerated modeling files for: - dbrx, ernie4_5_moe, ernie4_5_vl_moe, flex_olmo, glm4v_moe - gpt_oss, granitemoe, granitemoehybrid, granitemoeshared - jamba, jetmoe, minimax, minimax_m2, olmoe, phimoe - qwen2_moe, qwen3_moe, qwen3_next, qwen3_omni_moe, qwen3_vl_moe Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Contributor
|
[For maintainers] Suggested jobs to run (before merge) run-slow: dbrx, ernie4_5_moe, ernie4_5_vl_moe, flex_olmo, glm4v_moe, gpt_oss, granitemoe, granitemoehybrid, granitemoeshared, jamba, jetmoe, minimax, minimax_m2, mixtral, olmoe, phimoe |
This was referenced Apr 29, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fixes #43688
The auxiliary load balancing loss in MoE models was not correctly normalized when
top_k > 1. Thetokens_per_expertdistribution (f_i) was summing to K instead of 1, whilerouter_prob_per_expert(P_i) sums to 1, making the loss calculation incorrect.Before (incorrect):
After (correct):
Mathematical Background
From the Switch Transformer paper, the load balancing loss is:
Where:
For this dot product to work correctly, both$f_i$ and $P_i$ should represent the same scale (probability distributions that sum to 1).
When using top-k routing:
The fix divides
tokens_per_expertbytop_kto normalize the distribution.This matches the megablocks implementation and the approach described in DeepSeek-MoE.
Changes
/top_knormalization inload_balancing_loss_funcfor both attention mask and non-attention mask branchesAffected Models
This fix is in
modular_mixtral.pywhich propagates to all MoE models using the mixtral-style load balancing loss: