diff --git a/csrc/kernels/mla/metadata/v1_0_device.cuh b/csrc/kernels/mla/metadata/v1_0_device.cuh index a3a9fe2e6f..f7d0d5ea5b 100644 --- a/csrc/kernels/mla/metadata/v1_0_device.cuh +++ b/csrc/kernels/mla/metadata/v1_0_device.cuh @@ -47,20 +47,19 @@ void kn_get_mla_metadata_v1_0(MlaMetadataV1KernelParameter params) const int32_t bid_ori = bid / params.qk_batch_ratio; const int32_t kv_begin = params.p_seqlens_kv_indptr[bid_ori]; - const int32_t kv_end = params.p_seqlens_kv_indptr[bid_ori + 1]; - int32_t kv_tail = [&](){ if constexpr(DP_MODE) { - // max(*, 0) for cuda graph capture: kvlen < mtp+1 - return max(bid % params.ori_seqlen_qo - params.ori_seqlen_qo + 1, 0); + return bid % params.ori_seqlen_qo - params.ori_seqlen_qo + 1; } else { return 0; } }(); - const int32_t seqlen_kv = kv_end - kv_begin + kv_tail; + const int32_t kv_end = max(params.p_seqlens_kv_indptr[bid_ori + 1] + kv_tail, kv_begin + 1); + + const int32_t seqlen_kv = kv_end - kv_begin; const int32_t num_blocks = integer_divide_ceil_power2( seqlen_kv, params.kv_granularity, params.kv_granularity_log2); @@ -98,19 +97,17 @@ void kn_get_mla_metadata_v1_0(MlaMetadataV1KernelParameter params) const int32_t bid_ori = bid / params.qk_batch_ratio; const int32_t kv_begin = p_lds_kv_seqlen[bid_ori]; - int32_t kv_end = p_lds_kv_seqlen[bid_ori + 1]; int32_t kv_tail = [&](){ if constexpr(DP_MODE) { - // max(*, 0) for cuda graph capture: kvlen < mtp+1 - return max(bid % params.ori_seqlen_qo - params.ori_seqlen_qo + 1, 0); + return bid % params.ori_seqlen_qo - params.ori_seqlen_qo + 1; } else { return 0; } }(); - kv_end += kv_tail; + const int32_t kv_end = max(p_lds_kv_seqlen[bid_ori + 1] + kv_tail, kv_begin + 1); MlaWorkInfo work_info{}; const int32_t split_start = p_lds_shift[bid]; const int32_t split_local = p_lds_split[bid];