Skip to content

[Silo] Bulk merge: kernel fixes and features (SplitK, MoE fixes, Qwen3-Next, pa_mqa OOB)#3005

Open
sunway513 wants to merge 60 commits intomainfrom
silo/v0.1.13-kernels
Open

[Silo] Bulk merge: kernel fixes and features (SplitK, MoE fixes, Qwen3-Next, pa_mqa OOB)#3005
sunway513 wants to merge 60 commits intomainfrom
silo/v0.1.13-kernels

Conversation

@sunway513
Copy link
Copy Markdown
Collaborator

Summary

Bulk merge of 8 Silo kernel PRs for v0.1.13 deadline (vLLM 0.21 freeze 2026-05-08).

Included PRs

Risk

Medium — includes kernel code changes, CK GEMM extensions, MoE dispatch fixes, and new Triton kernels. Auto-resolved merge conflicts in fused_moe.py and test_gemm_a8w8_blockscale.py need manual verification.

Test Plan

  • CI passes (pre-checks, op tests gfx942+gfx950, triton tests)
  • Manual review of auto-resolved conflicts
  • ATOM benchmark regression check

vecheruk-amd and others added 30 commits March 24, 2026 19:33
Propagate the splitK parameter (as KBatch = 2^splitK) through the
block-scale GEMM kernel infrastructure so that the tuning scripts
can sweep split-K values to improve occupancy on small-M shapes.

CK path: add KBatch parameter to gemm_a8w8_blockscale_impl and call
SetKBatch on the device argument. The CK invoker handles output
zeroing and atomic accumulation internally.

CKTile path: add k_batch parameter to gemm_a8w8_blockscale_cktile_impl,
remove the "split-k is not supported yet" runtime guard, and add
hipMemsetAsync to zero the output buffer before atomic accumulation.

Non-tune entry points pass KBatch=1 (no split-K) to preserve existing
behavior. Code generation scripts (gen_instances.py, gen_instances_cktile.py)
updated to include the new parameter in generated wrappers and manifests.

Made-with: Cursor
The tuning infrastructure already sweeps splitK and writes it to the CSV,
but the production dispatch ignored it and hardcoded KBatch=1. Add splitK
as a runtime parameter to the non-tune entry points so tuned split-K
values are used without compiling the full _tune instance set.

Made-with: Cursor
The CK kernel scatters output via sorted_token_ids using:
  token_offset = (fused_token & 0xffffff) * topk + (fused_token >> 24)

Padding entries use the sentinel value (topk << 24 | token_num),
which decodes to scatter position (token_num * topk + topk) -- beyond
the valid output range [0, token_num * topk). The original buffer
(token_num, topk, w1.shape[1]) only has token_num * topk rows, so
the padding scatter writes out of bounds, causing "HIP runtime error:
invalid argument" during CUDA graph capture (e.g. DeepSeek-R1 decode
with token_num=1, topk=8, block_m=16).

Fix: allocate (token_num * topk + topk + 1) rows -- the exact minimum
needed to absorb all padding scatter writes. After the kernel, slice
only the valid [0, token_num * topk) rows for the activation.

Related: #2508
Made-with: Cursor
…ue, add correctness test

Agent-Logs-Url: https://github.com/ROCm/aiter/sessions/e3b37b0f-e151-4935-ad89-fd72436d41e2

Co-authored-by: samremes <181322991+samremes@users.noreply.github.com>
Remove unused variable in rmsnorm FP8 test ref. Apply Black to
kernels, launchers, tests, and gated_delta_rule decode __init__.

Made-with: Cursor
Add an optional `moe_buf` parameter through the moe_sorting and
fused_moe call chain. When provided, the sorting kernel writes
directly into the caller's buffer instead of allocating a new one,
eliminating a redundant copy on the output path.

Made-with: Cursor
@sunway513 sunway513 requested review from a team and Copilot May 1, 2026 23:30
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 1, 2026

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 3005 --add-label <label>

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Bulk merge of multiple kernel-focused PRs for the v0.1.13 deadline, spanning new Triton decode/quant kernels, CK/CKTile Split-K enablement, MoE dispatch/sorting fixes, and a pa_mqa_logits OOB store fix.

Changes:

  • Add new Triton kernels + launchers and corresponding tests (gated delta rule decode, causal conv1d single-token update, fused RMS+gate FP8 group quant).
  • Enable Split-K for CK/CKTile a8w8 blockscale GEMMs end-to-end (Python → pybind → C++ wrappers → manifests/instances), including output zeroing for atomic accumulation.
  • Fix/extend MoE behavior (preallocated sorting buffer plumbing, Quark W4A6 quant remap, split-K stage1 buffer sizing) and harden tuner/JIT loading; mask pa_mqa_logits OOB stores.

Reviewed changes

Copilot reviewed 39 out of 40 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
op_tests/triton_tests/test_fused_rearrange_sigmoid_gdr.py Adds correctness sweep test for fused rearrange+sigmoid gated delta rule decode kernel.
op_tests/triton_tests/test_causal_conv1d_update_single_token.py Adds reference-based tests for single-token causal conv1d update and fused reshape launcher.
op_tests/triton_tests/quant/test_fused_rms_gated_fp8_group_quant.py Adds reference + sweep + error-path tests for fused RMS+gate FP8 group quant.
op_tests/test_moe_sorting.py Extends MoE sorting test to validate preallocated moe_buf pass-through.
op_tests/test_gemm_a8w8_blockscale.py Updates benchmark harness and adds split-K correctness checks for blockscale GEMM.
csrc/include/rocm_ops.hpp Extends pybind signatures to accept splitK parameter for blockscale GEMM APIs.
csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_common.cuh Adds KBatch (split-K) plumbing into CK blockscale GEMM implementation.
csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_cktile_common.cuh Enables split-K for CKTile path; zeros output buffer when k_batch > 1 to support atomic accumulation.
csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_cktile.h Updates CKTile header API to include splitK parameter.
csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale.h Updates CK header API to include splitK parameter.
csrc/ck_gemm_a8w8_blockscale/gen_instances_cktile.py Updates CKTile instance generator to thread k_batch through to impl.
csrc/ck_gemm_a8w8_blockscale/gen_instances.py Updates CK instance generator to thread KBatch through to impl.
csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_tune.cu Propagates KBatch into tuned-kernel dispatch calls.
csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_cktile_tune.cu Propagates KBatch into tuned CKTile-kernel dispatch calls.
csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_cktile.cu Adds splitK argument validation and maps splitK→KBatch for CKTile dispatch.
csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale.cu Adds splitK argument validation and maps splitK→KBatch for CK dispatch.
aiter/utility/mp_tuner.py Ensures failed tasks always record dummy results to avoid downstream IndexError during result reconstruction.
aiter/ops/triton/quant/fused_fp8_quant.py Adds fused RMS+gate FP8 group quant API + helper utilities for fp8 bounds and launch heuristics.
aiter/ops/triton/quant/init.py Exports new FP8 quant helpers and fused kernel API.
aiter/ops/triton/gluon/pa_mqa_logits.py Adds missing upper-bound masks to OutLogits_buffer stores to prevent OOB writes.
aiter/ops/triton/gated_delta_net/fused_rearrange_sigmoid_gdr.py Adds new high-level launcher for fused rearrange+sigmoid gated delta rule decode.
aiter/ops/triton/gated_delta_net/init.py Wires new gated-delta-rule decode launcher into module exports.
aiter/ops/triton/causal_conv1d_update_single_token.py Adds Python launchers for single-token causal conv1d update kernels.
aiter/ops/triton/_triton_kernels/quant/fused_fp8_quant.py Adds the underlying Triton kernel for fused RMS+gate FP8 group quant.
aiter/ops/triton/_triton_kernels/gated_delta_rule/decode/fused_rearrange_sigmoid_gdr.py Adds Triton decode kernel implementation for fused rearrange+sigmoid gated delta rule.
aiter/ops/triton/_triton_kernels/gated_delta_rule/decode/init.py Exports the new decode kernel symbol.
aiter/ops/triton/_triton_kernels/causal_conv1d_update_single_token.py Adds Triton kernels for single-token causal conv1d update + fused reshape path.
aiter/ops/gemm_op_a8w8.py Threads splitK from config into gemm_a8w8_blockscale_ck/cktile calls.
aiter/jit/core.py Stops forcing torch_exclude=True for ctypes-loaded modules; relies on per-module config to ensure correct linkage.
aiter/fused_moe_dp_shared_expert.py Remaps QuantType.No→per_1x32 for fp4x2 weights (Quark W4A6) to match CK kernel requirements.
aiter/fused_moe.py Adds optional preallocated moe_buf, refactors MXFP4 sorting calls, adjusts heuristics, and fixes CK split-K stage1 tmp_out sizing.
aiter/configs/model_configs/a8w8_blockscale_untuned_gemm_qwen3_next_80b_a3b.csv Adds untuned shape list CSV for Qwen3-Next-80B-A3B blockscale GEMM coverage.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +309 to +312
# Correctness check: verify split-K produces matching results
print("\nRunning split-K correctness checks ...")
for splitK in [1, 2]:
test_splitk_correctness(m=4, n=512, k=16384, splitK=splitK)
Comment thread aiter/utility/mp_tuner.py
Comment on lines +508 to +512
dummy_results = []
add_dummy_result(k, dummy_results)
result_dict[k] = (
dummy_results if shape_grouped else [dummy_results[0]]
)
Comment thread aiter/fused_moe.py
block_size_M: int = -1,
num_local_tokens: Optional[torch.Tensor] = None,
moe_sorting_dispatch_policy: int = 0,
moe_sorting_dispatch_policy: bool = 0,
Comment thread aiter/fused_moe.py
Comment on lines 232 to 236
block_size_M: int = -1,
num_local_tokens: Optional[torch.Tensor] = None,
moe_sorting_dispatch_policy: int = 0,
moe_sorting_dispatch_policy: bool = 0,
dtype: Optional[torch.dtype] = None,
hidden_pad: int = 0,
Comment on lines +177 to +201
# Verify moe_buf pass-through: pre-allocated buffer should be reused
pre_buf = torch.empty((token, model_dim), dtype=dtype, device="cuda")
pre_buf_ptr = pre_buf.data_ptr()
(
(
sorted_ids_c,
sorted_weights_c,
sorted_expert_ids_c,
num_tokens_post_padded_c,
moe_buf_c,
),
_,
) = run_perftest(
moe_sorting,
topk_ids,
topk_weights,
E,
model_dim,
dtype,
BLOCK_SIZE_M,
expert_mask,
num_local_tokens,
dispatch_policy,
moe_buf=pre_buf,
num_warmup=1,
Comment on lines +81 to +99
# Shapes aligned with ``test_gated_delta_rule.test_fused_recurrent``; dtypes are
# half-precision only — long packed ``T`` with float32 activations tends to blow
# up the recurrent reference / kernel without tighter dynamic-range clamps.
# Each row ends with ``use_qk_l2norm_in_kernel`` (True for stable long-T sweep).
# One small bf16 row uses False to cover the no–L2-norm path (replaces former ``basic``).
_FUSED_GDR_SWEEP = [
(63, 1, 1, 64, 1, 1, torch.float16, True),
(500, 4, 4, 60, 1, 1, torch.float16, True),
(1000, 2, 8, 128, 1, 0.1, torch.float16, True),
(1024, 2, 2, 128, 0.1, 1, torch.float16, True),
(1024, 3, 3, 128, 1, 10, torch.float16, True),
(2048, 4, 4, 64, 0.1, 1, torch.float16, True),
(1024, 4, 4, 128, 1, 0.1, torch.float16, True),
(1024, 4, 8, 128, 1, 10, torch.float16, True),
(1024, 4, 4, 128, 1, 0.1, torch.bfloat16, True),
(1024, 4, 8, 128, 1, 1, torch.bfloat16, True),
(2048, 4, 8, 64, 0.1, 1, torch.bfloat16, True),
(8, 4, 4, 16, 16**-0.5, 1, torch.bfloat16, False),
]
Comment on lines +415 to +416
assert x.is_contiguous() and z.is_contiguous()
assert x.shape == z.shape, "x and z must have the same shape"
@valarLip
Copy link
Copy Markdown
Collaborator

valarLip commented May 2, 2026

ci test failed,
https://github.com/ROCm/aiter/actions/runs/25237936284/job/74019356117?pr=3005#step:13:40990

File "/opt/venv/lib/python3.12/site-packages/aiter/fused_moe.py", line 1301, in fused_moe_2stages
a2 = a2.view(-1, inter_dim)
^^^^^^^
AttributeError: 'tuple' object has no attribute 'view'


❌ Test failed: op_tests/test_moe_2stage.py

@sunway513 sunway513 force-pushed the silo/v0.1.13-kernels branch 2 times, most recently from ae4ea19 to 9e6d4b4 Compare May 2, 2026 17:40
@sunway513
Copy link
Copy Markdown
Collaborator Author

@LingpengJin — This PR is ready for your review. 29/30 CI checks pass. The 1 failure (Triton Shard 7, test_moe_routing with 128 experts top-4) is a pre-existing issue on main, not introduced by this PR.

Included PRs (6):

Excluded PRs (2) — need rework before they can be merged:

@vecheruk-amd @ChuanLi1101 — Your PRs were excluded from this bulk merge because the fused_moe.py changes are too invasive to auto-merge safely. The specific issue: removing fuse_quant/xbf16 codepaths and changing the MoE quant dispatch causes test_moe_2stage to fail with AttributeError: 'tuple' object has no attribute 'view'. Please consider submitting the bug fixes separately from the refactor.

sunway513 added 4 commits May 3, 2026 14:29
The `zmq` meta-package fails to install on some CI runners because
it cannot resolve the `pyzmq` dependency. Use `pyzmq` directly,
which is the actual package providing ZeroMQ bindings for Python.

Fixes Triton Test Shard 7 setup failures.
Set pip global retries=15 and timeout=120s in build_aiter_triton.sh
to handle transient PyPI network failures on self-hosted runners.
Shard 5/7 failures were caused by RemoteDisconnected during pip install.
pyzmq is only used by aiter.dist.shm_broadcast, not by any triton
test. When PyPI is unreachable on self-hosted runners, the pyzmq
install failure should not block the entire CI shard.

Split pyzmq into a separate pip install with || fallback so triton
tests can proceed even when PyPI connectivity is degraded.
When batch pip install fails (e.g., PyPI connectivity issues on
self-hosted runners), retry each package individually. Only pyzmq
is allowed to fail silently since it's only used by
aiter.dist.shm_broadcast and not required by any CI test suite.

Critical packages (pandas, einops, numpy) must still succeed.
@sunway513 sunway513 force-pushed the silo/v0.1.13-kernels branch from 92c88fa to e1aa5eb Compare May 3, 2026 14:29
frida-andersson and others added 4 commits May 3, 2026 15:02
Commit c849fd5 ("Add bf16 MLA decode kernel for gqa_ratio=64,
qseqlen=1 (non-persistent)") zeroed ptr_RP and out_16_nosplit for all
non-persistent dispatch. The legacy QH16 ASM kernel used for nhead=32
(MLA_A16W16_1TG_4W_32mx1_16nx1_Coex0_Msk1_QH16.co) still writes
directly to the output buffer via ptr_RP when kv_split==1.
Dereferencing nullptr causes a GPU memory access fault during CUDA
graph capture on MI355X (gfx950) with DeepSeek-V3.2 at TP4.

Fix:
- Conditionally restore ptr_RP and out_16_nosplit in the non-persistent
  path for legacy kernels (gqa_ratio * max_seqlen_q <= 64) while
  keeping nullptr for newer kernels (e.g. gqa_ratio=64).
- Restore the bf16 nhead in [32,64] early-return after stage1 when
  num_kv_splits==1 to prevent stage2 from overwriting the kernel's
  direct output.

Tested on MI355X TP4 with deepseek-ai/DeepSeek-V3.2 (nhead=32):
- No crash during CUDA graph capture
- Correct GSM8K accuracy

Made-with: Cursor
Reverting cherry-pick of #2983 from this bulk merge. The MLA nhead=32
non-persistent decode fix causes deterministic test_mla k_cache and
mla_decode-absorb precision failures on CI MI35X runners (Shard 1 & 2).

#2983 should go through its own PR with proper CI validation by the
original author (frida-andersson).
flydsl_moe_stage1 returns (out, out_scale_sorted) when the kernel uses
fused fp4/fp8 quantization. The tuple unpack logic was removed during
earlier refactoring but the kernel behavior was not changed, causing
fused_moe_2stages to crash with:
  AttributeError: 'tuple' object has no attribute 'view'

Restore the unpack: detect tuple return, extract tensor and scale,
handle fp4 byte-packing trim, and skip redundant Python-side requant
when the kernel already produced sorted scales.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.