Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions aiter/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ def mla_decode_fwd(
q_scale=None,
kv_scale=None,
intra_batch_mode=False,
return_logits=False,
return_lse=False,
):
device = q.device
assert logit_cap <= 0, f"{logit_cap=} is not support yet"
Expand Down Expand Up @@ -271,7 +273,7 @@ def mla_decode_fwd(
):
# Natively support cases
pass
elif nhead in range(32, 128 + 1, 16) and persistent_mode and max_seqlen_q == 1:
elif nhead in range(32, 128 + 1, 16) and persistent_mode:
# we use nhead=16 to simulate such cases by customized metadata
# metadata also views qo's tensor as shape (total_s * (nhead // 16), 16, ...)
total_s = ori_total_s * (ori_nhead // 16)
Expand All @@ -292,7 +294,11 @@ def mla_decode_fwd(
dtype=dtypes.fp32,
device=device,
)
final_lse = torch.empty((total_s, nhead), dtype=dtypes.fp32, device=device)
final_lse = (
torch.empty((total_s, nhead), dtype=dtypes.fp32, device=device)
if return_lse
else None
)

aiter.mla_decode_stage1_asm_fwd(
q,
Expand Down Expand Up @@ -326,10 +332,9 @@ def mla_decode_fwd(
)

if io_transformed:
if persistent_mode:
if return_logits:
logits = logits.view(-1, 1, ori_nhead, v_head_dim)
else:
logits = logits.view(ori_total_s, num_kv_splits, ori_nhead, v_head_dim)

q = q.view(ori_total_s, ori_nhead, -1)
o = o.view(ori_total_s, ori_nhead, -1)

Expand Down
3 changes: 2 additions & 1 deletion aiter/ops/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,8 @@ def get_mla_metadata_info_v1(

max_qo_tiles_per_batch = (
int(math.ceil(max_seqlen_qo * num_head_qo / 128))
if num_head_qo == 16 or (num_head_qo == 128 and kv_dtype == dtypes.fp8)
if num_head_qo == 16
or (num_head_qo == 128 and kv_dtype == dtypes.fp8 and q_dtype == dtypes.fp8)
else int(math.ceil(max_seqlen_qo * num_head_qo / 16))
)
batch_size = batch_size * max_seqlen_qo if is_sparse else batch_size
Expand Down
62 changes: 47 additions & 15 deletions csrc/kernels/mla/metadata/v1_2_device.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,34 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__
{
using QoState = QoState<Traits>;

const int32_t ori_seqlen_qo = [&]() {
if constexpr (Traits::kIsSparse)
{
return params.p_seqlens_qo_indptr[1] - params.p_seqlens_qo_indptr[0];
}
else
{
return params.ori_seqlen_qo;
}
}();

const int32_t num_batches = [&]() {
if constexpr (Traits::kIsSparse)
{
return params.num_batches * ori_seqlen_qo;
}
else
{
return params.num_batches;
}
}();

extern __shared__ uint8_t p_smem[];
int32_t* p_lds_seqlens_qo = reinterpret_cast<int32_t*>(p_smem);
int32_t* p_lds_seqlens_kv = p_lds_seqlens_qo + (QoState::is_unique() ? 0 : params.num_batches);
int32_t* p_lds_seqlens_kv = p_lds_seqlens_qo + (QoState::is_unique() ? 0 : num_batches);

QoState qo_state(
params.uni_seqlen_qo, params.ori_seqlen_qo, p_lds_seqlens_qo, params.p_seqlens_qo_indptr);
params.uni_seqlen_qo, ori_seqlen_qo, p_lds_seqlens_qo, params.p_seqlens_qo_indptr);

auto get_num_qo_tiles = [&](const int32_t batch_idx) {
if constexpr(Traits::kQoSplits)
Expand All @@ -53,10 +75,10 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__
MlaWorkInfo* p_work_info_set = reinterpret_cast<MlaWorkInfo*>(params.p_work_info_set_raw);

int32_t sum_blocks = 0;
for(int32_t bid = lane_idx; bid < params.num_batches; bid += ck_tile::get_warp_size())
for(int32_t bid = lane_idx; bid < num_batches; bid += ck_tile::get_warp_size())
{
const int32_t bid_ori = Traits::kIsSparse
? (bid / params.ori_seqlen_qo / params.qk_batch_ratio)
? (bid / ori_seqlen_qo / params.qk_batch_ratio)
: (bid / params.qk_batch_ratio);
const int32_t kv_end = params.p_seqlens_kv_indptr[bid_ori + 1];
const int32_t seqlen_kv =
Expand Down Expand Up @@ -119,7 +141,7 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__
for(int32_t cid = 0; cid < params.num_cu; ++cid)
{
int32_t remain_payload = payload;
while(curr_batch < params.num_batches)
while(curr_batch < num_batches)
{
const int32_t num_qo_tiles = get_num_qo_tiles(curr_batch);
const int32_t qo_tile_size =
Expand All @@ -143,9 +165,17 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__
work_info.qo_end = ck_tile::min(work_info.qo_start + qo_tile_size,
qo_state.get_end(curr_batch));
work_info.kv_start = curr_kv_begin + (curr_kv_block * params.kv_granularity);
int32_t batch_tail = (num_qo_tiles - 1 - curr_qo_tile_idx);
if constexpr(!Traits::kIsSparse)
{
if (params.qk_batch_ratio != 1)
{
batch_tail = num_qo_tiles - (work_info.qo_start / params.qk_batch_ratio) % ori_seqlen_qo - 1;
}
}
work_info.kv_end = ck_tile::min(
work_info.kv_start + (remain_kv_blocks * params.kv_granularity),
curr_kv_end - (num_qo_tiles - 1 - curr_qo_tile_idx));
curr_kv_end - batch_tail);
work_info.kv_offset = curr_kv_end - work_info.kv_end;

// split related info
Expand Down Expand Up @@ -202,7 +232,7 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__
curr_sub_head_idx = (curr_sub_head_idx == (params.qk_batch_ratio - 1))
? 0
: (curr_sub_head_idx + 1);
if(curr_batch < params.num_batches)
if(curr_batch < num_batches)
{
if(curr_sub_head_idx == 0)
{
Expand All @@ -213,7 +243,7 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__
else
{
const int32_t bid_ori = Traits::kIsSparse
? (curr_batch / params.ori_seqlen_qo /
? (curr_batch / ori_seqlen_qo /
params.qk_batch_ratio)
: (curr_batch / params.qk_batch_ratio);
curr_kv_seqlen = params.p_seqlens_kv_indptr[bid_ori + 1] -
Expand Down Expand Up @@ -251,9 +281,17 @@ __launch_bounds__(ck_tile::get_warp_size(), 1) __global__
qo_state.get_end(curr_batch));
work_info.kv_start =
curr_kv_begin + (curr_kv_block * params.kv_granularity);
int32_t batch_tail = (num_qo_tiles - 1 - curr_qo_tile_idx);
if constexpr(!Traits::kIsSparse)
{
if (params.qk_batch_ratio != 1)
{
batch_tail = num_qo_tiles - (work_info.qo_start / params.qk_batch_ratio) % ori_seqlen_qo - 1;
}
}
work_info.kv_end = ck_tile::min(
work_info.kv_start + (consuming_blks * params.kv_granularity),
curr_kv_end - (num_qo_tiles - 1 - curr_qo_tile_idx));
curr_kv_end - batch_tail);
work_info.kv_offset = curr_kv_end - work_info.kv_end;
work_info.partial_qo_loc = partial_idx;
p_work_info_set[num_works] = work_info;
Expand Down Expand Up @@ -365,12 +403,6 @@ void get_mla_metadata_v1_2_device(const torch::Tensor& seqlens_qo_indptr, // [ba
num_batches *= qk_batch_ratio;
}

if(is_sparse)
{
num_batches *= uni_seqlen_qo;
uni_seqlen_qo = 1;
}

TORCH_CHECK((num_heads == 16) || (num_heads == 128),
__func__,
": only supports #heads in [16, 128], or (#head, uni_seqlen_qo) = (16*N, 1) where "
Expand Down
43 changes: 25 additions & 18 deletions op_tests/test_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@
# qdtype fp8, kdtype fp8: nhead16, nhead128


def check_support(dtype, kv_dtype, nhead):
if dtype == dtypes.fp8 and kv_dtype == dtypes.bf16:
return False
return True


def cal_diff(
x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool = False
) -> None:
Expand Down Expand Up @@ -445,11 +451,11 @@ def test_absorb_decode_fp8():

err = None
us_asm_decode = 1e12
if (dtype == torch.bfloat16 and kvtype == torch.bfloat16) and nhead in [16, 128]:
if dtype == torch.bfloat16 and nhead in [16, 128]:
err, us_asm_decode = test_absorb_decode_bf16()

elif kvtype == dtypes.fp8 and nhead in [16, 128]:
err, us_asm_decode = test_absorb_decode_fp8()

ret["decode:err"] = err
ret["decode:asm_576"] = us_asm_decode

Expand Down Expand Up @@ -599,22 +605,23 @@ def test_absorb_decode_fp8():
for dtype, kvtype, ctx_len, batch_size, split_per_batch in itertools.product(
list_dtype, l_kv_dtype, args.ctxLen, args.batchSize, args.split_per_batch
):
ret = test_mla(
ctx_len,
batch_size,
nhead,
args.kv_lora_rank,
args.qk_nope_head_dim,
args.qk_rope_head_dim,
args.v_head_dim,
dtype,
kvtype,
args.block_size,
varlen=args.varlen,
decode_qlen=decode_qlen,
split_per_batch=split_per_batch,
)
df.append(ret)
if check_support(dtype, kvtype, nhead):
ret = test_mla(
ctx_len,
batch_size,
nhead,
args.kv_lora_rank,
args.qk_nope_head_dim,
args.qk_rope_head_dim,
args.v_head_dim,
dtype,
kvtype,
args.block_size,
varlen=args.varlen,
decode_qlen=decode_qlen,
split_per_batch=split_per_batch,
)
df.append(ret)
df = pd.DataFrame(df)
# df.to_csv(f"mla_nhead{nhead}decode_qlen{decode_qlen}.csv")
aiter.logger.info(f"summary:\n{df}")
51 changes: 26 additions & 25 deletions op_tests/test_mla_persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
# qdtype fp8, kdtype bf16: nhead16


def check_support(dtype, kv_dtype, nhead):
if dtype == dtypes.fp8 and kv_dtype == dtypes.bf16:
return False
return True


def cal_diff(
x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool = False
) -> None:
Expand Down Expand Up @@ -401,15 +407,9 @@ def test_absorb_decode_fp8():

err = None
us_asm_decode = 1e12
if (dtype == torch.bfloat16 and kvtype == torch.bfloat16) and (
(nhead in [16]) or (decode_qlen == 1 and nhead in range(32, 128 + 1, 16))
):
if dtype == torch.bfloat16:
err, us_asm_decode = test_absorb_decode_bf16()
elif kvtype == dtypes.fp8 and (
(dtype == dtypes.fp8 and nhead in [16, 128])
or (dtype == dtypes.bf16 and nhead in [16])
or (decode_qlen == 1 and nhead in range(32, 128 + 1, 16))
):
elif kvtype == dtypes.fp8:
err, us_asm_decode = test_absorb_decode_fp8()
ret["decode:err"] = err
ret["decode:asm_576"] = us_asm_decode
Expand Down Expand Up @@ -566,23 +566,24 @@ def test_absorb_decode_fp8():
for dtype, kvtype, ctx_len, batch_size, max_split_per_batch in itertools.product(
list_dtype, l_kv_dtype, args.ctxLen, args.batchSize, args.max_split_per_batch
):
ret = test_mla(
ctx_len,
batch_size,
nhead,
args.kv_lora_rank,
args.qk_nope_head_dim,
args.qk_rope_head_dim,
args.v_head_dim,
dtype,
kvtype,
args.block_size,
varlen=args.varlen,
decode_qlen=decode_qlen,
max_split_per_batch=max_split_per_batch,
non_persistent_mode=args.non_persistent_mode,
)
df.append(ret)
if check_support(dtype, kvtype, nhead):
ret = test_mla(
ctx_len,
batch_size,
nhead,
args.kv_lora_rank,
args.qk_nope_head_dim,
args.qk_rope_head_dim,
args.v_head_dim,
dtype,
kvtype,
args.block_size,
varlen=args.varlen,
decode_qlen=decode_qlen,
max_split_per_batch=max_split_per_batch,
non_persistent_mode=args.non_persistent_mode,
)
df.append(ret)
df = pd.DataFrame(df)
# df.to_csv(f"mla_nhead{nhead}decode_qlen{decode_qlen}.csv")
aiter.logger.info(f"summary:\n{df}")
Loading