diff --git a/include/tvm/runtime/disco/disco_worker.h b/include/tvm/runtime/disco/disco_worker.h index 301b5b8d626b..13f94802c886 100644 --- a/include/tvm/runtime/disco/disco_worker.h +++ b/include/tvm/runtime/disco/disco_worker.h @@ -93,6 +93,21 @@ class DiscoWorker { struct Impl; friend struct DiscoWorker::Impl; }; +/*! + * \brief A threadlocal wrapper of DiscoWorker. + */ +struct ThreadLocalDiscoWorker { + /*! \brief The Disco worker */ + DiscoWorker* worker; + + /*! + * \brief Get the threadlocal Disco worker. + */ + static ThreadLocalDiscoWorker* Get() { + thread_local static ThreadLocalDiscoWorker worker; + return &worker; + } +}; } // namespace runtime } // namespace tvm diff --git a/src/runtime/disco/disco_worker.cc b/src/runtime/disco/disco_worker.cc index b281a3aca7da..5e6f401054ea 100644 --- a/src/runtime/disco/disco_worker.cc +++ b/src/runtime/disco/disco_worker.cc @@ -28,15 +28,6 @@ namespace tvm { namespace runtime { -struct ThreadLocalDiscoWorker { - DiscoWorker* worker; - - static ThreadLocalDiscoWorker* Get() { - thread_local static ThreadLocalDiscoWorker worker; - return &worker; - } -}; - TVM_DLL DiscoWorker* DiscoWorker::ThreadLocal() { DiscoWorker* ret = ThreadLocalDiscoWorker::Get()->worker; CHECK(ret) << "ValueError: The current thread is not a DiscoWorker thread"; diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index ec1cc3593a53..2fb8a72f4279 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -21,6 +21,7 @@ * \brief Runtime paged KV cache object for language models. */ #include +#include #include #include #include @@ -825,6 +826,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { const int64_t page_size_; /*! \brief The number of layers in the model. */ const int64_t num_layers_; + /*! \brief The beginning layer id offset. */ + const int64_t layer_id_begin_offset_; /*! \brief The number of query/output heads in the model. */ const int64_t num_qo_heads_; /*! \brief The number of key/value heads in the model. */ @@ -981,14 +984,14 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { public: /*! \brief Constructor. Take the cache configuration and initialize the NDArrays. */ explicit PagedAttentionKVCacheObj( - int64_t page_size, // - int64_t num_layers, int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim, - int64_t reserved_num_seqs, int64_t num_total_pages, int64_t prefill_chunk_size, - bool support_sliding_window, RoPEMode rope_mode, double rotary_scale, double rotary_theta, - DLDataType dtype, Device device, PackedFunc f_transpose_append, PackedFunc f_compact_copy, - PackedFunc f_attention_prefill, PackedFunc f_attention_decode, - PackedFunc f_attention_prefill_sliding_window, PackedFunc f_attention_decode_sliding_window, - PackedFunc f_attention_prefill_ragged, PackedFunc f_attention_prefill_with_tree_mask, + int64_t page_size, int64_t num_layers, int64_t layer_id_begin_offset, // + int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim, int64_t reserved_num_seqs, + int64_t num_total_pages, int64_t prefill_chunk_size, bool support_sliding_window, + RoPEMode rope_mode, double rotary_scale, double rotary_theta, DLDataType dtype, Device device, + PackedFunc f_transpose_append, PackedFunc f_compact_copy, PackedFunc f_attention_prefill, + PackedFunc f_attention_decode, PackedFunc f_attention_prefill_sliding_window, + PackedFunc f_attention_decode_sliding_window, PackedFunc f_attention_prefill_ragged, + PackedFunc f_attention_prefill_with_tree_mask, Optional f_attention_prefill_ragged_begin_forward, Optional f_attention_prefill_ragged_end_forward, Optional f_attention_prefill_begin_forward, @@ -998,6 +1001,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { PackedFunc f_split_rotary, PackedFunc f_copy_single_page, Optional f_debug_get_kv) : page_size_(page_size), num_layers_(num_layers), + layer_id_begin_offset_(layer_id_begin_offset), num_qo_heads_(num_qo_heads), num_kv_heads_(num_kv_heads), head_dim_(head_dim), @@ -1672,7 +1676,10 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { void AttentionWithFusedQKV(int64_t layer_id, NDArray qkv_data, Optional mask, NDArray o_data, double attn_score_scaling_factor) final { // Part 1. Shape and dtype check. - NDArray pages = pages_[layer_id]; + int64_t local_layer_id = layer_id - layer_id_begin_offset_; + CHECK_GE(local_layer_id, 0); + CHECK_LT(local_layer_id, num_layers_); + NDArray pages = pages_[local_layer_id]; CHECK(qkv_data.DataType() == pages.DataType()); CHECK(o_data.DataType() == pages.DataType()); @@ -1713,13 +1720,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // Part 3. Append k/v data to kv-cache if flag "append_before_attn" is set. if (append_before_attn_) { - f_transpose_append_(pages_[layer_id], k_data, v_data, append_position_map_view_); + f_transpose_append_(pages_[local_layer_id], k_data, v_data, append_position_map_view_); } // Part 4: perform attention AttentionInternal(layer_id, q_data, k_data, v_data, o_data, attn_score_scaling_factor); // Part 5. Append k/v data to kv-cache if flag "append_before_attn" is not set. if (!append_before_attn_) { - f_transpose_append_(pages_[layer_id], k_data, v_data, append_position_map_view_); + f_transpose_append_(pages_[local_layer_id], k_data, v_data, append_position_map_view_); } } @@ -2238,6 +2245,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { */ void AttentionInternal(int64_t layer_id, NDArray q_data, NDArray k_data, NDArray v_data, NDArray output, double attn_score_scaling_factor) { + int64_t local_layer_id = layer_id - layer_id_begin_offset_; + CHECK_GE(local_layer_id, 0); + CHECK_LT(local_layer_id, num_layers_); PackedFunc f_prefill = !support_sliding_window_ ? f_attention_prefill_ : f_attention_prefill_sliding_window_; PackedFunc f_decode = @@ -2245,7 +2255,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { CHECK_GE(num_depths_, 1) << "The number of effective depths must be greater or equal to 1."; if (append_before_attn_) { f_decode( - /*depth=*/0, q_data, pages_[layer_id], page_indptr_on_depths_view_[0], + /*depth=*/0, q_data, pages_[local_layer_id], page_indptr_on_depths_view_[0], page_indices_on_depths_view_[0], length_info_on_depths_view_[0], k_rope_pos_offset_view_[0], q_rope_position_map_view_, output, merged_attn_scores_view_, /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, rotary_theta_, @@ -2280,7 +2290,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } if (use_decode_kernel_[d]) { // Use decode kernel for depth d - f_decode(/*depth=*/d, q_data, pages_[layer_id], page_indptr_on_depths_view_[d], + f_decode(/*depth=*/d, q_data, pages_[local_layer_id], page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d], length_info_on_depths_view_[d], k_rope_pos_offset_view_[d], q_rope_position_map_view_, temp_attn_output_view_, temp_attn_scores_view_, @@ -2289,7 +2299,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } else { // Use prefill kernel for depth d f_prefill( - /*depth=*/d, q_data, qo_indptr_on_depths_view_[d], pages_[layer_id], + /*depth=*/d, q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d], length_info_on_depths_view_[d], k_rope_pos_offset_view_[d], q_rope_position_map_view_, temp_attn_output_view_, temp_attn_scores_view_, @@ -2436,7 +2446,17 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") CHECK(args.size() == 25 || args.size() == 26 || args.size() == 27) << "Invalid number of KV cache constructor args."; ShapeTuple cache_config = args[0]; - int64_t num_layers = args[1]; + ShapeTuple layer_indptr_tuple = args[1]; + int num_groups = 1; + int group_id = 0; + if (DiscoWorker* disco_worker = ThreadLocalDiscoWorker::Get()->worker) { + // In the Disco worker thread + num_groups = disco_worker->num_groups; + group_id = disco_worker->worker_id / (disco_worker->num_workers / num_groups); + } + CHECK_EQ(layer_indptr_tuple.size(), num_groups + 1); + int64_t num_layers = layer_indptr_tuple[group_id + 1] - layer_indptr_tuple[group_id]; + int64_t layer_id_begin_offset = layer_indptr_tuple[group_id]; int64_t num_qo_heads = args[2]; int64_t num_kv_heads = args[3]; int64_t head_dim = args[4]; @@ -2482,11 +2502,11 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") num_total_pages += reserved_num_seqs * 2; } ObjectPtr n = make_object( - page_size, num_layers, num_qo_heads, num_kv_heads, head_dim, reserved_num_seqs, - num_total_pages, prefill_chunk_size, support_sliding_window, RoPEMode(rope_mode), - rotary_scale, rotary_theta, init->dtype, init->device, std::move(f_transpose_append), - std::move(f_compact_copy), std::move(f_attention_prefill), std::move(f_attention_decode), - std::move(f_attention_prefill_sliding_window), + page_size, num_layers, layer_id_begin_offset, num_qo_heads, num_kv_heads, head_dim, + reserved_num_seqs, num_total_pages, prefill_chunk_size, support_sliding_window, + RoPEMode(rope_mode), rotary_scale, rotary_theta, init->dtype, init->device, + std::move(f_transpose_append), std::move(f_compact_copy), std::move(f_attention_prefill), + std::move(f_attention_decode), std::move(f_attention_prefill_sliding_window), std::move(f_attention_decode_sliding_window), std::move(f_attention_prefill_ragged), std::move(f_attention_prefill_with_tree_mask), std::move(f_attention_prefill_ragged_begin_forward), @@ -2503,7 +2523,17 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced") CHECK(args.size() == 19 || args.size() == 20 || args.size() == 21) << "Invalid number of KV cache constructor args."; ShapeTuple cache_config = args[0]; - int64_t num_layers = args[1]; + ShapeTuple layer_indptr_tuple = args[1]; + int num_groups = 1; + int group_id = 0; + if (DiscoWorker* disco_worker = ThreadLocalDiscoWorker::Get()->worker) { + // In the Disco worker thread + num_groups = disco_worker->num_groups; + group_id = disco_worker->worker_id / (disco_worker->num_workers / num_groups); + } + CHECK_EQ(layer_indptr_tuple.size(), num_groups + 1); + int64_t num_layers = layer_indptr_tuple[group_id + 1] - layer_indptr_tuple[group_id]; + int64_t layer_id_begin_offset = layer_indptr_tuple[group_id]; int64_t num_qo_heads = args[2]; int64_t num_kv_heads = args[3]; int64_t head_dim = args[4]; @@ -2543,11 +2573,11 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced") num_total_pages += reserved_num_seqs * 2; } ObjectPtr n = make_object( - page_size, num_layers, num_qo_heads, num_kv_heads, head_dim, reserved_num_seqs, - num_total_pages, prefill_chunk_size, support_sliding_window, RoPEMode(rope_mode), - rotary_scale, rotary_theta, init->dtype, init->device, std::move(f_transpose_append), - std::move(f_compact_copy), std::move(f_attention_prefill), std::move(f_attention_decode), - std::move(f_attention_prefill_sliding_window), + page_size, num_layers, layer_id_begin_offset, num_qo_heads, num_kv_heads, head_dim, + reserved_num_seqs, num_total_pages, prefill_chunk_size, support_sliding_window, + RoPEMode(rope_mode), rotary_scale, rotary_theta, init->dtype, init->device, + std::move(f_transpose_append), std::move(f_compact_copy), std::move(f_attention_prefill), + std::move(f_attention_decode), std::move(f_attention_prefill_sliding_window), std::move(f_attention_decode_sliding_window), std::move(f_attention_prefill_ragged), std::move(f_attention_prefill_with_tree_mask), // NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, // diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py index 048cf498067b..bade04a7d753 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py @@ -354,7 +354,7 @@ def create_kv_cache(rope_mode): support_sliding_window, ] ), - num_layers, + tvm.runtime.ShapeTuple([0, num_layers]), num_qo_heads, num_kv_heads, head_dim, diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py index 34680160c8de..9192bb901ff0 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py @@ -153,7 +153,7 @@ def create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window): int(support_sliding_window), ] ), - num_layers, + tvm.runtime.ShapeTuple([0, num_layers]), num_qo_heads, num_kv_heads, head_dim,