-
Notifications
You must be signed in to change notification settings - Fork 171
add concat_and_cache_mla kernel #1194
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -973,6 +973,51 @@ __global__ void reshape_and_cache_with_block_quant_kernel_for_asmpa( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| template <typename scalar_t, typename cache_t, vllm::Fp8KVCacheDataType kv_dt> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| __global__ void concat_and_cache_mla_kernel( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // + pe_dim)] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const int64_t* __restrict__ slot_mapping, // [num_tokens] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const int block_stride, // | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const int entry_stride, // | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const int kv_c_stride, // | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const int k_pe_stride, // | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const int kv_lora_rank, // | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const int pe_dim, // | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const int block_size, // | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const float* scale // | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const int64_t token_idx = blockIdx.x; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const int64_t slot_idx = slot_mapping[token_idx]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // NOTE: slot_idx can be -1 if the token is padded | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (slot_idx < 0) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const int64_t block_idx = slot_idx / block_size; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const int64_t block_offset = slot_idx % block_size; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const float inverted_kscale = 1.0f / *scale; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto copy = [&](const scalar_t* __restrict__ src, cache_t* __restrict__ dst, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int src_stride, int dst_stride, int size, int offset) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (int i = threadIdx.x; i < size; i += blockDim.x) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const int64_t src_idx = token_idx * src_stride + i; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const int64_t dst_idx = | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| block_idx * block_stride + block_offset * entry_stride + i + offset; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if constexpr (kv_dt == vllm::Fp8KVCacheDataType::kAuto) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dst[dst_idx] = src[src_idx]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dst[dst_idx]= ck_tile::type_convert<cache_t>( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ck_tile::type_convert<float>(src[src_idx]) * inverted_kscale); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } // namespace aiter | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // KV_T is the stored data type of kv-cache. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1237,6 +1282,19 @@ void reshape_and_cache_flash( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ori_block_size); \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // KV_T is the data type of key and value tensors. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // CACHE_T is the stored data type of kv-cache. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // KV_DTYPE is the real data type of kv-cache. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| aiter::concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| <<<grid, block, 0, stream>>>( \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| reinterpret_cast<KV_T*>(kv_c.data_ptr()), \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| reinterpret_cast<KV_T*>(k_pe.data_ptr()), \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| slot_mapping.data_ptr<int64_t>(), block_stride, entry_stride, \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| reinterpret_cast<const float*>(scale.data_ptr())); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| namespace aiter { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| void reshape_and_cache_with_pertoken_quant( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1490,4 +1548,45 @@ void reshape_and_cache_with_block_quant_for_asm_pa( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TORCH_CHECK(false, "Unsupported data type of kv cache: ", key_cache.dtype()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| void concat_and_cache_mla( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch::Tensor& kv_c, // [num_tokens, kv_lora_rank] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch::Tensor& k_pe, // [num_tokens, pe_dim] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank + | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // pe_dim)] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const std::string& kv_cache_dtype, torch::Tensor& scale) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int num_tokens = slot_mapping.size(0); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int kv_lora_rank = kv_c.size(1); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int pe_dim = k_pe.size(1); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int block_size = kv_cache.size(1); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| //TORCH_CHECK(kv_cache_dtype != "fp8"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| //TORCH_CHECK(kv_cache_dtype != "fp8"); | |
| // TODO: Enable the following check if/when "fp8" support is implemented. | |
| // TORCH_CHECK(kv_cache_dtype != "fp8"); |
Copilot
AI
Oct 14, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Large blocks of commented-out code should be removed. If this functionality is planned for future implementation, consider using feature flags or moving it to a separate branch.
| //if (kv_cache_dtype == "fp8_ds_mla") { | |
| // dim3 grid(num_tokens); | |
| // // For the NoPE part, each tile of 128 elements is handled by half of one | |
| // // warp (16 threads). There are 4 total tiles, so 2 warps (64 threads). | |
| // // Lanes 0 and 16 of each warp write the scale values for that warp's tiles. | |
| // // The RoPE part (last 64 elements) is handled by another 1 warp (32 | |
| // // threads). So in total, we use 3 warps (96 threads) per block. | |
| // dim3 block(96); | |
| // DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, | |
| // CALL_CONCAT_AND_CACHE_DS_MLA); | |
| //} else { | |
| dim3 grid(num_tokens); | |
| dim3 block(std::min(kv_lora_rank, 512)); | |
| DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, | |
| CALL_CONCAT_AND_CACHE_MLA); | |
| //} | |
| dim3 grid(num_tokens); | |
| dim3 block(std::min(kv_lora_rank, 512)); | |
| DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, | |
| CALL_CONCAT_AND_CACHE_MLA); |
Copilot
AI
Oct 14, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Large blocks of commented-out code should be removed. If this functionality is planned for future implementation, consider using feature flags or moving it to a separate branch.
| //if (kv_cache_dtype == "fp8_ds_mla") { | |
| // dim3 grid(num_tokens); | |
| // // For the NoPE part, each tile of 128 elements is handled by half of one | |
| // // warp (16 threads). There are 4 total tiles, so 2 warps (64 threads). | |
| // // Lanes 0 and 16 of each warp write the scale values for that warp's tiles. | |
| // // The RoPE part (last 64 elements) is handled by another 1 warp (32 | |
| // // threads). So in total, we use 3 warps (96 threads) per block. | |
| // dim3 block(96); | |
| // DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, | |
| // CALL_CONCAT_AND_CACHE_DS_MLA); | |
| //} else { | |
| dim3 grid(num_tokens); | |
| dim3 block(std::min(kv_lora_rank, 512)); | |
| DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, | |
| CALL_CONCAT_AND_CACHE_MLA); | |
| //} | |
| dim3 grid(num_tokens); | |
| dim3 block(std::min(kv_lora_rank, 512)); | |
| DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, | |
| CALL_CONCAT_AND_CACHE_MLA); |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,204 @@ | ||
| import torch | ||
| import aiter | ||
| from aiter.test_common import checkAllclose, perftest, benchmark | ||
| from aiter import dtypes | ||
| from typing import Tuple | ||
| import argparse | ||
| import itertools | ||
| import pandas as pd | ||
| import random | ||
| import time | ||
| from vllm import _custom_ops as ops | ||
|
|
||
|
|
||
| @perftest() | ||
| def run_aiter( | ||
| kv_c, | ||
| k_pe, | ||
| kv_cache, | ||
| slot_mapping, | ||
| kv_cache_dtype: str, | ||
| scale, | ||
| ): | ||
| aiter.concat_and_cache_mla( | ||
| kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale | ||
| ) | ||
| return kv_cache | ||
|
|
||
|
|
||
| @perftest(3) | ||
| def run_torch( | ||
| kv_c, | ||
| k_pe, | ||
| kv_cache, | ||
| slot_mapping, | ||
| kv_cache_dtype: str, | ||
| scale, | ||
| dtype, | ||
| ): | ||
|
|
||
| block_size = kv_cache.shape[1] | ||
| num_tokens = kv_c.shape[0] | ||
| kv_lora_rank = kv_c.shape[-1] | ||
|
|
||
| for i in range(num_tokens): | ||
| slot = slot_mapping[i].item() | ||
| block_idx = slot // block_size | ||
| block_offset = slot % block_size | ||
| kv_cache[block_idx, block_offset, :kv_lora_rank] = kv_c[i] | ||
| kv_cache[block_idx, block_offset, kv_lora_rank:] = k_pe[i] | ||
|
|
||
| if kv_cache_dtype == "fp8": | ||
| ref_kv_cache = (kv_cache.to(torch.float32) / scale.item()).to(dtype) | ||
| else: | ||
| ref_kv_cache = kv_cache | ||
| return ref_kv_cache | ||
|
|
||
|
|
||
| @benchmark() | ||
| def test_concat_and_cache_mla( | ||
| kv_lora_rank: int, | ||
| qk_rope_head_dim: int, | ||
| num_tokens: int, | ||
| block_size: int, | ||
| num_blocks: int, | ||
| dtype: torch.dtype, | ||
| device: str, | ||
| kv_cache_dtype: str, | ||
| ) -> None: | ||
| ret = {} | ||
| torch.set_default_device(device) | ||
|
|
||
| total_slots = num_blocks * block_size | ||
| slot_mapping_lst = random.sample(range(total_slots), num_tokens) | ||
| slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device) | ||
|
|
||
| kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device) | ||
| k_pe = torch.randn(num_tokens, qk_rope_head_dim, dtype=dtype, device=device) | ||
| entry_size = kv_lora_rank + qk_rope_head_dim | ||
|
|
||
| scale = torch.tensor(0.1, dtype=torch.float32, device=device) | ||
| cache_dtype = dtypes.fp8 if kv_cache_dtype == "fp8" else dtype | ||
| kv_cache = torch.zeros( | ||
| num_blocks, block_size, entry_size, dtype=cache_dtype, device=device | ||
| ) | ||
|
|
||
| kv_cache, avg_us = run_aiter( | ||
| kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale | ||
| ) | ||
| ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device) | ||
| ref_kv_cache, ref_us = run_torch( | ||
| kv_c, k_pe, ref_temp, slot_mapping, kv_cache_dtype, scale, kv_cache.dtype | ||
| ) | ||
|
|
||
| if kv_cache_dtype == "fp8": | ||
| result_temp = kv_cache.to(torch.float32) * scale | ||
| expected_temp = ref_kv_cache.to(torch.float32) * scale | ||
| checkAllclose(result_temp, expected_temp, atol=0.01, rtol=0.01) | ||
| else: | ||
| checkAllclose(kv_cache, ref_kv_cache) | ||
| ret["aiter_us"] = avg_us | ||
| ret["torch_us"] = ref_us | ||
| return ret | ||
|
|
||
|
|
||
| df = [] | ||
| kv_lora_rank = 128 | ||
| qk_rope_head_dim = 64 | ||
| l_num_tokens = [128, 256, 512, 1024, 2048, 4096] # 8192, 16384 | ||
| block_size = 64 | ||
| dtype = torch.bfloat16 | ||
| device = "cuda" | ||
| l_kv_cache_dtypes = ["auto", "fp8"] | ||
|
|
||
| parser = argparse.ArgumentParser( | ||
| formatter_class=argparse.RawTextHelpFormatter, | ||
| description="config input of test", | ||
| ) | ||
| parser.add_argument( | ||
| "-k", | ||
| "--kv_lora_rank", | ||
| type=int, | ||
| default=512, | ||
| help="""kv lora rank. | ||
| e.g.: -k 512""", | ||
| ) | ||
| parser.add_argument( | ||
| "-qr", | ||
| "--qk_rope_head_dim", | ||
| type=int, | ||
| default=64, | ||
| help="""qk rope head dim. | ||
| e.g.: -qr 64""", | ||
| ) | ||
|
|
||
| parser.add_argument( | ||
| "-blk", | ||
| "--block_size", | ||
| type=int, | ||
| default=64, | ||
| help="""Block size. | ||
| e.g.: -blk 1""", | ||
| ) | ||
| parser.add_argument( | ||
| "-d", | ||
| "--dtype", | ||
| type=str, | ||
| choices=["fp16", "bf16"], | ||
| nargs="*", | ||
| default="bf16", | ||
| help="""Data type of input. | ||
| e.g.: -d bf16""", | ||
| ) | ||
| parser.add_argument( | ||
| "-kvd", | ||
| "--kv_dtype", | ||
| type=str, | ||
| choices=["auto", "fp8"], | ||
| nargs="*", | ||
| default=["auto", "fp8"], | ||
| help="""Data type of KV cache. | ||
| e.g.: -kvd auto""", | ||
| ) | ||
| parser.add_argument( | ||
| "-t", | ||
| "--token", | ||
| type=int, | ||
| nargs="*", | ||
| default=l_num_tokens, | ||
| help="""token nums. | ||
| e.g.: -t 128""", | ||
| ) | ||
|
|
||
|
|
||
| args = parser.parse_args() | ||
| if args.dtype is not None: | ||
| dtype = dtypes.d_dtypes[args.dtype] | ||
| if args.token is not None: | ||
| l_num_tokens = args.token | ||
| if args.kv_dtype is not None: | ||
| l_kv_cache_dtypes = args.kv_dtype | ||
| if args.block_size is not None: | ||
| block_size = args.block_size | ||
| if args.qk_rope_head_dim is not None: | ||
| qk_rope_head_dim = args.qk_rope_head_dim | ||
| if args.kv_lora_rank is not None: | ||
| kv_lora_rank = args.kv_lora_rank | ||
|
|
||
| for num_token in l_num_tokens: | ||
| num_blocks = num_token // block_size | ||
| for kv_cache_dtype in l_kv_cache_dtypes: | ||
| ret = test_concat_and_cache_mla( | ||
| kv_lora_rank, | ||
| qk_rope_head_dim, | ||
| num_token, | ||
| block_size, | ||
| num_blocks, | ||
| dtype, | ||
| device, | ||
| kv_cache_dtype, | ||
| ) | ||
|
|
||
| df.append(ret) | ||
| df = pd.DataFrame(df) | ||
| aiter.logger.info(f"summary:\n{df}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing space before the assignment operator. Should be
dst[dst_idx] = ck_tile::type_convert<cache_t>(.