diff --git a/aiter/jit/optCompilerConfig.json b/aiter/jit/optCompilerConfig.json index ad1a9b3309..52e76078bc 100755 --- a/aiter/jit/optCompilerConfig.json +++ b/aiter/jit/optCompilerConfig.json @@ -1077,7 +1077,8 @@ "module_topk_plain": { "srcs": [ "f'{AITER_CSRC_DIR}/pybind/topk_plain_pybind.cu'", - "f'{AITER_CSRC_DIR}/kernels/topk_plain_kernels.cu'" + "f'{AITER_CSRC_DIR}/kernels/topk_plain_kernels.cu'", + "f'{AITER_CSRC_DIR}/kernels/topk_per_row_kernels.cu'" ], "flags_extra_cc": [], "flags_extra_hip": [], diff --git a/aiter/ops/topk_plain.py b/aiter/ops/topk_plain.py index dea2c654b7..cd768b01e9 100644 --- a/aiter/ops/topk_plain.py +++ b/aiter/ops/topk_plain.py @@ -13,7 +13,12 @@ def topk_plain( x: torch.Tensor, topk_ids: torch.Tensor, + topk_out: torch.Tensor, topk: int, - largest: bool, + largest: bool = True, + rowStarts: torch.Tensor = None, + rowEnds: torch.Tensor = None, + stride0: int = -1, + stride1: int = 1, ) -> None: pass diff --git a/csrc/include/opus/opus.hpp b/csrc/include/opus/opus.hpp index bc3631e2a2..f2b96e4483 100644 --- a/csrc/include/opus/opus.hpp +++ b/csrc/include/opus/opus.hpp @@ -907,7 +907,7 @@ template<> OPUS_D float min(const float&a, const float&b) { return template OPUS_D T med3(const T&a, const T&b, const T&c) { auto max_0 = max(a, b); auto min_0 = max(a, b); return max(max_0, max(min_0, c)); } template<> OPUS_D float med3(const float&a, const float&b, const float&c) { return __builtin_amdgcn_fmed3f(a, b, c); } -template<> OPUS_D __fp16 med3<__fp16>(const __fp16&a, const __fp16&b, const __fp16&c) { return __builtin_amdgcn_fmed3h(a, b, c); } +template<> OPUS_D _Float16 med3<_Float16>(const _Float16&a, const _Float16&b, const _Float16&c) { return __builtin_amdgcn_fmed3h(a, b, c); } ///////////////////////////////////////////////////////////////////////////////////////////////////////// // buffer load/store related OPUS_D constexpr auto buffer_default_config() { diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index 908865ae07..c8262eed72 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -1635,10 +1635,15 @@ namespace py = pybind11; py::arg("final_output"), \ py::arg("final_lse") = std::nullopt); -#define TOPK_PLAIN_PYBIND \ - m.def("topk_plain", \ - &topk_plain, \ - py::arg("values"), \ - py::arg("topk_ids"), \ - py::arg("topk"), \ - py::arg("largest")); +#define TOPK_PLAIN_PYBIND \ + m.def("topk_plain", \ + &topk_plain, \ + py::arg("values"), \ + py::arg("topk_ids"), \ + py::arg("topk_out"), \ + py::arg("topk"), \ + py::arg("largest") = true, \ + py::arg("rowStarts") = torch::Tensor(), \ + py::arg("rowEnds") = torch::Tensor(), \ + py::arg("stride0") = -1, \ + py::arg("stride1") = 1); diff --git a/csrc/include/topk_plain.h b/csrc/include/topk_plain.h index 5a658e491d..087c157196 100644 --- a/csrc/include/topk_plain.h +++ b/csrc/include/topk_plain.h @@ -6,5 +6,10 @@ void topk_plain(torch::Tensor& values, torch::Tensor& topk_ids, - int topk_num, - bool largest); + torch::Tensor& topk_out, + int topk, + bool largest = true, + torch::Tensor rowStarts = torch::Tensor(), + torch::Tensor rowEnds = torch::Tensor(), + int64_t stride0 = -1, + int64_t stride1 = 1); diff --git a/csrc/kernels/topk_per_row_kernels.cu b/csrc/kernels/topk_per_row_kernels.cu index 14eae78163..89331c52df 100644 --- a/csrc/kernels/topk_per_row_kernels.cu +++ b/csrc/kernels/topk_per_row_kernels.cu @@ -420,7 +420,8 @@ __device__ void filter_and_histogram(T const* in_buf, IdxT* histogram, bool select_min, int pass, - bool early_stop) + bool early_stop, + IdxT k) { constexpr int num_buckets = calc_num_buckets(); __shared__ IdxT histogram_smem[num_buckets]; @@ -893,9 +894,19 @@ __global__ void radix_kernel(T const* in, int const pass) { const int64_t batch_id = blockIdx.y; - const IdxT row_len = phase == Phase::Prefill - ? rowEnds[batch_id] - rowStarts[batch_id] - : rowEnds[batch_id / next_n] - next_n + (batch_id % next_n) + 1; + + IdxT row_len = len; + if(phase == Phase::Prefill) + { + if(rowStarts && rowEnds) + { + row_len = rowEnds[batch_id] - rowStarts[batch_id]; + } + } + else + { + row_len = rowEnds[batch_id / next_n] - next_n + (batch_id % next_n) + 1; + } auto counter = counters + batch_id; IdxT current_k; @@ -965,7 +976,8 @@ __global__ void radix_kernel(T const* in, histogram, select_min, pass, - early_stop); + early_stop, + k); __threadfence(); bool isLastBlock = false; @@ -1187,7 +1199,8 @@ __device__ bool filter_and_histogram_for_one_block(T const* in_buf, Counter* counter, IdxT* histogram, bool select_min, - int pass) + int pass, + IdxT k) { constexpr int num_buckets = calc_num_buckets(); for(int i = threadIdx.x; i < num_buckets * 2; i += blockDim.x) @@ -1371,11 +1384,25 @@ __global__ void radix_topk_one_block_kernel(T const* in, __shared__ IdxT histogram[num_buckets * 2]; const int64_t batch_id = blockIdx.x; - const IdxT rowStart = phase == Phase::Prefill ? rowStarts[batch_id] : 0; - const IdxT rowEnd = phase == Phase::Prefill - ? rowEnds[batch_id] - : rowEnds[batch_id / next_n] - next_n + (batch_id % next_n) + 1; - const IdxT row_len = rowEnd - rowStart; + + IdxT rowStart = 0; + IdxT rowEnd = len; + if(phase == Phase::Prefill) + { + if(rowStarts && rowEnds) + { + rowStart = rowStarts[batch_id]; + rowEnd = rowEnds[batch_id]; + } + } + else + { + rowEnd = rowEnds[batch_id / next_n] - next_n + (batch_id % next_n) + 1; + rowStart = 0; + } + + const IdxT row_len = rowEnd - rowStart; + if(threadIdx.x == 0) { counter.k = k; @@ -1448,7 +1475,8 @@ __global__ void radix_topk_one_block_kernel(T const* in, &counter, histogram, select_min, - pass); //@TODO CHECK UPDATE CODE + pass, + k); //@TODO CHECK UPDATE CODE __syncthreads(); scan(histogram + use_one_pass * num_buckets); @@ -1811,6 +1839,35 @@ void standalone_stable_radix_11bits(void* buf, } } +// Explicit template instantiation for standalone_stable_radix_11bits +template void standalone_stable_radix_11bits(void* buf, + size_t& buf_size, + float const* in, + int batch_size, + int64_t len, + int* rowStarts, + int* rowEnds, + int k, + float* out, + int* out_idx, + bool greater, + hipStream_t stream, + int next_n); + +template void standalone_stable_radix_11bits(void* buf, + size_t& buf_size, + float const* in, + int batch_size, + int64_t len, + int* rowStarts, + int* rowEnds, + int k, + float* out, + int* out_idx, + bool greater, + hipStream_t stream, + int next_n); + // AIR TopK end static inline __device__ uint32_t floatAsSortableUint(float x) @@ -2410,6 +2467,9 @@ int64_t invokeComputeTopkLastDimWorkspaceSize(int32_t numRows, int32_t stride0) return buf_size; } +// Explicit template instantiation to ensure the symbol is available for linking +template int64_t invokeComputeTopkLastDimWorkspaceSize(int32_t numRows, int32_t stride0); + void top_k_per_row_prefill(const torch::Tensor& logits, const torch::Tensor& rowStarts, const torch::Tensor& rowEnds, diff --git a/csrc/kernels/topk_plain_kernels.cu b/csrc/kernels/topk_plain_kernels.cu index 4bf732756c..7c03823ae0 100644 --- a/csrc/kernels/topk_plain_kernels.cu +++ b/csrc/kernels/topk_plain_kernels.cu @@ -49,10 +49,251 @@ utils::hip_check_((val), __FILE__, __LINE__); \ } +// Forward declaration of topk_per_row kernel from topk_per_row_kernels.cu +namespace aiter { + +// Phase enum for distinguishing prefill vs decode paths +enum class Phase +{ + Prefill, + Decode, +}; + +template +__global__ void topk_per_row(const float* logits, + const int* rowStarts, + const int* rowEnds, + int* outIndices, + int stride0, + int stride1, + int rowOffset); + +// Forward declaration of standalone_stable_radix_11bits from topk_per_row_kernels.cu +template +void standalone_stable_radix_11bits(void* buf, + size_t& buf_size, + T const* in, + int batch_size, + int64_t len, + IdxT* rowStarts, + IdxT* rowEnds, + IdxT k, + T* out, + IdxT* out_idx, + bool greater, + hipStream_t stream, + int next_n = 0); + +} // namespace aiter + +// Forward declaration of workspace size calculation function (at global scope) +template +int64_t invokeComputeTopkLastDimWorkspaceSize(int32_t numRows, int32_t stride0); +extern template int64_t +invokeComputeTopkLastDimWorkspaceSize(int32_t numRows, + int32_t stride0); + +// Forward declaration of helper function to call topk_per_row kernel +template +void topk_per_row_kernel_launcher(const float* in, + const IdxT* rowStarts, + const IdxT* rowEnds, + IdxT* out_idx, + const float* out, + int batch_size, + int stride0, + int stride1, + int k, + hipStream_t stream); + +// Helper function to determine if topk_per_row kernel should be used +// Based on: n + K log²K ≥ 3 × Factor(n) × n +// where Factor(n) = 1/3 + 1.6/(log₂(n) - 9.5) +// Simplifies to: K log²K ≥ 4.8n/(log₂(n) - 9.5) +// TODO: We need to confirm whether, when n <= 2048, we might choose +// radix sort because the denominator becomes very small; does that +// still yield the best performance? +template +__forceinline__ __host__ bool should_use_topk_radix(IdxT len, IdxT k) +{ + const double n = static_cast(len); + const double K = static_cast(k); + + if(K <= 1.0) + { + return false; + } + + const double log_n = std::log2(n); + + const double denom = std::max(0.0001, log_n - 9.5); + + const double rhs = (4.8 * n) / denom; + + const double log_k = std::log2(K); + const double lhs = K * log_k * log_k; + + return lhs >= rhs; +} + +// Gather kernel to extract values based on indices (uniform length) +template +__global__ void gather_topk_values_kernel(const T* __restrict__ in, + const IdxT* __restrict__ indices, + T* __restrict__ out, + int batch_size, + int len, + int k) +{ + int batch_id = blockIdx.x; + if(batch_id >= batch_size) + return; + + const T* in_row = in + batch_id * len; + const IdxT* idx_row = indices + batch_id * k; + T* out_row = out + batch_id * k; + + for(int i = threadIdx.x; i < k; i += blockDim.x) + { + IdxT idx = idx_row[i]; + if(idx >= 0 && idx < len) + { + out_row[i] = in_row[idx]; + } + } +} + +// Gather kernel for variable length with strides +template +__global__ void gather_topk_values_strided_kernel(const T* __restrict__ in, + const IdxT* __restrict__ indices, + T* __restrict__ out, + const IdxT* __restrict__ rowStarts, + int batch_size, + int stride0, + int stride1, + int k) +{ + int batch_id = blockIdx.x; + if(batch_id >= batch_size) + return; + + IdxT start = rowStarts[batch_id]; + const T* in_row = in + batch_id * stride0; + const IdxT* idx_row = indices + batch_id * k; + T* out_row = out + batch_id * k; + + for(int i = threadIdx.x; i < k; i += blockDim.x) + { + IdxT idx = idx_row[i]; + if(idx >= 0) + { + // idx is relative to rowStart, need to add start and apply stride1 + out_row[i] = in_row[(start + idx) * stride1]; + } + } +} + namespace topk { + +// ============================================================================ +// TYPE TRAITS FOR DATA/COMPUTE TYPE SEPARATION +// ============================================================================ +// +// Design Philosophy: +// - DataType (DataT): The storage/I/O type for memory operations +// - ComputeType (ComputeT): The type used for internal computations +// +// Mapping: +// - fp16, bf16, float -> compute as float (better precision, consistent ops) +// - int -> compute as int +// +// This separation allows: +// 1. Memory-efficient storage with compact types (fp16, bf16) +// 2. High-precision computation with float +// 3. Easy extension for new types (e.g., fp8, int8) +// +// Usage: +// using ComputeT = compute_t; +// ComputeT val = type_convert::to_compute(data_val); +// DataT result = type_convert::to_data(compute_val); +// ============================================================================ + +namespace type_traits { + +// Primary template: maps DataType -> ComputeType +template +struct ComputeTypeTraits +{ + static_assert(sizeof(DataT) == 0, + "ComputeTypeTraits not specialized for this type. " + "Supported types: _Float16, __bf16, float, int"); +}; + +// Specializations for floating-point types -> float +template <> +struct ComputeTypeTraits<_Float16> +{ + using type = float; +}; + +template <> +struct ComputeTypeTraits<__bf16> +{ + using type = float; +}; + +template <> +struct ComputeTypeTraits +{ + using type = float; +}; + +// Specialization for integer types -> int +template <> +struct ComputeTypeTraits +{ + using type = int; +}; + +// Convenience alias +template +using compute_t = typename ComputeTypeTraits::type; + +} // namespace type_traits + +// Bring compute_t into topk namespace for convenience +using type_traits::compute_t; + +// ============================================================================ +// TYPE CONVERSION UTILITIES +// ============================================================================ + +namespace type_convert { + +// Convert from DataType to ComputeType +template +__device__ __host__ __forceinline__ type_traits::compute_t to_compute(DataT val) +{ + return static_cast>(val); +} + +// Convert from ComputeType to DataType +template +__device__ __host__ __forceinline__ DataT to_data(type_traits::compute_t val) +{ + return static_cast(val); +} + +} // namespace type_convert + namespace utils { -// Supported types +// Supported types (for validation) template struct is_supported_type { @@ -198,60 +439,62 @@ __inline__ __host__ __device__ constexpr int calc_capacity(int k) namespace numeric { +// ============================================================================ +// BOUNDS AND SENTINEL VALUES +// ============================================================================ +// These functions now work with ComputeType for internal operations. +// The sentinel values are defined in ComputeType space (float for floating-point +// DataTypes, int for integer DataTypes). +// ============================================================================ + /** - * @brief Gets the absolute lowest possible value for a numeric type T. + * @brief Gets the absolute lowest possible value for a compute type. + * + * Uses -infinity for floating-point compute types, and the lowest finite + * value for integer compute types. * - * Uses -infinity for signed floating-point types, and the lowest finite - * value for all other arithmetic types. + * @tparam ComputeT The compute type (float or int). */ -template -__inline__ constexpr T get_lower_bound() +template +__inline__ __device__ __host__ constexpr ComputeT get_lower_bound() { - static_assert(utils::is_supported_type_v, - "Unsupported type T: only _Float16, __bf16, float, and int are implemented"); - if constexpr(std::is_floating_point_v && std::is_signed_v) - { - return -std::numeric_limits::infinity(); - } - else if constexpr(std::is_integral_v) + if constexpr(std::is_same_v) { - return std::numeric_limits::lowest(); + return -std::numeric_limits::infinity(); } - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) { - return -__bf16(0x7F80); + return std::numeric_limits::lowest(); } else { + static_assert(sizeof(ComputeT) == 0, "Unsupported compute type"); __builtin_unreachable(); } } /** - * @brief Gets the absolute highest possible value for a numeric type T. + * @brief Gets the absolute highest possible value for a compute type. + * + * Uses +infinity for floating-point compute types, and the maximum finite + * value for integer compute types. * - * Uses +infinity for floating-point types, and the maximum finite - * value for all other arithmetic types. + * @tparam ComputeT The compute type (float or int). */ -template -__inline__ constexpr T get_upper_bound() +template +__inline__ __device__ __host__ constexpr ComputeT get_upper_bound() { - static_assert(utils::is_supported_type_v, - "Unsupported type T: only _Float16, __bf16, float, and int are implemented"); - if constexpr(std::is_floating_point_v) - { - return std::numeric_limits::infinity(); - } - else if constexpr(std::is_integral_v) + if constexpr(std::is_same_v) { - return std::numeric_limits::max(); + return std::numeric_limits::infinity(); } - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) { - return __bf16(0x7F80); + return std::numeric_limits::max(); } else { + static_assert(sizeof(ComputeT) == 0, "Unsupported compute type"); __builtin_unreachable(); } } @@ -259,42 +502,56 @@ __inline__ constexpr T get_upper_bound() /** * @brief Gets a sentinel value for a search algorithm (e.g., Top-K). * - * @tparam FindLargest A compile-time boolean. If true, returns the lowest possible - * value (the starting point for finding a maximum). If false, returns the - * highest possible value (the starting point for finding a minimum). - * @tparam T The numeric type. + * The sentinel is defined in ComputeType space. For finding the largest values, + * we use the lowest possible value as sentinel (so any real value will be preferred). + * For finding the smallest values, we use the highest possible value. + * + * @tparam FindLargest If true, returns lowest value. If false, returns highest value. + * @tparam ComputeT The compute type (float or int). */ -template -__inline__ constexpr T get_sentinel_value() +template +__inline__ __device__ __host__ constexpr ComputeT get_sentinel_value() { if constexpr(FindLargest) { - static_assert( - !std::is_unsigned_v, - "Cannot determine a meaningful lower bound for finding the 'largest' unsigned value. " - "The lowest value is 0, which is a poor sentinel."); - return get_lower_bound(); + return get_lower_bound(); } else { - return get_upper_bound(); + return get_upper_bound(); } } /** - * @brief A generic comparison function for search algorithms. 💡 + * @brief Gets sentinel value based on DataType (converts to appropriate ComputeType). + * + * This is a convenience overload that deduces the ComputeType from DataType. + * + * @tparam FindLargest If true, returns lowest value. If false, returns highest value. + * @tparam DataT The data type (fp16, bf16, float, int). + */ +template +__inline__ __device__ __host__ constexpr compute_t get_sentinel_value_for_data() +{ + return get_sentinel_value>(); +} + +/** + * @brief A generic comparison function for search algorithms. * * Compares `val` against `baseline` according to the search direction * specified by the `FindLargest` template parameter. + * Works with ComputeType values. * * @tparam FindLargest If true, checks if `val` is greater than `baseline`. - * If false, checks if `val` is less than `baseline`. + * If false, checks if `val` is less than `baseline`. + * @tparam ComputeT The compute type (float or int). * @param val The new value to check. * @param baseline The current best value. * @return True if `val` is "preferred" over `baseline`. */ -template -__device__ __host__ constexpr bool is_preferred(T val, T baseline) +template +__device__ __host__ __forceinline__ constexpr bool is_preferred(ComputeT val, ComputeT baseline) { if constexpr(FindLargest) { @@ -310,6 +567,19 @@ __device__ __host__ constexpr bool is_preferred(T val, T baseline) namespace sorting { +// ============================================================================ +// SORTING OPERATIONS (Work with ComputeType) +// ============================================================================ +// All sorting operations in this namespace work with ComputeType values. +// The template parameter T should be the compute type (float or int). +// The idxT parameter is the index type (typically int32_t). +// +// The sorting algorithms use: +// - DPP (Data Parallel Primitives) for small-stride shuffles (≤8) +// - Wave intrinsics (__ballot, __popcll, __shfl) for larger operations +// - Bitonic sort/merge for efficient parallel sorting +// ============================================================================ + template struct BitonicMerge { @@ -492,26 +762,30 @@ __forceinline__ __device__ T shfl_xor(T val, int stride) } } -template -__forceinline__ __device__ constexpr T get_guard(const bool x) +/** + * @brief Gets guard value for bitonic sort comparisons. + * + * This function returns boundary values used in bitonic sorting. + * Works with ComputeType (float or int). + * + * @tparam ComputeT The compute type (float or int). + * @param x If true, returns lowest value; if false, returns highest value. + */ +template +__forceinline__ __device__ constexpr ComputeT get_guard(const bool x) { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) { - auto inf = _Float16(0x7C00); - return x ? -inf : inf; + return x ? -std::numeric_limits::infinity() : std::numeric_limits::infinity(); } - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) { - auto inf = __bf16(0x7F80); - return x ? -inf : inf; - } - else if constexpr(!std::is_floating_point_v) - { - return x ? std::numeric_limits::lowest() : std::numeric_limits::max(); + return x ? std::numeric_limits::lowest() : std::numeric_limits::max(); } else { - return x ? -std::numeric_limits::infinity() : std::numeric_limits::infinity(); + static_assert(sizeof(ComputeT) == 0, "get_guard only supports float and int compute types"); + __builtin_unreachable(); } } @@ -709,14 +983,27 @@ struct BitonicMerge<64, ascending, T, idxT> namespace buffer_load_helpers { -constexpr int MAX_CAPACITY = 512; +constexpr int MAX_CAPACITY = 2048; using int32x4_t = int __attribute__((ext_vector_type(4))); using floatx4_t = float __attribute__((ext_vector_type(4))); -using bf16x8_t = uint16_t __attribute__((ext_vector_type(8))); +using bf16x8_t = __bf16 __attribute__((ext_vector_type(8))); using halfx8_t = _Float16 __attribute__((ext_vector_type(8))); using index_t = uint32_t; +__device__ __forceinline__ static int32x4_t +asm_buffer_load_dwordx4(int32x4_t srsrc, + int32_t voffset, + int32_t soffset, + int32_t aux) __asm("llvm.amdgcn.raw.buffer.load.v4i32"); + +template +__device__ __forceinline__ VecType +buffer_load_dwordx4(int32x4_t srsrc, int32_t voffset, int32_t soffset, int32_t aux) +{ + return __builtin_bit_cast(VecType, asm_buffer_load_dwordx4(srsrc, voffset, soffset, aux)); +} + } // namespace buffer_load_helpers // --- Wave-Level Priority Selection Primitives (AMD/HIP Optimized) --- @@ -766,21 +1053,39 @@ struct BlockTopkSort; template struct BlockTopkMerge; -// WaveBuffer: Manages per-wave register storage for priority candidates -template +// ============================================================================ +// WAVE BUFFER (Stores priorities in ComputeType) +// ============================================================================ +// +// WaveBuffer manages per-wave register storage for priority candidates. +// Key design: +// - DataT: The I/O type for loading/storing data +// - ComputeT: The internal type for priorities (float or int) +// - Priorities are stored as ComputeType for consistent computation +// - Conversion happens at I/O boundaries +// +// Template parameters: +// - capacity: Power-of-2 buffer capacity (>= wave size) +// - DataT: Data type for I/O (fp16, bf16, float, int) +// - IdxT: Index type (typically int32_t) +// ============================================================================ + +template struct WaveBuffer { + using ComputeT = compute_t; + static constexpr int slots_per_lane = capacity / opus::get_warp_size(); static_assert(capacity >= opus::get_warp_size() && utils::is_power_of_2(capacity), "Capacity must be power-of-2 and >= wave size"); - T priorities[slots_per_lane]; + ComputeT priorities[slots_per_lane]; IdxT positions[slots_per_lane]; int lane_id; IdxT target_count; - T sentinel; + ComputeT sentinel; - __device__ WaveBuffer(IdxT k, T sentinel_value) + __device__ WaveBuffer(IdxT k, ComputeT sentinel_value) : lane_id(threadIdx.x & (opus::get_warp_size() - 1)), target_count(k), sentinel(sentinel_value) @@ -792,13 +1097,16 @@ struct WaveBuffer } } - __device__ inline void reset_slot(int slot, T val = {}, IdxT pos = {}) + __device__ inline void reset_slot(int slot, ComputeT val = {}, IdxT pos = {}) { priorities[slot] = val; positions[slot] = pos; } - __device__ inline void flush_results(T* __restrict__ out_vals, + // Flush results to output buffer + // OutT can be DataT (for final output) or ComputeT (for LDS operations) + template + __device__ inline void flush_results(OutT* __restrict__ out_vals, IdxT* __restrict__ out_indices) const { #pragma unroll @@ -807,7 +1115,7 @@ struct WaveBuffer const IdxT global_slot = i * opus::get_warp_size() + lane_id; if(global_slot < target_count) { - out_vals[global_slot] = priorities[i]; + out_vals[global_slot] = static_cast(priorities[i]); out_indices[global_slot] = positions[i]; } } @@ -815,10 +1123,14 @@ struct WaveBuffer }; // Helper for merging sorted sequences (used by multiple strategies) -template +// Works with ComputeType internally, reads from ComputeType buffers +template struct WaveMergeHelper { + using ComputeT = compute_t; + // Merges a sorted k-element chunk with the buffer's existing Top-K + // Input is in ComputeType (from LDS or previous computation) // EXAMPLE (finding Top-4 largest, capacity=64, k=4): // Wave-distributed storage (64 lanes, each lane holds slots_per_lane=1 value): // Lanes 0-3: [80, 85, 90, 95] (current top-4, in ascending order) @@ -843,8 +1155,8 @@ struct WaveMergeHelper // // Extract top-k=4 (last 4 in ascending order): // Lanes 60-63 now contain: [85, 90, 95, 100] - __device__ static void merge_sorted_range(WaveBuffer& buffer, - const T* __restrict__ in, + __device__ static void merge_sorted_range(WaveBuffer& buffer, + const ComputeT* __restrict__ in, const IdxT* __restrict__ in_idx, IdxT start) { @@ -854,56 +1166,64 @@ struct WaveMergeHelper { if(idx < start + buffer.target_count) { - T candidate = in[idx]; - if(numeric::is_preferred(candidate, buffer.priorities[i])) + ComputeT candidate = in[idx]; + if(numeric::is_preferred(candidate, buffer.priorities[i])) { buffer.priorities[i] = candidate; buffer.positions[i] = in_idx[idx]; } } } - sorting::BitonicMerge::merge(buffer.priorities, - buffer.positions); + sorting::BitonicMerge::merge(buffer.priorities, + buffer.positions); } }; // Forward declarations for kernel wrapper functions -template -__global__ void __launch_bounds__(512, 2) topk_filter_kernel(const T* __restrict__ in, +// Note: Kernels use DataT for I/O and compute_t for sentinel/internal computation +template +__global__ void __launch_bounds__(512, 2) topk_filter_kernel(const DataT* __restrict__ in, const IdxT* __restrict__ in_idx, int batch_size, IdxT len, IdxT k, - T* __restrict__ out, + DataT* __restrict__ out, IdxT* __restrict__ out_idx, - T sentinel); + compute_t sentinel); -template -__global__ void __launch_bounds__(512, 2) topk_sort_kernel(const T* __restrict__ in, +template +__global__ void __launch_bounds__(512, 2) topk_sort_kernel(const DataT* __restrict__ in, const IdxT* __restrict__ in_idx, int batch_size, IdxT len, IdxT k, - T* __restrict__ out, + DataT* __restrict__ out, IdxT* __restrict__ out_idx, - T sentinel); + compute_t sentinel); -template -__global__ void __launch_bounds__(512, 2) topk_merge_kernel(const T* __restrict__ in, +template +__global__ void __launch_bounds__(512, 2) topk_merge_kernel(const DataT* __restrict__ in, const IdxT* __restrict__ in_idx, int batch_size, IdxT len, IdxT k, - T* __restrict__ out, + DataT* __restrict__ out, IdxT* __restrict__ out_idx, - T sentinel); + compute_t sentinel); -// Kernel function pointer type alias -template -using KernelFuncPtr = void (*)(const T*, const IdxT*, int, IdxT, IdxT, T*, IdxT*, T); +template +using KernelFuncPtr = + void (*)(const DataT*, const IdxT*, int, IdxT, IdxT, DataT*, IdxT*, compute_t); // Helper: Map block-level strategy class to its corresponding kernel function template -template