Fix torch.export compatibility for Mixtral MoE models #40114
Fix torch.export compatibility for Mixtral MoE models #40114akacmazz wants to merge 13 commits intohuggingface:mainfrom
Conversation
- Replace data-dependent .nonzero() operation with static expert loop - Resolves GuardOnDataDependentSymNode error during torch.export - Maintains identical functionality while enabling export compatibility - Fixes issue introduced in PR huggingface#32429 - Add tests for torch.export compatibility
- Auto-generate modeling_mixtral.py with the same fix - Apply black formatting - Fix repository consistency check
9b41625 to
c3e3c5e
Compare
ArthurZucker
left a comment
There was a problem hiding this comment.
Thanks for the PR, but the reason we cannot do that is because it is really a lot lot less efficient! I would recommend to rather have 2 paths, 1 for training one for inference, inference can just not loop at all and repeat the inputs
Thx for the feedback, i will work on this |
…ility - Training path: Keep efficient .nonzero() for performance - Inference path: Use static loop for torch.export compatibility - Add conditional check to skip empty experts in inference - Update tests to validate inference mode export - Addresses maintainer feedback on performance concerns
952a181 to
0aa9de7
Compare
- Apply black formatting to fix code style - Fix import sorting with isort - Address CI code quality checks
- Fix import organization in modeling_mixtral.py - Fix import organization in modular_mixtral.py - Address ruff I001 import sorting warnings
- Remove manually edited modeling_mixtral.py - Auto-generate from modular_mixtral.py using proper tool - Ensure consistency between modular and generated files - Fix check_repository_consistency CI failure
- Remove 'if top_x.shape[0] == 0: continue' check that causes GuardOnDataDependentSymNode error - Empty expert tensors naturally contribute 0, no explicit check needed - Update test error message for clarity - Fixes tests_processors CI failure Co-authored-by: ArthurZucker <arthur.zucker@gmail.com>
|
[For maintainers] Suggested jobs to run (before merge) run-slow: mixtral |
|
@akacmazz , @ArthurZucker Hi, could I know the blocker for this? I found this commit from #38518 , and I just want to know how things are going on. |
| # Inference path: loop over all experts for torch.export compatibility | ||
| for expert_idx in range(self.num_experts): | ||
| expert_layer = self.experts[expert_idx] | ||
| idx, top_x = torch.where(expert_mask[expert_idx]) | ||
|
|
||
| # Index the correct hidden states and compute the expert hidden state for | ||
| # the current expert. We need to make sure to multiply the output hidden | ||
| # states by `routing_weights` on the corresponding tokens (top-1 and top-2) | ||
| current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) | ||
| current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] | ||
|
|
||
| # However `index_add_` only support torch tensors for indexing so we'll use | ||
| # the `top_x` tensor here. | ||
| final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) |
There was a problem hiding this comment.
Hey! we should avoid this, it will really slow down "all" use cases. Is there no way to still export but have a different forward? Like using a kernel from the hub / using the bmm version? Or guarding manually or something?
Ported from huggingface#40114 (head 26ea571).
|
Hi @ArthurZucker @Cyrilvallez @akacmazz, dropping CPU benchmark data in case it helps unblock review. I needed to export Mixtral via
Setup
Results — single
|
| scenario | A .nonzero() |
B static-loop (this PR) | C dense (no loop) | B/A | C/A |
|---|---|---|---|---|---|
| decode b=1 s=1 | 31.5 ms | 32.2 ms | 123.9 ms | 1.022x | 3.927x |
| decode b=4 s=1 | 93.4 ms | 93.4 ms | 239.5 ms | 1.000x | 2.564x |
| prefill b=1 s=128 | 302.2 ms | 300.0 ms | 487.7 ms | 0.993x | 1.614x |
| prefill b=1 s=512 | 393.1 ms | 394.7 ms | 1593.1 ms | 1.004x | 4.052x |
| prefill b=1 s=2048 | 1388.7 ms | 1371.3 ms | 3491.6 ms | 0.987x | 2.514x |
I re-ran prefill seq=2048 separately with 5 warmups + 25 iters and again with A/B interleaved 10x each (to mitigate cache/thermal effects):
prefill seq=2048 confirmation (5 warmups + 25 iters):
A nonzero: median=1335.1 ms IQR=[1333.7-1338.2]
B static-loop: median=1334.8 ms IQR=[1331.9-1336.1] ratio=1.000x
prefill seq=2048 alternating A/B 10x each:
A median=1328.9 ms range=[1321.1-1333.0]
B median=1331.5 ms range=[1318.8-1338.2] ratio=1.002x
Headline
B (this PR's static-loop) is within 2% of A (.nonzero()) across all scenarios — essentially noise on CPU. The export-compatibility win comes effectively for free.
C (dense / no-loop / repeat inputs) is 1.6–4× slower than A in every scenario. The 4× decode penalty matches the math: top-k=2 means dense computes 4× more flops, and on CPU that translates to 2.5–4× wall-time.
Suggestion
Given B has no measurable perf cost on CPU and unlocks torch.export, this PR seems like a net positive without needing the dense form. Happy to re-run on GPU or with bf16 if maintainers want to verify the pattern holds there too.
Repro
Bench script (~80 lines) is small and self-contained. I can attach it as a follow-up comment or a gist if useful.
What does this PR do?
This PR fixes a torch.export compatibility issue #38518 with Mixtral MoE models that was introduced in PR #32429.
Problem
The optimization in PR #32429 introduced a .nonzero() operation that creates data-dependent tensor shapes, causing torch.export to fail with:
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not extract specialized integer from data-dependent expression
Solution
Replace the dynamic expert selection loop:
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
With a static loop over all experts:
for expert_idx in range(self.num_experts):
Impact
Testing
Fixes torch.export compatibility issues reported for Mixtral-8x7B models.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
@Cyrilvallez
@ArthurZucker
@gante