Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions aiter/mla.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

# user interface

Expand Down Expand Up @@ -235,7 +235,8 @@ def mla_decode_fwd(
)

if num_kv_splits == 1 and (
q.dtype == dtypes.fp8 or (q.dtype == dtypes.bf16 and max_seqlen_q == 4)
q.dtype == dtypes.fp8 or (q.dtype == dtypes.bf16 and max_seqlen_q == 4) or (
q.dtype == dtypes.bf16 and kv_buffer.dtype == dtypes.bf16 and nhead in [32, 64])
):
return logits.view(total_s, nhead, v_head_dim), attn_lse

Expand Down Expand Up @@ -269,7 +270,8 @@ def mla_decode_fwd(
if num_kv_splits is None:
num_kv_splits = get_cu_num()
if nhead == 16 or (
nhead == 128 and q.dtype == dtypes.fp8 and kv_buffer.dtype == dtypes.fp8
nhead == 128 and q.dtype == dtypes.fp8 and kv_buffer.dtype == dtypes.fp8) or (
nhead in [32, 64] and q.dtype == dtypes.bf16 and kv_buffer.dtype == dtypes.bf16
):
# Natively support cases
pass
Expand All @@ -281,6 +283,7 @@ def mla_decode_fwd(
q = q.view(total_s, nhead, -1)
o = o.view(total_s, nhead, -1)
io_transformed = True
max_seqlen_q = 1
else:
assert False, f"{nhead=} and {max_seqlen_q=} not supported"

Expand Down
3 changes: 2 additions & 1 deletion aiter/ops/attention.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

import math
from typing import Optional, Tuple
Expand Down Expand Up @@ -631,6 +631,7 @@ def get_mla_metadata_info_v1(
int(math.ceil(max_seqlen_qo * num_head_qo / 128))
if num_head_qo == 16
or (num_head_qo == 128 and kv_dtype == dtypes.fp8 and q_dtype == dtypes.fp8)
or (num_head_qo in [32, 64] and kv_dtype == dtypes.bf16 and q_dtype == dtypes.bf16)
else int(math.ceil(max_seqlen_qo * num_head_qo / 16))
)
batch_size = batch_size * max_seqlen_qo if is_sparse else batch_size
Expand Down
44 changes: 27 additions & 17 deletions csrc/kernels/mla/metadata/v1_2_device.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (C) 2025-2026, Advanced Micro Devices, Inc. All rights reserved.

#include "v1_comm.cuh"

Expand Down Expand Up @@ -371,8 +371,6 @@ void get_mla_metadata_v1_2_device(const torch::Tensor& seqlens_qo_indptr, // [ba
torch::Tensor& reduce_final_map,
torch::Tensor& reduce_partial_map)
{
constexpr int32_t kPackedQoLenPerWg = 128;

const hipStream_t stream = at::hip::getCurrentHIPStream();

hipDevice_t dev;
Expand All @@ -394,16 +392,25 @@ void get_mla_metadata_v1_2_device(const torch::Tensor& seqlens_qo_indptr, // [ba
(q_dtype == at::ScalarType::Float8_e4m3fnuz || q_dtype == at::ScalarType::Float8_e4m3fn);
const bool kv_is_fp8 =
(kv_dtype == at::ScalarType::Float8_e4m3fnuz || kv_dtype == at::ScalarType::Float8_e4m3fn);
const bool natively_supported =
(num_heads == 16) || ((num_heads == 128) && q_is_fp8 && kv_is_fp8);

const bool q_is_bf16 = q_dtype == at::ScalarType::BFloat16;
const bool kv_is_bf16 = kv_dtype == at::ScalarType::BFloat16;

const bool natively_supported = (num_heads == 16) ||
((num_heads == 128) && q_is_fp8 && kv_is_fp8) ||
((num_heads == 64) && q_is_bf16 && kv_is_bf16) ||
((num_heads == 32) && q_is_bf16 && kv_is_bf16);

if((natively_supported == false) && (num_heads % 16 == 0))
{
qk_batch_ratio = num_heads / 16;
num_heads = 16;
num_batches *= qk_batch_ratio;
}

TORCH_CHECK((num_heads == 16) || (num_heads == 128),
TORCH_CHECK((num_heads == 16) || (num_heads == 128) ||
((num_heads == 64) && q_is_bf16 && kv_is_bf16) ||
((num_heads == 32) && q_is_bf16 && kv_is_bf16),
__func__,
": only supports #heads in [16, 128], or (#head, uni_seqlen_qo) = (16*N, 1) where "
"N is in [2, 8).")
Expand Down Expand Up @@ -435,15 +442,18 @@ void get_mla_metadata_v1_2_device(const torch::Tensor& seqlens_qo_indptr, // [ba
params.qk_batch_ratio = qk_batch_ratio;

// launch kernel
MLA_METADATA_DISPATCHER(
max_seqlen_qo * num_heads_per_head_k,
kPackedQoLenPerWg,
params.uni_seqlen_qo,
topk,
dispatch_mla_metadata_v1_2_device<kPackedQoLenPerWg, kQoSplits, kUniSeqlenQo, kIsSparse>(
params,
stream,
max_seqlen_qo,
dev_prop.warpSize,
dev_prop.maxSharedMemoryPerMultiProcessor));
MLA_NUM_HEADS_DISPATCHER(
num_heads_per_head_k,
MLA_METADATA_DISPATCHER(
max_seqlen_qo * num_heads_per_head_k,
kPackedQoLenPerWg,
params.uni_seqlen_qo,
topk,
dispatch_mla_metadata_v1_2_device<kPackedQoLenPerWg, kQoSplits, kUniSeqlenQo, kIsSparse>(
params,
stream,
max_seqlen_qo,
dev_prop.warpSize,
dev_prop.maxSharedMemoryPerMultiProcessor)));

}
21 changes: 21 additions & 0 deletions csrc/kernels/mla/metadata/v1_comm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -384,3 +384,24 @@ private:
MLA_UNI_SEQLEN_DISPATCHER((UNI_SEQLEN_QO), __VA_ARGS__); \
} \
}

#define MLA_NUM_HEADS_CASE(C_NUM_HEADS, ...) \
case C_NUM_HEADS: \
{ \
constexpr int32_t kPackedQoLenPerWg = C_NUM_HEADS; \
__VA_ARGS__; \
break; \
}

#define MLA_NUM_HEADS_DISPATCHER(NUM_HEADS, ...) \
switch (NUM_HEADS) \
{ \
MLA_NUM_HEADS_CASE(32, __VA_ARGS__); \
MLA_NUM_HEADS_CASE(64, __VA_ARGS__); \
default: \
{ \
constexpr int32_t kPackedQoLenPerWg = 128; \
__VA_ARGS__; \
break; \
} \
}
4 changes: 4 additions & 0 deletions csrc/kernels/mla/reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,10 @@ __global__ void kn_mla_reduce_v1(
NUM_HEAD, 16, HEAD_DIM, 128, NUM_WG_PER_SEQ, NAME, __VA_ARGS__) \
MLA_REDUCE_CASE_EF( \
NUM_HEAD, 16, HEAD_DIM, 512, NUM_WG_PER_SEQ, NAME, __VA_ARGS__) \
MLA_REDUCE_CASE_EF( \
NUM_HEAD, 32, HEAD_DIM, 512, NUM_WG_PER_SEQ, NAME, __VA_ARGS__) \
MLA_REDUCE_CASE_EF( \
NUM_HEAD, 64, HEAD_DIM, 512, NUM_WG_PER_SEQ, NAME, __VA_ARGS__) \
MLA_REDUCE_CASE_EF( \
NUM_HEAD, 128, HEAD_DIM, 128, NUM_WG_PER_SEQ, NAME, __VA_ARGS__) \
MLA_REDUCE_CASE_EF( \
Expand Down
38 changes: 38 additions & 0 deletions csrc/py_itfs_cu/asm_mla.cu
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,44 @@ void mla_decode_stage1_asm_fwd(
"/mla/mla_dec_stage1_bf16_a16w16_subQ128_mqa128.co");
impl_ptr = &impl_a16w16_bf16_subQ128;
}
else if(gqa_ratio == 32)
{
if(persistent)
{
sub_Q = 64;
static AiterAsmKernel impl_a16w16_bf16_subQ32(
"_ZN5aiter42mla_a16w16_qh16_m32x1_n16x1_coex0_mask1_psE",
"/mla/MLA_A16W16_1TG_4W_32mx1_16nx1_Coex0_Msk1_QH16_PS.co");
impl_ptr = &impl_a16w16_bf16_subQ32;
}
else
{
sub_Q = 64;
static AiterAsmKernel impl_a16w16_bf16_subQ32(
"_ZN5aiter39mla_a16w16_qh16_m32x1_n16x1_coex0_mask1E",
"/mla/MLA_A16W16_1TG_4W_32mx1_16nx1_Coex0_Msk1_QH16.co");
impl_ptr = &impl_a16w16_bf16_subQ32;
}
}
else if(gqa_ratio == 64)
{
if(persistent)
{
sub_Q = 64;
static AiterAsmKernel impl_a16w16_bf16_subQ64(
"_ZN5aiter42mla_a16w16_qh16_m64x1_n16x1_coex0_mask1_psE",
"/mla/MLA_A16W16_1TG_4W_64mx1_16nx1_Coex0_Msk1_QH16_PS.co");
impl_ptr = &impl_a16w16_bf16_subQ64;
}
else
{
sub_Q = 64;
static AiterAsmKernel impl_a16w16_bf16_subQ64(
"_ZN5aiter39mla_a16w16_qh16_m64x1_n16x1_coex0_mask1E",
"/mla/MLA_A16W16_1TG_4W_64mx1_16nx1_Coex0_Msk1_QH16.co");
impl_ptr = &impl_a16w16_bf16_subQ64;
}
}
else if(gqa_ratio == 16)
{
if(persistent)
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
26 changes: 13 additions & 13 deletions op_tests/test_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,11 +216,11 @@ def test_normal_prefill():
out_dtype = torch.bfloat16

us_aiter = None
if (
dtype == torch.bfloat16 and kvtype == torch.bfloat16
) and batch_size * ctx_lens * nhead < 256 * 8192 * 16:
us_aiter = test_normal_prefill()
ret["prefill:ck_192"] = us_aiter
# if (
# dtype == torch.bfloat16 and kvtype == torch.bfloat16
# ) and batch_size * ctx_lens * nhead < 256 * 8192 * 16:
# us_aiter = test_normal_prefill()
# ret["prefill:ck_192"] = us_aiter

torch.cuda.empty_cache()
# absorb init
Expand Down Expand Up @@ -304,11 +304,11 @@ def test_absorb_prefill():
return us_asm

us_asm = None
if (
dtype == torch.bfloat16 and kvtype == torch.bfloat16
) and batch_size * ctx_lens * nhead < 32 * 8192 * 16:
us_asm = test_absorb_prefill()
ret["prefill:asm_576"] = us_asm
# if (
# dtype == torch.bfloat16 and kvtype == torch.bfloat16
# ) and batch_size * ctx_lens * nhead < 32 * 8192 * 16:
# us_asm = test_absorb_prefill()
# ret["prefill:asm_576"] = us_asm

torch.cuda.empty_cache()

Expand Down Expand Up @@ -451,7 +451,7 @@ def test_absorb_decode_fp8():

err = None
us_asm_decode = 1e12
if dtype == torch.bfloat16 and nhead in [16, 128]:
if (dtype == torch.bfloat16 and kvtype == torch.bfloat16) and nhead in [16, 32, 64, 128]:
err, us_asm_decode = test_absorb_decode_bf16()
elif kvtype == dtypes.fp8 and nhead in [16, 128]:
err, us_asm_decode = test_absorb_decode_fp8()
Expand Down Expand Up @@ -481,7 +481,7 @@ def test_absorb_decode_fp8():
block_size = 1
list_dtype = ["bf16", "fp8"]
l_kv_dtype = ["bf16", "fp8"]
list_nhead = [(16, 1), (16, 2), (16, 4), (128, 1), (128, 2)]
list_nhead = [(16, 1), (16, 2), (16, 4), (32, 1), (64, 1), (128, 1), (128, 2)]

parser = argparse.ArgumentParser(
formatter_class=argparse.RawTextHelpFormatter,
Expand Down Expand Up @@ -561,7 +561,7 @@ def test_absorb_decode_fp8():
"--batchSize",
type=int,
nargs="*",
default=[1, 3, 5, 16, 32, 64, 128, 256],
default=[4, 6, 8, 12, 12, 18, 16, 24],
help="""Batch size.
e.g.: -b 16""",
)
Expand Down
16 changes: 11 additions & 5 deletions op_tests/test_mla_persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,9 @@ def test_mla(
max_seqlen_qo = seq_lens_qo.max().item()
qo_indptr[1 : batch_size + 1] = torch.cumsum(seq_lens_qo, dim=0)
total_q = qo_indptr[-1].item()
q = torch.randn((total_q, nhead, qk_head_dim), dtype=torch.bfloat16)
# q = torch.randn((total_q, nhead, qk_head_dim), dtype=torch.bfloat16)
q = torch.randn((total_q + 1, nhead, qk_head_dim), dtype=torch.bfloat16)
q = q[:-1]

# troch implementation
out_ref, lse_ref = torch_mla_extend(
Expand Down Expand Up @@ -300,22 +302,28 @@ def test_mla(
dtype_q=dtype,
dtype_kv=kvtype,
)
q_scale = torch.ones([1], dtype=torch.float, device="cuda")
kv_scale = torch.ones([1], dtype=torch.float, device="cuda")


def test_absorb_decode_bf16():
kv_last_page_lens = torch.ones(batch_size, dtype=torch.int)
out_asm = torch.empty((total_q, nhead, v_head_dim), dtype=out_dtype).fill_(-1)

kv_buffer_cal = kv_buffer.to(kvtype)

(attn_logits, attn_lse), us_asm_decode = run_perftest(
aiter.mla.mla_decode_fwd,
q,
kv_buffer.view(num_page, page_size, nhead_kv, qk_head_dim),
kv_buffer_cal.view(num_page, page_size, nhead_kv, qk_head_dim),
out_asm,
qo_indptr,
kv_indptr,
kv_indices,
kv_last_page_lens,
max_seqlen_qo,
sm_scale,
kv_scale=kv_scale,
num_kv_splits=max_split_per_batch,
work_meta_data=work_meta_data,
work_indptr=work_indptr,
Expand All @@ -336,17 +344,15 @@ def test_absorb_decode_bf16():
out_asm,
msg=f"mla_decode-absorb [golden vs aiter_asm]: {us_asm_decode:>8.2f} us......",
)

return err, us_asm_decode

def test_absorb_decode_fp8():
kv_last_page_lens = torch.ones(batch_size, dtype=torch.int)
out_asm = torch.empty((total_q, nhead, v_head_dim), dtype=out_dtype).fill_(-1)

q_fp8 = q.to(dtypes.fp8)
q_scale = torch.ones([1], dtype=torch.float, device="cuda")

kv_buffer_fp8 = kv_buffer.to(dtypes.fp8)
kv_scale = torch.ones([1], dtype=torch.float, device="cuda")

out_ref_fp8, lse_ref_fp8 = torch_mla_extend(
q_fp8 if dtype == dtypes.fp8 else q,
Expand Down
11 changes: 7 additions & 4 deletions op_tests/test_mla_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,7 @@ def test_mla(
dtype_q=dtype,
dtype_kv=kvtype,
)
# import pdb;pdb.set_trace()

# generate kv topk per token & convert indices into per token
token_indices = generate_topk_kv(kv_indptr, decode_qlen)
Expand Down Expand Up @@ -493,11 +494,15 @@ def test_mla(
is_causal=False,
dtype=out_dtype,
)
q_scale = torch.ones([1], dtype=torch.float, device="cuda")
kv_scale = torch.ones([1], dtype=torch.float, device="cuda")

def test_sparse_mla_bf16():
kv_last_page_lens = torch.ones(batch_size, dtype=torch.int)
out_asm = torch.empty((total_q, nhead, v_head_dim), dtype=out_dtype).fill_(-1)

kv_buffer_cal = kv_buffer.to(kvtype)

(attn_logits, attn_lse), us_asm_decode = run_perftest(
aiter.mla.mla_decode_fwd,
q,
Expand All @@ -510,6 +515,7 @@ def test_sparse_mla_bf16():
kv_last_page_lens,
1,
sm_scale,
kv_scale=kv_scale,
num_kv_splits=max_split_per_batch,
work_meta_data=work_meta_data,
work_indptr=work_indptr,
Expand Down Expand Up @@ -540,10 +546,7 @@ def test_sparse_mla_fp8():
out_asm = torch.empty((total_q, nhead, v_head_dim), dtype=out_dtype).fill_(-1)

q_fp8 = q.to(dtypes.fp8)
q_scale = torch.ones([1], dtype=torch.float, device="cuda")

kv_buffer_fp8 = kv_buffer.to(kvtype)
kv_scale = torch.ones([1], dtype=torch.float, device="cuda")

out_ref_fp8, lse_ref_fp8 = torch_mla_extend(
q_fp8 if dtype == dtypes.fp8 else q,
Expand Down Expand Up @@ -632,7 +635,7 @@ def test_sparse_mla_fp8():
block_size = 1
list_dtype = ["bf16", "fp8"]
l_kv_dtype = ["bf16", "fp8"]
list_nhead = [(16, 2), (48, 1), (128, 2)]
list_nhead = [(16, 2), (48, 1), (64, 2), (128, 2)]

parser = argparse.ArgumentParser(
formatter_class=argparse.RawTextHelpFormatter,
Expand Down