Skip to content

fix(moe): normalize auxiliary loss by top_k for correct load balancing#43775

Open
Mr-Neutr0n wants to merge 3 commits intohuggingface:mainfrom
Mr-Neutr0n:fix/moe-aux-loss-normalization
Open

fix(moe): normalize auxiliary loss by top_k for correct load balancing#43775
Mr-Neutr0n wants to merge 3 commits intohuggingface:mainfrom
Mr-Neutr0n:fix/moe-aux-loss-normalization

Conversation

@Mr-Neutr0n
Copy link
Copy Markdown
Contributor

Summary

Fixes #43688

The auxiliary load balancing loss in MoE models was not correctly normalized when top_k > 1. The tokens_per_expert distribution (f_i) was summing to K instead of 1, while router_prob_per_expert (P_i) sums to 1, making the loss calculation incorrect.

Before (incorrect):

sum(f_i) = K
sum(P_i) = 1

After (correct):

sum(f_i) = 1  
sum(P_i) = 1

Mathematical Background

From the Switch Transformer paper, the load balancing loss is:

$$L = N \cdot \sum_{i=1}^{N} f_i \cdot P_i$$

Where:

  • $f_i$ = fraction of tokens routed to expert $i$
  • $P_i$ = average routing probability to expert $i$

For this dot product to work correctly, both $f_i$ and $P_i$ should represent the same scale (probability distributions that sum to 1).

When using top-k routing:

  • Each token picks K experts, so without normalization: $\sum f_i = K$
  • Router probabilities are softmax: $\sum P_i = 1$

The fix divides tokens_per_expert by top_k to normalize the distribution.

This matches the megablocks implementation and the approach described in DeepSeek-MoE.

Changes

  • Add /top_k normalization in load_balancing_loss_func for both attention mask and non-attention mask branches

Affected Models

This fix is in modular_mixtral.py which propagates to all MoE models using the mixtral-style load balancing loss:

  • OLMoE, GPT-OSS, Qwen2-MoE, Jamba, JetMoE, Phi-MoE, etc.

Mr-Neutr0n and others added 3 commits February 5, 2026 22:17
The auxiliary load balancing loss in MoE models was not correctly
normalized when top_k > 1. The tokens_per_expert distribution (f_i)
was summing to K instead of 1, while router_prob_per_expert (P_i)
sums to 1, making the loss calculation incorrect.

According to DeepSeek-MoE and megablocks implementations, f_i should
be normalized by K so that both distributions represent the same scale:

Before: sum(f_i) = K, sum(P_i) = 1
After:  sum(f_i) = 1, sum(P_i) = 1

This ensures the load balancing loss correctly penalizes unbalanced
routing when using top-k routing with k > 1.

Fixes huggingface#43688

Signed-off-by: Harikrishna KP <harikp2002@gmail.com>
Apply the same top_k normalization fix to the generated modeling file
so it matches the modular source file and passes CI consistency check.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
The top_k normalization fix in modular_mixtral.py propagates to all
MoE models that inherit load_balancing_loss_func from mixtral.

Regenerated modeling files for:
- dbrx, ernie4_5_moe, ernie4_5_vl_moe, flex_olmo, glm4v_moe
- gpt_oss, granitemoe, granitemoehybrid, granitemoeshared
- jamba, jetmoe, minimax, minimax_m2, olmoe, phimoe
- qwen2_moe, qwen3_moe, qwen3_next, qwen3_omni_moe, qwen3_vl_moe

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Feb 5, 2026

[For maintainers] Suggested jobs to run (before merge)

run-slow: dbrx, ernie4_5_moe, ernie4_5_vl_moe, flex_olmo, glm4v_moe, gpt_oss, granitemoe, granitemoehybrid, granitemoeshared, jamba, jetmoe, minimax, minimax_m2, mixtral, olmoe, phimoe

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Incorrect normalization of auxiliary loss in OLMoE and GPT Oss

1 participant