Skip to content

Conversation

@yzhou103
Copy link
Contributor

Motivation

Technical Details

  1. add function

Test Plan

python op_test/test_concat_cache_mla.py

Test Result

Submission Checklist

Copilot AI review requested due to automatic review settings October 14, 2025 10:51
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds a new kernel function concat_and_cache_mla to handle concatenation and caching operations for MLA (Multi-Level Attention) operations. The implementation includes CUDA kernel code, Python bindings, and comprehensive test coverage.

Key changes:

  • Implements a new CUDA kernel for concatenating KV components and caching them with optional FP8 quantization
  • Adds Python API bindings and function declarations
  • Provides comprehensive test suite with performance benchmarking

Reviewed Changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
op_tests/test_concat_cache_mla.py Comprehensive test suite with performance benchmarks for the new kernel
csrc/kernels/cache_kernels.cu Core CUDA kernel implementation and host function for concat_and_cache_mla
csrc/include/rocm_ops.hpp Python binding definitions for the new function
csrc/include/cache.h Function declaration for concat_and_cache_mla
aiter/ops/cache.py Python API wrapper for the new operation

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

if constexpr (kv_dt == vllm::Fp8KVCacheDataType::kAuto) {
dst[dst_idx] = src[src_idx];
} else {
dst[dst_idx]= ck_tile::type_convert<cache_t>(
Copy link

Copilot AI Oct 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing space before the assignment operator. Should be dst[dst_idx] = ck_tile::type_convert<cache_t>(.

Copilot uses AI. Check for mistakes.
int block_size = kv_cache.size(1);

TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
//TORCH_CHECK(kv_cache_dtype != "fp8");
Copy link

Copilot AI Oct 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commented-out code should be removed rather than left in the codebase. If this check is needed for future implementation, consider adding a TODO comment explaining why it's disabled.

Suggested change
//TORCH_CHECK(kv_cache_dtype != "fp8");
// TODO: Enable the following check if/when "fp8" support is implemented.
// TORCH_CHECK(kv_cache_dtype != "fp8");

Copilot uses AI. Check for mistakes.
Comment on lines +1574 to +1589
//if (kv_cache_dtype == "fp8_ds_mla") {
// dim3 grid(num_tokens);
// // For the NoPE part, each tile of 128 elements is handled by half of one
// // warp (16 threads). There are 4 total tiles, so 2 warps (64 threads).
// // Lanes 0 and 16 of each warp write the scale values for that warp's tiles.
// // The RoPE part (last 64 elements) is handled by another 1 warp (32
// // threads). So in total, we use 3 warps (96 threads) per block.
// dim3 block(96);
// DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
// CALL_CONCAT_AND_CACHE_DS_MLA);
//} else {
dim3 grid(num_tokens);
dim3 block(std::min(kv_lora_rank, 512));
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
CALL_CONCAT_AND_CACHE_MLA);
//}
Copy link

Copilot AI Oct 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Large blocks of commented-out code should be removed. If this functionality is planned for future implementation, consider using feature flags or moving it to a separate branch.

Suggested change
//if (kv_cache_dtype == "fp8_ds_mla") {
// dim3 grid(num_tokens);
// // For the NoPE part, each tile of 128 elements is handled by half of one
// // warp (16 threads). There are 4 total tiles, so 2 warps (64 threads).
// // Lanes 0 and 16 of each warp write the scale values for that warp's tiles.
// // The RoPE part (last 64 elements) is handled by another 1 warp (32
// // threads). So in total, we use 3 warps (96 threads) per block.
// dim3 block(96);
// DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
// CALL_CONCAT_AND_CACHE_DS_MLA);
//} else {
dim3 grid(num_tokens);
dim3 block(std::min(kv_lora_rank, 512));
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
CALL_CONCAT_AND_CACHE_MLA);
//}
dim3 grid(num_tokens);
dim3 block(std::min(kv_lora_rank, 512));
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
CALL_CONCAT_AND_CACHE_MLA);

Copilot uses AI. Check for mistakes.
Comment on lines +1574 to +1589
//if (kv_cache_dtype == "fp8_ds_mla") {
// dim3 grid(num_tokens);
// // For the NoPE part, each tile of 128 elements is handled by half of one
// // warp (16 threads). There are 4 total tiles, so 2 warps (64 threads).
// // Lanes 0 and 16 of each warp write the scale values for that warp's tiles.
// // The RoPE part (last 64 elements) is handled by another 1 warp (32
// // threads). So in total, we use 3 warps (96 threads) per block.
// dim3 block(96);
// DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
// CALL_CONCAT_AND_CACHE_DS_MLA);
//} else {
dim3 grid(num_tokens);
dim3 block(std::min(kv_lora_rank, 512));
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
CALL_CONCAT_AND_CACHE_MLA);
//}
Copy link

Copilot AI Oct 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Large blocks of commented-out code should be removed. If this functionality is planned for future implementation, consider using feature flags or moving it to a separate branch.

Suggested change
//if (kv_cache_dtype == "fp8_ds_mla") {
// dim3 grid(num_tokens);
// // For the NoPE part, each tile of 128 elements is handled by half of one
// // warp (16 threads). There are 4 total tiles, so 2 warps (64 threads).
// // Lanes 0 and 16 of each warp write the scale values for that warp's tiles.
// // The RoPE part (last 64 elements) is handled by another 1 warp (32
// // threads). So in total, we use 3 warps (96 threads) per block.
// dim3 block(96);
// DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
// CALL_CONCAT_AND_CACHE_DS_MLA);
//} else {
dim3 grid(num_tokens);
dim3 block(std::min(kv_lora_rank, 512));
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
CALL_CONCAT_AND_CACHE_MLA);
//}
dim3 grid(num_tokens);
dim3 block(std::min(kv_lora_rank, 512));
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
CALL_CONCAT_AND_CACHE_MLA);

Copilot uses AI. Check for mistakes.
@valarLip valarLip merged commit b8e5dde into main Oct 14, 2025
13 of 16 checks passed
@valarLip valarLip deleted the concat_cache_mla branch October 14, 2025 14:35
eliotwang pushed a commit to eliotwang/aiter that referenced this pull request Oct 21, 2025
* add concat_and_cache_mla kernel

* fix interface
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants