Skip to content

Better Grouped GEMM + EP#45621

Open
IlyasMoutawwakil wants to merge 26 commits intomainfrom
deepgemm
Open

Better Grouped GEMM + EP#45621
IlyasMoutawwakil wants to merge 26 commits intomainfrom
deepgemm

Conversation

@IlyasMoutawwakil
Copy link
Copy Markdown
Member

@IlyasMoutawwakil IlyasMoutawwakil commented Apr 24, 2026

What does this PR do?

The idea is that we shouldn't be clamping experts ids at all, clamping them makes their tokens get prjected as if they were routed to the last expert, instead, we should let the sentinels be:

  • moved to the queue by the sorting (num_experts/sentinel is the biggest value)
  • dropped by the histogram (because max=num_experts-1)
  • and then the offsets created from this histogram will contain the right mapping : experts[i] takes tokens from (offsets[i-1] or 0 if i==0) to (offsets[i])

I micro-benchmarked the kernel with sentinel tokens and it is faster (it is skipping their compute as expected) :

  • offsets[-1] = 16384 (100%) 22.7 ms/iter 1.00x
  • offsets[-1] = 8192 (50%) 13.0 ms/iter 0.57x
  • offsets[-1] = 2048 (12.5%) 4.5 ms/iter 0.20x

and still data independent / compatible with torch.compile / cuda graphs

Code Agent Policy

The Transformers repo is currently being overwhelmed by a large number of PRs and issue comments written by
code agents. We are currently bottlenecked by our ability to review and respond to them. As a result,
we ask that new users do not submit pure code agent PRs at this time.
You may use code agents in drafting or to help you diagnose issues. We'd also ask autonomous "OpenClaw"-like agents
not to open any PRs or issues for the moment.

PRs that appear to be fully agent-written will probably be closed without review, and we may block users who do this
repeatedly or maliciously.

This is a rapidly-evolving situation that's causing significant shockwaves in the open-source community. As a result,
this policy is likely to be updated regularly in the near future. For more information, please read CONTRIBUTING.md.

  • I confirm that this is not a pure code agent PR.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@IlyasMoutawwakil
Copy link
Copy Markdown
Member Author

will ad bf16 deepgemm to testing as well

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice~
Having TP code in moe is fine IMO let's not separate into TP since anything could be using sentinels actually!

LGTM otherwise, deepgemm isolation could be another PR 😉

Comment thread src/transformers/integrations/moe.py
Comment thread src/transformers/integrations/tensor_parallel.py Outdated
@IlyasMoutawwakil
Copy link
Copy Markdown
Member Author

Having TP code in moe is fine IMO let's not separate into TP since anything could be using sentinels actually!

yeah at first i created it because it was gonna be used everywhere but now only a clamp is needed in batched paths.

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@IlyasMoutawwakil
Copy link
Copy Markdown
Member Author

reverted the bf16 deepgemm and isolation

@IlyasMoutawwakil
Copy link
Copy Markdown
Member Author

Great that we have a couple models with TensorParallelTesterMixin they pass now !
the output of grouped mm can contain uninitialized values if not all tokens are routed to experts, because it uses torch.empty internally to initialize it.

@IlyasMoutawwakil IlyasMoutawwakil marked this pull request as ready for review April 24, 2026 11:26
@IlyasMoutawwakil IlyasMoutawwakil changed the title MoE/DeepGEMM/SonicMoE refactor + better EP Better Grouped GEMM + EP Apr 24, 2026
"finegrained-fp8": {"repo_id": "kernels-community/finegrained-fp8", "version": 1},
"deep-gemm": {"repo_id": "kernels-community/deep-gemm", "version": 1},
"sonic-moe": {"repo_id": "kernels-community/sonic-moe", "version": 1},
"sonic-moe": {"repo_id": "IlyasMoutawwakil/sonic-moe", "revision": "main"},
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for testing

Copy link
Copy Markdown
Member

@AmineDiro AmineDiro Apr 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick sanity check, does this redirect mean that the currently-published kernels-community/sonic-moe does not yet have the metadata/sentinel handling?

Asking because I hit cudaErrorIllegalAddress reliably running sonicmoe with EP=8 (Qwen3-30B-A3B, 2 nodes × 8 H100, FSDP2 dp=2 ep=8) with the current hub kernel:

File "...sonic-moe/build/torch-cuda/quack/autotuner.py", line 84, in _gpu_warmup
    a = torch.randn(4096, 4096, device="cuda", dtype=torch.bfloat16)
torch.AcceleratorError: CUDA error: an illegal memory access was encountered

Thats some sticky CUDA errors that ran asynchronously and faulted before the error propagated back.

with EP=8, RouterParallel produces sentinel expert_ids ≥ num_local_experts (16, since 128/8). The v1 sonic-moe kernel internally does gate_up_proj[expert_ids[i]] which is OOB ??

when I add expert_ids.clamp(0, num_experts-1) and masked_fill_(invalid_mask, 0.0) in the wrapper before the kernel call everything works.
Removing the clamp brings the crash back 🥲 . So the v1 kernel really does need its inputs in-bounds, and the new build in your fork is what actually fixes it ?? Just want to confirm the plan is to republish the fixed build to kernels-community/sonic-moe and revert this redirect once that's done ?

ws_down = self.down_proj_scale_inv
proj_fp8, proj_scales = deepgemm_per_token_cast_to_fp8(proj_out, use_ue8m0=False)
proj_out = torch.zeros(total_padded_rows, hidden_dim, device=device, dtype=torch.bfloat16)
proj_out = torch.empty(total_padded_rows, hidden_dim, device=device, dtype=torch.bfloat16)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice if we write to all, which we do

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes !

kernel = get_kernel(repo_id, revision=revision, version=version)
# Entries in `_HUB_KERNEL_MAPPING` are vetted in-tree, so we trust non-`kernels-community`
# repos (e.g. user/team forks) without requiring the per-call `allow_all_kernels` flag.
kernel = get_kernel(repo_id, revision=revision, version=version, allow_all_kernels=True)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to remove

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes and revert to the kernels-community one ?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

loads of rep from moe vs fp8 cool if we re-use stuff, fine otherwise haha

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants