Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions csrc/kernels/topk_per_row_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1435,6 +1435,18 @@ __global__ void radix_topk_one_block_kernel(T const* in,
return;
}

// Long-row path: kernel internally treats in[0..row_len) as the valid
// window. Shift `in` (and `in_idx`) up by `rowStart` so that the radix
// pipeline reads the actual valid columns rather than the masked-out
// [0, rowStart) prefix that fp8_mqa_logits fills with -inf. Internal
// indices i are then relative to rowStart; we add rowStart back to
// out_idx at the end of this branch to get absolute column indices.
in += rowStart;
if(in_idx)
{
in_idx += rowStart;
}
Comment on lines +1438 to +1448

const IdxT buf_len = calc_buf_len<T, IdxT, unsigned>(len);
bufs += batch_id * buf_len * 2 * (sizeof(T) + sizeof(IdxT));

Expand Down Expand Up @@ -1522,6 +1534,23 @@ __global__ void radix_topk_one_block_kernel(T const* in,
break;
}
}

// Long-row path was using rowStart-relative indices inside the radix
// pipeline (because we shifted `in` by rowStart above). Translate them
// back to absolute column indices for downstream consumers. Sentinels
// (-1, written when fewer than k valid candidates exist) are preserved.
if(rowStart > 0)
{
__syncthreads();
for(int i = threadIdx.x; i < k; i += BlockSize)
{
IdxT v = out_idx[i];
if(v >= 0)
{
out_idx[i] = v + rowStart;
}
Comment on lines +1542 to +1551
}
}
}

inline size_t calc_aligned_size(std::vector<size_t> const& sizes)
Expand Down
Loading