From f21083ff933566f4c75337703e39b1f57acc2739 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Thu, 18 Apr 2024 10:05:07 -0400 Subject: [PATCH] [Runtime] Use preferred host memory (pinned memory) in KV cache This PR updates the PagedKVCache with the pinned memory support, which can reduce the copy overhead between CPU and GPU. This PR also bumps FlashInfer version, which now supports * specifying kernels to build via cmake, * pinned memory as host memory. We also update CMakeLists.txt and config.cmake to include the FlashInfer compile options. Prior to this PR, the kernels being built is hardcoded in FlashInfer header files. --- 3rdparty/flashinfer | 2 +- CMakeLists.txt | 6 +- cmake/config.cmake | 13 ++ include/tvm/runtime/ndarray.h | 17 ++ src/runtime/relax_vm/paged_kv_cache.cc | 265 ++++++++++++++++--------- 5 files changed, 205 insertions(+), 98 deletions(-) diff --git a/3rdparty/flashinfer b/3rdparty/flashinfer index f978e02565d7..7e9cc7ff42ca 160000 --- a/3rdparty/flashinfer +++ b/3rdparty/flashinfer @@ -1 +1 @@ -Subproject commit f978e02565d7157d57803eb4153369e046fc4106 +Subproject commit 7e9cc7ff42ca283c317061a877305d09a395fad2 diff --git a/CMakeLists.txt b/CMakeLists.txt index 683ce819dbdb..7575d6c2b4d6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -960,13 +960,13 @@ option(USE_FLASHINFER "Build TVM with FlashInfer" OFF) if (USE_FLASHINFER STREQUAL "ON") message(STATUS "Build with FlashInfer") set(FLASHINFER_TVM_BINDING ON) - set(FLASHINFER_TVM_HOME ${PROJECT_SOURCE_DIR}) - set(FLASHINFER_ENABLE_FP8 OFF) - set(FLASHINFER_ENABLE_BF16 OFF) + set(FLASHINFER_TVM_SOURCE_DIR ${PROJECT_SOURCE_DIR}) set(FLASHINFER_PREFILL OFF) set(FLASHINFER_DECODE OFF) set(FLASHINFER_PAGE OFF) set(FLASHINFER_CASCADE OFF) + set(FLASHINFER_SAMPLING OFF) + set(FLASHINFER_NORM OFF) add_subdirectory(3rdparty/flashinfer) else () message(STATUS "Build without FlashInfer") diff --git a/cmake/config.cmake b/cmake/config.cmake index ccb449fe2b23..5847acc298b1 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -444,6 +444,19 @@ set(USE_GTEST AUTO) # Need to have USE_CUDA=ON set(USE_CUTLASS OFF) +# Whether to enable FlashInfer or not. +set(USE_FLASHINFER OFF) +# Options for FlashInfer kernel compilation. +set(FLASHINFER_ENABLE_FP8 OFF) +set(FLASHINFER_ENABLE_BF16 OFF) +set(FLASHINFER_GEN_GROUP_SIZES 1 4 6 8) +set(FLASHINFER_GEN_PAGE_SIZES 16) +set(FLASHINFER_GEN_HEAD_DIMS 128) +set(FLASHINFER_GEN_KV_LAYOUTS 0 1) +set(FLASHINFER_GEN_POS_ENCODING_MODES 0 1) +set(FLASHINFER_GEN_ALLOW_FP16_QK_REDUCTIONS "false") +set(FLASHINFER_GEN_CASUALS "false" "true") + # Enable to show a summary of TVM options set(SUMMARIZE OFF) diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index 5bdc883649c9..3eb225fccffe 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -534,6 +534,23 @@ inline bool NDArray::Load(dmlc::Stream* strm) { return true; } +/*! + * \brief Get the preferred host device from the input device. + * - For CUDA and ROCm, CUDAHost and ROCMHost will be returned for pinned memory, + * since pinned memory reduces copy overhead. + * - For other devices, CPU is returned as a fallback. + */ +inline Device GetPreferredHostDevice(Device device) { + if (device.device_type == DLDeviceType::kDLCUDA) { + return Device{DLDeviceType::kDLCUDAHost, 0}; + } else if (device.device_type == DLDeviceType::kDLROCM) { + return Device{DLDeviceType::kDLROCMHost, 0}; + } else { + // Fallback to CPU. + return Device{DLDeviceType::kDLCPU, 0}; + } +} + } // namespace runtime } // namespace tvm diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index a5d2d9f41554..62750d6d7daa 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -194,6 +194,56 @@ enum class RoPEMode : int { kInline = 2, }; +/*! + * \brief The class of host memory int32 vector in "std::vector" interface. + * This vector allocates static memory on the specified host memory + * at the time of construction. + */ +class HostMemoryVector { + public: + HostMemoryVector() = default; + HostMemoryVector(const HostMemoryVector&) = delete; + HostMemoryVector(HostMemoryVector&& other) = default; + HostMemoryVector& operator=(const HostMemoryVector&) = delete; + HostMemoryVector& operator=(HostMemoryVector&& other) = default; + + explicit HostMemoryVector(int64_t reserved_size, DLDataType dtype, Device device) + : reserved_size_(reserved_size) { + ICHECK(DataType(dtype) == DataType::Int(32)); + data_ = NDArray::Empty({reserved_size}, dtype, device); + } + + void push_back(int32_t value) { + ICHECK_LT(current_size_, reserved_size_); + static_cast(data_->data)[current_size_++] = value; + } + + const int32_t& operator[](int64_t idx) const { + ICHECK_GE(idx, 0) << "Index " << idx << " is negative."; + ICHECK_LT(idx, current_size_) << "Index " << idx << " out of bounds " << current_size_; + return static_cast(data_->data)[idx]; + } + + int32_t back() const { + ICHECK_GT(current_size_, 0) << "Vector is empty"; + return static_cast(data_->data)[current_size_ - 1]; + } + + size_t size() const { return static_cast(current_size_); } + + int32_t* data() const { return static_cast(data_->data); } + + void clear() { current_size_ = 0; } + + /*! \brief Return the vector as an NDArray. */ + NDArray as_ndarray() { return data_.CreateView({current_size_}, data_->dtype); } + + private: + int64_t reserved_size_ = 0; + int64_t current_size_ = 0; + NDArray data_{nullptr}; +}; + /*! * \brief The paged attention auxiliary data manager class. * This class manages all the int32 auxiliary data on GPU device, such as @@ -213,8 +263,12 @@ enum class RoPEMode : int { */ class PagedKVCacheAuxDataManager { public: - PagedKVCacheAuxDataManager(DLDataType dtype_aux, Device device, TVMStreamHandle copy_stream) - : dtype_aux_(dtype_aux), device_(device), copy_stream_(copy_stream) { + PagedKVCacheAuxDataManager(DLDataType dtype_aux, Device device, Device preferred_host_device, + TVMStreamHandle copy_stream) + : dtype_aux_(dtype_aux), + device_(device), + preferred_host_device_(preferred_host_device), + copy_stream_(copy_stream) { ICHECK(DataType(dtype_aux) == DataType::Int(32)); } @@ -222,13 +276,13 @@ class PagedKVCacheAuxDataManager { /*! \brief Reset the status of copy manager. */ virtual void ResetCopy() = 0; /*! \brief Copy the indptr array of append lengths after coalescing. (see GetChunkedBlockIds) */ - virtual NDArray CopyQOIndptrOnDepthAsync(std::vector* data, int depth) = 0; + virtual NDArray CopyQOIndptrOnDepthAsync(HostMemoryVector* data, int depth) = 0; /*! \brief Copy the indptr array of page table. */ - virtual NDArray CopyPageIndptrOnDepthAsync(std::vector* data, int depth) = 0; + virtual NDArray CopyPageIndptrOnDepthAsync(HostMemoryVector* data, int depth) = 0; /*! \brief Copy the indices array of page table. */ - virtual NDArray CopyPageIndicesOnDepthAsync(std::vector* data, int depth) = 0; + virtual NDArray CopyPageIndicesOnDepthAsync(HostMemoryVector* data, int depth) = 0; /*! \brief Copy the array of KV slot number used in the last page of the seq. */ - virtual NDArray CopyLastPageLenOnDepthAsync(std::vector* data, int depth) = 0; + virtual NDArray CopyLastPageLenOnDepthAsync(HostMemoryVector* data, int depth) = 0; /*! * \brief Copy the length information of the sequences. * Each NDArray is in shape `(3, n)`. "n" is the number of sequences. @@ -239,27 +293,27 @@ class PagedKVCacheAuxDataManager { * \note When sliding window is not enabled, only the * "last_page_len" (a.k.a., the first "n" elements) will be effectively used. */ - virtual NDArray CopyLengthInfoOnDepthAsync(std::vector* last_page_len, - std::vector* sliding_window_offset, - std::vector* sink_size, int depth) = 0; + virtual NDArray CopyLengthInfoOnDepthAsync(HostMemoryVector* last_page_len, + HostMemoryVector* sliding_window_offset, + HostMemoryVector* sink_size, int depth) = 0; /*! \brief Copy the k position offset of applying RoPE for each sequence. */ - virtual NDArray CopyKRoPEPosOffsetOnDepthAsync(std::vector* data, int depth) = 0; + virtual NDArray CopyKRoPEPosOffsetOnDepthAsync(HostMemoryVector* data, int depth) = 0; /*! * \brief Copy the append length indptr array on device. * \note Since the Q/K/V data may have raggedness in terms of lengths, * we represent the append lengths in CSR format. */ - virtual NDArray CopyCurAppendLengthIndptrAsync(std::vector* data) = 0; + virtual NDArray CopyCurAppendLengthIndptrAsync(HostMemoryVector* data) = 0; /*! \brief Copy the k position offset of applying RoPE for each sequence. */ - virtual NDArray CopyKRaggedRoPEPosOffsetAsync(std::vector* data) = 0; + virtual NDArray CopyKRaggedRoPEPosOffsetAsync(HostMemoryVector* data) = 0; /*! \brief Copy the q position mapping of applying RoPE for each sequence. */ - virtual NDArray CopyQRoPEPosMapAsync(std::vector* data) = 0; + virtual NDArray CopyQRoPEPosMapAsync(HostMemoryVector* data) = 0; /*! * \brief Copy the corresponding position in global KV cache (pages) * for each position along the length dimension of K/V data when * appending new K/V data. */ - virtual NDArray CopyAppendPositionMapAsync(std::vector* data) = 0; + virtual NDArray CopyAppendPositionMapAsync(HostMemoryVector* data) = 0; /*! \brief Commit all the copy operations since the last commit. */ virtual void CommitCopy() = 0; @@ -268,6 +322,8 @@ class PagedKVCacheAuxDataManager { const DLDataType dtype_aux_; /*! \brief The device this PagedKVCache runs on. */ const Device device_; + /*! \brief The preferred host device. */ + const Device preferred_host_device_; /*! \brief The device stream for copying auxiliary data structure to GPU. */ const TVMStreamHandle copy_stream_; }; @@ -280,8 +336,9 @@ class PlainPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { public: explicit PlainPagedKVCacheAuxDataManager(int64_t reserved_num_seqs, int64_t num_total_pages, int64_t prefill_chunk_size, DLDataType dtype_aux, - DLDevice device, TVMStreamHandle copy_stream) - : PagedKVCacheAuxDataManager(dtype_aux, device, copy_stream) { + Device device, Device preferred_host_device, + TVMStreamHandle copy_stream) + : PagedKVCacheAuxDataManager(dtype_aux, device, preferred_host_device, copy_stream) { for (int d = 0; d < kPagedKVCacheMaxBlockDepth; ++d) { qo_indptr_on_depths_device_.push_back( NDArray::Empty({reserved_num_seqs + 1}, dtype_aux_, device)); @@ -302,64 +359,64 @@ class PlainPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { // The reset of the plain auxiliary data manager is no-op. void ResetCopy() final {} - NDArray CopyQOIndptrOnDepthAsync(std::vector* data, int depth) final { + NDArray CopyQOIndptrOnDepthAsync(HostMemoryVector* data, int depth) final { NDArray view = qo_indptr_on_depths_device_[depth].CreateView( {static_cast(data->size())}, dtype_aux_); CopyVecDataToArray(view, data->data()); return view; } - NDArray CopyPageIndptrOnDepthAsync(std::vector* data, int depth) final { + NDArray CopyPageIndptrOnDepthAsync(HostMemoryVector* data, int depth) final { NDArray view = page_indptr_on_depths_device_[depth].CreateView( {static_cast(data->size())}, dtype_aux_); CopyVecDataToArray(view, data->data()); return view; } - NDArray CopyPageIndicesOnDepthAsync(std::vector* data, int depth) final { + NDArray CopyPageIndicesOnDepthAsync(HostMemoryVector* data, int depth) final { NDArray view = page_indices_on_depths_device_[depth].CreateView( {static_cast(data->size())}, dtype_aux_); CopyVecDataToArray(view, data->data()); return view; } - NDArray CopyLastPageLenOnDepthAsync(std::vector* data, int depth) final { + NDArray CopyLastPageLenOnDepthAsync(HostMemoryVector* data, int depth) final { NDArray view = length_info_on_depths_device_[depth].CreateView( {static_cast(data->size())}, dtype_aux_); CopyVecDataToArray(view, data->data()); return view; } - NDArray CopyKRoPEPosOffsetOnDepthAsync(std::vector* data, int depth) final { + NDArray CopyKRoPEPosOffsetOnDepthAsync(HostMemoryVector* data, int depth) final { NDArray view = k_rope_pos_offset_on_depths_device_[depth].CreateView( {static_cast(data->size())}, dtype_aux_); CopyVecDataToArray(view, data->data()); return view; } - NDArray CopyCurAppendLengthIndptrAsync(std::vector* data) final { + NDArray CopyCurAppendLengthIndptrAsync(HostMemoryVector* data) final { NDArray view = cur_append_length_indptr_device_.CreateView({static_cast(data->size())}, dtype_aux_); CopyVecDataToArray(view, data->data()); return view; } - NDArray CopyKRaggedRoPEPosOffsetAsync(std::vector* data) final { + NDArray CopyKRaggedRoPEPosOffsetAsync(HostMemoryVector* data) final { NDArray view = k_ragged_rope_pos_offset_device_.CreateView({static_cast(data->size())}, dtype_aux_); CopyVecDataToArray(view, data->data()); return view; } - NDArray CopyQRoPEPosMapAsync(std::vector* data) final { + NDArray CopyQRoPEPosMapAsync(HostMemoryVector* data) final { NDArray view = q_rope_position_map_device_.CreateView({static_cast(data->size())}, dtype_aux_); CopyVecDataToArray(view, data->data()); return view; } - NDArray CopyAppendPositionMapAsync(std::vector* data) final { + NDArray CopyAppendPositionMapAsync(HostMemoryVector* data) final { NDArray view = append_position_map_device_.CreateView({static_cast(data->size())}, dtype_aux_); CopyVecDataToArray(view, data->data()); return view; } - NDArray CopyLengthInfoOnDepthAsync(std::vector* last_page_len, - std::vector* sliding_window_offset, - std::vector* sink_size, int depth) final { + NDArray CopyLengthInfoOnDepthAsync(HostMemoryVector* last_page_len, + HostMemoryVector* sliding_window_offset, + HostMemoryVector* sink_size, int depth) final { int n_elem = last_page_len->size(); ICHECK_GT(n_elem, 0); NDArray view = length_info_on_depths_device_[depth].CreateView({3, n_elem}, dtype_aux_); @@ -412,7 +469,7 @@ class PlainPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { DLTensor copy_src; copy_src.data = vec_data; - copy_src.device = Device{kDLCPU, 0}; + copy_src.device = preferred_host_device_; copy_src.ndim = 1; copy_src.dtype = array->dtype; copy_src.shape = copy_dst.shape; @@ -443,15 +500,16 @@ class CachedPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { public: explicit CachedPagedKVCacheAuxDataManager(int64_t reserved_num_seqs, int64_t num_total_pages, int64_t prefill_chunk_size, DLDataType dtype_aux, - DLDevice device, TVMStreamHandle copy_stream) - : PagedKVCacheAuxDataManager(dtype_aux, device, copy_stream), + DLDevice device, Device preferred_host_device, + TVMStreamHandle copy_stream) + : PagedKVCacheAuxDataManager(dtype_aux, device, preferred_host_device, copy_stream), elem_byte_size_((dtype_aux.bits * dtype_aux.lanes + 7) / 8), offset_alignment_(cuda_byte_alignment_ / elem_byte_size_) { // - Calculate cache size of all the auxiliary arrays in // local cache and the large on-device array. int64_t cache_size = CalculateCacheSize(reserved_num_seqs, num_total_pages, prefill_chunk_size); // - Initialize the host auxiliary data buffer. - merged_aux_data_host_.resize(cache_size); + merged_aux_data_host_ = HostMemoryVector(cache_size, dtype_aux, preferred_host_device); // - Initialize the device auxiliary data buffer. memory::Allocator* allocator = memory::MemoryManager::GetOrCreateAllocator(device, memory::AllocatorType::kNaive); @@ -461,34 +519,32 @@ class CachedPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { } void ResetCopy() final { copy_offset_ = 0; } - NDArray CopyQOIndptrOnDepthAsync(std::vector* data, int depth) final { + NDArray CopyQOIndptrOnDepthAsync(HostMemoryVector* data, int depth) final { return CopyVecToCache(data); } - NDArray CopyPageIndptrOnDepthAsync(std::vector* data, int depth) final { + NDArray CopyPageIndptrOnDepthAsync(HostMemoryVector* data, int depth) final { return CopyVecToCache(data); } - NDArray CopyPageIndicesOnDepthAsync(std::vector* data, int depth) final { + NDArray CopyPageIndicesOnDepthAsync(HostMemoryVector* data, int depth) final { return CopyVecToCache(data); } - NDArray CopyLastPageLenOnDepthAsync(std::vector* data, int depth) final { + NDArray CopyLastPageLenOnDepthAsync(HostMemoryVector* data, int depth) final { return CopyVecToCache(data); } - NDArray CopyKRoPEPosOffsetOnDepthAsync(std::vector* data, int depth) final { + NDArray CopyKRoPEPosOffsetOnDepthAsync(HostMemoryVector* data, int depth) final { return CopyVecToCache(data); } - NDArray CopyCurAppendLengthIndptrAsync(std::vector* data) final { + NDArray CopyCurAppendLengthIndptrAsync(HostMemoryVector* data) final { return CopyVecToCache(data); } - NDArray CopyKRaggedRoPEPosOffsetAsync(std::vector* data) final { + NDArray CopyKRaggedRoPEPosOffsetAsync(HostMemoryVector* data) final { return CopyVecToCache(data); } - NDArray CopyQRoPEPosMapAsync(std::vector* data) final { return CopyVecToCache(data); } - NDArray CopyAppendPositionMapAsync(std::vector* data) final { - return CopyVecToCache(data); - } - NDArray CopyLengthInfoOnDepthAsync(std::vector* last_page_len, - std::vector* sliding_window_offset, - std::vector* sink_size, int depth) final { + NDArray CopyQRoPEPosMapAsync(HostMemoryVector* data) final { return CopyVecToCache(data); } + NDArray CopyAppendPositionMapAsync(HostMemoryVector* data) final { return CopyVecToCache(data); } + NDArray CopyLengthInfoOnDepthAsync(HostMemoryVector* last_page_len, + HostMemoryVector* sliding_window_offset, + HostMemoryVector* sink_size, int depth) final { int64_t n_elem = last_page_len->size(); std::memcpy(merged_aux_data_host_.data() + copy_offset_, last_page_len->data(), n_elem * elem_byte_size_); @@ -559,7 +615,7 @@ class CachedPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { * \brief Copy the input data to the cache at the given offset. * And return the NDArray view of the cache starting at the offset. */ - NDArray CopyVecToCache(std::vector* data) { + NDArray CopyVecToCache(HostMemoryVector* data) { int64_t n_elem = data->size(); std::memcpy(merged_aux_data_host_.data() + copy_offset_, data->data(), n_elem * elem_byte_size_); @@ -579,7 +635,7 @@ class CachedPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { const int64_t offset_alignment_; int64_t copy_offset_ = 0; - std::vector merged_aux_data_host_; + HostMemoryVector merged_aux_data_host_; memory::Storage merged_aux_data_device_; }; @@ -687,17 +743,17 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // Below are the auxiliary data structure on CPU. // We make them class members to avoid repetitive allocation time in BeginForward. //------------------------------------------- - std::vector> qo_indptr_on_depths_host_; - std::vector> page_indptr_on_depths_host_; - std::vector> page_indices_on_depths_host_; - std::vector> last_page_len_on_depths_host_; - std::vector> sliding_window_offset_on_depths_host_; - std::vector> sink_size_on_depths_host_; - std::vector> k_rope_pos_offset_on_depths_host_; - std::vector k_ragged_rope_pos_offset_host_; - std::vector q_rope_position_map_host_; - std::vector append_position_map_host_; - std::vector cur_append_lengths_indptr_host_; + std::vector qo_indptr_on_depths_host_; + std::vector page_indptr_on_depths_host_; + std::vector page_indices_on_depths_host_; + std::vector last_page_len_on_depths_host_; + std::vector sliding_window_offset_on_depths_host_; + std::vector sink_size_on_depths_host_; + std::vector k_rope_pos_offset_on_depths_host_; + HostMemoryVector k_ragged_rope_pos_offset_host_; + HostMemoryVector q_rope_position_map_host_; + HostMemoryVector append_position_map_host_; + HostMemoryVector cur_append_lengths_indptr_host_; //------------------------------------------- // For efficient memory management, the actual sizes of the arrays @@ -804,6 +860,33 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { pages_.push_back( NDArray::Empty({num_total_pages, 2, num_kv_heads, page_size, head_dim}, dtype, device)); } + // Allocate the host memory. + Device preferred_host_device = GetPreferredHostDevice(device); + for (int d = 0; d < kPagedKVCacheMaxBlockDepth; ++d) { + qo_indptr_on_depths_host_.push_back( + HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, preferred_host_device)); + page_indptr_on_depths_host_.push_back( + HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, preferred_host_device)); + page_indices_on_depths_host_.push_back( + HostMemoryVector(num_total_pages, dtype_aux_, preferred_host_device)); + last_page_len_on_depths_host_.push_back( + HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device)); + sliding_window_offset_on_depths_host_.push_back( + HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device)); + sink_size_on_depths_host_.push_back( + HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device)); + k_rope_pos_offset_on_depths_host_.push_back( + HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device)); + } + k_ragged_rope_pos_offset_host_ = + HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device); + q_rope_position_map_host_ = + HostMemoryVector(prefill_chunk_size, dtype_aux_, preferred_host_device); + append_position_map_host_ = + HostMemoryVector(prefill_chunk_size, dtype_aux_, preferred_host_device); + cur_append_lengths_indptr_host_ = + HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, preferred_host_device); + for (int d = 0; d < kPagedKVCacheMaxBlockDepth; ++d) { temp_attn_workspace_.push_back( NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32), device)); @@ -847,10 +930,12 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // operations may have issues on other platforms. if (device_.device_type == DLDeviceType::kDLCUDA) { aux_data_manager_ = std::make_unique( - reserved_num_seqs, num_total_pages, prefill_chunk_size, dtype_aux_, device, copy_stream_); + reserved_num_seqs, num_total_pages, prefill_chunk_size, dtype_aux_, device, + preferred_host_device, copy_stream_); } else { aux_data_manager_ = std::make_unique( - reserved_num_seqs, num_total_pages, prefill_chunk_size, dtype_aux_, device, copy_stream_); + reserved_num_seqs, num_total_pages, prefill_chunk_size, dtype_aux_, device, + preferred_host_device, copy_stream_); } } @@ -1124,7 +1209,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { is_decode_request_ = true; sequences.reserve(cur_batch_size_); last_block_length_before_append.reserve(cur_batch_size_); - k_ragged_rope_pos_offset_host_.resize(cur_batch_size_); + k_ragged_rope_pos_offset_host_.clear(); for (int i = 0; i < cur_batch_size_; ++i) { auto it = seq_map_.find(seq_ids[i]); CHECK(it != seq_map_.end()) << "The sequence \"" << seq_ids[i] @@ -1132,7 +1217,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { sequences.push_back(&it->second); last_block_length_before_append.push_back( global_block_pool_[it->second.last_block_idx].seq_length); - k_ragged_rope_pos_offset_host_[i] = it->second.seq_length; + k_ragged_rope_pos_offset_host_.push_back(it->second.seq_length); it->second.seq_length += append_lengths[i]; if (append_lengths[i] != 1) { is_decode_request_ = false; @@ -1162,22 +1247,14 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } } - qo_indptr_on_depths_host_.resize(num_depths_); - page_indptr_on_depths_host_.resize(num_depths_); - page_indices_on_depths_host_.resize(num_depths_); - last_page_len_on_depths_host_.resize(num_depths_); - sliding_window_offset_on_depths_host_.resize(num_depths_); - sink_size_on_depths_host_.resize(num_depths_); - k_rope_pos_offset_on_depths_host_.resize(num_depths_); - for (int d = 0; d < num_depths_; ++d) { - std::vector& qo_indptr_h = qo_indptr_on_depths_host_[d]; - std::vector& page_indptr_h = page_indptr_on_depths_host_[d]; - std::vector& page_indices_h = page_indices_on_depths_host_[d]; - std::vector& last_page_len_h = last_page_len_on_depths_host_[d]; - std::vector& sliding_window_offset_h = sliding_window_offset_on_depths_host_[d]; - std::vector& sink_size_h = sink_size_on_depths_host_[d]; - std::vector& k_rope_pos_offset_h = k_rope_pos_offset_on_depths_host_[d]; + HostMemoryVector& qo_indptr_h = qo_indptr_on_depths_host_[d]; + HostMemoryVector& page_indptr_h = page_indptr_on_depths_host_[d]; + HostMemoryVector& page_indices_h = page_indices_on_depths_host_[d]; + HostMemoryVector& last_page_len_h = last_page_len_on_depths_host_[d]; + HostMemoryVector& sliding_window_offset_h = sliding_window_offset_on_depths_host_[d]; + HostMemoryVector& sink_size_h = sink_size_on_depths_host_[d]; + HostMemoryVector& k_rope_pos_offset_h = k_rope_pos_offset_on_depths_host_[d]; qo_indptr_h.clear(); page_indptr_h.clear(); page_indices_h.clear(); @@ -1198,7 +1275,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } else { const Block& block = global_block_pool_[block_id]; page_indptr_h.push_back(page_indptr_h.back() + block.page_ids.size()); - page_indices_h.insert(page_indices_h.end(), block.page_ids.begin(), block.page_ids.end()); + for (int32_t page_id : block.page_ids) { + page_indices_h.push_back(page_id); + } last_page_len_h.push_back(block.seq_length == 0 ? 0 : (block.seq_length - block.sink_length + block.sliding_window_offset - 1) % @@ -1620,14 +1699,15 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { if (append_before_attn_) { if (!support_sliding_window_) { f_attention_decode_begin_forward_.value()( - /*depth=*/0, temp_attn_workspace_[1], page_indptr_on_depths_view_[0], - length_info_on_depths_view_[0], num_qo_heads_, num_kv_heads_, head_dim_, page_size_, + /*depth=*/0, temp_attn_workspace_[1], page_indptr_on_depths_host_[0].as_ndarray(), + last_page_len_on_depths_host_[0].as_ndarray(), num_qo_heads_, num_kv_heads_, head_dim_, + page_size_, /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, copy_stream_); } } else { f_attention_prefill_ragged_begin_forward_.value()( - temp_attn_workspace_[0], cur_append_length_indptr_view_, cur_batch_size_, num_qo_heads_, - num_kv_heads_, head_dim_, copy_stream_); + temp_attn_workspace_[0], cur_append_lengths_indptr_host_.as_ndarray(), cur_batch_size_, + num_qo_heads_, num_kv_heads_, head_dim_, copy_stream_); if (support_sliding_window_) { return; } @@ -1637,12 +1717,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } if (use_decode_kernel_[d]) { f_attention_decode_begin_forward_.value()( - d, temp_attn_workspace_[d + 1], page_indptr_on_depths_view_[d], - length_info_on_depths_view_[d], num_qo_heads_, num_kv_heads_, head_dim_, page_size_, + d, temp_attn_workspace_[d + 1], page_indptr_on_depths_host_[d].as_ndarray(), + last_page_len_on_depths_host_[d].as_ndarray(), num_qo_heads_, num_kv_heads_, + head_dim_, page_size_, /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, copy_stream_); } else { f_attention_prefill_begin_forward_.value()( - /*depth=*/d, temp_attn_workspace_[d + 1], qo_indptr_on_depths_view_[d], + /*depth=*/d, temp_attn_workspace_[d + 1], qo_indptr_on_depths_host_[d].as_ndarray(), length_info_on_depths_view_[d]->shape[0], num_qo_heads_, num_kv_heads_, head_dim_, copy_stream_); } @@ -1732,17 +1813,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { */ void SyncAuxArrayToDevice() { ICHECK(dtype_aux_.bits == 32 && dtype_aux_.code == kDLInt); - ICHECK_EQ(qo_indptr_on_depths_host_.size(), num_depths_); - ICHECK_EQ(page_indptr_on_depths_host_.size(), num_depths_); - ICHECK_EQ(page_indices_on_depths_host_.size(), num_depths_); - ICHECK_EQ(last_page_len_on_depths_host_.size(), num_depths_); int64_t total_append_length = 0; int num_sequences = cur_append_lengths_.size(); - cur_append_lengths_indptr_host_.resize(num_sequences + 1); - cur_append_lengths_indptr_host_[0] = 0; + cur_append_lengths_indptr_host_.clear(); + cur_append_lengths_indptr_host_.push_back(0); for (int i = 0; i < num_sequences; ++i) { - cur_append_lengths_indptr_host_[i + 1] = - cur_append_lengths_indptr_host_[i] + cur_append_lengths_[i]; + cur_append_lengths_indptr_host_.push_back(cur_append_lengths_indptr_host_.back() + + cur_append_lengths_[i]); } total_append_length = cur_append_lengths_indptr_host_.back(); ICHECK_EQ(total_append_length, append_position_map_host_.size());