Conversation
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
There was a problem hiding this comment.
Pull request overview
This PR aims to reduce overhead in the Qwen3Next forward path by removing preallocated output-buffer writes and by replacing the compiled PyTorch GemmaRMSNorm path with a dedicated fused Triton kernel.
Changes:
- Remove
outputbuffer parameters from Qwen3Next attention/linear-attention forwards and switch to return-based outputs. - Add a fused Triton kernel for Gemma-style RMSNorm with optional residual add.
- Route
GemmaRMSNorm.forward_cudato the new Triton implementation.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
atom/models/qwen3_next.py |
Removes output-buffer in-place writes and changes attention/linear-attention to return outputs directly. |
atom/model_ops/triton_gemma_rmsnorm.py |
Adds fused Triton GemmaRMSNorm (+ optional residual add) kernel and launcher. |
atom/model_ops/layernorm.py |
Switches GemmaRMSNorm.forward_cuda to the new Triton kernel path (replacing prior torch.compile behavior). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| from atom.model_ops.triton_gemma_rmsnorm import ( | ||
| gemma_rmsnorm_triton | ||
| ) | ||
| return gemma_rmsnorm_triton(x, self.weight.data, self.variance_epsilon, residual) |
There was a problem hiding this comment.
GemmaRMSNorm.forward_cuda currently returns the raw result of gemma_rmsnorm_triton(...), but gemma_rmsnorm_triton returns (out, None) when residual is None. That changes the public contract of GemmaRMSNorm.forward* (it should return a Tensor when no residual is provided) and will break callers like hidden_states = self.input_layernorm(hidden_states).
| return gemma_rmsnorm_triton(x, self.weight.data, self.variance_epsilon, residual) | |
| result = gemma_rmsnorm_triton(x, self.weight.data, self.variance_epsilon, residual) | |
| if residual is None and isinstance(result, tuple): | |
| return result[0] | |
| return result |
| x: torch.Tensor, | ||
| residual: torch.Tensor | None = None, | ||
| ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: | ||
| if torch.compiler.is_compiling(): | ||
| return self.forward_native(x, residual) | ||
|
|
||
| if not getattr(self, "_is_compiled", False): | ||
| self.forward_static = torch.compile(self.forward_static) # type: ignore | ||
| self._is_compiled = True | ||
| return self.forward_native(x, residual) | ||
| from atom.model_ops.triton_gemma_rmsnorm import ( | ||
| gemma_rmsnorm_triton | ||
| ) | ||
| return gemma_rmsnorm_triton(x, self.weight.data, self.variance_epsilon, residual) |
There was a problem hiding this comment.
This change removes the torch.compiler.is_compiling() fallback and now always calls into a Triton kernel from forward_cuda. The codebase compiles models with torch.compile(..., fullgraph=True) (see atom/model_engine/model_runner.py), so a non-traceable Python/Triton call here can cause Dynamo graph breaks or fullgraph compilation failures. Consider keeping a compile-time fallback to forward_native (or registering a proper torch custom op) when torch.compiler.is_compiling() is true.
| if self.layer_type == "linear_attention": | ||
| self.linear_attn( | ||
| hidden_states = self.linear_attn( | ||
| hidden_states=( | ||
| hidden_bf16 if hidden_bf16 is not None else hidden_states | ||
| ), | ||
| output=self_attention_output, | ||
| x_fp8=hidden_states if x_scale is not None else None, | ||
| x_scale=x_scale, | ||
| ) |
There was a problem hiding this comment.
Qwen3NextDecoderLayer.forward now calls self.linear_attn(...) expecting a returned Tensor. However, Qwen3_5DecoderLayer reuses this forward implementation while wiring self.linear_attn = Qwen3_5GatedDeltaNet, whose forward signature still requires an output argument (see atom/models/qwen3_5.py). This will raise a TypeError for Qwen3.5 linear-attention layers unless Qwen3_5GatedDeltaNet.forward (or the Qwen3.5 decoder forward path) is updated to match the new return-based API.
| Replaces the torch.compile'd GemmaRMSNorm.forward_static with a single Triton | ||
| kernel that fuses residual add and RMS normalization with the Gemma-style | ||
| weight offset: out = rmsnorm(x + residual) * (1 + w). | ||
|
|
||
| Based on aiter's _fused_add_rmsnorm_kernel with the (g + 1.0) Gemma offset. | ||
|
|
||
| Two custom ops are registered so that torch.compile (Dynamo) can trace through | ||
| them without falling back to the PyTorch implementation that contains | ||
| x.float() / x.to(orig_dtype) dtype-cast copy kernels: | ||
| - ``fused_gemma_rmsnorm`` (no residual) | ||
| - ``fused_gemma_add_rmsnorm`` (with residual add) |
There was a problem hiding this comment.
The module docstring says “Two custom ops are registered so that torch.compile (Dynamo) can trace through them…”, but this file only defines a Triton kernel + Python launcher and does not register any torch.library custom ops. Either register the stated ops (and update callers to use them) or update the docstring to avoid implying torch.compile support that isn’t actually implemented.
| Replaces the torch.compile'd GemmaRMSNorm.forward_static with a single Triton | |
| kernel that fuses residual add and RMS normalization with the Gemma-style | |
| weight offset: out = rmsnorm(x + residual) * (1 + w). | |
| Based on aiter's _fused_add_rmsnorm_kernel with the (g + 1.0) Gemma offset. | |
| Two custom ops are registered so that torch.compile (Dynamo) can trace through | |
| them without falling back to the PyTorch implementation that contains | |
| x.float() / x.to(orig_dtype) dtype-cast copy kernels: | |
| - ``fused_gemma_rmsnorm`` (no residual) | |
| - ``fused_gemma_add_rmsnorm`` (with residual add) | |
| Defines a Triton kernel and Python launcher that fuse residual add and RMS | |
| normalization with the Gemma-style weight offset: | |
| out = rmsnorm(x + residual) * (1 + w). | |
| Based on aiter's _fused_add_rmsnorm_kernel with the (g + 1.0) Gemma offset. | |
| This module does not register `torch.library` custom ops; it provides the | |
| Triton implementation and launcher directly. |
| residual_out = x # dummy, won't be written | ||
|
|
||
| BLOCK_SIZE = triton.next_power_of_2(n_cols) | ||
| NUM_PRGMS = min(n_rows, 304) # MI355X has 304 CUs |
There was a problem hiding this comment.
NUM_PRGMS = min(n_rows, 304) hard-codes MI355X’s CU count, which can under/over-subscribe on other AMD SKUs (and non-AMD platforms) and makes performance tuning device-specific. Consider deriving the program count from torch.cuda.get_device_properties(...).multi_processor_count (similar to atom/model_ops/fla_ops/layernorm_guard.py) or making this a heuristic/configurable parameter.
| NUM_PRGMS = min(n_rows, 304) # MI355X has 304 CUs | |
| try: | |
| multi_processor_count = torch.cuda.get_device_properties( | |
| x.device | |
| ).multi_processor_count | |
| except (AssertionError, AttributeError, RuntimeError, ValueError): | |
| multi_processor_count = n_rows | |
| NUM_PRGMS = min(n_rows, max(1, multi_processor_count)) |
e772efa to
4317003
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.
Comments suppressed due to low confidence (1)
atom/models/qwen3_next.py:984
- There are now commented-out allocations/assignments for
self_attention_outputandhidden_states = self_attention_output. Since this code is no longer used, it would be clearer to remove these commented blocks (or replace with a short explanatory comment about why the previous in-place buffer path was removed) to avoid confusion during future maintenance.
if self.layer_type == "linear_attention":
hidden_states = self.linear_attn(
hidden_states=(
hidden_bf16 if hidden_bf16 is not None else hidden_states
),
x_fp8=hidden_states if x_scale is not None else None,
x_scale=x_scale,
)
elif self.layer_type == "full_attention":
hidden_states = self.self_attn(
hidden_states=hidden_states,
positions=positions,
x_scale=x_scale,
)
else:
raise ValueError("Invalid layer_type")
if self.layer_scale:
if len(hidden_states.shape) == 2:
hidden_states = hidden_states * (
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def forward_cuda( | ||
| self, | ||
| x: torch.Tensor, | ||
| residual: torch.Tensor | None = None, | ||
| ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: | ||
| if torch.compiler.is_compiling(): | ||
| return self.forward_native(x, residual) | ||
| from atom.model_ops.triton_gemma_rmsnorm import gemma_rmsnorm_triton | ||
|
|
||
| if not getattr(self, "_is_compiled", False): | ||
| self.forward_static = torch.compile(self.forward_static) # type: ignore | ||
| self._is_compiled = True | ||
| return self.forward_native(x, residual) | ||
| return gemma_rmsnorm_triton( | ||
| x, self.weight.data, self.variance_epsilon, residual |
There was a problem hiding this comment.
GemmaRMSNorm.forward() always dispatches to forward_cuda(). With this change, forward_cuda() unconditionally imports/launches a Triton kernel, which will fail for CPU tensors (where previously it fell back to forward_native). Consider guarding on x.is_cuda (and possibly dtype/device constraints) and falling back to forward_native when Triton can't run.
| return gemma_rmsnorm_triton( | ||
| x, self.weight.data, self.variance_epsilon, residual | ||
| ) |
There was a problem hiding this comment.
GemmaRMSNorm.forward_cuda now returns the raw result of gemma_rmsnorm_triton, which is always a 2-tuple (out, residual_out_or_None). This breaks the existing contract of GemmaRMSNorm.forward_static/forward_native, which returns a Tensor when residual is None (and a 2-tuple only when residual is provided). Please unwrap the return so callers still receive a Tensor in the no-residual case (e.g., return just out when residual is None).
| return gemma_rmsnorm_triton( | |
| x, self.weight.data, self.variance_epsilon, residual | |
| ) | |
| out, residual_out = gemma_rmsnorm_triton( | |
| x, self.weight.data, self.variance_epsilon, residual | |
| ) | |
| return out if residual is None else (out, residual_out) |
| def gemma_rmsnorm_triton(x, weight, eps, residual): | ||
| """Launch the Triton kernel. Returns (out, residual_out | None).""" | ||
| ori_shape = x.shape | ||
| x = x.view(-1, ori_shape[-1]) | ||
| n_rows, n_cols = x.shape | ||
|
|
||
| out = torch.empty_like(x) | ||
|
|
||
| has_residual = residual is not None | ||
| if has_residual: | ||
| residual = residual.view(-1, ori_shape[-1]) | ||
| residual_out = torch.empty_like(residual) | ||
| else: | ||
| residual_out = x # dummy, won't be written | ||
|
|
||
| BLOCK_SIZE = triton.next_power_of_2(n_cols) | ||
| NUM_PRGMS = min(n_rows, 304) # MI355X has 304 CUs | ||
|
|
||
| _gemma_rmsnorm_kernel[(NUM_PRGMS,)]( | ||
| x, | ||
| out, | ||
| residual if has_residual else x, # dummy for res_in when no residual | ||
| residual_out, | ||
| weight, | ||
| x.stride(0), | ||
| out.stride(0), | ||
| n_rows, |
There was a problem hiding this comment.
The Triton kernel indexes columns as base + col_offsets and only takes a single row_stride argument, which implicitly assumes the last dimension is contiguous (stride(-1) == 1) for x, residual, and weight. gemma_rmsnorm_triton currently does x = x.view(...) without checking contiguity/stride, so non-contiguous inputs could either error (view) or produce incorrect results. Add an explicit contiguity/stride check (and fallback to forward_native or call .contiguous()/.reshape() as appropriate) before launching the kernel.
| x = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg") | ||
|
|
||
| if HAS_RESIDUAL: | ||
| res_in_ptrs = res_in_ptr + row_idx * input_row_stride + col_offsets | ||
| res_in_ptrs = tl.multiple_of(res_in_ptrs, (16,)) | ||
| res = tl.load(res_in_ptrs, mask=mask, other=0.0, cache_modifier=".cg") | ||
| x = x + res | ||
| # Store residual_out (needed by next layer) | ||
| res_out_ptrs = res_out_ptr + row_idx * input_row_stride + col_offsets | ||
| res_out_ptrs = tl.multiple_of(res_out_ptrs, (16,)) | ||
| tl.store(res_out_ptrs, x.to(res_out_ptr.dtype.element_ty), mask=mask) | ||
|
|
||
| x = x.to(tl.float32) | ||
|
|
There was a problem hiding this comment.
The fused path currently adds x + residual in the input dtype and only casts to fp32 afterwards (x = x + res then x = x.to(tl.float32)). This is not equivalent to GemmaRMSNorm.forward_static, which upcasts to fp32 before the add when orig_dtype == torch.float16. To match existing numerics, consider casting to fp32 before the add for fp16 inputs (or always performing the add in fp32).
| x = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg") | |
| if HAS_RESIDUAL: | |
| res_in_ptrs = res_in_ptr + row_idx * input_row_stride + col_offsets | |
| res_in_ptrs = tl.multiple_of(res_in_ptrs, (16,)) | |
| res = tl.load(res_in_ptrs, mask=mask, other=0.0, cache_modifier=".cg") | |
| x = x + res | |
| # Store residual_out (needed by next layer) | |
| res_out_ptrs = res_out_ptr + row_idx * input_row_stride + col_offsets | |
| res_out_ptrs = tl.multiple_of(res_out_ptrs, (16,)) | |
| tl.store(res_out_ptrs, x.to(res_out_ptr.dtype.element_ty), mask=mask) | |
| x = x.to(tl.float32) | |
| x = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg").to( | |
| tl.float32 | |
| ) | |
| if HAS_RESIDUAL: | |
| res_in_ptrs = res_in_ptr + row_idx * input_row_stride + col_offsets | |
| res_in_ptrs = tl.multiple_of(res_in_ptrs, (16,)) | |
| res = tl.load( | |
| res_in_ptrs, mask=mask, other=0.0, cache_modifier=".cg" | |
| ).to(tl.float32) | |
| x = x + res | |
| # Store residual_out (needed by next layer) | |
| res_out_ptrs = res_out_ptr + row_idx * input_row_stride + col_offsets | |
| res_out_ptrs = tl.multiple_of(res_out_ptrs, (16,)) | |
| tl.store(res_out_ptrs, x.to(res_out_ptr.dtype.element_ty), mask=mask) |
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| has_residual = residual is not None | ||
| if has_residual: | ||
| residual = residual.view(-1, ori_shape[-1]) | ||
| residual_out = torch.empty_like(residual) |
There was a problem hiding this comment.
gemma_rmsnorm_triton allocates residual_out = torch.empty_like(residual), which forces residual_out.dtype == residual.dtype. This diverges from GemmaRMSNorm.forward_static, which returns a float32 residual when the original dtype is float16 (due to the x.float() + residual.float() path). To keep behavior consistent, consider allocating/storing residual_out as float32 when x.dtype == torch.float16 (and keeping bf16/fp32 behavior unchanged).
| residual_out = torch.empty_like(residual) | |
| residual_out_dtype = torch.float32 if x.dtype == torch.float16 else residual.dtype | |
| residual_out = torch.empty( | |
| residual.shape, | |
| device=residual.device, | |
| dtype=residual_out_dtype, | |
| ) |
| """ | ||
| Forward pass with three parts: | ||
| 1. Input projection | ||
| 2. Core attention (custom op) |
There was a problem hiding this comment.
The forward docstring describes only two parts (input projection + core attention), but the function still performs and returns an output projection (“Part 3”). Please update the docstring (or the section headers) so it matches the actual computation.
| 2. Core attention (custom op) | |
| 2. Core attention (custom op) | |
| 3. Output projection |
Signed-off-by: ganyi <ygan@amd.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| core_attn_out, maybe_scale = self.norm(core_attn_out, z) | ||
| output[:num_tokens] = self.out_proj(core_attn_out, x_scale=maybe_scale) | ||
| output = self.out_proj(core_attn_out, x_scale=maybe_scale) | ||
| return output |
There was a problem hiding this comment.
num_tokens is now unused after switching from output[:num_tokens] = ... to returning the projected tensor directly. Please remove the dead variable (and any related logic) to avoid confusion and keep the forward path minimal.
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist