From ba9ff920cceb6e2e8f1cd101087102da2fcd2be3 Mon Sep 17 00:00:00 2001 From: Clement Lin Date: Fri, 5 Dec 2025 18:17:32 +0800 Subject: [PATCH 01/23] Add radix-base selection --- aiter/jit/optCompilerConfig.json | 3 +- aiter/ops/topk_plain.py | 7 +- csrc/include/opus/opus.hpp | 2 +- csrc/include/rocm_ops.hpp | 7 +- csrc/include/topk_plain.h | 9 +- csrc/kernels/topk_per_row_kernels.cu | 163 +++++++-- csrc/kernels/topk_plain_kernels.cu | 485 ++++++++++++++++++++++++--- op_tests/test_topk_plain.py | 128 ++++--- 8 files changed, 655 insertions(+), 149 deletions(-) diff --git a/aiter/jit/optCompilerConfig.json b/aiter/jit/optCompilerConfig.json index 6df6ed4da9..7efe53fc46 100755 --- a/aiter/jit/optCompilerConfig.json +++ b/aiter/jit/optCompilerConfig.json @@ -1074,7 +1074,8 @@ "module_topk_plain": { "srcs": [ "f'{AITER_CSRC_DIR}/pybind/topk_plain_pybind.cu'", - "f'{AITER_CSRC_DIR}/kernels/topk_plain_kernels.cu'" + "f'{AITER_CSRC_DIR}/kernels/topk_plain_kernels.cu'", + "f'{AITER_CSRC_DIR}/kernels/topk_per_row_kernels.cu'" ], "flags_extra_cc": [], "flags_extra_hip": [], diff --git a/aiter/ops/topk_plain.py b/aiter/ops/topk_plain.py index dea2c654b7..cd768b01e9 100644 --- a/aiter/ops/topk_plain.py +++ b/aiter/ops/topk_plain.py @@ -13,7 +13,12 @@ def topk_plain( x: torch.Tensor, topk_ids: torch.Tensor, + topk_out: torch.Tensor, topk: int, - largest: bool, + largest: bool = True, + rowStarts: torch.Tensor = None, + rowEnds: torch.Tensor = None, + stride0: int = -1, + stride1: int = 1, ) -> None: pass diff --git a/csrc/include/opus/opus.hpp b/csrc/include/opus/opus.hpp index bc3631e2a2..f2b96e4483 100644 --- a/csrc/include/opus/opus.hpp +++ b/csrc/include/opus/opus.hpp @@ -907,7 +907,7 @@ template<> OPUS_D float min(const float&a, const float&b) { return template OPUS_D T med3(const T&a, const T&b, const T&c) { auto max_0 = max(a, b); auto min_0 = max(a, b); return max(max_0, max(min_0, c)); } template<> OPUS_D float med3(const float&a, const float&b, const float&c) { return __builtin_amdgcn_fmed3f(a, b, c); } -template<> OPUS_D __fp16 med3<__fp16>(const __fp16&a, const __fp16&b, const __fp16&c) { return __builtin_amdgcn_fmed3h(a, b, c); } +template<> OPUS_D _Float16 med3<_Float16>(const _Float16&a, const _Float16&b, const _Float16&c) { return __builtin_amdgcn_fmed3h(a, b, c); } ///////////////////////////////////////////////////////////////////////////////////////////////////////// // buffer load/store related OPUS_D constexpr auto buffer_default_config() { diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index c21251f29a..9da94a7f9e 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -1470,5 +1470,10 @@ namespace py = pybind11; &topk_plain, \ py::arg("values"), \ py::arg("topk_ids"), \ + py::arg("topk_out"), \ py::arg("topk"), \ - py::arg("largest")); + py::arg("largest") = true, \ + py::arg("rowStarts") = torch::Tensor(), \ + py::arg("rowEnds") = torch::Tensor(), \ + py::arg("stride0") = -1, \ + py::arg("stride1") = 1); \ No newline at end of file diff --git a/csrc/include/topk_plain.h b/csrc/include/topk_plain.h index 5a658e491d..087c157196 100644 --- a/csrc/include/topk_plain.h +++ b/csrc/include/topk_plain.h @@ -6,5 +6,10 @@ void topk_plain(torch::Tensor& values, torch::Tensor& topk_ids, - int topk_num, - bool largest); + torch::Tensor& topk_out, + int topk, + bool largest = true, + torch::Tensor rowStarts = torch::Tensor(), + torch::Tensor rowEnds = torch::Tensor(), + int64_t stride0 = -1, + int64_t stride1 = 1); diff --git a/csrc/kernels/topk_per_row_kernels.cu b/csrc/kernels/topk_per_row_kernels.cu index 1b50ead6c9..c7bca0f4ca 100644 --- a/csrc/kernels/topk_per_row_kernels.cu +++ b/csrc/kernels/topk_per_row_kernels.cu @@ -414,7 +414,8 @@ __device__ void filter_and_histogram(T const* in_buf, IdxT* histogram, bool select_min, int pass, - bool early_stop) + bool early_stop, + IdxT k) { constexpr int num_buckets = calc_num_buckets(); __shared__ IdxT histogram_smem[num_buckets]; @@ -464,7 +465,8 @@ __device__ void filter_and_histogram(T const* in_buf, kth_value_bits, p_filter_cnt, p_out_cnt, - early_stop](T value, IdxT i, int&, int&, bool) { + early_stop, + k](T value, IdxT i, int&, int&, bool) { const auto previous_bits = (twiddle_in(value, select_min) >> previous_start_bit) << previous_start_bit; if(previous_bits == kth_value_bits) @@ -472,12 +474,14 @@ __device__ void filter_and_histogram(T const* in_buf, if(early_stop) { IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); - if(WRITE_TOPK_VALUES) + if (pos < k) { - out[pos] = value; + if(WRITE_TOPK_VALUES) + { + out[pos] = value; + } + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } - - out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } else { @@ -502,11 +506,14 @@ __device__ void filter_and_histogram(T const* in_buf, else if((out_buf || early_stop) && previous_bits < kth_value_bits) { IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); - if(WRITE_TOPK_VALUES) + if (pos < k) { - out[pos] = value; + if(WRITE_TOPK_VALUES) + { + out[pos] = value; + } + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } - out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } }; vectorized_process(static_cast(blockIdx.x) * blockDim.x + threadIdx.x, @@ -652,14 +659,17 @@ __device__ void last_filter(T const* in_buf, if(bits < kth_value_bits) { IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); - if(WRITE_TOPK_VALUES) + if (pos < k) { - out[pos] = value; + if(WRITE_TOPK_VALUES) + { + out[pos] = value; + } + // For one-block version, `in_idx_buf` could be nullptr at pass 0. + // For non one-block version, if writing has been skipped, `in_idx_buf` + // could be nullptr if `in_buf` is `in` + out_idx[pos] = in_idx_buf[i]; } - // For one-block version, `in_idx_buf` could be nullptr at pass 0. - // For non one-block version, if writing has been skipped, `in_idx_buf` - // could be nullptr if `in_buf` is `in` - out_idx[pos] = in_idx_buf[i]; } else if(bits == kth_value_bits) { @@ -691,14 +701,17 @@ __device__ void last_filter(T const* in_buf, if(bits < kth_value_bits) { IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); - if(WRITE_TOPK_VALUES) + if (pos < k) { - out[pos] = value; + if(WRITE_TOPK_VALUES) + { + out[pos] = value; + } + // For one-block version, `in_idx_buf` could be nullptr at pass 0. + // For non one-block version, if writing has been skipped, `in_idx_buf` + // could be nullptr if `in_buf` is `in` + out_idx[pos] = i; } - // For one-block version, `in_idx_buf` could be nullptr at pass 0. - // For non one-block version, if writing has been skipped, `in_idx_buf` - // could be nullptr if `in_buf` is `in` - out_idx[pos] = i; } else if(bits == kth_value_bits) { @@ -782,11 +795,14 @@ __global__ void last_filter_kernel(T const* in, if(bits < kth_value_bits) { IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); - if(WRITE_TOPK_VALUES) + if (pos < k) { - out[pos] = value; + if(WRITE_TOPK_VALUES) + { + out[pos] = value; + } + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } - out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } else if(bits == kth_value_bits) { @@ -878,7 +894,17 @@ __global__ void radix_kernel(T const* in, int const pass) { const int64_t batch_id = blockIdx.y; - const IdxT row_len = rowEnds[batch_id] - rowStarts[batch_id]; + + IdxT rowStart = 0; + IdxT rowEnd = len; + + if (rowStarts && rowEnds) + { + rowStart = rowStarts[batch_id]; + rowEnd = rowEnds[batch_id]; + } + + const IdxT row_len = rowEnd - rowStart; auto counter = counters + batch_id; IdxT current_k; @@ -948,7 +974,8 @@ __global__ void radix_kernel(T const* in, histogram, select_min, pass, - early_stop); + early_stop, + k); __threadfence(); bool isLastBlock = false; @@ -992,6 +1019,13 @@ __global__ void radix_kernel(T const* in, counter->previous_len = current_len; // not necessary for the last pass, but put it here anyway counter->filter_cnt = 0; + + counter->finished_block_cnt = 0; + if(pass == num_passes - 2) // Before the last pass + { + counter->out_cnt = 0; + counter->out_back_cnt = 0; + } } if(pass == num_passes - 1) @@ -1165,7 +1199,8 @@ __device__ bool filter_and_histogram_for_one_block(T const* in_buf, Counter* counter, IdxT* histogram, bool select_min, - int pass) + int pass, + IdxT k) { constexpr int num_buckets = calc_num_buckets(); for(int i = threadIdx.x; i < num_buckets * 2; i += blockDim.x) @@ -1284,11 +1319,14 @@ __device__ bool filter_and_histogram_for_one_block(T const* in_buf, else if(previous_bits < kth_value_bits) { IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); - if(WRITE_TOPK_VALUES) + if (pos < k) { - out[pos] = value; + if(WRITE_TOPK_VALUES) + { + out[pos] = value; + } + out_idx[pos] = in_idx_buf[i]; } - out_idx[pos] = in_idx_buf[i]; } } } @@ -1312,11 +1350,14 @@ __device__ bool filter_and_histogram_for_one_block(T const* in_buf, else if(previous_bits < kth_value_bits) { IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); - if(WRITE_TOPK_VALUES) + if(pos < k) { - out[pos] = value; + if(WRITE_TOPK_VALUES) + { + out[pos] = value; + } + out_idx[pos] = i; } - out_idx[pos] = i; } } } @@ -1347,9 +1388,18 @@ __global__ void radix_topk_one_block_kernel(T const* in, __shared__ IdxT histogram[num_buckets * 2]; const int64_t batch_id = blockIdx.x; - const IdxT rowStart = rowStarts[batch_id]; - const IdxT rowEnd = rowEnds[batch_id]; - const IdxT row_len = rowEnd - rowStart; + + IdxT rowStart = 0; + IdxT rowEnd = len; + + if (rowStart && rowEnds) + { + rowStart = rowStarts[batch_id]; + rowEnd = rowEnds[batch_id]; + } + + const IdxT row_len = rowEnd - rowStart; + if(threadIdx.x == 0) { counter.k = k; @@ -1422,7 +1472,8 @@ __global__ void radix_topk_one_block_kernel(T const* in, &counter, histogram, select_min, - pass); //@TODO CHECK UPDATE CODE + pass, + k); //@TODO CHECK UPDATE CODE __syncthreads(); scan(histogram + use_one_pass * num_buckets); @@ -2534,3 +2585,41 @@ void top_k_per_row_decode(const torch::Tensor& logits, } } } + +// Explicit template instantiations for use in topk_plain_kernels.cu +namespace aiter { + +// Instantiate standalone_stable_radix_11bits +template void standalone_stable_radix_11bits( + void* buf, + size_t& buf_size, + float const* in, + int batch_size, + int64_t len, + int* rowStarts, + int* rowEnds, + int k, + float* out, + int* out_idx, + bool greater, + hipStream_t stream); + +template void standalone_stable_radix_11bits( + void* buf, + size_t& buf_size, + float const* in, + int batch_size, + int64_t len, + int* rowStarts, + int* rowEnds, + int k, + float* out, + int* out_idx, + bool greater, + hipStream_t stream); + + +} // namespace aiter + +// Instantiate workspace size calculation function (at global scope) +template int64_t invokeComputeTopkLastDimWorkspaceSize(int32_t numRows, int32_t stride0); diff --git a/csrc/kernels/topk_plain_kernels.cu b/csrc/kernels/topk_plain_kernels.cu index 4bf732756c..7aa1cf6577 100644 --- a/csrc/kernels/topk_plain_kernels.cu +++ b/csrc/kernels/topk_plain_kernels.cu @@ -39,6 +39,7 @@ #include #include +#include "ck_tile/core.hpp" #include "dispatch_utils.h" #include "opus/opus.hpp" #include "py_itfs_common.h" @@ -49,6 +50,103 @@ utils::hip_check_((val), __FILE__, __LINE__); \ } +// Forward declaration of topk_per_row kernel from topk_per_row_kernels.cu +namespace aiter { +template +__global__ void topk_per_row(const float* logits, + const int* rowStarts, + const int* rowEnds, + int* outIndices, + int stride0, + int stride1, + int rowOffset); + +// Forward declaration of standalone_stable_radix_11bits from topk_per_row_kernels.cu +template +void standalone_stable_radix_11bits(void* buf, + size_t& buf_size, + T const* in, + int batch_size, + int64_t len, + IdxT* rowStarts, + IdxT* rowEnds, + IdxT k, + T* out, + IdxT* out_idx, + bool greater, + hipStream_t stream); + +} // namespace aiter + +// Forward declaration of workspace size calculation function (at global scope) +template +int64_t invokeComputeTopkLastDimWorkspaceSize(int32_t numRows, int32_t stride0); + +// Forward declaration of helper function to call topk_per_row kernel +template +void topk_per_row_kernel_launcher(const float* in, + const IdxT* rowStarts, + const IdxT* rowEnds, + IdxT* out_idx, + const float* out, + int batch_size, + int stride0, + int stride1, + int k, + hipStream_t stream); + +// Gather kernel to extract values based on indices (uniform length) +template +__global__ void gather_topk_values_kernel(const T* __restrict__ in, + const IdxT* __restrict__ indices, + T* __restrict__ out, + int batch_size, + int len, + int k) +{ + int batch_id = blockIdx.x; + if(batch_id >= batch_size) return; + + const T* in_row = in + batch_id * len; + const IdxT* idx_row = indices + batch_id * k; + T* out_row = out + batch_id * k; + + for(int i = threadIdx.x; i < k; i += blockDim.x) { + IdxT idx = idx_row[i]; + if(idx >= 0 && idx < len) { + out_row[i] = in_row[idx]; + } + } +} + +// Gather kernel for variable length with strides +template +__global__ void gather_topk_values_strided_kernel(const T* __restrict__ in, + const IdxT* __restrict__ indices, + T* __restrict__ out, + const IdxT* __restrict__ rowStarts, + int batch_size, + int stride0, + int stride1, + int k) +{ + int batch_id = blockIdx.x; + if(batch_id >= batch_size) return; + + IdxT start = rowStarts[batch_id]; + const T* in_row = in + batch_id * stride0; + const IdxT* idx_row = indices + batch_id * k; + T* out_row = out + batch_id * k; + + for(int i = threadIdx.x; i < k; i += blockDim.x) { + IdxT idx = idx_row[i]; + if(idx >= 0) { + // idx is relative to rowStart, need to add start and apply stride1 + out_row[i] = in_row[(start + idx) * stride1]; + } + } +} + namespace topk { namespace utils { @@ -205,7 +303,7 @@ namespace numeric { * value for all other arithmetic types. */ template -__inline__ constexpr T get_lower_bound() +__inline__ __host__ __device__ constexpr T get_lower_bound() { static_assert(utils::is_supported_type_v, "Unsupported type T: only _Float16, __bf16, float, and int are implemented"); @@ -219,7 +317,9 @@ __inline__ constexpr T get_lower_bound() } else if constexpr(std::is_same_v) { - return -__bf16(0x7F80); + // Use bit pattern for -inf to avoid __truncsfbf2 calls in debug builds + constexpr uint16_t neg_inf_bits = 0xFF80; // -infinity for bfloat16 + return __builtin_bit_cast(__bf16, neg_inf_bits); } else { @@ -234,7 +334,7 @@ __inline__ constexpr T get_lower_bound() * value for all other arithmetic types. */ template -__inline__ constexpr T get_upper_bound() +__inline__ __host__ __device__ constexpr T get_upper_bound() { static_assert(utils::is_supported_type_v, "Unsupported type T: only _Float16, __bf16, float, and int are implemented"); @@ -248,7 +348,9 @@ __inline__ constexpr T get_upper_bound() } else if constexpr(std::is_same_v) { - return __bf16(0x7F80); + // Use bit pattern for +inf to avoid __truncsfbf2 calls in debug builds + constexpr uint16_t pos_inf_bits = 0x7F80; // +infinity for bfloat16 + return __builtin_bit_cast(__bf16, pos_inf_bits); } else { @@ -265,7 +367,7 @@ __inline__ constexpr T get_upper_bound() * @tparam T The numeric type. */ template -__inline__ constexpr T get_sentinel_value() +__inline__ __host__ __device__ constexpr T get_sentinel_value() { if constexpr(FindLargest) { @@ -502,8 +604,11 @@ __forceinline__ __device__ constexpr T get_guard(const bool x) } else if constexpr(std::is_same_v) { - auto inf = __bf16(0x7F80); - return x ? -inf : inf; + // Use bit patterns to avoid __truncsfbf2 in debug builds + constexpr uint16_t pos_inf_bits = 0x7F80; // +infinity + constexpr uint16_t neg_inf_bits = 0xFF80; // -infinity + return x ? __builtin_bit_cast(__bf16, neg_inf_bits) + : __builtin_bit_cast(__bf16, pos_inf_bits); } else if constexpr(!std::is_floating_point_v) { @@ -709,7 +814,7 @@ struct BitonicMerge<64, ascending, T, idxT> namespace buffer_load_helpers { -constexpr int MAX_CAPACITY = 512; +constexpr int MAX_CAPACITY = 2048; using int32x4_t = int __attribute__((ext_vector_type(4))); using floatx4_t = float __attribute__((ext_vector_type(4))); @@ -868,7 +973,7 @@ struct WaveMergeHelper }; // Forward declarations for kernel wrapper functions -template +template __global__ void __launch_bounds__(512, 2) topk_filter_kernel(const T* __restrict__ in, const IdxT* __restrict__ in_idx, int batch_size, @@ -903,7 +1008,10 @@ template using KernelFuncPtr = void (*)(const T*, const IdxT*, int, IdxT, IdxT, T*, IdxT*, T); // Helper: Map block-level strategy class to its corresponding kernel function template -template