Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions csrc/kernels/mla/metadata/v1_0_device.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,19 @@ void kn_get_mla_metadata_v1_0(MlaMetadataV1KernelParameter params)
const int32_t bid_ori = bid / params.qk_batch_ratio;

const int32_t kv_begin = params.p_seqlens_kv_indptr[bid_ori];
const int32_t kv_end = params.p_seqlens_kv_indptr[bid_ori + 1];

int32_t kv_tail = [&](){
if constexpr(DP_MODE)
{
// max(*, 0) for cuda graph capture: kvlen < mtp+1
return max(bid % params.ori_seqlen_qo - params.ori_seqlen_qo + 1, 0);
return bid % params.ori_seqlen_qo - params.ori_seqlen_qo + 1;
}
else
{
return 0;
}
}();
const int32_t seqlen_kv = kv_end - kv_begin + kv_tail;
const int32_t kv_end = max(params.p_seqlens_kv_indptr[bid_ori + 1] + kv_tail, kv_begin + 1);

const int32_t seqlen_kv = kv_end - kv_begin;

const int32_t num_blocks = integer_divide_ceil_power2(
seqlen_kv, params.kv_granularity, params.kv_granularity_log2);
Expand Down Expand Up @@ -98,19 +97,17 @@ void kn_get_mla_metadata_v1_0(MlaMetadataV1KernelParameter params)
const int32_t bid_ori = bid / params.qk_batch_ratio;

const int32_t kv_begin = p_lds_kv_seqlen[bid_ori];
int32_t kv_end = p_lds_kv_seqlen[bid_ori + 1];
int32_t kv_tail = [&](){
if constexpr(DP_MODE)
{
// max(*, 0) for cuda graph capture: kvlen < mtp+1
return max(bid % params.ori_seqlen_qo - params.ori_seqlen_qo + 1, 0);
return bid % params.ori_seqlen_qo - params.ori_seqlen_qo + 1;
}
else
{
return 0;
}
}();
kv_end += kv_tail;
const int32_t kv_end = max(p_lds_kv_seqlen[bid_ori + 1] + kv_tail, kv_begin + 1);
MlaWorkInfo work_info{};
const int32_t split_start = p_lds_shift[bid];
const int32_t split_local = p_lds_split[bid];
Expand Down