From 47f4db5a29f0c1885ca69c7d58ef1c332a4289b7 Mon Sep 17 00:00:00 2001 From: Satya Nikhil Kodukula Date: Wed, 29 Oct 2025 16:45:56 +0000 Subject: [PATCH 1/7] implemented the use of kernel repr helper to standardize kernel metadata representation --- .../_triton_kernels/extend_attention.py | 30 +++++- .../_triton_kernels/moe_align_block_size.py | 48 ++++++++-- aiter/ops/triton/_triton_kernels/moe_op.py | 92 +++++++++++++++++- .../ops/triton/_triton_kernels/moe_op_e2e.py | 51 +++++++++- .../ops/triton/_triton_kernels/moe_op_gelu.py | 47 +++++++++- .../triton/_triton_kernels/moe_op_mxfp4.py | 23 ++++- .../moe_op_mxfp4_silu_fused.py | 22 ++++- .../_triton_kernels/moe_op_silu_fused.py | 94 ++++++++++++++++++- .../moe_routing_sigmoid_top1_fused.py | 15 ++- .../_triton_kernels/prefill_attention.py | 16 +++- aiter/ops/triton/_triton_kernels/softmax.py | 11 ++- aiter/ops/triton/_triton_kernels/topk.py | 40 +++++++- 12 files changed, 460 insertions(+), 29 deletions(-) diff --git a/aiter/ops/triton/_triton_kernels/extend_attention.py b/aiter/ops/triton/_triton_kernels/extend_attention.py index e5f7e778a8..a4524fef9d 100644 --- a/aiter/ops/triton/_triton_kernels/extend_attention.py +++ b/aiter/ops/triton/_triton_kernels/extend_attention.py @@ -31,9 +31,33 @@ from ..utils._triton import arch_info from ..utils.core import AITER_TRITON_CONFIGS_PATH from ..utils.device_info import get_num_xcds - - -@triton.jit +from ..utils._triton.kernel_repr import make_kernel_repr + + +_fwd_kernel_extend_repr = make_kernel_repr( + "_fwd_kernel", + [ + "logit_cap", + "Lq", + "Lv", + "BLOCK_DMODEL", + "BLOCK_DPE", + "BLOCK_DV", + "BLOCK_M", + "BLOCK_N", + "USE_CUSTOM_MASK", + "IS_CAUSAL", + "SKIP_PREFIX_CUSTOM_MASK", + "STORE_TRANSPOSE", + "NUM_Q_HEADS", + "NUM_BLOCKS", + "BATCH", + "NUM_XCDS", + ], +) + + +@triton.jit(repr=_fwd_kernel_extend_repr) def _fwd_kernel( Q_Extend, K_Extend, diff --git a/aiter/ops/triton/_triton_kernels/moe_align_block_size.py b/aiter/ops/triton/_triton_kernels/moe_align_block_size.py index 736c608e9d..953df47aa7 100644 --- a/aiter/ops/triton/_triton_kernels/moe_align_block_size.py +++ b/aiter/ops/triton/_triton_kernels/moe_align_block_size.py @@ -3,9 +3,45 @@ import triton import triton.language as tl - - -@triton.jit +from ..utils._triton.kernel_repr import make_kernel_repr + + +_moe_align_block_size_stage1_repr = make_kernel_repr( + "_moe_align_block_size_stage1_kernel", + [ + "num_experts", + "numel", + "tokens_per_thread", + ], +) + +_moe_align_block_size_stage2_repr = make_kernel_repr( + "_moe_align_block_size_stage2_kernel", + [ + "num_experts", + ], +) + +_moe_align_block_size_stage3_repr = make_kernel_repr( + "_moe_align_block_size_stage3_kernel", + [ + "num_experts", + "block_size", + ], +) + +_moe_align_block_size_stage4_repr = make_kernel_repr( + "_moe_align_block_size_stage4_kernel", + [ + "num_experts", + "block_size", + "numel", + "tokens_per_thread", + ], +) + + +@triton.jit(repr=_moe_align_block_size_stage1_repr) def _moe_align_block_size_stage1_kernel( topk_ids_ptr, tokens_cnts_ptr, @@ -26,7 +62,7 @@ def _moe_align_block_size_stage1_kernel( tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1) -@triton.jit +@triton.jit(repr=_moe_align_block_size_stage2_repr) def _moe_align_block_size_stage2_kernel( tokens_cnts_ptr, num_experts: tl.constexpr, @@ -40,7 +76,7 @@ def _moe_align_block_size_stage2_kernel( tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt) -@triton.jit +@triton.jit(repr=_moe_align_block_size_stage3_repr) def _moe_align_block_size_stage3_kernel( total_tokens_post_pad_ptr, tokens_cnts_ptr, @@ -57,7 +93,7 @@ def _moe_align_block_size_stage3_kernel( tl.store(total_tokens_post_pad_ptr, last_cumsum) -@triton.jit +@triton.jit(repr=_moe_align_block_size_stage4_repr) def _moe_align_block_size_stage4_kernel( topk_ids_ptr, sorted_token_ids_ptr, diff --git a/aiter/ops/triton/_triton_kernels/moe_op.py b/aiter/ops/triton/_triton_kernels/moe_op.py index 9b683a5e7a..2da8805a41 100644 --- a/aiter/ops/triton/_triton_kernels/moe_op.py +++ b/aiter/ops/triton/_triton_kernels/moe_op.py @@ -5,18 +5,102 @@ import triton.language as tl from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd from ..utils._triton.moe_common import _write_zeros_to_output +from ..utils._triton.kernel_repr import make_kernel_repr # Source: # MoE Kernel adapted from VLLM +_fused_moe_kernel_gptq_awq_repr = make_kernel_repr( + "_fused_moe_kernel_gptq_awq", + [ + "N", + "K", + "group_size", + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "BLOCK_SIZE_K", + "GROUP_SIZE_M", + "EVEN_K", + "MUL_ROUTED_WEIGHT", + "top_k", + "compute_type", + "has_zp", + "use_int4_w4a16", + "use_int8_w8a16", + "NUM_XCDS", + ], +) + +_fused_moe_persistent_kernel_gptq_awq_repr = make_kernel_repr( + "_fused_moe_persistent_kernel_gptq_awq", + [ + "N", + "K", + "group_size", + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "BLOCK_SIZE_K", + "GROUP_SIZE_M", + "EVEN_K", + "NUM_SMS", + "MUL_ROUTED_WEIGHT", + "top_k", + "compute_type", + "has_zp", + "use_int4_w4a16", + "use_int8_w8a16", + "NUM_XCDS", + ], +) + +_fused_moe_kernel_repr = make_kernel_repr( + "_fused_moe_kernel", + [ + "group_n", + "group_k", + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "BLOCK_SIZE_K", + "GROUP_SIZE_M", + "EVEN_K", + "MUL_ROUTED_WEIGHT", + "top_k", + "compute_type", + "use_fp8_w8a8", + "use_int8_w8a16", + "NUM_XCDS", + ], +) + +_fused_moe_persistent_kernel_repr = make_kernel_repr( + "_fused_moe_persistent_kernel", + [ + "group_n", + "group_k", + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "BLOCK_SIZE_K", + "GROUP_SIZE_M", + "EVEN_K", + "NUM_SMS", + "MUL_ROUTED_WEIGHT", + "top_k", + "compute_type", + "use_fp8_w8a8", + "use_int8_w8a16", + "NUM_XCDS", + ], +) + + @triton.heuristics( { "EVEN_K": lambda args: args["K"] % args["BLOCK_SIZE_K"] == 0, } ) -@triton.jit +@triton.jit(repr=_fused_moe_kernel_gptq_awq_repr) def _fused_moe_kernel_gptq_awq( # Pointers to matrices a_ptr, @@ -254,7 +338,7 @@ def _fused_moe_kernel_gptq_awq( "EVEN_K": lambda args: args["K"] % args["BLOCK_SIZE_K"] == 0, } ) -@triton.jit +@triton.jit(repr=_fused_moe_persistent_kernel_gptq_awq_repr) def _fused_moe_persistent_kernel_gptq_awq( # Pointers to matrices a_ptr, @@ -483,7 +567,7 @@ def _fused_moe_persistent_kernel_gptq_awq( "EVEN_K": lambda args: args["K"] % args["BLOCK_SIZE_K"] == 0, } ) -@triton.jit +@triton.jit(repr=_fused_moe_kernel_repr) def _fused_moe_kernel( # Pointers to matrices a_ptr, @@ -691,7 +775,7 @@ def _fused_moe_kernel( "EVEN_K": lambda args: args["K"] % args["BLOCK_SIZE_K"] == 0, } ) -@triton.jit +@triton.jit(repr=_fused_moe_persistent_kernel_repr) def _fused_moe_persistent_kernel( # Pointers to matrices a_ptr, diff --git a/aiter/ops/triton/_triton_kernels/moe_op_e2e.py b/aiter/ops/triton/_triton_kernels/moe_op_e2e.py index 8b32302590..d8e184c048 100644 --- a/aiter/ops/triton/_triton_kernels/moe_op_e2e.py +++ b/aiter/ops/triton/_triton_kernels/moe_op_e2e.py @@ -5,18 +5,65 @@ import triton.language as tl from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd +from ..utils._triton.kernel_repr import make_kernel_repr # Source: # MoE Kernel adapted from VLLM +_e2e_moe_kernel_repr = make_kernel_repr( + "e2e_moe_kernel", + [ + "top_k", + "EM", + "N", + "K", + "EVEN_K", + "MUL_ROUTED_WEIGHT", + "use_fp8_w8a8", + "use_int8_w8a16", + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "BLOCK_SIZE_K1", + "BLOCK_SIZE_K2", + "GROUP_SIZE_M", + "GRID_MN", + "atomic_num_stages", + "dtype", + "NUM_XCDS", + ], +) + +_e2e_moe_persistent_kernel_repr = make_kernel_repr( + "e2e_moe_persistent_kernel", + [ + "top_k", + "EM", + "N", + "K", + "EVEN_K", + "EVEN_N", + "MUL_ROUTED_WEIGHT", + "use_fp8_w8a8", + "use_int8_w8a16", + "BLOCK_SIZE_M", + "BLOCK_SIZE_N1", + "BLOCK_SIZE_N2", + "BLOCK_SIZE_K1", + "BLOCK_SIZE_K2", + "NUM_SMS", + "NUM_XCDS", + ], +) + + @triton.heuristics( { "GRID_MN": lambda args: triton.cdiv(args["EM"], args["BLOCK_SIZE_M"]) * triton.cdiv(args["N"], args["BLOCK_SIZE_N"]) } ) -@triton.jit +@triton.jit(repr=_e2e_moe_kernel_repr) def e2e_moe_kernel( A, W1, @@ -316,7 +363,7 @@ def e2e_moe_kernel( # tl.store(out_ptrs + k * BLOCK_SIZE_K2, out, mask=c_mask) -@triton.jit +@triton.jit(repr=_e2e_moe_persistent_kernel_repr) def e2e_moe_persistent_kernel( A, W1, diff --git a/aiter/ops/triton/_triton_kernels/moe_op_gelu.py b/aiter/ops/triton/_triton_kernels/moe_op_gelu.py index e9c94b3323..c92ff656da 100644 --- a/aiter/ops/triton/_triton_kernels/moe_op_gelu.py +++ b/aiter/ops/triton/_triton_kernels/moe_op_gelu.py @@ -8,18 +8,61 @@ from .activation import _gelu_tanh from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd from ..utils._triton.moe_common import _write_zeros_to_output +from ..utils._triton.kernel_repr import make_kernel_repr # Source: # MoE Kernel adapted from VLLM +_fused_moe_kernel_gelu_repr = make_kernel_repr( + "_fused_moe_kernel", + [ + "BLOCK_SCALE", + "group_n", + "group_k", + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "BLOCK_SIZE_K", + "GROUP_SIZE_M", + "EVEN_K", + "MUL_ROUTED_WEIGHT", + "top_k", + "compute_type", + "use_fp8_w8a8", + "use_int8_w8a16", + "NUM_XCDS", + ], +) + +_fused_moe_persistent_kernel_gelu_repr = make_kernel_repr( + "_fused_moe_persistent_kernel", + [ + "BLOCK_SCALE", + "group_n", + "group_k", + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "BLOCK_SIZE_K", + "GROUP_SIZE_M", + "EVEN_K", + "NUM_SMS", + "MUL_ROUTED_WEIGHT", + "top_k", + "compute_type", + "use_fp8_w8a8", + "use_int8_w8a16", + "NUM_XCDS", + ], +) + + @triton.heuristics( { "EVEN_K": lambda args: args["K"] % args["BLOCK_SIZE_K"] == 0, } ) -@triton.jit +@triton.jit(repr=_fused_moe_kernel_gelu_repr) def _fused_moe_kernel( # Pointers to matrices a_ptr, @@ -238,7 +281,7 @@ def _fused_moe_kernel( "EVEN_K": lambda args: args["K"] % args["BLOCK_SIZE_K"] == 0, } ) -@triton.jit +@triton.jit(repr=_fused_moe_persistent_kernel_gelu_repr) def _fused_moe_persistent_kernel( # Pointers to matrices a_ptr, diff --git a/aiter/ops/triton/_triton_kernels/moe_op_mxfp4.py b/aiter/ops/triton/_triton_kernels/moe_op_mxfp4.py index 9229f186eb..25850709d6 100644 --- a/aiter/ops/triton/_triton_kernels/moe_op_mxfp4.py +++ b/aiter/ops/triton/_triton_kernels/moe_op_mxfp4.py @@ -5,6 +5,7 @@ import triton.language as tl from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd from ..utils._triton.moe_common import _write_zeros_to_output +from ..utils._triton.kernel_repr import make_kernel_repr def get_scaled_dot_format_string(dtype: tl.dtype): @@ -18,12 +19,32 @@ def get_scaled_dot_format_string(dtype: tl.dtype): return mapping[dtype] +_fused_moe_kernel_mxfp4_repr = make_kernel_repr( + "_fused_moe_kernel_mxfp4", + [ + "A_DTYPE_FORMAT", + "B_DTYPE_FORMAT", + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "BLOCK_SIZE_K", + "GROUP_SIZE_M", + "EVEN_K", + "MUL_ROUTED_WEIGHT", + "top_k", + "compute_type", + "SWIZZLE_MX_A", + "SWIZZLE_MX_B", + "NUM_XCDS", + ], +) + + @triton.heuristics( { "EVEN_K": lambda args: args["K"] % args["BLOCK_SIZE_K"] == 0, } ) -@triton.jit +@triton.jit(repr=_fused_moe_kernel_mxfp4_repr) def _fused_moe_kernel_mxfp4( # Pointers to matrices a_ptr, diff --git a/aiter/ops/triton/_triton_kernels/moe_op_mxfp4_silu_fused.py b/aiter/ops/triton/_triton_kernels/moe_op_mxfp4_silu_fused.py index 2915737048..266621fa49 100644 --- a/aiter/ops/triton/_triton_kernels/moe_op_mxfp4_silu_fused.py +++ b/aiter/ops/triton/_triton_kernels/moe_op_mxfp4_silu_fused.py @@ -6,6 +6,7 @@ from .activation import _silu_exp2 from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd from ..utils._triton.moe_common import _write_zeros_to_output +from ..utils._triton.kernel_repr import make_kernel_repr def get_scaled_dot_format_string(dtype: tl.dtype): @@ -19,12 +20,31 @@ def get_scaled_dot_format_string(dtype: tl.dtype): return mapping[dtype] +_fused_moe_kernel_mxfp4_silu_repr = make_kernel_repr( + "_fused_moe_kernel_mxfp4_silu", + [ + "A_DTYPE_FORMAT", + "B_DTYPE_FORMAT", + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "BLOCK_SIZE_K", + "GROUP_SIZE_M", + "EVEN_K", + "MUL_ROUTED_WEIGHT", + "top_k", + "compute_type", + "SWIZZLE_MX_A", + "SWIZZLE_MX_B", + ], +) + + @triton.heuristics( { "EVEN_K": lambda args: args["K"] % args["BLOCK_SIZE_K"] == 0, } ) -@triton.jit +@triton.jit(repr=_fused_moe_kernel_mxfp4_silu_repr) def _fused_moe_kernel_mxfp4_silu( # Pointers to matrices a_ptr, diff --git a/aiter/ops/triton/_triton_kernels/moe_op_silu_fused.py b/aiter/ops/triton/_triton_kernels/moe_op_silu_fused.py index 4fe9620a6f..a37ff2e530 100644 --- a/aiter/ops/triton/_triton_kernels/moe_op_silu_fused.py +++ b/aiter/ops/triton/_triton_kernels/moe_op_silu_fused.py @@ -7,18 +7,104 @@ from .activation import _silu_exp2 from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd from ..utils._triton.moe_common import _write_zeros_to_output +from ..utils._triton.kernel_repr import make_kernel_repr # Source: # MoE Kernel adapted from VLLM +_fused_moe_silu_kernel_gptq_awq_repr = make_kernel_repr( + "_fused_moe_silu_kernel_gptq_awq", + [ + "N", + "K", + "block_k_diviable", + "group_size", + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "BLOCK_SIZE_K", + "GROUP_SIZE_M", + "EVEN_K", + "MUL_ROUTED_WEIGHT", + "top_k", + "compute_type", + "has_zp", + "use_int4_w4a16", + "use_int8_w8a16", + "NUM_XCDS", + ], +) + +_fused_moe_persistent_silu_kernel_gptq_awq_repr = make_kernel_repr( + "_fused_moe_persistent_silu_kernel_gptq_awq", + [ + "N", + "K", + "block_k_diviable", + "group_size", + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "BLOCK_SIZE_K", + "GROUP_SIZE_M", + "EVEN_K", + "NUM_SMS", + "MUL_ROUTED_WEIGHT", + "top_k", + "compute_type", + "has_zp", + "use_int4_w4a16", + "use_int8_w8a16", + "NUM_XCDS", + ], +) + +_fused_moe_silu_kernel_repr = make_kernel_repr( + "_fused_moe_silu_kernel", + [ + "group_n", + "group_k", + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "BLOCK_SIZE_K", + "GROUP_SIZE_M", + "EVEN_K", + "MUL_ROUTED_WEIGHT", + "top_k", + "compute_type", + "use_fp8_w8a8", + "use_int8_w8a16", + "NUM_XCDS", + ], +) + +_fused_moe_persistent_silu_kernel_repr = make_kernel_repr( + "_fused_moe_persistent_silu_kernel", + [ + "group_n", + "group_k", + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "BLOCK_SIZE_K", + "GROUP_SIZE_M", + "EVEN_K", + "NUM_SMS", + "MUL_ROUTED_WEIGHT", + "top_k", + "compute_type", + "use_fp8_w8a8", + "use_int8_w8a16", + "NUM_XCDS", + ], +) + + @triton.heuristics( { "EVEN_K": lambda args: args["K"] % args["BLOCK_SIZE_K"] == 0, } ) -@triton.jit +@triton.jit(repr=_fused_moe_silu_kernel_gptq_awq_repr) def _fused_moe_silu_kernel_gptq_awq( # Pointers to matrices a_ptr, @@ -279,7 +365,7 @@ def _fused_moe_silu_kernel_gptq_awq( "EVEN_K": lambda args: args["K"] % args["BLOCK_SIZE_K"] == 0, } ) -@triton.jit +@triton.jit(repr=_fused_moe_persistent_silu_kernel_gptq_awq_repr) def _fused_moe_persistent_silu_kernel_gptq_awq( # Pointers to matrices a_ptr, @@ -526,7 +612,7 @@ def _fused_moe_persistent_silu_kernel_gptq_awq( "EVEN_K": lambda args: args["K"] % args["BLOCK_SIZE_K"] == 0, } ) -@triton.jit +@triton.jit(repr=_fused_moe_silu_kernel_repr) def _fused_moe_silu_kernel( # Pointers to matrices a_ptr, @@ -757,7 +843,7 @@ def _fused_moe_silu_kernel( "EVEN_K": lambda args: args["K"] % args["BLOCK_SIZE_K"] == 0, } ) -@triton.jit +@triton.jit(repr=_fused_moe_persistent_silu_kernel_repr) def _fused_moe_persistent_silu_kernel( # Pointers to matrices a_ptr, diff --git a/aiter/ops/triton/_triton_kernels/moe_routing_sigmoid_top1_fused.py b/aiter/ops/triton/_triton_kernels/moe_routing_sigmoid_top1_fused.py index 31cdb76771..90027f4f81 100644 --- a/aiter/ops/triton/_triton_kernels/moe_routing_sigmoid_top1_fused.py +++ b/aiter/ops/triton/_triton_kernels/moe_routing_sigmoid_top1_fused.py @@ -8,9 +8,22 @@ import triton.language as tl from ..utils._triton import arch_info from ..utils.core import AITER_TRITON_CONFIGS_PATH +from ..utils._triton.kernel_repr import make_kernel_repr -@triton.jit +_routing_sigmoid_top1_repr = make_kernel_repr( + "_routing_sigmoid_top1_kernel", + [ + "BLOCK_M", + "BLOCK_K", + "BLOCK_N", + "TOPK", + "FUSED_SHARED_EXPERTS", + ], +) + + +@triton.jit(repr=_routing_sigmoid_top1_repr) def _routing_sigmoid_top1_kernel( X_ptr, W_ptr, diff --git a/aiter/ops/triton/_triton_kernels/prefill_attention.py b/aiter/ops/triton/_triton_kernels/prefill_attention.py index 9d422dbcfa..ef45e9f8d6 100644 --- a/aiter/ops/triton/_triton_kernels/prefill_attention.py +++ b/aiter/ops/triton/_triton_kernels/prefill_attention.py @@ -24,9 +24,23 @@ # https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1 import triton import triton.language as tl +from ..utils._triton.kernel_repr import make_kernel_repr -@triton.jit +_fwd_kernel_repr = make_kernel_repr( + "_fwd_kernel", + [ + "kv_group_num", + "BLOCK_M", + "BLOCK_DMODEL", + "BLOCK_N", + "IS_CAUSAL", + "Lk", + ], +) + + +@triton.jit(repr=_fwd_kernel_repr) def _fwd_kernel( Q, K, diff --git a/aiter/ops/triton/_triton_kernels/softmax.py b/aiter/ops/triton/_triton_kernels/softmax.py index 5ee7f0f08c..5cd0860084 100644 --- a/aiter/ops/triton/_triton_kernels/softmax.py +++ b/aiter/ops/triton/_triton_kernels/softmax.py @@ -1,8 +1,17 @@ import triton import triton.language as tl +from ..utils._triton.kernel_repr import make_kernel_repr -@triton.jit +_softmax_kernel_online_repr = make_kernel_repr( + "_softmax_kernel_online", + [ + "BLOCK_SIZE", + ], +) + + +@triton.jit(repr=_softmax_kernel_online_repr) def _softmax_kernel_online( output_ptr, input_ptr, diff --git a/aiter/ops/triton/_triton_kernels/topk.py b/aiter/ops/triton/_triton_kernels/topk.py index 17935c4c60..f01203e9db 100644 --- a/aiter/ops/triton/_triton_kernels/topk.py +++ b/aiter/ops/triton/_triton_kernels/topk.py @@ -10,10 +10,44 @@ import triton.language as tl import triton.language.core as core from triton.language.standard import _log2, zeros_like +from ..utils._triton.kernel_repr import make_kernel_repr + + +_topk_kernel_repr = make_kernel_repr( + "_topk_kernel", + [ + "M", + "K", + "BLOCK", + "FILL_VALUE", + ], +) + +_topk_stage1_kernel_repr = make_kernel_repr( + "topk_stage1_kernel", + [ + "N", + "CHUNK_SIZE", + "DESCENDING", + "FILL_VALUE", + ], +) + +_topk_stage2_kernel_repr = make_kernel_repr( + "topk_stage2_kernel", + [ + "k", + "N", + "BLOCK_SIZE", + "DESCENDING", + "FILL_VALUE", + "MASK_INDEX_VAL", + ], +) # 1-STAGE KERNEL (tiny rows) -@triton.jit +@triton.jit(repr=_topk_kernel_repr) def _topk_kernel( X, OUT_V, @@ -53,7 +87,7 @@ def _topk_kernel( # 2-STAGE KERNEL (large rows) -@triton.jit +@triton.jit(repr=_topk_stage1_kernel_repr) def topk_stage1_kernel( y_ptr, index_ptr, @@ -211,7 +245,7 @@ def argsort(x, ids, dim: tl.constexpr, descending: core.constexpr): return x, ids -@triton.jit +@triton.jit(repr=_topk_stage2_kernel_repr) def topk_stage2_kernel( y_ptr, index_ptr, From ce3580d7b64512ca9c0e1cacedf5294f14a7fe7b Mon Sep 17 00:00:00 2001 From: Satya Nikhil Kodukula Date: Wed, 29 Oct 2025 16:49:07 +0000 Subject: [PATCH 2/7] fix indentation error --- aiter/ops/triton/_triton_kernels/moe_align_block_size.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiter/ops/triton/_triton_kernels/moe_align_block_size.py b/aiter/ops/triton/_triton_kernels/moe_align_block_size.py index 953df47aa7..6dda7aa628 100644 --- a/aiter/ops/triton/_triton_kernels/moe_align_block_size.py +++ b/aiter/ops/triton/_triton_kernels/moe_align_block_size.py @@ -19,7 +19,7 @@ "_moe_align_block_size_stage2_kernel", [ "num_experts", - ], + ], ) _moe_align_block_size_stage3_repr = make_kernel_repr( From a11129b3c895da460f86838c3acaa940c5c74c20 Mon Sep 17 00:00:00 2001 From: Satya Nikhil Kodukula Date: Wed, 5 Nov 2025 20:25:13 +0000 Subject: [PATCH 3/7] Add Missing API documentation --- .../_triton_kernels/extend_attention.py | 4 +- aiter/ops/triton/extend_attention.py | 48 +++++++++++++++++-- aiter/ops/triton/moe_align_block_size.py | 18 ++++++- aiter/ops/triton/moe_op.py | 35 ++++++++++++-- aiter/ops/triton/moe_op_e2e.py | 28 ++++++++++- aiter/ops/triton/moe_op_gelu.py | 29 +++++++++-- aiter/ops/triton/moe_op_mxfp4.py | 28 ++++++++++- aiter/ops/triton/moe_op_mxfp4_silu_fused.py | 28 ++++++++++- aiter/ops/triton/moe_op_silu_fused.py | 35 ++++++++++++-- .../triton/moe_routing_sigmoid_top1_fused.py | 17 ++++++- aiter/ops/triton/prefill_attention.py | 20 ++++++-- aiter/ops/triton/softmax.py | 16 +++---- aiter/ops/triton/topk.py | 19 +++++++- 13 files changed, 279 insertions(+), 46 deletions(-) diff --git a/aiter/ops/triton/_triton_kernels/extend_attention.py b/aiter/ops/triton/_triton_kernels/extend_attention.py index a4524fef9d..262b80be4b 100644 --- a/aiter/ops/triton/_triton_kernels/extend_attention.py +++ b/aiter/ops/triton/_triton_kernels/extend_attention.py @@ -51,7 +51,7 @@ "STORE_TRANSPOSE", "NUM_Q_HEADS", "NUM_BLOCKS", - "BATCH", + # "BATCH", "NUM_XCDS", ], ) @@ -98,7 +98,7 @@ def _fwd_kernel( STORE_TRANSPOSE: tl.constexpr, NUM_Q_HEADS: tl.constexpr, NUM_BLOCKS: tl.constexpr, - BATCH: tl.constexpr, + # BATCH: tl.constexpr, NUM_XCDS: tl.constexpr, ): workgroup_id = tl.program_id(0) # workgroup index diff --git a/aiter/ops/triton/extend_attention.py b/aiter/ops/triton/extend_attention.py index 4cb163c06c..cb08519fd6 100644 --- a/aiter/ops/triton/extend_attention.py +++ b/aiter/ops/triton/extend_attention.py @@ -20,7 +20,7 @@ from typing import Optional import torch import triton -import triton.language as tl +# import triton.language as tl from aiter.ops.triton.prefill_attention import context_attention_fwd @@ -51,9 +51,30 @@ def extend_attention_fwd( config: Optional[dict[str, any]] = None, ): """ - q_extend, k_extend, v_extend, o_extend: contiguous tensors - - k_buffer, v_buffer: (prefix + extend) tensors in mem_manager + Attention for prefill with KV cache (extend phase). + Supports page size = 1 and variable-length sequences with prefix caching. + + Args: + q_extend (torch.Tensor): Query tensor for extend tokens with shape (total_extend_tokens, num_q_heads, head_dim). + k_extend (torch.Tensor): Key tensor for extend tokens with shape (total_extend_tokens, num_kv_heads, head_dim). + v_extend (torch.Tensor): Value tensor for extend tokens with shape (total_extend_tokens, num_kv_heads, head_dim). + o_extend (torch.Tensor): Output tensor for extend tokens with shape (total_extend_tokens, num_q_heads, head_dim). + k_buffer (torch.Tensor): KV cache buffer containing prefix + extend keys with shape (total_tokens, num_kv_heads, head_dim). + v_buffer (torch.Tensor): KV cache buffer containing prefix + extend values with shape (total_tokens, num_kv_heads, head_dim). + qo_indptr (torch.Tensor): Index pointer for query/output sequences with shape (batch_size + 1,). + kv_indptr (torch.Tensor): Index pointer for KV cache sequences with shape (batch_size + 1,). + kv_indices (torch.Tensor): Indices mapping into KV cache buffer. + custom_mask (Optional[torch.Tensor]): Custom attention mask tensor. + is_causal (bool): Apply causal masking. + mask_indptr (torch.Tensor): Index pointer for custom mask. + max_len_extend (int): Maximum extend sequence length in batch. + sm_scale (Optional[float]): Softmax scale, defaults to 1/sqrt(head_dim). + logit_cap (float): Cap logits to prevent overflow. + skip_prefix_custom_mask (bool): Skip custom mask for prefix portion. + config (Optional[dict]): Kernel tuning parameters (BLOCK_M, BLOCK_N). + + Returns: + None. Results written in-place to o_extend. """ _LOGGER.info( f"EXTEND_ATTENTION_FWD: q_extend={tuple(q_extend.shape)} k_extend={tuple(k_extend.shape)} v_extend={tuple(v_extend.shape)} " @@ -133,7 +154,7 @@ def extend_attention_fwd( STORE_TRANSPOSE=True, NUM_Q_HEADS=head_num, NUM_BLOCKS=num_blocks, - BATCH=batch_size, + # BATCH=batch_size, NUM_XCDS=get_num_xcds(), # num_warps=num_warps, # num_stages=num_stages, @@ -152,6 +173,23 @@ def redundant_attention( b_seq_len_prefix, max_len_in_batch, ): + """ + Alternative attention computation for extend tokens using full buffer reconstruction. + + Args: + q_extend (torch.Tensor): Query tensor for extend tokens with shape (total_extend_tokens, num_q_heads, head_dim). + o_extend (torch.Tensor): Output tensor for extend tokens with shape (total_extend_tokens, num_q_heads, head_dim). + k_buffer (torch.Tensor): KV cache buffer for keys with shape (total_tokens, num_kv_heads, head_dim). + v_buffer (torch.Tensor): KV cache buffer for values with shape (total_tokens, num_kv_heads, head_dim). + b_req_idx (torch.Tensor): Batch request indices with shape (batch_size,). + b_start_loc (torch.Tensor): Start locations for each sequence with shape (batch_size,). + b_seq_len (torch.Tensor): Total sequence lengths (prefix + extend) with shape (batch_size,). + b_seq_len_prefix (torch.Tensor): Prefix sequence lengths with shape (batch_size,). + max_len_in_batch (int): Maximum sequence length in the batch. + + Returns: + None. Results written in-place to o_extend. + """ _LOGGER.info( f"REDUNDANT_ATTENTION: q_extend={tuple(q_extend.shape)} o_extend={tuple(o_extend.shape)} \ k_buffer={tuple(k_buffer.shape)} v_buffer={tuple(v_buffer.shape)}" diff --git a/aiter/ops/triton/moe_align_block_size.py b/aiter/ops/triton/moe_align_block_size.py index b605e83869..395e34fdaa 100644 --- a/aiter/ops/triton/moe_align_block_size.py +++ b/aiter/ops/triton/moe_align_block_size.py @@ -2,8 +2,8 @@ # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. import torch -import triton -import triton.language as tl +# import triton +# import triton.language as tl from aiter.ops.triton.utils.logger import AiterTritonLogger from aiter.ops.triton._triton_kernels.moe_align_block_size import ( _moe_align_block_size_stage1_kernel, @@ -27,6 +27,20 @@ def moe_align_block_size_triton( expert_ids: torch.Tensor, num_tokens_post_pad: torch.Tensor, ) -> None: + """ + Aligns and sorts MoE tokens by expert assignment with block-size padding for efficient computation. + + Args: + topk_ids (torch.Tensor): Top-k expert assignments per token with shape (num_tokens, topk). + num_experts (int): Total number of experts. + block_size (int): Block size for alignment and padding. + sorted_token_ids (torch.Tensor): Output tensor for sorted token indices. + expert_ids (torch.Tensor): Output tensor for expert ID per sorted token. + num_tokens_post_pad (torch.Tensor): Output tensor for total tokens after padding with shape (1,). + + Returns: + None. Results written in-place to sorted_token_ids, expert_ids, and num_tokens_post_pad. + """ _LOGGER.info( f"MOE_ALIGN_BLOCK_SIZE_TRITON: topk_ids={tuple(topk_ids.shape)} num_experts={num_experts} sorted_token_ids={tuple(sorted_token_ids.shape)} " + "block_size={block_size} expert_ids={tuple(expert_ids.shape)} num_tokens_post_pad={tuple(num_tokens_post_pad.shape)}" diff --git a/aiter/ops/triton/moe_op.py b/aiter/ops/triton/moe_op.py index d670e2096a..2c8c5fa8a1 100644 --- a/aiter/ops/triton/moe_op.py +++ b/aiter/ops/triton/moe_op.py @@ -72,7 +72,32 @@ def fused_moe( config: Optional[Dict[str, Any]] = None, ) -> None: """ - #TODO: Add doc + Fused Mixture-of-Experts (MoE) computation with top-k expert routing and optional quantization. + + Args: + A (torch.Tensor): Input activations with shape (num_tokens, hidden_dim). + B (torch.Tensor): Expert weights with shape (num_experts, hidden_dim, intermediate_dim). + C (torch.Tensor): Output tensor with shape (num_tokens, top_k, intermediate_dim). + A_scale (Optional[torch.Tensor]): Scale for A in FP8 mode with shape (1,) or (num_tokens, num_groups). + B_scale (Optional[torch.Tensor]): Scale for B with shape (num_experts, ...) for quantized modes. + B_zp (Optional[torch.Tensor]): Zero point for B in INT4/INT8 modes. + topk_weights (torch.Tensor): Routing weights for top-k experts with shape (num_tokens, top_k). + topk_ids (torch.Tensor): Top-k expert IDs per token with shape (num_tokens, top_k). + sorted_token_ids (torch.Tensor): Token IDs sorted by expert assignment. + expert_ids (torch.Tensor): Expert ID for each sorted token. + num_tokens_post_padded (torch.Tensor): Total tokens after block-size padding with shape (1,). + mul_routed_weight (bool): Multiply output by routing weights. + top_k (int): Number of experts per token. + compute_type (tl.dtype): Computation dtype for accumulation. + use_fp8_w8a8 (bool): Use FP8 quantization for weights and activations. + use_int8_w8a16 (bool): Use INT8 weights with higher precision activations. + use_int4_w4a16 (bool): Use INT4 weights with higher precision activations. + block_shape (Optional[List[int]]): Block shape [block_n, block_k] for grouped quantization. + config (Optional[Dict[str, Any]]): Kernel tuning parameters (BLOCK_SIZE_M, BLOCK_SIZE_N, + BLOCK_SIZE_K, GROUP_SIZE_M). + + Returns: + None. Results written in-place to C. """ _LOGGER.info( @@ -143,7 +168,7 @@ def fused_moe( num_tokens_post_padded, B.shape[1], A.shape[1], - EM, + EM, # it's not being used in the kernel topk_ids.numel(), A.stride(0), A.stride(1), @@ -186,7 +211,7 @@ def fused_moe( num_tokens_post_padded, B.shape[1], A.shape[1], - EM, + EM, # it's not being used in the kernel topk_ids.numel(), A.stride(0), A.stride(1), @@ -236,7 +261,7 @@ def fused_moe( num_tokens_post_padded, B.shape[1], A.shape[1] - _PADDING_SIZE, - sorted_token_ids.shape[0], + sorted_token_ids.shape[0], # (EM) it's not being used in the kernel topk_ids.numel(), A.stride(0), A.stride(1), @@ -278,7 +303,7 @@ def fused_moe( num_tokens_post_padded, B.shape[1], A.shape[1] - _PADDING_SIZE, - EM, + EM, # it's not being used in the kernel topk_ids.numel(), A.stride(0), A.stride(1), diff --git a/aiter/ops/triton/moe_op_e2e.py b/aiter/ops/triton/moe_op_e2e.py index 4af754ae6c..fda4aef65c 100644 --- a/aiter/ops/triton/moe_op_e2e.py +++ b/aiter/ops/triton/moe_op_e2e.py @@ -3,7 +3,7 @@ import torch import triton -import triton.language as tl +# import triton.language as tl from typing import Any, Dict, Optional from aiter.ops.triton.quant import dynamic_per_tensor_quant_fp8_i8 @@ -70,7 +70,31 @@ def e2e_moe( config: Optional[Dict[str, Any]] = None, ) -> None: """ - #TODO: Add doc + End-to-end fused MoE computation with up-projection (W1) and down-projection (W2) in single kernel. + + Args: + A (torch.Tensor): Input activations with shape (num_tokens, hidden_dim). + W1 (torch.Tensor): Up-projection expert weights with shape (num_experts, hidden_dim, intermediate_dim). + W2 (torch.Tensor): Down-projection expert weights with shape (num_experts, intermediate_dim, hidden_dim). + Intermediate (torch.Tensor): Intermediate buffer for up-projection results. + C (torch.Tensor): Output tensor with shape (num_tokens, hidden_dim). + A_scale (Optional[torch.Tensor]): Scale for A in FP8 mode. + W1_scale (Optional[torch.Tensor]): Scale for W1 in quantized modes. + W2_scale (Optional[torch.Tensor]): Scale for W2 in quantized modes. + topk_weights (torch.Tensor): Routing weights for top-k experts with shape (num_tokens, top_k). + sorted_token_ids (torch.Tensor): Token IDs sorted by expert assignment. + topk_ids: Top-k expert IDs per token with shape (num_tokens, top_k). + expert_ids (torch.Tensor): Expert ID for each sorted token. + num_tokens_post_padded (torch.Tensor): Total tokens after block-size padding with shape (1,). + mul_routed_weight (bool): Multiply output by routing weights. + top_k (int): Number of experts per token. + use_fp8_w8a8 (bool): Use FP8 quantization for weights and activations. + use_int8_w8a16 (bool): Use INT8 weights with higher precision activations. + config (Optional[Dict[str, Any]]): Kernel tuning parameters (BLOCK_SIZE_M, BLOCK_SIZE_N, + BLOCK_SIZE_K1, BLOCK_SIZE_K2, GROUP_SIZE_M). + + Returns: + None. Results written in-place to C. """ _LOGGER.info( f"MOE_E2E: A={tuple(A.shape)} W1={tuple(W1.shape)} W2={tuple(W2.shape)} topk_weights={tuple(topk_weights.shape)}" diff --git a/aiter/ops/triton/moe_op_gelu.py b/aiter/ops/triton/moe_op_gelu.py index 146f50d091..2b398efac4 100644 --- a/aiter/ops/triton/moe_op_gelu.py +++ b/aiter/ops/triton/moe_op_gelu.py @@ -68,7 +68,30 @@ def fused_moe_gelu( config: Optional[Dict[str, Any]] = None, ) -> None: """ - #TODO: Add doc + Fused MoE computation with GELU activation and optional quantization. + + Args: + A (torch.Tensor): Input activations with shape (num_tokens, hidden_dim). + B (torch.Tensor): Expert weights with shape (num_experts, hidden_dim, intermediate_dim). + C (torch.Tensor): Output tensor with shape (num_tokens, top_k, intermediate_dim). + A_scale (Optional[torch.Tensor]): Scale for A in FP8 mode with shape (1,) or (num_tokens, num_groups). + B_scale (Optional[torch.Tensor]): Scale for B with shape (num_experts, ...) for quantized modes. + topk_weights (torch.Tensor): Routing weights for top-k experts with shape (num_tokens, top_k). + topk_ids (torch.Tensor): Top-k expert IDs per token with shape (num_tokens, top_k). + sorted_token_ids (torch.Tensor): Token IDs sorted by expert assignment. + expert_ids (torch.Tensor): Expert ID for each sorted token. + num_tokens_post_padded (torch.Tensor): Total tokens after block-size padding with shape (1,). + mul_routed_weight (bool): Multiply output by routing weights. + top_k (int): Number of experts per token. + compute_type (tl.dtype): Computation dtype for accumulation. + use_fp8_w8a8 (bool): Use FP8 quantization for weights and activations. + use_int8_w8a16 (bool): Use INT8 weights with higher precision activations. + block_shape (Optional[List[int]]): Block shape [block_n, block_k] for grouped quantization. + config (Optional[Dict[str, Any]]): Kernel tuning parameters (BLOCK_SIZE_M, BLOCK_SIZE_N, + BLOCK_SIZE_K, GROUP_SIZE_M). + + Returns: + None. Results written in-place to C with GELU activation applied. """ _LOGGER.info( f"FUSED_MOE_GELU: A={tuple(A.shape)} B={tuple(B.shape)} C={tuple(C.shape)} topk_weights-{tuple(topk_weights.shape)}" @@ -131,7 +154,7 @@ def fused_moe_gelu( num_tokens_post_padded, B.shape[1], A.shape[1] - _PADDING_SIZE, - sorted_token_ids.shape[0], + sorted_token_ids.shape[0], # (EM) it's not being used in the kernel topk_ids.numel(), A.stride(0), A.stride(1), @@ -174,7 +197,7 @@ def fused_moe_gelu( num_tokens_post_padded, B.shape[1], A.shape[1] - _PADDING_SIZE, - EM, + EM, # it's not being used in the kernel topk_ids.numel(), A.stride(0), A.stride(1), diff --git a/aiter/ops/triton/moe_op_mxfp4.py b/aiter/ops/triton/moe_op_mxfp4.py index 929c69c3c4..cdea3738b6 100644 --- a/aiter/ops/triton/moe_op_mxfp4.py +++ b/aiter/ops/triton/moe_op_mxfp4.py @@ -37,7 +37,31 @@ def fused_moe_mxfp4( compute_type: tl.dtype, ) -> None: """ - #TODO: Add doc + Fused MoE computation with MXFP4 (microscale FP4) quantization. + + Args: + A (torch.Tensor): Input activations with shape (num_tokens, hidden_dim). FP4 or higher precision. + B (torch.Tensor): Expert weights with shape (num_experts, hidden_dim, intermediate_dim). MXFP4 format. + C (torch.Tensor): Output tensor with shape (num_tokens, top_k, intermediate_dim). + A_scale (torch.Tensor): Per-tensor or per-group scale for A. + B_scale (torch.Tensor): Per-group scale for B with shape (num_experts, ...). + A_mx_scale (torch.Tensor): Microscale (E8M0) scale for A if A is MXFP4. + B_mx_scale (torch.Tensor): Microscale (E8M0) scale for B. + topk_weights (torch.Tensor): Routing weights for top-k experts with shape (num_tokens, top_k). + topk_ids (torch.Tensor): Top-k expert IDs per token with shape (num_tokens, top_k). + sorted_token_ids (torch.Tensor): Token IDs sorted by expert assignment. + expert_ids (torch.Tensor): Expert ID for each sorted token. + num_tokens_post_padded (torch.Tensor): Total tokens after block-size padding with shape (1,). + mul_routed_weight (bool): Multiply output by routing weights. + top_k (int): Number of experts per token. + swizzle_mx_a (bool): Enable swizzled layout for A microscales. + swizzle_mx_b (bool): Enable swizzled layout for B microscales. + config (Dict[str, Any]): Kernel tuning parameters (BLOCK_SIZE_M, BLOCK_SIZE_N, + BLOCK_SIZE_K, GROUP_SIZE_M). + compute_type (tl.dtype): Computation dtype for accumulation. + + Returns: + None. Results written in-place to C. """ _LOGGER.info( f"MOE_OP_MXFP4: A={tuple(A.shape)} B={tuple(B.shape)} C={tuple(C.shape)} " @@ -92,7 +116,7 @@ def fused_moe_mxfp4( num_tokens_post_padded, B.shape[1], A.shape[1], - EM, + EM, # it's not being used in the kernel topk_ids.numel(), A.stride(0), A.stride(1), diff --git a/aiter/ops/triton/moe_op_mxfp4_silu_fused.py b/aiter/ops/triton/moe_op_mxfp4_silu_fused.py index eae6a11fd0..8dd5918d69 100644 --- a/aiter/ops/triton/moe_op_mxfp4_silu_fused.py +++ b/aiter/ops/triton/moe_op_mxfp4_silu_fused.py @@ -36,7 +36,31 @@ def fused_moe_mxfp4_silu( compute_type: tl.dtype, ) -> None: """ - #TODO: Add doc + Fused MoE computation with MXFP4 quantization and SiLU activation. + + Args: + A (torch.Tensor): Input activations with shape (num_tokens, hidden_dim). FP4 or higher precision. + B (torch.Tensor): Expert weights with shape (num_experts, hidden_dim, intermediate_dim). MXFP4 format. + C (torch.Tensor): Output tensor with shape (num_tokens, intermediate_dim). + A_scale (torch.Tensor): Per-tensor or per-group scale for A. + B_scale (torch.Tensor): Per-group scale for B with shape (num_experts, ...). + A_mx_scale (torch.Tensor): Microscale (E8M0) scale for A if A is MXFP4. + B_mx_scale (torch.Tensor): Microscale (E8M0) scale for B. + topk_weights (torch.Tensor): Routing weights for top-k experts with shape (num_tokens, top_k). + topk_ids (torch.Tensor): Top-k expert IDs per token with shape (num_tokens, top_k). + sorted_token_ids (torch.Tensor): Token IDs sorted by expert assignment. + expert_ids (torch.Tensor): Expert ID for each sorted token. + num_tokens_post_padded (torch.Tensor): Total tokens after block-size padding with shape (1,). + mul_routed_weight (bool): Multiply output by routing weights. + top_k (int): Number of experts per token. + swizzle_mx_a (bool): Enable swizzled layout for A microscales. + swizzle_mx_b (bool): Enable swizzled layout for B microscales. + config (Dict[str, Any]): Kernel tuning parameters (BLOCK_SIZE_M, BLOCK_SIZE_N, + BLOCK_SIZE_K, GROUP_SIZE_M). + compute_type (tl.dtype): Computation dtype for accumulation. + + Returns: + None. Results written in-place to C with SiLU activation applied. """ _LOGGER.info( f"MOE_OP_MXFP4: A={tuple(A.shape)} B={tuple(B.shape)} C={tuple(C.shape)} " @@ -91,7 +115,7 @@ def fused_moe_mxfp4_silu( num_tokens_post_padded, B.shape[1], A.shape[1], - EM, + EM, # it's not being used in the kernel topk_ids.numel(), A.stride(0), A.stride(1), diff --git a/aiter/ops/triton/moe_op_silu_fused.py b/aiter/ops/triton/moe_op_silu_fused.py index c2af791b20..b6dcd6bd61 100644 --- a/aiter/ops/triton/moe_op_silu_fused.py +++ b/aiter/ops/triton/moe_op_silu_fused.py @@ -72,7 +72,32 @@ def fused_moe_silu( config: Optional[Dict[str, Any]] = None, ) -> None: """ - #TODO: Add doc + Fused MoE computation with SiLU activation and optional quantization. + + Args: + A (torch.Tensor): Input activations with shape (num_tokens, hidden_dim). + B (torch.Tensor): Expert weights with shape (num_experts, hidden_dim, intermediate_dim). + C (torch.Tensor): Output tensor with shape (num_tokens, top_k, intermediate_dim). + A_scale (Optional[torch.Tensor]): Scale for A in FP8 mode with shape (1,) or (num_tokens, num_groups). + B_scale (Optional[torch.Tensor]): Scale for B with shape (num_experts, ...) for quantized modes. + B_zp (Optional[torch.Tensor]): Zero point for B in INT4/INT8 modes. + topk_weights (torch.Tensor): Routing weights for top-k experts with shape (num_tokens, top_k). + topk_ids (torch.Tensor): Top-k expert IDs per token with shape (num_tokens, top_k). + sorted_token_ids (torch.Tensor): Token IDs sorted by expert assignment. + expert_ids (torch.Tensor): Expert ID for each sorted token. + num_tokens_post_padded (torch.Tensor): Total tokens after block-size padding with shape (1,). + mul_routed_weight (bool): Multiply output by routing weights. + top_k (int): Number of experts per token. + compute_type (tl.dtype): Computation dtype for accumulation. + use_fp8_w8a8 (bool): Use FP8 quantization for weights and activations. + use_int8_w8a16 (bool): Use INT8 weights with higher precision activations. + use_int4_w4a16 (bool): Use INT4 weights with higher precision activations. + block_shape (Optional[List[int]]): Block shape [block_n, block_k] for grouped quantization. + config (Optional[Dict[str, Any]]): Kernel tuning parameters (BLOCK_SIZE_M, BLOCK_SIZE_N, + BLOCK_SIZE_K, GROUP_SIZE_M). + + Returns: + None. Results written in-place to C with SiLU activation applied. """ _LOGGER.info( f"FUSED_MOE_SILU: A={tuple(A.shape)} B={tuple(B.shape)} C={tuple(C.shape)} " @@ -141,7 +166,7 @@ def fused_moe_silu( num_tokens_post_padded, B.shape[1], A.shape[1], - EM, + EM, # it's not being used in the kernel topk_ids.numel(), A.stride(0), A.stride(1), @@ -185,7 +210,7 @@ def fused_moe_silu( num_tokens_post_padded, B.shape[1], A.shape[1], - EM, + EM, # it's not being used in the kernel topk_ids.numel(), A.stride(0), A.stride(1), @@ -235,7 +260,7 @@ def fused_moe_silu( num_tokens_post_padded, B.shape[1], A.shape[1] - _PADDING_SIZE, - sorted_token_ids.shape[0], + sorted_token_ids.shape[0], # (EM) it's not being used in the kernel topk_ids.numel(), A.stride(0), A.stride(1), @@ -277,7 +302,7 @@ def fused_moe_silu( num_tokens_post_padded, B.shape[1], A.shape[1] - _PADDING_SIZE, - EM, + EM, # it's not being used in the kernel topk_ids.numel(), A.stride(0), A.stride(1), diff --git a/aiter/ops/triton/moe_routing_sigmoid_top1_fused.py b/aiter/ops/triton/moe_routing_sigmoid_top1_fused.py index 989a1d5645..23b4d20251 100644 --- a/aiter/ops/triton/moe_routing_sigmoid_top1_fused.py +++ b/aiter/ops/triton/moe_routing_sigmoid_top1_fused.py @@ -4,7 +4,7 @@ from typing import Optional import torch import triton -import triton.language as tl +# import triton.language as tl from aiter.ops.triton.utils.logger import AiterTritonLogger from aiter.ops.triton._triton_kernels.moe_routing_sigmoid_top1_fused import ( _routing_sigmoid_top1_kernel, @@ -17,6 +17,21 @@ def routing_sigmoid_top1( x, w, topk, fused_shared_experts=False, config: Optional[dict[str, any]] = None ): + """ + Computes top-1 MoE routing with sigmoid activation for expert selection. + + Args: + x (torch.Tensor): Input activations with shape (batch_size, seq_len, hidden_dim) or (M, K). + w (torch.Tensor): Routing weights with shape (hidden_dim, num_experts). + topk (int): Number of experts to select. Must be 1. + fused_shared_experts (bool): Include shared expert (always selected) alongside top-1. + config (Optional[dict]): Kernel tuning parameters (BLOCK_M, BLOCK_K). + + Returns: + tuple: (topk_ids, topk_weights) + - topk_ids (torch.Tensor): Selected expert IDs with shape (M, topk) or (M, topk+1) if fused_shared_experts. + - topk_weights (torch.Tensor): Routing weights (sigmoid scores) with shape (M, topk) or (M, topk+1). + """ _LOGGER.info( f"ROUTING_SIGMOID_TOP1: x={tuple(x.shape)} w={tuple(w.shape)} topk={topk} " ) diff --git a/aiter/ops/triton/prefill_attention.py b/aiter/ops/triton/prefill_attention.py index 0f15a384eb..a7f8a246fc 100644 --- a/aiter/ops/triton/prefill_attention.py +++ b/aiter/ops/triton/prefill_attention.py @@ -23,7 +23,7 @@ # Adapted from # https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1 import triton -import triton.language as tl +# import triton.language as tl from aiter.ops.triton.utils.logger import AiterTritonLogger from aiter.ops.triton._triton_kernels.prefill_attention import _fwd_kernel @@ -34,10 +34,20 @@ def context_attention_fwd( q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True ): """ - q, k, v: [b * s, head, head_dim] - b_start_loc: [b] - b_seq_len: [b] - out: [b * s, head, head_dim] + Memory-efficient attention for prefill with page size = 1. + + Args: + q (torch.Tensor): Query tensor with shape (total_tokens, num_q_heads, head_dim). + k (torch.Tensor): Key tensor with shape (total_tokens, num_kv_heads, head_dim). + v (torch.Tensor): Value tensor with shape (total_tokens, num_kv_heads, head_dim). + o (torch.Tensor): Output tensor with shape (total_tokens, num_q_heads, head_dim). + b_start_loc (torch.Tensor): Start location for each sequence with shape (batch_size,). + b_seq_len (torch.Tensor): Sequence length for each batch with shape (batch_size,). + max_input_len (int): Maximum sequence length in the batch. + is_causal (bool): Apply causal masking. + + Returns: + None. Results written in-place to o. """ _LOGGER.info( f"PREFILL_ATTENTION: q={tuple(q.shape)} k={tuple(k.shape)} v={tuple(v.shape)}" diff --git a/aiter/ops/triton/softmax.py b/aiter/ops/triton/softmax.py index 5cc275c5a7..1533517953 100644 --- a/aiter/ops/triton/softmax.py +++ b/aiter/ops/triton/softmax.py @@ -1,6 +1,6 @@ import torch import triton -import triton.language as tl +# import triton.language as tl from aiter.ops.triton._triton_kernels.softmax import _softmax_kernel_online from aiter.ops.triton.utils.logger import AiterTritonLogger @@ -9,17 +9,13 @@ def softmax(x): """ - Computes the row-wise softmax of a 2D input tensor. + Computes row-wise softmax of a 2D input tensor. - Key parameters: - x (torch.Tensor): A 2D input tensor. + Args: + x (torch.Tensor): Input tensor with shape (n_rows, n_cols). Must be on GPU. Returns: - torch.Tensor: A tensor of the same shape as 'x', where softmax has been - applied along the last dimension (row-wise). - - Note: - - The input tensor 'x' must reside on the GPU. + torch.Tensor: Output with same shape as x, softmax applied along last dimension. """ _LOGGER.info(f"SOFTMAX: x={tuple(x.shape)}") n_rows, n_cols = x.shape @@ -40,7 +36,7 @@ def softmax(x): x, x.stride(0), y.stride(0), - n_rows, + n_rows, # it's not being used in the kernel n_cols, BLOCK_SIZE, waves_per_eu=waves_per_eu, diff --git a/aiter/ops/triton/topk.py b/aiter/ops/triton/topk.py index ad0b0fea47..11f681ceea 100644 --- a/aiter/ops/triton/topk.py +++ b/aiter/ops/triton/topk.py @@ -11,8 +11,9 @@ import torch import triton import triton.language as tl -import triton.language.core as core -from triton.language.standard import _log2, zeros_like + +# import triton.language.core as core +# from triton.language.standard import _log2, zeros_like from aiter.ops.triton._triton_kernels.topk import ( _topk_kernel, topk_stage1_kernel, @@ -173,6 +174,20 @@ def topk( sorted: bool = True, tiny_row_thresh: int = MAX_TINY_ROW, ): + """ + Selects k largest elements along last dimension using 1-stage or 2-stage algorithm. + + Args: + x (torch.Tensor): Input tensor with shape (B, M). Must be 2D. + k (int): Number of top elements to select. + dim (int): Dimension to reduce. Must be -1 (last dimension). + largest (bool): Select largest elements. Must be True. + sorted (bool): Return sorted results. Must be True. + tiny_row_thresh (int): Threshold for choosing 1-stage vs 2-stage algorithm. + + Returns: + tuple: (values, indices) both with shape (B, k), sorted in descending order. + """ _LOGGER.info(f"TOPK: x={tuple(x.shape)}, k={k}, largest={largest}, sorted={sorted}") if dim < 0: dim += x.ndim From 92359fec01be6470af79c031b9261b2073095828 Mon Sep 17 00:00:00 2001 From: Satya Nikhil Kodukula Date: Wed, 29 Oct 2025 16:45:56 +0000 Subject: [PATCH 4/7] implemented the use of kernel repr helper to standardize kernel metadata representation --- aiter/ops/triton/_triton_kernels/extend_attention.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/aiter/ops/triton/_triton_kernels/extend_attention.py b/aiter/ops/triton/_triton_kernels/extend_attention.py index 262b80be4b..756231c58a 100644 --- a/aiter/ops/triton/_triton_kernels/extend_attention.py +++ b/aiter/ops/triton/_triton_kernels/extend_attention.py @@ -17,7 +17,6 @@ It supports page size = 1 and prefill with KV cache (i.e. extend). """ -from typing import Optional import functools import json import torch @@ -25,12 +24,10 @@ import triton.language as tl -# from .prefill_attention import context_attention_fwd from .activation import _tanh -from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd +from ..utils._triton.pid_preprocessing import remap_xcd from ..utils._triton import arch_info from ..utils.core import AITER_TRITON_CONFIGS_PATH -from ..utils.device_info import get_num_xcds from ..utils._triton.kernel_repr import make_kernel_repr @@ -51,7 +48,6 @@ "STORE_TRANSPOSE", "NUM_Q_HEADS", "NUM_BLOCKS", - # "BATCH", "NUM_XCDS", ], ) @@ -98,7 +94,6 @@ def _fwd_kernel( STORE_TRANSPOSE: tl.constexpr, NUM_Q_HEADS: tl.constexpr, NUM_BLOCKS: tl.constexpr, - # BATCH: tl.constexpr, NUM_XCDS: tl.constexpr, ): workgroup_id = tl.program_id(0) # workgroup index From f644b7d30ef96fcd0dbcd474a60db40bb946821e Mon Sep 17 00:00:00 2001 From: Satya Nikhil Kodukula Date: Thu, 6 Nov 2025 18:44:28 +0000 Subject: [PATCH 5/7] remove commented out code and fix formatting --- aiter/ops/triton/_triton_kernels/moe_op.py | 3 --- aiter/ops/triton/_triton_kernels/moe_op_e2e.py | 4 ---- aiter/ops/triton/_triton_kernels/moe_op_gelu.py | 2 -- aiter/ops/triton/_triton_kernels/moe_op_mxfp4.py | 1 - aiter/ops/triton/_triton_kernels/moe_op_mxfp4_silu_fused.py | 1 - aiter/ops/triton/_triton_kernels/moe_op_silu_fused.py | 4 ---- aiter/ops/triton/_triton_kernels/softmax.py | 1 - aiter/ops/triton/_triton_kernels/topk.py | 1 - aiter/ops/triton/extend_attention.py | 6 ------ aiter/ops/triton/moe_align_block_size.py | 2 -- aiter/ops/triton/moe_op.py | 3 --- aiter/ops/triton/moe_op_e2e.py | 3 --- aiter/ops/triton/moe_op_gelu.py | 2 -- aiter/ops/triton/moe_op_mxfp4.py | 1 - aiter/ops/triton/moe_op_mxfp4_silu_fused.py | 1 - aiter/ops/triton/moe_op_silu_fused.py | 4 ---- aiter/ops/triton/moe_routing_sigmoid_top1_fused.py | 1 - aiter/ops/triton/prefill_attention.py | 1 - aiter/ops/triton/softmax.py | 2 -- 19 files changed, 43 deletions(-) diff --git a/aiter/ops/triton/_triton_kernels/moe_op.py b/aiter/ops/triton/_triton_kernels/moe_op.py index 2da8805a41..0ff45df383 100644 --- a/aiter/ops/triton/_triton_kernels/moe_op.py +++ b/aiter/ops/triton/_triton_kernels/moe_op.py @@ -353,7 +353,6 @@ def _fused_moe_persistent_kernel_gptq_awq( # Matrix dimensions N: tl.constexpr, K: tl.constexpr, - EM, num_valid_tokens, # The stride variables represent how much to increase the ptr by when # moving by 1 element in a particular dimension. E.g. `stride_am` is @@ -582,7 +581,6 @@ def _fused_moe_kernel( # Matrix dimensions N, K, - EM, num_valid_tokens, # The stride variables represent how much to increase the ptr by when # moving by 1 element in a particular dimension. E.g. `stride_am` is @@ -790,7 +788,6 @@ def _fused_moe_persistent_kernel( # Matrix dimensions N, K, - EM, num_valid_tokens, # The stride variables represent how much to increase the ptr by when # moving by 1 element in a particular dimension. E.g. `stride_am` is diff --git a/aiter/ops/triton/_triton_kernels/moe_op_e2e.py b/aiter/ops/triton/_triton_kernels/moe_op_e2e.py index d8e184c048..20fee72f2b 100644 --- a/aiter/ops/triton/_triton_kernels/moe_op_e2e.py +++ b/aiter/ops/triton/_triton_kernels/moe_op_e2e.py @@ -38,7 +38,6 @@ "e2e_moe_persistent_kernel", [ "top_k", - "EM", "N", "K", "EVEN_K", @@ -52,7 +51,6 @@ "BLOCK_SIZE_K1", "BLOCK_SIZE_K2", "NUM_SMS", - "NUM_XCDS", ], ) @@ -393,7 +391,6 @@ def e2e_moe_persistent_kernel( expert_ids_ptr, num_tokens_post_padded_ptr, num_valid_tokens, - EM: tl.constexpr, N: tl.constexpr, K: tl.constexpr, EVEN_K: tl.constexpr, @@ -407,7 +404,6 @@ def e2e_moe_persistent_kernel( BLOCK_SIZE_K1: tl.constexpr, # original block_size_k BLOCK_SIZE_K2: tl.constexpr, # outputs (EM, BLOCK_SIZE_K2) NUM_SMS: tl.constexpr, - NUM_XCDS: tl.constexpr, ): start_m = tl.program_id(axis=0) num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) diff --git a/aiter/ops/triton/_triton_kernels/moe_op_gelu.py b/aiter/ops/triton/_triton_kernels/moe_op_gelu.py index c92ff656da..a7258cd8c7 100644 --- a/aiter/ops/triton/_triton_kernels/moe_op_gelu.py +++ b/aiter/ops/triton/_triton_kernels/moe_op_gelu.py @@ -77,7 +77,6 @@ def _fused_moe_kernel( # Matrix dimensions N, K, - EM, num_valid_tokens, # The stride variables represent how much to increase the ptr by when # moving by 1 element in a particular dimension. E.g. `stride_am` is @@ -296,7 +295,6 @@ def _fused_moe_persistent_kernel( # Matrix dimensions N, K, - EM, num_valid_tokens, # The stride variables represent how much to increase the ptr by when # moving by 1 element in a particular dimension. E.g. `stride_am` is diff --git a/aiter/ops/triton/_triton_kernels/moe_op_mxfp4.py b/aiter/ops/triton/_triton_kernels/moe_op_mxfp4.py index 25850709d6..8c2018b032 100644 --- a/aiter/ops/triton/_triton_kernels/moe_op_mxfp4.py +++ b/aiter/ops/triton/_triton_kernels/moe_op_mxfp4.py @@ -61,7 +61,6 @@ def _fused_moe_kernel_mxfp4( # Matrix dimensions N, K, - EM, num_valid_tokens, # Strides stride_am, diff --git a/aiter/ops/triton/_triton_kernels/moe_op_mxfp4_silu_fused.py b/aiter/ops/triton/_triton_kernels/moe_op_mxfp4_silu_fused.py index 266621fa49..326ce397b0 100644 --- a/aiter/ops/triton/_triton_kernels/moe_op_mxfp4_silu_fused.py +++ b/aiter/ops/triton/_triton_kernels/moe_op_mxfp4_silu_fused.py @@ -61,7 +61,6 @@ def _fused_moe_kernel_mxfp4_silu( # Matrix dimensions N, K, - EM, num_valid_tokens, # Strides stride_am, diff --git a/aiter/ops/triton/_triton_kernels/moe_op_silu_fused.py b/aiter/ops/triton/_triton_kernels/moe_op_silu_fused.py index a37ff2e530..2b263b4698 100644 --- a/aiter/ops/triton/_triton_kernels/moe_op_silu_fused.py +++ b/aiter/ops/triton/_triton_kernels/moe_op_silu_fused.py @@ -119,7 +119,6 @@ def _fused_moe_silu_kernel_gptq_awq( # Matrix dimensions N: tl.constexpr, K: tl.constexpr, - EM, num_valid_tokens, # The stride variables represent how much to increase the ptr by when # moving by 1 element in a particular dimension. E.g. `stride_am` is @@ -380,7 +379,6 @@ def _fused_moe_persistent_silu_kernel_gptq_awq( # Matrix dimensions N: tl.constexpr, K: tl.constexpr, - EM, num_valid_tokens, # The stride variables represent how much to increase the ptr by when # moving by 1 element in a particular dimension. E.g. `stride_am` is @@ -627,7 +625,6 @@ def _fused_moe_silu_kernel( # Matrix dimensions N, K, - EM, num_valid_tokens, # The stride variables represent how much to increase the ptr by when # moving by 1 element in a particular dimension. E.g. `stride_am` is @@ -858,7 +855,6 @@ def _fused_moe_persistent_silu_kernel( # Matrix dimensions N, K, - EM, num_valid_tokens, # The stride variables represent how much to increase the ptr by when # moving by 1 element in a particular dimension. E.g. `stride_am` is diff --git a/aiter/ops/triton/_triton_kernels/softmax.py b/aiter/ops/triton/_triton_kernels/softmax.py index 5cd0860084..5d924f4878 100644 --- a/aiter/ops/triton/_triton_kernels/softmax.py +++ b/aiter/ops/triton/_triton_kernels/softmax.py @@ -17,7 +17,6 @@ def _softmax_kernel_online( input_ptr, input_row_stride, output_row_stride, - n_rows, n_cols, BLOCK_SIZE: tl.constexpr, ): diff --git a/aiter/ops/triton/_triton_kernels/topk.py b/aiter/ops/triton/_triton_kernels/topk.py index f01203e9db..3a5b6db1f0 100644 --- a/aiter/ops/triton/_triton_kernels/topk.py +++ b/aiter/ops/triton/_triton_kernels/topk.py @@ -5,7 +5,6 @@ # https://github.com/FlagOpen/FlagGems/blob/master/src/flag_gems/ops/topk.py # Top-K on GPU: 1-stage (tiny rows) + 2-stage (large rows) Triton kernels, -import math import triton import triton.language as tl import triton.language.core as core diff --git a/aiter/ops/triton/extend_attention.py b/aiter/ops/triton/extend_attention.py index cb08519fd6..7d43a44bab 100644 --- a/aiter/ops/triton/extend_attention.py +++ b/aiter/ops/triton/extend_attention.py @@ -20,7 +20,6 @@ from typing import Optional import torch import triton -# import triton.language as tl from aiter.ops.triton.prefill_attention import context_attention_fwd @@ -144,8 +143,6 @@ def extend_attention_fwd( BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_DPE=BLOCK_DPE, BLOCK_DV=BLOCK_DV, - # BLOCK_M=BLOCK_M, - # BLOCK_N=BLOCK_N, Lq=Lq, Lv=Lv, USE_CUSTOM_MASK=USE_CUSTOM_MASK, @@ -154,10 +151,7 @@ def extend_attention_fwd( STORE_TRANSPOSE=True, NUM_Q_HEADS=head_num, NUM_BLOCKS=num_blocks, - # BATCH=batch_size, NUM_XCDS=get_num_xcds(), - # num_warps=num_warps, - # num_stages=num_stages, **config, ) diff --git a/aiter/ops/triton/moe_align_block_size.py b/aiter/ops/triton/moe_align_block_size.py index 395e34fdaa..f8e733dbda 100644 --- a/aiter/ops/triton/moe_align_block_size.py +++ b/aiter/ops/triton/moe_align_block_size.py @@ -2,8 +2,6 @@ # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. import torch -# import triton -# import triton.language as tl from aiter.ops.triton.utils.logger import AiterTritonLogger from aiter.ops.triton._triton_kernels.moe_align_block_size import ( _moe_align_block_size_stage1_kernel, diff --git a/aiter/ops/triton/moe_op.py b/aiter/ops/triton/moe_op.py index 2c8c5fa8a1..17f8b8c50c 100644 --- a/aiter/ops/triton/moe_op.py +++ b/aiter/ops/triton/moe_op.py @@ -168,7 +168,6 @@ def fused_moe( num_tokens_post_padded, B.shape[1], A.shape[1], - EM, # it's not being used in the kernel topk_ids.numel(), A.stride(0), A.stride(1), @@ -261,7 +260,6 @@ def fused_moe( num_tokens_post_padded, B.shape[1], A.shape[1] - _PADDING_SIZE, - sorted_token_ids.shape[0], # (EM) it's not being used in the kernel topk_ids.numel(), A.stride(0), A.stride(1), @@ -303,7 +301,6 @@ def fused_moe( num_tokens_post_padded, B.shape[1], A.shape[1] - _PADDING_SIZE, - EM, # it's not being used in the kernel topk_ids.numel(), A.stride(0), A.stride(1), diff --git a/aiter/ops/triton/moe_op_e2e.py b/aiter/ops/triton/moe_op_e2e.py index fda4aef65c..755dc955df 100644 --- a/aiter/ops/triton/moe_op_e2e.py +++ b/aiter/ops/triton/moe_op_e2e.py @@ -3,7 +3,6 @@ import torch import triton -# import triton.language as tl from typing import Any, Dict, Optional from aiter.ops.triton.quant import dynamic_per_tensor_quant_fp8_i8 @@ -184,7 +183,6 @@ def e2e_moe( expert_ids, num_tokens_post_padded, topk_ids.numel(), - EM, N, K, EVEN_K, @@ -193,7 +191,6 @@ def e2e_moe( use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, NUM_SMS=NUM_SMS, - NUM_XCDS=get_num_xcds(), **config, ) diff --git a/aiter/ops/triton/moe_op_gelu.py b/aiter/ops/triton/moe_op_gelu.py index 2b398efac4..a15e281977 100644 --- a/aiter/ops/triton/moe_op_gelu.py +++ b/aiter/ops/triton/moe_op_gelu.py @@ -154,7 +154,6 @@ def fused_moe_gelu( num_tokens_post_padded, B.shape[1], A.shape[1] - _PADDING_SIZE, - sorted_token_ids.shape[0], # (EM) it's not being used in the kernel topk_ids.numel(), A.stride(0), A.stride(1), @@ -197,7 +196,6 @@ def fused_moe_gelu( num_tokens_post_padded, B.shape[1], A.shape[1] - _PADDING_SIZE, - EM, # it's not being used in the kernel topk_ids.numel(), A.stride(0), A.stride(1), diff --git a/aiter/ops/triton/moe_op_mxfp4.py b/aiter/ops/triton/moe_op_mxfp4.py index cdea3738b6..d4da5d999d 100644 --- a/aiter/ops/triton/moe_op_mxfp4.py +++ b/aiter/ops/triton/moe_op_mxfp4.py @@ -116,7 +116,6 @@ def fused_moe_mxfp4( num_tokens_post_padded, B.shape[1], A.shape[1], - EM, # it's not being used in the kernel topk_ids.numel(), A.stride(0), A.stride(1), diff --git a/aiter/ops/triton/moe_op_mxfp4_silu_fused.py b/aiter/ops/triton/moe_op_mxfp4_silu_fused.py index 8dd5918d69..33e369033b 100644 --- a/aiter/ops/triton/moe_op_mxfp4_silu_fused.py +++ b/aiter/ops/triton/moe_op_mxfp4_silu_fused.py @@ -115,7 +115,6 @@ def fused_moe_mxfp4_silu( num_tokens_post_padded, B.shape[1], A.shape[1], - EM, # it's not being used in the kernel topk_ids.numel(), A.stride(0), A.stride(1), diff --git a/aiter/ops/triton/moe_op_silu_fused.py b/aiter/ops/triton/moe_op_silu_fused.py index b6dcd6bd61..bdd9309c68 100644 --- a/aiter/ops/triton/moe_op_silu_fused.py +++ b/aiter/ops/triton/moe_op_silu_fused.py @@ -166,7 +166,6 @@ def fused_moe_silu( num_tokens_post_padded, B.shape[1], A.shape[1], - EM, # it's not being used in the kernel topk_ids.numel(), A.stride(0), A.stride(1), @@ -210,7 +209,6 @@ def fused_moe_silu( num_tokens_post_padded, B.shape[1], A.shape[1], - EM, # it's not being used in the kernel topk_ids.numel(), A.stride(0), A.stride(1), @@ -260,7 +258,6 @@ def fused_moe_silu( num_tokens_post_padded, B.shape[1], A.shape[1] - _PADDING_SIZE, - sorted_token_ids.shape[0], # (EM) it's not being used in the kernel topk_ids.numel(), A.stride(0), A.stride(1), @@ -302,7 +299,6 @@ def fused_moe_silu( num_tokens_post_padded, B.shape[1], A.shape[1] - _PADDING_SIZE, - EM, # it's not being used in the kernel topk_ids.numel(), A.stride(0), A.stride(1), diff --git a/aiter/ops/triton/moe_routing_sigmoid_top1_fused.py b/aiter/ops/triton/moe_routing_sigmoid_top1_fused.py index 23b4d20251..7468bd8cdd 100644 --- a/aiter/ops/triton/moe_routing_sigmoid_top1_fused.py +++ b/aiter/ops/triton/moe_routing_sigmoid_top1_fused.py @@ -4,7 +4,6 @@ from typing import Optional import torch import triton -# import triton.language as tl from aiter.ops.triton.utils.logger import AiterTritonLogger from aiter.ops.triton._triton_kernels.moe_routing_sigmoid_top1_fused import ( _routing_sigmoid_top1_kernel, diff --git a/aiter/ops/triton/prefill_attention.py b/aiter/ops/triton/prefill_attention.py index a7f8a246fc..f4b805ba87 100644 --- a/aiter/ops/triton/prefill_attention.py +++ b/aiter/ops/triton/prefill_attention.py @@ -23,7 +23,6 @@ # Adapted from # https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1 import triton -# import triton.language as tl from aiter.ops.triton.utils.logger import AiterTritonLogger from aiter.ops.triton._triton_kernels.prefill_attention import _fwd_kernel diff --git a/aiter/ops/triton/softmax.py b/aiter/ops/triton/softmax.py index 1533517953..5b9370339b 100644 --- a/aiter/ops/triton/softmax.py +++ b/aiter/ops/triton/softmax.py @@ -1,6 +1,5 @@ import torch import triton -# import triton.language as tl from aiter.ops.triton._triton_kernels.softmax import _softmax_kernel_online from aiter.ops.triton.utils.logger import AiterTritonLogger @@ -36,7 +35,6 @@ def softmax(x): x, x.stride(0), y.stride(0), - n_rows, # it's not being used in the kernel n_cols, BLOCK_SIZE, waves_per_eu=waves_per_eu, From 4acf71e61dce823d7553b71afc665fd85664d834 Mon Sep 17 00:00:00 2001 From: Satya Nikhil Kodukula Date: Thu, 6 Nov 2025 18:58:04 +0000 Subject: [PATCH 6/7] remove commented code --- aiter/ops/triton/topk.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/aiter/ops/triton/topk.py b/aiter/ops/triton/topk.py index 11f681ceea..20fd7343e5 100644 --- a/aiter/ops/triton/topk.py +++ b/aiter/ops/triton/topk.py @@ -12,8 +12,7 @@ import triton import triton.language as tl -# import triton.language.core as core -# from triton.language.standard import _log2, zeros_like + from aiter.ops.triton._triton_kernels.topk import ( _topk_kernel, topk_stage1_kernel, From 2932122e22223bde9434d87a7c76f53951965ee0 Mon Sep 17 00:00:00 2001 From: Satya Nikhil Kodukula Date: Thu, 6 Nov 2025 19:28:24 +0000 Subject: [PATCH 7/7] remove FILL_VALUE to keep kernels name meaningful --- aiter/ops/triton/_triton_kernels/topk.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/aiter/ops/triton/_triton_kernels/topk.py b/aiter/ops/triton/_triton_kernels/topk.py index 3a5b6db1f0..0122976a0e 100644 --- a/aiter/ops/triton/_triton_kernels/topk.py +++ b/aiter/ops/triton/_triton_kernels/topk.py @@ -18,7 +18,6 @@ "M", "K", "BLOCK", - "FILL_VALUE", ], ) @@ -28,7 +27,6 @@ "N", "CHUNK_SIZE", "DESCENDING", - "FILL_VALUE", ], ) @@ -39,8 +37,6 @@ "N", "BLOCK_SIZE", "DESCENDING", - "FILL_VALUE", - "MASK_INDEX_VAL", ], )