Skip to content

Fix torch.export compatibility for Mixtral MoE models #40114

Open
akacmazz wants to merge 13 commits intohuggingface:mainfrom
akacmazz:fix-mixtral-torch-export-compatibility
Open

Fix torch.export compatibility for Mixtral MoE models #40114
akacmazz wants to merge 13 commits intohuggingface:mainfrom
akacmazz:fix-mixtral-torch-export-compatibility

Conversation

@akacmazz
Copy link
Copy Markdown

@akacmazz akacmazz commented Aug 12, 2025

  • Replace data-dependent .nonzero() operation with static expert loop
  • Resolves GuardOnDataDependentSymNode error during torch.export
  • Maintains identical functionality while enabling export compatibility
  • Fixes issue introduced in PR Skip non-selected experts for mixtral and qwen2_moe #32429
  • Add tests for torch.export compatibility

What does this PR do?

This PR fixes a torch.export compatibility issue #38518 with Mixtral MoE models that was introduced in PR #32429.

Problem

The optimization in PR #32429 introduced a .nonzero() operation that creates data-dependent tensor shapes, causing torch.export to fail with:
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not extract specialized integer from data-dependent expression

Solution

Replace the dynamic expert selection loop:

expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:

With a static loop over all experts:
for expert_idx in range(self.num_experts):

Impact

  • ✅ Enables torch.export compatibility for Mixtral models
  • ✅ Maintains identical functionality (empty experts contribute 0 naturally)
  • ✅ Minimal performance impact (same computation, different loop structure)
  • ✅ Consistent with other MoE implementations (Jamba, DBRX)

Testing

  • Verified torch.export works without errors
  • Confirmed functionality preservation with identical outputs
  • Tested with various input configurations

Fixes torch.export compatibility issues reported for Mixtral-8x7B models.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.

@Cyrilvallez
@ArthurZucker
@gante

- Replace data-dependent .nonzero() operation with static expert loop
- Resolves GuardOnDataDependentSymNode error during torch.export
- Maintains identical functionality while enabling export compatibility
- Fixes issue introduced in PR huggingface#32429
- Add tests for torch.export compatibility
@akacmazz akacmazz changed the title Fix torch.export compatibility for Mixtral MoE models Fix torch.export compatibility for Mixtral MoE models Aug 12, 2025
- Auto-generate modeling_mixtral.py with the same fix
- Apply black formatting
- Fix repository consistency check
@akacmazz akacmazz closed this Aug 12, 2025
@akacmazz akacmazz reopened this Aug 12, 2025
@akacmazz akacmazz force-pushed the fix-mixtral-torch-export-compatibility branch 3 times, most recently from 9b41625 to c3e3c5e Compare August 12, 2025 20:23
Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR, but the reason we cannot do that is because it is really a lot lot less efficient! I would recommend to rather have 2 paths, 1 for training one for inference, inference can just not loop at all and repeat the inputs

@akacmazz
Copy link
Copy Markdown
Author

the reason we cannot do that is because it is really a lot lot less efficient! I would recommend to rather have 2 paths, 1 for training one for inference, inference can just not loop at all and repeat the inputs

Thx for the feedback, i will work on this

…ility

  - Training path: Keep efficient .nonzero() for performance
  - Inference path: Use static loop for torch.export compatibility
  - Add conditional check to skip empty experts in inference
  - Update tests to validate inference mode export
  - Addresses maintainer feedback on performance concerns
@akacmazz akacmazz force-pushed the fix-mixtral-torch-export-compatibility branch from 952a181 to 0aa9de7 Compare August 13, 2025 08:51
akacmazz and others added 9 commits August 13, 2025 12:23
- Apply black formatting to fix code style
- Fix import sorting with isort
- Address CI code quality checks
- Fix import organization in modeling_mixtral.py
- Fix import organization in modular_mixtral.py
- Address ruff I001 import sorting warnings
- Remove manually edited modeling_mixtral.py
- Auto-generate from modular_mixtral.py using proper tool
- Ensure consistency between modular and generated files
- Fix check_repository_consistency CI failure
- Remove 'if top_x.shape[0] == 0: continue' check that causes GuardOnDataDependentSymNode error
- Empty expert tensors naturally contribute 0, no explicit check needed
- Update test error message for clarity
- Fixes tests_processors CI failure

