Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions csrc/py_itfs_ck/mha_varlen_fwd_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
bias_ptr = alibi_slopes.data_ptr();
stride_bias = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
}

return fmha_fwd_args{q.data_ptr(),
k.data_ptr(),
v.data_ptr(),
Expand Down Expand Up @@ -468,19 +468,11 @@ mha_varlen_fwd(at::Tensor &q, // [total_q, hq, d]
if (return_dropout_randval) {p.zero_();}
}

int num_splits = 0;
num_splits = aiter::override_num_splits_if_necessary(batch_size, num_heads, max_seqlen_q, head_size_v, 0, num_splits);
TORCH_CHECK(num_splits > 0, "num_splits should greater than 0");
TORCH_CHECK(num_splits <= 128, "num_splits greater than 128 is not supported");

auto softmax_lse_accum = torch::empty({num_heads, num_splits, total_q}, opts.dtype(at::kFloat));
auto out_accum = torch::empty({num_heads, num_splits, total_q, head_size_v}, opts.dtype(at::kFloat));

int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
auto rng_state = torch::empty({2}, opts.dtype(torch::kInt64));
auto rng_state_ptr = reinterpret_cast<uint64_t*>(rng_state.data_ptr());

if (p_dropout > 0.0) {
int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
// See Note [Acquire lock when using random generators]
Expand All @@ -490,12 +482,21 @@ mha_varlen_fwd(at::Tensor &q, // [total_q, hq, d]
aiter::ParsePhiloxCudaState, dim3(1), dim3(64), 0, 0, philox_args, rng_state_ptr);
}
std::optional<const at::Tensor> seqlens_k = std::nullopt;

if (max_seqlen_k > 0) {
auto stream = at::cuda::getCurrentHIPStream().stream();
ck_tile::stream_config stream_config{stream};

if (paged_KV)
{
int num_splits = 0;
num_splits = aiter::override_num_splits_if_necessary(batch_size, num_heads, max_seqlen_q, head_size_v, 0, num_splits);
TORCH_CHECK(num_splits > 0, "num_splits should greater than 0");
TORCH_CHECK(num_splits <= 128, "num_splits greater than 128 is not supported");

auto softmax_lse_accum = torch::empty({num_heads, num_splits, total_q}, opts.dtype(at::kFloat));
auto out_accum = torch::empty({num_heads, num_splits, total_q, head_size_v}, opts.dtype(at::kFloat));

auto args =
get_ck_fmha_varlen_fwd_splitkv_args(
has_lse,
Expand Down Expand Up @@ -581,7 +582,7 @@ mha_varlen_fwd(at::Tensor &q, // [total_q, hq, d]
out.zero_();
softmax_lse.fill_(std::numeric_limits<float>::infinity());
}

return {out, softmax_lse, p, rng_state};
}

Expand Down