From e76323a45ef74ca627f0038c6633da07872c45b5 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 8 Mar 2024 00:36:00 -0500 Subject: [PATCH] [Runtime] PagedKVCache execute data copy on a separate stream This PR enhances PagedKVCache with the copy stream separation. In detail, for CUDA and ROCm backend, we create a standalone copy stream for the copy of auxiliary data structure from CPU to GPU. Furthermore, we move the copy from BeginForward to Attention, which means it's no longer eagerly executed, instead, becoming lazily executed when Attention computation is needed. By making these changes, we are able to overlap the auxiliary data copy time (on the copy stream) with the model forward computation that happens before the first Attention. As a result, we can hide some of the copy latency. This PR also bumps the version of FlashInfer for the copy stream support. --- 3rdparty/flashinfer | 2 +- src/runtime/relax_vm/paged_kv_cache.cc | 161 ++++++++++++++++--------- 2 files changed, 106 insertions(+), 57 deletions(-) diff --git a/3rdparty/flashinfer b/3rdparty/flashinfer index f1f6a0de4e59..0d04571b614c 160000 --- a/3rdparty/flashinfer +++ b/3rdparty/flashinfer @@ -1 +1 @@ -Subproject commit f1f6a0de4e595b777e29cc0dc370c15bd1d736fb +Subproject commit 0d04571b614c944b5831d080882107a98b9c6e65 diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index 6dec511f2f88..fb22d20fcfc7 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -242,7 +242,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { //------------------------------------------- /*! * \brief A boolean flag indicating if the auxiliary arrays are dirty. - * If it is dirty, an explicit "SyncAuxArrayToDevice" should be invoked. + * If it is dirty, an explicit "ComputeStreamWaitForCopyStream" should be invoked. */ bool dirty_aux_data_device_ = false; /*! \brief The batch size of the current round of forwarding. */ @@ -285,6 +285,20 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { NDArray merged_attn_scores_device_; std::vector temp_attn_workspace_; + //------------------------------------------- + // Below are the auxiliary data structure on CPU. + // We make them class members to avoid repetitive allocation time in BeginForward. + //------------------------------------------- + std::vector> qo_indptr_on_depths_host_; + std::vector> page_indptr_on_depths_host_; + std::vector> page_indices_on_depths_host_; + std::vector> last_page_len_on_depths_host_; + std::vector> k_rope_pos_offset_on_depths_host_; + std::vector k_ragged_rope_pos_offset_host_; + std::vector q_rope_position_map_host_; + std::vector append_position_map_host_; + std::vector cur_append_lengths_indptr_host_; + //------------------------------------------- // For efficient memory management, the actual sizes of the arrays // above are over allocated. @@ -328,6 +342,12 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { std::vector use_decode_kernel_; /*! \brief Whether the attention request is a decode request, set in BeginForwardFunction. */ bool is_decode_request_; + /*! \brief The device this PagedKVCache runs on. */ + DLDevice device_; + /*! \brief The device stream for the default computation operations. */ + TVMStreamHandle compute_stream_ = nullptr; + /*! \brief The device stream for copying auxiliary data structure to GPU. */ + TVMStreamHandle copy_stream_ = nullptr; public: /*! \brief Constructor. Take the cache configuration and initialize the NDArrays. */ @@ -370,7 +390,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { f_merge_inplace_(std::move(f_merge_inplace)), f_split_rotary_(std::move(f_split_rotary)), f_rotary_inplace_(std::move(f_rotary_inplace)), - f_debug_get_kv_(std::move(f_debug_get_kv)) { + f_debug_get_kv_(std::move(f_debug_get_kv)), + device_(device) { pages_.reserve(num_layers); for (int i = 0; i < num_layers; ++i) { pages_.push_back( @@ -417,6 +438,22 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { for (int64_t page_id = num_total_pages - 1; page_id >= 0; --page_id) { free_page_ids_.push_back(page_id); } + + // The compute stream is the default stream. + // If the device is CUDA/ROCm, we create a standalone copy stream, in + // purpose to hide the latency of auxiliary stream copy. + compute_stream_ = DeviceAPI::Get(device)->GetCurrentStream(device); + if (device.device_type == DLDeviceType::kDLCUDA || + device.device_type == DLDeviceType::kDLROCM) { + copy_stream_ = DeviceAPI::Get(device)->CreateStream(device); + } + } + + ~PagedAttentionKVCacheObj() { + // Free the copy stream if defined. + if (copy_stream_ != nullptr) { + DeviceAPI::Get(device_)->FreeStream(device_, copy_stream_); + } } /*! \brief Reset the KV cache. */ @@ -522,16 +559,15 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // - Collect sequence/block/page information for attention. std::vector sequences; - std::vector k_ragged_rope_pos_offset; is_decode_request_ = true; sequences.reserve(cur_batch_size_); - k_ragged_rope_pos_offset.reserve(cur_batch_size_); + k_ragged_rope_pos_offset_host_.resize(cur_batch_size_); for (int i = 0; i < cur_batch_size_; ++i) { auto it = seq_map_.find(seq_ids[i]); CHECK(it != seq_map_.end()) << "The sequence \"" << seq_ids[i] << "\" cannot be found in KV cache."; sequences.push_back(&it->second); - k_ragged_rope_pos_offset.push_back(it->second.seq_length); + k_ragged_rope_pos_offset_host_[i] = it->second.seq_length; it->second.seq_length += append_lengths[i]; if (append_lengths[i] != 1) { is_decode_request_ = false; @@ -561,18 +597,25 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } } - std::vector> qo_indptr_on_depths; - std::vector> page_indptr_on_depths; - std::vector> page_indices_on_depths; - std::vector> last_page_len_on_depths; - std::vector> k_rope_pos_offset_on_depths; + qo_indptr_on_depths_host_.resize(num_depths_); + page_indptr_on_depths_host_.resize(num_depths_); + page_indices_on_depths_host_.resize(num_depths_); + last_page_len_on_depths_host_.resize(num_depths_); + k_rope_pos_offset_on_depths_host_.resize(num_depths_); for (int d = 0; d < num_depths_; ++d) { - std::vector qo_indptr_h{0}; - std::vector page_indptr_h{0}; - std::vector page_indices_h; - std::vector last_page_len_h; - std::vector k_rope_pos_offset_h; + std::vector& qo_indptr_h = qo_indptr_on_depths_host_[d]; + std::vector& page_indptr_h = page_indptr_on_depths_host_[d]; + std::vector& page_indices_h = page_indices_on_depths_host_[d]; + std::vector& last_page_len_h = last_page_len_on_depths_host_[d]; + std::vector& k_rope_pos_offset_h = k_rope_pos_offset_on_depths_host_[d]; + qo_indptr_h.clear(); + page_indptr_h.clear(); + page_indices_h.clear(); + last_page_len_h.clear(); + k_rope_pos_offset_h.clear(); + qo_indptr_h.push_back(0); + page_indptr_h.push_back(0); for (const auto& [block_id, chunk_append_length] : chunked_block_ids_arr[d]) { qo_indptr_h.push_back(qo_indptr_h.back() + chunk_append_length); if (block_id == -1) { @@ -588,11 +631,6 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { k_rope_pos_offset_h.push_back(block.start_pos); } } - qo_indptr_on_depths.push_back(qo_indptr_h); - page_indptr_on_depths.push_back(page_indptr_h); - page_indices_on_depths.push_back(page_indices_h); - last_page_len_on_depths.push_back(last_page_len_h); - k_rope_pos_offset_on_depths.push_back(k_rope_pos_offset_h); } if (!append_before_attn_) { @@ -606,28 +644,18 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // Map each the token position in the input batch to the position // in the global KV cache. The mapping is used in when appending k/v values. - std::vector q_rope_position_map; - std::vector append_position_map; + q_rope_position_map_host_.clear(); + append_position_map_host_.clear(); for (int i = 0; i < cur_batch_size_; ++i) { int64_t append_length = append_lengths[i]; const Block& block = global_block_pool_[sequences[i]->last_block_idx]; for (int64_t pos = 0; pos < append_length; ++pos) { int64_t pos_in_block = block.seq_length - append_length + pos; - q_rope_position_map.push_back(sequences[i]->seq_length - append_length + pos); - append_position_map.push_back(block.page_ids[pos_in_block / page_size_] * page_size_ + - pos_in_block % page_size_); + q_rope_position_map_host_.push_back(sequences[i]->seq_length - append_length + pos); + append_position_map_host_.push_back(block.page_ids[pos_in_block / page_size_] * page_size_ + + pos_in_block % page_size_); } } - - // - Sync NDArrays to GPU. - SyncAuxArrayToDevice(std::move(qo_indptr_on_depths), std::move(page_indptr_on_depths), - std::move(page_indices_on_depths), std::move(last_page_len_on_depths), - std::move(k_rope_pos_offset_on_depths), - std::move(k_ragged_rope_pos_offset), std::move(q_rope_position_map), - std::move(append_position_map)); - - // NOTE(Zihao): This logic is problematic ATM because we need a unique split per depth - KernelBeginForward(); } void EndForward() final { @@ -635,9 +663,6 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { !f_attention_prefill_ragged_end_forward_.defined()) { return; } - // Mark the dirty flag as true, so that BeginForward is required - // to be invoked before the next round of model forward. - dirty_aux_data_device_ = true; f_attention_prefill_ragged_end_forward_.value()(); for (int d = 0; d < num_depths_; ++d) { f_attention_prefill_end_forward_.value()(d); @@ -681,10 +706,10 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { total_seq_length += cur_append_lengths_[seq_id]; } CHECK_EQ(total_seq_length, q_data->shape[0]); + // Sync the copy stream and the compute stream. + ComputeStreamWaitForCopyStream(); // The auxiliary data structure on device must have been synchronized. - CHECK(!dirty_aux_data_device_) - << "The auxiliary arrays are not synchronized to device. Please call " - "`BeginForward` to synchronize before calling `Attention`."; + ICHECK(!dirty_aux_data_device_); if (rope_mode_ == RoPEMode::kNormal) { // Apply rotary embedding to q/k data. @@ -726,10 +751,10 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { total_seq_length += cur_append_lengths_[seq_id]; } CHECK_EQ(total_seq_length, qkv_data->shape[0]); + // Sync the copy stream and the compute stream. + ComputeStreamWaitForCopyStream(); // The auxiliary data structure on device must have been synchronized. - CHECK(!dirty_aux_data_device_) - << "The auxiliary arrays are not synchronized to device. Please call " - "`BeginForward` to synchronize before calling `Attention`."; + ICHECK(!dirty_aux_data_device_); NDArray q_data = temp_attn_q_device_.CreateView({total_seq_length, num_qo_heads_, head_dim_}, qkv_data->dtype); @@ -965,11 +990,11 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { f_attention_decode_begin_forward_.value()( /*depth=*/0, temp_attn_workspace_[1], page_indptr_on_depths_view_[0], last_page_len_on_depths_view_[0], num_qo_heads_, num_kv_heads_, head_dim_, page_size_, - /*rotary_mode=*/rope_mode_ == RoPEMode::kInline); + /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, copy_stream_); } else { f_attention_prefill_ragged_begin_forward_.value()( temp_attn_workspace_[0], cur_append_length_indptr_view_, cur_batch_size_, num_qo_heads_, - num_kv_heads_); + num_kv_heads_, head_dim_, copy_stream_); for (int d = 0; d < num_depths_; ++d) { if (page_indices_on_depths_view_[d]->shape[0] == 0) { continue; @@ -978,11 +1003,12 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { f_attention_decode_begin_forward_.value()( d, temp_attn_workspace_[d + 1], page_indptr_on_depths_view_[d], last_page_len_on_depths_view_[d], num_qo_heads_, num_kv_heads_, head_dim_, page_size_, - /*rotary_mode=*/rope_mode_ == RoPEMode::kInline); + /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, copy_stream_); } else { f_attention_prefill_begin_forward_.value()( /*depth=*/d, temp_attn_workspace_[d + 1], qo_indptr_on_depths_view_[d], - last_page_len_on_depths_view_[d]->shape[0], num_qo_heads_, num_kv_heads_); + last_page_len_on_depths_view_[d]->shape[0], num_qo_heads_, num_kv_heads_, head_dim_, + copy_stream_); } } } @@ -1041,6 +1067,28 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } } + /*! \brief Synchronize the copy stream and the compute stream. */ + void ComputeStreamWaitForCopyStream() { + if (!dirty_aux_data_device_) { + // If the auxiliary data is already synced, return and no need to sync again. + return; + } + // - Sync NDArrays to GPU. + SyncAuxArrayToDevice(qo_indptr_on_depths_host_, page_indptr_on_depths_host_, + page_indices_on_depths_host_, last_page_len_on_depths_host_, + k_rope_pos_offset_on_depths_host_, k_ragged_rope_pos_offset_host_, + q_rope_position_map_host_, append_position_map_host_); + KernelBeginForward(); + // - Clear the dirty flag. + dirty_aux_data_device_ = false; + // - If there is no particular copy stream, no action is needed. + if (copy_stream_ == nullptr) { + return; + } + // - Sync two streams. + DeviceAPI::Get(device_)->SyncStreamFromTo(device_, copy_stream_, compute_stream_); + } + /*! * \brief Synchronize auxiliary arrays to device. * \note This method resets the dirty flag to false, and needs to be @@ -1061,15 +1109,16 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { ICHECK_EQ(last_page_len_on_depths.size(), num_depths_); int64_t total_append_length = 0; int num_sequences = cur_append_lengths_.size(); - std::vector cur_append_lengths_indptr{0}; - for (int i = 0; i < static_cast(cur_append_lengths_.size()); ++i) { - cur_append_lengths_indptr.push_back(cur_append_lengths_indptr.back() + - cur_append_lengths_[i]); + cur_append_lengths_indptr_host_.resize(num_sequences + 1); + cur_append_lengths_indptr_host_[0] = 0; + for (int i = 0; i < num_sequences; ++i) { + cur_append_lengths_indptr_host_[i + 1] = + cur_append_lengths_indptr_host_[i] + cur_append_lengths_[i]; } - total_append_length = cur_append_lengths_indptr.back(); + total_append_length = cur_append_lengths_indptr_host_.back(); ICHECK_EQ(total_append_length, append_position_map.size()); - auto fcopy_from_vec = [](NDArray array, int32_t* vec_data) { + auto fcopy_from_vec = [copy_stream = this->copy_stream_](NDArray array, int32_t* vec_data) { DLTensor copy_dst = *array.operator->(); DLTensor copy_src; copy_src.data = vec_data; @@ -1079,7 +1128,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { copy_src.shape = array->shape; copy_src.strides = nullptr; copy_src.byte_offset = 0; - NDArray::CopyFromTo(©_src, ©_dst); + NDArray::CopyFromTo(©_src, ©_dst, copy_stream); }; // 1. qo_indptr_on_depths @@ -1126,7 +1175,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // 6. cur_append_lengths_indptr cur_append_length_indptr_view_ = cur_append_length_indptr_device_.CreateView({num_sequences + 1}, dtype_aux_); - fcopy_from_vec(cur_append_length_indptr_view_, cur_append_lengths_indptr.data()); + fcopy_from_vec(cur_append_length_indptr_view_, cur_append_lengths_indptr_host_.data()); // 7. k_ragged_rope_pos_offset ICHECK_EQ(k_ragged_rope_pos_offset.size(), num_sequences);