From 2bf72304e46874e753d6e3851cb21c38ce7a4fcd Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Mon, 7 Jul 2025 15:43:24 -0400 Subject: [PATCH 1/2] [KVCache] Fix kernel dispatch based on attention kinds This PR fixes a few kernel dispatch issues due to the recent introduction of `mha_sliding` as a new attention kind. Tested on Qwen3 1.7B with MLC-LLM. --- python/tvm/relax/frontend/nn/llm/kv_cache.py | 48 ++++++++++---------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py b/python/tvm/relax/frontend/nn/llm/kv_cache.py index a1d742739aca..5843f6b9f986 100644 --- a/python/tvm/relax/frontend/nn/llm/kv_cache.py +++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py @@ -374,19 +374,18 @@ def __init__( # pylint: disable=too-many-locals if rope_mode == RopeMode.INLINE: assert rotary_dim == qk_head_dim, "FlashInfer RoPE does not support partial rotary dim." + attn_kind_single = attn_kind[0] if isinstance(attn_kind, List) else attn_kind + if attn_kind_single == "mha_sliding": + attn_kind_single = "mha" flashinfer_prefill_mods = rx.backend.cuda.flashinfer.gen_flashinfer_prefill_module( dtype_q=dtype, dtype_kv=dtype, dtype_o=dtype, qk_head_dim=( - qk_head_dim - if (attn_kind == "mha" or isinstance(attn_kind, List)) - else mla_original_qk_head_dim + qk_head_dim if attn_kind_single == "mha" else mla_original_qk_head_dim ), v_head_dim=( - v_head_dim - if (attn_kind == "mha" or isinstance(attn_kind, List)) - else mla_original_v_head_dim + v_head_dim if attn_kind_single == "mha" else mla_original_v_head_dim ), target=target, enable_inline_rope=rope_mode == RopeMode.INLINE, @@ -400,7 +399,7 @@ def __init__( # pylint: disable=too-many-locals v_head_dim=v_head_dim, target=target, ) - if (attn_kind == "mha" or isinstance(attn_kind, List)) + if attn_kind_single == "mha" else [] ) flashinfer_mla_mods = ( @@ -412,7 +411,7 @@ def __init__( # pylint: disable=too-many-locals head_dim_kpe=qk_head_dim - v_head_dim, target=target, ) - if attn_kind == "mla" + if attn_kind_single == "mla" else [] ) self.extern_mods = flashinfer_prefill_mods + flashinfer_decode_mods + flashinfer_mla_mods @@ -429,21 +428,21 @@ def __init__( # pylint: disable=too-many-locals rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache")]), rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask")]), ] - if (attn_kind == "mha" or isinstance(attn_kind, List)) + if attn_kind_single == "mha" else [rx.Tuple([]) for _ in range(6)] ) - mla_function = rx.Tuple([rx.StringImm("flashinfer"), rx.ExternFunc("batch_mla_paged_attention_run"), rx.ExternFunc("batch_mla_paged_attention_plan")] if attn_kind == "mla" else []) + mla_function = rx.Tuple([rx.StringImm("flashinfer"), rx.ExternFunc("batch_mla_paged_attention_run"), rx.ExternFunc("batch_mla_paged_attention_plan")] if attn_kind_single == "mla" else []) attn_merge_functions = [ bb.add_func(_merge_state_inplace(num_attention_heads, v_head_dim, dtype, target, "tir_attention_merge_state"), "tir_attention_merge_state"), ] - if attn_kind == "mla": + if attn_kind_single == "mla": attn_merge_functions.append(bb.add_func(_merge_state_inplace(num_attention_heads, mla_original_v_head_dim, dtype, target, "tir_attention_merge_state_mla"), "tir_attention_merge_state_mla")) - if isinstance(attn_kind, List): attn_kind = [int(getattr(AttnKind, layer_kind.upper())) for layer_kind in attn_kind] else: attn_kind = [int(getattr(AttnKind, attn_kind.upper())) for _ in range(num_hidden_layers)] + args = [ rx.ShapeExpr( [ @@ -459,9 +458,7 @@ def __init__( # pylint: disable=too-many-locals rx.PrimValue(num_key_value_heads), rx.PrimValue(qk_head_dim), rx.PrimValue(v_head_dim), - rx.ShapeExpr( - [int(getattr(AttnKind, attn_kind.upper())) for _ in range(num_hidden_layers)] - ), + rx.ShapeExpr(attn_kind), rx.PrimValue(enable_disaggregation), rx.PrimValue(rope_mode), rx.PrimValue(rope_scale), @@ -475,7 +472,7 @@ def __init__( # pylint: disable=too-many-locals mla_function, rx.Tuple(attn_merge_functions), bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, qk_head_dim, num_attention_heads, num_key_value_heads, dtype, rope_scaling, rotary_dim), "tir_split_rotary"), - bb.add_func(_copy_single_page(num_key_value_heads, page_size, qk_head_dim, dtype, target) if attn_kind == "mha" else _copy_single_page_mla(page_size, qk_head_dim, dtype, target), "kv_cache_copy_single_page"), + bb.add_func(_copy_single_page(num_key_value_heads, page_size, qk_head_dim, dtype, target) if attn_kind_single == "mha" else _copy_single_page_mla(page_size, qk_head_dim, dtype, target), "kv_cache_copy_single_page"), bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, qk_head_dim, dtype), "kv_cache_debug_get_kv"), bb.add_func(_compact_kv_copy(num_key_value_heads, qk_head_dim, dtype, target), "kv_cache_compact_kv_copy"), # fmt: on @@ -567,6 +564,9 @@ def __init__( # pylint: disable=too-many-locals target : Target The target to build the model to. """ + attn_kind_single = attn_kind[0] if isinstance(attn_kind, List) else attn_kind + if attn_kind_single == "mha_sliding": + attn_kind_single = "mha" if isinstance(attn_kind, List): attn_kind = [int(getattr(AttnKind, layer_kind.upper())) for layer_kind in attn_kind] else: @@ -605,7 +605,7 @@ def __init__( # pylint: disable=too-many-locals ] if str(target.kind) == "llvm": - if attn_kind == "mla": + if attn_kind_single == "mla": raise ValueError("MLA is not supported in TIR kernels for now.") # pylint: disable=line-too-long # fmt: off @@ -631,9 +631,9 @@ def __init__( # pylint: disable=too-many-locals else: # pylint: disable=line-too-long # fmt: off - ragged_qk_head_dim = qk_head_dim if (attn_kind == "mha" or isinstance(attn_kind, List)) else mla_original_qk_head_dim - ragged_v_head_dim = v_head_dim if (attn_kind == "mha" or isinstance(attn_kind, List)) else mla_original_v_head_dim - args.append(rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill_ragged(num_key_value_heads if (attn_kind == "mha" or isinstance(attn_kind, List)) else num_attention_heads, num_attention_heads, ragged_qk_head_dim, ragged_v_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_ragged")])) + ragged_qk_head_dim = qk_head_dim if attn_kind_single == "mha" else mla_original_qk_head_dim + ragged_v_head_dim = v_head_dim if attn_kind_single == "mha" else mla_original_v_head_dim + args.append(rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill_ragged(num_key_value_heads if attn_kind_single == "mha" else num_attention_heads, num_attention_heads, ragged_qk_head_dim, ragged_v_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_ragged")])) mha_functions = ( [ rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, False, rope_scaling, target), "tir_attention_prefill")]), @@ -643,14 +643,14 @@ def __init__( # pylint: disable=too-many-locals rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache")]), rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask")]), ] - if (attn_kind == "mha" or isinstance(attn_kind, List)) + if attn_kind_single == "mha" else [rx.Tuple([]) for _ in range(6)] ) - mla_function = rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill_mla(num_attention_heads, v_head_dim, qk_head_dim - v_head_dim, dtype, False, target), "tir_attention_prefill_mla")] if attn_kind == "mla" else []) + mla_function = rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill_mla(num_attention_heads, v_head_dim, qk_head_dim - v_head_dim, dtype, False, target), "tir_attention_prefill_mla")] if attn_kind_single == "mla" else []) attn_merge_functions = [ bb.add_func(_merge_state_inplace(num_attention_heads, v_head_dim, dtype, target, "tir_attention_merge_state"), "tir_attention_merge_state"), ] - if attn_kind == "mla": + if attn_kind_single == "mla": attn_merge_functions.append(bb.add_func(_merge_state_inplace(num_attention_heads, mla_original_v_head_dim, dtype, target, "tir_attention_merge_state_mla"), "tir_attention_merge_state_mla")) args.extend(mha_functions) args.append(mla_function) @@ -658,7 +658,7 @@ def __init__( # pylint: disable=too-many-locals [ rx.Tuple(attn_merge_functions), bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, qk_head_dim, num_attention_heads, num_key_value_heads, dtype, rope_scaling, rotary_dim), "tir_split_rotary"), - bb.add_func(_copy_single_page(num_key_value_heads, page_size, qk_head_dim, dtype, target) if (attn_kind == "mha" or isinstance(attn_kind, List)) else _copy_single_page_mla(page_size, qk_head_dim, dtype, target), "kv_cache_copy_single_page"), + bb.add_func(_copy_single_page(num_key_value_heads, page_size, qk_head_dim, dtype, target) if attn_kind_single == "mha" else _copy_single_page_mla(page_size, qk_head_dim, dtype, target), "kv_cache_copy_single_page"), bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, qk_head_dim, dtype), "kv_cache_debug_get_kv"), bb.add_func(_compact_kv_copy(num_key_value_heads, qk_head_dim, dtype, target), "kv_cache_compact_kv_copy"), ] From 6564cf03e4a696c3c6845ddbad226f42d5c82745 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Mon, 7 Jul 2025 16:59:07 -0500 Subject: [PATCH 2/2] Fix lint --- python/tvm/relax/frontend/nn/llm/kv_cache.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py b/python/tvm/relax/frontend/nn/llm/kv_cache.py index 5843f6b9f986..e6e171da9903 100644 --- a/python/tvm/relax/frontend/nn/llm/kv_cache.py +++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py @@ -381,12 +381,8 @@ def __init__( # pylint: disable=too-many-locals dtype_q=dtype, dtype_kv=dtype, dtype_o=dtype, - qk_head_dim=( - qk_head_dim if attn_kind_single == "mha" else mla_original_qk_head_dim - ), - v_head_dim=( - v_head_dim if attn_kind_single == "mha" else mla_original_v_head_dim - ), + qk_head_dim=(qk_head_dim if attn_kind_single == "mha" else mla_original_qk_head_dim), + v_head_dim=(v_head_dim if attn_kind_single == "mha" else mla_original_v_head_dim), target=target, enable_inline_rope=rope_mode == RopeMode.INLINE, )