Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
ba9ff92
Add radix-base selection
ClementLinCF Dec 5, 2025
7aea9bf
Remove explicit template
ClementLinCF Dec 5, 2025
fa16c95
Update the selected k condition
ClementLinCF Dec 5, 2025
b138e5f
remove pos < k guard
ClementLinCF Dec 6, 2025
65f53b6
code format
ClementLinCF Dec 6, 2025
d8b8d2d
Update csrc/include/rocm_ops.hpp
ClementLinCF Dec 8, 2025
c34deed
Update csrc/kernels/topk_per_row_kernels.cu
ClementLinCF Dec 8, 2025
afe3ff6
Update csrc/kernels/topk_plain_kernels.cu
ClementLinCF Dec 8, 2025
c47deb4
Update test_topk_plain.py
ClementLinCF Dec 8, 2025
c361340
Update TODO message
ClementLinCF Dec 8, 2025
1946eb1
Update csrc/kernels/topk_per_row_kernels.cu
ClementLinCF Dec 8, 2025
5046dd3
Update op_tests/test_topk_plain.py
ClementLinCF Dec 8, 2025
f0ad619
Merge branch 'main' into adaptive_topk
ClementLinCF Dec 8, 2025
2e7791c
Merge branch 'main' into adaptive_topk
ClementLinCF Dec 9, 2025
d3cfb82
Merge branch 'main' into adaptive_topk
ClementLinCF Dec 10, 2025
090a0cd
format test_topk_plain.py with black
ClementLinCF Dec 10, 2025
e525ded
Merge branch 'main' into adaptive_topk
ClementLinCF Dec 10, 2025
76df2d2
Merge branch 'main' into adaptive_topk
valarLip Dec 10, 2025
cda1276
Merge branch 'main' into adaptive_topk
ClementLinCF Dec 11, 2025
f61c11a
Disable triton test for a resonalbe execution time
ClementLinCF Dec 11, 2025
88fe65a
add explicit template instantiation
ClementLinCF Dec 12, 2025
890298d
fix explicit template instantiation
ClementLinCF Dec 12, 2025
6ca176e
add explicit template instantiation
ClementLinCF Dec 12, 2025
ad8d1d2
Merge branch 'main' into adaptive_topk
ClementLinCF Dec 12, 2025
4550051
Add bf16 support
MHYangAMD Dec 15, 2025
734643c
Merge branch 'main' into adaptive_topk
MHYangAMD Dec 16, 2025
9bf966b
Fix linter
MHYangAMD Dec 16, 2025
3d4ec2e
Fix build errors
MHYangAMD Dec 16, 2025
aa20858
Fix condition
MHYangAMD Dec 17, 2025
f68a8f5
Fix build and test
MHYangAMD Dec 17, 2025
f5c82b3
Merge branch 'main' into adaptive_topk
MHYangAMD Dec 17, 2025
a827f81
Update conditions
MHYangAMD Dec 17, 2025
0c09f49
Merge branch 'main' into adaptive_topk
MHYangAMD Dec 17, 2025
18bb2d1
Merge branch 'main' into adaptive_topk
valarLip Dec 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion aiter/jit/optCompilerConfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -1077,7 +1077,8 @@
"module_topk_plain": {
"srcs": [
"f'{AITER_CSRC_DIR}/pybind/topk_plain_pybind.cu'",
"f'{AITER_CSRC_DIR}/kernels/topk_plain_kernels.cu'"
"f'{AITER_CSRC_DIR}/kernels/topk_plain_kernels.cu'",
"f'{AITER_CSRC_DIR}/kernels/topk_per_row_kernels.cu'"
],
"flags_extra_cc": [],
"flags_extra_hip": [],
Expand Down
7 changes: 6 additions & 1 deletion aiter/ops/topk_plain.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
def topk_plain(
x: torch.Tensor,
topk_ids: torch.Tensor,
topk_out: torch.Tensor,
topk: int,
largest: bool,
largest: bool = True,
rowStarts: torch.Tensor = None,
rowEnds: torch.Tensor = None,
stride0: int = -1,
stride1: int = 1,
) -> None:
pass
2 changes: 1 addition & 1 deletion csrc/include/opus/opus.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -907,7 +907,7 @@ template<> OPUS_D float min<float>(const float&a, const float&b) { return

template<typename T> OPUS_D T med3(const T&a, const T&b, const T&c) { auto max_0 = max(a, b); auto min_0 = max(a, b); return max(max_0, max(min_0, c)); }
template<> OPUS_D float med3<float>(const float&a, const float&b, const float&c) { return __builtin_amdgcn_fmed3f(a, b, c); }
template<> OPUS_D __fp16 med3<__fp16>(const __fp16&a, const __fp16&b, const __fp16&c) { return __builtin_amdgcn_fmed3h(a, b, c); }
template<> OPUS_D _Float16 med3<_Float16>(const _Float16&a, const _Float16&b, const _Float16&c) { return __builtin_amdgcn_fmed3h(a, b, c); }
/////////////////////////////////////////////////////////////////////////////////////////////////////////
// buffer load/store related
OPUS_D constexpr auto buffer_default_config() {
Expand Down
19 changes: 12 additions & 7 deletions csrc/include/rocm_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1635,10 +1635,15 @@ namespace py = pybind11;
py::arg("final_output"), \
py::arg("final_lse") = std::nullopt);

#define TOPK_PLAIN_PYBIND \
m.def("topk_plain", \
&topk_plain, \
py::arg("values"), \
py::arg("topk_ids"), \
py::arg("topk"), \
py::arg("largest"));
#define TOPK_PLAIN_PYBIND \
m.def("topk_plain", \
&topk_plain, \
py::arg("values"), \
py::arg("topk_ids"), \
py::arg("topk_out"), \
py::arg("topk"), \
py::arg("largest") = true, \
py::arg("rowStarts") = torch::Tensor(), \
py::arg("rowEnds") = torch::Tensor(), \
py::arg("stride0") = -1, \
py::arg("stride1") = 1);
9 changes: 7 additions & 2 deletions csrc/include/topk_plain.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,10 @@

void topk_plain(torch::Tensor& values,
torch::Tensor& topk_ids,
int topk_num,
bool largest);
torch::Tensor& topk_out,
int topk,
bool largest = true,
torch::Tensor rowStarts = torch::Tensor(),
torch::Tensor rowEnds = torch::Tensor(),
int64_t stride0 = -1,
int64_t stride1 = 1);
84 changes: 72 additions & 12 deletions csrc/kernels/topk_per_row_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,8 @@ __device__ void filter_and_histogram(T const* in_buf,
IdxT* histogram,
bool select_min,
int pass,
bool early_stop)
bool early_stop,
IdxT k)
{
constexpr int num_buckets = calc_num_buckets<BitsPerPass>();
__shared__ IdxT histogram_smem[num_buckets];
Expand Down Expand Up @@ -893,9 +894,19 @@ __global__ void radix_kernel(T const* in,
int const pass)
{
const int64_t batch_id = blockIdx.y;
const IdxT row_len = phase == Phase::Prefill
? rowEnds[batch_id] - rowStarts[batch_id]
: rowEnds[batch_id / next_n] - next_n + (batch_id % next_n) + 1;

IdxT row_len = len;
if(phase == Phase::Prefill)
{
if(rowStarts && rowEnds)
{
row_len = rowEnds[batch_id] - rowStarts[batch_id];
}
}
else
{
row_len = rowEnds[batch_id / next_n] - next_n + (batch_id % next_n) + 1;
}

auto counter = counters + batch_id;
IdxT current_k;
Expand Down Expand Up @@ -965,7 +976,8 @@ __global__ void radix_kernel(T const* in,
histogram,
select_min,
pass,
early_stop);
early_stop,
k);
__threadfence();

bool isLastBlock = false;
Expand Down Expand Up @@ -1187,7 +1199,8 @@ __device__ bool filter_and_histogram_for_one_block(T const* in_buf,
Counter<T, IdxT>* counter,
IdxT* histogram,
bool select_min,
int pass)
int pass,
IdxT k)
{
constexpr int num_buckets = calc_num_buckets<BitsPerPass>();
for(int i = threadIdx.x; i < num_buckets * 2; i += blockDim.x)
Expand Down Expand Up @@ -1371,11 +1384,25 @@ __global__ void radix_topk_one_block_kernel(T const* in,
__shared__ IdxT histogram[num_buckets * 2];

const int64_t batch_id = blockIdx.x;
const IdxT rowStart = phase == Phase::Prefill ? rowStarts[batch_id] : 0;
const IdxT rowEnd = phase == Phase::Prefill
? rowEnds[batch_id]
: rowEnds[batch_id / next_n] - next_n + (batch_id % next_n) + 1;
const IdxT row_len = rowEnd - rowStart;

IdxT rowStart = 0;
IdxT rowEnd = len;
if(phase == Phase::Prefill)
{
if(rowStarts && rowEnds)
{
rowStart = rowStarts[batch_id];
rowEnd = rowEnds[batch_id];
}
}
else
{
rowEnd = rowEnds[batch_id / next_n] - next_n + (batch_id % next_n) + 1;
rowStart = 0;
}

const IdxT row_len = rowEnd - rowStart;

if(threadIdx.x == 0)
{
counter.k = k;
Expand Down Expand Up @@ -1448,7 +1475,8 @@ __global__ void radix_topk_one_block_kernel(T const* in,
&counter,
histogram,
select_min,
pass); //@TODO CHECK UPDATE CODE
pass,
k); //@TODO CHECK UPDATE CODE
__syncthreads();

scan<IdxT, BitsPerPass, BlockSize>(histogram + use_one_pass * num_buckets);
Expand Down Expand Up @@ -1811,6 +1839,35 @@ void standalone_stable_radix_11bits(void* buf,
}
}

// Explicit template instantiation for standalone_stable_radix_11bits
template void standalone_stable_radix_11bits<float, int, true, true>(void* buf,
size_t& buf_size,
float const* in,
int batch_size,
int64_t len,
int* rowStarts,
int* rowEnds,
int k,
float* out,
int* out_idx,
bool greater,
hipStream_t stream,
int next_n);

template void standalone_stable_radix_11bits<float, int, false, true>(void* buf,
size_t& buf_size,
float const* in,
int batch_size,
int64_t len,
int* rowStarts,
int* rowEnds,
int k,
float* out,
int* out_idx,
bool greater,
hipStream_t stream,
int next_n);

// AIR TopK end

static inline __device__ uint32_t floatAsSortableUint(float x)
Expand Down Expand Up @@ -2410,6 +2467,9 @@ int64_t invokeComputeTopkLastDimWorkspaceSize(int32_t numRows, int32_t stride0)
return buf_size;
}

// Explicit template instantiation to ensure the symbol is available for linking
template int64_t invokeComputeTopkLastDimWorkspaceSize<float>(int32_t numRows, int32_t stride0);

void top_k_per_row_prefill(const torch::Tensor& logits,
const torch::Tensor& rowStarts,
const torch::Tensor& rowEnds,
Expand Down
Loading