From d7d48d98ee5a496f57609a1066cda57d56f00949 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Thu, 22 Aug 2024 15:17:36 -0400 Subject: [PATCH] [Runtime] Support KV cache with RoPE extension factor array This PR enhances the KV cache with the RoPE extensio factor support. With this PR, the KV cache can support models like Phi3.5 which comes with the extension factor. --- src/runtime/relax_vm/kv_state.h | 1 + src/runtime/relax_vm/paged_kv_cache.cc | 63 +++++++++++-------- ...tin_paged_attention_kv_cache_flashinfer.py | 3 + ...me_builtin_paged_attention_kv_cache_tir.py | 1 + 4 files changed, 43 insertions(+), 25 deletions(-) diff --git a/src/runtime/relax_vm/kv_state.h b/src/runtime/relax_vm/kv_state.h index f4d6036b9638..6d30ce998add 100644 --- a/src/runtime/relax_vm/kv_state.h +++ b/src/runtime/relax_vm/kv_state.h @@ -167,6 +167,7 @@ class AttentionKVCacheObj : public KVStateObj { * `(total_length, num_qo_heads + 2 * num_kv_heads, head_dim)`. * \param mask The input mask data, in layout `(total_sqr_length)`. * \param o_data The output O data, in layout `(total_length, num_qo_heads, head_dim)`. + * \param attn_score_scaling_factor The additional attention scaling factor. * \sa AttentionKVCache::Attention */ virtual void AttentionWithFusedQKV(int64_t layer_id, NDArray qkv_data, Optional mask, diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index 6bf3dc7ce609..591187ab5fe7 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -848,6 +848,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { const double rotary_scale_; /*! \brief The RoPE theta. */ const double rotary_theta_; + /*! \brief The optional RoPE extension factors for RoPE scaling. */ + const Optional rope_ext_factors_; /*! \brief We fix int32 to be the index dtype of auxiliary data. */ const DLDataType dtype_aux_ = DLDataType(DataType::Int(32, 1)); @@ -988,7 +990,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { int64_t page_size, int64_t num_layers, int64_t layer_id_begin_offset, // int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim, int64_t reserved_num_seqs, int64_t num_total_pages, int64_t prefill_chunk_size, bool support_sliding_window, - RoPEMode rope_mode, double rotary_scale, double rotary_theta, DLDataType dtype, Device device, + RoPEMode rope_mode, double rotary_scale, double rotary_theta, + Optional rope_ext_factors, DLDataType dtype, Device device, PackedFunc f_transpose_append, PackedFunc f_compact_copy, PackedFunc f_attention_prefill, PackedFunc f_attention_decode, PackedFunc f_attention_prefill_sliding_window, PackedFunc f_attention_decode_sliding_window, PackedFunc f_attention_prefill_ragged, @@ -1013,6 +1016,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { : rope_mode), rotary_scale_(rotary_scale), rotary_theta_(rotary_theta), + rope_ext_factors_(std::move(rope_ext_factors)), f_transpose_append_(std::move(f_transpose_append)), f_compact_copy_(std::move(f_compact_copy)), f_attention_prefill_(std::move(f_attention_prefill)), @@ -1132,6 +1136,12 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { reserved_num_seqs, num_total_pages, prefill_chunk_size, dtype_aux_, device, preferred_host_device, copy_stream_); } + + // Right now only the "normal" RoPE mode supports the RoPE extention factors. + if (rope_ext_factors_.defined()) { + CHECK(rope_mode_ == RoPEMode::kNormal) + << "The RoPE mode must be normal to support RoPE extension factors."; + } } ~PagedAttentionKVCacheObj() { @@ -1726,8 +1736,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { NDArray v_data = temp_attn_v_device_.CreateView({total_seq_length, num_kv_heads_, head_dim_}, qkv_data->dtype); // Part 2. Split fused qkv and apply rotary embedding to q/k data. - f_split_rotary_(qkv_data, q_rope_position_map_view_, q_data, k_data, v_data, - static_cast(rope_mode_ == RoPEMode::kNormal)); + if (!rope_ext_factors_.defined()) { + f_split_rotary_(qkv_data, q_rope_position_map_view_, q_data, k_data, v_data, + static_cast(rope_mode_ == RoPEMode::kNormal)); + } else { + f_split_rotary_(qkv_data, q_rope_position_map_view_, q_data, k_data, v_data, + rope_ext_factors_.value()); + } // Part 3. Append k/v data to kv-cache if flag "append_before_attn" is set. if (append_before_attn_) { @@ -2462,7 +2477,7 @@ TVM_REGISTER_OBJECT_TYPE(PagedAttentionKVCacheObj); TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") .set_body([](TVMArgs args, TVMRetValue* rv) { - CHECK(args.size() == 25 || args.size() == 26 || args.size() == 27) + CHECK(args.size() == 27 || args.size() == 28) << "Invalid number of KV cache constructor args."; ShapeTuple cache_config = args[0]; ShapeTuple layer_indptr_tuple = args[1]; @@ -2499,14 +2514,12 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") PackedFunc f_split_rotary = args[22]; PackedFunc f_copy_single_page = args[23]; Optional f_debug_get_kv = args[24]; - PackedFunc f_compact_copy{nullptr}; - PackedFunc f_attention_prefill_with_tree_mask{nullptr}; + PackedFunc f_compact_copy = args[25]; + PackedFunc f_attention_prefill_with_tree_mask = args[26]; + Optional rope_ext_factors = NullOpt; - if (args.size() >= 26) { - f_compact_copy = args[25].AsObjectRef(); - } - if (args.size() >= 27) { - f_attention_prefill_with_tree_mask = args[26].AsObjectRef(); + if (args.size() >= 28 && args[27].IsObjectRef()) { + rope_ext_factors = args[27].AsObjectRef(); } CHECK_EQ(cache_config.size(), 5); @@ -2523,9 +2536,10 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") ObjectPtr n = make_object( page_size, num_layers, layer_id_begin_offset, num_qo_heads, num_kv_heads, head_dim, reserved_num_seqs, num_total_pages, prefill_chunk_size, support_sliding_window, - RoPEMode(rope_mode), rotary_scale, rotary_theta, init->dtype, init->device, - std::move(f_transpose_append), std::move(f_compact_copy), std::move(f_attention_prefill), - std::move(f_attention_decode), std::move(f_attention_prefill_sliding_window), + RoPEMode(rope_mode), rotary_scale, rotary_theta, std::move(rope_ext_factors), // + init->dtype, init->device, std::move(f_transpose_append), std::move(f_compact_copy), + std::move(f_attention_prefill), std::move(f_attention_decode), + std::move(f_attention_prefill_sliding_window), std::move(f_attention_decode_sliding_window), std::move(f_attention_prefill_ragged), std::move(f_attention_prefill_with_tree_mask), std::move(f_attention_prefill_ragged_begin_forward), @@ -2539,7 +2553,7 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced") .set_body([](TVMArgs args, TVMRetValue* rv) { - CHECK(args.size() == 19 || args.size() == 20 || args.size() == 21) + CHECK(args.size() == 21 || args.size() == 22) << "Invalid number of KV cache constructor args."; ShapeTuple cache_config = args[0]; ShapeTuple layer_indptr_tuple = args[1]; @@ -2570,14 +2584,12 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced") PackedFunc f_split_rotary = args[16]; PackedFunc f_copy_single_page = args[17]; Optional f_debug_get_kv = args[18]; - PackedFunc f_compact_copy{nullptr}; - PackedFunc f_attention_prefill_with_tree_mask{nullptr}; + PackedFunc f_compact_copy = args[19]; + PackedFunc f_attention_prefill_with_tree_mask = args[20]; + Optional rope_ext_factors = NullOpt; - if (args.size() >= 20) { - f_compact_copy = args[19].AsObjectRef(); - } - if (args.size() >= 21) { - f_attention_prefill_with_tree_mask = args[20].AsObjectRef(); + if (args.size() >= 22 && args[21].IsObjectRef()) { + rope_ext_factors = args[21].AsObjectRef(); } CHECK_EQ(cache_config.size(), 5); @@ -2594,9 +2606,10 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced") ObjectPtr n = make_object( page_size, num_layers, layer_id_begin_offset, num_qo_heads, num_kv_heads, head_dim, reserved_num_seqs, num_total_pages, prefill_chunk_size, support_sliding_window, - RoPEMode(rope_mode), rotary_scale, rotary_theta, init->dtype, init->device, - std::move(f_transpose_append), std::move(f_compact_copy), std::move(f_attention_prefill), - std::move(f_attention_decode), std::move(f_attention_prefill_sliding_window), + RoPEMode(rope_mode), rotary_scale, rotary_theta, std::move(rope_ext_factors), // + init->dtype, init->device, std::move(f_transpose_append), std::move(f_compact_copy), + std::move(f_attention_prefill), std::move(f_attention_decode), + std::move(f_attention_prefill_sliding_window), std::move(f_attention_decode_sliding_window), std::move(f_attention_prefill_ragged), std::move(f_attention_prefill_with_tree_mask), // NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, // diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py index cab10f84cddf..2252cb8d9c09 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py @@ -379,6 +379,9 @@ def create_kv_cache(rope_mode): fsplit_rotary, fcopy_single_page, fcopy_cache, + None, + None, + None, ) return cache diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py index 96a2438505b2..ff655e141b96 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py @@ -180,6 +180,7 @@ def create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window): fcopy_cache, fcompact_copy, fattn_prefill_with_tree_mask, + None, ) return cache