From 6490ffd999d4b4e3d9a3e288175ca11bef6f2bfb Mon Sep 17 00:00:00 2001 From: zufayu Date: Mon, 22 Dec 2025 15:56:39 +0800 Subject: [PATCH 1/7] add fp32 input --- csrc/kernels/activation_kernels.cu | 228 +++++++++++++++++++++-------- op_tests/test_activation.py | 4 +- 2 files changed, 168 insertions(+), 64 deletions(-) diff --git a/csrc/kernels/activation_kernels.cu b/csrc/kernels/activation_kernels.cu index 3a685ae1e9..985791162a 100644 --- a/csrc/kernels/activation_kernels.cu +++ b/csrc/kernels/activation_kernels.cu @@ -21,18 +21,30 @@ using fp8_type = ck_tile::fp8_t; static constexpr int32_t max_vec_size = 8; static constexpr int32_t max_wave_num = 8; +// Type trait: fp32 inputs compute in bf16, others use their native type +template +using compute_type_t = std::conditional_t, ck_tile::bfloat16_t, T>; + namespace aiter { -// Activation and gating kernel template. +// Activation and gating kernel template with fp32 auto-conversion support. +// If DTYPE_I is float, it will be converted to bf16 for computation using ck_tile::type_convert. template __global__ void act_and_mul_kernel(DTYPE_I* __restrict__ out, // [..., d] const DTYPE_I* __restrict__ input, // [..., 2, d] const int d) { + using DTYPE_COMPUTE = compute_type_t; + + // CK Tile buffer addressing constraint: float supports VEC_SIZE <= 16 + static_assert(!(std::is_same_v && VEC_SIZE_I > 16), + "float type only supports VEC_SIZE up to 16"); + const int64_t token_idx = blockIdx.x; auto const* ptr_x = (input + token_idx * 2 * d); auto const* ptr_y = (input + token_idx * 2 * d + d); using vec_i = ck_tile::vec_t; + using vec_c = ck_tile::vec_t; static constexpr int32_t ooba_i = 4 / sizeof(DTYPE_I); const int32_t oob_i = (d + ooba_i - 1) / ooba_i * ooba_i; auto buffer_x = ck_tile::make_buffer_view(ptr_x, oob_i); @@ -116,23 +128,42 @@ __global__ void act_and_mul_kernel(DTYPE_I* __restrict__ out, // [..., d for(int64_t idx = threadIdx.x * VEC_SIZE_I; idx < d; idx += blockDim.x * VEC_SIZE_I) { - vec_i x{}; - vec_i y{}; + vec_i x = buffer_x.template get(idx, 0, true); + vec_i y = buffer_y.template get(idx, 0, true); - x = buffer_x.template get(idx, 0, true); - y = buffer_y.template get(idx, 0, true); + // Convert fp32→bf16 if needed, otherwise use directly + vec_c x_compute{}; + vec_c y_compute{}; + + if constexpr(std::is_same_v) + { +#pragma unroll + for(size_t j = 0; j < VEC_SIZE_I; j++) + { + x_compute[j] = ck_tile::type_convert(x[j]); // fp32→bf16 + y_compute[j] = ck_tile::type_convert(y[j]); + } + } + else + { + x_compute = x; // bf16/fp16: zero-copy + y_compute = y; + } vec_i r{}; #pragma unroll for(size_t j = 0; j < VEC_SIZE_I; j += 2) { - float ax0 = ACT_FN(x[j]); - float y0 = ck_tile::type_convert(y[j]); + // Call ACT_FN with appropriate type conversion + DTYPE_I x_val0 = ck_tile::type_convert(x_compute[j]); + float ax0 = ACT_FN(x_val0); + float y0 = ck_tile::type_convert(y_compute[j]); if(j + 1 < VEC_SIZE_I) { - float ax1 = ACT_FN(x[j + 1]); - float y1 = ck_tile::type_convert(y[j + 1]); + DTYPE_I x_val1 = ck_tile::type_convert(x_compute[j + 1]); + float ax1 = ACT_FN(x_val1); + float y1 = ck_tile::type_convert(y_compute[j + 1]); ck_tile::fp32x2_t a = {ax0, ax1}; ck_tile::fp32x2_t b = {y0, y1}; ck_tile::fp32x2_t c; @@ -158,17 +189,24 @@ __global__ void act_and_mul_kernel(DTYPE_I* __restrict__ out, // [..., d } } -// Scaled activation and gating kernel template. +// Scaled activation and gating kernel template with fp32 auto-conversion support. template __global__ void scaled_act_and_mul_kernel(fp8_type* __restrict__ out, // [..., d] const DTYPE_I* __restrict__ input, // [..., 2, d] const int d, const float scale) { + using DTYPE_COMPUTE = compute_type_t; + + // CK Tile buffer addressing constraint: float supports VEC_SIZE <= 16 + static_assert(!(std::is_same_v && VEC_SIZE_I > 16), + "float type only supports VEC_SIZE up to 16"); + const int64_t token_idx = blockIdx.x; auto const* ptr_x = (input + token_idx * 2 * d); auto const* ptr_y = (input + token_idx * 2 * d + d); using vec_i = ck_tile::vec_t; + using vec_c = ck_tile::vec_t; static constexpr int32_t ooba_i = 4 / sizeof(DTYPE_I); const int32_t oob_i = (d + ooba_i - 1) / ooba_i * ooba_i; @@ -179,17 +217,38 @@ __global__ void scaled_act_and_mul_kernel(fp8_type* __restrict__ out, // for(int64_t idx = threadIdx.x * VEC_SIZE_I; idx < d; idx += blockDim.x * VEC_SIZE_I) { - auto x = buffer_x.template get(idx, 0, true); - auto y = buffer_y.template get(idx, 0, true); + vec_i x = buffer_x.template get(idx, 0, true); + vec_i y = buffer_y.template get(idx, 0, true); + + // Convert fp32→bf16 if needed, otherwise use directly + vec_c x_compute{}; + vec_c y_compute{}; + + if constexpr(std::is_same_v) + { +#pragma unroll + for(size_t j = 0; j < VEC_SIZE_I; j++) + { + x_compute[j] = ck_tile::type_convert(x[j]); // fp32→bf16 + y_compute[j] = ck_tile::type_convert(y[j]); + } + } + else + { + x_compute = x; // bf16/fp16: zero-copy + y_compute = y; + } for(size_t j = 0; j < VEC_SIZE_I; j += 2) { if(j + 1 < VEC_SIZE_I) { - float act_x0 = ACT_FN(x[j]); - float act_x1 = ACT_FN(x[j + 1]); - float y0 = ck_tile::type_convert(y[j]); - float y1 = ck_tile::type_convert(y[j + 1]); + DTYPE_I x_val0 = ck_tile::type_convert(x_compute[j]); + DTYPE_I x_val1 = ck_tile::type_convert(x_compute[j + 1]); + float act_x0 = ACT_FN(x_val0); + float act_x1 = ACT_FN(x_val1); + float y0 = ck_tile::type_convert(y_compute[j]); + float y1 = ck_tile::type_convert(y_compute[j + 1]); float2 act_vals = {act_x0, act_x1}; float2 y_vals = {y0, y1}; @@ -206,7 +265,8 @@ __global__ void scaled_act_and_mul_kernel(fp8_type* __restrict__ out, // } else { - float r = ACT_FN(x[j]) * ck_tile::type_convert(y[j]) * scale; + DTYPE_I x_val = ck_tile::type_convert(x_compute[j]); + float r = ACT_FN(x_val) * ck_tile::type_convert(y_compute[j]) * scale; out[token_idx * d + idx + j] = ck_tile::type_convert(r); } } @@ -257,50 +317,94 @@ static constexpr int nextPow2(unsigned int num) return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); } -// Launch activation and gating kernel. -#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ - int d = input.size(-1) / 2; \ - int64_t num_tokens = input.numel() / input.size(-1); \ - int vec_size = nextPow2(d / 64); \ - vec_size = vec_size < 2 ? 2 : vec_size; \ - vec_size = vec_size > max_vec_size ? max_vec_size : vec_size; \ - int num_wave = nextPow2(d / 64 / vec_size); \ - num_wave = num_wave > max_wave_num ? max_wave_num : num_wave; \ - dim3 grid(num_tokens); \ - dim3 block(num_wave * 64); \ - const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); \ - const hipStream_t stream = at::hip::getCurrentHIPStream(); \ - AITER_DISPATCH_FLOATING16_TYPES(input.scalar_type(), "act_and_mul_kernel", [&] { \ - using input_dtype = typename t2ck::type; \ - AITER_DISPATCH_CASE_VEC_SIZE( \ - vec_size, \ - aiter::act_and_mul_kernel, VEC_SIZE> \ - <<>>(reinterpret_cast(out.data_ptr()), \ - reinterpret_cast(input.data_ptr()), \ - d);) \ - }); -#define LAUNCH_SCALED_ACTIVATION_GATE_KERNEL(KERNEL) \ - int d = input.size(-1) / 2; \ - int64_t num_tokens = input.numel() / input.size(-1); \ - int vec_size = nextPow2(d / 64); \ - vec_size = vec_size < 2 ? 2 : vec_size; \ - vec_size = vec_size > max_vec_size ? max_vec_size : vec_size; \ - int num_wave = nextPow2(d / 64 / vec_size); \ - num_wave = num_wave > max_wave_num ? max_wave_num : num_wave; \ - dim3 grid(num_tokens); \ - dim3 block(num_wave * 64); \ - const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); \ - const hipStream_t stream = at::hip::getCurrentHIPStream(); \ - AITER_DISPATCH_FLOATING16_TYPES(input.scalar_type(), "scaled_act_and_mul_kernel", [&] { \ - using input_dtype = typename t2ck::type; \ - AITER_DISPATCH_CASE_VEC_SIZE( \ - vec_size, \ - aiter::scaled_act_and_mul_kernel, VEC_SIZE> \ - <<>>(reinterpret_cast(out.data_ptr()), \ - reinterpret_cast(input.data_ptr()), \ - d, \ - 1.0f / (*scale.data_ptr()));) \ - }); +// Common kernel launch parameters computation +#define COMPUTE_ACTIVATION_KERNEL_PARAMS \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + int vec_size = nextPow2(d / 64); \ + vec_size = vec_size < 2 ? 2 : vec_size; \ + vec_size = vec_size > max_vec_size ? max_vec_size : vec_size; \ + int num_wave = nextPow2(d / 64 / vec_size); \ + num_wave = num_wave > max_wave_num ? max_wave_num : num_wave; \ + dim3 grid(num_tokens); \ + dim3 block(num_wave * 64); \ + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); \ + const hipStream_t stream = at::hip::getCurrentHIPStream(); + +// Helper macro for fp32 vec_size dispatch (CK Tile only supports VEC_SIZE <= 16 for fp32) +#define DISPATCH_FP32_VEC_SIZE_CASE(VS, KERNEL_NAME, KERNEL, ...) \ + case VS: \ + aiter::KERNEL_NAME, VS> \ + <<>>(__VA_ARGS__); \ + break; + +#define DISPATCH_FP32_KERNEL(KERNEL_NAME, KERNEL, ...) \ + switch(vec_size) \ + { \ + DISPATCH_FP32_VEC_SIZE_CASE(16, KERNEL_NAME, KERNEL, __VA_ARGS__) \ + DISPATCH_FP32_VEC_SIZE_CASE(8, KERNEL_NAME, KERNEL, __VA_ARGS__) \ + DISPATCH_FP32_VEC_SIZE_CASE(4, KERNEL_NAME, KERNEL, __VA_ARGS__) \ + DISPATCH_FP32_VEC_SIZE_CASE(2, KERNEL_NAME, KERNEL, __VA_ARGS__) \ + DISPATCH_FP32_VEC_SIZE_CASE(1, KERNEL_NAME, KERNEL, __VA_ARGS__) \ + } + +#define DISPATCH_FP32_ACT_KERNEL(KERNEL, out_ptr, in_ptr) \ + DISPATCH_FP32_KERNEL(act_and_mul_kernel, KERNEL, out_ptr, in_ptr, d) + +#define DISPATCH_FP32_SCALED_ACT_KERNEL(KERNEL, out_ptr, in_ptr, inv_scale) \ + DISPATCH_FP32_KERNEL(scaled_act_and_mul_kernel, KERNEL, out_ptr, in_ptr, d, inv_scale) + +// Launch activation and gating kernel (fp32/bf16/fp16 unified, fp32→bf16 auto-convert) +#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ + COMPUTE_ACTIVATION_KERNEL_PARAMS \ + if(input.scalar_type() == at::ScalarType::Float) \ + { \ + /* fp32: limited VEC_SIZE due to CK Tile constraint */ \ + using input_dtype = ck_tile::fp32_t; \ + auto* out_ptr = reinterpret_cast(out.data_ptr()); \ + auto* in_ptr = reinterpret_cast(input.data_ptr()); \ + DISPATCH_FP32_ACT_KERNEL(KERNEL, out_ptr, in_ptr) \ + } \ + else \ + { \ + /* bf16/fp16: full VEC_SIZE support */ \ + AITER_DISPATCH_FLOATING16_TYPES(input.scalar_type(), "act_and_mul_kernel", [&] { \ + using input_dtype = typename t2ck::type; \ + AITER_DISPATCH_CASE_VEC_SIZE( \ + vec_size, \ + aiter::act_and_mul_kernel, VEC_SIZE> \ + <<>>(reinterpret_cast(out.data_ptr()), \ + reinterpret_cast(input.data_ptr()), \ + d);) \ + }); \ + } + +// Launch scaled activation and gating kernel (fp32/bf16/fp16 unified, fp32→bf16 auto-convert) +#define LAUNCH_SCALED_ACTIVATION_GATE_KERNEL(KERNEL) \ + COMPUTE_ACTIVATION_KERNEL_PARAMS \ + if(input.scalar_type() == at::ScalarType::Float) \ + { \ + /* fp32: limited VEC_SIZE due to CK Tile constraint */ \ + using input_dtype = ck_tile::fp32_t; \ + auto* out_ptr = reinterpret_cast(out.data_ptr()); \ + auto* in_ptr = reinterpret_cast(input.data_ptr()); \ + float inv_scale = 1.0f / (*scale.data_ptr()); \ + DISPATCH_FP32_SCALED_ACT_KERNEL(KERNEL, out_ptr, in_ptr, inv_scale) \ + } \ + else \ + { \ + /* bf16/fp16: full VEC_SIZE support */ \ + AITER_DISPATCH_FLOATING16_TYPES(input.scalar_type(), "scaled_act_and_mul_kernel", [&] { \ + using input_dtype = typename t2ck::type; \ + AITER_DISPATCH_CASE_VEC_SIZE( \ + vec_size, \ + aiter::scaled_act_and_mul_kernel, VEC_SIZE> \ + <<>>(reinterpret_cast(out.data_ptr()), \ + reinterpret_cast(input.data_ptr()), \ + d, \ + 1.0f / (*scale.data_ptr()));) \ + }); \ + } namespace aiter { @@ -392,4 +496,4 @@ void gelu_fast(torch::Tensor& out, // [..., d] LAUNCH_ACTIVATION_KERNEL(aiter::gelu_fast_kernel); } -} // namespace aiter \ No newline at end of file +} // namespace aiter diff --git a/op_tests/test_activation.py b/op_tests/test_activation.py index dbefdf827a..999c8793f5 100644 --- a/op_tests/test_activation.py +++ b/op_tests/test_activation.py @@ -71,7 +71,7 @@ def test_silu_and_mul(m, n, dtype): return ret -l_dtype = ["fp16", "bf16"] +l_dtype = ["fp16", "bf16", "fp32"] l_m = [1, 32, 64, 128, 256, 512, 1024, 4096, 8192, 163840] l_n = [1024, 4096, 6400, 8192] @@ -135,4 +135,4 @@ def test_silu_and_mul(m, n, dtype): ret = test_silu_and_mul(m, n, dtype) df.append(ret) df = pd.DataFrame(df) -aiter.logger.info(f"silu_and_mul summary:\n{df}") +aiter.logger.info(f"silu_and_mul summary:\n{df}") \ No newline at end of file From 2fdbd0c4fe2a8642bfe3b095dedbfb75b4ef9365 Mon Sep 17 00:00:00 2001 From: zufayu Date: Mon, 22 Dec 2025 16:01:36 +0800 Subject: [PATCH 2/7] format code --- op_tests/test_activation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/op_tests/test_activation.py b/op_tests/test_activation.py index 999c8793f5..e373c9a750 100644 --- a/op_tests/test_activation.py +++ b/op_tests/test_activation.py @@ -135,4 +135,4 @@ def test_silu_and_mul(m, n, dtype): ret = test_silu_and_mul(m, n, dtype) df.append(ret) df = pd.DataFrame(df) -aiter.logger.info(f"silu_and_mul summary:\n{df}") \ No newline at end of file +aiter.logger.info(f"silu_and_mul summary:\n{df}") From 27138f5f55cb942dbc8aa9bfdff1a76a9fcf1653 Mon Sep 17 00:00:00 2001 From: zufayu Date: Mon, 22 Dec 2025 16:46:01 +0800 Subject: [PATCH 3/7] perf bug fix --- csrc/kernels/activation_kernels.cu | 90 +++++++++++------------------- 1 file changed, 34 insertions(+), 56 deletions(-) diff --git a/csrc/kernels/activation_kernels.cu b/csrc/kernels/activation_kernels.cu index 985791162a..a26951149e 100644 --- a/csrc/kernels/activation_kernels.cu +++ b/csrc/kernels/activation_kernels.cu @@ -21,20 +21,25 @@ using fp8_type = ck_tile::fp8_t; static constexpr int32_t max_vec_size = 8; static constexpr int32_t max_wave_num = 8; -// Type trait: fp32 inputs compute in bf16, others use their native type +// Type trait for computation type (all compute in native type) template -using compute_type_t = std::conditional_t, ck_tile::bfloat16_t, T>; +using compute_type_t = T; + +// Type trait for output type (fp32 outputs as bf16, others keep native type) +template +using output_type_t = std::conditional_t, ck_tile::bfloat16_t, T>; namespace aiter { -// Activation and gating kernel template with fp32 auto-conversion support. -// If DTYPE_I is float, it will be converted to bf16 for computation using ck_tile::type_convert. +// Activation and gating kernel template supporting fp32/bf16/fp16. +// fp32 inputs compute in fp32 but output as bf16; bf16/fp16 keep native type. template __global__ void act_and_mul_kernel(DTYPE_I* __restrict__ out, // [..., d] const DTYPE_I* __restrict__ input, // [..., 2, d] const int d) { using DTYPE_COMPUTE = compute_type_t; + using DTYPE_O = output_type_t; // CK Tile buffer addressing constraint: float supports VEC_SIZE <= 16 static_assert(!(std::is_same_v && VEC_SIZE_I > 16), @@ -45,6 +50,7 @@ __global__ void act_and_mul_kernel(DTYPE_I* __restrict__ out, // [..., d auto const* ptr_y = (input + token_idx * 2 * d + d); using vec_i = ck_tile::vec_t; using vec_c = ck_tile::vec_t; + using vec_o = ck_tile::vec_t; static constexpr int32_t ooba_i = 4 / sizeof(DTYPE_I); const int32_t oob_i = (d + ooba_i - 1) / ooba_i * ooba_i; auto buffer_x = ck_tile::make_buffer_view(ptr_x, oob_i); @@ -52,15 +58,17 @@ __global__ void act_and_mul_kernel(DTYPE_I* __restrict__ out, // [..., d buffer_x.init_raw(); buffer_y.init_raw(); - // Output buffer view for wide stores (raw path) - DTYPE_I* __restrict__ out_base = out + token_idx * d; + // Output buffer view (may have different type than input for fp32→bf16) + DTYPE_O* __restrict__ out_base = reinterpret_cast(out) + token_idx * d; + static constexpr int32_t ooba_o = 4 / sizeof(DTYPE_O); + const int32_t oob_o = (d + ooba_o - 1) / ooba_o * ooba_o; auto buffer_out = - ck_tile::make_buffer_view(out_base, oob_i); + ck_tile::make_buffer_view(out_base, oob_o); buffer_out.init_raw(); - constexpr int32_t allowed_max = std::is_same::value ? 8 : 16; + constexpr int32_t allowed_max = std::is_same::value ? 8 : 16; - auto store_vec_segmented = [&](int64_t base_idx, const vec_i& v) __device__ { + auto store_vec_segmented = [&](int64_t base_idx, const vec_o& v) __device__ { int64_t off = base_idx; int32_t rem = VEC_SIZE_I; int32_t pos = 0; @@ -68,7 +76,7 @@ __global__ void act_and_mul_kernel(DTYPE_I* __restrict__ out, // [..., d { if(allowed_max >= 16 && rem >= 16) { - using vec16 = ck_tile::vec_t; + using vec16 = ck_tile::vec_t; vec16 t{}; #pragma unroll for(int i = 0; i < 16; ++i) @@ -80,7 +88,7 @@ __global__ void act_and_mul_kernel(DTYPE_I* __restrict__ out, // [..., d } else if(rem >= 8) { - using vec8 = ck_tile::vec_t; + using vec8 = ck_tile::vec_t; vec8 t{}; #pragma unroll for(int i = 0; i < 8; ++i) @@ -92,7 +100,7 @@ __global__ void act_and_mul_kernel(DTYPE_I* __restrict__ out, // [..., d } else if(rem >= 4) { - using vec4 = ck_tile::vec_t; + using vec4 = ck_tile::vec_t; vec4 t{}; #pragma unroll for(int i = 0; i < 4; ++i) @@ -104,7 +112,7 @@ __global__ void act_and_mul_kernel(DTYPE_I* __restrict__ out, // [..., d } else if(rem >= 2) { - using vec2 = ck_tile::vec_t; + using vec2 = ck_tile::vec_t; vec2 t{}; t[0] = v[pos + 0]; t[1] = v[pos + 1]; @@ -115,7 +123,7 @@ __global__ void act_and_mul_kernel(DTYPE_I* __restrict__ out, // [..., d } else { - using vec1 = ck_tile::vec_t; + using vec1 = ck_tile::vec_t; vec1 t{}; t[0] = v[pos]; buffer_out.template set(off, 0, true, t); @@ -131,26 +139,11 @@ __global__ void act_and_mul_kernel(DTYPE_I* __restrict__ out, // [..., d vec_i x = buffer_x.template get(idx, 0, true); vec_i y = buffer_y.template get(idx, 0, true); - // Convert fp32→bf16 if needed, otherwise use directly - vec_c x_compute{}; - vec_c y_compute{}; - - if constexpr(std::is_same_v) - { -#pragma unroll - for(size_t j = 0; j < VEC_SIZE_I; j++) - { - x_compute[j] = ck_tile::type_convert(x[j]); // fp32→bf16 - y_compute[j] = ck_tile::type_convert(y[j]); - } - } - else - { - x_compute = x; // bf16/fp16: zero-copy - y_compute = y; - } + // Compute directly in native type (DTYPE_I == DTYPE_COMPUTE) + const vec_c& x_compute = x; + const vec_c& y_compute = y; - vec_i r{}; + vec_o r{}; #pragma unroll for(size_t j = 0; j < VEC_SIZE_I; j += 2) @@ -168,19 +161,19 @@ __global__ void act_and_mul_kernel(DTYPE_I* __restrict__ out, // [..., d ck_tile::fp32x2_t b = {y0, y1}; ck_tile::fp32x2_t c; asm volatile("v_pk_mul_f32 %0, %1, %2" : "=v"(c) : "v"(a), "v"(b)); - r[j] = ck_tile::type_convert(c.x); - r[j + 1] = ck_tile::type_convert(c.y); + r[j] = ck_tile::type_convert(c.x); + r[j + 1] = ck_tile::type_convert(c.y); } else { - r[j] = ck_tile::type_convert(ax0 * y0); + r[j] = ck_tile::type_convert(ax0 * y0); } } if constexpr(VEC_SIZE_I == 1 || VEC_SIZE_I == 2 || VEC_SIZE_I == 4 || VEC_SIZE_I == 8 || VEC_SIZE_I == 16) { - buffer_out.template set(idx, 0, true, r); + buffer_out.template set(idx, 0, true, r); } else { @@ -189,7 +182,7 @@ __global__ void act_and_mul_kernel(DTYPE_I* __restrict__ out, // [..., d } } -// Scaled activation and gating kernel template with fp32 auto-conversion support. +// Scaled activation and gating kernel template supporting fp32/bf16/fp16. template __global__ void scaled_act_and_mul_kernel(fp8_type* __restrict__ out, // [..., d] const DTYPE_I* __restrict__ input, // [..., 2, d] @@ -220,24 +213,9 @@ __global__ void scaled_act_and_mul_kernel(fp8_type* __restrict__ out, // vec_i x = buffer_x.template get(idx, 0, true); vec_i y = buffer_y.template get(idx, 0, true); - // Convert fp32→bf16 if needed, otherwise use directly - vec_c x_compute{}; - vec_c y_compute{}; - - if constexpr(std::is_same_v) - { -#pragma unroll - for(size_t j = 0; j < VEC_SIZE_I; j++) - { - x_compute[j] = ck_tile::type_convert(x[j]); // fp32→bf16 - y_compute[j] = ck_tile::type_convert(y[j]); - } - } - else - { - x_compute = x; // bf16/fp16: zero-copy - y_compute = y; - } + // Compute directly in native type (DTYPE_I == DTYPE_COMPUTE) + const vec_c& x_compute = x; + const vec_c& y_compute = y; for(size_t j = 0; j < VEC_SIZE_I; j += 2) { From 4594cf9cff9c7c6d1ed13e071817275fe43523ab Mon Sep 17 00:00:00 2001 From: zufayu Date: Mon, 22 Dec 2025 18:22:56 +0800 Subject: [PATCH 4/7] logic fix : out type != input type --- csrc/kernels/activation_kernels.cu | 161 +++++++++++++++++++---------- op_tests/test_activation.py | 89 ++++++++++++++-- 2 files changed, 192 insertions(+), 58 deletions(-) diff --git a/csrc/kernels/activation_kernels.cu b/csrc/kernels/activation_kernels.cu index a26951149e..a4ebc373dc 100644 --- a/csrc/kernels/activation_kernels.cu +++ b/csrc/kernels/activation_kernels.cu @@ -25,21 +25,17 @@ static constexpr int32_t max_wave_num = 8; template using compute_type_t = T; -// Type trait for output type (fp32 outputs as bf16, others keep native type) -template -using output_type_t = std::conditional_t, ck_tile::bfloat16_t, T>; - namespace aiter { -// Activation and gating kernel template supporting fp32/bf16/fp16. -// fp32 inputs compute in fp32 but output as bf16; bf16/fp16 keep native type. -template -__global__ void act_and_mul_kernel(DTYPE_I* __restrict__ out, // [..., d] +// Activation and gating kernel template with flexible input/output types. +// DTYPE_I: input type (fp32/bf16/fp16), DTYPE_O: output type (fp32/bf16/fp16) +// Computes in DTYPE_I native precision, converts to DTYPE_O on output. +template +__global__ void act_and_mul_kernel(DTYPE_O* __restrict__ out, // [..., d] const DTYPE_I* __restrict__ input, // [..., 2, d] const int d) { using DTYPE_COMPUTE = compute_type_t; - using DTYPE_O = output_type_t; // CK Tile buffer addressing constraint: float supports VEC_SIZE <= 16 static_assert(!(std::is_same_v && VEC_SIZE_I > 16), @@ -58,8 +54,8 @@ __global__ void act_and_mul_kernel(DTYPE_I* __restrict__ out, // [..., d buffer_x.init_raw(); buffer_y.init_raw(); - // Output buffer view (may have different type than input for fp32→bf16) - DTYPE_O* __restrict__ out_base = reinterpret_cast(out) + token_idx * d; + // Output buffer view (independent type from input) + DTYPE_O* __restrict__ out_base = out + token_idx * d; static constexpr int32_t ooba_o = 4 / sizeof(DTYPE_O); const int32_t oob_o = (d + ooba_o - 1) / ooba_o * ooba_o; auto buffer_out = @@ -182,9 +178,10 @@ __global__ void act_and_mul_kernel(DTYPE_I* __restrict__ out, // [..., d } } -// Scaled activation and gating kernel template supporting fp32/bf16/fp16. -template -__global__ void scaled_act_and_mul_kernel(fp8_type* __restrict__ out, // [..., d] +// Scaled activation and gating kernel template with flexible output type. +// DTYPE_I: input type, DTYPE_O: output type (typically fp8 for quantization) +template +__global__ void scaled_act_and_mul_kernel(DTYPE_O* __restrict__ out, // [..., d] const DTYPE_I* __restrict__ input, // [..., 2, d] const int d, const float scale) @@ -238,14 +235,14 @@ __global__ void scaled_act_and_mul_kernel(fp8_type* __restrict__ out, // : "=v"(result) : "v"(act_vals), "v"(y_vals), "v"(scale_vals)); - out[token_idx * d + idx + j] = ck_tile::type_convert(result.x); - out[token_idx * d + idx + j + 1] = ck_tile::type_convert(result.y); + out[token_idx * d + idx + j] = ck_tile::type_convert(result.x); + out[token_idx * d + idx + j + 1] = ck_tile::type_convert(result.y); } else { DTYPE_I x_val = ck_tile::type_convert(x_compute[j]); float r = ACT_FN(x_val) * ck_tile::type_convert(y_compute[j]) * scale; - out[token_idx * d + idx + j] = ck_tile::type_convert(r); + out[token_idx * d + idx + j] = ck_tile::type_convert(r); } } } @@ -310,10 +307,10 @@ static constexpr int nextPow2(unsigned int num) const hipStream_t stream = at::hip::getCurrentHIPStream(); // Helper macro for fp32 vec_size dispatch (CK Tile only supports VEC_SIZE <= 16 for fp32) -#define DISPATCH_FP32_VEC_SIZE_CASE(VS, KERNEL_NAME, KERNEL, ...) \ - case VS: \ - aiter::KERNEL_NAME, VS> \ - <<>>(__VA_ARGS__); \ +#define DISPATCH_FP32_VEC_SIZE_CASE(VS, KERNEL_NAME, KERNEL, ...) \ + case VS: \ + aiter::KERNEL_NAME, VS> \ + <<>>(__VA_ARGS__); \ break; #define DISPATCH_FP32_KERNEL(KERNEL_NAME, KERNEL, ...) \ @@ -332,52 +329,108 @@ static constexpr int nextPow2(unsigned int num) #define DISPATCH_FP32_SCALED_ACT_KERNEL(KERNEL, out_ptr, in_ptr, inv_scale) \ DISPATCH_FP32_KERNEL(scaled_act_and_mul_kernel, KERNEL, out_ptr, in_ptr, d, inv_scale) -// Launch activation and gating kernel (fp32/bf16/fp16 unified, fp32→bf16 auto-convert) -#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ - COMPUTE_ACTIVATION_KERNEL_PARAMS \ - if(input.scalar_type() == at::ScalarType::Float) \ - { \ - /* fp32: limited VEC_SIZE due to CK Tile constraint */ \ - using input_dtype = ck_tile::fp32_t; \ - auto* out_ptr = reinterpret_cast(out.data_ptr()); \ - auto* in_ptr = reinterpret_cast(input.data_ptr()); \ - DISPATCH_FP32_ACT_KERNEL(KERNEL, out_ptr, in_ptr) \ - } \ - else \ - { \ - /* bf16/fp16: full VEC_SIZE support */ \ - AITER_DISPATCH_FLOATING16_TYPES(input.scalar_type(), "act_and_mul_kernel", [&] { \ - using input_dtype = typename t2ck::type; \ - AITER_DISPATCH_CASE_VEC_SIZE( \ - vec_size, \ - aiter::act_and_mul_kernel, VEC_SIZE> \ - <<>>(reinterpret_cast(out.data_ptr()), \ - reinterpret_cast(input.data_ptr()), \ - d);) \ - }); \ +// Helper macro to dispatch scaled kernel based on output type +#define DISPATCH_OUTPUT_TYPE_SCALED(KERNEL, in_ptr, inv_scale) \ + if(out.scalar_type() == at::ScalarType::BFloat16) \ + { \ + using output_dtype = ck_tile::bf16_t; \ + auto* out_ptr = reinterpret_cast(out.data_ptr()); \ + DISPATCH_FP32_SCALED_ACT_KERNEL(KERNEL, out_ptr, in_ptr, inv_scale) \ + } \ + else if(out.scalar_type() == at::ScalarType::Half) \ + { \ + using output_dtype = ck_tile::fp16_t; \ + auto* out_ptr = reinterpret_cast(out.data_ptr()); \ + DISPATCH_FP32_SCALED_ACT_KERNEL(KERNEL, out_ptr, in_ptr, inv_scale) \ + } \ + else if(out.scalar_type() == at::ScalarType::Float) \ + { \ + using output_dtype = ck_tile::fp32_t; \ + auto* out_ptr = reinterpret_cast(out.data_ptr()); \ + DISPATCH_FP32_SCALED_ACT_KERNEL(KERNEL, out_ptr, in_ptr, inv_scale) \ + } \ + else \ + { \ + /* fp8 output */ \ + using output_dtype = fp8_type; \ + auto* out_ptr = reinterpret_cast(out.data_ptr()); \ + DISPATCH_FP32_SCALED_ACT_KERNEL(KERNEL, out_ptr, in_ptr, inv_scale) \ + } + +// Launch activation and gating kernel with flexible input/output types +// Input and output types are determined by the tensor dtypes passed from Python +#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ + COMPUTE_ACTIVATION_KERNEL_PARAMS \ + if(input.scalar_type() == at::ScalarType::Float) \ + { \ + /* fp32 input: dispatch based on output type */ \ + using input_dtype = ck_tile::fp32_t; \ + auto* in_ptr = reinterpret_cast(input.data_ptr()); \ + if(out.scalar_type() == at::ScalarType::BFloat16) \ + { \ + using output_dtype = ck_tile::bf16_t; \ + auto* out_ptr = reinterpret_cast(out.data_ptr()); \ + DISPATCH_FP32_ACT_KERNEL(KERNEL, out_ptr, in_ptr) \ + } \ + else if(out.scalar_type() == at::ScalarType::Half) \ + { \ + using output_dtype = ck_tile::fp16_t; \ + auto* out_ptr = reinterpret_cast(out.data_ptr()); \ + DISPATCH_FP32_ACT_KERNEL(KERNEL, out_ptr, in_ptr) \ + } \ + else if(out.scalar_type() == at::ScalarType::Float) \ + { \ + using output_dtype = ck_tile::fp32_t; \ + auto* out_ptr = reinterpret_cast(out.data_ptr()); \ + DISPATCH_FP32_ACT_KERNEL(KERNEL, out_ptr, in_ptr) \ + } \ + else \ + { \ + TORCH_CHECK(false, "Unsupported output type for fp32 input"); \ + } \ + } \ + else \ + { \ + /* bf16/fp16 input: output must match input type */ \ + TORCH_CHECK(input.scalar_type() == out.scalar_type(), \ + "For bf16/fp16 input, output type must match input type"); \ + AITER_DISPATCH_FLOATING16_TYPES(input.scalar_type(), "act_and_mul_kernel", [&] { \ + using input_dtype = typename t2ck::type; \ + using output_dtype = input_dtype; \ + AITER_DISPATCH_CASE_VEC_SIZE( \ + vec_size, \ + aiter:: \ + act_and_mul_kernel, VEC_SIZE> \ + <<>>(reinterpret_cast(out.data_ptr()), \ + reinterpret_cast(input.data_ptr()), \ + d);) \ + }); \ } -// Launch scaled activation and gating kernel (fp32/bf16/fp16 unified, fp32→bf16 auto-convert) +// Launch scaled activation and gating kernel with flexible input/output types #define LAUNCH_SCALED_ACTIVATION_GATE_KERNEL(KERNEL) \ COMPUTE_ACTIVATION_KERNEL_PARAMS \ if(input.scalar_type() == at::ScalarType::Float) \ { \ - /* fp32: limited VEC_SIZE due to CK Tile constraint */ \ + /* fp32 input: dispatch based on output type (fp8/bf16/fp16/fp32) */ \ using input_dtype = ck_tile::fp32_t; \ - auto* out_ptr = reinterpret_cast(out.data_ptr()); \ auto* in_ptr = reinterpret_cast(input.data_ptr()); \ float inv_scale = 1.0f / (*scale.data_ptr()); \ - DISPATCH_FP32_SCALED_ACT_KERNEL(KERNEL, out_ptr, in_ptr, inv_scale) \ + DISPATCH_OUTPUT_TYPE_SCALED(KERNEL, in_ptr, inv_scale) \ } \ else \ { \ - /* bf16/fp16: full VEC_SIZE support */ \ + /* bf16/fp16 input: output typically fp8 for quantization */ \ AITER_DISPATCH_FLOATING16_TYPES(input.scalar_type(), "scaled_act_and_mul_kernel", [&] { \ - using input_dtype = typename t2ck::type; \ + using input_dtype = typename t2ck::type; \ + using output_dtype = fp8_type; \ AITER_DISPATCH_CASE_VEC_SIZE( \ vec_size, \ - aiter::scaled_act_and_mul_kernel, VEC_SIZE> \ - <<>>(reinterpret_cast(out.data_ptr()), \ + aiter::scaled_act_and_mul_kernel, \ + VEC_SIZE> \ + <<>>(reinterpret_cast(out.data_ptr()), \ reinterpret_cast(input.data_ptr()), \ d, \ 1.0f / (*scale.data_ptr()));) \ @@ -386,6 +439,10 @@ static constexpr int nextPow2(unsigned int num) namespace aiter { +// Flexible type conversion: +// - fp32 input can output as fp32/bf16/fp16 (determined by out.dtype) +// - bf16 input must output as bf16 +// - fp16 input must output as fp16 void silu_and_mul(torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d] { diff --git a/op_tests/test_activation.py b/op_tests/test_activation.py index e373c9a750..722d2a3c56 100644 --- a/op_tests/test_activation.py +++ b/op_tests/test_activation.py @@ -22,13 +22,21 @@ def torch_silu_and_mul(input: torch.Tensor) -> torch.Tensor: @benchmark() -def test_scaled_silu_and_mul(m, n, dtype): +def test_scaled_silu_and_mul(m, n, dtype, output_dtype=None): + """ + Test scaled_silu_and_mul with flexible input/output types. + If output_dtype is None, defaults to fp8 for quantization. + """ ret = {} input = torch.randn(m, n, dtype=dtype, device="cuda") scale = torch.max(input).to(torch.float32) - out = torch.empty((m, n // 2), dtype=dtypes.fp8, device="cuda") + out_dtype = output_dtype if output_dtype is not None else dtypes.fp8 + out = torch.empty((m, n // 2), dtype=out_dtype, device="cuda") - ref = torch_scaled_silu_and_mul(input, scale) + # Reference: compute, scale, convert to output dtype + d = input.shape[-1] // 2 + x, y = input.split([d, d], dim=-1) + ref = (F.silu(x) * y / scale).to(out_dtype) _, us_aiter = run_perftest( aiter.scaled_silu_and_mul, @@ -39,6 +47,16 @@ def test_scaled_silu_and_mul(m, n, dtype): # Check if the results are close err = checkAllclose(ref.to(torch.float), out.to(torch.float)) + + # Record input/output types for clarity + dtype_map = { + torch.float32: "fp32", + torch.float16: "fp16", + torch.bfloat16: "bf16", + dtypes.fp8: "fp8", + } + ret["input_dtype"] = dtype_map.get(dtype, str(dtype)) + ret["output_dtype"] = dtype_map.get(out_dtype, str(out_dtype)) ret["us"] = us_aiter ret["TB/s"] = (input.nbytes + out.nbytes) / us_aiter / 1e6 ret["RD TB/s"] = (input.nbytes) / us_aiter / 1e6 @@ -48,12 +66,20 @@ def test_scaled_silu_and_mul(m, n, dtype): @benchmark() -def test_silu_and_mul(m, n, dtype): +def test_silu_and_mul(m, n, dtype, output_dtype=None): + """ + Test silu_and_mul with flexible input/output types. + If output_dtype is None, output matches input dtype. + """ ret = {} input = torch.randn(m, n, dtype=dtype, device="cuda") - out = torch.empty((m, n // 2), dtype=dtype, device="cuda") + out_dtype = output_dtype if output_dtype is not None else dtype + out = torch.empty((m, n // 2), dtype=out_dtype, device="cuda") + # Reference: compute in input dtype, convert to output dtype if needed ref = torch_silu_and_mul(input) + if output_dtype is not None: + ref = ref.to(output_dtype) _, us_aiter = run_perftest( aiter.silu_and_mul, @@ -61,6 +87,41 @@ def test_silu_and_mul(m, n, dtype): input, ) + # Check if the results are close + err = checkAllclose(ref, out) + + # Record input/output types for clarity + dtype_map = {torch.float32: "fp32", torch.float16: "fp16", torch.bfloat16: "bf16"} + ret["input_dtype"] = dtype_map.get(dtype, str(dtype)) + ret["output_dtype"] = dtype_map.get(out_dtype, str(out_dtype)) + ret["us"] = us_aiter + ret["TB/s"] = (input.nbytes + out.nbytes) / us_aiter / 1e6 + ret["RD TB/s"] = (input.nbytes) / us_aiter / 1e6 + ret["WR TB/s"] = (out.nbytes) / us_aiter / 1e6 + ret["err"] = err + return ret + + +@benchmark() +def test_scaled_silu_and_mul_mixed_dtype(m, n, input_dtype, output_dtype): + """Test fp32 input with fp16/bf16 output for scaled activation""" + ret = {} + input = torch.randn(m, n, dtype=input_dtype, device="cuda") + scale = torch.max(input).to(torch.float32) + out = torch.empty((m, n // 2), dtype=output_dtype, device="cuda") + + # Reference: compute in fp32, scale, convert to output dtype + d = input.shape[-1] // 2 + x, y = input.split([d, d], dim=-1) + ref = (F.silu(x) * y / scale).to(output_dtype) + + _, us_aiter = run_perftest( + aiter.scaled_silu_and_mul, + out, + input, + scale, + ) + # Check if the results are close err = checkAllclose(ref, out) ret["us"] = us_aiter @@ -88,7 +149,7 @@ def test_silu_and_mul(m, n, dtype): const=None, default=None, help="""Data type. - e.g.: -d bf16""", + e.g.: -d bf16, -d fp32""", ) parser.add_argument( "-m", @@ -120,19 +181,35 @@ def test_silu_and_mul(m, n, dtype): l_n = [args.n] df = [] +# Standard same-dtype tests for dtype in l_dtype: for m in l_m: for n in l_n: ret = test_scaled_silu_and_mul(m, n, dtype) df.append(ret) +# Add fp32 input with fp16/bf16 output (bandwidth optimization) +for output_dtype in [torch.float16, torch.bfloat16]: + for m in l_m: + for n in l_n: + ret = test_scaled_silu_and_mul( + m, n, torch.float32, output_dtype=output_dtype + ) + df.append(ret) df = pd.DataFrame(df) aiter.logger.info(f"scaled_silu_and_mul summary:\n{df}") df = [] +# Standard same-dtype tests for dtype in l_dtype: for m in l_m: for n in l_n: ret = test_silu_and_mul(m, n, dtype) df.append(ret) +# Add fp32 input with fp16/bf16 output (bandwidth optimization) +for output_dtype in [torch.float16, torch.bfloat16]: + for m in l_m: + for n in l_n: + ret = test_silu_and_mul(m, n, torch.float32, output_dtype=output_dtype) + df.append(ret) df = pd.DataFrame(df) aiter.logger.info(f"silu_and_mul summary:\n{df}") From 5758cf0b4316e525b61fdd78b44938a4038061fe Mon Sep 17 00:00:00 2001 From: zufayu Date: Tue, 23 Dec 2025 10:54:36 +0800 Subject: [PATCH 5/7] bug fix --- csrc/kernels/activation_kernels.cu | 56 +++++++++--------------------- op_tests/test_activation.py | 44 +++++++++++++---------- 2 files changed, 41 insertions(+), 59 deletions(-) diff --git a/csrc/kernels/activation_kernels.cu b/csrc/kernels/activation_kernels.cu index a4ebc373dc..aeac28c94a 100644 --- a/csrc/kernels/activation_kernels.cu +++ b/csrc/kernels/activation_kernels.cu @@ -22,21 +22,17 @@ static constexpr int32_t max_vec_size = 8; static constexpr int32_t max_wave_num = 8; // Type trait for computation type (all compute in native type) -template -using compute_type_t = T; namespace aiter { // Activation and gating kernel template with flexible input/output types. // DTYPE_I: input type (fp32/bf16/fp16), DTYPE_O: output type (fp32/bf16/fp16) -// Computes in DTYPE_I native precision, converts to DTYPE_O on output. +// Computes in float, converts to DTYPE_O on output. template __global__ void act_and_mul_kernel(DTYPE_O* __restrict__ out, // [..., d] const DTYPE_I* __restrict__ input, // [..., 2, d] const int d) { - using DTYPE_COMPUTE = compute_type_t; - // CK Tile buffer addressing constraint: float supports VEC_SIZE <= 16 static_assert(!(std::is_same_v && VEC_SIZE_I > 16), "float type only supports VEC_SIZE up to 16"); @@ -45,7 +41,6 @@ __global__ void act_and_mul_kernel(DTYPE_O* __restrict__ out, // [..., d auto const* ptr_x = (input + token_idx * 2 * d); auto const* ptr_y = (input + token_idx * 2 * d + d); using vec_i = ck_tile::vec_t; - using vec_c = ck_tile::vec_t; using vec_o = ck_tile::vec_t; static constexpr int32_t ooba_i = 4 / sizeof(DTYPE_I); const int32_t oob_i = (d + ooba_i - 1) / ooba_i * ooba_i; @@ -135,24 +130,20 @@ __global__ void act_and_mul_kernel(DTYPE_O* __restrict__ out, // [..., d vec_i x = buffer_x.template get(idx, 0, true); vec_i y = buffer_y.template get(idx, 0, true); - // Compute directly in native type (DTYPE_I == DTYPE_COMPUTE) - const vec_c& x_compute = x; - const vec_c& y_compute = y; - vec_o r{}; #pragma unroll for(size_t j = 0; j < VEC_SIZE_I; j += 2) { // Call ACT_FN with appropriate type conversion - DTYPE_I x_val0 = ck_tile::type_convert(x_compute[j]); + DTYPE_I x_val0 = x[j]; float ax0 = ACT_FN(x_val0); - float y0 = ck_tile::type_convert(y_compute[j]); + float y0 = ck_tile::type_convert(y[j]); if(j + 1 < VEC_SIZE_I) { - DTYPE_I x_val1 = ck_tile::type_convert(x_compute[j + 1]); + DTYPE_I x_val1 = x[j + 1]; float ax1 = ACT_FN(x_val1); - float y1 = ck_tile::type_convert(y_compute[j + 1]); + float y1 = ck_tile::type_convert(y[j + 1]); ck_tile::fp32x2_t a = {ax0, ax1}; ck_tile::fp32x2_t b = {y0, y1}; ck_tile::fp32x2_t c; @@ -186,8 +177,6 @@ __global__ void scaled_act_and_mul_kernel(DTYPE_O* __restrict__ out, // const int d, const float scale) { - using DTYPE_COMPUTE = compute_type_t; - // CK Tile buffer addressing constraint: float supports VEC_SIZE <= 16 static_assert(!(std::is_same_v && VEC_SIZE_I > 16), "float type only supports VEC_SIZE up to 16"); @@ -196,7 +185,6 @@ __global__ void scaled_act_and_mul_kernel(DTYPE_O* __restrict__ out, // auto const* ptr_x = (input + token_idx * 2 * d); auto const* ptr_y = (input + token_idx * 2 * d + d); using vec_i = ck_tile::vec_t; - using vec_c = ck_tile::vec_t; static constexpr int32_t ooba_i = 4 / sizeof(DTYPE_I); const int32_t oob_i = (d + ooba_i - 1) / ooba_i * ooba_i; @@ -210,20 +198,16 @@ __global__ void scaled_act_and_mul_kernel(DTYPE_O* __restrict__ out, // vec_i x = buffer_x.template get(idx, 0, true); vec_i y = buffer_y.template get(idx, 0, true); - // Compute directly in native type (DTYPE_I == DTYPE_COMPUTE) - const vec_c& x_compute = x; - const vec_c& y_compute = y; - for(size_t j = 0; j < VEC_SIZE_I; j += 2) { if(j + 1 < VEC_SIZE_I) { - DTYPE_I x_val0 = ck_tile::type_convert(x_compute[j]); - DTYPE_I x_val1 = ck_tile::type_convert(x_compute[j + 1]); + DTYPE_I x_val0 = x[j]; + DTYPE_I x_val1 = x[j + 1]; float act_x0 = ACT_FN(x_val0); float act_x1 = ACT_FN(x_val1); - float y0 = ck_tile::type_convert(y_compute[j]); - float y1 = ck_tile::type_convert(y_compute[j + 1]); + float y0 = ck_tile::type_convert(y[j]); + float y1 = ck_tile::type_convert(y[j + 1]); float2 act_vals = {act_x0, act_x1}; float2 y_vals = {y0, y1}; @@ -240,8 +224,8 @@ __global__ void scaled_act_and_mul_kernel(DTYPE_O* __restrict__ out, // } else { - DTYPE_I x_val = ck_tile::type_convert(x_compute[j]); - float r = ACT_FN(x_val) * ck_tile::type_convert(y_compute[j]) * scale; + DTYPE_I x_val = x[j]; + float r = ACT_FN(x_val) * ck_tile::type_convert(y[j]) * scale; out[token_idx * d + idx + j] = ck_tile::type_convert(r); } } @@ -420,20 +404,12 @@ static constexpr int nextPow2(unsigned int num) } \ else \ { \ - /* bf16/fp16 input: output typically fp8 for quantization */ \ + /* bf16/fp16 input: dispatch based on output type (fp8/bf16/fp16/fp32) */ \ AITER_DISPATCH_FLOATING16_TYPES(input.scalar_type(), "scaled_act_and_mul_kernel", [&] { \ - using input_dtype = typename t2ck::type; \ - using output_dtype = fp8_type; \ - AITER_DISPATCH_CASE_VEC_SIZE( \ - vec_size, \ - aiter::scaled_act_and_mul_kernel, \ - VEC_SIZE> \ - <<>>(reinterpret_cast(out.data_ptr()), \ - reinterpret_cast(input.data_ptr()), \ - d, \ - 1.0f / (*scale.data_ptr()));) \ + using input_dtype = typename t2ck::type; \ + auto* in_ptr = reinterpret_cast(input.data_ptr()); \ + float inv_scale = 1.0f / (*scale.data_ptr()); \ + DISPATCH_OUTPUT_TYPE_SCALED(KERNEL, in_ptr, inv_scale) \ }); \ } diff --git a/op_tests/test_activation.py b/op_tests/test_activation.py index 722d2a3c56..fc263ada25 100644 --- a/op_tests/test_activation.py +++ b/op_tests/test_activation.py @@ -7,13 +7,6 @@ import argparse -def torch_scaled_silu_and_mul(input: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: - d = input.shape[-1] // 2 - x, y = input.split([d, d], dim=-1) - out = F.silu(x) * y / scale - return out.to(dtypes.fp8) - - def torch_silu_and_mul(input: torch.Tensor) -> torch.Tensor: d = input.shape[-1] // 2 x, y = input.split([d, d], dim=-1) @@ -57,6 +50,8 @@ def test_scaled_silu_and_mul(m, n, dtype, output_dtype=None): } ret["input_dtype"] = dtype_map.get(dtype, str(dtype)) ret["output_dtype"] = dtype_map.get(out_dtype, str(out_dtype)) + ret["M"] = m + ret["N"] = n ret["us"] = us_aiter ret["TB/s"] = (input.nbytes + out.nbytes) / us_aiter / 1e6 ret["RD TB/s"] = (input.nbytes) / us_aiter / 1e6 @@ -71,7 +66,6 @@ def test_silu_and_mul(m, n, dtype, output_dtype=None): Test silu_and_mul with flexible input/output types. If output_dtype is None, output matches input dtype. """ - ret = {} input = torch.randn(m, n, dtype=dtype, device="cuda") out_dtype = output_dtype if output_dtype is not None else dtype out = torch.empty((m, n // 2), dtype=out_dtype, device="cuda") @@ -92,8 +86,11 @@ def test_silu_and_mul(m, n, dtype, output_dtype=None): # Record input/output types for clarity dtype_map = {torch.float32: "fp32", torch.float16: "fp16", torch.bfloat16: "bf16"} + ret = {} ret["input_dtype"] = dtype_map.get(dtype, str(dtype)) ret["output_dtype"] = dtype_map.get(out_dtype, str(out_dtype)) + ret["M"] = m + ret["N"] = n ret["us"] = us_aiter ret["TB/s"] = (input.nbytes + out.nbytes) / us_aiter / 1e6 ret["RD TB/s"] = (input.nbytes) / us_aiter / 1e6 @@ -105,7 +102,6 @@ def test_silu_and_mul(m, n, dtype, output_dtype=None): @benchmark() def test_scaled_silu_and_mul_mixed_dtype(m, n, input_dtype, output_dtype): """Test fp32 input with fp16/bf16 output for scaled activation""" - ret = {} input = torch.randn(m, n, dtype=input_dtype, device="cuda") scale = torch.max(input).to(torch.float32) out = torch.empty((m, n // 2), dtype=output_dtype, device="cuda") @@ -122,8 +118,18 @@ def test_scaled_silu_and_mul_mixed_dtype(m, n, input_dtype, output_dtype): scale, ) - # Check if the results are close - err = checkAllclose(ref, out) + err = checkAllclose(ref.to(torch.float), out.to(torch.float)) + dtype_map = { + torch.float32: "fp32", + torch.float16: "fp16", + torch.bfloat16: "bf16", + dtypes.fp8: "fp8", + } + ret = {} + ret["input_dtype"] = dtype_map.get(input_dtype, str(input_dtype)) + ret["output_dtype"] = dtype_map.get(output_dtype, str(output_dtype)) + ret["M"] = m + ret["N"] = n ret["us"] = us_aiter ret["TB/s"] = (input.nbytes + out.nbytes) / us_aiter / 1e6 ret["RD TB/s"] = (input.nbytes) / us_aiter / 1e6 @@ -187,20 +193,17 @@ def test_scaled_silu_and_mul_mixed_dtype(m, n, input_dtype, output_dtype): for n in l_n: ret = test_scaled_silu_and_mul(m, n, dtype) df.append(ret) -# Add fp32 input with fp16/bf16 output (bandwidth optimization) -for output_dtype in [torch.float16, torch.bfloat16]: - for m in l_m: - for n in l_n: - ret = test_scaled_silu_and_mul( - m, n, torch.float32, output_dtype=output_dtype - ) - df.append(ret) df = pd.DataFrame(df) +df = df[ + ["M", "N", "input_dtype", "output_dtype", "us", "TB/s", "RD TB/s", "WR TB/s", "err"] +] aiter.logger.info(f"scaled_silu_and_mul summary:\n{df}") df = [] # Standard same-dtype tests for dtype in l_dtype: + if dtype == torch.float32: + continue for m in l_m: for n in l_n: ret = test_silu_and_mul(m, n, dtype) @@ -212,4 +215,7 @@ def test_scaled_silu_and_mul_mixed_dtype(m, n, input_dtype, output_dtype): ret = test_silu_and_mul(m, n, torch.float32, output_dtype=output_dtype) df.append(ret) df = pd.DataFrame(df) +df = df[ + ["M", "N", "input_dtype", "output_dtype", "us", "TB/s", "RD TB/s", "WR TB/s", "err"] +] aiter.logger.info(f"silu_and_mul summary:\n{df}") From 3ba2d828ea2c6ffaf54f3cf91e83e31a02b3eca3 Mon Sep 17 00:00:00 2001 From: zufayu Date: Tue, 23 Dec 2025 11:10:46 +0800 Subject: [PATCH 6/7] format code --- csrc/kernels/activation_kernels.cu | 45 +++++++++++++----------------- 1 file changed, 19 insertions(+), 26 deletions(-) diff --git a/csrc/kernels/activation_kernels.cu b/csrc/kernels/activation_kernels.cu index aeac28c94a..f5dc3edcfa 100644 --- a/csrc/kernels/activation_kernels.cu +++ b/csrc/kernels/activation_kernels.cu @@ -313,32 +313,25 @@ static constexpr int nextPow2(unsigned int num) #define DISPATCH_FP32_SCALED_ACT_KERNEL(KERNEL, out_ptr, in_ptr, inv_scale) \ DISPATCH_FP32_KERNEL(scaled_act_and_mul_kernel, KERNEL, out_ptr, in_ptr, d, inv_scale) -// Helper macro to dispatch scaled kernel based on output type -#define DISPATCH_OUTPUT_TYPE_SCALED(KERNEL, in_ptr, inv_scale) \ - if(out.scalar_type() == at::ScalarType::BFloat16) \ - { \ - using output_dtype = ck_tile::bf16_t; \ - auto* out_ptr = reinterpret_cast(out.data_ptr()); \ - DISPATCH_FP32_SCALED_ACT_KERNEL(KERNEL, out_ptr, in_ptr, inv_scale) \ - } \ - else if(out.scalar_type() == at::ScalarType::Half) \ - { \ - using output_dtype = ck_tile::fp16_t; \ - auto* out_ptr = reinterpret_cast(out.data_ptr()); \ - DISPATCH_FP32_SCALED_ACT_KERNEL(KERNEL, out_ptr, in_ptr, inv_scale) \ - } \ - else if(out.scalar_type() == at::ScalarType::Float) \ - { \ - using output_dtype = ck_tile::fp32_t; \ - auto* out_ptr = reinterpret_cast(out.data_ptr()); \ - DISPATCH_FP32_SCALED_ACT_KERNEL(KERNEL, out_ptr, in_ptr, inv_scale) \ - } \ - else \ - { \ - /* fp8 output */ \ - using output_dtype = fp8_type; \ - auto* out_ptr = reinterpret_cast(out.data_ptr()); \ - DISPATCH_FP32_SCALED_ACT_KERNEL(KERNEL, out_ptr, in_ptr, inv_scale) \ +// Helper macro to dispatch scaled kernel with restricted output types (fp8 or int8) +#define DISPATCH_OUTPUT_TYPE_SCALED(KERNEL, in_ptr, inv_scale) \ + if(out.scalar_type() == at::ScalarType::Float8_e4m3fn || \ + out.scalar_type() == at::ScalarType::Float8_e4m3fnuz || \ + out.scalar_type() == at::ScalarType::Float8_e5m2) \ + { \ + using output_dtype = fp8_type; \ + auto* out_ptr = reinterpret_cast(out.data_ptr()); \ + DISPATCH_FP32_SCALED_ACT_KERNEL(KERNEL, out_ptr, in_ptr, inv_scale) \ + } \ + else if(out.scalar_type() == at::ScalarType::Char) \ + { \ + using output_dtype = ck_tile::int8_t; \ + auto* out_ptr = reinterpret_cast(out.data_ptr()); \ + DISPATCH_FP32_SCALED_ACT_KERNEL(KERNEL, out_ptr, in_ptr, inv_scale) \ + } \ + else \ + { \ + TORCH_CHECK(false, "scaled_act_and_mul only supports fp8 or int8 outputs"); \ } // Launch activation and gating kernel with flexible input/output types From 449504c35e4f8ee9079ea41fc79984afdc7e414a Mon Sep 17 00:00:00 2001 From: chenjun Date: Tue, 23 Dec 2025 09:40:24 +0000 Subject: [PATCH 7/7] remove dtype convert before act_and_mul in fused_moe --- aiter/fused_moe.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index 3cbbbbf3c8..05a4583c4c 100644 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -1132,9 +1132,9 @@ def asm_stage1( ) if ksplit > 0: if activation == ActivationType.Silu: - aiter.silu_and_mul(out, tmp_out.view(dtypes.fp32).to(dtype)) + aiter.silu_and_mul(out, tmp_out.view(dtypes.fp32)) else: - aiter.gelu_and_mul(out, tmp_out.view(dtypes.fp32).to(dtype)) + aiter.gelu_and_mul(out, tmp_out.view(dtypes.fp32)) return out @@ -1446,9 +1446,9 @@ def ck_moe_stage1( ) if splitk > 1: if activation == ActivationType.Silu: - aiter.silu_and_mul(out, tmp_out.view(dtypes.fp32).to(out.dtype)) + aiter.silu_and_mul(out, tmp_out.view(dtypes.fp32)) else: - aiter.gelu_and_mul(out, tmp_out.view(dtypes.fp32).to(out.dtype)) + aiter.gelu_and_mul(out, tmp_out.view(dtypes.fp32)) return out