From 50eca1290fd7c7d50847c1f9aac42f829a3935ae Mon Sep 17 00:00:00 2001 From: Marvin Tsai <62472426+mqhc2020@users.noreply.github.com> Date: Thu, 18 Dec 2025 12:04:26 +0000 Subject: [PATCH 1/5] add gen_fake for MLA RoPE operator --- aiter/ops/triton/fused_kv_cache.py | 69 +++++++++++++++++++++++++++--- 1 file changed, 64 insertions(+), 5 deletions(-) diff --git a/aiter/ops/triton/fused_kv_cache.py b/aiter/ops/triton/fused_kv_cache.py index 51e707ce07..87adf75a9f 100644 --- a/aiter/ops/triton/fused_kv_cache.py +++ b/aiter/ops/triton/fused_kv_cache.py @@ -1,15 +1,67 @@ import torch import triton +from typing import Tuple from aiter.ops.triton._triton_kernels.fused_kv_cache import ( _fused_qk_rope_cat_and_cache_mla_kernel, _fused_qk_rope_reshape_and_cache_kernel, _fused_qk_rope_cosine_cache_llama_kernel, ) +from aiter.jit.utils.torch_guard import torch_compile_guard from aiter.ops.triton.utils.logger import AiterTritonLogger _LOGGER = AiterTritonLogger() +def fused_qk_rope_cat_and_cache_mla_fake_tensor( + q_nope: torch.Tensor, + q_pe: torch.Tensor, + k_nope: torch.Tensor, + k_pe: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + pos: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + k_scale: torch.Tensor, + is_neox: bool, + num_decode_toks_for_zeros: int = 0, + apply_scale: bool = True, + q_out: torch.Tensor = None, + decode_q_pe_out: torch.Tensor = None, + k_pe_out: torch.Tensor = None, + q_out_dtype: torch.dtype=None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + b, qh, d_nope = q_nope.shape + _, _, d_pe = q_pe.shape + bk, kh, dk_nope = k_nope.shape + + if q_out is None: + q_out = torch.empty( + (b, qh, d_nope + d_pe), + dtype=q_out_dtype if q_out_dtype is not None else q_nope.dtype, + device=q_nope.device, + ) + + if decode_q_pe_out is None: + decode_q_pe_out = torch.empty( + (num_decode_toks_for_zeros, qh, d_pe), + dtype=q_nope.dtype, + device=q_nope.device, + ) + + if k_pe_out is None: + k_pe_out = torch.empty((bk, kh, d_pe), dtype=k_pe.dtype, device=k_pe.device) + + q_nope_zeros_out = torch.empty( + (num_decode_toks_for_zeros, qh, dk_nope), + dtype=q_nope.dtype, + device=q_nope.device, + ) + + return q_out, decode_q_pe_out, k_pe_out, kv_cache, q_nope_zeros_out + + +@torch_compile_guard(gen_fake=fused_qk_rope_cat_and_cache_mla_fake_tensor) def fused_qk_rope_cat_and_cache_mla( q_nope: torch.Tensor, q_pe: torch.Tensor, @@ -27,8 +79,8 @@ def fused_qk_rope_cat_and_cache_mla( q_out: torch.Tensor = None, decode_q_pe_out: torch.Tensor = None, k_pe_out: torch.Tensor = None, - q_out_dtype=None, -): + q_out_dtype: torch.dtype=None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Perform RoPE on q_pe and k_pe and concat q_nope with q_pe and k_nope with k_pe along the last dimension the concatentaed k_nope and k_pe are copied to kv_cache inplace @@ -75,6 +127,7 @@ def fused_qk_rope_cat_and_cache_mla( assert (d_freq == d_pe // 2) or ( d_freq == d_pe ), "cos/sin last dim should be the same or half of the qk last dim" + assert num_decode_toks_for_zeros >= 0, "in case tensor creation failure" if isinstance(k_scale, torch.Tensor): assert k_scale.numel() == 1, "k_scale should be a single-element torch.Tensor" reuse_freqs_front_part = d_freq == d_pe // 2 @@ -113,6 +166,7 @@ def fused_qk_rope_cat_and_cache_mla( bk == b_k_pe_out and kh == hk_k_pe_out and d_pe == d_k_pe_out ), "k_pe_out shape mismatch, expected (bk, kh, d_pe)" + q_nope_zeros_out = None if num_decode_toks_for_zeros > 0: q_nope_zeros_out = torch.empty( @@ -169,9 +223,14 @@ def fused_qk_rope_cat_and_cache_mla( num_warps=1, ) - if num_decode_toks_for_zeros > 0: - return q_out, decode_q_pe_out, k_pe_out, kv_cache, q_nope_zeros_out - return q_out, decode_q_pe_out, k_pe_out, kv_cache + if q_nope_zeros_out == None: + # change q_nope_zeros_out from None to a tensor for torch compile + q_nope_zeros_out = torch.empty( + (num_decode_toks_for_zeros, qh, dk_nope), + dtype=q_nope.dtype, + device=q_nope.device, + ) + return q_out, decode_q_pe_out, k_pe_out, kv_cache, q_nope_zeros_out def fused_qk_rope_reshape_and_cache( From 835e9dc4e8905ff4c00283a97a4e518266d3afe3 Mon Sep 17 00:00:00 2001 From: Marvin Tsai <62472426+mqhc2020@users.noreply.github.com> Date: Tue, 23 Dec 2025 22:26:01 -0600 Subject: [PATCH 2/5] fix code stype --- aiter/ops/triton/fused_kv_cache.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/aiter/ops/triton/fused_kv_cache.py b/aiter/ops/triton/fused_kv_cache.py index 87adf75a9f..aa0be6a96f 100644 --- a/aiter/ops/triton/fused_kv_cache.py +++ b/aiter/ops/triton/fused_kv_cache.py @@ -29,7 +29,7 @@ def fused_qk_rope_cat_and_cache_mla_fake_tensor( q_out: torch.Tensor = None, decode_q_pe_out: torch.Tensor = None, k_pe_out: torch.Tensor = None, - q_out_dtype: torch.dtype=None, + q_out_dtype: torch.dtype = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: b, qh, d_nope = q_nope.shape _, _, d_pe = q_pe.shape @@ -79,7 +79,7 @@ def fused_qk_rope_cat_and_cache_mla( q_out: torch.Tensor = None, decode_q_pe_out: torch.Tensor = None, k_pe_out: torch.Tensor = None, - q_out_dtype: torch.dtype=None, + q_out_dtype: torch.dtype = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Perform RoPE on q_pe and k_pe and concat q_nope with q_pe and k_nope with k_pe along the last dimension @@ -127,7 +127,9 @@ def fused_qk_rope_cat_and_cache_mla( assert (d_freq == d_pe // 2) or ( d_freq == d_pe ), "cos/sin last dim should be the same or half of the qk last dim" - assert num_decode_toks_for_zeros >= 0, "in case tensor creation failure" + assert ( + num_decode_toks_for_zeros >= 0 + ), "num_decode_toks_for_zeros must be non-negative to avoid invalid tensor creation" if isinstance(k_scale, torch.Tensor): assert k_scale.numel() == 1, "k_scale should be a single-element torch.Tensor" reuse_freqs_front_part = d_freq == d_pe // 2 @@ -223,7 +225,7 @@ def fused_qk_rope_cat_and_cache_mla( num_warps=1, ) - if q_nope_zeros_out == None: + if q_nope_zeros_out is None: # change q_nope_zeros_out from None to a tensor for torch compile q_nope_zeros_out = torch.empty( (num_decode_toks_for_zeros, qh, dk_nope), From 442664f1935c234565a5ea3be8f64f4437d9b815 Mon Sep 17 00:00:00 2001 From: Marvin Tsai <62472426+mqhc2020@users.noreply.github.com> Date: Tue, 23 Dec 2025 22:45:08 -0600 Subject: [PATCH 3/5] sync logic in fake with actual function --- aiter/ops/triton/fused_kv_cache.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/aiter/ops/triton/fused_kv_cache.py b/aiter/ops/triton/fused_kv_cache.py index aa0be6a96f..813f04ff58 100644 --- a/aiter/ops/triton/fused_kv_cache.py +++ b/aiter/ops/triton/fused_kv_cache.py @@ -52,11 +52,19 @@ def fused_qk_rope_cat_and_cache_mla_fake_tensor( if k_pe_out is None: k_pe_out = torch.empty((bk, kh, d_pe), dtype=k_pe.dtype, device=k_pe.device) - q_nope_zeros_out = torch.empty( - (num_decode_toks_for_zeros, qh, dk_nope), - dtype=q_nope.dtype, - device=q_nope.device, - ) + + if num_decode_toks_for_zeros > 0: + q_nope_zeros_out = torch.empty( + (num_decode_toks_for_zeros, qh, dk_nope), + dtype=q_nope.dtype, + device=q_nope.device, + ) + else: + q_nope_zeros_out = torch.empty( + (0, qh, dk_nope), + dtype=q_nope.dtype, + device=q_nope.device, + ) return q_out, decode_q_pe_out, k_pe_out, kv_cache, q_nope_zeros_out From 464e7dc91db068424224a26839f58628f2e71c4a Mon Sep 17 00:00:00 2001 From: Marvin Tsai <62472426+mqhc2020@users.noreply.github.com> Date: Wed, 24 Dec 2025 00:31:36 -0600 Subject: [PATCH 4/5] fix black error --- aiter/ops/triton/fused_kv_cache.py | 1 - 1 file changed, 1 deletion(-) diff --git a/aiter/ops/triton/fused_kv_cache.py b/aiter/ops/triton/fused_kv_cache.py index 813f04ff58..631af00dfb 100644 --- a/aiter/ops/triton/fused_kv_cache.py +++ b/aiter/ops/triton/fused_kv_cache.py @@ -52,7 +52,6 @@ def fused_qk_rope_cat_and_cache_mla_fake_tensor( if k_pe_out is None: k_pe_out = torch.empty((bk, kh, d_pe), dtype=k_pe.dtype, device=k_pe.device) - if num_decode_toks_for_zeros > 0: q_nope_zeros_out = torch.empty( (num_decode_toks_for_zeros, qh, dk_nope), From da6f0d83ce80fb5b0d85904c889b3df7e86f3230 Mon Sep 17 00:00:00 2001 From: Marvin Tsai <62472426+mqhc2020@users.noreply.github.com> Date: Wed, 24 Dec 2025 00:39:23 -0600 Subject: [PATCH 5/5] fix black error again --- aiter/ops/triton/fused_kv_cache.py | 1 - 1 file changed, 1 deletion(-) diff --git a/aiter/ops/triton/fused_kv_cache.py b/aiter/ops/triton/fused_kv_cache.py index 631af00dfb..8aa74c7147 100644 --- a/aiter/ops/triton/fused_kv_cache.py +++ b/aiter/ops/triton/fused_kv_cache.py @@ -175,7 +175,6 @@ def fused_qk_rope_cat_and_cache_mla( bk == b_k_pe_out and kh == hk_k_pe_out and d_pe == d_k_pe_out ), "k_pe_out shape mismatch, expected (bk, kh, d_pe)" - q_nope_zeros_out = None if num_decode_toks_for_zeros > 0: q_nope_zeros_out = torch.empty(