-
Notifications
You must be signed in to change notification settings - Fork 166
feat: Adaptive topk algorithm selection based on input characteristics #1578
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR implements adaptive Top-K algorithm selection based on input characteristics (primarily K value) to optimize performance across different workload scenarios. The implementation intelligently chooses between bitonic sort for small K values (≤128) and 11-bit radix sort for larger K values, maximizing performance across the entire spectrum.
Key Changes:
- Adaptive algorithm selection heuristic based on K value and input length
- Extended capacity support from 512 to 2048 elements
- Variable-length input support via rowStarts/rowEnds parameters
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| op_tests/test_topk_plain.py | Updated test to handle new function signatures and expanded test parameter ranges; added exception handling for Triton fallback |
| csrc/kernels/topk_plain_kernels.cu | Core implementation of adaptive selection logic; added support for variable-length inputs and increased MAX_CAPACITY to 2048 |
| csrc/kernels/topk_per_row_kernels.cu | Added null-safe handling for rowStarts/rowEnds parameters and extended k parameter to filtering functions |
| csrc/include/topk_plain.h | Updated function signature to include topk_out tensor and variable-length support parameters |
| csrc/include/rocm_ops.hpp | Updated Python bindings with new parameters and default values |
| csrc/include/opus/opus.hpp | Fixed type inconsistency: changed __fp16 to _Float16 for med3 template specialization |
| aiter/ops/topk_plain.py | Updated Python wrapper signature with additional parameters for variable-length support |
| aiter/jit/optCompilerConfig.json | Added topk_per_row_kernels.cu to build sources |
Comments suppressed due to low confidence (1)
csrc/kernels/topk_per_row_kernels.cu:1340
- Similar to the other function, the parameter
kis added tofilter_and_histogram_for_one_blockat line 1182 but is never used in the function body (lines 1183-1340). This indicates either an incomplete implementation or an unnecessary parameter addition. Review ifkshould be used for any logic within this function.
IdxT k)
{
constexpr int num_buckets = calc_num_buckets<BitsPerPass>();
for(int i = threadIdx.x; i < num_buckets * 2; i += blockDim.x)
{
histogram[i] = 0;
}
IdxT* p_filter_cnt = &counter->filter_cnt;
if(threadIdx.x == 0)
{
*p_filter_cnt = 0;
}
__syncthreads();
int const start_bit = calc_start_bit<T, BitsPerPass>(pass);
unsigned const mask = calc_mask<T, BitsPerPass>(pass);
if(pass == 0)
{
T local_min = std::numeric_limits<T>::max();
T local_max = std::numeric_limits<T>::lowest();
auto f = [histogram, select_min, start_bit, mask, &local_min, &local_max](
T value, IdxT, int& acc, int& prev_bin_idx, bool is_last) {
int bucket = calc_bucket<T, BitsPerPass>(value, start_bit, mask, select_min);
// atomicAdd(histogram + bucket, static_cast<IdxT>(1));
if(bucket == prev_bin_idx)
{
acc++;
}
else
{
if(acc > 0)
{
atomicAdd(histogram + prev_bin_idx, static_cast<IdxT>(acc));
}
acc = 1;
prev_bin_idx = bucket;
}
if(is_last)
{
return;
}
int bucket_low =
calc_bucket<T, BitsPerPass>(value, 0, (1 << BitsPerPass) - 1, select_min);
atomicAdd(histogram + num_buckets + bucket_low, static_cast<IdxT>(1));
local_min = fminf(local_min, value);
local_max = fmaxf(local_max, value);
};
vectorized_process(threadIdx.x, blockDim.x, in_buf, previous_len, f);
using BlockReduceT =
hipcub::BlockReduce<T, BlockSize, hipcub::BLOCK_REDUCE_WARP_REDUCTIONS>;
__shared__ typename BlockReduceT::TempStorage temp_storage;
__shared__ bool use_one_pass;
T global_min = BlockReduceT(temp_storage).Reduce(local_min, hipcub::Min());
T global_max = BlockReduceT(temp_storage).Reduce(local_max, hipcub::Max());
if(threadIdx.x == 0)
{
auto global_min_bits = twiddle_in(global_min, select_min);
auto global_max_bits = twiddle_in(global_max, select_min);
uint32_t diff = global_min_bits ^ global_max_bits;
use_one_pass = diff < (1u << BitsPerPass);
}
__syncthreads();
return use_one_pass;
}
else if(!out_buf)
{
// not use vectorized_process here because it increases #registers a lot
auto const kth_value_bits = counter->kth_value_bits;
int const previous_start_bit = calc_start_bit<T, BitsPerPass>(pass - 1);
for(IdxT i = threadIdx.x; i < previous_len; i += blockDim.x)
{
const T value = in_buf[i];
auto const previous_bits = (twiddle_in(value, select_min) >> previous_start_bit)
<< previous_start_bit;
if(previous_bits == kth_value_bits)
{
int bucket = calc_bucket<T, BitsPerPass>(value, start_bit, mask, select_min);
atomicAdd(histogram + bucket, static_cast<IdxT>(1));
}
}
}
else
{
// not use vectorized_process here because it increases #registers a lot
IdxT* p_out_cnt = &counter->out_cnt;
auto const kth_value_bits = counter->kth_value_bits;
int const previous_start_bit = calc_start_bit<T, BitsPerPass>(pass - 1);
if(in_idx_buf)
{
for(IdxT i = threadIdx.x; i < previous_len; i += blockDim.x)
{
const T value = in_buf[i];
auto const previous_bits = (twiddle_in(value, select_min) >> previous_start_bit)
<< previous_start_bit;
if(previous_bits == kth_value_bits)
{
IdxT pos = atomicAdd(p_filter_cnt, static_cast<IdxT>(1));
out_buf[pos] = value;
out_idx_buf[pos] = in_idx_buf[i];
int bucket = calc_bucket<T, BitsPerPass>(value, start_bit, mask, select_min);
atomicAdd(histogram + bucket, static_cast<IdxT>(1));
}
else if(previous_bits < kth_value_bits)
{
IdxT pos = atomicAdd(p_out_cnt, static_cast<IdxT>(1));
if(WRITE_TOPK_VALUES)
{
out[pos] = value;
}
out_idx[pos] = in_idx_buf[i];
}
}
}
else
{
for(IdxT i = threadIdx.x; i < previous_len; i += blockDim.x)
{
const T value = in_buf[i];
auto const previous_bits = (twiddle_in(value, select_min) >> previous_start_bit)
<< previous_start_bit;
if(previous_bits == kth_value_bits)
{
IdxT pos = atomicAdd(p_filter_cnt, static_cast<IdxT>(1));
out_buf[pos] = value;
out_idx_buf[pos] = i;
int bucket = calc_bucket<T, BitsPerPass>(value, start_bit, mask, select_min);
atomicAdd(histogram + bucket, static_cast<IdxT>(1));
}
else if(previous_bits < kth_value_bits)
{
IdxT pos = atomicAdd(p_out_cnt, static_cast<IdxT>(1));
if(WRITE_TOPK_VALUES)
{
out[pos] = value;
}
out_idx[pos] = i;
}
}
}
}
return false;
}
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
#1578) * Add radix-base selection * Remove explicit template * Update the selected k condition * remove pos < k guard * code format * Update csrc/include/rocm_ops.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update csrc/kernels/topk_per_row_kernels.cu Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update csrc/kernels/topk_plain_kernels.cu Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update test_topk_plain.py * Update TODO message * Update csrc/kernels/topk_per_row_kernels.cu Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update op_tests/test_topk_plain.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * format test_topk_plain.py with black * Disable triton test for a resonalbe execution time * add explicit template instantiation * fix explicit template instantiation * add explicit template instantiation * Add bf16 support * Fix linter * Fix build errors * Fix condition * Fix build and test * Update conditions --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Lingpeng Jin <103567126+valarLip@users.noreply.github.com> Co-authored-by: MHYang <meng-hsuan.yang@amd.com>
* fix sink error for asm fmha (#1652) Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> * add guard in case pynccl init failed (#1671) * One shot pa (#1670) * add one shot pa kernel * fix buffer load in sliding window kernel * fix typo * revert --------- Co-authored-by: root <root@hjbog-srdc-24.amd.com> * fix(pa_ps): fix pa_ps_asm .co for gfx950 (#1669) Signed-off-by: Double Young <yang.yang2@amd.com> * modify test_bf16gemm_test (#1678) * Fix Ruff command in pre-checks (#1675) * fix mha bwd golden perf issue (#1666) * topk uplift v1 (#1662) /lgtm The customer has tested the code. It can work. * topk uplift v1 * topk add api for choose topk_v1 or topk_v2 --------- Co-authored-by: yonshuai <yonshuai@amd.com> Co-authored-by: yongshuai <yongshuai@amd.com> * fix missing return in mha_bwd (#1688) * Remove the input parameter "out" in gemm_a4w4 (#1679) * Remove the input parameter "out" in gemm_a4w4 * update * format --------- Co-authored-by: valarLip <Lingpeng.Jin@amd.com> * fwd v3 hd192 optimize inst alignment for causal mode (#1663) Co-authored-by: Lingpeng Jin <103567126+valarLip@users.noreply.github.com> * fix swa case mismatch (#1694) * fixing the fp4 gemm tune script Exception caused by tile_m name inconsistency (#1686) * CI: Migrate Triton tests to aiter-1gpu-runner (#1690) * add ntile 128 for a8 blkQ moe 1 stage (#1695) * add fmoe co with tilesize 32x128 * add ps co * fix pertoken co bug * add co to csv * add 128ntile logic for one stage asm * fix mem fault during perf turn * en vs for pertoken kernel --------- Co-authored-by: feifei14119 <feiw@amd.com> Co-authored-by: zufayu <zufayu@amd.com> * Optimize RoPE in the cases that hdim is small. (#1698) * Introduce new grid config strategy for compatibility with cases that hdim is small. * add launch bound to make sure that occu is always 8 * follow Copilot the suggestions * rm garbage from whl (#1696) * enhance prebuild logic (#1672) * enhance prebuild logic * ATen.h build issues * bug fix * bug fix II * bug fix III --------- Co-authored-by: zufayu <zufayu@amd.com> Co-authored-by: Lingpeng Jin <103567126+valarLip@users.noreply.github.com> * LLfp4 qr cap for atom (#1673) * QR cap implemented to limit QR to prefill * test git config * Fix to genericize qr comm cap * Incorrect cap number * [MLA] MLA conditions rewrite (#1665) * open mla mtp and remove some logs * fix qlen dense 128,N * fix hint * support sparse qlen input = 1 * change default splits * fix dp causal (#1677) * add two fp4 tune shapes and tuned config (#1687) * add two fp4 tune shapes and tuned config * change 32800 to 65536 to cover all cases between 32768 to 65536 as per feedback * Dev/a8w4 and a8w8splitk (#1667) * support moe a8w8 splitk (#1654) * Add support to a8w8_ck_moe_blk_gemm1 splitk * add switch and add some logging * tiny fix * update ck 3rd party and add some logging * add AITER_HEURISTIC_ONLY env * update ck * add condition to bypass tuned cfg * change bypass type * fix * fix removed log * upate ck submodule * fix lint * force to run tests --------- Co-authored-by: oscar <huaiguxu@amd.com> * Zan/moe a8w4 (#1655) * update * update * update quant * ut ready * update quant type * compile pass * python3 op_tests/test_moe_2stage.py -t 16 -e 1 -k 1 -dim 256,256 ready * update aiter dipatcher for bf16&fp8 * support a16 a8 dispatch * finish quant & sort * update aiter framework for a8w4 moe * update ck * update * update * update for atom * update --------- Co-authored-by: Zzz9990 <Zzz9990> Co-authored-by: root <root@hjbog-srdc-24.amd.com> * update ck * fix dispatch * fix too much logging * update * update ck * update ck * fix ruff code style * revert aiter-test yaml * fix ci * fix ci * fix ci * add mocked tuned result and decoding cfg token to next power of 2 * Update tuned_fmoe.csv remove duplicate * remove hack dtype * fix black * unique index * add empty arg to ck_moe_stage1 * resolve bias into lru cache * rename bypass cfg to AITER_BYPASS_TUNE_CONFIG --------- Co-authored-by: oscar <huaiguxu@amd.com> Co-authored-by: Zzz9990 <zanzhang@amd.com> Co-authored-by: root <root@hjbog-srdc-24.amd.com> Co-authored-by: felix <felix.li@amd.com> Co-authored-by: Lingpeng Jin <103567126+valarLip@users.noreply.github.com> * bf16_gemm_clean_in_kl (#1700) * bf16_gemm_clean_in_kl * update * update * update * update * fix tuner (#1701) * fix tuner * Update gradlib/gradlib/GemmTuner.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: amd-ruitang3 <145657428+amd-ruitang3@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * add gen_fake for 4 gemm operators (#1456) Co-authored-by: Lin, Soga <soga.lin@amd.com> Co-authored-by: sogalin <39478626+sogalin@users.noreply.github.com> * fix llvm issue (#1703) * fix llvm issue * fix copilot * feat: Adaptive topk algorithm selection based on input characteristics (#1578) * Add radix-base selection * Remove explicit template * Update the selected k condition * remove pos < k guard * code format * Update csrc/include/rocm_ops.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update csrc/kernels/topk_per_row_kernels.cu Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update csrc/kernels/topk_plain_kernels.cu Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update test_topk_plain.py * Update TODO message * Update csrc/kernels/topk_per_row_kernels.cu Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update op_tests/test_topk_plain.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * format test_topk_plain.py with black * Disable triton test for a resonalbe execution time * add explicit template instantiation * fix explicit template instantiation * add explicit template instantiation * Add bf16 support * Fix linter * Fix build errors * Fix condition * Fix build and test * Update conditions --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Lingpeng Jin <103567126+valarLip@users.noreply.github.com> Co-authored-by: MHYang <meng-hsuan.yang@amd.com> * fix mha bwd build error (#1705) * fix moe bug when pipever=v1 and nblk=64 (#1707) * fix bug * update * fix (#1710) * fix * update lint * [PA] Optimize PA Decode Gluon Performance for BF16/FP16 with KV_BLOCK_SIZE=64 and Fix ROCm 7.0 AOT Compilation (#1691) * Optimize pa_decode_gluon f16/bf16 perf for KV_BLOCK_SIZE=64 & fix ROCm 7.0 AOT - Add dedicated blocked layouts for f16/bf16 compute types - Add local AOT compile tool to fix ROCm 7.0 compatibility * black format file * format file to pass the ruff check * fix error in gfx950 * Fix argument parsing logic when AITER_JIT_DIR is set (#1715) When AITER_JIT_DIR is defined the enum module is loaded as "module_aiter_enum" rather than "aiter.jit.module_aiter_enum". This caused the docstring cleanup of enums to not work properly, causing a NameError exception in check_args. * fix topk deocde bug in logit value is same (#1716) Co-authored-by: yonshuai <yonshuai@amd.com> * add fp32 input (#1706) * add fp32 input * format code * perf bug fix * logic fix : out type != input type * bug fix * format code * remove dtype convert before act_and_mul in fused_moe --------- Co-authored-by: zufayu <zufayu@amd.com> Co-authored-by: chenjun <junchen2@amd.com> * add sampling aot (#1711) * add sampling aot * simple compile * fix compile bugs * fix a bug * revert changes --------- Co-authored-by: root <root@hjbog-srdc-24.amd.com> * update * bugfix * update * update --------- Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> Signed-off-by: Double Young <yang.yang2@amd.com> Co-authored-by: Linjun-AMD <Jun.Lin@amd.com> Co-authored-by: Lingpeng Jin <103567126+valarLip@users.noreply.github.com> Co-authored-by: who who who <fsx950223@outlook.com> Co-authored-by: root <root@hjbog-srdc-24.amd.com> Co-authored-by: Double Young <yang.yang2@amd.com> Co-authored-by: amd-ruitang3 <145657428+amd-ruitang3@users.noreply.github.com> Co-authored-by: Satya Nikhil Kodukula <nikhil.kodukula@gmail.com> Co-authored-by: JaxChen29 <jichen@amd.com> Co-authored-by: steamedMantou <82486092+steamedMantou@users.noreply.github.com> Co-authored-by: yonshuai <yonshuai@amd.com> Co-authored-by: yongshuai <yongshuai@amd.com> Co-authored-by: Yu Guo <82124926+yuguo68@users.noreply.github.com> Co-authored-by: la <46212055+junhaha666@users.noreply.github.com> Co-authored-by: valarLip <Lingpeng.Jin@amd.com> Co-authored-by: shay-li77 <xiangxli@amd.com> Co-authored-by: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com> Co-authored-by: Xin Huang <Xin.Huang@amd.com> Co-authored-by: zufayu <zufa.yu@amd.com> Co-authored-by: feifei14119 <feiw@amd.com> Co-authored-by: zufayu <zufayu@amd.com> Co-authored-by: ruanjm <jiming.ruan@amd.com> Co-authored-by: amirumoAMD <Amelia.Moore@amd.com> Co-authored-by: yadaish <yadai@amd.com> Co-authored-by: oscar <huaiguxu@amd.com> Co-authored-by: felix <felix.li@amd.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: mqhc2020 <marvin.tsai@amd.com> Co-authored-by: Lin, Soga <soga.lin@amd.com> Co-authored-by: sogalin <39478626+sogalin@users.noreply.github.com> Co-authored-by: ClementLinCF <162283536+ClementLinCF@users.noreply.github.com> Co-authored-by: MHYang <meng-hsuan.yang@amd.com> Co-authored-by: lalala-sh <Jiaxing.Wen@amd.com> Co-authored-by: yanguahe <yanguahe@amd.com> Co-authored-by: omoisis-dn <omoisis@drivenets.com> Co-authored-by: chenjun <junchen2@amd.com>
#1578) * Add radix-base selection * Remove explicit template * Update the selected k condition * remove pos < k guard * code format * Update csrc/include/rocm_ops.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update csrc/kernels/topk_per_row_kernels.cu Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update csrc/kernels/topk_plain_kernels.cu Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update test_topk_plain.py * Update TODO message * Update csrc/kernels/topk_per_row_kernels.cu Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update op_tests/test_topk_plain.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * format test_topk_plain.py with black * Disable triton test for a resonalbe execution time * add explicit template instantiation * fix explicit template instantiation * add explicit template instantiation * Add bf16 support * Fix linter * Fix build errors * Fix condition * Fix build and test * Update conditions --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Lingpeng Jin <103567126+valarLip@users.noreply.github.com> Co-authored-by: MHYang <meng-hsuan.yang@amd.com>
Motivation
We need a selection mechanism to choose the optimal top-K algorithm.
Technical Details
Adaptive algorithm selection based on K value:
Test Plan
python test_topk_plain.py
Test Result
This figures illustrates the performance comparison across different Top-K strategies. The dashed line represents BlockTopK (our bitonic sort-based implementation), which demonstrates competitive performance for K ≤ 128 but begins to lag behind the radix sort-based approach (radix_11bits) as K increases beyond this threshold.
Our AdaptiveTopK strategy intelligently selects the optimal algorithm based on K: leveraging BlockTopK for K ≤ 128 and switching to radix_11bits for K > 128. This adaptive selection ensures peak performance across the entire K spectrum, as evidenced by the solid line consistently matching or exceeding the best-performing algorithm at each K value.
Submission Checklist