Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
aabd8f0
Modify op_benchmark directory structure to add bench_tests/ and bench…
willzhou-amd Jun 27, 2025
937b27e
Merge branch 'main' into willz/benchmarking-improvements
willzhou-amd Jun 27, 2025
a4b0e50
Update table formatting for bench_gemm_a8w8 and add tests for bench_g…
willzhou-amd Jun 30, 2025
a53a847
Add tensor parallel in bench_gemm_a8w8.py
willzhou-amd Jun 30, 2025
c4d03ee
Add -no_glu arg, fix error in tensor parallelism, and reset folder st…
willzhou-amd Jun 30, 2025
f09d24e
Fix argparse & tensor parallel bug
willzhou-amd Jun 30, 2025
fa770c1
Update bench_gemm_a8w8_blockscale.py and add repeated code to benchma…
willzhou-amd Jun 30, 2025
cc5fa63
Merge branch 'main' into willz/benchmarking-improvements
willzhou-amd Jun 30, 2025
d2c4817
Merge branch 'main' into willz/benchmarking-improvements
willzhou-amd Jul 2, 2025
0537963
Consolidate bench fn
willzhou-amd Jul 2, 2025
649f953
Consolidate bench fn: int8 blockscale
willzhou-amd Jul 2, 2025
71c0172
Merge branch 'main' into willz/benchmarking-improvements
willzhou-amd Jul 7, 2025
32eabab
Unify argparse for MHA benchmarking
willzhou-amd Jul 7, 2025
00dc362
Update configs for mha bench
willzhou-amd Jul 7, 2025
8e2266c
Broadcast updates to bench_batched_gemm_afp4wfp4.py
willzhou-amd Jul 7, 2025
d541d1b
Fix issue with arg names in bench_batched_gemm_afp4wfp4
willzhou-amd Jul 7, 2025
59e9c93
Add stride shape upcasting
willzhou-amd Jul 8, 2025
78b70fa
Broadcast changes to batch_gemm_afp4wfp4_pre_quant
willzhou-amd Jul 8, 2025
6539b54
Improve code reuse + fix benchmarking FLOP computation bug
willzhou-amd Jul 8, 2025
6762797
Fix shape order to allow plots to display properly
willzhou-amd Jul 8, 2025
ecd55db
Merge branch 'main' into willz/benchmarking-improvements
willzhou-amd Jul 8, 2025
30db901
Sweep through moe, extend_attn, prefill, rmsnorm, rope to fix bugs an…
willzhou-amd Jul 8, 2025
d327305
Add --model and --shape support to bench_routing.py
willzhou-amd Jul 8, 2025
577b79c
Add MOE information to deepseek model config
willzhou-amd Jul 8, 2025
a97417c
Merge branch 'main' into willz/benchmarking-improvements
willzhou-amd Jul 8, 2025
adadc86
Revert linting changes in the CK dir
willzhou-amd Jul 8, 2025
9fafe20
Revert linting changes to ck dir
willzhou-amd Jul 8, 2025
7db6fac
Black linting change
willzhou-amd Jul 8, 2025
dad6093
Fix f-string issue
willzhou-amd Jul 8, 2025
2140a72
Merge branch 'main' into willz/benchmarking-improvements
willzhou-amd Jul 9, 2025
4926570
Merge branch 'main' into willz/benchmarking-improvements
willzhou-amd Jul 9, 2025
f440ce9
Merge branch 'main' into willz/benchmarking-improvements
willzhou-amd Jul 9, 2025
e595628
Add --model support to bench_topk.py & set int64 stride flag in mha
willzhou-amd Jul 9, 2025
c11ec5d
Merge branch 'willz/benchmarking-improvements' of https://github.com/…
willzhou-amd Jul 9, 2025
e33a0de
Undo linting changes to csrc
willzhou-amd Jul 9, 2025
9cc6cb4
Add informative error when trying to benchmark non-MoE models
willzhou-amd Jul 9, 2025
1cee921
Format with Black
willzhou-amd Jul 9, 2025
2ab4989
Merge branch 'main' into willz/benchmarking-improvements
willzhou-amd Jul 10, 2025
ee6e70f
Support model flag for bench_gemm_a16w16
willzhou-amd Jul 10, 2025
0ac91c3
Merge branch 'main' into willz/benchmarking-improvements
willzhou-amd Jul 10, 2025
b606c7a
Add --layout flag support to int8 and fp16 GEMMs + set graph axes to …
willzhou-amd Jul 10, 2025
3fcd0a9
Add --layout support to afp4wfp4 GEMM
willzhou-amd Jul 10, 2025
9e92dbe
Fix function naming in bench_gemm_afp4wfp4
willzhou-amd Jul 10, 2025
0a0fa9c
Merge branch 'main' into willz/benchmarking-improvements
willzhou-amd Jul 10, 2025
5ee387d
Replace missing comma
willzhou-amd Jul 10, 2025
8d9d0c3
Merge branch 'willz/benchmarking-improvements' into willz/benchmarkin…
willzhou-amd Jul 10, 2025
eb2de54
Add --layout support to batched afp4wfp4 pre quant gemm
willzhou-amd Jul 10, 2025
4eee8ab
Merge branch 'main' into willz/benchmarking-memory-layout
willzhou-amd Jul 11, 2025
97b98ea
Remove merge duplicates
willzhou-amd Jul 11, 2025
e1d9a0c
Undo linting changes that removed CN comments
willzhou-amd Jul 11, 2025
b61a599
Fix bug with -M flag
willzhou-amd Jul 11, 2025
72ef22f
Merge branch 'main' into willz/benchmarking-memory-layout
willzhou-amd Jul 11, 2025
68a430b
Add --layout support to a8w8 blockscale gemm
willzhou-amd Jul 11, 2025
4834300
Add --layout support to batched afp4wfp4 GEMM
willzhou-amd Jul 11, 2025
c35e57e
Formatting changes
willzhou-amd Jul 11, 2025
a6ea3ab
Formatting changes
willzhou-amd Jul 11, 2025
a21a1f5
Debug shape issue that causes segfault when K > M
willzhou-amd Jul 15, 2025
ea4b16c
Black linting change
willzhou-amd Jul 15, 2025
ddab4d2
Fix issue where running batched GEMM benchmarking scripts with no arg…
willzhou-amd Jul 15, 2025
7c4683a
Linting changes
willzhou-amd Jul 15, 2025
923dbe3
Add -o flag and other fixes for benchmark scripts
Jul 16, 2025
57a4823
Fix moe_routing_sigmoid benchmark
Jul 16, 2025
e43f37f
add Mi350 config json for extend attention
Jul 16, 2025
2ec741d
Linting fixes
Jul 16, 2025
355266f
Merge remote-tracking branch 'origin/main' into willz/benchmarking-me…
Jul 17, 2025
ec5803e
More formatting fixes
Jul 17, 2025
f64ec83
batched_gemm mxfp4 fixes
Jul 17, 2025
ba10d1c
Merge branch 'main' into willz/benchmarking-memory-layout
willzhou-amd Jul 17, 2025
eb0b066
Linting changes
willzhou-amd Jul 17, 2025
54f9478
Fix batched_gemm_afp4wfp4_pre_quant benchmark
Jul 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions aiter/ops/triton/configs/MI350X-EXTEND_ATTENTION.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"default": {
"BLOCK_M": 64,
"BLOCK_N": 64,
"waves_per_eu": 1,
"num_warps": 4,
"matrix_instr_nonkdim": 16,
"kpack": 2
}
}
70 changes: 70 additions & 0 deletions aiter/ops/triton/configs/moe/MI350X-MOE_ROUTING_SIGMOID_TOPK1.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
{
"N16": {
"small" :{
"BLOCK_M": 16,
"BLOCK_K": 256,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 3,
"kpack": 1
},
"medium" :{
"BLOCK_M": 16,
"BLOCK_K": 256,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 3,
"kpack": 1
},
"large" :{
"BLOCK_M": 16,
"BLOCK_K": 256,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 3,
"kpack": 2
},
"xlarge" :{
"BLOCK_M": 32,
"BLOCK_K": 128,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 2,
"kpack": 2
}
},
"N128": {
"small" :{
"BLOCK_M": 16,
"BLOCK_K": 256,
"num_warps": 8,
"num_stages": 1,
"waves_per_eu": 0,
"kpack": 1
},
"medium" :{
"BLOCK_M": 16,
"BLOCK_K": 256,
"num_warps": 8,
"num_stages": 1,
"waves_per_eu": 0,
"kpack": 2
},
"large" :{
"BLOCK_M": 16,
"BLOCK_K": 256,
"num_warps": 8,
"num_stages": 1,
"waves_per_eu": 2,
"kpack": 2
},
"xlarge" :{
"BLOCK_M": 32,
"BLOCK_K": 128,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 2,
"kpack": 2
}
}
}
61 changes: 0 additions & 61 deletions aiter/ops/triton/moe_routing_sigmoid_top1_fused.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,67 +11,6 @@
from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH


def get_config_heuristic(M, K, N):
"""
Return the best Triton configuration based on input dimensions.

Args:
M: Batch dimension
K: Hidden dimension
N: Number of experts (16 or 128)
TOPK: Top-k value (default: 1)

Returns:
triton.Config: Configuration for the Triton kernel
"""
# Determine M bucket (small: <2048, medium: 2048-4095, large: 4096-8191, very_large: 8192+)
m_bucket = (
"very_large"
if M >= 8192
else "large" if M >= 4096 else "medium" if M >= 2048 else "small"
)

# Create parameter configuration using nested dictionaries
configs = {
# Format: {N: {m_bucket: (BLOCK_M, BLOCK_K, num_warps, num_stages, waves_per_eu, kpack)}}
16: {
"small": (16, 256, 4, 2, 3, 1),
"medium": (16, 256, 4, 2, 3, 1),
"large": (16, 256, 4, 2, 3, 2),
"very_large": (32, 256, 4, 2, 0, 1),
},
128: {
"small": (16, 256, 8, 1, 0, 1),
"medium": (16, 256, 8, 1, 0, 2),
"large": (16, 256, 8, 1, 2, 2),
"very_large": (32, 128, 8, 2, 2, 2),
},
256: {
"small": (16, 64, 8, 1, 0, 1),
"medium": (16, 64, 8, 1, 0, 2),
"large": (16, 64, 8, 1, 2, 2),
"very_large": (16, 64, 8, 2, 2, 2),
},
}

