Skip to content

fix: enable MXFP4 MoE at TP=4/8 via CKTile a4w4 kernels and quant fixes#2642

Open
thpereir wants to merge 1 commit intomainfrom
thpereir/cktile_a4w4
Open

fix: enable MXFP4 MoE at TP=4/8 via CKTile a4w4 kernels and quant fixes#2642
thpereir wants to merge 1 commit intomainfrom
thpereir/cktile_a4w4

Conversation

@thpereir
Copy link
Copy Markdown

@thpereir thpereir commented Apr 7, 2026

Enable tensor-parallel MXFP4 MoE for non-power-of-2 scale dimensions (TP=4: scaleN=12, TP=8: scaleN=6) by fixing quant padding, sort kernel reshaping, and adding CKTile a4w4 kernel dispatch.

Fix bugs and add two performance optimizations to the MXFP4 a4w4 (fp4x2 activations + fp4x2 weights) two-stage MoE path, validated on MiniMax-M2.1-MXFP4 at TP=8 on gfx950.

Fixes:

  1. Wrong stage1 kernel for a4w4:
    cktile_moe_stage1 routes fp4x2 activations through the fp8 pipeline
    (AQUANT_Pipeline), which misinterprets packed fp4 as fp8 -> garbage.
    Switch stage1 to ck_moe_stage1 (JIT CK DeviceMoeGemmMXBPreShuffle).

  2. ksplit>1 steals a4w4 cases for shared experts:
    MiniMax shared experts have inter_dim=256 (per TP=8 rank).
    get_ksplit() returns 2, and the "ksplit>1 and is_shuffled" elif
    fires before the a4w4 branch because it only checks q_dtype_w.
    Fix: add "q_dtype_a not in [dtypes.fp4x2]" guard.

  3. ksplit must be 0 for a4w4 in fused_moe_2stages:
    When metadata.ksplit>1, fused_moe_2stages skips fp4x2 activation
    quantization. Force ksplit=0 in the a4w4 MOEMetadata.

Performance optimizations:

  1. Skip intermediate bf16->fp4x2 quantization between stages:
    Stage2 now uses CKTile AOT with bf16 activations (a16w4 path)
    instead of JIT CK with fp4x2 activations, eliminating a full
    quantization kernel launch + memory round-trip between stages.
    New skip_inter_quant flag on MOEMetadata controls this.

  2. block_m 32->64 for fp4x2:
    Sweep across all MiniMax batch sizes shows block_m=64 is
    consistently faster than 32 on gfx950 (up to 1.54x for prefill).
    128 is not supported by current CKTile stage2 instances.

Additional changes:

  • CKTile AOT: add pk_fp4 x pk_fp4 codegen for both gemm1 and gemm2,
    generate stage1+stage2 instances for a4w4, skip split_k for pk_fp4.
  • cktile_moe_stage1 split_k>1: fix tmp_out buffer size (use sorted_size
    not token_num*topk), use ck_moe_stage2_fwd for correct dispatch.
  • Torch fallback: catch RuntimeError/AssertionError in fused_moe_ and
    fall back to torch_moe_stage1/stage2 for unsupported configurations.
  • shuffle_weight_a16w4: set is_shuffled=True on returned tensor.
  • Add e8m0_unshuffle utility (inverse of e8m0_shuffle).
  • test_moe_2stage: add a4w4 test branch.

Benchmark (gfx950, 256 CUs, MiniMax-M2.1-MXFP4 TP=8 dims):

Tokens OLD(bm32+quant+CK) NEW(bm64+skip+CKTile) Speedup
1 45.3 us 40.1 us 1.13x
4 53.5 us 45.4 us 1.18x
8 100.0 us 72.2 us 1.38x
16 157.5 us 108.8 us 1.45x
32 233.5 us 160.2 us 1.46x
64 309.1 us 216.6 us 1.43x
128 GPU fault 233.1 us NEW only
512 GPU fault 265.4 us NEW only

Motivation

Technical Details

Test Plan

One of the models requiring this fix is AMD MiniMax2.1-MXFP4 when used with TP=4 or TP=8. Both ATOM and VLLM need these changes to be able to run above MiniMax model with TP 4/8.

Test Result

Serving with ATOM:

python -m atom.entrypoints.openai_server \
  --model amd/MiniMax-M2.1-MXFP4 \
  --trust-remote-code \
  -tp 8
lm_eval \
  --model local-completions \
  --model_args model=amd/MiniMax-M2.1-MXFP4,base_url=http://localhost:8000/v1/completions,num_concurrent=32,max_retries=3,tokenized_requests=False \
  --tasks gsm8k \
  --num_fewshot 5 \
  --batch_size 1

lm-eval before changes:

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0 ± 0
strict-match 5 exact_match 0 ± 0

lm-eval after changes:

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.9371 ± 0.0067
strict-match 5 exact_match 0.9348 ± 0.0068

Submission Checklist

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 7, 2026

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-355 Run Triton tests on MI355 in addition to MI325
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2642 --add-label <label>

@thpereir thpereir force-pushed the thpereir/cktile_a4w4 branch 3 times, most recently from 207a3a9 to dbb2268 Compare April 9, 2026 00:39
@thpereir thpereir marked this pull request as ready for review April 9, 2026 00:39
@thpereir thpereir requested a review from a team April 9, 2026 00:39
@thpereir thpereir added the ci:all label Apr 9, 2026
@azaidy azaidy requested review from k50112113 and valarLip April 9, 2026 16:13
@azaidy
Copy link
Copy Markdown
Contributor

azaidy commented Apr 9, 2026

@k50112113 can you review the quant file

@lburzawa
Copy link
Copy Markdown
Contributor

Unfortunately I'm not familiar with CK enough to review this PR. For the Moe part we need a CK developer to review.

@thpereir thpereir force-pushed the thpereir/cktile_a4w4 branch 2 times, most recently from 668745e to cd0d359 Compare April 22, 2026 18:06
@valarLip valarLip requested a review from lalala-sh April 23, 2026 03:29
Comment thread aiter/fused_moe.py Outdated
Comment thread csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages_common.py Outdated
Comment thread csrc/ck_tile_gemm_moe_2stages/gen_instances.py Outdated
@thpereir thpereir force-pushed the thpereir/cktile_a4w4 branch 2 times, most recently from a99c298 to 438a09f Compare April 23, 2026 16:43
- Add CKTile AOT stage2 kernel path for a4w4 (fp4x2) MoE
- Add get_2stage_cfgs() heuristic routing a4w4 to cktile_moe_stage2
- Add inner CK→CKTile stage2 fallback for tuning-CSV edge cases where
  CK JIT device_gemm doesn't cover the TP-sharded shape on gfx950
- Remove outer torch fallback (try/except in fused_moe_()) that silently
  degraded to slow PyTorch path; unsupported configs now raise directly,
  consistent with all other quant types (a8w8, fp8, etc.)
@thpereir thpereir force-pushed the thpereir/cktile_a4w4 branch from 438a09f to d52cd22 Compare April 24, 2026 20:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants