Skip to content

[WIP] Fix naive for loops for MoE models resulting in sub 20% downstream MFU for training with trl, e.t.c (Qwen3, Deepseek V3, Ernie 4.5, GLM 4.5, Dots1) #40016

Closed
perinmclaughlin wants to merge 5 commits intohuggingface:mainfrom
perinmclaughlin:V3ScatterMoE
Closed

[WIP] Fix naive for loops for MoE models resulting in sub 20% downstream MFU for training with trl, e.t.c (Qwen3, Deepseek V3, Ernie 4.5, GLM 4.5, Dots1) #40016
perinmclaughlin wants to merge 5 commits intohuggingface:mainfrom
perinmclaughlin:V3ScatterMoE

Conversation

@perinmclaughlin
Copy link
Copy Markdown

What does this PR do?

Fixes the longstanding issues with MoE training being bottlenecked by naive for loops for models with > 8 experts.
This can result in sub 20% MFU in downstream training frameworks such as unsloth and trl. (Qwen3 30B on H800)

There have been several downstream issues already from training frameworks such as unslothai/unsloth#2582, and open source community members have made custom patches such as https://huggingface.co/Doctor-Shotgun/Qwen3-235B-A22B-Instruct-2507-ScatterMoE. Although not publicly available, I've also heard several complaints in the Axolotl and BeaverAI discords on this issue.

This PR mainly replaces the moe() method from Deepseek V3 with the mathematically equivalent but faster Scatter MoE implementation and makes other sparse moe blocks inherit from DeepseekV3MoE in addition to accordingly modifying the forward and init of those modules to use moe()

Also, from modular_deepseek_v3.py:
"""
CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused
to not have to do a loop here (deepseek has 256 experts soooo yeah).
"""

Before submitting

  • [N] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [Y] Did you read the contributor guideline,
    Pull Request section?
  • [Y] Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Models:

@perinmclaughlin perinmclaughlin changed the title [WIP] Fix naive for loops for MoE models resulting in sub 30% downstream MFU for training with trl, e.t.c (Qwen3, Deepseek V3, Ernie 4.5, GLM 4.5, Dots1) [WIP] Fix naive for loops for MoE models resulting in sub 20% downstream MFU for training with trl, e.t.c (Qwen3, Deepseek V3, Ernie 4.5, GLM 4.5, Dots1) Aug 7, 2025
@DocShotgun
Copy link
Copy Markdown

Interesting!

One question I have is - will this be compatible with PEFT?

I uploaded a few of those Qwen 3 scattermoe conversions based on Charles Goddard’s original remote code implementation, but the problem I got stuck on was that the fused MoE layers were not compatible with PEFT and we could only target the attention tensors and router during lora training.

@perinmclaughlin
Copy link
Copy Markdown
Author

This PR in it's current state should be fully compatible with PEFT and more specifically LoRA even without the target_parameters PR in peft as it does not currently fuse the experts.
I initially believed that there was little point to fusing the experts as most fused expert implementations use bmm which either requires significant wasted computation and memory access for padding or token dropping.
However I helpfully realized about a day after posting this PR that you could avoid padding by grouping experts by number of assigned tokens and using bmm on each group to avoid padding, so I may switch to fused experts if that method turns out to be significantly faster.

@perinmclaughlin perinmclaughlin deleted the V3ScatterMoE branch August 12, 2025 15:56
@ArthurZucker
Copy link
Copy Markdown
Collaborator

Hey!
I was about to review!

@ArthurZucker
Copy link
Copy Markdown
Collaborator

ArthurZucker commented Aug 12, 2025

Happy to have a better version than what we currently have, and also making sure it is TP compatible. For the best performance we cana also use https://huggingface.co/kernels-community/megablocks/tree/main/torch-ext/megablocks

@perinmclaughlin perinmclaughlin restored the V3ScatterMoE branch August 12, 2025 16:09
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: deepseek_v3, ernie4_5_moe, qwen3_moe

@perinmclaughlin
Copy link
Copy Markdown
Author

perinmclaughlin commented Aug 12, 2025

MB; I kinda realized that without a custom kernel the performance would likely still be poor and that there aren't any good torch ops to for fused experts, plus that I was maybe a bit out of my depth.
Torch bmm has the aforementioned issue of requiring significant padding if expert load is not equivalent.
I looked into the torchtune moe implementation and they're using an undocumented torch function, which also seems to have some quirky compatibility issues as usual.
Currently mostly just pulled some ops out of the for loop for experts and batched them, but I couldn't find a good way in pure torch to get rid of the kernel launch overhead for each expert.
Megablocks does look promising though.

@ArthurZucker
Copy link
Copy Markdown
Collaborator

get rid of the kernel launch overhead for each expert.
does cudagraph not help (compile with reduce-overhead) ?

@ArthurZucker
Copy link
Copy Markdown
Collaborator

Thanks a lot for the detailed explanation

@woct0rdho
Copy link
Copy Markdown
Contributor

woct0rdho commented Sep 4, 2025

One question I have is - will this be compatible with PEFT?

Great to see progress on this! @DocShotgun I've written something that may help: https://github.com/woct0rdho/transformers-qwen3-moe-fused

@ArthurZucker
Copy link
Copy Markdown
Collaborator

#41580 fixed this :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants