Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
153 commits
Select commit Hold shift + click to select a range
fa2c2d2
add num_kv_splits_indptr to mla for mtp<=4 case for now
valarLip Jun 26, 2025
15f6155
update
valarLip Jun 27, 2025
8dd5617
update new kernel
valarLip Jul 1, 2025
c871e8d
infrastructures
ruanjm Jul 14, 2025
3750b5f
1st version of split kernel
ruanjm Jul 16, 2025
7ca2598
Fix issues raised by Lingpeng and fix the issue on batch_size
ruanjm Jul 16, 2025
7c5891c
update mla
valarLip Jul 16, 2025
12def78
update mla_stage2
valarLip Jul 18, 2025
5dc5a6d
Merge branch 'main' into mla_splitkv_enhance
valarLip Jul 18, 2025
eae14ae
Merge branch 'main' into mla_splitkv_enhance
valarLip Jul 18, 2025
f244f11
Merge branch 'mla_splitkv_enhance' into jruan/mla_splitkv_enhance_spl…
ruanjm Jul 22, 2025
224f89f
1st draft of v1 split program
ruanjm Jul 22, 2025
ef442fd
add kv_offset
ruanjm Jul 28, 2025
f10235e
mla_splitkv_enhance_split_alg_inte
Zzz9990 Jul 29, 2025
600b5dd
splitkv debug
Zzz9990 Jul 29, 2025
5c58ae8
1st version of reduce kernel
ruanjm Jul 29, 2025
9700bc5
metadata & kernel finish
Zzz9990 Jul 30, 2025
4a86304
Merge branch 'jruan/mla_splitkv_enhance_split_alg' into mla_splitkv_e…
Zzz9990 Jul 30, 2025
d49c0cd
add reduce
Zzz9990 Jul 30, 2025
e4bf891
final_lse is optional now.
ruanjm Jul 30, 2025
7bf6aa4
update kernel
Zzz9990 Jul 30, 2025
2411f1f
bug fix
ruanjm Jul 30, 2025
e21600d
Merge branch 'jruan/mla_splitkv_enhance_split_alg' into mla_splitkv_e…
Zzz9990 Jul 30, 2025
ffcc113
bug fix 1
ruanjm Jul 30, 2025
07e4ed1
modify reduce api
Zzz9990 Jul 30, 2025
3f2bf25
Merge branch 'jruan/mla_splitkv_enhance_split_alg' into mla_splitkv_e…
Zzz9990 Jul 30, 2025
7c877c4
update kernel
Zzz9990 Jul 30, 2025
d10cdab
fix max splits
Zzz9990 Jul 30, 2025
bac5750
bug fix 3
ruanjm Jul 30, 2025
f59a3e6
fix s80 early return
Zzz9990 Jul 30, 2025
1ae58d1
Merge branch 'jruan/mla_splitkv_enhance_split_alg' into mla_splitkv_e…
Zzz9990 Jul 30, 2025
5680c26
udpate calculation of partial_indx
ruanjm Jul 30, 2025
fa87c91
Merge branch 'jruan/mla_splitkv_enhance_split_alg' into mla_splitkv_e…
Zzz9990 Jul 31, 2025
0dad74c
add per split test
Zzz9990 Jul 31, 2025
a8fa0b1
make lse support by ref
ruanjm Jul 31, 2025
56e964f
test split
Zzz9990 Jul 31, 2025
a76610a
fix redundant calculation of head offset in reduce kernel
ruanjm Jul 31, 2025
4ffd393
add custom test
Zzz9990 Jul 31, 2025
b3747df
Merge branch 'jruan/mla_splitkv_enhance_split_alg' into mla_splitkv_e…
Zzz9990 Jul 31, 2025
ba36541
Add support of 128 head size
ruanjm Jul 31, 2025
e5a1b17
update comments
ruanjm Aug 1, 2025
a68879c
1. Let large work be assigned first.
ruanjm Aug 1, 2025
7209c36
Merge branch 'jruan/mla_splitkv_enhance_split_alg' into mla_splitkv_e…
Zzz9990 Aug 4, 2025
4494b36
Calculate kv_limit dynamically
ruanjm Aug 4, 2025
09c4ca8
Merge branch 'jruan/mla_splitkv_enhance_split_alg' into mla_splitkv_e…
Zzz9990 Aug 4, 2025
1e5e71a
Fix bug about difference in split_kv(bool)
ruanjm Aug 4, 2025
f35cf04
Merge branch 'jruan/mla_splitkv_enhance_split_alg' into mla_splitkv_e…
Zzz9990 Aug 4, 2025
f7cf2b9
add test
Zzz9990 Aug 5, 2025
5b91267
fix seed
Zzz9990 Aug 5, 2025
59af206
Add global tolerance 16 in kv seqlen because main kernel cannot handl…
ruanjm Aug 5, 2025
e1b9065
Fix warp=1 error
ruanjm Aug 8, 2025
2adf050
Add redundant mode to make the size of output of metadata be fixed ad…
ruanjm Aug 8, 2025
c0df46b
Merge branch 'jruan/mla_splitkv_enhance_split_alg' into mla_splitkv_e…
Zzz9990 Aug 12, 2025
fbff664
fp8 setup
Zzz9990 Aug 12, 2025
1d36311
first version of device metadata
ruanjm Aug 12, 2025
4212a41
Add work_ptrs
ruanjm Aug 12, 2025
818229e
Compatibility to CUDA Graph
ruanjm Aug 13, 2025
704324a
Refactor code. Merge 2 iterations of generate work together.
ruanjm Aug 14, 2025
6be798a
Make sure that each batch of workload can never be splited to more th…
ruanjm Aug 14, 2025
1b0e26f
Adjust metadata. Get 1% perf gain.
ruanjm Aug 14, 2025
36e9b53
Paralize most of metadata kernel
ruanjm Aug 15, 2025
4403c82
add scale
Zzz9990 Aug 18, 2025
fcb36f0
1. Use warp-level bitonic sort to sort batch idx based on their cost …
ruanjm Aug 18, 2025
5dc1eb7
fp8 function pass
Zzz9990 Aug 19, 2025
b46a8e3
Fix issues:
ruanjm Aug 19, 2025
d8d92bc
fp8 ready
Zzz9990 Aug 19, 2025
ead163a
fix
Zzz9990 Aug 19, 2025
7fefc29
Merge remote-tracking branch 'origin/jruan/mla_splitkv_enhance_split_…
Zzz9990 Aug 19, 2025
cc7ffdc
persistent ready
Zzz9990 Aug 19, 2025
5e32d5d
add nv acc test
Zzz9990 Aug 21, 2025
a97fcf8
rename
Zzz9990 Sep 1, 2025
e0c72f8
updata metashape
Zzz9990 Sep 1, 2025
7220b04
update reduce cu num
Zzz9990 Sep 1, 2025
07bf6bb
update optest for mla
Zzz9990 Sep 1, 2025
3a7bd04
fix cu num
Zzz9990 Sep 1, 2025
88c8a0d
Update metadata and reduce kernels.
ruanjm Sep 1, 2025
7f86b0b
rename kernels
Zzz9990 Sep 1, 2025
018798d
Add new param kv_granularity to metadata kernel.
ruanjm Sep 2, 2025
3bf1623
Introduce cal_workload_limit_global_v2
ruanjm Sep 9, 2025
907dbed
Support qhead=128 cases.
ruanjm Sep 11, 2025
b2bed66
Change get_mla_metadata() api. Make some not important parameters be …
ruanjm Sep 12, 2025
a658ad8
fix potential problem on calculating tot_qo_tiles
ruanjm Sep 12, 2025
325e03f
refactor metadata files
ruanjm Sep 15, 2025
7072d90
update metadata v1_2
Zzz9990 Sep 18, 2025
851a888
update gqa_128 mla_ps & fix metadata v1_2
Zzz9990 Sep 18, 2025
b56eb25
Optimize mla metadata v1.2
ruanjm Sep 19, 2025
8ea8f73
Optimize mla metadata v1.2 Part.2
ruanjm Sep 19, 2025
9020ce8
Optimize mla metadata v1.2 Part.3
ruanjm Sep 19, 2025
59d8e33
update qlen <=4
Zzz9990 Sep 19, 2025
b401744
fix mla qlen1
Zzz9990 Sep 19, 2025
3bf8b2b
Optimize mla metadata v1.2 Part.4
ruanjm Sep 22, 2025
3f376b5
Make reduce_final_map be optional in mla_reduce_v1
ruanjm Sep 23, 2025
7c865a5
Slightly increase reduce perf
ruanjm Sep 23, 2025
8a17f56
Add persistent mode for mla reduce kernel
ruanjm Sep 24, 2025
75ebf74
add mla_a16w8_qh16_m16x4_n16x1_coex0_mask1_ps.co
fangche123 Sep 28, 2025
3f67dbe
update deepseekv32 sparse mla metadata
Zzz9990 Oct 1, 2025
84e9616
update mla_a16w8_qh16_m16x4_n16x1_coex0_mask1_ps.co
fangche123 Oct 9, 2025
ce9096f
Adjust code for sparse attn
ruanjm Oct 10, 2025
71abd03
Optimize the a16w8 kernel
fangche123 Oct 10, 2025
ebb2591
Improve metadata v1.1 perf
ruanjm Oct 10, 2025
9afba8f
Make metadata v1.1 support sparse attn
ruanjm Oct 10, 2025
2150d8f
Remove redundant code in mla_reduce
ruanjm Oct 11, 2025
363707f
futile struggle
ruanjm Oct 11, 2025
0b874cb
Merge branch 'main' into mla_splitkv_enhance_split_alg_inte
ruanjm Oct 13, 2025
ce9abd8
Fix issue after merge. aiter main branch is using torch.library.infer…
ruanjm Oct 13, 2025
64c3e29
Adjust metadata v1.1 and make this branch be ready to be merged to ma…
ruanjm Oct 14, 2025
57b9d57
Merge branch 'main' into mla_splitkv_enhance_split_alg_inte
ruanjm Oct 14, 2025
b70d8d4
remove invalid co kernel
Zzz9990 Oct 14, 2025
f668d60
Fix issue brought from f794ae4 which disabled hipify by default.
ruanjm Oct 14, 2025
33ea0e8
support qolen>1 for sparse mla
Zzz9990 Oct 14, 2025
6e2c4ff
make code become prettier
ruanjm Oct 14, 2025
c3813fb
Fix issue in metadata v1.1
ruanjm Oct 14, 2025
bcd219a
Merge branch 'main' into mla_splitkv_enhance_split_alg_inte
ruanjm Oct 15, 2025
33b0499
Fix issue in test_mla.py
ruanjm Oct 16, 2025
53f5826
Fix lint fails
ruanjm Oct 16, 2025
41576e1
Fix sub-test fails in op_test/test_mla.py
ruanjm Oct 16, 2025
68ef089
Fix regression in test_mla.py where mtp>1
ruanjm Oct 16, 2025
f7efe97
Add head_dim=128 support to reduce
ruanjm Oct 16, 2025
8440195
Merge branch 'main' into mla_splitkv_enhance_split_alg_inte
ruanjm Oct 17, 2025
1c5b77b
Add nhead=8 for pa and add assert to make sure the input tensors are in
ruanjm Oct 17, 2025
69d41a0
fix issue in vllm benchmark for deepseek: remove metadata v0 because …
ruanjm Oct 17, 2025
0cf3db2
fix lint
ruanjm Oct 17, 2025
ae96787
Revert all the change about mi350 gemm.
ruanjm Oct 17, 2025
be55ef5
add a8w8 and a16w8 kernel in mla mi350
fangche123 Oct 20, 2025
600d993
add A8W8 Non-persistent mode kernel
fangche123 Oct 21, 2025
6c7f795
Fix issue reported by Copilot
ruanjm Oct 22, 2025
573c3cd
add mla non-persistent test
fangche123 Oct 22, 2025
0cfc1a3
script: update a16w8 kernel
fangche123 Oct 23, 2025
0490f21
rm test_mla_persistent_mi350.py and support mi350 in test_mla_persist…
fangche123 Oct 24, 2025
8ca7679
Merge branch 'main' into mla_splitkv_enhance_split_alg_inte
valarLip Oct 24, 2025
526ddbc
add mla_a16w16_qh16_m16x4_n16x1_coex0_mask1_ps.co
minmengdie Oct 28, 2025
836e226
Merge branch 'main' into mla_splitkv_enhance_split_alg_inte
valarLip Oct 29, 2025
9d00e42
fix a8w8 num_kv_split=1
Zzz9990 Oct 29, 2025
2d47448
Merge branch 'main' into mla_splitkv_enhance_split_alg_inte
valarLip Oct 29, 2025
f5bb18a
Fix issue in metadata v1.2 on qo_tiles > 1
ruanjm Oct 31, 2025
bcb18b5
fix ut bandwidth
Zzz9990 Oct 31, 2025
cefebe9
Use nhead=16 simulate cases that nhead=16*N where N is in range(32,16…
ruanjm Oct 29, 2025
bdbb626
Add new api get_mla_metadata_info
ruanjm Oct 31, 2025
545189b
Merge branch 'main' into mla_splitkv_enhance_split_alg_inte
ruanjm Oct 31, 2025
7f95c68
fix lint format issues
ruanjm Oct 31, 2025
ce2098e
Adjust get_mla_metadata_info_v1's parameters.
ruanjm Oct 31, 2025
93b0171
update A16W8 kernel
fangche123 Oct 31, 2025
94800db
update A16W8 kernel2
fangche123 Oct 31, 2025
fd3c1d6
update A16W8 for mi300
fangche123 Nov 2, 2025
9e51724
fix ut and rename some kernels
Zzz9990 Nov 2, 2025
269a917
rename mla kernel name for head 128
fangche123 Nov 2, 2025
6cea09f
remove log
Zzz9990 Nov 2, 2025
d111de8
fix format
Zzz9990 Nov 2, 2025
a39e0c1
add nativly back
Zzz9990 Nov 2, 2025
e3a8834
Merge branch 'main' into mla_splitkv_enhance_split_alg_inte
valarLip Nov 2, 2025
4e7b7f4
Merge branch 'main' into mla_splitkv_enhance_split_alg_inte
Zzz9990 Nov 3, 2025
80526ac
change zeros into empty
Zzz9990 Nov 3, 2025
6ff53e2
fix with comments
Zzz9990 Nov 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions aiter/dist/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
* Copyright © Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2024, The vLLM team.
* Copyright (C) Advanced Micro Devices, Inc. All rights reserved.
* Copyright (C) 2024-2025, The vLLM team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down
27 changes: 27 additions & 0 deletions aiter/jit/optCompilerConfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -946,5 +946,32 @@
"extra_include": [],
"verbose": "False",
"blob_gen_cmd": "''"
},
"module_mla_metadata": {
"srcs": [
"f'{AITER_CSRC_DIR}/pybind/mla_metadata_pybind.cu'",
"f'{AITER_CSRC_DIR}/kernels/mla/metadata.cu'",
"f'{AITER_CSRC_DIR}/kernels/mla/metadata/v1_comm.cuh'",
"f'{AITER_CSRC_DIR}/kernels/mla/metadata/v1_1_device.cuh'",
"f'{AITER_CSRC_DIR}/kernels/mla/metadata/v1_1_host.cuh'",
"f'{AITER_CSRC_DIR}/kernels/mla/metadata/v1_2_device.cuh'"
],
"flags_extra_cc": [],
"flags_extra_hip": [],
"extra_ldflags": "None",
"extra_include": [],
"verbose": "False",
"blob_gen_cmd": "''"
},
"module_mla_reduce": {
"srcs": [
"f'{AITER_CSRC_DIR}/pybind/mla_reduce_pybind.cu'",
"f'{AITER_CSRC_DIR}/kernels/mla/reduce.cu'"],
"flags_extra_cc": [],
"flags_extra_hip": [],
"extra_ldflags": "None",
"extra_include": [],
"verbose": "False",
"blob_gen_cmd": "''"
}
}
Loading