Skip to content

feat(deepseek_v3): Add grouped GEMM kernel for faster MoE computation#40583

Closed
bzantium wants to merge 3 commits intohuggingface:mainfrom
bzantium:feature/#40582
Closed

feat(deepseek_v3): Add grouped GEMM kernel for faster MoE computation#40583
bzantium wants to merge 3 commits intohuggingface:mainfrom
bzantium:feature/#40582

Conversation

@bzantium
Copy link
Copy Markdown
Contributor

@bzantium bzantium commented Sep 1, 2025

What does this PR do?

This PR introduces an optimization that significantly accelerates the Mixture-of-Experts (MoE) layer computations in the DeepseekV3 model by integrating the grouped_gemm library. This enhances performance for both training and inference.

The original MoE implementation processed expert networks sequentially using a Python loop, which created a performance bottleneck due to high GPU kernel launch overhead. This PR addresses that issue with the following key changes:

  • 🚀 Grouped GEMM Kernel Integration: A new grouped_forward operational path replaces the iterative Python loop over experts with a single, high-performance kernel call from the grouped_gemm library. This minimizes GPU overhead and maximizes throughput.

  • 🧩 Expert Module Fusing: To efficiently leverage the grouped_gemm kernel, this PR implements a fuse_experts() utility and a GroupedDeepseekV3MLP module. These tools combine the weights of multiple experts into a single, contiguous tensor, which is a prerequisite for the kernel.

  • ⚙️ Configuration and Usability: A use_grouped_gemm flag has been added to DeepseekV3Config to enable this optimization.
    Important: If you set use_grouped_gemm=True directly when loading a model (e.g., in .from_pretrained()), you must provide a state_dict where the expert weights have already been fused. For standard checkpoints, the recommended workflow is to load the model normally and then call the model.fuse_experts() method.

  • ⚠️ Dependency Handling: The model now gracefully handles the optional dependency. If use_grouped_gemm is enabled but the library isn't found, it raises a clear ImportError with installation instructions.


How to use

  1. Install the required library:

    pip install git+https://github.com/fanshiqing/grouped_gemm@main
  2. Load a standard model and fuse the experts after loading (Recommended method):

    from transformers import AutoModelForCausalLM
    
    # Load a standard checkpoint from the Hub
    model = AutoModelForCausalLM.from_pretrained(
        "moonshotai/Moonlight-16B-A3B-Instruct",
        device_map="cuda:0",
        # Do not set use_grouped_gemm=True here
    )
    
    # Fuse the experts in-place
    model.fuse_experts()
    
    # The model is now optimized and ready for faster inference or training

For checkpoints that have already been saved in the fused format, you can load them directly by setting use_grouped_gemm=True in the .from_pretrained() call.

Fixes #40582

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@ArthurZucker @Rocketknight1

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Sep 1, 2025

[For maintainers] Suggested jobs to run (before merge)

run-slow: deepseek_v3, dots1, glm4_moe, glm4v_moe

@woct0rdho
Copy link
Copy Markdown
Contributor

Great to see progress on this! Previously a common concern is how to support PEFT/LoRA, such as in #40016 . I've written something that may help: https://github.com/woct0rdho/transformers-qwen3-moe-fused

@bzantium
Copy link
Copy Markdown
Contributor Author

bzantium commented Sep 4, 2025

@woct0rdho Thanks to share great work! I will check this out.
@ArthurZucker @Rocketknight1 please review this and let me know what I should do more to integrate this kind of works.

@ArthurZucker
Copy link
Copy Markdown
Collaborator

Will have a look thanks for the PR, glad to see you here again! 🤗

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the PR!
This IS planned, but not this way!

#40132 is taking care of isolating the expert class, it will then be followed up by a on the fly weight conversion to avoid having _fuse_experts!

We also don't want to add extra deps now that we have kernels!
But the "naive" path will be using torch's gemm (and probably with a fallback for older torch version with a naive for loop / bmm)

@bzantium
Copy link
Copy Markdown
Contributor Author

bzantium commented Sep 20, 2025

Thank you for the review and the clear direction! I appreciate you pointing me to #40132. The plan to isolate the expert class first and then use on-the-fly weight conversion makes a lot of sense.
to: @ArthurZucker

@ArthurZucker
Copy link
Copy Markdown
Collaborator

Awesome, first pr is close to being merged!

@glide-the
Copy link
Copy Markdown

Waiting for support!

@litanli
Copy link
Copy Markdown

litanli commented Oct 21, 2025

Waiting for support as well, would really appreciate MoE LoRA capability!

@zenyanbo
Copy link
Copy Markdown

Will all moe models benefit from this? Even trainer?

@bzantium
Copy link
Copy Markdown
Contributor Author

Thanks for many attention to this feature! I will start working on it since #40132 have successfully merged.

@bzantium
Copy link
Copy Markdown
Contributor Author

bzantium commented Jan 8, 2026

Since #42697 is merged, I will close this PR.

@bzantium bzantium closed this Jan 8, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

feat(deepseek_v3): Add grouped GEMM kernel for faster MoE computation

6 participants