From d7fb5e42f0ece65b4409645d0ccbde54293ecf1e Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Thu, 13 Feb 2025 16:37:36 -0500 Subject: [PATCH] [KVCache] PagedKVCache refactor, FlashInfer JIT and MLA integration This PR consists of the following parts: * We reorganized `paged_kv_cache.cc` by moving some of the utilities to `attn_utils.h`. * To integrate with the JIT kernel compilation in the latest FlashInfer project, while still being able to support attention kernels written with TIR, we introduced `AttnBackendFunc` in `attn_backend.h`, which exposes attention interfaces (e.g., `MHA`, `MLA`) to PagedKVCache. We subclass `AttnBackendFunc` and implement FlashInfer backends and TIR backends respectively. * With `AttnBackendFunc`, we refactored the PagedKVCache constructor. The new constructor is not backward compatible, and will break the existing compiled model libraries. * For both TIR and FlashInfer attention implementations, now we require an explicit attention softmax scale factor `sm_scale` to be passed in. Previously, it has an inlined `sm_scale` of `head_dim ** -0.5`. Due to the recent LLM inference techniques such as MLA weight absorption in DeepSeek models, the inlined `sm_scale` causes confusion and inconvenience. To keep attention interface standard and clear, we now require the explicit passing of `sm_scale`. * We refactored the existing GPU unit tests of the PagedKVCache, by updating from numpy to PyTorch for std calculation. This significantly reduces the test case run time. --- docs/how_to/tutorials/optimize_llm.py | 10 +- python/tvm/relax/__init__.py | 1 + python/tvm/relax/backend/cuda/__init__.py | 1 + python/tvm/relax/backend/cuda/flashinfer.py | 357 ++++ python/tvm/relax/frontend/nn/llm/kv_cache.py | 1345 ++++-------- python/tvm/relax/frontend/nn/llm/tree_attn.py | 56 +- src/runtime/relax_vm/attn_backend.cc | 125 ++ src/runtime/relax_vm/attn_backend.h | 531 +++++ src/runtime/relax_vm/attn_utils.h | 1027 +++++++++ src/runtime/relax_vm/kv_state.cc | 42 +- src/runtime/relax_vm/kv_state.h | 68 +- src/runtime/relax_vm/paged_kv_cache.cc | 1857 ++++------------- .../test_runtime_builtin_kv_cache_transfer.py | 216 +- ...me_builtin_paged_attention_kv_cache_cpu.py | 38 +- ...tin_paged_attention_kv_cache_flashinfer.py | 291 ++- ...paged_attention_kv_cache_mla_flashinfer.py | 593 ++++++ ...uiltin_paged_attention_kv_cache_mla_tir.py | 314 ++- ...me_builtin_paged_attention_kv_cache_tir.py | 223 +- 18 files changed, 4185 insertions(+), 2910 deletions(-) create mode 100644 python/tvm/relax/backend/cuda/flashinfer.py create mode 100644 src/runtime/relax_vm/attn_backend.cc create mode 100644 src/runtime/relax_vm/attn_backend.h create mode 100644 src/runtime/relax_vm/attn_utils.h create mode 100644 tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py diff --git a/docs/how_to/tutorials/optimize_llm.py b/docs/how_to/tutorials/optimize_llm.py index c30b2c381c8b..49855910fc45 100644 --- a/docs/how_to/tutorials/optimize_llm.py +++ b/docs/how_to/tutorials/optimize_llm.py @@ -191,7 +191,9 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d)) # Attention output = op.reshape( - paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, self.num_q_heads), + paged_kv_cache.attention_with_fused_qkv( + layer_id, qkv, self.num_q_heads, sm_scale=self.head_dim**-0.5 + ), (b, s, h_q * d), ) # Output Projection @@ -285,6 +287,7 @@ def create_tir_paged_kv_cache( page_size: tir.Var, ) -> PagedKVCache: return TIRPagedKVCache( + attn_kind="mha", max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size, @@ -294,7 +297,10 @@ def create_tir_paged_kv_cache( num_hidden_layers=self.num_hidden_layers, num_attention_heads=self.num_attention_heads, num_key_value_heads=self.num_key_value_heads, - head_dim=self.head_dim, + qk_head_dim=self.head_dim, + v_head_dim=self.head_dim, + mla_original_qk_head_dim=0, + mla_original_v_head_dim=0, rope_mode=RopeMode.NORMAL, rope_scale=1, rope_theta=self.rope_theta, diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index 471a4ba9d337..8494bd8e5838 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -91,6 +91,7 @@ ) # pipeline +from .pipeline import get_default_pipeline from .pipeline import get_pipeline from .pipeline import register_pipeline diff --git a/python/tvm/relax/backend/cuda/__init__.py b/python/tvm/relax/backend/cuda/__init__.py index f4458f4b55d1..df1a1f98e376 100644 --- a/python/tvm/relax/backend/cuda/__init__.py +++ b/python/tvm/relax/backend/cuda/__init__.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """The Relax CUDA backend compilation pipeline and other passes.""" +from . import flashinfer from .pipeline import ( finalize_passes, get_default_pipeline, diff --git a/python/tvm/relax/backend/cuda/flashinfer.py b/python/tvm/relax/backend/cuda/flashinfer.py new file mode 100644 index 000000000000..725fd105add0 --- /dev/null +++ b/python/tvm/relax/backend/cuda/flashinfer.py @@ -0,0 +1,357 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""FlashInfer JIT compilation module for CUDA backend""" +import os +import subprocess +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import List + +import tvm +from tvm.target import Target + + +def _compile_flashinfer_kernels( + name: str, source_paths: List[Path], target: Target, num_threads: int +) -> List[Path]: + from flashinfer.jit.env import ( # pylint: disable=import-outside-toplevel + CUTLASS_INCLUDE_DIRS, + FLASHINFER_CSRC_DIR, + FLASHINFER_INCLUDE_DIR, + FLASHINFER_JIT_DIR, + FLASHINFER_TVM_BINDING_DIR, + ) + + # Todo(tvm-team): enable compilation cache + # ------------------------------------------------------------------------ + # 1) Common CUDA compile flags + # ------------------------------------------------------------------------ + cuda_cflags = [ + "-O3", + "-std=c++17", + "--threads", + str(num_threads), + "-g", + "-use_fast_math", + "--expt-relaxed-constexpr", + # DMLC default + "-DDMLC_USE_FOPEN64=0", + "-DDMLC_USE_LOGGING_LIBRARY=", + # Enable `-fPIC` for the host compiler + "-Xcompiler=-fPIC", + "-DFLASHINFER_ENABLE_F16", + "-DFLASHINFER_ENABLE_BF16", + "-DFLASHINFER_ENABLE_FP8_E4M3", + "-DFLASHINFER_ENABLE_FP8_E5M2", + ] + + # Determine compute version + compute_version = "".join(tvm.contrib.nvcc.get_target_compute_version(target).split(".")) + if compute_version in ["90"]: + compute_version += "a" + cuda_cflags += [ + "-gencode", + f"arch=compute_{compute_version},code=sm_{compute_version}", + ] + + # ------------------------------------------------------------------------ + # 2) Include paths + # ------------------------------------------------------------------------ + tvm_home = os.environ["TVM_SOURCE_DIR"] + include_paths = [ + FLASHINFER_INCLUDE_DIR, + FLASHINFER_CSRC_DIR, + FLASHINFER_TVM_BINDING_DIR, + Path(tvm_home).resolve() / "include", + Path(tvm_home).resolve() / "3rdparty" / "dlpack" / "include", + Path(tvm_home).resolve() / "3rdparty" / "dmlc-core" / "include", + ] + CUTLASS_INCLUDE_DIRS + + # Where object files will be placed + build_directory = FLASHINFER_JIT_DIR / name + build_directory.mkdir(parents=True, exist_ok=True) + + # ------------------------------------------------------------------------ + # 3) Function to compile a single source file + # ------------------------------------------------------------------------ + def compile_single_source(src: Path) -> Path: + # Derive the .o filename from the source filename + obj_name = src.stem + ".o" + obj_path = build_directory / obj_name + + # Construct the command + cmd = ( + ["nvcc"] + + cuda_cflags + + [f"-I{inc_path}" for inc_path in include_paths] + + ["-c", "-o", str(obj_path), str(src)] + ) + + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + out, err = proc.communicate() + if proc.returncode != 0: + raise RuntimeError( + f"FlashInfer JIT compilation failed for {src}\n" + f"Command: {' '.join(cmd)}\n" + f"stdout:\n{out.decode('utf-8')}\n" + f"stderr:\n{err.decode('utf-8')}" + ) + return obj_path + + # ------------------------------------------------------------------------ + # 4) Compile each source in parallel using ThreadPoolExecutor + # ------------------------------------------------------------------------ + object_files = [] + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(compile_single_source, src) for src in source_paths] + for f in futures: + object_files.append(f.result()) # Will raise if there's a compilation error + + # Return list of generated object files for any further linking steps + return object_files + + +def _load_flashinfer_modules(object_files: List[Path]) -> List[tvm.runtime.Module]: + return [ + tvm.runtime.load_static_library(str(obj_path.absolute()), func_names=[]) + for obj_path in object_files + ] + + +def gen_flashinfer_prefill_module( + dtype_q: str, + dtype_kv: str, + dtype_o: str, + qk_head_dim: int, + v_head_dim: int, + target: Target, + enable_inline_rope: bool = True, + num_threads: int = 8, +) -> List[tvm.runtime.Module]: + """Generate a FlashInfer module for prefill. + + Parameters + ---------- + dtype_q : str + The data type of the query tensor. + dtype_kv : str + The data type of the key/value tensors. + dtype_o : str + The data type of the output tensor. + qk_head_dim : int + The head dimension of the query and key tensors. + v_head_dim : int + The head dimension of the value tensor. + target : Target + The target device to compile for. + enable_inline_rope : bool + Whether to enable inline rotary positional embedding. + num_threads : int + The number of threads to use for compilation. + + Returns + ------- + A list of compiled static library modules for FlashInfer prefill kernels. + """ + try: + from flashinfer.jit import ( # pylint: disable=import-outside-toplevel + gen_customize_batch_prefill_tvm_binding, + ) + except ImportError: + raise ImportError( + "FlashInfer is not installed. Please follow instructions " + "in https://docs.flashinfer.ai to install FlashInfer." + ) + try: + import torch # pylint: disable=import-outside-toplevel + except ImportError: + raise ImportError("PyTorch is not installed. Please install PyTorch to use FlashInfer.") + + if enable_inline_rope and qk_head_dim != v_head_dim: + raise ValueError("Inline rope mode is not supported when qk_head_dim == v_head_dim") + + torch_dtype_q = getattr(torch, dtype_q) + torch_dtype_kv = getattr(torch, dtype_kv) + torch_dtype_o = getattr(torch, dtype_o) + # Todo(tvm-team): decide which backend ("fa2/fa3") to use + backend = "fa2" + variant_name = ( + "DefaultAttention" + if backend == "fa2" + else "DefaultAttention" + ) + variant_decl = ( + "#include " + if backend == "fa2" + else "#include " + ) + jit_args = { + "backend": backend, + "uri": "batch_prefill_tvm", + "dtype_q": torch_dtype_q, + "dtype_kv": torch_dtype_kv, + "dtype_o": torch_dtype_o, + "idtype": torch.int32, + "head_dim_qk": qk_head_dim, + "head_dim_vo": v_head_dim, + "additional_tensor_names": [], + "additional_tensor_dtypes": [], + "additional_scalar_names": ["sm_scale", "rope_rcp_scale", "rope_rcp_theta"], + "additional_scalar_dtypes": ["double", "double", "double"], + "variant_name": variant_name, + "variant_decl": variant_decl, + "enable_inline_rope": enable_inline_rope, + } + uri, source_paths = gen_customize_batch_prefill_tvm_binding(**jit_args) + object_files = _compile_flashinfer_kernels(uri, source_paths, target, num_threads) + modules = _load_flashinfer_modules(object_files) + return modules + + +def gen_flashinfer_decode_module( + dtype_q: str, + dtype_kv: str, + dtype_o: str, + qk_head_dim: int, + v_head_dim: int, + target: Target, + num_threads: int = 8, +) -> List[tvm.runtime.Module]: + """Generate a FlashInfer module for decode. + + Parameters + ---------- + dtype_q : str + The data type of the query tensor. + dtype_kv : str + The data type of the key/value tensors. + dtype_o : str + The data type of the output tensor. + qk_head_dim : int + The head dimension of the query and key tensors. + v_head_dim : int + The head dimension of the value tensor. + target : Target + The target device to compile for. + num_threads : int + The number of threads to use for compilation. + + Returns + ------- + A list of compiled static library modules for FlashInfer decode kernels. + """ + try: + from flashinfer.jit import ( # pylint: disable=import-outside-toplevel + gen_customize_batch_decode_tvm_binding, + ) + except ImportError: + raise ImportError( + "FlashInfer is not installed. Please follow instructions " + "in https://docs.flashinfer.ai to install FlashInfer." + ) + try: + import torch # pylint: disable=import-outside-toplevel + except ImportError: + raise ImportError("PyTorch is not installed. Please install PyTorch to use FlashInfer.") + + torch_dtype_q = getattr(torch, dtype_q) + torch_dtype_kv = getattr(torch, dtype_kv) + torch_dtype_o = getattr(torch, dtype_o) + jit_args = { + "uri": "batch_decode_tvm", + "dtype_q": torch_dtype_q, + "dtype_kv": torch_dtype_kv, + "dtype_o": torch_dtype_o, + "idtype": torch.int32, + "head_dim_qk": qk_head_dim, + "head_dim_vo": v_head_dim, + "additional_tensor_names": [], + "additional_tensor_dtypes": [], + "additional_scalar_names": ["sm_scale", "rope_rcp_scale", "rope_rcp_theta"], + "additional_scalar_dtypes": ["double", "double", "double"], + "variant_name": "DefaultAttention", + "variant_decl": "#include ", + } + uri, source_paths = gen_customize_batch_decode_tvm_binding(**jit_args) + object_files = _compile_flashinfer_kernels(uri, source_paths, target, num_threads) + modules = _load_flashinfer_modules(object_files) + return modules + + +def gen_flashinfer_mla_module( + dtype_q: str, + dtype_kv: str, + dtype_o: str, + head_dim_ckv: int, + head_dim_kpe: int, + target: Target, + num_threads: int = 8, +) -> List[tvm.runtime.Module]: + """Generate a FlashInfer module for MLA. + + Parameters + ---------- + dtype_q : str + The data type of the query tensor. + dtype_kv : str + The data type of the key/value tensors. + dtype_o : str + The data type of the output tensor. + head_dim_ckv : int + The head dimension of the compressed key/value tensors. + head_dim_kpe : int + The head dimension of the query/key positional embedding. + target : Target + The target device to compile for. + num_threads : int + The number of threads to use for compilation. + + Returns + ------- + A list of compiled static library modules for FlashInfer MLA kernels. + """ + try: + from flashinfer.jit import ( # pylint: disable=import-outside-toplevel + gen_batch_mla_tvm_binding, + ) + except ImportError: + raise ImportError( + "FlashInfer is not installed. Please follow instructions " + "in https://docs.flashinfer.ai to install FlashInfer." + ) + try: + import torch # pylint: disable=import-outside-toplevel + except ImportError: + raise ImportError("PyTorch is not installed. Please install PyTorch to use FlashInfer.") + + torch_dtype_q = getattr(torch, dtype_q) + torch_dtype_kv = getattr(torch, dtype_kv) + torch_dtype_o = getattr(torch, dtype_o) + jit_args = { + "uri": "batch_mla_tvm", + "dtype_q": torch_dtype_q, + "dtype_kv": torch_dtype_kv, + "dtype_o": torch_dtype_o, + "dtype_idx": torch.int32, + "head_dim_ckv": head_dim_ckv, + "head_dim_kpe": head_dim_kpe, + } + uri, source_paths = gen_batch_mla_tvm_binding(**jit_args) + object_files = _compile_flashinfer_kernels(uri, source_paths, target, num_threads) + modules = _load_flashinfer_modules(object_files) + return modules diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py b/python/tvm/relax/frontend/nn/llm/kv_cache.py index ea6f15331654..1d06bf2f3595 100644 --- a/python/tvm/relax/frontend/nn/llm/kv_cache.py +++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py @@ -20,8 +20,9 @@ # pylint: disable=too-many-statements,too-many-lines,too-many-arguments,invalid-name import enum import math -from typing import Any, Dict, Tuple +from typing import Any, Dict, List, Literal, Optional, Tuple +import tvm from tvm import relax as rx from tvm import tir from tvm.relax.frontend.nn import Object, Tensor @@ -102,12 +103,14 @@ class RopeMode(enum.IntEnum): class PagedKVCache(Object): # pylint: disable=too-few-public-methods """The Paged KV Cache used in LLM batching for efficient attention computation.""" + extern_mods: List[tvm.runtime.Module] = [] + def attention_with_fused_qkv( self, layer_id: int, qkv: Tensor, num_qo_heads: int, - attn_score_scaling_factor: float = 1.0, + sm_scale: float, ) -> Tensor: """Compute attention with the given fused q/k/v data and in-cache k/v data on the specified layer. Rotary position embeddings are applied to k/v @@ -131,7 +134,7 @@ def attention_with_fused_qkv( [ self._expr, rx.PrimValue(layer_id), # type: ignore[arg-type] - rx.PrimValue(attn_score_scaling_factor), + rx.PrimValue(sm_scale), qkv._expr, ], out_sinfo=rx.TensorStructInfo((b * s, num_qo_heads, d), qkv.dtype), @@ -139,89 +142,131 @@ def attention_with_fused_qkv( ) ).reshape(b, s, num_qo_heads, d) - def mla_absorbed( + def self_attention( # pylint: disable=too-many-locals self, layer_id: int, q: Tensor, - compressed_kv: Tensor, - k_pe: Tensor, - attn_score_scaling_factor: float = 1.0, - ) -> Tensor: - """Compute multi-head latent attention with the given data - on the specified layer with the weight absorption optimization. - - - For prefill, the input q/kv and output tensor have shape - (1, total_seq_len) for the first two dimensions. - - For decode, the input q/kv and output tensor have shape - (batch_size, 1) for the first two dimensions. - """ + k: Tensor, + v: Tensor, + sm_scale: float, + ) -> Tuple[Tensor, Tensor]: + """Fine-grained API that computes ragged self attention with Q/K/V data.""" # pylint: disable=protected-access b, s, h_qo, d_qk = q._expr.struct_info.shape - kv_lora_rank = compressed_kv._expr.struct_info.shape[3] - qk_rope_head_dim = k_pe._expr.struct_info.shape[3] + _, _, h_kv, d_v = v._expr.struct_info.shape q = q.reshape(b * s, h_qo, d_qk) - compressed_kv = compressed_kv.reshape(b * s, kv_lora_rank) - k_pe = k_pe.reshape(b * s, qk_rope_head_dim) - - return Tensor( - _expr=rx.BlockBuilder.current().emit( - rx.call_dps_packed( - "vm.builtin.attention_kv_cache_mla_absorbed", - [ - self._expr, - rx.PrimValue(layer_id), # type: ignore[arg-type] - rx.PrimValue(attn_score_scaling_factor), - q._expr, - compressed_kv._expr, - k_pe._expr, - ], - out_sinfo=rx.TensorStructInfo((b * s, h_qo, kv_lora_rank), q.dtype), - ) + k = k.reshape(b * s, h_kv, d_qk) + v = v.reshape(b * s, h_kv, d_v) + bb = rx.BlockBuilder.current() + attn_results = bb.emit( + rx.call_dps_packed( + "vm.builtin.attention_kv_cache_self_attention", + [ + self._expr, + rx.PrimValue(layer_id), # type: ignore[arg-type] + rx.PrimValue(sm_scale), + q._expr, + k._expr, + v._expr, + ], + out_sinfo=[ + rx.TensorStructInfo((b * s, h_qo, d_v), q.dtype), + rx.TensorStructInfo((b * s, h_qo), "float32"), + ], ) - ).reshape(b, s, h_qo, kv_lora_rank) + ) + assert isinstance(attn_results.struct_info, rx.TupleStructInfo) + assert len(attn_results.struct_info.fields) == 2 + o = Tensor(_expr=bb.emit(rx.TupleGetItem(attn_results, 0))).reshape(b, s, h_qo, d_v) + lse = Tensor(_expr=bb.emit(rx.TupleGetItem(attn_results, 1))).reshape(b, s, h_qo) + return o, lse - def mla_normal( + def cross_attention( self, layer_id: int, q: Tensor, - k: Tensor, - v: Tensor, - compressed_kv: Tensor, - k_pe: Tensor, - attn_score_scaling_factor: float = 1.0, - ) -> Tensor: - """Compute multi-head latent attention with the given data - on the specified layer using the normal flow(WITHOUT weight absorption). - """ + v_head_dim: int, + sm_scale: float, + ) -> Tuple[Tensor, Tensor]: + """Fine-grained API that computes paged cross attention with Q and in-cache KV data.""" # pylint: disable=protected-access b, s, h_qo, d_qk = q._expr.struct_info.shape - d_v = v._expr.struct_info.shape[3] - kv_lora_rank = compressed_kv._expr.struct_info.shape[3] - qk_rope_head_dim = k_pe._expr.struct_info.shape[3] q = q.reshape(b * s, h_qo, d_qk) - k = k.reshape(b * s, h_qo, d_qk) - v = v.reshape(b * s, h_qo, d_v) - compressed_kv = compressed_kv.reshape(b * s, kv_lora_rank) - k_pe = k_pe.reshape(b * s, qk_rope_head_dim) + bb = rx.BlockBuilder.current() + attn_results = bb.emit( + rx.call_dps_packed( + "vm.builtin.attention_kv_cache_cross_attention", + [ + self._expr, + rx.PrimValue(layer_id), # type: ignore[arg-type] + rx.PrimValue(sm_scale), + q._expr, + ], + out_sinfo=[ + rx.TensorStructInfo((b * s, h_qo, v_head_dim), q.dtype), + rx.TensorStructInfo((b * s, h_qo), "float32"), + ], + ) + ) + assert isinstance(attn_results.struct_info, rx.TupleStructInfo) + assert len(attn_results.struct_info.fields) == 2 + o = Tensor(_expr=bb.emit(rx.TupleGetItem(attn_results, 0))).reshape(b, s, h_qo, v_head_dim) + lse = Tensor(_expr=bb.emit(rx.TupleGetItem(attn_results, 1))).reshape(b, s, h_qo) + return o, lse + + def append_mla_kv(self, layer_id: int, kv: Tensor) -> "PagedKVCache": + """Fine-grained API that appends the MLA K/V data to KV cache.""" + # pylint: disable=protected-access + b, s, _, d_qk = kv._expr.struct_info.shape + kv = kv.reshape(b * s, d_qk) + return PagedKVCache( + _expr=rx.call_pure_packed( + "vm.builtin.attention_kv_cache_append_mla_kv", + self._expr, + rx.PrimValue(layer_id), # type: ignore[arg-type] + kv._expr, + sinfo_args=rx.ObjectStructInfo(), + ), + _name="paged_kv_cache", + ) - return Tensor( - _expr=rx.BlockBuilder.current().emit( - rx.call_dps_packed( - "vm.builtin.attention_kv_cache_mla_normal", - [ - self._expr, - rx.PrimValue(layer_id), # type: ignore[arg-type] - rx.PrimValue(attn_score_scaling_factor), - q._expr, - k._expr, - v._expr, - compressed_kv._expr, - k_pe._expr, - ], - out_sinfo=rx.TensorStructInfo((b * s, h_qo, d_v), q.dtype), - ) + def merge_attn_output_inplace( + self, + o_self_attn: Tensor, + lse_self_attn: Tensor, + o_cross_attn: Tensor, + lse_cross_attn: Tensor, + ) -> Tuple[Tensor, Tensor]: + """Fine-grained API that merges the attention output from two sources. + The first two tensors will be inplace updated. + """ + # pylint: disable=protected-access + b, s, h_qo, d_v = o_self_attn._expr.struct_info.shape + o_self_attn = o_self_attn.reshape(b * s, h_qo, d_v) + lse_self_attn = lse_self_attn.reshape(b * s, h_qo) + o_cross_attn = o_cross_attn.reshape(b * s, h_qo, d_v) + lse_cross_attn = lse_cross_attn.reshape(b * s, h_qo) + bb = rx.BlockBuilder.current() + merge_results = bb.emit( + rx.call_pure_packed( + "vm.builtin.attention_kv_cache_merge_attn_output_inplace", + self._expr, + o_self_attn._expr, + lse_self_attn._expr, + o_cross_attn._expr, + lse_cross_attn._expr, + sinfo_args=rx.TupleStructInfo( + [o_self_attn._expr.struct_info, lse_self_attn._expr.struct_info] + ), ) - ).reshape(b, s, h_qo, d_v) + ) + assert isinstance(merge_results.struct_info, rx.TupleStructInfo) + assert len(merge_results.struct_info.fields) == 2 + o_self_attn = Tensor(_expr=bb.emit(rx.TupleGetItem(merge_results, 0))).reshape( + b, s, h_qo, d_v + ) + lse_self_attn = Tensor(_expr=bb.emit(rx.TupleGetItem(merge_results, 1))).reshape(b, s, h_qo) + return o_self_attn, lse_self_attn def get_query_positions(self, total_length: tir.PrimExpr) -> Tensor: """Get the in-sequence positions of each slot in the query, @@ -256,6 +301,7 @@ class FlashInferPagedKVCache(PagedKVCache): # pylint: disable=too-few-public-me def __init__( # pylint: disable=too-many-locals self, + attn_kind: Literal["mha", "mla"], max_batch_size: tir.Var, max_total_seq_len: tir.Var, prefill_chunk_size: tir.Var, @@ -265,7 +311,10 @@ def __init__( # pylint: disable=too-many-locals num_hidden_layers: int, num_attention_heads: int, num_key_value_heads: int, - head_dim: int, + qk_head_dim: int, + v_head_dim: int, + mla_original_qk_head_dim: int, + mla_original_v_head_dim: int, rope_mode: RopeMode, rope_scale: int, rope_theta: int, @@ -322,9 +371,65 @@ def __init__( # pylint: disable=too-many-locals Whether to enable disaggregation in the KV cache. """ if rope_mode == RopeMode.INLINE: - assert rotary_dim == head_dim, "FlashInfer RoPE does not support partial rotary dim." + assert rotary_dim == qk_head_dim, "FlashInfer RoPE does not support partial rotary dim." + + flashinfer_prefill_mods = rx.backend.cuda.flashinfer.gen_flashinfer_prefill_module( + dtype_q=dtype, + dtype_kv=dtype, + dtype_o=dtype, + qk_head_dim=qk_head_dim if attn_kind == "mha" else mla_original_qk_head_dim, + v_head_dim=v_head_dim if attn_kind == "mha" else mla_original_v_head_dim, + target=target, + enable_inline_rope=rope_mode == RopeMode.INLINE, + ) + flashinfer_decode_mods = ( + rx.backend.cuda.flashinfer.gen_flashinfer_decode_module( + dtype_q=dtype, + dtype_kv=dtype, + dtype_o=dtype, + qk_head_dim=qk_head_dim, + v_head_dim=v_head_dim, + target=target, + ) + if attn_kind == "mha" + else [] + ) + flashinfer_mla_mods = ( + rx.backend.cuda.flashinfer.gen_flashinfer_mla_module( + dtype_q=dtype, + dtype_kv=dtype, + dtype_o=dtype, + head_dim_ckv=v_head_dim, + head_dim_kpe=qk_head_dim - v_head_dim, + target=target, + ) + if attn_kind == "mla" + else [] + ) + self.extern_mods = flashinfer_prefill_mods + flashinfer_decode_mods + flashinfer_mla_mods + # fmt: off + # pylint: disable=line-too-long bb = rx.BlockBuilder.current() + mha_functions = ( + [ + rx.Tuple([rx.StringImm("flashinfer"), rx.ExternFunc("batch_prefill_with_paged_kv_cache_run"), rx.ExternFunc("batch_prefill_with_kv_cache_plan")]), + rx.Tuple([rx.StringImm("flashinfer"), rx.ExternFunc("batch_decode_with_paged_kv_cache_run"), rx.ExternFunc("batch_decode_with_paged_kv_cache_plan")]), + rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, True, rope_scaling, target), "tir_attention_prefill_sliding_window")]), + rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, True, rope_scaling, target), "tir_attention_decode_sliding_window")]), + rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache")]), + rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask")]), + ] + if attn_kind == "mha" + else [rx.Tuple([]) for _ in range(6)] + ) + mla_function = rx.Tuple([rx.StringImm("flashinfer"), rx.ExternFunc("batch_mla_paged_attention_run"), rx.ExternFunc("batch_mla_paged_attention_plan")] if attn_kind == "mla" else []) + attn_merge_functions = [ + bb.add_func(_merge_state_inplace(num_attention_heads, v_head_dim, dtype, target, "tir_attention_merge_state"), "tir_attention_merge_state"), + ] + if attn_kind == "mla": + attn_merge_functions.append(bb.add_func(_merge_state_inplace(num_attention_heads, mla_original_v_head_dim, dtype, target, "tir_attention_merge_state_mla"), "tir_attention_merge_state_mla")) + args = [ rx.ShapeExpr( [ @@ -338,34 +443,27 @@ def __init__( # pylint: disable=too-many-locals layer_partition, rx.PrimValue(num_attention_heads), rx.PrimValue(num_key_value_heads), - rx.PrimValue(head_dim), + rx.PrimValue(qk_head_dim), + rx.PrimValue(v_head_dim), + rx.ShapeExpr( + [int(getattr(AttnKind, attn_kind.upper())) for _ in range(num_hidden_layers)] + ), + rx.PrimValue(enable_disaggregation), rx.PrimValue(rope_mode), rx.PrimValue(rope_scale), rx.PrimValue(rope_theta), - rx.op.zeros((), dtype), - # pylint: disable=line-too-long - # fmt: off - bb.add_func(_kv_cache_transpose_append(num_key_value_heads, head_dim, dtype), "kv_cache_transpose_append"), - rx.extern("flashinfer.attention_kernel_prefill_with_paged_kv_cache"), - rx.extern("flashinfer.attention_kernel_decode_with_paged_kv_cache"), - bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, head_dim, dtype, True, rope_scaling, target), "tir_attention_prefill_sliding_window"), - bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, head_dim, dtype, True, rope_scaling, target), "tir_attention_decode_sliding_window"), - rx.extern("flashinfer.attention_kernel_prefill_with_ragged_kv_cache"), - rx.extern("flashinfer.attention_kernel_prefill_with_ragged_kv_cache_begin_forward"), - rx.extern("flashinfer.attention_kernel_prefill_with_ragged_kv_cache_end_forward"), - rx.extern("flashinfer.attention_kernel_prefill_with_paged_kv_cache_begin_forward"), - rx.extern("flashinfer.attention_kernel_prefill_with_paged_kv_cache_end_forward"), - rx.extern("flashinfer.attention_kernel_decode_with_paged_kv_cache_begin_forward"), - rx.extern("flashinfer.attention_kernel_decode_with_paged_kv_cache_end_forward"), - rx.extern("flashinfer.merge_state_in_place"), - bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, head_dim, num_attention_heads, num_key_value_heads, dtype, rope_scaling, rotary_dim), "tir_split_rotary"), - bb.add_func(_copy_single_page(num_key_value_heads, page_size, head_dim, dtype, target), "kv_cache_copy_single_page"), - bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype), "kv_cache_debug_get_kv"), - bb.add_func(_compact_kv_copy(num_key_value_heads, head_dim, dtype, target), "kv_cache_compact_kv_copy"), - bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask"), - bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache"), rope_ext_factors, - rx.PrimValue(enable_disaggregation), + rx.op.zeros((), dtype), + bb.add_func(_kv_cache_transpose_append(num_key_value_heads, qk_head_dim, dtype), "kv_cache_transpose_append"), + bb.add_func(_kv_cache_transpose_append_mla(qk_head_dim, dtype), "kv_cache_transpose_append_mla"), + rx.Tuple([rx.StringImm("flashinfer"), rx.ExternFunc("batch_prefill_with_ragged_kv_cache_run"), rx.ExternFunc("batch_prefill_with_kv_cache_plan")]), + *mha_functions, + mla_function, + rx.Tuple(attn_merge_functions), + bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, qk_head_dim, num_attention_heads, num_key_value_heads, dtype, rope_scaling, rotary_dim), "tir_split_rotary"), + bb.add_func(_copy_single_page(num_key_value_heads, page_size, qk_head_dim, dtype, target) if attn_kind == "mha" else _copy_single_page_mla(page_size, qk_head_dim, dtype, target), "kv_cache_copy_single_page"), + bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, qk_head_dim, dtype), "kv_cache_debug_get_kv"), + bb.add_func(_compact_kv_copy(num_key_value_heads, qk_head_dim, dtype, target), "kv_cache_compact_kv_copy"), # fmt: on # pylint: enable=line-too-long ] @@ -384,6 +482,7 @@ class TIRPagedKVCache(PagedKVCache): # pylint: disable=too-few-public-methods def __init__( # pylint: disable=too-many-locals self, + attn_kind: Literal["mha", "mla"], max_batch_size: tir.Var, max_total_seq_len: tir.Var, prefill_chunk_size: tir.Var, @@ -393,8 +492,11 @@ def __init__( # pylint: disable=too-many-locals num_hidden_layers: int, num_attention_heads: int, num_key_value_heads: int, + qk_head_dim: int, + v_head_dim: int, + mla_original_qk_head_dim: int, + mla_original_v_head_dim: int, rope_mode: RopeMode, - head_dim: int, rope_scale: int, rope_theta: int, rope_scaling: Dict[str, Any], @@ -466,37 +568,45 @@ def __init__( # pylint: disable=too-many-locals layer_partition, rx.PrimValue(num_attention_heads), rx.PrimValue(num_key_value_heads), - rx.PrimValue(head_dim), + rx.PrimValue(qk_head_dim), + rx.PrimValue(v_head_dim), + rx.ShapeExpr( + [int(getattr(AttnKind, attn_kind.upper())) for _ in range(num_hidden_layers)] + ), + rx.PrimValue(enable_disaggregation), rx.PrimValue(rope_mode), rx.PrimValue(rope_scale), rx.PrimValue(rope_theta), + rope_ext_factors, rx.op.zeros((), dtype), # pylint: disable=line-too-long # fmt: off - bb.add_func(_kv_cache_transpose_append(num_key_value_heads, head_dim, dtype), "kv_cache_transpose_append"), + bb.add_func(_kv_cache_transpose_append(num_key_value_heads, qk_head_dim, dtype), "kv_cache_transpose_append"), + bb.add_func(_kv_cache_transpose_append_mla(qk_head_dim, dtype), "kv_cache_transpose_append_mla"), # fmt: on # pylint: enable=line-too-long ] if str(target.kind) == "llvm": + if attn_kind == "mla": + raise ValueError("MLA is not supported in TIR kernels for now.") # pylint: disable=line-too-long # fmt: off args.extend( [ - bb.add_func(_attention_prefill_cpu(num_key_value_heads, num_attention_heads, head_dim, dtype, False, rope_scaling), "tir_attention_prefill_cpu"), - bb.add_func(_attention_decode_cpu(num_key_value_heads, num_attention_heads, head_dim, dtype, False, rope_scaling), "tir_attention_decode_cpu"), - bb.add_func(_attention_prefill_cpu(num_key_value_heads, num_attention_heads, head_dim, dtype, True, rope_scaling), "tir_attention_prefill_cpu_sliding_window"), - bb.add_func(_attention_decode_cpu(num_key_value_heads, num_attention_heads, head_dim, dtype, True, rope_scaling), "tir_attention_decode_cpu_sliding_window"), - bb.add_func(_attention_prefill_ragged_cpu(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling), "tir_attention_prefill_ragged_cpu"), - bb.add_func(_merge_state_inplace_cpu(dtype), "tir_attention_merge_state_cpu"), - bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, head_dim, num_attention_heads, num_key_value_heads, dtype, rope_scaling, rotary_dim), "tir_split_rotary"), - bb.add_func(_copy_single_page_cpu(num_key_value_heads, page_size, head_dim, dtype), "kv_cache_copy_single_page_cpu"), - bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype), "kv_cache_debug_get_kv"), - bb.add_func(_compact_kv_copy_cpu(num_key_value_heads, head_dim, dtype), "kv_cache_compact_kv_copy_cpu"), - bb.add_func(tree_attn_cpu(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling), "tir_attention_prefill_with_tree_mask_cpu"), - bb.add_func(tree_attn_with_paged_kv_cache_cpu(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache_cpu"), - rope_ext_factors, - rx.PrimValue(enable_disaggregation), + rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill_ragged_cpu(num_key_value_heads, num_attention_heads, qk_head_dim, v_head_dim, dtype, rope_scaling), "tir_attention_prefill_ragged_cpu")]), + rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill_cpu(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, False, rope_scaling), "tir_attention_prefill_cpu")]), + rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_decode_cpu(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, False, rope_scaling), "tir_attention_decode_cpu")]), + rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill_cpu(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, True, rope_scaling), "tir_attention_prefill_cpu_sliding_window")]), + rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_decode_cpu(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, True, rope_scaling), "tir_attention_decode_cpu_sliding_window")]), + rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn_cpu(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling), "tir_attention_prefill_with_tree_mask_cpu")]), + rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn_with_paged_kv_cache_cpu(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache_cpu")]), + rx.Tuple([]), # f_mla_prefill + rx.Tuple([bb.add_func(_merge_state_inplace_cpu(dtype), "tir_attention_merge_state_cpu")]), + bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, qk_head_dim, num_attention_heads, num_key_value_heads, dtype, rope_scaling, rotary_dim), "tir_split_rotary"), + bb.add_func(_copy_single_page_cpu(num_key_value_heads, page_size, qk_head_dim, dtype), "kv_cache_copy_single_page_cpu"), + bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, qk_head_dim, dtype), "kv_cache_debug_get_kv"), + bb.add_func(_compact_kv_copy_cpu(num_key_value_heads, qk_head_dim, dtype), "kv_cache_compact_kv_copy_cpu"), ] ) # fmt: on @@ -504,22 +614,36 @@ def __init__( # pylint: disable=too-many-locals else: # pylint: disable=line-too-long # fmt: off + ragged_qk_head_dim = qk_head_dim if attn_kind == "mha" else mla_original_qk_head_dim + ragged_v_head_dim = v_head_dim if attn_kind == "mha" else mla_original_v_head_dim + args.append(rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill_ragged(num_key_value_heads if attn_kind == "mha" else num_attention_heads, num_attention_heads, ragged_qk_head_dim, ragged_v_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_ragged")])) + mha_functions = ( + [ + rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, False, rope_scaling, target), "tir_attention_prefill")]), + rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, False, rope_scaling, target), "tir_attention_decode")]), + rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, True, rope_scaling, target), "tir_attention_prefill_sliding_window")]), + rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, True, rope_scaling, target), "tir_attention_decode_sliding_window")]), + rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache")]), + rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask")]), + ] + if attn_kind == "mha" + else [rx.Tuple([]) for _ in range(6)] + ) + mla_function = rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill_mla(num_attention_heads, v_head_dim, qk_head_dim - v_head_dim, dtype, False, target), "tir_attention_prefill_mla")] if attn_kind == "mla" else []) + attn_merge_functions = [ + bb.add_func(_merge_state_inplace(num_attention_heads, v_head_dim, dtype, target, "tir_attention_merge_state"), "tir_attention_merge_state"), + ] + if attn_kind == "mla": + attn_merge_functions.append(bb.add_func(_merge_state_inplace(num_attention_heads, mla_original_v_head_dim, dtype, target, "tir_attention_merge_state_mla"), "tir_attention_merge_state_mla")) + args.extend(mha_functions) + args.append(mla_function) args.extend( [ - bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, head_dim, dtype, False, rope_scaling, target), "tir_attention_prefill"), - bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, head_dim, dtype, False, rope_scaling, target), "tir_attention_decode"), - bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, head_dim, dtype, True, rope_scaling, target), "tir_attention_prefill_sliding_window"), - bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, head_dim, dtype, True, rope_scaling, target), "tir_attention_decode_sliding_window"), - bb.add_func(_attention_prefill_ragged(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling, target), "tir_attention_prefill_ragged"), - bb.add_func(_merge_state_inplace(num_attention_heads, head_dim, dtype, target), "tir_attention_merge_state"), - bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, head_dim, num_attention_heads, num_key_value_heads, dtype, rope_scaling, rotary_dim), "tir_split_rotary"), - bb.add_func(_copy_single_page(num_key_value_heads, page_size, head_dim, dtype, target), "kv_cache_copy_single_page"), - bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype), "kv_cache_debug_get_kv"), - bb.add_func(_compact_kv_copy(num_key_value_heads, head_dim, dtype, target), "kv_cache_compact_kv_copy"), - bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask"), - bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache"), - rope_ext_factors, - rx.PrimValue(enable_disaggregation), + rx.Tuple(attn_merge_functions), + bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, qk_head_dim, num_attention_heads, num_key_value_heads, dtype, rope_scaling, rotary_dim), "tir_split_rotary"), + bb.add_func(_copy_single_page(num_key_value_heads, page_size, qk_head_dim, dtype, target) if attn_kind == "mha" else _copy_single_page_mla(page_size, qk_head_dim, dtype, target), "kv_cache_copy_single_page"), + bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, qk_head_dim, dtype), "kv_cache_debug_get_kv"), + bb.add_func(_compact_kv_copy(num_key_value_heads, qk_head_dim, dtype, target), "kv_cache_compact_kv_copy"), ] ) # fmt: on @@ -527,130 +651,7 @@ def __init__( # pylint: disable=too-many-locals super().__init__( _expr=rx.call_pure_packed( - "vm.builtin.paged_attention_kv_cache_create_reduced", - *args, - sinfo_args=rx.ObjectStructInfo(), - ), - _name=name, - ) - - @staticmethod - def create_mla_kv_cache( # pylint: disable=too-many-locals - max_batch_size: tir.Var, - max_total_seq_len: tir.Var, - prefill_chunk_size: tir.Var, - page_size: tir.Var, - support_sliding_window: tir.Var, - layer_partition: rx.ShapeExpr, - num_hidden_layers: int, - num_attention_heads: int, - num_key_value_heads: int, - qk_nope_head_dim: int, - qk_rope_head_dim: int, - v_head_dim: int, - kv_lora_rank: int, - enable_disaggregation: bool, - dtype: str, - target: Target, - name: str = "paged_kv_cache", - ) -> PagedKVCache: - """Create a paged KV cache object with TIR kernels with multi-head latent attention. - - Parameters - ---------- - max_batch_size : tir.Var - The maximum allowed batch size of the KV cache. - It is a symbolic variable whose concrete value is specified - at runtime. - max_total_seq_len : tir.Var - The maximum allowed total sequence length of the KV cache. - It is a symbolic variable whose concrete value is specified - at runtime. - prefill_chunk_size : tir.Var - The maximum total sequence length in a prefill. - It is a symbolic variable whose concrete value is specified - at runtime. - page_size : tir.Var - The size (a.k.a. number of tokens) of each page. - It is a symbolic variable whose concrete value is specified - at runtime. - support_sliding_window : tir.Var - 0 or 1, denoting whether the KV cache supports sliding window. - It is a symbolic variable whose concrete value is specified - at runtime. - layer_partition : rx.ShapeExpr - The KV cache layer partition for pipeline stages. - It is an indptr array, denoting the starting layer of each pipeline stage. - qk_nope_head_dim : int - The head dim size (RoPE excluded) for queries and keys in MLA. - qk_rope_head_dim : int - The head dim size (RoPE included) for queries and keys in MLA. - v_head_dim : int - The head dim size for values in MLA. - kv_lora_rank : int - The LoRA rank for keys and values in MLA. - enable_disaggregation : bool - Whether to enable disaggregation in the KV cache. - target : Target - The target to build the model to. - """ - - bb = rx.BlockBuilder.current() - args = [ - rx.ShapeExpr( - [ - max_batch_size, - max_total_seq_len, - prefill_chunk_size, - page_size, - support_sliding_window, - ] - ), - layer_partition, - rx.PrimValue(num_attention_heads), - rx.PrimValue(1), - rx.PrimValue(kv_lora_rank + qk_rope_head_dim), - rx.PrimValue(kv_lora_rank), - rx.PrimValue(qk_rope_head_dim), - rx.ShapeExpr([int(AttnKind.MLA) for _ in range(num_hidden_layers)]), - rx.PrimValue(RopeMode.NONE), - rx.PrimValue(1), - rx.PrimValue(10000), - rx.op.zeros((), dtype), - # pylint: disable=line-too-long - # fmt: off - bb.add_func(_kv_cache_transpose_append(num_key_value_heads, v_head_dim, dtype), "kv_cache_transpose_append"), - bb.add_func(_kv_cache_transpose_append_mla(kv_lora_rank, qk_rope_head_dim, dtype), "kv_cache_transpose_append_mla"), - bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, qk_nope_head_dim, dtype, False, {}, target), "tir_attention_prefill"), - bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, qk_nope_head_dim, dtype, False, {}, target), "tir_attention_decode"), - bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, qk_nope_head_dim, dtype, True, {}, target), "tir_attention_prefill_sliding_window"), - bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, qk_nope_head_dim, dtype, True, {}, target), "tir_attention_decode_sliding_window"), - bb.add_func(_attention_prefill_ragged(num_key_value_heads, num_attention_heads, qk_nope_head_dim, dtype, {}, target), "tir_attention_prefill_ragged"), - rx.PrimValue(0), - rx.PrimValue(0), - rx.PrimValue(0), - rx.PrimValue(0), - rx.PrimValue(0), - rx.PrimValue(0), - bb.add_func(_attention_prefill_mla(num_attention_heads, kv_lora_rank, qk_rope_head_dim, dtype, False, target), "tir_attention_prefill_mla"), - bb.add_func(_attention_decode_mla(num_attention_heads, kv_lora_rank, qk_rope_head_dim, dtype, False, target), "tir_attention_decode_mla"), - bb.add_func(_attention_prefill_ragged_generic(num_key_value_heads, num_attention_heads, qk_rope_head_dim, v_head_dim, dtype, {}, target), "tir_attention_prefill_ragged_mla_normal"), - bb.add_func(_attention_prefill_ragged_mla_absorbed(num_attention_heads, kv_lora_rank, qk_rope_head_dim, dtype, target), "tir_attention_prefill_ragged_mla_absorbed"), - bb.add_func(_merge_state_inplace(num_attention_heads, kv_lora_rank, dtype, target), "tir_attention_merge_state"), - bb.add_func(llama_rope_with_position_map(10000, 1, qk_rope_head_dim, num_attention_heads, num_key_value_heads, dtype, {}, None), "tir_split_rotary"), - bb.add_func(_copy_single_page_mla(page_size, kv_lora_rank + qk_rope_head_dim, dtype, target), "kv_cache_copy_single_page_mla"), - bb.add_func(_kv_cache_debug_get_kv_mla(num_hidden_layers, kv_lora_rank + qk_rope_head_dim, dtype), "kv_cache_debug_get_kv_mla"), - bb.add_func(_compact_kv_copy(num_key_value_heads, qk_nope_head_dim, dtype, target), "kv_cache_compact_kv_copy"), - bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, qk_nope_head_dim, dtype, {}, target), "tir_attention_prefill_with_tree_mask"), - bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, qk_nope_head_dim, dtype, {}, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache"), - rx.PrimValue(0), - rx.PrimValue(enable_disaggregation), - # fmt: on - # pylint: enable=line-too-long - ] - return PagedKVCache( - _expr=rx.call_pure_packed( - "vm.builtin.paged_attention_kv_cache_create_reduced_mla", + "vm.builtin.paged_attention_kv_cache_create", *args, sinfo_args=rx.ObjectStructInfo(), ), @@ -662,7 +663,7 @@ def create_mla_kv_cache( # pylint: disable=too-many-locals # pylint: disable=too-many-locals -def _kv_cache_transpose_append(num_key_value_heads, head_dim, dtype): +def _kv_cache_transpose_append(num_key_value_heads, head_dim, dtype, page_size: int = 16): """Return the TIR function that appends new k/v data to PagedKVCache.""" # pylint: disable=line-too-long @@ -679,7 +680,7 @@ def tir_kv_cache_transpose_append( num_pages = T.int64() pages_elem_offset = T.int64() position_map_elem_offset = T.int32() - pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads, 16, head_dim), dtype, elem_offset=pages_elem_offset) + pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads, page_size, head_dim), dtype, elem_offset=pages_elem_offset) k_data = T.match_buffer(var_k_data, (ntoken, num_key_value_heads, head_dim), dtype) v_data = T.match_buffer(var_v_data, (ntoken, num_key_value_heads, head_dim), dtype) position_map = T.match_buffer( @@ -690,22 +691,22 @@ def tir_kv_cache_transpose_append( with T.block("k_transpose_append"): vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) T.reads(position_map[vgpos], k_data[vgpos, vh, vf]) - T.writes(pages[position_map[vgpos] // 16, 0, vh, position_map[vgpos] % 16, vf]) + T.writes(pages[position_map[vgpos] // page_size, 0, vh, position_map[vgpos] % page_size, vf]) position: T.int32 = position_map[vgpos] # type: ignore - pages[T.floordiv(position, 16), 0, vh, T.floormod(position, 16), vf] = k_data[vgpos, vh, vf] + pages[T.floordiv(position, page_size), 0, vh, T.floormod(position, page_size), vf] = k_data[vgpos, vh, vf] with T.block("v_transpose_append"): vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) T.reads(position_map[vgpos], v_data[vgpos, vh, vf]) - T.writes(pages[position_map[vgpos] // 16, 1, vh, position_map[vgpos] % 16, vf]) + T.writes(pages[position_map[vgpos] // page_size, 1, vh, position_map[vgpos] % page_size, vf]) position: T.int32 = position_map[vgpos] # type: ignore[name-defined,no-redef] - pages[T.floordiv(position, 16), 1, vh, T.floormod(position, 16), vf] = v_data[vgpos, vh, vf] + pages[T.floordiv(position, page_size), 1, vh, T.floormod(position, page_size), vf] = v_data[vgpos, vh, vf] # fmt: on # pylint: enable=line-too-long return tir_kv_cache_transpose_append -def _kv_cache_transpose_append_mla(kv_lora_rank: int, qk_rope_head_dim: int, dtype): +def _kv_cache_transpose_append_mla(d_qk: int, dtype, page_size: int = 16): """Return the TIR function that appends new compressed KV data to PagedKVCache for MLA.""" # pylint: disable=line-too-long @@ -713,8 +714,7 @@ def _kv_cache_transpose_append_mla(kv_lora_rank: int, qk_rope_head_dim: int, dty @T.prim_func def tir_kv_cache_transpose_append_mla( var_pages: T.handle, - var_compressed_kv_data: T.handle, - var_k_pe_data: T.handle, + var_kv_data: T.handle, var_position_map: T.handle, ): T.func_attr({"tir.noalias": T.bool(True)}) @@ -722,20 +722,19 @@ def tir_kv_cache_transpose_append_mla( num_pages = T.int64() pages_elem_offset = T.int64() position_map_elem_offset = T.int32() - pages = T.match_buffer(var_pages, (num_pages, 16, kv_lora_rank + qk_rope_head_dim), dtype, elem_offset=pages_elem_offset) - compressed_kv_data = T.match_buffer(var_compressed_kv_data, (ntoken, kv_lora_rank), dtype) - k_pe_data = T.match_buffer(var_k_pe_data, (ntoken, qk_rope_head_dim), dtype) + pages = T.match_buffer(var_pages, (num_pages, page_size, d_qk), dtype, elem_offset=pages_elem_offset) + kv_data = T.match_buffer(var_kv_data, (ntoken, d_qk), dtype) position_map = T.match_buffer( var_position_map, (ntoken,), "int32", elem_offset=position_map_elem_offset ) - for global_pos, f in T.grid(ntoken, kv_lora_rank + qk_rope_head_dim): + for global_pos, f in T.grid(ntoken, d_qk): if position_map[global_pos] != T.int32(-1): with T.block("k_transpose_append"): vgpos, vf = T.axis.remap("SS", [global_pos, f]) - T.reads(position_map[vgpos], compressed_kv_data[vgpos, vf], k_pe_data[vgpos, vf - kv_lora_rank]) - T.writes(pages[position_map[vgpos] // 16, position_map[vgpos] % 16, vf]) + T.reads(position_map[vgpos], kv_data[vgpos, vf]) + T.writes(pages[position_map[vgpos] // page_size, position_map[vgpos] % page_size, vf]) position: T.int32 = position_map[vgpos] # type: ignore - pages[T.floordiv(position, 16), T.floormod(position, 16), vf] = T.if_then_else(vf < kv_lora_rank, compressed_kv_data[vgpos, vf], k_pe_data[vgpos, vf - kv_lora_rank]) + pages[T.floordiv(position, page_size), T.floormod(position, page_size), vf] = kv_data[vgpos, vf] # fmt: on # pylint: enable=line-too-long @@ -886,18 +885,18 @@ def _get_seq_offset(pos, seq_id, length_info, sliding_window): ) -def _attention_prefill_cpu(h_kv, h_q, d, dtype, sliding_window: bool, rope_scaling: Dict[str, Any]): +def _attention_prefill_cpu( + h_kv, h_q, d, dtype, sliding_window: bool, rope_scaling: Dict[str, Any], page_size: int = 16 +): global_symbol = "batch_prefill_paged_kv_cpu" if sliding_window: global_symbol += "_sliding_window" group_size = h_q // h_kv - sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1)) # pylint: disable=line-too-long,too-many-branches # fmt: off @T.prim_func def batch_prefill_paged_kv_cpu( - _0: T.int32, # pylint: disable=unused-argument var_q: T.handle, # [total_len, h_q, d] var_q_indptr: T.handle, # [batch_size + 1] var_pages: T.handle, # [max_num_pages, 2, h_kv, page_size, d] @@ -912,7 +911,7 @@ def batch_prefill_paged_kv_cpu( rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, - attn_score_scaling_factor: T.float32, + sm_scale: T.float32, ): T.func_attr({"global_symbol": global_symbol}) batch_size = T.int32(is_size_var=True) @@ -928,7 +927,7 @@ def batch_prefill_paged_kv_cpu( q = T.match_buffer(var_q, (total_len, h_q, d), dtype) q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", elem_offset=q_indptr_elem_offset) - pages = T.match_buffer(var_pages, (max_num_pages, 2, h_kv, 16, d), dtype) + pages = T.match_buffer(var_pages, (max_num_pages, 2, h_kv, page_size, d), dtype) page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", elem_offset=page_indptr_elem_offset) page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", elem_offset=page_values_elem_offset) k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", elem_offset=k_rope_pos_offset_elem_offset) @@ -964,10 +963,10 @@ def batch_prefill_paged_kv_cpu( factor = T.alloc_buffer((1, ), "float32") cur_page_indptr_begin: T.int32 = page_indptr[b_idx] cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1] - #max_kv_len: T.int32 = max_num_pages * 16 + #max_kv_len: T.int32 = max_num_pages * page_size kv_chunk_len[0] = T.if_then_else( cur_page_indptr_begin != cur_page_indptr_end, - _get_kv_chunk_len(cur_page_indptr_end - cur_page_indptr_begin, 16, b_idx, length_info, sliding_window), + _get_kv_chunk_len(cur_page_indptr_end - cur_page_indptr_begin, page_size, b_idx, length_info, sliding_window), 0 ) @@ -987,12 +986,12 @@ def batch_prefill_paged_kv_cpu( _rope(q, q_rope_position[curl_q], d, rope_theta, rope_scale, (curl_q, h_qo, d_idx), dtype, rope_scaling), q[curl_q, h_qo, d_idx] ) - for row_idx in T.serial(max_num_pages * 16): + for row_idx in T.serial(max_num_pages * page_size): if row_idx < kv_chunk_len[0]: # seq_offset: T.int32(is_size_var=True) = _get_seq_offset(row_idx, b_idx, length_info, sliding_window) #seq_offset: T.int32(is_size_var=True) = row_idx - page_no: T.int32(is_size_var=True) = page_values[cur_page_indptr_begin + (_get_seq_offset(row_idx, b_idx, length_info, sliding_window) // 16)] - page_offset: T.int32(is_size_var=True) = _get_seq_offset(row_idx, b_idx, length_info, sliding_window) % 16 + page_no: T.int32(is_size_var=True) = page_values[cur_page_indptr_begin + (_get_seq_offset(row_idx, b_idx, length_info, sliding_window) // page_size)] + page_offset: T.int32(is_size_var=True) = _get_seq_offset(row_idx, b_idx, length_info, sliding_window) % page_size # Load KV for d_idx in T.serial(d): @@ -1004,11 +1003,11 @@ def batch_prefill_paged_kv_cpu( V_local[d_idx] = pages[page_no, 1, h_qo // group_size, page_offset, d_idx] # Compute S - # Q[i] * K[i] * attn_score * sm_scale + # Q[i] * K[i] * sm_scale S_val[0] = 0.0 for d_idx in T.serial(d): S_val[0] += Q_local[d_idx] * K_local[d_idx] - S_val[0] *= attn_score_scaling_factor * sm_scale + S_val[0] *= sm_scale * math.log2(math.exp(1)) # update m_val, d_val , O_local if _causal_mask(causal, @@ -1045,7 +1044,6 @@ def _get_prefill_kernel_config(h_kv, h_q, d, dtype, target: Target): NUM_BLKS = 16 LOAD_VEC = 8 // ((DataType(dtype).bits + 7) // 8) # 8 bytes group_size = h_q // h_kv - sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1)) bdx = 32 num_warps = 4 @@ -1076,7 +1074,7 @@ def _get_prefill_kernel_config(h_kv, h_q, d, dtype, target: Target): check_thread_limits(target, bdx=bdx, bdy=num_warps, bdz=1, gdz=1) - return NUM_BLKS, LOAD_VEC, group_size, sm_scale, bdx, num_warps, tile_x, tile_y, tile_z + return NUM_BLKS, LOAD_VEC, group_size, bdx, num_warps, tile_x, tile_y, tile_z def _schedule_prefill_kernel( @@ -1187,13 +1185,19 @@ def apply_to_md(sch, block): def _attention_prefill( - h_kv, h_q, d, dtype, sliding_window: bool, rope_scaling: Dict[str, Any], target: Target + h_kv, + h_q, + d, + dtype, + sliding_window: bool, + rope_scaling: Dict[str, Any], + target: Target, + page_size: int = 16, ): ( NUM_BLKS, LOAD_VEC, group_size, - sm_scale, bdx, num_warps, tile_x, @@ -1209,7 +1213,6 @@ def _attention_prefill( # fmt: off @T.prim_func def batch_prefill_paged_kv( - _0: T.int32, # pylint: disable=unused-argument var_q: T.handle, # [total_len, h_q, d] var_q_indptr: T.handle, # [batch_size + 1] var_pages: T.handle, # [max_num_pages, 2, h_kv, page_size, d] @@ -1224,7 +1227,7 @@ def batch_prefill_paged_kv( rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, - attn_score_scaling_factor: T.float32, + sm_scale: T.float32, ): T.func_attr({"global_symbol": global_symbol}) batch_size = T.int32(is_size_var=True) @@ -1241,7 +1244,7 @@ def batch_prefill_paged_kv( q = T.match_buffer(var_q, (total_len, h_q, d), dtype) q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", elem_offset=q_indptr_elem_offset) - pages = T.match_buffer(var_pages, (max_num_pages, 2, h_kv, 16, d), dtype, elem_offset=pages_elem_offset) + pages = T.match_buffer(var_pages, (max_num_pages, 2, h_kv, page_size, d), dtype, elem_offset=pages_elem_offset) page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", elem_offset=page_indptr_elem_offset) page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", elem_offset=page_values_elem_offset) k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", elem_offset=k_rope_pos_offset_elem_offset) @@ -1314,7 +1317,7 @@ def batch_prefill_paged_kv( cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1] kv_chunk_len[0] = T.if_then_else( cur_page_indptr_begin != cur_page_indptr_end, - _get_kv_chunk_len(cur_page_indptr_end - cur_page_indptr_begin, 16, b_idx, length_info, sliding_window), + _get_kv_chunk_len(cur_page_indptr_end - cur_page_indptr_begin, page_size, b_idx, length_info, sliding_window), 0 ) T.tvm_storage_sync("shared") @@ -1360,8 +1363,8 @@ def batch_prefill_paged_kv( cur_L = L_kv_start + i if cur_L < kv_chunk_len[0]: seq_offset: T.int32(is_size_var=True) = _get_seq_offset(cur_L, b_idx, length_info, sliding_window) # type: ignore - page_no: T.int32(is_size_var=True) = page_values[cur_page_indptr_begin + T.floordiv(seq_offset, 16)] # type: ignore - page_offset: T.int32(is_size_var=True) = T.floormod(seq_offset, 16) # type: ignore + page_no: T.int32(is_size_var=True) = page_values[cur_page_indptr_begin + T.floordiv(seq_offset, page_size)] # type: ignore + page_offset: T.int32(is_size_var=True) = T.floormod(seq_offset, page_size) # type: ignore K_smem[i, j] = T.if_then_else( rotary_mode == 1, _rope(pages, k_rope_pos_offset[b_idx] + cur_L, d, rope_theta, rope_scale, (page_no, 0, by, page_offset, j), dtype, rope_scaling), @@ -1378,8 +1381,8 @@ def batch_prefill_paged_kv( cur_L = L_kv_start + i if cur_L < kv_chunk_len[0]: seq_offset: T.int32(is_size_var=True) = _get_seq_offset(cur_L, b_idx, length_info, sliding_window) # type: ignore - page_no: T.int32(is_size_var=True) = page_values[cur_page_indptr_begin + T.floordiv(seq_offset, 16)] # type: ignore - page_offset: T.int32(is_size_var=True) = T.floormod(seq_offset, 16) # type: ignore + page_no: T.int32(is_size_var=True) = page_values[cur_page_indptr_begin + T.floordiv(seq_offset, page_size)] # type: ignore + page_offset: T.int32(is_size_var=True) = T.floormod(seq_offset, page_size) # type: ignore V_smem[i, j] = pages[page_no, 1, by, page_offset, j] else: V_smem[i, j] = 0.0 @@ -1392,7 +1395,7 @@ def batch_prefill_paged_kv( i, j, k = T.axis.remap("SSR", [li, lj, lk]) with T.init(): S_local[i, j] = 0.0 - S_local[i, j] += T.cast(Q_smem[i, k], "float32") * T.cast(K_smem[j, k], "float32") * attn_score_scaling_factor * sm_scale + S_local[i, j] += T.cast(Q_smem[i, k], "float32") * T.cast(K_smem[j, k], "float32") * sm_scale * math.log2(math.exp(1)) T.tvm_storage_sync("shared") for li, lj in T.grid(tile_x, tile_z): with T.block("S_store"): @@ -1490,8 +1493,8 @@ def _attention_decode_cpu( qkv_dtype, sliding_window: bool, rope_scaling: Dict[str, Any], + page_size: int = 16, ): - log2e = math.log2(math.exp(1)) H_qo = num_qo_heads H_kv = num_kv_heads D = head_dim @@ -1501,9 +1504,10 @@ def _attention_decode_cpu( if sliding_window: global_symbol += "_sliding_window" + # fmt: off + # pylint: disable=line-too-long @T.prim_func(check_well_formed=False) def batch_decode_paged_kv( - _0: T.int32, # pylint: disable=unused-argument Q_handle: T.handle, pages_handle: T.handle, page_table_indptr_handle: T.handle, @@ -1516,7 +1520,7 @@ def batch_decode_paged_kv( rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, - attn_score_scaling_factor: T.float32, + sm_scale: T.float32, ): T.func_attr({"tir.is_scheduled": 1, "global_symbol": global_symbol}) B = T.int32(is_size_var=True) @@ -1528,8 +1532,8 @@ def batch_decode_paged_kv( q_rope_position_elem_offset = T.int32(is_size_var=True) length_info_elem_offset = T.int32(is_size_var=True) - Q = T.match_buffer(Q_handle, (B, H_qo, D), qkv_dtype) # query 值 - pages = T.match_buffer(pages_handle, (max_num_pages, 2, H_kv, 16, D), qkv_dtype) + Q = T.match_buffer(Q_handle, (B, H_qo, D), qkv_dtype) + pages = T.match_buffer(pages_handle, (max_num_pages, 2, H_kv, page_size, D), qkv_dtype) page_table_indptr = T.match_buffer( page_table_indptr_handle, (B + 1,), "int32", elem_offset=page_indptr_elem_offset ) @@ -1556,8 +1560,6 @@ def batch_decode_paged_kv( var_length_info, B, sliding_window, length_info_elem_offset ) - sm_scale = 1.0 / math.sqrt(float(D)) * log2e - for b in T.serial(B): with T.block("attn"): O_local = T.alloc_buffer((D,), "float32") @@ -1579,13 +1581,7 @@ def batch_decode_paged_kv( kv_chunk_len[0] = T.if_then_else( cur_page_indptr_begin != cur_page_indptr_end, - _get_kv_chunk_len( - cur_page_indptr_end - cur_page_indptr_begin, - 16, - b, - length_info, - sliding_window, - ), + _get_kv_chunk_len(cur_page_indptr_end - cur_page_indptr_begin, page_size, b, length_info, sliding_window), 0, ) @@ -1599,47 +1595,25 @@ def batch_decode_paged_kv( for d in T.serial(D): Q_local[d] = T.if_then_else( rotary_mode == 1, - _rope( - Q, - q_rope_position[b], - head_dim, - rope_theta, - rope_scale, - (b, h_qo, d), - qkv_dtype, - rope_scaling, - ), + _rope(Q, q_rope_position[b], head_dim, rope_theta, rope_scale, (b, h_qo, d), qkv_dtype, rope_scaling), Q[b, h_qo, d], ) for row_idx in T.serial(kv_chunk_len[0]): - seq_offset: T.int32(is_size_var=True) = _get_seq_offset( - row_idx, b, length_info, sliding_window - ) - page_no: T.int32(is_size_var=True) = page_table_values[ - cur_page_indptr_begin + (seq_offset // 16) - ] - page_offset: T.int32(is_size_var=True) = seq_offset % 16 + seq_offset: T.int32(is_size_var=True) = _get_seq_offset(row_idx, b, length_info, sliding_window) + page_no: T.int32(is_size_var=True) = page_table_values[cur_page_indptr_begin + (seq_offset // page_size)] + page_offset: T.int32(is_size_var=True) = seq_offset % page_size for d in T.serial(D): K_local[d] = T.if_then_else( rotary_mode == 1, - _rope( - pages, - k_rope_pos_offset[b] + row_idx, - head_dim, - rope_theta, - rope_scale, - (page_no, 0, h_qo // group_size, page_offset, d), - qkv_dtype, - rope_scaling, - ), + _rope(pages, k_rope_pos_offset[b] + row_idx, head_dim, rope_theta, rope_scale, (page_no, 0, h_qo // group_size, page_offset, d), qkv_dtype, rope_scaling), pages[page_no, 0, h_qo // group_size, page_offset, d], ) S_val[0] = 0.0 for d in T.serial(D): S_val[0] += Q_local[d] * K_local[d] - S_val[0] *= attn_score_scaling_factor * sm_scale + S_val[0] *= sm_scale * math.log2(math.exp(1)) new_m[0] = T.max(m_val[0], S_val[0]) d_val[0] = (d_val[0] * T.exp2(m_val[0] - new_m[0])) + T.exp2( @@ -1662,6 +1636,8 @@ def batch_decode_paged_kv( O_local[d] = O_local[d] / d_val[0] output[b, h_qo, d] = O_local[d] lse[b, h_qo] = m_val[0] + T.log2(d_val[0]) + # fmt: on + # pylint: enable=line-too-long return batch_decode_paged_kv @@ -1674,6 +1650,7 @@ def _attention_decode( sliding_window: bool, rope_scaling: Dict[str, Any], target: Target, + page_size: int = 16, ): qkv_dtype_bytes = 2 H_qo = num_qo_heads @@ -1702,7 +1679,6 @@ def _attention_decode( threads_per_CTA = max(thread_limit, bdx * bdy) bdz = threads_per_CTA // (bdx * bdy) tile_size_per_bdx = TILE_SIZE_PER_BDX if GROUP_SIZE == 1 else 1 - log2e = math.log2(math.exp(1)) check_thread_limits(target, bdx=bdx, bdy=bdy, bdz=bdz, gdz=1) global_symbol = "batch_decode_paged_kv" @@ -1713,7 +1689,6 @@ def _attention_decode( # fmt: off @T.prim_func def batch_decode_paged_kv( - _0: T.int32, # pylint: disable=unused-argument Q_handle: T.handle, pages_handle: T.handle, page_table_indptr_handle: T.handle, @@ -1726,7 +1701,7 @@ def batch_decode_paged_kv( rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, - attn_score_scaling_factor: T.float32, + sm_scale: T.float32, ): T.func_attr({"tir.is_scheduled": 1, "global_symbol": global_symbol}) B = T.int32(is_size_var=True) @@ -1741,7 +1716,7 @@ def batch_decode_paged_kv( Q = T.match_buffer(Q_handle, (B, H_qo, D), qkv_dtype) pages = T.match_buffer( - pages_handle, (max_num_pages, 2, H_kv, 16, D), qkv_dtype, elem_offset=pages_elem_offset + pages_handle, (max_num_pages, 2, H_kv, page_size, D), qkv_dtype, elem_offset=pages_elem_offset ) page_table_indptr = T.match_buffer(page_table_indptr_handle, (B + 1,), "int32", elem_offset=page_indptr_elem_offset) page_table_values = T.match_buffer(page_table_values_handle, (nnz_pages,), "int32", elem_offset=page_values_elem_offset) @@ -1759,8 +1734,6 @@ def batch_decode_paged_kv( # denoting the "last_page_len". length_info = _declare_length_info(var_length_info, B, sliding_window, length_info_elem_offset) - sm_scale = 1.0 / math.sqrt(float(D)) * log2e - for bx in T.thread_binding(B, thread="blockIdx.x"): for fused_by_bz in T.thread_binding(H_kv * gdz, thread="blockIdx.y"): for ty in T.thread_binding(bdy, thread="threadIdx.y"): @@ -1797,7 +1770,7 @@ def batch_decode_paged_kv( cur_page_indptr_end: T.int32 = page_table_indptr[batch_idx + 1] kv_chunk_len[0] = T.if_then_else( cur_page_indptr_begin != cur_page_indptr_end, - _get_kv_chunk_len(cur_page_indptr_end - cur_page_indptr_begin, 16, batch_idx, length_info, sliding_window), + _get_kv_chunk_len(cur_page_indptr_end - cur_page_indptr_begin, page_size, batch_idx, length_info, sliding_window), 0 ) @@ -1826,8 +1799,8 @@ def batch_decode_paged_kv( row_g: T.int32(is_size_var=True) = tile_start_g + j # type: ignore if row_g < kv_chunk_len[0]: seq_offset: T.int32(is_size_var=True) = _get_seq_offset(row_g, batch_idx, length_info, sliding_window) # type: ignore - page_no: T.int32(is_size_var=True) = page_table_values[cur_page_indptr_begin + T.floordiv(seq_offset, 16)] # type: ignore - page_offset: T.int32(is_size_var=True) = T.floormod(seq_offset, 16) # type: ignore + page_no: T.int32(is_size_var=True) = page_table_values[cur_page_indptr_begin + T.floordiv(seq_offset, page_size)] # type: ignore + page_offset: T.int32(is_size_var=True) = T.floormod(seq_offset, page_size) # type: ignore for vec in T.vectorized(VEC_SIZE): K_smem[tile_start_s + j, tx * VEC_SIZE + vec] = T.if_then_else( rotary_mode == 1, @@ -1845,7 +1818,7 @@ def batch_decode_paged_kv( for j in T.serial(bdy * tile_size_per_bdx): # compute S = Q * K * sm_scale for vec in T.vectorized(VEC_SIZE): - QK_local[vec] = T.cast(Q_local[vec], "float32") * T.cast(K_smem[tz * bdy * tile_size_per_bdx + j, tx * VEC_SIZE + vec], "float32") * attn_score_scaling_factor * sm_scale + QK_local[vec] = T.cast(Q_local[vec], "float32") * T.cast(K_smem[tz * bdy * tile_size_per_bdx + j, tx * VEC_SIZE + vec], "float32") * sm_scale * math.log2(math.exp(1)) S_reduce_local[0] = 0 for vec in T.unroll(VEC_SIZE): S_reduce_local[0] += QK_local[vec] @@ -1966,7 +1939,9 @@ def merge_state_inplace_cpu( return merge_state_inplace_cpu -def _merge_state_inplace(num_heads, head_dim, v_dtype, target: Target): +def _merge_state_inplace( + num_heads, head_dim, v_dtype, target: Target, global_symbol: Optional[str] = None +): v_dtype_bytes = 2 VEC_SIZE = min(max(8 // v_dtype_bytes, head_dim // 32), 4) bdx = head_dim // VEC_SIZE @@ -2036,17 +2011,19 @@ def merge_state_inplace( # store s S[bx, ty + by * bdy] = T.log2(s_val[0] + s_other_val[0]) + s_max[0] - return merge_state_inplace + func = merge_state_inplace + if global_symbol: + func = func.with_attr("global_symbol", global_symbol) + return func def _attention_sequence_prefill( - h_kv, h_q, d, dtype, target: Target, causal=0, attn_score_scaling_factor=1.0 + h_kv, h_q, d, dtype, target: Target, causal=0, sm_scale=1.0 ): # pylint: disable=line-too-long ( _, LOAD_VEC, group_size, - sm_scale, bdx, num_warps, tile_x, @@ -2178,8 +2155,8 @@ def batch_sequence_prefill_kv( # pylint: disable=too-many-branches S_local[i, j] += ( T.cast(Q_smem[i, k], "float32") * T.cast(K_smem[j, k], "float32") - * attn_score_scaling_factor * sm_scale + * math.log2(math.exp(1)) ) T.tvm_storage_sync("shared") for li, lj in T.grid(tile_x, tile_z): @@ -2293,26 +2270,27 @@ def batch_sequence_prefill_kv( # pylint: disable=too-many-branches return sch.mod["main"].with_attr("tir.is_scheduled", 1) -def _attention_prefill_ragged_cpu(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any]): +def _attention_prefill_ragged_cpu(h_kv, h_q, d_qk, d_v, dtype, rope_scaling: Dict[str, Any]): group_size = h_q // h_kv - sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1)) + # fmt: off + # pylint: disable=line-too-long @T.prim_func def batch_prefill_ragged_kv( # pylint: disable=too-many-branches - var_q: T.handle, # [total_len, h_q, d] + var_q: T.handle, # [total_len, h_q, d_qk] var_q_indptr: T.handle, # [batch_size + 1] - var_k: T.handle, # [total_len, h_kv, d] - var_v: T.handle, # [total_len, h_kv, d] + var_k: T.handle, # [total_len, h_kv, d_qk] + var_v: T.handle, # [total_len, h_kv, d_v] var_kv_indptr: T.handle, # [batch_size + 1] var_q_rope_position: T.handle, # [total_q_len] var_k_rope_pos_offset: T.handle, # [b] - var_output: T.handle, # [total_len, h_q, d] + var_output: T.handle, # [total_len, h_q, d_v] var_lse: T.handle, # [total_len, h_q] causal: T.int32, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, - attn_score_scaling_factor: T.float32, + sm_scale: T.float32, ): batch_size = T.int32(is_size_var=True) qo_len = T.int32(is_size_var=True) @@ -2322,12 +2300,12 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches q_rope_position_elem_offset = T.int32(is_size_var=True) k_rope_pos_offset_elem_offset = T.int32(is_size_var=True) - q = T.match_buffer(var_q, (qo_len, h_q, d), dtype) + q = T.match_buffer(var_q, (qo_len, h_q, d_qk), dtype) q_indptr = T.match_buffer( var_q_indptr, (batch_size + 1,), "int32", elem_offset=q_indptr_elem_offset ) - k = T.match_buffer(var_k, (kv_len, h_kv, d), dtype) - v = T.match_buffer(var_v, (kv_len, h_kv, d), dtype) + k = T.match_buffer(var_k, (kv_len, h_kv, d_qk), dtype) + v = T.match_buffer(var_v, (kv_len, h_kv, d_v), dtype) kv_indptr = T.match_buffer( var_kv_indptr, (batch_size + 1,), "int32", elem_offset=kv_indptr_elem_offset ) @@ -2337,7 +2315,7 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches k_rope_pos_offset = T.match_buffer( var_k_rope_pos_offset, (batch_size,), "int32", elem_offset=k_rope_pos_offset_elem_offset ) - output = T.match_buffer(var_output, (qo_len, h_q, d), dtype) + output = T.match_buffer(var_output, (qo_len, h_q, d_v), dtype) lse = T.match_buffer(var_lse, (qo_len, h_q), "float32") # pylint: disable=unused-variable for b in T.serial(batch_size): @@ -2347,34 +2325,14 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches m_new = T.alloc_buffer([h_q], "float32") d_prev = T.alloc_buffer([h_q], "float32") d_new = T.alloc_buffer([h_q], "float32") - p_sum = T.alloc_buffer([d], "float32") + p_sum = T.alloc_buffer([d_v], "float32") max_score = T.alloc_buffer([h_q], "float32") attention_scores = T.alloc_buffer([kv_len, h_q], "float32") exp_scores = T.alloc_buffer([kv_len, h_q], "float32") - attention_score = T.alloc_buffer( - [ - 1, - ], - "float32", - ) - query_val = T.alloc_buffer( - [ - 1, - ], - "float32", - ) - key_val = T.alloc_buffer( - [ - 1, - ], - "float32", - ) - result = T.alloc_buffer( - [ - 1, - ], - "float32", - ) + attention_score = T.alloc_buffer([1], "float32") + query_val = T.alloc_buffer([1], "float32") + key_val = T.alloc_buffer([1], "float32") + result = T.alloc_buffer([1], "float32") for q_idx in T.serial(q_indptr[b + 1] - q_indptr[b]): for i in T.serial(h_q): @@ -2394,43 +2352,23 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches qo_len=q_indptr[b + 1] - q_indptr[b], ): result[0] = 0.0 - for d_idx in T.serial(d): + for d_idx in T.serial(d_qk): query_val[0] = T.if_then_else( rotary_mode == 1, - _rope( - q, - q_rope_position[q_indptr[b] + q_idx], - d, - rope_theta, - rope_scale, - (q_indptr[b] + q_idx, h, d_idx), - dtype, - rope_scaling, - ), + _rope(q, q_rope_position[q_indptr[b] + q_idx], d_qk, rope_theta, rope_scale, (q_indptr[b] + q_idx, h, d_idx), dtype, rope_scaling), q[q_indptr[b] + q_idx, h, d_idx], ) key_val[0] = T.if_then_else( rotary_mode == 1, - _rope( - k, - k_rope_pos_offset[b] + k_idx, - d, - rope_theta, - rope_scale, - (kv_indptr[b] + k_idx, h_kv_idx, d_idx), - dtype, - rope_scaling, - ), + _rope(k, k_rope_pos_offset[b] + k_idx, d_qk, rope_theta, rope_scale, (kv_indptr[b] + k_idx, h_kv_idx, d_idx), dtype, rope_scaling), k[kv_indptr[b] + k_idx, h_kv_idx, d_idx], ) result[0] += query_val[0] * key_val[0] - attention_score[0] = ( - result[0] * sm_scale * attn_score_scaling_factor - ) + attention_score[0] = result[0] * math.log2(math.exp(1)) * sm_scale else: - attention_score[0] = -5e4 * sm_scale * attn_score_scaling_factor + attention_score[0] = -5e4 * math.log2(math.exp(1)) * sm_scale attention_scores[k_idx, h] = attention_score[0] max_score[h] = T.max(max_score[h], attention_score[0]) m_new[h] = T.max(m_prev[h], max_score[h]) @@ -2449,24 +2387,21 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches for h in T.serial(h_q): h_kv_idx = h // group_size - for i in T.serial(d): + for i in T.serial(d_v): p_sum[i] = 0.0 for v_idx in T.serial(kv_indptr[b + 1] - kv_indptr[b]): weight = exp_scores[v_idx, h] / d_new[h] - for i in T.serial(d): + for i in T.serial(d_v): p_sum[i] += v[kv_indptr[b] + v_idx, h_kv_idx, i] * weight - for i in T.serial(d): + for i in T.serial(d_v): output[q_indptr[b] + q_idx, h, i] = p_sum[i] lse[q_indptr[b] + q_idx, h] = m_prev[h] + T.log2(d_prev[h]) - + # fmt: on + # pylint: enable=line-too-long return batch_prefill_ragged_kv -def _attention_prefill_ragged(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any], target: Target): - return _attention_prefill_ragged_generic(h_kv, h_q, d, d, dtype, rope_scaling, target) - - -def _attention_prefill_ragged_generic( +def _attention_prefill_ragged( h_kv, h_q, d_qk, d_v, dtype, rope_scaling: Dict[str, Any], target: Target ): # pylint: disable=line-too-long @@ -2474,7 +2409,6 @@ def _attention_prefill_ragged_generic( NUM_BLKS, LOAD_VEC, group_size, - sm_scale, bdx, num_warps, tile_x, @@ -2485,20 +2419,20 @@ def _attention_prefill_ragged_generic( # fmt: off @T.prim_func def batch_prefill_ragged_kv( # pylint: disable=too-many-branches - var_q: T.handle, # [total_len, h_q, d] + var_q: T.handle, # [total_len, h_q, d_qk] var_q_indptr: T.handle, # [batch_size + 1] - var_k: T.handle, # [total_len, h_kv, d] - var_v: T.handle, # [total_len, h_kv, d] + var_k: T.handle, # [total_len, h_kv, d_qk] + var_v: T.handle, # [total_len, h_kv, d_v] var_kv_indptr: T.handle, # [batch_size + 1] var_q_rope_position: T.handle, # [total_q_len] var_k_rope_pos_offset: T.handle, # [b] - var_output: T.handle, # [total_len, h_q, d] + var_output: T.handle, # [total_len, h_q, d_v] var_lse: T.handle, # [total_len, h_q] causal: T.int32, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, - attn_score_scaling_factor: T.float32 + sm_scale: T.float32 ): batch_size = T.int32(is_size_var=True) qo_len = T.int32(is_size_var=True) @@ -2580,7 +2514,7 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches m_smem[row] = -5e4 d_smem[row] = 1.0 - for li, lj in T.grid(tile_x, tile_y): + for li, lj in T.grid(tile_x, d_v): with T.block("O_init"): i, j = T.axis.remap("SS", [li, lj]) O_local[i, j] = 0.0 @@ -2620,7 +2554,7 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches else: K_smem[i, j] = 0.0 T.tvm_storage_sync("shared") - for lz, ly in T.grid(tile_z, tile_y): + for lz, ly in T.grid(tile_z, d_v): with T.block("V_load"): i, j = T.axis.remap("SS", [lz, ly]) T.reads() @@ -2639,7 +2573,7 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches i, j, k = T.axis.remap("SSR", [li, lj, lk]) with T.init(): S_local[i, j] = 0.0 - S_local[i, j] += T.cast(Q_smem[i, k], "float32") * T.cast(K_smem[j, k], "float32") * attn_score_scaling_factor * sm_scale + S_local[i, j] += T.cast(Q_smem[i, k], "float32") * T.cast(K_smem[j, k], "float32") * sm_scale * math.log2(math.exp(1)) T.tvm_storage_sync("shared") for li, lj in T.grid(tile_x, tile_z): with T.block("S_store"): @@ -2694,7 +2628,7 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches # Update O with T.block(): - for li, lj, lk in T.grid(tile_x, tile_y, tile_z): + for li, lj, lk in T.grid(tile_x, d_v, tile_z): with T.block("O_gemm"): i, j, k = T.axis.remap("SSR", [li, lj, lk]) with T.init(): @@ -2702,7 +2636,7 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches O_local[i, j] += S_smem[i, k] * T.cast(V_smem[k, j], "float32") # Store O from smem to gmem - for li, lj in T.grid(tile_x, tile_y): + for li, lj in T.grid(tile_x, d_v): with T.block("O_store"): i, j = T.axis.remap("SS", [li, lj]) cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size @@ -2724,9 +2658,7 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches # fmt: on # pylint: enable=line-too-long,too-many-branches sch = tir.Schedule(batch_prefill_ragged_kv) - sch = _schedule_prefill_kernel( - sch, LOAD_VEC, bdx, num_warps, tile_x, tile_y, tile_z, True, False - ) + sch = _schedule_prefill_kernel(sch, LOAD_VEC, bdx, num_warps, tile_x, d_v, tile_z, True, False) return sch.mod["main"].with_attr("tir.is_scheduled", 1) @@ -2737,13 +2669,13 @@ def _attention_prefill_mla( dtype, sliding_window: bool, target: Target, + page_size: int = 16, ): d_qk = d_latent + d_rope ( NUM_BLKS, LOAD_VEC, group_size, - _, bdx, num_warps, tile_x, @@ -2759,7 +2691,6 @@ def _attention_prefill_mla( # fmt: off @T.prim_func def batch_prefill_paged_kv_mla( - _0: T.int32, # pylint: disable=unused-argument var_q: T.handle, # [total_len, h_q, d_qk] var_q_indptr: T.handle, # [batch_size + 1] var_pages: T.handle, # [max_num_pages, page_size, d_qk] @@ -2769,7 +2700,7 @@ def batch_prefill_paged_kv_mla( var_output: T.handle, # [total_len, h_q, d_latent] var_lse: T.handle, # [total_len, h_q] causal: T.int32, - attn_score_scaling_factor: T.float32, + sm_scale: T.float32, ): T.func_attr({"global_symbol": global_symbol}) batch_size = T.int32(is_size_var=True) @@ -2784,7 +2715,7 @@ def batch_prefill_paged_kv_mla( q = T.match_buffer(var_q, (total_len, h_q, d_qk), dtype) q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", elem_offset=q_indptr_elem_offset) - pages = T.match_buffer(var_pages, (max_num_pages, 16, d_qk), dtype, elem_offset=pages_elem_offset) + pages = T.match_buffer(var_pages, (max_num_pages, page_size, d_qk), dtype, elem_offset=pages_elem_offset) page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", elem_offset=page_indptr_elem_offset) page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", elem_offset=page_values_elem_offset) output = T.match_buffer(var_output, (total_len, h_q, d_latent), dtype) @@ -2853,7 +2784,7 @@ def batch_prefill_paged_kv_mla( cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1] kv_chunk_len[0] = T.if_then_else( cur_page_indptr_begin != cur_page_indptr_end, - _get_kv_chunk_len(cur_page_indptr_end - cur_page_indptr_begin, 16, b_idx, length_info, sliding_window), + _get_kv_chunk_len(cur_page_indptr_end - cur_page_indptr_begin, page_size, b_idx, length_info, sliding_window), 0 ) T.tvm_storage_sync("shared") @@ -2895,8 +2826,8 @@ def batch_prefill_paged_kv_mla( cur_L = L_kv_start + i if cur_L < kv_chunk_len[0]: seq_offset: T.int32(is_size_var=True) = _get_seq_offset(cur_L, b_idx, length_info, sliding_window) # type: ignore - page_no: T.int32(is_size_var=True) = page_values[cur_page_indptr_begin + T.floordiv(seq_offset, 16)] # type: ignore - page_offset: T.int32(is_size_var=True) = T.floormod(seq_offset, 16) # type: ignore + page_no: T.int32(is_size_var=True) = page_values[cur_page_indptr_begin + T.floordiv(seq_offset, page_size)] # type: ignore + page_offset: T.int32(is_size_var=True) = T.floormod(seq_offset, page_size) # type: ignore KV_smem[i, j] = pages[page_no, page_offset, j] else: KV_smem[i, j] = 0.0 @@ -2909,7 +2840,7 @@ def batch_prefill_paged_kv_mla( i, j, k = T.axis.remap("SSR", [li, lj, lk]) with T.init(): S_local[i, j] = 0.0 - S_local[i, j] += T.cast(Q_smem[i, k], "float32") * T.cast(KV_smem[j, k], "float32") * attn_score_scaling_factor * math.log2(math.exp(1)) + S_local[i, j] += T.cast(Q_smem[i, k], "float32") * T.cast(KV_smem[j, k], "float32") * sm_scale * math.log2(math.exp(1)) T.tvm_storage_sync("shared") for li, lj in T.grid(tile_x, tile_z): with T.block("S_store"): @@ -3000,498 +2931,25 @@ def batch_prefill_paged_kv_mla( return sch.mod["main"].with_attr("tir.is_scheduled", 1) -def _attention_prefill_ragged_mla_absorbed(h_q, d_latent, d_rope, dtype, target: Target): - d_qk = d_latent + d_rope - ( - NUM_BLKS, - LOAD_VEC, - group_size, - _, - bdx, - num_warps, - tile_x, - tile_y, - tile_z, - ) = _get_prefill_kernel_config(1, h_q, d_qk, dtype, target) +def _copy_single_page(num_heads, page_size, head_dim, dtype, target: Target): + tx = get_max_num_threads_per_block(target) - # pylint: disable=line-too-long,too-many-branches - # fmt: off @T.prim_func - def batch_prefill_ragged_kv_mla_absorbed( # pylint: disable=too-many-branches - var_q: T.handle, # [total_len, h_q, d_qk] - var_q_indptr: T.handle, # [batch_size + 1] - var_compressed_kv: T.handle, # [total_len, d_latent] - var_k_pe: T.handle, # [total_len, d_rope] - var_kv_indptr: T.handle, # [batch_size + 1] - var_output: T.handle, # [total_len, h_q, d_latent] - var_lse: T.handle, # [total_len, h_q] - causal: T.int32, - attn_score_scaling_factor: T.float32 + def copy_single_page( + var_pages: T.handle, + src_page_id: T.int64, + tgt_page_id: T.int64, + copy_length: T.int64, ): - batch_size_plus_1 = T.int32(is_size_var=True) - qo_len = T.int32(is_size_var=True) - kv_len = T.int32(is_size_var=True) - q_indptr_elem_offset = T.int32(is_size_var=True) - kv_indptr_elem_offset = T.int32(is_size_var=True) - - q = T.match_buffer(var_q, (qo_len, h_q, d_qk), dtype) - q_indptr = T.match_buffer(var_q_indptr, (batch_size_plus_1,), "int32", elem_offset=q_indptr_elem_offset) - compressed_kv = T.match_buffer(var_compressed_kv, (kv_len, d_latent), dtype) - k_pe = T.match_buffer(var_k_pe, (kv_len, d_rope), dtype) - kv_indptr = T.match_buffer(var_kv_indptr, (batch_size_plus_1,), "int32", elem_offset=kv_indptr_elem_offset) - output = T.match_buffer(var_output, (qo_len, h_q, d_latent), dtype) - lse = T.match_buffer(var_lse, (qo_len, h_q), "float32") # pylint: disable=unused-variable - - # kernel code - for lbx in T.thread_binding(NUM_BLKS, thread="blockIdx.x"): - for lty in T.thread_binding(num_warps, thread="threadIdx.y"): - for ltx in T.thread_binding(bdx, thread="threadIdx.x"): - with T.block("attn"): - bx, ty, tx = T.axis.remap("SSS", [lbx, lty, ltx]) - T.reads() - T.writes() - tile_id = _var("int32") - batch_idx = _var("int32") - batch_tiles = _var("int32") - batch_rows = _var("int32") - iterator = _var("int32") - kv_chunk_len = _var("int32") - - Q_smem = T.alloc_buffer((tile_x, d_qk), dtype, scope="shared") - KV_smem = T.alloc_buffer((tile_z, d_qk), dtype, scope="shared") - S_smem = T.alloc_buffer((tile_x, tile_z), "float32", scope="shared") - - S_local = T.alloc_buffer((tile_x, tile_z), "float32", scope="local") - O_local = T.alloc_buffer((tile_x, d_latent), "float32", scope="local") - - m_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") - m_prev_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") - d_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") - - m_new = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local") - m_prev = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local") - d_new = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local") - - ## get tile_no, batch_idx, batch_tiles, batch_rows - tile_id[0] = bx - batch_idx[0] = 0 - batch_rows[0] = (q_indptr[1] - q_indptr[0]) * group_size - batch_tiles[0] = T.ceildiv(batch_rows[0], tile_x) - while T.tvm_thread_invariant(batch_idx[0] < batch_size_plus_1 - 1): - # advance to next tile - while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size_plus_1 - 1: - tile_id[0] -= batch_tiles[0] - batch_idx[0] += 1 - if batch_idx[0] < batch_size_plus_1 - 1: - b_idx: T.int32 = batch_idx[0] - batch_rows[0] = (q_indptr[b_idx + 1] - q_indptr[b_idx]) * group_size - batch_tiles[0] = T.ceildiv(batch_rows[0], tile_x) - - if T.tvm_thread_invariant(batch_idx[0] < batch_size_plus_1 - 1): - b_idx: T.int32 = batch_idx[0] - q_indptr_val: T.int32 = q_indptr[b_idx] - LH_start: T.int32 = tile_id[0] * tile_x - - kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx] - T.tvm_storage_sync("shared") - - # init states - for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): - row: T.int32 = i * bdx * num_warps + ty * bdx + tx - if row < tile_x: - m_smem[row] = -5e4 - d_smem[row] = 1.0 - - for li, lj in T.grid(tile_x, d_latent): - with T.block("O_init"): - i, j = T.axis.remap("SS", [li, lj]) - O_local[i, j] = 0.0 - T.tvm_storage_sync("shared") - - # Load Q from gmem to smem - for li, lj in T.grid(tile_x, tile_y): - with T.block("Q_load"): - i, j = T.axis.remap("SS", [li, lj]) - T.reads() - T.writes() - cur_L = q_indptr_val + (LH_start + i) // group_size - cur_H_qo = (LH_start + i) % group_size - if cur_L < q_indptr[b_idx + 1]: - Q_smem[i, j] = q[cur_L, cur_H_qo, j] - else: - Q_smem[i, j] = 0.0 - T.tvm_storage_sync("shared") - - for iterator in T.serial(T.ceildiv(kv_chunk_len[0], tile_z)): - L_kv_start: T.int32 = iterator * tile_z - L_kv_base: T.int32 = kv_indptr[b_idx] - for lz, ly in T.grid(tile_z, d_latent): - with T.block("V_load"): - i, j = T.axis.remap("SS", [lz, ly]) - cur_L = L_kv_start + i - if cur_L < kv_chunk_len[0]: - KV_smem[i, j] = compressed_kv[L_kv_base + cur_L, j] - else: - KV_smem[i, j] = 0.0 - T.tvm_storage_sync("shared") - for lz, ly in T.grid(tile_z, d_rope): - with T.block("K_load"): - i, j = T.axis.remap("SS", [lz, ly]) - T.reads() - T.writes() - cur_L = L_kv_start + i - if cur_L < kv_chunk_len[0]: - KV_smem[i, d_latent + j] = k_pe[L_kv_base + cur_L, j] - else: - KV_smem[i, d_latent + j] = 0.0 - T.tvm_storage_sync("shared") - - # Compute S - with T.block(): - for li, lj, lk in T.grid(tile_x, tile_z, tile_y): - with T.block("S_gemm"): - i, j, k = T.axis.remap("SSR", [li, lj, lk]) - with T.init(): - S_local[i, j] = 0.0 - S_local[i, j] += T.cast(Q_smem[i, k], "float32") * T.cast(KV_smem[j, k], "float32") * attn_score_scaling_factor * math.log2(math.exp(1)) - T.tvm_storage_sync("shared") - for li, lj in T.grid(tile_x, tile_z): - with T.block("S_store"): - i, j = T.axis.remap("SS", [li, lj]) - S_smem[i, j] = S_local[i, j] - T.tvm_storage_sync("shared") - - # Update S, m, d - for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): - row: T.int32 = i * bdx * num_warps + ty * bdx + tx - if row < tile_x: - with T.block("update1"): - m_prev[i] = m_smem[row] - m_new[i] = m_smem[row] - # mask out of kv_chunk_len S - row_: T.int32 = (LH_start + row) // group_size - for j in T.serial(tile_z): - if _causal_mask(causal, - row=row_, - col=L_kv_start + j, - kv_len=kv_chunk_len[0], - qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]): - m_new[i] = T.max(m_new[i], S_smem[row, j]) - d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i]) - - for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): - row: T.int32 = i * bdx * num_warps + ty * bdx + tx - with T.block("update"): - for j in T.serial(tile_z): - # this is to avoid sync inside condition branch - if row < tile_x: - row_: T.int32 = (LH_start + row) // group_size - if _causal_mask(causal, - row=row_, - col=L_kv_start + j, - kv_len=kv_chunk_len[0], - qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]): - S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i]) - else: - S_smem[row, j] = T.exp2(-5e4 - m_new[i]) - - for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): - row: T.int32 = i * bdx * num_warps + ty * bdx + tx - if row < tile_x: - with T.block("update"): - for j in T.serial(tile_z): - d_new[i] += S_smem[row, j] - m_smem[row] = m_new[i] - d_smem[row] = d_new[i] - m_prev_smem[row] = m_prev[i] - T.tvm_storage_sync("shared") - - # Update O - with T.block(): - for li, lj, lk in T.grid(tile_x, d_latent, tile_z): - with T.block("O_gemm"): - i, j, k = T.axis.remap("SSR", [li, lj, lk]) - with T.init(): - O_local[i, j] *= T.exp2(m_prev_smem[i] - m_smem[i]) - O_local[i, j] += S_smem[i, k] * T.cast(KV_smem[k, j], "float32") - - # Store O from smem to gmem - for li, lj in T.grid(tile_x, d_latent): - with T.block("O_store"): - i, j = T.axis.remap("SS", [li, lj]) - cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size - cur_H_qo: T.int32 = (LH_start + i) % group_size - if cur_L < q_indptr[b_idx + 1]: - output[cur_L, cur_H_qo, j] = O_local[i, j] / d_smem[i] - - # Store LSE to gmem - for li in T.grid(tile_x): - with T.block("lse_store"): - i = T.axis.remap("S", [li]) - cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size - cur_H_qo: T.int32 = (LH_start + i) % group_size - if cur_L < q_indptr[b_idx + 1]: - lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i]) - - # move to next tile - tile_id[0] += NUM_BLKS - # fmt: on - # pylint: enable=line-too-long,too-many-branches - sch = tir.Schedule(batch_prefill_ragged_kv_mla_absorbed) - sch = _schedule_prefill_kernel( - sch, LOAD_VEC, bdx, num_warps, tile_x, tile_y, tile_z, False, False - ) - return sch.mod["main"].with_attr("tir.is_scheduled", 1) - - -def _attention_decode_mla(h_q, d_latent, d_rope, qkv_dtype, sliding_window: bool, target: Target): - d_qk = d_latent + d_rope - qkv_dtype_bytes = 2 - - THREAD_LIMIT = 512 - TILE_SIZE_PER_BDX = 2 - if target.kind.name == "opencl" and ( - ("android" in str(target.host)) or ("adreno" in str(target.attrs)) - ): - # Keeping lower thread limit for this kernel on adreno target - # to avoid register spill - THREAD_LIMIT = 256 - TILE_SIZE_PER_BDX = 1 - max_num_threads_per_block = get_max_num_threads_per_block(target) - thread_limit = min(max_num_threads_per_block, THREAD_LIMIT) - - GROUP_SIZE = h_q - VEC_SIZE = min(max(8 // qkv_dtype_bytes, d_qk // 32), 4) - bdx = d_qk // VEC_SIZE - bdy = GROUP_SIZE - while bdx * bdy > thread_limit and bdy > 1: - bdy //= 2 - gdy = GROUP_SIZE // bdy - threads_per_CTA = max(thread_limit, bdx * bdy) - bdz = threads_per_CTA // (bdx * bdy) - tile_size_per_bdx = TILE_SIZE_PER_BDX if GROUP_SIZE == 1 else 1 - check_thread_limits(target, bdx=bdx, bdy=bdy, bdz=bdz, gdz=1) - - global_symbol = "batch_decode_paged_kv_mla" - if sliding_window: - global_symbol += "_sliding_window" - - # pylint: disable=line-too-long,too-many-branches - # fmt: off - @T.prim_func - def batch_decode_paged_kv_mla( - _0: T.int32, # pylint: disable=unused-argument - Q_handle: T.handle, - pages_handle: T.handle, - page_table_indptr_handle: T.handle, - page_table_values_handle: T.handle, - var_length_info: T.handle, # [b] when sliding window = False, or otherwise [3, b] - output_handle: T.handle, - lse_handle: T.handle, - attn_score_scaling_factor: T.float32, - ): - T.func_attr({"tir.is_scheduled": 1, "global_symbol": global_symbol}) - B = T.int32(is_size_var=True) - nnz_pages = T.int32(is_size_var=True) - max_num_pages = T.int32(is_size_var=True) - pages_elem_offset = T.int64(is_size_var=True) - page_indptr_elem_offset = T.int32(is_size_var=True) - page_values_elem_offset = T.int32(is_size_var=True) - length_info_elem_offset = T.int32(is_size_var=True) - - Q = T.match_buffer(Q_handle, (B, h_q, d_qk), qkv_dtype) - pages = T.match_buffer( - pages_handle, (max_num_pages, 16, d_qk), qkv_dtype, elem_offset=pages_elem_offset - ) - page_table_indptr = T.match_buffer(page_table_indptr_handle, (B + 1,), "int32", elem_offset=page_indptr_elem_offset) - page_table_values = T.match_buffer(page_table_values_handle, (nnz_pages,), "int32", elem_offset=page_values_elem_offset) - output = T.match_buffer(output_handle, (B, h_q, d_latent), qkv_dtype) - lse = T.match_buffer(lse_handle, (B, h_q), "float32") # pylint: disable=unused-variable - # The length information of the sequences. - # - It is in shape `(3, batch_size)` when sliding window is enabled. - # For a sequence "i", location - # - "(0, i)" is the number of KV slots used in the last page of the seq ("last_page_len"), - # - "(1, i)" is the starting offset of the sliding window in the seq, - # - "(2, i)" is the attn sink length of the sequence. - # - It is in shape `(batch_size,)` when sliding window is disabled, - # denoting the "last_page_len". - length_info = _declare_length_info(var_length_info, B, sliding_window, length_info_elem_offset) - - for bx in T.thread_binding(B, thread="blockIdx.x"): - for by in T.thread_binding(gdy, thread="blockIdx.y"): - for ty in T.thread_binding(bdy, thread="threadIdx.y"): - for tx in T.thread_binding(bdx, thread="threadIdx.x"): - for tz in T.thread_binding(bdz, thread="threadIdx.z"): - with T.block("attn"): - Q_local = T.alloc_buffer((VEC_SIZE,), qkv_dtype, scope="local") - kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local") - KV_smem = T.alloc_buffer((bdz * bdy * tile_size_per_bdx, d_qk), qkv_dtype, scope="shared") - O_allreduce = T.alloc_buffer((bdz, bdy, d_qk), "float32", scope="shared") - md_allreduce = T.alloc_buffer((bdz, bdy, 2), "float32", scope="shared") - S_reduce_local = T.alloc_buffer((1,), "float32", scope="local") - t0 = T.alloc_buffer((1,), "float32", scope="local") - - S_local = T.alloc_buffer((bdy * tile_size_per_bdx), "float32", scope="local") - QK_local = T.alloc_buffer((VEC_SIZE,), "float32", scope="local") - V_local = T.alloc_buffer((VEC_SIZE,), qkv_dtype, scope="local") - m_prev = T.alloc_buffer((1,), "float32", scope="local") - d_prev = T.alloc_buffer((1,), "float32", scope="local") - other_m = T.alloc_buffer((1,), "float32", scope="local") - other_d = T.alloc_buffer((1,), "float32", scope="local") - exp_mprev = T.alloc_buffer((1,), "float32", scope="local") - exp_otherm = T.alloc_buffer((1,), "float32", scope="local") - other_o = T.alloc_buffer((VEC_SIZE,), "float32", scope="local") - st_m = T.alloc_buffer((1,), "float32", scope="local") - st_d = T.alloc_buffer((1,), "float32", scope="local") - O_local = T.alloc_buffer((VEC_SIZE,), "float32", scope="local") - - batch_idx: T.int32 = bx - cur_page_indptr_begin: T.int32 = page_table_indptr[batch_idx] - cur_page_indptr_end: T.int32 = page_table_indptr[batch_idx + 1] - kv_chunk_len[0] = T.if_then_else( - cur_page_indptr_begin != cur_page_indptr_end, - _get_kv_chunk_len(cur_page_indptr_end - cur_page_indptr_begin, 16, batch_idx, length_info, sliding_window), - 0 - ) - - # init states - st_m[0] = -5e4 - st_d[0] = 1.0 - for vec in T.vectorized(VEC_SIZE): - O_local[vec] = 0.0 - - # load q - for vec in T.vectorized(VEC_SIZE): - Q_local[vec] = Q[bx, by * bdy + ty, tx * VEC_SIZE + vec] - - for iterator in T.serial(T.ceildiv(kv_chunk_len[0], tile_size_per_bdx * bdy * bdz)): - tile_start_s: T.int32(is_size_var=True) = (tz * bdy + ty) * tile_size_per_bdx # type: ignore - tile_start_g: T.int32(is_size_var=True) = ((iterator * bdz + tz) * bdy + ty) * tile_size_per_bdx # type: ignore - # load KV from global memory to shared memory - for j in T.serial(tile_size_per_bdx): - with T.block("KV_load"): - T.reads() - T.writes() - row_g: T.int32(is_size_var=True) = tile_start_g + j # type: ignore - if row_g < kv_chunk_len[0]: - seq_offset: T.int32(is_size_var=True) = _get_seq_offset(row_g, batch_idx, length_info, sliding_window) # type: ignore - page_no: T.int32(is_size_var=True) = page_table_values[cur_page_indptr_begin + T.floordiv(seq_offset, 16)] # type: ignore - page_offset: T.int32(is_size_var=True) = T.floormod(seq_offset, 16) # type: ignore - for vec in T.vectorized(VEC_SIZE): - KV_smem[tile_start_s + j, tx * VEC_SIZE + vec] = pages[page_no, page_offset, tx * VEC_SIZE + vec] - else: - for vec in T.vectorized(VEC_SIZE): - KV_smem[tile_start_s + j, tx * VEC_SIZE + vec] = 0.0 - T.tvm_storage_sync("shared") - # compute QK - m_prev[0] = st_m[0] - for j in T.serial(bdy * tile_size_per_bdx): - # compute S = Q * K * sm_scale - for vec in T.vectorized(VEC_SIZE): - QK_local[vec] = T.cast(Q_local[vec], "float32") * T.cast(KV_smem[tz * bdy * tile_size_per_bdx + j, tx * VEC_SIZE + vec], "float32") * attn_score_scaling_factor * math.log2(math.exp(1)) - S_reduce_local[0] = 0 - for vec in T.unroll(VEC_SIZE): - S_reduce_local[0] += QK_local[vec] - - with T.block("block_cross_thread"): - T.reads(S_reduce_local[0]) - T.writes(t0[0]) - T.attr( - T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), - "reduce_scope", - T.reinterpret("handle", T.uint64(0)), - ) - T.tvm_thread_allreduce(T.uint32(1), S_reduce_local[0], True, t0[0], tx, dtype="handle") - - S_local[j] = -5e4 - if (iterator * bdz + tz) * bdy * tile_size_per_bdx + j < kv_chunk_len[0]: - S_local[j] = t0[0] - # update st_m - st_m[0] = T.max(st_m[0], S_local[j]) - - # update st_d, st_O - o_scale: T.float32 = T.exp2(m_prev[0] - st_m[0]) - st_d[0] *= o_scale - for j in T.serial(bdy * tile_size_per_bdx): - S_local[j] = T.exp2(S_local[j] - st_m[0]) - st_d[0] += S_local[j] - for j in T.vectorized(VEC_SIZE): - O_local[j] *= o_scale - - # load V from shared memory to local memory - # compute O - for j in T.serial(bdy * tile_size_per_bdx): - if tx * VEC_SIZE < d_latent: - for vec in T.vectorized(VEC_SIZE): - V_local[vec] = KV_smem[tz * bdy * tile_size_per_bdx + j, tx * VEC_SIZE + vec] - else: - for vec in T.vectorized(VEC_SIZE): - V_local[vec] = 0.0 - for vec in T.vectorized(VEC_SIZE): - O_local[vec] += T.cast(V_local[vec], "float32") * S_local[j] - - if bdz > 1: - # allreduce over bdz - for vec in T.vectorized(VEC_SIZE): - O_allreduce[tz, ty, tx * VEC_SIZE + vec] = O_local[vec] - md_allreduce[tz, ty, 0] = st_m[0] - md_allreduce[tz, ty, 1] = st_d[0] - T.tvm_storage_sync("shared") - - st_m[0] = -5e4 - st_d[0] = 1.0 - for vec in T.vectorized(VEC_SIZE): - O_local[vec] = 0.0 - - for j in T.serial(bdz): - m_prev[0] = st_m[0] - d_prev[0] = st_d[0] - other_m[0] = md_allreduce[j, ty, 0] - other_d[0] = md_allreduce[j, ty, 1] - for vec in T.vectorized(VEC_SIZE): - other_o[vec] = O_allreduce[j, ty, tx * VEC_SIZE + vec] - st_m[0] = T.max(st_m[0], other_m[0]) - st_d[0] = d_prev[0] * T.exp2(m_prev[0] - st_m[0]) + other_d[0] * T.exp2(other_m[0] - st_m[0]) - exp_mprev[0] = T.exp2(m_prev[0] - st_m[0]) - exp_otherm[0] = T.exp2(other_m[0] - st_m[0]) - for vec in T.vectorized(VEC_SIZE): - O_local[vec] = O_local[vec] * exp_mprev[0] + other_o[vec] * exp_otherm[0] - - # normalize O - for vec in T.vectorized(VEC_SIZE): - O_local[vec] /= st_d[0] - - # store O to global memory - if tx * VEC_SIZE < d_latent: - for vec in T.vectorized(VEC_SIZE): - output[batch_idx, by * bdy + ty, tx * VEC_SIZE + vec] = O_local[vec] - - # store lse to global memory - lse[batch_idx, by * bdy + ty] = st_m[0] + T.log2(st_d[0]) - # fmt: on - # pylint: enable=line-too-long,too-many-branches - return batch_decode_paged_kv_mla - - -def _copy_single_page(num_heads, page_size, head_dim, dtype, target: Target): - tx = get_max_num_threads_per_block(target) - - @T.prim_func - def copy_single_page( - var_pages: T.handle, - src_page_id: T.int64, - tgt_page_id: T.int64, - copy_length: T.int64, - ): - T.func_attr({"tir.is_scheduled": 1}) - num_pages = T.int32() - pages_elem_offset = T.int64() - pages = T.match_buffer( - var_pages, - (num_pages, 2, num_heads, page_size, head_dim), - dtype, - elem_offset=pages_elem_offset, - ) + T.func_attr({"tir.is_scheduled": 1}) + num_pages = T.int32() + pages_elem_offset = T.int64() + pages = T.match_buffer( + var_pages, + (num_pages, 2, num_heads, page_size, head_dim), + dtype, + elem_offset=pages_elem_offset, + ) for b in T.thread_binding( (copy_length * num_heads * head_dim + tx - 1) // tx, thread="blockIdx.x" @@ -3587,7 +3045,7 @@ def copy_single_page_cpu( return copy_single_page_cpu -def _compact_kv_copy(num_heads, head_dim, dtype, target: Target): +def _compact_kv_copy(num_heads, head_dim, dtype, target: Target, page_size: int = 16): tx = get_max_num_threads_per_block(target) @T.prim_func @@ -3604,7 +3062,10 @@ def compact_kv_copy( copy_src_dst_pos_elem_offset = T.int32() pages_elem_offset = T.int64() pages = T.match_buffer( - var_pages, (num_pages, 2, num_heads, 16, head_dim), dtype, elem_offset=pages_elem_offset + var_pages, + (num_pages, 2, num_heads, page_size, head_dim), + dtype, + elem_offset=pages_elem_offset, ) copy_length_indptr = T.match_buffer( var_copy_length_indptr, @@ -3631,17 +3092,17 @@ def compact_kv_copy( for i in T.serial(copy_length_indptr[b + 1] - copy_length_indptr[b]): src_pos: T.int32 = copy_src_dst_pos[0, copy_length_indptr[b] + i] dst_pos: T.int32 = copy_src_dst_pos[1, copy_length_indptr[b] + i] - pages[dst_pos // 16, 0, h, dst_pos % 16, d] = pages[ - src_pos // 16, 0, h, src_pos % 16, d + pages[dst_pos // page_size, 0, h, dst_pos % page_size, d] = pages[ + src_pos // page_size, 0, h, src_pos % page_size, d ] - pages[dst_pos // 16, 1, h, dst_pos % 16, d] = pages[ - src_pos // 16, 1, h, src_pos % 16, d + pages[dst_pos // page_size, 1, h, dst_pos % page_size, d] = pages[ + src_pos // page_size, 1, h, src_pos % page_size, d ] return compact_kv_copy -def _compact_kv_copy_cpu(num_heads, head_dim, dtype): +def _compact_kv_copy_cpu(num_heads, head_dim, dtype, page_size: int = 16): tx = 8 @T.prim_func @@ -3656,7 +3117,7 @@ def compact_kv_copy_cpu( total_copy_length = T.int32() copy_length_indptr_elem_offset = T.int32() copy_src_dst_pos_elem_offset = T.int32() - pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, 16, head_dim), dtype) + pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, page_size, head_dim), dtype) copy_length_indptr = T.match_buffer( var_copy_length_indptr, (batch_size + 1,), @@ -3680,11 +3141,11 @@ def compact_kv_copy_cpu( for i in T.serial(copy_length_indptr[b + 1] - copy_length_indptr[b]): src_pos: T.int32 = copy_src_dst_pos[0, copy_length_indptr[b] + i] dst_pos: T.int32 = copy_src_dst_pos[1, copy_length_indptr[b] + i] - pages[dst_pos // 16, 0, h, dst_pos % 16, d] = pages[ - src_pos // 16, 0, h, src_pos % 16, d + pages[dst_pos // page_size, 0, h, dst_pos % page_size, d] = pages[ + src_pos // page_size, 0, h, src_pos % page_size, d ] - pages[dst_pos // 16, 1, h, dst_pos % 16, d] = pages[ - src_pos // 16, 1, h, src_pos % 16, d + pages[dst_pos // page_size, 1, h, dst_pos % page_size, d] = pages[ + src_pos // page_size, 1, h, src_pos % page_size, d ] return compact_kv_copy_cpu diff --git a/python/tvm/relax/frontend/nn/llm/tree_attn.py b/python/tvm/relax/frontend/nn/llm/tree_attn.py index 9aa27ca83d70..36a6e2dab84a 100644 --- a/python/tvm/relax/frontend/nn/llm/tree_attn.py +++ b/python/tvm/relax/frontend/nn/llm/tree_attn.py @@ -112,7 +112,6 @@ def tree_attn_cpu(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any]): The generated IR module. """ group_size = h_q // h_kv - sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1)) # fmt: off @T.prim_func @@ -130,8 +129,7 @@ def batch_tree_attn( # pylint: disable=too-many-branches,line-too-long rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, - attn_score_scaling_factor: T.float32, - batch_size: T.int32, + sm_scale: T.float32, ): qo_len = T.int32(is_size_var=True) kv_len = T.int32(is_size_var=True) @@ -141,27 +139,28 @@ def batch_tree_attn( # pylint: disable=too-many-branches,line-too-long mn_indptr_elem_offset = T.int32(is_size_var=True) mask_elem_offset = T.int32(is_size_var=True) tree_size = T.int32(is_size_var=True) + batch_size_plus_1 = T.int32(is_size_var=True) q = T.match_buffer(var_q, (qo_len, h_q, d), dtype) q_indptr = T.match_buffer( - var_q_indptr, (batch_size + 1,), "int32", elem_offset=q_indptr_elem_offset + var_q_indptr, (batch_size_plus_1,), "int32", elem_offset=q_indptr_elem_offset ) k = T.match_buffer(var_k, (kv_len, h_kv, d), dtype) v = T.match_buffer(var_v, (kv_len, h_kv, d), dtype) kv_indptr = T.match_buffer( - var_kv_indptr, (batch_size + 1,), "int32", elem_offset=kv_indptr_elem_offset + var_kv_indptr, (batch_size_plus_1,), "int32", elem_offset=kv_indptr_elem_offset ) q_rope_position = T.match_buffer( var_q_rope_position, (qo_len,), "int32", elem_offset=q_rope_position_elem_offset ) mn_indptr = T.match_buffer( - var_mn_indptr, (batch_size + 1,), "int32", elem_offset=mn_indptr_elem_offset + var_mn_indptr, (batch_size_plus_1,), "int32", elem_offset=mn_indptr_elem_offset ) mask = T.match_buffer(var_mask, (tree_size, 2), "int32", elem_offset=mask_elem_offset) output = T.match_buffer(var_output, (qo_len, h_q, d), dtype) lse = T.match_buffer(var_lse, (qo_len, h_q), "float32") # pylint: disable=unused-variable - for b in T.serial(batch_size): + for b in T.serial(batch_size_plus_1 - 1): with T.block("attn"): softmax_sum = T.alloc_buffer([h_q], "float32") @@ -252,10 +251,10 @@ def batch_tree_attn( # pylint: disable=too-many-branches,line-too-long result[0] += query_val[0] * key_val[0] attention_score[0] = ( - result[0] * sm_scale * attn_score_scaling_factor + result[0] * math.log2(math.exp(1)) * sm_scale ) else: - attention_score[0] = -5e4 * sm_scale * attn_score_scaling_factor + attention_score[0] = -5e4 * math.log2(math.exp(1)) * sm_scale attention_scores[k_idx, h] = attention_score[0] max_score[h] = T.max(max_score[h], attention_score[0]) m_new[h] = T.max(m_prev[h], max_score[h]) @@ -316,7 +315,6 @@ def tree_attn( NUM_BLKS = 16 LOAD_VEC = 8 // ((DataType(dtype).bits + 7) // 8) # 8 bytes group_size = h_q // h_kv - sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1)) bdx = 32 num_warps = 4 @@ -356,8 +354,7 @@ def batch_tree_attn( # pylint: disable=too-many-branches rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, - attn_score_scaling_factor: T.float32, - batch_size: T.int32, + sm_scale: T.float32, ): qo_len = T.int32(is_size_var=True) kv_len = T.int32(is_size_var=True) @@ -367,14 +364,15 @@ def batch_tree_attn( # pylint: disable=too-many-branches mn_indptr_elem_offset = T.int32(is_size_var=True) mask_elem_offset = T.int32(is_size_var=True) tree_size = T.int32(is_size_var=True) + batch_size_plus_1 = T.int32(is_size_var=True) q = T.match_buffer(var_q, (qo_len, h_q, d), dtype) - q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", elem_offset=q_indptr_elem_offset) + q_indptr = T.match_buffer(var_q_indptr, (batch_size_plus_1,), "int32", elem_offset=q_indptr_elem_offset) k = T.match_buffer(var_k, (kv_len, h_kv, d), dtype) v = T.match_buffer(var_v, (kv_len, h_kv, d), dtype) - kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32", elem_offset=kv_indptr_elem_offset) + kv_indptr = T.match_buffer(var_kv_indptr, (batch_size_plus_1,), "int32", elem_offset=kv_indptr_elem_offset) q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), "int32", elem_offset=q_rope_position_elem_offset) - mn_indptr = T.match_buffer(var_mn_indptr, (batch_size + 1,), "int32", elem_offset=mn_indptr_elem_offset) + mn_indptr = T.match_buffer(var_mn_indptr, (batch_size_plus_1,), "int32", elem_offset=mn_indptr_elem_offset) mask = T.match_buffer(var_mask, (tree_size, 2), "int32", elem_offset=mask_elem_offset) output = T.match_buffer(var_output, (qo_len, h_q, d), dtype) lse = T.match_buffer(var_lse, (qo_len, h_q), "float32") # pylint: disable=unused-variable @@ -416,17 +414,17 @@ def batch_tree_attn( # pylint: disable=too-many-branches batch_idx[0] = 0 batch_rows[0] = (q_indptr[1] - q_indptr[0]) * group_size batch_tiles[0] = T.ceildiv(batch_rows[0], tile_x) - while T.tvm_thread_invariant(batch_idx[0] < batch_size): + while T.tvm_thread_invariant(batch_idx[0] < batch_size_plus_1 - 1): # advance to next tile - while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size: + while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size_plus_1 - 1: tile_id[0] -= batch_tiles[0] batch_idx[0] += 1 - if batch_idx[0] < batch_size: + if batch_idx[0] < batch_size_plus_1 - 1: b_idx: T.int32 = batch_idx[0] batch_rows[0] = (q_indptr[b_idx + 1] - q_indptr[b_idx]) * group_size batch_tiles[0] = T.ceildiv(batch_rows[0], tile_x) - if T.tvm_thread_invariant(batch_idx[0] < batch_size): + if T.tvm_thread_invariant(batch_idx[0] < batch_size_plus_1 - 1): b_idx: T.int32 = batch_idx[0] LH_start: T.int32 = tile_id[0] * tile_x q_indptr_val: T.int32 = q_indptr[b_idx] @@ -493,7 +491,7 @@ def batch_tree_attn( # pylint: disable=too-many-branches i, j, k = T.axis.remap("SSR", [li, lj, lk]) with T.init(): S_local[i, j] = 0.0 - S_local[i, j] += T.cast(Q_smem[i, k], "float32") * T.cast(K_smem[j, k], "float32") * attn_score_scaling_factor * sm_scale + S_local[i, j] += T.cast(Q_smem[i, k], "float32") * T.cast(K_smem[j, k], "float32") * sm_scale * math.log2(math.exp(1)) T.tvm_storage_sync("shared") for li, lj in T.grid(tile_x, tile_z): with T.block("S_store"): @@ -676,21 +674,15 @@ def tree_attn_with_paged_kv_cache_cpu(h_kv, h_q, d, dtype, rope_scaling: Dict[st The generated IR module. """ # pylint: disable=import-outside-toplevel - from .kv_cache import ( - _declare_length_info, - _get_kv_chunk_len, - _get_seq_offset, - ) + from .kv_cache import _declare_length_info, _get_kv_chunk_len, _get_seq_offset global_symbol = "tree_attn_paged_kv_cpu" sliding_window = False group_size = h_q // h_kv - sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1)) # pylint: disable=line-too-long,too-many-branches # fmt: off @T.prim_func(check_well_formed=False) def tree_attn_paged_kv_cpu( - _0: T.int32, # pylint: disable=unused-argument var_q: T.handle, # [total_len, h_q, d] var_q_indptr: T.handle, # [batch_size + 1] var_pages: T.handle, # [max_num_pages, 2, h_kv, page_size, d] @@ -704,7 +696,7 @@ def tree_attn_paged_kv_cpu( rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, - attn_score_scaling_factor: T.float32, + sm_scale: T.float32, tree_order_indptr_handle: T.handle, # [batch_size + 1] tree_order_handle: T.handle, # [total_len, 2] ): @@ -817,7 +809,7 @@ def tree_attn_paged_kv_cpu( S_val[0] = 0.0 for d_idx in T.serial(d): S_val[0] += Q_local[d_idx] * K_local[d_idx] - S_val[0] *= attn_score_scaling_factor * sm_scale + S_val[0] *= sm_scale * math.log2(math.exp(1)) # update m_val, d_val , O_local if _check_tree_order( @@ -889,7 +881,6 @@ def tree_attn_with_paged_kv_cache( NUM_BLKS = 16 LOAD_VEC = 8 // ((DataType(dtype).bits + 7) // 8) # 8 bytes group_size = h_q // h_kv - sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1)) bdx = 32 num_warps = 4 @@ -920,7 +911,6 @@ def tree_attn_with_paged_kv_cache( # fmt: off @T.prim_func def tree_attn_paged_kv( - _0: T.int32, # pylint: disable=unused-argument var_q: T.handle, # [total_len, h_q, d] var_q_indptr: T.handle, # [batch_size + 1] var_pages: T.handle, # [max_num_pages, 2, h_kv, page_size, d] @@ -934,7 +924,7 @@ def tree_attn_paged_kv( rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, - attn_score_scaling_factor: T.float32, + sm_scale: T.float32, tree_order_indptr_handle: T.handle, # [batch_size + 1] tree_order_handle: T.handle, # [total_len, 2] ): @@ -1164,8 +1154,8 @@ def tree_attn_paged_kv( S_local[i, j] += ( T.cast(Q_smem[i, k], "float32") * T.cast(K_smem[j, k], "float32") - * attn_score_scaling_factor * sm_scale + * math.log2(math.exp(1)) ) T.tvm_storage_sync("shared") for li, lj in T.grid(tile_x, tile_z): diff --git a/src/runtime/relax_vm/attn_backend.cc b/src/runtime/relax_vm/attn_backend.cc new file mode 100644 index 000000000000..0b94d541c2dd --- /dev/null +++ b/src/runtime/relax_vm/attn_backend.cc @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! \file src/runtime/relax_vm/attn_backend.cc */ + +#include "attn_backend.h" + +namespace tvm { +namespace runtime { +namespace relax_vm { + +std::unique_ptr ConvertPagedPrefillFunc(Array args, + AttnKind attn_kind) { + if (args.empty()) { + return nullptr; + } + String backend_name = Downcast(args[0]); + if (backend_name == "tir") { + CHECK_EQ(args.size(), 2); + PackedFunc attn_func = Downcast(args[1]); + return std::make_unique(std::move(attn_func), attn_kind); + } + if (backend_name == "flashinfer") { + CHECK_EQ(args.size(), 3); + PackedFunc attn_func = Downcast(args[1]); + PackedFunc plan_func = Downcast(args[2]); + return std::make_unique(std::move(attn_func), std::move(plan_func), + attn_kind); + } + LOG(FATAL) << "Cannot reach here"; + throw; +} + +std::unique_ptr ConvertRaggedPrefillFunc(Array args, + AttnKind attn_kind) { + if (args.empty()) { + return nullptr; + } + String backend_name = Downcast(args[0]); + if (backend_name == "tir") { + CHECK_EQ(args.size(), 2); + PackedFunc attn_func = Downcast(args[1]); + return std::make_unique(std::move(attn_func), attn_kind); + } + if (backend_name == "flashinfer") { + CHECK_EQ(args.size(), 3); + PackedFunc attn_func = Downcast(args[1]); + PackedFunc plan_func = Downcast(args[2]); + return std::make_unique(std::move(attn_func), std::move(plan_func), + attn_kind); + } + LOG(FATAL) << "Cannot reach here"; + throw; +} + +std::unique_ptr ConvertPagedDecodeFunc(Array args, AttnKind attn_kind) { + if (args.empty()) { + return nullptr; + } + String backend_name = Downcast(args[0]); + if (backend_name == "tir") { + CHECK_EQ(args.size(), 2); + PackedFunc attn_func = Downcast(args[1]); + return std::make_unique(std::move(attn_func), attn_kind); + } + if (backend_name == "flashinfer") { + CHECK_EQ(args.size(), 3); + PackedFunc attn_func = Downcast(args[1]); + PackedFunc plan_func = Downcast(args[2]); + return std::make_unique(std::move(attn_func), std::move(plan_func), + attn_kind); + } + LOG(FATAL) << "Cannot reach here"; + throw; +} + +std::unique_ptr ConvertPagedPrefillTreeMaskFunc(Array args, + AttnKind attn_kind) { + if (args.empty()) { + return nullptr; + } + String backend_name = Downcast(args[0]); + if (backend_name == "tir") { + CHECK_EQ(args.size(), 2); + PackedFunc attn_func = Downcast(args[1]); + return std::make_unique(std::move(attn_func), attn_kind); + } + LOG(FATAL) << "Cannot reach here"; + throw; +} + +std::unique_ptr ConvertRaggedPrefillTreeMaskFunc(Array args, + AttnKind attn_kind) { + if (args.empty()) { + return nullptr; + } + String backend_name = Downcast(args[0]); + if (backend_name == "tir") { + CHECK_EQ(args.size(), 2); + PackedFunc attn_func = Downcast(args[1]); + return std::make_unique(std::move(attn_func), attn_kind); + } + LOG(FATAL) << "Cannot reach here"; + throw; +} + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/relax_vm/attn_backend.h b/src/runtime/relax_vm/attn_backend.h new file mode 100644 index 000000000000..4064d2e3de94 --- /dev/null +++ b/src/runtime/relax_vm/attn_backend.h @@ -0,0 +1,531 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/relax_vm/attn_backend.h + * \brief The attention backend classes used by KV cache. + */ + +#ifndef TVM_RUNTIME_RELAX_VM_ATTN_BACKEND_H_ +#define TVM_RUNTIME_RELAX_VM_ATTN_BACKEND_H_ + +#include +#include + +#include +#include +#include +#include + +#include "attn_utils.h" + +namespace tvm { +namespace runtime { +namespace relax_vm { + +/*! \brief The attention backend kinds. */ +enum class AttnBackendKind : int { + kTIR = 0, + kFlashInfer = 1, +}; + +/*! \brief The base class of attention backends. */ +class AttnBackendFunc { + public: + explicit AttnBackendFunc(PackedFunc attn_func, AttnKind attn_kind, AttnBackendKind backend_kind) + : attn_func_(std::move(attn_func)), attn_kind(attn_kind), backend_kind(backend_kind) {} + + virtual ~AttnBackendFunc() = default; + + protected: + PackedFunc attn_func_; + + public: + AttnKind attn_kind; + AttnBackendKind backend_kind; +}; + +/*! \brief The paged prefill attention function base class. */ +class PagedPrefillFunc : public AttnBackendFunc { + public: + explicit PagedPrefillFunc(PackedFunc attn_func, AttnKind attn_kind, AttnBackendKind backend_kind) + : AttnBackendFunc(std::move(attn_func), attn_kind, backend_kind) {} + + virtual void MHA(int depth, NDArray q, NDArray qo_indptr, NDArray pages, NDArray page_indptr, + NDArray page_indices, NDArray length_info, NDArray q_rope_position, + NDArray k_rope_pos_offset, bool causal, RoPEMode rope_mode, double rotary_scale, + double rotary_theta, double sm_scale, NDArray attn_output, NDArray attn_lse, + TVMStreamHandle compute_stream) { + LOG(FATAL) << "MHA computation is not supported by the current backend"; + } + + virtual void MLA(int depth, NDArray q, NDArray qo_indptr, NDArray pages, NDArray page_indptr, + NDArray page_indices, NDArray length_info, bool causal, double sm_scale, + NDArray attn_output, NDArray attn_lse, TVMStreamHandle compute_stream) { + LOG(FATAL) << "MLA computation is not supported by the current backend"; + } + + virtual void BeginForward(int depth, NDArray float_workspace_buffer, NDArray int_workspace_buffer, + NDArray page_locked_int_workspace_buffer, HostMemoryVector* qo_indptr, + HostMemoryVector* page_indptr, HostMemoryVector* last_page_len, + int64_t batch_size, int64_t total_qo_len, int64_t page_size, + int64_t num_qo_heads, int64_t num_kv_heads, int64_t qk_head_dim, + int64_t v_head_dim, bool causal, TVMStreamHandle copy_stream) { + // Do nothing. Subclasses can override to customize behavior. + } +}; + +/*! \brief The TIR-based paged prefill attention function class. */ +class TIRPagedPrefillFunc : public PagedPrefillFunc { + public: + explicit TIRPagedPrefillFunc(PackedFunc attn_func, AttnKind attn_kind) + : PagedPrefillFunc(std::move(attn_func), attn_kind, AttnBackendKind::kTIR) {} + + void MHA(int depth, NDArray q, NDArray qo_indptr, NDArray pages, NDArray page_indptr, + NDArray page_indices, NDArray length_info, NDArray q_rope_position, + NDArray k_rope_pos_offset, bool causal, RoPEMode rope_mode, double rotary_scale, + double rotary_theta, double sm_scale, NDArray attn_output, NDArray attn_lse, + TVMStreamHandle compute_stream) final { + attn_func_(q, qo_indptr, pages, page_indptr, page_indices, length_info, k_rope_pos_offset, + q_rope_position, attn_output, attn_lse, static_cast(causal), + /*rotary_mode=*/static_cast(rope_mode == RoPEMode::kInline), rotary_scale, + rotary_theta, sm_scale); + } + + void MLA(int depth, NDArray q, NDArray qo_indptr, NDArray pages, NDArray page_indptr, + NDArray page_indices, NDArray length_info, bool causal, double sm_scale, + NDArray attn_output, NDArray attn_lse, TVMStreamHandle compute_stream) final { + attn_func_(q, qo_indptr, pages, page_indptr, page_indices, length_info, attn_output, attn_lse, + static_cast(causal), sm_scale); + } +}; + +/*! \brief The FlashInfer-based paged prefill attention function class. */ +class FlashInferPagedPrefillFunc : public PagedPrefillFunc { + public: + explicit FlashInferPagedPrefillFunc(PackedFunc attn_func, PackedFunc plan_func, + AttnKind attn_kind) + : PagedPrefillFunc(std::move(attn_func), attn_kind, AttnBackendKind::kFlashInfer), + plan_func_(std::move(plan_func)) {} + + void MHA(int depth, NDArray q, NDArray qo_indptr, NDArray pages, NDArray page_indptr, + NDArray page_indices, NDArray length_info, NDArray q_rope_position, + NDArray k_rope_pos_offset, bool causal, RoPEMode rope_mode, double rotary_scale, + double rotary_theta, double sm_scale, NDArray attn_output, NDArray attn_lse, + TVMStreamHandle compute_stream) final { + auto [float_workspace_buffer, int_workspace_buffer, page_locked_int_workspace_buffer, + plan_info_vec] = cached_buffers_[depth]; + double rope_rcp_scale = 1 / rotary_scale; + double rope_rcp_theta = 1 / rotary_theta; + attn_func_(float_workspace_buffer, int_workspace_buffer, plan_info_vec, q, pages, qo_indptr, + page_indptr, page_indices, length_info, q_rope_position, k_rope_pos_offset, + attn_output, attn_lse, /*mask_mode_code=*/static_cast(causal), + /*pos_encoding_mode_code=*/static_cast(rope_mode == RoPEMode::kInline), + /*layout(HND)=*/1, /*window_left=*/-1, sm_scale, /*rope_rcp_scale=*/rope_rcp_scale, + /*rope_rcp_theta=*/rope_rcp_theta, compute_stream); + } + + void MLA(int depth, NDArray q, NDArray qo_indptr, NDArray pages, NDArray page_indptr, + NDArray page_indices, NDArray length_info, bool causal, double sm_scale, + NDArray attn_output, NDArray attn_lse, TVMStreamHandle compute_stream) final { + auto [float_workspace_buffer, int_workspace_buffer, page_locked_int_workspace_buffer, + plan_info_vec] = cached_buffers_[depth]; + attn_func_(float_workspace_buffer, int_workspace_buffer, plan_info_vec, q, pages, page_indices, + attn_output, attn_lse, /*mask_mode_code=*/static_cast(causal), + /*num_heads=*/q->shape[1], /*page_size=*/pages->shape[1], sm_scale, compute_stream); + } + + void BeginForward(int depth, NDArray float_workspace_buffer, NDArray int_workspace_buffer, + NDArray page_locked_int_workspace_buffer, HostMemoryVector* qo_indptr, + HostMemoryVector* page_indptr, HostMemoryVector* last_page_len, + int64_t batch_size, int64_t total_qo_len, int64_t page_size, + int64_t num_qo_heads, int64_t num_kv_heads, int64_t qk_head_dim, + int64_t v_head_dim, bool causal, TVMStreamHandle copy_stream) final { + std::vector kv_len; + kv_len.reserve(batch_size); + for (int i = 0; i < static_cast(batch_size); ++i) { + kv_len.push_back((*page_indptr)[i + 1] != (*page_indptr)[i] + ? ((*page_indptr)[i + 1] - (*page_indptr)[i] - 1) * page_size + + (*last_page_len)[i] + : 0); + } + IntTuple plan_info_vec; + if (attn_kind == AttnKind::kMHA) { + // Todo(tvm-team): enable cuda graph + plan_info_vec = plan_func_( + float_workspace_buffer, int_workspace_buffer, page_locked_int_workspace_buffer, + qo_indptr->as_ndarray(), page_indptr->as_ndarray(), IntTuple(std::move(kv_len)), + total_qo_len, batch_size, num_qo_heads, num_kv_heads, page_size, + /*enable_cuda_graph=*/false, qk_head_dim, v_head_dim, causal, copy_stream); + } else if (attn_kind == AttnKind::kMLA) { + plan_info_vec = + plan_func_(float_workspace_buffer, int_workspace_buffer, page_locked_int_workspace_buffer, + qo_indptr->as_ndarray(), page_indptr->as_ndarray(), + IntTuple(std::move(kv_len)), num_qo_heads, v_head_dim, causal, copy_stream); + } + + if (cached_buffers_.size() <= static_cast(depth)) { + cached_buffers_.resize(depth + 1); + } + cached_buffers_[depth] = + std::make_tuple(float_workspace_buffer, int_workspace_buffer, + page_locked_int_workspace_buffer, std::move(plan_info_vec)); + } + + private: + PackedFunc plan_func_; + std::vector> cached_buffers_; +}; + +/*! \brief The ragged prefill attention function base class. */ +class RaggedPrefillFunc : public AttnBackendFunc { + public: + explicit RaggedPrefillFunc(PackedFunc attn_func, AttnKind attn_kind, AttnBackendKind backend_kind) + : AttnBackendFunc(std::move(attn_func), attn_kind, backend_kind) {} + + virtual void MHA(NDArray q, NDArray k, NDArray v, NDArray qo_indptr, NDArray kv_indptr, + NDArray q_rope_position, NDArray k_rope_pos_offset, bool causal, + RoPEMode rope_mode, double rotary_scale, double rotary_theta, double sm_scale, + NDArray attn_output, NDArray attn_lse, TVMStreamHandle compute_stream) { + LOG(FATAL) << "MHA computation is not supported by the current backend"; + } + + virtual void BeginForward(NDArray float_workspace_buffer, NDArray int_workspace_buffer, + NDArray page_locked_int_workspace_buffer, HostMemoryVector* qo_indptr, + HostMemoryVector* kv_indptr, int64_t batch_size, int64_t total_qo_len, + int64_t num_qo_heads, int64_t num_kv_heads, int64_t qk_head_dim, + int64_t v_head_dim, bool causal, TVMStreamHandle copy_stream) { + // Do nothing. Subclasses can override to customize behavior. + } +}; + +/*! \brief The TIR-based ragged prefill attention function class. */ +class TIRRaggedPrefillFunc : public RaggedPrefillFunc { + public: + explicit TIRRaggedPrefillFunc(PackedFunc attn_func, AttnKind attn_kind) + : RaggedPrefillFunc(std::move(attn_func), attn_kind, AttnBackendKind::kTIR) {} + + void MHA(NDArray q, NDArray k, NDArray v, NDArray qo_indptr, NDArray kv_indptr, + NDArray q_rope_position, NDArray k_rope_pos_offset, bool causal, RoPEMode rope_mode, + double rotary_scale, double rotary_theta, double sm_scale, NDArray attn_output, + NDArray attn_lse, TVMStreamHandle compute_stream) final { + attn_func_(q, qo_indptr, k, v, kv_indptr, q_rope_position, k_rope_pos_offset, attn_output, + attn_lse, static_cast(causal), + /*rotary_mode=*/static_cast(rope_mode == RoPEMode::kInline), rotary_scale, + rotary_theta, sm_scale); + } +}; + +/*! \brief The FlashInfer-based ragged prefill attention function class. */ +class FlashInferRaggedPrefillFunc : public RaggedPrefillFunc { + public: + explicit FlashInferRaggedPrefillFunc(PackedFunc attn_func, PackedFunc plan_func, + AttnKind attn_kind) + : RaggedPrefillFunc(std::move(attn_func), attn_kind, AttnBackendKind::kFlashInfer), + plan_func_(std::move(plan_func)) {} + + void MHA(NDArray q, NDArray k, NDArray v, NDArray qo_indptr, NDArray kv_indptr, + NDArray q_rope_position, NDArray k_rope_pos_offset, bool causal, RoPEMode rope_mode, + double rotary_scale, double rotary_theta, double sm_scale, NDArray attn_output, + NDArray attn_lse, TVMStreamHandle compute_stream) final { + double rope_rcp_scale = 1 / rotary_scale; + double rope_rcp_theta = 1 / rotary_theta; + attn_func_(float_workspace_buffer_, int_workspace_buffer_, plan_info_vec_, q, k, v, qo_indptr, + kv_indptr, q_rope_position, k_rope_pos_offset, attn_output, attn_lse, + /*mask_mode_code=*/static_cast(causal), + /*pos_encoding_mode_code=*/static_cast(rope_mode == RoPEMode::kInline), + /*layout(NHD)=*/0, /*window_left=*/-1, sm_scale, + /*rope_rcp_scale=*/rope_rcp_scale, + /*rope_rcp_theta=*/rope_rcp_theta, compute_stream); + } + + void BeginForward(NDArray float_workspace_buffer, NDArray int_workspace_buffer, + NDArray page_locked_int_workspace_buffer, HostMemoryVector* qo_indptr, + HostMemoryVector* kv_indptr, int64_t batch_size, int64_t total_qo_len, + int64_t num_qo_heads, int64_t num_kv_heads, int64_t qk_head_dim, + int64_t v_head_dim, bool causal, TVMStreamHandle copy_stream) final { + std::vector kv_len; + kv_len.reserve(batch_size); + for (int i = 0; i < static_cast(batch_size); ++i) { + kv_len.push_back((*kv_indptr)[i + 1] - (*kv_indptr)[i]); + } + // Todo(tvm-team): enable cuda graph + float_workspace_buffer_ = float_workspace_buffer; + int_workspace_buffer_ = int_workspace_buffer; + page_locked_int_workspace_buffer_ = page_locked_int_workspace_buffer; + plan_info_vec_ = + plan_func_(float_workspace_buffer, int_workspace_buffer, page_locked_int_workspace_buffer, + qo_indptr->as_ndarray(), kv_indptr->as_ndarray(), IntTuple(std::move(kv_len)), + total_qo_len, batch_size, num_qo_heads, num_kv_heads, /*page_size=*/1, + /*enable_cuda_graph=*/false, qk_head_dim, v_head_dim, causal, copy_stream); + } + + private: + PackedFunc plan_func_; + NDArray float_workspace_buffer_; + NDArray int_workspace_buffer_; + NDArray page_locked_int_workspace_buffer_; + IntTuple plan_info_vec_; +}; + +/*! \brief The paged decode attention function base class. */ +class PagedDecodeFunc : public AttnBackendFunc { + public: + explicit PagedDecodeFunc(PackedFunc attn_func, AttnKind attn_kind, AttnBackendKind backend_kind) + : AttnBackendFunc(std::move(attn_func), attn_kind, backend_kind) {} + + virtual void MHA(int depth, NDArray q, NDArray pages, NDArray page_indptr, NDArray page_indices, + NDArray length_info, NDArray k_rope_pos_offset, NDArray q_rope_position, + RoPEMode rope_mode, double rotary_scale, double rotary_theta, double sm_scale, + NDArray attn_output, NDArray attn_lse, TVMStreamHandle compute_stream) { + LOG(FATAL) << "MHA computation is not supported by the current backend"; + } + + virtual void MLA(int depth, NDArray q, NDArray pages, NDArray page_indptr, NDArray page_indices, + NDArray length_info, double sm_scale, NDArray attn_output, NDArray attn_lse, + TVMStreamHandle compute_stream) { + LOG(FATAL) << "MLA computation is not supported by the current backend"; + } + + virtual void BeginForward(int depth, NDArray float_workspace_buffer, NDArray int_workspace_buffer, + NDArray page_locked_int_workspace_buffer, HostMemoryVector* page_indptr, + int64_t batch_size, int64_t page_size, int64_t num_qo_heads, + int64_t num_kv_heads, int64_t qk_head_dim, int64_t v_head_dim, + RoPEMode rope_mode, DataType q_dtype, DataType kv_dtype, + TVMStreamHandle copy_stream) { + // Do nothing. Subclasses can override to customize behavior. + } +}; + +/*! \brief The TIR-based paged decode attention function class. */ +class TIRPagedDecodeFunc : public PagedDecodeFunc { + public: + explicit TIRPagedDecodeFunc(PackedFunc attn_func, AttnKind attn_kind) + : PagedDecodeFunc(std::move(attn_func), attn_kind, AttnBackendKind::kTIR) {} + + void MHA(int depth, NDArray q, NDArray pages, NDArray page_indptr, NDArray page_indices, + NDArray length_info, NDArray k_rope_pos_offset, NDArray q_rope_position, + RoPEMode rope_mode, double rotary_scale, double rotary_theta, double sm_scale, + NDArray attn_output, NDArray attn_lse, TVMStreamHandle compute_stream) final { + attn_func_(q, pages, page_indptr, page_indices, length_info, k_rope_pos_offset, q_rope_position, + attn_output, attn_lse, + /*rotary_mode=*/static_cast(rope_mode == RoPEMode::kInline), rotary_scale, + rotary_theta, sm_scale); + } + + void MLA(int depth, NDArray q, NDArray pages, NDArray page_indptr, NDArray page_indices, + NDArray length_info, double sm_scale, NDArray attn_output, NDArray attn_lse, + TVMStreamHandle compute_stream) final { + attn_func_(q, pages, page_indptr, page_indices, length_info, attn_output, attn_lse, sm_scale); + } +}; + +/*! \brief The FlashInfer-based paged decode attention function class. */ +class FlashInferPagedDecodeFunc : public PagedDecodeFunc { + public: + explicit FlashInferPagedDecodeFunc(PackedFunc attn_func, PackedFunc plan_func, AttnKind attn_kind) + : PagedDecodeFunc(std::move(attn_func), attn_kind, AttnBackendKind::kFlashInfer), + plan_func_(std::move(plan_func)) {} + + void MHA(int depth, NDArray q, NDArray pages, NDArray page_indptr, NDArray page_indices, + NDArray length_info, NDArray k_rope_pos_offset, NDArray q_rope_position, + RoPEMode rope_mode, double rotary_scale, double rotary_theta, double sm_scale, + NDArray attn_output, NDArray attn_lse, TVMStreamHandle compute_stream) final { + auto [float_workspace_buffer, int_workspace_buffer, page_locked_int_workspace_buffer, + plan_info_vec] = cached_buffers_[depth]; + double rope_rcp_scale = 1 / rotary_scale; + double rope_rcp_theta = 1 / rotary_theta; + attn_func_(float_workspace_buffer, int_workspace_buffer, plan_info_vec, q, pages, page_indptr, + page_indices, length_info, q_rope_position, k_rope_pos_offset, attn_output, attn_lse, + /*pos_encoding_mode_code=*/static_cast(rope_mode == RoPEMode::kInline), + /*layout(HND)=*/1, /*window_left=*/-1, sm_scale, /*rope_rcp_scale=*/rope_rcp_scale, + /*rope_rcp_theta=*/rope_rcp_theta, compute_stream); + } + + void BeginForward(int depth, NDArray float_workspace_buffer, NDArray int_workspace_buffer, + NDArray page_locked_int_workspace_buffer, HostMemoryVector* page_indptr, + int64_t batch_size, int64_t page_size, int64_t num_qo_heads, + int64_t num_kv_heads, int64_t qk_head_dim, int64_t v_head_dim, + RoPEMode rope_mode, DataType q_dtype, DataType kv_dtype, + TVMStreamHandle copy_stream) final { + // Todo(tvm-team): enable cuda graph + IntTuple plan_info_vec = plan_func_( + float_workspace_buffer, int_workspace_buffer, page_locked_int_workspace_buffer, + page_indptr->as_ndarray(), batch_size, num_qo_heads, num_kv_heads, page_size, + /*enable_cuda_graph=*/false, static_cast(rope_mode == RoPEMode::kInline), + /*window_left=*/-1, qk_head_dim, v_head_dim, q_dtype, kv_dtype, copy_stream); + + if (cached_buffers_.size() <= static_cast(depth)) { + cached_buffers_.resize(depth + 1); + } + cached_buffers_[depth] = + std::make_tuple(float_workspace_buffer, int_workspace_buffer, + page_locked_int_workspace_buffer, std::move(plan_info_vec)); + } + + private: + PackedFunc plan_func_; + std::vector> cached_buffers_; +}; + +/*! \brief The paged prefill with tree mask attention function base class. */ +class PagedPrefillTreeMaskFunc : public AttnBackendFunc { + public: + explicit PagedPrefillTreeMaskFunc(PackedFunc attn_func, AttnKind attn_kind, + AttnBackendKind backend_kind) + : AttnBackendFunc(std::move(attn_func), attn_kind, backend_kind) {} + + virtual void MHA(NDArray q, NDArray qo_indptr, NDArray pages, NDArray page_indptr, + NDArray page_indices, NDArray length_info, NDArray k_rope_pos_offset, + NDArray q_rope_position, NDArray tree_attn_mn_indptr, NDArray tree_attn_mask, + RoPEMode rope_mode, double rotary_scale, double rotary_theta, double sm_scale, + NDArray attn_output, NDArray attn_lse, TVMStreamHandle compute_stream) { + LOG(FATAL) << "MHA computation is not supported by the current backend"; + } + + virtual void MLA(NDArray q, NDArray qo_indptr, NDArray pages, NDArray page_indptr, + NDArray page_indices, NDArray length_info, NDArray tree_attn_mn_indptr, + NDArray tree_attn_mask, double sm_scale, NDArray attn_output, NDArray attn_lse, + TVMStreamHandle compute_stream) { + LOG(FATAL) << "MLA computation is not supported by the current backend"; + } + + virtual void BeginForward(NDArray temp_float_attn_workspace, NDArray temp_int_attn_workspace, + HostMemoryVector* page_indptr, HostMemoryVector* last_page_len, + HostMemoryVector* qo_indptr, int64_t batch_size, int64_t page_size, + int64_t num_qo_heads, int64_t num_kv_heads, int64_t qk_head_dim, + int64_t v_head_dim, RoPEMode rope_mode, TVMStreamHandle copy_stream) { + // Do nothing. Subclasses can override to customize behavior. + } +}; + +/*! \brief The TIR-based paged prefill with tree mask attention function class. */ +class TIRPagedPrefillTreeMaskFunc : public PagedPrefillTreeMaskFunc { + public: + explicit TIRPagedPrefillTreeMaskFunc(PackedFunc attn_func, AttnKind attn_kind) + : PagedPrefillTreeMaskFunc(std::move(attn_func), attn_kind, AttnBackendKind::kTIR) {} + + void MHA(NDArray q, NDArray qo_indptr, NDArray pages, NDArray page_indptr, NDArray page_indices, + NDArray length_info, NDArray k_rope_pos_offset, NDArray q_rope_position, + NDArray tree_attn_mn_indptr, NDArray tree_attn_mask, RoPEMode rope_mode, + double rotary_scale, double rotary_theta, double sm_scale, NDArray attn_output, + NDArray attn_lse, TVMStreamHandle compute_stream) final { + attn_func_(q, qo_indptr, pages, page_indptr, page_indices, length_info, k_rope_pos_offset, + q_rope_position, attn_output, attn_lse, + /*rotary_mode=*/static_cast(rope_mode == RoPEMode::kInline), rotary_scale, + rotary_theta, sm_scale, tree_attn_mn_indptr, tree_attn_mask); + } +}; + +/*! \brief The ragged prefill with tree mask function base class. */ +class RaggedPrefillTreeMaskFunc : public AttnBackendFunc { + public: + explicit RaggedPrefillTreeMaskFunc(PackedFunc attn_func, AttnKind attn_kind, + AttnBackendKind backend_kind) + : AttnBackendFunc(std::move(attn_func), attn_kind, backend_kind) {} + + virtual void MHA(NDArray q, NDArray k, NDArray v, NDArray qo_indptr, NDArray kv_indptr, + NDArray q_rope_position, NDArray tree_attn_mn_indptr, NDArray tree_attn_mask, + RoPEMode rope_mode, double rotary_scale, double rotary_theta, double sm_scale, + NDArray attn_output, NDArray attn_lse, TVMStreamHandle compute_stream) { + LOG(FATAL) << "MHA computation is not supported by the current backend"; + } + + virtual void MLA(NDArray q, NDArray compressed_kv, NDArray k_pe, NDArray qo_indptr, + NDArray kv_indptr, NDArray tree_attn_mn_indptr, NDArray tree_attn_mask, + double sm_scale, NDArray attn_output, NDArray attn_lse, + TVMStreamHandle compute_stream) { + LOG(FATAL) << "MLA computation is not supported by the current backend"; + } + + virtual void BeginForward(NDArray temp_float_attn_workspace, NDArray temp_int_attn_workspace, + HostMemoryVector* page_indptr, HostMemoryVector* last_page_len, + HostMemoryVector* qo_indptr, int64_t batch_size, int64_t page_size, + int64_t num_qo_heads, int64_t num_kv_heads, int64_t qk_head_dim, + int64_t v_head_dim, RoPEMode rope_mode, TVMStreamHandle copy_stream) { + // Do nothing. Subclasses can override to customize behavior. + } +}; + +/*! \brief The TIR-based ragged prefill with tree mask attention function class. */ +class TIRRaggedPrefillTreeMaskFunc : public RaggedPrefillTreeMaskFunc { + public: + explicit TIRRaggedPrefillTreeMaskFunc(PackedFunc attn_func, AttnKind attn_kind) + : RaggedPrefillTreeMaskFunc(std::move(attn_func), attn_kind, AttnBackendKind::kTIR) {} + + void MHA(NDArray q, NDArray k, NDArray v, NDArray qo_indptr, NDArray kv_indptr, + NDArray q_rope_position, NDArray tree_attn_mn_indptr, NDArray tree_attn_mask, + RoPEMode rope_mode, double rotary_scale, double rotary_theta, double sm_scale, + NDArray attn_output, NDArray attn_lse, TVMStreamHandle compute_stream) final { + attn_func_(q, qo_indptr, k, v, kv_indptr, q_rope_position, tree_attn_mn_indptr, tree_attn_mask, + attn_output, attn_lse, + /*rotary_mode=*/static_cast(rope_mode == RoPEMode::kInline), rotary_scale, + rotary_theta, sm_scale); + } +}; + +/*! + * \brief Create a PagedPrefillFunc from the given arguments and the attention kind. + * \param args The arguments that contains the backend kind and the runtime attention PackedFuncs. + * \param attn_kind The attention kind of the function. + * \return The created PagedPrefillFunc pointer. + */ +std::unique_ptr ConvertPagedPrefillFunc(Array args, + AttnKind attn_kind); + +/*! + * \brief Create a PagedDecodeFunc from the given arguments and the attention kind. + * \param args The arguments that contains the backend kind and the runtime attention PackedFuncs. + * \param attn_kind The attention kind of the function. + * \return The created PagedDecodeFunc pointer. + */ +std::unique_ptr ConvertPagedDecodeFunc(Array args, AttnKind attn_kind); + +/*! + * \brief Create a RaggedPrefillFunc from the given arguments and the attention kind. + * \param args The arguments that contains the backend kind and the runtime attention PackedFuncs. + * \param attn_kind The attention kind of the function. + * \return The created RaggedPrefillFunc pointer. + */ +std::unique_ptr ConvertRaggedPrefillFunc(Array args, + AttnKind attn_kind); + +/*! + * \brief Create a PagedPrefillTreeMaskFunc from the given arguments and the attention kind. + * \param args The arguments that contains the backend kind and the runtime attention PackedFuncs. + * \param attn_kind The attention kind of the function. + * \return The created PagedPrefillTreeMaskFunc pointer. + */ +std::unique_ptr ConvertPagedPrefillTreeMaskFunc(Array args, + AttnKind attn_kind); + +/*! + * \brief Create a RaggedPrefillTreeMaskFunc from the given arguments and the attention kind. + * \param args The arguments that contains the backend kind and the runtime attention PackedFuncs. + * \param attn_kind The attention kind of the function. + * \return The created RaggedPrefillTreeMaskFunc pointer. + */ +std::unique_ptr ConvertRaggedPrefillTreeMaskFunc(Array args, + AttnKind attn_kind); + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_RELAX_VM_ATTN_BACKEND_H_ diff --git a/src/runtime/relax_vm/attn_utils.h b/src/runtime/relax_vm/attn_utils.h new file mode 100644 index 000000000000..af46073adefe --- /dev/null +++ b/src/runtime/relax_vm/attn_utils.h @@ -0,0 +1,1027 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/runtime/relax_vm/attn_utils.h + * \brief Data structure and utilities for KV cache. + */ + +#ifndef TVM_RUNTIME_RELAX_VM_ATTN_UTILS_H_ +#define TVM_RUNTIME_RELAX_VM_ATTN_UTILS_H_ + +#include + +#include +#include +#include +#include + +namespace tvm { +namespace runtime { +namespace relax_vm { + +/*! + * \brief The maximum allowed block depth (a.k.a. number of common + * prefixes) in paged KV cache. + */ +constexpr const int kPagedKVCacheMaxBlockDepth = 2; +/*! \brief The maximum tree size of a single sequence in tree attention. */ +constexpr const int kTreeAttnMaxTreeSize = 256; +/*! \brief The 1MB workspace size for integer attention auxiliary data. */ +constexpr const int kIntAttnWorkspaceByte = 8 * 1024 * 1024; +/*! \brief The 128MB workspace size for floating-point attention auxiliary data. */ +constexpr const int kFloatAttnWorkspaceByte = 768 * 1024 * 1024; +/*! \brief The id of the temporary logical page, which is useful for sliding window. */ +constexpr const int kPagedKVCacheTempPageId = -1; + +/*! + * \brief The supported attention kinds in PagedKVCache. + * "MHA" means multi-head attention, multi-query attention and grouped query attention in general. + * "MLA" means multi-head latent attention. + * "LinearAttn" means linear attention. + */ +enum class AttnKind : int { + kMHA = 0, + kMLA = 1, + kLinearAttn = 2, +}; + +/*! \brief Given the attention kind and other metadata, return the one-layer KV cache shape. */ +inline ShapeTuple GetKVCacheShape(AttnKind attn_kind, int64_t num_total_pages, int num_sequence, + int64_t num_kv_heads, int64_t page_size, int64_t qk_head_dim, + int64_t v_head_dim) { + if (attn_kind == AttnKind::kMHA) { + // Ignore v_head_dim since multi-head attention requires K/V to have the same head dim. + return {num_total_pages, 2, num_kv_heads, page_size, qk_head_dim}; + } else if (attn_kind == AttnKind::kMLA) { + return {num_total_pages, page_size, qk_head_dim}; + } else if (attn_kind == AttnKind::kLinearAttn) { + return {num_sequence, num_kv_heads, qk_head_dim, v_head_dim}; + } + ICHECK(false); + return ShapeTuple(); +} + +/*! + * \brief The block structure in paged KV cache with common prefix support. + * Each block contains a list of pages for cached KV data. + * If a block has `n` pages, the first `n - 1` pages must be + * full, and only the last page can be partially filled. + * + * To support common prefix, each sequence in KV cache is represented + * as one or more blocks, where the common prefix is a standalone + * block among. + * + * Each block has a parent block when it uses a prefix. + */ +struct Block { + /*! + * \brief The ids of the pages in the block. + * Each page can only be used by a unique block (in other + * words, different blocks do not share pages). + */ + std::vector page_ids; + /*! \brief The total sequence length in the block. */ + int32_t seq_length = 0; + /*! + * \brief The start position in sequence of this block. + * This is the absolute position in the sequence for RoPE computation. + */ + int32_t start_pos = 0; + /*! + * \brief The current attention sink length of the block. + * It means the **first** sink size elements will be pinned + * in the KV cache even when sliding window is enabled. + */ + int32_t sink_length = 0; + /*! + * \brief The start offset of the sliding window in the block. + * It is always 0 when sliding window attn is not enabled. + */ + int32_t sliding_window_offset = 0; + + /*! \brief The global index of the block. */ + const int32_t index; + /*! + * \brief The global index of the parent block of this block, or -1 + * if the block does not have a parent. */ + int32_t parent_idx = -1; + /*! + * \brief The external reference counter of the block. + * When a block is externally referred by some block, + * we do not allow appending new KV values to this block. + */ + int external_ref_cnt = 0; + + explicit Block(int32_t index) : index(index) {} + + /*! \brief Reset the block data. */ + void Reset() { + page_ids.clear(); + seq_length = 0; + start_pos = 0; + sink_length = 0; + sliding_window_offset = 0; + parent_idx = -1; + external_ref_cnt = 0; + } +}; + +struct KVTransferMetadata { + int64_t start = std::numeric_limits::max(); + std::vector remote_position_map; + int32_t recver_pe_offset = -1; + std::vector local_position_map; +}; + +/*! + * \brief The sequence structure in paged KV cache with common prefix support. + * Each sequence contains one or more blocks to support common prefix. + */ +struct Sequence { + /*! + * \brief The global index of the last block of the sequence. + * We only store the last block, since all the blocks can be + * tracked with the `parent` field of Block. + */ + int32_t last_block_idx; + /*! + * \brief The total sequence length of the sequence. + * It is the sum of lengths of all its blocks. + */ + int32_t seq_length = 0; + /*! + * \brief The sliding window size of the sequence, or -1 if sliding window is not enabled. + * When a sequence is enabled for sliding window, it can no longer be forked. + */ + int sliding_window_size = -1; + /*! + * \brief The attention sink size of the last block of the sequence. + * The **first** sink size elements of the last block will be pinned + * in the KV cache even when sliding window is enabled. + */ + int last_block_attn_sink_size = 0; + + /*! \brief Whether the current appended tokens form a chain (not a tree). */ + bool is_chain = true; + /*! \brief The token tree parent pointer array of the current appended tokens. */ + std::vector token_tree_parent_ptr; + /*! \brief The depth of each node in the token tree. */ + std::vector token_tree_node_depths; + /*! \brief The metadata of kv transfer*/ + KVTransferMetadata kv_transfer_metadata; + /*! + * \brief A boolean denoting whether the accepted token tree indices of + * this sequence are committed + */ + bool accepted_indices_committed = true; + + explicit Sequence(std::vector* global_block_pool, int32_t last_block_idx) { + ++global_block_pool->at(last_block_idx).external_ref_cnt; + this->last_block_idx = last_block_idx; + int32_t block_ptr = last_block_idx; + // Go through each block in the sequence, sum up the length. + while (true) { + const Block& block = global_block_pool->at(block_ptr); + this->seq_length += block.seq_length; + if (block.parent_idx == -1) { + break; + } + block_ptr = block.parent_idx; + } + } + + std::vector GetBlockTrace(const std::vector& global_block_pool) const { + std::vector trace; + // Get the trace from the last block of the sequence to the root block. + int32_t block_ptr = last_block_idx; + while (block_ptr != -1) { + trace.push_back(block_ptr); + block_ptr = global_block_pool[block_ptr].parent_idx; + } + // Reverse the trace so that it starts from the root block. + std::reverse(trace.begin(), trace.end()); + return trace; + } +}; + +/*! + * \brief For the given list of sequences, check the block trace of + * each sequence, and return the blocks ids used by the sequences + * on each depth. And if the depth is larger than the kPagedKVCacheMaxBlockDepth, + * the exceeding blocks will concatenate and output separately. + * More precisely, the inner returned vector contains the block ids + * used by the sequences on a certain depth (or "-1" if a sequence + * has fewer depth). The outer returned vector contains the inner + * vectors from the lowest depth to the highest depth. + */ +inline std::pair>, std::vector>> +GetBlockIdsOnDepth(const std::vector& sequences, + const std::vector& global_block_pool, int64_t batch_size) { + // - Get the trace of each sequence. + int64_t num_depths = 0; + std::vector> seq_block_traces; + std::vector> trailing_block_traces; + seq_block_traces.reserve(batch_size); + trailing_block_traces.reserve(batch_size); + for (int i = 0; i < batch_size; ++i) { + std::vector trace = sequences[i]->GetBlockTrace(global_block_pool); + if (static_cast(trace.size()) <= kPagedKVCacheMaxBlockDepth) { + seq_block_traces.push_back(std::vector(trace.begin(), trace.end())); + trailing_block_traces.push_back({}); + num_depths = std::max(num_depths, static_cast(trace.size())); + } else { + seq_block_traces.push_back( + std::vector(trace.begin(), trace.begin() + kPagedKVCacheMaxBlockDepth)); + trailing_block_traces.push_back( + std::vector(trace.begin() + kPagedKVCacheMaxBlockDepth, trace.end())); + num_depths = std::max(num_depths, static_cast(kPagedKVCacheMaxBlockDepth)); + } + } + + // "Transpose" the traces, yielding the block ids used on each depth. + std::vector> block_ids_on_depths; + block_ids_on_depths.reserve(num_depths); + for (int d = 0; d < num_depths; ++d) { + std::vector block_ids; + block_ids.reserve(batch_size); + for (int i = 0; i < batch_size; ++i) { + block_ids.push_back(d < static_cast(seq_block_traces[i].size()) ? seq_block_traces[i][d] + : -1); + } + block_ids_on_depths.push_back(std::move(block_ids)); + } + return {block_ids_on_depths, trailing_block_traces}; +} + +/*! + * \brief This function considers an optimization which coalesces + * adjacent decode attention computations into a single prefill + * attention computation if the adjacent decodes attend to the same + * k/v values under certain conditions. + * If it decides to coalesce on a certain depth, we need to know + * the prefill length after coalescing. This function returns + * - a vector of block ids together with the prefill/decode lengths + * that attend to the blocks. + * - a boolean indicating whether to use decode kernel on for the + * input blocks. + */ +inline std::pair>, bool> GetChunkedBlockIds( + const std::vector& block_ids, bool enable_coalesce, const IntTuple& append_lengths, + const std::vector& global_block_pool, bool is_decode_request) { + std::vector> uncoalesced_block_ids; + std::vector> coalesced_block_ids; + + // Gather the number of pages before/after coalescing respectively. + int cur_block_id = block_ids[0]; + int chunk_append_length = append_lengths[0]; + int page_counter_coalesced = 0; + int page_counter_uncoalesced = + block_ids[0] != -1 ? global_block_pool[block_ids[0]].page_ids.size() : 0; + for (int i = 1; i < static_cast(block_ids.size()); ++i) { + if (block_ids[i] != -1) { + page_counter_uncoalesced += global_block_pool[block_ids[i]].page_ids.size(); + } + uncoalesced_block_ids.emplace_back(block_ids[i - 1], append_lengths[i - 1]); + if (block_ids[i] == cur_block_id) { + chunk_append_length += append_lengths[i]; + } else { + coalesced_block_ids.emplace_back(cur_block_id, chunk_append_length); + if (cur_block_id != -1) { + page_counter_coalesced += global_block_pool[cur_block_id].page_ids.size(); + } + cur_block_id = block_ids[i]; + chunk_append_length = append_lengths[i]; + } + } + uncoalesced_block_ids.emplace_back(block_ids.back(), append_lengths.back()); + coalesced_block_ids.emplace_back(cur_block_id, chunk_append_length); + if (cur_block_id != -1) { + page_counter_coalesced += global_block_pool[cur_block_id].page_ids.size(); + } + double coalesce_ratio = + page_counter_coalesced > 0 ? 1.0 * page_counter_uncoalesced / page_counter_coalesced : 0.0; + // Do not coalesce and use batch decode kernel when coalesce ratio is small. + bool use_decode_kernel = is_decode_request && coalesce_ratio < 32; + return {use_decode_kernel || !enable_coalesce ? uncoalesced_block_ids : coalesced_block_ids, + use_decode_kernel}; +} + +/*! + * \brief The rotary embedding mode adopted by the paged KV cache + * when computing attention. + * "None" means RoPE is never applied to q and k. + * "Normal" means RoPE is computed in a standalone kernel. + * "Inline" means RoPE is computed on-the-fly in attention kernels. + */ +enum class RoPEMode : int { + kNone = 0, + kNormal = 1, + kInline = 2, +}; + +/*! + * \brief The class of host memory int32 vector in "std::vector" interface. + * This vector allocates static memory on the specified host memory + * at the time of construction. + */ +class HostMemoryVector { + public: + HostMemoryVector() = default; + HostMemoryVector(const HostMemoryVector&) = delete; + HostMemoryVector(HostMemoryVector&& other) = default; + HostMemoryVector& operator=(const HostMemoryVector&) = delete; + HostMemoryVector& operator=(HostMemoryVector&& other) = default; + + explicit HostMemoryVector(int64_t reserved_size, DLDataType dtype, Device device) + : reserved_size_(reserved_size) { + ICHECK(DataType(dtype) == DataType::Int(32)); + data_ = NDArray::Empty({reserved_size}, dtype, device); + } + + void push_back(int32_t value) { + ICHECK_LE(current_size_, reserved_size_); + if (current_size_ == reserved_size_) { + reserved_size_ *= 2; + NDArray new_data = NDArray::Empty({reserved_size_}, data_->dtype, data_->device); + std::memcpy(new_data->data, data_->data, current_size_ * DataType(data_->dtype).bytes()); + data_ = new_data; + } + static_cast(data_->data)[current_size_++] = value; + } + + const int32_t& operator[](int64_t idx) const { + ICHECK_GE(idx, 0) << "Index " << idx << " is negative."; + ICHECK_LT(idx, current_size_) << "Index " << idx << " out of bounds " << current_size_; + return static_cast(data_->data)[idx]; + } + + int32_t back() const { + ICHECK_GT(current_size_, 0) << "Vector is empty"; + return static_cast(data_->data)[current_size_ - 1]; + } + + size_t size() const { return static_cast(current_size_); } + + int32_t* data() const { return static_cast(data_->data); } + + void clear() { current_size_ = 0; } + + /*! \brief Return the vector as an NDArray. */ + NDArray as_ndarray() { return data_.CreateView({current_size_}, data_->dtype); } + + IntTuple as_int_tuple() const { + std::vector values; + values.reserve(current_size_); + for (int i = 0; i < current_size_; ++i) { + values.push_back(static_cast(data_->data)[i]); + } + return IntTuple(values); + } + + private: + int64_t reserved_size_ = 0; + int64_t current_size_ = 0; + NDArray data_{nullptr}; +}; + +/*! + * \brief The paged attention auxiliary data manager class. + * This class manages all the int32 auxiliary data on GPU device, such as + * page table, position arrays, etc.. + * + * The core functions of this class is `CopyXXXAsync` and `CommitAttnAuxDataCopy`. + * `CopyXXXAsync` takes the input data on CPU host, and copy the input data + * to GPU in an asynchronous way, and returns the NDArray view of the data + * on GPU device. + * + * Being asynchronous here means the `CopyXXXAsync` function may not perform + * data copy from CPU to GPU at the time of being called. Therefore, the + * returned NDArray view may have wrong result, until `CommitAttnAuxDataCopy` is + * explicitly invoked and the data copy stream is synchronized. + * + * We design this manager class in order to reduce the data copy overhead. + */ +class PagedKVCacheAuxDataManager { + public: + PagedKVCacheAuxDataManager(DLDataType dtype_aux, Device device, Device preferred_host_device, + TVMStreamHandle copy_stream) + : dtype_aux_(dtype_aux), + device_(device), + preferred_host_device_(preferred_host_device), + copy_stream_(copy_stream) { + ICHECK(DataType(dtype_aux) == DataType::Int(32)); + } + + virtual ~PagedKVCacheAuxDataManager() = default; + /*! \brief Reset the attention auxiliary data status of copy manager. */ + virtual void ResetAttnAuxDataCopy() = 0; + /*! \brief Copy the indptr array of append lengths after coalescing. (see GetChunkedBlockIds) */ + virtual NDArray CopyQOIndptrOnDepthAsync(HostMemoryVector* data, int depth) = 0; + /*! \brief Copy the indptr array of page table. */ + virtual NDArray CopyPageIndptrOnDepthAsync(HostMemoryVector* data, int depth) = 0; + /*! \brief Copy the indices array of page table. */ + virtual NDArray CopyPageIndicesOnDepthAsync(HostMemoryVector* data, int depth) = 0; + /*! \brief Copy the array of KV slot number used in the last page of the seq. */ + virtual NDArray CopyLastPageLenOnDepthAsync(HostMemoryVector* data, int depth) = 0; + /*! + * \brief Copy the length information of the sequences. + * Each NDArray is in shape `(3, n)`. "n" is the number of sequences. + * For a sequence "i", location + * - "(0, i)" is the number of KV slots used in the last page of the seq ("last_page_len"), + * - "(1, i)" is the starting offset of the sliding window in the seq, + * - "(2, i)" is the attn sink length of the sequence. + * \note When sliding window is not enabled, only the + * "last_page_len" (a.k.a., the first "n" elements) will be effectively used. + */ + virtual NDArray CopyLengthInfoOnDepthAsync(HostMemoryVector* last_page_len, + HostMemoryVector* sliding_window_offset, + HostMemoryVector* sink_size, int depth) = 0; + /*! \brief Copy the k position offset of applying RoPE for each sequence. */ + virtual NDArray CopyKRoPEPosOffsetOnDepthAsync(HostMemoryVector* data, int depth) = 0; + /*! + * \brief Copy the append length indptr array on device. + * \note Since the Q/K/V data may have raggedness in terms of lengths, + * we represent the append lengths in CSR format. + */ + virtual NDArray CopyCurAppendLengthIndptrAsync(HostMemoryVector* data) = 0; + /*! \brief Copy the k position offset of applying RoPE for each sequence. */ + virtual NDArray CopyKRaggedRoPEPosOffsetAsync(HostMemoryVector* data) = 0; + /*! \brief Copy the q position mapping of applying RoPE for each sequence. */ + virtual NDArray CopyQRoPEPosMapAsync(HostMemoryVector* data) = 0; + /*! + * \brief Copy the corresponding position in global KV cache (pages) + * for each position along the length dimension of K/V data when + * appending new K/V data. + */ + virtual NDArray CopyAppendPositionMapAsync(HostMemoryVector* data) = 0; + /*! \brief Copy the remote position map for KV transfer. */ + virtual NDArray CopyKVTransferRemotePositionMapAsync(HostMemoryVector* data) = 0; + /*! \brief Copy the receiver id for KV transfer. */ + virtual NDArray CopyKVTransferRecverIDAsync(HostMemoryVector* data) = 0; + /*! \brief Copy the local position map for KV page-to-page transfer. */ + virtual NDArray CopyKVTransferPage2PageLocalPositionMapAsync(HostMemoryVector* data) = 0; + /*! \brief Copy the remote position map for KV page-to-page transfer. */ + virtual NDArray CopyKVTransferPage2PageRemotePositionMapAsync(HostMemoryVector* data) = 0; + /*! \brief Copy the receiver id for KV page-to-page transfer. */ + virtual NDArray CopyKVTransferPage2PageRecverIDAsync(HostMemoryVector* data) = 0; + /*! \brief Copy the tree attention mask. */ + virtual NDArray CopyTreeAttnMaskOnDepthAsync(HostMemoryVector* data, int depth) = 0; + /*! \brief Copy the mn indptr of the tree attention mask. */ + virtual NDArray CopyTreeAttnMNIndptrOnDepthAsync(HostMemoryVector* data, int depth) = 0; + /*! \brief Commit all the attention auxiliary data copy operations since the last commit. */ + virtual void CommitAttnAuxDataCopy() = 0; + + /*! \brief Reset the compact KV auxiliary data status of copy manager. */ + virtual void ResetCompactKVAuxDataCopy() = 0; + /*! \brief Copy the length indptr array of KV data copy for each sequence. */ + virtual NDArray CopyCommitLengthIndptrAsync(HostMemoryVector* data) = 0; + /*! \brief Copy the src/dst position arrays for each sequence. */ + virtual NDArray CopyCommitSrcDstPosInPageTableAsync(HostMemoryVector* src_data, + HostMemoryVector* dst_data) = 0; + /*! \brief Commit all the compact KV auxiliary data copy operations since the last commit. */ + virtual void CommitCompactKVAuxDataCopy() = 0; + + protected: + /*! \brief The dtype of the auxiliary data. It is expected to be int32. */ + const DLDataType dtype_aux_; + /*! \brief The device this PagedKVCache runs on. */ + const Device device_; + /*! \brief The preferred host device. */ + const Device preferred_host_device_; + /*! \brief The device stream for copying auxiliary data structure to GPU. */ + const TVMStreamHandle copy_stream_; +}; + +/*! + * \brief The plain auxiliary data manager class. + * It simply issues one host-to-device copy operation for each `CopyXXXAsync`. + */ +class PlainPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { + public: + explicit PlainPagedKVCacheAuxDataManager(int64_t reserved_num_seqs, int64_t num_total_pages, + int64_t prefill_chunk_size, DLDataType dtype_aux, + Device device, Device preferred_host_device, + TVMStreamHandle copy_stream) + : PagedKVCacheAuxDataManager(dtype_aux, device, preferred_host_device, copy_stream) { + for (int d = 0; d < kPagedKVCacheMaxBlockDepth; ++d) { + qo_indptr_on_depths_device_.push_back( + NDArray::Empty({reserved_num_seqs + 1}, dtype_aux_, device)); + page_indptr_on_depths_device_.push_back( + NDArray::Empty({reserved_num_seqs + 1}, dtype_aux_, device)); + page_indices_on_depths_device_.push_back( + NDArray::Empty({num_total_pages}, dtype_aux_, device)); + length_info_on_depths_device_.push_back( + NDArray::Empty({3, reserved_num_seqs}, dtype_aux_, device)); + k_rope_pos_offset_on_depths_device_.push_back( + NDArray::Empty({reserved_num_seqs}, dtype_aux_, device)); + tree_attn_mask_device_.push_back(NDArray::Empty( + {kTreeAttnMaxTreeSize * kTreeAttnMaxTreeSize * reserved_num_seqs}, dtype_aux_, device)); + tree_attn_mn_indptr_device_.push_back( + NDArray::Empty({reserved_num_seqs + 1}, dtype_aux_, device)); + } + cur_append_length_indptr_device_ = NDArray::Empty({reserved_num_seqs + 1}, dtype_aux_, device); + k_ragged_rope_pos_offset_device_ = NDArray::Empty({reserved_num_seqs}, dtype_aux_, device); + q_rope_position_map_device_ = NDArray::Empty({prefill_chunk_size}, dtype_aux_, device); + append_position_map_device_ = NDArray::Empty({prefill_chunk_size}, dtype_aux_, device); + kv_transfer_remote_position_map_device = + NDArray::Empty({prefill_chunk_size}, dtype_aux_, device); + kv_transfer_recver_id_device = NDArray::Empty({prefill_chunk_size}, dtype_aux_, device); + kv_transfer_page_to_page_local_position_map_device = + kv_transfer_page_to_page_remote_position_map_device = + NDArray::Empty({prefill_chunk_size}, dtype_aux_, device); + kv_transfer_page_to_page_recver_id_device = + NDArray::Empty({prefill_chunk_size}, dtype_aux_, device); + commit_copy_length_indptr_device_ = NDArray::Empty({reserved_num_seqs + 1}, dtype_aux_, device); + commit_copy_src_dst_pos_in_page_table_device_ = + NDArray::Empty({2, std::min(kTreeAttnMaxTreeSize * reserved_num_seqs, prefill_chunk_size)}, + dtype_aux_, device); + } + + // The reset of the plain auxiliary data manager is no-op. + void ResetAttnAuxDataCopy() final {} + NDArray CopyQOIndptrOnDepthAsync(HostMemoryVector* data, int depth) final { + NDArray view = qo_indptr_on_depths_device_[depth].CreateView( + {static_cast(data->size())}, dtype_aux_); + CopyVecDataToArray(view, data->data()); + return view; + } + NDArray CopyPageIndptrOnDepthAsync(HostMemoryVector* data, int depth) final { + NDArray view = page_indptr_on_depths_device_[depth].CreateView( + {static_cast(data->size())}, dtype_aux_); + CopyVecDataToArray(view, data->data()); + return view; + } + NDArray CopyPageIndicesOnDepthAsync(HostMemoryVector* data, int depth) final { + NDArray view = page_indices_on_depths_device_[depth].CreateView( + {static_cast(data->size())}, dtype_aux_); + CopyVecDataToArray(view, data->data()); + return view; + } + NDArray CopyLastPageLenOnDepthAsync(HostMemoryVector* data, int depth) final { + NDArray view = length_info_on_depths_device_[depth].CreateView( + {static_cast(data->size())}, dtype_aux_); + CopyVecDataToArray(view, data->data()); + return view; + } + NDArray CopyKRoPEPosOffsetOnDepthAsync(HostMemoryVector* data, int depth) final { + NDArray view = k_rope_pos_offset_on_depths_device_[depth].CreateView( + {static_cast(data->size())}, dtype_aux_); + CopyVecDataToArray(view, data->data()); + return view; + } + NDArray CopyCurAppendLengthIndptrAsync(HostMemoryVector* data) final { + NDArray view = cur_append_length_indptr_device_.CreateView({static_cast(data->size())}, + dtype_aux_); + CopyVecDataToArray(view, data->data()); + return view; + } + NDArray CopyKRaggedRoPEPosOffsetAsync(HostMemoryVector* data) final { + NDArray view = k_ragged_rope_pos_offset_device_.CreateView({static_cast(data->size())}, + dtype_aux_); + CopyVecDataToArray(view, data->data()); + return view; + } + NDArray CopyQRoPEPosMapAsync(HostMemoryVector* data) final { + NDArray view = + q_rope_position_map_device_.CreateView({static_cast(data->size())}, dtype_aux_); + CopyVecDataToArray(view, data->data()); + return view; + } + NDArray CopyAppendPositionMapAsync(HostMemoryVector* data) final { + NDArray view = + append_position_map_device_.CreateView({static_cast(data->size())}, dtype_aux_); + CopyVecDataToArray(view, data->data()); + return view; + } + NDArray CopyKVTransferRemotePositionMapAsync(HostMemoryVector* data) final { + NDArray view = kv_transfer_remote_position_map_device.CreateView( + {static_cast(data->size())}, dtype_aux_); + CopyVecDataToArray(view, data->data()); + return view; + } + NDArray CopyKVTransferRecverIDAsync(HostMemoryVector* data) final { + NDArray view = + kv_transfer_recver_id_device.CreateView({static_cast(data->size())}, dtype_aux_); + CopyVecDataToArray(view, data->data()); + return view; + } + NDArray CopyKVTransferPage2PageLocalPositionMapAsync(HostMemoryVector* data) final { + NDArray view = kv_transfer_page_to_page_local_position_map_device.CreateView( + {static_cast(data->size())}, dtype_aux_); + CopyVecDataToArray(view, data->data()); + return view; + } + NDArray CopyKVTransferPage2PageRemotePositionMapAsync(HostMemoryVector* data) final { + NDArray view = kv_transfer_page_to_page_remote_position_map_device.CreateView( + {static_cast(data->size())}, dtype_aux_); + CopyVecDataToArray(view, data->data()); + return view; + } + NDArray CopyKVTransferPage2PageRecverIDAsync(HostMemoryVector* data) final { + NDArray view = kv_transfer_page_to_page_recver_id_device.CreateView( + {static_cast(data->size())}, dtype_aux_); + CopyVecDataToArray(view, data->data()); + return view; + } + + NDArray CopyTreeAttnMaskOnDepthAsync(HostMemoryVector* data, int depth) final { + NDArray view = + tree_attn_mask_device_[depth].CreateView({static_cast(data->size())}, dtype_aux_); + CopyVecDataToArray(view, data->data()); + return view; + } + NDArray CopyTreeAttnMNIndptrOnDepthAsync(HostMemoryVector* data, int depth) final { + NDArray view = tree_attn_mn_indptr_device_[depth].CreateView( + {static_cast(data->size())}, dtype_aux_); + CopyVecDataToArray(view, data->data()); + return view; + } + + NDArray CopyLengthInfoOnDepthAsync(HostMemoryVector* last_page_len, + HostMemoryVector* sliding_window_offset, + HostMemoryVector* sink_size, int depth) final { + int n_elem = last_page_len->size(); + ICHECK_GT(n_elem, 0); + NDArray view = length_info_on_depths_device_[depth].CreateView({3, n_elem}, dtype_aux_); + ShapeTuple copy_shape{n_elem}; + CopyVecDataToArray(view, last_page_len->data(), copy_shape); + CopyVecDataToArray(view, sliding_window_offset->data(), copy_shape, + /*dst_elem_offset=*/n_elem); + CopyVecDataToArray(view, sink_size->data(), copy_shape, + /*dst_elem_offset=*/2 * n_elem); + return view; + } + + // The commit of the plain auxiliary data manager is no-op. + void CommitAttnAuxDataCopy() final {} + + // The reset of the plain auxiliary data manager is no-op. + void ResetCompactKVAuxDataCopy() final {} + + NDArray CopyCommitLengthIndptrAsync(HostMemoryVector* data) final { + NDArray view = commit_copy_length_indptr_device_.CreateView( + {static_cast(data->size())}, dtype_aux_); + CopyVecDataToArray(view, data->data()); + return view; + } + NDArray CopyCommitSrcDstPosInPageTableAsync(HostMemoryVector* src_data, + HostMemoryVector* dst_data) final { + int n_elem = src_data->size(); + ICHECK_GT(n_elem, 0); + NDArray view = + commit_copy_src_dst_pos_in_page_table_device_.CreateView({2, n_elem}, dtype_aux_); + ShapeTuple copy_shape{n_elem}; + CopyVecDataToArray(view, src_data->data(), copy_shape); + CopyVecDataToArray(view, dst_data->data(), copy_shape, + /*dst_elem_offset=*/n_elem); + return view; + } + + // The commit of the plain auxiliary data manager is no-op. + void CommitCompactKVAuxDataCopy() final {} + + private: + /*! + * \brief Copy a vector of data to the input NDArray. + * It optionally supports specifying the shape of copy and the element + * offset to the destination NDArray. + */ + void CopyVecDataToArray(NDArray array, int32_t* vec_data, Optional shape = NullOpt, + int dst_elem_offset = 0) { + if (array->shape[0] == 0) { + return; + } + DLTensor copy_dst = *array.operator->(); +#if defined(OPENCL_ENABLE_HOST_PTR) + tvm::runtime::cl::OpenCLWorkspace* workspace = tvm::runtime::cl::OpenCLWorkspace::Global(); + if (workspace->IsOpenCLDevice(copy_dst.device)) { + void* nptr = workspace->GetNativePtr(array); + uint64_t copy_size; + if (shape.defined()) { + ICHECK_EQ(shape.value().size(), 1); + copy_size = shape.value()->data[0] * sizeof(int32_t); + } else { + copy_size = DeviceAPI::Get(array->device)->GetDataSize(*array.operator->()); + } + memcpy(static_cast(nptr) + dst_elem_offset * sizeof(int32_t), vec_data, copy_size); + return; + } +#endif + + if (shape.defined()) { + ICHECK_EQ(shape.value().size(), 1); + copy_dst.ndim = 1; + copy_dst.shape = shape.value()->data; + } + copy_dst.byte_offset = dst_elem_offset * sizeof(int32_t); + + DLTensor copy_src; + copy_src.data = vec_data; + copy_src.device = preferred_host_device_; + copy_src.ndim = 1; + copy_src.dtype = array->dtype; + copy_src.shape = copy_dst.shape; + copy_src.strides = nullptr; + copy_src.byte_offset = 0; + NDArray::CopyFromTo(©_src, ©_dst, copy_stream_); + } + + std::vector qo_indptr_on_depths_device_; + std::vector page_indptr_on_depths_device_; + std::vector page_indices_on_depths_device_; + std::vector length_info_on_depths_device_; + std::vector k_rope_pos_offset_on_depths_device_; + std::vector tree_attn_mask_device_; + std::vector tree_attn_mn_indptr_device_; + NDArray cur_append_length_indptr_device_; + NDArray k_ragged_rope_pos_offset_device_; + NDArray q_rope_position_map_device_; + NDArray append_position_map_device_; + NDArray kv_transfer_remote_position_map_device; + NDArray kv_transfer_recver_id_device; + NDArray kv_transfer_page_to_page_local_position_map_device; + NDArray kv_transfer_page_to_page_remote_position_map_device; + NDArray kv_transfer_page_to_page_recver_id_device; + NDArray commit_copy_length_indptr_device_; + NDArray commit_copy_src_dst_pos_in_page_table_device_; +}; + +/*! + * \brief The cached auxiliary data manager class. + * It allocates a large on-device array to store all the auxiliary data. + * For each `CopyXXXAsync`, it copies the input data to a local cache on host. + * In `CommitAttnAuxDataCopy`, it copies all the data in the local cache to the device + * array for a single time, and thus reduce the number of host-to-device copies needed. + */ +class CachedPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { + public: + explicit CachedPagedKVCacheAuxDataManager(int64_t reserved_num_seqs, int64_t num_total_pages, + int64_t prefill_chunk_size, DLDataType dtype_aux, + Device device, Device preferred_host_device, + TVMStreamHandle copy_stream) + : PagedKVCacheAuxDataManager(dtype_aux, device, preferred_host_device, copy_stream), + elem_byte_size_((dtype_aux.bits * dtype_aux.lanes + 7) / 8), + offset_alignment_(cuda_byte_alignment_ / elem_byte_size_) { + // - Calculate cache size of all the attention auxiliary arrays in + // local cache and the large on-device array. + int64_t attn_aux_data_cache_size = + CalculateAttnAuxDataCacheSize(reserved_num_seqs, num_total_pages, prefill_chunk_size); + // - Initialize the host auxiliary data buffer. + merged_attn_aux_data_host_ = + HostMemoryVector(attn_aux_data_cache_size, dtype_aux, preferred_host_device); + // - Initialize the device auxiliary data buffer. + merged_attn_aux_data_device_ = NDArray::Empty({attn_aux_data_cache_size}, dtype_aux, device); + + // - Calculate cache size of all the compact KV auxiliary arrays in + // local cache and the large on-device array. + int64_t compact_kv_aux_data_cache_size = + CalculateCompactKVAuxDataCacheSize(reserved_num_seqs, prefill_chunk_size); + // - Initialize the host auxiliary data buffer. + merged_compact_kv_aux_data_host_ = + HostMemoryVector(compact_kv_aux_data_cache_size, dtype_aux, preferred_host_device); + merged_compact_kv_aux_data_device_ = + NDArray::Empty({compact_kv_aux_data_cache_size}, dtype_aux, device); + } + + void ResetAttnAuxDataCopy() final { attn_aux_data_copy_offset_ = 0; } + NDArray CopyQOIndptrOnDepthAsync(HostMemoryVector* data, int depth) final { + return CopyAttnAuxVecToCache(data); + } + NDArray CopyPageIndptrOnDepthAsync(HostMemoryVector* data, int depth) final { + return CopyAttnAuxVecToCache(data); + } + NDArray CopyPageIndicesOnDepthAsync(HostMemoryVector* data, int depth) final { + return CopyAttnAuxVecToCache(data); + } + NDArray CopyLastPageLenOnDepthAsync(HostMemoryVector* data, int depth) final { + return CopyAttnAuxVecToCache(data); + } + NDArray CopyKRoPEPosOffsetOnDepthAsync(HostMemoryVector* data, int depth) final { + return CopyAttnAuxVecToCache(data); + } + NDArray CopyCurAppendLengthIndptrAsync(HostMemoryVector* data) final { + return CopyAttnAuxVecToCache(data); + } + NDArray CopyKRaggedRoPEPosOffsetAsync(HostMemoryVector* data) final { + return CopyAttnAuxVecToCache(data); + } + NDArray CopyQRoPEPosMapAsync(HostMemoryVector* data) final { return CopyAttnAuxVecToCache(data); } + NDArray CopyAppendPositionMapAsync(HostMemoryVector* data) final { + return CopyAttnAuxVecToCache(data); + } + NDArray CopyKVTransferRemotePositionMapAsync(HostMemoryVector* data) final { + return CopyAttnAuxVecToCache(data); + } + NDArray CopyKVTransferRecverIDAsync(HostMemoryVector* data) final { + return CopyAttnAuxVecToCache(data); + } + NDArray CopyKVTransferPage2PageLocalPositionMapAsync(HostMemoryVector* data) final { + return CopyAttnAuxVecToCache(data); + } + NDArray CopyKVTransferPage2PageRemotePositionMapAsync(HostMemoryVector* data) final { + return CopyAttnAuxVecToCache(data); + } + NDArray CopyKVTransferPage2PageRecverIDAsync(HostMemoryVector* data) final { + return CopyAttnAuxVecToCache(data); + } + NDArray CopyTreeAttnMaskOnDepthAsync(HostMemoryVector* data, int depth) final { + NDArray mask_1d = CopyAttnAuxVecToCache(data); + return mask_1d.CreateView({static_cast(data->size() / 2), 2}, mask_1d->dtype); + } + NDArray CopyTreeAttnMNIndptrOnDepthAsync(HostMemoryVector* data, int depth) final { + return CopyAttnAuxVecToCache(data); + } + NDArray CopyLengthInfoOnDepthAsync(HostMemoryVector* last_page_len, + HostMemoryVector* sliding_window_offset, + HostMemoryVector* sink_size, int depth) final { + int64_t n_elem = last_page_len->size(); + std::memcpy(merged_attn_aux_data_host_.data() + attn_aux_data_copy_offset_, + last_page_len->data(), n_elem * elem_byte_size_); + std::memcpy(merged_attn_aux_data_host_.data() + attn_aux_data_copy_offset_ + n_elem, + sliding_window_offset->data(), n_elem * elem_byte_size_); + std::memcpy(merged_attn_aux_data_host_.data() + attn_aux_data_copy_offset_ + 2 * n_elem, + sink_size->data(), n_elem * elem_byte_size_); + NDArray view = merged_attn_aux_data_device_.CreateView( + {3, n_elem}, dtype_aux_, attn_aux_data_copy_offset_ * elem_byte_size_); + attn_aux_data_copy_offset_ += CeilDivElemAlignment(3 * n_elem); + return view; + } + + void CommitAttnAuxDataCopy() final { + std::vector copy_shape{attn_aux_data_copy_offset_}; + DLTensor copy_dst; + copy_dst.data = merged_attn_aux_data_device_->data; + copy_dst.device = device_; + copy_dst.ndim = 1; + copy_dst.dtype = dtype_aux_; + copy_dst.shape = copy_shape.data(); + copy_dst.strides = nullptr; + copy_dst.byte_offset = 0; + + DLTensor copy_src = copy_dst; + copy_src.data = merged_attn_aux_data_host_.data(); + copy_src.device = Device{kDLCPU, 0}; + NDArray::CopyFromTo(©_src, ©_dst, copy_stream_); + } + + void ResetCompactKVAuxDataCopy() final { compact_kv_aux_data_copy_offset_ = 0; } + + NDArray CopyCommitLengthIndptrAsync(HostMemoryVector* data) final { + return CopyCompactKVAuxVecToCache(data); + } + NDArray CopyCommitSrcDstPosInPageTableAsync(HostMemoryVector* src_data, + HostMemoryVector* dst_data) final { + int64_t n_elem = src_data->size(); + std::memcpy(merged_compact_kv_aux_data_host_.data() + compact_kv_aux_data_copy_offset_, + src_data->data(), n_elem * elem_byte_size_); + std::memcpy(merged_compact_kv_aux_data_host_.data() + compact_kv_aux_data_copy_offset_ + n_elem, + dst_data->data(), n_elem * elem_byte_size_); + NDArray view = merged_compact_kv_aux_data_device_.CreateView( + {2, n_elem}, dtype_aux_, compact_kv_aux_data_copy_offset_ * elem_byte_size_); + compact_kv_aux_data_copy_offset_ += CeilDivElemAlignment(2 * n_elem); + return view; + } + + void CommitCompactKVAuxDataCopy() final { + std::vector copy_shape{compact_kv_aux_data_copy_offset_}; + DLTensor copy_dst; + copy_dst.data = merged_compact_kv_aux_data_device_->data; + copy_dst.device = device_; + copy_dst.ndim = 1; + copy_dst.dtype = dtype_aux_; + copy_dst.shape = copy_shape.data(); + copy_dst.strides = nullptr; + copy_dst.byte_offset = 0; + + DLTensor copy_src = copy_dst; + copy_src.data = merged_compact_kv_aux_data_host_.data(); + copy_src.device = Device{kDLCPU, 0}; + NDArray::CopyFromTo(©_src, ©_dst, copy_stream_); + } + + private: + /*! + * \brief Calculate the start element offsets of the auxiliary arrays in the local cache. + * \return Return the local cache size (total number of elements in the local cache). + */ + int64_t CalculateAttnAuxDataCacheSize(int64_t reserved_num_seqs, int64_t num_total_pages, + int64_t prefill_chunk_size) { + int64_t cache_size = 0; + // - Array size of the arrays that every depth has. + // Corresponding to the following arrays respectively + // - qo_indptr_in_depth + // - page_indptr_in_depth + // - page_indices_in_depth + // - length_info_in_depth + // - k_rope_pos_offset_in_depth + cache_size += CeilDivElemAlignment(reserved_num_seqs + 1); + cache_size += CeilDivElemAlignment(reserved_num_seqs + 1); + cache_size += CeilDivElemAlignment(num_total_pages); + cache_size += CeilDivElemAlignment(3 * reserved_num_seqs); + cache_size += CeilDivElemAlignment(reserved_num_seqs); + cache_size *= kPagedKVCacheMaxBlockDepth; + + // - Array size of other arrays. + // Corresponding to the following arrays respectively + // - cur_append_length_indptr + // - k_ragged_rope_pos_offset + // - q_rope_position_map + // - append_position_map + // - kv_transfer_remote_position_map + // - kv_transfer_recver_id + // - kv_transfer_page_to_page_local_position_map + // - kv_transfer_page_to_page_remote_position_map + // - kv_transfer_page_to_page_recver_id + // - tree_attn_mask + // - tree_attn_mn_indptr + cache_size += CeilDivElemAlignment(reserved_num_seqs + 1); + cache_size += CeilDivElemAlignment(reserved_num_seqs); + cache_size += CeilDivElemAlignment(prefill_chunk_size); + cache_size += CeilDivElemAlignment(prefill_chunk_size); + cache_size += CeilDivElemAlignment(prefill_chunk_size); + cache_size += CeilDivElemAlignment(prefill_chunk_size); + cache_size += CeilDivElemAlignment(prefill_chunk_size); + cache_size += CeilDivElemAlignment(prefill_chunk_size); + cache_size += CeilDivElemAlignment(prefill_chunk_size); + cache_size += + CeilDivElemAlignment(kTreeAttnMaxTreeSize * kTreeAttnMaxTreeSize * reserved_num_seqs); + cache_size += CeilDivElemAlignment(reserved_num_seqs + 1); + + return cache_size; + } + + int64_t CalculateCompactKVAuxDataCacheSize(int64_t reserved_num_seqs, + int64_t prefill_chunk_size) { + int64_t cache_size = 0; + // Corresponding to the following arrays respectively + // - commit_copy_length_indptr + // - commit_copy_src_dst_pos_in_page_table + cache_size += CeilDivElemAlignment(reserved_num_seqs + 1); + cache_size += CeilDivElemAlignment( + 2 * std::min(kTreeAttnMaxTreeSize * reserved_num_seqs, prefill_chunk_size)); + + return cache_size; + } + + /*! + * \brief Copy the input data to the cache at the given offset. + * And return the NDArray view of the cache starting at the offset. + */ + NDArray CopyAttnAuxVecToCache(HostMemoryVector* data) { + int64_t n_elem = data->size(); + std::memcpy(merged_attn_aux_data_host_.data() + attn_aux_data_copy_offset_, data->data(), + n_elem * elem_byte_size_); + NDArray view = merged_attn_aux_data_device_.CreateView( + {n_elem}, dtype_aux_, attn_aux_data_copy_offset_ * elem_byte_size_); + attn_aux_data_copy_offset_ += CeilDivElemAlignment(n_elem); + return view; + } + + NDArray CopyCompactKVAuxVecToCache(HostMemoryVector* data) { + int64_t n_elem = data->size(); + std::memcpy(merged_compact_kv_aux_data_host_.data() + compact_kv_aux_data_copy_offset_, + data->data(), n_elem * elem_byte_size_); + NDArray view = merged_compact_kv_aux_data_device_.CreateView( + {n_elem}, dtype_aux_, compact_kv_aux_data_copy_offset_ * elem_byte_size_); + compact_kv_aux_data_copy_offset_ += CeilDivElemAlignment(n_elem); + return view; + } + + /*! \brief For safety, we align the start offset of the arrays to `offset_alignment`. */ + int64_t CeilDivElemAlignment(int n) { + return (n + offset_alignment_ - 1) / offset_alignment_ * offset_alignment_; + } + + const int64_t cuda_byte_alignment_ = 16; + const int64_t elem_byte_size_; + const int64_t offset_alignment_; + + int64_t attn_aux_data_copy_offset_ = 0; + int64_t compact_kv_aux_data_copy_offset_ = 0; + HostMemoryVector merged_attn_aux_data_host_; + HostMemoryVector merged_compact_kv_aux_data_host_; + NDArray merged_attn_aux_data_device_; + NDArray merged_compact_kv_aux_data_device_; +}; + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_RELAX_VM_ATTN_UTILS_H_ diff --git a/src/runtime/relax_vm/kv_state.cc b/src/runtime/relax_vm/kv_state.cc index 1b1867f06093..43b7c7ab4064 100644 --- a/src/runtime/relax_vm/kv_state.cc +++ b/src/runtime/relax_vm/kv_state.cc @@ -77,27 +77,33 @@ TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_debug_get_kv") TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_debug_get_kv_mla") .set_body_method(&AttentionKVCacheObj::DebugGetKVMLA); TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_attention_with_fused_qkv") - .set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id, - double attn_score_scaling_factor, NDArray qkv_data, NDArray o_data) { + .set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id, double sm_scale, + NDArray qkv_data, NDArray o_data) { kv_cache->AttentionWithFusedQKV(layer_id, std::move(qkv_data), NullOpt, std::move(o_data), - attn_score_scaling_factor); + sm_scale); }); -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_mla_absorbed") - .set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id, - double attn_score_scaling_factor, NDArray q_data, NDArray compressed_kv_data, - NDArray k_pe_data, NDArray o_data) { - kv_cache->MLAAbsorbed(layer_id, std::move(q_data), std::move(compressed_kv_data), - std::move(k_pe_data), std::move(o_data), attn_score_scaling_factor); +TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_self_attention") + .set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id, double sm_scale, NDArray q_data, + NDArray k_data, NDArray v_data, NDArray o_data, NDArray lse_data) { + kv_cache->SelfAttention(layer_id, std::move(q_data), std::move(k_data), std::move(v_data), + std::move(o_data), std::move(lse_data), sm_scale); }); - -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_mla_normal") - .set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id, - double attn_score_scaling_factor, NDArray q_data, NDArray k_data, - NDArray v_data, NDArray compressed_kv_data, NDArray k_pe_data, - NDArray o_data) { - kv_cache->MLANormal(layer_id, std::move(q_data), std::move(k_data), std::move(v_data), - std::move(compressed_kv_data), std::move(k_pe_data), std::move(o_data), - attn_score_scaling_factor); +TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_cross_attention") + .set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id, double sm_scale, NDArray q_data, + NDArray o_data, NDArray lse_data) { + kv_cache->CrossAttention(layer_id, std::move(q_data), std::move(o_data), std::move(lse_data), + sm_scale); + }); +TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_append_mla_kv") + .set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id, NDArray kv_data) { + kv_cache->AppendMLAKV(layer_id, std::move(kv_data)); + return kv_cache; + }); +TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_merge_attn_output_inplace") + .set_body_typed([](AttentionKVCache kv_cache, NDArray o_self_attn, NDArray lse_self_attn, + NDArray o_cross_attn, NDArray lse_cross_attn) { + return kv_cache->MergeAttnOutputInplace(std::move(o_self_attn), std::move(lse_self_attn), + std::move(o_cross_attn), std::move(lse_cross_attn)); }); // RNN State methods diff --git a/src/runtime/relax_vm/kv_state.h b/src/runtime/relax_vm/kv_state.h index 300d22b85909..1e530b41ece5 100644 --- a/src/runtime/relax_vm/kv_state.h +++ b/src/runtime/relax_vm/kv_state.h @@ -175,47 +175,53 @@ class AttentionKVCacheObj : public KVStateObj { * `(total_length, num_qo_heads + 2 * num_kv_heads, head_dim)`. * \param mask The input mask data, in layout `(total_sqr_length)`. * \param o_data The output O data, in layout `(total_length, num_qo_heads, head_dim)`. - * \param attn_score_scaling_factor The additional attention scaling factor. + * \param sm_scale The additional attention scaling factor. * \sa AttentionKVCache::Attention */ virtual void AttentionWithFusedQKV(int64_t layer_id, NDArray qkv_data, Optional mask, - NDArray o_data, double attn_score_scaling_factor) = 0; + NDArray o_data, double sm_scale) = 0; /*! - * \brief Compute multi-head latent attention after applying weight absorption. + * \brief Fine-grained API that computes ragged self attention with Q/K/V data. * \param layer_id The model layer where the attention compute happens. - * \param q_data The input Q data, in layout `(total_length, num_qo_heads, qk_head_dim)` - * \param compressed_kv_data The compressed latent KV data, in layout - * `(total_length, num_kv_heads, kv_lora_rank)` - * \param k_pe_data The positional embedding part of K data, in layout - * `(total_length, num_kv_heads, qk_rope_head_dim)`, where `kv_lora_rank + qk_rope_head_dim` - * equals qk_head_dim + * \param q_data The input Q data. + * \param k_data The input K data. + * \param v_data The input V data. * \param o_data The output O data, in layout `(total_length, num_qo_heads, v_head_dim)`. - * \param attn_score_scaling_factor The additional attention scaling factor. + * \param lse_data The output attention LSE data, in layout `(total_length, num_qo_heads)`. + * \param sm_scale The additional attention scaling factor. */ - virtual void MLAAbsorbed(int64_t layer_id, NDArray q_data, NDArray compressed_kv_data, - NDArray k_pe_data, NDArray o_data, double attn_score_scaling_factor) = 0; + virtual void SelfAttention(int64_t layer_id, NDArray q_data, NDArray k_data, NDArray v_data, + NDArray o_data, NDArray lse_data, double sm_scale) = 0; /*! - * \brief Compute multi-head latent attention in normal style. + * \brief Fine-grained API that computes paged cross attention with Q and in-cache KV data. * \param layer_id The model layer where the attention compute happens. - * \param q_data The input Q data, in layout - * `(total_length, num_qo_heads, qk_nope_head_dim + qk_rope_head_dim)` - * \param k_data The input K data, in layout - * `(total_length, num_qo_heads, qk_nope_head_dim + qk_rope_head_dim)` - * \param v_data The input V data, in layout - * `(total_length, num_qo_heads, v_head_dim)` - * \param compressed_kv_data The compressed latent KV data, in layout - * `(total_length, num_kv_heads, kv_lora_rank)` - * \param k_pe_data The positional embedding part of K data, in layout - * `(total_length, num_kv_heads, qk_rope_head_dim)`, where `kv_lora_rank + qk_rope_head_dim` - * equals qk_head_dim - * \param o_data The output O data, in layout `(total_length, num_qo_heads, v_head_dim)`. - * \param attn_score_scaling_factor The additional attention scaling factor. + * \param q_data The input Q data. + * \param o_data The output O data. + * \param lse_data The output attention LSE data, in layout `(total_length, num_qo_heads)`. + * \param sm_scale The additional attention scaling factor. + */ + virtual void CrossAttention(int64_t layer_id, NDArray q_data, NDArray o_data, NDArray lse_data, + double sm_scale) = 0; + + /*! + * \brief Fine-grained API that appends the MLA K/V data to KV cache. + * \param layer_id The model layer where the attention compute happens. + * \param kv_data The input KV data to append, in layout `(total_length, qk_head_dim)`. + */ + virtual void AppendMLAKV(int64_t layer_id, NDArray kv_data) = 0; + + /*! + * \brief Fine-grained API that merges the attention output from two sources. + * \param o1_data The first source O data. + * \param lse1_data The first source LSE data. + * \param o2_data The second source O data. + * \param lse2_data The second source LSE data. + * \return The merged O and LSE data. */ - virtual void MLANormal(int64_t layer_id, NDArray q_data, NDArray k_data, NDArray v_data, - NDArray compressed_kv_data, NDArray k_pe_data, NDArray o_data, - double attn_score_scaling_factor) = 0; + virtual Array MergeAttnOutputInplace(NDArray o_self_attn, NDArray lse_self_attn, + NDArray o_cross_attn, NDArray lse_cross_attn) = 0; /*! * \brief Compute linear attention with Q/K/V data. @@ -224,11 +230,11 @@ class AttentionKVCacheObj : public KVStateObj { * \param k_data The input K data, in layout `(total_length, num_kv_heads, head_dim)`. * \param v_data The input V data, in layout `(total_length, num_kv_heads, head_dim)`. * \param o_data The output O data, in layout `(total_length, num_qo_heads, head_dim)`. - * \param attn_score_scaling_factor The additional attention scaling factor. + * \param sm_scale The additional attention scaling factor. * \sa AttentionKVCache::Attention */ virtual void LinearAttention(int64_t layer_id, NDArray q_data, NDArray k_data, NDArray v_data, - double attn_score_scaling_factor) = 0; + double sm_scale) = 0; /************** Positions **************/ diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index a936f429eeec..6b68bb1b1a30 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -33,6 +33,8 @@ #include #include +#include "attn_backend.h" +#include "attn_utils.h" #include "kv_state.h" #if defined(OPENCL_ENABLE_HOST_PTR) #include "../opencl/opencl_common.h" @@ -50,888 +52,6 @@ namespace relax_vm { // runtime API function calls //------------------------------------------- -/*! - * \brief The maximum allowed block depth (a.k.a. number of common - * prefixes) in paged KV cache. - */ -constexpr const int kPagedKVCacheMaxBlockDepth = 2; -/*! \brief The maximum tree size of a single sequence in tree attention. */ -constexpr const int kTreeAttnMaxTreeSize = 256; -/*! \brief The 1MB workspace size for integer attention auxiliary data. */ -constexpr const int kIntAttnWorkspaceByte = 1 * 1024 * 1024; -/*! \brief The 128MB workspace size for floating-point attention auxiliary data. */ -constexpr const int kFloatAttnWorkspaceByte = 768 * 1024 * 1024; -/*! \brief The id of the temporary logical page, which is useful for sliding window. */ -constexpr const int kPagedKVCacheTempPageId = -1; - -/*! - * \brief The supported attention kinds in PagedKVCache. - * "MHA" means multi-head attention, multi-query attention and grouped query attention in general. - * "MLA" means multi-head latent attention. - * "LinearAttn" means linear attention. - */ -enum class AttnKind : int { - kMHA = 0, - kMLA = 1, - kLinearAttn = 2, -}; - -ShapeTuple GetKVCacheShape(AttnKind attn_kind, int64_t num_total_pages, int num_sequence, - int64_t num_kv_heads, int64_t page_size, int64_t qk_head_dim, - int64_t v_head_dim, int64_t qk_rope_head_dim) { - if (attn_kind == AttnKind::kMHA) { - // Ignore v_head_dim since multi-head attention requires K/V to have the same head dim. - return {num_total_pages, 2, num_kv_heads, page_size, qk_head_dim}; - } else if (attn_kind == AttnKind::kMLA) { - return {num_total_pages, page_size, qk_head_dim}; - } else if (attn_kind == AttnKind::kLinearAttn) { - return {num_sequence, num_kv_heads, qk_head_dim, v_head_dim}; - } - ICHECK(false); - throw; -} - -/*! - * \brief The block structure in paged KV cache with common prefix support. - * Each block contains a list of pages for cached KV data. - * If a block has `n` pages, the first `n - 1` pages must be - * full, and only the last page can be partially filled. - * - * To support common prefix, each sequence in KV cache is represented - * as one or more blocks, where the common prefix is a standalone - * block among. - * - * Each block has a parent block when it uses a prefix. - */ -struct Block { - /*! - * \brief The ids of the pages in the block. - * Each page can only be used by a unique block (in other - * words, different blocks do not share pages). - */ - std::vector page_ids; - /*! \brief The total sequence length in the block. */ - int32_t seq_length = 0; - /*! - * \brief The start position in sequence of this block. - * This is the absolute position in the sequence for RoPE computation. - */ - int32_t start_pos = 0; - /*! - * \brief The current attention sink length of the block. - * It means the **first** sink size elements will be pinned - * in the KV cache even when sliding window is enabled. - */ - int32_t sink_length = 0; - /*! - * \brief The start offset of the sliding window in the block. - * It is always 0 when sliding window attn is not enabled. - */ - int32_t sliding_window_offset = 0; - - /*! \brief The global index of the block. */ - const int32_t index; - /*! - * \brief The global index of the parent block of this block, or -1 - * if the block does not have a parent. */ - int32_t parent_idx = -1; - /*! - * \brief The external reference counter of the block. - * When a block is externally referred by some block, - * we do not allow appending new KV values to this block. - */ - int external_ref_cnt = 0; - - explicit Block(int32_t index) : index(index) {} - - /*! \brief Reset the block data. */ - void Reset() { - page_ids.clear(); - seq_length = 0; - start_pos = 0; - sink_length = 0; - sliding_window_offset = 0; - parent_idx = -1; - external_ref_cnt = 0; - } -}; - -struct KVTransferMetadata { - int64_t start = std::numeric_limits::max(); - std::vector remote_position_map; - int32_t recver_pe_offset = -1; - std::vector local_position_map; -}; - -/*! - * \brief The sequence structure in paged KV cache with common prefix support. - * Each sequence contains one or more blocks to support common prefix. - */ -struct Sequence { - /*! - * \brief The global index of the last block of the sequence. - * We only store the last block, since all the blocks can be - * tracked with the `parent` field of Block. - */ - int32_t last_block_idx; - /*! - * \brief The total sequence length of the sequence. - * It is the sum of lengths of all its blocks. - */ - int32_t seq_length = 0; - /*! - * \brief The sliding window size of the sequence, or -1 if sliding window is not enabled. - * When a sequence is enabled for sliding window, it can no longer be forked. - */ - int sliding_window_size = -1; - /*! - * \brief The attention sink size of the last block of the sequence. - * The **first** sink size elements of the last block will be pinned - * in the KV cache even when sliding window is enabled. - */ - int last_block_attn_sink_size = 0; - - /*! \brief Whether the current appended tokens form a chain (not a tree). */ - bool is_chain = true; - /*! \brief The token tree parent pointer array of the current appended tokens. */ - std::vector token_tree_parent_ptr; - /*! \brief The depth of each node in the token tree. */ - std::vector token_tree_node_depths; - /*! \brief The metadata of kv transfer*/ - KVTransferMetadata kv_transfer_metadata; - /*! - * \brief A boolean denoting whether the accepted token tree indices of - * this sequence are committed - */ - bool accepted_indices_committed = true; - - explicit Sequence(std::vector* global_block_pool, int32_t last_block_idx) { - ++global_block_pool->at(last_block_idx).external_ref_cnt; - this->last_block_idx = last_block_idx; - int32_t block_ptr = last_block_idx; - // Go through each block in the sequence, sum up the length. - while (true) { - const Block& block = global_block_pool->at(block_ptr); - this->seq_length += block.seq_length; - if (block.parent_idx == -1) { - break; - } - block_ptr = block.parent_idx; - } - } - - std::vector GetBlockTrace(const std::vector& global_block_pool) const { - std::vector trace; - // Get the trace from the last block of the sequence to the root block. - int32_t block_ptr = last_block_idx; - while (block_ptr != -1) { - trace.push_back(block_ptr); - block_ptr = global_block_pool[block_ptr].parent_idx; - } - // Reverse the trace so that it starts from the root block. - std::reverse(trace.begin(), trace.end()); - return trace; - } -}; - -/*! - * \brief The rotary embedding mode adopted by the paged KV cache - * when computing attention. - * "None" means RoPE is never applied to q and k. - * "Normal" means RoPE is computed in a standalone kernel. - * "Inline" means RoPE is computed on-the-fly in attention kernels. - */ -enum class RoPEMode : int { - kNone = 0, - kNormal = 1, - kInline = 2, -}; - -/*! - * \brief The class of host memory int32 vector in "std::vector" interface. - * This vector allocates static memory on the specified host memory - * at the time of construction. - */ -class HostMemoryVector { - public: - HostMemoryVector() = default; - HostMemoryVector(const HostMemoryVector&) = delete; - HostMemoryVector(HostMemoryVector&& other) = default; - HostMemoryVector& operator=(const HostMemoryVector&) = delete; - HostMemoryVector& operator=(HostMemoryVector&& other) = default; - - explicit HostMemoryVector(int64_t reserved_size, DLDataType dtype, Device device) - : reserved_size_(reserved_size) { - ICHECK(DataType(dtype) == DataType::Int(32)); - data_ = NDArray::Empty({reserved_size}, dtype, device); - } - - void push_back(int32_t value) { - ICHECK_LE(current_size_, reserved_size_); - if (current_size_ == reserved_size_) { - reserved_size_ *= 2; - NDArray new_data = NDArray::Empty({reserved_size_}, data_->dtype, data_->device); - std::memcpy(new_data->data, data_->data, current_size_ * DataType(data_->dtype).bytes()); - data_ = new_data; - } - static_cast(data_->data)[current_size_++] = value; - } - - const int32_t& operator[](int64_t idx) const { - ICHECK_GE(idx, 0) << "Index " << idx << " is negative."; - ICHECK_LT(idx, current_size_) << "Index " << idx << " out of bounds " << current_size_; - return static_cast(data_->data)[idx]; - } - - int32_t back() const { - ICHECK_GT(current_size_, 0) << "Vector is empty"; - return static_cast(data_->data)[current_size_ - 1]; - } - - size_t size() const { return static_cast(current_size_); } - - int32_t* data() const { return static_cast(data_->data); } - - void clear() { current_size_ = 0; } - - /*! \brief Return the vector as an NDArray. */ - NDArray as_ndarray() { return data_.CreateView({current_size_}, data_->dtype); } - - IntTuple as_int_tuple() const { - std::vector values; - values.reserve(current_size_); - for (int i = 0; i < current_size_; ++i) { - values.push_back(static_cast(data_->data)[i]); - } - return IntTuple(values); - } - - private: - int64_t reserved_size_ = 0; - int64_t current_size_ = 0; - NDArray data_{nullptr}; -}; - -/*! - * \brief The paged attention auxiliary data manager class. - * This class manages all the int32 auxiliary data on GPU device, such as - * page table, position arrays, etc.. - * - * The core functions of this class is `CopyXXXAsync` and `CommitAttnAuxDataCopy`. - * `CopyXXXAsync` takes the input data on CPU host, and copy the input data - * to GPU in an asynchronous way, and returns the NDArray view of the data - * on GPU device. - * - * Being asynchronous here means the `CopyXXXAsync` function may not perform - * data copy from CPU to GPU at the time of being called. Therefore, the - * returned NDArray view may have wrong result, until `CommitAttnAuxDataCopy` is - * explicitly invoked and the data copy stream is synchronized. - * - * We design this manager class in order to reduce the data copy overhead. - */ -class PagedKVCacheAuxDataManager { - public: - PagedKVCacheAuxDataManager(DLDataType dtype_aux, Device device, Device preferred_host_device, - TVMStreamHandle copy_stream) - : dtype_aux_(dtype_aux), - device_(device), - preferred_host_device_(preferred_host_device), - copy_stream_(copy_stream) { - ICHECK(DataType(dtype_aux) == DataType::Int(32)); - } - - virtual ~PagedKVCacheAuxDataManager() = default; - /*! \brief Reset the attention auxiliary data status of copy manager. */ - virtual void ResetAttnAuxDataCopy() = 0; - /*! \brief Copy the indptr array of append lengths after coalescing. (see GetChunkedBlockIds) */ - virtual NDArray CopyQOIndptrOnDepthAsync(HostMemoryVector* data, int depth) = 0; - /*! \brief Copy the indptr array of page table. */ - virtual NDArray CopyPageIndptrOnDepthAsync(HostMemoryVector* data, int depth) = 0; - /*! \brief Copy the indices array of page table. */ - virtual NDArray CopyPageIndicesOnDepthAsync(HostMemoryVector* data, int depth) = 0; - /*! \brief Copy the array of KV slot number used in the last page of the seq. */ - virtual NDArray CopyLastPageLenOnDepthAsync(HostMemoryVector* data, int depth) = 0; - /*! - * \brief Copy the length information of the sequences. - * Each NDArray is in shape `(3, n)`. "n" is the number of sequences. - * For a sequence "i", location - * - "(0, i)" is the number of KV slots used in the last page of the seq ("last_page_len"), - * - "(1, i)" is the starting offset of the sliding window in the seq, - * - "(2, i)" is the attn sink length of the sequence. - * \note When sliding window is not enabled, only the - * "last_page_len" (a.k.a., the first "n" elements) will be effectively used. - */ - virtual NDArray CopyLengthInfoOnDepthAsync(HostMemoryVector* last_page_len, - HostMemoryVector* sliding_window_offset, - HostMemoryVector* sink_size, int depth) = 0; - /*! \brief Copy the k position offset of applying RoPE for each sequence. */ - virtual NDArray CopyKRoPEPosOffsetOnDepthAsync(HostMemoryVector* data, int depth) = 0; - /*! - * \brief Copy the append length indptr array on device. - * \note Since the Q/K/V data may have raggedness in terms of lengths, - * we represent the append lengths in CSR format. - */ - virtual NDArray CopyCurAppendLengthIndptrAsync(HostMemoryVector* data) = 0; - /*! \brief Copy the k position offset of applying RoPE for each sequence. */ - virtual NDArray CopyKRaggedRoPEPosOffsetAsync(HostMemoryVector* data) = 0; - /*! \brief Copy the q position mapping of applying RoPE for each sequence. */ - virtual NDArray CopyQRoPEPosMapAsync(HostMemoryVector* data) = 0; - /*! - * \brief Copy the corresponding position in global KV cache (pages) - * for each position along the length dimension of K/V data when - * appending new K/V data. - */ - virtual NDArray CopyAppendPositionMapAsync(HostMemoryVector* data) = 0; - /*! \brief Copy the remote position map for KV transfer. */ - virtual NDArray CopyKVTransferRemotePositionMapAsync(HostMemoryVector* data) = 0; - /*! \brief Copy the receiver id for KV transfer. */ - virtual NDArray CopyKVTransferRecverIDAsync(HostMemoryVector* data) = 0; - /*! \brief Copy the local position map for KV page-to-page transfer. */ - virtual NDArray CopyKVTransferPage2PageLocalPositionMapAsync(HostMemoryVector* data) = 0; - /*! \brief Copy the remote position map for KV page-to-page transfer. */ - virtual NDArray CopyKVTransferPage2PageRemotePositionMapAsync(HostMemoryVector* data) = 0; - /*! \brief Copy the receiver id for KV page-to-page transfer. */ - virtual NDArray CopyKVTransferPage2PageRecverIDAsync(HostMemoryVector* data) = 0; - /*! \brief Copy the tree attention mask. */ - virtual NDArray CopyTreeAttnMaskOnDepthAsync(HostMemoryVector* data, int depth) = 0; - /*! \brief Copy the mn indptr of the tree attention mask. */ - virtual NDArray CopyTreeAttnMNIndptrOnDepthAsync(HostMemoryVector* data, int depth) = 0; - /*! \brief Commit all the attention auxiliary data copy operations since the last commit. */ - virtual void CommitAttnAuxDataCopy() = 0; - - /*! \brief Reset the compact KV auxiliary data status of copy manager. */ - virtual void ResetCompactKVAuxDataCopy() = 0; - /*! \brief Copy the length indptr array of KV data copy for each sequence. */ - virtual NDArray CopyCommitLengthIndptrAsync(HostMemoryVector* data) = 0; - /*! \brief Copy the src/dst position arrays for each sequence. */ - virtual NDArray CopyCommitSrcDstPosInPageTableAsync(HostMemoryVector* src_data, - HostMemoryVector* dst_data) = 0; - /*! \brief Commit all the compact KV auxiliary data copy operations since the last commit. */ - virtual void CommitCompactKVAuxDataCopy() = 0; - - protected: - /*! \brief The dtype of the auxiliary data. It is expected to be int32. */ - const DLDataType dtype_aux_; - /*! \brief The device this PagedKVCache runs on. */ - const Device device_; - /*! \brief The preferred host device. */ - const Device preferred_host_device_; - /*! \brief The device stream for copying auxiliary data structure to GPU. */ - const TVMStreamHandle copy_stream_; -}; - -/*! - * \brief The plain auxiliary data manager class. - * It simply issues one host-to-device copy operation for each `CopyXXXAsync`. - */ -class PlainPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { - public: - explicit PlainPagedKVCacheAuxDataManager(int64_t reserved_num_seqs, int64_t num_total_pages, - int64_t prefill_chunk_size, DLDataType dtype_aux, - Device device, Device preferred_host_device, - TVMStreamHandle copy_stream) - : PagedKVCacheAuxDataManager(dtype_aux, device, preferred_host_device, copy_stream) { - for (int d = 0; d < kPagedKVCacheMaxBlockDepth; ++d) { - qo_indptr_on_depths_device_.push_back( - NDArray::Empty({reserved_num_seqs + 1}, dtype_aux_, device)); - page_indptr_on_depths_device_.push_back( - NDArray::Empty({reserved_num_seqs + 1}, dtype_aux_, device)); - page_indices_on_depths_device_.push_back( - NDArray::Empty({num_total_pages}, dtype_aux_, device)); - length_info_on_depths_device_.push_back( - NDArray::Empty({3, reserved_num_seqs}, dtype_aux_, device)); - k_rope_pos_offset_on_depths_device_.push_back( - NDArray::Empty({reserved_num_seqs}, dtype_aux_, device)); - tree_attn_mask_device_.push_back(NDArray::Empty( - {kTreeAttnMaxTreeSize * kTreeAttnMaxTreeSize * reserved_num_seqs}, dtype_aux_, device)); - tree_attn_mn_indptr_device_.push_back( - NDArray::Empty({reserved_num_seqs + 1}, dtype_aux_, device)); - } - cur_append_length_indptr_device_ = NDArray::Empty({reserved_num_seqs + 1}, dtype_aux_, device); - k_ragged_rope_pos_offset_device_ = NDArray::Empty({reserved_num_seqs}, dtype_aux_, device); - q_rope_position_map_device_ = NDArray::Empty({prefill_chunk_size}, dtype_aux_, device); - append_position_map_device_ = NDArray::Empty({prefill_chunk_size}, dtype_aux_, device); - kv_transfer_remote_position_map_device = - NDArray::Empty({prefill_chunk_size}, dtype_aux_, device); - kv_transfer_recver_id_device = NDArray::Empty({prefill_chunk_size}, dtype_aux_, device); - kv_transfer_page_to_page_local_position_map_device = - kv_transfer_page_to_page_remote_position_map_device = - NDArray::Empty({prefill_chunk_size}, dtype_aux_, device); - kv_transfer_page_to_page_recver_id_device = - NDArray::Empty({prefill_chunk_size}, dtype_aux_, device); - commit_copy_length_indptr_device_ = NDArray::Empty({reserved_num_seqs + 1}, dtype_aux_, device); - commit_copy_src_dst_pos_in_page_table_device_ = - NDArray::Empty({2, std::min(kTreeAttnMaxTreeSize * reserved_num_seqs, prefill_chunk_size)}, - dtype_aux_, device); - } - - // The reset of the plain auxiliary data manager is no-op. - void ResetAttnAuxDataCopy() final {} - NDArray CopyQOIndptrOnDepthAsync(HostMemoryVector* data, int depth) final { - NDArray view = qo_indptr_on_depths_device_[depth].CreateView( - {static_cast(data->size())}, dtype_aux_); - CopyVecDataToArray(view, data->data()); - return view; - } - NDArray CopyPageIndptrOnDepthAsync(HostMemoryVector* data, int depth) final { - NDArray view = page_indptr_on_depths_device_[depth].CreateView( - {static_cast(data->size())}, dtype_aux_); - CopyVecDataToArray(view, data->data()); - return view; - } - NDArray CopyPageIndicesOnDepthAsync(HostMemoryVector* data, int depth) final { - NDArray view = page_indices_on_depths_device_[depth].CreateView( - {static_cast(data->size())}, dtype_aux_); - CopyVecDataToArray(view, data->data()); - return view; - } - NDArray CopyLastPageLenOnDepthAsync(HostMemoryVector* data, int depth) final { - NDArray view = length_info_on_depths_device_[depth].CreateView( - {static_cast(data->size())}, dtype_aux_); - CopyVecDataToArray(view, data->data()); - return view; - } - NDArray CopyKRoPEPosOffsetOnDepthAsync(HostMemoryVector* data, int depth) final { - NDArray view = k_rope_pos_offset_on_depths_device_[depth].CreateView( - {static_cast(data->size())}, dtype_aux_); - CopyVecDataToArray(view, data->data()); - return view; - } - NDArray CopyCurAppendLengthIndptrAsync(HostMemoryVector* data) final { - NDArray view = cur_append_length_indptr_device_.CreateView({static_cast(data->size())}, - dtype_aux_); - CopyVecDataToArray(view, data->data()); - return view; - } - NDArray CopyKRaggedRoPEPosOffsetAsync(HostMemoryVector* data) final { - NDArray view = k_ragged_rope_pos_offset_device_.CreateView({static_cast(data->size())}, - dtype_aux_); - CopyVecDataToArray(view, data->data()); - return view; - } - NDArray CopyQRoPEPosMapAsync(HostMemoryVector* data) final { - NDArray view = - q_rope_position_map_device_.CreateView({static_cast(data->size())}, dtype_aux_); - CopyVecDataToArray(view, data->data()); - return view; - } - NDArray CopyAppendPositionMapAsync(HostMemoryVector* data) final { - NDArray view = - append_position_map_device_.CreateView({static_cast(data->size())}, dtype_aux_); - CopyVecDataToArray(view, data->data()); - return view; - } - NDArray CopyKVTransferRemotePositionMapAsync(HostMemoryVector* data) final { - NDArray view = kv_transfer_remote_position_map_device.CreateView( - {static_cast(data->size())}, dtype_aux_); - CopyVecDataToArray(view, data->data()); - return view; - } - NDArray CopyKVTransferRecverIDAsync(HostMemoryVector* data) final { - NDArray view = - kv_transfer_recver_id_device.CreateView({static_cast(data->size())}, dtype_aux_); - CopyVecDataToArray(view, data->data()); - return view; - } - NDArray CopyKVTransferPage2PageLocalPositionMapAsync(HostMemoryVector* data) final { - NDArray view = kv_transfer_page_to_page_local_position_map_device.CreateView( - {static_cast(data->size())}, dtype_aux_); - CopyVecDataToArray(view, data->data()); - return view; - } - NDArray CopyKVTransferPage2PageRemotePositionMapAsync(HostMemoryVector* data) final { - NDArray view = kv_transfer_page_to_page_remote_position_map_device.CreateView( - {static_cast(data->size())}, dtype_aux_); - CopyVecDataToArray(view, data->data()); - return view; - } - NDArray CopyKVTransferPage2PageRecverIDAsync(HostMemoryVector* data) final { - NDArray view = kv_transfer_page_to_page_recver_id_device.CreateView( - {static_cast(data->size())}, dtype_aux_); - CopyVecDataToArray(view, data->data()); - return view; - } - - NDArray CopyTreeAttnMaskOnDepthAsync(HostMemoryVector* data, int depth) final { - NDArray view = - tree_attn_mask_device_[depth].CreateView({static_cast(data->size())}, dtype_aux_); - CopyVecDataToArray(view, data->data()); - return view; - } - NDArray CopyTreeAttnMNIndptrOnDepthAsync(HostMemoryVector* data, int depth) final { - NDArray view = tree_attn_mn_indptr_device_[depth].CreateView( - {static_cast(data->size())}, dtype_aux_); - CopyVecDataToArray(view, data->data()); - return view; - } - - NDArray CopyLengthInfoOnDepthAsync(HostMemoryVector* last_page_len, - HostMemoryVector* sliding_window_offset, - HostMemoryVector* sink_size, int depth) final { - int n_elem = last_page_len->size(); - ICHECK_GT(n_elem, 0); - NDArray view = length_info_on_depths_device_[depth].CreateView({3, n_elem}, dtype_aux_); - ShapeTuple copy_shape{n_elem}; - CopyVecDataToArray(view, last_page_len->data(), copy_shape); - CopyVecDataToArray(view, sliding_window_offset->data(), copy_shape, - /*dst_elem_offset=*/n_elem); - CopyVecDataToArray(view, sink_size->data(), copy_shape, - /*dst_elem_offset=*/2 * n_elem); - return view; - } - - // The commit of the plain auxiliary data manager is no-op. - void CommitAttnAuxDataCopy() final {} - - // The reset of the plain auxiliary data manager is no-op. - void ResetCompactKVAuxDataCopy() final {} - - NDArray CopyCommitLengthIndptrAsync(HostMemoryVector* data) final { - NDArray view = commit_copy_length_indptr_device_.CreateView( - {static_cast(data->size())}, dtype_aux_); - CopyVecDataToArray(view, data->data()); - return view; - } - NDArray CopyCommitSrcDstPosInPageTableAsync(HostMemoryVector* src_data, - HostMemoryVector* dst_data) final { - int n_elem = src_data->size(); - ICHECK_GT(n_elem, 0); - NDArray view = - commit_copy_src_dst_pos_in_page_table_device_.CreateView({2, n_elem}, dtype_aux_); - ShapeTuple copy_shape{n_elem}; - CopyVecDataToArray(view, src_data->data(), copy_shape); - CopyVecDataToArray(view, dst_data->data(), copy_shape, - /*dst_elem_offset=*/n_elem); - return view; - } - - // The commit of the plain auxiliary data manager is no-op. - void CommitCompactKVAuxDataCopy() final {} - - private: - /*! - * \brief Copy a vector of data to the input NDArray. - * It optionally supports specifying the shape of copy and the element - * offset to the destination NDArray. - */ - void CopyVecDataToArray(NDArray array, int32_t* vec_data, Optional shape = NullOpt, - int dst_elem_offset = 0) { - if (array->shape[0] == 0) { - return; - } - DLTensor copy_dst = *array.operator->(); -#if defined(OPENCL_ENABLE_HOST_PTR) - tvm::runtime::cl::OpenCLWorkspace* workspace = tvm::runtime::cl::OpenCLWorkspace::Global(); - if (workspace->IsOpenCLDevice(copy_dst.device)) { - void* nptr = workspace->GetNativePtr(array); - uint64_t copy_size; - if (shape.defined()) { - ICHECK_EQ(shape.value().size(), 1); - copy_size = shape.value()->data[0] * sizeof(int32_t); - } else { - copy_size = DeviceAPI::Get(array->device)->GetDataSize(*array.operator->()); - } - memcpy(static_cast(nptr) + dst_elem_offset * sizeof(int32_t), vec_data, copy_size); - return; - } -#endif - - if (shape.defined()) { - ICHECK_EQ(shape.value().size(), 1); - copy_dst.ndim = 1; - copy_dst.shape = shape.value()->data; - } - copy_dst.byte_offset = dst_elem_offset * sizeof(int32_t); - - DLTensor copy_src; - copy_src.data = vec_data; - copy_src.device = preferred_host_device_; - copy_src.ndim = 1; - copy_src.dtype = array->dtype; - copy_src.shape = copy_dst.shape; - copy_src.strides = nullptr; - copy_src.byte_offset = 0; - NDArray::CopyFromTo(©_src, ©_dst, copy_stream_); - } - - std::vector qo_indptr_on_depths_device_; - std::vector page_indptr_on_depths_device_; - std::vector page_indices_on_depths_device_; - std::vector length_info_on_depths_device_; - std::vector k_rope_pos_offset_on_depths_device_; - std::vector tree_attn_mask_device_; - std::vector tree_attn_mn_indptr_device_; - NDArray cur_append_length_indptr_device_; - NDArray k_ragged_rope_pos_offset_device_; - NDArray q_rope_position_map_device_; - NDArray append_position_map_device_; - NDArray kv_transfer_remote_position_map_device; - NDArray kv_transfer_recver_id_device; - NDArray kv_transfer_page_to_page_local_position_map_device; - NDArray kv_transfer_page_to_page_remote_position_map_device; - NDArray kv_transfer_page_to_page_recver_id_device; - NDArray commit_copy_length_indptr_device_; - NDArray commit_copy_src_dst_pos_in_page_table_device_; -}; - -/*! - * \brief The cached auxiliary data manager class. - * It allocates a large on-device array to store all the auxiliary data. - * For each `CopyXXXAsync`, it copies the input data to a local cache on host. - * In `CommitAttnAuxDataCopy`, it copies all the data in the local cache to the device - * array for a single time, and thus reduce the number of host-to-device copies needed. - */ -class CachedPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { - public: - explicit CachedPagedKVCacheAuxDataManager(int64_t reserved_num_seqs, int64_t num_total_pages, - int64_t prefill_chunk_size, DLDataType dtype_aux, - Device device, Device preferred_host_device, - TVMStreamHandle copy_stream) - : PagedKVCacheAuxDataManager(dtype_aux, device, preferred_host_device, copy_stream), - elem_byte_size_((dtype_aux.bits * dtype_aux.lanes + 7) / 8), - offset_alignment_(cuda_byte_alignment_ / elem_byte_size_) { - // - Calculate cache size of all the attention auxiliary arrays in - // local cache and the large on-device array. - int64_t attn_aux_data_cache_size = - CalculateAttnAuxDataCacheSize(reserved_num_seqs, num_total_pages, prefill_chunk_size); - // - Initialize the host auxiliary data buffer. - merged_attn_aux_data_host_ = - HostMemoryVector(attn_aux_data_cache_size, dtype_aux, preferred_host_device); - // - Initialize the device auxiliary data buffer. - merged_attn_aux_data_device_ = NDArray::Empty({attn_aux_data_cache_size}, dtype_aux, device); - - // - Calculate cache size of all the compact KV auxiliary arrays in - // local cache and the large on-device array. - int64_t compact_kv_aux_data_cache_size = - CalculateCompactKVAuxDataCacheSize(reserved_num_seqs, prefill_chunk_size); - // - Initialize the host auxiliary data buffer. - merged_compact_kv_aux_data_host_ = - HostMemoryVector(compact_kv_aux_data_cache_size, dtype_aux, preferred_host_device); - merged_compact_kv_aux_data_device_ = - NDArray::Empty({compact_kv_aux_data_cache_size}, dtype_aux, device); - } - - void ResetAttnAuxDataCopy() final { attn_aux_data_copy_offset_ = 0; } - NDArray CopyQOIndptrOnDepthAsync(HostMemoryVector* data, int depth) final { - return CopyAttnAuxVecToCache(data); - } - NDArray CopyPageIndptrOnDepthAsync(HostMemoryVector* data, int depth) final { - return CopyAttnAuxVecToCache(data); - } - NDArray CopyPageIndicesOnDepthAsync(HostMemoryVector* data, int depth) final { - return CopyAttnAuxVecToCache(data); - } - NDArray CopyLastPageLenOnDepthAsync(HostMemoryVector* data, int depth) final { - return CopyAttnAuxVecToCache(data); - } - NDArray CopyKRoPEPosOffsetOnDepthAsync(HostMemoryVector* data, int depth) final { - return CopyAttnAuxVecToCache(data); - } - NDArray CopyCurAppendLengthIndptrAsync(HostMemoryVector* data) final { - return CopyAttnAuxVecToCache(data); - } - NDArray CopyKRaggedRoPEPosOffsetAsync(HostMemoryVector* data) final { - return CopyAttnAuxVecToCache(data); - } - NDArray CopyQRoPEPosMapAsync(HostMemoryVector* data) final { return CopyAttnAuxVecToCache(data); } - NDArray CopyAppendPositionMapAsync(HostMemoryVector* data) final { - return CopyAttnAuxVecToCache(data); - } - NDArray CopyKVTransferRemotePositionMapAsync(HostMemoryVector* data) final { - return CopyAttnAuxVecToCache(data); - } - NDArray CopyKVTransferRecverIDAsync(HostMemoryVector* data) final { - return CopyAttnAuxVecToCache(data); - } - NDArray CopyKVTransferPage2PageLocalPositionMapAsync(HostMemoryVector* data) final { - return CopyAttnAuxVecToCache(data); - } - NDArray CopyKVTransferPage2PageRemotePositionMapAsync(HostMemoryVector* data) final { - return CopyAttnAuxVecToCache(data); - } - NDArray CopyKVTransferPage2PageRecverIDAsync(HostMemoryVector* data) final { - return CopyAttnAuxVecToCache(data); - } - NDArray CopyTreeAttnMaskOnDepthAsync(HostMemoryVector* data, int depth) final { - NDArray mask_1d = CopyAttnAuxVecToCache(data); - return mask_1d.CreateView({static_cast(data->size() / 2), 2}, mask_1d->dtype); - } - NDArray CopyTreeAttnMNIndptrOnDepthAsync(HostMemoryVector* data, int depth) final { - return CopyAttnAuxVecToCache(data); - } - NDArray CopyLengthInfoOnDepthAsync(HostMemoryVector* last_page_len, - HostMemoryVector* sliding_window_offset, - HostMemoryVector* sink_size, int depth) final { - int64_t n_elem = last_page_len->size(); - std::memcpy(merged_attn_aux_data_host_.data() + attn_aux_data_copy_offset_, - last_page_len->data(), n_elem * elem_byte_size_); - std::memcpy(merged_attn_aux_data_host_.data() + attn_aux_data_copy_offset_ + n_elem, - sliding_window_offset->data(), n_elem * elem_byte_size_); - std::memcpy(merged_attn_aux_data_host_.data() + attn_aux_data_copy_offset_ + 2 * n_elem, - sink_size->data(), n_elem * elem_byte_size_); - NDArray view = merged_attn_aux_data_device_.CreateView( - {3, n_elem}, dtype_aux_, attn_aux_data_copy_offset_ * elem_byte_size_); - attn_aux_data_copy_offset_ += CeilDivElemAlignment(3 * n_elem); - return view; - } - - void CommitAttnAuxDataCopy() final { - std::vector copy_shape{attn_aux_data_copy_offset_}; - DLTensor copy_dst; - copy_dst.data = merged_attn_aux_data_device_->data; - copy_dst.device = device_; - copy_dst.ndim = 1; - copy_dst.dtype = dtype_aux_; - copy_dst.shape = copy_shape.data(); - copy_dst.strides = nullptr; - copy_dst.byte_offset = 0; - - DLTensor copy_src = copy_dst; - copy_src.data = merged_attn_aux_data_host_.data(); - copy_src.device = Device{kDLCPU, 0}; - NDArray::CopyFromTo(©_src, ©_dst, copy_stream_); - } - - void ResetCompactKVAuxDataCopy() final { compact_kv_aux_data_copy_offset_ = 0; } - - NDArray CopyCommitLengthIndptrAsync(HostMemoryVector* data) final { - return CopyCompactKVAuxVecToCache(data); - } - NDArray CopyCommitSrcDstPosInPageTableAsync(HostMemoryVector* src_data, - HostMemoryVector* dst_data) final { - int64_t n_elem = src_data->size(); - std::memcpy(merged_compact_kv_aux_data_host_.data() + compact_kv_aux_data_copy_offset_, - src_data->data(), n_elem * elem_byte_size_); - std::memcpy(merged_compact_kv_aux_data_host_.data() + compact_kv_aux_data_copy_offset_ + n_elem, - dst_data->data(), n_elem * elem_byte_size_); - NDArray view = merged_compact_kv_aux_data_device_.CreateView( - {2, n_elem}, dtype_aux_, compact_kv_aux_data_copy_offset_ * elem_byte_size_); - compact_kv_aux_data_copy_offset_ += CeilDivElemAlignment(2 * n_elem); - return view; - } - - void CommitCompactKVAuxDataCopy() final { - std::vector copy_shape{compact_kv_aux_data_copy_offset_}; - DLTensor copy_dst; - copy_dst.data = merged_compact_kv_aux_data_device_->data; - copy_dst.device = device_; - copy_dst.ndim = 1; - copy_dst.dtype = dtype_aux_; - copy_dst.shape = copy_shape.data(); - copy_dst.strides = nullptr; - copy_dst.byte_offset = 0; - - DLTensor copy_src = copy_dst; - copy_src.data = merged_compact_kv_aux_data_host_.data(); - copy_src.device = Device{kDLCPU, 0}; - NDArray::CopyFromTo(©_src, ©_dst, copy_stream_); - } - - private: - /*! - * \brief Calculate the start element offsets of the auxiliary arrays in the local cache. - * \return Return the local cache size (total number of elements in the local cache). - */ - int64_t CalculateAttnAuxDataCacheSize(int64_t reserved_num_seqs, int64_t num_total_pages, - int64_t prefill_chunk_size) { - int64_t cache_size = 0; - // - Array size of the arrays that every depth has. - // Corresponding to the following arrays respectively - // - qo_indptr_in_depth - // - page_indptr_in_depth - // - page_indices_in_depth - // - length_info_in_depth - // - k_rope_pos_offset_in_depth - cache_size += CeilDivElemAlignment(reserved_num_seqs + 1); - cache_size += CeilDivElemAlignment(reserved_num_seqs + 1); - cache_size += CeilDivElemAlignment(num_total_pages); - cache_size += CeilDivElemAlignment(3 * reserved_num_seqs); - cache_size += CeilDivElemAlignment(reserved_num_seqs); - cache_size *= kPagedKVCacheMaxBlockDepth; - - // - Array size of other arrays. - // Corresponding to the following arrays respectively - // - cur_append_length_indptr - // - k_ragged_rope_pos_offset - // - q_rope_position_map - // - append_position_map - // - kv_transfer_remote_position_map - // - kv_transfer_recver_id - // - kv_transfer_page_to_page_local_position_map - // - kv_transfer_page_to_page_remote_position_map - // - kv_transfer_page_to_page_recver_id - // - tree_attn_mask - // - tree_attn_mn_indptr - cache_size += CeilDivElemAlignment(reserved_num_seqs + 1); - cache_size += CeilDivElemAlignment(reserved_num_seqs); - cache_size += CeilDivElemAlignment(prefill_chunk_size); - cache_size += CeilDivElemAlignment(prefill_chunk_size); - cache_size += CeilDivElemAlignment(prefill_chunk_size); - cache_size += CeilDivElemAlignment(prefill_chunk_size); - cache_size += CeilDivElemAlignment(prefill_chunk_size); - cache_size += CeilDivElemAlignment(prefill_chunk_size); - cache_size += CeilDivElemAlignment(prefill_chunk_size); - cache_size += - CeilDivElemAlignment(kTreeAttnMaxTreeSize * kTreeAttnMaxTreeSize * reserved_num_seqs); - cache_size += CeilDivElemAlignment(reserved_num_seqs + 1); - - return cache_size; - } - - int64_t CalculateCompactKVAuxDataCacheSize(int64_t reserved_num_seqs, - int64_t prefill_chunk_size) { - int64_t cache_size = 0; - // Corresponding to the following arrays respectively - // - commit_copy_length_indptr - // - commit_copy_src_dst_pos_in_page_table - cache_size += CeilDivElemAlignment(reserved_num_seqs + 1); - cache_size += CeilDivElemAlignment( - 2 * std::min(kTreeAttnMaxTreeSize * reserved_num_seqs, prefill_chunk_size)); - - return cache_size; - } - - /*! - * \brief Copy the input data to the cache at the given offset. - * And return the NDArray view of the cache starting at the offset. - */ - NDArray CopyAttnAuxVecToCache(HostMemoryVector* data) { - int64_t n_elem = data->size(); - std::memcpy(merged_attn_aux_data_host_.data() + attn_aux_data_copy_offset_, data->data(), - n_elem * elem_byte_size_); - NDArray view = merged_attn_aux_data_device_.CreateView( - {n_elem}, dtype_aux_, attn_aux_data_copy_offset_ * elem_byte_size_); - attn_aux_data_copy_offset_ += CeilDivElemAlignment(n_elem); - return view; - } - - NDArray CopyCompactKVAuxVecToCache(HostMemoryVector* data) { - int64_t n_elem = data->size(); - std::memcpy(merged_compact_kv_aux_data_host_.data() + compact_kv_aux_data_copy_offset_, - data->data(), n_elem * elem_byte_size_); - NDArray view = merged_compact_kv_aux_data_device_.CreateView( - {n_elem}, dtype_aux_, compact_kv_aux_data_copy_offset_ * elem_byte_size_); - compact_kv_aux_data_copy_offset_ += CeilDivElemAlignment(n_elem); - return view; - } - - /*! \brief For safety, we align the start offset of the arrays to `offset_alignment`. */ - int64_t CeilDivElemAlignment(int n) { - return (n + offset_alignment_ - 1) / offset_alignment_ * offset_alignment_; - } - - const int64_t cuda_byte_alignment_ = 16; - const int64_t elem_byte_size_; - const int64_t offset_alignment_; - - int64_t attn_aux_data_copy_offset_ = 0; - int64_t compact_kv_aux_data_copy_offset_ = 0; - HostMemoryVector merged_attn_aux_data_host_; - HostMemoryVector merged_compact_kv_aux_data_host_; - NDArray merged_attn_aux_data_device_; - NDArray merged_compact_kv_aux_data_device_; -}; - /*! * \brief The paged KV cache for attention. * - It supports managing the K/V data of **multiple sequences**. @@ -962,6 +82,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { const int64_t num_layers_; /*! \brief The beginning layer id offset. */ const int64_t layer_id_begin_offset_; + /*! \brief The ending layer id offset. */ + const int64_t layer_id_end_offset_; /*! \brief The number of query/output heads in the model. */ const int64_t num_qo_heads_; /*! \brief The number of key/value heads in the model. */ @@ -973,11 +95,6 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { * For layers that use multi-head attention, this field is overriden by qk_head_dim. */ const int64_t v_head_dim_; - /*! - * \brief The number of features each head has for RoPE in multi-head latent attention. - * This field is ignored for non-MLA. - */ - const int64_t qk_rope_head_dim_; /*! \brief The number of total pages allocated in KV cache. */ const int64_t num_total_pages_; /*! \brief The maximum total sequence length in a prefill. */ @@ -996,6 +113,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { /*! \brief The optional RoPE extension factors for RoPE scaling. */ const Optional rope_ext_factors_; + /*! \brief The KV cache dtype. */ + const DataType kv_dtype_; /*! \brief We fix int32 to be the index dtype of auxiliary data. */ const DLDataType dtype_aux_ = DLDataType(DataType::Int(32, 1)); @@ -1066,9 +185,10 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { NDArray temp_attn_k_device_; NDArray temp_attn_v_device_; NDArray temp_attn_output_device_; - NDArray temp_attn_scores_device_; - NDArray merged_attn_scores_device_; + NDArray temp_attn_lse_device_; + NDArray merged_attn_lse_device_; std::vector temp_int_attn_workspace_; + std::vector temp_int_pinned_attn_workspace_; NDArray temp_float_attn_workspace_; //------------------------------------------- @@ -1114,8 +234,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { NDArray kv_transfer_page_to_page_remote_position_map_view_; NDArray kv_transfer_page_to_page_recver_id_view_; NDArray temp_attn_output_view_; - NDArray temp_attn_scores_view_; - NDArray merged_attn_scores_view_; + NDArray temp_attn_lse_view_; + NDArray merged_attn_lse_view_; std::vector qo_indptr_on_depths_view_; std::vector page_indptr_on_depths_view_; std::vector page_indices_on_depths_view_; @@ -1124,29 +244,20 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { std::vector tree_attn_mask_view_; std::vector tree_attn_mn_indptr_view_; - PackedFunc f_transpose_append_; - PackedFunc f_transpose_append_mla_; + Optional f_transpose_append_mha_; + Optional f_transpose_append_mla_; Optional f_transfer_kv_; Optional f_transfer_kv_page_to_page_ = NullOpt; PackedFunc f_compact_copy_; - PackedFunc f_attention_prefill_; - PackedFunc f_attention_decode_; - PackedFunc f_attention_prefill_sliding_window_; - PackedFunc f_attention_decode_sliding_window_; - PackedFunc f_attention_prefill_ragged_; - PackedFunc f_attention_prefill_with_tree_mask_; - PackedFunc f_attention_prefill_with_tree_mask_paged_kv_; - Optional f_attention_prefill_ragged_begin_forward_; - Optional f_attention_prefill_ragged_end_forward_; - Optional f_attention_prefill_begin_forward_; - Optional f_attention_prefill_end_forward_; - Optional f_attention_decode_begin_forward_; - Optional f_attention_decode_end_forward_; - PackedFunc f_mla_prefill_; - PackedFunc f_mla_decode_; - PackedFunc f_mla_prefill_ragged_normal_; - PackedFunc f_mla_prefill_ragged_absorbed_; - PackedFunc f_merge_inplace_; + std::unique_ptr f_attention_prefill_ragged_; + std::unique_ptr f_attention_prefill_; + std::unique_ptr f_attention_decode_; + std::unique_ptr f_attention_prefill_sliding_window_; + std::unique_ptr f_attention_decode_sliding_window_; + std::unique_ptr f_attention_prefill_with_tree_mask_paged_kv_; + std::unique_ptr f_attention_prefill_with_tree_mask_; + std::unique_ptr f_mla_prefill_; + Array f_merge_inplace_; PackedFunc f_split_rotary_; PackedFunc f_copy_single_page_; Optional f_debug_get_kv_; @@ -1163,34 +274,30 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { public: /*! \brief Constructor. Take the cache configuration and initialize the NDArrays. */ explicit PagedAttentionKVCacheObj( - int64_t page_size, int64_t num_layers, int64_t layer_id_begin_offset, // - int64_t num_qo_heads, int64_t num_kv_heads, int64_t qk_head_dim, int64_t v_head_dim, - int64_t qk_rope_head_dim, std::vector attn_kinds, int64_t reserved_num_seqs, + int64_t page_size, int64_t num_layers, int64_t layer_id_begin_offset, + int64_t layer_id_end_offset, int64_t num_qo_heads, int64_t num_kv_heads, int64_t qk_head_dim, + int64_t v_head_dim, std::vector attn_kinds, int64_t reserved_num_seqs, int64_t num_total_pages, int64_t prefill_chunk_size, bool support_sliding_window, RoPEMode rope_mode, double rotary_scale, double rotary_theta, Optional rope_ext_factors, bool enable_kv_transfer, DLDataType dtype, Device device, - PackedFunc f_transpose_append, PackedFunc f_transpose_append_mla, PackedFunc f_compact_copy, - PackedFunc f_attention_prefill, PackedFunc f_attention_decode, - PackedFunc f_attention_prefill_sliding_window, PackedFunc f_attention_decode_sliding_window, - PackedFunc f_attention_prefill_ragged, PackedFunc f_attention_prefill_with_tree_mask, - PackedFunc f_attention_prefill_with_tree_mask_paged_kv, - Optional f_attention_prefill_ragged_begin_forward, - Optional f_attention_prefill_ragged_end_forward, - Optional f_attention_prefill_begin_forward, - Optional f_attention_prefill_end_forward, - Optional f_attention_decode_begin_forward, - Optional f_attention_decode_end_forward, PackedFunc f_mla_prefill, - PackedFunc f_mla_decode, PackedFunc f_mla_prefill_ragged_normal, - PackedFunc f_mla_prefill_ragged_absorbed, PackedFunc f_merge_inplace, - PackedFunc f_split_rotary, PackedFunc f_copy_single_page, Optional f_debug_get_kv) + Optional f_transpose_append_mha, Optional f_transpose_append_mla, + PackedFunc f_compact_copy, std::unique_ptr f_attention_prefill_ragged, + std::unique_ptr f_attention_prefill, + std::unique_ptr f_attention_decode, + std::unique_ptr f_attention_prefill_sliding_window, + std::unique_ptr f_attention_decode_sliding_window, + std::unique_ptr f_attention_prefill_with_tree_mask_paged_kv, + std::unique_ptr f_attention_prefill_with_tree_mask, + std::unique_ptr f_mla_prefill, Array f_merge_inplace, + PackedFunc f_split_rotary, PackedFunc f_copy_single_page, PackedFunc f_debug_get_kv) : page_size_(page_size), num_layers_(num_layers), layer_id_begin_offset_(layer_id_begin_offset), + layer_id_end_offset_(layer_id_end_offset), num_qo_heads_(num_qo_heads), num_kv_heads_(num_kv_heads), qk_head_dim_(qk_head_dim), v_head_dim_(v_head_dim), - qk_rope_head_dim_(qk_rope_head_dim), num_total_pages_(num_total_pages), prefill_chunk_size_(prefill_chunk_size), support_sliding_window_(support_sliding_window), @@ -1200,28 +307,19 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { rotary_scale_(rotary_scale), rotary_theta_(rotary_theta), rope_ext_factors_(std::move(rope_ext_factors)), - f_transpose_append_(std::move(f_transpose_append)), + kv_dtype_(DataType(dtype)), + f_transpose_append_mha_(std::move(f_transpose_append_mha)), f_transpose_append_mla_(std::move(f_transpose_append_mla)), f_compact_copy_(std::move(f_compact_copy)), + f_attention_prefill_ragged_(std::move(f_attention_prefill_ragged)), f_attention_prefill_(std::move(f_attention_prefill)), f_attention_decode_(std::move(f_attention_decode)), f_attention_prefill_sliding_window_(std::move(f_attention_prefill_sliding_window)), f_attention_decode_sliding_window_(std::move(f_attention_decode_sliding_window)), - f_attention_prefill_ragged_(std::move(f_attention_prefill_ragged)), - f_attention_prefill_with_tree_mask_(std::move(f_attention_prefill_with_tree_mask)), f_attention_prefill_with_tree_mask_paged_kv_( std::move(f_attention_prefill_with_tree_mask_paged_kv)), - f_attention_prefill_ragged_begin_forward_( - std::move(f_attention_prefill_ragged_begin_forward)), - f_attention_prefill_ragged_end_forward_(std::move(f_attention_prefill_ragged_end_forward)), - f_attention_prefill_begin_forward_(std::move(f_attention_prefill_begin_forward)), - f_attention_prefill_end_forward_(std::move(f_attention_prefill_end_forward)), - f_attention_decode_begin_forward_(std::move(f_attention_decode_begin_forward)), - f_attention_decode_end_forward_(std::move(f_attention_decode_end_forward)), + f_attention_prefill_with_tree_mask_(std::move(f_attention_prefill_with_tree_mask)), f_mla_prefill_(std::move(f_mla_prefill)), - f_mla_decode_(std::move(f_mla_decode)), - f_mla_prefill_ragged_normal_(std::move(f_mla_prefill_ragged_normal)), - f_mla_prefill_ragged_absorbed_(std::move(f_mla_prefill_ragged_absorbed)), f_merge_inplace_(std::move(f_merge_inplace)), f_split_rotary_(std::move(f_split_rotary)), f_copy_single_page_(std::move(f_copy_single_page)), @@ -1262,9 +360,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { f_transfer_kv_page_to_page_ = *f_transfer_kv_page_to_page_ptr; } else { for (int i = 0; i < num_layers; ++i) { - ShapeTuple kv_cache_shape = GetKVCacheShape( - attn_kinds_[layer_id_begin_offset_ + i], num_total_pages, reserved_num_seqs, - num_kv_heads, page_size, qk_head_dim, v_head_dim, qk_rope_head_dim); + ShapeTuple kv_cache_shape = + GetKVCacheShape(attn_kinds_[layer_id_begin_offset_ + i], num_total_pages, + reserved_num_seqs, num_kv_heads, page_size, qk_head_dim, v_head_dim); pages_.push_back(NDArray::Empty(kv_cache_shape, dtype, device)); } } @@ -1321,7 +419,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { for (int d = 0; d < kPagedKVCacheMaxBlockDepth; ++d) { if (NeedKernelBeginForward()) { temp_int_attn_workspace_.push_back( - NDArray::Empty({kIntAttnWorkspaceByte / 4}, DataType::Float(32), device)); + NDArray::Empty({kIntAttnWorkspaceByte}, DataType::UInt(8), device)); + temp_int_pinned_attn_workspace_.push_back(NDArray::Empty( + {kIntAttnWorkspaceByte}, DataType::UInt(8), GetPreferredHostDevice(device))); } qo_indptr_on_depths_view_.push_back(NDArray()); page_indptr_on_depths_view_.push_back(NDArray()); @@ -1335,9 +435,11 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // Additional workspace for the "prefill with ragged kv" kernel. if (NeedKernelBeginForward()) { temp_int_attn_workspace_.push_back( - NDArray::Empty({kIntAttnWorkspaceByte / 4}, DataType::Float(32), device)); + NDArray::Empty({kIntAttnWorkspaceByte}, DataType::UInt(8), device)); + temp_int_pinned_attn_workspace_.push_back(NDArray::Empty( + {kIntAttnWorkspaceByte}, DataType::UInt(8), GetPreferredHostDevice(device))); temp_float_attn_workspace_ = - NDArray::Empty({kFloatAttnWorkspaceByte / 4}, DataType::Float(32), device); + NDArray::Empty({kFloatAttnWorkspaceByte}, DataType::UInt(8), device); } if (std::find(attn_kinds_.begin(), attn_kinds_.end(), AttnKind::kMHA) != attn_kinds_.end()) { @@ -1350,9 +452,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } temp_attn_output_device_ = NDArray::Empty({prefill_chunk_size_, num_qo_heads, v_head_dim}, dtype, device); - temp_attn_scores_device_ = + temp_attn_lse_device_ = NDArray::Empty({prefill_chunk_size_, num_qo_heads}, DataType::Float(32), device); - merged_attn_scores_device_ = + merged_attn_lse_device_ = NDArray::Empty({prefill_chunk_size_, num_qo_heads}, DataType::Float(32), device); for (int64_t page_id = num_total_pages - 1; page_id >= 0; --page_id) { free_page_ids_.push_back(page_id); @@ -1757,7 +859,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } } - auto [block_ids_on_depths, trailing_blocks] = GetBlockIdsOnDepth(sequences); + auto [block_ids_on_depths, trailing_blocks] = + GetBlockIdsOnDepth(sequences, global_block_pool_, cur_batch_size_); num_depths_ = std::min(static_cast(block_ids_on_depths.size()), kPagedKVCacheMaxBlockDepth); ICHECK_LE(num_depths_, kPagedKVCacheMaxBlockDepth); @@ -1769,7 +872,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // We force the blocks at maximum depth not to coalesce, so that it can be concatenated with // trailing exceeding blocks. auto [chunked_block_ids, use_decode_kernel] = GetChunkedBlockIds( - block_ids_on_depths[d], /*enable_coalesce=*/d != kPagedKVCacheMaxBlockDepth - 1); + block_ids_on_depths[d], /*enable_coalesce=*/d != kPagedKVCacheMaxBlockDepth - 1, + cur_append_lengths_, global_block_pool_, is_decode_request_); chunked_block_ids_arr.push_back(chunked_block_ids); use_decode_kernel_.push_back(use_decode_kernel); } @@ -1784,6 +888,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { if (NeedKernelBeginForward() && num_qo_heads_ / num_kv_heads_ >= 4) { // When GQA group size is at least 4 and FlashInfer is enabled, // we always use prefill kernel for better performance. + // Note: For MLA, we always use prefill kernel, so values in `use_decode_kernel` will + // be ignored for MLA. std::fill(use_decode_kernel_.begin(), use_decode_kernel_.end(), /*value=*/false); } @@ -1991,15 +1097,6 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { if (kv_transfer_stream_ != nullptr) { DeviceAPI::Get(device_)->SyncStreamFromTo(device_, kv_transfer_stream_, compute_stream_); } - if (!f_attention_prefill_end_forward_.defined() || !f_attention_decode_end_forward_.defined() || - !f_attention_prefill_ragged_end_forward_.defined()) { - return; - } - f_attention_prefill_ragged_end_forward_.value()(); - for (int d = 0; d < num_depths_; ++d) { - f_attention_prefill_end_forward_.value()(d); - f_attention_decode_end_forward_.value()(d); - } } IntTuple DisaggPrepareRecv(int64_t seq_id, int append_length) final { @@ -2085,7 +1182,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } void AttentionWithFusedQKV(int64_t layer_id, NDArray qkv_data, Optional mask, - NDArray o_data, double attn_score_scaling_factor) final { + NDArray o_data, double sm_scale) final { // Part 1. Shape and dtype check. int64_t local_layer_id = layer_id - layer_id_begin_offset_; CHECK_GE(local_layer_id, 0); @@ -2149,8 +1246,10 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } // Part 3. Append k/v data to kv-cache if flag "append_before_attn" is set. + CHECK(f_transpose_append_mha_.defined()); if (append_before_attn_) { - f_transpose_append_(pages_[local_layer_id], k_data, v_data, append_position_map_view_); + f_transpose_append_mha_.value()(pages_[local_layer_id], k_data, v_data, + append_position_map_view_); } // Part 4: KV transfer if (page_to_page_transfer_kv_) { @@ -2174,153 +1273,132 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { kv_transfer_stream_); } // Part 5: perform attention - AttentionInternal(layer_id, q_data, k_data, v_data, o_data_view, attn_score_scaling_factor); + AttentionInternal(layer_id, q_data, k_data, v_data, o_data_view, sm_scale); // Part 6. Append k/v data to kv-cache if flag "append_before_attn" is not set. if (!append_before_attn_) { - f_transpose_append_(pages_[local_layer_id], k_data, v_data, append_position_map_view_); + f_transpose_append_mha_.value()(pages_[local_layer_id], k_data, v_data, + append_position_map_view_); } } - void MLAAbsorbed(int64_t layer_id, NDArray q_data, NDArray compressed_kv_data, NDArray k_pe_data, - NDArray o_data, double attn_score_scaling_factor) { - // Part 1. Shape and dtype check. + void SelfAttention(int64_t layer_id, NDArray q_data, NDArray k_data, NDArray v_data, + NDArray o_data, NDArray lse_data, double sm_scale) final { + // Shape and dtype check. int64_t local_layer_id = layer_id - layer_id_begin_offset_; CHECK_GE(local_layer_id, 0); CHECK_LT(local_layer_id, num_layers_); NDArray pages = pages_[local_layer_id]; CHECK(q_data.DataType() == pages.DataType()); - CHECK(compressed_kv_data.DataType() == pages.DataType()); - CHECK(k_pe_data.DataType() == pages.DataType()); + CHECK(k_data.DataType() == pages.DataType()); + CHECK(v_data.DataType() == pages.DataType()); CHECK(o_data.DataType() == pages.DataType()); - CHECK(attn_kinds_[layer_id] == AttnKind::kMLA); + AttnKind attn_kind = attn_kinds_[layer_id]; // q_data: (num_total_length, num_qo_heads, qk_head_dim) - // compressed_kv_data: (num_total_length, qk_head_dim - qk_rope_head_dim) - // k_pe_data: (num_total_length, qk_rope_head_dim) + // k_data: (num_total_length, num_kv_heads, qk_head_dim) + // v_data: (num_total_length, num_kv_heads, v_head_dim) // o_data: (num_total_length, num_qo_heads, v_head_dim) - CHECK_EQ(q_data->ndim, 3); - CHECK_EQ(compressed_kv_data->ndim, 2); - CHECK_EQ(k_pe_data->ndim, 2); - CHECK_EQ(o_data->ndim, 3); int64_t total_seq_length = 0; for (int64_t seq_id = 0; seq_id < cur_batch_size_; ++seq_id) { total_seq_length += cur_append_lengths_[seq_id]; } - CHECK_LE(q_data->shape[0], total_seq_length); - CHECK_LE(compressed_kv_data->shape[0], total_seq_length); - CHECK_LE(k_pe_data->shape[0], total_seq_length); - CHECK_LE(o_data->shape[0], total_seq_length); - CHECK_EQ(q_data->shape[1], num_qo_heads_); - CHECK_EQ(o_data->shape[1], num_qo_heads_); - CHECK_EQ(q_data->shape[2], qk_head_dim_); - CHECK_EQ(compressed_kv_data->shape[1], qk_head_dim_ - qk_rope_head_dim_); - CHECK_EQ(k_pe_data->shape[1], qk_rope_head_dim_); - CHECK_EQ(o_data->shape[2], v_head_dim_); + CHECK_EQ(q_data->ndim, 3); + CHECK_EQ(k_data->ndim, 3); + CHECK_EQ(v_data->ndim, 3); + CHECK_EQ(o_data->ndim, 3); + CHECK_EQ(q_data->shape[0], total_seq_length); + CHECK_EQ(k_data->shape[0], total_seq_length); + CHECK_EQ(v_data->shape[0], total_seq_length); + CHECK_EQ(o_data->shape[0], total_seq_length); // Sync the copy stream and the compute stream. ComputeStreamWaitForCopyStream(); // The auxiliary data structure on device must have been synchronized. ICHECK(!dirty_aux_data_device_); - // Append k/v data to kv-cache if flag "append_before_attn" is set. - if (append_before_attn_) { - f_transpose_append_mla_(pages_[local_layer_id], compressed_kv_data, k_pe_data, - append_position_map_view_); - } - // Perform MLA with weight absorption. - MLAAbsorbedInternal(layer_id, q_data, compressed_kv_data, k_pe_data, o_data, - attn_score_scaling_factor); - // Append k/v data to kv-cache if flag "append_before_attn" is not set. - if (!append_before_attn_) { - f_transpose_append_mla_(pages_[local_layer_id], compressed_kv_data, k_pe_data, - append_position_map_view_); + if (attn_kind == AttnKind::kMHA) { + MHASelfAttnInternal(q_data, k_data, v_data, o_data, lse_data, sm_scale); + } else { + MLASelfAttnInternal(q_data, k_data, v_data, o_data, lse_data, sm_scale); } } - void MLANormal(int64_t layer_id, NDArray q_data, NDArray k_data, NDArray v_data, - NDArray compressed_kv_data, NDArray k_pe_data, NDArray o_data, - double attn_score_scaling_factor) { - // Part 1: Basic Checks and Setup. + void CrossAttention(int64_t layer_id, NDArray q_data, NDArray o_data, NDArray lse_data, + double sm_scale) final { + // Shape and dtype check. int64_t local_layer_id = layer_id - layer_id_begin_offset_; CHECK_GE(local_layer_id, 0); CHECK_LT(local_layer_id, num_layers_); NDArray pages = pages_[local_layer_id]; CHECK(q_data.DataType() == pages.DataType()); - CHECK(k_data.DataType() == pages.DataType()); - CHECK(v_data.DataType() == pages.DataType()); - CHECK(compressed_kv_data.DataType() == pages.DataType()); - CHECK(k_pe_data.DataType() == pages.DataType()); CHECK(o_data.DataType() == pages.DataType()); - CHECK(attn_kinds_[layer_id] == AttnKind::kMLA); + AttnKind attn_kind = attn_kinds_[layer_id]; - // Expected shapes: - // q_data: (num_total_length, num_qo_heads, qk_head_dim) - // k_data: (num_total_length, num_qo_heads, qk_head_dim) - // v_data: (num_total_length, num_qo_heads, v_head_dim) - // compressed_kv_data: (num_total_length, qk_head_dim - qk_rope_head_dim) - // k_pe_data: (num_total_length, qk_rope_head_dim) - // o_data: (num_total_length, num_qo_heads, v_head_dim) - CHECK_EQ(q_data->ndim, 3); - CHECK_EQ(k_data->ndim, 3); - CHECK_EQ(v_data->ndim, 3); - CHECK_EQ(compressed_kv_data->ndim, 2); - CHECK_EQ(k_pe_data->ndim, 2); - CHECK_EQ(o_data->ndim, 3); + // q_data: (num_total_length, num_qo_heads, qk_head_dim) + // o_data: (num_total_length, num_qo_heads, v_head_dim) int64_t total_seq_length = 0; - for (int64_t i = 0; i < cur_batch_size_; ++i) { - total_seq_length += cur_append_lengths_[i]; - } - CHECK_LE(q_data->shape[0], total_seq_length); - CHECK_LE(k_data->shape[0], total_seq_length); - CHECK_LE(v_data->shape[0], total_seq_length); - CHECK_LE(compressed_kv_data->shape[0], total_seq_length); - CHECK_LE(k_pe_data->shape[0], total_seq_length); - CHECK_EQ(k_pe_data->shape[1], qk_rope_head_dim_); - CHECK_LE(o_data->shape[0], total_seq_length); + for (int64_t seq_id = 0; seq_id < cur_batch_size_; ++seq_id) { + total_seq_length += cur_append_lengths_[seq_id]; + } + CHECK_EQ(q_data->ndim, 3); + CHECK_EQ(o_data->ndim, 3); + CHECK_EQ(q_data->shape[0], total_seq_length); + CHECK_EQ(o_data->shape[0], total_seq_length); CHECK_EQ(q_data->shape[1], num_qo_heads_); CHECK_EQ(o_data->shape[1], num_qo_heads_); - CHECK_EQ(k_data->shape[1], num_qo_heads_); - CHECK_EQ(v_data->shape[1], num_qo_heads_); CHECK_EQ(q_data->shape[2], qk_head_dim_); - CHECK_EQ(k_data->shape[2], qk_head_dim_); - CHECK_EQ(v_data->shape[2], v_head_dim_); CHECK_EQ(o_data->shape[2], v_head_dim_); - // Part 2: Synchronize streams and update auxiliary data. + // Sync the copy stream and the compute stream. ComputeStreamWaitForCopyStream(); + // The auxiliary data structure on device must have been synchronized. ICHECK(!dirty_aux_data_device_); - // Append k/v data to kv-cache if flag "append_before_attn" is set. - if (append_before_attn_) { - f_transpose_append_mla_(pages_[local_layer_id], compressed_kv_data, k_pe_data, - append_position_map_view_); - } - - // Part 4: Call the ragged kernel. - // Here, we use f_mla_prefill_ragged_normal_, which is designed to work for both decode - // and normal prefill cases. Optionally, you could check a flag like `use_decode_kernel_[0]` - // to adjust parameters; here we assume the kernel internally supports both cases. - f_mla_prefill_ragged_normal_(q_data, cur_append_length_indptr_view_, k_data, v_data, - cur_append_length_indptr_view_, q_rope_position_map_view_, - k_ragged_rope_pos_offset_view_, - o_data, // output tensor - merged_attn_scores_view_, - /*causal=*/1, static_cast(RoPEMode::kNone), - 0, // Rope param, not important - 0, // Rope param, not important - attn_score_scaling_factor); - - // Part 5: If appending is to occur after attention, call the append kernel. - if (!append_before_attn_) { - f_transpose_append_mla_(pages_[local_layer_id], compressed_kv_data, k_pe_data, - append_position_map_view_); + if (attn_kind == AttnKind::kMHA) { + MHACrossAttnInternal(local_layer_id, q_data, o_data, lse_data, sm_scale, + /*is_first_kernel=*/true); + } else { + MLACrossAttnInternal(local_layer_id, q_data, o_data, lse_data, sm_scale); } } + void AppendMLAKV(int64_t layer_id, NDArray kv_data) final { + // Shape and dtype check. + int64_t local_layer_id = layer_id - layer_id_begin_offset_; + CHECK_GE(local_layer_id, 0); + CHECK_LT(local_layer_id, num_layers_); + NDArray pages = pages_[local_layer_id]; + CHECK(kv_data.DataType() == pages.DataType()); + CHECK(attn_kinds_[layer_id] == AttnKind::kMLA); + + // kv_data: (num_total_length, qk_head_dim) + CHECK_EQ(kv_data->ndim, 2); + int64_t total_seq_length = 0; + for (int64_t seq_id = 0; seq_id < cur_batch_size_; ++seq_id) { + total_seq_length += cur_append_lengths_[seq_id]; + } + CHECK_LE(kv_data->shape[0], total_seq_length); + CHECK_EQ(kv_data->shape[1], qk_head_dim_); + // Sync the copy stream and the compute stream. + ComputeStreamWaitForCopyStream(); + // The auxiliary data structure on device must have been synchronized. + ICHECK(!dirty_aux_data_device_); + + CHECK(f_transpose_append_mla_.defined()); + f_transpose_append_mla_.value()(pages_[local_layer_id], kv_data, append_position_map_view_); + } + + Array MergeAttnOutputInplace(NDArray o_self_attn, NDArray lse_self_attn, + NDArray o_cross_attn, NDArray lse_cross_attn) final { + CHECK_GE(f_merge_inplace_.size(), 2) << "The general attention merge function is not defined."; + f_merge_inplace_[1](o_self_attn, lse_self_attn, o_cross_attn, lse_cross_attn); + return {o_self_attn, lse_self_attn}; + } + void LinearAttention(int64_t layer_id, NDArray q_data, NDArray k_data, NDArray v_data, - double attn_score_scaling_factor) { + double sm_scale) { // Todo(ruihang): implement it } @@ -2795,125 +1873,95 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { dirty_aux_data_device_ = true; } - /*! - * \brief For the given list of sequences, check the block trace of - * each sequence, and return the blocks ids used by the sequences - * on each depth. And if the depth is larger than the kPagedKVCacheMaxBlockDepth, - * the exceeding blocks will concatenate and output separately. - * More precisely, the inner returned vector contains the block ids - * used by the sequences on a certain depth (or "-1" if a sequence - * has fewer depth). The outer returned vector contains the inner - * vectors from the lowest depth to the highest depth. - */ - std::pair>, std::vector>> - GetBlockIdsOnDepth(const std::vector& sequences) const { - // - Get the trace of each sequence. - int64_t num_depths = 0; - std::vector> seq_block_traces; - std::vector> trailing_block_traces; - seq_block_traces.reserve(cur_batch_size_); - trailing_block_traces.reserve(cur_batch_size_); - for (int i = 0; i < cur_batch_size_; ++i) { - std::vector trace = sequences[i]->GetBlockTrace(global_block_pool_); - if (static_cast(trace.size()) <= kPagedKVCacheMaxBlockDepth) { - seq_block_traces.push_back(std::vector(trace.begin(), trace.end())); - trailing_block_traces.push_back({}); - num_depths = std::max(num_depths, static_cast(trace.size())); - } else { - seq_block_traces.push_back( - std::vector(trace.begin(), trace.begin() + kPagedKVCacheMaxBlockDepth)); - trailing_block_traces.push_back( - std::vector(trace.begin() + kPagedKVCacheMaxBlockDepth, trace.end())); - num_depths = std::max(num_depths, static_cast(kPagedKVCacheMaxBlockDepth)); + /*! \brief Check whether BeginForward for kernels is needed. */ + bool NeedKernelBeginForward() { + std::vector funcs = {f_attention_prefill_.get(), + f_attention_prefill_ragged_.get(), + f_attention_decode_.get(), + f_attention_prefill_sliding_window_.get(), + f_attention_decode_sliding_window_.get(), + f_attention_prefill_with_tree_mask_.get(), + f_attention_prefill_with_tree_mask_paged_kv_.get(), + f_mla_prefill_.get()}; + for (AttnBackendFunc* func : funcs) { + if (func != nullptr && func->backend_kind == AttnBackendKind::kFlashInfer) { + return true; } } + return false; + } - // "Transpose" the traces, yielding the block ids used on each depth. - std::vector> block_ids_on_depths; - block_ids_on_depths.reserve(num_depths); - for (int d = 0; d < num_depths; ++d) { - std::vector block_ids; - block_ids.reserve(cur_batch_size_); - for (int i = 0; i < cur_batch_size_; ++i) { - block_ids.push_back( - d < static_cast(seq_block_traces[i].size()) ? seq_block_traces[i][d] : -1); - } - block_ids_on_depths.push_back(std::move(block_ids)); + /*! \brief Invoke the "begin forward" functions of underlying kernels. */ + void KernelBeginForward() { + if (!NeedKernelBeginForward()) { + return; + } + + auto it_layer_begin = attn_kinds_.begin() + layer_id_begin_offset_; + auto it_layer_end = attn_kinds_.begin() + layer_id_end_offset_; + if (std::find(it_layer_begin, it_layer_end, AttnKind::kMHA) != it_layer_end) { + MHAKernelBeginForward(); + } + if (std::find(it_layer_begin, it_layer_end, AttnKind::kMLA) != it_layer_end) { + MLAKernelBeginForward(); } - return {block_ids_on_depths, trailing_block_traces}; } - /*! - * \brief This function considers an optimization which coalesces - * adjacent decode attention computations into a single prefill - * attention computation if the adjacent decodes attend to the same - * k/v values under certain conditions. - * If it decides to coalesce on a certain depth, we need to know - * the prefill length after coalescing. This function returns - * - a vector of block ids together with the prefill/decode lengths - * that attend to the blocks. - * - a boolean indicating whether to use decode kernel on for the - * input blocks. - */ - std::pair>, bool> GetChunkedBlockIds( - const std::vector& block_ids, bool enable_coalesce = true) const { - std::vector> uncoalesced_block_ids; - std::vector> coalesced_block_ids; - - // Gather the number of pages before/after coalescing respectively. - int cur_block_id = block_ids[0]; - int chunk_append_length = cur_append_lengths_[0]; - int page_counter_coalesced = 0; - int page_counter_uncoalesced = - block_ids[0] != -1 ? global_block_pool_[block_ids[0]].page_ids.size() : 0; - for (int i = 1; i < static_cast(block_ids.size()); ++i) { - if (block_ids[i] != -1) { - page_counter_uncoalesced += global_block_pool_[block_ids[i]].page_ids.size(); + /*! \brief KernelBeginForward for multi-head attention. */ + void MHAKernelBeginForward() { + if (!append_before_attn_) { + if (is_chain_on_depths_[0] && f_attention_prefill_ragged_ != nullptr && + f_attention_prefill_ragged_->backend_kind == AttnBackendKind::kFlashInfer) { + f_attention_prefill_ragged_->BeginForward( + temp_float_attn_workspace_, temp_int_attn_workspace_[0], + temp_int_pinned_attn_workspace_[0], &cur_append_lengths_indptr_host_, + &cur_append_lengths_indptr_host_, cur_batch_size_, + cur_append_lengths_indptr_host_.back(), num_qo_heads_, num_kv_heads_, qk_head_dim_, + v_head_dim_, /*causal=*/true, copy_stream_); + } + } + for (int d = 0; d < num_depths_; ++d) { + if (page_indices_on_depths_view_[d]->shape[0] == 0) { + continue; } - uncoalesced_block_ids.emplace_back(block_ids[i - 1], cur_append_lengths_[i - 1]); - if (block_ids[i] == cur_block_id) { - chunk_append_length += cur_append_lengths_[i]; + CHECK(!support_sliding_window_) << "Kernel BeginForward doesn't support sliding window."; + if (use_decode_kernel_[d]) { + if (f_attention_decode_ != nullptr && + f_attention_decode_->backend_kind == AttnBackendKind::kFlashInfer) { + f_attention_decode_->BeginForward( + d, temp_float_attn_workspace_, temp_int_attn_workspace_[d + 1], + temp_int_pinned_attn_workspace_[d + 1], &page_indptr_on_depths_host_[d], + cur_batch_size_, page_size_, num_qo_heads_, num_kv_heads_, qk_head_dim_, v_head_dim_, + rope_mode_, kv_dtype_, kv_dtype_, copy_stream_); + } } else { - coalesced_block_ids.emplace_back(cur_block_id, chunk_append_length); - if (cur_block_id != -1) { - page_counter_coalesced += global_block_pool_[cur_block_id].page_ids.size(); + if (f_attention_prefill_ != nullptr && + f_attention_prefill_->backend_kind == AttnBackendKind::kFlashInfer) { + f_attention_prefill_->BeginForward( + d, temp_float_attn_workspace_, temp_int_attn_workspace_[d + 1], + temp_int_pinned_attn_workspace_[d + 1], &qo_indptr_on_depths_host_[d], + &page_indptr_on_depths_host_[d], &last_page_len_on_depths_host_[d], + static_cast(qo_indptr_on_depths_host_[d].size()) - 1, + cur_append_lengths_indptr_host_.back(), page_size_, num_qo_heads_, num_kv_heads_, + qk_head_dim_, v_head_dim_, /*causal=*/false, copy_stream_); } - cur_block_id = block_ids[i]; - chunk_append_length = cur_append_lengths_[i]; } } - uncoalesced_block_ids.emplace_back(block_ids.back(), cur_append_lengths_.back()); - coalesced_block_ids.emplace_back(cur_block_id, chunk_append_length); - if (cur_block_id != -1) { - page_counter_coalesced += global_block_pool_[cur_block_id].page_ids.size(); - } - double coalesce_ratio = 1.0 * page_counter_uncoalesced / page_counter_coalesced; - // Do not coalesce and use batch decode kernel when coalesce ratio is small. - bool use_decode_kernel = is_decode_request_ && coalesce_ratio < 32; - return {use_decode_kernel || !enable_coalesce ? uncoalesced_block_ids : coalesced_block_ids, - use_decode_kernel}; - } - - /*! \brief Check whether BeginForward for kernels is needed. */ - bool NeedKernelBeginForward() { - return f_attention_prefill_begin_forward_.defined() && - f_attention_decode_begin_forward_.defined() && - f_attention_prefill_ragged_begin_forward_.defined(); } - /*! \brief Invoke the "begin forward" functions of underlying kernels. */ - void KernelBeginForward() { - if (!NeedKernelBeginForward()) { - return; - } - + /*! \brief KernelBeginForward for multi-head latent attention. */ + void MLAKernelBeginForward() { if (!append_before_attn_) { if (is_chain_on_depths_[0]) { - f_attention_prefill_ragged_begin_forward_.value()( - temp_float_attn_workspace_, temp_int_attn_workspace_[0], - cur_append_lengths_indptr_host_.as_ndarray(), - cur_append_lengths_indptr_host_.as_ndarray(), cur_batch_size_, num_qo_heads_, - num_kv_heads_, qk_head_dim_, copy_stream_); + if (f_attention_prefill_ragged_ != nullptr && + f_attention_prefill_ragged_->backend_kind == AttnBackendKind::kFlashInfer) { + f_attention_prefill_ragged_->BeginForward( + temp_float_attn_workspace_, temp_int_attn_workspace_[0], + temp_int_pinned_attn_workspace_[0], &cur_append_lengths_indptr_host_, + &cur_append_lengths_indptr_host_, cur_batch_size_, + cur_append_lengths_indptr_host_.back(), num_qo_heads_, num_kv_heads_, qk_head_dim_, + v_head_dim_, /*causal=*/true, copy_stream_); + } } } for (int d = 0; d < num_depths_; ++d) { @@ -2921,19 +1969,15 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { continue; } CHECK(!support_sliding_window_) << "Kernel BeginForward doesn't support sliding window."; - if (use_decode_kernel_[d]) { - f_attention_decode_begin_forward_.value()( + if (f_mla_prefill_ != nullptr && + f_mla_prefill_->backend_kind == AttnBackendKind::kFlashInfer) { + f_mla_prefill_->BeginForward( d, temp_float_attn_workspace_, temp_int_attn_workspace_[d + 1], - page_indptr_on_depths_host_[d].as_ndarray(), - last_page_len_on_depths_host_[d].as_ndarray(), num_qo_heads_, num_kv_heads_, - qk_head_dim_, page_size_, - /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, copy_stream_); - } else { - f_attention_prefill_begin_forward_.value()( - /*depth=*/d, temp_float_attn_workspace_, temp_int_attn_workspace_[d + 1], - qo_indptr_on_depths_host_[d].as_ndarray(), page_indptr_on_depths_host_[d].as_ndarray(), - static_cast(page_indptr_on_depths_host_[d].size()) - 1, num_qo_heads_, - num_kv_heads_, qk_head_dim_, page_size_, copy_stream_); + temp_int_pinned_attn_workspace_[d + 1], &qo_indptr_on_depths_host_[d], + &page_indptr_on_depths_host_[d], &last_page_len_on_depths_host_[d], + static_cast(qo_indptr_on_depths_host_[d].size()) - 1, + cur_append_lengths_indptr_host_.back(), page_size_, num_qo_heads_, num_kv_heads_, + qk_head_dim_, v_head_dim_, /*causal=*/false, copy_stream_); } } } @@ -2943,145 +1987,150 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { * input k/v data and the k/v data in cache on the given layer. */ void AttentionInternal(int64_t layer_id, NDArray q_data, NDArray k_data, NDArray v_data, - NDArray output, double attn_score_scaling_factor) { + NDArray output, double sm_scale) { int64_t local_layer_id = layer_id - layer_id_begin_offset_; CHECK_GE(local_layer_id, 0); CHECK_LT(local_layer_id, num_layers_); - PackedFunc f_prefill = - !support_sliding_window_ ? f_attention_prefill_ : f_attention_prefill_sliding_window_; - PackedFunc f_decode = - !support_sliding_window_ ? f_attention_decode_ : f_attention_decode_sliding_window_; - CHECK_GE(num_depths_, 1) << "The number of effective depths must be greater or equal to 1."; bool is_first_kernel = true; if (!append_before_attn_) { // The first part of attention, which only involves the q and the newly appended k/v. is_first_kernel = false; - if (is_chain_on_depths_[0]) { - // If the batch does not form a tree, use raggedness prefill kernel. - f_attention_prefill_ragged_(q_data, cur_append_length_indptr_view_, k_data, v_data, - cur_append_length_indptr_view_, q_rope_position_map_view_, - k_ragged_rope_pos_offset_view_, output, - merged_attn_scores_view_, - /*causal=*/1, - /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, - rotary_theta_, attn_score_scaling_factor); - } else { - // The batch requires tree attention. - ICHECK(f_attention_prefill_with_tree_mask_.defined()) - << "Function \"f_attention_prefill_with_tree_mask_\" is not defined."; - ICHECK(tree_attn_mask_view_[0].defined()); - ICHECK(tree_attn_mn_indptr_view_[0].defined()); - f_attention_prefill_with_tree_mask_( - q_data, cur_append_length_indptr_view_, k_data, v_data, cur_append_length_indptr_view_, - q_rope_position_map_view_, tree_attn_mn_indptr_view_[0], tree_attn_mask_view_[0], - output, merged_attn_scores_view_, /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, - rotary_scale_, rotary_theta_, attn_score_scaling_factor, cur_batch_size_); - } + MHASelfAttnInternal(q_data, k_data, v_data, output, merged_attn_lse_view_, sm_scale); } + bool self_attn_computed = !is_first_kernel; + bool cross_attn_computed = MHACrossAttnInternal( + local_layer_id, q_data, output, merged_attn_lse_view_, sm_scale, is_first_kernel); + CHECK(self_attn_computed || cross_attn_computed) + << "Both self-attention and cross-attention are not computed."; + } + void MHASelfAttnInternal(NDArray q_data, NDArray k_data, NDArray v_data, NDArray o_data, + NDArray lse_data, double sm_scale) { + if (is_chain_on_depths_[0]) { + // If the batch does not form a tree, use raggedness prefill kernel. + ICHECK_NOTNULL(f_attention_prefill_ragged_); + f_attention_prefill_ragged_->MHA( + q_data, k_data, v_data, cur_append_length_indptr_view_, cur_append_length_indptr_view_, + q_rope_position_map_view_, k_ragged_rope_pos_offset_view_, /*causal=*/true, rope_mode_, + rotary_scale_, rotary_theta_, sm_scale, o_data, lse_data, compute_stream_); + } else { + // The batch requires tree attention. + ICHECK(f_attention_prefill_with_tree_mask_ != nullptr) + << "Function \"f_attention_prefill_with_tree_mask_\" is not defined."; + ICHECK(tree_attn_mask_view_[0].defined()); + ICHECK(tree_attn_mn_indptr_view_[0].defined()); + f_attention_prefill_with_tree_mask_->MHA( + q_data, k_data, v_data, cur_append_length_indptr_view_, cur_append_length_indptr_view_, + q_rope_position_map_view_, tree_attn_mn_indptr_view_[0], tree_attn_mask_view_[0], + rope_mode_, rotary_scale_, rotary_theta_, sm_scale, o_data, lse_data, compute_stream_); + } + } + + void MLASelfAttnInternal(NDArray q_data, NDArray k_data, NDArray v_data, NDArray o_data, + NDArray lse_data, double sm_scale) { + CHECK(is_chain_on_depths_[0]) << "Tree attn not able for MLA for now."; + // If the batch does not form a tree, use raggedness prefill kernel. + ICHECK_NOTNULL(f_attention_prefill_ragged_); + f_attention_prefill_ragged_->MHA( + q_data, k_data, v_data, cur_append_length_indptr_view_, cur_append_length_indptr_view_, + q_rope_position_map_view_, k_ragged_rope_pos_offset_view_, /*causal=*/true, RoPEMode::kNone, + rotary_scale_, rotary_theta_, sm_scale, o_data, lse_data, compute_stream_); + } + + /*! \brief Compute cross-attention for MHA. Return if there is effective computation. */ + bool MHACrossAttnInternal(int64_t local_layer_id, NDArray q_data, NDArray o_data, + NDArray lse_data, double sm_scale, bool is_first_kernel) { + std::unique_ptr& f_prefill = + !support_sliding_window_ ? f_attention_prefill_ : f_attention_prefill_sliding_window_; + std::unique_ptr& f_decode = + !support_sliding_window_ ? f_attention_decode_ : f_attention_decode_sliding_window_; + CHECK_GE(num_depths_, 1) << "The number of effective depths must be greater or equal to 1."; + + bool cross_attn_computed = false; for (int d = 0; d < num_depths_; ++d) { if (page_indices_on_depths_view_[d]->shape[0] == 0) { continue; } NDArray attn_output; - NDArray attn_scores; + NDArray attn_lse; if (is_first_kernel) { - attn_output = output; - attn_scores = merged_attn_scores_view_; + attn_output = o_data; + attn_lse = lse_data; } else { attn_output = temp_attn_output_view_; - attn_scores = temp_attn_scores_view_; + attn_lse = temp_attn_lse_view_; } if (append_before_attn_ && !is_chain_on_depths_[d]) { - f_attention_prefill_with_tree_mask_paged_kv_( - /*depth=*/d, q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], + ICHECK_NOTNULL(f_attention_prefill_with_tree_mask_paged_kv_); + f_attention_prefill_with_tree_mask_paged_kv_->MHA( + q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d], length_info_on_depths_view_[d], k_rope_pos_offset_view_[d], q_rope_position_map_view_, - attn_output, attn_scores, - /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, rotary_theta_, - attn_score_scaling_factor, tree_attn_mn_indptr_view_[d], tree_attn_mask_view_[d]); + tree_attn_mn_indptr_view_[d], tree_attn_mask_view_[d], rope_mode_, rotary_scale_, + rotary_theta_, sm_scale, attn_output, attn_lse, compute_stream_); } else if (use_decode_kernel_[d]) { // Use decode kernel for depth d - f_decode(/*depth=*/d, q_data, pages_[local_layer_id], page_indptr_on_depths_view_[d], - page_indices_on_depths_view_[d], length_info_on_depths_view_[d], - k_rope_pos_offset_view_[d], q_rope_position_map_view_, attn_output, attn_scores, - /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, rotary_theta_, - attn_score_scaling_factor); + ICHECK_NOTNULL(f_decode); + f_decode->MHA(d, q_data, pages_[local_layer_id], page_indptr_on_depths_view_[d], + page_indices_on_depths_view_[d], length_info_on_depths_view_[d], + k_rope_pos_offset_view_[d], q_rope_position_map_view_, rope_mode_, + rotary_scale_, rotary_theta_, sm_scale, attn_output, attn_lse, + compute_stream_); } else { // Use prefill kernel for depth d - f_prefill(/*depth=*/d, q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], - page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d], - length_info_on_depths_view_[d], k_rope_pos_offset_view_[d], - q_rope_position_map_view_, attn_output, attn_scores, /*causal=*/0, - /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, rotary_theta_, - attn_score_scaling_factor); + ICHECK_NOTNULL(f_prefill); + f_prefill->MHA(d, q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], + page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d], + length_info_on_depths_view_[d], q_rope_position_map_view_, + k_rope_pos_offset_view_[d], /*causal=*/false, + /*rotary_mode=*/rope_mode_, rotary_scale_, rotary_theta_, sm_scale, + attn_output, attn_lse, compute_stream_); } if (!is_first_kernel) { - f_merge_inplace_(output, merged_attn_scores_view_, temp_attn_output_view_, - temp_attn_scores_view_); + f_merge_inplace_[0](o_data, lse_data, temp_attn_output_view_, temp_attn_lse_view_); } else { is_first_kernel = false; } + cross_attn_computed = true; } + return cross_attn_computed; } - void MLAAbsorbedInternal(int64_t layer_id, NDArray q_data, NDArray compressed_kv_data, - NDArray k_pe_data, NDArray output, double attn_score_scaling_factor) { - int64_t local_layer_id = layer_id - layer_id_begin_offset_; - CHECK_GE(local_layer_id, 0); - CHECK_LT(local_layer_id, num_layers_); - PackedFunc f_prefill = f_mla_prefill_; - PackedFunc f_decode = f_mla_decode_; + /*! \brief Compute cross-attention for MLA. Return if there is effective computation. */ + bool MLACrossAttnInternal(int64_t local_layer_id, NDArray q_data, NDArray o_data, + NDArray lse_data, double sm_scale) { CHECK_GE(num_depths_, 1) << "The number of effective depths must be greater or equal to 1."; bool is_first_kernel = true; - if (!append_before_attn_) { - // The first part of attention, which only involves the q and the newly appended k/v. - is_first_kernel = false; - CHECK(is_chain_on_depths_[0]) << "Tree attn not able for MLA for now."; - // If the batch does not form a tree, use raggedness prefill kernel. - f_mla_prefill_ragged_absorbed_(q_data, cur_append_length_indptr_view_, compressed_kv_data, - k_pe_data, cur_append_length_indptr_view_, output, - merged_attn_scores_view_, - /*causal=*/1, attn_score_scaling_factor); - } - for (int d = 0; d < num_depths_; ++d) { if (page_indices_on_depths_view_[d]->shape[0] == 0) { continue; } NDArray attn_output; - NDArray attn_scores; + NDArray attn_lse; if (is_first_kernel) { - attn_output = output; - attn_scores = merged_attn_scores_view_; + attn_output = o_data; + attn_lse = lse_data; } else { attn_output = temp_attn_output_view_; - attn_scores = temp_attn_scores_view_; + attn_lse = temp_attn_lse_view_; } CHECK(is_chain_on_depths_[d]) << "Tree attn not able for MLA for now."; - if (use_decode_kernel_[d]) { - // Use decode kernel for depth d - f_decode(/*depth=*/d, q_data, pages_[local_layer_id], page_indptr_on_depths_view_[d], - page_indices_on_depths_view_[d], length_info_on_depths_view_[d], attn_output, - attn_scores, attn_score_scaling_factor); - } else { - // Use prefill kernel for depth d - f_prefill(/*depth=*/d, q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], - page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d], - length_info_on_depths_view_[d], attn_output, attn_scores, /*causal=*/0, - attn_score_scaling_factor); - } + ICHECK_NOTNULL(f_mla_prefill_); + f_mla_prefill_->MLA(d, q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], + page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d], + length_info_on_depths_view_[d], /*causal=*/false, sm_scale, attn_output, + attn_lse, compute_stream_); if (!is_first_kernel) { - f_merge_inplace_(output, merged_attn_scores_view_, temp_attn_output_view_, - temp_attn_scores_view_); + f_merge_inplace_[0](o_data, lse_data, temp_attn_output_view_, temp_attn_lse_view_); } else { is_first_kernel = false; } } + return !is_first_kernel; } /*! \brief Synchronize the copy stream and the compute stream. */ @@ -3215,10 +2264,10 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // 16. Create view for temporary arrays for attention computation. temp_attn_output_view_ = temp_attn_output_device_.CreateView( {total_append_length, num_qo_heads_, v_head_dim_}, temp_attn_output_device_->dtype); - temp_attn_scores_view_ = temp_attn_scores_device_.CreateView( - {total_append_length, num_qo_heads_}, temp_attn_scores_device_->dtype); - merged_attn_scores_view_ = merged_attn_scores_device_.CreateView( - {total_append_length, num_qo_heads_}, merged_attn_scores_device_->dtype); + temp_attn_lse_view_ = temp_attn_lse_device_.CreateView({total_append_length, num_qo_heads_}, + temp_attn_lse_device_->dtype); + merged_attn_lse_view_ = merged_attn_lse_device_.CreateView({total_append_length, num_qo_heads_}, + merged_attn_lse_device_->dtype); // - Commit the copy. aux_data_manager_->CommitAttnAuxDataCopy(); @@ -3235,178 +2284,9 @@ TVM_REGISTER_OBJECT_TYPE(PagedAttentionKVCacheObj); TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") .set_body([](TVMArgs args, TVMRetValue* rv) { - CHECK(args.size() == 29 || args.size() == 30) - << "Invalid number of KV cache constructor args."; - ShapeTuple cache_config = args[0]; - ShapeTuple layer_indptr_tuple = args[1]; - int num_groups = 1; - int group_id = 0; - if (DiscoWorker* disco_worker = ThreadLocalDiscoWorker::Get()->worker) { - // In the Disco worker thread - num_groups = disco_worker->num_groups; - group_id = disco_worker->worker_id / (disco_worker->num_workers / num_groups); - } - CHECK_EQ(layer_indptr_tuple.size(), num_groups + 1); - int64_t num_layers = layer_indptr_tuple[group_id + 1] - layer_indptr_tuple[group_id]; - int64_t layer_id_begin_offset = layer_indptr_tuple[group_id]; - int64_t num_qo_heads = args[2]; - int64_t num_kv_heads = args[3]; - int64_t head_dim = args[4]; - int rope_mode = args[5]; - double rotary_scale = args[6]; - double rotary_theta = args[7]; - NDArray init = args[8]; - PackedFunc f_transpose_append = args[9]; - PackedFunc f_attention_prefill = args[10]; - PackedFunc f_attention_decode = args[11]; - PackedFunc f_attention_prefill_sliding_window = args[12]; - PackedFunc f_attention_decode_sliding_window = args[13]; - PackedFunc f_attention_prefill_ragged = args[14]; - PackedFunc f_attention_prefill_ragged_begin_forward = args[15]; - PackedFunc f_attention_prefill_ragged_end_forward = args[16]; - PackedFunc f_attention_prefill_begin_forward = args[17]; - PackedFunc f_attention_prefill_end_forward = args[18]; - PackedFunc f_attention_decode_begin_forward = args[19]; - PackedFunc f_attention_decode_end_forward = args[20]; - PackedFunc f_merge_inplace = args[21]; - PackedFunc f_split_rotary = args[22]; - PackedFunc f_copy_single_page = args[23]; - Optional f_debug_get_kv = args[24]; - PackedFunc f_compact_copy = args[25]; - PackedFunc f_attention_prefill_with_tree_mask = args[26]; - PackedFunc f_attention_prefill_with_tree_mask_paged_kv = args[27]; - Optional rope_ext_factors = NullOpt; - bool enable_kv_transfer = false; - - if (args[28].IsObjectRef()) { - rope_ext_factors = args[28].AsObjectRef(); - } - if (args.size() >= 30) { - enable_kv_transfer = args[29]; - } - - std::vector attn_kinds(/*size=*/layer_indptr_tuple[num_groups], - /*value=*/AttnKind::kMHA); - - CHECK_EQ(cache_config.size(), 5); - int64_t reserved_num_seqs = cache_config[0]; - int64_t total_token_capacity = cache_config[1]; - int64_t prefill_chunk_size = cache_config[2]; - int64_t page_size = cache_config[3]; - bool support_sliding_window = cache_config[4]; - int64_t num_total_pages = (total_token_capacity + page_size - 1) / page_size + 1; - if (support_sliding_window) { - // When sliding window is enabled, each sequence may use two more pages at most. - num_total_pages += reserved_num_seqs * 2; - } - // NOTE: We will remove this legacy construction after finishing the transition phase. - // Some `PackedFunc()` here are placeholders that will be filled. - ObjectPtr n = make_object( - page_size, num_layers, layer_id_begin_offset, num_qo_heads, num_kv_heads, head_dim, - head_dim, /*qk_rope_head_dim=*/0, attn_kinds, reserved_num_seqs, num_total_pages, - prefill_chunk_size, support_sliding_window, RoPEMode(rope_mode), rotary_scale, - rotary_theta, - std::move(rope_ext_factors), // - enable_kv_transfer, init->dtype, init->device, // - std::move(f_transpose_append), PackedFunc(), std::move(f_compact_copy), - std::move(f_attention_prefill), std::move(f_attention_decode), - std::move(f_attention_prefill_sliding_window), - std::move(f_attention_decode_sliding_window), std::move(f_attention_prefill_ragged), - std::move(f_attention_prefill_with_tree_mask), - std::move(f_attention_prefill_with_tree_mask_paged_kv), - std::move(f_attention_prefill_ragged_begin_forward), - std::move(f_attention_prefill_ragged_end_forward), - std::move(f_attention_prefill_begin_forward), std::move(f_attention_prefill_end_forward), - std::move(f_attention_decode_begin_forward), std::move(f_attention_decode_end_forward), - PackedFunc(), PackedFunc(), PackedFunc(), PackedFunc(), std::move(f_merge_inplace), - std::move(f_split_rotary), std::move(f_copy_single_page), std::move(f_debug_get_kv)); - *rv = AttentionKVCache(std::move(n)); - }); - -TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced") - .set_body([](TVMArgs args, TVMRetValue* rv) { - CHECK(args.size() == 23 || args.size() == 24) - << "Invalid number of KV cache constructor args."; - ShapeTuple cache_config = args[0]; - ShapeTuple layer_indptr_tuple = args[1]; - int num_groups = 1; - int group_id = 0; - if (DiscoWorker* disco_worker = ThreadLocalDiscoWorker::Get()->worker) { - // In the Disco worker thread - num_groups = disco_worker->num_groups; - group_id = disco_worker->worker_id / (disco_worker->num_workers / num_groups); - } - CHECK_EQ(layer_indptr_tuple.size(), num_groups + 1); - int64_t num_layers = layer_indptr_tuple[group_id + 1] - layer_indptr_tuple[group_id]; - int64_t layer_id_begin_offset = layer_indptr_tuple[group_id]; - int64_t num_qo_heads = args[2]; - int64_t num_kv_heads = args[3]; - int64_t head_dim = args[4]; - int rope_mode = args[5]; - double rotary_scale = args[6]; - double rotary_theta = args[7]; - NDArray init = args[8]; - PackedFunc f_transpose_append = args[9]; - PackedFunc f_attention_prefill = args[10]; - PackedFunc f_attention_decode = args[11]; - PackedFunc f_attention_prefill_sliding_window = args[12]; - PackedFunc f_attention_decode_sliding_window = args[13]; - PackedFunc f_attention_prefill_ragged = args[14]; - PackedFunc f_merge_inplace = args[15]; - PackedFunc f_split_rotary = args[16]; - PackedFunc f_copy_single_page = args[17]; - Optional f_debug_get_kv = args[18]; - PackedFunc f_compact_copy = args[19]; - PackedFunc f_attention_prefill_with_tree_mask = args[20]; - PackedFunc f_attention_prefill_with_tree_mask_paged_kv = args[21]; - Optional rope_ext_factors = NullOpt; - bool enable_kv_transfer = false; - - if (args[22].IsObjectRef()) { - rope_ext_factors = args[22].AsObjectRef(); - } - if (args.size() >= 24) { - enable_kv_transfer = args[23]; - } - - std::vector attn_kinds(/*size=*/layer_indptr_tuple[num_groups], - /*value=*/AttnKind::kMHA); - - CHECK_EQ(cache_config.size(), 5); - int64_t reserved_num_seqs = cache_config[0]; - int64_t total_token_capacity = cache_config[1]; - int64_t prefill_chunk_size = cache_config[2]; - int64_t page_size = cache_config[3]; - bool support_sliding_window = cache_config[4]; - int64_t num_total_pages = (total_token_capacity + page_size - 1) / page_size + 1; - if (support_sliding_window) { - // When sliding window is enabled, each sequence may use two more pages at most. - num_total_pages += reserved_num_seqs * 2; - } - // NOTE: We will remove this legacy construction after finishing the transition phase. - // Some `PackedFunc()` here are placeholders that will be filled. - ObjectPtr n = make_object( - page_size, num_layers, layer_id_begin_offset, num_qo_heads, num_kv_heads, head_dim, - head_dim, /*qk_rope_head_dim=*/0, attn_kinds, reserved_num_seqs, num_total_pages, - prefill_chunk_size, support_sliding_window, RoPEMode(rope_mode), rotary_scale, - rotary_theta, - std::move(rope_ext_factors), // - enable_kv_transfer, init->dtype, init->device, // - std::move(f_transpose_append), PackedFunc(), std::move(f_compact_copy), - std::move(f_attention_prefill), std::move(f_attention_decode), - std::move(f_attention_prefill_sliding_window), - std::move(f_attention_decode_sliding_window), std::move(f_attention_prefill_ragged), - std::move(f_attention_prefill_with_tree_mask), // - std::move(f_attention_prefill_with_tree_mask_paged_kv), // - NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, // - PackedFunc(), PackedFunc(), PackedFunc(), PackedFunc(), std::move(f_merge_inplace), - std::move(f_split_rotary), std::move(f_copy_single_page), std::move(f_debug_get_kv)); - *rv = AttentionKVCache(std::move(n)); - }); - -TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced_mla") - .set_body([](TVMArgs args, TVMRetValue* rv) { - CHECK(args.size() == 38) << "Invalid number of KV cache constructor args."; + // Todo: cuda graph arg + CHECK(args.size() == 28 || args.size() == 29) + << "Invalid number of KV cache constructor args: " << args.size(); ShapeTuple cache_config = args[0]; ShapeTuple layer_indptr_tuple = args[1]; int num_groups = 1; @@ -3419,60 +2299,54 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced_mla") CHECK_EQ(layer_indptr_tuple.size(), num_groups + 1); int64_t num_layers = layer_indptr_tuple[group_id + 1] - layer_indptr_tuple[group_id]; int64_t layer_id_begin_offset = layer_indptr_tuple[group_id]; + int64_t layer_id_end_offset = layer_indptr_tuple[group_id + 1]; int64_t num_qo_heads = args[2]; int64_t num_kv_heads = args[3]; int64_t qk_head_dim = args[4]; int64_t v_head_dim = args[5]; - int64_t qk_rope_head_dim = args[6]; - IntTuple attn_kinds = args[7]; + IntTuple attn_kinds = args[6]; + bool enable_kv_transfer = args[7]; int rope_mode = args[8]; double rotary_scale = args[9]; double rotary_theta = args[10]; - NDArray init = args[11]; - PackedFunc f_transpose_append = args[12]; - PackedFunc f_transpose_append_mla = args[13]; - PackedFunc f_attention_prefill = args[14]; - PackedFunc f_attention_decode = args[15]; - PackedFunc f_attention_prefill_sliding_window = args[16]; - PackedFunc f_attention_decode_sliding_window = args[17]; - PackedFunc f_attention_prefill_ragged = args[18]; - Optional f_attention_prefill_ragged_begin_forward = NullOpt; - Optional f_attention_prefill_ragged_end_forward = NullOpt; - Optional f_attention_prefill_begin_forward = NullOpt; - Optional f_attention_prefill_end_forward = NullOpt; - Optional f_attention_decode_begin_forward = NullOpt; - Optional f_attention_decode_end_forward = NullOpt; - PackedFunc f_mla_prefill = args[25]; - PackedFunc f_mla_decode = args[26]; - PackedFunc f_mla_prefill_ragged_normal = args[27]; - PackedFunc f_mla_prefill_ragged_absorbed = args[28]; - PackedFunc f_merge_inplace = args[29]; - PackedFunc f_split_rotary = args[30]; - PackedFunc f_copy_single_page = args[31]; - Optional f_debug_get_kv = args[32]; - PackedFunc f_compact_copy = args[33]; - PackedFunc f_attention_prefill_with_tree_mask = args[34]; - PackedFunc f_attention_prefill_with_tree_mask_paged_kv = args[35]; - Optional rope_ext_factors = NullOpt; - bool enable_kv_transfer = false; - - if (args[36].IsObjectRef()) { - rope_ext_factors = args[36].AsObjectRef(); + Optional rope_ext_factors = NullOpt; // args[11] + NDArray init = args[12]; + Optional f_transpose_append_mha = NullOpt; // args[13] + Optional f_transpose_append_mla = NullOpt; // args[14] + std::unique_ptr f_attention_prefill_ragged = + ConvertRaggedPrefillFunc(args[15], AttnKind::kMHA); + std::unique_ptr f_attention_prefill = + ConvertPagedPrefillFunc(args[16], AttnKind::kMHA); + std::unique_ptr f_attention_decode = + ConvertPagedDecodeFunc(args[17], AttnKind::kMHA); + std::unique_ptr f_attention_prefill_sliding_window = + ConvertPagedPrefillFunc(args[18], AttnKind::kMHA); + std::unique_ptr f_attention_decode_sliding_window = + ConvertPagedDecodeFunc(args[19], AttnKind::kMHA); + std::unique_ptr f_attention_prefill_with_tree_mask_paged_kv = + ConvertPagedPrefillTreeMaskFunc(args[20], AttnKind::kMHA); + std::unique_ptr f_attention_prefill_with_tree_mask = + ConvertRaggedPrefillTreeMaskFunc(args[21], AttnKind::kMHA); + std::unique_ptr f_mla_prefill = + ConvertPagedPrefillFunc(args[22], AttnKind::kMLA); + Array f_merge_inplace = args[23]; + PackedFunc f_split_rotary = args[24]; + PackedFunc f_copy_single_page = args[25]; + PackedFunc f_debug_get_kv = args[26]; + PackedFunc f_compact_copy = args[27]; + + if (args[11].IsObjectRef()) { + rope_ext_factors = args[11].AsObjectRef(); } - enable_kv_transfer = args[37]; - auto f_convert_optional_packed_func = [&args](int arg_idx) -> Optional { if (args[arg_idx].IsObjectRef()) { return args[arg_idx].AsObjectRef(); } return NullOpt; }; - f_attention_prefill_ragged_begin_forward = f_convert_optional_packed_func(19); - f_attention_prefill_ragged_end_forward = f_convert_optional_packed_func(20); - f_attention_prefill_begin_forward = f_convert_optional_packed_func(21); - f_attention_prefill_end_forward = f_convert_optional_packed_func(22); - f_attention_decode_begin_forward = f_convert_optional_packed_func(23); - f_attention_decode_end_forward = f_convert_optional_packed_func(24); + f_transpose_append_mha = f_convert_optional_packed_func(13); + f_transpose_append_mla = f_convert_optional_packed_func(14); + CHECK(!f_merge_inplace.empty()) << "Merge inplace function is not defined."; std::vector attn_kinds_vec; attn_kinds_vec.reserve(attn_kinds.size()); @@ -3494,25 +2368,20 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced_mla") // NOTE: We will remove this legacy construction after finishing the transition phase. // Some `PackedFunc()` here are placeholders that will be filled. ObjectPtr n = make_object( - page_size, num_layers, layer_id_begin_offset, num_qo_heads, num_kv_heads, qk_head_dim, - v_head_dim, qk_rope_head_dim, attn_kinds_vec, reserved_num_seqs, num_total_pages, + page_size, num_layers, layer_id_begin_offset, layer_id_end_offset, num_qo_heads, + num_kv_heads, qk_head_dim, v_head_dim, attn_kinds_vec, reserved_num_seqs, num_total_pages, prefill_chunk_size, support_sliding_window, RoPEMode(rope_mode), rotary_scale, - rotary_theta, - std::move(rope_ext_factors), // - enable_kv_transfer, init->dtype, init->device, // - std::move(f_transpose_append), std::move(f_transpose_append_mla), - std::move(f_compact_copy), std::move(f_attention_prefill), std::move(f_attention_decode), + rotary_theta, std::move(rope_ext_factors), enable_kv_transfer, // + init->dtype, init->device, // + std::move(f_transpose_append_mha), std::move(f_transpose_append_mla), + std::move(f_compact_copy), std::move(f_attention_prefill_ragged), + std::move(f_attention_prefill), std::move(f_attention_decode), std::move(f_attention_prefill_sliding_window), - std::move(f_attention_decode_sliding_window), std::move(f_attention_prefill_ragged), - std::move(f_attention_prefill_with_tree_mask), // + std::move(f_attention_decode_sliding_window), std::move(f_attention_prefill_with_tree_mask_paged_kv), // - std::move(f_attention_prefill_ragged_begin_forward), - std::move(f_attention_prefill_ragged_end_forward), - std::move(f_attention_prefill_begin_forward), std::move(f_attention_prefill_end_forward), - std::move(f_attention_decode_begin_forward), std::move(f_attention_decode_end_forward), - std::move(f_mla_prefill), std::move(f_mla_decode), std::move(f_mla_prefill_ragged_normal), - std::move(f_mla_prefill_ragged_absorbed), std::move(f_merge_inplace), - std::move(f_split_rotary), std::move(f_copy_single_page), std::move(f_debug_get_kv)); + std::move(f_attention_prefill_with_tree_mask), // + std::move(f_mla_prefill), std::move(f_merge_inplace), std::move(f_split_rotary), + std::move(f_copy_single_page), std::move(f_debug_get_kv)); *rv = AttentionKVCache(std::move(n)); }); diff --git a/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py b/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py index 483108ca838c..5a3d65dbe0de 100644 --- a/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py +++ b/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py @@ -21,11 +21,14 @@ import numpy as np import pytest import scipy.special +import torch import tvm import tvm.testing from tvm import dlight as dl from tvm.relax.frontend.nn.llm.kv_cache import ( + AttnKind, + RopeMode, _attention_decode, _attention_prefill, _attention_prefill_ragged, @@ -62,11 +65,14 @@ def get_comm_rank(): num_qo_heads = 32 num_kv_heads = 4 head_dim = None +sm_scale = None rope_scale = 1.0 rope_theta = 1e4 rope_scaling = {} dtype = None +dtype_torch = None device = tvm.cuda(rank) +device_torch = torch.device(f"cuda:{rank}") fclear = None fadd_sequence = None @@ -147,7 +153,7 @@ def set_global_func(head_dim, dtype): _attention_prefill(num_kv_heads, num_qo_heads, head_dim, dtype, True, rope_scaling, target), _attention_decode(num_kv_heads, num_qo_heads, head_dim, dtype, True, rope_scaling, target), _attention_prefill_ragged( - num_kv_heads, num_qo_heads, head_dim, dtype, rope_scaling, target + num_kv_heads, num_qo_heads, head_dim, head_dim, dtype, rope_scaling, target ), tree_attn(num_kv_heads, num_qo_heads, head_dim, dtype, rope_scaling, target), tree_attn_with_paged_kv_cache( @@ -184,7 +190,7 @@ def set_global_func(head_dim, dtype): def create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window): - fcreate = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_create_reduced") + fcreate = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_create") cache = fcreate( tvm.runtime.ShapeTuple( [ @@ -199,41 +205,33 @@ def create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window): num_qo_heads, num_kv_heads, head_dim, + head_dim, # v_head_dim + tvm.runtime.ShapeTuple([int(AttnKind.MHA) for _ in range(num_layers)]), + False, # enable_kv_transfer rope_mode, rope_scale, rope_theta, + None, # rope_ext_factors tvm.nd.empty((), dtype, device=device), ftranspose_append, - fattn_prefill, - fattn_decode, - fattn_prefill_sliding_window, - fattn_decode_sliding_window, - fattn_prefill_ragged, - fmerge_state, + None, # f_transpose_append_mla + ["tir", fattn_prefill_ragged], + ["tir", fattn_prefill], + ["tir", fattn_decode], + ["tir", fattn_prefill_sliding_window], + ["tir", fattn_decode_sliding_window], + ["tir", fattn_prefill_with_tree_mask_paged_kv_cache], + ["tir", fattn_prefill_with_tree_mask], + [], # f_mla_prefill + [fmerge_state], fsplit_rotary, fcopy_single_page, fcopy_cache, fcompact_copy, - fattn_prefill_with_tree_mask, - fattn_prefill_with_tree_mask_paged_kv_cache, - None, - True, ) return cache -class RopeMode(enum.IntEnum): - """The RoPE mode of the Paged KV cache. - If it is none, the KV cache will not apply RoPE to q and k. - If it is normal, RoPE will be applied to k before adding k to cache. - Otherwise, RoPE will be applied to q/k in attention kernel on-the-fly. - """ - - NONE = 0 - NORMAL = 1 - INLINE = 2 - - @pytest.fixture( params=itertools.chain( itertools.product( @@ -251,8 +249,9 @@ class RopeMode(enum.IntEnum): ) ) def kv_cache_and_config(request): - global head_dim, dtype + global head_dim, sm_scale, dtype head_dim, dtype, rope_mode, support_sliding_window = request.param + sm_scale = head_dim ** (-0.5) set_global_func(head_dim, dtype) return create_kv_cache(*request.param), rope_mode, support_sliding_window @@ -266,8 +265,12 @@ def verify_cached_kv(kv_cache, seq_ids, expected_k, expected_v): keys = tvm.nd.empty(keys_expected.shape, dtype=dtype, device=device) values = tvm.nd.empty(values_expected.shape, dtype=dtype, device=device) fdebug_get_kv(kv_cache, seq_id, 0, seq_length, keys, values) - tvm.testing.assert_allclose(keys.numpy(), keys_expected, rtol=1e-3, atol=1e-3) - tvm.testing.assert_allclose(values.numpy(), values_expected, rtol=1e-3, atol=1e-3) + torch.testing.assert_close( + torch.from_numpy(keys.numpy()).to(device_torch), keys_expected, rtol=1e-3, atol=1e-3 + ) + torch.testing.assert_close( + torch.from_numpy(values.numpy()).to(device_torch), values_expected, rtol=1e-3, atol=1e-3 + ) def f_apply_rotary(x, offset, scale, theta, offset_list: Optional[List[int]] = None): @@ -275,29 +278,34 @@ def f_apply_rotary(x, offset, scale, theta, offset_list: Optional[List[int]] = N assert len(x.shape) == 3 nfeat = x.shape[-1] nfeat_half = x.shape[-1] // 2 - x = x.astype("float32") - y = np.concatenate([-x[:, :, nfeat_half:], x[:, :, :nfeat_half]], axis=-1) + x_dtype = x.dtype + x = x.to(torch.float32) + y = torch.cat([-x[:, :, nfeat_half:], x[:, :, :nfeat_half]], dim=-1) - inv_freq = scale / (theta ** (np.arange(0, nfeat, 2).astype("float32") / nfeat)) + inv_freq = scale / ( + theta ** (torch.arange(0, nfeat, 2, device=device_torch, dtype=torch.float32) / nfeat) + ) t = ( - np.arange(offset, offset + x.shape[0], dtype=inv_freq.dtype) + torch.arange(offset, offset + x.shape[0], device=device_torch, dtype=inv_freq.dtype) if offset_list is None - else (np.array(offset_list, dtype=inv_freq.dtype) + offset) + else (torch.tensor(offset_list, dtype=inv_freq.dtype, device=device_torch) + offset) ) - freqs = np.einsum("i,j->ij", t, inv_freq) - emb = np.concatenate((freqs, freqs), axis=-1) - cos_values = np.cos(emb) - sin_values = np.sin(emb) + freqs = torch.einsum("i,j->ij", t, inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + cos_values = torch.cos(emb) + sin_values = torch.sin(emb) - return np.einsum("ij,ikj->ikj", cos_values, x) + np.einsum("ij,ikj->ikj", sin_values, y) + return torch.einsum("ij,ikj->ikj", cos_values, x).to(x_dtype) + torch.einsum( + "ij,ikj->ikj", sin_values, y + ).to(x_dtype) def apply_attention( kv_cache, rope_mode: RopeMode, batch: List[Tuple[Union[int, Tuple[int, int, int]], int]], - cached_k: Dict[int, np.ndarray], - cached_v: Dict[int, np.ndarray], + cached_k: Dict[int, torch.Tensor], + cached_v: Dict[int, torch.Tensor], sliding_window_sizes: Optional[List[int]] = None, attn_sink_sizes: Optional[List[int]] = None, token_tree_parent_ptr_list: Optional[List[List[int]]] = None, @@ -329,8 +337,12 @@ def apply_attention( elif seq_id not in cached_k: if not only_update_host and not skip_add_sequence: fadd_sequence(kv_cache, seq_id) - cached_k[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype) - cached_v[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype) + cached_k[seq_id] = torch.zeros( + (num_layers, 0, num_kv_heads, head_dim), dtype=dtype_torch, device=device_torch + ) + cached_v[seq_id] = torch.zeros( + (num_layers, 0, num_kv_heads, head_dim), dtype=dtype_torch, device=device_torch + ) flattened_token_tree_parent_ptr = None token_tree_node_depths_list: List[Optional[List[int]]] = [None for _ in batch] @@ -353,6 +365,7 @@ def apply_attention( ) # depth of each node in the tree (this contains more than the last `append_length` nodes) token_tree_node_depths_list[i] = token_tree_node_depths + if not only_update_host: fbegin_forward( kv_cache, @@ -365,15 +378,45 @@ def apply_attention( ), ) - global_new_q = np.zeros((num_layers, 0, num_qo_heads, head_dim), dtype) - global_new_k = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype) - global_new_v = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype) + global_new_q = torch.zeros( + (num_layers, 0, num_qo_heads, head_dim), dtype=dtype_torch, device=device_torch + ) + global_new_k = torch.zeros( + (num_layers, 0, num_kv_heads, head_dim), dtype=dtype_torch, device=device_torch + ) + global_new_v = torch.zeros( + (num_layers, 0, num_kv_heads, head_dim), dtype=dtype_torch, device=device_torch + ) q_array = [] for i, (seq_id, append_length) in enumerate(batch): - new_q = np.random.rand(num_layers, append_length, num_qo_heads, head_dim).astype(dtype) - new_k = np.random.rand(num_layers, append_length, num_kv_heads, head_dim).astype(dtype) - new_v = np.random.rand(num_layers, append_length, num_kv_heads, head_dim).astype(dtype) + new_q = torch.rand( + num_layers, + append_length, + num_qo_heads, + head_dim, + dtype=dtype_torch, + device=device_torch, + ) + new_k = torch.rand( + num_layers, + append_length, + num_kv_heads, + head_dim, + dtype=dtype_torch, + device=device_torch, + ) + new_v = torch.rand( + num_layers, + append_length, + num_kv_heads, + head_dim, + dtype=dtype_torch, + device=device_torch, + ) + new_q = new_q * 2 - 1 + new_k = new_k * 2 - 1 + new_v = new_v * 2 - 1 q_array.append(new_q) rope_offset = cached_k[seq_id].shape[1] @@ -381,10 +424,10 @@ def apply_attention( prev_tree_size = len(token_tree_parent_ptr_list[i]) - append_length assert prev_tree_size >= 0 rope_offset -= prev_tree_size - cached_k[seq_id] = np.concatenate( + cached_k[seq_id] = torch.cat( [ cached_k[seq_id], - np.stack( + torch.stack( [ ( new_k[l] @@ -403,27 +446,27 @@ def apply_attention( ) for l in range(num_layers) ], - axis=0, + dim=0, ), ], - axis=1, + dim=1, ) - cached_v[seq_id] = np.concatenate([cached_v[seq_id], new_v], axis=1) - global_new_q = np.concatenate([global_new_q, new_q], axis=1) - global_new_k = np.concatenate([global_new_k, new_k], axis=1) - global_new_v = np.concatenate([global_new_v, new_v], axis=1) + cached_v[seq_id] = torch.cat([cached_v[seq_id], new_v], dim=1) + global_new_q = torch.cat([global_new_q, new_q], dim=1) + global_new_k = torch.cat([global_new_k, new_k], dim=1) + global_new_v = torch.cat([global_new_v, new_v], dim=1) for layer_id in range(num_layers): queries_np = global_new_q[layer_id] keys_np = global_new_k[layer_id] values_np = global_new_v[layer_id] - qkv = tvm.nd.array(np.concatenate([queries_np, keys_np, values_np], axis=1), device) + qkv = tvm.nd.array(torch.cat([queries_np, keys_np, values_np], dim=1).cpu().numpy(), device) outputs = tvm.nd.empty(queries_np.shape, dtype, device=device) if not only_update_host: - fattention_with_fuse_qkv(kv_cache, layer_id, 1.0, qkv, outputs) + fattention_with_fuse_qkv(kv_cache, layer_id, sm_scale, qkv, outputs) # Compute attention expected results. - outputs = np.expand_dims(outputs.numpy(), axis=0) + outputs = torch.from_numpy(outputs.numpy()).unsqueeze(0).to(device_torch) sum_length = 0 for i, (seq_id, append_length) in enumerate(batch): assert cached_k[seq_id].shape[1] == cached_v[seq_id].shape[1] >= append_length @@ -447,7 +490,7 @@ def apply_attention( else None ), ) - ).transpose(1, 0, 2) + ).permute(1, 0, 2) k_seq = ( cached_k[seq_id][layer_id] if rope_mode != RopeMode.INLINE @@ -465,41 +508,48 @@ def apply_attention( else None ), ) - ).transpose(1, 2, 0) - v_seq = cached_v[seq_id][layer_id].transpose(1, 0, 2) + ).permute(1, 2, 0) + v_seq = cached_v[seq_id][layer_id].permute(1, 0, 2) - k_seq = np.repeat(k_seq, num_qo_heads // num_kv_heads, axis=0) - v_seq = np.repeat(v_seq, num_qo_heads // num_kv_heads, axis=0) - softmax_input = (q_seq.astype("float32") @ k_seq.astype("float32")) / np.sqrt(head_dim) + k_seq = k_seq.repeat_interleave(num_qo_heads // num_kv_heads, dim=0) + v_seq = v_seq.repeat_interleave(num_qo_heads // num_kv_heads, dim=0) + softmax_input = (q_seq.to(torch.float32) @ k_seq.to(torch.float32)) / (head_dim**0.5) softmax_shape = softmax_input.shape assert softmax_shape[-2] == append_length length_diff = softmax_shape[-1] - softmax_shape[-2] assert length_diff >= 0 - mask = np.tril( - np.full_like(softmax_input, np.finfo("float32").max), k=length_diff - ) + np.triu(np.full_like(softmax_input, np.finfo("float32").min), k=length_diff + 1) + mask = torch.tril( + torch.full_like(softmax_input, torch.finfo(torch.float32).max), diagonal=length_diff + ) + torch.triu( + torch.full_like(softmax_input, torch.finfo(torch.float32).min), + diagonal=length_diff + 1, + ) if token_tree_parent_ptr_list is not None: tree_size = len(token_tree_parent_ptr_list[i]) - tree_mask = np.full( - (tree_size, tree_size), np.finfo("float32").min, dtype="float32" + tree_mask = torch.full( + (tree_size, tree_size), + torch.finfo(torch.float32).min, + dtype=torch.float32, + device=device_torch, ) for i, parent in enumerate(token_tree_parent_ptr_list[i]): if parent != -1: tree_mask[i] = tree_mask[parent] - tree_mask[i, i] = np.finfo("float32").max - tree_mask = np.broadcast_to(tree_mask, (num_qo_heads, *tree_mask.shape)) + tree_mask[i, i] = torch.finfo(torch.float32).max + tree_mask = tree_mask.expand(num_qo_heads, *tree_mask.shape) mask[:, :, -tree_size:] = tree_mask[:, -append_length:, :] - softmax_input = np.minimum(softmax_input, mask) + softmax_input = torch.minimum(softmax_input, mask) + + results = torch.unsqueeze( + ( + torch.nn.functional.softmax(softmax_input, dim=-1) @ v_seq.to(torch.float32) + ).permute(1, 0, 2), + dim=0, + ).to(dtype_torch) - results = np.expand_dims( - (scipy.special.softmax(softmax_input, axis=-1) @ v_seq.astype("float32")).transpose( - 1, 0, 2 - ), - axis=0, - ).astype(dtype) if not only_update_host: - tvm.testing.assert_allclose( + torch.testing.assert_close( outputs[:, sum_length : sum_length + append_length, ...], results, rtol=1e-3, @@ -549,19 +599,19 @@ def apply_attention( if cached_k[seq_id].shape[1] > sliding_window_size: # Apply sliding window and sink to cached kv. length_to_slide = cached_k[seq_id].shape[1] - sliding_window_size - cached_k[seq_id] = np.concatenate( + cached_k[seq_id] = torch.cat( [ cached_k[seq_id][:, :attn_sink_size, ...], cached_k[seq_id][:, attn_sink_size + length_to_slide :, ...], ], - axis=1, + dim=1, ) - cached_v[seq_id] = np.concatenate( + cached_v[seq_id] = torch.cat( [ cached_v[seq_id][:, :attn_sink_size, ...], cached_v[seq_id][:, attn_sink_size + length_to_slide :, ...], ], - axis=1, + dim=1, ) assert cached_k[seq_id].shape[1] == sliding_window_size diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_cpu.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_cpu.py index 9487bbf8601a..0d6ee7b54e50 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_cpu.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_cpu.py @@ -26,6 +26,7 @@ import tvm.testing from tvm import dlight as dl from tvm.relax.frontend.nn.llm.kv_cache import ( + AttnKind, _attention_decode_cpu, _attention_prefill_cpu, _attention_prefill_ragged_cpu, @@ -48,6 +49,7 @@ num_qo_heads = 32 num_kv_heads = 4 head_dim = None +sm_scale = None rope_scale = 1.0 rope_theta = 1e4 rope_scaling = {} @@ -120,7 +122,9 @@ def set_global_func(head_dim, dtype): _attention_decode_cpu(num_kv_heads, num_qo_heads, head_dim, dtype, False, rope_scaling), _attention_prefill_cpu(num_kv_heads, num_qo_heads, head_dim, dtype, True, rope_scaling), _attention_decode_cpu(num_kv_heads, num_qo_heads, head_dim, dtype, True, rope_scaling), - _attention_prefill_ragged_cpu(num_kv_heads, num_qo_heads, head_dim, dtype, rope_scaling), + _attention_prefill_ragged_cpu( + num_kv_heads, num_qo_heads, head_dim, head_dim, dtype, rope_scaling + ), tree_attn_cpu(num_kv_heads, num_qo_heads, head_dim, dtype, rope_scaling), tree_attn_with_paged_kv_cache_cpu( num_kv_heads, num_qo_heads, head_dim, dtype, rope_scaling @@ -156,7 +160,7 @@ def set_global_func(head_dim, dtype): def create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window): - fcreate = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_create_reduced") + fcreate = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_create") cache = fcreate( tvm.runtime.ShapeTuple( [ @@ -171,25 +175,29 @@ def create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window): num_qo_heads, num_kv_heads, head_dim, + head_dim, # v_head_dim + tvm.runtime.ShapeTuple([int(AttnKind.MHA) for _ in range(num_layers)]), + False, # enable_kv_transfer rope_mode, rope_scale, rope_theta, + None, # rope_ext_factors tvm.nd.empty((), dtype, device=device), ftranspose_append, - fattn_prefill, - fattn_decode, - fattn_prefill_sliding_window, - fattn_decode_sliding_window, - fattn_prefill_ragged, - fmerge_state, + None, # f_transpose_append_mla + ["tir", fattn_prefill_ragged], + ["tir", fattn_prefill], + ["tir", fattn_decode], + ["tir", fattn_prefill_sliding_window], + ["tir", fattn_decode_sliding_window], + ["tir", fattn_prefill_with_tree_mask_paged_kv_cache], + ["tir", fattn_prefill_with_tree_mask], + [], # f_mla_prefill + [fmerge_state], fsplit_rotary, fcopy_single_page, fcopy_cache, fcompact_copy, - fattn_prefill_with_tree_mask, - fattn_prefill_with_tree_mask_paged_kv_cache, - None, - False, ) return cache @@ -223,8 +231,9 @@ class RopeMode(enum.IntEnum): ) ) def kv_cache_and_config(request): - global head_dim, dtype + global head_dim, sm_scale, dtype head_dim, dtype, rope_mode, support_sliding_window = request.param + sm_scale = head_dim ** (-0.5) set_global_func(head_dim, dtype) return create_kv_cache(*request.param), rope_mode, support_sliding_window @@ -388,7 +397,7 @@ def apply_attention( values_np = global_new_v[layer_id] qkv = tvm.nd.array(np.concatenate([queries_np, keys_np, values_np], axis=1), device) outputs = tvm.nd.empty(queries_np.shape, dtype, device=device) - fattention_with_fuse_qkv(kv_cache, layer_id, 1.0, qkv, outputs) + fattention_with_fuse_qkv(kv_cache, layer_id, sm_scale, qkv, outputs) # Compute attention expected results. outputs = np.expand_dims(outputs.numpy(), axis=0) @@ -944,6 +953,7 @@ def test_paged_attention_kv_cache_tree_attn(kv_cache_and_config): for head_dim, dtype, rope_mode, support_sliding_window in itertools.product( HEAD_DIMS, DTYPES, ROPE_MODES, SUPPORT_SLIDING_WINDOW ): + sm_scale = head_dim ** (-0.5) set_global_func(head_dim, dtype) cache = create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window) cache_and_config = (cache, rope_mode, support_sliding_window) diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py index fe4da50cd9bf..589184b51091 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py @@ -16,18 +16,22 @@ # under the License. from typing import Dict, List, Tuple, Union -import numpy as np import pytest -import scipy.special +import torch import tvm import tvm.testing from tvm import dlight as dl +from tvm import relax +from tvm.contrib import utils from tvm.relax.frontend.nn.llm.kv_cache import ( + AttnKind, RopeMode, + _compact_kv_copy, _copy_single_page, _kv_cache_debug_get_kv, _kv_cache_transpose_append, + _merge_state_inplace, llama_rope_with_position_map, ) from tvm.runtime import ShapeTuple @@ -40,13 +44,15 @@ num_qo_heads = 32 num_kv_heads = 4 head_dim = 128 +sm_scale = head_dim ** (-0.5) rope_scale = 1.0 rope_theta = 1e4 dtype = "float16" +dtype_torch = getattr(torch, dtype) device = tvm.cuda() +device_torch = torch.device("cuda") fclear = None -fcreate = None fadd_sequence = None fremove_sequence = None ffork_sequence = None @@ -60,33 +66,27 @@ fattention_prefill = None fattention_decode = None fattention_prefill_ragged = None -fattention_prefill_begin_forward = None -fattention_prefill_end_forward = None -fattention_decode_begin_forward = None -fattention_decode_end_forward = None -fattention_prefill_ragged_begin_forward = None -fattention_prefill_ragged_end_forward = None +fattention_prefill_plan = None +fattention_decode_plan = None +fattention_prefill_ragged_plan = None fattention_merge_state = None ftranspose_append = None fsplit_rotary = None fcopy_single_page = None fcopy_cache = None +fcompact_copy = None def set_global_func(): - global fclear, fcreate, fadd_sequence, fremove_sequence, ffork_sequence, fpopn + global fclear, fadd_sequence, fremove_sequence, ffork_sequence, fpopn global fbegin_forward, fend_forward, fattention, fattention_with_fuse_qkv, fdebug_get_kv - global fattention_prefill, fattention_prefill_begin_forward, fattention_prefill_end_forward - global fattention_decode, fattention_decode_begin_forward, fattention_decode_end_forward - global fattention_prefill_ragged - global fattention_prefill_ragged_begin_forward - global fattention_prefill_ragged_end_forward + global fattention_prefill, fattention_decode, fattention_prefill_ragged + global fattention_prefill_plan, fattention_decode_plan, fattention_prefill_ragged_plan global fattention_merge_state, fsplit_rotary, fcopy_single_page - global ftranspose_append, fcopy_cache + global ftranspose_append, fcopy_cache, fcompact_copy fclear = tvm.get_global_func("vm.builtin.kv_state_clear") - fcreate = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_create") fadd_sequence = tvm.get_global_func("vm.builtin.kv_state_add_sequence") fremove_sequence = tvm.get_global_func("vm.builtin.kv_state_remove_sequence") ffork_sequence = tvm.get_global_func("vm.builtin.kv_state_fork_sequence") @@ -98,44 +98,59 @@ def set_global_func(): ) fdebug_get_kv = tvm.get_global_func("vm.builtin.attention_kv_cache_debug_get_kv") - fattention_prefill = tvm.get_global_func( - "flashinfer.attention_kernel_prefill_with_paged_kv_cache" - ) - fattention_decode = tvm.get_global_func( - "flashinfer.attention_kernel_decode_with_paged_kv_cache" - ) - fattention_prefill_ragged = tvm.get_global_func( - "flashinfer.attention_kernel_prefill_with_ragged_kv_cache" - ) - fattention_prefill_begin_forward = tvm.get_global_func( - "flashinfer.attention_kernel_prefill_with_paged_kv_cache_begin_forward" - ) - fattention_prefill_end_forward = tvm.get_global_func( - "flashinfer.attention_kernel_prefill_with_paged_kv_cache_end_forward" - ) - fattention_decode_begin_forward = tvm.get_global_func( - "flashinfer.attention_kernel_decode_with_paged_kv_cache_begin_forward" - ) - fattention_decode_end_forward = tvm.get_global_func( - "flashinfer.attention_kernel_decode_with_paged_kv_cache_end_forward" - ) - fattention_prefill_ragged_begin_forward = tvm.get_global_func( - "flashinfer.attention_kernel_prefill_with_ragged_kv_cache_begin_forward" + def load_module(name: str, static_modules: List[tvm.runtime.Module]): + assert len(static_modules) > 0 + if len(static_modules) == 1: + return static_modules[0] + static_mod = static_modules[0] + for mod in static_modules[1:]: + static_mod.import_module(mod) + temp = utils.tempdir() + mod_path = temp.relpath(f"{name}.so") + static_mod.export_library(mod_path) + return tvm.runtime.load_module(mod_path) + + target = tvm.target.Target.from_device(device) + flashinfer_prefill_mod = load_module( + "flashinfer_prefill", + relax.backend.cuda.flashinfer.gen_flashinfer_prefill_module( + dtype_q=dtype, + dtype_kv=dtype, + dtype_o=dtype, + qk_head_dim=head_dim, + v_head_dim=head_dim, + target=target, + ), ) - fattention_prefill_ragged_end_forward = tvm.get_global_func( - "flashinfer.attention_kernel_prefill_with_ragged_kv_cache_end_forward" + flashinfer_decode_mod = load_module( + "flashinfer_decode", + relax.backend.cuda.flashinfer.gen_flashinfer_decode_module( + dtype_q=dtype, + dtype_kv=dtype, + dtype_o=dtype, + qk_head_dim=head_dim, + v_head_dim=head_dim, + target=target, + ), ) - fattention_merge_state = tvm.get_global_func("flashinfer.merge_state_in_place") - target = tvm.target.Target.from_device(device) + fattention_prefill = flashinfer_prefill_mod["batch_prefill_with_paged_kv_cache_run"] + fattention_prefill_plan = flashinfer_prefill_mod["batch_prefill_with_kv_cache_plan"] + fattention_prefill_ragged = flashinfer_prefill_mod["batch_prefill_with_ragged_kv_cache_run"] + fattention_prefill_ragged_plan = flashinfer_prefill_mod["batch_prefill_with_kv_cache_plan"] + fattention_decode = flashinfer_decode_mod["batch_decode_with_paged_kv_cache_run"] + fattention_decode_plan = flashinfer_decode_mod["batch_decode_with_paged_kv_cache_plan"] + builts = [] for tir_func in [ _kv_cache_transpose_append(num_kv_heads, head_dim, dtype), + _merge_state_inplace(num_qo_heads, head_dim, dtype, target), llama_rope_with_position_map( rope_theta, rope_scale, head_dim, num_qo_heads, num_kv_heads, dtype, {} ), _copy_single_page(num_kv_heads, page_size, head_dim, dtype, target), _kv_cache_debug_get_kv(num_layers, num_kv_heads, head_dim, dtype), + _compact_kv_copy(num_kv_heads, head_dim, dtype, target), ]: mod = tvm.IRModule({"main": tir_func}) with target: @@ -143,10 +158,18 @@ def set_global_func(): f = tvm.build(mod["main"], target=target) builts.append(f.entry_func) - ftranspose_append, fsplit_rotary, fcopy_single_page, fcopy_cache = builts + ( + ftranspose_append, + fattention_merge_state, + fsplit_rotary, + fcopy_single_page, + fcopy_cache, + fcompact_copy, + ) = builts def create_kv_cache(rope_mode): + fcreate = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_create") support_sliding_window = 0 cache = fcreate( tvm.runtime.ShapeTuple( @@ -162,31 +185,29 @@ def create_kv_cache(rope_mode): num_qo_heads, num_kv_heads, head_dim, + head_dim, # v_head_dim + tvm.runtime.ShapeTuple([int(AttnKind.MHA) for _ in range(num_layers)]), + False, # enable_kv_transfer rope_mode, rope_scale, rope_theta, + None, # rope_ext_factors tvm.nd.empty((), dtype, device=device), ftranspose_append, - fattention_prefill, - fattention_decode, - fattention_prefill, - fattention_decode, - fattention_prefill_ragged, - fattention_prefill_ragged_begin_forward, - fattention_prefill_ragged_end_forward, - fattention_prefill_begin_forward, - fattention_prefill_end_forward, - fattention_decode_begin_forward, - fattention_decode_end_forward, - fattention_merge_state, + None, # f_transpose_append_mla + ["flashinfer", fattention_prefill_ragged, fattention_prefill_ragged_plan], + ["flashinfer", fattention_prefill, fattention_prefill_plan], + ["flashinfer", fattention_decode, fattention_decode_plan], + [], # fattn_prefill_sliding_window + [], # fattn_decode_sliding_window + [], # fattn_prefill_with_tree_mask_paged_kv_cache + [], # fattn_prefill_with_tree_mask + [], # f_mla_prefill + [fattention_merge_state], fsplit_rotary, fcopy_single_page, fcopy_cache, - None, - None, - None, - None, - False, + fcompact_copy, ) return cache @@ -206,8 +227,12 @@ def verify_cached_kv(kv_cache, seq_ids, expected_k, expected_v): keys = tvm.nd.empty(keys_expected.shape, dtype=dtype, device=device) values = tvm.nd.empty(values_expected.shape, dtype=dtype, device=device) fdebug_get_kv(kv_cache, seq_id, 0, seq_length, keys, values) - tvm.testing.assert_allclose(keys.numpy(), keys_expected, rtol=1e-3, atol=1e-3) - tvm.testing.assert_allclose(values.numpy(), values_expected, rtol=1e-3, atol=1e-3) + torch.testing.assert_close( + torch.from_numpy(keys.numpy()).to(device_torch), keys_expected, rtol=1e-3, atol=1e-3 + ) + torch.testing.assert_close( + torch.from_numpy(values.numpy()).to(device_torch), values_expected, rtol=1e-3, atol=1e-3 + ) def f_apply_rotary(x, offset, scale, theta): @@ -215,25 +240,30 @@ def f_apply_rotary(x, offset, scale, theta): assert len(x.shape) == 3 nfeat = x.shape[-1] nfeat_half = x.shape[-1] // 2 - x = x.astype("float32") - y = np.concatenate([-x[:, :, nfeat_half:], x[:, :, :nfeat_half]], axis=-1) + x_dtype = x.dtype + x = x.to(torch.float32) + y = torch.cat([-x[:, :, nfeat_half:], x[:, :, :nfeat_half]], dim=-1) - inv_freq = scale / (theta ** (np.arange(0, nfeat, 2).astype("float32") / nfeat)) - t = np.arange(offset, offset + x.shape[0], dtype=inv_freq.dtype) - freqs = np.einsum("i,j->ij", t, inv_freq) - emb = np.concatenate((freqs, freqs), axis=-1) - cos_values = np.cos(emb) - sin_values = np.sin(emb) + inv_freq = scale / ( + theta ** (torch.arange(0, nfeat, 2, device=device_torch, dtype=torch.float32) / nfeat) + ) + t = torch.arange(offset, offset + x.shape[0], device=device_torch, dtype=inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + cos_values = torch.cos(emb) + sin_values = torch.sin(emb) - return np.einsum("ij,ikj->ikj", cos_values, x) + np.einsum("ij,ikj->ikj", sin_values, y) + return torch.einsum("ij,ikj->ikj", cos_values, x).to(x_dtype) + torch.einsum( + "ij,ikj->ikj", sin_values, y + ).to(x_dtype) def apply_attention( kv_cache, rope_mode: RopeMode, batch: List[Tuple[Union[int, Tuple[int, int, int]], int]], - cached_k: Dict[int, np.ndarray], - cached_v: Dict[int, np.ndarray], + cached_k: Dict[int, torch.Tensor], + cached_v: Dict[int, torch.Tensor], ) -> None: seq_ids = [] append_lengths = [] @@ -257,26 +287,60 @@ def apply_attention( cached_v[seq_id] = cached_v[fork_parent_id][::, :fork_pos] elif seq_id not in cached_k: fadd_sequence(kv_cache, seq_id) - cached_k[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype) - cached_v[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype) + cached_k[seq_id] = torch.zeros( + (num_layers, 0, num_kv_heads, head_dim), dtype=dtype_torch, device=device_torch + ) + cached_v[seq_id] = torch.zeros( + (num_layers, 0, num_kv_heads, head_dim), dtype=dtype_torch, device=device_torch + ) fbegin_forward(kv_cache, ShapeTuple(seq_ids), ShapeTuple(append_lengths)) - global_new_q = np.zeros((num_layers, 0, num_qo_heads, head_dim), dtype) - global_new_k = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype) - global_new_v = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype) + global_new_q = torch.zeros( + (num_layers, 0, num_qo_heads, head_dim), dtype=dtype_torch, device=device_torch + ) + global_new_k = torch.zeros( + (num_layers, 0, num_kv_heads, head_dim), dtype=dtype_torch, device=device_torch + ) + global_new_v = torch.zeros( + (num_layers, 0, num_kv_heads, head_dim), dtype=dtype_torch, device=device_torch + ) q_array = [] for seq_id, append_length in batch: - new_q = np.random.rand(num_layers, append_length, num_qo_heads, head_dim).astype(dtype) - new_k = np.random.rand(num_layers, append_length, num_kv_heads, head_dim).astype(dtype) - new_v = np.random.rand(num_layers, append_length, num_kv_heads, head_dim).astype(dtype) + new_q = torch.rand( + num_layers, + append_length, + num_qo_heads, + head_dim, + dtype=dtype_torch, + device=device_torch, + ) + new_k = torch.rand( + num_layers, + append_length, + num_kv_heads, + head_dim, + dtype=dtype_torch, + device=device_torch, + ) + new_v = torch.rand( + num_layers, + append_length, + num_kv_heads, + head_dim, + dtype=dtype_torch, + device=device_torch, + ) + new_q = new_q * 2 - 1 + new_k = new_k * 2 - 1 + new_v = new_v * 2 - 1 q_array.append(new_q) - cached_k[seq_id] = np.concatenate( + cached_k[seq_id] = torch.cat( [ cached_k[seq_id], - np.stack( + torch.stack( [ ( new_k[l] @@ -287,26 +351,26 @@ def apply_attention( ) for l in range(num_layers) ], - axis=0, + dim=0, ), ], - axis=1, + dim=1, ) - cached_v[seq_id] = np.concatenate([cached_v[seq_id], new_v], axis=1) - global_new_q = np.concatenate([global_new_q, new_q], axis=1) - global_new_k = np.concatenate([global_new_k, new_k], axis=1) - global_new_v = np.concatenate([global_new_v, new_v], axis=1) + cached_v[seq_id] = torch.cat([cached_v[seq_id], new_v], dim=1) + global_new_q = torch.cat([global_new_q, new_q], dim=1) + global_new_k = torch.cat([global_new_k, new_k], dim=1) + global_new_v = torch.cat([global_new_v, new_v], dim=1) for layer_id in range(num_layers): queries_np = global_new_q[layer_id] keys_np = global_new_k[layer_id] values_np = global_new_v[layer_id] - qkv = tvm.nd.array(np.concatenate([queries_np, keys_np, values_np], axis=1), device) + qkv = tvm.nd.array(torch.cat([queries_np, keys_np, values_np], dim=1).cpu().numpy(), device) outputs = tvm.nd.empty(queries_np.shape, dtype, device=device) - fattention_with_fuse_qkv(kv_cache, layer_id, 1.0, qkv, outputs) + fattention_with_fuse_qkv(kv_cache, layer_id, sm_scale, qkv, outputs) # Compute attention expected results. - outputs = np.expand_dims(outputs.numpy(), axis=0) + outputs = torch.from_numpy(outputs.numpy()).unsqueeze(0).to(device_torch) sum_length = 0 for i, (seq_id, append_length) in enumerate(batch): assert cached_k[seq_id].shape[1] == cached_v[seq_id].shape[1] >= append_length @@ -321,33 +385,36 @@ def apply_attention( rope_scale, rope_theta, ) - ).transpose(1, 0, 2) + ).permute(1, 0, 2) k_seq = ( cached_k[seq_id][layer_id] if rope_mode != RopeMode.INLINE else f_apply_rotary(cached_k[seq_id][layer_id], 0, rope_scale, rope_theta) - ).transpose(1, 2, 0) - v_seq = cached_v[seq_id][layer_id].transpose(1, 0, 2) + ).permute(1, 2, 0) + v_seq = cached_v[seq_id][layer_id].permute(1, 0, 2) - k_seq = np.repeat(k_seq, num_qo_heads // num_kv_heads, axis=0) - v_seq = np.repeat(v_seq, num_qo_heads // num_kv_heads, axis=0) - softmax_input = (q_seq.astype("float32") @ k_seq.astype("float32")) / np.sqrt(head_dim) + k_seq = k_seq.repeat_interleave(num_qo_heads // num_kv_heads, dim=0) + v_seq = v_seq.repeat_interleave(num_qo_heads // num_kv_heads, dim=0) + softmax_input = (q_seq.to(torch.float32) @ k_seq.to(torch.float32)) / (head_dim**0.5) softmax_shape = softmax_input.shape length_diff = softmax_shape[-1] - softmax_shape[-2] assert length_diff >= 0 - mask = np.tril( - np.full_like(softmax_input, np.finfo("float32").max), k=length_diff - ) + np.triu(np.full_like(softmax_input, np.finfo("float32").min), k=length_diff + 1) - softmax_input = np.minimum(softmax_input, mask) - - results = np.expand_dims( - (scipy.special.softmax(softmax_input, axis=-1) @ v_seq.astype("float32")).transpose( - 1, 0, 2 - ), - axis=0, - ).astype(dtype) + mask = torch.tril( + torch.full_like(softmax_input, torch.finfo(torch.float32).max), diagonal=length_diff + ) + torch.triu( + torch.full_like(softmax_input, torch.finfo(torch.float32).min), + diagonal=length_diff + 1, + ) + softmax_input = torch.minimum(softmax_input, mask) + + results = torch.unsqueeze( + ( + torch.nn.functional.softmax(softmax_input, dim=-1) @ v_seq.to(torch.float32) + ).permute(1, 0, 2), + dim=0, + ).to(dtype_torch) - tvm.testing.assert_allclose( + torch.testing.assert_close( outputs[:, sum_length : sum_length + append_length, ...], results, rtol=1e-3, diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py new file mode 100644 index 000000000000..2aeb1b158bf4 --- /dev/null +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py @@ -0,0 +1,593 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import itertools +from typing import Dict, List, Tuple, Union + +import numpy as np +import pytest +import torch + +import tvm +import tvm.testing +from tvm import dlight as dl +from tvm import relax +from tvm.contrib import utils +from tvm.relax.frontend.nn.llm.kv_cache import ( + AttnKind, + RopeMode, + _copy_single_page_mla, + _kv_cache_debug_get_kv_mla, + _kv_cache_transpose_append_mla, + _merge_state_inplace, +) +from tvm.runtime import ShapeTuple + +np.random.seed(0) + +reserved_nseq = 32 +maximum_total_seq_length = 2048 +prefill_chunk_size = 512 +page_size = 16 +num_layers = 4 +num_attention_heads = 128 +qk_nope_head_dim = 128 +qk_rope_head_dim = 64 +v_head_dim = qk_nope_head_dim +sm_scale = (qk_nope_head_dim + qk_rope_head_dim) ** (-0.5) +kv_lora_rank = 512 +dtype = "float16" +dtype_torch = getattr(torch, dtype) +device = tvm.cuda() +device_torch = torch.device("cuda") + +fclear = None +fadd_sequence = None +fremove_sequence = None +ffork_sequence = None +fpopn = None +fbegin_forward = None +fend_forward = None +fself_attn = None +fcross_attn = None +fappend_mla_kv = None +fkv_merge_attn_output = None +fis_empty = None +fdebug_get_kv = None + +ftranspose_append = None +fcopy_cache = None +fmla_prefill = None +fmla_prefill_plan = None +fattn_prefill_ragged = None +fattn_prefill_ragged_plan = None +fmerge_state = None +fmerge_state_additional = None +fcopy_single_page = None + +w_kv = None +w_uk = None +w_uv = None + + +# Register a dumb function for testing purpose. +@tvm.register_func("test.dumb_function", override=True) +def _dumb_function(): + raise RuntimeError("Dumb function isn't supposed to be accessed.") + + +def set_global_func(dtype): + global fclear, fadd_sequence, fremove_sequence, ffork_sequence + global fpopn, fbegin_forward, fend_forward + global fself_attn, fcross_attn, fappend_mla_kv, fkv_merge_attn_output + global fis_empty, fdebug_get_kv + global ftranspose_append, fcopy_cache, fmla_prefill, fmla_prefill_plan + global fattn_prefill_ragged, fattn_prefill_ragged_plan + global fmerge_state, fmerge_state_additional, fcopy_single_page + global w_kv, w_uk, w_uv + + fclear = tvm.get_global_func("vm.builtin.kv_state_clear") + fadd_sequence = tvm.get_global_func("vm.builtin.kv_state_add_sequence") + fremove_sequence = tvm.get_global_func("vm.builtin.kv_state_remove_sequence") + ffork_sequence = tvm.get_global_func("vm.builtin.kv_state_fork_sequence") + fpopn = tvm.get_global_func("vm.builtin.kv_state_popn") + fbegin_forward = tvm.get_global_func("vm.builtin.kv_state_begin_forward") + fend_forward = tvm.get_global_func("vm.builtin.kv_state_end_forward") + fself_attn = tvm.get_global_func("vm.builtin.attention_kv_cache_self_attention") + fcross_attn = tvm.get_global_func("vm.builtin.attention_kv_cache_cross_attention") + fappend_mla_kv = tvm.get_global_func("vm.builtin.attention_kv_cache_append_mla_kv") + fkv_merge_attn_output = tvm.get_global_func( + "vm.builtin.attention_kv_cache_merge_attn_output_inplace" + ) + fis_empty = tvm.get_global_func("vm.builtin.attention_kv_cache_empty") + fdebug_get_kv = tvm.get_global_func("vm.builtin.attention_kv_cache_debug_get_kv_mla") + + def load_module(name: str, static_modules: List[tvm.runtime.Module]): + assert len(static_modules) > 0 + if len(static_modules) == 1: + return static_modules[0] + static_mod = static_modules[0] + for mod in static_modules[1:]: + static_mod.import_module(mod) + temp = utils.tempdir() + mod_path = temp.relpath(f"{name}.so") + static_mod.export_library(mod_path) + return tvm.runtime.load_module(mod_path) + + target = tvm.target.Target.from_device(device) + flashinfer_prefill_mod = load_module( + "flashinfer_prefill", + relax.backend.cuda.flashinfer.gen_flashinfer_prefill_module( + dtype_q=dtype, + dtype_kv=dtype, + dtype_o=dtype, + qk_head_dim=qk_nope_head_dim + qk_rope_head_dim, + v_head_dim=v_head_dim, + target=target, + enable_inline_rope=False, + ), + ) + flashinfer_mla_mod = load_module( + "flashinfer_mla", + relax.backend.cuda.flashinfer.gen_flashinfer_mla_module( + dtype_q=dtype, + dtype_kv=dtype, + dtype_o=dtype, + head_dim_ckv=kv_lora_rank, + head_dim_kpe=qk_rope_head_dim, + target=target, + ), + ) + + fattn_prefill_ragged = flashinfer_prefill_mod["batch_prefill_with_ragged_kv_cache_run"] + fattn_prefill_ragged_plan = flashinfer_prefill_mod["batch_prefill_with_kv_cache_plan"] + fmla_prefill = flashinfer_mla_mod["batch_mla_paged_attention_run"] + fmla_prefill_plan = flashinfer_mla_mod["batch_mla_paged_attention_plan"] + + builts = [] + for tir_func in [ + _kv_cache_transpose_append_mla(kv_lora_rank + qk_rope_head_dim, dtype), + _kv_cache_debug_get_kv_mla(num_layers, kv_lora_rank + qk_rope_head_dim, dtype), + _merge_state_inplace(num_attention_heads, kv_lora_rank, dtype, target), + _merge_state_inplace(num_attention_heads, v_head_dim, dtype, target), + _copy_single_page_mla(page_size, kv_lora_rank + qk_rope_head_dim, dtype, target), + ]: + mod = tvm.IRModule({"main": tir_func}) + with target: + mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod) + f = tvm.build(mod["main"], target=target) + builts.append(f.entry_func) + + ( + ftranspose_append, + fcopy_cache, + fmerge_state, + fmerge_state_additional, + fcopy_single_page, + ) = builts + + w_kv = torch.empty( + (kv_lora_rank, num_attention_heads * (qk_nope_head_dim + v_head_dim)), + device=device_torch, + dtype=dtype_torch, + ) + w_kv.uniform_(-0.1, 0.1) + w_uk, w_uv = torch.split( + w_kv.view(kv_lora_rank, num_attention_heads, qk_nope_head_dim + v_head_dim), + [qk_nope_head_dim, v_head_dim], + dim=2, + ) + w_uk = w_uk.permute(1, 2, 0) + w_uv = w_uv.permute(1, 0, 2) + + +def create_kv_cache(dtype): + fcreate = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_create") + fdumb = tvm.get_global_func("test.dumb_function") + cache = fcreate( + tvm.runtime.ShapeTuple( + [ + reserved_nseq, + maximum_total_seq_length, + prefill_chunk_size, + page_size, + 0, + ] + ), + tvm.runtime.ShapeTuple([0, num_layers]), + num_attention_heads, + 1, # num_kv_heads + kv_lora_rank + qk_rope_head_dim, + kv_lora_rank, + tvm.runtime.ShapeTuple([int(AttnKind.MLA) for _ in range(num_layers)]), + False, # enable_kv_transfer + RopeMode.NONE, + 1, + 10000, + None, # rope_ext_factors + tvm.nd.empty((), dtype, device=device), + None, # f_transpose_append_mha + ftranspose_append, + ["flashinfer", fattn_prefill_ragged, fattn_prefill_ragged_plan], # fattn_prefill_ragged + [], # fattn_prefill + [], # fattn_decode + [], # fattn_prefill_sliding_window + [], # fattn_decode_sliding_window + [], # fattn_prefill_with_tree_mask_paged_kv_cache + [], # fattn_prefill_with_tree_mask + ["flashinfer", fmla_prefill, fmla_prefill_plan], + [fmerge_state, fmerge_state_additional], + fdumb, # fsplit_rotary + fcopy_single_page, + fcopy_cache, + fdumb, # fcompact_copy + ) + return cache + + +@pytest.fixture(params=itertools.product(["float16"])) +def kv_cache_and_config(request): + global dtype, dtype_torch + (dtype,) = request.param + dtype_torch = getattr(torch, dtype) + set_global_func(dtype) + return (create_kv_cache(dtype),) + + +def verify_cached_kv(kv_cache, seq_ids, expected_kv): + for seq_id in seq_ids: + kv_expected = expected_kv[seq_id] + seq_length = expected_kv[seq_id].shape[1] + kv_actual = tvm.nd.empty(kv_expected.shape, dtype=dtype, device=device) + fdebug_get_kv(kv_cache, seq_id, 0, seq_length, kv_actual) + torch.testing.assert_close( + torch.from_numpy(kv_actual.numpy()).to(device_torch), kv_expected, rtol=1e-3, atol=1e-3 + ) + + +def apply_attention( + kv_cache, + batch: List[Tuple[Union[int, Tuple[int, int, int]], int]], + cached_kv: Dict[int, torch.Tensor], +) -> None: + seq_ids = [] + append_lengths = [] + for i, (seq_id, append_length) in enumerate(batch): + fork_parent_id = None + if isinstance(seq_id, tuple): + # Fork sequence + seq_id, fork_parent_id, fork_pos = seq_id + batch[i] = (seq_id, append_length) + seq_ids.append(seq_id) + append_lengths.append(append_length) + if fork_parent_id is not None: + assert fork_parent_id in cached_kv + assert seq_id not in cached_kv + ffork_sequence(kv_cache, fork_parent_id, seq_id, fork_pos) + if fork_pos == -1: + cached_kv[seq_id] = cached_kv[fork_parent_id] + else: + cached_kv[seq_id] = cached_kv[fork_parent_id][::, :fork_pos] + elif seq_id not in cached_kv: + fadd_sequence(kv_cache, seq_id) + cached_kv[seq_id] = torch.zeros( + (num_layers, 0, kv_lora_rank + qk_rope_head_dim), + dtype=dtype_torch, + device=device_torch, + ) + + fbegin_forward(kv_cache, ShapeTuple(seq_ids), ShapeTuple(append_lengths), None) + + global_new_q = torch.zeros( + (num_layers, 0, num_attention_heads, qk_nope_head_dim + qk_rope_head_dim), + dtype=dtype_torch, + device=device_torch, + ) + global_new_kv = torch.zeros( + (num_layers, 0, kv_lora_rank + qk_rope_head_dim), + dtype=dtype_torch, + device=device_torch, + ) + + q_array = [] + is_decode_request = True + all_new_sequences = True + for i, (seq_id, append_length) in enumerate(batch): + new_q_np = np.random.uniform( + -0.1, + 0.1, + size=( + num_layers, + append_length, + num_attention_heads, + qk_nope_head_dim + qk_rope_head_dim, + ), + ).astype(dtype) + new_kv_np = np.random.uniform( + -0.1, 0.1, size=(num_layers, append_length, kv_lora_rank + qk_rope_head_dim) + ).astype(dtype) + q_array.append(new_q_np) + + # Convert the numpy arrays to torch tensors on device. + new_q_tensor = torch.from_numpy(new_q_np).to(device_torch) + new_kv_tensor = torch.from_numpy(new_kv_np).to(device_torch) + + all_new_sequences = all_new_sequences and cached_kv[seq_id].shape[1] == 0 + cached_kv[seq_id] = torch.cat([cached_kv[seq_id], new_kv_tensor], dim=1) + global_new_q = torch.cat([global_new_q, new_q_tensor], dim=1) + global_new_kv = torch.cat([global_new_kv, new_kv_tensor], dim=1) + + if append_length > 1: + is_decode_request = False + + for layer_id in range(num_layers): + queries = tvm.nd.array(global_new_q[layer_id].cpu().numpy(), device) + key_value = tvm.nd.array(global_new_kv[layer_id].cpu().numpy(), device) + total_seq_length = global_new_q[layer_id].shape[0] + outputs1 = tvm.nd.empty( + (total_seq_length, num_attention_heads, v_head_dim), dtype, device=device + ) + lse1 = tvm.nd.empty((total_seq_length, num_attention_heads), "float32", device=device) + outputs2 = tvm.nd.empty( + (total_seq_length, num_attention_heads, kv_lora_rank), dtype, device=device + ) + lse2 = tvm.nd.empty((total_seq_length, num_attention_heads), "float32", device=device) + + fappend_mla_kv(kv_cache, layer_id, key_value) + if not is_decode_request: + # Part 1. self-attention + latent, k_pe = torch.split( + global_new_kv[layer_id], [kv_lora_rank, qk_rope_head_dim], dim=1 + ) + keys, values = torch.split( + (latent @ w_kv).to(dtype_torch).reshape(total_seq_length, num_attention_heads, -1), + [qk_nope_head_dim, v_head_dim], + dim=2, + ) + k_pe_expanded = torch.unsqueeze(k_pe, 1).expand( + total_seq_length, num_attention_heads, qk_rope_head_dim + ) + keys = torch.cat([keys, k_pe_expanded], dim=2) + keys_tvm = tvm.nd.array(keys.cpu().numpy(), device) + values_tvm = tvm.nd.array(values.cpu().numpy(), device) + fself_attn(kv_cache, layer_id, sm_scale, queries, keys_tvm, values_tvm, outputs1, lse1) + + if not all_new_sequences or is_decode_request: + # Part 2. cross-attention + queries_lora_np, q_pe = torch.split( + global_new_q[layer_id], [qk_nope_head_dim, qk_rope_head_dim], dim=2 + ) + queries_lora_np = torch.cat( + [torch.bmm(queries_lora_np.permute(1, 0, 2), w_uk).permute(1, 0, 2), q_pe], dim=2 + ) + queries_lora = tvm.nd.array(queries_lora_np.cpu().numpy(), device) + fcross_attn(kv_cache, layer_id, sm_scale, queries_lora, outputs2, lse2) + cross_attn_output = tvm.nd.array( + torch.bmm( + torch.from_numpy(outputs2.numpy()).to(device_torch).permute(1, 0, 2), w_uv + ) + .permute(1, 0, 2) + .cpu() + .numpy(), + device, + ) + + if not is_decode_request: + if not all_new_sequences: + fkv_merge_attn_output(kv_cache, outputs1, lse1, cross_attn_output, lse2) + else: + outputs1 = cross_attn_output + + # Compute attention expected results. + outputs = torch.unsqueeze(torch.tensor(outputs1.numpy()).to(device_torch), 0) + sum_length = 0 + for i, (seq_id, append_length) in enumerate(batch): + assert cached_kv[seq_id].shape[1] >= append_length + + q_seq = torch.from_numpy(q_array[i][layer_id]).to(device_torch).permute(1, 0, 2) + latent_seq, k_pe_seq = torch.split( + torch.unsqueeze(cached_kv[seq_id][layer_id], 1), + [kv_lora_rank, qk_rope_head_dim], + dim=2, + ) + k_seq, v_seq = torch.split( + (latent_seq @ w_kv).reshape(k_pe_seq.shape[0], num_attention_heads, -1), + [qk_nope_head_dim, v_head_dim], + dim=2, + ) + k_pe_seq = k_pe_seq.expand(k_pe_seq.shape[0], num_attention_heads, qk_rope_head_dim) + k_seq = torch.cat([k_seq, k_pe_seq], dim=2).permute(1, 2, 0) + v_seq = v_seq.permute(1, 0, 2) + + softmax_input = (q_seq.to(torch.float32) @ k_seq.to(torch.float32)) / torch.sqrt( + torch.tensor(qk_nope_head_dim + qk_rope_head_dim, dtype=torch.float32) + ) + softmax_shape = softmax_input.shape + assert softmax_shape[-2] == append_length + length_diff = softmax_shape[-1] - softmax_shape[-2] + assert length_diff >= 0 + # Create a mask similar to np.tril and np.triu. + mask = torch.tril( + torch.full_like(softmax_input, float(np.finfo(np.float32).max)), + diagonal=length_diff, + ) + torch.triu( + torch.full_like(softmax_input, float(np.finfo(np.float32).min)), + diagonal=length_diff + 1, + ) + softmax_input = torch.minimum(softmax_input, mask) + + results = torch.nn.functional.softmax(softmax_input, dim=-1) @ v_seq.to(torch.float32) + results = results.permute(1, 0, 2).unsqueeze(0).to(dtype_torch) + + torch.testing.assert_close( + outputs[:, sum_length : sum_length + append_length, ...], + results, + rtol=1e-3, + atol=1e-3, + ) + sum_length += append_length + fend_forward(kv_cache) + + # Verify + verify_cached_kv(kv_cache, seq_ids, cached_kv) + + +@pytest.mark.skip(reason="Require FlashInfer enabled") +def test_paged_attention_kv_cache_prefill_and_decode(kv_cache_and_config): + (kv_cache,) = kv_cache_and_config + fclear(kv_cache) + + # Prefill. + operation_seq = [[(0, 6)], [(1, 8)], [(2, 11)], [(3, 16)], [(4, 19), (5, 20)]] + operation_seq += [[(6, 21), (7, 24)], [(2, 5), (4, 7), (8, 24)]] + operation_seq += [[(6, 13)], [(8, 19)], [(0, 1)], [(1, 3), (3, 8), (5, 12), (7, 11)]] + # Decode + operation_seq += [[(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, 1), (6, 1), (7, 1), (8, 1)]] + operation_seq += [[(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, 1), (6, 1), (7, 1), (8, 1)]] + operation_seq += [[(0, 1), (2, 1), (4, 1), (6, 1), (8, 1)]] + operation_seq += [[(4, 1), (5, 1), (6, 1), (7, 1), (8, 1)]] + + cached_kv = {} + for batch in operation_seq: + apply_attention(kv_cache, batch, cached_kv) + + +@pytest.mark.skip(reason="Require FlashInfer enabled") +def test_paged_attention_kv_cache_remove_sequence(kv_cache_and_config): + (kv_cache,) = kv_cache_and_config + fclear(kv_cache) + + num_sequences = 5 + batch = [(seq_id, 1) for seq_id in range(num_sequences)] + cached_kv = {} + for seq_id_to_remove in range(num_sequences): + apply_attention(kv_cache, batch, cached_kv) + # Remove sequence. + fremove_sequence(kv_cache, seq_id_to_remove) + cached_kv.pop(seq_id_to_remove) + verify_cached_kv( + kv_cache, + seq_ids=[seq_id for seq_id in range(num_sequences) if seq_id != seq_id_to_remove], + expected_kv=cached_kv, + ) + + +@pytest.mark.skip(reason="Require FlashInfer enabled") +def test_paged_attention_kv_cache_fork_sequence(kv_cache_and_config): + (kv_cache,) = kv_cache_and_config + fclear(kv_cache) + + cached_kv = {} + batch = [(0, 60), (1, 88), (2, 17), (3, 4)] + apply_attention(kv_cache, batch, cached_kv) + # Fork existing sequences. + apply_attention(kv_cache, [((4, 3, -1), 35)], cached_kv) + apply_attention(kv_cache, [((5, 0, -1), 20)], cached_kv) + apply_attention(kv_cache, [((6, 5, -1), 102)], cached_kv) + apply_attention(kv_cache, [((7, 0, -1), 3)], cached_kv) + apply_attention(kv_cache, [((8, 5, -1), 71), ((9, 5, -1), 20)], cached_kv) + # 0 <- 5 <- 6,8,9 + # 0 <- 7 + # 3 <- 4 + # Mixture of decode and prefill. + operation_seq = [ + [(2, 1), (4, 1), (7, 1), (6, 1), (8, 1), (9, 1)], + [(7, 1), (6, 1), (8, 1), (9, 1)], + [(7, 1), (1, 1), (6, 1), (2, 1), (8, 1), (4, 1), (9, 1)], + [(7, 10), (6, 2), (8, 3), (9, 4)], + ] + for batch in operation_seq: + apply_attention(kv_cache, batch, cached_kv) + + apply_attention(kv_cache, [((10, 1, 33), 11)], cached_kv) + apply_attention(kv_cache, [((11, 0, 60), 45), ((12, 0, 15), 14)], cached_kv) + apply_attention(kv_cache, [((13, 0, 16), 19), ((14, 0, 17), 19)], cached_kv) + apply_attention(kv_cache, [((15, 5, 60), 8), ((16, 5, 80), 10)], cached_kv) + apply_attention( + kv_cache, + [((17, 5, 75), 11), ((18, 5, 76), 45), ((19, 5, 77), 14)], + cached_kv, + ) + + operation_seq = [ + [(6, 1), (11, 1), (13, 1), (9, 1)], + [(10, 1), (16, 1), (18, 1), (19, 1)], + [(8, 1), (15, 1), (17, 1), (12, 1), (14, 1)], + [(10, 10), (6, 2), (8, 3), (19, 4)], + ] + for batch in operation_seq: + apply_attention(kv_cache, batch, cached_kv) + + num_sequence = 20 + for i in range(num_sequence): + fremove_sequence(kv_cache, i) + cached_kv.pop(i) + verify_cached_kv( + kv_cache, + seq_ids=list(range(i + 1, num_sequence)), + expected_kv=cached_kv, + ) + + assert fis_empty(kv_cache), "The KV cache is not empty after removing all sequences" + + # Test fork after page recycle + apply_attention(kv_cache, [(0, 7), (1, 24)], cached_kv) + apply_attention(kv_cache, [((2, 1, -1), 10)], cached_kv) + apply_attention(kv_cache, [((3, 0, -1), 20)], cached_kv) + apply_attention(kv_cache, [(2, 1), (3, 1)], cached_kv) + + apply_attention(kv_cache, [(10, 7), (11, 24)], cached_kv) + apply_attention(kv_cache, [((12, 11, -1), 200)], cached_kv) + apply_attention(kv_cache, [(10, 1), (12, 1)], cached_kv) + + +@pytest.mark.skip(reason="Require FlashInfer enabled") +def test_paged_attention_kv_cache_popn(kv_cache_and_config): + (kv_cache,) = kv_cache_and_config + fclear(kv_cache) + + cached_kv = {} + batch = [(0, 35), (1, 88), (2, 17), (3, 4)] + apply_attention(kv_cache, batch, cached_kv) + apply_attention(kv_cache, [((4, 3, -1), 35)], cached_kv) + + popn_operations = [(0, 17), (1, 57), (2, 16), (3, 0), (4, 37)] + for seq_id, pop_length in popn_operations: + fpopn(kv_cache, seq_id, pop_length) + if pop_length != 0: + cached_kv[seq_id] = cached_kv[seq_id][:, :-pop_length, ...] + verify_cached_kv(kv_cache, seq_ids=list(range(4)), expected_kv=cached_kv) + + num_sequence = 5 + for seq_id in range(num_sequence): + fremove_sequence(kv_cache, seq_id) + verify_cached_kv( + kv_cache, + seq_ids=list(range(seq_id + 1, num_sequence)), + expected_kv=cached_kv, + ) + + assert fis_empty(kv_cache), "The KV cache is not empty after removing all sequences" + + +if __name__ == "__main__": + set_global_func(dtype) + cache = create_kv_cache(dtype) + cache_and_config = (cache,) + test_paged_attention_kv_cache_prefill_and_decode(cache_and_config) + test_paged_attention_kv_cache_remove_sequence(cache_and_config) + test_paged_attention_kv_cache_fork_sequence(cache_and_config) + test_paged_attention_kv_cache_popn(cache_and_config) diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_tir.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_tir.py index 72a45b8a4cf3..bee4cfe1a3cf 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_tir.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_tir.py @@ -19,7 +19,7 @@ import numpy as np import pytest -import scipy.special +import torch import tvm import tvm.testing @@ -27,9 +27,8 @@ from tvm.relax.frontend.nn.llm.kv_cache import ( AttnKind, RopeMode, - _attention_decode_mla, _attention_prefill_mla, - _attention_prefill_ragged_mla_absorbed, + _attention_prefill_ragged, _copy_single_page_mla, _kv_cache_debug_get_kv_mla, _kv_cache_transpose_append_mla, @@ -45,9 +44,13 @@ num_attention_heads = 128 qk_nope_head_dim = 128 qk_rope_head_dim = 64 +v_head_dim = qk_nope_head_dim +sm_scale = (qk_nope_head_dim + qk_rope_head_dim) ** (-0.5) kv_lora_rank = 512 -dtype = None +dtype = "float16" +dtype_torch = getattr(torch, dtype) device = tvm.cuda() +device_torch = torch.device("cuda") fclear = None fadd_sequence = None @@ -56,32 +59,39 @@ fpopn = None fbegin_forward = None fend_forward = None -fmla_absorbed = None +fself_attn = None +fcross_attn = None +fappend_mla_kv = None +fkv_merge_attn_output = None fis_empty = None fdebug_get_kv = None ftranspose_append = None fcopy_cache = None -fattn_prefill = None -fattn_decode = None -fattn_prefill_ragged_absorbed = None +fmla_prefill = None +fmla_prefill_ragged = None fmerge_state = None fcopy_single_page = None +w_kv = None +w_uk = None +w_uv = None + # Register a dumb function for testing purpose. -@tvm.register_func("test.dumb_function") +@tvm.register_func("test.dumb_function", override=True) def _dumb_function(): - pass + raise RuntimeError("Dumb function isn't supposed to be accessed.") def set_global_func(dtype): global fclear, fadd_sequence, fremove_sequence, ffork_sequence global fpopn, fbegin_forward, fend_forward - global fmla_absorbed, fis_empty, fdebug_get_kv - global ftranspose_append, fcopy_cache, fattn_prefill, fattn_decode - global fattn_prefill_ragged_absorbed - global fmerge_state, fcopy_single_page + global fself_attn, fcross_attn, fappend_mla_kv, fkv_merge_attn_output + global fis_empty, fdebug_get_kv + global ftranspose_append, fcopy_cache, fmla_prefill, fmla_prefill_ragged + global fmerge_state, fmerge_state_additional, fcopy_single_page + global w_kv, w_uk, w_uv fclear = tvm.get_global_func("vm.builtin.kv_state_clear") fadd_sequence = tvm.get_global_func("vm.builtin.kv_state_add_sequence") @@ -90,25 +100,34 @@ def set_global_func(dtype): fpopn = tvm.get_global_func("vm.builtin.kv_state_popn") fbegin_forward = tvm.get_global_func("vm.builtin.kv_state_begin_forward") fend_forward = tvm.get_global_func("vm.builtin.kv_state_end_forward") - fmla_absorbed = tvm.get_global_func("vm.builtin.attention_kv_cache_mla_absorbed") + fself_attn = tvm.get_global_func("vm.builtin.attention_kv_cache_self_attention") + fcross_attn = tvm.get_global_func("vm.builtin.attention_kv_cache_cross_attention") + fappend_mla_kv = tvm.get_global_func("vm.builtin.attention_kv_cache_append_mla_kv") + fkv_merge_attn_output = tvm.get_global_func( + "vm.builtin.attention_kv_cache_merge_attn_output_inplace" + ) fis_empty = tvm.get_global_func("vm.builtin.attention_kv_cache_empty") fdebug_get_kv = tvm.get_global_func("vm.builtin.attention_kv_cache_debug_get_kv_mla") target = tvm.target.Target.from_device(device) builts = [] for tir_func in [ - _kv_cache_transpose_append_mla(kv_lora_rank, qk_rope_head_dim, dtype), + _kv_cache_transpose_append_mla(kv_lora_rank + qk_rope_head_dim, dtype), _kv_cache_debug_get_kv_mla(num_layers, kv_lora_rank + qk_rope_head_dim, dtype), _attention_prefill_mla( num_attention_heads, kv_lora_rank, qk_rope_head_dim, dtype, False, target ), - _attention_decode_mla( - num_attention_heads, kv_lora_rank, qk_rope_head_dim, dtype, False, target - ), - _attention_prefill_ragged_mla_absorbed( - num_attention_heads, kv_lora_rank, qk_rope_head_dim, dtype, target + _attention_prefill_ragged( + num_attention_heads, + num_attention_heads, + qk_nope_head_dim + qk_rope_head_dim, + v_head_dim, + dtype, + {}, + target, ), _merge_state_inplace(num_attention_heads, kv_lora_rank, dtype, target), + _merge_state_inplace(num_attention_heads, v_head_dim, dtype, target), _copy_single_page_mla(page_size, kv_lora_rank + qk_rope_head_dim, dtype, target), ]: mod = tvm.IRModule({"main": tir_func}) @@ -120,16 +139,30 @@ def set_global_func(dtype): ( ftranspose_append, fcopy_cache, - fattn_prefill, - fattn_decode, - fattn_prefill_ragged_absorbed, + fmla_prefill, + fmla_prefill_ragged, fmerge_state, + fmerge_state_additional, fcopy_single_page, ) = builts + w_kv = torch.empty( + (kv_lora_rank, num_attention_heads * (qk_nope_head_dim + v_head_dim)), + device=device_torch, + dtype=dtype_torch, + ) + w_kv.uniform_(-0.1, 0.1) + w_uk, w_uv = torch.split( + w_kv.view(kv_lora_rank, num_attention_heads, qk_nope_head_dim + v_head_dim), + [qk_nope_head_dim, v_head_dim], + dim=2, + ) + w_uk = w_uk.permute(1, 2, 0) + w_uv = w_uv.permute(1, 0, 2) + def create_kv_cache(dtype): - fcreate = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_create_reduced_mla") + fcreate = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_create") fdumb = tvm.get_global_func("test.dumb_function") cache = fcreate( tvm.runtime.ShapeTuple( @@ -143,49 +176,40 @@ def create_kv_cache(dtype): ), tvm.runtime.ShapeTuple([0, num_layers]), num_attention_heads, - 1, + 1, # num_kv_heads kv_lora_rank + qk_rope_head_dim, kv_lora_rank, - qk_rope_head_dim, tvm.runtime.ShapeTuple([int(AttnKind.MLA) for _ in range(num_layers)]), + False, # enable_kv_transfer RopeMode.NONE, 1, 10000, + None, # rope_ext_factors tvm.nd.empty((), dtype, device=device), - fdumb, + None, # f_transpose_append_mha ftranspose_append, - fdumb, - fdumb, - fdumb, - fdumb, - fdumb, - 0, - 0, - 0, - 0, - 0, - 0, - fattn_prefill, - fattn_decode, - fdumb, - fattn_prefill_ragged_absorbed, - fmerge_state, - fdumb, + ["tir", fmla_prefill_ragged], # fattn_prefill_ragged + [], # fattn_prefill + [], # fattn_decode + [], # fattn_prefill_sliding_window + [], # fattn_decode_sliding_window + [], # fattn_prefill_with_tree_mask_paged_kv_cache + [], # fattn_prefill_with_tree_mask + ["tir", fmla_prefill], + [fmerge_state, fmerge_state_additional], + fdumb, # fsplit_rotary fcopy_single_page, fcopy_cache, - fdumb, - fdumb, - fdumb, - None, - False, + fdumb, # fcompact_copy ) return cache @pytest.fixture(params=itertools.product(["float16"])) def kv_cache_and_config(request): - global dtype + global dtype, dtype_torch (dtype,) = request.param + dtype_torch = getattr(torch, dtype) set_global_func(dtype) return (create_kv_cache(dtype),) @@ -196,13 +220,15 @@ def verify_cached_kv(kv_cache, seq_ids, expected_kv): seq_length = expected_kv[seq_id].shape[1] kv_actual = tvm.nd.empty(kv_expected.shape, dtype=dtype, device=device) fdebug_get_kv(kv_cache, seq_id, 0, seq_length, kv_actual) - tvm.testing.assert_allclose(kv_actual.numpy(), kv_expected, rtol=1e-3, atol=1e-3) + torch.testing.assert_close( + torch.from_numpy(kv_actual.numpy()).to(device_torch), kv_expected, rtol=1e-3, atol=1e-3 + ) def apply_attention( kv_cache, batch: List[Tuple[Union[int, Tuple[int, int, int]], int]], - cached_kv: Dict[int, np.ndarray], + cached_kv: Dict[int, torch.Tensor], ) -> None: seq_ids = [] append_lengths = [] @@ -224,72 +250,156 @@ def apply_attention( cached_kv[seq_id] = cached_kv[fork_parent_id][::, :fork_pos] elif seq_id not in cached_kv: fadd_sequence(kv_cache, seq_id) - cached_kv[seq_id] = np.zeros((num_layers, 0, kv_lora_rank + qk_rope_head_dim), dtype) + cached_kv[seq_id] = torch.zeros( + (num_layers, 0, kv_lora_rank + qk_rope_head_dim), + dtype=dtype_torch, + device=device_torch, + ) fbegin_forward(kv_cache, ShapeTuple(seq_ids), ShapeTuple(append_lengths), None) - global_new_q = np.zeros( - (num_layers, 0, num_attention_heads, kv_lora_rank + qk_rope_head_dim), dtype + global_new_q = torch.zeros( + (num_layers, 0, num_attention_heads, qk_nope_head_dim + qk_rope_head_dim), + dtype=dtype_torch, + device=device_torch, + ) + global_new_kv = torch.zeros( + (num_layers, 0, kv_lora_rank + qk_rope_head_dim), + dtype=dtype_torch, + device=device_torch, ) - global_new_kv = np.zeros((num_layers, 0, kv_lora_rank + qk_rope_head_dim), dtype) q_array = [] + is_decode_request = True + all_new_sequences = True for i, (seq_id, append_length) in enumerate(batch): - new_q = np.random.rand( - num_layers, append_length, num_attention_heads, kv_lora_rank + qk_rope_head_dim + new_q_np = np.random.uniform( + -0.1, + 0.1, + size=( + num_layers, + append_length, + num_attention_heads, + qk_nope_head_dim + qk_rope_head_dim, + ), ).astype(dtype) - new_kv = np.random.rand(num_layers, append_length, kv_lora_rank + qk_rope_head_dim).astype( - dtype - ) - q_array.append(new_q) + new_kv_np = np.random.uniform( + -0.1, 0.1, size=(num_layers, append_length, kv_lora_rank + qk_rope_head_dim) + ).astype(dtype) + q_array.append(new_q_np) + + # Convert the numpy arrays to torch tensors on device. + new_q_tensor = torch.from_numpy(new_q_np).to(device_torch) + new_kv_tensor = torch.from_numpy(new_kv_np).to(device_torch) + + all_new_sequences = all_new_sequences and cached_kv[seq_id].shape[1] == 0 + cached_kv[seq_id] = torch.cat([cached_kv[seq_id], new_kv_tensor], dim=1) + global_new_q = torch.cat([global_new_q, new_q_tensor], dim=1) + global_new_kv = torch.cat([global_new_kv, new_kv_tensor], dim=1) - cached_kv[seq_id] = np.concatenate([cached_kv[seq_id], new_kv], axis=1) - global_new_q = np.concatenate([global_new_q, new_q], axis=1) - global_new_kv = np.concatenate([global_new_kv, new_kv], axis=1) + if append_length > 1: + is_decode_request = False for layer_id in range(num_layers): - queries_np = global_new_q[layer_id] - queries = tvm.nd.array(queries_np, device) - compressed_kv = tvm.nd.array(global_new_kv[layer_id][:, :kv_lora_rank], device) - k_pe = tvm.nd.array(global_new_kv[layer_id][:, kv_lora_rank:], device) - outputs = tvm.nd.empty( - (queries_np.shape[0], queries_np.shape[1], kv_lora_rank), dtype, device=device + queries = tvm.nd.array(global_new_q[layer_id].cpu().numpy(), device) + key_value = tvm.nd.array(global_new_kv[layer_id].cpu().numpy(), device) + total_seq_length = global_new_q[layer_id].shape[0] + outputs1 = tvm.nd.empty( + (total_seq_length, num_attention_heads, v_head_dim), dtype, device=device ) - fmla_absorbed(kv_cache, layer_id, 1.0, queries, compressed_kv, k_pe, outputs) + lse1 = tvm.nd.empty((total_seq_length, num_attention_heads), "float32", device=device) + outputs2 = tvm.nd.empty( + (total_seq_length, num_attention_heads, kv_lora_rank), dtype, device=device + ) + lse2 = tvm.nd.empty((total_seq_length, num_attention_heads), "float32", device=device) + + fappend_mla_kv(kv_cache, layer_id, key_value) + if not is_decode_request: + # Part 1. self-attention + latent, k_pe = torch.split( + global_new_kv[layer_id], [kv_lora_rank, qk_rope_head_dim], dim=1 + ) + keys, values = torch.split( + (latent @ w_kv).to(dtype_torch).reshape(total_seq_length, num_attention_heads, -1), + [qk_nope_head_dim, v_head_dim], + dim=2, + ) + k_pe_expanded = torch.unsqueeze(k_pe, 1).expand( + total_seq_length, num_attention_heads, qk_rope_head_dim + ) + keys = torch.cat([keys, k_pe_expanded], dim=2) + keys_tvm = tvm.nd.array(keys.cpu().numpy(), device) + values_tvm = tvm.nd.array(values.cpu().numpy(), device) + fself_attn(kv_cache, layer_id, sm_scale, queries, keys_tvm, values_tvm, outputs1, lse1) + + if not all_new_sequences or is_decode_request: + # Part 2. cross-attention + queries_lora_np, q_pe = torch.split( + global_new_q[layer_id], [qk_nope_head_dim, qk_rope_head_dim], dim=2 + ) + queries_lora_np = torch.cat( + [torch.bmm(queries_lora_np.permute(1, 0, 2), w_uk).permute(1, 0, 2), q_pe], dim=2 + ) + queries_lora = tvm.nd.array(queries_lora_np.cpu().numpy(), device) + fcross_attn(kv_cache, layer_id, sm_scale, queries_lora, outputs2, lse2) + cross_attn_output = tvm.nd.array( + torch.bmm( + torch.from_numpy(outputs2.numpy()).to(device_torch).permute(1, 0, 2), w_uv + ) + .permute(1, 0, 2) + .cpu() + .numpy(), + device, + ) + + if not is_decode_request: + if not all_new_sequences: + fkv_merge_attn_output(kv_cache, outputs1, lse1, cross_attn_output, lse2) + else: + outputs1 = cross_attn_output # Compute attention expected results. - outputs = np.expand_dims(outputs.numpy(), axis=0) + outputs = torch.unsqueeze(torch.tensor(outputs1.numpy()).to(device_torch), 0) sum_length = 0 for i, (seq_id, append_length) in enumerate(batch): assert cached_kv[seq_id].shape[1] >= append_length - q_seq = q_array[i][layer_id].transpose(1, 0, 2) - k_seq = np.expand_dims(cached_kv[seq_id][layer_id], axis=1).transpose(1, 2, 0) - v_seq = np.expand_dims(cached_kv[seq_id][layer_id], axis=1).transpose(1, 0, 2)[ - :, :, :kv_lora_rank - ] + q_seq = torch.from_numpy(q_array[i][layer_id]).to(device_torch).permute(1, 0, 2) + latent_seq, k_pe_seq = torch.split( + torch.unsqueeze(cached_kv[seq_id][layer_id], 1), + [kv_lora_rank, qk_rope_head_dim], + dim=2, + ) + k_seq, v_seq = torch.split( + (latent_seq @ w_kv).reshape(k_pe_seq.shape[0], num_attention_heads, -1), + [qk_nope_head_dim, v_head_dim], + dim=2, + ) + k_pe_seq = k_pe_seq.expand(k_pe_seq.shape[0], num_attention_heads, qk_rope_head_dim) + k_seq = torch.cat([k_seq, k_pe_seq], dim=2).permute(1, 2, 0) + v_seq = v_seq.permute(1, 0, 2) - k_seq = np.repeat(k_seq, num_attention_heads, axis=0) - v_seq = np.repeat(v_seq, num_attention_heads, axis=0) - softmax_input = q_seq.astype("float32") @ k_seq.astype("float32") + softmax_input = (q_seq.to(torch.float32) @ k_seq.to(torch.float32)) / torch.sqrt( + torch.tensor(qk_nope_head_dim + qk_rope_head_dim, dtype=torch.float32) + ) softmax_shape = softmax_input.shape assert softmax_shape[-2] == append_length length_diff = softmax_shape[-1] - softmax_shape[-2] assert length_diff >= 0 - mask = np.tril( - np.full_like(softmax_input, np.finfo("float32").max), k=length_diff - ) + np.triu(np.full_like(softmax_input, np.finfo("float32").min), k=length_diff + 1) - - softmax_input = np.minimum(softmax_input, mask) + # Create a mask similar to np.tril and np.triu. + mask = torch.tril( + torch.full_like(softmax_input, float(np.finfo(np.float32).max)), + diagonal=length_diff, + ) + torch.triu( + torch.full_like(softmax_input, float(np.finfo(np.float32).min)), + diagonal=length_diff + 1, + ) + softmax_input = torch.minimum(softmax_input, mask) - results = np.expand_dims( - (scipy.special.softmax(softmax_input, axis=-1) @ v_seq.astype("float32")).transpose( - 1, 0, 2 - ), - axis=0, - ).astype(dtype) + results = torch.nn.functional.softmax(softmax_input, dim=-1) @ v_seq.to(torch.float32) + results = results.permute(1, 0, 2).unsqueeze(0).to(dtype_torch) - tvm.testing.assert_allclose( + torch.testing.assert_close( outputs[:, sum_length : sum_length + append_length, ...], results, rtol=1e-3, @@ -445,12 +555,10 @@ def test_paged_attention_kv_cache_popn(kv_cache_and_config): if __name__ == "__main__": - DTYPES = ["float16"] - for (dtype,) in itertools.product(DTYPES): - set_global_func(dtype) - cache = create_kv_cache(dtype) - cache_and_config = (cache,) - test_paged_attention_kv_cache_prefill_and_decode(cache_and_config) - test_paged_attention_kv_cache_remove_sequence(cache_and_config) - test_paged_attention_kv_cache_fork_sequence(cache_and_config) - test_paged_attention_kv_cache_popn(cache_and_config) + set_global_func(dtype) + cache = create_kv_cache(dtype) + cache_and_config = (cache,) + test_paged_attention_kv_cache_prefill_and_decode(cache_and_config) + test_paged_attention_kv_cache_remove_sequence(cache_and_config) + test_paged_attention_kv_cache_fork_sequence(cache_and_config) + test_paged_attention_kv_cache_popn(cache_and_config) diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py index e30debabfede..70a016697758 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py @@ -17,14 +17,14 @@ import itertools from typing import Dict, List, Optional, Tuple, Union -import numpy as np import pytest -import scipy.special +import torch import tvm import tvm.testing from tvm import dlight as dl from tvm.relax.frontend.nn.llm.kv_cache import ( + AttnKind, RopeMode, _attention_decode, _attention_prefill, @@ -48,12 +48,14 @@ num_qo_heads = 32 num_kv_heads = 4 head_dim = None +sm_scale = None rope_scale = 1.0 rope_theta = 1e4 rope_scaling = {} dtype = None +dtype_torch = None device = tvm.cuda() - +device_torch = torch.device("cuda") fclear = None fadd_sequence = None fremove_sequence = None @@ -123,7 +125,7 @@ def set_global_func(head_dim, dtype): _attention_prefill(num_kv_heads, num_qo_heads, head_dim, dtype, True, rope_scaling, target), _attention_decode(num_kv_heads, num_qo_heads, head_dim, dtype, True, rope_scaling, target), _attention_prefill_ragged( - num_kv_heads, num_qo_heads, head_dim, dtype, rope_scaling, target + num_kv_heads, num_qo_heads, head_dim, head_dim, dtype, rope_scaling, target ), tree_attn(num_kv_heads, num_qo_heads, head_dim, dtype, rope_scaling, target), tree_attn_with_paged_kv_cache( @@ -160,7 +162,7 @@ def set_global_func(head_dim, dtype): def create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window): - fcreate = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_create_reduced") + fcreate = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_create") cache = fcreate( tvm.runtime.ShapeTuple( [ @@ -175,25 +177,29 @@ def create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window): num_qo_heads, num_kv_heads, head_dim, + head_dim, # v_head_dim + tvm.runtime.ShapeTuple([int(AttnKind.MHA) for _ in range(num_layers)]), + False, # enable_kv_transfer rope_mode, rope_scale, rope_theta, + None, # rope_ext_factors tvm.nd.empty((), dtype, device=device), ftranspose_append, - fattn_prefill, - fattn_decode, - fattn_prefill_sliding_window, - fattn_decode_sliding_window, - fattn_prefill_ragged, - fmerge_state, + None, # f_transpose_append_mla + ["tir", fattn_prefill_ragged], + ["tir", fattn_prefill], + ["tir", fattn_decode], + ["tir", fattn_prefill_sliding_window], + ["tir", fattn_decode_sliding_window], + ["tir", fattn_prefill_with_tree_mask_paged_kv_cache], + ["tir", fattn_prefill_with_tree_mask], + [], # f_mla_prefill + [fmerge_state], fsplit_rotary, fcopy_single_page, fcopy_cache, fcompact_copy, - fattn_prefill_with_tree_mask, - fattn_prefill_with_tree_mask_paged_kv_cache, - None, - False, ) return cache @@ -215,8 +221,10 @@ def create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window): ) ) def kv_cache_and_config(request): - global head_dim, dtype + global head_dim, sm_scale, dtype, dtype_torch head_dim, dtype, rope_mode, support_sliding_window = request.param + dtype_torch = getattr(torch, dtype) + sm_scale = head_dim ** (-0.5) set_global_func(head_dim, dtype) return create_kv_cache(*request.param), rope_mode, support_sliding_window @@ -230,8 +238,12 @@ def verify_cached_kv(kv_cache, seq_ids, expected_k, expected_v): keys = tvm.nd.empty(keys_expected.shape, dtype=dtype, device=device) values = tvm.nd.empty(values_expected.shape, dtype=dtype, device=device) fdebug_get_kv(kv_cache, seq_id, 0, seq_length, keys, values) - tvm.testing.assert_allclose(keys.numpy(), keys_expected, rtol=1e-3, atol=1e-3) - tvm.testing.assert_allclose(values.numpy(), values_expected, rtol=1e-3, atol=1e-3) + torch.testing.assert_close( + torch.from_numpy(keys.numpy()).to(device_torch), keys_expected, rtol=1e-3, atol=1e-3 + ) + torch.testing.assert_close( + torch.from_numpy(values.numpy()).to(device_torch), values_expected, rtol=1e-3, atol=1e-3 + ) def f_apply_rotary(x, offset, scale, theta, offset_list: Optional[List[int]] = None): @@ -239,29 +251,34 @@ def f_apply_rotary(x, offset, scale, theta, offset_list: Optional[List[int]] = N assert len(x.shape) == 3 nfeat = x.shape[-1] nfeat_half = x.shape[-1] // 2 - x = x.astype("float32") - y = np.concatenate([-x[:, :, nfeat_half:], x[:, :, :nfeat_half]], axis=-1) + x_dtype = x.dtype + x = x.to(torch.float32) + y = torch.cat([-x[:, :, nfeat_half:], x[:, :, :nfeat_half]], dim=-1) - inv_freq = scale / (theta ** (np.arange(0, nfeat, 2).astype("float32") / nfeat)) + inv_freq = scale / ( + theta ** (torch.arange(0, nfeat, 2, device=device_torch, dtype=torch.float32) / nfeat) + ) t = ( - np.arange(offset, offset + x.shape[0], dtype=inv_freq.dtype) + torch.arange(offset, offset + x.shape[0], device=device_torch, dtype=inv_freq.dtype) if offset_list is None - else (np.array(offset_list, dtype=inv_freq.dtype) + offset) + else (torch.tensor(offset_list, dtype=inv_freq.dtype, device=device_torch) + offset) ) - freqs = np.einsum("i,j->ij", t, inv_freq) - emb = np.concatenate((freqs, freqs), axis=-1) - cos_values = np.cos(emb) - sin_values = np.sin(emb) + freqs = torch.einsum("i,j->ij", t, inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + cos_values = torch.cos(emb) + sin_values = torch.sin(emb) - return np.einsum("ij,ikj->ikj", cos_values, x) + np.einsum("ij,ikj->ikj", sin_values, y) + return torch.einsum("ij,ikj->ikj", cos_values, x).to(x_dtype) + torch.einsum( + "ij,ikj->ikj", sin_values, y + ).to(x_dtype) def apply_attention( kv_cache, rope_mode: RopeMode, batch: List[Tuple[Union[int, Tuple[int, int, int]], int]], - cached_k: Dict[int, np.ndarray], - cached_v: Dict[int, np.ndarray], + cached_k: Dict[int, torch.Tensor], + cached_v: Dict[int, torch.Tensor], sliding_window_sizes: Optional[List[int]] = None, attn_sink_sizes: Optional[List[int]] = None, token_tree_parent_ptr_list: Optional[List[List[int]]] = None, @@ -289,8 +306,12 @@ def apply_attention( cached_v[seq_id] = cached_v[fork_parent_id][::, :fork_pos] elif seq_id not in cached_k: fadd_sequence(kv_cache, seq_id) - cached_k[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype) - cached_v[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype) + cached_k[seq_id] = torch.zeros( + (num_layers, 0, num_kv_heads, head_dim), dtype=dtype_torch, device=device_torch + ) + cached_v[seq_id] = torch.zeros( + (num_layers, 0, num_kv_heads, head_dim), dtype=dtype_torch, device=device_torch + ) flattened_token_tree_parent_ptr = None token_tree_node_depths_list: List[Optional[List[int]]] = [None for _ in batch] @@ -325,15 +346,45 @@ def apply_attention( ), ) - global_new_q = np.zeros((num_layers, 0, num_qo_heads, head_dim), dtype) - global_new_k = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype) - global_new_v = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype) + global_new_q = torch.zeros( + (num_layers, 0, num_qo_heads, head_dim), dtype=dtype_torch, device=device_torch + ) + global_new_k = torch.zeros( + (num_layers, 0, num_kv_heads, head_dim), dtype=dtype_torch, device=device_torch + ) + global_new_v = torch.zeros( + (num_layers, 0, num_kv_heads, head_dim), dtype=dtype_torch, device=device_torch + ) q_array = [] for i, (seq_id, append_length) in enumerate(batch): - new_q = np.random.rand(num_layers, append_length, num_qo_heads, head_dim).astype(dtype) - new_k = np.random.rand(num_layers, append_length, num_kv_heads, head_dim).astype(dtype) - new_v = np.random.rand(num_layers, append_length, num_kv_heads, head_dim).astype(dtype) + new_q = torch.rand( + num_layers, + append_length, + num_qo_heads, + head_dim, + dtype=dtype_torch, + device=device_torch, + ) + new_k = torch.rand( + num_layers, + append_length, + num_kv_heads, + head_dim, + dtype=dtype_torch, + device=device_torch, + ) + new_v = torch.rand( + num_layers, + append_length, + num_kv_heads, + head_dim, + dtype=dtype_torch, + device=device_torch, + ) + new_q = new_q * 2 - 1 + new_k = new_k * 2 - 1 + new_v = new_v * 2 - 1 q_array.append(new_q) rope_offset = cached_k[seq_id].shape[1] @@ -341,10 +392,10 @@ def apply_attention( prev_tree_size = len(token_tree_parent_ptr_list[i]) - append_length assert prev_tree_size >= 0 rope_offset -= prev_tree_size - cached_k[seq_id] = np.concatenate( + cached_k[seq_id] = torch.cat( [ cached_k[seq_id], - np.stack( + torch.stack( [ ( new_k[l] @@ -363,26 +414,26 @@ def apply_attention( ) for l in range(num_layers) ], - axis=0, + dim=0, ), ], - axis=1, + dim=1, ) - cached_v[seq_id] = np.concatenate([cached_v[seq_id], new_v], axis=1) - global_new_q = np.concatenate([global_new_q, new_q], axis=1) - global_new_k = np.concatenate([global_new_k, new_k], axis=1) - global_new_v = np.concatenate([global_new_v, new_v], axis=1) + cached_v[seq_id] = torch.cat([cached_v[seq_id], new_v], dim=1) + global_new_q = torch.cat([global_new_q, new_q], dim=1) + global_new_k = torch.cat([global_new_k, new_k], dim=1) + global_new_v = torch.cat([global_new_v, new_v], dim=1) for layer_id in range(num_layers): queries_np = global_new_q[layer_id] keys_np = global_new_k[layer_id] values_np = global_new_v[layer_id] - qkv = tvm.nd.array(np.concatenate([queries_np, keys_np, values_np], axis=1), device) + qkv = tvm.nd.array(torch.cat([queries_np, keys_np, values_np], dim=1).cpu().numpy(), device) outputs = tvm.nd.empty(queries_np.shape, dtype, device=device) - fattention_with_fuse_qkv(kv_cache, layer_id, 1.0, qkv, outputs) + fattention_with_fuse_qkv(kv_cache, layer_id, sm_scale, qkv, outputs) # Compute attention expected results. - outputs = np.expand_dims(outputs.numpy(), axis=0) + outputs = torch.from_numpy(outputs.numpy()).unsqueeze(0).to(device_torch) sum_length = 0 for i, (seq_id, append_length) in enumerate(batch): assert cached_k[seq_id].shape[1] == cached_v[seq_id].shape[1] >= append_length @@ -406,7 +457,7 @@ def apply_attention( else None ), ) - ).transpose(1, 0, 2) + ).permute(1, 0, 2) k_seq = ( cached_k[seq_id][layer_id] if rope_mode != RopeMode.INLINE @@ -424,41 +475,47 @@ def apply_attention( else None ), ) - ).transpose(1, 2, 0) - v_seq = cached_v[seq_id][layer_id].transpose(1, 0, 2) + ).permute(1, 2, 0) + v_seq = cached_v[seq_id][layer_id].permute(1, 0, 2) - k_seq = np.repeat(k_seq, num_qo_heads // num_kv_heads, axis=0) - v_seq = np.repeat(v_seq, num_qo_heads // num_kv_heads, axis=0) - softmax_input = (q_seq.astype("float32") @ k_seq.astype("float32")) / np.sqrt(head_dim) + k_seq = k_seq.repeat_interleave(num_qo_heads // num_kv_heads, dim=0) + v_seq = v_seq.repeat_interleave(num_qo_heads // num_kv_heads, dim=0) + softmax_input = (q_seq.to(torch.float32) @ k_seq.to(torch.float32)) / (head_dim**0.5) softmax_shape = softmax_input.shape assert softmax_shape[-2] == append_length length_diff = softmax_shape[-1] - softmax_shape[-2] assert length_diff >= 0 - mask = np.tril( - np.full_like(softmax_input, np.finfo("float32").max), k=length_diff - ) + np.triu(np.full_like(softmax_input, np.finfo("float32").min), k=length_diff + 1) + mask = torch.tril( + torch.full_like(softmax_input, torch.finfo(torch.float32).max), diagonal=length_diff + ) + torch.triu( + torch.full_like(softmax_input, torch.finfo(torch.float32).min), + diagonal=length_diff + 1, + ) if token_tree_parent_ptr_list is not None: tree_size = len(token_tree_parent_ptr_list[i]) - tree_mask = np.full( - (tree_size, tree_size), np.finfo("float32").min, dtype="float32" + tree_mask = torch.full( + (tree_size, tree_size), + torch.finfo(torch.float32).min, + dtype=torch.float32, + device=device_torch, ) for i, parent in enumerate(token_tree_parent_ptr_list[i]): if parent != -1: tree_mask[i] = tree_mask[parent] - tree_mask[i, i] = np.finfo("float32").max - tree_mask = np.broadcast_to(tree_mask, (num_qo_heads, *tree_mask.shape)) + tree_mask[i, i] = torch.finfo(torch.float32).max + tree_mask = tree_mask.expand(num_qo_heads, *tree_mask.shape) mask[:, :, -tree_size:] = tree_mask[:, -append_length:, :] - softmax_input = np.minimum(softmax_input, mask) + softmax_input = torch.minimum(softmax_input, mask) - results = np.expand_dims( - (scipy.special.softmax(softmax_input, axis=-1) @ v_seq.astype("float32")).transpose( - 1, 0, 2 - ), - axis=0, - ).astype(dtype) + results = torch.unsqueeze( + ( + torch.nn.functional.softmax(softmax_input, dim=-1) @ v_seq.to(torch.float32) + ).permute(1, 0, 2), + dim=0, + ).to(dtype_torch) - tvm.testing.assert_allclose( + torch.testing.assert_close( outputs[:, sum_length : sum_length + append_length, ...], results, rtol=1e-3, @@ -506,19 +563,19 @@ def apply_attention( if cached_k[seq_id].shape[1] > sliding_window_size: # Apply sliding window and sink to cached kv. length_to_slide = cached_k[seq_id].shape[1] - sliding_window_size - cached_k[seq_id] = np.concatenate( + cached_k[seq_id] = torch.cat( [ cached_k[seq_id][:, :attn_sink_size, ...], cached_k[seq_id][:, attn_sink_size + length_to_slide :, ...], ], - axis=1, + dim=1, ) - cached_v[seq_id] = np.concatenate( + cached_v[seq_id] = torch.cat( [ cached_v[seq_id][:, :attn_sink_size, ...], cached_v[seq_id][:, attn_sink_size + length_to_slide :, ...], ], - axis=1, + dim=1, ) assert cached_k[seq_id].shape[1] == sliding_window_size @@ -759,8 +816,12 @@ def test_paged_attention_kv_cache_sliding_window(kv_cache_and_config): ): fadd_sequence(kv_cache, seq_id) fenable_sliding_window_for_seq(kv_cache, seq_id, sliding_window_size, attn_sink_size) - cached_k[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype) - cached_v[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype) + cached_k[seq_id] = torch.zeros( + (num_layers, 0, num_kv_heads, head_dim), dtype=dtype_torch, device=device_torch + ) + cached_v[seq_id] = torch.zeros( + (num_layers, 0, num_kv_heads, head_dim), dtype=dtype_torch, device=device_torch + ) # Prefill. operation_seq = [[(0, 4)], [(1, 6)], [(2, 6), (3, 7), (4, 7)]] @@ -807,8 +868,12 @@ def test_paged_attention_kv_cache_sliding_window_fork(kv_cache_and_config): ): fadd_sequence(kv_cache, seq_id) fenable_sliding_window_for_seq(kv_cache, seq_id, sliding_window_size, attn_sink_size) - cached_k[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype) - cached_v[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype) + cached_k[seq_id] = torch.zeros( + (num_layers, 0, num_kv_heads, head_dim), dtype=dtype_torch, device=device_torch + ) + cached_v[seq_id] = torch.zeros( + (num_layers, 0, num_kv_heads, head_dim), dtype=dtype_torch, device=device_torch + ) apply_attention( kv_cache, rope_mode, @@ -951,6 +1016,8 @@ def test_paged_attention_kv_cache_tree_attn(kv_cache_and_config): for head_dim, dtype, rope_mode, support_sliding_window in itertools.product( HEAD_DIMS, DTYPES, ROPE_MODES, SUPPORT_SLIDING_WINDOW ): + dtype_torch = getattr(torch, dtype) + sm_scale = head_dim ** (-0.5) set_global_func(head_dim, dtype) cache = create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window) cache_and_config = (cache, rope_mode, support_sliding_window)