[Silo] Bulk merge: kernel fixes and features (SplitK, MoE fixes, Qwen3-Next, pa_mqa OOB)#3005
[Silo] Bulk merge: kernel fixes and features (SplitK, MoE fixes, Qwen3-Next, pa_mqa OOB)#3005
Conversation
Propagate the splitK parameter (as KBatch = 2^splitK) through the block-scale GEMM kernel infrastructure so that the tuning scripts can sweep split-K values to improve occupancy on small-M shapes. CK path: add KBatch parameter to gemm_a8w8_blockscale_impl and call SetKBatch on the device argument. The CK invoker handles output zeroing and atomic accumulation internally. CKTile path: add k_batch parameter to gemm_a8w8_blockscale_cktile_impl, remove the "split-k is not supported yet" runtime guard, and add hipMemsetAsync to zero the output buffer before atomic accumulation. Non-tune entry points pass KBatch=1 (no split-K) to preserve existing behavior. Code generation scripts (gen_instances.py, gen_instances_cktile.py) updated to include the new parameter in generated wrappers and manifests. Made-with: Cursor
The tuning infrastructure already sweeps splitK and writes it to the CSV, but the production dispatch ignored it and hardcoded KBatch=1. Add splitK as a runtime parameter to the non-tune entry points so tuned split-K values are used without compiling the full _tune instance set. Made-with: Cursor
The CK kernel scatters output via sorted_token_ids using: token_offset = (fused_token & 0xffffff) * topk + (fused_token >> 24) Padding entries use the sentinel value (topk << 24 | token_num), which decodes to scatter position (token_num * topk + topk) -- beyond the valid output range [0, token_num * topk). The original buffer (token_num, topk, w1.shape[1]) only has token_num * topk rows, so the padding scatter writes out of bounds, causing "HIP runtime error: invalid argument" during CUDA graph capture (e.g. DeepSeek-R1 decode with token_num=1, topk=8, block_m=16). Fix: allocate (token_num * topk + topk + 1) rows -- the exact minimum needed to absorb all padding scatter writes. After the kernel, slice only the valid [0, token_num * topk) rows for the activation. Related: #2508 Made-with: Cursor
…ue, add correctness test Agent-Logs-Url: https://github.com/ROCm/aiter/sessions/e3b37b0f-e151-4935-ad89-fd72436d41e2 Co-authored-by: samremes <181322991+samremes@users.noreply.github.com>
Remove unused variable in rmsnorm FP8 test ref. Apply Black to kernels, launchers, tests, and gated_delta_rule decode __init__. Made-with: Cursor
Add an optional `moe_buf` parameter through the moe_sorting and fused_moe call chain. When provided, the sorting kernel writes directly into the caller's buffer instead of allocating a new one, eliminating a redundant copy on the output path. Made-with: Cursor
Made-with: Cursor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
There was a problem hiding this comment.
Pull request overview
Bulk merge of multiple kernel-focused PRs for the v0.1.13 deadline, spanning new Triton decode/quant kernels, CK/CKTile Split-K enablement, MoE dispatch/sorting fixes, and a pa_mqa_logits OOB store fix.
Changes:
- Add new Triton kernels + launchers and corresponding tests (gated delta rule decode, causal conv1d single-token update, fused RMS+gate FP8 group quant).
- Enable Split-K for CK/CKTile a8w8 blockscale GEMMs end-to-end (Python → pybind → C++ wrappers → manifests/instances), including output zeroing for atomic accumulation.
- Fix/extend MoE behavior (preallocated sorting buffer plumbing, Quark W4A6 quant remap, split-K stage1 buffer sizing) and harden tuner/JIT loading; mask pa_mqa_logits OOB stores.
Reviewed changes
Copilot reviewed 39 out of 40 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| op_tests/triton_tests/test_fused_rearrange_sigmoid_gdr.py | Adds correctness sweep test for fused rearrange+sigmoid gated delta rule decode kernel. |
| op_tests/triton_tests/test_causal_conv1d_update_single_token.py | Adds reference-based tests for single-token causal conv1d update and fused reshape launcher. |
| op_tests/triton_tests/quant/test_fused_rms_gated_fp8_group_quant.py | Adds reference + sweep + error-path tests for fused RMS+gate FP8 group quant. |
| op_tests/test_moe_sorting.py | Extends MoE sorting test to validate preallocated moe_buf pass-through. |
| op_tests/test_gemm_a8w8_blockscale.py | Updates benchmark harness and adds split-K correctness checks for blockscale GEMM. |
| csrc/include/rocm_ops.hpp | Extends pybind signatures to accept splitK parameter for blockscale GEMM APIs. |
| csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_common.cuh | Adds KBatch (split-K) plumbing into CK blockscale GEMM implementation. |
| csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_cktile_common.cuh | Enables split-K for CKTile path; zeros output buffer when k_batch > 1 to support atomic accumulation. |
| csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_cktile.h | Updates CKTile header API to include splitK parameter. |
| csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale.h | Updates CK header API to include splitK parameter. |
| csrc/ck_gemm_a8w8_blockscale/gen_instances_cktile.py | Updates CKTile instance generator to thread k_batch through to impl. |
| csrc/ck_gemm_a8w8_blockscale/gen_instances.py | Updates CK instance generator to thread KBatch through to impl. |
| csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_tune.cu | Propagates KBatch into tuned-kernel dispatch calls. |
| csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_cktile_tune.cu | Propagates KBatch into tuned CKTile-kernel dispatch calls. |
| csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_cktile.cu | Adds splitK argument validation and maps splitK→KBatch for CKTile dispatch. |
| csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale.cu | Adds splitK argument validation and maps splitK→KBatch for CK dispatch. |
| aiter/utility/mp_tuner.py | Ensures failed tasks always record dummy results to avoid downstream IndexError during result reconstruction. |
| aiter/ops/triton/quant/fused_fp8_quant.py | Adds fused RMS+gate FP8 group quant API + helper utilities for fp8 bounds and launch heuristics. |
| aiter/ops/triton/quant/init.py | Exports new FP8 quant helpers and fused kernel API. |
| aiter/ops/triton/gluon/pa_mqa_logits.py | Adds missing upper-bound masks to OutLogits_buffer stores to prevent OOB writes. |
| aiter/ops/triton/gated_delta_net/fused_rearrange_sigmoid_gdr.py | Adds new high-level launcher for fused rearrange+sigmoid gated delta rule decode. |
| aiter/ops/triton/gated_delta_net/init.py | Wires new gated-delta-rule decode launcher into module exports. |
| aiter/ops/triton/causal_conv1d_update_single_token.py | Adds Python launchers for single-token causal conv1d update kernels. |
| aiter/ops/triton/_triton_kernels/quant/fused_fp8_quant.py | Adds the underlying Triton kernel for fused RMS+gate FP8 group quant. |
| aiter/ops/triton/_triton_kernels/gated_delta_rule/decode/fused_rearrange_sigmoid_gdr.py | Adds Triton decode kernel implementation for fused rearrange+sigmoid gated delta rule. |
| aiter/ops/triton/_triton_kernels/gated_delta_rule/decode/init.py | Exports the new decode kernel symbol. |
| aiter/ops/triton/_triton_kernels/causal_conv1d_update_single_token.py | Adds Triton kernels for single-token causal conv1d update + fused reshape path. |
| aiter/ops/gemm_op_a8w8.py | Threads splitK from config into gemm_a8w8_blockscale_ck/cktile calls. |
| aiter/jit/core.py | Stops forcing torch_exclude=True for ctypes-loaded modules; relies on per-module config to ensure correct linkage. |
| aiter/fused_moe_dp_shared_expert.py | Remaps QuantType.No→per_1x32 for fp4x2 weights (Quark W4A6) to match CK kernel requirements. |
| aiter/fused_moe.py | Adds optional preallocated moe_buf, refactors MXFP4 sorting calls, adjusts heuristics, and fixes CK split-K stage1 tmp_out sizing. |
| aiter/configs/model_configs/a8w8_blockscale_untuned_gemm_qwen3_next_80b_a3b.csv | Adds untuned shape list CSV for Qwen3-Next-80B-A3B blockscale GEMM coverage. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Correctness check: verify split-K produces matching results | ||
| print("\nRunning split-K correctness checks ...") | ||
| for splitK in [1, 2]: | ||
| test_splitk_correctness(m=4, n=512, k=16384, splitK=splitK) |
| dummy_results = [] | ||
| add_dummy_result(k, dummy_results) | ||
| result_dict[k] = ( | ||
| dummy_results if shape_grouped else [dummy_results[0]] | ||
| ) |
| block_size_M: int = -1, | ||
| num_local_tokens: Optional[torch.Tensor] = None, | ||
| moe_sorting_dispatch_policy: int = 0, | ||
| moe_sorting_dispatch_policy: bool = 0, |
| block_size_M: int = -1, | ||
| num_local_tokens: Optional[torch.Tensor] = None, | ||
| moe_sorting_dispatch_policy: int = 0, | ||
| moe_sorting_dispatch_policy: bool = 0, | ||
| dtype: Optional[torch.dtype] = None, | ||
| hidden_pad: int = 0, |
| # Verify moe_buf pass-through: pre-allocated buffer should be reused | ||
| pre_buf = torch.empty((token, model_dim), dtype=dtype, device="cuda") | ||
| pre_buf_ptr = pre_buf.data_ptr() | ||
| ( | ||
| ( | ||
| sorted_ids_c, | ||
| sorted_weights_c, | ||
| sorted_expert_ids_c, | ||
| num_tokens_post_padded_c, | ||
| moe_buf_c, | ||
| ), | ||
| _, | ||
| ) = run_perftest( | ||
| moe_sorting, | ||
| topk_ids, | ||
| topk_weights, | ||
| E, | ||
| model_dim, | ||
| dtype, | ||
| BLOCK_SIZE_M, | ||
| expert_mask, | ||
| num_local_tokens, | ||
| dispatch_policy, | ||
| moe_buf=pre_buf, | ||
| num_warmup=1, |
| # Shapes aligned with ``test_gated_delta_rule.test_fused_recurrent``; dtypes are | ||
| # half-precision only — long packed ``T`` with float32 activations tends to blow | ||
| # up the recurrent reference / kernel without tighter dynamic-range clamps. | ||
| # Each row ends with ``use_qk_l2norm_in_kernel`` (True for stable long-T sweep). | ||
| # One small bf16 row uses False to cover the no–L2-norm path (replaces former ``basic``). | ||
| _FUSED_GDR_SWEEP = [ | ||
| (63, 1, 1, 64, 1, 1, torch.float16, True), | ||
| (500, 4, 4, 60, 1, 1, torch.float16, True), | ||
| (1000, 2, 8, 128, 1, 0.1, torch.float16, True), | ||
| (1024, 2, 2, 128, 0.1, 1, torch.float16, True), | ||
| (1024, 3, 3, 128, 1, 10, torch.float16, True), | ||
| (2048, 4, 4, 64, 0.1, 1, torch.float16, True), | ||
| (1024, 4, 4, 128, 1, 0.1, torch.float16, True), | ||
| (1024, 4, 8, 128, 1, 10, torch.float16, True), | ||
| (1024, 4, 4, 128, 1, 0.1, torch.bfloat16, True), | ||
| (1024, 4, 8, 128, 1, 1, torch.bfloat16, True), | ||
| (2048, 4, 8, 64, 0.1, 1, torch.bfloat16, True), | ||
| (8, 4, 4, 16, 16**-0.5, 1, torch.bfloat16, False), | ||
| ] |
| assert x.is_contiguous() and z.is_contiguous() | ||
| assert x.shape == z.shape, "x and z must have the same shape" |
|
ci test failed, File "/opt/venv/lib/python3.12/site-packages/aiter/fused_moe.py", line 1301, in fused_moe_2stages ❌ Test failed: op_tests/test_moe_2stage.py |
ae4ea19 to
9e6d4b4
Compare
|
@LingpengJin — This PR is ready for your review. 29/30 CI checks pass. The 1 failure (Triton Shard 7, Included PRs (6):
Excluded PRs (2) — need rework before they can be merged:
@vecheruk-amd @ChuanLi1101 — Your PRs were excluded from this bulk merge because the |
b2bec91 to
92c88fa
Compare
The `zmq` meta-package fails to install on some CI runners because it cannot resolve the `pyzmq` dependency. Use `pyzmq` directly, which is the actual package providing ZeroMQ bindings for Python. Fixes Triton Test Shard 7 setup failures.
Set pip global retries=15 and timeout=120s in build_aiter_triton.sh to handle transient PyPI network failures on self-hosted runners. Shard 5/7 failures were caused by RemoteDisconnected during pip install.
pyzmq is only used by aiter.dist.shm_broadcast, not by any triton test. When PyPI is unreachable on self-hosted runners, the pyzmq install failure should not block the entire CI shard. Split pyzmq into a separate pip install with || fallback so triton tests can proceed even when PyPI connectivity is degraded.
When batch pip install fails (e.g., PyPI connectivity issues on self-hosted runners), retry each package individually. Only pyzmq is allowed to fail silently since it's only used by aiter.dist.shm_broadcast and not required by any CI test suite. Critical packages (pandas, einops, numpy) must still succeed.
92c88fa to
e1aa5eb
Compare
Commit c849fd5 ("Add bf16 MLA decode kernel for gqa_ratio=64, qseqlen=1 (non-persistent)") zeroed ptr_RP and out_16_nosplit for all non-persistent dispatch. The legacy QH16 ASM kernel used for nhead=32 (MLA_A16W16_1TG_4W_32mx1_16nx1_Coex0_Msk1_QH16.co) still writes directly to the output buffer via ptr_RP when kv_split==1. Dereferencing nullptr causes a GPU memory access fault during CUDA graph capture on MI355X (gfx950) with DeepSeek-V3.2 at TP4. Fix: - Conditionally restore ptr_RP and out_16_nosplit in the non-persistent path for legacy kernels (gqa_ratio * max_seqlen_q <= 64) while keeping nullptr for newer kernels (e.g. gqa_ratio=64). - Restore the bf16 nhead in [32,64] early-return after stage1 when num_kv_splits==1 to prevent stage2 from overwriting the kernel's direct output. Tested on MI355X TP4 with deepseek-ai/DeepSeek-V3.2 (nhead=32): - No crash during CUDA graph capture - Correct GSM8K accuracy Made-with: Cursor
Reverting cherry-pick of #2983 from this bulk merge. The MLA nhead=32 non-persistent decode fix causes deterministic test_mla k_cache and mla_decode-absorb precision failures on CI MI35X runners (Shard 1 & 2). #2983 should go through its own PR with proper CI validation by the original author (frida-andersson).
flydsl_moe_stage1 returns (out, out_scale_sorted) when the kernel uses fused fp4/fp8 quantization. The tuple unpack logic was removed during earlier refactoring but the kernel behavior was not changed, causing fused_moe_2stages to crash with: AttributeError: 'tuple' object has no attribute 'view' Restore the unpack: detect tuple return, extract tensor and scale, handle fp4 byte-packing trim, and skip redundant Python-side requant when the kernel already produced sorted scales.
Summary
Bulk merge of 8 Silo kernel PRs for v0.1.13 deadline (vLLM 0.21 freeze 2026-05-08).
Included PRs
Risk
Medium — includes kernel code changes, CK GEMM extensions, MoE dispatch fixes, and new Triton kernels. Auto-resolved merge conflicts in
fused_moe.pyandtest_gemm_a8w8_blockscale.pyneed manual verification.Test Plan