[Triton] optimized decode kernels for Qwen3-Next model#2423
Open
hellozhuo-amd wants to merge 22 commits intomainfrom
Open
[Triton] optimized decode kernels for Qwen3-Next model#2423hellozhuo-amd wants to merge 22 commits intomainfrom
hellozhuo-amd wants to merge 22 commits intomainfrom
Conversation
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
43bf452 to
da98e37
Compare
da98e37 to
3216bce
Compare
Remove unused variable in rmsnorm FP8 test ref. Apply Black to kernels, launchers, tests, and gated_delta_rule decode __init__. Made-with: Cursor
… single token decoding
…group_quant Colocate the gated RMSNorm + FP8 group quant path with the other fused FP8 ops. The Triton kernel is now _fused_rms_gated_fp8_group_quant_kernel in _triton_kernels/quant/fused_fp8_quant.py; the Python entry point is fused_rms_gated_fp8_group_quant in quant/fused_fp8_quant.py, with a docstring that contrasts it with fused_rms_fp8_group_quant. Remove the old rmsnorm_input_quant_fp8 module and rms_norm_input_quant_fp8 kernel file. Re-export the new symbol and helpers (get_fp8_min_max_bounds, calc_rows_per_block) from aiter.ops.triton.quant. Rename the test file to test_fused_rms_gated_fp8_group_quant.py and update test.sh. BREAKING CHANGE: rmsnorm_input_quant_fp8 is removed; use fused_rms_gated_fp8_group_quant instead. Made-with: Cursor
tpopp
added a commit
to tpopp/vllm
that referenced
this pull request
Apr 23, 2026
Follow upstream aiter rename (ROCm/aiter#2423). The kernel moved from aiter.ops.triton.quant.rmsnorm_input_quant_fp8 to aiter.ops.triton.quant.fused_fp8_quant.fused_rms_gated_fp8_group_quant. Update the vLLM custom op registration, impl, fake, getter, and fusion pass references accordingly. Made-with: Cursor Signed-off-by: Tres Popp <tres.popp@amd.com>
3 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
On the Qwen3-Next decode path, vLLM runs several Triton-backed steps back-to-back (causal conv1d state update, QKV layout work, gated delta rule / linear attention). Bringing well-tested kernels into aiter improves reuse on ROCm and keeps a single place for Triton tuning and CI.
What this PR adds
Triton code follows the aiter split:
@triton.jitinaiter/ops/triton/_triton_kernels/, Python launchers and public APIs inaiter/ops/triton/.fused_rearrange_sigmoid_gated_delta_ruleinaiter/ops/triton/gated_delta_net/_triton_kernels/gated_delta_rule/decode/fused_rearrange_sigmoid_gdr.pycausal_conv1d_update_single_token,fused_reshape_causal_conv1d_update_single_tokeninaiter/ops/triton/causal_conv1d_update_single_token.py_triton_kernels/causal_conv1d_update_single_token.py(usesPAD_SLOT_IDfrom_triton_kernels/causal_conv1d.py)fused_rms_gated_fp8_group_quant,get_fp8_min_max_bounds,calc_rows_per_blockinaiter/ops/triton/quant/fused_fp8_quant.py_fused_rms_gated_fp8_group_quant_kernelin_triton_kernels/quant/fused_fp8_quant.py(colocated with other fused FP8 quant kernels)Exports are wired through
aiter/ops/triton/gated_delta_net/__init__.pyandaiter/ops/triton/quant/__init__.py.About Gated Delta Rule
paper: https
technical blog: https
Tests
PyTorch reference tests under
op_tests/triton_tests/:test_fused_rearrange_sigmoid_gdr.pytest_causal_conv1d_update_single_token.pyquant/test_fused_rms_gated_fp8_group_quant.pyTest command
Effect on vllm Qwen3 Next model
overall effect
Baseline: around 39us


PR: around 15.4us
effect from
fused_rearrange_sigmoid_gated_delta_ruleBaseline with
fused_recurrent_gated_delta_rule_packed_decode_kernel: 6.218 us (averaged)PR with
fused_rearrange_sigmoid_gated_delta_rule: 5.840us (averaged)effect from
fused_reshape_causal_conv1d_update_single_tokenBaseline: 4-5 kernels with 20-24us

PR: fused to 1 kernel with 4.9us (averaged)

effect from
rmsnorm_input_quant_fp8Baseline: 2 kernels with around 9.4us

PR: fused to 1 kernel with 4.5us (averaged)

Submission checklist