[WIP] Fix naive for loops for MoE models resulting in sub 20% downstream MFU for training with trl, e.t.c (Qwen3, Deepseek V3, Ernie 4.5, GLM 4.5, Dots1) #40016
Conversation
…d and init accordingly to use moe()
…the forward to use moe()
|
Interesting! One question I have is - will this be compatible with PEFT? I uploaded a few of those Qwen 3 scattermoe conversions based on Charles Goddard’s original remote code implementation, but the problem I got stuck on was that the fused MoE layers were not compatible with PEFT and we could only target the attention tensors and router during lora training. |
|
This PR in it's current state should be fully compatible with PEFT and more specifically LoRA even without the target_parameters PR in peft as it does not currently fuse the experts. |
|
Hey! |
|
Happy to have a better version than what we currently have, and also making sure it is TP compatible. For the best performance we cana also use https://huggingface.co/kernels-community/megablocks/tree/main/torch-ext/megablocks |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: deepseek_v3, ernie4_5_moe, qwen3_moe |
|
MB; I kinda realized that without a custom kernel the performance would likely still be poor and that there aren't any good torch ops to for fused experts, plus that I was maybe a bit out of my depth. |
|
|
Thanks a lot for the detailed explanation |
Great to see progress on this! @DocShotgun I've written something that may help: https://github.com/woct0rdho/transformers-qwen3-moe-fused |
|
#41580 fixed this :) |
What does this PR do?
Fixes the longstanding issues with MoE training being bottlenecked by naive for loops for models with > 8 experts.
This can result in sub 20% MFU in downstream training frameworks such as unsloth and trl. (Qwen3 30B on H800)
There have been several downstream issues already from training frameworks such as unslothai/unsloth#2582, and open source community members have made custom patches such as https://huggingface.co/Doctor-Shotgun/Qwen3-235B-A22B-Instruct-2507-ScatterMoE. Although not publicly available, I've also heard several complaints in the Axolotl and BeaverAI discords on this issue.
This PR mainly replaces the moe() method from Deepseek V3 with the mathematically equivalent but faster Scatter MoE implementation and makes other sparse moe blocks inherit from DeepseekV3MoE in addition to accordingly modifying the forward and init of those modules to use moe()
Also, from modular_deepseek_v3.py:
"""
CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused
to not have to do a loop here (deepseek has 256 experts soooo yeah).
"""
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Models: