diff --git a/aiter/ops/topk.py b/aiter/ops/topk.py index b29f8cc27a..1c3666f832 100755 --- a/aiter/ops/topk.py +++ b/aiter/ops/topk.py @@ -3,13 +3,13 @@ # user interface -from typing import Tuple +from typing import Optional, Tuple + import torch -from ..jit.core import ( - compile_ops, -) -from ..utility import dtypes + +from ..jit.core import compile_ops from ..jit.utils.chip_info import get_cu_num +from ..utility import dtypes @compile_ops("module_moe_asm", fc_name="biased_grouped_topk") @@ -202,6 +202,7 @@ def top_k_per_row_prefill( rowStarts: torch.Tensor, rowEnds: torch.Tensor, indices: torch.Tensor, + values: Optional[torch.Tensor], numRows: int, stride0: int, stride1: int, diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index 6aca386744..b0ed0f3d2e 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -1337,6 +1337,7 @@ namespace py = pybind11; py::arg("rowStarts"), \ py::arg("rowEnds"), \ py::arg("indices"), \ + py::arg("values"), \ py::arg("numRows"), \ py::arg("stride0"), \ py::arg("stride1")); \ diff --git a/csrc/include/topk_per_row.h b/csrc/include/topk_per_row.h index b2894fe54a..e3bae1887d 100644 --- a/csrc/include/topk_per_row.h +++ b/csrc/include/topk_per_row.h @@ -6,6 +6,7 @@ void top_k_per_row_prefill(const torch::Tensor& logits, const torch::Tensor& rowStarts, const torch::Tensor& rowEnds, torch::Tensor& indices, + std::optional values, int64_t numRows, int64_t stride0, int64_t stride1); diff --git a/csrc/kernels/topk_per_row_kernels.cu b/csrc/kernels/topk_per_row_kernels.cu old mode 100755 new mode 100644 index af59458122..ccef79d35d --- a/csrc/kernels/topk_per_row_kernels.cu +++ b/csrc/kernels/topk_per_row_kernels.cu @@ -4,6 +4,7 @@ #include #include +#include "aiter_hip_common.h" #include "dispatch_utils.h" #include #include @@ -25,6 +26,7 @@ static inline __device__ uint16_t extractBinIdx(float x) using fp32x1 = __attribute__((__ext_vector_type__(1))) float; using fp32x2 = __attribute__((__ext_vector_type__(2))) float; using fp32x4 = __attribute__((__ext_vector_type__(4))) float; +using fp32x8 = __attribute__((__ext_vector_type__(8))) float; template struct to_vector; @@ -46,6 +48,1562 @@ struct to_vector<4> { using type = fp32x4; }; +template <> +struct to_vector<8> +{ + using type = fp32x8; +}; + +// AIR TopK start + +using WideT = fp32x4; +constexpr int VECTORIZED_READ_SIZE = 16; +constexpr int WARP_SIZE = 64; + +template +struct ComputeOffset +{ + __host__ __device__ explicit ComputeOffset(IdxT const& cols) : cols_(cols) {} + + __host__ __device__ IdxT operator()(IdxT const& x) const { return cols_ * x; } + + IdxT cols_; +}; + +template +__host__ __device__ constexpr int calc_num_buckets() +{ + return 1 << BitsPerPass; +} + +/** + * @brief Provide a ceiling division operation ie. ceil(a / b) + * @tparam IntType supposed to be only integers for now! + */ +template +constexpr __host__ __device__ IntType ceildiv(IntType a, IntType b) +{ + return (a + b - 1) / b; +} + +/** + * @brief Provide an alignment function ie. ceil(a / b) * b + * @tparam IntType supposed to be only integers for now! + */ +template +constexpr __host__ __device__ IntType alignTo(IntType a, IntType b) +{ + return ceildiv(a, b) * b; +} + +template +__host__ __device__ constexpr int calc_num_passes() +{ + return ceildiv(sizeof(T) * 8, BitsPerPass); +} + +__host__ __device__ int round(int num, int round_value) +{ + return ((num - 1) / round_value + 1) * round_value; +} + +template +__device__ constexpr int calc_start_bit(int pass) +{ + int start_bit = static_cast(sizeof(T) * 8) - (pass + 1) * BitsPerPass; + int r = start_bit < 0 ? 0 : start_bit; + return r; +} + +template +__device__ constexpr unsigned calc_mask(int pass) +{ + static_assert(BitsPerPass <= 31); + int num_bits = calc_start_bit(pass - 1) - calc_start_bit(pass); + return (1 << num_bits) - 1; +} + +template +__device__ typename hipcub::Traits::UnsignedBits twiddle_in(T key, bool select_min) +{ + auto bits = reinterpret_cast::UnsignedBits&>(key); + if constexpr (std::is_same_v){ + // TODO: hardcoded for select_min is false! + uint32_t mask = (key < 0) ? 0 : 0x7fffffff; + return bits ^ mask; + } + else { + bits = hipcub::Traits::TwiddleIn(bits); + if(!select_min) + { + bits = ~bits; + } + return bits; + } +} + +template +__device__ T twiddle_out(typename hipcub::Traits::UnsignedBits bits, bool select_min) +{ + if(!select_min) + { + bits = ~bits; + } + bits = hipcub::Traits::TwiddleOut(bits); + return reinterpret_cast(bits); +} + +template +__device__ int calc_bucket(T x, int start_bit, unsigned mask, bool select_min) +{ + static_assert(BitsPerPass <= sizeof(int) * 8 - 1, + "BitsPerPass is too large that the result type could not be int"); + return (twiddle_in(x, select_min) >> start_bit) & mask; +} + +template +constexpr inline std::enable_if_t::value, bool> +is_a_power_of_two(I val) noexcept +{ + return ((val - 1) & val) == 0; +} + +template +__host__ __device__ IdxT calc_buf_len(IdxT len) +{ + // When writing is skipped, only read `in`(type T). + // When writing is not skipped, read `in_buf`(T) and `in_idx_buf`(IdxT), and + // write `out_buf`(T) and `out_idx_buf`(IdxT). The ratio between these cases + // determines whether to skip writing and hence the buffer size. + constexpr RATIO_T ratio = 2 + sizeof(IdxT) * 2 / sizeof(T); + // Even such estimation is too conservative, so further decrease buf_len by + // 1/8 + IdxT buf_len = len / (ratio * 8); + + // one-block kernel splits one large buffer into smaller ones, so round buf + // size to 256 bytes to avoid alignment issues + static_assert(is_a_power_of_two(sizeof(T))); + static_assert(is_a_power_of_two(sizeof(IdxT))); + constexpr IdxT aligned = 256 / std::min(sizeof(T), sizeof(IdxT)); + buf_len = buf_len & (~(aligned - 1)); + return buf_len; +} + +/** + * Map a Func over the input data, using vectorized load instructions if + * possible. + * + * NB: in future, we should move this to + * cpp/include/raft/linalg/detail/unary_op.cuh, which currently does not support + * the second lambda argument (index of an element) + * + * @tparam T element type + * @tparam IdxT indexing type + * @tparam Func void (T x, IdxT idx) + * + * @param thread_rank rank of the calling thread among all participating threads + * @param num_threads number of the threads that participate in processing + * @param in the input data + * @param len the number of elements to read + * @param f the lambda taking two arguments (T x, IdxT idx) + */ +template +__device__ void +vectorized_process(size_t thread_rank, size_t num_threads, T const* in, IdxT len, Func f) +{ + if constexpr(sizeof(T) >= sizeof(WideT)) + { + for(IdxT i = thread_rank; i < len; i += num_threads) + { + f(in[i], i); + } + } + else + { + static_assert(sizeof(WideT) % sizeof(T) == 0); + constexpr int items_per_scalar = sizeof(WideT) / sizeof(T); + + // TODO: it's UB + union + { + WideT scalar; + T array[items_per_scalar]; + } wide; + + int skip_cnt = + (reinterpret_cast(in) % sizeof(WideT)) + ? ((sizeof(WideT) - reinterpret_cast(in) % sizeof(WideT)) / sizeof(T)) + : 0; + if(skip_cnt > len) + { + skip_cnt = len; + } + WideT const* in_cast = reinterpret_cast(in + skip_cnt); + const IdxT len_cast = (len - skip_cnt) / items_per_scalar; + + for(IdxT i = thread_rank; i < len_cast; i += num_threads) + { + wide.scalar = in_cast[i]; + const IdxT real_i = skip_cnt + i * items_per_scalar; +#pragma unroll + for(int j = 0; j < items_per_scalar; ++j) + { + f(wide.array[j], real_i + j); + } + } + + static_assert(WARP_SIZE >= items_per_scalar); + // and because items_per_scalar > skip_cnt, WARP_SIZE > skip_cnt + // no need to use loop + if(thread_rank < skip_cnt) + { + f(in[thread_rank], thread_rank); + } + // because len_cast = (len - skip_cnt) / items_per_scalar, + // len_cast * items_per_scalar + items_per_scalar > len - skip_cnt; + // and so + // len - (skip_cnt + len_cast * items_per_scalar) < items_per_scalar <= + // WARP_SIZE no need to use loop + const IdxT remain_i = skip_cnt + len_cast * items_per_scalar + thread_rank; + if(remain_i < len) + { + f(in[remain_i], remain_i); + } + } +} + +// sync_width should >= WARP_SIZE +template +__device__ void vectorized_process(T const* in, IdxT len, Func f, int sync_width) +{ + const IdxT stride = blockDim.x * gridDim.x; + const IdxT tid = blockIdx.x * blockDim.x + threadIdx.x; + if constexpr(sizeof(T) >= sizeof(WideT)) + { + for(IdxT i = tid; i < len; i += stride) + { + f(in[i], i, true); + } + } + else + { + static_assert(sizeof(WideT) % sizeof(T) == 0); + constexpr int items_per_scalar = sizeof(WideT) / sizeof(T); + + union + { + WideT scalar; + T array[items_per_scalar]; + } wide; + + int skip_cnt = + (reinterpret_cast(in) % sizeof(WideT)) + ? ((sizeof(WideT) - reinterpret_cast(in) % sizeof(WideT)) / sizeof(T)) + : 0; + if(skip_cnt > len) + { + skip_cnt = len; + } + WideT const* in_cast = reinterpret_cast(in + skip_cnt); + const IdxT len_cast = (len - skip_cnt) / items_per_scalar; + + const IdxT len_cast_for_sync = ((len_cast - 1) / sync_width + 1) * sync_width; + for(IdxT i = tid; i < len_cast_for_sync; i += stride) + { + bool valid = i < len_cast; + if(valid) + { + wide.scalar = in_cast[i]; + } + const IdxT real_i = skip_cnt + i * items_per_scalar; +#pragma unroll + for(int j = 0; j < items_per_scalar; ++j) + { + f(wide.array[j], real_i + j, valid); + } + } + + static_assert(WARP_SIZE >= items_per_scalar); + // need at most one warp for skipped and remained elements, + // and sync_width >= WARP_SIZE + if(tid < sync_width) + { + bool valid = tid < skip_cnt; + T value = valid ? in[tid] : T(); + f(value, tid, valid); + + const IdxT remain_i = skip_cnt + len_cast * items_per_scalar + tid; + valid = remain_i < len; + value = valid ? in[remain_i] : T(); + f(value, remain_i, valid); + } + } +} + +template +struct alignas(128) Counter +{ + // We are processing the values in multiple passes, from most significant to + // least significant. In each pass, we keep the length of input (`len`) and + // the `k` of current pass, and update them at the end of the pass. + IdxT k; + IdxT len; + + // `previous_len` is the length of input in previous pass. Note that + // `previous_len` rather than `len` is used for the filtering step because + // filtering is indeed for previous pass (see comments before + // `radix_kernel`). + IdxT previous_len; + + // We determine the bits of the k_th value inside the mask processed by the + // pass. The already known bits are stored in `kth_value_bits`. It's used to + // discriminate a element is a result (written to `out`), a candidate for next + // pass (written to `out_buf`), or not useful (discarded). The bits that are + // not yet processed do not matter for this purpose. + typename hipcub::Traits::UnsignedBits kth_value_bits; + + // Record how many elements have passed filtering. It's used to determine the + // position in the `out_buf` where an element should be written. + alignas(128) IdxT filter_cnt; + + // For a row inside a batch, we may launch multiple thread blocks. This + // counter is used to determine if the current block is the last running + // block. If so, this block will execute scan() and choose_bucket(). + alignas(128) unsigned int finished_block_cnt; + + // Record how many elements have been written to the front of `out`. Elements + // less (if select_min==true) than the k-th value are written from front to + // back. + alignas(128) IdxT out_cnt; + + // Record how many elements have been written to the back of `out`. Elements + // equal to the k-th value are written from back to front. We need to keep + // count of them separately because the number of elements that <= the k-th + // value might exceed k. + alignas(128) IdxT out_back_cnt; +}; + +/** + * Fused filtering of the current pass and building histogram for the next pass + * (see steps 4 & 1 in `radix_kernel` description). + */ +template +__device__ void filter_and_histogram(T const* in_buf, + IdxT const* in_idx_buf, + T* out_buf, + IdxT* out_idx_buf, + T* out, + IdxT* out_idx, + IdxT previous_len, + Counter* counter, + IdxT* histogram, + bool select_min, + int pass, + bool early_stop) +{ + constexpr int num_buckets = calc_num_buckets(); + __shared__ IdxT histogram_smem[num_buckets]; + for(IdxT i = threadIdx.x; i < num_buckets; i += blockDim.x) + { + histogram_smem[i] = 0; + } + __syncthreads(); + + int const start_bit = calc_start_bit(pass); + unsigned const mask = calc_mask(pass); + + if(pass == 0) + { + // Passed to vectorized_process, this function executes in all blocks in + // parallel, i.e. the work is split along the input (both, in batches and + // chunks of a single row). Later, the histograms are merged using + // atomicAdd. + auto f = [select_min, start_bit, mask](T value, IdxT) { + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram_smem + bucket, static_cast(1)); + }; + vectorized_process(static_cast(blockIdx.x) * blockDim.x + threadIdx.x, + static_cast(blockDim.x) * gridDim.x, + in_buf, + previous_len, + f); + } + else + { + IdxT* p_filter_cnt = &counter->filter_cnt; + IdxT* p_out_cnt = &counter->out_cnt; + auto const kth_value_bits = counter->kth_value_bits; + int const previous_start_bit = calc_start_bit(pass - 1); + + // See the remark above on the distributed execution of `f` using + // vectorized_process. + auto f = [in_idx_buf, + out_buf, + out_idx_buf, + out, + out_idx, + select_min, + start_bit, + mask, + previous_start_bit, + kth_value_bits, + p_filter_cnt, + p_out_cnt, + early_stop](T value, IdxT i) { + const auto previous_bits = (twiddle_in(value, select_min) >> previous_start_bit) + << previous_start_bit; + if(previous_bits == kth_value_bits) + { + if(early_stop) + { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); + if(WRITE_TOPK_VALUES) + { + out[pos] = value; + } + + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } + else + { + if(out_buf) + { + IdxT pos = atomicAdd(p_filter_cnt, static_cast(1)); + out_buf[pos] = value; + out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; + } + + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram_smem + bucket, static_cast(1)); + } + } + // the condition `(out_buf || early_stop)` is a little tricky: + // If we skip writing to `out_buf` (when `out_buf` is nullptr), we should + // skip writing to `out` too. So we won't write the same value to `out` + // multiple times in different passes. And if we keep skipping the + // writing, values will be written in `last_filter_kernel()` at last. But + // when `early_stop` is true, we need to write to `out` since it's the + // last chance. + else if((out_buf || early_stop) && previous_bits < kth_value_bits) + { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); + if(WRITE_TOPK_VALUES) + { + out[pos] = value; + } + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } + }; + vectorized_process(static_cast(blockIdx.x) * blockDim.x + threadIdx.x, + static_cast(blockDim.x) * gridDim.x, + in_buf, + previous_len, + f); + } + if(early_stop) + { + return; + } + __syncthreads(); + + // merge histograms produced by individual blocks + for(int i = threadIdx.x; i < num_buckets; i += blockDim.x) + { + // if(histogram_smem[i] != 0) + // { + // atomicAdd(histogram + i, histogram_smem[i]); + // } + *(histogram + i) = histogram_smem[i]; + } +} + +/** + * Replace histogram with its own prefix sum + * (step 2 in `radix_kernel` description) + */ +template +__device__ void scan(IdxT volatile* histogram) +{ + constexpr int num_buckets = calc_num_buckets(); + if constexpr(num_buckets >= BlockSize) + { + static_assert(num_buckets % BlockSize == 0); + constexpr int items_per_thread = num_buckets / BlockSize; + typedef hipcub::BlockLoad + BlockLoad; + typedef hipcub::BlockStore + BlockStore; + typedef hipcub::BlockScan BlockScan; + + __shared__ union + { + typename BlockLoad::TempStorage load; + typename BlockScan::TempStorage scan; + typename BlockStore::TempStorage store; + } temp_storage; + + IdxT thread_data[items_per_thread]; + + BlockLoad(temp_storage.load).Load(histogram, thread_data); + __syncthreads(); + + BlockScan(temp_storage.scan).InclusiveSum(thread_data, thread_data); + __syncthreads(); + + BlockStore(temp_storage.store).Store(histogram, thread_data); + } + else + { + typedef hipcub::BlockScan BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; + + IdxT thread_data = 0; + if(threadIdx.x < num_buckets) + { + thread_data = histogram[threadIdx.x]; + } + + BlockScan(temp_storage).InclusiveSum(thread_data, thread_data); + __syncthreads(); + + if(threadIdx.x < num_buckets) + { + histogram[threadIdx.x] = thread_data; + } + } +} + +/** + * Calculate in which bucket the k-th value will fall + * (steps 3 in `radix_kernel` description) + */ +template +__device__ void +choose_bucket(Counter* counter, IdxT const* histogram, const IdxT k, int const pass) +{ + constexpr int num_buckets = calc_num_buckets(); + for(int i = threadIdx.x; i < num_buckets; i += blockDim.x) + { + IdxT prev = (i == 0) ? 0 : histogram[i - 1]; + IdxT cur = histogram[i]; + + // one and only one thread will satisfy this condition, so counter is + // written by only one thread + if(prev < k && cur >= k) + { + counter->k = k - prev; // how many values still are there to find + counter->len = cur - prev; // number of values in next pass + typename hipcub::Traits::UnsignedBits bucket = i; + int start_bit = calc_start_bit(pass); + counter->kth_value_bits |= bucket << start_bit; + } + } +} + +// For one-block version, last_filter() could be called when pass < num_passes +// - 1. So `pass` could not be constexpr +template +__device__ void last_filter(T const* in_buf, + IdxT const* in_idx_buf, + T* out, + IdxT* out_idx, + IdxT current_len, + IdxT k, + Counter* counter, + bool const select_min, + int const pass) +{ + auto const kth_value_bits = counter->kth_value_bits; + int const start_bit = calc_start_bit(pass); + + // changed in choose_bucket(); need to reload + const IdxT num_of_kth_needed = counter->k; + IdxT* p_out_cnt = &counter->out_cnt; + IdxT* p_out_back_cnt = &counter->out_back_cnt; + IdxT* p_equal = out_idx + k - num_of_kth_needed; + for(IdxT i = threadIdx.x; i < current_len; i += blockDim.x) + { + const T value = in_buf[i]; + auto const bits = (twiddle_in(value, select_min) >> start_bit) << start_bit; + if(bits < kth_value_bits) + { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); + if(WRITE_TOPK_VALUES) + { + out[pos] = value; + } + // For one-block version, `in_idx_buf` could be nullptr at pass 0. + // For non one-block version, if writing has been skipped, `in_idx_buf` + // could be nullptr if `in_buf` is `in` + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } + else if(bits == kth_value_bits) + { + IdxT new_idx = in_idx_buf ? in_idx_buf[i] : i; + IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast(1)); + if(back_pos < num_of_kth_needed) + { + IdxT pos = k - 1 - back_pos; + if(WRITE_TOPK_VALUES) + { + out[pos] = value; + } + if constexpr(!prioritize_smaller_indice) + { + out_idx[pos] = new_idx; + } + } + } + } +} + +template +__global__ void last_filter_kernel(T const* in, + IdxT const* in_idx, + T const* in_buf, + IdxT const* in_idx_buf, + T* out, + IdxT* out_idx, + IdxT len, + IdxT k, + Counter* counters, + bool const select_min) +{ + const int64_t batch_id = blockIdx.y; // size_t to avoid multiplication overflow + + Counter* counter = counters + batch_id; + IdxT previous_len = counter->previous_len; + if(previous_len == 0) + { + return; + } + const IdxT buf_len = calc_buf_len(len); + if(previous_len > buf_len || in_buf == in) + { + in_buf = in + batch_id * len; + in_idx_buf = in_idx ? (in_idx + batch_id * len) : nullptr; + previous_len = len; + } + else + { + in_buf += batch_id * buf_len; + in_idx_buf += batch_id * buf_len; + } + out += batch_id * k; + out_idx += batch_id * k; + + constexpr int pass = calc_num_passes() - 1; + constexpr int start_bit = calc_start_bit(pass); + + auto const kth_value_bits = counter->kth_value_bits; + const IdxT num_of_kth_needed = counter->k; + IdxT* p_out_cnt = &counter->out_cnt; + IdxT* p_out_back_cnt = &counter->out_back_cnt; + IdxT* p_equal = out_idx + k - num_of_kth_needed; + auto f = [k, + select_min, + kth_value_bits, + num_of_kth_needed, + p_out_cnt, + p_out_back_cnt, + in_idx_buf, + out, + out_idx, + p_equal](T value, IdxT i) { + const auto bits = (twiddle_in(value, select_min) >> start_bit) << start_bit; + if(bits < kth_value_bits) + { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); + if(WRITE_TOPK_VALUES) + { + out[pos] = value; + } + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } + else if(bits == kth_value_bits) + { + IdxT new_idx = in_idx_buf ? in_idx_buf[i] : i; + IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast(1)); + if(back_pos < num_of_kth_needed) + { + IdxT pos = k - 1 - back_pos; + if(WRITE_TOPK_VALUES) + { + out[pos] = value; + } + if constexpr(!prioritize_smaller_indice) + { + out_idx[pos] = new_idx; + } + } + } + }; + + vectorized_process(static_cast(blockIdx.x) * blockDim.x + threadIdx.x, + static_cast(blockDim.x) * gridDim.x, + in_buf, + previous_len, + f); +} + +/** + * + * It is expected to call this kernel multiple times (passes), in each pass we + * process a radix, going from the most significant towards the least + * significant bits (MSD). + * + * Conceptually, each pass consists of 4 steps: + * + * 1. Calculate histogram + * First, transform bits into a digit, the value of which is in the range + * [0, 2^{BITS_PER_PASS}-1]. Then count the frequency of each digit value + * and the result is a histogram. That is, histogram[i] contains the count of + * inputs having value i. + * + * 2. Scan the histogram + * Inclusive prefix sum is computed for the histogram. After this step, + * histogram[i] contains the count of inputs having value <= i. + * + * 3. Find the bucket j of the histogram that the k-th value falls into + * + * 4. Filtering + * Input elements whose digit value +__global__ void radix_kernel(T const* in, + IdxT const* in_idx, + T const* in_buf, + IdxT const* in_idx_buf, + T* out_buf, + IdxT* out_idx_buf, + T* out, + IdxT* out_idx, + Counter* counters, + IdxT* histograms, + const IdxT len, + const IdxT* rowStarts, + const IdxT* rowEnds, + const IdxT k, + bool const select_min, + int const pass) +{ + const int64_t batch_id = blockIdx.y; + const IdxT row_len = rowEnds[batch_id] - rowStarts[batch_id]; + + auto counter = counters + batch_id; + IdxT current_k; + IdxT previous_len; + IdxT current_len; + if(pass == 0) + { + current_k = k; + previous_len = row_len; + current_len = row_len; + } + else + { + current_k = counter->k; + current_len = counter->len; + previous_len = counter->previous_len; + } + if(current_len == 0) + { + return; + } + + // When k=len, early_stop will be true at pass 0. It means + // filter_and_histogram() should handle correctly the case that pass=0 and + // early_stop=true. However, this special case of k=len is handled in other + // way in select_k() so such case is not possible here. + bool const early_stop = (current_len == current_k); + const IdxT buf_len = calc_buf_len(row_len); + + // "previous_len > buf_len" means previous pass skips writing buffer + if(pass == 0 || pass == 1 || previous_len > buf_len) + { + in_buf = in + batch_id * len; + in_idx_buf = in_idx ? (in_idx + batch_id * len) : nullptr; + previous_len = row_len; + } + else + { + in_buf += batch_id * buf_len; + in_idx_buf += batch_id * buf_len; + } + // "current_len > buf_len" means current pass will skip writing buffer + if(pass == 0 || current_len > buf_len) + { + out_buf = nullptr; + out_idx_buf = nullptr; + } + else + { + out_buf += batch_id * buf_len; + out_idx_buf += batch_id * buf_len; + } + out += batch_id * k; + out_idx += batch_id * k; + + constexpr int num_buckets = calc_num_buckets(); + auto histogram = histograms + batch_id * num_buckets; + + filter_and_histogram(in_buf, + in_idx_buf, + out_buf, + out_idx_buf, + out, + out_idx, + previous_len, + counter, + histogram, + select_min, + pass, + early_stop); + __threadfence(); + + bool isLastBlock = false; + if(threadIdx.x == 0) + { + unsigned int finished = atomicInc(&counter->finished_block_cnt, gridDim.x - 1); + isLastBlock = (finished == (gridDim.x - 1)); + } + + if(__syncthreads_or(isLastBlock)) + { + if(early_stop) + { + if(threadIdx.x == 0) + { + // `last_filter_kernel()` requires setting previous_len + counter->previous_len = 0; + counter->len = 0; + } + return; + } + + scan(histogram); + __syncthreads(); + choose_bucket(counter, histogram, current_k, pass); + __syncthreads(); + + constexpr int num_passes = calc_num_passes(); + // reset for next pass + // if(pass != num_passes - 1) + // { + // for(int i = threadIdx.x; i < num_buckets; i += blockDim.x) + // { + // histogram[i] = 0; + // } + // } + if(threadIdx.x == 0) + { + // `last_filter_kernel()` requires setting previous_len even in the last + // pass + counter->previous_len = current_len; + // not necessary for the last pass, but put it here anyway + counter->filter_cnt = 0; + } + + if(pass == num_passes - 1) + { + const volatile IdxT num_of_kth_needed = counter->k; + for(IdxT i = threadIdx.x; i < num_of_kth_needed; i += blockDim.x) + { + out_idx[k - num_of_kth_needed + i] = std::numeric_limits::max(); + } + __syncthreads(); + if constexpr(fused_last_filter) + { + last_filter( + out_buf ? out_buf : in_buf, + out_idx_buf ? out_idx_buf : in_idx_buf, + out, + out_idx, + out_buf ? current_len : row_len, + k, + counter, + select_min, + pass); + } + } + } +} + +template +unsigned calc_grid_dim(int batch_size, IdxT len, int sm_cnt) +{ + static_assert(VECTORIZED_READ_SIZE / sizeof(T) >= 1); + + int active_blocks; + HIP_CALL(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &active_blocks, + radix_kernel, + BlockSize, + 0)); + active_blocks *= sm_cnt; + + IdxT best_num_blocks = 0; + float best_tail_wave_penalty = 1.0f; + const IdxT max_num_blocks = ceildiv(len, VECTORIZED_READ_SIZE / sizeof(T) * BlockSize); + for(int num_waves = 1;; ++num_waves) + { + IdxT num_blocks = std::min( + max_num_blocks, static_cast(std::max(num_waves * active_blocks / batch_size, 1))); + IdxT items_per_thread = ceildiv(len, num_blocks * BlockSize); + items_per_thread = alignTo(items_per_thread, VECTORIZED_READ_SIZE / sizeof(T)); + num_blocks = ceildiv(len, items_per_thread * BlockSize); + float actual_num_waves = static_cast(num_blocks) * batch_size / active_blocks; + float tail_wave_penalty = + (ceilf(actual_num_waves) - actual_num_waves) / ceilf(actual_num_waves); + + // 0.15 is determined experimentally. It also ensures breaking the loop + // early, e.g. when num_waves > 7, tail_wave_penalty will always <0.15 + if(tail_wave_penalty < 0.15) + { + best_num_blocks = num_blocks; + break; + } + else if(tail_wave_penalty < best_tail_wave_penalty) + { + best_num_blocks = num_blocks; + best_tail_wave_penalty = tail_wave_penalty; + } + + if(num_blocks == max_num_blocks) + { + break; + } + } + return best_num_blocks; +} + +template +__host__ __device__ void set_buf_pointers(T const* in, + IdxT const* in_idx, + T* buf1, + IdxT* idx_buf1, + T* buf2, + IdxT* idx_buf2, + int pass, + T const*& in_buf, + IdxT const*& in_idx_buf, + T*& out_buf, + IdxT*& out_idx_buf) +{ + if(pass == 0) + { + in_buf = in; + in_idx_buf = nullptr; + out_buf = nullptr; + out_idx_buf = nullptr; + } + else if(pass == 1) + { + in_buf = in; + in_idx_buf = in_idx; + out_buf = buf1; + out_idx_buf = idx_buf1; + } + else if(pass % 2 == 0) + { + in_buf = buf1; + in_idx_buf = idx_buf1; + out_buf = buf2; + out_idx_buf = idx_buf2; + } + else + { + in_buf = buf2; + in_idx_buf = idx_buf2; + out_buf = buf1; + out_idx_buf = idx_buf1; + } +} + +template +__device__ void set_buf_pointers(T const* in, + IdxT const* in_idx, + char* bufs, + IdxT buf_len, + int pass, + T const*& in_buf, + IdxT const*& in_idx_buf, + T*& out_buf, + IdxT*& out_idx_buf) +{ + // bufs consists of 4 pieces in order: buf1, buf2, idx_buf1, idx_buf2 + if(pass == 0) + { + in_buf = in; + in_idx_buf = nullptr; + out_buf = nullptr; + out_idx_buf = nullptr; + } + else if(pass == 1) + { + in_buf = in; + in_idx_buf = in_idx; + out_buf = reinterpret_cast(bufs); + out_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); + } + else if(pass % 2 == 0) + { + in_buf = reinterpret_cast(bufs); + in_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); + out_buf = const_cast(in_buf + buf_len); + out_idx_buf = const_cast(in_idx_buf + buf_len); + } + else + { + out_buf = reinterpret_cast(bufs); + out_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); + in_buf = out_buf + buf_len; + in_idx_buf = out_idx_buf + buf_len; + } +} + +// The following a few functions are for the one-block version, which uses +// single thread block for each row of a batch. +template +__device__ void filter_and_histogram_for_one_block(T const* in_buf, + IdxT const* in_idx_buf, + T* out_buf, + IdxT* out_idx_buf, + T* out, + IdxT* out_idx, + const IdxT previous_len, + Counter* counter, + IdxT* histogram, + bool select_min, + int pass) +{ + constexpr int num_buckets = calc_num_buckets(); + for(int i = threadIdx.x; i < num_buckets; i += blockDim.x) + { + histogram[i] = 0; + } + IdxT* p_filter_cnt = &counter->filter_cnt; + if(threadIdx.x == 0) + { + *p_filter_cnt = 0; + } + __syncthreads(); + + int const start_bit = calc_start_bit(pass); + unsigned const mask = calc_mask(pass); + + if(pass == 0) + { + auto f = [histogram, select_min, start_bit, mask](T value, IdxT) { + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram + bucket, static_cast(1)); + }; + vectorized_process(threadIdx.x, blockDim.x, in_buf, previous_len, f); + } + else if(!out_buf) + { + // not use vectorized_process here because it increases #registers a lot + auto const kth_value_bits = counter->kth_value_bits; + int const previous_start_bit = calc_start_bit(pass - 1); + + for(IdxT i = threadIdx.x; i < previous_len; i += blockDim.x) + { + const T value = in_buf[i]; + auto const previous_bits = (twiddle_in(value, select_min) >> previous_start_bit) + << previous_start_bit; + if(previous_bits == kth_value_bits) + { + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram + bucket, static_cast(1)); + } + } + } + else + { + // not use vectorized_process here because it increases #registers a lot + IdxT* p_out_cnt = &counter->out_cnt; + auto const kth_value_bits = counter->kth_value_bits; + int const previous_start_bit = calc_start_bit(pass - 1); + + for(IdxT i = threadIdx.x; i < previous_len; i += blockDim.x) + { + const T value = in_buf[i]; + auto const previous_bits = (twiddle_in(value, select_min) >> previous_start_bit) + << previous_start_bit; + if(previous_bits == kth_value_bits) + { + + IdxT pos = atomicAdd(p_filter_cnt, static_cast(1)); + out_buf[pos] = value; + out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; + + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram + bucket, static_cast(1)); + } + else if(previous_bits < kth_value_bits) + { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); + if(WRITE_TOPK_VALUES) + { + out[pos] = value; + } + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } + } + } +} + +template +__global__ void radix_topk_one_block_kernel(T const* in, + IdxT const* in_idx, + const int64_t len, + const IdxT* rowStarts, + const IdxT* rowEnds, + const IdxT k, + T* out, + IdxT* out_idx, + bool const select_min, + char* bufs) +{ + constexpr int num_buckets = calc_num_buckets(); + __shared__ Counter counter; + __shared__ IdxT histogram[num_buckets]; + + const int64_t batch_id = blockIdx.x; + const IdxT rowStart = rowStarts[batch_id]; + const IdxT rowEnd = rowEnds[batch_id]; + const IdxT row_len = rowEnd - rowStart; + if(threadIdx.x == 0) + { + counter.k = k; + counter.len = row_len; + counter.previous_len = row_len; + counter.kth_value_bits = 0; + counter.out_cnt = 0; + counter.out_back_cnt = 0; + } + __syncthreads(); + + in += batch_id * len; + if(in_idx) + { + in_idx += batch_id * len; + } + + out += batch_id * k; + out_idx += batch_id * k; + if(row_len <= k) + { + in += rowStart; + for(int rowIt = threadIdx.x; rowIt < k; rowIt += BlockSize) + { + out_idx[rowIt] = rowIt < row_len ? rowIt + rowStart : -1; + if(WRITE_TOPK_VALUES) + { + out[rowIt] = rowIt < row_len ? in[rowIt] : 0; + } + } + return; + } + + const IdxT buf_len = calc_buf_len(row_len); + bufs += batch_id * buf_len * 2 * (sizeof(T) + sizeof(IdxT)); + + constexpr int num_passes = calc_num_passes(); + for(int pass = 0; pass < num_passes; ++pass) + { + T const* in_buf = nullptr; + IdxT const* in_idx_buf = nullptr; + T* out_buf = nullptr; + IdxT* out_idx_buf = nullptr; + set_buf_pointers(in, in_idx, bufs, buf_len, pass, in_buf, in_idx_buf, out_buf, out_idx_buf); + + const IdxT current_len = counter.len; + const IdxT current_k = counter.k; + IdxT previous_len = counter.previous_len; + if(previous_len > buf_len) + { + in_buf = in; + in_idx_buf = in_idx; + previous_len = row_len; + } + if(current_len > buf_len) + { + // so "out_buf==nullptr" denotes skipping writing buffer in current pass + out_buf = nullptr; + out_idx_buf = nullptr; + } + + filter_and_histogram_for_one_block( + in_buf, + in_idx_buf, + out_buf, + out_idx_buf, + out, + out_idx, + previous_len, + &counter, + histogram, + select_min, + pass); //@TODO CHECK UPDATE CODE + __syncthreads(); + + scan(histogram); + __syncthreads(); + + choose_bucket(&counter, histogram, current_k, pass); + if(threadIdx.x == 0) + { + counter.previous_len = current_len; + } + __syncthreads(); + + if(pass == num_passes - 1) + { + last_filter( + out_buf ? out_buf : in, + out_buf ? out_idx_buf : in_idx, + out, + out_idx, + out_buf ? current_len : row_len, + k, + &counter, + select_min, + pass); + break; + } + else if(counter.len == counter.k) + { + last_filter( + out_buf ? out_buf : in, + out_buf ? out_idx_buf : in_idx, + out, + out_idx, + out_buf ? current_len : row_len, + k, + &counter, + select_min, + pass); + break; + } + } +} + +inline size_t calc_aligned_size(std::vector const& sizes) +{ + const size_t ALIGN_BYTES = 256; + const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1); + size_t total = 0; + for(auto sz : sizes) + { + total += (sz + ALIGN_BYTES - 1) & ALIGN_MASK; + } + return total + ALIGN_BYTES - 1; +} + +inline std::vector calc_aligned_pointers(void const* p, std::vector const& sizes) +{ + const size_t ALIGN_BYTES = 256; + const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1); + + char* ptr = + reinterpret_cast((reinterpret_cast(p) + ALIGN_BYTES - 1) & ALIGN_MASK); + + std::vector aligned_pointers; + aligned_pointers.reserve(sizes.size()); + for(auto sz : sizes) + { + aligned_pointers.push_back(ptr); + ptr += (sz + ALIGN_BYTES - 1) & ALIGN_MASK; + } + + return aligned_pointers; +} + +template +void standalone_stable_radix_topk_(void* buf, + size_t& buf_size, + T const* in, + IdxT const* in_idx, + int batch_size, + int64_t len, + IdxT* rowStarts, + IdxT* rowEnds, + IdxT k, + T* out, + IdxT* out_idx, + bool select_min, + bool fused_last_filter, + unsigned grid_dim, + hipStream_t stream, + bool sorted = false) +{ + static_assert(calc_num_passes() > 1); + constexpr int num_buckets = calc_num_buckets(); + + Counter* counters = nullptr; + IdxT* histograms = nullptr; + T* buf1 = nullptr; + IdxT* idx_buf1 = nullptr; + T* buf2 = nullptr; + IdxT* idx_buf2 = nullptr; + + IdxT* topk_out_idx = nullptr; + + { + IdxT len_candidates = calc_buf_len(len); + std::vector sizes = {sizeof(*counters) * batch_size, + sizeof(*histograms) * num_buckets * batch_size, + sizeof(*buf1) * len_candidates * batch_size, + sizeof(*idx_buf1) * len_candidates * batch_size, + sizeof(*buf2) * len_candidates * batch_size, + sizeof(*idx_buf2) * len_candidates * batch_size, + sizeof(*topk_out_idx) * k * batch_size}; + + size_t total_size = calc_aligned_size(sizes); + if(!buf) + { + buf_size = total_size; + return; + } + + std::vector aligned_pointers = calc_aligned_pointers(buf, sizes); + counters = static_cast(aligned_pointers[0]); + histograms = static_cast(aligned_pointers[1]); + buf1 = static_cast(aligned_pointers[2]); + idx_buf1 = static_cast(aligned_pointers[3]); + buf2 = static_cast(aligned_pointers[4]); + idx_buf2 = static_cast(aligned_pointers[5]); + topk_out_idx = static_cast(aligned_pointers[6]); + + HIP_CALL(hipMemsetAsync(aligned_pointers[0], + 0, + static_cast(aligned_pointers[2]) - + static_cast(aligned_pointers[0]), + stream)); + } + + T const* in_buf = nullptr; + IdxT const* in_idx_buf = nullptr; + T* out_buf = nullptr; + IdxT* out_idx_buf = nullptr; + + dim3 blocks(grid_dim, batch_size); + + constexpr int num_passes = calc_num_passes(); + + auto kernel = radix_kernel; + + for(int pass = 0; pass < num_passes; ++pass) + { + set_buf_pointers(in, + in_idx, + buf1, + idx_buf1, + buf2, + idx_buf2, + pass, + in_buf, + in_idx_buf, + out_buf, + out_idx_buf); + + if(fused_last_filter && pass == num_passes - 1) + { + kernel = radix_kernel; + } + + kernel<<>>(in, + in_idx, + in_buf, + in_idx_buf, + out_buf, + out_idx_buf, + out, + out_idx, + counters, + histograms, + len, + rowStarts, + rowEnds, + k, + select_min, + pass); + } + + if(!fused_last_filter) + { + last_filter_kernel + <<>>( + in, in_idx, out_buf, out_idx_buf, out, out_idx, len, k, counters, select_min); + } +} + +template +void standalone_stable_radix_topk_one_block_(void* buf, + size_t& buf_size, + T const* in, + IdxT const* in_idx, + int batch_size, + int64_t len, + IdxT* rowStarts, + IdxT* rowEnds, + IdxT k, + T* out, + IdxT* out_idx, + bool select_min, + hipStream_t stream, + bool sorted = false) +{ + static_assert(calc_num_passes() > 1); + + char* bufs = nullptr; + IdxT* topk_out_idx = nullptr; + + const IdxT buf_len = calc_buf_len(len); + + { + size_t total_size = 0; + std::vector sizes = {buf_len * 2 * (sizeof(T) + sizeof(IdxT)) * batch_size, + sizeof(*topk_out_idx) * k * batch_size}; + + total_size = calc_aligned_size(sizes); + + if(!buf) + { + buf_size = total_size; + return; + } + + std::vector aligned_pointers = calc_aligned_pointers(buf, sizes); + bufs = static_cast(aligned_pointers[0]); + topk_out_idx = static_cast(aligned_pointers[1]); + } + + radix_topk_one_block_kernel + <<>>( + in, in_idx, len, rowStarts, rowEnds, k, out, out_idx, select_min, bufs); +} + +template +void standalone_stable_radix_11bits(void* buf, + size_t& buf_size, + T const* in, + int batch_size, + int64_t len, + IdxT* rowStarts, + IdxT* rowEnds, + IdxT k, + T* out, + IdxT* out_idx, + bool greater, + hipStream_t stream) +{ + constexpr int items_per_thread = 32; + constexpr int block_dim = 1024; + constexpr bool fused_last_filter = false; + if(len <= block_dim * items_per_thread) + { + standalone_stable_radix_topk_one_block_( + buf, + buf_size, + in, + static_cast(nullptr), + batch_size, + len, + rowStarts, + rowEnds, + k, + out, + out_idx, + !greater, + stream, + sorted); + } + else + { + int sm_cnt = get_num_cu_func(); + + unsigned grid_dim = + calc_grid_dim(batch_size, len, sm_cnt); + + if(grid_dim == 1) + { + standalone_stable_radix_topk_one_block_( + buf, + buf_size, + in, + static_cast(nullptr), + batch_size, + len, + rowStarts, + rowEnds, + k, + out, + out_idx, + !greater, + stream, + sorted); + } + else + { + standalone_stable_radix_topk_( + buf, + buf_size, + in, + static_cast(nullptr), + batch_size, + len, + rowStarts, + rowEnds, + k, + out, + out_idx, + !greater, + fused_last_filter, + grid_dim, + stream, + sorted); + } + } +} + +// AIR TopK end static inline __device__ uint32_t floatAsSortableUint(float x) { @@ -131,7 +1689,7 @@ __device__ bool processHistogramStep(const float* logits, for(int vecIdx = (rowStart / Vector) + threadIdx.x; vecIdx < (rowEnd + Vector - 1) / Vector; vecIdx += kNumThreadsPerBlock) { - auto v = reinterpret_cast(logits)[vecIdx]; + auto v = reinterpret_cast(logits)[vecIdx]; #pragma unroll for(int j = 0; j < Vector; j++) { @@ -271,11 +1829,8 @@ template -__device__ void topk_per_row_kernel(const float* logits, - const int rowStart, - const int rowEnd, - int* outIndices, - int stride1) +__device__ void topk_per_row_kernel( + const float* logits, const int rowStart, const int rowEnd, int* outIndices, int stride1) { // The number of slots for the final pass. static constexpr int kNumFinalItems = 2048; @@ -580,95 +2135,211 @@ static __global__ void topk_per_row_decode( auto logitsLocal = logits + rowIdx * stride0; topk_per_row_kernel( - logitsLocal, rowStart, rowEnd, outIndicesLocal, stride1); + logitsLocal, rowStart, rowEnd, outIndicesLocal, stride1); } } // namespace aiter -void top_k_per_row_prefill(const torch::Tensor& logits, - const torch::Tensor& rowStarts, - const torch::Tensor& rowEnds, - torch::Tensor& indices, - int64_t numRows, - int64_t stride0, - int64_t stride1) +template +int64_t invokeComputeTopkLastDimWorkspaceSize(int32_t numRows, int32_t stride0) { - constexpr int kSortingAlgorithmThreshold = 12288; + using IdxT = int32_t; - // Compute the results on the device. - constexpr int kNumThreadsPerBlock = 512; + size_t buf_size = 0; + void* workspace = nullptr; + T const* in = nullptr; + T* out_val = nullptr; + IdxT* out_idx = nullptr; - // The top-k width. - static constexpr int kTopK = 2048; + constexpr int block_dim = 1024; + constexpr bool fused_last_filter = false; + constexpr bool sorted = true; + constexpr bool is_largest = true; + constexpr int k = 2048; - const hipStream_t stream = at::hip::getCurrentHIPStream(); - - int numInsertionBlocks = std::min(static_cast(numRows), kSortingAlgorithmThreshold); + int sm_cnt = get_num_cu_func(); + unsigned grid_dim = + aiter::calc_grid_dim(numRows, stride0, sm_cnt); - if(stride0 % 4 == 0) + if(grid_dim == 1) { - aiter::topk_per_row - <<>>(logits.data_ptr(), - rowStarts.data_ptr(), - rowEnds.data_ptr(), - indices.data_ptr(), - static_cast(stride0), - static_cast(stride1), - 0); + aiter::standalone_stable_radix_topk_one_block_( + workspace, + buf_size, + in, + static_cast(nullptr), + numRows, + stride0, + static_cast(nullptr), + static_cast(nullptr), + k, + out_val, + out_idx, + !is_largest, + 0, + sorted); } else { - aiter::topk_per_row - <<>>(logits.data_ptr(), - rowStarts.data_ptr(), - rowEnds.data_ptr(), - indices.data_ptr(), - static_cast(stride0), - static_cast(stride1), - 0); + aiter::standalone_stable_radix_topk_( + workspace, + buf_size, + in, + static_cast(nullptr), + numRows, + stride0, + static_cast(nullptr), + static_cast(nullptr), + k, + out_val, + out_idx, + !is_largest, + fused_last_filter, + grid_dim, + 0, + sorted); } + return buf_size; +} - if(numRows > kSortingAlgorithmThreshold) +void top_k_per_row_prefill(const torch::Tensor& logits, + const torch::Tensor& rowStarts, + const torch::Tensor& rowEnds, + torch::Tensor& indices, + std::optional values, + int64_t numRows, + int64_t stride0, + int64_t stride1) +{ + size_t buf_size = 0; // will be overwritten by the kernel + + static constexpr int kTopK = 2048; + static constexpr bool is_largest = true; + + const hipStream_t stream = at::hip::getCurrentHIPStream(); + int64_t workspace_size = invokeComputeTopkLastDimWorkspaceSize(numRows, stride0); + // int64_t workspace_size = int64_t(1024)*1024*1024*2; + auto options = torch::TensorOptions().dtype(torch::kUInt8).device(logits.device()); + torch::Tensor workspace = torch::empty({workspace_size}, options); + + if(values.has_value()) { - int numRadixBlocks = numRows - kSortingAlgorithmThreshold; - if(stride0 % 4 == 0) - { - aiter::topk_per_row - <<>>(logits.data_ptr(), - rowStarts.data_ptr(), - rowEnds.data_ptr(), - indices.data_ptr(), - static_cast(stride0), - static_cast(stride1), - kSortingAlgorithmThreshold); - } - else - { - aiter::topk_per_row - <<>>(logits.data_ptr(), - rowStarts.data_ptr(), - rowEnds.data_ptr(), - indices.data_ptr(), - static_cast(stride0), - static_cast(stride1), - kSortingAlgorithmThreshold); - } + aiter::standalone_stable_radix_11bits( + static_cast(workspace.data_ptr()), + buf_size, + logits.data_ptr(), + static_cast(numRows), + stride0, + rowStarts.data_ptr(), + rowEnds.data_ptr(), + kTopK, + values->data_ptr(), + indices.data_ptr(), + is_largest, + stream); + } + else + { + aiter::standalone_stable_radix_11bits( + static_cast(workspace.data_ptr()), + buf_size, + logits.data_ptr(), + static_cast(numRows), + stride0, + rowStarts.data_ptr(), + rowEnds.data_ptr(), + kTopK, + nullptr, + indices.data_ptr(), + is_largest, + stream); } } +// void top_k_per_row_prefill(const torch::Tensor& logits, +// const torch::Tensor& rowStarts, +// const torch::Tensor& rowEnds, +// torch::Tensor& indices, +// int64_t numRows, +// int64_t stride0, +// int64_t stride1) +// { +// constexpr int kSortingAlgorithmThreshold = 12288; + +// // Compute the results on the device. +// constexpr int kNumThreadsPerBlock = 1024; + +// // The top-k width. +// static constexpr int kTopK = 2048; + +// const hipStream_t stream = at::hip::getCurrentHIPStream(); + +// int numInsertionBlocks = std::min(static_cast(numRows), kSortingAlgorithmThreshold); + +// if(stride0 % 4 == 0) +// { +// aiter::topk_per_row +// <<>>(logits.data_ptr(), +// rowStarts.data_ptr(), +// rowEnds.data_ptr(), +// indices.data_ptr(), +// static_cast(stride0), +// static_cast(stride1), +// 0); +// } +// else +// { +// aiter::topk_per_row +// <<>>(logits.data_ptr(), +// rowStarts.data_ptr(), +// rowEnds.data_ptr(), +// indices.data_ptr(), +// static_cast(stride0), +// static_cast(stride1), +// 0); +// } + +// if(numRows > kSortingAlgorithmThreshold) +// { +// int numRadixBlocks = numRows - kSortingAlgorithmThreshold; +// if(stride0 % 4 == 0) +// { +// aiter::topk_per_row +// <<>>(logits.data_ptr(), +// rowStarts.data_ptr(), +// rowEnds.data_ptr(), +// indices.data_ptr(), +// static_cast(stride0), +// static_cast(stride1), +// kSortingAlgorithmThreshold); +// } +// else +// { +// aiter::topk_per_row +// <<>>(logits.data_ptr(), +// rowStarts.data_ptr(), +// rowEnds.data_ptr(), +// indices.data_ptr(), +// static_cast(stride0), +// static_cast(stride1), +// kSortingAlgorithmThreshold); +// } +// } +// } + void top_k_per_row_decode(const torch::Tensor& logits, - int64_t next_n, - const torch::Tensor& seqLens, - torch::Tensor& indices, - int64_t numRows, - int64_t stride0, - int64_t stride1) + int64_t next_n, + const torch::Tensor& seqLens, + torch::Tensor& indices, + int64_t numRows, + int64_t stride0, + int64_t stride1) { constexpr int kSortingAlgorithmThreshold = 12288; // Compute the results on the device. constexpr int kNumThreadsPerBlock = 1024; const hipStream_t stream = at::hip::getCurrentHIPStream(); - const auto numColumns = logits.size(1); + const auto numColumns = logits.size(1); if(numColumns < kSortingAlgorithmThreshold) { @@ -695,23 +2366,25 @@ void top_k_per_row_decode(const torch::Tensor& logits, } else { - if (stride0 % 4 == 0) + if(stride0 % 4 == 0) { aiter::topk_per_row_decode <<>>(logits.data_ptr(), seqLens.data_ptr(), indices.data_ptr(), static_cast(stride0), - static_cast(stride1), - static_cast(next_n)); - } else { - aiter::topk_per_row_decode - <<>>(logits.data_ptr(), - seqLens.data_ptr(), - indices.data_ptr(), - static_cast(stride0), - static_cast(stride1), - static_cast(next_n)); + static_cast(stride1), + static_cast(next_n)); + } + else + { + aiter::topk_per_row_decode + <<>>(logits.data_ptr(), + seqLens.data_ptr(), + indices.data_ptr(), + static_cast(stride0), + static_cast(stride1), + static_cast(next_n)); + } } } -} diff --git a/op_tests/test_topk_per_row.py b/op_tests/test_topk_per_row.py index a4ca0f7a05..30f038dac5 100755 --- a/op_tests/test_topk_per_row.py +++ b/op_tests/test_topk_per_row.py @@ -47,11 +47,13 @@ def create_random_logits( def create_row_boundaries( - num_rows: int, top_k: int = 2048 + num_rows: int, num_prefix: int = 0, top_k: int = 2048 ) -> tuple[torch.Tensor, torch.Tensor]: """Create row start and end indices for testing.""" row_starts = torch.zeros(num_rows, dtype=torch.int32, device="cuda") - row_ends = torch.arange(1, num_rows + 1, device="cuda", dtype=torch.int32) + row_ends = torch.arange( + num_prefix + 1, num_prefix + num_rows + 1, device="cuda", dtype=torch.int32 + ) return row_starts, row_ends @@ -118,6 +120,7 @@ def run_top_k_per_row_prefill( row_starts: torch.Tensor, row_ends: torch.Tensor, indices: torch.Tensor, + values: torch.Tensor, num_rows: int, stride_row: int, stride_col: int, @@ -130,6 +133,7 @@ def run_top_k_per_row_prefill( row_starts, row_ends, indices, + values, num_rows, stride_row, stride_col, @@ -161,7 +165,7 @@ def run_top_k_per_row_decode( @benchmark() -def test_top_k_per_row_prefill(num_rows: int, top_k: int) -> dict: +def test_top_k_per_row_prefill(num_rows: int, num_prefix: int, top_k: int) -> dict: """ Test topk_per_row_prefill. """ @@ -169,18 +173,22 @@ def test_top_k_per_row_prefill(num_rows: int, top_k: int) -> dict: torch.set_default_device("cuda:0") # Create test data - row_starts, row_ends = create_row_boundaries(num_rows) + row_starts, row_ends = create_row_boundaries(num_rows, num_prefix) logits = create_random_logits(row_starts, row_ends, torch.float32, 42) # Create output tensors indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda") + values = torch.empty((num_rows, top_k), dtype=torch.float32, device="cuda").fill_(0) + # Run the kernel _, us = run_top_k_per_row_prefill( logits, row_starts, row_ends, indices, + None, # values + # values, num_rows, logits.stride(0), logits.stride(1), @@ -272,7 +280,7 @@ def test_top_k_per_row_decode( "-c", "--context_len", type=int, - default=[8, 16, 32, 64, 128, 1024, 16384, 65536, 90000, 128000], + default=[8, 128, 1024, 3072, 4096, 8192, 16384, 32768, 65536, 90000, 128000], nargs="+", help="""number of kv. e.g.: -c 64""", @@ -288,6 +296,15 @@ def test_top_k_per_row_decode( e.g.: -k 2048""", ) +parser.add_argument( + "--num_prefix", + type=int, + default=[0], + nargs="+", + help="""top-k elements per row. + e.g.: --num_prefix 8000 16000 24000 32000 40000 48000 56000""", +) + parser.add_argument( "-b", "--decode_batch_size", @@ -325,20 +342,21 @@ def test_top_k_per_row_decode( df = [] for m in args.context_len: for k in args.top_k: - ret = test_top_k_per_row_prefill(m, k) - df.append(ret) + for num_prefix in args.num_prefix: + ret = test_top_k_per_row_prefill(m, num_prefix, k) + df.append(ret) df = pd.DataFrame(df) aiter.logger.info(f"summary for top_k_per_row_prefill kernel:\n{df}") -df = [] -for m in args.decode_batch_size: - for ctx in args.context_len: - for k in args.top_k: - for n in args.next_n: - ret = test_top_k_per_row_decode(m, ctx, k, n) - df.append(ret) +# df = [] +# for m in args.decode_batch_size: +# for ctx in args.context_len: +# for k in args.top_k: +# for n in args.next_n: +# ret = test_top_k_per_row_decode(m, ctx, k, n) +# df.append(ret) -df = pd.DataFrame(df) -aiter.logger.info(f"summary for top_k_per_row_decode kernel:\n{df}") +# df = pd.DataFrame(df) +# aiter.logger.info(f"summary for top_k_per_row_decode kernel:\n{df}")