From 9c178071c36e9ff0639d9e0b0568fd000b4d8458 Mon Sep 17 00:00:00 2001 From: chenjun Date: Fri, 24 Apr 2026 02:48:57 -0500 Subject: [PATCH] Fix top_k_per_row_prefill err when batched_token_numm > 4096 --- csrc/kernels/topk_per_row_kernels.cu | 29 ++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/csrc/kernels/topk_per_row_kernels.cu b/csrc/kernels/topk_per_row_kernels.cu index 89331c52df..6edf377ca8 100644 --- a/csrc/kernels/topk_per_row_kernels.cu +++ b/csrc/kernels/topk_per_row_kernels.cu @@ -1435,6 +1435,18 @@ __global__ void radix_topk_one_block_kernel(T const* in, return; } + // Long-row path: kernel internally treats in[0..row_len) as the valid + // window. Shift `in` (and `in_idx`) up by `rowStart` so that the radix + // pipeline reads the actual valid columns rather than the masked-out + // [0, rowStart) prefix that fp8_mqa_logits fills with -inf. Internal + // indices i are then relative to rowStart; we add rowStart back to + // out_idx at the end of this branch to get absolute column indices. + in += rowStart; + if(in_idx) + { + in_idx += rowStart; + } + const IdxT buf_len = calc_buf_len(len); bufs += batch_id * buf_len * 2 * (sizeof(T) + sizeof(IdxT)); @@ -1522,6 +1534,23 @@ __global__ void radix_topk_one_block_kernel(T const* in, break; } } + + // Long-row path was using rowStart-relative indices inside the radix + // pipeline (because we shifted `in` by rowStart above). Translate them + // back to absolute column indices for downstream consumers. Sentinels + // (-1, written when fewer than k valid candidates exist) are preserved. + if(rowStart > 0) + { + __syncthreads(); + for(int i = threadIdx.x; i < k; i += BlockSize) + { + IdxT v = out_idx[i]; + if(v >= 0) + { + out_idx[i] = v + rowStart; + } + } + } } inline size_t calc_aligned_size(std::vector const& sizes)