Skip to content

optimize qwen3 next inplace copy#645

Merged
valarLip merged 7 commits intomainfrom
ganyi/optimize_qwen3_next_copy
Apr 30, 2026
Merged

optimize qwen3 next inplace copy#645
valarLip merged 7 commits intomainfrom
ganyi/optimize_qwen3_next_copy

Conversation

@ganyi1996ppo
Copy link
Copy Markdown
Contributor

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Copilot AI review requested due to automatic review settings April 24, 2026 11:59
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR aims to reduce overhead in the Qwen3Next forward path by removing preallocated output-buffer writes and by replacing the compiled PyTorch GemmaRMSNorm path with a dedicated fused Triton kernel.

Changes:

  • Remove output buffer parameters from Qwen3Next attention/linear-attention forwards and switch to return-based outputs.
  • Add a fused Triton kernel for Gemma-style RMSNorm with optional residual add.
  • Route GemmaRMSNorm.forward_cuda to the new Triton implementation.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.

File Description
atom/models/qwen3_next.py Removes output-buffer in-place writes and changes attention/linear-attention to return outputs directly.
atom/model_ops/triton_gemma_rmsnorm.py Adds fused Triton GemmaRMSNorm (+ optional residual add) kernel and launcher.
atom/model_ops/layernorm.py Switches GemmaRMSNorm.forward_cuda to the new Triton kernel path (replacing prior torch.compile behavior).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread atom/model_ops/layernorm.py Outdated
from atom.model_ops.triton_gemma_rmsnorm import (
gemma_rmsnorm_triton
)
return gemma_rmsnorm_triton(x, self.weight.data, self.variance_epsilon, residual)
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GemmaRMSNorm.forward_cuda currently returns the raw result of gemma_rmsnorm_triton(...), but gemma_rmsnorm_triton returns (out, None) when residual is None. That changes the public contract of GemmaRMSNorm.forward* (it should return a Tensor when no residual is provided) and will break callers like hidden_states = self.input_layernorm(hidden_states).

Suggested change
return gemma_rmsnorm_triton(x, self.weight.data, self.variance_epsilon, residual)
result = gemma_rmsnorm_triton(x, self.weight.data, self.variance_epsilon, residual)
if residual is None and isinstance(result, tuple):
return result[0]
return result

Copilot uses AI. Check for mistakes.
Comment thread atom/model_ops/layernorm.py Outdated
Comment on lines +547 to +553
x: torch.Tensor,
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if torch.compiler.is_compiling():
return self.forward_native(x, residual)

if not getattr(self, "_is_compiled", False):
self.forward_static = torch.compile(self.forward_static) # type: ignore
self._is_compiled = True
return self.forward_native(x, residual)
from atom.model_ops.triton_gemma_rmsnorm import (
gemma_rmsnorm_triton
)
return gemma_rmsnorm_triton(x, self.weight.data, self.variance_epsilon, residual)
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change removes the torch.compiler.is_compiling() fallback and now always calls into a Triton kernel from forward_cuda. The codebase compiles models with torch.compile(..., fullgraph=True) (see atom/model_engine/model_runner.py), so a non-traceable Python/Triton call here can cause Dynamo graph breaks or fullgraph compilation failures. Consider keeping a compile-time fallback to forward_native (or registering a proper torch custom op) when torch.compiler.is_compiling() is true.

Copilot uses AI. Check for mistakes.
Comment thread atom/models/qwen3_next.py
Comment on lines 968 to 975
if self.layer_type == "linear_attention":
self.linear_attn(
hidden_states = self.linear_attn(
hidden_states=(
hidden_bf16 if hidden_bf16 is not None else hidden_states
),
output=self_attention_output,
x_fp8=hidden_states if x_scale is not None else None,
x_scale=x_scale,
)
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Qwen3NextDecoderLayer.forward now calls self.linear_attn(...) expecting a returned Tensor. However, Qwen3_5DecoderLayer reuses this forward implementation while wiring self.linear_attn = Qwen3_5GatedDeltaNet, whose forward signature still requires an output argument (see atom/models/qwen3_5.py). This will raise a TypeError for Qwen3.5 linear-attention layers unless Qwen3_5GatedDeltaNet.forward (or the Qwen3.5 decoder forward path) is updated to match the new return-based API.

Copilot uses AI. Check for mistakes.
Comment on lines +3 to +13
Replaces the torch.compile'd GemmaRMSNorm.forward_static with a single Triton
kernel that fuses residual add and RMS normalization with the Gemma-style
weight offset: out = rmsnorm(x + residual) * (1 + w).

Based on aiter's _fused_add_rmsnorm_kernel with the (g + 1.0) Gemma offset.

Two custom ops are registered so that torch.compile (Dynamo) can trace through
them without falling back to the PyTorch implementation that contains
x.float() / x.to(orig_dtype) dtype-cast copy kernels:
- ``fused_gemma_rmsnorm`` (no residual)
- ``fused_gemma_add_rmsnorm`` (with residual add)
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The module docstring says “Two custom ops are registered so that torch.compile (Dynamo) can trace through them…”, but this file only defines a Triton kernel + Python launcher and does not register any torch.library custom ops. Either register the stated ops (and update callers to use them) or update the docstring to avoid implying torch.compile support that isn’t actually implemented.

Suggested change
Replaces the torch.compile'd GemmaRMSNorm.forward_static with a single Triton
kernel that fuses residual add and RMS normalization with the Gemma-style
weight offset: out = rmsnorm(x + residual) * (1 + w).
Based on aiter's _fused_add_rmsnorm_kernel with the (g + 1.0) Gemma offset.
Two custom ops are registered so that torch.compile (Dynamo) can trace through
them without falling back to the PyTorch implementation that contains
x.float() / x.to(orig_dtype) dtype-cast copy kernels:
- ``fused_gemma_rmsnorm`` (no residual)
- ``fused_gemma_add_rmsnorm`` (with residual add)
Defines a Triton kernel and Python launcher that fuse residual add and RMS
normalization with the Gemma-style weight offset:
out = rmsnorm(x + residual) * (1 + w).
Based on aiter's _fused_add_rmsnorm_kernel with the (g + 1.0) Gemma offset.
This module does not register `torch.library` custom ops; it provides the
Triton implementation and launcher directly.

Copilot uses AI. Check for mistakes.
residual_out = x # dummy, won't be written

BLOCK_SIZE = triton.next_power_of_2(n_cols)
NUM_PRGMS = min(n_rows, 304) # MI355X has 304 CUs
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NUM_PRGMS = min(n_rows, 304) hard-codes MI355X’s CU count, which can under/over-subscribe on other AMD SKUs (and non-AMD platforms) and makes performance tuning device-specific. Consider deriving the program count from torch.cuda.get_device_properties(...).multi_processor_count (similar to atom/model_ops/fla_ops/layernorm_guard.py) or making this a heuristic/configurable parameter.

Suggested change
NUM_PRGMS = min(n_rows, 304) # MI355X has 304 CUs
try:
multi_processor_count = torch.cuda.get_device_properties(
x.device
).multi_processor_count
except (AssertionError, AttributeError, RuntimeError, ValueError):
multi_processor_count = n_rows
NUM_PRGMS = min(n_rows, max(1, multi_processor_count))

Copilot uses AI. Check for mistakes.
Copilot AI review requested due to automatic review settings April 29, 2026 11:13
@ganyi1996ppo ganyi1996ppo force-pushed the ganyi/optimize_qwen3_next_copy branch from e772efa to 4317003 Compare April 29, 2026 11:13
Signed-off-by: ganyi <ygan@amd.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.

Comments suppressed due to low confidence (1)

atom/models/qwen3_next.py:984

  • There are now commented-out allocations/assignments for self_attention_output and hidden_states = self_attention_output. Since this code is no longer used, it would be clearer to remove these commented blocks (or replace with a short explanatory comment about why the previous in-place buffer path was removed) to avoid confusion during future maintenance.
        if self.layer_type == "linear_attention":
            hidden_states = self.linear_attn(
                hidden_states=(
                    hidden_bf16 if hidden_bf16 is not None else hidden_states
                ),
                x_fp8=hidden_states if x_scale is not None else None,
                x_scale=x_scale,
            )
        elif self.layer_type == "full_attention":
            hidden_states = self.self_attn(
                hidden_states=hidden_states,
                positions=positions,
                x_scale=x_scale,
            )
        else:
            raise ValueError("Invalid layer_type")

        if self.layer_scale:
            if len(hidden_states.shape) == 2:
                hidden_states = hidden_states * (

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 545 to +553
def forward_cuda(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if torch.compiler.is_compiling():
return self.forward_native(x, residual)
from atom.model_ops.triton_gemma_rmsnorm import gemma_rmsnorm_triton

if not getattr(self, "_is_compiled", False):
self.forward_static = torch.compile(self.forward_static) # type: ignore
self._is_compiled = True
return self.forward_native(x, residual)
return gemma_rmsnorm_triton(
x, self.weight.data, self.variance_epsilon, residual
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GemmaRMSNorm.forward() always dispatches to forward_cuda(). With this change, forward_cuda() unconditionally imports/launches a Triton kernel, which will fail for CPU tensors (where previously it fell back to forward_native). Consider guarding on x.is_cuda (and possibly dtype/device constraints) and falling back to forward_native when Triton can't run.

Copilot uses AI. Check for mistakes.
Comment on lines +552 to +554
return gemma_rmsnorm_triton(
x, self.weight.data, self.variance_epsilon, residual
)
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GemmaRMSNorm.forward_cuda now returns the raw result of gemma_rmsnorm_triton, which is always a 2-tuple (out, residual_out_or_None). This breaks the existing contract of GemmaRMSNorm.forward_static/forward_native, which returns a Tensor when residual is None (and a 2-tuple only when residual is provided). Please unwrap the return so callers still receive a Tensor in the no-residual case (e.g., return just out when residual is None).

Suggested change
return gemma_rmsnorm_triton(
x, self.weight.data, self.variance_epsilon, residual
)
out, residual_out = gemma_rmsnorm_triton(
x, self.weight.data, self.variance_epsilon, residual
)
return out if residual is None else (out, residual_out)

Copilot uses AI. Check for mistakes.
Comment on lines +83 to +109
def gemma_rmsnorm_triton(x, weight, eps, residual):
"""Launch the Triton kernel. Returns (out, residual_out | None)."""
ori_shape = x.shape
x = x.view(-1, ori_shape[-1])
n_rows, n_cols = x.shape

out = torch.empty_like(x)

has_residual = residual is not None
if has_residual:
residual = residual.view(-1, ori_shape[-1])
residual_out = torch.empty_like(residual)
else:
residual_out = x # dummy, won't be written

BLOCK_SIZE = triton.next_power_of_2(n_cols)
NUM_PRGMS = min(n_rows, 304) # MI355X has 304 CUs

_gemma_rmsnorm_kernel[(NUM_PRGMS,)](
x,
out,
residual if has_residual else x, # dummy for res_in when no residual
residual_out,
weight,
x.stride(0),
out.stride(0),
n_rows,
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Triton kernel indexes columns as base + col_offsets and only takes a single row_stride argument, which implicitly assumes the last dimension is contiguous (stride(-1) == 1) for x, residual, and weight. gemma_rmsnorm_triton currently does x = x.view(...) without checking contiguity/stride, so non-contiguous inputs could either error (view) or produce incorrect results. Add an explicit contiguity/stride check (and fallback to forward_native or call .contiguous()/.reshape() as appropriate) before launching the kernel.

Copilot uses AI. Check for mistakes.
Comment on lines +55 to +68
x = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg")

if HAS_RESIDUAL:
res_in_ptrs = res_in_ptr + row_idx * input_row_stride + col_offsets
res_in_ptrs = tl.multiple_of(res_in_ptrs, (16,))
res = tl.load(res_in_ptrs, mask=mask, other=0.0, cache_modifier=".cg")
x = x + res
# Store residual_out (needed by next layer)
res_out_ptrs = res_out_ptr + row_idx * input_row_stride + col_offsets
res_out_ptrs = tl.multiple_of(res_out_ptrs, (16,))
tl.store(res_out_ptrs, x.to(res_out_ptr.dtype.element_ty), mask=mask)

x = x.to(tl.float32)

Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fused path currently adds x + residual in the input dtype and only casts to fp32 afterwards (x = x + res then x = x.to(tl.float32)). This is not equivalent to GemmaRMSNorm.forward_static, which upcasts to fp32 before the add when orig_dtype == torch.float16. To match existing numerics, consider casting to fp32 before the add for fp16 inputs (or always performing the add in fp32).

Suggested change
x = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg")
if HAS_RESIDUAL:
res_in_ptrs = res_in_ptr + row_idx * input_row_stride + col_offsets
res_in_ptrs = tl.multiple_of(res_in_ptrs, (16,))
res = tl.load(res_in_ptrs, mask=mask, other=0.0, cache_modifier=".cg")
x = x + res
# Store residual_out (needed by next layer)
res_out_ptrs = res_out_ptr + row_idx * input_row_stride + col_offsets
res_out_ptrs = tl.multiple_of(res_out_ptrs, (16,))
tl.store(res_out_ptrs, x.to(res_out_ptr.dtype.element_ty), mask=mask)
x = x.to(tl.float32)
x = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg").to(
tl.float32
)
if HAS_RESIDUAL:
res_in_ptrs = res_in_ptr + row_idx * input_row_stride + col_offsets
res_in_ptrs = tl.multiple_of(res_in_ptrs, (16,))
res = tl.load(
res_in_ptrs, mask=mask, other=0.0, cache_modifier=".cg"
).to(tl.float32)
x = x + res
# Store residual_out (needed by next layer)
res_out_ptrs = res_out_ptr + row_idx * input_row_stride + col_offsets
res_out_ptrs = tl.multiple_of(res_out_ptrs, (16,))
tl.store(res_out_ptrs, x.to(res_out_ptr.dtype.element_ty), mask=mask)

Copilot uses AI. Check for mistakes.
valarLip
valarLip previously approved these changes Apr 29, 2026
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Copilot AI review requested due to automatic review settings April 29, 2026 12:56
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

has_residual = residual is not None
if has_residual:
residual = residual.view(-1, ori_shape[-1])
residual_out = torch.empty_like(residual)
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gemma_rmsnorm_triton allocates residual_out = torch.empty_like(residual), which forces residual_out.dtype == residual.dtype. This diverges from GemmaRMSNorm.forward_static, which returns a float32 residual when the original dtype is float16 (due to the x.float() + residual.float() path). To keep behavior consistent, consider allocating/storing residual_out as float32 when x.dtype == torch.float16 (and keeping bf16/fp32 behavior unchanged).

Suggested change
residual_out = torch.empty_like(residual)
residual_out_dtype = torch.float32 if x.dtype == torch.float16 else residual.dtype
residual_out = torch.empty(
residual.shape,
device=residual.device,
dtype=residual_out_dtype,
)

Copilot uses AI. Check for mistakes.
Comment thread atom/models/qwen3_5.py
"""
Forward pass with three parts:
1. Input projection
2. Core attention (custom op)
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The forward docstring describes only two parts (input projection + core attention), but the function still performs and returns an output projection (“Part 3”). Please update the docstring (or the section headers) so it matches the actual computation.

Suggested change
2. Core attention (custom op)
2. Core attention (custom op)
3. Output projection

Copilot uses AI. Check for mistakes.
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Copilot AI review requested due to automatic review settings April 29, 2026 14:43
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread atom/models/qwen3_next.py
Comment on lines 771 to +773
core_attn_out, maybe_scale = self.norm(core_attn_out, z)
output[:num_tokens] = self.out_proj(core_attn_out, x_scale=maybe_scale)
output = self.out_proj(core_attn_out, x_scale=maybe_scale)
return output
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_tokens is now unused after switching from output[:num_tokens] = ... to returning the projected tensor directly. Please remove the dead variable (and any related logic) to avoid confusion and keep the forward path minimal.

Copilot uses AI. Check for mistakes.
@valarLip valarLip merged commit 8a7416d into main Apr 30, 2026
46 of 51 checks passed
@valarLip valarLip deleted the ganyi/optimize_qwen3_next_copy branch April 30, 2026 02:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants