Skip to content

fix: ck_moe_stage1 split-K buffer overflow from padding scatter (alternative to #2508)#2547

Closed
ChuanLi1101 wants to merge 1 commit intomainfrom
chuan/fix-ck-moe-stage1-splitk-scatter
Closed

fix: ck_moe_stage1 split-K buffer overflow from padding scatter (alternative to #2508)#2547
ChuanLi1101 wants to merge 1 commit intomainfrom
chuan/fix-ck-moe-stage1-splitk-scatter

Conversation

@ChuanLi1101
Copy link
Copy Markdown

@ChuanLi1101 ChuanLi1101 commented Mar 31, 2026

Summary

Fix out-of-bounds buffer overflow in \ck_moe_stage1\ when splitK is enabled.

Root cause

The CK MoE kernel uses \sorted_size = min(token_num * topk * block_m, sorted_token_ids.shape[0])\ as its M dimension. The kernel launches tile-based blocks covering the entire M range and scatters results to the output buffer. The output buffer must be large enough to accommodate \sorted_size\ rows.

The original code allocated only (token_num, topk, w1.shape[1])\ = \ oken_num * topk\ rows (a 3D tensor). For the padding entries in \sorted_token_ids, the sentinel value (topk << 24 | token_num)\ decodes to scatter position \ oken_num * topk + topk, which exceeds the allocated buffer. Additionally, the kernel expects the output buffer to span at least \sorted_size\ rows to match its tile-based computation grid.

Fix

  • Compute \sorted_size = min(token_num * topk * block_m, sorted_token_ids.shape[0])\ (matching the C++ wrapper logic)
  • Allocate a 2D fp32 buffer of shape (sorted_size, w1.shape[1])\ instead of the undersized 3D (token_num, topk, w1.shape[1])\
  • After the kernel, slice only the valid rows \ mp_out[:token_num*topk, :]\ before passing to \silu_and_mul\ / \gelu_and_mul\

Verification

Tested on MI355X (gfx950) with multiple token/topk/expert configurations:
\\

SplitK Scatter Fix Verification

[tok=1 topk=8 E=256] OK shape=torch.Size([1, 8, 256]) nan=False inf=False
[tok=2 topk=8 E=256] OK shape=torch.Size([2, 8, 256]) nan=False inf=False
[tok=4 topk=8 E=256] OK shape=torch.Size([4, 8, 256]) nan=False inf=False
[tok=16 topk=8 E=256] OK shape=torch.Size([16, 8, 256]) nan=False inf=False
[tok=1 topk=4 E=64] OK shape=torch.Size([1, 4, 256]) nan=False inf=False
[tok=3 topk=6 E=128] OK shape=torch.Size([3, 6, 256]) nan=False inf=False

Results: 6 passed, 0 failed out of 6
ALL TESTS PASSED!
\\

Comparison to PR #2508

PR #2508 uses \sorted_token_ids.shape[0]\ rows (safe but over-allocates). This PR uses \sorted_size\ (the exact M dimension the C++ wrapper computes), which is the minimal correct size. Both are valid; this PR is tighter on memory.

Test plan

  • Verified on MI355X with 6 different token/topk/expert configs
  • All tests pass with no NaN/Inf in output
  • Non-splitK path is unchanged (tmp_out = out)

@ChuanLi1101 ChuanLi1101 requested a review from a team March 31, 2026 06:20
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-355 Run Triton tests on MI355 in addition to MI325
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2547 --add-label <label>

The CK kernel scatters output via sorted_token_ids using:
  token_offset = (fused_token & 0xffffff) * topk + (fused_token >> 24)

Padding entries use the sentinel value (topk << 24 | token_num),
which decodes to scatter position (token_num * topk + topk) -- beyond
the valid output range [0, token_num * topk). The original buffer
(token_num, topk, w1.shape[1]) only has token_num * topk rows, so
the padding scatter writes out of bounds, causing "HIP runtime error:
invalid argument" during CUDA graph capture (e.g. DeepSeek-R1 decode
with token_num=1, topk=8, block_m=16).

Fix: allocate (token_num * topk + topk + 1) rows -- the exact minimum
needed to absorb all padding scatter writes. After the kernel, slice
only the valid [0, token_num * topk) rows for the activation.

Related: #2508
Made-with: Cursor
@ChuanLi1101 ChuanLi1101 force-pushed the chuan/fix-ck-moe-stage1-splitk-scatter branch from 1ff7e70 to ab58051 Compare March 31, 2026 07:46
frida-andersson added a commit to frida-andersson/aiter that referenced this pull request Apr 1, 2026
Align split-K tmp_out allocation with CK sorted_size and scatter padding
so tile writes stay in bounds; slice valid rows for silu/gelu_and_mul.

Upstream: ROCm#2547
Made-with: Cursor
frida-andersson added a commit to frida-andersson/aiter that referenced this pull request Apr 1, 2026
Allow callers to supply a pre-allocated (M, model_dim) buffer for
moe_sorting instead of torch.empty each forward, for DSv32/vLLM integration.

Keeps ck_moe_stage1 split-K fix from ROCm#2547.

docs: update dsv32-opt-branch provenance (moe_buf + ROCm#2547).
Made-with: Cursor
frida-andersson added a commit to frida-andersson/aiter that referenced this pull request Apr 30, 2026
Align split-K tmp_out allocation with CK sorted_size and scatter padding
so tile writes stay in bounds; slice valid rows for silu/gelu_and_mul.

Upstream: ROCm#2547
Made-with: Cursor
frida-andersson added a commit to frida-andersson/aiter that referenced this pull request Apr 30, 2026
Allow callers to supply a pre-allocated (M, model_dim) buffer for
moe_sorting instead of torch.empty each forward, for DSv32/vLLM integration.

Keeps ck_moe_stage1 split-K fix from ROCm#2547.

docs: update dsv32-opt-branch provenance (moe_buf + ROCm#2547).
Made-with: Cursor
sunway513 added a commit that referenced this pull request May 1, 2026
@ChuanLi1101
Copy link
Copy Markdown
Author

Superseded by #2551 (merged to main on 2026-03-31 by @rbrugaro, commit e47cc0e). #2551 implements the same fix with two improvements over this PR: (1) uses orch.empty instead of orch.zeros to avoid double-zeroing (CK kernel zeros the buffer via hipMemsetAsync when KBatch > 1), and (2) keeps the .view(dtypes.fp32) call on the sliced �alid_out. Closing as duplicate.
A follow-up PR will address the same pattern in cktile_moe_stage1, which currently has a WARNING comment on main flagging the same undersized-buffer bug.

@ChuanLi1101 ChuanLi1101 closed this May 1, 2026
sunway513 added a commit that referenced this pull request May 4, 2026
…e.py

- Restore import to match main: use `from aiter import
  fused_dynamic_mxfp4_quant_moe_sort, mxfp4_moe_sort_fwd` instead of
  importing from internal triton path and fp4_utils
- Replace all fp4_utils.moe_mxfp4_sort() calls with mxfp4_moe_sort_fwd()
  using correct parameter names (cols= instead of block_size=)
- Remove all moe_buf preallocated buffer additions (PR #2687 rejected):
  parameter defaults, if-guards, and pass-throughs in _moe_sorting_impl,
  moe_sorting, fused_moe, fused_moe_fake, and fused_moe_
- Fix moe_sorting_dispatch_policy type annotation: bool -> int in
  fused_moe_fake and fused_moe_
- Remove moe_buf pass-through test from test_moe_sorting.py
- Preserve legitimate fp4_utils usage (mxfp4_to_f32, e8m0_to_f32) with
  local imports in stage1/stage2 fallback functions
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant