fix: correct MXFP4 MoE weight shuffle for Quark models (MiniMax-M2.1)#632
Draft
fix: correct MXFP4 MoE weight shuffle for Quark models (MiniMax-M2.1)#632
Conversation
Fix the Quark branch of Mxfp4MoEMethod.process_weights_after_loading() to use the correct pre-shuffle format expected by CK JIT MoE kernels (DeviceMoeGemmMXBPreShuffle). Previously the Quark path called shuffle_weights() (a16w4 interleave format) and e8m0_shuffle() for scales. The CK JIT a4w4 kernels (ck_moe_stage1/ck_moe_stage2_fwd) expect shuffle_weight_a16w4(gate_up= False) format for both w1 and w2, and shuffle_scale_a16w4() for scales. Changes: - w13 (gate+up): shuffle_weight_a16w4(w13, 16, gate_up=False) + shuffle_scale_a16w4(scale, E, gate_up=False). The a4w4 kernel uses gate_up=False even for w1 which has 2*N rows (unlike the a16w4 path which uses gate_up=True for w1). - w2 (down-proj): shuffle_weight_a16w4(w2, 16, gate_up=False) + shuffle_scale_a16w4(scale, E, gate_up=False). - Set is_shuffled=True on both w13_weight and w2_weight so that ck_moe_stage2_fwd selects the preshuffle_on module variant. - Also set is_shuffled=True on w2_weight in the existing gpt_oss/Swiglu branch (latent bug: weight was shuffled but attribute was missing). Scope: Only the Quark branch is new code. Currently only MiniMax-M2.1- MXFP4 uses quant_method="quark" with fp4x2 MoE. The gpt_oss is_shuffled fix is a one-line correctness fix for DeepSeek-R1 FP4 and similar models. Depends on: ROCm/aiter thpereir/cktile_a4w4 branch (CK JIT a4w4 dispatch) Tested: MiniMax-M2.1-MXFP4 TP=8 server starts and produces correct output.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fix the Quark branch of Mxfp4MoEMethod.process_weights_after_loading() to use the correct pre-shuffle format expected by CK JIT MoE kernels (DeviceMoeGemmMXBPreShuffle).
Previously the Quark path called shuffle_weights() (a16w4 interleave format) and e8m0_shuffle() for scales. The CK JIT a4w4 kernels (ck_moe_stage1/ck_moe_stage2_fwd) expect shuffle_weight_a16w4(gate_up= False) format for both w1 and w2, and shuffle_scale_a16w4() for scales.
Changes:
Scope: Only the Quark branch is new code. Currently only MiniMax-M2.1- MXFP4 uses quant_method="quark" with fp4x2 MoE. The gpt_oss is_shuffled fix is a one-line correctness fix for DeepSeek-R1 FP4 and similar models.
Depends on: ROCm/aiter thpereir/cktile_a4w4 branch (CK JIT a4w4 dispatch)
Tested: MiniMax-M2.1-MXFP4 TP=8 server starts and produces correct output.
Motivation
Technical Details
Test Plan
Test Result
Serving with ATOM:
lm-eval before changes:
lm-eval after changes:
Submission Checklist