From 46237a1c182cc7817589064f74dcf50f3c36f9ac Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Wed, 29 May 2024 17:27:02 -0400 Subject: [PATCH] [Runtime] Fix PagedKVCache for PopN and enhance tests This PR fixes a bug in the PagedKVCache which may happen when the sequence removal order is not consistent with the reverse order of sequence add/fork order. With this fix, the PagedKVCache now supports removing sequences in any order without breaking. This PR also adds an `empty` function to PagedKVCache to check if the KV cache is empty. Right now this function is only used for test purpose, where we check if everything in the KV cache is freed after removing all sequences. --- src/runtime/relax_vm/kv_state.cc | 2 + src/runtime/relax_vm/kv_state.h | 2 + src/runtime/relax_vm/paged_kv_cache.cc | 49 ++++++++++++------- ...me_builtin_paged_attention_kv_cache_tir.py | 30 ++++++++++-- 4 files changed, 62 insertions(+), 21 deletions(-) diff --git a/src/runtime/relax_vm/kv_state.cc b/src/runtime/relax_vm/kv_state.cc index 05ba7c96506a..b1572bf4091a 100644 --- a/src/runtime/relax_vm/kv_state.cc +++ b/src/runtime/relax_vm/kv_state.cc @@ -47,6 +47,8 @@ TVM_REGISTER_GLOBAL("vm.builtin.kv_state_end_forward") // Attention KV Cache methods TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_enable_sliding_window_for_seq") .set_body_method(&AttentionKVCacheObj::EnableSlidingWindowForSeq); +TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_empty") + .set_body_method(&AttentionKVCacheObj::Empty); TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_num_available_pages") .set_body_method(&AttentionKVCacheObj::GetNumAvailablePages); TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_total_sequence_length") diff --git a/src/runtime/relax_vm/kv_state.h b/src/runtime/relax_vm/kv_state.h index 7b90ffce50b2..12a18ba89502 100644 --- a/src/runtime/relax_vm/kv_state.h +++ b/src/runtime/relax_vm/kv_state.h @@ -117,6 +117,8 @@ class AttentionKVCacheObj : public KVStateObj { public: /************** Raw Info Query **************/ + /*! \brief Check if the KV cache is empty. */ + virtual bool Empty() const = 0; /*! * \brief Get the number of available pages in the KV cache. * When the underlying KV cache implementation is not diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index 62750d6d7daa..4ab0f3f0c686 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -147,13 +147,14 @@ struct Sequence { */ int last_block_attn_sink_size = 0; - explicit Sequence(const std::vector& global_block_pool, int32_t last_block_idx) { + explicit Sequence(std::vector* global_block_pool, int32_t last_block_idx) { + ++global_block_pool->at(last_block_idx).external_ref_cnt; this->last_block_idx = last_block_idx; int32_t block_ptr = last_block_idx; // Go through each block in the sequence, sum up the length. int depth = 0; while (true) { - const Block& block = global_block_pool[block_ptr]; + const Block& block = global_block_pool->at(block_ptr); this->seq_length += block.seq_length; ++depth; if (block.parent_idx == -1) { @@ -965,7 +966,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { CHECK(seq_map_.find(seq_id) == seq_map_.end()) << "The sequence \"" << seq_id << "\" is already in the KV cache."; int32_t block_idx = GetFreeBlock(); - seq_map_.insert({seq_id, Sequence(global_block_pool_, block_idx)}); + seq_map_.insert({seq_id, Sequence(&global_block_pool_, block_idx)}); dirty_aux_data_device_ = true; } @@ -973,9 +974,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { auto it = seq_map_.find(seq_id); CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id << "\" cannot be found in KV cache."; int32_t block_idx = it->second.last_block_idx; - CHECK_EQ(global_block_pool_[block_idx].external_ref_cnt, 0) - << "The sequence is currently referenced by other sequence and thus cannot be removed."; - while (block_idx != -1 && global_block_pool_[block_idx].external_ref_cnt == 0) { + // The block should have at least one reference, which comes from the sequence. + ICHECK_GE(global_block_pool_[block_idx].external_ref_cnt, 1); + while (block_idx != -1 && global_block_pool_[block_idx].external_ref_cnt == 1) { // - Free pages in the last block. for (int32_t page_id : global_block_pool_[block_idx].page_ids) { free_page_ids_.push_back(page_id); @@ -985,7 +986,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } // - Decrease the external reference of the parent block. if (block_idx != -1) { - ICHECK_GT(global_block_pool_[block_idx].external_ref_cnt, 0); + ICHECK_GT(global_block_pool_[block_idx].external_ref_cnt, 1); --global_block_pool_[block_idx].external_ref_cnt; } seq_map_.erase(it); @@ -1018,11 +1019,15 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // Update child block start position and parent index global_block_pool_[child_block_idx].start_pos = parent_it->second.seq_length; global_block_pool_[child_block_idx].parent_idx = parent_block_idx; - if (global_block_pool_[parent_block_idx].seq_length) { - // If parent is not empty, append a new block + if (parent_block_idx == parent_it->second.last_block_idx && + global_block_pool_[parent_block_idx].seq_length) { + // To enable the parent sequence to continue decode after the fork, + // we add a new empty block at the end of the parent sequence. + // So the new decoded KV data will go into the new block. int32_t new_parent_block_idx = GetFreeBlock(); global_block_pool_[new_parent_block_idx].start_pos = parent_it->second.seq_length; global_block_pool_[new_parent_block_idx].parent_idx = parent_block_idx; + global_block_pool_[new_parent_block_idx].external_ref_cnt = 1; parent_it->second.last_block_idx = new_parent_block_idx; } } else { @@ -1055,7 +1060,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { global_block_pool_[forked_block_idx].parent_idx; global_block_pool_[forked_block_idx].parent_idx = parent_block_idx; global_block_pool_[child_block_idx].parent_idx = parent_block_idx; - global_block_pool_[parent_block_idx].external_ref_cnt = 1; + global_block_pool_[parent_block_idx].external_ref_cnt = 2; // Move common leading pages to new parent block auto first_page = global_block_pool_[forked_block_idx].page_ids.begin(); @@ -1085,7 +1090,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } } // Create the child sequence with the child block. - seq_map_.insert({child_seq_id, Sequence(global_block_pool_, child_block_idx)}); + seq_map_.insert({child_seq_id, Sequence(&global_block_pool_, child_block_idx)}); dirty_aux_data_device_ = true; } @@ -1119,7 +1124,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { << "A sequence cannot be enabled twice for sliding window."; // Compute the total length of the prefix blocks of this sequence. - Block& last_block = global_block_pool_[it->second.last_block_idx]; + const Block& last_block = global_block_pool_[it->second.last_block_idx]; int32_t prefix_length = it->second.seq_length - last_block.seq_length; ICHECK_GE(prefix_length, 0); // Since the prefix blocks cannot sliding, they are natural @@ -1139,7 +1144,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { << "The sequence only has length " << it->second.seq_length << ", while the length of pop is " << n << " which exceeds the whole sequence length."; int32_t block_idx = it->second.last_block_idx; - while (block_idx != -1 && global_block_pool_[block_idx].external_ref_cnt == 0) { + // The block should have at least one reference, which comes from the sequence. + ICHECK_GE(global_block_pool_[block_idx].external_ref_cnt, 1); + while (block_idx != -1 && global_block_pool_[block_idx].external_ref_cnt == 1) { if (n > global_block_pool_[block_idx].seq_length) { n -= global_block_pool_[block_idx].seq_length; it->second.seq_length -= global_block_pool_[block_idx].seq_length; @@ -1168,14 +1175,16 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } if (n) { - int32_t temp_seq_id = -1 - seq_id; + // We use a temporary sequence id for fork. + // This temporary seq id will immediately end its effect outside this function. + int64_t temp_seq_id = -1 - seq_id; CHECK(seq_map_.find(temp_seq_id) == seq_map_.end()); ForkSequence(seq_id, temp_seq_id, it->second.seq_length - n); CHECK(seq_map_.find(temp_seq_id) != seq_map_.end()); RemoveSequence(seq_id); CHECK(seq_map_.find(seq_id) == seq_map_.end()); auto it = seq_map_.find(temp_seq_id); - seq_map_.insert({seq_id, Sequence(global_block_pool_, it->second.last_block_idx)}); + seq_map_.insert({seq_id, it->second}); seq_map_.erase(temp_seq_id); } @@ -1184,6 +1193,12 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { /************** Raw Info Query **************/ + bool Empty() const final { + return seq_map_.empty() && // + free_block_idx_.size() == global_block_pool_.size() && // + free_page_ids_.size() == static_cast(num_total_pages_); + } + int32_t GetNumAvailablePages() const final { return free_page_ids_.size(); } int32_t GetTotalSequenceLength() const final { @@ -1565,8 +1580,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { int32_t block_idx = seq->last_block_idx; Block& block = global_block_pool_[block_idx]; CHECK_GT(append_length, 0) << "Append with length 0 is not allowed."; - CHECK_EQ(block.external_ref_cnt, 0) - << "The block is " << block.external_ref_cnt + CHECK_EQ(block.external_ref_cnt, 1) + << "The block is " << block.external_ref_cnt - 1 << "-time referenced by other blocks, thus cannot accept new KV values."; // ==================== Reserve ==================== diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py index f7b01bb84066..6504175b5680 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py @@ -54,6 +54,7 @@ fbegin_forward = None fend_forward = None fattention_with_fuse_qkv = None +fis_empty = None fdebug_get_kv = None ftranspose_append = None @@ -71,7 +72,7 @@ def set_global_func(head_dim, dtype): global fclear, fadd_sequence, fremove_sequence, ffork_sequence, fenable_sliding_window_for_seq - global fpopn, fbegin_forward, fend_forward, fattention_with_fuse_qkv, fdebug_get_kv + global fpopn, fbegin_forward, fend_forward, fattention_with_fuse_qkv, fis_empty, fdebug_get_kv global ftranspose_append, fcopy_cache, fattn_prefill, fattn_decode, fattn_prefill_ragged global fattn_prefill_sliding_window, fattn_decode_sliding_window global fmerge_state, fsplit_rotary, fattention_rotary, fcopy_single_page @@ -89,6 +90,7 @@ def set_global_func(head_dim, dtype): fattention_with_fuse_qkv = tvm.get_global_func( "vm.builtin.attention_kv_cache_attention_with_fused_qkv" ) + fis_empty = tvm.get_global_func("vm.builtin.attention_kv_cache_empty") fdebug_get_kv = tvm.get_global_func("vm.builtin.attention_kv_cache_debug_get_kv") target = tvm.target.Target("cuda") @@ -489,11 +491,19 @@ def test_paged_attention_kv_cache_fork_sequence(kv_cache_and_config): for batch in operation_seq: apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v) - for i in range(19, -1, -1): + num_sequence = 20 + for i in range(num_sequence): fremove_sequence(kv_cache, i) cached_k.pop(i) cached_v.pop(i) - verify_cached_kv(kv_cache, seq_ids=list(range(i)), expected_k=cached_k, expected_v=cached_v) + verify_cached_kv( + kv_cache, + seq_ids=list(range(i + 1, num_sequence)), + expected_k=cached_k, + expected_v=cached_v, + ) + + assert fis_empty(kv_cache), "The KV cache is not empty after removing all sequences" @tvm.testing.requires_gpu @@ -510,7 +520,7 @@ def test_paged_attention_kv_cache_popn(kv_cache_and_config): apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v) apply_attention(kv_cache, rope_mode, [((4, 3, -1), 35)], cached_k, cached_v) - popn_operations = [(0, 17), (1, 57), (2, 16), (3, 0)] + popn_operations = [(0, 17), (1, 57), (2, 16), (3, 0), (4, 37)] for seq_id, pop_length in popn_operations: fpopn(kv_cache, seq_id, pop_length) if pop_length != 0: @@ -518,6 +528,18 @@ def test_paged_attention_kv_cache_popn(kv_cache_and_config): cached_v[seq_id] = cached_v[seq_id][:, :-pop_length, ...] verify_cached_kv(kv_cache, seq_ids=list(range(4)), expected_k=cached_k, expected_v=cached_v) + num_sequence = 5 + for seq_id in range(num_sequence): + fremove_sequence(kv_cache, seq_id) + verify_cached_kv( + kv_cache, + seq_ids=list(range(seq_id + 1, num_sequence)), + expected_k=cached_k, + expected_v=cached_v, + ) + + assert fis_empty(kv_cache), "The KV cache is not empty after removing all sequences" + @tvm.testing.requires_gpu @tvm.testing.requires_cuda