From 1ee0695daf0e153c633ab8767b6385bd344e3d2c Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sat, 7 Sep 2024 15:15:51 -0400 Subject: [PATCH] [Fix][Relax] Add the missing tree-attn func arg for KV cache creation This PR fixes the TIRPagedKVCache construction issue, which is caused by missing the tree-attention with paged KV cache kernel. --- python/tvm/relax/frontend/nn/llm/kv_cache.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py b/python/tvm/relax/frontend/nn/llm/kv_cache.py index 7b14c67a2e57..ae0537f0d9af 100644 --- a/python/tvm/relax/frontend/nn/llm/kv_cache.py +++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py @@ -375,6 +375,7 @@ def __init__( # pylint: disable=too-many-locals bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype), "kv_cache_debug_get_kv"), bb.add_func(_compact_kv_copy(num_key_value_heads, head_dim, dtype, target), "kv_cache_compact_kv_copy"), bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask"), + bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache"), rope_ext_factors, # fmt: on # pylint: enable=line-too-long