Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
106 commits
Select commit Hold shift + click to select a range
aabd8f0
Modify op_benchmark directory structure to add bench_tests/ and bench…
willzhou-amd Jun 27, 2025
937b27e
Merge branch 'main' into willz/benchmarking-improvements
willzhou-amd Jun 27, 2025
a4b0e50
Update table formatting for bench_gemm_a8w8 and add tests for bench_g…
willzhou-amd Jun 30, 2025
a53a847
Add tensor parallel in bench_gemm_a8w8.py
willzhou-amd Jun 30, 2025
c4d03ee
Add -no_glu arg, fix error in tensor parallelism, and reset folder st…
willzhou-amd Jun 30, 2025
f09d24e
Fix argparse & tensor parallel bug
willzhou-amd Jun 30, 2025
fa770c1
Update bench_gemm_a8w8_blockscale.py and add repeated code to benchma…
willzhou-amd Jun 30, 2025
cc5fa63
Merge branch 'main' into willz/benchmarking-improvements
willzhou-amd Jun 30, 2025
d2c4817
Merge branch 'main' into willz/benchmarking-improvements
willzhou-amd Jul 2, 2025
0537963
Consolidate bench fn
willzhou-amd Jul 2, 2025
649f953
Consolidate bench fn: int8 blockscale
willzhou-amd Jul 2, 2025
71c0172
Merge branch 'main' into willz/benchmarking-improvements
willzhou-amd Jul 7, 2025
32eabab
Unify argparse for MHA benchmarking
willzhou-amd Jul 7, 2025
00dc362
Update configs for mha bench
willzhou-amd Jul 7, 2025
8e2266c
Broadcast updates to bench_batched_gemm_afp4wfp4.py
willzhou-amd Jul 7, 2025
d541d1b
Fix issue with arg names in bench_batched_gemm_afp4wfp4
willzhou-amd Jul 7, 2025
59e9c93
Add stride shape upcasting
willzhou-amd Jul 8, 2025
78b70fa
Broadcast changes to batch_gemm_afp4wfp4_pre_quant
willzhou-amd Jul 8, 2025
6539b54
Improve code reuse + fix benchmarking FLOP computation bug
willzhou-amd Jul 8, 2025
6762797
Fix shape order to allow plots to display properly
willzhou-amd Jul 8, 2025
ecd55db
Merge branch 'main' into willz/benchmarking-improvements
willzhou-amd Jul 8, 2025
30db901
Sweep through moe, extend_attn, prefill, rmsnorm, rope to fix bugs an…
willzhou-amd Jul 8, 2025
d327305
Add --model and --shape support to bench_routing.py
willzhou-amd Jul 8, 2025
577b79c
Add MOE information to deepseek model config
willzhou-amd Jul 8, 2025
a97417c
Merge branch 'main' into willz/benchmarking-improvements
willzhou-amd Jul 8, 2025
adadc86
Revert linting changes in the CK dir
willzhou-amd Jul 8, 2025
9fafe20
Revert linting changes to ck dir
willzhou-amd Jul 8, 2025
7db6fac
Black linting change
willzhou-amd Jul 8, 2025
dad6093
Fix f-string issue
willzhou-amd Jul 8, 2025
2140a72
Merge branch 'main' into willz/benchmarking-improvements
willzhou-amd Jul 9, 2025
4926570
Merge branch 'main' into willz/benchmarking-improvements
willzhou-amd Jul 9, 2025
f440ce9
Merge branch 'main' into willz/benchmarking-improvements
willzhou-amd Jul 9, 2025
e595628
Add --model support to bench_topk.py & set int64 stride flag in mha
willzhou-amd Jul 9, 2025
c11ec5d
Merge branch 'willz/benchmarking-improvements' of https://github.com/…
willzhou-amd Jul 9, 2025
e33a0de
Undo linting changes to csrc
willzhou-amd Jul 9, 2025
9cc6cb4
Add informative error when trying to benchmark non-MoE models
willzhou-amd Jul 9, 2025
1cee921
Format with Black
willzhou-amd Jul 9, 2025
2ab4989
Merge branch 'main' into willz/benchmarking-improvements
willzhou-amd Jul 10, 2025
ee6e70f
Support model flag for bench_gemm_a16w16
willzhou-amd Jul 10, 2025
0ac91c3
Merge branch 'main' into willz/benchmarking-improvements
willzhou-amd Jul 10, 2025
b606c7a
Add --layout flag support to int8 and fp16 GEMMs + set graph axes to …
willzhou-amd Jul 10, 2025
3fcd0a9
Add --layout support to afp4wfp4 GEMM
willzhou-amd Jul 10, 2025
9e92dbe
Fix function naming in bench_gemm_afp4wfp4
willzhou-amd Jul 10, 2025
0a0fa9c
Merge branch 'main' into willz/benchmarking-improvements
willzhou-amd Jul 10, 2025
5ee387d
Replace missing comma
willzhou-amd Jul 10, 2025
8d9d0c3
Merge branch 'willz/benchmarking-improvements' into willz/benchmarkin…
willzhou-amd Jul 10, 2025
eb2de54
Add --layout support to batched afp4wfp4 pre quant gemm
willzhou-amd Jul 10, 2025
930a674
Merge branch 'main' into willz/benchmarking-improvements
willzhou-amd Jul 10, 2025
4071f66
Remove linting changes that removed CN comments
willzhou-amd Jul 10, 2025
4eee8ab
Merge branch 'main' into willz/benchmarking-memory-layout
willzhou-amd Jul 11, 2025
97b98ea
Remove merge duplicates
willzhou-amd Jul 11, 2025
e1d9a0c
Undo linting changes that removed CN comments
willzhou-amd Jul 11, 2025
b61a599
Fix bug with -M flag
willzhou-amd Jul 11, 2025
72ef22f
Merge branch 'main' into willz/benchmarking-memory-layout
willzhou-amd Jul 11, 2025
68a430b
Add --layout support to a8w8 blockscale gemm
willzhou-amd Jul 11, 2025
4834300
Add --layout support to batched afp4wfp4 GEMM
willzhou-amd Jul 11, 2025
c35e57e
Formatting changes
willzhou-amd Jul 11, 2025
a6ea3ab
Formatting changes
willzhou-amd Jul 11, 2025
a21a1f5
Debug shape issue that causes segfault when K > M
willzhou-amd Jul 15, 2025
ea4b16c
Black linting change
willzhou-amd Jul 15, 2025
ddab4d2
Fix issue where running batched GEMM benchmarking scripts with no arg…
willzhou-amd Jul 15, 2025
3798b0e
Add batched a8w8 benchmark
willzhou-amd Jul 15, 2025
a1cd571
Add batched bf16 benchmark
willzhou-amd Jul 15, 2025
d4b1ad0
Update a8fp4 tests to add input generating function
willzhou-amd Jul 15, 2025
44bfebb
Update test shapes
willzhou-amd Jul 15, 2025
5effb1a
Add benchmarking script for afp4wfp4 pre quant GEMM
willzhou-amd Jul 15, 2025
9f1b415
Linting changes
willzhou-amd Jul 15, 2025
7c4683a
Linting changes
willzhou-amd Jul 15, 2025
c4868af
Merge branch 'willz/benchmarking-improvements' into willz/additional-…
willzhou-amd Jul 15, 2025
a4aff2e
Stash changes
willzhou-amd Jul 16, 2025
923dbe3
Add -o flag and other fixes for benchmark scripts
Jul 16, 2025
57a4823
Fix moe_routing_sigmoid benchmark
Jul 16, 2025
e43f37f
add Mi350 config json for extend attention
Jul 16, 2025
2ec741d
Linting fixes
Jul 16, 2025
355266f
Merge remote-tracking branch 'origin/main' into willz/benchmarking-me…
Jul 17, 2025
ec5803e
More formatting fixes
Jul 17, 2025
f64ec83
batched_gemm mxfp4 fixes
Jul 17, 2025
ba10d1c
Merge branch 'main' into willz/benchmarking-memory-layout
willzhou-amd Jul 17, 2025
eb0b066
Linting changes
willzhou-amd Jul 17, 2025
51742c3
Merge branch 'willz/benchmarking-memory-layout' into willz/additional…
willzhou-amd Jul 17, 2025
73b1fc0
Update mla decode benchmark
willzhou-amd Jul 17, 2025
296428b
Update argparse for mla decode rope benchmark
willzhou-amd Jul 17, 2025
09d5f78
Stashing changes
willzhou-amd Jul 17, 2025
aba1a1a
Fix kpack bug on MI350x
willzhou-amd Jul 17, 2025
c403e3f
Complete support for --model flag for mla_decode_rope benchmarking
willzhou-amd Jul 18, 2025
935041e
Linting changes
willzhou-amd Jul 18, 2025
32e30a6
Slight tune
willzhou-amd Jul 18, 2025
ca3b340
Merge branch 'main' into willz/additional-gemm-benchmarks
willzhou-amd Jul 18, 2025
a67adb1
Revert unintentional changes from main in merge
willzhou-amd Jul 18, 2025
66c9d4c
Remove MLA decode from PR - will be in next one
willzhou-amd Jul 18, 2025
08fec10
Remove changes made to attention benchmarking scripts for this PR
willzhou-amd Jul 18, 2025
406bdfb
Undo accidental deletions & fix linting errors
willzhou-amd Jul 18, 2025
c5aad51
Merge branch 'main' into willz/additional-gemm-benchmarks
willzhou-amd Jul 18, 2025
a328795
Add a8wfp4 benchmarking script & fix minimal test error
willzhou-amd Jul 18, 2025
28b9143
Add --atomic flag for a16w16 GEMM benchmark
willzhou-amd Jul 21, 2025
536cbfd
Linting fix
willzhou-amd Jul 21, 2025
1f7055c
Update a8w8 benchmark - fix --shape flag for 4 args
willzhou-amd Jul 21, 2025
28ad9f0
Update a16w16 benchmark - misc fixes for GEMM layout flag, -B flag, d…
willzhou-amd Jul 21, 2025
68538fe
Fix bug with -M flag for all new batched GEMM kernels
willzhou-amd Jul 21, 2025
b33b333
Fix errors with afp4wfp4_pre_quant_atomic tests (change accum from bf…
willzhou-amd Jul 22, 2025
3cd4829
Arg changes to batched benchmarking scripts to support model name dis…
willzhou-amd Jul 22, 2025
046c261
Fold common batched model benchmarking config code into utility script
willzhou-amd Jul 22, 2025
b21080b
Add --get_vgpr flag to all GEMM benchmarking scripts
willzhou-amd Jul 22, 2025
d4bd3fe
Fix .json.json config (whoops)
willzhou-amd Jul 22, 2025
dd6d73d
Fix issue with vgpr table output generation parsing where tables with…
willzhou-amd Jul 22, 2025
a85d39a
Fix vgpr bug with --model on benched batched gemm afp4wfp4 pre quant
willzhou-amd Jul 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 149 additions & 0 deletions op_tests/op_benchmarks/triton/bench_batched_gemm_a16w16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import sys
import torch
import triton
import math
from op_tests.triton_tests.test_batched_gemm_bf16 import (
generate_batched_gemm_a16w16_inputs,
)
from op_tests.op_benchmarks.triton.utils.argparse import (
get_parser,
add_argparse_ff,
get_ff_args,
)
from op_tests.op_benchmarks.triton.utils.benchmark_utils import (
get_model_benchmark_object,
get_shape_benchmark_object,
batched_model_benchmark_shapes,
print_vgpr,
)
from aiter.ops.triton.batched_gemm_bf16 import batched_gemm_bf16


def bench_gemm_fn(batch: int, M: int, N: int, K: int, metric: str, layout: str):
c_dtype = torch.bfloat16
x, w, bias, y = generate_batched_gemm_a16w16_inputs(
batch, M, N, K, dtype=c_dtype, layout=layout, output=True
)
# print(f"M: {M}, N: {N}, K: {K}, x.shape: {x.shape}, x.stride(): {x.stride()}, w.shape: {w.shape}, w.stride(): {w.stride()}")
# flops
flops = 2.0 * M * N * K * batch
# memory transfer
mem_read = x.numel() * x.element_size() + w.numel() * w.element_size()
mem_write = (M * N) * 2 # TODO: Fix for c_dtype != bf16
mem = mem_read + mem_write

ms = triton.testing.do_bench(
lambda: batched_gemm_bf16(x, w, bias, c_dtype, YQ=y),
warmup=25,
rep=100,
)

# Return exactly one scalar depending on which metric is active
if metric == "time":
return ms
elif metric == "throughput":
tflops = flops / ms * 1e-9
return tflops
elif metric == "bandwidth":
bandwidth = mem / (ms * 1e-3) * 1e-9 # GB/s
return bandwidth
else:
raise ValueError("Unknown metric: " + metric)


def run_model_benchmark(args):
benchmark = get_model_benchmark_object(
plot_name="Batched GEMM MXFP4 x MXFP4 Benchmark",
args=args,
x_names=["M", "hidden_dim", "intermediate_dim", "batch", "model_name"],
model_benchmark_shapes_fn=batched_model_benchmark_shapes,
)

@triton.testing.perf_report([benchmark])
def bench_batched_gemm_a8w8(
M, hidden_dim, intermediate_dim, batch, metric, layer, **kwargs
):
if layer == "fc1":
if args.no_glu:
N, K = intermediate_dim, hidden_dim
else:
N, K = intermediate_dim * 2, hidden_dim
# Divide N by tensor parallel
N = math.ceil(N / args.tp)
elif layer == "fc2":
N, K = hidden_dim, intermediate_dim
# Divide K by tensor parallel
K = math.ceil(K / args.tp)
# print(f"Layer: {layer}, B: {batch}, M: {M}, N: {N}, K: {K}, hidden_dim: {hidden_dim}, intermediate_dim: {intermediate_dim}")

return bench_gemm_fn(batch, M, N, K, metric, args.layout)

bench_batched_gemm_a8w8.run(save_path="." if args.o else None, print_data=True)


def run_shape_benchmark(args):
benchmark = get_shape_benchmark_object(
plot_name="Batched GEMM MXFP4 x MXFP4 Benchmark",
args=args,
x_names=["batch", "M", "N", "K"],
)

@triton.testing.perf_report([benchmark])
def bench_batched_gemm_a8w8(batch, M, N, K, metric, provider):
return bench_gemm_fn(batch, M, N, K, metric, args.layout)

bench_batched_gemm_a8w8.run(save_path="." if args.o else None, print_data=True)


def run_benchmark(args, defaults):
assert not (args.shape and args.model) or not (
args.shape and args.M
), "User can specify --shape or --model MODEL -M VAL exclusively"

if args.model:
unsupported_args = []
for arg in unsupported_args:
if getattr(args, arg, None) != getattr(defaults, arg, None):
raise Exception(
f"Argument '{arg}' is not supported for benchmarking with the --model flag."
)
run_model_benchmark(args)
else:
unsupported_args = [
"fc1",
"fc2",
"no_glu",
"tp",
]
for arg in unsupported_args:
if getattr(args, arg, None) != getattr(defaults, arg, None):
raise Exception(
f"Argument '{arg}' is not supported for benchmarking without the --model flag."
)
run_shape_benchmark(args)


def parse_args():
parser = get_parser("Batched Int8 x Int8 GEMM")
parser = add_argparse_ff(parser)
parser.add_argument(
"-B",
type=int,
required=False,
help="Batch size to be used when using --model flag.",
)
return get_ff_args(parser)


def main():
args, defaults = parse_args()
if args.print_vgpr:
print("Retrieving VGPR usage for Triton kernels...")
fun = lambda: run_benchmark(args, defaults) # noqa: E731
print_vgpr(fun, "Batched GEMM")
return 0
run_benchmark(args, defaults)


if __name__ == "__main__":
sys.exit(main())
155 changes: 155 additions & 0 deletions op_tests/op_benchmarks/triton/bench_batched_gemm_a8w8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import sys
import torch
import triton
import math
from op_tests.triton_tests.test_batched_gemm_a8w8 import (
generate_batched_gemm_a8w8_inputs,
)
from op_tests.op_benchmarks.triton.utils.argparse import (
get_parser,
add_argparse_ff,
get_ff_args,
)
from op_tests.op_benchmarks.triton.utils.benchmark_utils import (
get_model_benchmark_object,
get_shape_benchmark_object,
batched_model_benchmark_shapes,
print_vgpr,
)
from aiter.ops.triton.batched_gemm_a8w8 import (
batched_gemm_a8w8 as batched_gemm_a8w8,
)


def bench_gemm_fn(batch: int, M: int, N: int, K: int, metric: str, layout: str):
c_dtype = torch.bfloat16
x, w, x_scale, w_scale, bias, y = generate_batched_gemm_a8w8_inputs(
batch, M, N, K, dtype=c_dtype, layout=layout, output=True
)
# print(f"M: {M}, N: {N}, K: {K}, x.shape: {x.shape}, x.stride(): {x.stride()}, w.shape: {w.shape}, w.stride(): {w.stride()}")
# flops
flops = 2.0 * M * N * K * batch
# memory transfer
mem_read = x.numel() * x.element_size() + w.numel() * w.element_size()
mem_read += (
x_scale.numel() * x_scale.element_size()
+ w_scale.numel() * w_scale.element_size()
)
mem_write = (M * N) * 2 # TODO: Fix for c_dtype != bf16
mem = mem_read + mem_write

ms = triton.testing.do_bench(
lambda: batched_gemm_a8w8(x, w, x_scale, w_scale, bias, c_dtype, YQ=y),
warmup=25,
rep=100,
)

# Return exactly one scalar depending on which metric is active
if metric == "time":
return ms
elif metric == "throughput":
tflops = flops / ms * 1e-9
return tflops
elif metric == "bandwidth":
bandwidth = mem / (ms * 1e-3) * 1e-9 # GB/s
return bandwidth
else:
raise ValueError("Unknown metric: " + metric)


def run_model_benchmark(args):
benchmark = get_model_benchmark_object(
plot_name="Batched GEMM MXFP4 x MXFP4 Benchmark",
args=args,
x_names=["M", "hidden_dim", "intermediate_dim", "batch", "model_name"],
model_benchmark_shapes_fn=batched_model_benchmark_shapes,
)

@triton.testing.perf_report([benchmark])
def bench_batched_gemm_a8w8(
M, hidden_dim, intermediate_dim, batch, metric, layer, **kwargs
):
if layer == "fc1":
if args.no_glu:
N, K = intermediate_dim, hidden_dim
else:
N, K = intermediate_dim * 2, hidden_dim
# Divide N by tensor parallel
N = math.ceil(N / args.tp)
elif layer == "fc2":
N, K = hidden_dim, intermediate_dim
# Divide K by tensor parallel
K = math.ceil(K / args.tp)
# print(f"Layer: {layer}, B: {batch}, M: {M}, N: {N}, K: {K}, hidden_dim: {hidden_dim}, intermediate_dim: {intermediate_dim}")

return bench_gemm_fn(batch, M, N, K, metric, args.layout)

bench_batched_gemm_a8w8.run(save_path="." if args.o else None, print_data=True)


def run_shape_benchmark(args):
benchmark = get_shape_benchmark_object(
plot_name="Batched GEMM MXFP4 x MXFP4 Benchmark",
args=args,
x_names=["batch", "M", "N", "K"],
)

@triton.testing.perf_report([benchmark])
def bench_batched_gemm_a8w8(batch, M, N, K, metric, provider):
return bench_gemm_fn(batch, M, N, K, metric, args.layout)

bench_batched_gemm_a8w8.run(save_path="." if args.o else None, print_data=True)


def run_benchmark(args, defaults):
assert not (args.shape and args.model) or not (
args.shape and args.M
), "User can specify --shape or --model MODEL -M VAL exclusively"

if args.model:
unsupported_args = []
for arg in unsupported_args:
if getattr(args, arg, None) != getattr(defaults, arg, None):
raise Exception(
f"Argument '{arg}' is not supported for benchmarking with the --model flag."
)
run_model_benchmark(args)
else:
unsupported_args = [
"fc1",
"fc2",
"no_glu",
"tp",
]
for arg in unsupported_args:
if getattr(args, arg, None) != getattr(defaults, arg, None):
raise Exception(
f"Argument '{arg}' is not supported for benchmarking without the --model flag."
)
run_shape_benchmark(args)


def parse_args():
parser = get_parser("Batched Int8 x Int8 GEMM")
parser = add_argparse_ff(parser)
parser.add_argument(
"-B",
type=int,
required=False,
help="Batch size to be used when using --model flag.",
)
return get_ff_args(parser)


def main():
args, defaults = parse_args()
if args.print_vgpr:
print("Retrieving VGPR usage for Triton kernels...")
fun = lambda: run_benchmark(args, defaults) # noqa: E731
print_vgpr(fun, "Batched GEMM")
return 0
run_benchmark(args, defaults)


if __name__ == "__main__":
sys.exit(main())
38 changes: 17 additions & 21 deletions op_tests/op_benchmarks/triton/bench_batched_gemm_afp4wfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,31 +13,15 @@
from op_tests.op_benchmarks.triton.utils.benchmark_utils import (
get_model_benchmark_object,
get_shape_benchmark_object,
get_model_configs,
batched_model_benchmark_shapes,
print_vgpr,
)
from aiter.ops.triton.batched_gemm_afp4wfp4 import (
batched_gemm_afp4wfp4 as batched_gemm_afp4wfp4,
)
import aiter.ops.triton.utils.arch_info as arch_info


def model_benchmark_shapes(args):
config_file = args.model_configs
configs = get_model_configs(config_path=config_file, models=args.model)
M_list = [4096] if args.model == "all" else [2**i for i in range(0, 15)]
shapes = []
for M in M_list:
for model_name, config in configs.items():
N = config["intermediate_size"]
K = config["hidden_size"]

shapes.append(
(model_name, M, N, K, 16)
) # rearrange batch to last dim so M is graph x-axis

return shapes


def bench_gemm_fn(
batch: int, M: int, N: int, K: int, metric: str, layout: str, model_name=None
):
Expand Down Expand Up @@ -87,7 +71,7 @@ def run_model_benchmark(args):
plot_name="Batched GEMM MXFP4 x MXFP4 Benchmark",
args=args,
x_names=["model_name", "M", "hidden_dim", "intermediate_dim", "batch"],
model_benchmark_shapes_fn=model_benchmark_shapes,
model_benchmark_shapes_fn=batched_model_benchmark_shapes,
)

@triton.testing.perf_report([benchmark])
Expand Down Expand Up @@ -116,11 +100,11 @@ def run_shape_benchmark(args):
benchmark = get_shape_benchmark_object(
plot_name="Batched GEMM MXFP4 x MXFP4 Benchmark",
args=args,
x_names=["M", "N", "K", "batch"],
x_names=["batch", "M", "N", "K"],
)

@triton.testing.perf_report([benchmark])
def bench_batched_gemm_afp4wfp4(M, N, K, batch, metric, provider, model_name=None):
def bench_batched_gemm_afp4wfp4(batch, M, N, K, metric, **kwargs):
return bench_gemm_fn(batch, M, N, K, metric, layout=args.layout)

bench_batched_gemm_afp4wfp4.run(save_path="." if args.o else None, print_data=True)
Expand All @@ -144,6 +128,7 @@ def run_benchmark(args, defaults):
"fc1",
"fc2",
"no_glu",
"tp",
]
for arg in unsupported_args:
if getattr(args, arg, None) != getattr(defaults, arg, None):
Expand All @@ -156,6 +141,12 @@ def run_benchmark(args, defaults):
def parse_args():
parser = get_parser("MXFP4 x MXFP4 GEMM")
parser = add_argparse_ff(parser)
parser.add_argument(
"-B",
type=int,
required=False,
help="Batch size to be used when using --model flag.",
)
return get_ff_args(parser)


Expand All @@ -165,6 +156,11 @@ def main():
sys.exit()

args, defaults = parse_args()
if args.print_vgpr:
print("Retrieving VGPR usage for Triton kernels...")
fun = lambda: run_benchmark(args, defaults) # noqa: E731
print_vgpr(fun, "Batched GEMM")
return 0
run_benchmark(args, defaults)


Expand Down
Loading