diff --git a/src/runtime/relax_vm/kv_state.cc b/src/runtime/relax_vm/kv_state.cc index b1572bf4091a..b730a4eb07ce 100644 --- a/src/runtime/relax_vm/kv_state.cc +++ b/src/runtime/relax_vm/kv_state.cc @@ -40,13 +40,26 @@ TVM_REGISTER_GLOBAL("vm.builtin.kv_state_fork_sequence") .set_body_method(&KVStateObj::ForkSequence); TVM_REGISTER_GLOBAL("vm.builtin.kv_state_popn").set_body_method(&KVStateObj::PopN); TVM_REGISTER_GLOBAL("vm.builtin.kv_state_begin_forward") - .set_body_method(&KVStateObj::BeginForward); + .set_body([](TVMArgs args, TVMRetValue* rv) { + CHECK(args.size() == 3 || args.size() == 4) + << "KVState BeginForward only accepts 3 or 4 arguments"; + KVState kv_state = args[0]; + IntTuple seq_ids = args[1]; + IntTuple append_lengths = args[2]; + Optional token_tree_parent_ptr{nullptr}; + if (args.size() == 4) { + token_tree_parent_ptr = args[3].operator Optional(); + } + kv_state->BeginForward(seq_ids, append_lengths, token_tree_parent_ptr); + }); TVM_REGISTER_GLOBAL("vm.builtin.kv_state_end_forward") .set_body_method(&KVStateObj::EndForward); // Attention KV Cache methods TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_enable_sliding_window_for_seq") .set_body_method(&AttentionKVCacheObj::EnableSlidingWindowForSeq); +TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_commit_accepted_token_tree_nodes") + .set_body_method(&AttentionKVCacheObj::CommitAcceptedTokenTreeNodes); TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_empty") .set_body_method(&AttentionKVCacheObj::Empty); TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_num_available_pages") diff --git a/src/runtime/relax_vm/kv_state.h b/src/runtime/relax_vm/kv_state.h index 12a18ba89502..8de560f12266 100644 --- a/src/runtime/relax_vm/kv_state.h +++ b/src/runtime/relax_vm/kv_state.h @@ -89,8 +89,12 @@ class KVStateObj : public Object { * in the model forward function. * \param seq_ids The ids of the sequence to run in the incoming model forward. * \param append_lengths The sequence lengths to run forward for for each sequence. + * \param token_tree_parent_ptr The parent idx array of the token trees. Its length + * is the sum of "append_lengths". Nullptr means the token tree of each sequence + * is a chain. */ - virtual void BeginForward(const IntTuple& seq_ids, const IntTuple& append_lengths) = 0; + virtual void BeginForward(const IntTuple& seq_ids, const IntTuple& append_lengths, + const Optional& token_tree_parent_ptr = NullOpt) = 0; /*! * \brief Mark the start of the forward function. @@ -142,6 +146,15 @@ class AttentionKVCacheObj : public KVStateObj { virtual void EnableSlidingWindowForSeq(int64_t seq_id, int32_t sliding_window_size, int32_t attn_sink_size) = 0; + /*! + * \brief Committed the accepted token tree nodes to KV cache. + * The commit will update the KV cache, by compacting the KV data and discard + * the KV data of rejected tokens. + * This is a mandatory step when the BeginForward is given with a token tree. + * \param leaf_indices The leaf token tree node index of each sequence. + */ + virtual void CommitAcceptedTokenTreeNodes(const IntTuple& leaf_indices) = 0; + /************** Attention **************/ /*! diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index 4ab0f3f0c686..a5b970e81716 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -26,6 +26,8 @@ #include #include +#include +#include #include #include #include @@ -52,6 +54,8 @@ namespace relax_vm { * prefixes) in paged KV cache. */ constexpr const int kPagedKVCacheMaxBlockDepth = 5; +/*! \brief The maximum tree size of a single sequence in tree attention. */ +constexpr const int kTreeAttnMaxTreeSize = 256; /*! \brief The 8MB workspace size for attention auxiliary data. */ constexpr const int kAttnWorkspaceByte = 8 * 1024 * 1024; /*! \brief The id of the temporary logical page, which is useful for sliding window. */ @@ -250,14 +254,14 @@ class HostMemoryVector { * This class manages all the int32 auxiliary data on GPU device, such as * page table, position arrays, etc.. * - * The core functions of this class is `CopyXXXAsync` and `CommitCopy`. + * The core functions of this class is `CopyXXXAsync` and `CommitAttnAuxDataCopy`. * `CopyXXXAsync` takes the input data on CPU host, and copy the input data * to GPU in an asynchronous way, and returns the NDArray view of the data * on GPU device. * * Being asynchronous here means the `CopyXXXAsync` function may not perform * data copy from CPU to GPU at the time of being called. Therefore, the - * returned NDArray view may have wrong result, until `CommitCopy` is + * returned NDArray view may have wrong result, until `CommitAttnAuxDataCopy` is * explicitly invoked and the data copy stream is synchronized. * * We design this manager class in order to reduce the data copy overhead. @@ -274,8 +278,8 @@ class PagedKVCacheAuxDataManager { } virtual ~PagedKVCacheAuxDataManager() = default; - /*! \brief Reset the status of copy manager. */ - virtual void ResetCopy() = 0; + /*! \brief Reset the attention auxiliary data status of copy manager. */ + virtual void ResetAttnAuxDataCopy() = 0; /*! \brief Copy the indptr array of append lengths after coalescing. (see GetChunkedBlockIds) */ virtual NDArray CopyQOIndptrOnDepthAsync(HostMemoryVector* data, int depth) = 0; /*! \brief Copy the indptr array of page table. */ @@ -315,8 +319,22 @@ class PagedKVCacheAuxDataManager { * appending new K/V data. */ virtual NDArray CopyAppendPositionMapAsync(HostMemoryVector* data) = 0; - /*! \brief Commit all the copy operations since the last commit. */ - virtual void CommitCopy() = 0; + /*! \brief Copy the tree attention mask. */ + virtual NDArray CopyTreeAttnMaskAsync(HostMemoryVector* data) = 0; + /*! \brief Copy the mn indptr of the tree attention mask. */ + virtual NDArray CopyTreeAttnMNIndptrAsync(HostMemoryVector* data) = 0; + /*! \brief Commit all the attention auxiliary data copy operations since the last commit. */ + virtual void CommitAttnAuxDataCopy() = 0; + + /*! \brief Reset the compact KV auxiliary data status of copy manager. */ + virtual void ResetCompactKVAuxDataCopy() = 0; + /*! \brief Copy the length indptr array of KV data copy for each sequence. */ + virtual NDArray CopyCommitLengthIndptrAsync(HostMemoryVector* data) = 0; + /*! \brief Copy the src/dst position arrays for each sequence. */ + virtual NDArray CopyCommitSrcDstPosInPageTableAsync(HostMemoryVector* src_data, + HostMemoryVector* dst_data) = 0; + /*! \brief Commit all the compact KV auxiliary data copy operations since the last commit. */ + virtual void CommitCompactKVAuxDataCopy() = 0; protected: /*! \brief The dtype of the auxiliary data. It is expected to be int32. */ @@ -356,10 +374,18 @@ class PlainPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { k_ragged_rope_pos_offset_device_ = NDArray::Empty({reserved_num_seqs}, dtype_aux_, device); q_rope_position_map_device_ = NDArray::Empty({prefill_chunk_size}, dtype_aux_, device); append_position_map_device_ = NDArray::Empty({prefill_chunk_size}, dtype_aux_, device); + tree_attn_mask_device_ = NDArray::Empty( + {kTreeAttnMaxTreeSize * kTreeAttnMaxTreeSize * reserved_num_seqs}, dtype_aux_, device); + tree_attn_mn_indptr_device_ = NDArray::Empty({reserved_num_seqs + 1}, dtype_aux_, device); + + commit_copy_length_indptr_device_ = NDArray::Empty({reserved_num_seqs + 1}, dtype_aux_, device); + commit_copy_src_dst_pos_in_page_table_device_ = + NDArray::Empty({2, std::min(kTreeAttnMaxTreeSize * reserved_num_seqs, prefill_chunk_size)}, + dtype_aux_, device); } // The reset of the plain auxiliary data manager is no-op. - void ResetCopy() final {} + void ResetAttnAuxDataCopy() final {} NDArray CopyQOIndptrOnDepthAsync(HostMemoryVector* data, int depth) final { NDArray view = qo_indptr_on_depths_device_[depth].CreateView( {static_cast(data->size())}, dtype_aux_); @@ -414,6 +440,18 @@ class PlainPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { CopyVecDataToArray(view, data->data()); return view; } + NDArray CopyTreeAttnMaskAsync(HostMemoryVector* data) final { + NDArray view = + tree_attn_mask_device_.CreateView({static_cast(data->size())}, dtype_aux_); + CopyVecDataToArray(view, data->data()); + return view; + } + NDArray CopyTreeAttnMNIndptrAsync(HostMemoryVector* data) final { + NDArray view = + tree_attn_mn_indptr_device_.CreateView({static_cast(data->size())}, dtype_aux_); + CopyVecDataToArray(view, data->data()); + return view; + } NDArray CopyLengthInfoOnDepthAsync(HostMemoryVector* last_page_len, HostMemoryVector* sliding_window_offset, @@ -431,7 +469,32 @@ class PlainPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { } // The commit of the plain auxiliary data manager is no-op. - void CommitCopy() final {} + void CommitAttnAuxDataCopy() final {} + + // The reset of the plain auxiliary data manager is no-op. + void ResetCompactKVAuxDataCopy() final {} + + NDArray CopyCommitLengthIndptrAsync(HostMemoryVector* data) final { + NDArray view = commit_copy_length_indptr_device_.CreateView( + {static_cast(data->size())}, dtype_aux_); + CopyVecDataToArray(view, data->data()); + return view; + } + NDArray CopyCommitSrcDstPosInPageTableAsync(HostMemoryVector* src_data, + HostMemoryVector* dst_data) final { + int n_elem = src_data->size(); + ICHECK_GT(n_elem, 0); + NDArray view = + commit_copy_src_dst_pos_in_page_table_device_.CreateView({2, n_elem}, dtype_aux_); + ShapeTuple copy_shape{n_elem}; + CopyVecDataToArray(view, src_data->data(), copy_shape); + CopyVecDataToArray(view, dst_data->data(), copy_shape, + /*dst_elem_offset=*/n_elem); + return view; + } + + // The commit of the plain auxiliary data manager is no-op. + void CommitCompactKVAuxDataCopy() final {} private: /*! @@ -488,81 +551,136 @@ class PlainPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { NDArray k_ragged_rope_pos_offset_device_; NDArray q_rope_position_map_device_; NDArray append_position_map_device_; + NDArray tree_attn_mask_device_; + NDArray tree_attn_mn_indptr_device_; + NDArray commit_copy_length_indptr_device_; + NDArray commit_copy_src_dst_pos_in_page_table_device_; }; /*! * \brief The cached auxiliary data manager class. * It allocates a large on-device array to store all the auxiliary data. * For each `CopyXXXAsync`, it copies the input data to a local cache on host. - * In `CommitCopy`, it copies all the data in the local cache to the device + * In `CommitAttnAuxDataCopy`, it copies all the data in the local cache to the device * array for a single time, and thus reduce the number of host-to-device copies needed. */ class CachedPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { public: explicit CachedPagedKVCacheAuxDataManager(int64_t reserved_num_seqs, int64_t num_total_pages, int64_t prefill_chunk_size, DLDataType dtype_aux, - DLDevice device, Device preferred_host_device, + Device device, Device preferred_host_device, TVMStreamHandle copy_stream) : PagedKVCacheAuxDataManager(dtype_aux, device, preferred_host_device, copy_stream), elem_byte_size_((dtype_aux.bits * dtype_aux.lanes + 7) / 8), offset_alignment_(cuda_byte_alignment_ / elem_byte_size_) { - // - Calculate cache size of all the auxiliary arrays in + // - Calculate cache size of all the attention auxiliary arrays in // local cache and the large on-device array. - int64_t cache_size = CalculateCacheSize(reserved_num_seqs, num_total_pages, prefill_chunk_size); + int64_t attn_aux_data_cache_size = + CalculateAttnAuxDataCacheSize(reserved_num_seqs, num_total_pages, prefill_chunk_size); // - Initialize the host auxiliary data buffer. - merged_aux_data_host_ = HostMemoryVector(cache_size, dtype_aux, preferred_host_device); + merged_attn_aux_data_host_ = + HostMemoryVector(attn_aux_data_cache_size, dtype_aux, preferred_host_device); // - Initialize the device auxiliary data buffer. - memory::Allocator* allocator = - memory::MemoryManager::GetOrCreateAllocator(device, memory::AllocatorType::kNaive); - ICHECK_NOTNULL(allocator); - merged_aux_data_device_ = - memory::Storage(allocator->Alloc(device, {cache_size}, dtype_aux), allocator); + merged_attn_aux_data_device_ = NDArray::Empty({attn_aux_data_cache_size}, dtype_aux, device); + + // - Calculate cache size of all the compact KV auxiliary arrays in + // local cache and the large on-device array. + int64_t compact_kv_aux_data_cache_size = + CalculateCompactKVAuxDataCacheSize(reserved_num_seqs, prefill_chunk_size); + // - Initialize the host auxiliary data buffer. + merged_compact_kv_aux_data_host_ = + HostMemoryVector(compact_kv_aux_data_cache_size, dtype_aux, preferred_host_device); + merged_compact_kv_aux_data_device_ = + NDArray::Empty({compact_kv_aux_data_cache_size}, dtype_aux, device); } - void ResetCopy() final { copy_offset_ = 0; } + void ResetAttnAuxDataCopy() final { attn_aux_data_copy_offset_ = 0; } NDArray CopyQOIndptrOnDepthAsync(HostMemoryVector* data, int depth) final { - return CopyVecToCache(data); + return CopyAttnAuxVecToCache(data); } NDArray CopyPageIndptrOnDepthAsync(HostMemoryVector* data, int depth) final { - return CopyVecToCache(data); + return CopyAttnAuxVecToCache(data); } NDArray CopyPageIndicesOnDepthAsync(HostMemoryVector* data, int depth) final { - return CopyVecToCache(data); + return CopyAttnAuxVecToCache(data); } NDArray CopyLastPageLenOnDepthAsync(HostMemoryVector* data, int depth) final { - return CopyVecToCache(data); + return CopyAttnAuxVecToCache(data); } NDArray CopyKRoPEPosOffsetOnDepthAsync(HostMemoryVector* data, int depth) final { - return CopyVecToCache(data); + return CopyAttnAuxVecToCache(data); } NDArray CopyCurAppendLengthIndptrAsync(HostMemoryVector* data) final { - return CopyVecToCache(data); + return CopyAttnAuxVecToCache(data); } NDArray CopyKRaggedRoPEPosOffsetAsync(HostMemoryVector* data) final { - return CopyVecToCache(data); + return CopyAttnAuxVecToCache(data); + } + NDArray CopyQRoPEPosMapAsync(HostMemoryVector* data) final { return CopyAttnAuxVecToCache(data); } + NDArray CopyAppendPositionMapAsync(HostMemoryVector* data) final { + return CopyAttnAuxVecToCache(data); + } + NDArray CopyTreeAttnMaskAsync(HostMemoryVector* data) final { + return CopyAttnAuxVecToCache(data); + } + NDArray CopyTreeAttnMNIndptrAsync(HostMemoryVector* data) final { + return CopyAttnAuxVecToCache(data); } - NDArray CopyQRoPEPosMapAsync(HostMemoryVector* data) final { return CopyVecToCache(data); } - NDArray CopyAppendPositionMapAsync(HostMemoryVector* data) final { return CopyVecToCache(data); } NDArray CopyLengthInfoOnDepthAsync(HostMemoryVector* last_page_len, HostMemoryVector* sliding_window_offset, HostMemoryVector* sink_size, int depth) final { int64_t n_elem = last_page_len->size(); - std::memcpy(merged_aux_data_host_.data() + copy_offset_, last_page_len->data(), - n_elem * elem_byte_size_); - std::memcpy(merged_aux_data_host_.data() + copy_offset_ + n_elem, sliding_window_offset->data(), - n_elem * elem_byte_size_); - std::memcpy(merged_aux_data_host_.data() + copy_offset_ + 2 * n_elem, sink_size->data(), - n_elem * elem_byte_size_); - NDArray view = merged_aux_data_device_->AllocNDArray(copy_offset_ * elem_byte_size_, - {3, n_elem}, dtype_aux_); - copy_offset_ += CeilDivElemAlignment(3 * n_elem); + std::memcpy(merged_attn_aux_data_host_.data() + attn_aux_data_copy_offset_, + last_page_len->data(), n_elem * elem_byte_size_); + std::memcpy(merged_attn_aux_data_host_.data() + attn_aux_data_copy_offset_ + n_elem, + sliding_window_offset->data(), n_elem * elem_byte_size_); + std::memcpy(merged_attn_aux_data_host_.data() + attn_aux_data_copy_offset_ + 2 * n_elem, + sink_size->data(), n_elem * elem_byte_size_); + NDArray view = merged_attn_aux_data_device_.CreateView( + {3, n_elem}, dtype_aux_, attn_aux_data_copy_offset_ * elem_byte_size_); + attn_aux_data_copy_offset_ += CeilDivElemAlignment(3 * n_elem); + return view; + } + + void CommitAttnAuxDataCopy() final { + std::vector copy_shape{attn_aux_data_copy_offset_}; + DLTensor copy_dst; + copy_dst.data = merged_attn_aux_data_device_->data; + copy_dst.device = device_; + copy_dst.ndim = 1; + copy_dst.dtype = dtype_aux_; + copy_dst.shape = copy_shape.data(); + copy_dst.strides = nullptr; + copy_dst.byte_offset = 0; + + DLTensor copy_src = copy_dst; + copy_src.data = merged_attn_aux_data_host_.data(); + copy_src.device = Device{kDLCPU, 0}; + NDArray::CopyFromTo(©_src, ©_dst, copy_stream_); + } + + void ResetCompactKVAuxDataCopy() final { compact_kv_aux_data_copy_offset_ = 0; } + + NDArray CopyCommitLengthIndptrAsync(HostMemoryVector* data) final { + return CopyCompactKVAuxVecToCache(data); + } + NDArray CopyCommitSrcDstPosInPageTableAsync(HostMemoryVector* src_data, + HostMemoryVector* dst_data) final { + int64_t n_elem = src_data->size(); + std::memcpy(merged_compact_kv_aux_data_host_.data() + compact_kv_aux_data_copy_offset_, + src_data->data(), n_elem * elem_byte_size_); + std::memcpy(merged_compact_kv_aux_data_host_.data() + compact_kv_aux_data_copy_offset_ + n_elem, + dst_data->data(), n_elem * elem_byte_size_); + NDArray view = merged_compact_kv_aux_data_device_.CreateView( + {2, n_elem}, dtype_aux_, compact_kv_aux_data_copy_offset_ * elem_byte_size_); + compact_kv_aux_data_copy_offset_ += CeilDivElemAlignment(2 * n_elem); return view; } - void CommitCopy() final { - std::vector copy_shape{copy_offset_}; + void CommitCompactKVAuxDataCopy() final { + std::vector copy_shape{compact_kv_aux_data_copy_offset_}; DLTensor copy_dst; - copy_dst.data = merged_aux_data_device_->buffer.data; + copy_dst.data = merged_compact_kv_aux_data_device_->data; copy_dst.device = device_; copy_dst.ndim = 1; copy_dst.dtype = dtype_aux_; @@ -571,7 +689,7 @@ class CachedPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { copy_dst.byte_offset = 0; DLTensor copy_src = copy_dst; - copy_src.data = merged_aux_data_host_.data(); + copy_src.data = merged_compact_kv_aux_data_host_.data(); copy_src.device = Device{kDLCPU, 0}; NDArray::CopyFromTo(©_src, ©_dst, copy_stream_); } @@ -581,8 +699,8 @@ class CachedPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { * \brief Calculate the start element offsets of the auxiliary arrays in the local cache. * \return Return the local cache size (total number of elements in the local cache). */ - int64_t CalculateCacheSize(int64_t reserved_num_seqs, int64_t num_total_pages, - int64_t prefill_chunk_size) { + int64_t CalculateAttnAuxDataCacheSize(int64_t reserved_num_seqs, int64_t num_total_pages, + int64_t prefill_chunk_size) { int64_t cache_size = 0; // - Array size of the arrays that every depth has. // Corresponding to the following arrays respectively @@ -604,10 +722,28 @@ class CachedPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { // - k_ragged_rope_pos_offset // - q_rope_position_map // - append_position_map + // - tree_attn_mask + // - tree_attn_mn_indptr cache_size += CeilDivElemAlignment(reserved_num_seqs + 1); cache_size += CeilDivElemAlignment(reserved_num_seqs); cache_size += CeilDivElemAlignment(prefill_chunk_size); cache_size += CeilDivElemAlignment(prefill_chunk_size); + cache_size += + CeilDivElemAlignment(kTreeAttnMaxTreeSize * kTreeAttnMaxTreeSize * reserved_num_seqs); + cache_size += CeilDivElemAlignment(reserved_num_seqs + 1); + + return cache_size; + } + + int64_t CalculateCompactKVAuxDataCacheSize(int64_t reserved_num_seqs, + int64_t prefill_chunk_size) { + int64_t cache_size = 0; + // Corresponding to the following arrays respectively + // - commit_copy_length_indptr + // - commit_copy_src_dst_pos_in_page_table + cache_size += CeilDivElemAlignment(reserved_num_seqs + 1); + cache_size += CeilDivElemAlignment( + 2 * std::min(kTreeAttnMaxTreeSize * reserved_num_seqs, prefill_chunk_size)); return cache_size; } @@ -616,13 +752,23 @@ class CachedPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { * \brief Copy the input data to the cache at the given offset. * And return the NDArray view of the cache starting at the offset. */ - NDArray CopyVecToCache(HostMemoryVector* data) { + NDArray CopyAttnAuxVecToCache(HostMemoryVector* data) { int64_t n_elem = data->size(); - std::memcpy(merged_aux_data_host_.data() + copy_offset_, data->data(), + std::memcpy(merged_attn_aux_data_host_.data() + attn_aux_data_copy_offset_, data->data(), n_elem * elem_byte_size_); - NDArray view = - merged_aux_data_device_->AllocNDArray(copy_offset_ * elem_byte_size_, {n_elem}, dtype_aux_); - copy_offset_ += CeilDivElemAlignment(n_elem); + NDArray view = merged_attn_aux_data_device_.CreateView( + {n_elem}, dtype_aux_, attn_aux_data_copy_offset_ * elem_byte_size_); + attn_aux_data_copy_offset_ += CeilDivElemAlignment(n_elem); + return view; + } + + NDArray CopyCompactKVAuxVecToCache(HostMemoryVector* data) { + int64_t n_elem = data->size(); + std::memcpy(merged_compact_kv_aux_data_host_.data() + compact_kv_aux_data_copy_offset_, + data->data(), n_elem * elem_byte_size_); + NDArray view = merged_compact_kv_aux_data_device_.CreateView( + {n_elem}, dtype_aux_, compact_kv_aux_data_copy_offset_ * elem_byte_size_); + compact_kv_aux_data_copy_offset_ += CeilDivElemAlignment(n_elem); return view; } @@ -635,9 +781,12 @@ class CachedPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { const int64_t elem_byte_size_; const int64_t offset_alignment_; - int64_t copy_offset_ = 0; - HostMemoryVector merged_aux_data_host_; - memory::Storage merged_aux_data_device_; + int64_t attn_aux_data_copy_offset_ = 0; + int64_t compact_kv_aux_data_copy_offset_ = 0; + HostMemoryVector merged_attn_aux_data_host_; + HostMemoryVector merged_compact_kv_aux_data_host_; + NDArray merged_attn_aux_data_device_; + NDArray merged_compact_kv_aux_data_device_; }; /*! @@ -726,8 +875,24 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { bool dirty_aux_data_device_ = false; /*! \brief The batch size of the current round of forwarding. */ int64_t cur_batch_size_; + /*! \brief The ids of the sequences in the current round of forwarding. */ + IntTuple cur_seq_ids_; /*! \brief The append lengths of the sequences in the current round of forwarding. */ IntTuple cur_append_lengths_; + /*! \brief The token tree parent array of the sequences in the current round of forwarding. */ + IntTuple cur_token_tree_parent_ptr_{nullptr}; + /*! \brief The depth of each node in the token tree, for the sequences in the current batch. */ + std::vector> cur_token_tree_node_depths_; + /*! \brief Whether the current batch of sequences are token chains (not token trees). */ + bool is_chain_; + /*! \brief Number of fork depth in the current round of forward. */ + int num_depths_; + /*! \brief Whether to compute attention after appending KV into cache or not. */ + bool append_before_attn_; + /*! \brief Whether to use decode kernel for each depth. (see GetChunkedBlockIds) */ + std::vector use_decode_kernel_; + /*! \brief Whether the attention request is a decode request, set in BeginForwardFunction. */ + bool is_decode_request_; /*! \brief The auxiliary data manager for attention. */ std::unique_ptr aux_data_manager_; @@ -755,6 +920,11 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { HostMemoryVector q_rope_position_map_host_; HostMemoryVector append_position_map_host_; HostMemoryVector cur_append_lengths_indptr_host_; + HostMemoryVector tree_attn_mask_host_; + HostMemoryVector tree_attn_mn_indptr_host_; + HostMemoryVector commit_copy_length_indptr_host_; + HostMemoryVector commit_copy_src_pos_in_page_table_host_; + HostMemoryVector commit_copy_dst_pos_in_page_table_host_; //------------------------------------------- // For efficient memory management, the actual sizes of the arrays @@ -767,6 +937,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { NDArray k_ragged_rope_pos_offset_view_; NDArray q_rope_position_map_view_; NDArray append_position_map_view_; + NDArray tree_attn_mask_view_; + NDArray tree_attn_mn_indptr_view_; NDArray temp_attn_output_view_; NDArray temp_attn_scores_view_; NDArray merged_attn_scores_view_; @@ -777,11 +949,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { std::vector k_rope_pos_offset_view_; PackedFunc f_transpose_append_; + PackedFunc f_compact_copy_; PackedFunc f_attention_prefill_; PackedFunc f_attention_decode_; PackedFunc f_attention_prefill_sliding_window_; PackedFunc f_attention_decode_sliding_window_; PackedFunc f_attention_prefill_ragged_; + PackedFunc f_attention_prefill_with_tree_mask_; Optional f_attention_prefill_ragged_begin_forward_; Optional f_attention_prefill_ragged_end_forward_; Optional f_attention_prefill_begin_forward_; @@ -793,16 +967,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { PackedFunc f_copy_single_page_; Optional f_debug_get_kv_; - /*! \brief Number of fork depth in the current round of forward. */ - int num_depths_; - /*! \brief Whether to compute attention after appending KV into cache or not. */ - bool append_before_attn_; - /*! \brief Whether to use decode kernel for each depth. (see GetChunkedBlockIds) */ - std::vector use_decode_kernel_; - /*! \brief Whether the attention request is a decode request, set in BeginForwardFunction. */ - bool is_decode_request_; /*! \brief The device this PagedKVCache runs on. */ - DLDevice device_; + Device device_; /*! \brief The device stream for the default computation operations. */ TVMStreamHandle compute_stream_ = nullptr; /*! \brief The device stream for copying auxiliary data structure to GPU. */ @@ -815,10 +981,10 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { int64_t num_layers, int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim, int64_t reserved_num_seqs, int64_t num_total_pages, int64_t prefill_chunk_size, bool support_sliding_window, RoPEMode rope_mode, double rotary_scale, double rotary_theta, - DLDataType dtype, DLDevice device, PackedFunc f_transpose_append, + DLDataType dtype, Device device, PackedFunc f_transpose_append, PackedFunc f_compact_copy, PackedFunc f_attention_prefill, PackedFunc f_attention_decode, PackedFunc f_attention_prefill_sliding_window, PackedFunc f_attention_decode_sliding_window, - PackedFunc f_attention_prefill_ragged, + PackedFunc f_attention_prefill_ragged, PackedFunc f_attention_prefill_with_tree_mask, Optional f_attention_prefill_ragged_begin_forward, Optional f_attention_prefill_ragged_end_forward, Optional f_attention_prefill_begin_forward, @@ -839,11 +1005,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { rotary_scale_(rotary_scale), rotary_theta_(rotary_theta), f_transpose_append_(std::move(f_transpose_append)), + f_compact_copy_(std::move(f_compact_copy)), f_attention_prefill_(std::move(f_attention_prefill)), f_attention_decode_(std::move(f_attention_decode)), f_attention_prefill_sliding_window_(std::move(f_attention_prefill_sliding_window)), f_attention_decode_sliding_window_(std::move(f_attention_decode_sliding_window)), f_attention_prefill_ragged_(std::move(f_attention_prefill_ragged)), + f_attention_prefill_with_tree_mask_(std::move(f_attention_prefill_with_tree_mask)), f_attention_prefill_ragged_begin_forward_( std::move(f_attention_prefill_ragged_begin_forward)), f_attention_prefill_ragged_end_forward_(std::move(f_attention_prefill_ragged_end_forward)), @@ -887,6 +1055,19 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { HostMemoryVector(prefill_chunk_size, dtype_aux_, preferred_host_device); cur_append_lengths_indptr_host_ = HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, preferred_host_device); + tree_attn_mask_host_ = + HostMemoryVector(kTreeAttnMaxTreeSize * kTreeAttnMaxTreeSize * reserved_num_seqs, + dtype_aux_, preferred_host_device); + tree_attn_mn_indptr_host_ = + HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, preferred_host_device); + commit_copy_length_indptr_host_ = + HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, preferred_host_device); + commit_copy_src_pos_in_page_table_host_ = + HostMemoryVector(std::min(kTreeAttnMaxTreeSize * reserved_num_seqs, prefill_chunk_size), + dtype_aux_, preferred_host_device); + commit_copy_dst_pos_in_page_table_host_ = + HostMemoryVector(std::min(kTreeAttnMaxTreeSize * reserved_num_seqs, prefill_chunk_size), + dtype_aux_, preferred_host_device); for (int d = 0; d < kPagedKVCacheMaxBlockDepth; ++d) { temp_attn_workspace_.push_back( @@ -1108,6 +1289,42 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } } + void CompactKVCopy() { + int total_copy_length = commit_copy_length_indptr_host_.back(); + ICHECK_GE(total_copy_length, 0); + if (total_copy_length == 0) { + return; + } + + // Copy indptr/src/dst arrays to GPU. + aux_data_manager_->ResetCompactKVAuxDataCopy(); + NDArray commit_copy_length_indptr_view = + aux_data_manager_->CopyCommitLengthIndptrAsync(&commit_copy_length_indptr_host_); + NDArray commit_copy_src_dst_pos_in_page_table_view = + aux_data_manager_->CopyCommitSrcDstPosInPageTableAsync( + &commit_copy_src_pos_in_page_table_host_, &commit_copy_dst_pos_in_page_table_host_); + aux_data_manager_->CommitCompactKVAuxDataCopy(); + + // Invoke the copy kernel on copy stream. + if (copy_stream_ != compute_stream_) { + // Set the copy stream for copy. + DeviceAPI::Get(device_)->SetStream(device_, copy_stream_); + } + ICHECK(f_compact_copy_.defined()) << "Function \"f_compact_copy\" is not defined."; + for (int layer = 0; layer < num_layers_; ++layer) { + f_compact_copy_(pages_[layer], commit_copy_length_indptr_view, + commit_copy_src_dst_pos_in_page_table_view, cur_batch_size_); + } + if (copy_stream_ != compute_stream_) { + // Set the compute stream back. + DeviceAPI::Get(device_)->SetStream(device_, compute_stream_); + } + + // Note: We do not explicitly synchronize the copy stream here. + // The safety is guaranteed by the synchronization pushed by the next round + // of BeginForward, which also copies auxiliary data structure to GPU. + } + void EnableSlidingWindowForSeq(int64_t seq_id, int32_t sliding_window_size, int32_t attn_sink_size) final { CHECK(support_sliding_window_) << "The KV cache does not support sliding window."; @@ -1143,6 +1360,10 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { CHECK_LE(n, it->second.seq_length) << "The sequence only has length " << it->second.seq_length << ", while the length of pop is " << n << " which exceeds the whole sequence length."; + if (n == 0) { + return; + } + int32_t block_idx = it->second.last_block_idx; // The block should have at least one reference, which comes from the sequence. ICHECK_GE(global_block_pool_[block_idx].external_ref_cnt, 1); @@ -1211,13 +1432,27 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { /************** Attention **************/ - void BeginForward(const IntTuple& seq_ids, const IntTuple& append_lengths) final { + void BeginForward(const IntTuple& seq_ids, const IntTuple& append_lengths, + const Optional& opt_token_tree_parent_ptr) final { + CHECK(!cur_token_tree_parent_ptr_.defined()) + << "The last round of forward which involves token tree has not been committed. Please " + "call \"CommitAcceptedTreeNodes\" to commit the accepted tokens."; + CHECK_EQ(seq_ids.size(), append_lengths.size()) << "The seq_ids size (" << seq_ids.size() << ") and append_lengths size (" << append_lengths.size() << ") mismatch."; cur_batch_size_ = seq_ids.size(); + cur_seq_ids_ = seq_ids; cur_append_lengths_ = append_lengths; + // - Check token tree validity and process the token tree. + is_chain_ = true; + tree_attn_mask_host_.clear(); + tree_attn_mn_indptr_host_.clear(); + if (opt_token_tree_parent_ptr.defined()) { + is_chain_ = ConstructTokenTreeMask(opt_token_tree_parent_ptr.value()); + } + // - Collect sequence/block/page information for attention. std::vector sequences; std::vector last_block_length_before_append; @@ -1322,7 +1557,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { int64_t append_length = append_lengths[i]; const Block& block = global_block_pool_[sequences[i]->last_block_idx]; for (int64_t pos = 0; pos < append_length; ++pos) { - q_rope_position_map_host_.push_back(k_ragged_rope_pos_offset_host_[i] + pos); + q_rope_position_map_host_.push_back( + k_ragged_rope_pos_offset_host_[i] + + (is_chain_ ? pos : cur_token_tree_node_depths_[i][pos])); int32_t pos_in_block = block.seq_length - append_length + pos; if (last_block_length_before_append[i] + pos < block.sink_length) { @@ -1412,6 +1649,81 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } } + void CommitAcceptedTokenTreeNodes(const IntTuple& leaf_indices) final { + CHECK_NE(cur_batch_size_, -1) + << "Cannot commit accepted token tree nodes since BeginForward is not invoked."; + CHECK_EQ(leaf_indices.size(), cur_batch_size_) + << "The number of input leaf indices does not equal to the current batch size."; + + for (int i = 0; i < cur_batch_size_; ++i) { + CHECK_GE(leaf_indices[i], 0) + << "Invalid tree index " << leaf_indices[i] << " which is negative"; + CHECK_LT(leaf_indices[i], cur_append_lengths_[i]) + << "Invalid tree index " << leaf_indices[i] + << " which is larger than or equals to the append length " << cur_append_lengths_[i] + << " of the sequence"; + } + + if (!is_chain_) { + commit_copy_length_indptr_host_.clear(); + commit_copy_src_pos_in_page_table_host_.clear(); + commit_copy_dst_pos_in_page_table_host_.clear(); + commit_copy_length_indptr_host_.push_back(0); + + for (int i = 0; i < cur_batch_size_; ++i) { + // Get the accepted node path on the token tree. + std::vector path_on_tree; + path_on_tree.reserve(cur_token_tree_node_depths_[i][leaf_indices[i]] + 1); + int node = leaf_indices[i]; + while (node != -1) { + path_on_tree.push_back(node); + node = cur_token_tree_parent_ptr_[cur_append_lengths_indptr_host_[i] + node]; + } + ICHECK_EQ(path_on_tree.size(), cur_token_tree_node_depths_[i][leaf_indices[i]] + 1); + // Get the destination array (range [0, path_length - 1)) of KV cache copy. + std::vector copy_dst_pos_in_seq; + copy_dst_pos_in_seq.resize(path_on_tree.size()); + std::iota(copy_dst_pos_in_seq.rbegin(), copy_dst_pos_in_seq.rend(), /*value=*/0); + // Remove the positions whose KV data do not need copy. + while (!path_on_tree.empty() && path_on_tree.back() == copy_dst_pos_in_seq.back()) { + path_on_tree.pop_back(); + copy_dst_pos_in_seq.pop_back(); + } + // Reverse the position arrays so that they are in ascending order. + std::reverse(path_on_tree.begin(), path_on_tree.end()); + std::reverse(copy_dst_pos_in_seq.begin(), copy_dst_pos_in_seq.end()); + + // Convert the in-sequence src/dst positions to src/dst positions in page table + // by looking up "append_position_map". + for (int p = 0; p < static_cast(path_on_tree.size()); ++p) { + commit_copy_src_pos_in_page_table_host_.push_back( + append_position_map_host_[cur_append_lengths_indptr_host_[i] + path_on_tree[p]]); + commit_copy_dst_pos_in_page_table_host_.push_back( + append_position_map_host_[cur_append_lengths_indptr_host_[i] + + copy_dst_pos_in_seq[p]]); + } + commit_copy_length_indptr_host_.push_back(commit_copy_length_indptr_host_.back() + + path_on_tree.size()); + } + + // Compact the KV data for each sequence by copying KV data. + CompactKVCopy(); + } + + // - Update the KV cache page data structure. + // Note: Function "PopN" only changes the page table structure and does not + // change the KV cache data. Therefore, we can directly use it, since + // we have already launched all copies. + for (int i = 0; i < cur_batch_size_; ++i) { + int64_t length_to_pop = + cur_append_lengths_[i] - cur_token_tree_node_depths_[i][leaf_indices[i]] - 1; + PopN(cur_seq_ids_[i], length_to_pop); + } + + // Reset the token tree. + cur_token_tree_parent_ptr_ = IntTuple{nullptr}; + } + NDArray GetQueryPositions() final { // Sync the copy stream and the compute stream. ComputeStreamWaitForCopyStream(); @@ -1502,6 +1814,73 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { return block_idx; } + bool ConstructTokenTreeMask(const IntTuple& token_tree_parent_ptr) { + // We check if the token tree deteriorates to a chain, + // because chain cases can have simplified attention work flow. + bool is_chain = true; + cur_token_tree_parent_ptr_ = token_tree_parent_ptr; + cur_token_tree_node_depths_.clear(); + cur_token_tree_node_depths_.reserve(cur_batch_size_); + + int64_t sum_append_length = 0; + // - Construct the mn indptr array, which is the indptr of the mask size of each sequence. + tree_attn_mn_indptr_host_.push_back(0); + for (int64_t append_length : cur_append_lengths_) { + sum_append_length += append_length; + tree_attn_mn_indptr_host_.push_back(tree_attn_mn_indptr_host_.back() + + static_cast(append_length * append_length)); + } + CHECK_EQ(token_tree_parent_ptr.size(), sum_append_length) + << "Invalid token tree size. The sum of \"append_lengths\" is " << sum_append_length + << " while there are " << token_tree_parent_ptr.size() + << " elements in \"token_tree_parent_ptr\"."; + + // - Construct the mask of each sequence. + int processed_pos = 0; + for (int i = 0; i < cur_batch_size_; ++i) { + int64_t append_length = cur_append_lengths_[i]; + std::vector> mask; + std::vector depth; + mask.reserve(append_length); + depth.reserve(append_length); + for (int64_t n = 0; n < append_length; ++n) { + CHECK_LT(token_tree_parent_ptr[processed_pos], n) + << "Invalid token tree. The parent of node " << n << " in tree " << i << " is " + << token_tree_parent_ptr[processed_pos] << ", which is not smaller than " << n; + CHECK_GE(token_tree_parent_ptr[processed_pos], -1) + << "Invalid token tree. The parent of node " << n << " in tree " << i << " is " + << token_tree_parent_ptr[processed_pos]; + if (token_tree_parent_ptr[processed_pos] != n - 1) { + // The parent of the current node is not the last node. + // Therefore the tree is not a chain. + is_chain = false; + } + + std::vector single_pos_mask; + if (token_tree_parent_ptr[processed_pos] != -1) { + // The current node has a parent in the token tree. + single_pos_mask = {mask[token_tree_parent_ptr[processed_pos]].begin(), + mask[token_tree_parent_ptr[processed_pos]].end()}; + depth.push_back(depth[token_tree_parent_ptr[processed_pos]] + 1); + } else { + // The current node is root in the token tree. + single_pos_mask.resize(append_length, /*value=*/0); + depth.push_back(0); + } + single_pos_mask[n] = 1; + mask.push_back(single_pos_mask); + for (int32_t mask_val : single_pos_mask) { + tree_attn_mask_host_.push_back(mask_val); + } + + ++processed_pos; + } + cur_token_tree_node_depths_.push_back(std::move(depth)); + } + + return is_chain; + } + /*! * \brief Slide the KV cache window of the given sequence when * it has sliding window enabled. @@ -1766,12 +2145,27 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { attn_score_scaling_factor); } else { // Compute appended text self-attention - f_attention_prefill_ragged_(q_data, cur_append_length_indptr_view_, k_data, v_data, - cur_append_length_indptr_view_, q_rope_position_map_view_, - k_ragged_rope_pos_offset_view_, output, merged_attn_scores_view_, - /*causal=*/1, - /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, - rotary_theta_, attn_score_scaling_factor); + if (is_chain_) { + // If the batch does not form a tree, use raggedness prefill kernel. + f_attention_prefill_ragged_(q_data, cur_append_length_indptr_view_, k_data, v_data, + cur_append_length_indptr_view_, q_rope_position_map_view_, + k_ragged_rope_pos_offset_view_, output, + merged_attn_scores_view_, + /*causal=*/1, + /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, + rotary_theta_, attn_score_scaling_factor); + } else { + // The batch requires tree attention. + ICHECK(tree_attn_mask_view_.defined()); + ICHECK(tree_attn_mn_indptr_view_.defined()); + ICHECK(f_attention_prefill_with_tree_mask_.defined()) + << "Function \"f_attention_prefill_with_tree_mask_\" is not defined."; + f_attention_prefill_with_tree_mask_( + q_data, cur_append_length_indptr_view_, k_data, v_data, cur_append_length_indptr_view_, + q_rope_position_map_view_, tree_attn_mn_indptr_view_, tree_attn_mask_view_, output, + merged_attn_scores_view_, /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, + rotary_scale_, rotary_theta_, attn_score_scaling_factor, cur_batch_size_); + } for (int d = 0; d < num_depths_; ++d) { if (page_indices_on_depths_view_[d]->shape[0] == 0) { @@ -1840,7 +2234,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { ICHECK_EQ(total_append_length, append_position_map_host_.size()); // - Reset the copy. - aux_data_manager_->ResetCopy(); + aux_data_manager_->ResetAttnAuxDataCopy(); // 1. q_rope_position_map // q_rope_position_map has to be synced first so that it has a 0 byte offset @@ -1900,7 +2294,16 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // 9. append_position_map append_position_map_view_ = aux_data_manager_->CopyAppendPositionMapAsync(&append_position_map_host_); - // 10. Create view for temporary arrays for attention computation. + // 10. tree_attn_mask and tree_attn_mn_indptr + if (!is_chain_) { + tree_attn_mask_view_ = aux_data_manager_->CopyTreeAttnMaskAsync(&tree_attn_mask_host_); + tree_attn_mn_indptr_view_ = + aux_data_manager_->CopyTreeAttnMNIndptrAsync(&tree_attn_mn_indptr_host_); + } else { + tree_attn_mask_view_ = NDArray{nullptr}; + tree_attn_mn_indptr_view_ = NDArray{nullptr}; + } + // 11. Create view for temporary arrays for attention computation. temp_attn_output_view_ = temp_attn_output_device_.CreateView( {total_append_length, num_qo_heads_, head_dim_}, temp_attn_output_device_->dtype); temp_attn_scores_view_ = temp_attn_scores_device_.CreateView( @@ -1909,7 +2312,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { {total_append_length, num_qo_heads_}, merged_attn_scores_device_->dtype); // - Commit the copy. - aux_data_manager_->CommitCopy(); + aux_data_manager_->CommitAttnAuxDataCopy(); // - Reset the dirty flag to false. dirty_aux_data_device_ = false; } @@ -1922,21 +2325,44 @@ TVM_REGISTER_OBJECT_TYPE(PagedAttentionKVCacheObj); //------------------------------------------------- TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") - .set_body_typed([](ShapeTuple cache_config, int64_t num_layers, int64_t num_qo_heads, - int64_t num_kv_heads, int64_t head_dim, int rope_mode, double rotary_scale, - double rotary_theta, NDArray init, PackedFunc f_transpose_append, - PackedFunc f_attention_prefill, PackedFunc f_attention_decode, - PackedFunc f_attention_prefill_sliding_window, // - PackedFunc f_attention_decode_sliding_window, - PackedFunc f_attention_prefill_ragged, - PackedFunc f_attention_prefill_ragged_begin_forward, - PackedFunc f_attention_prefill_ragged_end_forward, - PackedFunc f_attention_prefill_begin_forward, - PackedFunc f_attention_prefill_end_forward, - PackedFunc f_attention_decode_begin_forward, - PackedFunc f_attention_decode_end_forward, PackedFunc f_merge_inplace, - PackedFunc f_split_rotary, PackedFunc f_copy_single_page, - Optional f_debug_get_kv) { + .set_body([](TVMArgs args, TVMRetValue* rv) { + CHECK(args.size() == 25 || args.size() == 26 || args.size() == 27) + << "Invalid number of KV cache constructor args."; + ShapeTuple cache_config = args[0]; + int64_t num_layers = args[1]; + int64_t num_qo_heads = args[2]; + int64_t num_kv_heads = args[3]; + int64_t head_dim = args[4]; + int rope_mode = args[5]; + double rotary_scale = args[6]; + double rotary_theta = args[7]; + NDArray init = args[8]; + PackedFunc f_transpose_append = args[9]; + PackedFunc f_attention_prefill = args[10]; + PackedFunc f_attention_decode = args[11]; + PackedFunc f_attention_prefill_sliding_window = args[12]; + PackedFunc f_attention_decode_sliding_window = args[13]; + PackedFunc f_attention_prefill_ragged = args[14]; + PackedFunc f_attention_prefill_ragged_begin_forward = args[15]; + PackedFunc f_attention_prefill_ragged_end_forward = args[16]; + PackedFunc f_attention_prefill_begin_forward = args[17]; + PackedFunc f_attention_prefill_end_forward = args[18]; + PackedFunc f_attention_decode_begin_forward = args[19]; + PackedFunc f_attention_decode_end_forward = args[20]; + PackedFunc f_merge_inplace = args[21]; + PackedFunc f_split_rotary = args[22]; + PackedFunc f_copy_single_page = args[23]; + Optional f_debug_get_kv = args[24]; + PackedFunc f_compact_copy{nullptr}; + PackedFunc f_attention_prefill_with_tree_mask{nullptr}; + + if (args.size() >= 26) { + f_compact_copy = args[25].AsObjectRef(); + } + if (args.size() >= 27) { + f_attention_prefill_with_tree_mask = args[26].AsObjectRef(); + } + CHECK_EQ(cache_config.size(), 5); int64_t reserved_num_seqs = cache_config[0]; int64_t total_token_capacity = cache_config[1]; @@ -1952,28 +2378,52 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") page_size, num_layers, num_qo_heads, num_kv_heads, head_dim, reserved_num_seqs, num_total_pages, prefill_chunk_size, support_sliding_window, RoPEMode(rope_mode), rotary_scale, rotary_theta, init->dtype, init->device, std::move(f_transpose_append), - std::move(f_attention_prefill), std::move(f_attention_decode), + std::move(f_compact_copy), std::move(f_attention_prefill), std::move(f_attention_decode), std::move(f_attention_prefill_sliding_window), std::move(f_attention_decode_sliding_window), std::move(f_attention_prefill_ragged), + std::move(f_attention_prefill_with_tree_mask), std::move(f_attention_prefill_ragged_begin_forward), std::move(f_attention_prefill_ragged_end_forward), std::move(f_attention_prefill_begin_forward), std::move(f_attention_prefill_end_forward), std::move(f_attention_decode_begin_forward), std::move(f_attention_decode_end_forward), std::move(f_merge_inplace), std::move(f_split_rotary), std::move(f_copy_single_page), std::move(f_debug_get_kv)); - return AttentionKVCache(std::move(n)); + *rv = AttentionKVCache(std::move(n)); }); TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced") - .set_body_typed([](ShapeTuple cache_config, int64_t num_layers, int64_t num_qo_heads, - int64_t num_kv_heads, int64_t head_dim, int rope_mode, double rotary_scale, - double rotary_theta, NDArray init, PackedFunc f_transpose_append, - PackedFunc f_attention_prefill, PackedFunc f_attention_decode, - PackedFunc f_attention_prefill_sliding_window, - PackedFunc f_attention_decode_sliding_window, - PackedFunc f_attention_prefill_ragged, PackedFunc f_merge_inplace, - PackedFunc f_split_rotary, PackedFunc f_copy_single_page, - Optional f_debug_get_kv) { + .set_body([](TVMArgs args, TVMRetValue* rv) { + CHECK(args.size() == 19 || args.size() == 20 || args.size() == 21) + << "Invalid number of KV cache constructor args."; + ShapeTuple cache_config = args[0]; + int64_t num_layers = args[1]; + int64_t num_qo_heads = args[2]; + int64_t num_kv_heads = args[3]; + int64_t head_dim = args[4]; + int rope_mode = args[5]; + double rotary_scale = args[6]; + double rotary_theta = args[7]; + NDArray init = args[8]; + PackedFunc f_transpose_append = args[9]; + PackedFunc f_attention_prefill = args[10]; + PackedFunc f_attention_decode = args[11]; + PackedFunc f_attention_prefill_sliding_window = args[12]; + PackedFunc f_attention_decode_sliding_window = args[13]; + PackedFunc f_attention_prefill_ragged = args[14]; + PackedFunc f_merge_inplace = args[15]; + PackedFunc f_split_rotary = args[16]; + PackedFunc f_copy_single_page = args[17]; + Optional f_debug_get_kv = args[18]; + PackedFunc f_compact_copy{nullptr}; + PackedFunc f_attention_prefill_with_tree_mask{nullptr}; + + if (args.size() >= 20) { + f_compact_copy = args[19].AsObjectRef(); + } + if (args.size() >= 21) { + f_attention_prefill_with_tree_mask = args[20].AsObjectRef(); + } + CHECK_EQ(cache_config.size(), 5); int64_t reserved_num_seqs = cache_config[0]; int64_t total_token_capacity = cache_config[1]; @@ -1989,13 +2439,14 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced") page_size, num_layers, num_qo_heads, num_kv_heads, head_dim, reserved_num_seqs, num_total_pages, prefill_chunk_size, support_sliding_window, RoPEMode(rope_mode), rotary_scale, rotary_theta, init->dtype, init->device, std::move(f_transpose_append), - std::move(f_attention_prefill), std::move(f_attention_decode), + std::move(f_compact_copy), std::move(f_attention_prefill), std::move(f_attention_decode), std::move(f_attention_prefill_sliding_window), - std::move(f_attention_decode_sliding_window), std::move(f_attention_prefill_ragged), // - NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, // + std::move(f_attention_decode_sliding_window), std::move(f_attention_prefill_ragged), + std::move(f_attention_prefill_with_tree_mask), // + NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, // std::move(f_merge_inplace), std::move(f_split_rotary), std::move(f_copy_single_page), std::move(f_debug_get_kv)); - return AttentionKVCache(std::move(n)); + *rv = AttentionKVCache(std::move(n)); }); } // namespace relax_vm diff --git a/src/runtime/relax_vm/rnn_state.cc b/src/runtime/relax_vm/rnn_state.cc index 69225d6b2c47..16fe6791b88d 100644 --- a/src/runtime/relax_vm/rnn_state.cc +++ b/src/runtime/relax_vm/rnn_state.cc @@ -205,10 +205,24 @@ class RNNStateImpObj : public RNNStateObj { /************** Interaction **************/ - void BeginForward(const IntTuple& seq_ids, const IntTuple& append_lengths) { + void BeginForward(const IntTuple& seq_ids, const IntTuple& append_lengths, + const Optional& opt_token_tree_parent_ptr) final { CHECK_EQ(seq_ids.size(), append_lengths.size()) << "The seq_ids size (" << seq_ids.size() << ") and append_lengths size (" << append_lengths.size() << ") mismatch."; + + if (opt_token_tree_parent_ptr.defined()) { + IntTuple token_tree_parent_ptr = opt_token_tree_parent_ptr.value(); + int matched_pos = 0; + for (int64_t append_length : append_lengths) { + for (int64_t i = 0; i < append_length; ++i) { + CHECK_EQ(token_tree_parent_ptr[matched_pos], i - 1) + << "Unexpected token tree for RNN state. RNN state only supports chains as token " + "trees."; + ++matched_pos; + } + } + } cur_batch_size_ = seq_ids.size(); cur_append_lengths_ = append_lengths; cur_seq_ids_ = seq_ids; diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py index 6504175b5680..0a69d184e5a9 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py @@ -53,6 +53,7 @@ fpopn = None fbegin_forward = None fend_forward = None +fcommit_accepted_token_tree_nodes = None fattention_with_fuse_qkv = None fis_empty = None fdebug_get_kv = None @@ -64,18 +65,22 @@ fattn_prefill_sliding_window = None fattn_decode_sliding_window = None fattn_prefill_ragged = None +fattn_prefill_with_tree_mask = None fmerge_state = None fsplit_rotary = None fattention_rotary = None fcopy_single_page = None +fcompact_copy = None def set_global_func(head_dim, dtype): global fclear, fadd_sequence, fremove_sequence, ffork_sequence, fenable_sliding_window_for_seq - global fpopn, fbegin_forward, fend_forward, fattention_with_fuse_qkv, fis_empty, fdebug_get_kv - global ftranspose_append, fcopy_cache, fattn_prefill, fattn_decode, fattn_prefill_ragged + global fpopn, fbegin_forward, fend_forward, fcommit_accepted_token_tree_nodes + global fattention_with_fuse_qkv, fis_empty, fdebug_get_kv + global ftranspose_append, fcopy_cache, fattn_prefill, fattn_decode + global fattn_prefill_ragged, fattn_prefill_with_tree_mask global fattn_prefill_sliding_window, fattn_decode_sliding_window - global fmerge_state, fsplit_rotary, fattention_rotary, fcopy_single_page + global fmerge_state, fsplit_rotary, fattention_rotary, fcopy_single_page, fcompact_copy fclear = tvm.get_global_func("vm.builtin.kv_state_clear") fadd_sequence = tvm.get_global_func("vm.builtin.kv_state_add_sequence") @@ -87,6 +92,9 @@ def set_global_func(head_dim, dtype): fpopn = tvm.get_global_func("vm.builtin.kv_state_popn") fbegin_forward = tvm.get_global_func("vm.builtin.kv_state_begin_forward") fend_forward = tvm.get_global_func("vm.builtin.kv_state_end_forward") + fcommit_accepted_token_tree_nodes = tvm.get_global_func( + "vm.builtin.attention_kv_cache_commit_accepted_token_tree_nodes" + ) fattention_with_fuse_qkv = tvm.get_global_func( "vm.builtin.attention_kv_cache_attention_with_fused_qkv" ) @@ -103,11 +111,13 @@ def set_global_func(head_dim, dtype): _attention_prefill(num_kv_heads, num_qo_heads, head_dim, dtype, True, target), _attention_decode(num_kv_heads, num_qo_heads, head_dim, dtype, True, target), _attention_prefill_ragged(num_kv_heads, num_qo_heads, head_dim, dtype, target), + _attention_prefill_with_tree_mask(num_kv_heads, num_qo_heads, head_dim, dtype, target), _merge_state_inplace(num_qo_heads, head_dim, dtype, target), llama_rope_with_position_map( rope_theta, rope_scale, head_dim, num_qo_heads, num_kv_heads, dtype ), _copy_single_page(num_kv_heads, page_size, head_dim, dtype, target), + _compact_kv_copy(num_kv_heads, head_dim, dtype, target), ]: mod = tvm.IRModule({"main": tir_func}) with target: @@ -123,9 +133,11 @@ def set_global_func(head_dim, dtype): fattn_prefill_sliding_window, fattn_decode_sliding_window, fattn_prefill_ragged, + fattn_prefill_with_tree_mask, fmerge_state, fsplit_rotary, fcopy_single_page, + fcompact_copy, ) = builts @@ -159,6 +171,8 @@ def create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window): fsplit_rotary, fcopy_single_page, fcopy_cache, + fcompact_copy, + fattn_prefill_with_tree_mask, ) return cache @@ -211,7 +225,7 @@ def verify_cached_kv(kv_cache, seq_ids, expected_k, expected_v): tvm.testing.assert_allclose(values.numpy(), values_expected, rtol=1e-3, atol=1e-3) -def f_apply_rotary(x, offset, scale, theta): +def f_apply_rotary(x, offset, scale, theta, offset_list: Optional[List[int]] = None): # x: (N, H, D) assert len(x.shape) == 3 nfeat = x.shape[-1] @@ -220,7 +234,11 @@ def f_apply_rotary(x, offset, scale, theta): y = np.concatenate([-x[:, :, nfeat_half:], x[:, :, :nfeat_half]], axis=-1) inv_freq = scale / (theta ** (np.arange(0, nfeat, 2).astype("float32") / nfeat)) - t = np.arange(offset, offset + x.shape[0], dtype=inv_freq.dtype) + t = ( + np.arange(offset, offset + x.shape[0], dtype=inv_freq.dtype) + if offset_list is None + else (np.array(offset_list, dtype=inv_freq.dtype) + offset) + ) freqs = np.einsum("i,j->ij", t, inv_freq) emb = np.concatenate((freqs, freqs), axis=-1) cos_values = np.cos(emb) @@ -237,6 +255,8 @@ def apply_attention( cached_v: Dict[int, np.ndarray], sliding_window_sizes: Optional[List[int]] = None, attn_sink_sizes: Optional[List[int]] = None, + token_tree_parent_ptr_list: Optional[List[List[int]]] = None, + accepted_leaf_indices: Optional[List[int]] = None, ) -> None: seq_ids = [] append_lengths = [] @@ -263,14 +283,42 @@ def apply_attention( cached_k[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype) cached_v[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype) - fbegin_forward(kv_cache, ShapeTuple(seq_ids), ShapeTuple(append_lengths)) + assert (token_tree_parent_ptr_list is None) == (accepted_leaf_indices is None) + flattened_token_tree_parent_ptr = None + token_tree_node_depths_list: List[Optional[List[int]]] = [None for _ in batch] + if token_tree_parent_ptr_list: + assert len(token_tree_node_depths_list) == len(seq_ids) + assert len(accepted_leaf_indices) == len(seq_ids) + flattened_token_tree_parent_ptr = [] + for i, (token_tree_parent_ptr, append_length) in enumerate( + zip(token_tree_parent_ptr_list, append_lengths) + ): + assert len(token_tree_parent_ptr) == append_length + flattened_token_tree_parent_ptr += token_tree_parent_ptr + token_tree_node_depths = [] + for parent in token_tree_parent_ptr: + token_tree_node_depths.append( + 0 if parent == -1 else token_tree_node_depths[parent] + 1 + ) + token_tree_node_depths_list[i] = token_tree_node_depths + + fbegin_forward( + kv_cache, + ShapeTuple(seq_ids), + ShapeTuple(append_lengths), + ( + ShapeTuple(flattened_token_tree_parent_ptr) + if flattened_token_tree_parent_ptr is not None + else None + ), + ) global_new_q = np.zeros((num_layers, 0, num_qo_heads, head_dim), dtype) global_new_k = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype) global_new_v = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype) q_array = [] - for seq_id, append_length in batch: + for i, (seq_id, append_length) in enumerate(batch): new_q = np.random.rand(num_layers, append_length, num_qo_heads, head_dim).astype(dtype) new_k = np.random.rand(num_layers, append_length, num_kv_heads, head_dim).astype(dtype) new_v = np.random.rand(num_layers, append_length, num_kv_heads, head_dim).astype(dtype) @@ -285,7 +333,11 @@ def apply_attention( new_k[l] if rope_mode != RopeMode.NORMAL else f_apply_rotary( - new_k[l], cached_k[seq_id].shape[1], rope_scale, rope_theta + new_k[l], + cached_k[seq_id].shape[1], + rope_scale, + rope_theta, + token_tree_node_depths_list[i], ) ) for l in range(num_layers) @@ -323,12 +375,26 @@ def apply_attention( rope_offset, rope_scale, rope_theta, + token_tree_node_depths_list[i], ) ).transpose(1, 0, 2) k_seq = ( cached_k[seq_id][layer_id] if rope_mode != RopeMode.INLINE - else f_apply_rotary(cached_k[seq_id][layer_id], 0, rope_scale, rope_theta) + else f_apply_rotary( + cached_k[seq_id][layer_id], + 0, + rope_scale, + rope_theta, + ( + ( + list(range(rope_offset)) + + [depth + rope_offset for depth in token_tree_node_depths_list[i]] + ) + if token_tree_node_depths_list[i] is not None + else None + ), + ) ).transpose(1, 2, 0) v_seq = cached_v[seq_id][layer_id].transpose(1, 0, 2) @@ -336,11 +402,23 @@ def apply_attention( v_seq = np.repeat(v_seq, num_qo_heads // num_kv_heads, axis=0) softmax_input = (q_seq.astype("float32") @ k_seq.astype("float32")) / np.sqrt(head_dim) softmax_shape = softmax_input.shape + assert softmax_shape[-2] == append_length length_diff = softmax_shape[-1] - softmax_shape[-2] assert length_diff >= 0 mask = np.tril( np.full_like(softmax_input, np.finfo("float32").max), k=length_diff ) + np.triu(np.full_like(softmax_input, np.finfo("float32").min), k=length_diff + 1) + if token_tree_parent_ptr_list is not None: + tree_mask = np.full( + (append_length, append_length), np.finfo("float32").min, dtype="float32" + ) + for i, parent in enumerate(token_tree_parent_ptr_list[i]): + if parent != -1: + tree_mask[i] = tree_mask[parent] + tree_mask[i, i] = np.finfo("float32").max + tree_mask = np.broadcast_to(tree_mask, (num_qo_heads, *tree_mask.shape)) + mask[:, :, length_diff:] = tree_mask + softmax_input = np.minimum(softmax_input, mask) results = np.expand_dims( @@ -359,6 +437,32 @@ def apply_attention( sum_length += append_length fend_forward(kv_cache) + if accepted_leaf_indices is not None: + fcommit_accepted_token_tree_nodes(kv_cache, ShapeTuple(accepted_leaf_indices)) + for i, (accepted_leaf_idx, (seq_id, append_length)) in enumerate( + zip(accepted_leaf_indices, batch) + ): + tree_path = [] + node = accepted_leaf_idx + while node != -1: + tree_path.append(node) + node = token_tree_parent_ptr_list[i][node] + offset = cached_k[seq_id].shape[1] - append_length + length_to_pop = append_length - len(tree_path) + assert 0 <= length_to_pop < append_length + for dst_pos, src_pos in enumerate(reversed(tree_path)): + if dst_pos == src_pos: + continue + cached_k[seq_id][:, offset + dst_pos, ...] = cached_k[seq_id][ + :, offset + src_pos, ... + ] + cached_v[seq_id][:, offset + dst_pos, ...] = cached_v[seq_id][ + :, offset + src_pos, ... + ] + if length_to_pop > 0: + cached_k[seq_id] = cached_k[seq_id][:, :-length_to_pop, ...] + cached_v[seq_id] = cached_v[seq_id][:, :-length_to_pop, ...] + for seq_id, _ in batch: if sliding_window_sizes is not None and len(sliding_window_sizes) > seq_id: sliding_window_size = sliding_window_sizes[seq_id] @@ -618,6 +722,64 @@ def test_paged_attention_kv_cache_sliding_window(kv_cache_and_config): ) +@tvm.testing.requires_gpu +@tvm.testing.requires_cuda +def test_paged_attention_kv_cache_tree_attn(kv_cache_and_config): + kv_cache, rope_mode, support_sliding_window = kv_cache_and_config + if support_sliding_window and rope_mode == RopeMode.NORMAL: + # Normal RoPE mode under sliding window settings is not supported. + return + fclear(kv_cache) + + cached_k = {} + cached_v = {} + # Prefill 4 sequences + apply_attention(kv_cache, rope_mode, [(0, 10), (1, 20), (2, 30), (3, 40)], cached_k, cached_v) + # Tree attention + apply_attention( + kv_cache, + rope_mode, + [(0, 7), (1, 15), (2, 10), (3, 14)], + cached_k, + cached_v, + token_tree_parent_ptr_list=[ + [-1, 0, 0, 1, 1, 2, 2], # complete binary tree of height 3 + [-1, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6], # complete binary tree of height 4 + [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8], # chain of length 10 + [-1, 0, 0, 1, 1, 2, 2, -1, 7, 7, 8, 8, 9, 9], # two complete binary trees of height 3 + ], + accepted_leaf_indices=[6, 11, 6, 13], + ) + # Do 5 rounds of decode. + for _ in range(5): + apply_attention(kv_cache, rope_mode, [(0, 1), (1, 1), (2, 1), (3, 1)], cached_k, cached_v) + + # Test the cases where all trees are chains. + fclear(kv_cache) + cached_k = {} + cached_v = {} + # Prefill 4 sequences + apply_attention(kv_cache, rope_mode, [(0, 10), (1, 20), (2, 30), (3, 40)], cached_k, cached_v) + # Tree attention + apply_attention( + kv_cache, + rope_mode, + [(0, 7), (1, 15), (2, 10), (3, 14)], + cached_k, + cached_v, + token_tree_parent_ptr_list=[ + [-1, 0, 1, 2, 3, 4, 5], # complete binary tree of height 7 + [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], # chain of length 15 + [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8], # chain of length 10 + [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], # chain of length 14 + ], + accepted_leaf_indices=[2, 6, 6, 4], + ) + # Do 5 rounds of decode. + for _ in range(5): + apply_attention(kv_cache, rope_mode, [(0, 1), (1, 1), (2, 1), (3, 1)], cached_k, cached_v) + + def kv_cache_transpose_append(head_dim, dtype): # undefined vars used @T.prim_func(check_well_formed=False) @@ -1843,6 +2005,336 @@ def apply_to_md(sch, block): return sch.mod["main"].with_attr("tir.is_scheduled", 1) +def _tree_mask(row, col, mask_ptr, offset, stride, kv_len): + return tir.all(col < kv_len, mask_ptr[offset + row * stride + col] == 1) + + +def _attention_prefill_with_tree_mask( + h_kv, h_q, d, dtype, target: Target +): # pylint: disable=unused-argument + # pylint: disable=invalid-name,line-too-long + NUM_BLKS = 16 + LOAD_VEC = 8 // ((DataType(dtype).bits + 7) // 8) # 8 bytes + group_size = h_q // h_kv + sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1)) + + bdx = 32 + num_warps = 4 + tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16 + L_per_cta = tile_x // group_size + + # Otherwise we would exceed maxComputeWorkgroupStorageSize + if ( + str(target.kind) == "webgpu" + and ((d + 127) // 128) * ((DataType(dtype).bits + 15) // 16) >= 4 + ): + tile_z = 8 + num_warps = 2 + + # fmt: off + @T.prim_func + def batch_tree_attn( # pylint: disable=too-many-branches + var_q: T.handle, # [total_len, h_q, d] + var_q_indptr: T.handle, # [batch_size + 1] + var_k: T.handle, # [total_len, h_kv, d] + var_v: T.handle, # [total_len, h_kv, d] + var_kv_indptr: T.handle, # [batch_size + 1], kv_indptr should be the same as q_indptr in this case + var_q_rope_position: T.handle, # [total_q_len] + var_mn_indptr: T.handle, # [batch_size + 1] + var_mask: T.handle, # [mn_indptr[batch_size]] + var_output: T.handle, # [total_len, h_q, d] + var_lse: T.handle, # [total_len, h_q] + rotary_mode: T.int32, + rope_scale: T.float32, + rope_theta: T.float32, + attn_score_scaling_factor: T.float32, + batch_size: T.int32, + ): + qo_len = T.int32(is_size_var=True) + kv_len = T.int32(is_size_var=True) + q_indptr_elem_offset = T.int32(is_size_var=True) + kv_indptr_elem_offset = T.int32(is_size_var=True) + q_rope_position_elem_offset = T.int32(is_size_var=True) + mn_indptr_elem_offset = T.int32(is_size_var=True) + mask_elem_offset = T.int32(is_size_var=True) + tree_size = T.int32(is_size_var=True) + + q = T.match_buffer(var_q, (qo_len, h_q, d), dtype) + q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", elem_offset=q_indptr_elem_offset) + k = T.match_buffer(var_k, (kv_len, h_kv, d), dtype) + v = T.match_buffer(var_v, (kv_len, h_kv, d), dtype) + kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32", elem_offset=kv_indptr_elem_offset) + q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), "int32", elem_offset=q_rope_position_elem_offset) + mn_indptr = T.match_buffer(var_mn_indptr, (batch_size + 1,), "int32", elem_offset=mn_indptr_elem_offset) + mask = T.match_buffer(var_mask, (tree_size,), "int32", elem_offset=mask_elem_offset) + output = T.match_buffer(var_output, (qo_len, h_q, d), dtype) + lse = T.match_buffer(var_lse, (qo_len, h_q), "float32") # pylint: disable=unused-variable + + # kernel code + for lbx in T.thread_binding(NUM_BLKS, thread="blockIdx.x"): + for lby in T.thread_binding(h_kv, thread="blockIdx.y"): + for lty in T.thread_binding(num_warps, thread="threadIdx.y"): + for ltx in T.thread_binding(bdx, thread="threadIdx.x"): + with T.block("attn"): + bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx]) + T.reads() + T.writes() + tile_id = _var("int32") + batch_idx = _var("int32") + batch_tiles = _var("int32") + batch_rows = _var("int32") + iterator = _var("int32") + kv_chunk_len = _var("int32") + + Q_smem = T.alloc_buffer((tile_x, d), dtype, scope="shared") + K_smem = T.alloc_buffer((tile_z, d), dtype, scope="shared") + V_smem = T.alloc_buffer((tile_z, d), dtype, scope="shared") + S_smem = T.alloc_buffer((tile_x, tile_z), "float32", scope="shared") + + S_local = T.alloc_buffer((tile_x, tile_z), "float32", scope="local") + O_local = T.alloc_buffer((tile_x, d), "float32", scope="local") + + m_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") + m_prev_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") + d_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") + + m_new = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local") + m_prev = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local") + d_new = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local") + + ## get tile_no, batch_idx, batch_tiles, batch_rows + tile_id[0] = bx + batch_idx[0] = 0 + batch_rows[0] = (q_indptr[1] - q_indptr[0]) * group_size + batch_tiles[0] = T.ceildiv(batch_rows[0], tile_x) + while T.tvm_thread_invariant(batch_idx[0] < batch_size): + # advance to next tile + while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size: + tile_id[0] -= batch_tiles[0] + batch_idx[0] += 1 + if batch_idx[0] < batch_size: + b_idx: T.int32 = batch_idx[0] + batch_rows[0] = (q_indptr[b_idx + 1] - q_indptr[b_idx]) * group_size + batch_tiles[0] = T.ceildiv(batch_rows[0], tile_x) + + if T.tvm_thread_invariant(batch_idx[0] < batch_size): + b_idx: T.int32 = batch_idx[0] + L_start: T.int32 = q_indptr[b_idx] + tile_id[0] * L_per_cta + H_qo_start: T.int32 = by * group_size + + kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx] + T.tvm_storage_sync("shared") + + # init states + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx + if row < tile_x: + m_smem[row] = -5e4 + d_smem[row] = 1.0 + + for li, lj in T.grid(tile_x, tile_y): + with T.block("O_init"): + i, j = T.axis.remap("SS", [li, lj]) + O_local[i, j] = 0.0 + T.tvm_storage_sync("shared") + + # Load Q from gmem to smem + for li, lj in T.grid(tile_x, tile_y): + with T.block("Q_load"): + i, j = T.axis.remap("SS", [li, lj]) + T.reads() + T.writes() + cur_L = L_start + i // group_size + cur_H_qo = H_qo_start + i % group_size + if cur_L < q_indptr[b_idx + 1]: + Q_smem[i, j] = T.if_then_else( + rotary_mode == 1, + _rope(q, q_rope_position[cur_L], d, rope_theta, rope_scale, (cur_L, cur_H_qo, j), dtype), + q[cur_L, cur_H_qo, j] + ) + else: + Q_smem[i, j] = 0.0 + T.tvm_storage_sync("shared") + + for iterator in T.serial(T.ceildiv(kv_chunk_len[0], tile_z)): + L_kv_start: T.int32 = iterator * tile_z + L_kv_base: T.int32 = kv_indptr[b_idx] + for lz, ly in T.grid(tile_z, tile_y): + with T.block("KV_load"): + i, j = T.axis.remap("SS", [lz, ly]) + T.reads() + T.writes() + cur_L = L_kv_base + L_kv_start + i + if L_kv_start + i < kv_chunk_len[0]: + K_smem[i, j] = T.if_then_else( + rotary_mode == 1, + _rope(k, q_rope_position[cur_L], d, rope_theta, rope_scale, (cur_L, by, j), dtype), + k[cur_L, by, j] + ) + V_smem[i, j] = v[cur_L, by, j] + else: + K_smem[i, j] = 0.0 + V_smem[i, j] = 0.0 + T.tvm_storage_sync("shared") + + # Compute S + with T.block(): + for li, lj, lk in T.grid(tile_x, tile_z, tile_y): + with T.block("S_gemm"): + i, j, k = T.axis.remap("SSR", [li, lj, lk]) + with T.init(): + S_local[i, j] = 0.0 + S_local[i, j] += T.cast(Q_smem[i, k], "float32") * T.cast(K_smem[j, k], "float32") * attn_score_scaling_factor * sm_scale + T.tvm_storage_sync("shared") + for li, lj in T.grid(tile_x, tile_z): + with T.block("S_store"): + i, j = T.axis.remap("SS", [li, lj]) + S_smem[i, j] = S_local[i, j] + T.tvm_storage_sync("shared") + + # Update S, m, d + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx + if row < tile_x: + with T.block("update1"): + m_prev[i] = m_smem[row] + m_new[i] = m_smem[row] + # mask out of kv_chunk_len S + for j in T.serial(tile_z): + if _tree_mask(row=tile_id[0] * L_per_cta + row // group_size, + col=L_kv_start + j, + mask_ptr=mask, + offset=mn_indptr[b_idx], + stride=q_indptr[b_idx + 1] - q_indptr[b_idx], + kv_len=kv_chunk_len[0]): + m_new[i] = T.max(m_new[i], S_smem[row, j]) + d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i]) + + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx + with T.block("update"): + for j in T.serial(tile_z): + # this is to avoid sync inside condition branch + if row < tile_x: + if _tree_mask(row=tile_id[0] * L_per_cta + row // group_size, + col=L_kv_start + j, + mask_ptr=mask, + offset=mn_indptr[b_idx], + stride=q_indptr[b_idx + 1] - q_indptr[b_idx], + kv_len=kv_chunk_len[0]): + S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i]) + else: + S_smem[row, j] = T.exp2(-5e4 - m_new[i]) + + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx + if row < tile_x: + with T.block("update"): + for j in T.serial(tile_z): + d_new[i] += S_smem[row, j] + m_smem[row] = m_new[i] + d_smem[row] = d_new[i] + m_prev_smem[row] = m_prev[i] + T.tvm_storage_sync("shared") + + # Update O + with T.block(): + for li, lj, lk in T.grid(tile_x, tile_y, tile_z): + with T.block("O_gemm"): + i, j, k = T.axis.remap("SSR", [li, lj, lk]) + with T.init(): + O_local[i, j] *= T.exp2(m_prev_smem[i] - m_smem[i]) + O_local[i, j] += S_smem[i, k] * T.cast(V_smem[k, j], "float32") + + # Store O from smem to gmem + for li, lj in T.grid(tile_x, tile_y): + with T.block("O_store"): + i, j = T.axis.remap("SS", [li, lj]) + if L_start + i // group_size < q_indptr[b_idx + 1]: + output[L_start + i // group_size, H_qo_start + i % group_size, j] = O_local[i, j] / d_smem[i] + + # Store LSE to gmem + for li in T.grid(tile_x): + with T.block("lse_store"): + i = T.axis.remap("S", [li]) + if L_start + i // group_size < q_indptr[b_idx + 1]: + lse[L_start + i // group_size, H_qo_start + i % group_size] = m_smem[i] + T.log2(d_smem[i]) + + # move to next tile + tile_id[0] += NUM_BLKS + # fmt: on + # pylint: enable=line-too-long,invalid-name,too-many-branches + sch = tir.Schedule(batch_tree_attn) + + def get_tile_size(x, y, t): + cnt = (x * y) // t + assert (x * y) % t == 0 + tile_y = (int)(math.ceil(math.sqrt(cnt))) + while (cnt % tile_y != 0 or y % tile_y != 0) and tile_y <= cnt: + tile_y += 1 + assert tile_y <= cnt + tile_x = cnt // tile_y + return tile_x, tile_y + + def apply_to_qkv_load(sch: tir.Schedule, block): + loop_x, loop_y = sch.get_loops(block)[-2:] + loop = sch.fuse(loop_x, loop_y) + _, ty, tx, vec = sch.split( + loop, factors=[None, num_warps, bdx, LOAD_VEC], preserve_unit_iters=True + ) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + sch.vectorize(vec) + + def apply_to_so_ewise(sch: tir.Schedule, block, tile): + loop_x, loop_y = sch.get_loops(block)[-2:] + xo, xi = sch.split(loop_x, factors=[None, tile[0]]) + yo, yi = sch.split(loop_y, factors=[None, tile[1]]) + sch.reorder(xo, yo, xi, yi) + t = sch.fuse(xo, yo) + ty, tx = sch.split(t, factors=[None, bdx]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + def apply_to_gemm( # pylint: disable=unused-argument + sch: tir.Schedule, block, tile, read_0, read_1, r_len=8, k_major=False + ): + loop_x, loop_y, loop_z = sch.get_loops(block)[-3:] + xo, xi = sch.split(loop_x, factors=[None, tile[0]]) + yo, yi = sch.split(loop_y, factors=[None, tile[1]]) + sch.reorder(xo, yo, xi, yi) + t = sch.fuse(xo, yo) + ty, tx = sch.split(t, factors=[None, bdx]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + ko, ki = sch.split(loop_z, factors=[None, r_len]) + if k_major: + sch.reorder(ko, xi, yi, ki) + else: + sch.reorder(ko, ki, xi, yi) + sch.decompose_reduction(block, ty) + + def apply_to_md(sch, block): + loop = sch.get_loops(block)[-1] + _, ty, tx = sch.split(loop, factors=[None, num_warps, bdx]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + tile_s = get_tile_size(tile_x, tile_z, bdx * num_warps) + tile_o = get_tile_size(tile_x, tile_y, bdx * num_warps) + apply_to_gemm(sch, sch.get_block("S_gemm"), tile_s, 0, 1, k_major=True) + apply_to_gemm(sch, sch.get_block("O_gemm"), tile_o, 2, 3, k_major=False) + apply_to_so_ewise(sch, sch.get_block("S_store"), tile_s) + apply_to_so_ewise(sch, sch.get_block("O_init"), tile_o) + apply_to_so_ewise(sch, sch.get_block("O_store"), tile_o) + apply_to_qkv_load(sch, sch.get_block("Q_load")) + apply_to_qkv_load(sch, sch.get_block("KV_load")) + + apply_to_md(sch, sch.get_block("lse_store")) + return sch.mod["main"].with_attr("tir.is_scheduled", 1) + + def _merge_state_inplace( num_heads, head_dim, v_dtype, target: Target ): # pylint: disable=unused-argument @@ -1960,6 +2452,56 @@ def copy_single_page( return copy_single_page +def _compact_kv_copy(num_heads, head_dim, dtype, target: Target): + tx = 256 if str(target.kind) == "webgpu" else 1024 + + @T.prim_func + def compact_kv_copy( + var_pages: T.handle, + var_copy_length_indptr: T.handle, + var_copy_src_dst_pos: T.handle, + batch_size: T.int32, + ): + T.func_attr({"tir.is_scheduled": 1}) + num_pages = T.int32() + total_copy_length = T.int32() + copy_length_indptr_elem_offset = T.int32() + copy_src_dst_pos_elem_offset = T.int32() + pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, 16, head_dim), dtype) + copy_length_indptr = T.match_buffer( + var_copy_length_indptr, + (batch_size + 1,), + "int32", + elem_offset=copy_length_indptr_elem_offset, + ) + copy_src_dst_pos = T.match_buffer( + var_copy_src_dst_pos, + (2, total_copy_length), + "int32", + elem_offset=copy_src_dst_pos_elem_offset, + ) + + for bhd_o in T.thread_binding( + (batch_size * num_heads * head_dim + tx - 1) // tx, thread="blockIdx.x" + ): + for bhd_i in T.thread_binding(tx, thread="threadIdx.x"): + b: T.int32 = (bhd_o * tx + bhd_i) // (num_heads * head_dim) + h: T.int32 = (bhd_o * tx + bhd_i) // head_dim % num_heads + d: T.int32 = (bhd_o * tx + bhd_i) % head_dim + if (bhd_o * tx + bhd_i) < batch_size * num_heads * head_dim: + for i in T.serial(copy_length_indptr[b + 1] - copy_length_indptr[b]): + src_pos: T.int32 = copy_src_dst_pos[0, copy_length_indptr[b] + i] + dst_pos: T.int32 = copy_src_dst_pos[1, copy_length_indptr[b] + i] + pages[dst_pos // 16, 0, h, dst_pos % 16, d] = pages[ + src_pos // 16, 0, h, src_pos % 16, d + ] + pages[dst_pos // 16, 1, h, dst_pos % 16, d] = pages[ + src_pos // 16, 1, h, src_pos % 16, d + ] + + return compact_kv_copy + + if __name__ == "__main__": HEAD_DIMS = [64, 128] DTYPES = ["float16", "float32"] @@ -1976,3 +2518,4 @@ def copy_single_page( test_paged_attention_kv_cache_fork_sequence(cache_and_config) test_paged_attention_kv_cache_popn(cache_and_config) test_paged_attention_kv_cache_sliding_window(cache_and_config) + test_paged_attention_kv_cache_tree_attn(cache_and_config)