Skip to content

fix: correct MXFP4 MoE weight shuffle for Quark models (MiniMax-M2.1)#632

Draft
thpereir wants to merge 1 commit intomainfrom
thpereir/minimax_21_qmoe
Draft

fix: correct MXFP4 MoE weight shuffle for Quark models (MiniMax-M2.1)#632
thpereir wants to merge 1 commit intomainfrom
thpereir/minimax_21_qmoe

Conversation

@thpereir
Copy link
Copy Markdown
Contributor

Fix the Quark branch of Mxfp4MoEMethod.process_weights_after_loading() to use the correct pre-shuffle format expected by CK JIT MoE kernels (DeviceMoeGemmMXBPreShuffle).

Previously the Quark path called shuffle_weights() (a16w4 interleave format) and e8m0_shuffle() for scales. The CK JIT a4w4 kernels (ck_moe_stage1/ck_moe_stage2_fwd) expect shuffle_weight_a16w4(gate_up= False) format for both w1 and w2, and shuffle_scale_a16w4() for scales.

Changes:

  • w13 (gate+up): shuffle_weight_a16w4(w13, 16, gate_up=False) + shuffle_scale_a16w4(scale, E, gate_up=False). The a4w4 kernel uses gate_up=False even for w1 which has 2*N rows (unlike the a16w4 path which uses gate_up=True for w1).
  • w2 (down-proj): shuffle_weight_a16w4(w2, 16, gate_up=False) + shuffle_scale_a16w4(scale, E, gate_up=False).
  • Set is_shuffled=True on both w13_weight and w2_weight so that ck_moe_stage2_fwd selects the preshuffle_on module variant.
  • Also set is_shuffled=True on w2_weight in the existing gpt_oss/Swiglu branch (latent bug: weight was shuffled but attribute was missing).

Scope: Only the Quark branch is new code. Currently only MiniMax-M2.1- MXFP4 uses quant_method="quark" with fp4x2 MoE. The gpt_oss is_shuffled fix is a one-line correctness fix for DeepSeek-R1 FP4 and similar models.

Depends on: ROCm/aiter thpereir/cktile_a4w4 branch (CK JIT a4w4 dispatch)

Tested: MiniMax-M2.1-MXFP4 TP=8 server starts and produces correct output.

Motivation

Technical Details

Test Plan

Test Result

Serving with ATOM:

python -m atom.entrypoints.openai_server \
  --model amd/MiniMax-M2.1-MXFP4 \
  --trust-remote-code \
  -tp 8
lm_eval \
  --model local-completions \
  --model_args model=amd/MiniMax-M2.1-MXFP4,base_url=http://localhost:8000/v1/completions,num_concurrent=32,max_retries=3,tokenized_requests=False \
  --tasks gsm8k \
  --num_fewshot 5 \
  --batch_size 1

lm-eval before changes:

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0 ± 0
strict-match 5 exact_match 0 ± 0

lm-eval after changes:

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.9371 ± 0.0067
strict-match 5 exact_match 0.9348 ± 0.0068

Submission Checklist

Fix the Quark branch of Mxfp4MoEMethod.process_weights_after_loading()
to use the correct pre-shuffle format expected by CK JIT MoE kernels
(DeviceMoeGemmMXBPreShuffle).

Previously the Quark path called shuffle_weights() (a16w4 interleave
format) and e8m0_shuffle() for scales. The CK JIT a4w4 kernels
(ck_moe_stage1/ck_moe_stage2_fwd) expect shuffle_weight_a16w4(gate_up=
False) format for both w1 and w2, and shuffle_scale_a16w4() for scales.

Changes:
- w13 (gate+up): shuffle_weight_a16w4(w13, 16, gate_up=False) +
  shuffle_scale_a16w4(scale, E, gate_up=False). The a4w4 kernel uses
  gate_up=False even for w1 which has 2*N rows (unlike the a16w4 path
  which uses gate_up=True for w1).
- w2 (down-proj): shuffle_weight_a16w4(w2, 16, gate_up=False) +
  shuffle_scale_a16w4(scale, E, gate_up=False).
- Set is_shuffled=True on both w13_weight and w2_weight so that
  ck_moe_stage2_fwd selects the preshuffle_on module variant.
- Also set is_shuffled=True on w2_weight in the existing gpt_oss/Swiglu
  branch (latent bug: weight was shuffled but attribute was missing).

Scope: Only the Quark branch is new code. Currently only MiniMax-M2.1-
MXFP4 uses quant_method="quark" with fp4x2 MoE. The gpt_oss is_shuffled
fix is a one-line correctness fix for DeepSeek-R1 FP4 and similar models.

Depends on: ROCm/aiter thpereir/cktile_a4w4 branch (CK JIT a4w4 dispatch)

Tested: MiniMax-M2.1-MXFP4 TP=8 server starts and produces correct output.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant