diff --git a/aiter/ops/triton/fused_kv_cache.py b/aiter/ops/triton/fused_kv_cache.py index 51e707ce07..8aa74c7147 100644 --- a/aiter/ops/triton/fused_kv_cache.py +++ b/aiter/ops/triton/fused_kv_cache.py @@ -1,15 +1,74 @@ import torch import triton +from typing import Tuple from aiter.ops.triton._triton_kernels.fused_kv_cache import ( _fused_qk_rope_cat_and_cache_mla_kernel, _fused_qk_rope_reshape_and_cache_kernel, _fused_qk_rope_cosine_cache_llama_kernel, ) +from aiter.jit.utils.torch_guard import torch_compile_guard from aiter.ops.triton.utils.logger import AiterTritonLogger _LOGGER = AiterTritonLogger() +def fused_qk_rope_cat_and_cache_mla_fake_tensor( + q_nope: torch.Tensor, + q_pe: torch.Tensor, + k_nope: torch.Tensor, + k_pe: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + pos: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + k_scale: torch.Tensor, + is_neox: bool, + num_decode_toks_for_zeros: int = 0, + apply_scale: bool = True, + q_out: torch.Tensor = None, + decode_q_pe_out: torch.Tensor = None, + k_pe_out: torch.Tensor = None, + q_out_dtype: torch.dtype = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + b, qh, d_nope = q_nope.shape + _, _, d_pe = q_pe.shape + bk, kh, dk_nope = k_nope.shape + + if q_out is None: + q_out = torch.empty( + (b, qh, d_nope + d_pe), + dtype=q_out_dtype if q_out_dtype is not None else q_nope.dtype, + device=q_nope.device, + ) + + if decode_q_pe_out is None: + decode_q_pe_out = torch.empty( + (num_decode_toks_for_zeros, qh, d_pe), + dtype=q_nope.dtype, + device=q_nope.device, + ) + + if k_pe_out is None: + k_pe_out = torch.empty((bk, kh, d_pe), dtype=k_pe.dtype, device=k_pe.device) + + if num_decode_toks_for_zeros > 0: + q_nope_zeros_out = torch.empty( + (num_decode_toks_for_zeros, qh, dk_nope), + dtype=q_nope.dtype, + device=q_nope.device, + ) + else: + q_nope_zeros_out = torch.empty( + (0, qh, dk_nope), + dtype=q_nope.dtype, + device=q_nope.device, + ) + + return q_out, decode_q_pe_out, k_pe_out, kv_cache, q_nope_zeros_out + + +@torch_compile_guard(gen_fake=fused_qk_rope_cat_and_cache_mla_fake_tensor) def fused_qk_rope_cat_and_cache_mla( q_nope: torch.Tensor, q_pe: torch.Tensor, @@ -27,8 +86,8 @@ def fused_qk_rope_cat_and_cache_mla( q_out: torch.Tensor = None, decode_q_pe_out: torch.Tensor = None, k_pe_out: torch.Tensor = None, - q_out_dtype=None, -): + q_out_dtype: torch.dtype = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Perform RoPE on q_pe and k_pe and concat q_nope with q_pe and k_nope with k_pe along the last dimension the concatentaed k_nope and k_pe are copied to kv_cache inplace @@ -75,6 +134,9 @@ def fused_qk_rope_cat_and_cache_mla( assert (d_freq == d_pe // 2) or ( d_freq == d_pe ), "cos/sin last dim should be the same or half of the qk last dim" + assert ( + num_decode_toks_for_zeros >= 0 + ), "num_decode_toks_for_zeros must be non-negative to avoid invalid tensor creation" if isinstance(k_scale, torch.Tensor): assert k_scale.numel() == 1, "k_scale should be a single-element torch.Tensor" reuse_freqs_front_part = d_freq == d_pe // 2 @@ -169,9 +231,14 @@ def fused_qk_rope_cat_and_cache_mla( num_warps=1, ) - if num_decode_toks_for_zeros > 0: - return q_out, decode_q_pe_out, k_pe_out, kv_cache, q_nope_zeros_out - return q_out, decode_q_pe_out, k_pe_out, kv_cache + if q_nope_zeros_out is None: + # change q_nope_zeros_out from None to a tensor for torch compile + q_nope_zeros_out = torch.empty( + (num_decode_toks_for_zeros, qh, dk_nope), + dtype=q_nope.dtype, + device=q_nope.device, + ) + return q_out, decode_q_pe_out, k_pe_out, kv_cache, q_nope_zeros_out def fused_qk_rope_reshape_and_cache(