Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
4c4f729
add sparse attn. kernels
cagrikymk Oct 29, 2025
b10497f
formatting
cagrikymk Oct 29, 2025
afdfd81
add comments
cagrikymk Oct 29, 2025
bed26a4
formatting
cagrikymk Oct 29, 2025
de49d68
comment changes
cagrikymk Oct 29, 2025
02079a7
fp8 mqa kernel optimizations
cagrikymk Nov 5, 2025
069118d
formatting
cagrikymk Nov 5, 2025
205ba1c
fp8 mqa updates
cagrikymk Nov 5, 2025
53c8017
Merge branch 'main' into cagri/sparse_mla_attn_upstream
cagrikymk Nov 5, 2025
6a89d5d
renaming
cagrikymk Nov 5, 2025
57ea55e
formatting
cagrikymk Nov 5, 2025
d51e75a
fix comments
cagrikymk Nov 6, 2025
da77e58
Merge branch 'main' into cagri/sparse_mla_attn_upstream
cagrikymk Nov 7, 2025
3216d3d
nonkdim change for better perf
cagrikymk Nov 7, 2025
0f0253c
add mqa bench
cagrikymk Nov 12, 2025
d74b829
help msgs
cagrikymk Nov 12, 2025
ce23e9e
update bench code
cagrikymk Nov 12, 2025
ba72f16
formatting
cagrikymk Nov 12, 2025
4f637e6
more optimizations
cagrikymk Nov 12, 2025
67d470a
Merge branch 'main' into cagri/sparse_mla_attn_upstream
cagrikymk Nov 13, 2025
e632edd
cleanup
cagrikymk Nov 13, 2025
70b58f7
Merge branch 'main' into cagri/sparse_mla_attn_upstream
cagrikymk Nov 17, 2025
19ed0d1
loop bound changes
cagrikymk Nov 17, 2025
bfe4cb1
kernel name fix
cagrikymk Nov 17, 2025
33f0541
formatting changes
cagrikymk Nov 17, 2025
9792c1a
Merge branch 'main' into cagri/sparse_mla_attn_upstream
cagrikymk Nov 17, 2025
5233364
mqa bench fix and formatting
cagrikymk Nov 18, 2025
3952c05
Merge branch 'main' into cagri/sparse_mla_attn_upstream
cagrikymk Nov 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 28 additions & 25 deletions aiter/ops/triton/_triton_kernels/fp8_mqa_logits.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def _fp8_mqa_logits_kernel(
BLOCK_KV: tl.constexpr,
):
row_id = tl.program_id(0)

# go from larger to smaller in terms of work
# to reduce the tail effect
row_id = tl.num_programs(0) - row_id - 1
tl.assume(row_id >= 0)
tl.assume(stride_q_s > 0)
tl.assume(stride_q_h > 0)
Expand All @@ -39,6 +41,8 @@ def _fp8_mqa_logits_kernel(
tl.assume(stride_w_s > 0)
tl.assume(stride_w_h > 0)

logits_row_ptrs = logits_ptr + row_id * stride_logits_s

h_inds = tl.arange(0, NUM_HEADS)[:, None]
d_inds = tl.arange(0, HEAD_SIZE)

Expand All @@ -57,9 +61,9 @@ def _fp8_mqa_logits_kernel(

start_ind = tl.maximum(start_ind, 0)
end_ind = tl.minimum(end_ind, seq_len_kv)
unmasked_end_ind = (end_ind // BLOCK_KV) * BLOCK_KV
shifted_end = end_ind - start_ind
shifted_unmasked_end = shifted_end // BLOCK_KV * BLOCK_KV

logits_row_ptrs = logits_ptr + row_id * stride_logits_s
kv_col_offsets = tl.arange(0, BLOCK_KV) + start_ind
kv_ptrs = (
KV_ptr + kv_col_offsets[None, :] * stride_kv_s + d_inds[:, None] * stride_kv_d
Expand All @@ -70,12 +74,12 @@ def _fp8_mqa_logits_kernel(
logits_ptrs = logits_row_ptrs + kv_col_offsets * stride_logits_k

# Loop over KV tiles
for _ in tl.range(start_ind, unmasked_end_ind, BLOCK_KV):
for _ in tl.range(0, shifted_unmasked_end, BLOCK_KV):
kv_block = tl.load(kv_ptrs)
kv_scales = tl.load(kv_scales_ptrs)

# [NUM_HEADS, BLOCK_KV] = [NUM_HEADS, HEAD_SIZE] x [HEAD_SIZE, BLOCK_KV]
scores = tl.dot(q_block, kv_block)
scores = tl.dot(q_block, kv_block, input_precision="ieee")
# Multiply by kv_scales (broadcast along rows)
scores = scores * kv_scales[None, :]
# ReLU
Expand All @@ -88,23 +92,22 @@ def _fp8_mqa_logits_kernel(
kv_ptrs += BLOCK_KV * stride_kv_s
kv_scales_ptrs += BLOCK_KV
logits_ptrs += BLOCK_KV * stride_logits_k

if unmasked_end_ind != end_ind:
# masked load
kv_col_offsets = tl.arange(0, BLOCK_KV) + unmasked_end_ind
kv_col_mask = kv_col_offsets < seq_len_kv
kv_block = tl.load(kv_ptrs, mask=kv_col_mask[None, :], other=0.0)
kv_scales = tl.load(kv_scales_ptrs, mask=kv_col_mask, other=0.0)

# [NUM_HEADS, BLOCK_KV] = [NUM_HEADS, HEAD_SIZE] x [HEAD_SIZE, BLOCK_KV]
scores = tl.dot(q_block, kv_block)
# Multiply by kv_scales (broadcast along rows)
scores = scores * kv_scales[None, :]
# ReLU
scores = tl.maximum(scores, 0.0)
scores = scores * w_block
# [NUM_HEADS, BLOCK_KV] -> [BLOCK_KV, ]
scores = tl.sum(scores, axis=0)
# masked store
in_window = (kv_col_offsets >= start_ind) & (kv_col_offsets < end_ind)
tl.store(logits_ptrs, scores, mask=in_window)
kv_col_offsets += BLOCK_KV

# masked load
kv_col_mask = kv_col_offsets < end_ind
kv_block = tl.load(kv_ptrs, mask=kv_col_mask[None, :], other=0.0)
kv_scales = tl.load(kv_scales_ptrs, mask=kv_col_mask, other=0.0)

# [NUM_HEADS, BLOCK_KV] = [NUM_HEADS, HEAD_SIZE] x [HEAD_SIZE, BLOCK_KV]
scores = tl.dot(q_block, kv_block, input_precision="ieee")
# Multiply by kv_scales (broadcast along rows)
scores = scores * kv_scales[None, :]
# ReLU
scores = tl.maximum(scores, 0.0)
scores = scores * w_block
# [NUM_HEADS, BLOCK_KV] -> [BLOCK_KV, ]
scores = tl.sum(scores, axis=0)
# masked store
in_window = (kv_col_offsets >= start_ind) & (kv_col_offsets < end_ind)
tl.store(logits_ptrs, scores, mask=in_window)
8 changes: 7 additions & 1 deletion aiter/ops/triton/fp8_mqa_logits.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ def fp8_mqa_logits(
stride_kv_s, stride_kv_d = KV.stride()
stride_w_s, stride_w_h = weights.stride()
stride_logits_s, stride_logits_k = logits.stride()

# heuristic for MFMA instruction shape
matrix_instr_nonkdim = 32
if seq_len <= 1024:
matrix_instr_nonkdim = 16

_fp8_mqa_logits_kernel[(seq_len,)](
Q_ptr=Q,
KV_ptr=KV,
Expand All @@ -69,7 +75,7 @@ def fp8_mqa_logits(
num_warps=4,
num_stages=2,
waves_per_eu=2,
matrix_instr_nonkdim=16,
matrix_instr_nonkdim=matrix_instr_nonkdim,
)

return logits
117 changes: 117 additions & 0 deletions op_tests/op_benchmarks/triton/bench_fp8_mqa_logits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import torch
import triton
import argparse
from aiter.ops.triton.fp8_mqa_logits import fp8_mqa_logits
from aiter.ops.triton.utils.types import e4m3_dtype
from op_tests.triton_tests.test_fp8_mqa_logits import (
per_custom_dims_cast_to_fp8,
generate_cp_test_data,
)
from op_tests.op_benchmarks.triton.utils.benchmark_utils import (
print_vgpr,
get_caller_name_no_ext,
)


def calculate_tflops(start_inds, end_inds, num_heads_q, head_dim, time_ms):
time_s = time_ms * 1e-3
start_inds = start_inds.to("cpu").numpy()
end_inds = end_inds.to("cpu").numpy()
total_flops = 0.0
for i in range(len(start_inds)):
start = start_inds[i]
end = end_inds[i]
total_flops += 2.0 * num_heads_q * head_dim * (end - start)
# TFLOPs = total FLOPs / (time in seconds * 1e12)
tflops = total_flops / (time_s * 1e12)

return tflops


def run_benchmark(args):
x_names = ["seq_q_l", "seq_kv_l", "num_heads_q", "head_dim"]
x_vals_list = [[args.seq_q_l, args.seq_kv_l, args.num_heads_q, args.head_dim]]
if args.metric == "time":
ylabel = "Time (ms)"
elif args.metric == "throughput":
ylabel = "TFLOPs"
else:
raise NotImplementedError(f"{args.metric} is not supported")

line_names = [ylabel]
line_vals = [ylabel]
benchmark = triton.testing.Benchmark(
x_names=x_names,
x_vals=x_vals_list,
line_arg="unit",
line_vals=line_vals,
line_names=line_names,
styles=[("green", "-")],
ylabel=ylabel,
plot_name=get_caller_name_no_ext(),
args={"metric": args.metric},
)

@triton.testing.perf_report([benchmark])
def bench_fp8_mqa_logits(
seq_q_l, seq_kv_l, num_heads_q, head_dim, metric, **kwargs
):
q = torch.randn(
seq_q_l, num_heads_q, head_dim, device="cuda", dtype=torch.bfloat16
)
kv = torch.randn(seq_kv_l, head_dim, device="cuda", dtype=torch.bfloat16)
weights = torch.randn(seq_q_l, num_heads_q, device="cuda", dtype=torch.float32)

ks = torch.zeros(seq_q_l, dtype=torch.int, device="cuda")
ke = torch.arange(seq_q_l, dtype=torch.int, device="cuda") + (
seq_kv_l - seq_q_l
)

q_fp8 = q.to(e4m3_dtype)
kv_fp8, scales = per_custom_dims_cast_to_fp8(kv, (0,), False)

func = lambda: fp8_mqa_logits(q_fp8, kv_fp8, scales, weights, ks, ke)

time_ms = triton.testing.do_bench(func, warmup=25, rep=100)
tflops = calculate_tflops(ks, ke, num_heads_q, head_dim, time_ms)

# Return exactly one scalar depending on which metric is active
if metric == "time":
return time_ms
elif metric == "throughput":
return tflops
else:
raise ValueError("Unknown metric: " + metric)

bench_fp8_mqa_logits.run(save_path="." if args.o else None, print_data=True)


def main():
parser = argparse.ArgumentParser(
description="FP8 MQA Logits Benchmark",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--num_heads_q", type=int, default=64, help="num. q heads")
parser.add_argument("--head_dim", type=int, default=128, help="head dim size")
parser.add_argument(
"--seq_q_l", type=int, default=4096, help="Input sequence length"
)
parser.add_argument(
"--seq_kv_l", type=int, default=4096, help="Output sequence length"
)
parser.add_argument(
"-o", action="store_true", help="Write performance results to CSV file"
)
parser.add_argument(
"--metric",
type=str,
choices=["time", "throughput"],
default="throughput",
help="metric to plot",
)
args = parser.parse_args()
run_benchmark(args)


if __name__ == "__main__":
main()
Loading