Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 72 additions & 5 deletions aiter/ops/triton/fused_kv_cache.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,74 @@
import torch
import triton
from typing import Tuple
from aiter.ops.triton._triton_kernels.fused_kv_cache import (
_fused_qk_rope_cat_and_cache_mla_kernel,
_fused_qk_rope_reshape_and_cache_kernel,
_fused_qk_rope_cosine_cache_llama_kernel,
)
from aiter.jit.utils.torch_guard import torch_compile_guard
from aiter.ops.triton.utils.logger import AiterTritonLogger

_LOGGER = AiterTritonLogger()


def fused_qk_rope_cat_and_cache_mla_fake_tensor(
q_nope: torch.Tensor,
q_pe: torch.Tensor,
k_nope: torch.Tensor,
k_pe: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
pos: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
k_scale: torch.Tensor,
is_neox: bool,
num_decode_toks_for_zeros: int = 0,
apply_scale: bool = True,
q_out: torch.Tensor = None,
decode_q_pe_out: torch.Tensor = None,
k_pe_out: torch.Tensor = None,
q_out_dtype: torch.dtype = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
b, qh, d_nope = q_nope.shape
_, _, d_pe = q_pe.shape
bk, kh, dk_nope = k_nope.shape

if q_out is None:
q_out = torch.empty(
(b, qh, d_nope + d_pe),
dtype=q_out_dtype if q_out_dtype is not None else q_nope.dtype,
device=q_nope.device,
)

if decode_q_pe_out is None:
decode_q_pe_out = torch.empty(
(num_decode_toks_for_zeros, qh, d_pe),
dtype=q_nope.dtype,
device=q_nope.device,
)

if k_pe_out is None:
k_pe_out = torch.empty((bk, kh, d_pe), dtype=k_pe.dtype, device=k_pe.device)

if num_decode_toks_for_zeros > 0:
q_nope_zeros_out = torch.empty(
(num_decode_toks_for_zeros, qh, dk_nope),
dtype=q_nope.dtype,
device=q_nope.device,
)
else:
q_nope_zeros_out = torch.empty(
(0, qh, dk_nope),
dtype=q_nope.dtype,
device=q_nope.device,
)

return q_out, decode_q_pe_out, k_pe_out, kv_cache, q_nope_zeros_out


@torch_compile_guard(gen_fake=fused_qk_rope_cat_and_cache_mla_fake_tensor)
def fused_qk_rope_cat_and_cache_mla(
q_nope: torch.Tensor,
q_pe: torch.Tensor,
Expand All @@ -27,8 +86,8 @@ def fused_qk_rope_cat_and_cache_mla(
q_out: torch.Tensor = None,
decode_q_pe_out: torch.Tensor = None,
k_pe_out: torch.Tensor = None,
q_out_dtype=None,
):
q_out_dtype: torch.dtype = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Perform RoPE on q_pe and k_pe and concat q_nope with q_pe and k_nope with k_pe along the last dimension
the concatentaed k_nope and k_pe are copied to kv_cache inplace
Expand Down Expand Up @@ -75,6 +134,9 @@ def fused_qk_rope_cat_and_cache_mla(
assert (d_freq == d_pe // 2) or (
d_freq == d_pe
), "cos/sin last dim should be the same or half of the qk last dim"
assert (
num_decode_toks_for_zeros >= 0
), "num_decode_toks_for_zeros must be non-negative to avoid invalid tensor creation"
if isinstance(k_scale, torch.Tensor):
assert k_scale.numel() == 1, "k_scale should be a single-element torch.Tensor"
reuse_freqs_front_part = d_freq == d_pe // 2
Expand Down Expand Up @@ -169,9 +231,14 @@ def fused_qk_rope_cat_and_cache_mla(
num_warps=1,
)

if num_decode_toks_for_zeros > 0:
return q_out, decode_q_pe_out, k_pe_out, kv_cache, q_nope_zeros_out
return q_out, decode_q_pe_out, k_pe_out, kv_cache
if q_nope_zeros_out is None:
# change q_nope_zeros_out from None to a tensor for torch compile
q_nope_zeros_out = torch.empty(
(num_decode_toks_for_zeros, qh, dk_nope),
dtype=q_nope.dtype,
device=q_nope.device,
)
return q_out, decode_q_pe_out, k_pe_out, kv_cache, q_nope_zeros_out


def fused_qk_rope_reshape_and_cache(
Expand Down