Co-authored-by: ArthurZucker <arthur.zucker@gmail.com>
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: mixtral

@sshonTT
Copy link
Copy Markdown

sshonTT commented Sep 15, 2025

@akacmazz , @ArthurZucker Hi, could I know the blocker for this? I found this commit from #38518 , and I just want to know how things are going on.

Comment on lines +147 to +160
# Inference path: loop over all experts for torch.export compatibility
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])

# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]

# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! we should avoid this, it will really slow down "all" use cases. Is there no way to still export but have a different forward? Like using a kernel from the hub / using the bmm version? Or guarding manually or something?

@t8
Copy link
Copy Markdown

t8 commented Apr 30, 2026

Hi @ArthurZucker @Cyrilvallez @akacmazz, dropping CPU benchmark data in case it helps unblock review.

I needed to export Mixtral via torch.export and ran the perf comparison between three forms of MixtralSparseMoeBlock.forward:

  • A — .nonzero(): the form on main (and 4.53.3) that this PR is trying to fix.
  • B — static range(num_experts) loop: the inference path proposed by this PR.
  • C — dense (no loop, repeat inputs through all experts, gate by routing weight): the form @ArthurZucker suggested in the first review (inference can just not loop at all and repeat the inputs).

Setup

  • transformers==4.53.3 (the .nonzero() inference path matches main for the relevant code), torch==2.7.0+cpu
  • Single MixtralSparseMoeBlock at random init, Mixtral 8x7B config: hidden_size=4096, intermediate_size=14336, num_local_experts=8, num_experts_per_tok=2
  • Linux x86_64, 24-core Xeon Gold 6252, 24 PyTorch threads, fp32
  • Median of 10 iters per scenario, 3 warmup iters
  • Output equality verified: A vs B is bit-identical (max-abs-diff = 0.00e+00); A vs C is 3.05e-07 (different summation order)

Results — single MixtralSparseMoeBlock.forward, fp32, ms

scenario A .nonzero() B static-loop (this PR) C dense (no loop) B/A C/A
decode b=1 s=1 31.5 ms 32.2 ms 123.9 ms 1.022x 3.927x
decode b=4 s=1 93.4 ms 93.4 ms 239.5 ms 1.000x 2.564x
prefill b=1 s=128 302.2 ms 300.0 ms 487.7 ms 0.993x 1.614x
prefill b=1 s=512 393.1 ms 394.7 ms 1593.1 ms 1.004x 4.052x
prefill b=1 s=2048 1388.7 ms 1371.3 ms 3491.6 ms 0.987x 2.514x

I re-ran prefill seq=2048 separately with 5 warmups + 25 iters and again with A/B interleaved 10x each (to mitigate cache/thermal effects):

prefill seq=2048 confirmation (5 warmups + 25 iters):
  A nonzero:     median=1335.1 ms  IQR=[1333.7-1338.2]
  B static-loop: median=1334.8 ms  IQR=[1331.9-1336.1]   ratio=1.000x

prefill seq=2048 alternating A/B 10x each:
  A median=1328.9 ms  range=[1321.1-1333.0]
  B median=1331.5 ms  range=[1318.8-1338.2]   ratio=1.002x

Headline

B (this PR's static-loop) is within 2% of A (.nonzero()) across all scenarios — essentially noise on CPU. The export-compatibility win comes effectively for free.

C (dense / no-loop / repeat inputs) is 1.6–4× slower than A in every scenario. The decode penalty matches the math: top-k=2 means dense computes 4× more flops, and on CPU that translates to 2.5–4× wall-time.

Suggestion

Given B has no measurable perf cost on CPU and unlocks torch.export, this PR seems like a net positive without needing the dense form. Happy to re-run on GPU or with bf16 if maintainers want to verify the pattern holds there too.

Repro

Bench script (~80 lines) is small and self-contained. I can attach it as a follow-up comment or a gist if useful.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants