Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions aiter/ops/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,14 @@ def reshape_and_cache_with_block_quant_for_asm_pa(
asm_layout: bool,
ori_block_size: int = 128, # [128/256]
) -> None: ...


@compile_ops("module_cache")
def concat_and_cache_mla(
kv_c: Tensor,
k_pe: Tensor,
kv_cache: Tensor,
slot_mapping: Tensor,
kv_cache_dtype: str,
scale: Tensor,
) -> None: ...
8 changes: 8 additions & 0 deletions csrc/include/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,12 @@ void reshape_and_cache_with_block_quant_for_asm_pa(
const bool asm_layout,
const int ori_block_size = 128);

void concat_and_cache_mla(
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
torch::Tensor& k_pe, // [num_tokens, pe_dim]
torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank +
// pe_dim)]
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
const std::string& kv_cache_dtype, torch::Tensor& scale);

} // namespace aiter
14 changes: 13 additions & 1 deletion csrc/include/rocm_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,19 @@
py::arg("v_dequant_scales"), \
py::arg("slot_mapping"), \
py::arg("asm_layout"), \
py::arg("ori_block_size") = 128);
py::arg("ori_block_size") = 128); \
m.def("concat_and_cache_mla", &aiter::concat_and_cache_mla, \
"concat_and_cache_mla(Tensor kv_c, Tensor k_pe," \
" Tensor! kv_cache," \
" Tensor slot_mapping," \
" str kv_cache_dtype," \
" Tensor scale) -> ()", \
py::arg("kv_c"), \
py::arg("k_pe"), \
py::arg("kv_cache"), \
py::arg("slot_mapping"), \
py::arg("kv_cache_dtype"), \
py::arg("scale")); \

#define CUSTOM_ALL_REDUCE_PYBIND \
m.def("init_custom_ar", \
Expand Down
99 changes: 99 additions & 0 deletions csrc/kernels/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,51 @@ __global__ void reshape_and_cache_with_block_quant_kernel_for_asmpa(
}
}
}

template <typename scalar_t, typename cache_t, vllm::Fp8KVCacheDataType kv_dt>
__global__ void concat_and_cache_mla_kernel(
const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank]
const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim]
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
// + pe_dim)]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int block_stride, //
const int entry_stride, //
const int kv_c_stride, //
const int k_pe_stride, //
const int kv_lora_rank, //
const int pe_dim, //
const int block_size, //
const float* scale //
) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
// NOTE: slot_idx can be -1 if the token is padded
if (slot_idx < 0) {
return;
}
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
const float inverted_kscale = 1.0f / *scale;
auto copy = [&](const scalar_t* __restrict__ src, cache_t* __restrict__ dst,
int src_stride, int dst_stride, int size, int offset) {
for (int i = threadIdx.x; i < size; i += blockDim.x) {
const int64_t src_idx = token_idx * src_stride + i;
const int64_t dst_idx =
block_idx * block_stride + block_offset * entry_stride + i + offset;
if constexpr (kv_dt == vllm::Fp8KVCacheDataType::kAuto) {
dst[dst_idx] = src[src_idx];
} else {
dst[dst_idx]= ck_tile::type_convert<cache_t>(
Copy link

Copilot AI Oct 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing space before the assignment operator. Should be dst[dst_idx] = ck_tile::type_convert<cache_t>(.

Copilot uses AI. Check for mistakes.
ck_tile::type_convert<float>(src[src_idx]) * inverted_kscale);
}
}
};

copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0);
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
}

} // namespace aiter

// KV_T is the stored data type of kv-cache.
Expand Down Expand Up @@ -1237,6 +1282,19 @@ void reshape_and_cache_flash(
ori_block_size); \
}

// KV_T is the data type of key and value tensors.
// CACHE_T is the stored data type of kv-cache.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \
aiter::concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, entry_stride, \
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
reinterpret_cast<const float*>(scale.data_ptr()));

namespace aiter {

void reshape_and_cache_with_pertoken_quant(
Expand Down Expand Up @@ -1490,4 +1548,45 @@ void reshape_and_cache_with_block_quant_for_asm_pa(
TORCH_CHECK(false, "Unsupported data type of kv cache: ", key_cache.dtype());
}
}

void concat_and_cache_mla(
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
torch::Tensor& k_pe, // [num_tokens, pe_dim]
torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank +
// pe_dim)]
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
const std::string& kv_cache_dtype, torch::Tensor& scale) {
int num_tokens = slot_mapping.size(0);
int kv_lora_rank = kv_c.size(1);
int pe_dim = k_pe.size(1);
int block_size = kv_cache.size(1);

TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
//TORCH_CHECK(kv_cache_dtype != "fp8");
Copy link

Copilot AI Oct 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commented-out code should be removed rather than left in the codebase. If this check is needed for future implementation, consider adding a TODO comment explaining why it's disabled.

Suggested change
//TORCH_CHECK(kv_cache_dtype != "fp8");
// TODO: Enable the following check if/when "fp8" support is implemented.
// TORCH_CHECK(kv_cache_dtype != "fp8");

Copilot uses AI. Check for mistakes.

int kv_c_stride = kv_c.stride(0);
int k_pe_stride = k_pe.stride(0);
int block_stride = kv_cache.stride(0);
int entry_stride = kv_cache.stride(1);
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(kv_c));
const hipStream_t stream = at::hip::getCurrentHIPStream();

//if (kv_cache_dtype == "fp8_ds_mla") {
// dim3 grid(num_tokens);
// // For the NoPE part, each tile of 128 elements is handled by half of one
// // warp (16 threads). There are 4 total tiles, so 2 warps (64 threads).
// // Lanes 0 and 16 of each warp write the scale values for that warp's tiles.
// // The RoPE part (last 64 elements) is handled by another 1 warp (32
// // threads). So in total, we use 3 warps (96 threads) per block.
// dim3 block(96);
// DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
// CALL_CONCAT_AND_CACHE_DS_MLA);
//} else {
dim3 grid(num_tokens);
dim3 block(std::min(kv_lora_rank, 512));
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
CALL_CONCAT_AND_CACHE_MLA);
//}
Comment on lines +1574 to +1589
Copy link

Copilot AI Oct 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Large blocks of commented-out code should be removed. If this functionality is planned for future implementation, consider using feature flags or moving it to a separate branch.

Suggested change
//if (kv_cache_dtype == "fp8_ds_mla") {
// dim3 grid(num_tokens);
// // For the NoPE part, each tile of 128 elements is handled by half of one
// // warp (16 threads). There are 4 total tiles, so 2 warps (64 threads).
// // Lanes 0 and 16 of each warp write the scale values for that warp's tiles.
// // The RoPE part (last 64 elements) is handled by another 1 warp (32
// // threads). So in total, we use 3 warps (96 threads) per block.
// dim3 block(96);
// DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
// CALL_CONCAT_AND_CACHE_DS_MLA);
//} else {
dim3 grid(num_tokens);
dim3 block(std::min(kv_lora_rank, 512));
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
CALL_CONCAT_AND_CACHE_MLA);
//}
dim3 grid(num_tokens);
dim3 block(std::min(kv_lora_rank, 512));
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
CALL_CONCAT_AND_CACHE_MLA);

Copilot uses AI. Check for mistakes.
Comment on lines +1574 to +1589
Copy link

Copilot AI Oct 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Large blocks of commented-out code should be removed. If this functionality is planned for future implementation, consider using feature flags or moving it to a separate branch.

Suggested change
//if (kv_cache_dtype == "fp8_ds_mla") {
// dim3 grid(num_tokens);
// // For the NoPE part, each tile of 128 elements is handled by half of one
// // warp (16 threads). There are 4 total tiles, so 2 warps (64 threads).
// // Lanes 0 and 16 of each warp write the scale values for that warp's tiles.
// // The RoPE part (last 64 elements) is handled by another 1 warp (32
// // threads). So in total, we use 3 warps (96 threads) per block.
// dim3 block(96);
// DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
// CALL_CONCAT_AND_CACHE_DS_MLA);
//} else {
dim3 grid(num_tokens);
dim3 block(std::min(kv_lora_rank, 512));
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
CALL_CONCAT_AND_CACHE_MLA);
//}
dim3 grid(num_tokens);
dim3 block(std::min(kv_lora_rank, 512));
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
CALL_CONCAT_AND_CACHE_MLA);

Copilot uses AI. Check for mistakes.
}

} // namespace aiter
204 changes: 204 additions & 0 deletions op_tests/test_concat_cache_mla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import torch
import aiter
from aiter.test_common import checkAllclose, perftest, benchmark
from aiter import dtypes
from typing import Tuple
import argparse
import itertools
import pandas as pd
import random
import time
from vllm import _custom_ops as ops


@perftest()
def run_aiter(
kv_c,
k_pe,
kv_cache,
slot_mapping,
kv_cache_dtype: str,
scale,
):
aiter.concat_and_cache_mla(
kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale
)
return kv_cache


@perftest(3)
def run_torch(
kv_c,
k_pe,
kv_cache,
slot_mapping,
kv_cache_dtype: str,
scale,
dtype,
):

block_size = kv_cache.shape[1]
num_tokens = kv_c.shape[0]
kv_lora_rank = kv_c.shape[-1]

for i in range(num_tokens):
slot = slot_mapping[i].item()
block_idx = slot // block_size
block_offset = slot % block_size
kv_cache[block_idx, block_offset, :kv_lora_rank] = kv_c[i]
kv_cache[block_idx, block_offset, kv_lora_rank:] = k_pe[i]

if kv_cache_dtype == "fp8":
ref_kv_cache = (kv_cache.to(torch.float32) / scale.item()).to(dtype)
else:
ref_kv_cache = kv_cache
return ref_kv_cache


@benchmark()
def test_concat_and_cache_mla(
kv_lora_rank: int,
qk_rope_head_dim: int,
num_tokens: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
device: str,
kv_cache_dtype: str,
) -> None:
ret = {}
torch.set_default_device(device)

total_slots = num_blocks * block_size
slot_mapping_lst = random.sample(range(total_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)

kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
k_pe = torch.randn(num_tokens, qk_rope_head_dim, dtype=dtype, device=device)
entry_size = kv_lora_rank + qk_rope_head_dim

scale = torch.tensor(0.1, dtype=torch.float32, device=device)
cache_dtype = dtypes.fp8 if kv_cache_dtype == "fp8" else dtype
kv_cache = torch.zeros(
num_blocks, block_size, entry_size, dtype=cache_dtype, device=device
)

kv_cache, avg_us = run_aiter(
kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale
)
ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)
ref_kv_cache, ref_us = run_torch(
kv_c, k_pe, ref_temp, slot_mapping, kv_cache_dtype, scale, kv_cache.dtype
)

if kv_cache_dtype == "fp8":
result_temp = kv_cache.to(torch.float32) * scale
expected_temp = ref_kv_cache.to(torch.float32) * scale
checkAllclose(result_temp, expected_temp, atol=0.01, rtol=0.01)
else:
checkAllclose(kv_cache, ref_kv_cache)
ret["aiter_us"] = avg_us
ret["torch_us"] = ref_us
return ret


df = []
kv_lora_rank = 128
qk_rope_head_dim = 64
l_num_tokens = [128, 256, 512, 1024, 2048, 4096] # 8192, 16384
block_size = 64
dtype = torch.bfloat16
device = "cuda"
l_kv_cache_dtypes = ["auto", "fp8"]

parser = argparse.ArgumentParser(
formatter_class=argparse.RawTextHelpFormatter,
description="config input of test",
)
parser.add_argument(
"-k",
"--kv_lora_rank",
type=int,
default=512,
help="""kv lora rank.
e.g.: -k 512""",
)
parser.add_argument(
"-qr",
"--qk_rope_head_dim",
type=int,
default=64,
help="""qk rope head dim.
e.g.: -qr 64""",
)

parser.add_argument(
"-blk",
"--block_size",
type=int,
default=64,
help="""Block size.
e.g.: -blk 1""",
)
parser.add_argument(
"-d",
"--dtype",
type=str,
choices=["fp16", "bf16"],
nargs="*",
default="bf16",
help="""Data type of input.
e.g.: -d bf16""",
)
parser.add_argument(
"-kvd",
"--kv_dtype",
type=str,
choices=["auto", "fp8"],
nargs="*",
default=["auto", "fp8"],
help="""Data type of KV cache.
e.g.: -kvd auto""",
)
parser.add_argument(
"-t",
"--token",
type=int,
nargs="*",
default=l_num_tokens,
help="""token nums.
e.g.: -t 128""",
)


args = parser.parse_args()
if args.dtype is not None:
dtype = dtypes.d_dtypes[args.dtype]
if args.token is not None:
l_num_tokens = args.token
if args.kv_dtype is not None:
l_kv_cache_dtypes = args.kv_dtype
if args.block_size is not None:
block_size = args.block_size
if args.qk_rope_head_dim is not None:
qk_rope_head_dim = args.qk_rope_head_dim
if args.kv_lora_rank is not None:
kv_lora_rank = args.kv_lora_rank

for num_token in l_num_tokens:
num_blocks = num_token // block_size
for kv_cache_dtype in l_kv_cache_dtypes:
ret = test_concat_and_cache_mla(
kv_lora_rank,
qk_rope_head_dim,
num_token,
block_size,
num_blocks,
dtype,
device,
kv_cache_dtype,
)

df.append(ret)
df = pd.DataFrame(df)
aiter.logger.info(f"summary:\n{df}")
Loading