From 4351d36fa8b0230ed8b69ed69a7e702447842411 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 9 Jul 2024 14:02:34 -0400 Subject: [PATCH] [Runtime] Reorganize PagedKVCache attn kernel invocation This PR reorganizes the attention kernel invocation logic in the PagedKVCache, so that in cases of sequence fork, we can effectively merge one ragged-prefill kernel and a decode kernel into a single decode kernel. --- src/relax/transform/fuse_ops.cc | 2 +- src/runtime/relax_vm/paged_kv_cache.cc | 127 +++++++++++++------------ 2 files changed, 65 insertions(+), 64 deletions(-) diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index e791aeab061d..85c739e08353 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -646,7 +646,7 @@ class FunctionCreator : public ExprMutator { return tvm::tir::UndefinedVars(prim_value->value).empty(); } else if (const auto* shape_expr = expr.as()) { return std::all_of(shape_expr->values.begin(), shape_expr->values.end(), - [this](const PrimExpr& e) { return tvm::tir::UndefinedVars(e).empty(); }); + [](const PrimExpr& e) { return tvm::tir::UndefinedVars(e).empty(); }); } return false; } diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index 5aa1411ec154..cf5de97202cc 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -1535,7 +1535,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { CHECK_EQ(chunked_block_ids_arr[num_depths_ - 1].size(), cur_batch_size_); } - append_before_attn_ = !support_sliding_window_ && num_depths_ == 1 && use_decode_kernel_[0]; + append_before_attn_ = !support_sliding_window_ && use_decode_kernel_.back(); if (NeedKernelBeginForward() && num_qo_heads_ / num_kv_heads_ >= 4) { // When GQA group size is at least 4 and FlashInfer is enabled, // we always use prefill kernel for better performance. @@ -2220,39 +2220,33 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { return; } - if (append_before_attn_) { - if (!support_sliding_window_) { + if (!append_before_attn_) { + if (is_chain_) { + f_attention_prefill_ragged_begin_forward_.value()( + temp_attn_workspace_[0], cur_append_lengths_indptr_host_.as_ndarray(), + cur_append_lengths_indptr_host_.as_ndarray(), cur_batch_size_, num_qo_heads_, + num_kv_heads_, head_dim_, copy_stream_); + } else { + LOG(FATAL) << "Kernel BeginForward doesn't support tree attn."; + } + } + for (int d = 0; d < num_depths_; ++d) { + if (page_indices_on_depths_view_[d]->shape[0] == 0) { + continue; + } + CHECK(!support_sliding_window_) << "Kernel BeginForward doesn't support sliding window."; + if (use_decode_kernel_[d]) { f_attention_decode_begin_forward_.value()( - /*depth=*/0, temp_attn_workspace_[1], page_indptr_on_depths_host_[0].as_ndarray(), - last_page_len_on_depths_host_[0].as_ndarray(), num_qo_heads_, num_kv_heads_, head_dim_, + d, temp_attn_workspace_[d + 1], page_indptr_on_depths_host_[d].as_ndarray(), + last_page_len_on_depths_host_[d].as_ndarray(), num_qo_heads_, num_kv_heads_, head_dim_, page_size_, /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, copy_stream_); - } - } else { - f_attention_prefill_ragged_begin_forward_.value()( - temp_attn_workspace_[0], cur_append_lengths_indptr_host_.as_ndarray(), - cur_append_lengths_indptr_host_.as_ndarray(), cur_batch_size_, num_qo_heads_, - num_kv_heads_, head_dim_, copy_stream_); - if (support_sliding_window_) { - return; - } - for (int d = 0; d < num_depths_; ++d) { - if (page_indices_on_depths_view_[d]->shape[0] == 0) { - continue; - } - if (use_decode_kernel_[d]) { - f_attention_decode_begin_forward_.value()( - d, temp_attn_workspace_[d + 1], page_indptr_on_depths_host_[d].as_ndarray(), - last_page_len_on_depths_host_[d].as_ndarray(), num_qo_heads_, num_kv_heads_, - head_dim_, page_size_, - /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, copy_stream_); - } else { - f_attention_prefill_begin_forward_.value()( - /*depth=*/d, temp_attn_workspace_[d + 1], qo_indptr_on_depths_host_[d].as_ndarray(), - page_indptr_on_depths_host_[d].as_ndarray(), - static_cast(page_indptr_on_depths_host_[d].size()) - 1, num_qo_heads_, - num_kv_heads_, head_dim_, page_size_, copy_stream_); - } + } else { + f_attention_prefill_begin_forward_.value()( + /*depth=*/d, temp_attn_workspace_[d + 1], qo_indptr_on_depths_host_[d].as_ndarray(), + page_indptr_on_depths_host_[d].as_ndarray(), + static_cast(page_indptr_on_depths_host_[d].size()) - 1, num_qo_heads_, + num_kv_heads_, head_dim_, page_size_, copy_stream_); } } } @@ -2271,15 +2265,11 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { PackedFunc f_decode = !support_sliding_window_ ? f_attention_decode_ : f_attention_decode_sliding_window_; CHECK_GE(num_depths_, 1) << "The number of effective depths must be greater or equal to 1."; - if (append_before_attn_) { - f_decode( - /*depth=*/0, q_data, pages_[local_layer_id], page_indptr_on_depths_view_[0], - page_indices_on_depths_view_[0], length_info_on_depths_view_[0], - k_rope_pos_offset_view_[0], q_rope_position_map_view_, output, merged_attn_scores_view_, - /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, rotary_theta_, - attn_score_scaling_factor); - } else { - // Compute appended text self-attention + + bool is_first_kernel = true; + if (!append_before_attn_) { + // The first part of attention, which only involves the q and the newly appended k/v. + is_first_kernel = false; if (is_chain_) { // If the batch does not form a tree, use raggedness prefill kernel. f_attention_prefill_ragged_(q_data, cur_append_length_indptr_view_, k_data, v_data, @@ -2301,32 +2291,43 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { merged_attn_scores_view_, /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, rotary_theta_, attn_score_scaling_factor, cur_batch_size_); } + } - for (int d = 0; d < num_depths_; ++d) { - if (page_indices_on_depths_view_[d]->shape[0] == 0) { - continue; - } - if (use_decode_kernel_[d]) { - // Use decode kernel for depth d - f_decode(/*depth=*/d, q_data, pages_[local_layer_id], page_indptr_on_depths_view_[d], - page_indices_on_depths_view_[d], length_info_on_depths_view_[d], - k_rope_pos_offset_view_[d], q_rope_position_map_view_, temp_attn_output_view_, - temp_attn_scores_view_, - /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, rotary_theta_, - attn_score_scaling_factor); - } else { - // Use prefill kernel for depth d - f_prefill( - /*depth=*/d, q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], - page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d], - length_info_on_depths_view_[d], k_rope_pos_offset_view_[d], q_rope_position_map_view_, - temp_attn_output_view_, temp_attn_scores_view_, - /*causal=*/0, - /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, rotary_theta_, - attn_score_scaling_factor); - } + for (int d = 0; d < num_depths_; ++d) { + if (page_indices_on_depths_view_[d]->shape[0] == 0) { + continue; + } + NDArray attn_output; + NDArray attn_scores; + if (is_first_kernel) { + attn_output = output; + attn_scores = merged_attn_scores_view_; + } else { + attn_output = temp_attn_output_view_; + attn_scores = temp_attn_scores_view_; + } + if (use_decode_kernel_[d]) { + // Use decode kernel for depth d + f_decode(/*depth=*/d, q_data, pages_[local_layer_id], page_indptr_on_depths_view_[d], + page_indices_on_depths_view_[d], length_info_on_depths_view_[d], + k_rope_pos_offset_view_[d], q_rope_position_map_view_, attn_output, attn_scores, + /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, rotary_theta_, + attn_score_scaling_factor); + } else { + // Use prefill kernel for depth d + f_prefill(/*depth=*/d, q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], + page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d], + length_info_on_depths_view_[d], k_rope_pos_offset_view_[d], + q_rope_position_map_view_, attn_output, attn_scores, /*causal=*/0, + /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, rotary_theta_, + attn_score_scaling_factor); + } + + if (!is_first_kernel) { f_merge_inplace_(output, merged_attn_scores_view_, temp_attn_output_view_, temp_attn_scores_view_); + } else { + is_first_kernel = false; } } }