From 5526deee7174819ff443ac773a8435c957663821 Mon Sep 17 00:00:00 2001 From: fsx950223 Date: Wed, 16 Jul 2025 09:11:23 +0000 Subject: [PATCH] refactor --- csrc/py_itfs_ck/mha_varlen_fwd_kernels.cu | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/csrc/py_itfs_ck/mha_varlen_fwd_kernels.cu b/csrc/py_itfs_ck/mha_varlen_fwd_kernels.cu index cf968b7fff..7449a64988 100644 --- a/csrc/py_itfs_ck/mha_varlen_fwd_kernels.cu +++ b/csrc/py_itfs_ck/mha_varlen_fwd_kernels.cu @@ -93,7 +93,7 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, bias_ptr = alibi_slopes.data_ptr(); stride_bias = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; } - + return fmha_fwd_args{q.data_ptr(), k.data_ptr(), v.data_ptr(), @@ -468,19 +468,11 @@ mha_varlen_fwd(at::Tensor &q, // [total_q, hq, d] if (return_dropout_randval) {p.zero_();} } - int num_splits = 0; - num_splits = aiter::override_num_splits_if_necessary(batch_size, num_heads, max_seqlen_q, head_size_v, 0, num_splits); - TORCH_CHECK(num_splits > 0, "num_splits should greater than 0"); - TORCH_CHECK(num_splits <= 128, "num_splits greater than 128 is not supported"); - - auto softmax_lse_accum = torch::empty({num_heads, num_splits, total_q}, opts.dtype(at::kFloat)); - auto out_accum = torch::empty({num_heads, num_splits, total_q, head_size_v}, opts.dtype(at::kFloat)); - - int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size(); auto rng_state = torch::empty({2}, opts.dtype(torch::kInt64)); auto rng_state_ptr = reinterpret_cast(rng_state.data_ptr()); if (p_dropout > 0.0) { + int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size(); auto gen = at::get_generator_or_default( gen_, at::cuda::detail::getDefaultCUDAGenerator()); // See Note [Acquire lock when using random generators] @@ -490,12 +482,21 @@ mha_varlen_fwd(at::Tensor &q, // [total_q, hq, d] aiter::ParsePhiloxCudaState, dim3(1), dim3(64), 0, 0, philox_args, rng_state_ptr); } std::optional seqlens_k = std::nullopt; + if (max_seqlen_k > 0) { auto stream = at::cuda::getCurrentHIPStream().stream(); ck_tile::stream_config stream_config{stream}; if (paged_KV) { + int num_splits = 0; + num_splits = aiter::override_num_splits_if_necessary(batch_size, num_heads, max_seqlen_q, head_size_v, 0, num_splits); + TORCH_CHECK(num_splits > 0, "num_splits should greater than 0"); + TORCH_CHECK(num_splits <= 128, "num_splits greater than 128 is not supported"); + + auto softmax_lse_accum = torch::empty({num_heads, num_splits, total_q}, opts.dtype(at::kFloat)); + auto out_accum = torch::empty({num_heads, num_splits, total_q, head_size_v}, opts.dtype(at::kFloat)); + auto args = get_ck_fmha_varlen_fwd_splitkv_args( has_lse, @@ -581,7 +582,7 @@ mha_varlen_fwd(at::Tensor &q, // [total_q, hq, d] out.zero_(); softmax_lse.fill_(std::numeric_limits::infinity()); } - + return {out, softmax_lse, p, rng_state}; }