Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions include/tvm/runtime/disco/disco_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,21 @@ class DiscoWorker {
struct Impl;
friend struct DiscoWorker::Impl;
};
/*!
* \brief A threadlocal wrapper of DiscoWorker.
*/
struct ThreadLocalDiscoWorker {
/*! \brief The Disco worker */
DiscoWorker* worker;

/*!
* \brief Get the threadlocal Disco worker.
*/
static ThreadLocalDiscoWorker* Get() {
thread_local static ThreadLocalDiscoWorker worker;
return &worker;
}
};

} // namespace runtime
} // namespace tvm
Expand Down
9 changes: 0 additions & 9 deletions src/runtime/disco/disco_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,6 @@
namespace tvm {
namespace runtime {

struct ThreadLocalDiscoWorker {
DiscoWorker* worker;

static ThreadLocalDiscoWorker* Get() {
thread_local static ThreadLocalDiscoWorker worker;
return &worker;
}
};

TVM_DLL DiscoWorker* DiscoWorker::ThreadLocal() {
DiscoWorker* ret = ThreadLocalDiscoWorker::Get()->worker;
CHECK(ret) << "ValueError: The current thread is not a DiscoWorker thread";
Expand Down
82 changes: 56 additions & 26 deletions src/runtime/relax_vm/paged_kv_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
* \brief Runtime paged KV cache object for language models.
*/
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/disco/disco_worker.h>
#include <tvm/runtime/logging.h>
#include <tvm/runtime/memory/memory_manager.h>
#include <tvm/runtime/ndarray.h>
Expand Down Expand Up @@ -825,6 +826,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
const int64_t page_size_;
/*! \brief The number of layers in the model. */
const int64_t num_layers_;
/*! \brief The beginning layer id offset. */
const int64_t layer_id_begin_offset_;
/*! \brief The number of query/output heads in the model. */
const int64_t num_qo_heads_;
/*! \brief The number of key/value heads in the model. */
Expand Down Expand Up @@ -981,14 +984,14 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
public:
/*! \brief Constructor. Take the cache configuration and initialize the NDArrays. */
explicit PagedAttentionKVCacheObj(
int64_t page_size, //
int64_t num_layers, int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim,
int64_t reserved_num_seqs, int64_t num_total_pages, int64_t prefill_chunk_size,
bool support_sliding_window, RoPEMode rope_mode, double rotary_scale, double rotary_theta,
DLDataType dtype, Device device, PackedFunc f_transpose_append, PackedFunc f_compact_copy,
PackedFunc f_attention_prefill, PackedFunc f_attention_decode,
PackedFunc f_attention_prefill_sliding_window, PackedFunc f_attention_decode_sliding_window,
PackedFunc f_attention_prefill_ragged, PackedFunc f_attention_prefill_with_tree_mask,
int64_t page_size, int64_t num_layers, int64_t layer_id_begin_offset, //
int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim, int64_t reserved_num_seqs,
int64_t num_total_pages, int64_t prefill_chunk_size, bool support_sliding_window,
RoPEMode rope_mode, double rotary_scale, double rotary_theta, DLDataType dtype, Device device,
PackedFunc f_transpose_append, PackedFunc f_compact_copy, PackedFunc f_attention_prefill,
PackedFunc f_attention_decode, PackedFunc f_attention_prefill_sliding_window,
PackedFunc f_attention_decode_sliding_window, PackedFunc f_attention_prefill_ragged,
PackedFunc f_attention_prefill_with_tree_mask,
Optional<PackedFunc> f_attention_prefill_ragged_begin_forward,
Optional<PackedFunc> f_attention_prefill_ragged_end_forward,
Optional<PackedFunc> f_attention_prefill_begin_forward,
Expand All @@ -998,6 +1001,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
PackedFunc f_split_rotary, PackedFunc f_copy_single_page, Optional<PackedFunc> f_debug_get_kv)
: page_size_(page_size),
num_layers_(num_layers),
layer_id_begin_offset_(layer_id_begin_offset),
num_qo_heads_(num_qo_heads),
num_kv_heads_(num_kv_heads),
head_dim_(head_dim),
Expand Down Expand Up @@ -1672,7 +1676,10 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
void AttentionWithFusedQKV(int64_t layer_id, NDArray qkv_data, Optional<NDArray> mask,
NDArray o_data, double attn_score_scaling_factor) final {
// Part 1. Shape and dtype check.
NDArray pages = pages_[layer_id];
int64_t local_layer_id = layer_id - layer_id_begin_offset_;
CHECK_GE(local_layer_id, 0);
CHECK_LT(local_layer_id, num_layers_);
NDArray pages = pages_[local_layer_id];
CHECK(qkv_data.DataType() == pages.DataType());
CHECK(o_data.DataType() == pages.DataType());

Expand Down Expand Up @@ -1713,13 +1720,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {

// Part 3. Append k/v data to kv-cache if flag "append_before_attn" is set.
if (append_before_attn_) {
f_transpose_append_(pages_[layer_id], k_data, v_data, append_position_map_view_);
f_transpose_append_(pages_[local_layer_id], k_data, v_data, append_position_map_view_);
}
// Part 4: perform attention
AttentionInternal(layer_id, q_data, k_data, v_data, o_data, attn_score_scaling_factor);
// Part 5. Append k/v data to kv-cache if flag "append_before_attn" is not set.
if (!append_before_attn_) {
f_transpose_append_(pages_[layer_id], k_data, v_data, append_position_map_view_);
f_transpose_append_(pages_[local_layer_id], k_data, v_data, append_position_map_view_);
}
}

Expand Down Expand Up @@ -2238,14 +2245,17 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
*/
void AttentionInternal(int64_t layer_id, NDArray q_data, NDArray k_data, NDArray v_data,
NDArray output, double attn_score_scaling_factor) {
int64_t local_layer_id = layer_id - layer_id_begin_offset_;
CHECK_GE(local_layer_id, 0);
CHECK_LT(local_layer_id, num_layers_);
PackedFunc f_prefill =
!support_sliding_window_ ? f_attention_prefill_ : f_attention_prefill_sliding_window_;
PackedFunc f_decode =
!support_sliding_window_ ? f_attention_decode_ : f_attention_decode_sliding_window_;
CHECK_GE(num_depths_, 1) << "The number of effective depths must be greater or equal to 1.";
if (append_before_attn_) {
f_decode(
/*depth=*/0, q_data, pages_[layer_id], page_indptr_on_depths_view_[0],
/*depth=*/0, q_data, pages_[local_layer_id], page_indptr_on_depths_view_[0],
page_indices_on_depths_view_[0], length_info_on_depths_view_[0],
k_rope_pos_offset_view_[0], q_rope_position_map_view_, output, merged_attn_scores_view_,
/*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, rotary_theta_,
Expand Down Expand Up @@ -2280,7 +2290,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
}
if (use_decode_kernel_[d]) {
// Use decode kernel for depth d
f_decode(/*depth=*/d, q_data, pages_[layer_id], page_indptr_on_depths_view_[d],
f_decode(/*depth=*/d, q_data, pages_[local_layer_id], page_indptr_on_depths_view_[d],
page_indices_on_depths_view_[d], length_info_on_depths_view_[d],
k_rope_pos_offset_view_[d], q_rope_position_map_view_, temp_attn_output_view_,
temp_attn_scores_view_,
Expand All @@ -2289,7 +2299,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
} else {
// Use prefill kernel for depth d
f_prefill(
/*depth=*/d, q_data, qo_indptr_on_depths_view_[d], pages_[layer_id],
/*depth=*/d, q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id],
page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d],
length_info_on_depths_view_[d], k_rope_pos_offset_view_[d], q_rope_position_map_view_,
temp_attn_output_view_, temp_attn_scores_view_,
Expand Down Expand Up @@ -2436,7 +2446,17 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
CHECK(args.size() == 25 || args.size() == 26 || args.size() == 27)
<< "Invalid number of KV cache constructor args.";
ShapeTuple cache_config = args[0];
int64_t num_layers = args[1];
ShapeTuple layer_indptr_tuple = args[1];
int num_groups = 1;
int group_id = 0;
if (DiscoWorker* disco_worker = ThreadLocalDiscoWorker::Get()->worker) {
// In the Disco worker thread
num_groups = disco_worker->num_groups;
group_id = disco_worker->worker_id / (disco_worker->num_workers / num_groups);
}
CHECK_EQ(layer_indptr_tuple.size(), num_groups + 1);
int64_t num_layers = layer_indptr_tuple[group_id + 1] - layer_indptr_tuple[group_id];
int64_t layer_id_begin_offset = layer_indptr_tuple[group_id];
int64_t num_qo_heads = args[2];
int64_t num_kv_heads = args[3];
int64_t head_dim = args[4];
Expand Down Expand Up @@ -2482,11 +2502,11 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
num_total_pages += reserved_num_seqs * 2;
}
ObjectPtr<PagedAttentionKVCacheObj> n = make_object<PagedAttentionKVCacheObj>(
page_size, num_layers, num_qo_heads, num_kv_heads, head_dim, reserved_num_seqs,
num_total_pages, prefill_chunk_size, support_sliding_window, RoPEMode(rope_mode),
rotary_scale, rotary_theta, init->dtype, init->device, std::move(f_transpose_append),
std::move(f_compact_copy), std::move(f_attention_prefill), std::move(f_attention_decode),
std::move(f_attention_prefill_sliding_window),
page_size, num_layers, layer_id_begin_offset, num_qo_heads, num_kv_heads, head_dim,
reserved_num_seqs, num_total_pages, prefill_chunk_size, support_sliding_window,
RoPEMode(rope_mode), rotary_scale, rotary_theta, init->dtype, init->device,
std::move(f_transpose_append), std::move(f_compact_copy), std::move(f_attention_prefill),
std::move(f_attention_decode), std::move(f_attention_prefill_sliding_window),
std::move(f_attention_decode_sliding_window), std::move(f_attention_prefill_ragged),
std::move(f_attention_prefill_with_tree_mask),
std::move(f_attention_prefill_ragged_begin_forward),
Expand All @@ -2503,7 +2523,17 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
CHECK(args.size() == 19 || args.size() == 20 || args.size() == 21)
<< "Invalid number of KV cache constructor args.";
ShapeTuple cache_config = args[0];
int64_t num_layers = args[1];
ShapeTuple layer_indptr_tuple = args[1];
int num_groups = 1;
int group_id = 0;
if (DiscoWorker* disco_worker = ThreadLocalDiscoWorker::Get()->worker) {
// In the Disco worker thread
num_groups = disco_worker->num_groups;
group_id = disco_worker->worker_id / (disco_worker->num_workers / num_groups);
}
CHECK_EQ(layer_indptr_tuple.size(), num_groups + 1);
int64_t num_layers = layer_indptr_tuple[group_id + 1] - layer_indptr_tuple[group_id];
int64_t layer_id_begin_offset = layer_indptr_tuple[group_id];
int64_t num_qo_heads = args[2];
int64_t num_kv_heads = args[3];
int64_t head_dim = args[4];
Expand Down Expand Up @@ -2543,11 +2573,11 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
num_total_pages += reserved_num_seqs * 2;
}
ObjectPtr<PagedAttentionKVCacheObj> n = make_object<PagedAttentionKVCacheObj>(
page_size, num_layers, num_qo_heads, num_kv_heads, head_dim, reserved_num_seqs,
num_total_pages, prefill_chunk_size, support_sliding_window, RoPEMode(rope_mode),
rotary_scale, rotary_theta, init->dtype, init->device, std::move(f_transpose_append),
std::move(f_compact_copy), std::move(f_attention_prefill), std::move(f_attention_decode),
std::move(f_attention_prefill_sliding_window),
page_size, num_layers, layer_id_begin_offset, num_qo_heads, num_kv_heads, head_dim,
reserved_num_seqs, num_total_pages, prefill_chunk_size, support_sliding_window,
RoPEMode(rope_mode), rotary_scale, rotary_theta, init->dtype, init->device,
std::move(f_transpose_append), std::move(f_compact_copy), std::move(f_attention_prefill),
std::move(f_attention_decode), std::move(f_attention_prefill_sliding_window),
std::move(f_attention_decode_sliding_window), std::move(f_attention_prefill_ragged),
std::move(f_attention_prefill_with_tree_mask), //
NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, //
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def create_kv_cache(rope_mode):
support_sliding_window,
]
),
num_layers,
tvm.runtime.ShapeTuple([0, num_layers]),
num_qo_heads,
num_kv_heads,
head_dim,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window):
int(support_sliding_window),
]
),
num_layers,
tvm.runtime.ShapeTuple([0, num_layers]),
num_qo_heads,
num_kv_heads,
head_dim,
Expand Down