-
Notifications
You must be signed in to change notification settings - Fork 167
add concat_and_cache_mla kernel #1194
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds a new kernel function concat_and_cache_mla to handle concatenation and caching operations for MLA (Multi-Level Attention) operations. The implementation includes CUDA kernel code, Python bindings, and comprehensive test coverage.
Key changes:
- Implements a new CUDA kernel for concatenating KV components and caching them with optional FP8 quantization
- Adds Python API bindings and function declarations
- Provides comprehensive test suite with performance benchmarking
Reviewed Changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| op_tests/test_concat_cache_mla.py | Comprehensive test suite with performance benchmarks for the new kernel |
| csrc/kernels/cache_kernels.cu | Core CUDA kernel implementation and host function for concat_and_cache_mla |
| csrc/include/rocm_ops.hpp | Python binding definitions for the new function |
| csrc/include/cache.h | Function declaration for concat_and_cache_mla |
| aiter/ops/cache.py | Python API wrapper for the new operation |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| if constexpr (kv_dt == vllm::Fp8KVCacheDataType::kAuto) { | ||
| dst[dst_idx] = src[src_idx]; | ||
| } else { | ||
| dst[dst_idx]= ck_tile::type_convert<cache_t>( |
Copilot
AI
Oct 14, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing space before the assignment operator. Should be dst[dst_idx] = ck_tile::type_convert<cache_t>(.
| int block_size = kv_cache.size(1); | ||
|
|
||
| TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim); | ||
| //TORCH_CHECK(kv_cache_dtype != "fp8"); |
Copilot
AI
Oct 14, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Commented-out code should be removed rather than left in the codebase. If this check is needed for future implementation, consider adding a TODO comment explaining why it's disabled.
| //TORCH_CHECK(kv_cache_dtype != "fp8"); | |
| // TODO: Enable the following check if/when "fp8" support is implemented. | |
| // TORCH_CHECK(kv_cache_dtype != "fp8"); |
| //if (kv_cache_dtype == "fp8_ds_mla") { | ||
| // dim3 grid(num_tokens); | ||
| // // For the NoPE part, each tile of 128 elements is handled by half of one | ||
| // // warp (16 threads). There are 4 total tiles, so 2 warps (64 threads). | ||
| // // Lanes 0 and 16 of each warp write the scale values for that warp's tiles. | ||
| // // The RoPE part (last 64 elements) is handled by another 1 warp (32 | ||
| // // threads). So in total, we use 3 warps (96 threads) per block. | ||
| // dim3 block(96); | ||
| // DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, | ||
| // CALL_CONCAT_AND_CACHE_DS_MLA); | ||
| //} else { | ||
| dim3 grid(num_tokens); | ||
| dim3 block(std::min(kv_lora_rank, 512)); | ||
| DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, | ||
| CALL_CONCAT_AND_CACHE_MLA); | ||
| //} |
Copilot
AI
Oct 14, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Large blocks of commented-out code should be removed. If this functionality is planned for future implementation, consider using feature flags or moving it to a separate branch.
| //if (kv_cache_dtype == "fp8_ds_mla") { | |
| // dim3 grid(num_tokens); | |
| // // For the NoPE part, each tile of 128 elements is handled by half of one | |
| // // warp (16 threads). There are 4 total tiles, so 2 warps (64 threads). | |
| // // Lanes 0 and 16 of each warp write the scale values for that warp's tiles. | |
| // // The RoPE part (last 64 elements) is handled by another 1 warp (32 | |
| // // threads). So in total, we use 3 warps (96 threads) per block. | |
| // dim3 block(96); | |
| // DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, | |
| // CALL_CONCAT_AND_CACHE_DS_MLA); | |
| //} else { | |
| dim3 grid(num_tokens); | |
| dim3 block(std::min(kv_lora_rank, 512)); | |
| DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, | |
| CALL_CONCAT_AND_CACHE_MLA); | |
| //} | |
| dim3 grid(num_tokens); | |
| dim3 block(std::min(kv_lora_rank, 512)); | |
| DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, | |
| CALL_CONCAT_AND_CACHE_MLA); |
| //if (kv_cache_dtype == "fp8_ds_mla") { | ||
| // dim3 grid(num_tokens); | ||
| // // For the NoPE part, each tile of 128 elements is handled by half of one | ||
| // // warp (16 threads). There are 4 total tiles, so 2 warps (64 threads). | ||
| // // Lanes 0 and 16 of each warp write the scale values for that warp's tiles. | ||
| // // The RoPE part (last 64 elements) is handled by another 1 warp (32 | ||
| // // threads). So in total, we use 3 warps (96 threads) per block. | ||
| // dim3 block(96); | ||
| // DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, | ||
| // CALL_CONCAT_AND_CACHE_DS_MLA); | ||
| //} else { | ||
| dim3 grid(num_tokens); | ||
| dim3 block(std::min(kv_lora_rank, 512)); | ||
| DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, | ||
| CALL_CONCAT_AND_CACHE_MLA); | ||
| //} |
Copilot
AI
Oct 14, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Large blocks of commented-out code should be removed. If this functionality is planned for future implementation, consider using feature flags or moving it to a separate branch.
| //if (kv_cache_dtype == "fp8_ds_mla") { | |
| // dim3 grid(num_tokens); | |
| // // For the NoPE part, each tile of 128 elements is handled by half of one | |
| // // warp (16 threads). There are 4 total tiles, so 2 warps (64 threads). | |
| // // Lanes 0 and 16 of each warp write the scale values for that warp's tiles. | |
| // // The RoPE part (last 64 elements) is handled by another 1 warp (32 | |
| // // threads). So in total, we use 3 warps (96 threads) per block. | |
| // dim3 block(96); | |
| // DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, | |
| // CALL_CONCAT_AND_CACHE_DS_MLA); | |
| //} else { | |
| dim3 grid(num_tokens); | |
| dim3 block(std::min(kv_lora_rank, 512)); | |
| DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, | |
| CALL_CONCAT_AND_CACHE_MLA); | |
| //} | |
| dim3 grid(num_tokens); | |
| dim3 block(std::min(kv_lora_rank, 512)); | |
| DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, | |
| CALL_CONCAT_AND_CACHE_MLA); |
* add concat_and_cache_mla kernel * fix interface
Motivation
Technical Details
Test Plan
python op_test/test_concat_cache_mla.py
Test Result
Submission Checklist