fix: enable MXFP4 MoE at TP=4/8 via CKTile a4w4 kernels and quant fixes#2642
Open
fix: enable MXFP4 MoE at TP=4/8 via CKTile a4w4 kernels and quant fixes#2642
Conversation
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
207a3a9 to
dbb2268
Compare
Contributor
|
@k50112113 can you review the quant file |
88cba1b to
0eb8893
Compare
0eb8893 to
742d570
Compare
1 task
Contributor
|
Unfortunately I'm not familiar with CK enough to review this PR. For the Moe part we need a CK developer to review. |
668745e to
cd0d359
Compare
lalala-sh
reviewed
Apr 23, 2026
a99c298 to
438a09f
Compare
- Add CKTile AOT stage2 kernel path for a4w4 (fp4x2) MoE - Add get_2stage_cfgs() heuristic routing a4w4 to cktile_moe_stage2 - Add inner CK→CKTile stage2 fallback for tuning-CSV edge cases where CK JIT device_gemm doesn't cover the TP-sharded shape on gfx950 - Remove outer torch fallback (try/except in fused_moe_()) that silently degraded to slow PyTorch path; unsupported configs now raise directly, consistent with all other quant types (a8w8, fp8, etc.)
438a09f to
d52cd22
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Enable tensor-parallel MXFP4 MoE for non-power-of-2 scale dimensions (TP=4: scaleN=12, TP=8: scaleN=6) by fixing quant padding, sort kernel reshaping, and adding CKTile a4w4 kernel dispatch.
Fix bugs and add two performance optimizations to the MXFP4 a4w4 (fp4x2 activations + fp4x2 weights) two-stage MoE path, validated on MiniMax-M2.1-MXFP4 at TP=8 on gfx950.
Fixes:
Wrong stage1 kernel for a4w4:
cktile_moe_stage1 routes fp4x2 activations through the fp8 pipeline
(AQUANT_Pipeline), which misinterprets packed fp4 as fp8 -> garbage.
Switch stage1 to ck_moe_stage1 (JIT CK DeviceMoeGemmMXBPreShuffle).
ksplit>1 steals a4w4 cases for shared experts:
MiniMax shared experts have inter_dim=256 (per TP=8 rank).
get_ksplit() returns 2, and the "ksplit>1 and is_shuffled" elif
fires before the a4w4 branch because it only checks q_dtype_w.
Fix: add "q_dtype_a not in [dtypes.fp4x2]" guard.
ksplit must be 0 for a4w4 in fused_moe_2stages:
When metadata.ksplit>1, fused_moe_2stages skips fp4x2 activation
quantization. Force ksplit=0 in the a4w4 MOEMetadata.
Performance optimizations:
Skip intermediate bf16->fp4x2 quantization between stages:
Stage2 now uses CKTile AOT with bf16 activations (a16w4 path)
instead of JIT CK with fp4x2 activations, eliminating a full
quantization kernel launch + memory round-trip between stages.
New skip_inter_quant flag on MOEMetadata controls this.
block_m 32->64 for fp4x2:
Sweep across all MiniMax batch sizes shows block_m=64 is
consistently faster than 32 on gfx950 (up to 1.54x for prefill).
128 is not supported by current CKTile stage2 instances.
Additional changes:
generate stage1+stage2 instances for a4w4, skip split_k for pk_fp4.
not token_num*topk), use ck_moe_stage2_fwd for correct dispatch.
fall back to torch_moe_stage1/stage2 for unsupported configurations.
Benchmark (gfx950, 256 CUs, MiniMax-M2.1-MXFP4 TP=8 dims):
Tokens OLD(bm32+quant+CK) NEW(bm64+skip+CKTile) Speedup
1 45.3 us 40.1 us 1.13x
4 53.5 us 45.4 us 1.18x
8 100.0 us 72.2 us 1.38x
16 157.5 us 108.8 us 1.45x
32 233.5 us 160.2 us 1.46x
64 309.1 us 216.6 us 1.43x
128 GPU fault 233.1 us NEW only
512 GPU fault 265.4 us NEW only
Motivation
Technical Details
Test Plan
One of the models requiring this fix is AMD MiniMax2.1-MXFP4 when used with TP=4 or TP=8. Both ATOM and VLLM need these changes to be able to run above MiniMax model with TP 4/8.
Test Result
Serving with ATOM:
lm-eval before changes:
lm-eval after changes:
Submission Checklist