Skip to content

fix: MoE dispatch for Quark W4A6 models (MXFP4 weights with QuantType.No)#2457

Open
vecheruk-amd wants to merge 1 commit intoROCm:mainfrom
vecheruk-amd:fix/quark-w4a6-mxfp4-compat-v2
Open

fix: MoE dispatch for Quark W4A6 models (MXFP4 weights with QuantType.No)#2457
vecheruk-amd wants to merge 1 commit intoROCm:mainfrom
vecheruk-amd:fix/quark-w4a6-mxfp4-compat-v2

Conversation

@vecheruk-amd
Copy link
Copy Markdown

Motivation

W4A6 models store MoE weights in MXFP4 (fp4x2 dtype), but use MXFP6 for activation quantization. Because the Quark quantization scheme handles activation quantization separately, it passes QuantType.No to the AITER CK-based fused MoE kernel. However, the CK kernel only supports A4W4 (both activations and weights in fp4) and there is no codepath for bf16 activations with fp4x2 weights.

Technical Details

After the existing quant_remap lookup in ck_moe_2stages (and the equivalent in ck_moe_2stages_dp), detect the unsupported combination of QuantType.No with fp4x2 weights and remap to QuantType.per_1x32. This ensures activations are dynamically quantized to fp4x2 at runtime, matching what the CK kernel expects.

Test Plan

Verified with ziliangpeng/DeepSeek-V3-Quark-MXFP4-v4-w4a6 on MI355X (gfx950) / ROCm 7.2 / vLLM 0.17.1. MoE layers execute successfully in both eager and compile modes. Results of the experiment can be found here: https://github.com/AMD-AGI/di-recipes/blob/main/tools/prompt_replay/baselines/DeepSeek-V3-0324/MI355/serve_dsr1_0528_mxfp4-v4-w4a6_20260316_smci355-ccs-aus-m15-13.cs-aus.dcgpu.log

Test Result

Submission Checklist

@vecheruk-amd vecheruk-amd requested a review from a team March 24, 2026 19:38
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-355 Run Triton tests on MI355 in addition to MI325
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2457 --add-label <label>

Comment thread aiter/fused_moe.py
quant_type = quant_remap.get(quant_type, quant_type)
# W4A6: remap QuantType.No -> per_1x32 for fp4x2 weights so activations
# get quantized to fp4x2 at runtime (CK MoE only supports A4W4).
if quant_type == QuantType.No and w1.dtype == dtypes.fp4x2:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not just send QuantType.per_1x32 in, instead of this hack

sunway513 added a commit that referenced this pull request May 4, 2026
…e.py

- Restore import to match main: use `from aiter import
  fused_dynamic_mxfp4_quant_moe_sort, mxfp4_moe_sort_fwd` instead of
  importing from internal triton path and fp4_utils
- Replace all fp4_utils.moe_mxfp4_sort() calls with mxfp4_moe_sort_fwd()
  using correct parameter names (cols= instead of block_size=)
- Remove all moe_buf preallocated buffer additions (PR #2687 rejected):
  parameter defaults, if-guards, and pass-throughs in _moe_sorting_impl,
  moe_sorting, fused_moe, fused_moe_fake, and fused_moe_
- Fix moe_sorting_dispatch_policy type annotation: bool -> int in
  fused_moe_fake and fused_moe_
- Remove moe_buf pass-through test from test_moe_sorting.py
- Preserve legitimate fp4_utils usage (mxfp4_to_f32, e8m0_to_f32) with
  local imports in stage1/stage2 fallback functions
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants