Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
f365115
Add optional fused activation + gating args to a16w16 GEMM
willzhou-amd Jul 25, 2025
fd7625f
Add tests for fused a16w16 GEMM
willzhou-amd Jul 25, 2025
910170d
Formatting changes
willzhou-amd Jul 25, 2025
abef383
Add gating & activation tests for fused a16w16 GEMM
willzhou-amd Jul 25, 2025
c30e5dd
Factor out fused GEMM kernel into separate file
willzhou-amd Jul 28, 2025
30b3a9c
Revert testing tolerances to not affect CI
willzhou-amd Jul 28, 2025
4c2cd6e
Merge branch 'main' into willz/fused-ff-gemms
willzhou-amd Jul 28, 2025
927f804
Update a16w16 gated GEMM tests
willzhou-amd Jul 28, 2025
9ced0f4
Factor out FF block function into separate file + write tests + bench…
willzhou-amd Jul 29, 2025
c11dfba
Add tests for full FF interface
willzhou-amd Jul 29, 2025
f827c20
Formatting changes
willzhou-amd Jul 29, 2025
c40c8c0
Update fused-act-gate a16w16 benchmark
willzhou-amd Jul 29, 2025
22c5ba1
Fix shape error for FF interface when y is not provided
willzhou-amd Jul 29, 2025
082eddb
Add tests for ungated FF function
willzhou-amd Jul 29, 2025
daa3f01
Tune a16w16 performance for standard & gated GEMMs
willzhou-amd Jul 29, 2025
03b9ac7
Remove FF interface (will refactor for next PR)
willzhou-amd Jul 30, 2025
934bcb1
Add mi300x tuning configs for A16W16 & Gated A16W16 GEMMs
willzhou-amd Jul 30, 2025
1b8e84b
Fix error with DS config under new config format
willzhou-amd Aug 1, 2025
6420a5b
Restore config to previously tuned variables
willzhou-amd Aug 1, 2025
ad5d0ec
Add logging to a16w16 gated GEMM
willzhou-amd Aug 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions aiter/ops/triton/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,19 @@ def _gelu_tanh(x):
return 0.5 * x * (1.0 + _tanh(inner))


@triton.jit
def _relu(x):
return tl.maximum(0.0, x)


@tl.constexpr_function
def _get_activation_from_str(activation: str):
mapping = {
"gelu": _gelu,
"gelu_tanh": _gelu_tanh,
"silu": _silu,
"silu_exp2": _silu_exp2,
"relu": _relu,
}
return mapping[activation]

Expand Down
74 changes: 74 additions & 0 deletions aiter/ops/triton/configs/gemm/MI300X-GEMM-A16W16-gated.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
{
"M_LEQ_64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 2,
"matrix_instr_nonkdim": 16,
"cache_modifier": ".cg",
"kpack": 1
},
"M_LEQ_128": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 2,
"matrix_instr_nonkdim": 16,
"cache_modifier": ".cg",
"kpack": 1
},
"M_LEQ_256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 2,
"matrix_instr_nonkdim": 16,
"cache_modifier": ".cg",
"kpack": 1
},
"M_LEQ_512": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 2,
"matrix_instr_nonkdim": 16,
"cache_modifier": null,
"kpack": 1
},
"M_LEQ_2048": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 2,
"matrix_instr_nonkdim": 16,
"cache_modifier": null,
"kpack": 1
},
"M_GEQ_4096": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 4,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 2,
"matrix_instr_nonkdim": 16,
"cache_modifier": null,
"kpack": 1
}
}
62 changes: 61 additions & 1 deletion aiter/ops/triton/configs/gemm/MI300X-GEMM-A16W16.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,65 @@
{
"any": {
"M_LEQ_64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 2,
"matrix_instr_nonkdim": 16,
"cache_modifier": ".cg",
"kpack": 1
},
"M_LEQ_128": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 2,
"matrix_instr_nonkdim": 16,
"cache_modifier": ".cg",
"kpack": 1
},
"M_LEQ_256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 2,
"matrix_instr_nonkdim": 16,
"cache_modifier": ".cg",
"kpack": 1
},
"M_LEQ_512": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 2,
"matrix_instr_nonkdim": 16,
"cache_modifier": null,
"kpack": 1
},
"M_LEQ_2048": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 2,
"matrix_instr_nonkdim": 16,
"cache_modifier": null,
"kpack": 1
},
"M_GEQ_4096": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
Expand Down
66 changes: 57 additions & 9 deletions aiter/ops/triton/configs/gemm/MI350X-GEMM-A16W16-N=256-K=7168.json
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
{
"any": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"M_LEQ_64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 512,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_stages": 3,
"waves_per_eu": 1,
"matrix_instr_nonkdim": 32,
"cache_modifier": null,
"matrix_instr_nonkdim": 16,
"cache_modifier": ".cg",
"kpack": 1
},
"small": {
"M_LEQ_128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 512,
Expand All @@ -22,5 +22,53 @@
"matrix_instr_nonkdim": 16,
"cache_modifier": ".cg",
"kpack": 1
},
"M_LEQ_256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 1,
"matrix_instr_nonkdim": 32,
"cache_modifier": null,
"kpack": 1
},
"M_LEQ_512": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 1,
"matrix_instr_nonkdim": 32,
"cache_modifier": null,
"kpack": 1
},
"M_LEQ_2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 1,
"matrix_instr_nonkdim": 32,
"cache_modifier": null,
"kpack": 1
},
"M_GEQ_4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 1,
"matrix_instr_nonkdim": 32,
"cache_modifier": null,
"kpack": 1
}
}
}
74 changes: 74 additions & 0 deletions aiter/ops/triton/configs/gemm/MI350X-GEMM-A16W16-gated.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
{
"M_LEQ_64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 3,
"matrix_instr_nonkdim": 16,
"cache_modifier": ".cg",
"kpack": 1
},
"M_LEQ_128": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 3,
"matrix_instr_nonkdim": 16,
"cache_modifier": ".cg",
"kpack": 1
},
"M_LEQ_256": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 2,
"matrix_instr_nonkdim": 16,
"cache_modifier": ".cg",
"kpack": 1
},
"M_LEQ_512": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 2,
"matrix_instr_nonkdim": 16,
"cache_modifier": ".cg",
"kpack": 1
},
"M_LEQ_2048": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 2,
"matrix_instr_nonkdim": 16,
"cache_modifier": null,
"kpack": 1
},
"M_GEQ_4096": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 2,
"matrix_instr_nonkdim": 16,
"cache_modifier": null,
"kpack": 1
}
}
66 changes: 63 additions & 3 deletions aiter/ops/triton/configs/gemm/MI350X-GEMM-A16W16.json
Original file line number Diff line number Diff line change
@@ -1,13 +1,73 @@
{
"any": {
"M_LEQ_64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 3,
"matrix_instr_nonkdim": 16,
"cache_modifier": ".cg",
"kpack": 1
},
"M_LEQ_128": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 3,
"matrix_instr_nonkdim": 16,
"cache_modifier": ".cg",
"kpack": 1
},
"M_LEQ_256": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 2,
"matrix_instr_nonkdim": 16,
"cache_modifier": ".cg",
"kpack": 1
},
"M_LEQ_512": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 2,
"matrix_instr_nonkdim": 16,
"cache_modifier": ".cg",
"kpack": 1
},
"M_LEQ_2048": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 2,
"matrix_instr_nonkdim": 16,
"cache_modifier": null,
"kpack": 1
},
"M_GEQ_4096": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 4,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 2,
"matrix_instr_nonkdim": 32,
"matrix_instr_nonkdim": 16,
"cache_modifier": null,
"kpack": 1
}
Expand Down
Loading