diff --git a/src/runtime/relax_vm/kv_state.cc b/src/runtime/relax_vm/kv_state.cc index 05ba7c96506a..b1572bf4091a 100644 --- a/src/runtime/relax_vm/kv_state.cc +++ b/src/runtime/relax_vm/kv_state.cc @@ -47,6 +47,8 @@ TVM_REGISTER_GLOBAL("vm.builtin.kv_state_end_forward") // Attention KV Cache methods TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_enable_sliding_window_for_seq") .set_body_method(&AttentionKVCacheObj::EnableSlidingWindowForSeq); +TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_empty") + .set_body_method(&AttentionKVCacheObj::Empty); TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_num_available_pages") .set_body_method(&AttentionKVCacheObj::GetNumAvailablePages); TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_total_sequence_length") diff --git a/src/runtime/relax_vm/kv_state.h b/src/runtime/relax_vm/kv_state.h index 7b90ffce50b2..12a18ba89502 100644 --- a/src/runtime/relax_vm/kv_state.h +++ b/src/runtime/relax_vm/kv_state.h @@ -117,6 +117,8 @@ class AttentionKVCacheObj : public KVStateObj { public: /************** Raw Info Query **************/ + /*! \brief Check if the KV cache is empty. */ + virtual bool Empty() const = 0; /*! * \brief Get the number of available pages in the KV cache. * When the underlying KV cache implementation is not diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index 62750d6d7daa..4ab0f3f0c686 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -147,13 +147,14 @@ struct Sequence { */ int last_block_attn_sink_size = 0; - explicit Sequence(const std::vector& global_block_pool, int32_t last_block_idx) { + explicit Sequence(std::vector* global_block_pool, int32_t last_block_idx) { + ++global_block_pool->at(last_block_idx).external_ref_cnt; this->last_block_idx = last_block_idx; int32_t block_ptr = last_block_idx; // Go through each block in the sequence, sum up the length. int depth = 0; while (true) { - const Block& block = global_block_pool[block_ptr]; + const Block& block = global_block_pool->at(block_ptr); this->seq_length += block.seq_length; ++depth; if (block.parent_idx == -1) { @@ -965,7 +966,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { CHECK(seq_map_.find(seq_id) == seq_map_.end()) << "The sequence \"" << seq_id << "\" is already in the KV cache."; int32_t block_idx = GetFreeBlock(); - seq_map_.insert({seq_id, Sequence(global_block_pool_, block_idx)}); + seq_map_.insert({seq_id, Sequence(&global_block_pool_, block_idx)}); dirty_aux_data_device_ = true; } @@ -973,9 +974,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { auto it = seq_map_.find(seq_id); CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id << "\" cannot be found in KV cache."; int32_t block_idx = it->second.last_block_idx; - CHECK_EQ(global_block_pool_[block_idx].external_ref_cnt, 0) - << "The sequence is currently referenced by other sequence and thus cannot be removed."; - while (block_idx != -1 && global_block_pool_[block_idx].external_ref_cnt == 0) { + // The block should have at least one reference, which comes from the sequence. + ICHECK_GE(global_block_pool_[block_idx].external_ref_cnt, 1); + while (block_idx != -1 && global_block_pool_[block_idx].external_ref_cnt == 1) { // - Free pages in the last block. for (int32_t page_id : global_block_pool_[block_idx].page_ids) { free_page_ids_.push_back(page_id); @@ -985,7 +986,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } // - Decrease the external reference of the parent block. if (block_idx != -1) { - ICHECK_GT(global_block_pool_[block_idx].external_ref_cnt, 0); + ICHECK_GT(global_block_pool_[block_idx].external_ref_cnt, 1); --global_block_pool_[block_idx].external_ref_cnt; } seq_map_.erase(it); @@ -1018,11 +1019,15 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // Update child block start position and parent index global_block_pool_[child_block_idx].start_pos = parent_it->second.seq_length; global_block_pool_[child_block_idx].parent_idx = parent_block_idx; - if (global_block_pool_[parent_block_idx].seq_length) { - // If parent is not empty, append a new block + if (parent_block_idx == parent_it->second.last_block_idx && + global_block_pool_[parent_block_idx].seq_length) { + // To enable the parent sequence to continue decode after the fork, + // we add a new empty block at the end of the parent sequence. + // So the new decoded KV data will go into the new block. int32_t new_parent_block_idx = GetFreeBlock(); global_block_pool_[new_parent_block_idx].start_pos = parent_it->second.seq_length; global_block_pool_[new_parent_block_idx].parent_idx = parent_block_idx; + global_block_pool_[new_parent_block_idx].external_ref_cnt = 1; parent_it->second.last_block_idx = new_parent_block_idx; } } else { @@ -1055,7 +1060,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { global_block_pool_[forked_block_idx].parent_idx; global_block_pool_[forked_block_idx].parent_idx = parent_block_idx; global_block_pool_[child_block_idx].parent_idx = parent_block_idx; - global_block_pool_[parent_block_idx].external_ref_cnt = 1; + global_block_pool_[parent_block_idx].external_ref_cnt = 2; // Move common leading pages to new parent block auto first_page = global_block_pool_[forked_block_idx].page_ids.begin(); @@ -1085,7 +1090,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } } // Create the child sequence with the child block. - seq_map_.insert({child_seq_id, Sequence(global_block_pool_, child_block_idx)}); + seq_map_.insert({child_seq_id, Sequence(&global_block_pool_, child_block_idx)}); dirty_aux_data_device_ = true; } @@ -1119,7 +1124,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { << "A sequence cannot be enabled twice for sliding window."; // Compute the total length of the prefix blocks of this sequence. - Block& last_block = global_block_pool_[it->second.last_block_idx]; + const Block& last_block = global_block_pool_[it->second.last_block_idx]; int32_t prefix_length = it->second.seq_length - last_block.seq_length; ICHECK_GE(prefix_length, 0); // Since the prefix blocks cannot sliding, they are natural @@ -1139,7 +1144,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { << "The sequence only has length " << it->second.seq_length << ", while the length of pop is " << n << " which exceeds the whole sequence length."; int32_t block_idx = it->second.last_block_idx; - while (block_idx != -1 && global_block_pool_[block_idx].external_ref_cnt == 0) { + // The block should have at least one reference, which comes from the sequence. + ICHECK_GE(global_block_pool_[block_idx].external_ref_cnt, 1); + while (block_idx != -1 && global_block_pool_[block_idx].external_ref_cnt == 1) { if (n > global_block_pool_[block_idx].seq_length) { n -= global_block_pool_[block_idx].seq_length; it->second.seq_length -= global_block_pool_[block_idx].seq_length; @@ -1168,14 +1175,16 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } if (n) { - int32_t temp_seq_id = -1 - seq_id; + // We use a temporary sequence id for fork. + // This temporary seq id will immediately end its effect outside this function. + int64_t temp_seq_id = -1 - seq_id; CHECK(seq_map_.find(temp_seq_id) == seq_map_.end()); ForkSequence(seq_id, temp_seq_id, it->second.seq_length - n); CHECK(seq_map_.find(temp_seq_id) != seq_map_.end()); RemoveSequence(seq_id); CHECK(seq_map_.find(seq_id) == seq_map_.end()); auto it = seq_map_.find(temp_seq_id); - seq_map_.insert({seq_id, Sequence(global_block_pool_, it->second.last_block_idx)}); + seq_map_.insert({seq_id, it->second}); seq_map_.erase(temp_seq_id); } @@ -1184,6 +1193,12 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { /************** Raw Info Query **************/ + bool Empty() const final { + return seq_map_.empty() && // + free_block_idx_.size() == global_block_pool_.size() && // + free_page_ids_.size() == static_cast(num_total_pages_); + } + int32_t GetNumAvailablePages() const final { return free_page_ids_.size(); } int32_t GetTotalSequenceLength() const final { @@ -1565,8 +1580,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { int32_t block_idx = seq->last_block_idx; Block& block = global_block_pool_[block_idx]; CHECK_GT(append_length, 0) << "Append with length 0 is not allowed."; - CHECK_EQ(block.external_ref_cnt, 0) - << "The block is " << block.external_ref_cnt + CHECK_EQ(block.external_ref_cnt, 1) + << "The block is " << block.external_ref_cnt - 1 << "-time referenced by other blocks, thus cannot accept new KV values."; // ==================== Reserve ==================== diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py index f7b01bb84066..6504175b5680 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py @@ -54,6 +54,7 @@ fbegin_forward = None fend_forward = None fattention_with_fuse_qkv = None +fis_empty = None fdebug_get_kv = None ftranspose_append = None @@ -71,7 +72,7 @@ def set_global_func(head_dim, dtype): global fclear, fadd_sequence, fremove_sequence, ffork_sequence, fenable_sliding_window_for_seq - global fpopn, fbegin_forward, fend_forward, fattention_with_fuse_qkv, fdebug_get_kv + global fpopn, fbegin_forward, fend_forward, fattention_with_fuse_qkv, fis_empty, fdebug_get_kv global ftranspose_append, fcopy_cache, fattn_prefill, fattn_decode, fattn_prefill_ragged global fattn_prefill_sliding_window, fattn_decode_sliding_window global fmerge_state, fsplit_rotary, fattention_rotary, fcopy_single_page @@ -89,6 +90,7 @@ def set_global_func(head_dim, dtype): fattention_with_fuse_qkv = tvm.get_global_func( "vm.builtin.attention_kv_cache_attention_with_fused_qkv" ) + fis_empty = tvm.get_global_func("vm.builtin.attention_kv_cache_empty") fdebug_get_kv = tvm.get_global_func("vm.builtin.attention_kv_cache_debug_get_kv") target = tvm.target.Target("cuda") @@ -489,11 +491,19 @@ def test_paged_attention_kv_cache_fork_sequence(kv_cache_and_config): for batch in operation_seq: apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v) - for i in range(19, -1, -1): + num_sequence = 20 + for i in range(num_sequence): fremove_sequence(kv_cache, i) cached_k.pop(i) cached_v.pop(i) - verify_cached_kv(kv_cache, seq_ids=list(range(i)), expected_k=cached_k, expected_v=cached_v) + verify_cached_kv( + kv_cache, + seq_ids=list(range(i + 1, num_sequence)), + expected_k=cached_k, + expected_v=cached_v, + ) + + assert fis_empty(kv_cache), "The KV cache is not empty after removing all sequences" @tvm.testing.requires_gpu @@ -510,7 +520,7 @@ def test_paged_attention_kv_cache_popn(kv_cache_and_config): apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v) apply_attention(kv_cache, rope_mode, [((4, 3, -1), 35)], cached_k, cached_v) - popn_operations = [(0, 17), (1, 57), (2, 16), (3, 0)] + popn_operations = [(0, 17), (1, 57), (2, 16), (3, 0), (4, 37)] for seq_id, pop_length in popn_operations: fpopn(kv_cache, seq_id, pop_length) if pop_length != 0: @@ -518,6 +528,18 @@ def test_paged_attention_kv_cache_popn(kv_cache_and_config): cached_v[seq_id] = cached_v[seq_id][:, :-pop_length, ...] verify_cached_kv(kv_cache, seq_ids=list(range(4)), expected_k=cached_k, expected_v=cached_v) + num_sequence = 5 + for seq_id in range(num_sequence): + fremove_sequence(kv_cache, seq_id) + verify_cached_kv( + kv_cache, + seq_ids=list(range(seq_id + 1, num_sequence)), + expected_k=cached_k, + expected_v=cached_v, + ) + + assert fis_empty(kv_cache), "The KV cache is not empty after removing all sequences" + @tvm.testing.requires_gpu @tvm.testing.requires_cuda