diff --git a/3rdparty/flashinfer b/3rdparty/flashinfer index f1f6a0de4e59..0d04571b614c 160000 --- a/3rdparty/flashinfer +++ b/3rdparty/flashinfer @@ -1 +1 @@ -Subproject commit f1f6a0de4e595b777e29cc0dc370c15bd1d736fb +Subproject commit 0d04571b614c944b5831d080882107a98b9c6e65 diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index 6dec511f2f88..fb22d20fcfc7 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -242,7 +242,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { //------------------------------------------- /*! * \brief A boolean flag indicating if the auxiliary arrays are dirty. - * If it is dirty, an explicit "SyncAuxArrayToDevice" should be invoked. + * If it is dirty, an explicit "ComputeStreamWaitForCopyStream" should be invoked. */ bool dirty_aux_data_device_ = false; /*! \brief The batch size of the current round of forwarding. */ @@ -285,6 +285,20 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { NDArray merged_attn_scores_device_; std::vector temp_attn_workspace_; + //------------------------------------------- + // Below are the auxiliary data structure on CPU. + // We make them class members to avoid repetitive allocation time in BeginForward. + //------------------------------------------- + std::vector> qo_indptr_on_depths_host_; + std::vector> page_indptr_on_depths_host_; + std::vector> page_indices_on_depths_host_; + std::vector> last_page_len_on_depths_host_; + std::vector> k_rope_pos_offset_on_depths_host_; + std::vector k_ragged_rope_pos_offset_host_; + std::vector q_rope_position_map_host_; + std::vector append_position_map_host_; + std::vector cur_append_lengths_indptr_host_; + //------------------------------------------- // For efficient memory management, the actual sizes of the arrays // above are over allocated. @@ -328,6 +342,12 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { std::vector use_decode_kernel_; /*! \brief Whether the attention request is a decode request, set in BeginForwardFunction. */ bool is_decode_request_; + /*! \brief The device this PagedKVCache runs on. */ + DLDevice device_; + /*! \brief The device stream for the default computation operations. */ + TVMStreamHandle compute_stream_ = nullptr; + /*! \brief The device stream for copying auxiliary data structure to GPU. */ + TVMStreamHandle copy_stream_ = nullptr; public: /*! \brief Constructor. Take the cache configuration and initialize the NDArrays. */ @@ -370,7 +390,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { f_merge_inplace_(std::move(f_merge_inplace)), f_split_rotary_(std::move(f_split_rotary)), f_rotary_inplace_(std::move(f_rotary_inplace)), - f_debug_get_kv_(std::move(f_debug_get_kv)) { + f_debug_get_kv_(std::move(f_debug_get_kv)), + device_(device) { pages_.reserve(num_layers); for (int i = 0; i < num_layers; ++i) { pages_.push_back( @@ -417,6 +438,22 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { for (int64_t page_id = num_total_pages - 1; page_id >= 0; --page_id) { free_page_ids_.push_back(page_id); } + + // The compute stream is the default stream. + // If the device is CUDA/ROCm, we create a standalone copy stream, in + // purpose to hide the latency of auxiliary stream copy. + compute_stream_ = DeviceAPI::Get(device)->GetCurrentStream(device); + if (device.device_type == DLDeviceType::kDLCUDA || + device.device_type == DLDeviceType::kDLROCM) { + copy_stream_ = DeviceAPI::Get(device)->CreateStream(device); + } + } + + ~PagedAttentionKVCacheObj() { + // Free the copy stream if defined. + if (copy_stream_ != nullptr) { + DeviceAPI::Get(device_)->FreeStream(device_, copy_stream_); + } } /*! \brief Reset the KV cache. */ @@ -522,16 +559,15 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // - Collect sequence/block/page information for attention. std::vector sequences; - std::vector k_ragged_rope_pos_offset; is_decode_request_ = true; sequences.reserve(cur_batch_size_); - k_ragged_rope_pos_offset.reserve(cur_batch_size_); + k_ragged_rope_pos_offset_host_.resize(cur_batch_size_); for (int i = 0; i < cur_batch_size_; ++i) { auto it = seq_map_.find(seq_ids[i]); CHECK(it != seq_map_.end()) << "The sequence \"" << seq_ids[i] << "\" cannot be found in KV cache."; sequences.push_back(&it->second); - k_ragged_rope_pos_offset.push_back(it->second.seq_length); + k_ragged_rope_pos_offset_host_[i] = it->second.seq_length; it->second.seq_length += append_lengths[i]; if (append_lengths[i] != 1) { is_decode_request_ = false; @@ -561,18 +597,25 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } } - std::vector> qo_indptr_on_depths; - std::vector> page_indptr_on_depths; - std::vector> page_indices_on_depths; - std::vector> last_page_len_on_depths; - std::vector> k_rope_pos_offset_on_depths; + qo_indptr_on_depths_host_.resize(num_depths_); + page_indptr_on_depths_host_.resize(num_depths_); + page_indices_on_depths_host_.resize(num_depths_); + last_page_len_on_depths_host_.resize(num_depths_); + k_rope_pos_offset_on_depths_host_.resize(num_depths_); for (int d = 0; d < num_depths_; ++d) { - std::vector qo_indptr_h{0}; - std::vector page_indptr_h{0}; - std::vector page_indices_h; - std::vector last_page_len_h; - std::vector k_rope_pos_offset_h; + std::vector& qo_indptr_h = qo_indptr_on_depths_host_[d]; + std::vector& page_indptr_h = page_indptr_on_depths_host_[d]; + std::vector& page_indices_h = page_indices_on_depths_host_[d]; + std::vector& last_page_len_h = last_page_len_on_depths_host_[d]; + std::vector& k_rope_pos_offset_h = k_rope_pos_offset_on_depths_host_[d]; + qo_indptr_h.clear(); + page_indptr_h.clear(); + page_indices_h.clear(); + last_page_len_h.clear(); + k_rope_pos_offset_h.clear(); + qo_indptr_h.push_back(0); + page_indptr_h.push_back(0); for (const auto& [block_id, chunk_append_length] : chunked_block_ids_arr[d]) { qo_indptr_h.push_back(qo_indptr_h.back() + chunk_append_length); if (block_id == -1) { @@ -588,11 +631,6 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { k_rope_pos_offset_h.push_back(block.start_pos); } } - qo_indptr_on_depths.push_back(qo_indptr_h); - page_indptr_on_depths.push_back(page_indptr_h); - page_indices_on_depths.push_back(page_indices_h); - last_page_len_on_depths.push_back(last_page_len_h); - k_rope_pos_offset_on_depths.push_back(k_rope_pos_offset_h); } if (!append_before_attn_) { @@ -606,28 +644,18 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // Map each the token position in the input batch to the position // in the global KV cache. The mapping is used in when appending k/v values. - std::vector q_rope_position_map; - std::vector append_position_map; + q_rope_position_map_host_.clear(); + append_position_map_host_.clear(); for (int i = 0; i < cur_batch_size_; ++i) { int64_t append_length = append_lengths[i]; const Block& block = global_block_pool_[sequences[i]->last_block_idx]; for (int64_t pos = 0; pos < append_length; ++pos) { int64_t pos_in_block = block.seq_length - append_length + pos; - q_rope_position_map.push_back(sequences[i]->seq_length - append_length + pos); - append_position_map.push_back(block.page_ids[pos_in_block / page_size_] * page_size_ + - pos_in_block % page_size_); + q_rope_position_map_host_.push_back(sequences[i]->seq_length - append_length + pos); + append_position_map_host_.push_back(block.page_ids[pos_in_block / page_size_] * page_size_ + + pos_in_block % page_size_); } } - - // - Sync NDArrays to GPU. - SyncAuxArrayToDevice(std::move(qo_indptr_on_depths), std::move(page_indptr_on_depths), - std::move(page_indices_on_depths), std::move(last_page_len_on_depths), - std::move(k_rope_pos_offset_on_depths), - std::move(k_ragged_rope_pos_offset), std::move(q_rope_position_map), - std::move(append_position_map)); - - // NOTE(Zihao): This logic is problematic ATM because we need a unique split per depth - KernelBeginForward(); } void EndForward() final { @@ -635,9 +663,6 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { !f_attention_prefill_ragged_end_forward_.defined()) { return; } - // Mark the dirty flag as true, so that BeginForward is required - // to be invoked before the next round of model forward. - dirty_aux_data_device_ = true; f_attention_prefill_ragged_end_forward_.value()(); for (int d = 0; d < num_depths_; ++d) { f_attention_prefill_end_forward_.value()(d); @@ -681,10 +706,10 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { total_seq_length += cur_append_lengths_[seq_id]; } CHECK_EQ(total_seq_length, q_data->shape[0]); + // Sync the copy stream and the compute stream. + ComputeStreamWaitForCopyStream(); // The auxiliary data structure on device must have been synchronized. - CHECK(!dirty_aux_data_device_) - << "The auxiliary arrays are not synchronized to device. Please call " - "`BeginForward` to synchronize before calling `Attention`."; + ICHECK(!dirty_aux_data_device_); if (rope_mode_ == RoPEMode::kNormal) { // Apply rotary embedding to q/k data. @@ -726,10 +751,10 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { total_seq_length += cur_append_lengths_[seq_id]; } CHECK_EQ(total_seq_length, qkv_data->shape[0]); + // Sync the copy stream and the compute stream. + ComputeStreamWaitForCopyStream(); // The auxiliary data structure on device must have been synchronized. - CHECK(!dirty_aux_data_device_) - << "The auxiliary arrays are not synchronized to device. Please call " - "`BeginForward` to synchronize before calling `Attention`."; + ICHECK(!dirty_aux_data_device_); NDArray q_data = temp_attn_q_device_.CreateView({total_seq_length, num_qo_heads_, head_dim_}, qkv_data->dtype); @@ -965,11 +990,11 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { f_attention_decode_begin_forward_.value()( /*depth=*/0, temp_attn_workspace_[1], page_indptr_on_depths_view_[0], last_page_len_on_depths_view_[0], num_qo_heads_, num_kv_heads_, head_dim_, page_size_, - /*rotary_mode=*/rope_mode_ == RoPEMode::kInline); + /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, copy_stream_); } else { f_attention_prefill_ragged_begin_forward_.value()( temp_attn_workspace_[0], cur_append_length_indptr_view_, cur_batch_size_, num_qo_heads_, - num_kv_heads_); + num_kv_heads_, head_dim_, copy_stream_); for (int d = 0; d < num_depths_; ++d) { if (page_indices_on_depths_view_[d]->shape[0] == 0) { continue; @@ -978,11 +1003,12 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { f_attention_decode_begin_forward_.value()( d, temp_attn_workspace_[d + 1], page_indptr_on_depths_view_[d], last_page_len_on_depths_view_[d], num_qo_heads_, num_kv_heads_, head_dim_, page_size_, - /*rotary_mode=*/rope_mode_ == RoPEMode::kInline); + /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, copy_stream_); } else { f_attention_prefill_begin_forward_.value()( /*depth=*/d, temp_attn_workspace_[d + 1], qo_indptr_on_depths_view_[d], - last_page_len_on_depths_view_[d]->shape[0], num_qo_heads_, num_kv_heads_); + last_page_len_on_depths_view_[d]->shape[0], num_qo_heads_, num_kv_heads_, head_dim_, + copy_stream_); } } } @@ -1041,6 +1067,28 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } } + /*! \brief Synchronize the copy stream and the compute stream. */ + void ComputeStreamWaitForCopyStream() { + if (!dirty_aux_data_device_) { + // If the auxiliary data is already synced, return and no need to sync again. + return; + } + // - Sync NDArrays to GPU. + SyncAuxArrayToDevice(qo_indptr_on_depths_host_, page_indptr_on_depths_host_, + page_indices_on_depths_host_, last_page_len_on_depths_host_, + k_rope_pos_offset_on_depths_host_, k_ragged_rope_pos_offset_host_, + q_rope_position_map_host_, append_position_map_host_); + KernelBeginForward(); + // - Clear the dirty flag. + dirty_aux_data_device_ = false; + // - If there is no particular copy stream, no action is needed. + if (copy_stream_ == nullptr) { + return; + } + // - Sync two streams. + DeviceAPI::Get(device_)->SyncStreamFromTo(device_, copy_stream_, compute_stream_); + } + /*! * \brief Synchronize auxiliary arrays to device. * \note This method resets the dirty flag to false, and needs to be @@ -1061,15 +1109,16 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { ICHECK_EQ(last_page_len_on_depths.size(), num_depths_); int64_t total_append_length = 0; int num_sequences = cur_append_lengths_.size(); - std::vector cur_append_lengths_indptr{0}; - for (int i = 0; i < static_cast(cur_append_lengths_.size()); ++i) { - cur_append_lengths_indptr.push_back(cur_append_lengths_indptr.back() + - cur_append_lengths_[i]); + cur_append_lengths_indptr_host_.resize(num_sequences + 1); + cur_append_lengths_indptr_host_[0] = 0; + for (int i = 0; i < num_sequences; ++i) { + cur_append_lengths_indptr_host_[i + 1] = + cur_append_lengths_indptr_host_[i] + cur_append_lengths_[i]; } - total_append_length = cur_append_lengths_indptr.back(); + total_append_length = cur_append_lengths_indptr_host_.back(); ICHECK_EQ(total_append_length, append_position_map.size()); - auto fcopy_from_vec = [](NDArray array, int32_t* vec_data) { + auto fcopy_from_vec = [copy_stream = this->copy_stream_](NDArray array, int32_t* vec_data) { DLTensor copy_dst = *array.operator->(); DLTensor copy_src; copy_src.data = vec_data; @@ -1079,7 +1128,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { copy_src.shape = array->shape; copy_src.strides = nullptr; copy_src.byte_offset = 0; - NDArray::CopyFromTo(©_src, ©_dst); + NDArray::CopyFromTo(©_src, ©_dst, copy_stream); }; // 1. qo_indptr_on_depths @@ -1126,7 +1175,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // 6. cur_append_lengths_indptr cur_append_length_indptr_view_ = cur_append_length_indptr_device_.CreateView({num_sequences + 1}, dtype_aux_); - fcopy_from_vec(cur_append_length_indptr_view_, cur_append_lengths_indptr.data()); + fcopy_from_vec(cur_append_length_indptr_view_, cur_append_lengths_indptr_host_.data()); // 7. k_ragged_rope_pos_offset ICHECK_EQ(k_ragged_rope_pos_offset.size(), num_sequences);