# Get configuration parameters
BLOCK_M, BLOCK_K, num_warps, num_stages, waves_per_eu, kpack = configs[N][m_bucket]

# Return Triton configuration
return triton.Config(
{
"BLOCK_M": BLOCK_M,
"BLOCK_K": BLOCK_K,
"matrix_instr_nonkdim": 16, # Always 16
"waves_per_eu": waves_per_eu,
"kpack": kpack,
},
num_warps=num_warps,
num_stages=num_stages,
num_ctas=1,
)


@triton.jit
def _routing_sigmoid_top1_kernel(
X_ptr,
Expand Down
48 changes: 29 additions & 19 deletions op_tests/op_benchmarks/triton/bench_batched_gemm_afp4wfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,39 @@
from aiter.ops.triton.batched_gemm_afp4wfp4 import (
batched_gemm_afp4wfp4 as batched_gemm_afp4wfp4,
)
import aiter.ops.triton.utils.arch_info as arch_info


def model_benchmark_shapes(args):
config_file = args.model_configs
configs = get_model_configs(config_path=config_file, models=args.model)
M_list = [args.M] if args.model == "all" else [2**i for i in range(0, 15)]
M_list = [4096] if args.model == "all" else [2**i for i in range(0, 15)]
shapes = []
for M in M_list:
for _, config in configs.items():
for model_name, config in configs.items():
N = config["intermediate_size"]
K = config["hidden_size"]

shapes.append(
(M, N, K, 16)
(model_name, M, N, K, 16)
) # rearrange batch to last dim so M is graph x-axis

return shapes


def bench_gemm_fn(batch, M, N, K, metric):
def bench_gemm_fn(
batch: int, M: int, N: int, K: int, metric: str, layout: str, model_name=None
):
c_dtype = torch.bfloat16
x, w, x_scale, w_scale = generate_batched_gemm_afp4wfp4_inputs(batch, M, N, K)
x, w, x_scale, w_scale, y = generate_batched_gemm_afp4wfp4_inputs(
batch,
M,
N,
K,
c_dtype,
layout=layout,
output=True,
)
# print(f"M: {M}, N: {N}, K: {K}, x.shape: {x.shape}, x.stride(): {x.stride()}, w.shape: {w.shape}, w.stride(): {w.stride()}")
# flops
flops = 2.0 * M * N * K * batch
Expand All @@ -51,12 +62,9 @@ def bench_gemm_fn(batch, M, N, K, metric):
)
mem_write = (M * N) * 2 # TODO: Fix for c_dtype != bf16
mem = mem_read + mem_write
out = torch.empty(
x.shape[0], x.shape[1], w.shape[2], device=x.device, dtype=c_dtype
)

ms = triton.testing.do_bench(
lambda: batched_gemm_afp4wfp4(x, w, x_scale, w_scale, c_dtype, out),
lambda: batched_gemm_afp4wfp4(x, w, x_scale, w_scale, c_dtype, y),
warmup=25,
rep=100,
)
Expand All @@ -78,13 +86,13 @@ def run_model_benchmark(args):
benchmark = get_model_benchmark_object(
plot_name="Batched GEMM MXFP4 x MXFP4 Benchmark",
args=args,
x_names=["M", "hidden_dim", "intermediate_dim", "batch"],
x_names=["model_name", "M", "hidden_dim", "intermediate_dim", "batch"],
model_benchmark_shapes_fn=model_benchmark_shapes,
)

@triton.testing.perf_report([benchmark])
def bench_batched_gemm_afp4wfp4(
M, hidden_dim, intermediate_dim, batch, metric, layer, **kwargs
M, hidden_dim, intermediate_dim, batch, metric, layer, model_name=None, **kwargs
):
if layer == "fc1":
if args.no_glu:
Expand All @@ -99,9 +107,9 @@ def bench_batched_gemm_afp4wfp4(
K = math.ceil(K / args.tp)
# print(f"Layer: {layer}, B: {batch}, M: {M}, N: {N}, K: {K}, hidden_dim: {hidden_dim}, intermediate_dim: {intermediate_dim}")

return bench_gemm_fn(batch, M, N, K, metric)
return bench_gemm_fn(batch, M, N, K, metric, layout=args.layout)

bench_batched_gemm_afp4wfp4.run(save_path=".", print_data=True)
bench_batched_gemm_afp4wfp4.run(save_path="." if args.o else None, print_data=True)


def run_shape_benchmark(args):
Expand All @@ -112,10 +120,10 @@ def run_shape_benchmark(args):
)

@triton.testing.perf_report([benchmark])
def bench_batched_gemm_afp4wfp4(M, N, K, batch, metric, provider):
return bench_gemm_fn(batch, M, N, K, metric)
def bench_batched_gemm_afp4wfp4(M, N, K, batch, metric, provider, model_name=None):
return bench_gemm_fn(batch, M, N, K, metric, layout=args.layout)

bench_batched_gemm_afp4wfp4.run(save_path=".", print_data=True)
bench_batched_gemm_afp4wfp4.run(save_path="." if args.o else None, print_data=True)


def run_benchmark(args, defaults):
Expand All @@ -124,9 +132,7 @@ def run_benchmark(args, defaults):
), "User can specify --shape or --model MODEL -M VAL exclusively"

if args.model:
unsupported_args = [
"layout",
]
unsupported_args = []
for arg in unsupported_args:
if getattr(args, arg, None) != getattr(defaults, arg, None):
raise Exception(
Expand Down Expand Up @@ -154,6 +160,10 @@ def parse_args():


def main():
if not (arch_info.is_fp4_avail()):
print("MXFP4 is not available on this architecture")
sys.exit()

args, defaults = parse_args()
run_benchmark(args, defaults)

Expand Down
Loading