Skip to content

[Triton] optimized decode kernels for Qwen3-Next model#2423

Open
hellozhuo-amd wants to merge 22 commits intomainfrom
zhuo/qwen3_triton_gdn
Open

[Triton] optimized decode kernels for Qwen3-Next model#2423
hellozhuo-amd wants to merge 22 commits intomainfrom
zhuo/qwen3_triton_gdn

Conversation

@hellozhuo-amd
Copy link
Copy Markdown

@hellozhuo-amd hellozhuo-amd commented Mar 23, 2026

Motivation

On the Qwen3-Next decode path, vLLM runs several Triton-backed steps back-to-back (causal conv1d state update, QKV layout work, gated delta rule / linear attention). Bringing well-tested kernels into aiter improves reuse on ROCm and keeps a single place for Triton tuning and CI.

What this PR adds

Triton code follows the aiter split: @triton.jit in aiter/ops/triton/_triton_kernels/, Python launchers and public APIs in aiter/ops/triton/.

Area Launcher / API Kernel location
Gated delta rule (decode) fused_rearrange_sigmoid_gated_delta_rule in aiter/ops/triton/gated_delta_net/ _triton_kernels/gated_delta_rule/decode/fused_rearrange_sigmoid_gdr.py
Causal conv1d "update" fast path causal_conv1d_update_single_token, fused_reshape_causal_conv1d_update_single_token in aiter/ops/triton/causal_conv1d_update_single_token.py _triton_kernels/causal_conv1d_update_single_token.py (uses PAD_SLOT_ID from _triton_kernels/causal_conv1d.py)
RMSNorm + gated + FP8 group quant fused_rms_gated_fp8_group_quant, get_fp8_min_max_bounds, calc_rows_per_block in aiter/ops/triton/quant/fused_fp8_quant.py _fused_rms_gated_fp8_group_quant_kernel in _triton_kernels/quant/fused_fp8_quant.py (colocated with other fused FP8 quant kernels)

Exports are wired through aiter/ops/triton/gated_delta_net/__init__.py and aiter/ops/triton/quant/__init__.py.

About Gated Delta Rule

paper: https
technical blog: https

Tests

PyTorch reference tests under op_tests/triton_tests/:

  • test_fused_rearrange_sigmoid_gdr.py
  • test_causal_conv1d_update_single_token.py
  • quant/test_fused_rms_gated_fp8_group_quant.py

Test command

python3 -m pytest \
  op_tests/triton_tests/test_fused_rearrange_sigmoid_gdr.py \
  op_tests/triton_tests/test_causal_conv1d_update_single_token.py \
  op_tests/triton_tests/quant/test_fused_rms_gated_fp8_group_quant.py \
  -v

Effect on vllm Qwen3 Next model

overall effect

Baseline: around 39us
image
PR: around 15.4us
image

effect from fused_rearrange_sigmoid_gated_delta_rule

Baseline with fused_recurrent_gated_delta_rule_packed_decode_kernel: 6.218 us (averaged)

PR with fused_rearrange_sigmoid_gated_delta_rule: 5.840us (averaged)

effect from fused_reshape_causal_conv1d_update_single_token

Baseline: 4-5 kernels with 20-24us
image

PR: fused to 1 kernel with 4.9us (averaged)
image

effect from rmsnorm_input_quant_fp8

Baseline: 2 kernels with around 9.4us
image

PR: fused to 1 kernel with 4.5us (averaged)
image

Submission checklist

@hellozhuo-amd hellozhuo-amd requested review from a team and Copilot March 23, 2026 07:40
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2423 --add-label <label>

Comment thread aiter/ops/triton/fusions/fused_rearrange_recurrent.py Outdated
Comment thread aiter/ops/triton/_triton_kernels/fusions/fused_rearrange_recurrent.py Outdated
Comment thread aiter/ops/triton/_triton_kernels/fusions/fused_conv1d_rearrange_recurrent.py Outdated
Comment thread aiter/ops/triton/fusions/fused_conv1d_rearrange_recurrent.py Outdated
Comment thread aiter/ops/triton/fusions/fused_rearrange_recurrent.py Outdated
Comment thread aiter/ops/triton/fusions/fused_conv1d_rearrange_recurrent.py Outdated
Comment thread aiter/ops/triton/fusions/fused_conv1d_rearrange_recurrent.py Outdated
Comment thread aiter/ops/triton/_triton_kernels/fusions/fused_conv1d_rearrange_recurrent.py Outdated
@hellozhuo-amd hellozhuo-amd marked this pull request as draft March 23, 2026 07:54
@hellozhuo-amd hellozhuo-amd changed the title Zhuo/qwen3 triton gdn Zhuo/qwen3 triton gdn: fused conv1d with recurrent gated delta rule Mar 23, 2026
@hellozhuo-amd hellozhuo-amd changed the title Zhuo/qwen3 triton gdn: fused conv1d with recurrent gated delta rule Zhuo/Performance enhancement for Qwen3-Next model with Triton kernels Apr 10, 2026
@hellozhuo-amd hellozhuo-amd force-pushed the zhuo/qwen3_triton_gdn branch from 43bf452 to da98e37 Compare April 10, 2026 21:30
@hellozhuo-amd hellozhuo-amd force-pushed the zhuo/qwen3_triton_gdn branch from da98e37 to 3216bce Compare April 10, 2026 21:39
@ROCm ROCm deleted a comment from Copilot AI Apr 10, 2026
Remove unused variable in rmsnorm FP8 test ref. Apply Black to
kernels, launchers, tests, and gated_delta_rule decode __init__.

Made-with: Cursor
@hellozhuo-amd hellozhuo-amd self-assigned this Apr 11, 2026
@hellozhuo-amd hellozhuo-amd marked this pull request as ready for review April 13, 2026 11:12
@hellozhuo-amd hellozhuo-amd changed the title Zhuo/Performance enhancement for Qwen3-Next model with Triton kernels [Triton] optimized decode kernels for Qwen3-Next model Apr 14, 2026
juuso-oskari

This comment was marked as outdated.

@juuso-oskari juuso-oskari dismissed their stale review April 21, 2026 12:06

rereviewing

juuso-oskari
juuso-oskari previously approved these changes Apr 22, 2026
Copy link
Copy Markdown
Contributor

@juuso-oskari juuso-oskari left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

…group_quant

Colocate the gated RMSNorm + FP8 group quant path with the other fused FP8
ops. The Triton kernel is now _fused_rms_gated_fp8_group_quant_kernel in
_triton_kernels/quant/fused_fp8_quant.py; the Python entry point is
fused_rms_gated_fp8_group_quant in quant/fused_fp8_quant.py, with a docstring
that contrasts it with fused_rms_fp8_group_quant. Remove the old
rmsnorm_input_quant_fp8 module and rms_norm_input_quant_fp8 kernel file.
Re-export the new symbol and helpers (get_fp8_min_max_bounds,
calc_rows_per_block) from aiter.ops.triton.quant. Rename the test file to
test_fused_rms_gated_fp8_group_quant.py and update test.sh.

BREAKING CHANGE: rmsnorm_input_quant_fp8 is removed; use
fused_rms_gated_fp8_group_quant instead.

Made-with: Cursor
tpopp added a commit to tpopp/vllm that referenced this pull request Apr 23, 2026
Follow upstream aiter rename (ROCm/aiter#2423). The kernel moved from
aiter.ops.triton.quant.rmsnorm_input_quant_fp8 to
aiter.ops.triton.quant.fused_fp8_quant.fused_rms_gated_fp8_group_quant.
Update the vLLM custom op registration, impl, fake, getter, and fusion
pass references accordingly.

Made-with: Cursor

Signed-off-by: Tres Popp <tres.popp@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants