From 7000ac8ce56c13c54aa6b8436a8ed78160f625b6 Mon Sep 17 00:00:00 2001 From: Amir Balwel Date: Thu, 27 Nov 2025 05:57:36 +0000 Subject: [PATCH 01/10] init Signed-off-by: Amir Balwel --- aiter/jit/utils/chip_info.py | 1 + .../flash_attn_triton_amd/utils.py | 2 +- .../gemm/R9700-GEMM-A16W16-N=1024-K=1024.json | 86 +++++ .../gemm/R9700-GEMM-A16W16-N=1024-K=2048.json | 86 +++++ .../gemm/R9700-GEMM-A16W16-N=1024-K=3072.json | 86 +++++ .../gemm/R9700-GEMM-A16W16-N=4096-K=1024.json | 86 +++++ .../gemm/R9700-GEMM-A16W16-N=6144-K=1024.json | 86 +++++ .../configs/gemm/R9700-GEMM-A16W16.json | 100 ++++++ aiter/ops/triton/tune_a16w16.py | 299 ++++++++++++++++++ aiter/ops/triton/utils/_triton/arch_info.py | 3 +- aiter/ops/triton/utils/types.py | 4 +- aiter/utility/dtypes.py | 1 + op_tests/triton_tests/test_moe.py | 2 +- op_tests/triton_tests/test_moe_routing.py | 2 +- 14 files changed, 838 insertions(+), 6 deletions(-) create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=1024-K=1024.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=1024-K=2048.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=1024-K=3072.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=4096-K=1024.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=6144-K=1024.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16.json create mode 100644 aiter/ops/triton/tune_a16w16.py diff --git a/aiter/jit/utils/chip_info.py b/aiter/jit/utils/chip_info.py index b91c449e65..30c5f7263a 100644 --- a/aiter/jit/utils/chip_info.py +++ b/aiter/jit/utils/chip_info.py @@ -18,6 +18,7 @@ 6: "gfx945", 7: "gfx1100", 8: "gfx950", + 9: "gfx1201", } diff --git a/aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/utils.py b/aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/utils.py index 44c8a53541..2173b35df2 100644 --- a/aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/utils.py +++ b/aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/utils.py @@ -632,7 +632,7 @@ def is_fp8(x) -> bool: def _is_fp8_single(t: torch.Tensor) -> bool: if is_dtype_fp8(t.dtype): arch = get_arch() - if arch not in ("gfx942", "gfx950"): + if arch not in ("gfx942", "gfx950", "gfx1201"): raise RuntimeError( f"{arch} is not in the list of supported architectures for FP8" ) diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=1024-K=1024.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=1024-K=1024.json new file mode 100644 index 0000000000..25c11e56cf --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=1024-K=1024.json @@ -0,0 +1,86 @@ +{ + "M_LEQ_64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1 + }, + "M_LEQ_512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1 + }, + "M_LEQ_2048": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1 + }, + "any": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=1024-K=2048.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=1024-K=2048.json new file mode 100644 index 0000000000..25c11e56cf --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=1024-K=2048.json @@ -0,0 +1,86 @@ +{ + "M_LEQ_64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1 + }, + "M_LEQ_512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1 + }, + "M_LEQ_2048": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1 + }, + "any": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=1024-K=3072.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=1024-K=3072.json new file mode 100644 index 0000000000..25c11e56cf --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=1024-K=3072.json @@ -0,0 +1,86 @@ +{ + "M_LEQ_64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1 + }, + "M_LEQ_512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1 + }, + "M_LEQ_2048": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1 + }, + "any": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=4096-K=1024.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=4096-K=1024.json new file mode 100644 index 0000000000..25c11e56cf --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=4096-K=1024.json @@ -0,0 +1,86 @@ +{ + "M_LEQ_64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1 + }, + "M_LEQ_512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1 + }, + "M_LEQ_2048": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1 + }, + "any": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=6144-K=1024.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=6144-K=1024.json new file mode 100644 index 0000000000..25c11e56cf --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=6144-K=1024.json @@ -0,0 +1,86 @@ +{ + "M_LEQ_64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1 + }, + "M_LEQ_512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1 + }, + "M_LEQ_2048": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1 + }, + "any": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16.json new file mode 100644 index 0000000000..b6fbdca5bc --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16.json @@ -0,0 +1,100 @@ +{ + "M_LEQ_64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1 + }, + "M_LEQ_512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1 + }, + "M_LEQ_2048": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1 + }, + "any": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1 + }, + "default": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1 + } +} diff --git a/aiter/ops/triton/tune_a16w16.py b/aiter/ops/triton/tune_a16w16.py new file mode 100644 index 0000000000..67bc319494 --- /dev/null +++ b/aiter/ops/triton/tune_a16w16.py @@ -0,0 +1,299 @@ +import argparse +import json +import multiprocessing as mp +import os +import time +import triton +from datetime import datetime + +import torch +from tqdm import tqdm + + +from gemm_a16w16 import gemm_a16w16 # type: ignore +from utils.core import AITER_TRITON_CONFIGS_PATH # type: ignore + +mp.set_start_method("spawn", force=True) + + +DTYPE_MAP = { + "float32": torch.float32, + "float16": torch.float16, + "half": torch.half, + "bfloat16": torch.bfloat16, +} + + +def get_configs_compute_bound(): + configs = [] + for num_stages in [2]: + for block_m in [16]: + for block_k in [64]: + for block_n in [32]: + for num_warps in [4]: + for group_size in [1]: + for waves_per_eu in [3]: + configs.append( + { + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": group_size, + "num_warps": num_warps, + "num_stages": num_stages, + "waves_per_eu": waves_per_eu, # TODO check if compatible + "matrix_instr_nonkdim": 16, # TODO + "cache_modifier": None, # TODO + "NUM_KSPLIT": 1, # TODO + "kpack": 1, # TODO + "SPLITK_BLOCK_SIZE": 1, + } + ) + return configs + + +# def get_configs_compute_bound(): +# configs = [] +# for num_stages in [2, 3, 4, 5]: +# for block_m in [16, 32, 64, 128, 256]: +# for block_k in [64, 128]: +# for block_n in [32, 64, 128, 256]: +# for num_warps in [4, 8]: +# for group_size in [1, 16, 32, 64]: +# for waves_per_eu in [1,2,3,4]: +# configs.append( +# { +# "BLOCK_SIZE_M": block_m, +# "BLOCK_SIZE_N": block_n, +# "BLOCK_SIZE_K": block_k, +# "GROUP_SIZE_M": group_size, +# "num_warps": num_warps, +# "num_stages": num_stages, +# "waves_per_eu": waves_per_eu, # TODO check if compatible +# "matrix_instr_nonkdim": 16, # TODO +# "cache_modifier": None, # TODO +# "NUM_KSPLIT": 1, # TODO +# "kpack": 1, # TODO +# "SPLITK_BLOCK_SIZE":1, +# } +# ) +# return configs + + +def get_weight_shapes(tp_size): + total = [ + (1024, 1024), + (4096, 1024), + (1024, 2048), + (6144, 1024), + (1024, 3072), + ] + + weight_shapes = [] + for t in total: + weight_shapes.append(t) + + return weight_shapes + + +def benchmark_config(x, w, bias, dtype, y, config, activation, num_iters=10): + def run(): + gemm_a16w16(x, w, bias, dtype, y, config, activation) + + torch.cuda.synchronize() + # JIT complication & warmup + for _ in range(5): + run() + torch.cuda.synchronize() + + start_event = torch.Event(enable_timing=True) + end_event = torch.Event(enable_timing=True) + + latencies: list[float] = [] + for i in range(num_iters): + torch.cuda.synchronize() + start_event.record() + run() + end_event.record() + end_event.synchronize() + latencies.append(start_event.elapsed_time(end_event)) + avg = sum(latencies) / (num_iters * 10) * 1000 # us + return avg + + +def tune(M, N, K, out_dtype, search_space, input_type): + if input_type == "bfloat16": + fp16_info = torch.finfo(torch.bfloat16) + fp16_max, fp16_min = fp16_info.max, fp16_info.min + + x_fp32 = ( + (torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp16_max + ) + x = x_fp32.clamp(min=fp16_min, max=fp16_max).to(torch.bfloat16) + + w_fp32 = ( + (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp16_max + ) + w = w_fp32.clamp(min=fp16_min, max=fp16_max).to(torch.bfloat16) + else: + raise RuntimeError("Currently, only support tune w16a16 block fp16 kernel.") + + best_config = None + best_time = float("inf") + for config in tqdm(search_space): + try: + kernel_time = benchmark_config( + x=x, + w=w, + bias=None, + dtype=torch.bfloat16, + y=None, + config=config, + activation=None, + num_iters=10, + ) + except triton.runtime.autotuner.OutOfResources: + # Some configurations may be invalid and fail to compile. + continue + + if kernel_time < best_time: + best_time = kernel_time + best_config = config + now = datetime.now() + print(f"{now.ctime()}] Completed tuning for batch_size={M}") + assert best_config is not None + return best_config + + +def save_configs( + N, + K, + configs, + save_path, +) -> None: + os.makedirs(save_path, exist_ok=True) + device_name = "R9700" # TODO: Hardcoded, make it dynamic + json_file_name = f"{device_name}-GEMM-A16W16-N={N}-K={K}.json" + + config_file_path = os.path.join(save_path, json_file_name) + print(f"Writing best config to {config_file_path}...") + + with open(config_file_path, "w") as f: + json.dump(configs, f, indent=4) + f.write("\n") + + +def tune_on_gpu(args_dict): + """Run tuning on a specific GPU.""" + gpu_id = args_dict["gpu_id"] + batch_sizes = args_dict["batch_sizes"] + weight_shapes = args_dict["weight_shapes"] + args = args_dict["args"] + + torch.cuda.set_device(gpu_id) + print(f"Starting tuning on GPU {gpu_id} with batch sizes {batch_sizes}") + + out_dtype = DTYPE_MAP[args.out_dtype] + save_path = AITER_TRITON_CONFIGS_PATH + "/gemm/" + input_type = args.input_type + + search_space = get_configs_compute_bound() + + start = time.time() + for shape in tqdm(weight_shapes, desc=f"GPU {gpu_id} - Shapes"): + N, K = shape[0], shape[1] + print(f"[GPU {gpu_id}] Tune for weight shape of `N: {N}, K: {K}`") + benchmark_results = [ + tune( + batch_size, + N, + K, + out_dtype, + search_space, + input_type, + ) + for batch_size in tqdm(batch_sizes, desc=f"GPU {gpu_id} - Batch sizes") + ] + best_configs = { + ("any" if i == len(batch_sizes) - 1 else f"M_LEQ_{M}"): config + for i, (M, config) in enumerate(zip(batch_sizes, benchmark_results)) + } + save_configs(N, K, best_configs, save_path) + + end = time.time() + print(f"Tuning on GPU {gpu_id} took {end - start:.2f} seconds") + + +def distribute_batch_sizes(batch_sizes, num_gpus): + """Distribute batch sizes across available GPUs.""" + batches_per_gpu = [] + for i in range(num_gpus): + start_idx = i * len(batch_sizes) // num_gpus + end_idx = (i + 1) * len(batch_sizes) // num_gpus + batches_per_gpu.append(batch_sizes[start_idx:end_idx]) + return batches_per_gpu + + +def main(args): + print(args) + num_gpus = torch.cuda.device_count() + if num_gpus == 0: + raise RuntimeError("No GPU available for tuning") + print(f"Found {num_gpus} GPUs for parallel tuning") + + torch.cuda.init() + + if args.batch_size is None: + batch_sizes = [ + 64, + 128, + 256, + 512, + 2048, + 4096, + ] + else: + batch_sizes = [args.batch_size] + num_gpus = 1 # If only one batch size, use only one GPU + + weight_shapes = get_weight_shapes(args.tp_size) + + batches_per_gpu = distribute_batch_sizes(batch_sizes, 1) + + process_args = [] + for gpu_id in range(1): + process_args.append( + { + "gpu_id": gpu_id, + "batch_sizes": batches_per_gpu[gpu_id], + "weight_shapes": weight_shapes, # Each GPU processes all weight shapes + "args": args, + } + ) + + ctx = mp.get_context("spawn") + with ctx.Pool(1) as pool: + pool.map(tune_on_gpu, process_args) + + print("Multi-GPU tuning completed") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + formatter_class=argparse.RawTextHelpFormatter, + ) + + parser.add_argument("--tp-size", "-tp", type=int, default=1) + parser.add_argument( + "--input-type", type=str, choices=["bfloat16"], default="bfloat16" + ) + parser.add_argument( + "--out-dtype", + type=str, + choices=["float32", "float16", "bfloat16", "half"], + default="bfloat16", + ) + parser.add_argument("--batch-size", type=int, required=False) + args = parser.parse_args() + + main(args) diff --git a/aiter/ops/triton/utils/_triton/arch_info.py b/aiter/ops/triton/utils/_triton/arch_info.py index 20bab95679..11de410532 100644 --- a/aiter/ops/triton/utils/_triton/arch_info.py +++ b/aiter/ops/triton/utils/_triton/arch_info.py @@ -4,6 +4,7 @@ _ARCH_TO_DEVICE = { "gfx942": "MI300X", "gfx950": "MI350X", + "gfx1201": "R9700", } @@ -30,4 +31,4 @@ def is_fp4_avail(): def is_fp8_avail(): - return get_arch() in ("gfx942", "gfx950") + return get_arch() in ("gfx942", "gfx950", "gfx1201") diff --git a/aiter/ops/triton/utils/types.py b/aiter/ops/triton/utils/types.py index 92d18c6eed..6a07c63753 100644 --- a/aiter/ops/triton/utils/types.py +++ b/aiter/ops/triton/utils/types.py @@ -11,7 +11,7 @@ def get_dtype_max(dtype): def get_fp8_dtypes(): - if arch_info.get_arch() in ("gfx950"): + if arch_info.get_arch() in ("gfx950", "gfx1201"): e5m2_dtype = torch.float8_e5m2 e4m3_dtype = torch.float8_e4m3fn else: @@ -22,7 +22,7 @@ def get_fp8_dtypes(): def get_fp8_e4m3_dtype(): - if arch_info.get_arch() in ("gfx950"): + if arch_info.get_arch() in ("gfx950", "gfx1201"): e4m3_dtype = torch.float8_e4m3fn else: e4m3_dtype = torch.float8_e4m3fnuz diff --git a/aiter/utility/dtypes.py b/aiter/utility/dtypes.py index 5252f5f8b3..e1eb66a153 100644 --- a/aiter/utility/dtypes.py +++ b/aiter/utility/dtypes.py @@ -7,6 +7,7 @@ defaultDtypes = { "gfx942": {"fp8": torch.float8_e4m3fnuz}, "gfx950": {"fp8": torch.float8_e4m3fn}, + "gfx1201": {"fp8": torch.float8_e4m3fn}, } _8bit_fallback = torch.uint8 diff --git a/op_tests/triton_tests/test_moe.py b/op_tests/triton_tests/test_moe.py index a2c93957fb..ce7522c34c 100644 --- a/op_tests/triton_tests/test_moe.py +++ b/op_tests/triton_tests/test_moe.py @@ -324,7 +324,7 @@ def quantize_fp8( tensor: torch.Tensor, dim=() ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: dev = arch_info.get_device() - if dev == "MI350X": + if dev in ["MI350X", "R9700"]: fp8_type = torch.float8_e4m3fn else: fp8_type = torch.float8_e4m3fnuz diff --git a/op_tests/triton_tests/test_moe_routing.py b/op_tests/triton_tests/test_moe_routing.py index 85477f0dda..ee6985dd2f 100644 --- a/op_tests/triton_tests/test_moe_routing.py +++ b/op_tests/triton_tests/test_moe_routing.py @@ -95,7 +95,7 @@ def init_data(n_tokens, n_expts_tot, dtype=torch.float16, device="cuda"): @pytest.mark.parametrize("use_expt_indx", [False, True]) @pytest.mark.parametrize("sm_first", [True, False]) def test_op(n_tokens, n_expts_tot, n_expts_act, sm_first, use_expt_indx): - if get_arch() != "gfx950": + if get_arch() not in ["gfx950", "gfx1201"]: pytest.skip("MOE stack not fully implemented on non-CDNA4 arch yet.") device = "cuda" From 6b72cab62e599733d5378b141beecbb2d2f86065 Mon Sep 17 00:00:00 2001 From: big-yellow-duck Date: Thu, 27 Nov 2025 07:10:26 +0000 Subject: [PATCH 02/10] added logging for gemm_a16w16 tests and atomic kernel tune --- aiter/ops/triton/README_tune_atomic.md | 103 +++++ .../ops/triton/_triton_kernels/gemm_a16w16.py | 8 + .../_triton_kernels/gemm_a16w16_atomic.py | 9 + ...9700-GEMM-A16W16-ATOMIC-N=1024-K=1024.json | 100 +++++ ...9700-GEMM-A16W16-ATOMIC-N=1024-K=2048.json | 100 +++++ ...9700-GEMM-A16W16-ATOMIC-N=1024-K=3072.json | 100 +++++ ...9700-GEMM-A16W16-ATOMIC-N=4096-K=1024.json | 100 +++++ ...9700-GEMM-A16W16-ATOMIC-N=6144-K=1024.json | 100 +++++ .../gemm/R9700-GEMM-A16W16-ATOMIC.json | 100 +++++ aiter/ops/triton/gemm_a16w16.py | 1 + aiter/ops/triton/tune_a16w16_atomic.py | 385 ++++++++++++++++++ op_tests/triton_tests/README.md | 4 +- op_tests/triton_tests/test_gemm_a16w16.py | 9 + pyproject.toml | 4 + 14 files changed, 1122 insertions(+), 1 deletion(-) create mode 100644 aiter/ops/triton/README_tune_atomic.md create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=1024.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=2048.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=3072.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=4096-K=1024.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=6144-K=1024.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC.json create mode 100644 aiter/ops/triton/tune_a16w16_atomic.py diff --git a/aiter/ops/triton/README_tune_atomic.md b/aiter/ops/triton/README_tune_atomic.md new file mode 100644 index 0000000000..eafa38498d --- /dev/null +++ b/aiter/ops/triton/README_tune_atomic.md @@ -0,0 +1,103 @@ +# GEMM A16W16 Atomic Kernel Tuning + +This document explains how to use the tuning script for the atomic GEMM kernel. + +## Overview + +The atomic GEMM kernel (`gemm_a16w16_atomic`) is a specialized kernel that uses atomic operations for split-K reduction. This allows for better parallelization in certain scenarios, especially for larger matrix dimensions. + +## Tuning Script + +The tuning script is located at `aiter/ops/triton/tune_a16w16_atomic.py`. It follows a similar pattern to the regular GEMM tuning script but with specific adaptations for the atomic kernel. + +### Key Differences from Regular GEMM Tuning + +1. **NUM_KSPLIT Parameter**: The atomic kernel includes a `NUM_KSPLIT` parameter that controls how the K dimension is split across multiple thread blocks for parallel reduction. + +2. **Configuration Categories**: The atomic kernel uses different configuration categories based on the M dimension: + - `small`: M < 32 + - `medium_M32`: M โ‰ค 128 with BLOCK_SIZE_M = 32 + - `medium_M64`: M โ‰ค 128 with BLOCK_SIZE_M = 64 + - `medium_M128`: M โ‰ค 128 with BLOCK_SIZE_M = 128 + - `large`: M โ‰ค 256 + - `xlarge`: M > 256 + +3. **No Bias Parameter**: The atomic kernel doesn't support a bias parameter. + +### Running the Tuning Script + +To run the tuning script: + +```bash +cd aiter/ops/triton +python tune_a16w16_atomic.py +``` + +You can also specify specific parameters: + +```bash +python tune_a16w16_atomic.py --batch-size 512 --input-type bfloat16 --out-dtype bfloat16 +``` + +### Output + +The tuning script will generate configuration files in the `aiter/ops/triton/configs/gemm/` directory with names like: +- `R9700-GEMM-A16W16-ATOMIC-N={N}-K={K}.json` for specific N,K dimensions +- `R9700-GEMM-A16W16-ATOMIC.json` - A default config file (without N,K parameters) that contains the most common optimal configurations across all tested shapes + +The default config file is created by analyzing all the specific N,K configurations and selecting the most common optimal configuration for each category (small, medium_M32, etc.). This provides a good general-purpose configuration when a specific N,K configuration is not available. + +### Configuration Parameters + +The tuning script searches through these parameters: +- `BLOCK_SIZE_M`: [16, 32, 64, 128] +- `BLOCK_SIZE_N`: [32, 64, 128, 256] +- `BLOCK_SIZE_K`: [64, 128] +- `GROUP_SIZE_M`: [1, 8, 16] +- `NUM_KSPLIT`: [1, 2, 4, 8] (atomic kernel specific) +- `num_warps`: [4, 8] +- `num_stages`: [2] +- `waves_per_eu`: [3] + +### Batch Sizes Tested + +The script tests these batch sizes (M dimensions): +- 16 (small) +- 32 (medium_M32) +- 64 (medium_M64) +- 128 (medium_M128) +- 256 (large) +- 512 (large) +- 2048 (xlarge) +- 4096 (xlarge) + +### Weight Shapes Tested + +The script tunes for these weight shapes (N, K): +- (1024, 1024) +- (4096, 1024) +- (1024, 2048) +- (6144, 1024) +- (1024, 3072) + +## Usage in Code + +Once the tuning is complete, the atomic kernel can be used in your code: + +```python +from aiter.ops.triton.gemm_a16w16_atomic import gemm_a16w16_atomic + +# Basic usage +x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") +w = torch.randn(N, K, dtype=torch.bfloat16, device="cuda") + +# The kernel will automatically use the tuned configurations +y = gemm_a16w16_atomic(x, w, dtype=torch.bfloat16) +``` + +## Notes + +1. The atomic kernel is particularly useful for larger matrices where split-K parallelization provides benefits. +2. For smaller matrices, the regular GEMM kernel might be more efficient. +3. The tuning process can take considerable time as it tests many configurations. +4. Make sure you have sufficient GPU memory for the tuning process. \ No newline at end of file diff --git a/aiter/ops/triton/_triton_kernels/gemm_a16w16.py b/aiter/ops/triton/_triton_kernels/gemm_a16w16.py index f5377e7eb3..9653e16c91 100644 --- a/aiter/ops/triton/_triton_kernels/gemm_a16w16.py +++ b/aiter/ops/triton/_triton_kernels/gemm_a16w16.py @@ -7,6 +7,9 @@ from ..utils._triton import arch_info from ..utils.core import AITER_TRITON_CONFIGS_PATH from ..utils._triton.kernel_repr import make_kernel_repr +from ..utils.logger import AiterTritonLogger + +_LOGGER = AiterTritonLogger() _gemm_a16w16_repr = make_kernel_repr( @@ -260,6 +263,7 @@ def _get_config( dev = arch_info.get_device() _get_config._config_dict = {} fpath = f"{AITER_TRITON_CONFIGS_PATH}/gemm/{dev}-GEMM-A16W16.json" + _LOGGER.info(f"Loading default GEMM config from: {fpath}") with open(fpath, "r") as file: config = json.load(file) _get_config._config_dict["default"] = config @@ -269,19 +273,23 @@ def _get_config( dev = arch_info.get_device() fpath = f"{AITER_TRITON_CONFIGS_PATH}/gemm/{dev}-GEMM-A16W16-N={N}-K={K}.json" if os.path.exists(fpath): + _LOGGER.info(f"Loading specific GEMM config from: {fpath}") with open(fpath, "r") as file: config = json.load(file) _get_config._config_dict[key] = config else: key = "default" # fall back to default config + _LOGGER.info(f"Specific config not found, using default config for N={N}, K={K}") bounds = [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] for bound in bounds: if M <= bound and f"M_LEQ_{bound}" in _get_config._config_dict[key]: temp_config = _get_config._config_dict[key][f"M_LEQ_{bound}"] + _LOGGER.info(f"Using config for M <= {bound} with N={N}, K={K}") break else: temp_config = _get_config._config_dict[key]["any"] + _LOGGER.info(f"Using 'any' config for M={M}, N={N}, K={K}") # Copy to avoid mutating the cached config chosen_config = dict(temp_config) diff --git a/aiter/ops/triton/_triton_kernels/gemm_a16w16_atomic.py b/aiter/ops/triton/_triton_kernels/gemm_a16w16_atomic.py index d08a1b2760..507f22a4f1 100644 --- a/aiter/ops/triton/_triton_kernels/gemm_a16w16_atomic.py +++ b/aiter/ops/triton/_triton_kernels/gemm_a16w16_atomic.py @@ -11,6 +11,9 @@ from ..utils._triton import arch_info from ..utils.core import AITER_TRITON_CONFIGS_PATH from ..utils._triton.kernel_repr import make_kernel_repr +from ..utils.logger import AiterTritonLogger + +_LOGGER = AiterTritonLogger() _gemm_a16w16_atomic_repr = make_kernel_repr( @@ -171,16 +174,22 @@ def _get_config( # single config. for the default path return _get_config._config_dict[key]["any"] if M < 32: + _LOGGER.info(f"Using 'small' ATOMIC config for M={M}, N={N}, K={K}") return _get_config._config_dict[key]["small"] elif M <= 128: BLK_M = triton.next_power_of_2(M) if BLK_M == 32: + _LOGGER.info(f"Using 'medium_M32' ATOMIC config for M={M}, N={N}, K={K}") return _get_config._config_dict[key]["medium_M32"] elif BLK_M == 64: + _LOGGER.info(f"Using 'medium_M64' ATOMIC config for M={M}, N={N}, K={K}") return _get_config._config_dict[key]["medium_M64"] elif BLK_M == 128: + _LOGGER.info(f"Using 'medium_M128' ATOMIC config for M={M}, N={N}, K={K}") return _get_config._config_dict[key]["medium_M128"] elif M <= 256: + _LOGGER.info(f"Using 'large' ATOMIC config for M={M}, N={N}, K={K}") return _get_config._config_dict[key]["large"] else: + _LOGGER.info(f"Using 'xlarge' ATOMIC config for M={M}, N={N}, K={K}") return _get_config._config_dict[key]["xlarge"] diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=1024.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=1024.json new file mode 100644 index 0000000000..3b3b9424eb --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=1024.json @@ -0,0 +1,100 @@ +{ + "small": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1024 + }, + "medium_M32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1024 + }, + "medium_M64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1024 + }, + "medium_M128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1024 + }, + "large": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1024 + }, + "xlarge": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1024 + }, + "any": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1024 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=2048.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=2048.json new file mode 100644 index 0000000000..4bd20b9771 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=2048.json @@ -0,0 +1,100 @@ +{ + "small": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 2048 + }, + "medium_M32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 2048 + }, + "medium_M64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 2048 + }, + "medium_M128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 2048 + }, + "large": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 2048 + }, + "xlarge": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 2048 + }, + "any": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 2048 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=3072.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=3072.json new file mode 100644 index 0000000000..e9df57e6e1 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=3072.json @@ -0,0 +1,100 @@ +{ + "small": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 3072 + }, + "medium_M32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 3072 + }, + "medium_M64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 3072 + }, + "medium_M128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 3072 + }, + "large": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 3072 + }, + "xlarge": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 3072 + }, + "any": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 3072 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=4096-K=1024.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=4096-K=1024.json new file mode 100644 index 0000000000..3b3b9424eb --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=4096-K=1024.json @@ -0,0 +1,100 @@ +{ + "small": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1024 + }, + "medium_M32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1024 + }, + "medium_M64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1024 + }, + "medium_M128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1024 + }, + "large": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1024 + }, + "xlarge": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1024 + }, + "any": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1024 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=6144-K=1024.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=6144-K=1024.json new file mode 100644 index 0000000000..3b3b9424eb --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=6144-K=1024.json @@ -0,0 +1,100 @@ +{ + "small": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1024 + }, + "medium_M32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1024 + }, + "medium_M64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1024 + }, + "medium_M128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1024 + }, + "large": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1024 + }, + "xlarge": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1024 + }, + "any": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1024 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC.json new file mode 100644 index 0000000000..3bcb88ed00 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC.json @@ -0,0 +1,100 @@ +{ + "small": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 3072, + "cache_modifier": "", + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_stages": 2, + "num_warps": 4, + "waves_per_eu": 3 + }, + "medium_M32": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 3072, + "cache_modifier": "", + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_stages": 2, + "num_warps": 4, + "waves_per_eu": 3 + }, + "medium_M64": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 3072, + "cache_modifier": "", + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_stages": 2, + "num_warps": 4, + "waves_per_eu": 3 + }, + "medium_M128": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 3072, + "cache_modifier": "", + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_stages": 2, + "num_warps": 4, + "waves_per_eu": 3 + }, + "large": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 3072, + "cache_modifier": "", + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_stages": 2, + "num_warps": 4, + "waves_per_eu": 3 + }, + "xlarge": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 3072, + "cache_modifier": "", + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_stages": 2, + "num_warps": 4, + "waves_per_eu": 3 + }, + "any": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 3072, + "cache_modifier": "", + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_stages": 2, + "num_warps": 4, + "waves_per_eu": 3 + } +} diff --git a/aiter/ops/triton/gemm_a16w16.py b/aiter/ops/triton/gemm_a16w16.py index 549ddd36d8..4a746a57d9 100644 --- a/aiter/ops/triton/gemm_a16w16.py +++ b/aiter/ops/triton/gemm_a16w16.py @@ -57,6 +57,7 @@ def gemm_a16w16( if config is None: config = _get_config(M, N, K) + _LOGGER.info(f"Using GEMM config: {config}") if y is None and (config["NUM_KSPLIT"] == 1 or not skip_reduce): y = torch.empty((M, N), dtype=dtype, device=x.device) diff --git a/aiter/ops/triton/tune_a16w16_atomic.py b/aiter/ops/triton/tune_a16w16_atomic.py new file mode 100644 index 0000000000..25c1a51f6f --- /dev/null +++ b/aiter/ops/triton/tune_a16w16_atomic.py @@ -0,0 +1,385 @@ +import argparse +import json +import multiprocessing as mp +import os +import time +import triton +from datetime import datetime + +import torch +from tqdm import tqdm + + +from gemm_a16w16_atomic import gemm_a16w16_atomic # type: ignore +from utils.core import AITER_TRITON_CONFIGS_PATH # type: ignore + +mp.set_start_method("spawn", force=True) + + +DTYPE_MAP = { + "float32": torch.float32, + "float16": torch.float16, + "half": torch.half, + "bfloat16": torch.bfloat16, +} + + +def get_configs_compute_bound(): + configs = [] + for num_stages in [2]: + for block_m in [ + 16, + ]: + for block_k in [ + 64, + ]: + for block_n in [ + 32, + ]: + for num_warps in [ + 4, + ]: + for group_size in [ + 1, + ]: + for num_ksplit in [ + 1, + ]: # Atomic kernel specific parameter + for waves_per_eu in [3]: + configs.append( + { + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": group_size, + "num_warps": num_warps, + "num_stages": num_stages, + "waves_per_eu": waves_per_eu, # TODO check if compatible + "matrix_instr_nonkdim": 16, # TODO + "cache_modifier": "", # Empty string for atomic kernel + "NUM_KSPLIT": num_ksplit, # Atomic kernel specific + "kpack": 1, # TODO + "SPLITK_BLOCK_SIZE": 1, # Will be set dynamically + } + ) + return configs + + +# def get_configs_compute_bound(): +# configs = [] +# for num_stages in [2, 3, 4, 5]: +# for block_m in [16, 32, 64, 128, 256]: +# for block_k in [64, 128]: +# for block_n in [32, 64, 128, 256]: +# for num_warps in [4, 8]: +# for group_size in [1, 8, 16, 32, 64]: +# for num_ksplit in [1, 2, 4, 8]: # Atomic kernel specific parameter +# for waves_per_eu in [1,2,3,4]: +# configs.append( +# { +# "BLOCK_SIZE_M": block_m, +# "BLOCK_SIZE_N": block_n, +# "BLOCK_SIZE_K": block_k, +# "GROUP_SIZE_M": group_size, +# "num_warps": num_warps, +# "num_stages": num_stages, +# "waves_per_eu": waves_per_eu, # TODO check if compatible +# "matrix_instr_nonkdim": 16, # TODO +# "cache_modifier": None, # TODO +# "NUM_KSPLIT": num_ksplit, # Atomic kernel specific +# "kpack": 1, # TODO +# "SPLITK_BLOCK_SIZE": 1, # Will be set dynamically +# } +# ) +# return configs + + +def get_weight_shapes(tp_size): + total = [ + (1024, 1024), + (4096, 1024), + (1024, 2048), + (6144, 1024), + (1024, 3072), + ] + + weight_shapes = [] + for t in total: + weight_shapes.append(t) + + return weight_shapes + + +def benchmark_config(x, w, dtype, y, config, num_iters=10): + def run(): + gemm_a16w16_atomic(x, w, dtype, y, config) + + torch.cuda.synchronize() + # JIT complication & warmup + for _ in range(5): + run() + torch.cuda.synchronize() + + start_event = torch.Event(enable_timing=True) + end_event = torch.Event(enable_timing=True) + + latencies: list[float] = [] + for i in range(num_iters): + torch.cuda.synchronize() + start_event.record() + run() + end_event.record() + end_event.synchronize() + latencies.append(start_event.elapsed_time(end_event)) + avg = sum(latencies) / (num_iters * 10) * 1000 # us + return avg + + +def tune(M, N, K, out_dtype, search_space, input_type): + if input_type == "bfloat16": + fp16_info = torch.finfo(torch.bfloat16) + fp16_max, fp16_min = fp16_info.max, fp16_info.min + + x_fp32 = ( + (torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp16_max + ) + x = x_fp32.clamp(min=fp16_min, max=fp16_max).to(torch.bfloat16) + + w_fp32 = ( + (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp16_max + ) + w = w_fp32.clamp(min=fp16_min, max=fp16_max).to(torch.bfloat16) + else: + raise RuntimeError("Currently, only support tune w16a16 block fp16 kernel.") + + best_config = None + best_time = float("inf") + for config in tqdm(search_space): + try: + kernel_time = benchmark_config( + x=x, + w=w, + dtype=torch.bfloat16, + y=None, + config=config, + num_iters=10, + ) + except triton.runtime.autotuner.OutOfResources: + # Some configurations may be invalid and fail to compile. + continue + + if kernel_time < best_time: + best_time = kernel_time + best_config = config + now = datetime.now() + print(f"{now.ctime()}] Completed tuning for batch_size={M}") + assert best_config is not None + return best_config + + +def save_configs( + N, + K, + configs, + save_path, +) -> None: + os.makedirs(save_path, exist_ok=True) + device_name = "R9700" # TODO: Hardcoded, make it dynamic + json_file_name = f"{device_name}-GEMM-A16W16-ATOMIC-N={N}-K={K}.json" + + config_file_path = os.path.join(save_path, json_file_name) + print(f"Writing best config to {config_file_path}...") + + with open(config_file_path, "w") as f: + json.dump(configs, f, indent=4) + f.write("\n") + + +def tune_on_gpu(args_dict): + """Run tuning on a specific GPU.""" + gpu_id = args_dict["gpu_id"] + batch_sizes = args_dict["batch_sizes"] + weight_shapes = args_dict["weight_shapes"] + args = args_dict["args"] + + torch.cuda.set_device(gpu_id) + print(f"Starting tuning on GPU {gpu_id} with batch sizes {batch_sizes}") + + out_dtype = DTYPE_MAP[args.out_dtype] + save_path = AITER_TRITON_CONFIGS_PATH + "/gemm/" + input_type = args.input_type + + search_space = get_configs_compute_bound() + + start = time.time() + + # Collect all configs to determine best overall config + all_configs = [] + + for shape in tqdm(weight_shapes, desc=f"GPU {gpu_id} - Shapes"): + N, K = shape[0], shape[1] + print(f"[GPU {gpu_id}] Tune for weight shape of `N: {N}, K: {K}`") + benchmark_results = [ + tune( + batch_size, + N, + K, + out_dtype, + search_space, + input_type, + ) + for batch_size in tqdm(batch_sizes, desc=f"GPU {gpu_id} - Batch sizes") + ] + best_configs = {} + # Create configs for different M size categories as expected by the atomic kernel + for i, (M, config) in enumerate(zip(batch_sizes, benchmark_results)): + if i == len(batch_sizes) - 1: + best_configs["any"] = config + elif M < 32: + best_configs["small"] = config + elif M <= 128: + BLK_M = triton.next_power_of_2(M) + if BLK_M == 32: + best_configs["medium_M32"] = config + elif BLK_M == 64: + best_configs["medium_M64"] = config + elif BLK_M == 128: + best_configs["medium_M128"] = config + elif M <= 256: + best_configs["large"] = config + else: + best_configs["xlarge"] = config + # Store configs for later analysis + all_configs.append(best_configs) + save_configs(N, K, best_configs, save_path) + + # Create a default config file (without N,K parameters) by selecting most common config + default_config = create_default_config(all_configs) + save_default_config(default_config, save_path) + + end = time.time() + print(f"Tuning on GPU {gpu_id} took {end - start:.2f} seconds") + + +def create_default_config(all_configs): + """Create a default config by selecting the most common config across all shapes.""" + from collections import Counter + + # Collect all configs for each category + category_configs = { + "small": [], + "medium_M32": [], + "medium_M64": [], + "medium_M128": [], + "large": [], + "xlarge": [], + "any": [] + } + + for config in all_configs: + for category, params in config.items(): + if category in category_configs: + # Convert config to a hashable tuple for counting + config_tuple = tuple(sorted(params.items())) + category_configs[category].append(config_tuple) + + # Find the most common config for each category + default_config = {} + for category, configs in category_configs.items(): + if configs: + most_common = Counter(configs).most_common(1)[0][0] + default_config[category] = dict(most_common) + + return default_config + + +def save_default_config(config, save_path): + """Save the default config file (without N,K parameters).""" + os.makedirs(save_path, exist_ok=True) + device_name = "R9700" # TODO: Hardcoded, make it dynamic + json_file_name = f"{device_name}-GEMM-A16W16-ATOMIC.json" + + config_file_path = os.path.join(save_path, json_file_name) + print(f"Writing default config to {config_file_path}...") + + with open(config_file_path, "w") as f: + json.dump(config, f, indent=4) + f.write("\n") + + +def distribute_batch_sizes(batch_sizes, num_gpus): + """Distribute batch sizes across available GPUs.""" + batches_per_gpu = [] + for i in range(num_gpus): + start_idx = i * len(batch_sizes) // num_gpus + end_idx = (i + 1) * len(batch_sizes) // num_gpus + batches_per_gpu.append(batch_sizes[start_idx:end_idx]) + return batches_per_gpu + + +def main(args): + print(args) + num_gpus = torch.cuda.device_count() + if num_gpus == 0: + raise RuntimeError("No GPU available for tuning") + print(f"Found {num_gpus} GPUs for parallel tuning") + + torch.cuda.init() + + if args.batch_size is None: + batch_sizes = [ + 16, # For small config + 32, # For medium_M32 config + 64, # For medium_M64 config + 128, # For medium_M128 config + 256, # For large config + 512, # For large config + 2048, # For xlarge config + 4096, # For xlarge config + ] + else: + batch_sizes = [args.batch_size] + num_gpus = 1 # If only one batch size, use only one GPU + + weight_shapes = get_weight_shapes(args.tp_size) + + batches_per_gpu = distribute_batch_sizes(batch_sizes, 1) + + process_args = [] + for gpu_id in range(1): + process_args.append( + { + "gpu_id": gpu_id, + "batch_sizes": batches_per_gpu[gpu_id], + "weight_shapes": weight_shapes, # Each GPU processes all weight shapes + "args": args, + } + ) + + ctx = mp.get_context("spawn") + with ctx.Pool(1) as pool: + pool.map(tune_on_gpu, process_args) + + print("Multi-GPU tuning completed") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + formatter_class=argparse.RawTextHelpFormatter, + ) + + parser.add_argument("--tp-size", "-tp", type=int, default=1) + parser.add_argument( + "--input-type", type=str, choices=["bfloat16"], default="bfloat16" + ) + parser.add_argument( + "--out-dtype", + type=str, + choices=["float32", "float16", "bfloat16", "half"], + default="bfloat16", + ) + parser.add_argument("--batch-size", type=int, required=False) + args = parser.parse_args() + + main(args) +# diff --git a/op_tests/triton_tests/README.md b/op_tests/triton_tests/README.md index a10e3275b5..7a9ef98b1a 100644 --- a/op_tests/triton_tests/README.md +++ b/op_tests/triton_tests/README.md @@ -6,4 +6,6 @@ Triton's implementation uses numpy under the hood. Switch back to Triton if you're experiencing OOM issues. 2. When possible, generate test inputs directly on the GPU (e.g with `torch.randn((M, K), device="cuda")` as opposed to `torch.randn((M, K)).cuda()`). -It's ~2 orders of magnitude faster for large test cases. \ No newline at end of file +It's ~2 orders of magnitude faster for large test cases. + +** Run the tests with ```AITER_TRITON_LOG_LEVEL=INFO``` to get logging \ No newline at end of file diff --git a/op_tests/triton_tests/test_gemm_a16w16.py b/op_tests/triton_tests/test_gemm_a16w16.py index 66f64b6119..bc63eeb4ff 100644 --- a/op_tests/triton_tests/test_gemm_a16w16.py +++ b/op_tests/triton_tests/test_gemm_a16w16.py @@ -4,10 +4,14 @@ import torch import torch.nn.functional as F import pytest +import logging from aiter.ops.triton.gemm_a16w16 import gemm_a16w16 from aiter.ops.triton.gemm_a16w16_atomic import gemm_a16w16_atomic from op_tests.triton_tests.utils.types import str_to_torch_dtype +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + def generate_gemm_a16w16_inputs(M, N, K, dtype, layout="TN", output=True, bias=False): if isinstance(dtype, str): @@ -80,6 +84,7 @@ def get_x_vals(): @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("output", [True, False]) def test_gemm_a16_w16_activation(M: int, N: int, K: int, dtype, output, activation): + logger.info(f"Running test_gemm_a16_w16_activation with M={M}, N={N}, K={K}, dtype={dtype}, output={output}, activation={activation}") x, w, _, out_dtype, y = generate_gemm_a16w16_inputs( M, N, @@ -123,6 +128,7 @@ def test_gemm_a16_w16_activation(M: int, N: int, K: int, dtype, output, activati @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("output", [True, False]) def test_gemm_a16_w16(M: int, N: int, K: int, dtype, output): + logger.info(f"Running test_gemm_a16_w16 with M={M}, N={N}, K={K}, dtype={dtype}, output={output}") torch.cuda.empty_cache() # Helps avoid hangs in large tests x, w, bias, out_dtype, y = generate_gemm_a16w16_inputs( @@ -144,6 +150,7 @@ def test_gemm_a16_w16(M: int, N: int, K: int, dtype, output): @pytest.mark.parametrize("layout", ["TT", "NN", "NT"]) @pytest.mark.parametrize("output", [True, False]) def test_gemm_a16_w16_layout(M: int, N: int, K: int, dtype, layout, output): + logger.info(f"Running test_gemm_a16_w16_layout with M={M}, N={N}, K={K}, dtype={dtype}, layout={layout}, output={output}") torch.cuda.empty_cache() # Helps avoid hangs in large tests x, w, _, out_dtype, y = generate_gemm_a16w16_inputs( @@ -164,6 +171,7 @@ def test_gemm_a16_w16_layout(M: int, N: int, K: int, dtype, layout, output): @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("output", [True, False]) def test_gemm_a16_w16_atomic(M: int, N: int, K: int, dtype, output): + logger.info(f"Running test_gemm_a16_w16_atomic with M={M}, N={N}, K={K}, dtype={dtype}, output={output}") torch.cuda.empty_cache() # Helps avoid hangs in large tests x, w, _, out_dtype, y = generate_gemm_a16w16_inputs(M, N, K, dtype, output=output) @@ -185,6 +193,7 @@ def test_gemm_a16_w16_atomic(M: int, N: int, K: int, dtype, output): @pytest.mark.parametrize("layout", ["TT", "NN", "NT"]) @pytest.mark.parametrize("output", [True, False]) def test_gemm_a16_w16_atomic_layout(M: int, N: int, K: int, dtype, layout, output): + logger.info(f"Running test_gemm_a16_w16_atomic_layout with M={M}, N={N}, K={K}, dtype={dtype}, layout={layout}, output={output}") torch.cuda.empty_cache() # Helps avoid hangs in large tests x, w, _, out_dtype, y = generate_gemm_a16w16_inputs( diff --git a/pyproject.toml b/pyproject.toml index 36ada796c1..e39cd81668 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,3 +14,7 @@ requires = [ write_to = "aiter/_version.py" write_to_template = "__version__ = '{version}'\n" fallback_version = "0.1.0" + +[tool.pytest.ini_options] +log_cli = true +log_cli_level = "INFO" \ No newline at end of file From 9492f4dd0762fbf1d29bdd8735044126cc4abaa1 Mon Sep 17 00:00:00 2001 From: Amir Balwel Date: Mon, 1 Dec 2025 08:01:01 +0000 Subject: [PATCH 03/10] improve tuning script and add test for qwen shapes Signed-off-by: Amir Balwel --- aiter/ops/triton/tune_a16w16.py | 179 ++++++++---------- .../triton_tests/test_gemm_a16w16_qwen3.py | 179 ++++++++++++++++++ 2 files changed, 254 insertions(+), 104 deletions(-) create mode 100644 op_tests/triton_tests/test_gemm_a16w16_qwen3.py diff --git a/aiter/ops/triton/tune_a16w16.py b/aiter/ops/triton/tune_a16w16.py index 67bc319494..d21e832edc 100644 --- a/aiter/ops/triton/tune_a16w16.py +++ b/aiter/ops/triton/tune_a16w16.py @@ -9,30 +9,55 @@ import torch from tqdm import tqdm +from aiter.ops.triton.gemm_a16w16 import gemm_a16w16 + +from op_tests.triton_tests.utils.types import str_to_torch_dtype -from gemm_a16w16 import gemm_a16w16 # type: ignore from utils.core import AITER_TRITON_CONFIGS_PATH # type: ignore mp.set_start_method("spawn", force=True) -DTYPE_MAP = { - "float32": torch.float32, - "float16": torch.float16, - "half": torch.half, - "bfloat16": torch.bfloat16, -} +def generate_gemm_a16w16_inputs(M, N, K, dtype, layout="TN", output=True, bias=False): + if isinstance(dtype, str): + dtype = str_to_torch_dtype[dtype] + + # TN is default layout + if layout[0] == "T": + print(M, K) + print(dtype) + x = torch.randn((M, K), dtype=dtype, device="cuda") + else: + x = torch.randn((K, M), dtype=dtype, device="cuda").T + + if layout[1] == "T": + weight = torch.randn((K, N), dtype=dtype, device="cuda").T + else: + weight = torch.randn((N, K), dtype=dtype, device="cuda") + + bias_tensor = None + if bias: + bias_tensor = torch.empty((N), dtype=dtype, device="cuda") + + y = None + if output: + y = torch.empty((M, N), dtype=dtype, device="cuda") + out_dtype = (None,) + else: + out_dtype = dtype + + return x, weight, bias_tensor, out_dtype, y def get_configs_compute_bound(): configs = [] - for num_stages in [2]: - for block_m in [16]: - for block_k in [64]: - for block_n in [32]: - for num_warps in [4]: - for group_size in [1]: - for waves_per_eu in [3]: + for num_stages in [2, 3, 4, 5]: + for block_m in [16, 32, 64, 128, 256]: + for block_k in [64, 128]: + for block_n in [32, 64, 128, 256]: + for num_warps in [4, 8]: + for group_size in [1, 16, 32, 64]: + for waves_per_eu in [2, 3, 4]: configs.append( { "BLOCK_SIZE_M": block_m, @@ -41,46 +66,17 @@ def get_configs_compute_bound(): "GROUP_SIZE_M": group_size, "num_warps": num_warps, "num_stages": num_stages, - "waves_per_eu": waves_per_eu, # TODO check if compatible - "matrix_instr_nonkdim": 16, # TODO - "cache_modifier": None, # TODO - "NUM_KSPLIT": 1, # TODO - "kpack": 1, # TODO - "SPLITK_BLOCK_SIZE": 1, + "waves_per_eu": waves_per_eu, + "NUM_KSPLIT": 1, + "kpack": 1, + "SPLITK_BLOCK_SIZE": 1, # Why are those 2 needed for gfx1201 + "cache_modifier": None # but does not exist in other configs? } ) return configs -# def get_configs_compute_bound(): -# configs = [] -# for num_stages in [2, 3, 4, 5]: -# for block_m in [16, 32, 64, 128, 256]: -# for block_k in [64, 128]: -# for block_n in [32, 64, 128, 256]: -# for num_warps in [4, 8]: -# for group_size in [1, 16, 32, 64]: -# for waves_per_eu in [1,2,3,4]: -# configs.append( -# { -# "BLOCK_SIZE_M": block_m, -# "BLOCK_SIZE_N": block_n, -# "BLOCK_SIZE_K": block_k, -# "GROUP_SIZE_M": group_size, -# "num_warps": num_warps, -# "num_stages": num_stages, -# "waves_per_eu": waves_per_eu, # TODO check if compatible -# "matrix_instr_nonkdim": 16, # TODO -# "cache_modifier": None, # TODO -# "NUM_KSPLIT": 1, # TODO -# "kpack": 1, # TODO -# "SPLITK_BLOCK_SIZE":1, -# } -# ) -# return configs - - -def get_weight_shapes(tp_size): +def get_weight_shapes(): total = [ (1024, 1024), (4096, 1024), @@ -98,8 +94,8 @@ def get_weight_shapes(tp_size): def benchmark_config(x, w, bias, dtype, y, config, activation, num_iters=10): def run(): - gemm_a16w16(x, w, bias, dtype, y, config, activation) - + gemm_a16w16(x, w, bias=bias, dtype=dtype, y=y, config=config, activation=activation) + torch.cuda.synchronize() # JIT complication & warmup for _ in range(5): @@ -121,22 +117,20 @@ def run(): return avg -def tune(M, N, K, out_dtype, search_space, input_type): - if input_type == "bfloat16": - fp16_info = torch.finfo(torch.bfloat16) - fp16_max, fp16_min = fp16_info.max, fp16_info.min - - x_fp32 = ( - (torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp16_max - ) - x = x_fp32.clamp(min=fp16_min, max=fp16_max).to(torch.bfloat16) - - w_fp32 = ( - (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp16_max - ) - w = w_fp32.clamp(min=fp16_min, max=fp16_max).to(torch.bfloat16) - else: - raise RuntimeError("Currently, only support tune w16a16 block fp16 kernel.") +def tune(M, N, K, dtype, search_space): + ( + x, + w, + _, + _, + _, + ) = generate_gemm_a16w16_inputs( + M, + N, + K, + dtype, + output=False, + ) best_config = None best_time = float("inf") @@ -146,7 +140,7 @@ def tune(M, N, K, out_dtype, search_space, input_type): x=x, w=w, bias=None, - dtype=torch.bfloat16, + dtype=dtype, y=None, config=config, activation=None, @@ -193,9 +187,8 @@ def tune_on_gpu(args_dict): torch.cuda.set_device(gpu_id) print(f"Starting tuning on GPU {gpu_id} with batch sizes {batch_sizes}") - out_dtype = DTYPE_MAP[args.out_dtype] + dtype = str_to_torch_dtype[args.dtype] save_path = AITER_TRITON_CONFIGS_PATH + "/gemm/" - input_type = args.input_type search_space = get_configs_compute_bound() @@ -208,9 +201,8 @@ def tune_on_gpu(args_dict): batch_size, N, K, - out_dtype, + dtype, search_space, - input_type, ) for batch_size in tqdm(batch_sizes, desc=f"GPU {gpu_id} - Batch sizes") ] @@ -224,16 +216,6 @@ def tune_on_gpu(args_dict): print(f"Tuning on GPU {gpu_id} took {end - start:.2f} seconds") -def distribute_batch_sizes(batch_sizes, num_gpus): - """Distribute batch sizes across available GPUs.""" - batches_per_gpu = [] - for i in range(num_gpus): - start_idx = i * len(batch_sizes) // num_gpus - end_idx = (i + 1) * len(batch_sizes) // num_gpus - batches_per_gpu.append(batch_sizes[start_idx:end_idx]) - return batches_per_gpu - - def main(args): print(args) num_gpus = torch.cuda.device_count() @@ -243,29 +225,23 @@ def main(args): torch.cuda.init() - if args.batch_size is None: - batch_sizes = [ - 64, - 128, - 256, - 512, - 2048, - 4096, - ] - else: - batch_sizes = [args.batch_size] - num_gpus = 1 # If only one batch size, use only one GPU - - weight_shapes = get_weight_shapes(args.tp_size) + batch_sizes = [ + 64, + 128, + 256, + 512, + 2048, + 4096, + ] - batches_per_gpu = distribute_batch_sizes(batch_sizes, 1) + weight_shapes = get_weight_shapes() process_args = [] for gpu_id in range(1): process_args.append( { "gpu_id": gpu_id, - "batch_sizes": batches_per_gpu[gpu_id], + "batch_sizes": batch_sizes, "weight_shapes": weight_shapes, # Each GPU processes all weight shapes "args": args, } @@ -283,17 +259,12 @@ def main(args): formatter_class=argparse.RawTextHelpFormatter, ) - parser.add_argument("--tp-size", "-tp", type=int, default=1) - parser.add_argument( - "--input-type", type=str, choices=["bfloat16"], default="bfloat16" - ) parser.add_argument( - "--out-dtype", + "--dtype", type=str, choices=["float32", "float16", "bfloat16", "half"], default="bfloat16", ) - parser.add_argument("--batch-size", type=int, required=False) args = parser.parse_args() main(args) diff --git a/op_tests/triton_tests/test_gemm_a16w16_qwen3.py b/op_tests/triton_tests/test_gemm_a16w16_qwen3.py new file mode 100644 index 0000000000..9681e7f5fa --- /dev/null +++ b/op_tests/triton_tests/test_gemm_a16w16_qwen3.py @@ -0,0 +1,179 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +import torch +import torch.nn.functional as F +import pytest +from aiter.ops.triton.gemm_a16w16 import gemm_a16w16 +from aiter.ops.triton.gemm_a16w16_atomic import gemm_a16w16_atomic +from op_tests.triton_tests.utils.types import str_to_torch_dtype + + +def generate_gemm_a16w16_inputs(M, N, K, dtype, layout="TN", output=True, bias=False): + if isinstance(dtype, str): + dtype = str_to_torch_dtype[dtype] + + # TN is default layout + if layout[0] == "T": + x = torch.randn((M, K), dtype=dtype, device="cuda") + else: + x = torch.randn((K, M), dtype=dtype, device="cuda").T + + if layout[1] == "T": + weight = torch.randn((K, N), dtype=dtype, device="cuda").T + else: + weight = torch.randn((N, K), dtype=dtype, device="cuda") + + bias_tensor = None + if bias: + bias_tensor = torch.empty((N), dtype=dtype, device="cuda") + + y = None + if output: + y = torch.empty((M, N), dtype=dtype, device="cuda") + out_dtype = (None,) + else: + out_dtype = dtype + + return x, weight, bias_tensor, out_dtype, y + + +def get_x_vals(): + x_vals = [] # minimal case + for batch in [64,128,256,512,2048,4096]: + x_vals += [ + (batch, 1024, 1024), + (batch, 4096, 1024), + (batch, 1024, 2048), + (batch, 6144, 1024), + (batch, 1024, 3072), + ] + return x_vals + + +@pytest.mark.parametrize("activation", ["gelu", "gelu_tanh", "silu", "silu_exp2"]) +@pytest.mark.parametrize("M, N, K", get_x_vals()) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("output", [True, False]) +def test_gemm_a16_w16_activation(M: int, N: int, K: int, dtype, output, activation): + x, w, _, out_dtype, y = generate_gemm_a16w16_inputs( + M, + N, + K, + dtype, + output=output, + ) + + torch_out = F.linear(x, w, bias=None) + if activation == "gelu": + torch_out = F.gelu(torch_out) + elif activation == "gelu_tanh": + torch_out = F.gelu(torch_out, approximate="tanh") + elif activation == "silu": + torch_out = F.silu(torch_out) + elif activation == "silu_exp2": + torch_out = F.silu(torch_out) + + if output: + triton_out = gemm_a16w16( + x, + w, + None, + out_dtype, + y, + activation=activation, + ) + else: + triton_out = gemm_a16w16( + x, + w, + None, + out_dtype, + activation=activation, + ) + + torch.testing.assert_close(triton_out, torch_out, atol=1e-1, rtol=1e-2) + + +@pytest.mark.parametrize("M, N, K", get_x_vals()) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("output", [True, False]) +def test_gemm_a16_w16(M: int, N: int, K: int, dtype, output): + torch.cuda.empty_cache() # Helps avoid hangs in large tests + + x, w, bias, out_dtype, y = generate_gemm_a16w16_inputs( + M, N, K, dtype, output=output, bias=True + ) + + torch_out = F.linear(x, w, bias=bias) + + if output: + triton_out = gemm_a16w16(x, w, bias, out_dtype, y) + else: + triton_out = gemm_a16w16(x, w, bias, out_dtype) + + torch.testing.assert_close(triton_out, torch_out, atol=1e-1, rtol=1e-1) + + +@pytest.mark.parametrize("M, N, K", get_x_vals()) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("layout", ["TT", "NN", "NT"]) +@pytest.mark.parametrize("output", [True, False]) +def test_gemm_a16_w16_layout(M: int, N: int, K: int, dtype, layout, output): + torch.cuda.empty_cache() # Helps avoid hangs in large tests + + x, w, _, out_dtype, y = generate_gemm_a16w16_inputs( + M, N, K, dtype, layout=layout, output=output + ) + + torch_out = F.linear(x, w, bias=None) + + if output: + triton_out = gemm_a16w16(x, w, None, out_dtype, y) + else: + triton_out = gemm_a16w16(x, w, None, out_dtype) + + torch.testing.assert_close(triton_out, torch_out, atol=1e-1, rtol=1e-1) + + +@pytest.mark.parametrize("M, N, K", get_x_vals()) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("output", [True, False]) +def test_gemm_a16_w16_atomic(M: int, N: int, K: int, dtype, output): + torch.cuda.empty_cache() # Helps avoid hangs in large tests + + x, w, _, out_dtype, y = generate_gemm_a16w16_inputs(M, N, K, dtype, output=output) + + torch_out = F.linear(x, w, bias=None) + + # Accumulation in bf16/fp16 leads to precision loss, cast y to fp32 to prevent that + if output: + y = y.to(torch.float32).zero_() + triton_out = gemm_a16w16_atomic(x, w, torch.float32, y).to(dtype) + else: + triton_out = gemm_a16w16_atomic(x, w, dtype=torch.float32).to(dtype) + + torch.testing.assert_close(triton_out, torch_out, atol=1e-1, rtol=1e-1) + + +@pytest.mark.parametrize("M, N, K", get_x_vals()) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("layout", ["TT", "NN", "NT"]) +@pytest.mark.parametrize("output", [True, False]) +def test_gemm_a16_w16_atomic_layout(M: int, N: int, K: int, dtype, layout, output): + torch.cuda.empty_cache() # Helps avoid hangs in large tests + + x, w, _, out_dtype, y = generate_gemm_a16w16_inputs( + M, N, K, dtype, layout=layout, output=output + ) + + torch_out = F.linear(x, w, bias=None) + + # Accumulation in bf16/fp16 leads to precision loss, cast y to fp32 to prevent that + if output: + y = y.to(torch.float32).zero_() + triton_out = gemm_a16w16_atomic(x, w, torch.float32, y).to(dtype) + else: + triton_out = gemm_a16w16_atomic(x, w, dtype=torch.float32).to(dtype) + + torch.testing.assert_close(triton_out, torch_out, atol=1e-1, rtol=1e-1) From 4c99eff4b8840aff2448907a50919a4ef3cf1f76 Mon Sep 17 00:00:00 2001 From: Amir Balwel Date: Thu, 4 Dec 2025 04:40:59 +0000 Subject: [PATCH 04/10] proper tuning and results Signed-off-by: Amir Balwel --- .../gemm/R9700-GEMM-A16W16-N=1024-K=1024.json | 78 ++++++++-------- .../gemm/R9700-GEMM-A16W16-N=1024-K=2048.json | 84 ++++++++--------- .../gemm/R9700-GEMM-A16W16-N=1024-K=3072.json | 82 ++++++++--------- .../gemm/R9700-GEMM-A16W16-N=4096-K=1024.json | 82 ++++++++--------- .../gemm/R9700-GEMM-A16W16-N=6144-K=1024.json | 70 +++++++-------- .../configs/gemm/R9700-GEMM-A16W16.json | 90 ++++++++----------- aiter/ops/triton/tune_a16w16.py | 70 +++++++-------- 7 files changed, 267 insertions(+), 289 deletions(-) diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=1024-K=1024.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=1024-K=1024.json index 25c11e56cf..3f2af9911b 100644 --- a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=1024-K=1024.json +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=1024-K=1024.json @@ -1,86 +1,86 @@ { "M_LEQ_64": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 3, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5, + "waves_per_eu": 2, "matrix_instr_nonkdim": 16, - "cache_modifier": null, "NUM_KSPLIT": 1, "kpack": 1, - "SPLITK_BLOCK_SIZE": 1 + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 }, "M_LEQ_128": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, - "waves_per_eu": 3, + "waves_per_eu": 2, "matrix_instr_nonkdim": 16, - "cache_modifier": null, "NUM_KSPLIT": 1, "kpack": 1, - "SPLITK_BLOCK_SIZE": 1 + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 }, "M_LEQ_256": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, + "GROUP_SIZE_M": 64, + "num_warps": 8, "num_stages": 2, - "waves_per_eu": 3, + "waves_per_eu": 4, "matrix_instr_nonkdim": 16, - "cache_modifier": null, "NUM_KSPLIT": 1, "kpack": 1, - "SPLITK_BLOCK_SIZE": 1 + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 }, "M_LEQ_512": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, "num_stages": 2, - "waves_per_eu": 3, + "waves_per_eu": 2, "matrix_instr_nonkdim": 16, - "cache_modifier": null, "NUM_KSPLIT": 1, "kpack": 1, - "SPLITK_BLOCK_SIZE": 1 + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 }, "M_LEQ_2048": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, + "GROUP_SIZE_M": 64, + "num_warps": 8, "num_stages": 2, - "waves_per_eu": 3, + "waves_per_eu": 2, "matrix_instr_nonkdim": 16, - "cache_modifier": null, "NUM_KSPLIT": 1, "kpack": 1, - "SPLITK_BLOCK_SIZE": 1 + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 }, "any": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, - "waves_per_eu": 3, + "waves_per_eu": 2, "matrix_instr_nonkdim": 16, - "cache_modifier": null, "NUM_KSPLIT": 1, "kpack": 1, - "SPLITK_BLOCK_SIZE": 1 + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 } } diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=1024-K=2048.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=1024-K=2048.json index 25c11e56cf..0a62925162 100644 --- a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=1024-K=2048.json +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=1024-K=2048.json @@ -1,86 +1,86 @@ { "M_LEQ_64": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 3, + "num_stages": 3, + "waves_per_eu": 4, "matrix_instr_nonkdim": 16, - "cache_modifier": null, "NUM_KSPLIT": 1, "kpack": 1, - "SPLITK_BLOCK_SIZE": 1 + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 2048 }, "M_LEQ_128": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, "num_stages": 2, - "waves_per_eu": 3, + "waves_per_eu": 2, "matrix_instr_nonkdim": 16, - "cache_modifier": null, "NUM_KSPLIT": 1, "kpack": 1, - "SPLITK_BLOCK_SIZE": 1 + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 2048 }, "M_LEQ_256": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, "num_stages": 2, - "waves_per_eu": 3, + "waves_per_eu": 2, "matrix_instr_nonkdim": 16, - "cache_modifier": null, "NUM_KSPLIT": 1, "kpack": 1, - "SPLITK_BLOCK_SIZE": 1 + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 2048 }, "M_LEQ_512": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 2, "waves_per_eu": 3, "matrix_instr_nonkdim": 16, - "cache_modifier": null, "NUM_KSPLIT": 1, "kpack": 1, - "SPLITK_BLOCK_SIZE": 1 + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 2048 }, "M_LEQ_2048": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, + "GROUP_SIZE_M": 16, + "num_warps": 8, "num_stages": 2, - "waves_per_eu": 3, + "waves_per_eu": 2, "matrix_instr_nonkdim": 16, - "cache_modifier": null, "NUM_KSPLIT": 1, "kpack": 1, - "SPLITK_BLOCK_SIZE": 1 + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 2048 }, "any": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, - "waves_per_eu": 3, + "waves_per_eu": 2, "matrix_instr_nonkdim": 16, - "cache_modifier": null, "NUM_KSPLIT": 1, "kpack": 1, - "SPLITK_BLOCK_SIZE": 1 + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 2048 } } diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=1024-K=3072.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=1024-K=3072.json index 25c11e56cf..79e50ebc8d 100644 --- a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=1024-K=3072.json +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=1024-K=3072.json @@ -1,86 +1,86 @@ { "M_LEQ_64": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 3, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4, + "waves_per_eu": 4, "matrix_instr_nonkdim": 16, - "cache_modifier": null, "NUM_KSPLIT": 1, "kpack": 1, - "SPLITK_BLOCK_SIZE": 1 + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 3072 }, "M_LEQ_128": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, "num_stages": 2, - "waves_per_eu": 3, + "waves_per_eu": 2, "matrix_instr_nonkdim": 16, - "cache_modifier": null, "NUM_KSPLIT": 1, "kpack": 1, - "SPLITK_BLOCK_SIZE": 1 + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 3072 }, "M_LEQ_256": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, - "waves_per_eu": 3, + "waves_per_eu": 4, "matrix_instr_nonkdim": 16, - "cache_modifier": null, "NUM_KSPLIT": 1, "kpack": 1, - "SPLITK_BLOCK_SIZE": 1 + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 3072 }, "M_LEQ_512": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 2, - "waves_per_eu": 3, + "waves_per_eu": 2, "matrix_instr_nonkdim": 16, - "cache_modifier": null, "NUM_KSPLIT": 1, "kpack": 1, - "SPLITK_BLOCK_SIZE": 1 + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 3072 }, "M_LEQ_2048": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, - "waves_per_eu": 3, + "waves_per_eu": 4, "matrix_instr_nonkdim": 16, - "cache_modifier": null, "NUM_KSPLIT": 1, "kpack": 1, - "SPLITK_BLOCK_SIZE": 1 + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 3072 }, "any": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, + "GROUP_SIZE_M": 32, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 3, "matrix_instr_nonkdim": 16, - "cache_modifier": null, "NUM_KSPLIT": 1, "kpack": 1, - "SPLITK_BLOCK_SIZE": 1 + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 3072 } } diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=4096-K=1024.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=4096-K=1024.json index 25c11e56cf..c8d365ec8e 100644 --- a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=4096-K=1024.json +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=4096-K=1024.json @@ -1,86 +1,86 @@ { "M_LEQ_64": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 2, - "waves_per_eu": 3, + "waves_per_eu": 2, "matrix_instr_nonkdim": 16, - "cache_modifier": null, "NUM_KSPLIT": 1, "kpack": 1, - "SPLITK_BLOCK_SIZE": 1 + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 }, "M_LEQ_128": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 2, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3, "waves_per_eu": 3, "matrix_instr_nonkdim": 16, - "cache_modifier": null, "NUM_KSPLIT": 1, "kpack": 1, - "SPLITK_BLOCK_SIZE": 1 + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 }, "M_LEQ_256": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, + "GROUP_SIZE_M": 32, + "num_warps": 8, "num_stages": 2, - "waves_per_eu": 3, + "waves_per_eu": 4, "matrix_instr_nonkdim": 16, - "cache_modifier": null, "NUM_KSPLIT": 1, "kpack": 1, - "SPLITK_BLOCK_SIZE": 1 + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 }, "M_LEQ_512": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 3, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3, + "waves_per_eu": 4, "matrix_instr_nonkdim": 16, - "cache_modifier": null, "NUM_KSPLIT": 1, "kpack": 1, - "SPLITK_BLOCK_SIZE": 1 + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 }, "M_LEQ_2048": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, - "waves_per_eu": 3, + "waves_per_eu": 4, "matrix_instr_nonkdim": 16, - "cache_modifier": null, "NUM_KSPLIT": 1, "kpack": 1, - "SPLITK_BLOCK_SIZE": 1 + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 }, "any": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, - "waves_per_eu": 3, + "waves_per_eu": 4, "matrix_instr_nonkdim": 16, - "cache_modifier": null, "NUM_KSPLIT": 1, "kpack": 1, - "SPLITK_BLOCK_SIZE": 1 + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 } } diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=6144-K=1024.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=6144-K=1024.json index 25c11e56cf..3b7c3066ec 100644 --- a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=6144-K=1024.json +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=6144-K=1024.json @@ -1,86 +1,86 @@ { "M_LEQ_64": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 2, "waves_per_eu": 3, "matrix_instr_nonkdim": 16, - "cache_modifier": null, "NUM_KSPLIT": 1, "kpack": 1, - "SPLITK_BLOCK_SIZE": 1 + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 }, "M_LEQ_128": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, + "GROUP_SIZE_M": 32, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 3, "matrix_instr_nonkdim": 16, - "cache_modifier": null, "NUM_KSPLIT": 1, "kpack": 1, - "SPLITK_BLOCK_SIZE": 1 + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 }, "M_LEQ_256": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, + "GROUP_SIZE_M": 32, + "num_warps": 8, "num_stages": 2, - "waves_per_eu": 3, + "waves_per_eu": 2, "matrix_instr_nonkdim": 16, - "cache_modifier": null, "NUM_KSPLIT": 1, "kpack": 1, - "SPLITK_BLOCK_SIZE": 1 + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 }, "M_LEQ_512": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, + "GROUP_SIZE_M": 16, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 3, "matrix_instr_nonkdim": 16, - "cache_modifier": null, "NUM_KSPLIT": 1, "kpack": 1, - "SPLITK_BLOCK_SIZE": 1 + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 }, "M_LEQ_2048": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, + "GROUP_SIZE_M": 32, + "num_warps": 8, "num_stages": 2, - "waves_per_eu": 3, + "waves_per_eu": 4, "matrix_instr_nonkdim": 16, - "cache_modifier": null, "NUM_KSPLIT": 1, "kpack": 1, - "SPLITK_BLOCK_SIZE": 1 + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 }, "any": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 3, "matrix_instr_nonkdim": 16, - "cache_modifier": null, "NUM_KSPLIT": 1, "kpack": 1, - "SPLITK_BLOCK_SIZE": 1 + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 } } diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16.json index b6fbdca5bc..3f2af9911b 100644 --- a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16.json +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16.json @@ -1,100 +1,86 @@ { "M_LEQ_64": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 3, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5, + "waves_per_eu": 2, "matrix_instr_nonkdim": 16, - "cache_modifier": null, "NUM_KSPLIT": 1, "kpack": 1, - "SPLITK_BLOCK_SIZE": 1 + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 }, "M_LEQ_128": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, - "waves_per_eu": 3, + "waves_per_eu": 2, "matrix_instr_nonkdim": 16, - "cache_modifier": null, "NUM_KSPLIT": 1, "kpack": 1, - "SPLITK_BLOCK_SIZE": 1 + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 }, "M_LEQ_256": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, + "GROUP_SIZE_M": 64, + "num_warps": 8, "num_stages": 2, - "waves_per_eu": 3, + "waves_per_eu": 4, "matrix_instr_nonkdim": 16, - "cache_modifier": null, "NUM_KSPLIT": 1, "kpack": 1, - "SPLITK_BLOCK_SIZE": 1 + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 }, "M_LEQ_512": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, "num_stages": 2, - "waves_per_eu": 3, + "waves_per_eu": 2, "matrix_instr_nonkdim": 16, - "cache_modifier": null, "NUM_KSPLIT": 1, "kpack": 1, - "SPLITK_BLOCK_SIZE": 1 + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 }, "M_LEQ_2048": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, + "GROUP_SIZE_M": 64, + "num_warps": 8, "num_stages": 2, - "waves_per_eu": 3, + "waves_per_eu": 2, "matrix_instr_nonkdim": 16, - "cache_modifier": null, "NUM_KSPLIT": 1, "kpack": 1, - "SPLITK_BLOCK_SIZE": 1 + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 }, "any": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, - "waves_per_eu": 3, + "waves_per_eu": 2, "matrix_instr_nonkdim": 16, - "cache_modifier": null, "NUM_KSPLIT": 1, "kpack": 1, - "SPLITK_BLOCK_SIZE": 1 - }, - "default": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 3, - "matrix_instr_nonkdim": 16, "cache_modifier": null, - "NUM_KSPLIT": 1, - "kpack": 1, - "SPLITK_BLOCK_SIZE": 1 + "SPLITK_BLOCK_SIZE": 1024 } } diff --git a/aiter/ops/triton/tune_a16w16.py b/aiter/ops/triton/tune_a16w16.py index d21e832edc..08bb254c62 100644 --- a/aiter/ops/triton/tune_a16w16.py +++ b/aiter/ops/triton/tune_a16w16.py @@ -1,6 +1,5 @@ import argparse import json -import multiprocessing as mp import os import time import triton @@ -15,8 +14,6 @@ from utils.core import AITER_TRITON_CONFIGS_PATH # type: ignore -mp.set_start_method("spawn", force=True) - def generate_gemm_a16w16_inputs(M, N, K, dtype, layout="TN", output=True, bias=False): if isinstance(dtype, str): @@ -67,10 +64,10 @@ def get_configs_compute_bound(): "num_warps": num_warps, "num_stages": num_stages, "waves_per_eu": waves_per_eu, + "matrix_instr_nonkdim": 16, "NUM_KSPLIT": 1, "kpack": 1, - "SPLITK_BLOCK_SIZE": 1, # Why are those 2 needed for gfx1201 - "cache_modifier": None # but does not exist in other configs? + "cache_modifier": None, } ) return configs @@ -93,9 +90,13 @@ def get_weight_shapes(): def benchmark_config(x, w, bias, dtype, y, config, activation, num_iters=10): + torch_out = torch.nn.functional.linear(x, w, bias=bias) # Ground truth + def run(): - gemm_a16w16(x, w, bias=bias, dtype=dtype, y=y, config=config, activation=activation) - + return gemm_a16w16( + x, w, bias=bias, dtype=dtype, y=y, config=config, activation=None + ) + torch.cuda.synchronize() # JIT complication & warmup for _ in range(5): @@ -109,46 +110,42 @@ def run(): for i in range(num_iters): torch.cuda.synchronize() start_event.record() - run() + triton_out = run() end_event.record() end_event.synchronize() latencies.append(start_event.elapsed_time(end_event)) + torch.testing.assert_close(triton_out, torch_out, atol=1e-1, rtol=1e-1) avg = sum(latencies) / (num_iters * 10) * 1000 # us return avg def tune(M, N, K, dtype, search_space): - ( - x, - w, - _, - _, - _, - ) = generate_gemm_a16w16_inputs( - M, - N, - K, - dtype, - output=False, + x, w, bias, out_dtype, y = generate_gemm_a16w16_inputs( + M, N, K, dtype, output=True, bias=True ) best_config = None best_time = float("inf") for config in tqdm(search_space): + config["SPLITK_BLOCK_SIZE"]= triton.cdiv(K, config["NUM_KSPLIT"]) try: kernel_time = benchmark_config( x=x, w=w, - bias=None, - dtype=dtype, - y=None, + bias=bias, + dtype=out_dtype, + y=y, config=config, activation=None, num_iters=10, ) except triton.runtime.autotuner.OutOfResources: + print("OutOfResources encountered during tuning.") # Some configurations may be invalid and fail to compile. continue + except AssertionError: + print("AssertionError encountered during tuning.") + continue if kernel_time < best_time: best_time = kernel_time @@ -221,7 +218,7 @@ def main(args): num_gpus = torch.cuda.device_count() if num_gpus == 0: raise RuntimeError("No GPU available for tuning") - print(f"Found {num_gpus} GPUs for parallel tuning") + print(f"Found {num_gpus} GPUs for tuning") torch.cuda.init() @@ -236,22 +233,17 @@ def main(args): weight_shapes = get_weight_shapes() - process_args = [] - for gpu_id in range(1): - process_args.append( - { - "gpu_id": gpu_id, - "batch_sizes": batch_sizes, - "weight_shapes": weight_shapes, # Each GPU processes all weight shapes - "args": args, - } - ) - - ctx = mp.get_context("spawn") - with ctx.Pool(1) as pool: - pool.map(tune_on_gpu, process_args) + # Run tuning sequentially on GPU 0 + tune_on_gpu( + { + "gpu_id": 0, + "batch_sizes": batch_sizes, + "weight_shapes": weight_shapes, + "args": args, + } + ) - print("Multi-GPU tuning completed") + print("Tuning completed") if __name__ == "__main__": From 8f4551eb16abb798304d7e35916ea25c3d8b2640 Mon Sep 17 00:00:00 2001 From: big-yellow-duck Date: Thu, 4 Dec 2025 07:11:39 +0000 Subject: [PATCH 05/10] updating atomic kernel --- aiter/ops/triton/gemm_a16w16_atomic.py | 2 +- aiter/ops/triton/tune_a16w16_atomic.py | 216 ++++++++++++++++--------- 2 files changed, 142 insertions(+), 76 deletions(-) diff --git a/aiter/ops/triton/gemm_a16w16_atomic.py b/aiter/ops/triton/gemm_a16w16_atomic.py index 78026c80f0..7a6e41e924 100644 --- a/aiter/ops/triton/gemm_a16w16_atomic.py +++ b/aiter/ops/triton/gemm_a16w16_atomic.py @@ -18,7 +18,7 @@ def gemm_a16w16_atomic( x, w, - dtype: Optional[float] = torch.bfloat16, + dtype: Optional[torch.dtype] = torch.bfloat16, y: Optional[torch.Tensor] = None, config: Optional[dict] = None, ): diff --git a/aiter/ops/triton/tune_a16w16_atomic.py b/aiter/ops/triton/tune_a16w16_atomic.py index 25c1a51f6f..746d7b2ae8 100644 --- a/aiter/ops/triton/tune_a16w16_atomic.py +++ b/aiter/ops/triton/tune_a16w16_atomic.py @@ -5,7 +5,7 @@ import time import triton from datetime import datetime - +from typing import List, Dict, Any, Union, Tuple, Optional import torch from tqdm import tqdm @@ -24,7 +24,48 @@ } -def get_configs_compute_bound(): +# def get_configs_compute_bound(): +# configs = [] +# for num_stages in [2]: +# for block_m in [ +# 16, +# ]: +# for block_k in [ +# 64, +# ]: +# for block_n in [ +# 32, +# ]: +# for num_warps in [ +# 4, +# ]: +# for group_size in [ +# 1, +# ]: +# for num_ksplit in [ +# 1, +# ]: # Atomic kernel specific parameter +# for waves_per_eu in [3]: +# configs.append( +# { +# "BLOCK_SIZE_M": block_m, +# "BLOCK_SIZE_N": block_n, +# "BLOCK_SIZE_K": block_k, +# "GROUP_SIZE_M": group_size, +# "num_warps": num_warps, +# "num_stages": num_stages, +# "waves_per_eu": waves_per_eu, # TODO check if compatible +# "matrix_instr_nonkdim": 16, # TODO +# "cache_modifier": "", # Empty string for atomic kernel +# "NUM_KSPLIT": num_ksplit, # Atomic kernel specific +# "kpack": 1, # TODO +# "SPLITK_BLOCK_SIZE": 1, # Will be set dynamically +# } +# ) +# return configs + + +def get_configs_compute_bound() -> List[Dict[str, int | str]]: configs = [] for num_stages in [2]: for block_m in [ @@ -46,6 +87,19 @@ def get_configs_compute_bound(): 1, ]: # Atomic kernel specific parameter for waves_per_eu in [3]: + # for num_stages in [2, 3, 4, 5]: + # for block_m in [16, 32, 64, 128, 256]: + # for block_k in [64, 128]: + # for block_n in [32, 64, 128, 256]: + # for num_warps in [4, 8]: + # for group_size in [1, 8, 16, 32, 64]: + # for num_ksplit in [ + # 1, + # 2, + # 4, + # 8, + # ]: # Atomic kernel specific parameter + # for waves_per_eu in [1, 2, 3, 4]: configs.append( { "BLOCK_SIZE_M": block_m, @@ -65,36 +119,7 @@ def get_configs_compute_bound(): return configs -# def get_configs_compute_bound(): -# configs = [] -# for num_stages in [2, 3, 4, 5]: -# for block_m in [16, 32, 64, 128, 256]: -# for block_k in [64, 128]: -# for block_n in [32, 64, 128, 256]: -# for num_warps in [4, 8]: -# for group_size in [1, 8, 16, 32, 64]: -# for num_ksplit in [1, 2, 4, 8]: # Atomic kernel specific parameter -# for waves_per_eu in [1,2,3,4]: -# configs.append( -# { -# "BLOCK_SIZE_M": block_m, -# "BLOCK_SIZE_N": block_n, -# "BLOCK_SIZE_K": block_k, -# "GROUP_SIZE_M": group_size, -# "num_warps": num_warps, -# "num_stages": num_stages, -# "waves_per_eu": waves_per_eu, # TODO check if compatible -# "matrix_instr_nonkdim": 16, # TODO -# "cache_modifier": None, # TODO -# "NUM_KSPLIT": num_ksplit, # Atomic kernel specific -# "kpack": 1, # TODO -# "SPLITK_BLOCK_SIZE": 1, # Will be set dynamically -# } -# ) -# return configs - - -def get_weight_shapes(tp_size): +def get_weight_shapes(tp_size: int) -> List[Tuple[int, int]]: total = [ (1024, 1024), (4096, 1024), @@ -103,14 +128,48 @@ def get_weight_shapes(tp_size): (1024, 3072), ] - weight_shapes = [] + weight_shapes: List[Tuple[int, int]] = [] for t in total: weight_shapes.append(t) return weight_shapes -def benchmark_config(x, w, dtype, y, config, num_iters=10): +def benchmark_config( + x: torch.Tensor, + w: torch.Tensor, + dtype: torch.dtype, + config: Dict[str, Union[str, int]], + y: Optional[torch.Tensor] = None, + num_iters=10, +) -> float: + """ + Benchmark the performance of a GEMM operation with a specific configuration. + + This function measures the execution time of the gemm_a16w16_atomic kernel by running + it multiple times with synchronization points to ensure accurate timing. It performs + warmup runs before the actual benchmarking to account for JIT compilation overhead. + + Args: + x (torch.Tensor): Input tensor of shape (M, K) representing the first matrix operand. + w (torch.Tensor): Weight tensor of shape (N, K) representing the second matrix operand. + dtype (torch.dtype): Data type for the computation (e.g., torch.bfloat16). + config (Dict[str, Union[str, int]]): Configuration dictionary containing kernel + parameters such as block sizes, number of warps, etc. + y (Optional[torch.Tensor], optional): Output tensor to store the result. If None, + a new tensor will be allocated. Defaults to None. + num_iters (int, optional): Number of benchmark iterations to run. Defaults to 10. + + Returns: + float: Average execution time in microseconds (us) per iteration. + + Note: + The function performs 5 warmup iterations before benchmarking to account for + JIT compilation and GPU warmup effects. The timing is measured using CUDA events + for accurate GPU kernel timing. + """ + + # run the kernel def run(): gemm_a16w16_atomic(x, w, dtype, y, config) @@ -135,20 +194,23 @@ def run(): return avg -def tune(M, N, K, out_dtype, search_space, input_type): +def tune( + M: int, N: int, K: int, search_space: List[Dict[str, int | str]], input_type: str +): if input_type == "bfloat16": - fp16_info = torch.finfo(torch.bfloat16) - fp16_max, fp16_min = fp16_info.max, fp16_info.min + bf16_info = torch.finfo(torch.bfloat16) + bf16_max, bf_16_max = bf16_info.max, bf16_info.min + # create random weights downcasted from torch.float32 to torch.bf16 x_fp32 = ( - (torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp16_max + (torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * bf16_max ) - x = x_fp32.clamp(min=fp16_min, max=fp16_max).to(torch.bfloat16) + x = x_fp32.clamp(min=bf_16_max, max=bf16_max).to(torch.bfloat16) w_fp32 = ( - (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp16_max + (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * bf16_max ) - w = w_fp32.clamp(min=fp16_min, max=fp16_max).to(torch.bfloat16) + w = w_fp32.clamp(min=bf_16_max, max=bf16_max).to(torch.bfloat16) else: raise RuntimeError("Currently, only support tune w16a16 block fp16 kernel.") @@ -164,7 +226,9 @@ def tune(M, N, K, out_dtype, search_space, input_type): config=config, num_iters=10, ) - except triton.runtime.autotuner.OutOfResources: + except triton.runtime.autotuner.OutOfResources as e: + print("config failed!", config) + print("error: ", e) # Some configurations may be invalid and fail to compile. continue @@ -195,27 +259,25 @@ def save_configs( f.write("\n") -def tune_on_gpu(args_dict): +def tune_on_gpu( + gpu_id: int, + batch_sizes: List[int], + weight_shapes: List[Tuple[int, int]], + input_type: str, +) -> None: """Run tuning on a specific GPU.""" - gpu_id = args_dict["gpu_id"] - batch_sizes = args_dict["batch_sizes"] - weight_shapes = args_dict["weight_shapes"] - args = args_dict["args"] - torch.cuda.set_device(gpu_id) print(f"Starting tuning on GPU {gpu_id} with batch sizes {batch_sizes}") - out_dtype = DTYPE_MAP[args.out_dtype] save_path = AITER_TRITON_CONFIGS_PATH + "/gemm/" - input_type = args.input_type search_space = get_configs_compute_bound() start = time.time() - + # Collect all configs to determine best overall config - all_configs = [] - + all_configs: List[Dict[str, Dict[str, int | str]]] = [] + for shape in tqdm(weight_shapes, desc=f"GPU {gpu_id} - Shapes"): N, K = shape[0], shape[1] print(f"[GPU {gpu_id}] Tune for weight shape of `N: {N}, K: {K}`") @@ -224,13 +286,12 @@ def tune_on_gpu(args_dict): batch_size, N, K, - out_dtype, search_space, input_type, ) for batch_size in tqdm(batch_sizes, desc=f"GPU {gpu_id} - Batch sizes") ] - best_configs = {} + best_configs: Dict[str, Dict[str, int | str]] = {} # Create configs for different M size categories as expected by the atomic kernel for i, (M, config) in enumerate(zip(batch_sizes, benchmark_results)): if i == len(batch_sizes) - 1: @@ -252,7 +313,7 @@ def tune_on_gpu(args_dict): # Store configs for later analysis all_configs.append(best_configs) save_configs(N, K, best_configs, save_path) - + # Create a default config file (without N,K parameters) by selecting most common config default_config = create_default_config(all_configs) save_default_config(default_config, save_path) @@ -261,10 +322,12 @@ def tune_on_gpu(args_dict): print(f"Tuning on GPU {gpu_id} took {end - start:.2f} seconds") -def create_default_config(all_configs): +def create_default_config( + all_configs: List[Dict[str, Dict[str, Union[int, str]]]], +) -> Dict[str, Dict[str, Union[int, str]]]: """Create a default config by selecting the most common config across all shapes.""" from collections import Counter - + # Collect all configs for each category category_configs = { "small": [], @@ -273,43 +336,45 @@ def create_default_config(all_configs): "medium_M128": [], "large": [], "xlarge": [], - "any": [] + "any": [], } - + for config in all_configs: for category, params in config.items(): if category in category_configs: # Convert config to a hashable tuple for counting config_tuple = tuple(sorted(params.items())) category_configs[category].append(config_tuple) - + # Find the most common config for each category - default_config = {} + default_config: Dict[str, Dict[str, Union[int, str]]] = {} for category, configs in category_configs.items(): if configs: most_common = Counter(configs).most_common(1)[0][0] default_config[category] = dict(most_common) - + return default_config -def save_default_config(config, save_path): +def save_default_config( + config: Dict[str, Dict[str, Union[int, str]]], save_path: str +) -> None: """Save the default config file (without N,K parameters).""" os.makedirs(save_path, exist_ok=True) device_name = "R9700" # TODO: Hardcoded, make it dynamic json_file_name = f"{device_name}-GEMM-A16W16-ATOMIC.json" - + config_file_path = os.path.join(save_path, json_file_name) print(f"Writing default config to {config_file_path}...") - + with open(config_file_path, "w") as f: json.dump(config, f, indent=4) f.write("\n") -def distribute_batch_sizes(batch_sizes, num_gpus): +def distribute_batch_sizes(batch_sizes: List[int], num_gpus: int) -> List[List[int]]: """Distribute batch sizes across available GPUs.""" - batches_per_gpu = [] + batches_per_gpu: List[List[int]] = [] for i in range(num_gpus): start_idx = i * len(batch_sizes) // num_gpus end_idx = (i + 1) * len(batch_sizes) // num_gpus @@ -345,20 +410,21 @@ def main(args): batches_per_gpu = distribute_batch_sizes(batch_sizes, 1) + # Prepare arguments for each GPU process process_args = [] for gpu_id in range(1): process_args.append( - { - "gpu_id": gpu_id, - "batch_sizes": batches_per_gpu[gpu_id], - "weight_shapes": weight_shapes, # Each GPU processes all weight shapes - "args": args, - } + ( + gpu_id, + batches_per_gpu[gpu_id], + weight_shapes, # Each GPU processes all weight shapes + args.input_type + ) ) ctx = mp.get_context("spawn") with ctx.Pool(1) as pool: - pool.map(tune_on_gpu, process_args) + pool.starmap(tune_on_gpu, process_args) print("Multi-GPU tuning completed") From 8c77318e9e8e2329f7835d635b3f41abf5ab910e Mon Sep 17 00:00:00 2001 From: big-yellow-duck Date: Fri, 5 Dec 2025 14:41:57 +0000 Subject: [PATCH 06/10] added gemm_a8w8_blockscale tuning --- ...9700-GEMM-A16W16-ATOMIC-N=1024-K=1024.json | 79 ++- ...9700-GEMM-A16W16-ATOMIC-N=1024-K=2048.json | 64 +-- ...9700-GEMM-A16W16-ATOMIC-N=1024-K=3072.json | 36 +- ...9700-GEMM-A16W16-ATOMIC-N=4096-K=1024.json | 62 +-- ...9700-GEMM-A16W16-ATOMIC-N=6144-K=1024.json | 40 +- .../gemm/R9700-GEMM-A16W16-ATOMIC.json | 38 +- aiter/ops/triton/gemm_a8w8.py | 2 +- aiter/ops/triton/gemm_a8w8_blockscale.py | 2 +- aiter/ops/triton/tune_a16w16_atomic.py | 154 +++--- aiter/ops/triton/tune_a8w8_blockscale.py | 492 ++++++++++++++++++ 10 files changed, 709 insertions(+), 260 deletions(-) create mode 100644 aiter/ops/triton/tune_a8w8_blockscale.py diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=1024.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=1024.json index 3b3b9424eb..84fb9646ed 100644 --- a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=1024.json +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=1024.json @@ -1,100 +1,93 @@ { "small": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 3, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4, + "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, - "kpack": 1, "SPLITK_BLOCK_SIZE": 1024 }, "medium_M32": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, - "waves_per_eu": 3, + "waves_per_eu": 2, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, - "kpack": 1, "SPLITK_BLOCK_SIZE": 1024 }, "medium_M64": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 3, + "num_warps": 8, + "num_stages": 3, + "waves_per_eu": 2, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, - "kpack": 1, "SPLITK_BLOCK_SIZE": 1024 }, "medium_M128": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 3, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3, + "waves_per_eu": 8, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, - "kpack": 1, "SPLITK_BLOCK_SIZE": 1024 }, "large": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 3, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3, + "waves_per_eu": 8, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, - "kpack": 1, "SPLITK_BLOCK_SIZE": 1024 }, "xlarge": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, + "GROUP_SIZE_M": 32, + "num_warps": 8, "num_stages": 2, - "waves_per_eu": 3, + "waves_per_eu": 2, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, - "kpack": 1, "SPLITK_BLOCK_SIZE": 1024 }, "any": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, - "waves_per_eu": 3, + "waves_per_eu": 2, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, - "kpack": 1, "SPLITK_BLOCK_SIZE": 1024 } } diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=2048.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=2048.json index 4bd20b9771..6807313ca6 100644 --- a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=2048.json +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=2048.json @@ -1,12 +1,12 @@ { "small": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 3, + "num_stages": 1, + "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, @@ -14,13 +14,13 @@ "SPLITK_BLOCK_SIZE": 2048 }, "medium_M32": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 3, + "num_stages": 1, + "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, @@ -28,13 +28,13 @@ "SPLITK_BLOCK_SIZE": 2048 }, "medium_M64": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 3, + "num_stages": 1, + "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, @@ -42,13 +42,13 @@ "SPLITK_BLOCK_SIZE": 2048 }, "medium_M128": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 3, + "num_stages": 1, + "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, @@ -56,13 +56,13 @@ "SPLITK_BLOCK_SIZE": 2048 }, "large": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 3, + "num_stages": 1, + "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, @@ -70,13 +70,13 @@ "SPLITK_BLOCK_SIZE": 2048 }, "xlarge": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 3, + "num_stages": 1, + "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, @@ -84,13 +84,13 @@ "SPLITK_BLOCK_SIZE": 2048 }, "any": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 3, + "num_stages": 1, + "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=3072.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=3072.json index e9df57e6e1..f0d8836a5a 100644 --- a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=3072.json +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=3072.json @@ -1,11 +1,11 @@ { "small": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 2, + "num_stages": 1, "waves_per_eu": 3, "matrix_instr_nonkdim": 16, "cache_modifier": "", @@ -14,12 +14,12 @@ "SPLITK_BLOCK_SIZE": 3072 }, "medium_M32": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 2, + "num_stages": 1, "waves_per_eu": 3, "matrix_instr_nonkdim": 16, "cache_modifier": "", @@ -28,13 +28,13 @@ "SPLITK_BLOCK_SIZE": 3072 }, "medium_M64": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 3, + "num_stages": 1, + "waves_per_eu": 1, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, @@ -42,13 +42,13 @@ "SPLITK_BLOCK_SIZE": 3072 }, "medium_M128": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 3, + "num_stages": 1, + "waves_per_eu": 2, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, @@ -56,12 +56,12 @@ "SPLITK_BLOCK_SIZE": 3072 }, "large": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 2, + "num_stages": 1, "waves_per_eu": 3, "matrix_instr_nonkdim": 16, "cache_modifier": "", @@ -70,13 +70,13 @@ "SPLITK_BLOCK_SIZE": 3072 }, "xlarge": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 3, + "num_stages": 1, + "waves_per_eu": 1, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, @@ -84,13 +84,13 @@ "SPLITK_BLOCK_SIZE": 3072 }, "any": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 3, + "num_stages": 1, + "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=4096-K=1024.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=4096-K=1024.json index 3b3b9424eb..973c7c2623 100644 --- a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=4096-K=1024.json +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=4096-K=1024.json @@ -1,12 +1,12 @@ { "small": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 3, + "num_stages": 1, + "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, @@ -14,13 +14,13 @@ "SPLITK_BLOCK_SIZE": 1024 }, "medium_M32": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 8, "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 3, + "num_stages": 1, + "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, @@ -28,13 +28,13 @@ "SPLITK_BLOCK_SIZE": 1024 }, "medium_M64": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 3, + "num_stages": 1, + "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, @@ -42,13 +42,13 @@ "SPLITK_BLOCK_SIZE": 1024 }, "medium_M128": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 3, + "num_stages": 1, + "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, @@ -56,13 +56,13 @@ "SPLITK_BLOCK_SIZE": 1024 }, "large": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 3, + "num_stages": 1, + "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, @@ -70,13 +70,13 @@ "SPLITK_BLOCK_SIZE": 1024 }, "xlarge": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 8, "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 3, + "num_stages": 1, + "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, @@ -84,13 +84,13 @@ "SPLITK_BLOCK_SIZE": 1024 }, "any": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 3, + "num_stages": 1, + "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=6144-K=1024.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=6144-K=1024.json index 3b3b9424eb..d02f02664f 100644 --- a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=6144-K=1024.json +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=6144-K=1024.json @@ -1,12 +1,12 @@ { "small": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 3, + "num_stages": 1, + "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, @@ -14,12 +14,12 @@ "SPLITK_BLOCK_SIZE": 1024 }, "medium_M32": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 2, + "num_stages": 1, "waves_per_eu": 3, "matrix_instr_nonkdim": 16, "cache_modifier": "", @@ -28,13 +28,13 @@ "SPLITK_BLOCK_SIZE": 1024 }, "medium_M64": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 3, + "num_stages": 1, + "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, @@ -42,13 +42,13 @@ "SPLITK_BLOCK_SIZE": 1024 }, "medium_M128": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 3, + "num_stages": 1, + "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, @@ -56,13 +56,13 @@ "SPLITK_BLOCK_SIZE": 1024 }, "large": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 3, + "num_stages": 1, + "waves_per_eu": 1, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, @@ -70,13 +70,13 @@ "SPLITK_BLOCK_SIZE": 1024 }, "xlarge": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 3, + "num_stages": 1, + "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, @@ -84,13 +84,13 @@ "SPLITK_BLOCK_SIZE": 1024 }, "any": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 3, + "num_stages": 1, + "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC.json index 3bcb88ed00..920e68b04e 100644 --- a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC.json +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC.json @@ -1,7 +1,7 @@ { "small": { "BLOCK_SIZE_K": 64, - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "GROUP_SIZE_M": 1, "NUM_KSPLIT": 1, @@ -9,13 +9,13 @@ "cache_modifier": "", "kpack": 1, "matrix_instr_nonkdim": 16, - "num_stages": 2, + "num_stages": 1, "num_warps": 4, - "waves_per_eu": 3 + "waves_per_eu": 4 }, "medium_M32": { "BLOCK_SIZE_K": 64, - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "GROUP_SIZE_M": 1, "NUM_KSPLIT": 1, @@ -23,13 +23,13 @@ "cache_modifier": "", "kpack": 1, "matrix_instr_nonkdim": 16, - "num_stages": 2, + "num_stages": 1, "num_warps": 4, "waves_per_eu": 3 }, "medium_M64": { "BLOCK_SIZE_K": 64, - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "GROUP_SIZE_M": 1, "NUM_KSPLIT": 1, @@ -37,13 +37,13 @@ "cache_modifier": "", "kpack": 1, "matrix_instr_nonkdim": 16, - "num_stages": 2, + "num_stages": 1, "num_warps": 4, - "waves_per_eu": 3 + "waves_per_eu": 4 }, "medium_M128": { "BLOCK_SIZE_K": 64, - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "GROUP_SIZE_M": 1, "NUM_KSPLIT": 1, @@ -51,13 +51,13 @@ "cache_modifier": "", "kpack": 1, "matrix_instr_nonkdim": 16, - "num_stages": 2, + "num_stages": 1, "num_warps": 4, - "waves_per_eu": 3 + "waves_per_eu": 2 }, "large": { "BLOCK_SIZE_K": 64, - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "GROUP_SIZE_M": 1, "NUM_KSPLIT": 1, @@ -65,13 +65,13 @@ "cache_modifier": "", "kpack": 1, "matrix_instr_nonkdim": 16, - "num_stages": 2, + "num_stages": 1, "num_warps": 4, "waves_per_eu": 3 }, "xlarge": { "BLOCK_SIZE_K": 64, - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "GROUP_SIZE_M": 1, "NUM_KSPLIT": 1, @@ -79,13 +79,13 @@ "cache_modifier": "", "kpack": 1, "matrix_instr_nonkdim": 16, - "num_stages": 2, + "num_stages": 1, "num_warps": 4, - "waves_per_eu": 3 + "waves_per_eu": 4 }, "any": { "BLOCK_SIZE_K": 64, - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "GROUP_SIZE_M": 1, "NUM_KSPLIT": 1, @@ -93,8 +93,8 @@ "cache_modifier": "", "kpack": 1, "matrix_instr_nonkdim": 16, - "num_stages": 2, + "num_stages": 1, "num_warps": 4, - "waves_per_eu": 3 + "waves_per_eu": 2 } } diff --git a/aiter/ops/triton/gemm_a8w8.py b/aiter/ops/triton/gemm_a8w8.py index 3602ef2ff6..16fed7622e 100644 --- a/aiter/ops/triton/gemm_a8w8.py +++ b/aiter/ops/triton/gemm_a8w8.py @@ -22,7 +22,7 @@ def gemm_a8w8( x_scale: torch.Tensor, w_scale: torch.Tensor, bias: Optional[torch.Tensor] = None, - dtype: Optional[float] = torch.bfloat16, + dtype: Optional[torch.dtype] = torch.bfloat16, y: Optional[torch.Tensor] = None, config: Optional[dict] = None, ): diff --git a/aiter/ops/triton/gemm_a8w8_blockscale.py b/aiter/ops/triton/gemm_a8w8_blockscale.py index ee2d072822..e6db089975 100644 --- a/aiter/ops/triton/gemm_a8w8_blockscale.py +++ b/aiter/ops/triton/gemm_a8w8_blockscale.py @@ -22,7 +22,7 @@ def gemm_a8w8_blockscale( w: torch.Tensor, x_scale: torch.Tensor, w_scale: torch.Tensor, - dtype: Optional[float] = torch.bfloat16, + dtype: Optional[torch.dtype] = torch.bfloat16, y: Optional[torch.Tensor] = None, config: Optional[dict] = None, skip_reduce: Optional[bool] = False, diff --git a/aiter/ops/triton/tune_a16w16_atomic.py b/aiter/ops/triton/tune_a16w16_atomic.py index 746d7b2ae8..44a73c4494 100644 --- a/aiter/ops/triton/tune_a16w16_atomic.py +++ b/aiter/ops/triton/tune_a16w16_atomic.py @@ -5,7 +5,7 @@ import time import triton from datetime import datetime -from typing import List, Dict, Any, Union, Tuple, Optional +from typing import List, Dict, Union, Tuple, Optional import torch from tqdm import tqdm @@ -24,82 +24,48 @@ } -# def get_configs_compute_bound(): -# configs = [] -# for num_stages in [2]: -# for block_m in [ -# 16, -# ]: -# for block_k in [ -# 64, -# ]: -# for block_n in [ -# 32, -# ]: -# for num_warps in [ -# 4, -# ]: -# for group_size in [ -# 1, -# ]: -# for num_ksplit in [ -# 1, -# ]: # Atomic kernel specific parameter -# for waves_per_eu in [3]: -# configs.append( -# { -# "BLOCK_SIZE_M": block_m, -# "BLOCK_SIZE_N": block_n, -# "BLOCK_SIZE_K": block_k, -# "GROUP_SIZE_M": group_size, -# "num_warps": num_warps, -# "num_stages": num_stages, -# "waves_per_eu": waves_per_eu, # TODO check if compatible -# "matrix_instr_nonkdim": 16, # TODO -# "cache_modifier": "", # Empty string for atomic kernel -# "NUM_KSPLIT": num_ksplit, # Atomic kernel specific -# "kpack": 1, # TODO -# "SPLITK_BLOCK_SIZE": 1, # Will be set dynamically -# } -# ) -# return configs +def generate_gemm_a16w16_inputs(M, N, K, dtype, layout="TN", output=True, bias=False): + if isinstance(dtype, str): + dtype = DTYPE_MAP[dtype] + + # TN is default layout + if layout[0] == "T": + x = torch.randn((M, K), dtype=dtype, device="cuda") + else: + x = torch.randn((K, M), dtype=dtype, device="cuda").T + + if layout[1] == "T": + weight = torch.randn((K, N), dtype=dtype, device="cuda").T + else: + weight = torch.randn((N, K), dtype=dtype, device="cuda") + + bias_tensor = None + if bias: + bias_tensor = torch.empty((N), dtype=dtype, device="cuda") + + y = None + if output: + y = torch.empty((M, N), dtype=dtype, device="cuda") + out_dtype = (None,) + else: + out_dtype = dtype + + return x, weight, bias_tensor, out_dtype, y def get_configs_compute_bound() -> List[Dict[str, int | str]]: configs = [] - for num_stages in [2]: - for block_m in [ - 16, - ]: - for block_k in [ - 64, - ]: - for block_n in [ - 32, - ]: - for num_warps in [ - 4, - ]: - for group_size in [ - 1, - ]: - for num_ksplit in [ - 1, - ]: # Atomic kernel specific parameter - for waves_per_eu in [3]: - # for num_stages in [2, 3, 4, 5]: - # for block_m in [16, 32, 64, 128, 256]: - # for block_k in [64, 128]: - # for block_n in [32, 64, 128, 256]: - # for num_warps in [4, 8]: - # for group_size in [1, 8, 16, 32, 64]: - # for num_ksplit in [ - # 1, - # 2, - # 4, - # 8, - # ]: # Atomic kernel specific parameter - # for waves_per_eu in [1, 2, 3, 4]: + # Optimize parameters based on kernel analysis + # Only test parameters that are actually used in the kernel + # Based on the generated configs, we'll use the optimal values + for num_stages in [2,3,4,5]: # Only 1 stage is used + for block_m in [16,32,64,128]: # Fixed to 64 as in current configs + for block_k in [64,128]: # Fixed to 64 as in current configs + for block_n in [32,64,128,256]: # Fixed to 32 as in current configs + for group_size in [1,8,16,32,64]: + for num_warps in [4,8]: + for num_ksplit in [1,2,4,8]: # Only test 1 since higher values may cause issues + for waves_per_eu in [2,4,8]: # Fixed to 3 as in current configs configs.append( { "BLOCK_SIZE_M": block_m, @@ -108,12 +74,12 @@ def get_configs_compute_bound() -> List[Dict[str, int | str]]: "GROUP_SIZE_M": group_size, "num_warps": num_warps, "num_stages": num_stages, - "waves_per_eu": waves_per_eu, # TODO check if compatible - "matrix_instr_nonkdim": 16, # TODO + "waves_per_eu": waves_per_eu, + "matrix_instr_nonkdim": 16, # Fixed value used in kernel "cache_modifier": "", # Empty string for atomic kernel "NUM_KSPLIT": num_ksplit, # Atomic kernel specific - "kpack": 1, # TODO - "SPLITK_BLOCK_SIZE": 1, # Will be set dynamically + # "kpack": 1, # Fixed value used in kernel + # "SPLITK_BLOCK_SIZE": 1, # Will be set dynamically } ) return configs @@ -168,10 +134,11 @@ def benchmark_config( JIT compilation and GPU warmup effects. The timing is measured using CUDA events for accurate GPU kernel timing. """ + torch_out = torch.nn.functional.linear(x, w, bias=None) # run the kernel def run(): - gemm_a16w16_atomic(x, w, dtype, y, config) + return gemm_a16w16_atomic(x, w, dtype, y, config) torch.cuda.synchronize() # JIT complication & warmup @@ -186,10 +153,13 @@ def run(): for i in range(num_iters): torch.cuda.synchronize() start_event.record() - run() + triton_out_raw = run() + # Convert to the same dtype as the reference for comparison + triton_out = triton_out_raw.to(torch_out.dtype) end_event.record() end_event.synchronize() latencies.append(start_event.elapsed_time(end_event)) + torch.testing.assert_close(triton_out, torch_out, atol=1e-1, rtol=1e-1) avg = sum(latencies) / (num_iters * 10) * 1000 # us return avg @@ -198,19 +168,10 @@ def tune( M: int, N: int, K: int, search_space: List[Dict[str, int | str]], input_type: str ): if input_type == "bfloat16": - bf16_info = torch.finfo(torch.bfloat16) - bf16_max, bf_16_max = bf16_info.max, bf16_info.min - - # create random weights downcasted from torch.float32 to torch.bf16 - x_fp32 = ( - (torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * bf16_max + # Use the same input generation as test file + x, w, _, out_dtype, y = generate_gemm_a16w16_inputs( + M, N, K, torch.bfloat16, output=True ) - x = x_fp32.clamp(min=bf_16_max, max=bf16_max).to(torch.bfloat16) - - w_fp32 = ( - (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * bf16_max - ) - w = w_fp32.clamp(min=bf_16_max, max=bf16_max).to(torch.bfloat16) else: raise RuntimeError("Currently, only support tune w16a16 block fp16 kernel.") @@ -221,16 +182,19 @@ def tune( kernel_time = benchmark_config( x=x, w=w, - dtype=torch.bfloat16, + dtype=torch.float32, y=None, config=config, num_iters=10, ) except triton.runtime.autotuner.OutOfResources as e: - print("config failed!", config) - print("error: ", e) + # print("config failed!", config) + # print("error: ", e) # Some configurations may be invalid and fail to compile. continue + except AssertionError as e: + print("Assert error:", e) + continue if kernel_time < best_time: best_time = kernel_time @@ -418,7 +382,7 @@ def main(args): gpu_id, batches_per_gpu[gpu_id], weight_shapes, # Each GPU processes all weight shapes - args.input_type + args.input_type, ) ) diff --git a/aiter/ops/triton/tune_a8w8_blockscale.py b/aiter/ops/triton/tune_a8w8_blockscale.py new file mode 100644 index 0000000000..47c26c0d50 --- /dev/null +++ b/aiter/ops/triton/tune_a8w8_blockscale.py @@ -0,0 +1,492 @@ +import argparse +import json +import multiprocessing as mp +import os +import time +import triton +from datetime import datetime +from typing import List, Dict, Union, Tuple, Optional +import torch +from tqdm import tqdm + + +from gemm_a8w8_blockscale import gemm_a8w8_blockscale # type: ignore +from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH # type: ignore +from aiter.ops.triton.utils.types import get_fp8_dtypes + +mp.set_start_method("spawn", force=True) + +# Get FP8 data types +e5m2_type, e4m3_type = get_fp8_dtypes() + + +def generate_gemm_a8w8_blockscale_inputs(M, N, K, dtype, block_shape=(128, 128), layout="TN", output=True, bias=False): + """ + Generate inputs for gemm_a8w8_blockscale kernel. + + Args: + M, N, K: Matrix dimensions + dtype: Output data type + block_shape: Tuple of (block_shape_n, block_shape_k) for block scaling + layout: Input matrix layout + output: Whether to generate output tensor + bias: Whether to generate bias tensor + + Returns: + Tuple of (x, w, x_scale, w_scale, bias, y) + """ + block_shape_n, block_shape_k = block_shape + scale_n = (N + block_shape_n - 1) // block_shape_n + scale_k = (K + block_shape_k - 1) // block_shape_k + + # Generate input matrix x (M, K) + if layout[0] == "T": + x = (torch.rand((M, K), dtype=torch.float16, device="cuda") / 10).to(e4m3_type) + else: + x = ((torch.rand((K, M), dtype=torch.float16, device="cuda") / 10).to(e4m3_type)).T + + # Generate weight matrix w (N, K) + if layout[1] == "N": + w = (torch.rand((N, K), dtype=torch.float16, device="cuda") / 10).to(e4m3_type) + else: + w = ((torch.rand((K, N), dtype=torch.float16, device="cuda") / 10).to(e4m3_type)).T + + # Generate scale tensors + x_scale = torch.rand([M, scale_k], dtype=torch.float32, device="cuda") + w_scale = torch.rand([scale_n, scale_k], dtype=torch.float32, device="cuda") + + # Generate bias tensor if needed + bias_tensor = None + if bias: + bias_tensor = torch.empty((N), dtype=dtype, device="cuda") + + # Generate output tensor if needed + y = None + if output: + y = torch.empty((M, N), dtype=dtype, device="cuda") + + return x, w, x_scale, w_scale, bias_tensor, y + + +def get_configs_compute_bound() -> List[Dict[str, int | str]]: + """ + Generate configuration space for tuning the gemm_a8w8_blockscale kernel. + Based on the sample config file, we'll tune around those values. + Note: GROUP_K must equal BLOCK_SIZE_K as required by the kernel. + """ + configs = [] + # Based on the sample config from MI300X-GEMM-A8W8_BLOCKSCALE.json + # We'll explore a reasonable range around these values + for num_stages in [1, 2]: # Sample config uses 2 + for block_m in [64, 128, 256]: # Sample config uses 128 + for block_k in [64, 128, 256]: # Sample config uses 128 + for block_n in [64, 128, 256]: # Sample config uses 128 + for group_size in [1, 8]: # Sample config uses 1 + for num_warps in [4, 8]: # Sample config uses 4 + for num_ksplit in [1, 2, 4]: # Sample config uses 1 + for waves_per_eu in [1, 2, 4]: # Sample config uses 2 + for kpack in [1,2 ]: # Sample config uses 2 + for cache_modifier in ["", ".cg"]: # Sample config uses ".cg" + configs.append( + { + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": group_size, + "num_warps": num_warps, + "num_stages": num_stages, + "waves_per_eu": waves_per_eu, + "matrix_instr_nonkdim": 16, # Fixed value used in kernel + "cache_modifier": cache_modifier, + "NUM_KSPLIT": num_ksplit, + "kpack": kpack, + # "SPLITK_BLOCK_SIZE": 1, # Will be set dynamically + } + ) + return configs + + +def get_weight_shapes(tp_size: int) -> List[Tuple[int, int]]: + """Get weight shapes to test during tuning.""" + total = [ + (1024, 1024), + (4096, 1024), + (1024, 2048), + (6144, 1024), + (1024, 3072), + ] + + weight_shapes: List[Tuple[int, int]] = [] + for t in total: + weight_shapes.append(t) + + return weight_shapes + + +def run_torch_reference(x, w, x_scale, w_scale, bias, dtype=torch.bfloat16, block_shape=(128, 128)): + """ + Run reference implementation using PyTorch. + This is used for correctness verification. + Based on the test file implementation. + """ + block_shape_n, block_shape_k = block_shape + m, k = x.shape + n = w.shape[0] + scale_n = (n + block_shape_n - 1) // block_shape_n + scale_k = (k + block_shape_k - 1) // block_shape_k + + # Expand scales to match the full matrix dimensions + x_scale_expanded = x_scale.repeat_interleave(block_shape_k, dim=1) + x_scaled = x.to(x_scale_expanded.dtype) * x_scale_expanded[:m, :k] + x_scaled = x_scaled.view(m, k) + + w_scale_expanded = w_scale.repeat_interleave(block_shape_n, dim=0) + w_scale_expanded = w_scale_expanded.repeat_interleave(block_shape_k, dim=1) + w_scale_expanded = w_scale_expanded[:n, :k] + w_scaled = w.to(w_scale_expanded.dtype) * w_scale_expanded + + # Compute the matrix multiplication with bias if provided + # Convert bias to float32 if it's not None to match the other tensors + bias_float32 = bias.to(torch.float32) if bias is not None else None + out = torch.nn.functional.linear(x_scaled.to(torch.float32), w_scaled.to(torch.float32), bias=bias_float32) + + return out.to(dtype) + + +def benchmark_config( + x: torch.Tensor, + w: torch.Tensor, + x_scale: torch.Tensor, + w_scale: torch.Tensor, + bias: Optional[torch.Tensor], + dtype: torch.dtype, + config: Dict[str, Union[str, int]], + y: Optional[torch.Tensor] = None, + num_iters=10, +) -> float: + """ + Benchmark the performance of a GEMM operation with a specific configuration. + + This function measures the execution time of the gemm_a8w8_blockscale kernel by running + it multiple times with synchronization points to ensure accurate timing. It performs + warmup runs before the actual benchmarking to account for JIT compilation overhead. + + Args: + x (torch.Tensor): Input tensor of shape (M, K) representing the first matrix operand. + w (torch.Tensor): Weight tensor of shape (N, K) representing the second matrix operand. + x_scale (torch.Tensor): Scale tensor for x with shape (M, scale_k). + w_scale (torch.Tensor): Scale tensor for w with shape (scale_n, scale_k). + dtype (torch.dtype): Data type for the computation (e.g., torch.bfloat16). + config (Dict[str, Union[str, int]]): Configuration dictionary containing kernel + parameters such as block sizes, number of warps, etc. + y (Optional[torch.Tensor], optional): Output tensor to store the result. If None, + a new tensor will be allocated. Defaults to None. + num_iters (int, optional): Number of benchmark iterations to run. Defaults to 10. + + Returns: + float: Average execution time in microseconds (us) per iteration. + + Note: + The function performs 5 warmup iterations before benchmarking to account for + JIT compilation and GPU warmup effects. The timing is measured using CUDA events + for accurate GPU kernel timing. + """ + # Calculate GROUP_K based on the input dimensions and w_scale shape + # This must match BLOCK_SIZE_K to satisfy the kernel assertion + M, K = x.shape + N, _ = w.shape + w_scale_T = w_scale.T # Transpose to match kernel's expectation + group_k = triton.next_power_of_2(triton.cdiv(K, w_scale_T.shape[0])) + + # Create a copy of config to modify + modified_config = config.copy() + # Set BLOCK_SIZE_K to match GROUP_K to satisfy kernel assertion + modified_config["BLOCK_SIZE_K"] = group_k + + # Get reference output for correctness verification + torch_out = run_torch_reference(x, w, x_scale, w_scale, bias, dtype) + + # Run kernel + def run(): + # Pass the modified config to the kernel + return gemm_a8w8_blockscale(x, w, x_scale, w_scale, dtype, y, modified_config, skip_reduce=False) + + torch.cuda.synchronize() + # JIT compilation & warmup + for _ in range(5): + run() + torch.cuda.synchronize() + + start_event = torch.Event(enable_timing=True) + end_event = torch.Event(enable_timing=True) + + latencies: list[float] = [] + for i in range(num_iters): + torch.cuda.synchronize() + start_event.record() + triton_out_raw = run() + # Convert to the same dtype as the reference for comparison + # Handle the case where triton_out_raw might be None + if triton_out_raw is not None: + triton_out = triton_out_raw.to(torch_out.dtype) + else: + triton_out = torch_out # Fallback to reference output + end_event.record() + end_event.synchronize() + latencies.append(start_event.elapsed_time(end_event)) + torch.testing.assert_close(triton_out, torch_out, atol=1e-1, rtol=1e-1) + avg = sum(latencies) / (num_iters * 10) * 1000 # us + return avg + + +def tune( + M: int, N: int, K: int, search_space: List[Dict[str, int | str]], input_type: str +): + """Tune the kernel for specific matrix dimensions.""" + if input_type == "bfloat16": + # Use the same input generation as test file + x, w, x_scale, w_scale, bias, y = generate_gemm_a8w8_blockscale_inputs( + M, N, K, torch.bfloat16, bias=True + ) + else: + raise RuntimeError("Currently, only support tune a8w8 blockscale kernel with bfloat16 output.") + + best_config = None + best_time = float("inf") + for config in tqdm(search_space): + try: + kernel_time = benchmark_config( + x=x, + w=w, + x_scale=x_scale, + w_scale=w_scale, + bias=bias, + dtype=torch.bfloat16, + y=None, + config=config, + num_iters=10, + ) + except triton.runtime.autotuner.OutOfResources as e: + # Some configurations may be invalid and fail to compile. + continue + except AssertionError as e: + print("Assert error:", e) + continue + + if kernel_time < best_time: + best_time = kernel_time + best_config = config + now = datetime.now() + print(f"{now.ctime()}] Completed tuning for batch_size={M}") + assert best_config is not None + return best_config + + +def save_configs( + N, + K, + configs, + save_path, +) -> None: + """Save the best configurations to a JSON file.""" + os.makedirs(save_path, exist_ok=True) + device_name = "MI300X" # TODO: Hardcoded, make it dynamic + json_file_name = f"{device_name}-GEMM-A8W8_BLOCKSCALE-N={N}-K={K}.json" + + config_file_path = os.path.join(save_path, json_file_name) + print(f"Writing best config to {config_file_path}...") + + with open(config_file_path, "w") as f: + json.dump(configs, f, indent=4) + f.write("\n") + + +def tune_on_gpu( + gpu_id: int, + batch_sizes: List[int], + weight_shapes: List[Tuple[int, int]], + input_type: str, +) -> None: + """Run tuning on a specific GPU.""" + torch.cuda.set_device(gpu_id) + print(f"Starting tuning on GPU {gpu_id} with batch sizes {batch_sizes}") + + save_path = AITER_TRITON_CONFIGS_PATH + "/gemm/" + + search_space = get_configs_compute_bound() + + start = time.time() + + # Collect all configs to determine the best overall config + all_configs: List[Dict[str, Dict[str, int | str]]] = [] + + for shape in tqdm(weight_shapes, desc=f"GPU {gpu_id} - Shapes"): + N, K = shape[0], shape[1] + print(f"[GPU {gpu_id}] Tune for weight shape of `N: {N}, K: {K}`") + benchmark_results = [ + tune( + batch_size, + N, + K, + search_space, + input_type, + ) + for batch_size in tqdm(batch_sizes, desc=f"GPU {gpu_id} - Batch sizes") + ] + best_configs: Dict[str, Dict[str, int | str]] = {} + # Create configs for different M size categories as expected by the kernel + for i, (M, config) in enumerate(zip(batch_sizes, benchmark_results)): + if i == len(batch_sizes) - 1: + best_configs["any"] = config + elif M < 32: + best_configs["small"] = config + elif M <= 128: + BLK_M = triton.next_power_of_2(M) + if BLK_M == 32: + best_configs["medium_M32"] = config + elif BLK_M == 64: + best_configs["medium_M64"] = config + elif BLK_M == 128: + best_configs["medium_M128"] = config + elif M <= 256: + best_configs["large"] = config + else: + best_configs["xlarge"] = config + # Store configs for later analysis + all_configs.append(best_configs) + save_configs(N, K, best_configs, save_path) + + # Create a default config file (without N,K parameters) by selecting the most common config + default_config = create_default_config(all_configs) + save_default_config(default_config, save_path) + + end = time.time() + print(f"Tuning on GPU {gpu_id} took {end - start:.2f} seconds") + + +def create_default_config( + all_configs: List[Dict[str, Dict[str, Union[int, str]]]], +) -> Dict[str, Dict[str, Union[int, str]]]: + """Create a default config by selecting the most common config across all shapes.""" + from collections import Counter + + # Collect all configs for each category + category_configs = { + "small": [], + "medium_M32": [], + "medium_M64": [], + "medium_M128": [], + "large": [], + "xlarge": [], + "any": [], + } + + for config in all_configs: + for category, params in config.items(): + if category in category_configs: + # Convert config to a hashable tuple for counting + config_tuple = tuple(sorted(params.items())) + category_configs[category].append(config_tuple) + + # Find the most common config for each category + default_config: Dict[str, Dict[str, Union[int, str]]] = {} + for category, configs in category_configs.items(): + if configs: + most_common = Counter(configs).most_common(1)[0][0] + default_config[category] = dict(most_common) + + return default_config + + +def save_default_config( + config: Dict[str, Dict[str, Union[int, str]]], save_path: str +) -> None: + """Save the default config file (without N,K parameters).""" + os.makedirs(save_path, exist_ok=True) + device_name = "MI300X" # TODO: Hardcoded, make it dynamic + json_file_name = f"{device_name}-GEMM-A8W8_BLOCKSCALE.json" + + config_file_path = os.path.join(save_path, json_file_name) + print(f"Writing default config to {config_file_path}...") + + with open(config_file_path, "w") as f: + json.dump(config, f, indent=4) + f.write("\n") + + +def distribute_batch_sizes(batch_sizes: List[int], num_gpus: int) -> List[List[int]]: + """Distribute batch sizes across available GPUs.""" + batches_per_gpu: List[List[int]] = [] + for i in range(num_gpus): + start_idx = i * len(batch_sizes) // num_gpus + end_idx = (i + 1) * len(batch_sizes) // num_gpus + batches_per_gpu.append(batch_sizes[start_idx:end_idx]) + return batches_per_gpu + + +def main(args): + print(args) + num_gpus = torch.cuda.device_count() + if num_gpus == 0: + raise RuntimeError("No GPU available for tuning") + print(f"Found {num_gpus} GPUs for parallel tuning") + + torch.cuda.init() + + if args.batch_size is None: + batch_sizes = [ + 16, # For small config + 32, # For medium_M32 config + 64, # For medium_M64 config + 128, # For medium_M128 config + 256, # For large config + 512, # For large config + 2048, # For xlarge config + 4096, # For xlarge config + ] + else: + batch_sizes = [args.batch_size] + num_gpus = 1 # If only one batch size, use only one GPU + + weight_shapes = get_weight_shapes(args.tp_size) + + batches_per_gpu = distribute_batch_sizes(batch_sizes, 1) + + # Prepare arguments for each GPU process + process_args = [] + for gpu_id in range(1): + process_args.append( + ( + gpu_id, + batches_per_gpu[gpu_id], + weight_shapes, # Each GPU processes all weight shapes + args.input_type, + ) + ) + + ctx = mp.get_context("spawn") + with ctx.Pool(1) as pool: + pool.starmap(tune_on_gpu, process_args) + + print("Multi-GPU tuning completed") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + formatter_class=argparse.RawTextHelpFormatter, + ) + + parser.add_argument("--tp-size", "-tp", type=int, default=1) + parser.add_argument( + "--input-type", type=str, choices=["bfloat16"], default="bfloat16" + ) + parser.add_argument( + "--out-dtype", + type=str, + choices=["float32", "float16", "bfloat16", "half"], + default="bfloat16", + ) + parser.add_argument("--batch-size", type=int, required=False) + args = parser.parse_args() + + main(args) From 1574097ddd84e2ec82c9323ecaa694dd55adc3bd Mon Sep 17 00:00:00 2001 From: Amir Balwel Date: Mon, 8 Dec 2025 07:46:16 +0000 Subject: [PATCH 07/10] fix: shared memory issue on R9700 for certain sizes in unified attention kernel Signed-off-by: Amir Balwel --- aiter/ops/triton/unified_attention.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/aiter/ops/triton/unified_attention.py b/aiter/ops/triton/unified_attention.py index b2231ee563..78a7879db6 100644 --- a/aiter/ops/triton/unified_attention.py +++ b/aiter/ops/triton/unified_attention.py @@ -1,5 +1,6 @@ # The kernels in this file are adapted from vLLM: # https://github.com/vllm-project/vllm/blob/main/vllm/attention/ops/triton_unified_attention.py +from aiter.ops.triton.utils._triton import arch_info import triton import torch from aiter.ops.triton.utils.device_info import get_num_sms @@ -27,8 +28,12 @@ def select_2d_config( TILE_SIZE = 64 # in case head_size is large max_num_stages_2d = 4 + dev = arch_info.get_device() if head_size > 128: - max_num_stages_2d = 2 + if block_size >=64 and dev == "R9700": + max_num_stages_2d = 1 + else: + max_num_stages_2d = 2 if all_decode == False: num_stages_2d = 1 num_warps = 2 From 42137a103c2f3e7d9d6b832932f92c3fd355ea31 Mon Sep 17 00:00:00 2001 From: big-yellow-duck Date: Tue, 9 Dec 2025 04:19:08 +0000 Subject: [PATCH 08/10] testing a8w8 tuning --- ...9700-GEMM-A16W16-ATOMIC-N=1024-K=1024.json | 44 +- ...9700-GEMM-A16W16-ATOMIC-N=1024-K=2048.json | 67 +- ...9700-GEMM-A16W16-ATOMIC-N=1024-K=3072.json | 81 +- ...9700-GEMM-A16W16-ATOMIC-N=4096-K=1024.json | 59 +- ...9700-GEMM-A16W16-ATOMIC-N=6144-K=1024.json | 81 +- .../gemm/R9700-GEMM-A16W16-ATOMIC.json | 85 +- ...00-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024.json | 15 + ...00-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048.json | 15 + ...00-GEMM-A8W8_BLOCKSCALE-N=1024-K=3072.json | 15 + ...00-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024.json | 15 + ...00-GEMM-A8W8_BLOCKSCALE-N=6144-K=1024.json | 15 + .../gemm/R9700-GEMM-A8W8_BLOCKSCALE.json | 15 + ...MM-A8W8_PER_TOKEN_SCALE-N=1024-K=1024.json | 100 ++ ...MM-A8W8_PER_TOKEN_SCALE-N=1024-K=2048.json | 100 ++ ...MM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072.json | 100 ++ ...ER_TOKEN_SCALE-N=1024-K=3072_batch_16.json | 16 + ...TOKEN_SCALE-N=1024-K=3072_batch_16_32.json | 30 + ...EN_SCALE-N=1024-K=3072_batch_16_32_64.json | 44 + ...CALE-N=1024-K=3072_batch_16_32_64_128.json | 58 + ...-N=1024-K=3072_batch_16_32_64_128_256.json | 72 + ...024-K=3072_batch_16_32_64_128_256_512.json | 86 + ...=3072_batch_16_32_64_128_256_512_2048.json | 100 ++ ..._batch_16_32_64_128_256_512_2048_4096.json | 100 ++ ...MM-A8W8_PER_TOKEN_SCALE-N=4096-K=1024.json | 100 ++ ...MM-A8W8_PER_TOKEN_SCALE-N=6144-K=1024.json | 100 ++ .../gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE.json | 100 ++ aiter/ops/triton/gemm_a8w8_per_token_scale.py | 2 +- aiter/ops/triton/tune_a16w16_atomic.py | 32 +- aiter/ops/triton/tune_a8w8_blockscale.py | 853 +++++++-- aiter/ops/triton/tune_a8w8_per_token_scale.py | 1520 +++++++++++++++++ 30 files changed, 3636 insertions(+), 384 deletions(-) create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=3072.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=6144-K=1024.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=1024.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=2048.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256_512.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256_512_2048.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256_512_2048_4096.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=4096-K=1024.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=6144-K=1024.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE.json create mode 100644 aiter/ops/triton/tune_a8w8_per_token_scale.py diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=1024.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=1024.json index 84fb9646ed..b5c279bfb2 100644 --- a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=1024.json +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=1024.json @@ -1,11 +1,11 @@ { "small": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 4, + "num_warps": 4, + "num_stages": 2, "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "cache_modifier": "", @@ -13,13 +13,13 @@ "SPLITK_BLOCK_SIZE": 1024 }, "medium_M32": { - "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 2, + "num_stages": 3, + "waves_per_eu": 8, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, @@ -27,12 +27,12 @@ }, "medium_M64": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 3, - "waves_per_eu": 2, + "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, @@ -40,25 +40,25 @@ }, "medium_M128": { "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 8, "num_warps": 8, - "num_stages": 3, - "waves_per_eu": 8, + "num_stages": 2, + "waves_per_eu": 2, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, "SPLITK_BLOCK_SIZE": 1024 }, "large": { - "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, "num_warps": 8, - "num_stages": 3, - "waves_per_eu": 8, + "num_stages": 2, + "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, @@ -68,10 +68,10 @@ "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 2, - "waves_per_eu": 2, + "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, @@ -81,10 +81,10 @@ "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 32, "num_warps": 8, "num_stages": 2, - "waves_per_eu": 2, + "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=2048.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=2048.json index 6807313ca6..3ff0922c5e 100644 --- a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=2048.json +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=2048.json @@ -2,99 +2,92 @@ "small": { "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 1, + "num_stages": 5, "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, - "kpack": 1, "SPLITK_BLOCK_SIZE": 2048 }, "medium_M32": { "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 1, - "waves_per_eu": 4, + "num_stages": 3, + "waves_per_eu": 8, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, - "kpack": 1, "SPLITK_BLOCK_SIZE": 2048 }, "medium_M64": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 1, - "waves_per_eu": 4, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, - "kpack": 1, "SPLITK_BLOCK_SIZE": 2048 }, "medium_M128": { "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 1, - "waves_per_eu": 4, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, - "kpack": 1, "SPLITK_BLOCK_SIZE": 2048 }, "large": { - "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 1, - "waves_per_eu": 4, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, - "kpack": 1, "SPLITK_BLOCK_SIZE": 2048 }, "xlarge": { "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 1, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, - "kpack": 1, "SPLITK_BLOCK_SIZE": 2048 }, "any": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 1, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, - "kpack": 1, "SPLITK_BLOCK_SIZE": 2048 } } diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=3072.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=3072.json index f0d8836a5a..9c626d0fac 100644 --- a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=3072.json +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=3072.json @@ -1,100 +1,93 @@ { "small": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 1, - "waves_per_eu": 3, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, - "kpack": 1, "SPLITK_BLOCK_SIZE": 3072 }, "medium_M32": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 1, - "waves_per_eu": 3, + "num_stages": 2, + "waves_per_eu": 2, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, - "kpack": 1, "SPLITK_BLOCK_SIZE": 3072 }, "medium_M64": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 1, - "waves_per_eu": 1, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 8, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, - "kpack": 1, "SPLITK_BLOCK_SIZE": 3072 }, "medium_M128": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 1, + "num_stages": 2, "waves_per_eu": 2, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, - "kpack": 1, "SPLITK_BLOCK_SIZE": 3072 }, "large": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 1, - "waves_per_eu": 3, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, - "kpack": 1, "SPLITK_BLOCK_SIZE": 3072 }, "xlarge": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 1, - "waves_per_eu": 1, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, - "kpack": 1, "SPLITK_BLOCK_SIZE": 3072 }, "any": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 1, - "waves_per_eu": 4, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, - "kpack": 1, "SPLITK_BLOCK_SIZE": 3072 } } diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=4096-K=1024.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=4096-K=1024.json index 973c7c2623..3ad52f6053 100644 --- a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=4096-K=1024.json +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=4096-K=1024.json @@ -1,86 +1,80 @@ { "small": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 1, + "num_stages": 2, "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, - "kpack": 1, "SPLITK_BLOCK_SIZE": 1024 }, "medium_M32": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 8, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 1, - "waves_per_eu": 4, + "num_stages": 3, + "waves_per_eu": 8, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, - "kpack": 1, "SPLITK_BLOCK_SIZE": 1024 }, "medium_M64": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 1, + "num_stages": 4, "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, - "kpack": 1, "SPLITK_BLOCK_SIZE": 1024 }, "medium_M128": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 1, + "num_warps": 8, + "num_stages": 2, "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, - "kpack": 1, "SPLITK_BLOCK_SIZE": 1024 }, "large": { "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 1, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, - "kpack": 1, "SPLITK_BLOCK_SIZE": 1024 }, "xlarge": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8, - "num_warps": 4, - "num_stages": 1, + "num_warps": 8, + "num_stages": 2, "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, - "kpack": 1, "SPLITK_BLOCK_SIZE": 1024 }, "any": { @@ -88,13 +82,12 @@ "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 1, - "waves_per_eu": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, - "kpack": 1, "SPLITK_BLOCK_SIZE": 1024 } } diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=6144-K=1024.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=6144-K=1024.json index d02f02664f..b0605c7435 100644 --- a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=6144-K=1024.json +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=6144-K=1024.json @@ -1,100 +1,93 @@ { "small": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 1, - "waves_per_eu": 4, + "num_stages": 3, + "waves_per_eu": 2, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, - "kpack": 1, "SPLITK_BLOCK_SIZE": 1024 }, "medium_M32": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 1, - "waves_per_eu": 3, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, - "kpack": 1, "SPLITK_BLOCK_SIZE": 1024 }, "medium_M64": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 1, - "waves_per_eu": 4, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 8, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, - "kpack": 1, "SPLITK_BLOCK_SIZE": 1024 }, "medium_M128": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 1, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, - "kpack": 1, "SPLITK_BLOCK_SIZE": 1024 }, "large": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 1, - "waves_per_eu": 1, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, - "kpack": 1, "SPLITK_BLOCK_SIZE": 1024 }, "xlarge": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 1, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, - "kpack": 1, "SPLITK_BLOCK_SIZE": 1024 }, "any": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 1, + "num_warps": 8, + "num_stages": 2, "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "cache_modifier": "", "NUM_KSPLIT": 1, - "kpack": 1, "SPLITK_BLOCK_SIZE": 1024 } } diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC.json index 920e68b04e..7d1c3825bc 100644 --- a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC.json +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC.json @@ -1,100 +1,93 @@ { "small": { - "BLOCK_SIZE_K": 64, - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, "GROUP_SIZE_M": 1, "NUM_KSPLIT": 1, "SPLITK_BLOCK_SIZE": 3072, "cache_modifier": "", - "kpack": 1, "matrix_instr_nonkdim": 16, - "num_stages": 1, + "num_stages": 3, "num_warps": 4, - "waves_per_eu": 4 + "waves_per_eu": 2 }, "medium_M32": { - "BLOCK_SIZE_K": 64, - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, "NUM_KSPLIT": 1, "SPLITK_BLOCK_SIZE": 3072, "cache_modifier": "", - "kpack": 1, "matrix_instr_nonkdim": 16, - "num_stages": 1, - "num_warps": 4, - "waves_per_eu": 3 + "num_stages": 2, + "num_warps": 8, + "waves_per_eu": 4 }, "medium_M64": { - "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_K": 128, "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 8, "NUM_KSPLIT": 1, "SPLITK_BLOCK_SIZE": 3072, "cache_modifier": "", - "kpack": 1, "matrix_instr_nonkdim": 16, - "num_stages": 1, - "num_warps": 4, - "waves_per_eu": 4 + "num_stages": 2, + "num_warps": 8, + "waves_per_eu": 8 }, "medium_M128": { "BLOCK_SIZE_K": 64, - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, "NUM_KSPLIT": 1, "SPLITK_BLOCK_SIZE": 3072, "cache_modifier": "", - "kpack": 1, "matrix_instr_nonkdim": 16, - "num_stages": 1, - "num_warps": 4, - "waves_per_eu": 2 + "num_stages": 2, + "num_warps": 8, + "waves_per_eu": 4 }, "large": { "BLOCK_SIZE_K": 64, - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, "NUM_KSPLIT": 1, "SPLITK_BLOCK_SIZE": 3072, "cache_modifier": "", - "kpack": 1, "matrix_instr_nonkdim": 16, - "num_stages": 1, - "num_warps": 4, - "waves_per_eu": 3 + "num_stages": 2, + "num_warps": 8, + "waves_per_eu": 2 }, "xlarge": { "BLOCK_SIZE_K": 64, - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 8, "NUM_KSPLIT": 1, "SPLITK_BLOCK_SIZE": 3072, "cache_modifier": "", - "kpack": 1, "matrix_instr_nonkdim": 16, - "num_stages": 1, - "num_warps": 4, + "num_stages": 2, + "num_warps": 8, "waves_per_eu": 4 }, "any": { "BLOCK_SIZE_K": 64, "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_N": 128, "GROUP_SIZE_M": 1, "NUM_KSPLIT": 1, "SPLITK_BLOCK_SIZE": 3072, "cache_modifier": "", - "kpack": 1, "matrix_instr_nonkdim": 16, - "num_stages": 1, - "num_warps": 4, - "waves_per_eu": 2 + "num_stages": 2, + "num_warps": 8, + "waves_per_eu": 4 } } diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024.json new file mode 100644 index 0000000000..5b8b362d82 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024.json @@ -0,0 +1,15 @@ +{ + "any": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 2, + "cache_modifier": ".cg" + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048.json new file mode 100644 index 0000000000..fc17b9ca81 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048.json @@ -0,0 +1,15 @@ +{ + "any": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 2, + "cache_modifier": ".cg" + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=3072.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=3072.json new file mode 100644 index 0000000000..5b8b362d82 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=3072.json @@ -0,0 +1,15 @@ +{ + "any": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 2, + "cache_modifier": ".cg" + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024.json new file mode 100644 index 0000000000..a03fd4ef76 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024.json @@ -0,0 +1,15 @@ +{ + "any": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 2, + "cache_modifier": ".cg" + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=6144-K=1024.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=6144-K=1024.json new file mode 100644 index 0000000000..8281184361 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=6144-K=1024.json @@ -0,0 +1,15 @@ +{ + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 2, + "cache_modifier": ".cg" + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE.json new file mode 100644 index 0000000000..de5a5d4300 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE.json @@ -0,0 +1,15 @@ +{ + "any": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NUM_KSPLIT": 1, + "cache_modifier": ".cg", + "kpack": 2, + "matrix_instr_nonkdim": 16, + "num_stages": 2, + "num_warps": 8, + "waves_per_eu": 2 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=1024.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=1024.json new file mode 100644 index 0000000000..e02c3d9496 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=1024.json @@ -0,0 +1,100 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024 + }, + "medium_M64": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024 + }, + "medium_M128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024 + }, + "large": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024 + }, + "xlarge": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=2048.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=2048.json new file mode 100644 index 0000000000..bf4b466417 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=2048.json @@ -0,0 +1,100 @@ +{ + "small": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 2048 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 2048 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 2048 + }, + "medium_M128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "SPLITK_BLOCK_SIZE": 2048 + }, + "large": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 2048 + }, + "xlarge": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 2048 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 2048 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072.json new file mode 100644 index 0000000000..e808bfd22f --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072.json @@ -0,0 +1,100 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "SPLITK_BLOCK_SIZE": 3072 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "SPLITK_BLOCK_SIZE": 3072 + }, + "medium_M64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "SPLITK_BLOCK_SIZE": 3072 + }, + "medium_M128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "large": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "xlarge": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "any": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16.json new file mode 100644 index 0000000000..f39c54a983 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16.json @@ -0,0 +1,16 @@ +{ + "any": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "SPLITK_BLOCK_SIZE": 3072 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32.json new file mode 100644 index 0000000000..835e37a7f7 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32.json @@ -0,0 +1,30 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "SPLITK_BLOCK_SIZE": 3072 + }, + "any": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "SPLITK_BLOCK_SIZE": 3072 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64.json new file mode 100644 index 0000000000..7f8b5719d1 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64.json @@ -0,0 +1,44 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "SPLITK_BLOCK_SIZE": 3072 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "SPLITK_BLOCK_SIZE": 3072 + }, + "any": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "SPLITK_BLOCK_SIZE": 3072 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128.json new file mode 100644 index 0000000000..6495d74112 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128.json @@ -0,0 +1,58 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "SPLITK_BLOCK_SIZE": 3072 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "SPLITK_BLOCK_SIZE": 3072 + }, + "medium_M64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "SPLITK_BLOCK_SIZE": 3072 + }, + "any": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256.json new file mode 100644 index 0000000000..23396fa7b5 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256.json @@ -0,0 +1,72 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "SPLITK_BLOCK_SIZE": 3072 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "SPLITK_BLOCK_SIZE": 3072 + }, + "medium_M64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "SPLITK_BLOCK_SIZE": 3072 + }, + "medium_M128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "any": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256_512.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256_512.json new file mode 100644 index 0000000000..696cc82039 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256_512.json @@ -0,0 +1,86 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "SPLITK_BLOCK_SIZE": 3072 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "SPLITK_BLOCK_SIZE": 3072 + }, + "medium_M64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "SPLITK_BLOCK_SIZE": 3072 + }, + "medium_M128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "large": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "any": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256_512_2048.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256_512_2048.json new file mode 100644 index 0000000000..e808bfd22f --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256_512_2048.json @@ -0,0 +1,100 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "SPLITK_BLOCK_SIZE": 3072 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "SPLITK_BLOCK_SIZE": 3072 + }, + "medium_M64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "SPLITK_BLOCK_SIZE": 3072 + }, + "medium_M128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "large": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "xlarge": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "any": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256_512_2048_4096.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256_512_2048_4096.json new file mode 100644 index 0000000000..e808bfd22f --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256_512_2048_4096.json @@ -0,0 +1,100 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "SPLITK_BLOCK_SIZE": 3072 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "SPLITK_BLOCK_SIZE": 3072 + }, + "medium_M64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "SPLITK_BLOCK_SIZE": 3072 + }, + "medium_M128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "large": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "xlarge": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "any": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=4096-K=1024.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=4096-K=1024.json new file mode 100644 index 0000000000..7010775054 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=4096-K=1024.json @@ -0,0 +1,100 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024 + }, + "medium_M32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024 + }, + "medium_M64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024 + }, + "medium_M128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024 + }, + "large": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024 + }, + "xlarge": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=6144-K=1024.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=6144-K=1024.json new file mode 100644 index 0000000000..fe0fd5eb63 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=6144-K=1024.json @@ -0,0 +1,100 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024 + }, + "medium_M128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024 + }, + "large": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024 + }, + "xlarge": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE.json new file mode 100644 index 0000000000..5df51493e2 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE.json @@ -0,0 +1,100 @@ +{ + "small": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 3072, + "cache_modifier": ".cg", + "kpack": 2, + "matrix_instr_nonkdim": 16, + "num_stages": 1, + "num_warps": 2, + "waves_per_eu": 2 + }, + "medium_M32": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 3072, + "cache_modifier": ".cg", + "kpack": 2, + "matrix_instr_nonkdim": 16, + "num_stages": 1, + "num_warps": 2, + "waves_per_eu": 2 + }, + "medium_M64": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 3072, + "cache_modifier": ".cg", + "kpack": 2, + "matrix_instr_nonkdim": 16, + "num_stages": 1, + "num_warps": 2, + "waves_per_eu": 2 + }, + "medium_M128": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 3072, + "cache_modifier": "", + "kpack": 2, + "matrix_instr_nonkdim": 16, + "num_stages": 1, + "num_warps": 2, + "waves_per_eu": 2 + }, + "large": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 3072, + "cache_modifier": "", + "kpack": 2, + "matrix_instr_nonkdim": 16, + "num_stages": 1, + "num_warps": 2, + "waves_per_eu": 2 + }, + "xlarge": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 3072, + "cache_modifier": "", + "kpack": 2, + "matrix_instr_nonkdim": 16, + "num_stages": 1, + "num_warps": 2, + "waves_per_eu": 2 + }, + "any": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 3072, + "cache_modifier": "", + "kpack": 2, + "matrix_instr_nonkdim": 16, + "num_stages": 1, + "num_warps": 2, + "waves_per_eu": 2 + } +} diff --git a/aiter/ops/triton/gemm_a8w8_per_token_scale.py b/aiter/ops/triton/gemm_a8w8_per_token_scale.py index e8032bdbeb..3a29f0a52a 100644 --- a/aiter/ops/triton/gemm_a8w8_per_token_scale.py +++ b/aiter/ops/triton/gemm_a8w8_per_token_scale.py @@ -19,7 +19,7 @@ def gemm_a8w8_per_token_scale( w: torch.Tensor, x_scale: torch.Tensor, w_scale: torch.Tensor, - dtype: Optional[float] = torch.bfloat16, + dtype: Optional[torch.dtype] = torch.bfloat16, y: Optional[torch.Tensor] = None, config=None, ): diff --git a/aiter/ops/triton/tune_a16w16_atomic.py b/aiter/ops/triton/tune_a16w16_atomic.py index 44a73c4494..0b57b6e596 100644 --- a/aiter/ops/triton/tune_a16w16_atomic.py +++ b/aiter/ops/triton/tune_a16w16_atomic.py @@ -58,14 +58,20 @@ def get_configs_compute_bound() -> List[Dict[str, int | str]]: # Optimize parameters based on kernel analysis # Only test parameters that are actually used in the kernel # Based on the generated configs, we'll use the optimal values - for num_stages in [2,3,4,5]: # Only 1 stage is used - for block_m in [16,32,64,128]: # Fixed to 64 as in current configs - for block_k in [64,128]: # Fixed to 64 as in current configs - for block_n in [32,64,128,256]: # Fixed to 32 as in current configs - for group_size in [1,8,16,32,64]: - for num_warps in [4,8]: - for num_ksplit in [1,2,4,8]: # Only test 1 since higher values may cause issues - for waves_per_eu in [2,4,8]: # Fixed to 3 as in current configs + for num_stages in [2, 3, 4, 5]: # Only 1 stage is used + for block_m in [16, 32, 64, 128, 256]: # Fixed to 64 as in current configs + for block_n in [32, 64, 128, 256]: # Fixed to 32 as in current configs + for block_k in [64, 128]: # Fixed to 64 as in current configs + for group_size in [1, 8, 16, 32, 64]: + for num_warps in [4, 8]: + for num_ksplit in [ + 1, + ]: # Only test 1 since higher values may cause issues + for waves_per_eu in [ + 2, + 4, + 8, + ]: # Fixed to 3 as in current configs configs.append( { "BLOCK_SIZE_M": block_m, @@ -87,9 +93,9 @@ def get_configs_compute_bound() -> List[Dict[str, int | str]]: def get_weight_shapes(tp_size: int) -> List[Tuple[int, int]]: total = [ - (1024, 1024), - (4096, 1024), - (1024, 2048), + # (1024, 1024), + # (4096, 1024), + # (1024, 2048), (6144, 1024), (1024, 3072), ] @@ -169,7 +175,7 @@ def tune( ): if input_type == "bfloat16": # Use the same input generation as test file - x, w, _, out_dtype, y = generate_gemm_a16w16_inputs( + x, weight, x_scale, w_scale, y = generate_gemm_a16w16_inputs( M, N, K, torch.bfloat16, output=True ) else: @@ -181,7 +187,7 @@ def tune( try: kernel_time = benchmark_config( x=x, - w=w, + w=weight, dtype=torch.float32, y=None, config=config, diff --git a/aiter/ops/triton/tune_a8w8_blockscale.py b/aiter/ops/triton/tune_a8w8_blockscale.py index 47c26c0d50..4edd45d72a 100644 --- a/aiter/ops/triton/tune_a8w8_blockscale.py +++ b/aiter/ops/triton/tune_a8w8_blockscale.py @@ -1,16 +1,29 @@ import argparse import json +import logging import multiprocessing as mp import os +import signal import time import triton from datetime import datetime -from typing import List, Dict, Union, Tuple, Optional +from typing import List, Dict, Union, Tuple, Optional, Any import torch +from rich.console import Console +from rich.progress import ( + Progress, + SpinnerColumn, + TextColumn, + BarColumn, + TaskProgressColumn, + TimeElapsedColumn, +) +from rich.table import Table +from rich.panel import Panel from tqdm import tqdm -from gemm_a8w8_blockscale import gemm_a8w8_blockscale # type: ignore +from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale # type: ignore from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH # type: ignore from aiter.ops.triton.utils.types import get_fp8_dtypes @@ -20,52 +33,340 @@ e5m2_type, e4m3_type = get_fp8_dtypes() -def generate_gemm_a8w8_blockscale_inputs(M, N, K, dtype, block_shape=(128, 128), layout="TN", output=True, bias=False): +class TimeoutError(Exception): + """Custom exception for timeout errors.""" + + pass + + +# Global variables to track bad configurations and current state +BAD_CONFIGS = { + "timeouts": [], + "out_of_resources": [], + "assert_errors": [], + "other_errors": [], +} + +# Global variables to track current state for SIGINT handling +CURRENT_CONFIG: Dict[str, Any] = { + "M": None, + "N": None, + "K": None, + "config": None, + "config_index": None, + "total_configs": None, + "batch_size": None, + "weight_shape_index": None, + "total_weight_shapes": None, + "gpu_id": None, +} + +INTERRUPTED = False + + +def sigint_handler(signum, frame): + """Handle SIGINT (Ctrl+C) gracefully by logging the current configuration.""" + global INTERRUPTED + INTERRUPTED = True + + print("\n" + "=" * 80) + print("๐Ÿ›‘ TUNING INTERRUPTED BY USER (Ctrl+C)") + print("=" * 80) + + if CURRENT_CONFIG["M"] is not None: + print("๐Ÿ“ Last configuration being processed:") + print(f" ๐ŸŽฏ GPU: {CURRENT_CONFIG['gpu_id']}") + print( + f" ๐Ÿ“Š Matrix: M={CURRENT_CONFIG['M']} N={CURRENT_CONFIG['N']} K={CURRENT_CONFIG['K']}" + ) + print(f" ๐Ÿ“ฆ Batch Size: {CURRENT_CONFIG['batch_size']}") + print( + f" ๐Ÿ”„ Progress: Config {CURRENT_CONFIG['config_index'] + 1}/{CURRENT_CONFIG['total_configs']}" + ) + print( + f" ๐Ÿ—๏ธ Weight Shape: {CURRENT_CONFIG['weight_shape_index'] + 1}/{CURRENT_CONFIG['total_weight_shapes']}" + ) + + if CURRENT_CONFIG["config"]: + config = CURRENT_CONFIG["config"] + print(" โš™๏ธ Parameters:") + print(f" BLOCK_SIZE_M: {config.get('BLOCK_SIZE_M', 'N/A')}") + print(f" BLOCK_SIZE_N: {config.get('BLOCK_SIZE_N', 'N/A')}") + print(f" BLOCK_SIZE_K: {config.get('BLOCK_SIZE_K', 'N/A')}") + print(f" num_warps: {config.get('num_warps', 'N/A')}") + print(f" num_stages: {config.get('num_stages', 'N/A')}") + print(f" NUM_KSPLIT: {config.get('NUM_KSPLIT', 'N/A')}") + print(f" waves_per_eu: {config.get('waves_per_eu', 'N/A')}") + print(f" kpack: {config.get('kpack', 'N/A')}") + print(f" cache_modifier: {config.get('cache_modifier', 'N/A')}") + + # Log the interruption to the file if logger is available + try: + logger = logging.getLogger("gemm_a8w8_blockscale_tuning") + if logger.handlers: + log_entry = { + "timestamp": datetime.now().isoformat(), + "event_type": "user_interrupt", + "batch_size": CURRENT_CONFIG["batch_size"], + "matrix_dims": f"M={CURRENT_CONFIG['M']} N={CURRENT_CONFIG['N']} K={CURRENT_CONFIG['K']}", + "config": CURRENT_CONFIG["config"], + "progress": f"Config {CURRENT_CONFIG['config_index'] + 1}/{CURRENT_CONFIG['total_configs']}", + "weight_shape_progress": f"Shape {CURRENT_CONFIG['weight_shape_index'] + 1}/{CURRENT_CONFIG['total_weight_shapes']}", + } + logger.info(f"USER_INTERRUPT: {log_entry}") + + # Force flush to write immediately + for handler in logger.handlers: + if hasattr(handler, "stream"): + handler.stream.flush() + + print(" ๐Ÿ“ Interruption logged to tuning log file") + except Exception as e: + print(f" โš ๏ธ Could not log interruption: {e}") + + print("\n๐Ÿ’ก You can use this information to:") + print(" โ€ข Skip this problematic configuration in future runs") + print(" โ€ข Analyze why this specific config might be causing issues") + print(" โ€ข Adjust the search space to avoid similar parameter combinations") + print("=" * 80) + + # Exit gracefully + import sys + + sys.exit(1) + + +def setup_logger(log_file_path: str) -> logging.Logger: + """ + Setup logger for recording bad configurations during tuning. + + Args: + log_file_path: Path to the log file + + Returns: + Configured logger instance + """ + logger = logging.getLogger("gemm_a8w8_blockscale_tuning") + logger.setLevel(logging.INFO) + + # Clear existing handlers + logger.handlers.clear() + + # Create file handler with live writing (immediate flush) + file_handler = logging.FileHandler(log_file_path, mode="w") + file_handler.setLevel(logging.INFO) + + # Create custom formatter that flushes immediately + file_handler.flush = lambda: file_handler.stream.flush() # type: ignore + + # Create console handler + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.WARNING) + + # Create formatter + formatter = logging.Formatter( + "%(asctime)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S" + ) + file_handler.setFormatter(formatter) + console_handler.setFormatter(formatter) + + # Add handlers to logger + logger.addHandler(file_handler) + logger.addHandler(console_handler) + + return logger + + +def log_bad_config( + logger: logging.Logger, + error_type: str, + M: int, + N: int, + K: int, + config: Dict[str, Union[str, int]], + error_msg: str = "", +): """ - Generate inputs for gemm_a8w8_blockscale kernel. - + Log a bad configuration that failed during tuning. + Args: + logger: Logger instance + error_type: Type of error ('timeout', 'out_of_resources', 'assert_error', 'other_error') M, N, K: Matrix dimensions - dtype: Output data type - block_shape: Tuple of (block_shape_n, block_shape_k) for block scaling - layout: Input matrix layout - output: Whether to generate output tensor - bias: Whether to generate bias tensor - + config: Configuration that failed + error_msg: Additional error message + """ + log_entry = { + "timestamp": datetime.now().isoformat(), + "error_type": error_type, + "batch_size": M, + "matrix_dims": f"M={M} N={N} K={K}", + "config": config, + "error_msg": str(error_msg), + } + + # Log to file + logger.info(f"BAD_CONFIG_{error_type.upper()}: {log_entry}") + + # Force flush to write immediately + for handler in logger.handlers: + if hasattr(handler, "stream"): + handler.stream.flush() + + # Store in global list for summary + if error_type == "timeout": + BAD_CONFIGS["timeouts"].append(log_entry) + elif error_type == "out_of_resources": + BAD_CONFIGS["out_of_resources"].append(log_entry) + elif error_type == "assert_error": + BAD_CONFIGS["assert_errors"].append(log_entry) + else: + BAD_CONFIGS["other_errors"].append(log_entry) + + +def log_bad_config_summary(logger: logging.Logger, total_configs_tested: int): + """ + Log a summary of all bad configurations encountered during tuning. + + Args: + logger: Logger instance + total_configs_tested: Total number of configurations tested + """ + total_bad = ( + len(BAD_CONFIGS["timeouts"]) + + len(BAD_CONFIGS["out_of_resources"]) + + len(BAD_CONFIGS["assert_errors"]) + + len(BAD_CONFIGS["other_errors"]) + ) + success_rate = ( + ((total_configs_tested - total_bad) / total_configs_tested * 100) + if total_configs_tested > 0 + else 0 + ) + + logger.info("=" * 80) + logger.info("BAD CONFIGURATION SUMMARY") + logger.info("=" * 80) + logger.info(f"Total configurations tested: {total_configs_tested}") + logger.info(f"Successful configurations: {total_configs_tested - total_bad}") + logger.info(f"Failed configurations: {total_bad}") + logger.info(f"Success rate: {success_rate:.2f}%") + logger.info("") + + logger.info(f"Timeouts: {len(BAD_CONFIGS['timeouts'])}") + logger.info(f"Out of Resources: {len(BAD_CONFIGS['out_of_resources'])}") + logger.info(f"Assert Errors: {len(BAD_CONFIGS['assert_errors'])}") + logger.info(f"Other Errors: {len(BAD_CONFIGS['other_errors'])}") + logger.info("") + + if BAD_CONFIGS["timeouts"]: + logger.info("TIMEOUT CONFIGS (most problematic):") + for entry in BAD_CONFIGS["timeouts"]: + config = entry["config"] + logger.info( + f" - Batch {entry['batch_size']} | {entry['matrix_dims']} | BM:{config.get('BLOCK_SIZE_M', 'N/A')}, BN:{config.get('BLOCK_SIZE_N', 'N/A')}, BK:{config.get('BLOCK_SIZE_K', 'N/A')}, W:{config.get('num_warps', 'N/A')}, S:{config.get('num_stages', 'N/A')}, KS:{config.get('NUM_KSPLIT', 'N/A')}" + ) + + if BAD_CONFIGS["out_of_resources"]: + logger.info("OUT OF RESOURCE CONFIGS:") + for entry in BAD_CONFIGS["out_of_resources"]: + config = entry["config"] + logger.info( + f" - Batch {entry['batch_size']} | {entry['matrix_dims']} | BM:{config.get('BLOCK_SIZE_M', 'N/A')}, BN:{config.get('BLOCK_SIZE_N', 'N/A')}, BK:{config.get('BLOCK_SIZE_K', 'N/A')}, W:{config.get('num_warps', 'N/A')}, S:{config.get('num_stages', 'N/A')}, KS:{config.get('NUM_KSPLIT', 'N/A')}" + ) + + logger.info("=" * 80) + + # Print summary to console as well + print("\n๐Ÿ“Š Bad Configuration Summary:") + print(f" Total tested: {total_configs_tested}") + print(f" โœ… Successful: {total_configs_tested - total_bad}") + print( + f" โŒ Failed: {total_bad} ({len(BAD_CONFIGS['timeouts'])} timeouts, {len(BAD_CONFIGS['out_of_resources'])} OOM, {len(BAD_CONFIGS['assert_errors'])} assert, {len(BAD_CONFIGS['other_errors'])} other)" + ) + print(f" ๐Ÿ“ˆ Success rate: {success_rate:.1f}%") + + +def timeout_handler(signum, frame): + """Signal handler for timeout.""" + raise TimeoutError("Kernel execution timed out") + + +def run_with_timeout(func, timeout_seconds=3, *args, **kwargs): + """ + Run a function with a timeout limit. + + Args: + func: Function to execute + timeout_seconds: Timeout in seconds (default: 3) + *args: Arguments to pass to the function + **kwargs: Keyword arguments to pass to the function + Returns: - Tuple of (x, w, x_scale, w_scale, bias, y) + Result of the function call + + Raises: + TimeoutError: If function execution exceeds timeout + """ + # Set the signal handler + old_handler = signal.signal(signal.SIGALRM, timeout_handler) + signal.alarm(timeout_seconds) + + try: + result = func(*args, **kwargs) + return result + finally: + # Cancel the alarm and restore the old handler + signal.alarm(0) + signal.signal(signal.SIGALRM, old_handler) + + +def generate_gemm_a8w8_blockscale_inputs( + M: int, + N: int, + K: int, + block_shape_n: int, + block_shape_k: int, + dtype=torch.bfloat16, + layout: str = "TN", + output=False, +): + """ + The GEMM kernel expects: + - x: (M, K) -> row-major format + - w: (N, K) -> column-major format """ - block_shape_n, block_shape_k = block_shape scale_n = (N + block_shape_n - 1) // block_shape_n scale_k = (K + block_shape_k - 1) // block_shape_k - # Generate input matrix x (M, K) if layout[0] == "T": x = (torch.rand((M, K), dtype=torch.float16, device="cuda") / 10).to(e4m3_type) else: - x = ((torch.rand((K, M), dtype=torch.float16, device="cuda") / 10).to(e4m3_type)).T + x = ( + (torch.rand((K, M), dtype=torch.float16, device="cuda") / 10) + .to(e4m3_type) + .T + ) - # Generate weight matrix w (N, K) if layout[1] == "N": - w = (torch.rand((N, K), dtype=torch.float16, device="cuda") / 10).to(e4m3_type) + weight = (torch.rand((N, K), dtype=torch.float16, device="cuda") / 10).to( + e4m3_type + ) else: - w = ((torch.rand((K, N), dtype=torch.float16, device="cuda") / 10).to(e4m3_type)).T + weight = ( + (torch.rand((K, N), dtype=torch.float16, device="cuda") / 10) + .to(e4m3_type) + .T + ) - # Generate scale tensors x_scale = torch.rand([M, scale_k], dtype=torch.float32, device="cuda") w_scale = torch.rand([scale_n, scale_k], dtype=torch.float32, device="cuda") - # Generate bias tensor if needed - bias_tensor = None - if bias: - bias_tensor = torch.empty((N), dtype=dtype, device="cuda") - - # Generate output tensor if needed y = None if output: - y = torch.empty((M, N), dtype=dtype, device="cuda") + y = torch.empty((M, N), dtype=dtype, device="cuda").cuda() - return x, w, x_scale, w_scale, bias_tensor, y + return x, weight, x_scale, w_scale, y def get_configs_compute_bound() -> List[Dict[str, int | str]]: @@ -75,38 +376,64 @@ def get_configs_compute_bound() -> List[Dict[str, int | str]]: Note: GROUP_K must equal BLOCK_SIZE_K as required by the kernel. """ configs = [] - # Based on the sample config from MI300X-GEMM-A8W8_BLOCKSCALE.json - # We'll explore a reasonable range around these values - for num_stages in [1, 2]: # Sample config uses 2 - for block_m in [64, 128, 256]: # Sample config uses 128 - for block_k in [64, 128, 256]: # Sample config uses 128 - for block_n in [64, 128, 256]: # Sample config uses 128 - for group_size in [1, 8]: # Sample config uses 1 - for num_warps in [4, 8]: # Sample config uses 4 - for num_ksplit in [1, 2, 4]: # Sample config uses 1 - for waves_per_eu in [1, 2, 4]: # Sample config uses 2 - for kpack in [1,2 ]: # Sample config uses 2 - for cache_modifier in ["", ".cg"]: # Sample config uses ".cg" - configs.append( - { - "BLOCK_SIZE_M": block_m, - "BLOCK_SIZE_N": block_n, - "BLOCK_SIZE_K": block_k, - "GROUP_SIZE_M": group_size, - "num_warps": num_warps, - "num_stages": num_stages, - "waves_per_eu": waves_per_eu, - "matrix_instr_nonkdim": 16, # Fixed value used in kernel - "cache_modifier": cache_modifier, - "NUM_KSPLIT": num_ksplit, - "kpack": kpack, - # "SPLITK_BLOCK_SIZE": 1, # Will be set dynamically - } - ) + # Start with the known working configuration from MI300X-GEMM-A8W8_BLOCKSCALE.json + base_config = { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 2, + "cache_modifier": ".cg", + } + + # Add the base config first (known to work) + configs.append(base_config.copy()) + + # Generate variations around the base config, but be conservative + for block_m in [ + 32, + 64, + 128, + ]: + for block_n in [ + 32, + 64, + 128, + ]: + for block_k in [64, 128]: # Keep as power of 2 + for num_warps in [4, 8]: + for num_stages in [2, 3, 4, 5]: + for waves_per_eu in [2, 4, 8]: + for cache_modifier in [ + ".cg", + "", + ]: # Start with cache modifier + config = { + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_K": block_k, + "GROUP_SIZE_M": 1, # Keep fixed for now + "num_warps": num_warps, + "num_stages": num_stages, + "waves_per_eu": waves_per_eu, # Keep fixed for now + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 2, # Keep fixed for now + "cache_modifier": cache_modifier, + } + configs.append(config) + + print(f"Generated {len(configs)} configurations") return configs -def get_weight_shapes(tp_size: int) -> List[Tuple[int, int]]: +def get_weight_shapes(tp_size: int = 1) -> List[Tuple[int, int]]: """Get weight shapes to test during tuning.""" total = [ (1024, 1024), @@ -123,33 +450,27 @@ def get_weight_shapes(tp_size: int) -> List[Tuple[int, int]]: return weight_shapes -def run_torch_reference(x, w, x_scale, w_scale, bias, dtype=torch.bfloat16, block_shape=(128, 128)): - """ - Run reference implementation using PyTorch. - This is used for correctness verification. - Based on the test file implementation. - """ +def run_torch( + x, weight, x_scale, w_scale, block_shape: Tuple[int, int], dtype=torch.bfloat16 +): block_shape_n, block_shape_k = block_shape m, k = x.shape - n = w.shape[0] - scale_n = (n + block_shape_n - 1) // block_shape_n - scale_k = (k + block_shape_k - 1) // block_shape_k - - # Expand scales to match the full matrix dimensions + n = weight.shape[0] + + # Expand scales to match the block sizes x_scale_expanded = x_scale.repeat_interleave(block_shape_k, dim=1) - x_scaled = x.to(x_scale_expanded.dtype) * x_scale_expanded[:m, :k] - x_scaled = x_scaled.view(m, k) - + x_dequant = x.to(x_scale_expanded.dtype) * x_scale_expanded[:m, :k] + + # Expand weight scales: first repeat along N dimension, then along K dimension w_scale_expanded = w_scale.repeat_interleave(block_shape_n, dim=0) w_scale_expanded = w_scale_expanded.repeat_interleave(block_shape_k, dim=1) w_scale_expanded = w_scale_expanded[:n, :k] - w_scaled = w.to(w_scale_expanded.dtype) * w_scale_expanded + weight_dequant = weight.to(w_scale_expanded.dtype) * w_scale_expanded + + out = torch.nn.functional.linear( + x_dequant.to(torch.float32), weight_dequant.to(torch.float32) + ) - # Compute the matrix multiplication with bias if provided - # Convert bias to float32 if it's not None to match the other tensors - bias_float32 = bias.to(torch.float32) if bias is not None else None - out = torch.nn.functional.linear(x_scaled.to(torch.float32), w_scaled.to(torch.float32), bias=bias_float32) - return out.to(dtype) @@ -158,7 +479,6 @@ def benchmark_config( w: torch.Tensor, x_scale: torch.Tensor, w_scale: torch.Tensor, - bias: Optional[torch.Tensor], dtype: torch.dtype, config: Dict[str, Union[str, int]], y: Optional[torch.Tensor] = None, @@ -166,11 +486,11 @@ def benchmark_config( ) -> float: """ Benchmark the performance of a GEMM operation with a specific configuration. - + This function measures the execution time of the gemm_a8w8_blockscale kernel by running it multiple times with synchronization points to ensure accurate timing. It performs warmup runs before the actual benchmarking to account for JIT compilation overhead. - + Args: x (torch.Tensor): Input tensor of shape (M, K) representing the first matrix operand. w (torch.Tensor): Weight tensor of shape (N, K) representing the second matrix operand. @@ -182,39 +502,44 @@ def benchmark_config( y (Optional[torch.Tensor], optional): Output tensor to store the result. If None, a new tensor will be allocated. Defaults to None. num_iters (int, optional): Number of benchmark iterations to run. Defaults to 10. - + Returns: float: Average execution time in microseconds (us) per iteration. - + Note: The function performs 5 warmup iterations before benchmarking to account for JIT compilation and GPU warmup effects. The timing is measured using CUDA events for accurate GPU kernel timing. """ - # Calculate GROUP_K based on the input dimensions and w_scale shape - # This must match BLOCK_SIZE_K to satisfy the kernel assertion - M, K = x.shape - N, _ = w.shape - w_scale_T = w_scale.T # Transpose to match kernel's expectation - group_k = triton.next_power_of_2(triton.cdiv(K, w_scale_T.shape[0])) - - # Create a copy of config to modify - modified_config = config.copy() - # Set BLOCK_SIZE_K to match GROUP_K to satisfy kernel assertion - modified_config["BLOCK_SIZE_K"] = group_k - - # Get reference output for correctness verification - torch_out = run_torch_reference(x, w, x_scale, w_scale, bias, dtype) + + torch_out = run_torch( + x, + w, + x_scale, + w_scale, + (128, 128), # follow test using (128,128) + dtype, + ) # Run kernel def run(): # Pass the modified config to the kernel - return gemm_a8w8_blockscale(x, w, x_scale, w_scale, dtype, y, modified_config, skip_reduce=False) + return gemm_a8w8_blockscale( + x, w, x_scale, w_scale, dtype, y, config, skip_reduce=False + ) torch.cuda.synchronize() - # JIT compilation & warmup - for _ in range(5): - run() + + # JIT compilation & warmup with timeout for entire warmup phase + def run_warmup(): + for i in range(5): + run() + + try: + run_with_timeout(run_warmup, timeout_seconds=3) + except TimeoutError: + # If warmup times out, this config is likely bad, skip it + raise TimeoutError("Warmup phase timed out after 3 seconds") torch.cuda.synchronize() start_event = torch.Event(enable_timing=True) @@ -224,7 +549,12 @@ def run(): for i in range(num_iters): torch.cuda.synchronize() start_event.record() - triton_out_raw = run() + try: + triton_out_raw = run_with_timeout(run, timeout_seconds=3) + except TimeoutError: + # If benchmark iteration times out, skip this config + raise TimeoutError(f"Benchmark iteration {i + 1} timed out after 3 seconds") + # Convert to the same dtype as the reference for comparison # Handle the case where triton_out_raw might be None if triton_out_raw is not None: @@ -240,44 +570,227 @@ def run(): def tune( - M: int, N: int, K: int, search_space: List[Dict[str, int | str]], input_type: str + M: int, + N: int, + K: int, + search_space: List[Dict[str, int | str]], + input_type: str, + logger: logging.Logger, ): """Tune the kernel for specific matrix dimensions.""" - if input_type == "bfloat16": - # Use the same input generation as test file - x, w, x_scale, w_scale, bias, y = generate_gemm_a8w8_blockscale_inputs( - M, N, K, torch.bfloat16, bias=True + # Register SIGINT handler if not already registered + if not signal.getsignal(signal.SIGINT) == sigint_handler: + signal.signal(signal.SIGINT, sigint_handler) + + if input_type != "bfloat16": + raise RuntimeError( + "Currently, only support tune a8w8 blockscale kernel with bfloat16 output." ) - else: - raise RuntimeError("Currently, only support tune a8w8 blockscale kernel with bfloat16 output.") - best_config = None + best_config: Dict[str, Union[int, str]] = {} best_time = float("inf") - for config in tqdm(search_space): - try: - kernel_time = benchmark_config( - x=x, - w=w, - x_scale=x_scale, - w_scale=w_scale, - bias=bias, - dtype=torch.bfloat16, - y=None, - config=config, - num_iters=10, + slow_config_threshold = ( + 1000 # microseconds - configs slower than this get highlighted + ) + + # Initialize Rich console for better formatting + console = Console() + + # Create progress display with Rich + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + TimeElapsedColumn(), + console=console, + transient=False, # Keep progress bar visible + ) as progress: + task = progress.add_task( + f"๐Ÿ”ง Tuning M={M} N={N} K={K}", + total=len(search_space), + ) + + for i, config in enumerate(search_space): + # Check if we were interrupted + if INTERRUPTED: + break + + # Update global state for SIGINT handling + CURRENT_CONFIG.update( + { + "M": M, + "N": N, + "K": K, + "config": config, + "config_index": i, + "total_configs": len(search_space), + "batch_size": M, + } ) - except triton.runtime.autotuner.OutOfResources as e: - # Some configurations may be invalid and fail to compile. - continue - except AssertionError as e: - print("Assert error:", e) - continue - - if kernel_time < best_time: - best_time = kernel_time - best_config = config - now = datetime.now() - print(f"{now.ctime()}] Completed tuning for batch_size={M}") + + # Update progress + progress.update( + task, + advance=1, + description=f"๐Ÿ”ง Testing config {i + 1}/{len(search_space)}", + ) + + # Show current config (only every 10 configs to avoid flicker) + if i % 10 == 0 or i == len(search_space) - 1: + # Create fresh config table with matrix dimensions and batch size + config_table = Table( + title="Current Configuration", + show_header=True, + header_style="bold magenta", + ) + config_table.add_column("Parameter", style="cyan", width=15) + config_table.add_column("Value", style="green", width=10) + + # Add matrix dimensions and batch size first + config_table.add_row( + "[bold yellow]Matrix M[/bold yellow]", str(M), style="yellow" + ) + config_table.add_row( + "[bold yellow]Matrix N[/bold yellow]", str(N), style="yellow" + ) + config_table.add_row( + "[bold yellow]Matrix K[/bold yellow]", str(K), style="yellow" + ) + config_table.add_row( + "[bold yellow]Batch Size[/bold yellow]", str(M), style="yellow" + ) + config_table.add_row("", "") # Separator + config_table.add_row( + "BLOCK_SIZE_M", str(config.get("BLOCK_SIZE_M", "N/A")) + ) + config_table.add_row( + "BLOCK_SIZE_N", str(config.get("BLOCK_SIZE_N", "N/A")) + ) + config_table.add_row( + "BLOCK_SIZE_K", str(config.get("BLOCK_SIZE_K", "N/A")) + ) + config_table.add_row("num_warps", str(config.get("num_warps", "N/A"))) + config_table.add_row("num_stages", str(config.get("num_stages", "N/A"))) + config_table.add_row("NUM_KSPLIT", str(config.get("NUM_KSPLIT", "N/A"))) + config_table.add_row( + "waves_per_eu", str(config.get("waves_per_eu", "N/A")) + ) + + # Create summary header with all tuning parameters + header_text = f"[bold blue]๐Ÿ”ง Tuning M={M} N={N} K={K} | Batch Size={M} | Config {i + 1}/{len(search_space)}[/bold blue]" + + # Show config info (don't clear screen to avoid issues in multiprocessing) + console.print(f"\n{header_text}") + console.print(config_table) + + if best_time != float("inf"): + console.print( + f"[yellow]๐Ÿ† Best time so far: {best_time:.1f}ฮผs[/yellow]" + ) + console.print("โ”€" * 70) + + try: + # Use the same input generation as test file + x, w, x_scale, w_scale, _ = generate_gemm_a8w8_blockscale_inputs( + M, + N, + K, + int(config["BLOCK_SIZE_N"]), + int(config["BLOCK_SIZE_K"]), + torch.bfloat16, + ) + kernel_time = benchmark_config( + x=x, + w=w, + x_scale=x_scale, + w_scale=w_scale, + dtype=torch.bfloat16, + y=None, + config=config, + num_iters=10, + ) + + # Warn about slow configs + if kernel_time > slow_config_threshold: + console.print( + f"\n[bold yellow]โš ๏ธ SLOW CONFIG DETECTED: {kernel_time:.1f}ฮผs[/bold yellow]" + ) + console.print( + f"[cyan]๐Ÿ“Š Matrix: M={M} N={N} K={K} | Config:[/cyan] BM:{config.get('BLOCK_SIZE_M', 'N/A')}, BN:{config.get('BLOCK_SIZE_N', 'N/A')}, BK:{config.get('BLOCK_SIZE_K', 'N/A')}, W:{config.get('num_warps', 'N/A')}, S:{config.get('num_stages', 'N/A')}, KS:{config.get('NUM_KSPLIT', 'N/A')}" + ) + + # Update best time and config + if kernel_time < best_time: + best_time = kernel_time + best_config = config + + except triton.runtime.autotuner.OutOfResources as e: + # Log and skip out of resources configurations + log_bad_config(logger, "out_of_resources", M, N, K, config, str(e)) + console.print( + f"\n[bold red]โš ๏ธ Out of resources for M={M} N={N} K={K} - logged[/bold red]" + ) + continue + except AssertionError as e: + # Log and skip assert error configurations + log_bad_config(logger, "assert_error", M, N, K, config, str(e)) + console.print( + f"\n[bold red]โŒ Assert error for M={M} N={N} K={K} - logged[/bold red]" + ) + console.print(f"[red]๐Ÿ’ฌ Error:[/red] {e}") + continue + except TimeoutError as e: + # Log and skip timeout configurations + log_bad_config(logger, "timeout", M, N, K, config, str(e)) + console.print( + f"\n[bold orange1]โฑ๏ธ TIMEOUT for M={M} N={N} K={K} - logged[/bold orange1]" + ) + console.print(f"[orange1]๐Ÿ’ฌ Timeout:[/orange1] {e}") + continue + except Exception as e: + # Log and skip other error configurations + log_bad_config(logger, "other_error", M, N, K, config, str(e)) + console.print( + f"\n[bold red]๐Ÿ’ฅ Unexpected error for M={M} N={N} K={K} - logged[/bold red]" + ) + console.print(f"[red]๐Ÿ’ฌ Error:[/red] {e}") + continue + + # Show final completion message with Rich + print("\n" + "=" * 70) + + # Create best config table with matrix dimensions + best_table = Table( + title="๐Ÿ† Best Configuration Found", show_header=True, header_style="bold green" + ) + best_table.add_column("Parameter", style="cyan", width=15) + best_table.add_column("Value", style="green", width=10) + + # Add matrix dimensions and batch size first + best_table.add_row( + "[bold yellow]Matrix M (Batch)[/bold yellow]", str(M), style="yellow" + ) + best_table.add_row("[bold yellow]Matrix N[/bold yellow]", str(N), style="yellow") + best_table.add_row("[bold yellow]Matrix K[/bold yellow]", str(K), style="yellow") + best_table.add_row("", "") # Separator + best_table.add_row("Performance", f"{best_time:.1f}ฮผs") + best_table.add_row("BLOCK_SIZE_M", str(best_config.get("BLOCK_SIZE_M", "N/A"))) + best_table.add_row("BLOCK_SIZE_N", str(best_config.get("BLOCK_SIZE_N", "N/A"))) + best_table.add_row("BLOCK_SIZE_K", str(best_config.get("BLOCK_SIZE_K", "N/A"))) + best_table.add_row("num_warps", str(best_config.get("num_warps", "N/A"))) + best_table.add_row("num_stages", str(best_config.get("num_stages", "N/A"))) + best_table.add_row("NUM_KSPLIT", str(best_config.get("NUM_KSPLIT", "N/A"))) + best_table.add_row("waves_per_eu", str(best_config.get("waves_per_eu", "N/A"))) + + completion_panel = Panel( + best_table, + title=f"[bold green]โœ… Completed Tuning for M={M} N={N} K={K} (Batch Size={M})[/bold green]", + border_style="green", + ) + console.print(completion_panel) + print("=" * 70) + assert best_config is not None return best_config @@ -290,7 +803,7 @@ def save_configs( ) -> None: """Save the best configurations to a JSON file.""" os.makedirs(save_path, exist_ok=True) - device_name = "MI300X" # TODO: Hardcoded, make it dynamic + device_name = "R9700" # TODO: Hardcoded, make it dynamic json_file_name = f"{device_name}-GEMM-A8W8_BLOCKSCALE-N={N}-K={K}.json" config_file_path = os.path.join(save_path, json_file_name) @@ -308,31 +821,79 @@ def tune_on_gpu( input_type: str, ) -> None: """Run tuning on a specific GPU.""" + # Register SIGINT handler and set GPU ID in global state + signal.signal(signal.SIGINT, sigint_handler) + CURRENT_CONFIG["gpu_id"] = gpu_id + torch.cuda.set_device(gpu_id) - print(f"Starting tuning on GPU {gpu_id} with batch sizes {batch_sizes}") + print(f"๐Ÿš€ Starting tuning on GPU {gpu_id} with batch sizes {batch_sizes}") save_path = AITER_TRITON_CONFIGS_PATH + "/gemm/" + # Setup logger for this GPU with proper prefix + log_file_path = os.path.join( + save_path, f"tune_a8w8_blockscale_bad_configs_gpu{gpu_id}.log" + ) + logger = setup_logger(log_file_path) + logger.info(f"Starting tuning on GPU {gpu_id} with batch sizes {batch_sizes}") + search_space = get_configs_compute_bound() + total_configs = len(search_space) + total_tests = total_configs * len(batch_sizes) * len(weight_shapes) + + print(f" ๐Ÿ“Š Search space: {total_configs:,} configurations") + print(f" ๐ŸŽฏ Total tests to run: {total_tests:,}") + print( + f" โšก Estimated tests per weight shape: {total_configs * len(batch_sizes):,}" + ) + print(f" ๐Ÿ“ Bad configurations will be logged to: {log_file_path}") start = time.time() # Collect all configs to determine the best overall config all_configs: List[Dict[str, Dict[str, int | str]]] = [] - for shape in tqdm(weight_shapes, desc=f"GPU {gpu_id} - Shapes"): + for i, shape in enumerate(weight_shapes): + # Check if we were interrupted + if INTERRUPTED: + break + + # Update weight shape tracking + CURRENT_CONFIG.update( + {"weight_shape_index": i, "total_weight_shapes": len(weight_shapes)} + ) + N, K = shape[0], shape[1] - print(f"[GPU {gpu_id}] Tune for weight shape of `N: {N}, K: {K}`") - benchmark_results = [ - tune( + print( + f"\n๐Ÿš€ [GPU {gpu_id}] Shape {i + 1}/{len(weight_shapes)}: Starting tuning for N:{N}, K:{K}" + ) + print( + f" ๐Ÿ“Š Testing {len(search_space):,} configurations across {len(batch_sizes)} batch sizes" + ) + + benchmark_results = [] + for batch_size in batch_sizes: + # Check if we were interrupted + if INTERRUPTED: + break + + print( + f"\n ๐Ÿ” [GPU {gpu_id}] Testing batch size M={batch_size} for N={N}, K={K}" + ) + result = tune( batch_size, N, K, search_space, input_type, + logger, ) - for batch_size in tqdm(batch_sizes, desc=f"GPU {gpu_id} - Batch sizes") - ] + + # Check if tune() was interrupted + if INTERRUPTED: + break + + benchmark_results.append(result) best_configs: Dict[str, Dict[str, int | str]] = {} # Create configs for different M size categories as expected by the kernel for i, (M, config) in enumerate(zip(batch_sizes, benchmark_results)): @@ -361,6 +922,10 @@ def tune_on_gpu( save_default_config(default_config, save_path) end = time.time() + + # Log summary of bad configurations + log_bad_config_summary(logger, total_tests) + print(f"Tuning on GPU {gpu_id} took {end - start:.2f} seconds") @@ -403,7 +968,7 @@ def save_default_config( ) -> None: """Save the default config file (without N,K parameters).""" os.makedirs(save_path, exist_ok=True) - device_name = "MI300X" # TODO: Hardcoded, make it dynamic + device_name = "R9700" # TODO: Hardcoded, make it dynamic json_file_name = f"{device_name}-GEMM-A8W8_BLOCKSCALE.json" config_file_path = os.path.join(save_path, json_file_name) @@ -450,11 +1015,11 @@ def main(args): weight_shapes = get_weight_shapes(args.tp_size) - batches_per_gpu = distribute_batch_sizes(batch_sizes, 1) + batches_per_gpu = distribute_batch_sizes(batch_sizes, num_gpus) # Prepare arguments for each GPU process process_args = [] - for gpu_id in range(1): + for gpu_id in range(num_gpus): process_args.append( ( gpu_id, @@ -465,7 +1030,7 @@ def main(args): ) ctx = mp.get_context("spawn") - with ctx.Pool(1) as pool: + with ctx.Pool(num_gpus) as pool: pool.starmap(tune_on_gpu, process_args) print("Multi-GPU tuning completed") diff --git a/aiter/ops/triton/tune_a8w8_per_token_scale.py b/aiter/ops/triton/tune_a8w8_per_token_scale.py new file mode 100644 index 0000000000..ad7419b4ed --- /dev/null +++ b/aiter/ops/triton/tune_a8w8_per_token_scale.py @@ -0,0 +1,1520 @@ +import argparse +import json +import logging +import multiprocessing as mp +import os +import signal +import sys +import time +import triton +from datetime import datetime +from typing import List, Dict, Union, Tuple, Optional, Any +import torch +import pytz +from rich.console import Console +from rich.table import Table +from rich.panel import Panel +from rich.columns import Columns +from rich.live import Live +from rich.layout import Layout +from rich.progress import Progress, BarColumn, TextColumn + + +from aiter.ops.triton.gemm_a8w8_per_token_scale import gemm_a8w8_per_token_scale # type: ignore +from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH # type: ignore +from aiter.ops.triton.utils.types import get_fp8_dtypes + +mp.set_start_method("spawn", force=True) + +# Get FP8 data types +e5m2_type, e4m3_type = get_fp8_dtypes() + + +class TimeoutError(Exception): + """Custom exception for timeout errors.""" + + pass + + +# Global variables to track bad configurations and current state +BAD_CONFIGS = { + "timeouts": [], + "out_of_resources": [], + "assert_errors": [], + "other_errors": [], +} + +# Global variables to track current state for SIGINT handling +CURRENT_CONFIG: Dict[str, Any] = { + "M": None, + "N": None, + "K": None, + "config": None, + "config_index": None, + "total_configs": None, + "batch_size": None, + "weight_shape_index": None, + "total_weight_shapes": None, + "gpu_id": None, +} + +INTERRUPTED = False + + +class GMT8Formatter(logging.Formatter): + """Custom formatter that uses GMT+8 timezone.""" + + def __init__(self, fmt=None, datefmt=None): + super().__init__(fmt, datefmt) + self.gmt8 = pytz.timezone("Asia/Shanghai") # GMT+8 + + def formatTime(self, record, datefmt=None): + # Convert timestamp to GMT+8 + dt = datetime.fromtimestamp(record.created, tz=self.gmt8) + if datefmt: + return dt.strftime(datefmt) + else: + return dt.strftime("%Y-%m-%d %H:%M:%S") + + def format(self, record): + # Add timezone info to the formatted message + original = super().format(record) + return original.replace("[GMT+8]", "") # Remove any existing timezone tag + + +def get_timestamped_filename(base_name: str, extension: str = ".log") -> str: + """Generate a filename with timestamp in GMT+8 timezone.""" + gmt8 = pytz.timezone("Asia/Shanghai") + timestamp = datetime.now(gmt8).strftime("%Y%m%d_%H%M%S") + return f"{base_name}_{timestamp}{extension}" + + +def sigint_handler(signum, frame): + """Handle SIGINT (Ctrl+C) gracefully by logging the current configuration.""" + global INTERRUPTED + global CURRENT_CONFIG + INTERRUPTED = True + + print("\n" + "=" * 80) + print("๐Ÿ›‘ TUNING INTERRUPTED BY USER (Ctrl+C)") + print("=" * 80) + + if CURRENT_CONFIG["M"] is not None: + print("๐Ÿ“ Last configuration being processed:") + print(f" ๐ŸŽฏ GPU: {CURRENT_CONFIG['gpu_id']}") + print( + f" ๐Ÿ“Š Matrix: M={CURRENT_CONFIG['M']} N={CURRENT_CONFIG['N']} K={CURRENT_CONFIG['K']}" + ) + print(f" ๐Ÿ“ฆ Batch Size: {CURRENT_CONFIG['batch_size']}") + print( + f" ๐Ÿ”„ Progress: Config {CURRENT_CONFIG['config_index'] + 1}/{CURRENT_CONFIG['total_configs']}" + ) + print( + f" ๐Ÿ—๏ธ Weight Shape: {CURRENT_CONFIG['weight_shape_index'] + 1}/{CURRENT_CONFIG['total_weight_shapes']}" + ) + + if CURRENT_CONFIG["config"]: + config = CURRENT_CONFIG["config"] + print(" โš™๏ธ Parameters:") + print(f" BLOCK_SIZE_M: {config.get('BLOCK_SIZE_M', 'N/A')}") + print(f" BLOCK_SIZE_N: {config.get('BLOCK_SIZE_N', 'N/A')}") + print(f" BLOCK_SIZE_K: {config.get('BLOCK_SIZE_K', 'N/A')}") + print(f" num_warps: {config.get('num_warps', 'N/A')}") + print(f" num_stages: {config.get('num_stages', 'N/A')}") + print(f" NUM_KSPLIT: {config.get('NUM_KSPLIT', 'N/A')}") + print(f" waves_per_eu: {config.get('waves_per_eu', 'N/A')}") + print(f" kpack: {config.get('kpack', 'N/A')}") + print(f" cache_modifier: {config.get('cache_modifier', 'N/A')}") + print(f" GROUP_SIZE_M: {config.get('GROUP_SIZE_M', 'N/A')}") + + # Show config in same format as console output for consistency + config_num = CURRENT_CONFIG["config_index"] + 1 + console_format = f" ๐Ÿ’ป Config {config_num} (INTERRUPTED): | BM:{config.get('BLOCK_SIZE_M')} BN:{config.get('BLOCK_SIZE_N')} BK:{config.get('BLOCK_SIZE_K')} W:{config.get('num_warps')} S:{config.get('num_stages')} KS:{config.get('NUM_KSPLIT')} kpack:{config.get('kpack')} cache:{config.get('cache_modifier')}" + print(console_format) + + # Log the interruption to the file if logger is available + try: + logger = logging.getLogger("gemm_a8w8_per_token_scale_tuning") + if logger.handlers: + # Use GMT+8 timestamp for consistency + gmt8 = pytz.timezone("Asia/Shanghai") + + # Create detailed log entry + detailed_log_entry = { + "timestamp": datetime.now(gmt8).isoformat(), + "event_type": "user_interrupt", + "gpu_id": CURRENT_CONFIG.get("gpu_id", "N/A"), + "batch_size": CURRENT_CONFIG["batch_size"], + "matrix_dims": f"M={CURRENT_CONFIG['M']} N={CURRENT_CONFIG['N']} K={CURRENT_CONFIG['K']}", + "config": CURRENT_CONFIG["config"], + "progress": f"Config {CURRENT_CONFIG['config_index'] + 1}/{CURRENT_CONFIG['total_configs']}", + "weight_shape_progress": f"Shape {CURRENT_CONFIG['weight_shape_index'] + 1}/{CURRENT_CONFIG['total_weight_shapes']}", + } + + # Log detailed interruption info + logger.info(f"=== USER INTERRUPT ===") + logger.info( + f"Interrupted while testing: Config {CURRENT_CONFIG['config_index'] + 1}/{CURRENT_CONFIG['total_configs']}" + ) + logger.info(f"GPU: {CURRENT_CONFIG.get('gpu_id', 'N/A')}") + logger.info( + f"Matrix: M={CURRENT_CONFIG['M']} N={CURRENT_CONFIG['N']} K={CURRENT_CONFIG['K']}" + ) + logger.info( + f"Weight Shape Progress: {CURRENT_CONFIG['weight_shape_index'] + 1}/{CURRENT_CONFIG['total_weight_shapes']}" + ) + + # Log config details in same format as console output for consistency + if CURRENT_CONFIG["config"]: + config = CURRENT_CONFIG["config"] + config_num = CURRENT_CONFIG["config_index"] + 1 + config_str = f"Config {config_num} (INTERRUPTED): | BM:{config.get('BLOCK_SIZE_M')} BN:{config.get('BLOCK_SIZE_N')} BK:{config.get('BLOCK_SIZE_K')} W:{config.get('num_warps')} S:{config.get('num_stages')} KS:{config.get('NUM_KSPLIT')} kpack:{config.get('kpack')} cache:{config.get('cache_modifier')} GROUP_SIZE_M:{config.get('GROUP_SIZE_M')}" + logger.info(f"CONFIG_DETAILS: {config_str}") + + logger.info(f"DETAILED_ENTRY: {detailed_log_entry}") + logger.info(f"=== END USER INTERRUPT ===") + + # Force flush to write immediately + for handler in logger.handlers: + if hasattr(handler, "stream"): + handler.stream.flush() + + print(" ๐Ÿ“ Interruption logged to tuning log file") + except Exception as e: + print(f" โš ๏ธ Could not log interruption: {e}") + + print("\n๐Ÿ’ก You can use this information to:") + print(" โ€ข Skip this problematic configuration in future runs") + print(" โ€ข Analyze why this specific config might be causing issues") + print(" โ€ข Adjust the search space to avoid similar parameter combinations") + print("=" * 80) + + # Exit gracefully + import sys + + sys.exit(1) + + +def setup_logger(log_file_path: str, mode: str = "a") -> logging.Logger: + """ + Setup logger for recording bad configurations during tuning. + + Args: + log_file_path: Path to the log file + mode: File write mode - 'a' to append to existing logs, 'w' to overwrite + + Returns: + Configured logger instance + """ + logger = logging.getLogger("gemm_a8w8_per_token_scale_tuning") + logger.setLevel(logging.INFO) + + # Clear existing handlers + logger.handlers.clear() + + # Create file handler with live writing (immediate flush) + # Default to append mode to preserve logs across resume sessions + file_handler = logging.FileHandler(log_file_path, mode=mode) + file_handler.setLevel(logging.INFO) + + # Create custom formatter that flushes immediately + file_handler.flush = lambda: file_handler.stream.flush() # type: ignore + + # Create console handler + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.WARNING) + + # Create GMT+8 formatter + formatter = GMT8Formatter( + "%(asctime)s [GMT+8] - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S" + ) + file_handler.setFormatter(formatter) + console_handler.setFormatter(formatter) + + # Add handlers to logger + logger.addHandler(file_handler) + logger.addHandler(console_handler) + + return logger + + +def log_bad_config( + logger: logging.Logger, + error_type: str, + M: int, + N: int, + K: int, + config: Dict[str, Union[str, int]], + error_msg: str = "", +): + """ + Log a bad configuration that failed during tuning. + + Args: + logger: Logger instance + error_type: Type of error ('timeout', 'out_of_resources', 'assert_error', 'other_error') + M, N, K: Matrix dimensions + config: Configuration that failed + error_msg: Additional error message + """ + # Use GMT+8 timestamp for consistency + gmt8 = pytz.timezone("Asia/Shanghai") + log_entry = { + "timestamp": datetime.now(gmt8).isoformat(), + "error_type": error_type, + "batch_size": M, + "matrix_dims": f"M={M} N={N} K={K}", + "config": config, + "error_msg": str(error_msg), + } + + # Log to file + logger.info(f"BAD_CONFIG_{error_type.upper()}: {log_entry}") + + # Force flush to write immediately + for handler in logger.handlers: + if hasattr(handler, "stream"): + handler.stream.flush() + + # Store in global list for summary + if error_type == "timeout": + BAD_CONFIGS["timeouts"].append(log_entry) + elif error_type == "out_of_resources": + BAD_CONFIGS["out_of_resources"].append(log_entry) + elif error_type == "assert_error": + BAD_CONFIGS["assert_errors"].append(log_entry) + else: + BAD_CONFIGS["other_errors"].append(log_entry) + + +def log_bad_config_summary(logger: logging.Logger, total_configs_tested: int): + """ + Log a summary of all bad configurations encountered during tuning. + + Args: + logger: Logger instance + total_configs_tested: Total number of configurations tested + """ + total_bad = ( + len(BAD_CONFIGS["timeouts"]) + + len(BAD_CONFIGS["out_of_resources"]) + + len(BAD_CONFIGS["assert_errors"]) + + len(BAD_CONFIGS["other_errors"]) + ) + success_rate = ( + ((total_configs_tested - total_bad) / total_configs_tested * 100) + if total_configs_tested > 0 + else 0 + ) + + logger.info("=" * 80) + logger.info("BAD CONFIGURATION SUMMARY") + logger.info("=" * 80) + logger.info(f"Total configurations tested: {total_configs_tested}") + logger.info(f"Successful configurations: {total_configs_tested - total_bad}") + logger.info(f"Failed configurations: {total_bad}") + logger.info(f"Success rate: {success_rate:.2f}%") + logger.info("") + + logger.info(f"Timeouts: {len(BAD_CONFIGS['timeouts'])}") + logger.info(f"Out of Resources: {len(BAD_CONFIGS['out_of_resources'])}") + logger.info(f"Assert Errors: {len(BAD_CONFIGS['assert_errors'])}") + logger.info(f"Other Errors: {len(BAD_CONFIGS['other_errors'])}") + logger.info("") + + if BAD_CONFIGS["timeouts"]: + logger.info("TIMEOUT CONFIGS (most problematic):") + for entry in BAD_CONFIGS["timeouts"]: + config = entry["config"] + logger.info( + f" - Batch {entry['batch_size']} | {entry['matrix_dims']} | BM:{config.get('BLOCK_SIZE_M', 'N/A')}, BN:{config.get('BLOCK_SIZE_N', 'N/A')}, BK:{config.get('BLOCK_SIZE_K', 'N/A')}, W:{config.get('num_warps', 'N/A')}, S:{config.get('num_stages', 'N/A')}, KS:{config.get('NUM_KSPLIT', 'N/A')}" + ) + + if BAD_CONFIGS["out_of_resources"]: + logger.info("OUT OF RESOURCE CONFIGS:") + for entry in BAD_CONFIGS["out_of_resources"]: + config = entry["config"] + logger.info( + f" - Batch {entry['batch_size']} | {entry['matrix_dims']} | BM:{config.get('BLOCK_SIZE_M', 'N/A')}, BN:{config.get('BLOCK_SIZE_N', 'N/A')}, BK:{config.get('BLOCK_SIZE_K', 'N/A')}, W:{config.get('num_warps', 'N/A')}, S:{config.get('num_stages', 'N/A')}, KS:{config.get('NUM_KSPLIT', 'N/A')}" + ) + + logger.info("=" * 80) + + # Print summary to console as well + print("\n๐Ÿ“Š Bad Configuration Summary:") + print(f" Total tested: {total_configs_tested}") + print(f" โœ… Successful: {total_configs_tested - total_bad}") + print( + f" โŒ Failed: {total_bad} ({len(BAD_CONFIGS['timeouts'])} timeouts, {len(BAD_CONFIGS['out_of_resources'])} OOM, {len(BAD_CONFIGS['assert_errors'])} assert, {len(BAD_CONFIGS['other_errors'])} other)" + ) + print(f" ๐Ÿ“ˆ Success rate: {success_rate:.1f}%") + + +def timeout_handler(signum, frame): + """Signal handler for timeout.""" + raise TimeoutError("Kernel execution timed out") + + +def run_with_timeout(func, timeout_seconds=3, *args, **kwargs): + """ + Run a function with a timeout limit. + + Args: + func: Function to execute + timeout_seconds: Timeout in seconds (default: 3) + *args: Arguments to pass to the function + **kwargs: Keyword arguments to pass to the function + + Returns: + Result of the function call + + Raises: + TimeoutError: If function execution exceeds timeout + """ + # Set the signal handler + old_handler = signal.signal(signal.SIGALRM, timeout_handler) + signal.alarm(timeout_seconds) + + try: + result = func(*args, **kwargs) + return result + finally: + # Cancel the alarm and restore the old handler + signal.alarm(0) + signal.signal(signal.SIGALRM, old_handler) + + +def generate_gemm_a8w8_per_token_scale_inputs(M, N, K, dtype, output=True, bias=False): + """ + Generate inputs for gemm_a8w8_per_token_scale kernel. + + Args: + M, N, K: Matrix dimensions + dtype: Output data type + output: Whether to generate output tensor + bias: Whether to generate bias tensor + + Returns: + Tuple of (x, w, x_scale, w_scale, bias, y) + """ + # Generate input matrix x (M, K) - FP8 E4M3 (matching test file pattern) + x = (torch.rand((M, K), dtype=torch.float16, device="cuda") / 10).to(e4m3_type) + + # Generate weight matrix w (N, K) - FP8 E4M3 (matching test file pattern) + w = (torch.rand((N, K), dtype=torch.float16, device="cuda") / 10).to(e4m3_type) + + # Generate per-token and per-channel scale tensors - 2D [M,1] and [N,1] + x_scale = torch.rand([M, 1], dtype=torch.float32, device="cuda") + w_scale = torch.rand([N, 1], dtype=torch.float32, device="cuda") + + # Generate bias tensor if needed + bias_tensor = None + if bias: + bias_tensor = torch.empty((N), dtype=dtype, device="cuda") + + # Generate output tensor if needed + y = None + if output: + y = torch.empty((M, N), dtype=dtype, device="cuda") + + return x, w, x_scale, w_scale, bias_tensor, y + + +def get_configs_compute_bound() -> List[Dict[str, int | str]]: + """ + Generate configuration space for tuning the gemm_a8w8_per_token_scale kernel. + Focus on parameters that affect performance for this specific kernel. + Comprehensive search space matching atomic kernel patterns. + """ + configs = [] + + # Explore optimized parameter space (removed large block sizes that cause slowdowns) + # for num_stages in [1, 2, 3, 4]: + # for block_m in [32, 64, 128]: # Removed 256 (causes slowdowns) + # for block_n in [32, 64, 128]: # Removed 256 (causes slowdowns) + # for block_k in [64, 128, 256]: + # for group_size in [1, 8, 16]: + # for num_warps in [2, 4, 8]: + # for num_ksplit in [ + # 1, + # 2, + # 4, + # ]: # Key parameter for K-splitting + # for waves_per_eu in [2, 4, 8]: + # for kpack in [2]: + # for cache_modifier in ["", ".cg"]: + # configs.append( + # { + # "BLOCK_SIZE_M": block_m, + # "BLOCK_SIZE_N": block_n, + # "BLOCK_SIZE_K": block_k, + # "GROUP_SIZE_M": group_size, + # "num_warps": num_warps, + # "num_stages": num_stages, + # "NUM_KSPLIT": num_ksplit, + # "waves_per_eu": waves_per_eu, + # "kpack": kpack, + # "matrix_instr_nonkdim": 16, # Fixed value from atomic kernel + # "cache_modifier": cache_modifier, + # } + # ) + # return configs + for num_stages in [ + 1, + ]: + for block_m in [ + 32, + ]: # Removed 256 (causes slowdowns) + for block_n in [ + 32, + ]: # Removed 256 (causes slowdowns) + for block_k in [ + 64, + ]: + for group_size in [ + 1, + ]: + for num_warps in [ + 2, + ]: + for num_ksplit in [ + 1, + ]: # Key parameter for K-splitting + for waves_per_eu in [ + 2, + ]: + for kpack in [2]: + for cache_modifier in ["", ".cg"]: + configs.append( + { + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": group_size, + "num_warps": num_warps, + "num_stages": num_stages, + "NUM_KSPLIT": num_ksplit, + "waves_per_eu": waves_per_eu, + "kpack": kpack, + "matrix_instr_nonkdim": 16, # Fixed value from atomic kernel + "cache_modifier": cache_modifier, + } + ) + return configs + + +def get_weight_shapes(tp_size: int) -> List[Tuple[int, int]]: + """Get weight shapes to test during tuning.""" + total = [ + # (1024, 1024), + # (4096, 1024), + # (1024, 2048), + # (6144, 1024), + (1024, 3072), + ] + + weight_shapes: List[Tuple[int, int]] = [] + for t in total: + weight_shapes.append(t) + + return weight_shapes + + +def run_torch_reference(x, w, x_scale, w_scale, bias, dtype=torch.bfloat16): + """ + Run reference implementation using PyTorch. + This is used for correctness verification. + """ + # Apply scaling as in test file: convert to scale dtype, multiply, then compute + x = x.to(x_scale.dtype) * x_scale + w = w.to(w_scale.dtype) * w_scale + + # Compute the matrix multiplication - note weights need to be transposed for torch linear + out = torch.nn.functional.linear(x.to(torch.float32), w.to(torch.float32)) + + return out.to(dtype) + + +def benchmark_config( + x: torch.Tensor, + w: torch.Tensor, + x_scale: torch.Tensor, + w_scale: torch.Tensor, + bias: Optional[torch.Tensor], + dtype: torch.dtype, + config: Dict[str, Union[str, int]], + y: Optional[torch.Tensor] = None, + num_iters=10, +) -> float: + """ + Benchmark the performance of a GEMM operation with a specific configuration. + + This function measures the execution time of the gemm_a8w8_per_token_scale kernel by running + it multiple times with synchronization points to ensure accurate timing. It performs + warmup runs before the actual benchmarking to account for JIT compilation overhead. + + Args: + x (torch.Tensor): Input tensor of shape (M, K) representing the first matrix operand. + w (torch.Tensor): Weight tensor of shape (N, K) representing the second matrix operand. + x_scale (torch.Tensor): Per-token scale tensor for x with shape (M, 1). + w_scale (torch.Tensor): Per-output-channel scale tensor for w with shape (N, 1). + dtype (torch.dtype): Data type for the computation (e.g., torch.bfloat16). + config (Dict[str, Union[str, int]]): Configuration dictionary containing kernel + parameters such as block sizes, number of warps, etc. + y (Optional[torch.Tensor], optional): Output tensor to store the result. If None, + a new tensor will be allocated. Defaults to None. + num_iters (int, optional): Number of benchmark iterations to run. Defaults to 10. + + Returns: + float: Average execution time in microseconds (us) per iteration. + + Note: + The function performs 5 warmup iterations before benchmarking to account for + JIT compilation and GPU warmup effects. The timing is measured using CUDA events + for accurate GPU kernel timing. + """ + # Get reference output for correctness verification + torch_out = run_torch_reference(x, w, x_scale, w_scale, bias, dtype) + + # Add SPLITK_BLOCK_SIZE computation as done in the kernel function + _, K = x.shape + _, K = w.shape + num_ksplit = int(config["NUM_KSPLIT"]) + block_k = int(config["BLOCK_SIZE_K"]) + splitk_block_size = triton.cdiv(K, num_ksplit) + + config["SPLITK_BLOCK_SIZE"] = splitk_block_size + if block_k > splitk_block_size: + block_k = triton.next_power_of_2(splitk_block_size) + if block_k > splitk_block_size: + block_k = block_k // 4 + block_k = max(block_k, 16) + config["BLOCK_SIZE_K"] = block_k + + # Run kernel + def run(): + return gemm_a8w8_per_token_scale(x, w, x_scale, w_scale, dtype, y, config) + + torch.cuda.synchronize() + + # JIT compilation & warmup with timeout for entire warmup phase + def run_warmup(): + for i in range(5): + run() + + try: + run_with_timeout(run_warmup, timeout_seconds=3) + except TimeoutError: + # If warmup times out, this config is likely bad, skip it + raise TimeoutError("Warmup phase timed out after 3 seconds") + torch.cuda.synchronize() + + start_event = torch.Event(enable_timing=True) + end_event = torch.Event(enable_timing=True) + + latencies: list[float] = [] + for i in range(num_iters): + torch.cuda.synchronize() + start_event.record() + try: + triton_out_raw = run_with_timeout(run, timeout_seconds=3) + except TimeoutError: + # If benchmark iteration times out, skip this config + raise TimeoutError(f"Benchmark iteration {i + 1} timed out after 3 seconds") + + # Convert to the same dtype as the reference for comparison + # Handle the case where triton_out_raw might be None + if triton_out_raw is not None: + triton_out = triton_out_raw.to(torch_out.dtype) + else: + triton_out = torch_out # Fallback to reference output + end_event.record() + end_event.synchronize() + latencies.append(start_event.elapsed_time(end_event)) + torch.testing.assert_close(triton_out, torch_out, atol=1e-1, rtol=1e-1) + avg = sum(latencies) / (num_iters * 10) * 1000 # us + return avg + + +# Global variable to store console output +console_output = [] + + +def create_live_display( + M: int, + N: int, + K: int, + current_config: Dict[str, Union[str, int]], + best_config: Dict[str, Union[str, int]], + best_time: float, + config_index: int, + total_configs: int, + console_messages: Optional[List[str]] = None, +) -> Layout: + """Create a live display layout with current and best configuration tables.""" + + layout = Layout() + + # Use global console_output if none provided + if console_messages is None: + global console_output + console_messages = console_output + + # Create progress bar + progress = Progress( + TextColumn("[bold blue]{task.description}"), + BarColumn(bar_width=40), + TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), + TextColumn("({task.completed}/{task.total})"), + ) + task = progress.add_task( + f"๐Ÿ”ง Tuning M={M} N={N} K={K} | Batch Size={M}", + total=total_configs, + completed=config_index, + ) + + # Create status information + status_text = "" + if best_time != float("inf"): + status_text = f"๐Ÿ† Best Performance: {best_time:.1f}ฮผs" + else: + status_text = "๐Ÿ” Searching for best configuration..." + + # Create console area + console_text = ( + "\n".join(console_messages[-10:]) + if console_messages + else "Waiting for results..." + ) + console_table = Table(show_header=False, box=None, padding=0) + console_table.add_column("Output", style="white") + console_table.add_row(console_text) + + # Create current config table + config_table = Table( + title="Current Configuration", + show_header=True, + header_style="bold magenta", + ) + config_table.add_column("Parameter", style="cyan", width=15) + config_table.add_column("Value", style="green", width=10) + + # Add matrix dimensions and batch size first + config_table.add_row("[bold yellow]Matrix M[/bold yellow]", str(M), style="yellow") + config_table.add_row("[bold yellow]Matrix N[/bold yellow]", str(N), style="yellow") + config_table.add_row("[bold yellow]Matrix K[/bold yellow]", str(K), style="yellow") + config_table.add_row( + "[bold yellow]Batch Size[/bold yellow]", str(M), style="yellow" + ) + config_table.add_row("", "") # Separator + config_table.add_row("BLOCK_SIZE_M", str(current_config.get("BLOCK_SIZE_M", "N/A"))) + config_table.add_row("BLOCK_SIZE_N", str(current_config.get("BLOCK_SIZE_N", "N/A"))) + config_table.add_row("BLOCK_SIZE_K", str(current_config.get("BLOCK_SIZE_K", "N/A"))) + config_table.add_row("num_warps", str(current_config.get("num_warps", "N/A"))) + config_table.add_row("num_stages", str(current_config.get("num_stages", "N/A"))) + config_table.add_row("NUM_KSPLIT", str(current_config.get("NUM_KSPLIT", "N/A"))) + config_table.add_row("waves_per_eu", str(current_config.get("waves_per_eu", "N/A"))) + config_table.add_row("kpack", str(current_config.get("kpack", "N/A"))) + config_table.add_row( + "cache_modifier", str(current_config.get("cache_modifier", "N/A")) + ) + config_table.add_row("GROUP_SIZE_M", str(current_config.get("GROUP_SIZE_M", "N/A"))) + + # Create best config table if we have a best configuration + best_config_table = None + if best_time != float("inf"): + best_config_table = Table( + title="๐Ÿ† Best Configuration So Far", + show_header=True, + header_style="bold green", + ) + best_config_table.add_column("Parameter", style="cyan", width=15) + best_config_table.add_column("Value", style="green", width=10) + + # Add performance and matrix dimensions + best_config_table.add_row( + "[bold green]Performance[/bold green]", f"{best_time:.1f}ฮผs", style="green" + ) + best_config_table.add_row("", "") # Separator + best_config_table.add_row( + "[bold yellow]BLOCK_SIZE_M[/bold yellow]", + str(best_config.get("BLOCK_SIZE_M", "N/A")), + style="yellow", + ) + best_config_table.add_row( + "[bold yellow]BLOCK_SIZE_N[/bold yellow]", + str(best_config.get("BLOCK_SIZE_N", "N/A")), + style="yellow", + ) + best_config_table.add_row( + "[bold yellow]BLOCK_SIZE_K[/bold yellow]", + str(best_config.get("BLOCK_SIZE_K", "N/A")), + style="yellow", + ) + best_config_table.add_row("num_warps", str(best_config.get("num_warps", "N/A"))) + best_config_table.add_row( + "num_stages", str(best_config.get("num_stages", "N/A")) + ) + best_config_table.add_row( + "NUM_KSPLIT", str(best_config.get("NUM_KSPLIT", "N/A")) + ) + best_config_table.add_row( + "waves_per_eu", str(best_config.get("waves_per_eu", "N/A")) + ) + best_config_table.add_row("kpack", str(best_config.get("kpack", "N/A"))) + best_config_table.add_row( + "cache_modifier", str(best_config.get("cache_modifier", "N/A")) + ) + best_config_table.add_row( + "GROUP_SIZE_M", str(best_config.get("GROUP_SIZE_M", "N/A")) + ) + + # Create combined layout + if best_config_table: + # Display tables side by side + tables = Columns([config_table, best_config_table], equal=True, expand=True) + layout.split_column( + Layout(Panel(progress, title="Progress", border_style="blue"), size=5), + Layout(Panel(status_text, title="Status", border_style="green"), size=3), + Layout(tables), + Layout( + Panel(console_table, title="Console Output", border_style="cyan"), + size=10, + ), + ) + else: + # Display only current config + layout.split_column( + Layout(Panel(progress, title="Progress", border_style="blue"), size=5), + Layout(Panel(status_text, title="Status", border_style="green"), size=3), + Layout(config_table), + Layout( + Panel(console_table, title="Console Output", border_style="cyan"), + size=10, + ), + ) + + return layout + + +def tune( + M: int, + N: int, + K: int, + search_space: List[Dict[str, int | str]], + input_type: str, + logger: logging.Logger, +): + """Tune the kernel for specific matrix dimensions.""" + # Register SIGINT handler if not already registered + if not signal.getsignal(signal.SIGINT) == sigint_handler: + signal.signal(signal.SIGINT, sigint_handler) + + if input_type == "bfloat16": + # Use the same input generation as test file + x, w, x_scale, w_scale, bias, y = generate_gemm_a8w8_per_token_scale_inputs( + M, N, K, torch.bfloat16, bias=True + ) + else: + raise RuntimeError( + "Currently, only support tune a8w8 per-token scale kernel with bfloat16 output." + ) + + best_config: Dict[str, Union[int, str]] = {} + best_time = float("inf") + slow_config_threshold = ( + 1000 # microseconds - configs slower than this get highlighted + ) + + # Clear console output for fresh start + global console_output + console_output = [] + + # Initialize Rich console for better formatting + console = Console() + + # Create initial live display + initial_layout = create_live_display( + M, + N, + K, + search_space[0] if search_space else {}, + best_config, + best_time, + 0, + len(search_space), + console_output, + ) + + # Create progress display with Rich and Live + with Live(initial_layout, refresh_per_second=4, console=console) as live: + for i, config in enumerate(search_space): + # Check if we were interrupted + if INTERRUPTED: + break + + # Update global state for SIGINT handling + CURRENT_CONFIG.update( + { + "M": M, + "N": N, + "K": K, + "config": config, + "config_index": i, + "total_configs": len(search_space), + "batch_size": M, + } + ) + + # Update live display with current configuration + layout = create_live_display( + M, + N, + K, + config, + best_config, + best_time, + i + 1, + len(search_space), + console_output, + ) + live.update(layout) + + try: + kernel_time = benchmark_config( + x=x, + w=w, + x_scale=x_scale, + w_scale=w_scale, + bias=bias, + dtype=torch.bfloat16, + y=None, + config=config, + num_iters=10, + ) + + # Add kernel time to console output + if kernel_time > slow_config_threshold: + console_msg = f"[yellow]โš ๏ธ Config {i + 1} (SLOW): {kernel_time:.1f}ฮผs | BM:{config.get('BLOCK_SIZE_M')} BN:{config.get('BLOCK_SIZE_N')} BK:{config.get('BLOCK_SIZE_K')} W:{config.get('num_warps')} S:{config.get('num_stages')} KS:{config.get('NUM_KSPLIT')}[/yellow]" + live.console.print(console_msg) + console_output.append(console_msg) + else: + console_msg = f"[green]โœ… Config {i + 1}: {kernel_time:.1f}ฮผs | BM:{config.get('BLOCK_SIZE_M')} BN:{config.get('BLOCK_SIZE_N')} BK:{config.get('BLOCK_SIZE_K')} W:{config.get('num_warps')} S:{config.get('num_stages')} KS:{config.get('NUM_KSPLIT')}[/green]" + console_output.append(console_msg) + + # Update best time and config + if kernel_time < best_time: + best_time = kernel_time + best_config = config + best_msg = ( + f"[bold green]๐Ÿ† NEW BEST: {kernel_time:.1f}ฮผs![/bold green]" + ) + console_output.append(best_msg) + + # Update live display with current configuration and console output + layout = create_live_display( + M, + N, + K, + config, + best_config, + best_time, + i + 1, + len(search_space), + console_output, + ) + live.update(layout) + + except triton.runtime.autotuner.OutOfResources as e: + # Log and skip out of resources configurations + log_bad_config(logger, "out_of_resources", M, N, K, config, str(e)) + error_msg = f"[red]โŒ Config {i + 1} (OOM): Out of resources | BM:{config.get('BLOCK_SIZE_M')} BN:{config.get('BLOCK_SIZE_N')} BK:{config.get('BLOCK_SIZE_K')} W:{config.get('num_warps')} S:{config.get('num_stages')} KS:{config.get('NUM_KSPLIT')}[/red]" + console_output.append(error_msg) + live.console.print(error_msg) + continue + except AssertionError as e: + # Log and skip assert error configurations + log_bad_config(logger, "assert_error", M, N, K, config, str(e)) + error_msg = f"[red]โŒ Config {i + 1} (ASSERT): Assert error | BM:{config.get('BLOCK_SIZE_M')} BN:{config.get('BLOCK_SIZE_N')} BK:{config.get('BLOCK_SIZE_K')} W:{config.get('num_warps')} S:{config.get('num_stages')} KS:{config.get('NUM_KSPLIT')} | {e}[/red]" + console_output.append(error_msg) + live.console.print(error_msg) + continue + except TimeoutError as e: + # Log and skip timeout configurations + log_bad_config(logger, "timeout", M, N, K, config, str(e)) + error_msg = f"[orange1]โฑ๏ธ Config {i + 1} (TIMEOUT): {e} | BM:{config.get('BLOCK_SIZE_M')} BN:{config.get('BLOCK_SIZE_N')} BK:{config.get('BLOCK_SIZE_K')} W:{config.get('num_warps')} S:{config.get('num_stages')} KS:{config.get('NUM_KSPLIT')}[/orange1]" + console_output.append(error_msg) + live.console.print(error_msg) + continue + except Exception as e: + # Log and skip other error configurations + log_bad_config(logger, "other_error", M, N, K, config, str(e)) + error_msg = f"[red]๐Ÿ’ฅ Config {i + 1} (ERROR): {e} | BM:{config.get('BLOCK_SIZE_M')} BN:{config.get('BLOCK_SIZE_N')} BK:{config.get('BLOCK_SIZE_K')} W:{config.get('num_warps')} S:{config.get('num_stages')} KS:{config.get('NUM_KSPLIT')}[/red]" + console_output.append(error_msg) + live.console.print(error_msg) + continue + + # Show final completion message with Rich + print("\n" + "=" * 70) + + # Create best config table with matrix dimensions + best_table = Table( + title="๐Ÿ† Best Configuration Found", show_header=True, header_style="bold green" + ) + best_table.add_column("Parameter", style="cyan", width=15) + best_table.add_column("Value", style="green", width=10) + + # Add matrix dimensions and batch size first + best_table.add_row( + "[bold yellow]Matrix M (Batch)[/bold yellow]", str(M), style="yellow" + ) + best_table.add_row("[bold yellow]Matrix N[/bold yellow]", str(N), style="yellow") + best_table.add_row("[bold yellow]Matrix K[/bold yellow]", str(K), style="yellow") + best_table.add_row("", "") # Separator + best_table.add_row( + "[bold green]Performance[/bold green]", f"{best_time:.1f}ฮผs", style="green" + ) + best_table.add_row("", "") # Separator + best_table.add_row( + "[bold yellow]BLOCK_SIZE_M[/bold yellow]", + str(best_config.get("BLOCK_SIZE_M", "N/A")), + style="yellow", + ) + best_table.add_row( + "[bold yellow]BLOCK_SIZE_N[/bold yellow]", + str(best_config.get("BLOCK_SIZE_N", "N/A")), + style="yellow", + ) + best_table.add_row( + "[bold yellow]BLOCK_SIZE_K[/bold yellow]", + str(best_config.get("BLOCK_SIZE_K", "N/A")), + style="yellow", + ) + best_table.add_row("num_warps", str(best_config.get("num_warps", "N/A"))) + best_table.add_row("num_stages", str(best_config.get("num_stages", "N/A"))) + best_table.add_row("NUM_KSPLIT", str(best_config.get("NUM_KSPLIT", "N/A"))) + best_table.add_row("waves_per_eu", str(best_config.get("waves_per_eu", "N/A"))) + best_table.add_row("kpack", str(best_config.get("kpack", "N/A"))) + best_table.add_row("cache_modifier", str(best_config.get("cache_modifier", "N/A"))) + best_table.add_row("GROUP_SIZE_M", str(best_config.get("GROUP_SIZE_M", "N/A"))) + best_table.add_row( + "matrix_instr_nonkdim", str(best_config.get("matrix_instr_nonkdim", "N/A")) + ) + + completion_panel = Panel( + best_table, + title=f"[bold green]โœ… Completed Tuning for M={M} N={N} K={K} (Batch Size={M})[/bold green]", + border_style="green", + ) + console.print(completion_panel) + print("=" * 70) + + assert best_config is not None + return best_config + + +def save_configs( + N, + K, + configs, + save_path, + is_incremental=False, + completed_batch_sizes=None, +) -> None: + """Save the best configurations to a JSON file.""" + os.makedirs(save_path, exist_ok=True) + device_name = "R9700" # TODO: Hardcoded, make it dynamic + + if is_incremental: + # Save incremental progress with batch size info in filename + batch_sizes_str = ( + "_".join(map(str, completed_batch_sizes)) + if completed_batch_sizes + else "partial" + ) + json_file_name = f"{device_name}-GEMM-A8W8_PER_TOKEN_SCALE-N={N}-K={K}_batch_{batch_sizes_str}.json" + progress_file = os.path.join( + save_path, + f"{device_name}-GEMM-A8W8_PER_TOKEN_SCALE-N={N}-K={K}_progress.json", + ) + + # Save progress info + progress_info = { + "completed_batch_sizes": completed_batch_sizes or [], + "configs": configs, + "last_updated": datetime.now().isoformat(), + } + + with open(progress_file, "w") as f: + json.dump(progress_info, f, indent=4) + f.write("\n") + else: + json_file_name = f"{device_name}-GEMM-A8W8_PER_TOKEN_SCALE-N={N}-K={K}.json" + + config_file_path = os.path.join(save_path, json_file_name) + + # Add incremental flag to filename + action = "Updating incremental" if is_incremental else "Writing" + print(f"{action} config to {config_file_path}...") + + with open(config_file_path, "w") as f: + json.dump(configs, f, indent=4) + f.write("\n") + + +def load_progress(N, K, save_path): + """Load previously saved progress for a given N,K configuration.""" + device_name = "R9700" # TODO: Hardcoded, make it dynamic + progress_file = os.path.join( + save_path, f"{device_name}-GEMM-A8W8_PER_TOKEN_SCALE-N={N}-K={K}_progress.json" + ) + + if os.path.exists(progress_file): + try: + with open(progress_file, "r") as f: + progress_info = json.load(f) + return progress_info.get("completed_batch_sizes", []), progress_info.get( + "configs", {} + ) + except Exception as e: + print(f"Warning: Could not load progress file {progress_file}: {e}") + return [], {} + return [], {} + + +def tune_on_gpu( + gpu_id: int, + batch_sizes: List[int], + weight_shapes: List[Tuple[int, int]], + input_type: str, + resume: bool = True, + log_filename: Optional[str] = None, +) -> None: + """Run tuning on a specific GPU.""" + # Register SIGINT handler and set GPU ID in global state + signal.signal(signal.SIGINT, sigint_handler) + CURRENT_CONFIG["gpu_id"] = gpu_id + + torch.cuda.set_device(gpu_id) + print(f"๐Ÿš€ Starting tuning on GPU {gpu_id} with batch sizes {batch_sizes}") + + save_path = AITER_TRITON_CONFIGS_PATH + "/gemm/" + + # Setup logger for this GPU with custom or timestamped filename + if log_filename: + # Use custom filename, ensure it has .log extension + if not log_filename.endswith(".log"): + log_filename += ".log" + # If no path separator, assume it's just a filename + if "/" not in log_filename and "\\" not in log_filename: + log_filename = os.path.join(save_path, log_filename) + else: + log_filename = log_filename # Use full path as provided + else: + # Fall back to timestamped filename + log_filename = os.path.join( + save_path, + get_timestamped_filename( + f"tune_a8w8_per_token_scale_bad_configs_gpu{gpu_id}" + ), + ) + + # Choose appropriate logging mode: append for resume, overwrite for fresh start + log_mode = "a" if resume else "w" + logger = setup_logger(log_filename, mode=log_mode) + + # Log the start time in GMT+8 + gmt8 = pytz.timezone("Asia/Shanghai") + start_time_gmt8 = datetime.now(gmt8).strftime("%Y-%m-%d %H:%M:%S") + logger.info(f"=== TUNING SESSION STARTED AT {start_time_gmt8} [GMT+8] ===") + logger.info(f"GPU: {gpu_id}") + logger.info(f"Batch sizes: {batch_sizes}") + + search_space = get_configs_compute_bound() + total_configs = len(search_space) + total_tests = total_configs * len(batch_sizes) * len(weight_shapes) + + print(f" ๐Ÿ“Š Search space: {total_configs:,} configurations") + print(f" ๐ŸŽฏ Total tests to run: {total_tests:,}") + print( + f" โšก Estimated tests per weight shape: {total_configs * len(batch_sizes):,}" + ) + log_action = ( + "Appending to existing" + if resume and os.path.exists(log_filename) + else "Writing to new" + ) + print(f" ๐Ÿ“ Bad configurations will be logged to: {log_filename}") + print(f" ๐Ÿ“ Logging mode: {log_action}") + + start = time.time() + + # Collect all configs to determine the best overall config + all_configs: List[Dict[str, Dict[str, int | str]]] = [] + + for i, shape in enumerate(weight_shapes): + # Check if we were interrupted + if INTERRUPTED: + break + + # Update weight shape tracking + CURRENT_CONFIG.update( + {"weight_shape_index": i, "total_weight_shapes": len(weight_shapes)} + ) + + N, K = shape[0], shape[1] + print( + f"\n๐Ÿš€ [GPU {gpu_id}] Shape {i + 1}/{len(weight_shapes)}: Starting tuning for N:{N}, K:{K}" + ) + print( + f" ๐Ÿ“Š Testing {len(search_space):,} configurations across {len(batch_sizes)} batch sizes" + ) + + # Check for existing progress and resume from there (if resume is enabled) + if resume: + completed_batch_sizes, existing_configs = load_progress(N, K, save_path) + else: + completed_batch_sizes, existing_configs = [], {} + + # Filter batch_sizes to only those not yet completed + remaining_batch_sizes = [ + bs for bs in batch_sizes if bs not in completed_batch_sizes + ] + + if completed_batch_sizes and resume: + print(f"\n ๐Ÿ“‚ [GPU {gpu_id}] Found progress for N={N}, K={K}") + print(f" โœ… Already completed batch sizes: {completed_batch_sizes}") + print(f" ๐Ÿ”„ Remaining batch sizes to tune: {remaining_batch_sizes}") + elif not resume: + print( + f"\n ๐Ÿ”„ [GPU {gpu_id}] Starting fresh (resume disabled) for N={N}, K={K}" + ) + elif not remaining_batch_sizes: + print( + f"\n โœ… [GPU {gpu_id}] All batch sizes already completed for N={N}, K={K}" + ) + # Add existing configs to all_configs and continue to next shape + if existing_configs: + all_configs.append(existing_configs) + save_configs(N, K, existing_configs, save_path) + continue + + # Initialize benchmark_results with existing results if any + benchmark_results :List[Dict[str, str |int]]= [] + if existing_configs: + # Reconstruct benchmark_results from existing configs + # We need to map the configs back to their corresponding batch sizes + for i, batch_size in enumerate(batch_sizes): + if batch_size in completed_batch_sizes: + # Find the config for this batch size + config_to_add = None + + # Try to find matching config based on batch size category + if batch_size < 32 and "small" in existing_configs: + config_to_add = existing_configs["small"] + elif batch_size <= 128: + BLK_M = triton.next_power_of_2(batch_size) + if BLK_M == 32 and "medium_M32" in existing_configs: + config_to_add = existing_configs["medium_M32"] + elif BLK_M == 64 and "medium_M64" in existing_configs: + config_to_add = existing_configs["medium_M64"] + elif BLK_M == 128 and "medium_M128" in existing_configs: + config_to_add = existing_configs["medium_M128"] + elif batch_size <= 256 and "large" in existing_configs: + config_to_add = existing_configs["large"] + elif batch_size > 256 and "xlarge" in existing_configs: + config_to_add = existing_configs["xlarge"] + + if config_to_add: + benchmark_results.append(config_to_add) + else: + # If we couldn't find a matching config, we'll need to retune this batch size + remaining_batch_sizes.append(batch_size) + + for batch_size in remaining_batch_sizes: + # Check if we were interrupted + if INTERRUPTED: + break + + print( + f"\n ๐Ÿ” [GPU {gpu_id}] Testing batch size M={batch_size} for N={N}, K={K}" + ) + result = tune( + batch_size, + N, + K, + search_space, + input_type, + logger, + ) + + # Check if tune() was interrupted + if INTERRUPTED: + break + + benchmark_results.append(result) + + # Save incremental progress immediately after each batch size + updated_completed_batch_sizes = completed_batch_sizes + [batch_size] + + # Create configs for different M size categories as expected by the kernel + incremental_configs: Dict[str, Dict[str, int | str]] = {} + for i, (M, config) in enumerate( + zip(batch_sizes[: len(benchmark_results)], benchmark_results) + ): + if i == len(batch_sizes[: len(benchmark_results)]) - 1: + incremental_configs["any"] = config + elif M < 32: + incremental_configs["small"] = config + elif M <= 128: + BLK_M = triton.next_power_of_2(M) + if BLK_M == 32: + incremental_configs["medium_M32"] = config + elif BLK_M == 64: + incremental_configs["medium_M64"] = config + elif BLK_M == 128: + incremental_configs["medium_M128"] = config + elif M <= 256: + incremental_configs["large"] = config + else: + incremental_configs["xlarge"] = config + + # Save the incremental progress + save_configs( + N, + K, + incremental_configs, + save_path, + is_incremental=True, + completed_batch_sizes=updated_completed_batch_sizes, + ) + + print(f" ๐Ÿ’พ [GPU {gpu_id}] Saved progress for batch size {batch_size}") + + # Update completed_batch_sizes for next iteration + completed_batch_sizes = updated_completed_batch_sizes + + # Create final configs for different M size categories as expected by the kernel + best_configs: Dict[str, Dict[str, int | str]] = {} + for i, (M, config) in enumerate(zip(batch_sizes, benchmark_results)): + if i == len(batch_sizes) - 1: + best_configs["any"] = config + elif M < 32: + best_configs["small"] = config + elif M <= 128: + BLK_M = triton.next_power_of_2(M) + if BLK_M == 32: + best_configs["medium_M32"] = config + elif BLK_M == 64: + best_configs["medium_M64"] = config + elif BLK_M == 128: + best_configs["medium_M128"] = config + elif M <= 256: + best_configs["large"] = config + else: + best_configs["xlarge"] = config + + # Store configs for later analysis + all_configs.append(best_configs) + + # Save the final complete config (non-incremental) + save_configs(N, K, best_configs, save_path) + + # Clean up progress file since we completed successfully + device_name = "R9700" # TODO: Hardcoded, make it dynamic + progress_file = os.path.join( + save_path, + f"{device_name}-GEMM-A8W8_PER_TOKEN_SCALE-N={N}-K={K}_progress.json", + ) + if os.path.exists(progress_file): + os.remove(progress_file) + print(f" ๐Ÿงน [GPU {gpu_id}] Cleaned up progress file for N={N}, K={K}") + + # Create a default config file (without N,K parameters) by selecting the most common config + default_config = create_default_config(all_configs) + save_default_config(default_config, save_path) + + end = time.time() + + # Log session end time in GMT+8 + gmt8 = pytz.timezone("Asia/Shanghai") + end_time_gmt8 = datetime.now(gmt8).strftime("%Y-%m-%d %H:%M:%S") + duration = end - start + logger.info(f"=== TUNING SESSION COMPLETED AT {end_time_gmt8} [GMT+8] ===") + logger.info(f"Total duration: {duration:.2f} seconds") + + # Log summary of bad configurations + log_bad_config_summary(logger, total_tests) + + print(f"Tuning on GPU {gpu_id} took {duration:.2f} seconds") + + +def create_default_config( + all_configs: List[Dict[str, Dict[str, Union[int, str]]]], +) -> Dict[str, Dict[str, Union[int, str]]]: + """Create a default config by selecting the most common config across all shapes.""" + from collections import Counter + + # Collect all configs for each category + category_configs = { + "small": [], + "medium_M32": [], + "medium_M64": [], + "medium_M128": [], + "large": [], + "xlarge": [], + "any": [], + } + + for config in all_configs: + for category, params in config.items(): + if category in category_configs: + # Convert config to a hashable tuple for counting + config_tuple = tuple(sorted(params.items())) + category_configs[category].append(config_tuple) + + # Find the most common config for each category + default_config: Dict[str, Dict[str, Union[int, str]]] = {} + for category, configs in category_configs.items(): + if configs: + most_common = Counter(configs).most_common(1)[0][0] + default_config[category] = dict(most_common) + + return default_config + + +def save_default_config( + config: Dict[str, Dict[str, Union[int, str]]], save_path: str +) -> None: + """Save the default config file (without N,K parameters).""" + os.makedirs(save_path, exist_ok=True) + device_name = "R9700" # TODO: Hardcoded, make it dynamic + json_file_name = f"{device_name}-GEMM-A8W8_PER_TOKEN_SCALE.json" + + config_file_path = os.path.join(save_path, json_file_name) + print(f"Writing default config to {config_file_path}...") + + with open(config_file_path, "w") as f: + json.dump(config, f, indent=4) + f.write("\n") + + +def distribute_batch_sizes(batch_sizes: List[int], num_gpus: int) -> List[List[int]]: + """Distribute batch sizes across available GPUs.""" + batches_per_gpu: List[List[int]] = [] + for i in range(num_gpus): + start_idx = i * len(batch_sizes) // num_gpus + end_idx = (i + 1) * len(batch_sizes) // num_gpus + batches_per_gpu.append(batch_sizes[start_idx:end_idx]) + return batches_per_gpu + + +def main(args): + print(args) + num_gpus = torch.cuda.device_count() + if num_gpus == 0: + raise RuntimeError("No GPU available for tuning") + print(f"Found {num_gpus} GPUs for parallel tuning") + + torch.cuda.init() + + if args.batch_size is None: + batch_sizes = [ + 16, # For small config + 32, # For medium_M32 config + 64, # For medium_M64 config + 128, # For medium_M128 config + 256, # For large config + 512, # For large config + 2048, # For xlarge config + 4096, # For xlarge config + ] + else: + batch_sizes = [args.batch_size] + num_gpus = 1 # If only one batch size, use only one GPU + + weight_shapes = get_weight_shapes(args.tp_size) + + batches_per_gpu = distribute_batch_sizes(batch_sizes, num_gpus) + + # Prepare arguments for each GPU process + process_args = [] + for gpu_id in range(num_gpus): + process_args.append( + ( + gpu_id, + batches_per_gpu[gpu_id], + weight_shapes, # Each GPU processes all weight shapes + args.input_type, + args.resume, + args.log_filename, + ) + ) + + # Set up signal handler for main process to gracefully terminate workers + def main_sigint_handler(signum, frame): # type: ignore + print("\n" + "=" * 80) + print("๐Ÿ›‘ MAIN PROCESS INTERRUPTED BY USER (Ctrl+C)") + print("๐Ÿ“ก Sending termination signal to worker processes...") + print("โณ Giving workers 3 seconds to log their current state...") + print("=" * 80) + # Set a flag for workers to check and give them time to cleanup + global INTERRUPTED + INTERRUPTED = True + import time + + time.sleep(3) # Give workers time to handle the signal and log + sys.exit(1) + + # Register main process signal handler + if not signal.getsignal(signal.SIGINT) == main_sigint_handler: + signal.signal(signal.SIGINT, main_sigint_handler) + + ctx = mp.get_context("spawn") + try: + with ctx.Pool(num_gpus) as pool: + pool.starmap(tune_on_gpu, process_args) + except KeyboardInterrupt: + print("\n๐Ÿ›‘ Keyboard interrupt received in main process") + print("๐Ÿ“ก Worker processes terminated") + sys.exit(1) + except Exception as e: + print(f"\nโŒ Error in main process: {e}") + sys.exit(1) + + print("Multi-GPU tuning completed") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + formatter_class=argparse.RawTextHelpFormatter, + ) + + parser.add_argument("--tp-size", "-tp", type=int, default=1) + parser.add_argument( + "--input-type", type=str, choices=["bfloat16"], default="bfloat16" + ) + parser.add_argument( + "--out-dtype", + type=str, + choices=["float32", "float16", "bfloat16", "half"], + default="bfloat16", + ) + parser.add_argument("--batch-size", type=int, required=False) + parser.add_argument( + "--log-filename", + type=str, + default=None, + help="Custom log filename (without .log extension). If not provided, timestamped filename will be used.", + ) + parser.add_argument( + "--no-resume", + action="store_true", + help="Disable resume functionality and start fresh tuning", + ) + args = parser.parse_args() + + # Convert no_resume flag to resume boolean + args.resume = not args.no_resume + + main(args) From e1eca714d4cb36d1379b427cc244b7520be31f62 Mon Sep 17 00:00:00 2001 From: big-yellow-duck Date: Wed, 10 Dec 2025 08:34:12 +0000 Subject: [PATCH 09/10] tuning for a8w8blockscale --- .../ops/triton/_triton_kernels/gemm_a16w16.py | 1 + ...00-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024.json | 105 ++- ...8W8_BLOCKSCALE-N=1024-K=1024_batch_16.json | 18 + ..._BLOCKSCALE-N=1024-K=1024_batch_16_32.json | 34 + ...OCKSCALE-N=1024-K=1024_batch_16_32_64.json | 50 + ...CALE-N=1024-K=1024_batch_16_32_64_128.json | 66 ++ ...-N=1024-K=1024_batch_16_32_64_128_256.json | 82 ++ ...024-K=1024_batch_16_32_64_128_256_512.json | 98 ++ ...=1024_batch_16_32_64_128_256_512_2048.json | 114 +++ ..._batch_16_32_64_128_256_512_2048_4096.json | 114 +++ ...8W8_BLOCKSCALE-N=1024-K=2048_batch_16.json | 18 + ..._BLOCKSCALE-N=1024-K=2048_batch_16_32.json | 34 + ...OCKSCALE-N=1024-K=2048_batch_16_32_64.json | 50 + ...CALE-N=1024-K=2048_batch_16_32_64_128.json | 66 ++ ...8W8_BLOCKSCALE-N=1024-K=2048_progress.json | 75 ++ ...00-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024.json | 107 ++- ...8W8_BLOCKSCALE-N=4096-K=1024_batch_16.json | 18 + ..._BLOCKSCALE-N=4096-K=1024_batch_16_32.json | 34 + ...OCKSCALE-N=4096-K=1024_batch_16_32_64.json | 50 + ...CALE-N=4096-K=1024_batch_16_32_64_128.json | 66 ++ ...-N=4096-K=1024_batch_16_32_64_128_256.json | 82 ++ ...096-K=1024_batch_16_32_64_128_256_512.json | 98 ++ ...=1024_batch_16_32_64_128_256_512_2048.json | 114 +++ ..._batch_16_32_64_128_256_512_2048_4096.json | 114 +++ ...MM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072.json | 86 +- ...ER_TOKEN_SCALE-N=1024-K=3072_batch_16.json | 12 +- ...TOKEN_SCALE-N=1024-K=3072_batch_16_32.json | 26 +- ...EN_SCALE-N=1024-K=3072_batch_16_32_64.json | 38 +- ...CALE-N=1024-K=3072_batch_16_32_64_128.json | 48 +- ...-N=1024-K=3072_batch_16_32_64_128_256.json | 60 +- ...024-K=3072_batch_16_32_64_128_256_512.json | 74 +- ...=3072_batch_16_32_64_128_256_512_2048.json | 88 +- ..._batch_16_32_64_128_256_512_2048_4096.json | 86 +- .../gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE.json | 86 +- aiter/ops/triton/tune_a8w8_blockscale.py | 884 ++++++++++++++---- aiter/ops/triton/tune_a8w8_per_token_scale.py | 104 +-- 36 files changed, 2633 insertions(+), 567 deletions(-) create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32_64.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32_64_128.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32_64_128_256.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32_64_128_256_512.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32_64_128_256_512_2048.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32_64_128_256_512_2048_4096.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048_batch_16.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048_batch_16_32.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048_batch_16_32_64.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048_batch_16_32_64_128.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048_progress.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32_64.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32_64_128.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32_64_128_256.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32_64_128_256_512.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32_64_128_256_512_2048.json create mode 100644 aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32_64_128_256_512_2048_4096.json diff --git a/aiter/ops/triton/_triton_kernels/gemm_a16w16.py b/aiter/ops/triton/_triton_kernels/gemm_a16w16.py index 9653e16c91..4df1d94985 100644 --- a/aiter/ops/triton/_triton_kernels/gemm_a16w16.py +++ b/aiter/ops/triton/_triton_kernels/gemm_a16w16.py @@ -274,6 +274,7 @@ def _get_config( fpath = f"{AITER_TRITON_CONFIGS_PATH}/gemm/{dev}-GEMM-A16W16-N={N}-K={K}.json" if os.path.exists(fpath): _LOGGER.info(f"Loading specific GEMM config from: {fpath}") + print('config path', fpath) with open(fpath, "r") as file: config = json.load(file) _get_config._config_dict[key] = config diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024.json index 5b8b362d82..010661e88c 100644 --- a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024.json +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024.json @@ -1,15 +1,114 @@ { - "any": { + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M64": { "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "large": { + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 8, - "num_stages": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "xlarge": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 1, + "NUM_KSPLIT": 1, "waves_per_eu": 2, + "kpack": 2, "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 1, "NUM_KSPLIT": 1, + "waves_per_eu": 8, "kpack": 2, - "cache_modifier": ".cg" + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 } } diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16.json new file mode 100644 index 0000000000..a1913185c8 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16.json @@ -0,0 +1,18 @@ +{ + "any": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32.json new file mode 100644 index 0000000000..8fd3a5c3a4 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32.json @@ -0,0 +1,34 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "any": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32_64.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32_64.json new file mode 100644 index 0000000000..6f1641f003 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32_64.json @@ -0,0 +1,50 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "any": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32_64_128.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32_64_128.json new file mode 100644 index 0000000000..a275ac85f7 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32_64_128.json @@ -0,0 +1,66 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32_64_128_256.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32_64_128_256.json new file mode 100644 index 0000000000..ca20e4c6fd --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32_64_128_256.json @@ -0,0 +1,82 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32_64_128_256_512.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32_64_128_256_512.json new file mode 100644 index 0000000000..61c2f0a956 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32_64_128_256_512.json @@ -0,0 +1,98 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "large": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "any": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32_64_128_256_512_2048.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32_64_128_256_512_2048.json new file mode 100644 index 0000000000..538ec8bdb1 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32_64_128_256_512_2048.json @@ -0,0 +1,114 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "large": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "xlarge": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "any": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32_64_128_256_512_2048_4096.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32_64_128_256_512_2048_4096.json new file mode 100644 index 0000000000..010661e88c --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32_64_128_256_512_2048_4096.json @@ -0,0 +1,114 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "large": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "xlarge": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048_batch_16.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048_batch_16.json new file mode 100644 index 0000000000..f0401b39de --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048_batch_16.json @@ -0,0 +1,18 @@ +{ + "any": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 2048, + "GROUP_K": 128, + "GROUP_N": 128 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048_batch_16_32.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048_batch_16_32.json new file mode 100644 index 0000000000..c6168c8dfe --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048_batch_16_32.json @@ -0,0 +1,34 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 2048, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "any": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 2048, + "GROUP_K": 128, + "GROUP_N": 128 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048_batch_16_32_64.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048_batch_16_32_64.json new file mode 100644 index 0000000000..852626e62c --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048_batch_16_32_64.json @@ -0,0 +1,50 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 2048, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 2048, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "any": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 2048, + "GROUP_K": 128, + "GROUP_N": 128 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048_batch_16_32_64_128.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048_batch_16_32_64_128.json new file mode 100644 index 0000000000..5cf3dd7af5 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048_batch_16_32_64_128.json @@ -0,0 +1,66 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 2048, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 2048, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 2048, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 2048, + "GROUP_K": 128, + "GROUP_N": 128 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048_progress.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048_progress.json new file mode 100644 index 0000000000..b85ccf2329 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048_progress.json @@ -0,0 +1,75 @@ +{ + "completed_batch_sizes": [ + 16, + 32, + 64, + 128 + ], + "configs": { + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 2048, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 2048, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 2048, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 2048, + "GROUP_K": 128, + "GROUP_N": 128 + } + }, + "last_updated": "2025-12-09T16:50:24.402167" +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024.json index a03fd4ef76..67408d99fb 100644 --- a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024.json +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024.json @@ -1,15 +1,114 @@ { - "any": { - "BLOCK_SIZE_M": 64, + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 8, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "large": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 1, + "NUM_KSPLIT": 1, "waves_per_eu": 2, + "kpack": 2, "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "xlarge": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 2, + "num_stages": 1, "NUM_KSPLIT": 1, + "waves_per_eu": 8, "kpack": 2, - "cache_modifier": ".cg" + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "any": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 } } diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16.json new file mode 100644 index 0000000000..530370c6e9 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16.json @@ -0,0 +1,18 @@ +{ + "any": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32.json new file mode 100644 index 0000000000..fc7e680f33 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32.json @@ -0,0 +1,34 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "any": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32_64.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32_64.json new file mode 100644 index 0000000000..07b4ee6f63 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32_64.json @@ -0,0 +1,50 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "any": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32_64_128.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32_64_128.json new file mode 100644 index 0000000000..db03b10ab6 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32_64_128.json @@ -0,0 +1,66 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32_64_128_256.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32_64_128_256.json new file mode 100644 index 0000000000..24d51459eb --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32_64_128_256.json @@ -0,0 +1,82 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32_64_128_256_512.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32_64_128_256_512.json new file mode 100644 index 0000000000..6966dfeab9 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32_64_128_256_512.json @@ -0,0 +1,98 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "large": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "any": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32_64_128_256_512_2048.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32_64_128_256_512_2048.json new file mode 100644 index 0000000000..ac35d393df --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32_64_128_256_512_2048.json @@ -0,0 +1,114 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "large": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "xlarge": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "any": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32_64_128_256_512_2048_4096.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32_64_128_256_512_2048_4096.json new file mode 100644 index 0000000000..67408d99fb --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32_64_128_256_512_2048_4096.json @@ -0,0 +1,114 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "large": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "xlarge": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "any": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072.json index e808bfd22f..8a704a8fab 100644 --- a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072.json +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072.json @@ -2,54 +2,54 @@ "small": { "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4, "NUM_KSPLIT": 1, - "waves_per_eu": 2, + "waves_per_eu": 4, "kpack": 2, "matrix_instr_nonkdim": 16, - "cache_modifier": ".cg", + "cache_modifier": "", "SPLITK_BLOCK_SIZE": 3072 }, "medium_M32": { - "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 3, "NUM_KSPLIT": 1, - "waves_per_eu": 2, + "waves_per_eu": 4, "kpack": 2, "matrix_instr_nonkdim": 16, - "cache_modifier": ".cg", + "cache_modifier": "", "SPLITK_BLOCK_SIZE": 3072 }, "medium_M64": { - "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "num_warps": 8, + "num_stages": 4, "NUM_KSPLIT": 1, - "waves_per_eu": 2, + "waves_per_eu": 4, "kpack": 2, "matrix_instr_nonkdim": 16, - "cache_modifier": ".cg", + "cache_modifier": "", "SPLITK_BLOCK_SIZE": 3072 }, "medium_M128": { - "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "num_warps": 8, + "num_stages": 4, "NUM_KSPLIT": 1, - "waves_per_eu": 2, + "waves_per_eu": 4, "kpack": 2, "matrix_instr_nonkdim": 16, "cache_modifier": "", @@ -57,41 +57,41 @@ }, "large": { "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3, "NUM_KSPLIT": 1, - "waves_per_eu": 2, + "waves_per_eu": 8, "kpack": 2, "matrix_instr_nonkdim": 16, "cache_modifier": "", "SPLITK_BLOCK_SIZE": 3072 }, "xlarge": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, "NUM_KSPLIT": 1, - "waves_per_eu": 2, + "waves_per_eu": 4, "kpack": 2, "matrix_instr_nonkdim": 16, "cache_modifier": "", "SPLITK_BLOCK_SIZE": 3072 }, "any": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, "NUM_KSPLIT": 1, - "waves_per_eu": 2, + "waves_per_eu": 4, "kpack": 2, "matrix_instr_nonkdim": 16, "cache_modifier": "", diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16.json index f39c54a983..5f96695c9d 100644 --- a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16.json +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16.json @@ -2,15 +2,15 @@ "any": { "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4, "NUM_KSPLIT": 1, - "waves_per_eu": 2, + "waves_per_eu": 4, "kpack": 2, "matrix_instr_nonkdim": 16, - "cache_modifier": ".cg", + "cache_modifier": "", "SPLITK_BLOCK_SIZE": 3072 } } diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32.json index 835e37a7f7..b1a818f586 100644 --- a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32.json +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32.json @@ -2,29 +2,29 @@ "small": { "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4, "NUM_KSPLIT": 1, - "waves_per_eu": 2, + "waves_per_eu": 4, "kpack": 2, "matrix_instr_nonkdim": 16, - "cache_modifier": ".cg", + "cache_modifier": "", "SPLITK_BLOCK_SIZE": 3072 }, "any": { - "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 3, "NUM_KSPLIT": 1, - "waves_per_eu": 2, + "waves_per_eu": 4, "kpack": 2, "matrix_instr_nonkdim": 16, - "cache_modifier": ".cg", + "cache_modifier": "", "SPLITK_BLOCK_SIZE": 3072 } } diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64.json index 7f8b5719d1..12f7ea9bdd 100644 --- a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64.json +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64.json @@ -2,43 +2,43 @@ "small": { "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4, "NUM_KSPLIT": 1, - "waves_per_eu": 2, + "waves_per_eu": 4, "kpack": 2, "matrix_instr_nonkdim": 16, - "cache_modifier": ".cg", + "cache_modifier": "", "SPLITK_BLOCK_SIZE": 3072 }, "medium_M32": { - "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 3, "NUM_KSPLIT": 1, - "waves_per_eu": 2, + "waves_per_eu": 4, "kpack": 2, "matrix_instr_nonkdim": 16, - "cache_modifier": ".cg", + "cache_modifier": "", "SPLITK_BLOCK_SIZE": 3072 }, "any": { - "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "num_warps": 8, + "num_stages": 4, "NUM_KSPLIT": 1, - "waves_per_eu": 2, + "waves_per_eu": 4, "kpack": 2, "matrix_instr_nonkdim": 16, - "cache_modifier": ".cg", + "cache_modifier": "", "SPLITK_BLOCK_SIZE": 3072 } } diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128.json index 6495d74112..e99366c468 100644 --- a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128.json +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128.json @@ -2,54 +2,54 @@ "small": { "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4, "NUM_KSPLIT": 1, - "waves_per_eu": 2, + "waves_per_eu": 4, "kpack": 2, "matrix_instr_nonkdim": 16, - "cache_modifier": ".cg", + "cache_modifier": "", "SPLITK_BLOCK_SIZE": 3072 }, "medium_M32": { - "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 3, "NUM_KSPLIT": 1, - "waves_per_eu": 2, + "waves_per_eu": 4, "kpack": 2, "matrix_instr_nonkdim": 16, - "cache_modifier": ".cg", + "cache_modifier": "", "SPLITK_BLOCK_SIZE": 3072 }, "medium_M64": { - "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "num_warps": 8, + "num_stages": 4, "NUM_KSPLIT": 1, - "waves_per_eu": 2, + "waves_per_eu": 4, "kpack": 2, "matrix_instr_nonkdim": 16, - "cache_modifier": ".cg", + "cache_modifier": "", "SPLITK_BLOCK_SIZE": 3072 }, "any": { - "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "num_warps": 8, + "num_stages": 4, "NUM_KSPLIT": 1, - "waves_per_eu": 2, + "waves_per_eu": 4, "kpack": 2, "matrix_instr_nonkdim": 16, "cache_modifier": "", diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256.json index 23396fa7b5..adfa60a35f 100644 --- a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256.json +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256.json @@ -2,54 +2,54 @@ "small": { "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4, "NUM_KSPLIT": 1, - "waves_per_eu": 2, + "waves_per_eu": 4, "kpack": 2, "matrix_instr_nonkdim": 16, - "cache_modifier": ".cg", + "cache_modifier": "", "SPLITK_BLOCK_SIZE": 3072 }, "medium_M32": { - "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 3, "NUM_KSPLIT": 1, - "waves_per_eu": 2, + "waves_per_eu": 4, "kpack": 2, "matrix_instr_nonkdim": 16, - "cache_modifier": ".cg", + "cache_modifier": "", "SPLITK_BLOCK_SIZE": 3072 }, "medium_M64": { - "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "num_warps": 8, + "num_stages": 4, "NUM_KSPLIT": 1, - "waves_per_eu": 2, + "waves_per_eu": 4, "kpack": 2, "matrix_instr_nonkdim": 16, - "cache_modifier": ".cg", + "cache_modifier": "", "SPLITK_BLOCK_SIZE": 3072 }, "medium_M128": { - "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "num_warps": 8, + "num_stages": 4, "NUM_KSPLIT": 1, - "waves_per_eu": 2, + "waves_per_eu": 4, "kpack": 2, "matrix_instr_nonkdim": 16, "cache_modifier": "", @@ -57,13 +57,13 @@ }, "any": { "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3, "NUM_KSPLIT": 1, - "waves_per_eu": 2, + "waves_per_eu": 8, "kpack": 2, "matrix_instr_nonkdim": 16, "cache_modifier": "", diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256_512.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256_512.json index 696cc82039..ff4a63e228 100644 --- a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256_512.json +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256_512.json @@ -2,54 +2,54 @@ "small": { "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4, "NUM_KSPLIT": 1, - "waves_per_eu": 2, + "waves_per_eu": 4, "kpack": 2, "matrix_instr_nonkdim": 16, - "cache_modifier": ".cg", + "cache_modifier": "", "SPLITK_BLOCK_SIZE": 3072 }, "medium_M32": { - "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 3, "NUM_KSPLIT": 1, - "waves_per_eu": 2, + "waves_per_eu": 4, "kpack": 2, "matrix_instr_nonkdim": 16, - "cache_modifier": ".cg", + "cache_modifier": "", "SPLITK_BLOCK_SIZE": 3072 }, "medium_M64": { - "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "num_warps": 8, + "num_stages": 4, "NUM_KSPLIT": 1, - "waves_per_eu": 2, + "waves_per_eu": 4, "kpack": 2, "matrix_instr_nonkdim": 16, - "cache_modifier": ".cg", + "cache_modifier": "", "SPLITK_BLOCK_SIZE": 3072 }, "medium_M128": { - "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "num_warps": 8, + "num_stages": 4, "NUM_KSPLIT": 1, - "waves_per_eu": 2, + "waves_per_eu": 4, "kpack": 2, "matrix_instr_nonkdim": 16, "cache_modifier": "", @@ -57,27 +57,27 @@ }, "large": { "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3, "NUM_KSPLIT": 1, - "waves_per_eu": 2, + "waves_per_eu": 8, "kpack": 2, "matrix_instr_nonkdim": 16, "cache_modifier": "", "SPLITK_BLOCK_SIZE": 3072 }, "any": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 3, "NUM_KSPLIT": 1, - "waves_per_eu": 2, + "waves_per_eu": 4, "kpack": 2, "matrix_instr_nonkdim": 16, "cache_modifier": "", diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256_512_2048.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256_512_2048.json index e808bfd22f..1542ef92a6 100644 --- a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256_512_2048.json +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256_512_2048.json @@ -2,54 +2,54 @@ "small": { "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4, "NUM_KSPLIT": 1, - "waves_per_eu": 2, + "waves_per_eu": 4, "kpack": 2, "matrix_instr_nonkdim": 16, - "cache_modifier": ".cg", + "cache_modifier": "", "SPLITK_BLOCK_SIZE": 3072 }, "medium_M32": { - "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 3, "NUM_KSPLIT": 1, - "waves_per_eu": 2, + "waves_per_eu": 4, "kpack": 2, "matrix_instr_nonkdim": 16, - "cache_modifier": ".cg", + "cache_modifier": "", "SPLITK_BLOCK_SIZE": 3072 }, "medium_M64": { - "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "num_warps": 8, + "num_stages": 4, "NUM_KSPLIT": 1, - "waves_per_eu": 2, + "waves_per_eu": 4, "kpack": 2, "matrix_instr_nonkdim": 16, - "cache_modifier": ".cg", + "cache_modifier": "", "SPLITK_BLOCK_SIZE": 3072 }, "medium_M128": { - "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "num_warps": 8, + "num_stages": 4, "NUM_KSPLIT": 1, - "waves_per_eu": 2, + "waves_per_eu": 4, "kpack": 2, "matrix_instr_nonkdim": 16, "cache_modifier": "", @@ -57,41 +57,41 @@ }, "large": { "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3, "NUM_KSPLIT": 1, - "waves_per_eu": 2, + "waves_per_eu": 8, "kpack": 2, "matrix_instr_nonkdim": 16, "cache_modifier": "", "SPLITK_BLOCK_SIZE": 3072 }, "xlarge": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 3, "NUM_KSPLIT": 1, - "waves_per_eu": 2, + "waves_per_eu": 4, "kpack": 2, "matrix_instr_nonkdim": 16, "cache_modifier": "", "SPLITK_BLOCK_SIZE": 3072 }, "any": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, "NUM_KSPLIT": 1, - "waves_per_eu": 2, + "waves_per_eu": 4, "kpack": 2, "matrix_instr_nonkdim": 16, "cache_modifier": "", diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256_512_2048_4096.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256_512_2048_4096.json index e808bfd22f..8a704a8fab 100644 --- a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256_512_2048_4096.json +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256_512_2048_4096.json @@ -2,54 +2,54 @@ "small": { "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4, "NUM_KSPLIT": 1, - "waves_per_eu": 2, + "waves_per_eu": 4, "kpack": 2, "matrix_instr_nonkdim": 16, - "cache_modifier": ".cg", + "cache_modifier": "", "SPLITK_BLOCK_SIZE": 3072 }, "medium_M32": { - "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 3, "NUM_KSPLIT": 1, - "waves_per_eu": 2, + "waves_per_eu": 4, "kpack": 2, "matrix_instr_nonkdim": 16, - "cache_modifier": ".cg", + "cache_modifier": "", "SPLITK_BLOCK_SIZE": 3072 }, "medium_M64": { - "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "num_warps": 8, + "num_stages": 4, "NUM_KSPLIT": 1, - "waves_per_eu": 2, + "waves_per_eu": 4, "kpack": 2, "matrix_instr_nonkdim": 16, - "cache_modifier": ".cg", + "cache_modifier": "", "SPLITK_BLOCK_SIZE": 3072 }, "medium_M128": { - "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "num_warps": 8, + "num_stages": 4, "NUM_KSPLIT": 1, - "waves_per_eu": 2, + "waves_per_eu": 4, "kpack": 2, "matrix_instr_nonkdim": 16, "cache_modifier": "", @@ -57,41 +57,41 @@ }, "large": { "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3, "NUM_KSPLIT": 1, - "waves_per_eu": 2, + "waves_per_eu": 8, "kpack": 2, "matrix_instr_nonkdim": 16, "cache_modifier": "", "SPLITK_BLOCK_SIZE": 3072 }, "xlarge": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, "NUM_KSPLIT": 1, - "waves_per_eu": 2, + "waves_per_eu": 4, "kpack": 2, "matrix_instr_nonkdim": 16, "cache_modifier": "", "SPLITK_BLOCK_SIZE": 3072 }, "any": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, "NUM_KSPLIT": 1, - "waves_per_eu": 2, + "waves_per_eu": 4, "kpack": 2, "matrix_instr_nonkdim": 16, "cache_modifier": "", diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE.json index 5df51493e2..dee783b5dc 100644 --- a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE.json +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE.json @@ -1,49 +1,49 @@ { "small": { - "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_K": 256, "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "NUM_KSPLIT": 1, "SPLITK_BLOCK_SIZE": 3072, - "cache_modifier": ".cg", + "cache_modifier": "", "kpack": 2, "matrix_instr_nonkdim": 16, - "num_stages": 1, - "num_warps": 2, - "waves_per_eu": 2 + "num_stages": 4, + "num_warps": 8, + "waves_per_eu": 4 }, "medium_M32": { - "BLOCK_SIZE_K": 64, - "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 8, "NUM_KSPLIT": 1, "SPLITK_BLOCK_SIZE": 3072, - "cache_modifier": ".cg", + "cache_modifier": "", "kpack": 2, "matrix_instr_nonkdim": 16, - "num_stages": 1, - "num_warps": 2, - "waves_per_eu": 2 + "num_stages": 3, + "num_warps": 8, + "waves_per_eu": 4 }, "medium_M64": { - "BLOCK_SIZE_K": 64, - "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "GROUP_SIZE_M": 1, "NUM_KSPLIT": 1, "SPLITK_BLOCK_SIZE": 3072, - "cache_modifier": ".cg", + "cache_modifier": "", "kpack": 2, "matrix_instr_nonkdim": 16, - "num_stages": 1, - "num_warps": 2, - "waves_per_eu": 2 + "num_stages": 4, + "num_warps": 8, + "waves_per_eu": 4 }, "medium_M128": { - "BLOCK_SIZE_K": 64, - "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "GROUP_SIZE_M": 1, "NUM_KSPLIT": 1, @@ -51,50 +51,50 @@ "cache_modifier": "", "kpack": 2, "matrix_instr_nonkdim": 16, - "num_stages": 1, - "num_warps": 2, - "waves_per_eu": 2 + "num_stages": 4, + "num_warps": 8, + "waves_per_eu": 4 }, "large": { - "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_K": 128, "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, "NUM_KSPLIT": 1, "SPLITK_BLOCK_SIZE": 3072, "cache_modifier": "", "kpack": 2, "matrix_instr_nonkdim": 16, - "num_stages": 1, - "num_warps": 2, - "waves_per_eu": 2 + "num_stages": 3, + "num_warps": 4, + "waves_per_eu": 8 }, "xlarge": { - "BLOCK_SIZE_K": 64, - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, "NUM_KSPLIT": 1, "SPLITK_BLOCK_SIZE": 3072, "cache_modifier": "", "kpack": 2, "matrix_instr_nonkdim": 16, - "num_stages": 1, - "num_warps": 2, - "waves_per_eu": 2 + "num_stages": 2, + "num_warps": 8, + "waves_per_eu": 4 }, "any": { "BLOCK_SIZE_K": 64, - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 8, "NUM_KSPLIT": 1, "SPLITK_BLOCK_SIZE": 3072, "cache_modifier": "", "kpack": 2, "matrix_instr_nonkdim": 16, - "num_stages": 1, - "num_warps": 2, - "waves_per_eu": 2 + "num_stages": 2, + "num_warps": 8, + "waves_per_eu": 4 } } diff --git a/aiter/ops/triton/tune_a8w8_blockscale.py b/aiter/ops/triton/tune_a8w8_blockscale.py index 4edd45d72a..8106fd8199 100644 --- a/aiter/ops/triton/tune_a8w8_blockscale.py +++ b/aiter/ops/triton/tune_a8w8_blockscale.py @@ -4,23 +4,20 @@ import multiprocessing as mp import os import signal +import sys import time import triton from datetime import datetime from typing import List, Dict, Union, Tuple, Optional, Any import torch +import pytz from rich.console import Console -from rich.progress import ( - Progress, - SpinnerColumn, - TextColumn, - BarColumn, - TaskProgressColumn, - TimeElapsedColumn, -) from rich.table import Table from rich.panel import Panel -from tqdm import tqdm +from rich.columns import Columns +from rich.live import Live +from rich.layout import Layout +from rich.progress import Progress, BarColumn, TextColumn from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale # type: ignore @@ -64,9 +61,38 @@ class TimeoutError(Exception): INTERRUPTED = False +class GMT8Formatter(logging.Formatter): + """Custom formatter that uses GMT+8 timezone.""" + + def __init__(self, fmt=None, datefmt=None): + super().__init__(fmt, datefmt) + self.gmt8 = pytz.timezone("Asia/Shanghai") # GMT+8 + + def formatTime(self, record, datefmt=None): + # Convert timestamp to GMT+8 + dt = datetime.fromtimestamp(record.created, tz=self.gmt8) + if datefmt: + return dt.strftime(datefmt) + else: + return dt.strftime("%Y-%m-%d %H:%M:%S") + + def format(self, record): + # Add timezone info to the formatted message + original = super().format(record) + return original.replace("[GMT+8]", "") # Remove any existing timezone tag + + +def get_timestamped_filename(base_name: str, extension: str = ".log") -> str: + """Generate a filename with timestamp in GMT+8 timezone.""" + gmt8 = pytz.timezone("Asia/Shanghai") + timestamp = datetime.now(gmt8).strftime("%Y%m%d_%H%M%S") + return f"{base_name}_{timestamp}{extension}" + + def sigint_handler(signum, frame): """Handle SIGINT (Ctrl+C) gracefully by logging the current configuration.""" global INTERRUPTED + global CURRENT_CONFIG INTERRUPTED = True print("\n" + "=" * 80) @@ -99,21 +125,55 @@ def sigint_handler(signum, frame): print(f" waves_per_eu: {config.get('waves_per_eu', 'N/A')}") print(f" kpack: {config.get('kpack', 'N/A')}") print(f" cache_modifier: {config.get('cache_modifier', 'N/A')}") + print(f" GROUP_SIZE_M: {config.get('GROUP_SIZE_M', 'N/A')}") + print(f" GROUP_K: {config.get('GROUP_K', 'N/A')}") + + # Show config in same format as console output for consistency + config_num = CURRENT_CONFIG["config_index"] + 1 + console_format = f" ๐Ÿ’ป Config {config_num} (INTERRUPTED): | BM:{config.get('BLOCK_SIZE_M')} BN:{config.get('BLOCK_SIZE_N')} BK:{config.get('BLOCK_SIZE_K')} W:{config.get('num_warps')} S:{config.get('num_stages')} KS:{config.get('NUM_KSPLIT')} kpack:{config.get('kpack')} cache:{config.get('cache_modifier')} GSM:{config.get('GROUP_SIZE_M')}" + print(console_format) # Log the interruption to the file if logger is available try: logger = logging.getLogger("gemm_a8w8_blockscale_tuning") if logger.handlers: - log_entry = { - "timestamp": datetime.now().isoformat(), + # Use GMT+8 timestamp for consistency + gmt8 = pytz.timezone("Asia/Shanghai") + + # Create detailed log entry + detailed_log_entry = { + "timestamp": datetime.now(gmt8).isoformat(), "event_type": "user_interrupt", + "gpu_id": CURRENT_CONFIG.get("gpu_id", "N/A"), "batch_size": CURRENT_CONFIG["batch_size"], "matrix_dims": f"M={CURRENT_CONFIG['M']} N={CURRENT_CONFIG['N']} K={CURRENT_CONFIG['K']}", "config": CURRENT_CONFIG["config"], "progress": f"Config {CURRENT_CONFIG['config_index'] + 1}/{CURRENT_CONFIG['total_configs']}", "weight_shape_progress": f"Shape {CURRENT_CONFIG['weight_shape_index'] + 1}/{CURRENT_CONFIG['total_weight_shapes']}", } - logger.info(f"USER_INTERRUPT: {log_entry}") + + # Log detailed interruption info + logger.info(f"=== USER INTERRUPT ===") + logger.info( + f"Interrupted while testing: Config {CURRENT_CONFIG['config_index'] + 1}/{CURRENT_CONFIG['total_configs']}" + ) + logger.info(f"GPU: {CURRENT_CONFIG.get('gpu_id', 'N/A')}") + logger.info( + f"Matrix: M={CURRENT_CONFIG['M']} N={CURRENT_CONFIG['N']} K={CURRENT_CONFIG['K']}" + ) + logger.info( + f"Weight Shape Progress: {CURRENT_CONFIG['weight_shape_index'] + 1}/{CURRENT_CONFIG['total_weight_shapes']}" + ) + + # Log config details in same format as console output for consistency + if CURRENT_CONFIG["config"]: + config = CURRENT_CONFIG["config"] + config_num = CURRENT_CONFIG["config_index"] + 1 + config_str = f"Config {config_num} (INTERRUPTED): | BM:{config.get('BLOCK_SIZE_M')} BN:{config.get('BLOCK_SIZE_N')} BK:{config.get('BLOCK_SIZE_K')} W:{config.get('num_warps')} S:{config.get('num_stages')} KS:{config.get('NUM_KSPLIT')} kpack:{config.get('kpack')} cache:{config.get('cache_modifier')} GROUP_SIZE_M:{config.get('GROUP_SIZE_M')} GROUP_K:{config.get('GROUP_K')}" + logger.info(f"CONFIG_DETAILS: {config_str}") + + logger.info(f"DETAILED_ENTRY: {detailed_log_entry}") + logger.info(f"=== END USER INTERRUPT ===") # Force flush to write immediately for handler in logger.handlers: @@ -136,12 +196,13 @@ def sigint_handler(signum, frame): sys.exit(1) -def setup_logger(log_file_path: str) -> logging.Logger: +def setup_logger(log_file_path: str, mode: str = "a") -> logging.Logger: """ Setup logger for recording bad configurations during tuning. Args: log_file_path: Path to the log file + mode: File write mode - 'a' to append to existing logs, 'w' to overwrite Returns: Configured logger instance @@ -153,7 +214,8 @@ def setup_logger(log_file_path: str) -> logging.Logger: logger.handlers.clear() # Create file handler with live writing (immediate flush) - file_handler = logging.FileHandler(log_file_path, mode="w") + # Default to append mode to preserve logs across resume sessions + file_handler = logging.FileHandler(log_file_path, mode=mode) file_handler.setLevel(logging.INFO) # Create custom formatter that flushes immediately @@ -163,9 +225,9 @@ def setup_logger(log_file_path: str) -> logging.Logger: console_handler = logging.StreamHandler() console_handler.setLevel(logging.WARNING) - # Create formatter - formatter = logging.Formatter( - "%(asctime)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S" + # Create GMT+8 formatter + formatter = GMT8Formatter( + "%(asctime)s [GMT+8] - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S" ) file_handler.setFormatter(formatter) console_handler.setFormatter(formatter) @@ -196,8 +258,10 @@ def log_bad_config( config: Configuration that failed error_msg: Additional error message """ + # Use GMT+8 timestamp for consistency + gmt8 = pytz.timezone("Asia/Shanghai") log_entry = { - "timestamp": datetime.now().isoformat(), + "timestamp": datetime.now(gmt8).isoformat(), "error_type": error_type, "batch_size": M, "matrix_dims": f"M={M} N={N} K={K}", @@ -373,63 +437,48 @@ def get_configs_compute_bound() -> List[Dict[str, int | str]]: """ Generate configuration space for tuning the gemm_a8w8_blockscale kernel. Based on the sample config file, we'll tune around those values. - Note: GROUP_K must equal BLOCK_SIZE_K as required by the kernel. + Note: With (128, 128) quantization blocks, GROUP_K will be computed as 128, + so we must only use BLOCK_SIZE_K = 128 to satisfy the constraint GROUP_K == BLOCK_SIZE_K. """ configs = [] - # Start with the known working configuration from MI300X-GEMM-A8W8_BLOCKSCALE.json - base_config = { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 2, - "matrix_instr_nonkdim": 16, - "NUM_KSPLIT": 1, - "kpack": 2, - "cache_modifier": ".cg", - } - # Add the base config first (known to work) - configs.append(base_config.copy()) - - # Generate variations around the base config, but be conservative - for block_m in [ - 32, - 64, - 128, - ]: - for block_n in [ - 32, - 64, - 128, - ]: - for block_k in [64, 128]: # Keep as power of 2 - for num_warps in [4, 8]: - for num_stages in [2, 3, 4, 5]: - for waves_per_eu in [2, 4, 8]: - for cache_modifier in [ - ".cg", - "", - ]: # Start with cache modifier - config = { - "BLOCK_SIZE_M": block_m, - "BLOCK_SIZE_N": block_n, - "BLOCK_SIZE_K": block_k, - "GROUP_K": block_k, - "GROUP_SIZE_M": 1, # Keep fixed for now - "num_warps": num_warps, - "num_stages": num_stages, - "waves_per_eu": waves_per_eu, # Keep fixed for now - "matrix_instr_nonkdim": 16, - "NUM_KSPLIT": 1, - "kpack": 2, # Keep fixed for now - "cache_modifier": cache_modifier, - } - configs.append(config) - - print(f"Generated {len(configs)} configurations") + # For blockscale kernel with (128, 128) quantization blocks: + # - scale_k = ceil(K / 128) = 8 for K=1024 + # - GROUP_K = next_power_of_2(K / scale_k) = next_power_of_2(128) = 128 + # - BLOCK_SIZE_K must equal GROUP_K, so BLOCK_SIZE_K must be 128 + block_k = 128 # Fixed to match computed GROUP_K + + # Explore optimized parameter space for blockscale kernel + for num_stages in [1, 2, 3, 4]: + for block_m in [32, 64, 128]: # Removed 256 (causes slowdowns) + for block_n in [32, 64, 128]: # Removed 256 (causes slowdowns) + for group_size_m in [1, 8, 16]: + for num_warps in [2, 4, 8]: + for num_ksplit in [ + 1, + 2, + 4, + ]: # Key parameter for K-splitting + for waves_per_eu in [2, 4, 8]: + for kpack in [2]: + for cache_modifier in ["", ".cg"]: + # Note: GROUP_K and GROUP_N are computed by the kernel function + # from the scale tensor shapes, not hardcoded in config + configs.append( + { + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, # Must be 128 to match computed GROUP_K + "GROUP_SIZE_M": group_size_m, + "num_warps": num_warps, + "num_stages": num_stages, + "NUM_KSPLIT": num_ksplit, + "waves_per_eu": waves_per_eu, + "kpack": kpack, + "matrix_instr_nonkdim": 16, # Fixed value from atomic kernel + "cache_modifier": cache_modifier, + } + ) return configs @@ -450,14 +499,23 @@ def get_weight_shapes(tp_size: int = 1) -> List[Tuple[int, int]]: return weight_shapes -def run_torch( - x, weight, x_scale, w_scale, block_shape: Tuple[int, int], dtype=torch.bfloat16 +def run_torch_reference( + x: torch.Tensor, + w: torch.Tensor, + x_scale: torch.Tensor, + w_scale: torch.Tensor, + block_shape: Tuple[int, int], + dtype=torch.bfloat16, ): + """ + Run reference implementation using PyTorch for blockscale kernel. + This is used for correctness verification. + """ block_shape_n, block_shape_k = block_shape m, k = x.shape - n = weight.shape[0] + n = w.shape[0] - # Expand scales to match the block sizes + # Expand scales to match the block sizes (same as original) x_scale_expanded = x_scale.repeat_interleave(block_shape_k, dim=1) x_dequant = x.to(x_scale_expanded.dtype) * x_scale_expanded[:m, :k] @@ -465,7 +523,7 @@ def run_torch( w_scale_expanded = w_scale.repeat_interleave(block_shape_n, dim=0) w_scale_expanded = w_scale_expanded.repeat_interleave(block_shape_k, dim=1) w_scale_expanded = w_scale_expanded[:n, :k] - weight_dequant = weight.to(w_scale_expanded.dtype) * w_scale_expanded + weight_dequant = w.to(w_scale_expanded.dtype) * w_scale_expanded out = torch.nn.functional.linear( x_dequant.to(torch.float32), weight_dequant.to(torch.float32) @@ -474,6 +532,169 @@ def run_torch( return out.to(dtype) +# Global variable to store console output +console_output = [] + + +def create_live_display( + M: int, + N: int, + K: int, + current_config: Dict[str, Union[str, int]], + best_config: Dict[str, Union[str, int]], + best_time: float, + config_index: int, + total_configs: int, + console_messages: Optional[List[str]] = None, +) -> Layout: + """Create a live display layout with current and best configuration tables.""" + + layout = Layout() + + # Use global console_output if none provided + if console_messages is None: + global console_output + console_messages = console_output + + # Create progress bar + progress = Progress( + TextColumn("[bold blue]{task.description}"), + BarColumn(bar_width=40), + TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), + TextColumn("({task.completed}/{task.total})"), + ) + task = progress.add_task( + f"๐Ÿ”ง Tuning M={M} N={N} K={K} | Batch Size={M}", + total=total_configs, + completed=config_index, + ) + + # Create status information + status_text = "" + if best_time != float("inf"): + status_text = f"๐Ÿ† Best Performance: {best_time:.1f}ฮผs" + else: + status_text = "๐Ÿ” Searching for best configuration..." + + # Create console area + console_text = ( + "\n".join(console_messages[-10:]) + if console_messages + else "Waiting for results..." + ) + console_table = Table(show_header=False, box=None, padding=0) + console_table.add_column("Output", style="white") + console_table.add_row(console_text) + + # Create current config table + config_table = Table( + title="Current Configuration", + show_header=True, + header_style="bold magenta", + ) + config_table.add_column("Parameter", style="cyan", width=15) + config_table.add_column("Value", style="green", width=10) + + # Add matrix dimensions and batch size first + config_table.add_row("[bold yellow]Matrix M[/bold yellow]", str(M), style="yellow") + config_table.add_row("[bold yellow]Matrix N[/bold yellow]", str(N), style="yellow") + config_table.add_row("[bold yellow]Matrix K[/bold yellow]", str(K), style="yellow") + config_table.add_row( + "[bold yellow]Batch Size[/bold yellow]", str(M), style="yellow" + ) + config_table.add_row("", "") # Separator + config_table.add_row("BLOCK_SIZE_M", str(current_config.get("BLOCK_SIZE_M", "N/A"))) + config_table.add_row("BLOCK_SIZE_N", str(current_config.get("BLOCK_SIZE_N", "N/A"))) + config_table.add_row("BLOCK_SIZE_K", str(current_config.get("BLOCK_SIZE_K", "N/A"))) + config_table.add_row("num_warps", str(current_config.get("num_warps", "N/A"))) + config_table.add_row("num_stages", str(current_config.get("num_stages", "N/A"))) + config_table.add_row("NUM_KSPLIT", str(current_config.get("NUM_KSPLIT", "N/A"))) + config_table.add_row("waves_per_eu", str(current_config.get("waves_per_eu", "N/A"))) + config_table.add_row("kpack", str(current_config.get("kpack", "N/A"))) + config_table.add_row( + "cache_modifier", str(current_config.get("cache_modifier", "N/A")) + ) + config_table.add_row("GROUP_SIZE_M", str(current_config.get("GROUP_SIZE_M", "N/A"))) + config_table.add_row("GROUP_K", str(current_config.get("GROUP_K", "N/A"))) + + # Create best config table if we have a best configuration + best_config_table = None + if best_time != float("inf"): + best_config_table = Table( + title="๐Ÿ† Best Configuration So Far", + show_header=True, + header_style="bold green", + ) + best_config_table.add_column("Parameter", style="cyan", width=15) + best_config_table.add_column("Value", style="green", width=10) + + # Add performance and matrix dimensions + best_config_table.add_row( + "[bold green]Performance[/bold green]", f"{best_time:.1f}ฮผs", style="green" + ) + best_config_table.add_row("", "") # Separator + best_config_table.add_row( + "[bold yellow]BLOCK_SIZE_M[/bold yellow]", + str(best_config.get("BLOCK_SIZE_M", "N/A")), + style="yellow", + ) + best_config_table.add_row( + "[bold yellow]BLOCK_SIZE_N[/bold yellow]", + str(best_config.get("BLOCK_SIZE_N", "N/A")), + style="yellow", + ) + best_config_table.add_row( + "[bold yellow]BLOCK_SIZE_K[/bold yellow]", + str(best_config.get("BLOCK_SIZE_K", "N/A")), + style="yellow", + ) + best_config_table.add_row("num_warps", str(best_config.get("num_warps", "N/A"))) + best_config_table.add_row( + "num_stages", str(best_config.get("num_stages", "N/A")) + ) + best_config_table.add_row( + "NUM_KSPLIT", str(best_config.get("NUM_KSPLIT", "N/A")) + ) + best_config_table.add_row( + "waves_per_eu", str(best_config.get("waves_per_eu", "N/A")) + ) + best_config_table.add_row("kpack", str(best_config.get("kpack", "N/A"))) + best_config_table.add_row( + "cache_modifier", str(best_config.get("cache_modifier", "N/A")) + ) + best_config_table.add_row( + "GROUP_SIZE_M", str(best_config.get("GROUP_SIZE_M", "N/A")) + ) + best_config_table.add_row("GROUP_K", str(best_config.get("GROUP_K", "N/A"))) + + # Create combined layout + if best_config_table: + # Display tables side by side + tables = Columns([config_table, best_config_table], equal=True, expand=True) + layout.split_column( + Layout(Panel(progress, title="Progress", border_style="blue"), size=5), + Layout(Panel(status_text, title="Status", border_style="green"), size=3), + Layout(tables), + Layout( + Panel(console_table, title="Console Output", border_style="cyan"), + size=10, + ), + ) + else: + # Display only current config + layout.split_column( + Layout(Panel(progress, title="Progress", border_style="blue"), size=5), + Layout(Panel(status_text, title="Status", border_style="green"), size=3), + Layout(config_table), + Layout( + Panel(console_table, title="Console Output", border_style="cyan"), + size=10, + ), + ) + + return layout + + def benchmark_config( x: torch.Tensor, w: torch.Tensor, @@ -511,19 +732,33 @@ def benchmark_config( JIT compilation and GPU warmup effects. The timing is measured using CUDA events for accurate GPU kernel timing. """ - - torch_out = run_torch( + # Get reference output for correctness verification + torch_out = run_torch_reference( x, w, x_scale, w_scale, - (128, 128), # follow test using (128,128) - dtype, + (128, 128), + dtype, # follow test using (128,128) ) + # Add SPLITK_BLOCK_SIZE computation as done in the kernel function + _, K = x.shape + _, K = w.shape + num_ksplit = int(config["NUM_KSPLIT"]) + block_k = int(config["BLOCK_SIZE_K"]) + splitk_block_size = triton.cdiv(K, num_ksplit) + + config["SPLITK_BLOCK_SIZE"] = splitk_block_size + if block_k > splitk_block_size: + block_k = triton.next_power_of_2(splitk_block_size) + if block_k > splitk_block_size: + block_k = block_k // 4 + block_k = max(block_k, 16) + config["BLOCK_SIZE_K"] = block_k + # Run kernel def run(): - # Pass the modified config to the kernel return gemm_a8w8_blockscale( x, w, x_scale, w_scale, dtype, y, config, skip_reduce=False ) @@ -566,6 +801,7 @@ def run_warmup(): latencies.append(start_event.elapsed_time(end_event)) torch.testing.assert_close(triton_out, torch_out, atol=1e-1, rtol=1e-1) avg = sum(latencies) / (num_iters * 10) * 1000 # us + return avg @@ -582,7 +818,15 @@ def tune( if not signal.getsignal(signal.SIGINT) == sigint_handler: signal.signal(signal.SIGINT, sigint_handler) - if input_type != "bfloat16": + if input_type == "bfloat16": + # Use the same input generation as test file + # IMPORTANT: Use fixed quantization block sizes (128, 128), not kernel BLOCK_SIZE_N/BLOCK_SIZE_K + # These are completely different concepts - quantization blocks vs kernel tiling + quant_block_size_n, quant_block_size_k = 128, 128 + x, w, x_scale, w_scale, _ = generate_gemm_a8w8_blockscale_inputs( + M, N, K, quant_block_size_n, quant_block_size_k, torch.bfloat16 + ) + else: raise RuntimeError( "Currently, only support tune a8w8 blockscale kernel with bfloat16 output." ) @@ -593,24 +837,28 @@ def tune( 1000 # microseconds - configs slower than this get highlighted ) + # Clear console output for fresh start + global console_output + console_output = [] + # Initialize Rich console for better formatting console = Console() - # Create progress display with Rich - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TaskProgressColumn(), - TimeElapsedColumn(), - console=console, - transient=False, # Keep progress bar visible - ) as progress: - task = progress.add_task( - f"๐Ÿ”ง Tuning M={M} N={N} K={K}", - total=len(search_space), - ) + # Create initial live display + initial_layout = create_live_display( + M, + N, + K, + search_space[0] if search_space else {}, + best_config, + best_time, + 0, + len(search_space), + console_output, + ) + # Create progress display with Rich and Live + with Live(initial_layout, refresh_per_second=4, console=console) as live: for i, config in enumerate(search_space): # Check if we were interrupted if INTERRUPTED: @@ -629,77 +877,21 @@ def tune( } ) - # Update progress - progress.update( - task, - advance=1, - description=f"๐Ÿ”ง Testing config {i + 1}/{len(search_space)}", + # Update live display with current configuration + layout = create_live_display( + M, + N, + K, + config, + best_config, + best_time, + i + 1, + len(search_space), + console_output, ) - - # Show current config (only every 10 configs to avoid flicker) - if i % 10 == 0 or i == len(search_space) - 1: - # Create fresh config table with matrix dimensions and batch size - config_table = Table( - title="Current Configuration", - show_header=True, - header_style="bold magenta", - ) - config_table.add_column("Parameter", style="cyan", width=15) - config_table.add_column("Value", style="green", width=10) - - # Add matrix dimensions and batch size first - config_table.add_row( - "[bold yellow]Matrix M[/bold yellow]", str(M), style="yellow" - ) - config_table.add_row( - "[bold yellow]Matrix N[/bold yellow]", str(N), style="yellow" - ) - config_table.add_row( - "[bold yellow]Matrix K[/bold yellow]", str(K), style="yellow" - ) - config_table.add_row( - "[bold yellow]Batch Size[/bold yellow]", str(M), style="yellow" - ) - config_table.add_row("", "") # Separator - config_table.add_row( - "BLOCK_SIZE_M", str(config.get("BLOCK_SIZE_M", "N/A")) - ) - config_table.add_row( - "BLOCK_SIZE_N", str(config.get("BLOCK_SIZE_N", "N/A")) - ) - config_table.add_row( - "BLOCK_SIZE_K", str(config.get("BLOCK_SIZE_K", "N/A")) - ) - config_table.add_row("num_warps", str(config.get("num_warps", "N/A"))) - config_table.add_row("num_stages", str(config.get("num_stages", "N/A"))) - config_table.add_row("NUM_KSPLIT", str(config.get("NUM_KSPLIT", "N/A"))) - config_table.add_row( - "waves_per_eu", str(config.get("waves_per_eu", "N/A")) - ) - - # Create summary header with all tuning parameters - header_text = f"[bold blue]๐Ÿ”ง Tuning M={M} N={N} K={K} | Batch Size={M} | Config {i + 1}/{len(search_space)}[/bold blue]" - - # Show config info (don't clear screen to avoid issues in multiprocessing) - console.print(f"\n{header_text}") - console.print(config_table) - - if best_time != float("inf"): - console.print( - f"[yellow]๐Ÿ† Best time so far: {best_time:.1f}ฮผs[/yellow]" - ) - console.print("โ”€" * 70) + live.update(layout) try: - # Use the same input generation as test file - x, w, x_scale, w_scale, _ = generate_gemm_a8w8_blockscale_inputs( - M, - N, - K, - int(config["BLOCK_SIZE_N"]), - int(config["BLOCK_SIZE_K"]), - torch.bfloat16, - ) kernel_time = benchmark_config( x=x, w=w, @@ -711,50 +903,65 @@ def tune( num_iters=10, ) - # Warn about slow configs + # Add kernel time to console output if kernel_time > slow_config_threshold: - console.print( - f"\n[bold yellow]โš ๏ธ SLOW CONFIG DETECTED: {kernel_time:.1f}ฮผs[/bold yellow]" - ) - console.print( - f"[cyan]๐Ÿ“Š Matrix: M={M} N={N} K={K} | Config:[/cyan] BM:{config.get('BLOCK_SIZE_M', 'N/A')}, BN:{config.get('BLOCK_SIZE_N', 'N/A')}, BK:{config.get('BLOCK_SIZE_K', 'N/A')}, W:{config.get('num_warps', 'N/A')}, S:{config.get('num_stages', 'N/A')}, KS:{config.get('NUM_KSPLIT', 'N/A')}" - ) + console_msg = f"[yellow]โš ๏ธ Config {i + 1} (SLOW): {kernel_time:.1f}ฮผs | BM:{config.get('BLOCK_SIZE_M')} BN:{config.get('BLOCK_SIZE_N')} BK:{config.get('BLOCK_SIZE_K')} W:{config.get('num_warps')} S:{config.get('num_stages')} KS:{config.get('NUM_KSPLIT')} GSM:{config.get('GROUP_SIZE_M')}[/yellow]" + live.console.print(console_msg) + console_output.append(console_msg) + else: + console_msg = f"[green]โœ… Config {i + 1}: {kernel_time:.1f}ฮผs | BM:{config.get('BLOCK_SIZE_M')} BN:{config.get('BLOCK_SIZE_N')} BK:{config.get('BLOCK_SIZE_K')} W:{config.get('num_warps')} S:{config.get('num_stages')} KS:{config.get('NUM_KSPLIT')} GSM:{config.get('GROUP_SIZE_M')}[/green]" + console_output.append(console_msg) # Update best time and config if kernel_time < best_time: best_time = kernel_time best_config = config + best_msg = ( + f"[bold green]๐Ÿ† NEW BEST: {kernel_time:.1f}ฮผs![/bold green]" + ) + console_output.append(best_msg) + + # Update live display with current configuration and console output + layout = create_live_display( + M, + N, + K, + config, + best_config, + best_time, + i + 1, + len(search_space), + console_output, + ) + live.update(layout) except triton.runtime.autotuner.OutOfResources as e: # Log and skip out of resources configurations log_bad_config(logger, "out_of_resources", M, N, K, config, str(e)) - console.print( - f"\n[bold red]โš ๏ธ Out of resources for M={M} N={N} K={K} - logged[/bold red]" - ) + error_msg = f"[red]โŒ Config {i + 1} (OOM): Out of resources | BM:{config.get('BLOCK_SIZE_M')} BN:{config.get('BLOCK_SIZE_N')} BK:{config.get('BLOCK_SIZE_K')} W:{config.get('num_warps')} S:{config.get('num_stages')} KS:{config.get('NUM_KSPLIT')} GSM:{config.get('GROUP_SIZE_M')}[/red]" + console_output.append(error_msg) + live.console.print(error_msg) continue except AssertionError as e: # Log and skip assert error configurations log_bad_config(logger, "assert_error", M, N, K, config, str(e)) - console.print( - f"\n[bold red]โŒ Assert error for M={M} N={N} K={K} - logged[/bold red]" - ) - console.print(f"[red]๐Ÿ’ฌ Error:[/red] {e}") + error_msg = f"[red]โŒ Config {i + 1} (ASSERT): Assert error | BM:{config.get('BLOCK_SIZE_M')} BN:{config.get('BLOCK_SIZE_N')} BK:{config.get('BLOCK_SIZE_K')} W:{config.get('num_warps')} S:{config.get('num_stages')} KS:{config.get('NUM_KSPLIT')} GSM:{config.get('GROUP_SIZE_M')} | {e}[/red]" + console_output.append(error_msg) + live.console.print(error_msg) continue except TimeoutError as e: # Log and skip timeout configurations log_bad_config(logger, "timeout", M, N, K, config, str(e)) - console.print( - f"\n[bold orange1]โฑ๏ธ TIMEOUT for M={M} N={N} K={K} - logged[/bold orange1]" - ) - console.print(f"[orange1]๐Ÿ’ฌ Timeout:[/orange1] {e}") + error_msg = f"[orange1]โฑ๏ธ Config {i + 1} (TIMEOUT): {e} | BM:{config.get('BLOCK_SIZE_M')} BN:{config.get('BLOCK_SIZE_N')} BK:{config.get('BLOCK_SIZE_K')} W:{config.get('num_warps')} S:{config.get('num_stages')} KS:{config.get('NUM_KSPLIT')} GSM:{config.get('GROUP_SIZE_M')}[/orange1]" + console_output.append(error_msg) + live.console.print(error_msg) continue except Exception as e: # Log and skip other error configurations log_bad_config(logger, "other_error", M, N, K, config, str(e)) - console.print( - f"\n[bold red]๐Ÿ’ฅ Unexpected error for M={M} N={N} K={K} - logged[/bold red]" - ) - console.print(f"[red]๐Ÿ’ฌ Error:[/red] {e}") + error_msg = f"[red]๐Ÿ’ฅ Config {i + 1} (ERROR): {e} | BM:{config.get('BLOCK_SIZE_M')} BN:{config.get('BLOCK_SIZE_N')} BK:{config.get('BLOCK_SIZE_K')} W:{config.get('num_warps')} S:{config.get('num_stages')} KS:{config.get('NUM_KSPLIT')} GSM:{config.get('GROUP_SIZE_M')}[/red]" + console_output.append(error_msg) + live.console.print(error_msg) continue # Show final completion message with Rich @@ -774,14 +981,36 @@ def tune( best_table.add_row("[bold yellow]Matrix N[/bold yellow]", str(N), style="yellow") best_table.add_row("[bold yellow]Matrix K[/bold yellow]", str(K), style="yellow") best_table.add_row("", "") # Separator - best_table.add_row("Performance", f"{best_time:.1f}ฮผs") - best_table.add_row("BLOCK_SIZE_M", str(best_config.get("BLOCK_SIZE_M", "N/A"))) - best_table.add_row("BLOCK_SIZE_N", str(best_config.get("BLOCK_SIZE_N", "N/A"))) - best_table.add_row("BLOCK_SIZE_K", str(best_config.get("BLOCK_SIZE_K", "N/A"))) + best_table.add_row( + "[bold green]Performance[/bold green]", f"{best_time:.1f}ฮผs", style="green" + ) + best_table.add_row("", "") # Separator + best_table.add_row( + "[bold yellow]BLOCK_SIZE_M[/bold yellow]", + str(best_config.get("BLOCK_SIZE_M", "N/A")), + style="yellow", + ) + best_table.add_row( + "[bold yellow]BLOCK_SIZE_N[/bold yellow]", + str(best_config.get("BLOCK_SIZE_N", "N/A")), + style="yellow", + ) + best_table.add_row( + "[bold yellow]BLOCK_SIZE_K[/bold yellow]", + str(best_config.get("BLOCK_SIZE_K", "N/A")), + style="yellow", + ) best_table.add_row("num_warps", str(best_config.get("num_warps", "N/A"))) best_table.add_row("num_stages", str(best_config.get("num_stages", "N/A"))) best_table.add_row("NUM_KSPLIT", str(best_config.get("NUM_KSPLIT", "N/A"))) best_table.add_row("waves_per_eu", str(best_config.get("waves_per_eu", "N/A"))) + best_table.add_row("kpack", str(best_config.get("kpack", "N/A"))) + best_table.add_row("cache_modifier", str(best_config.get("cache_modifier", "N/A"))) + best_table.add_row("GROUP_SIZE_M", str(best_config.get("GROUP_SIZE_M", "N/A"))) + best_table.add_row("GROUP_K", str(best_config.get("GROUP_K", "N/A"))) + best_table.add_row( + "matrix_instr_nonkdim", str(best_config.get("matrix_instr_nonkdim", "N/A")) + ) completion_panel = Panel( best_table, @@ -800,25 +1029,77 @@ def save_configs( K, configs, save_path, + is_incremental=False, + completed_batch_sizes=None, ) -> None: """Save the best configurations to a JSON file.""" os.makedirs(save_path, exist_ok=True) device_name = "R9700" # TODO: Hardcoded, make it dynamic - json_file_name = f"{device_name}-GEMM-A8W8_BLOCKSCALE-N={N}-K={K}.json" + + if is_incremental: + # Save incremental progress with batch size info in filename + batch_sizes_str = ( + "_".join(map(str, completed_batch_sizes)) + if completed_batch_sizes + else "partial" + ) + json_file_name = f"{device_name}-GEMM-A8W8_BLOCKSCALE-N={N}-K={K}_batch_{batch_sizes_str}.json" + progress_file = os.path.join( + save_path, + f"{device_name}-GEMM-A8W8_BLOCKSCALE-N={N}-K={K}_progress.json", + ) + + # Save progress info + progress_info = { + "completed_batch_sizes": completed_batch_sizes or [], + "configs": configs, + "last_updated": datetime.now().isoformat(), + } + + with open(progress_file, "w") as f: + json.dump(progress_info, f, indent=4) + f.write("\n") + else: + json_file_name = f"{device_name}-GEMM-A8W8_BLOCKSCALE-N={N}-K={K}.json" config_file_path = os.path.join(save_path, json_file_name) - print(f"Writing best config to {config_file_path}...") + + # Add incremental flag to filename + action = "Updating incremental" if is_incremental else "Writing" + print(f"{action} config to {config_file_path}...") with open(config_file_path, "w") as f: json.dump(configs, f, indent=4) f.write("\n") +def load_progress(N, K, save_path): + """Load previously saved progress for a given N,K configuration.""" + device_name = "R9700" # TODO: Hardcoded, make it dynamic + progress_file = os.path.join( + save_path, f"{device_name}-GEMM-A8W8_BLOCKSCALE-N={N}-K={K}_progress.json" + ) + + if os.path.exists(progress_file): + try: + with open(progress_file, "r") as f: + progress_info = json.load(f) + return progress_info.get("completed_batch_sizes", []), progress_info.get( + "configs", {} + ) + except Exception as e: + print(f"Warning: Could not load progress file {progress_file}: {e}") + return [], {} + return [], {} + + def tune_on_gpu( gpu_id: int, batch_sizes: List[int], weight_shapes: List[Tuple[int, int]], input_type: str, + resume: bool = True, + log_filename: Optional[str] = None, ) -> None: """Run tuning on a specific GPU.""" # Register SIGINT handler and set GPU ID in global state @@ -830,12 +1111,33 @@ def tune_on_gpu( save_path = AITER_TRITON_CONFIGS_PATH + "/gemm/" - # Setup logger for this GPU with proper prefix - log_file_path = os.path.join( - save_path, f"tune_a8w8_blockscale_bad_configs_gpu{gpu_id}.log" - ) - logger = setup_logger(log_file_path) - logger.info(f"Starting tuning on GPU {gpu_id} with batch sizes {batch_sizes}") + # Setup logger for this GPU with custom or timestamped filename + if log_filename: + # Use custom filename, ensure it has .log extension + if not log_filename.endswith(".log"): + log_filename += ".log" + # If no path separator, assume it's just a filename + if "/" not in log_filename and "\\" not in log_filename: + log_filename = os.path.join(save_path, log_filename) + else: + log_filename = log_filename # Use full path as provided + else: + # Fall back to timestamped filename + log_filename = os.path.join( + save_path, + get_timestamped_filename(f"tune_a8w8_blockscale_bad_configs_gpu{gpu_id}"), + ) + + # Choose appropriate logging mode: append for resume, overwrite for fresh start + log_mode = "a" if resume else "w" + logger = setup_logger(log_filename, mode=log_mode) + + # Log the start time in GMT+8 + gmt8 = pytz.timezone("Asia/Shanghai") + start_time_gmt8 = datetime.now(gmt8).strftime("%Y-%m-%d %H:%M:%S") + logger.info(f"=== TUNING SESSION STARTED AT {start_time_gmt8} [GMT+8] ===") + logger.info(f"GPU: {gpu_id}") + logger.info(f"Batch sizes: {batch_sizes}") search_space = get_configs_compute_bound() total_configs = len(search_space) @@ -846,7 +1148,13 @@ def tune_on_gpu( print( f" โšก Estimated tests per weight shape: {total_configs * len(batch_sizes):,}" ) - print(f" ๐Ÿ“ Bad configurations will be logged to: {log_file_path}") + log_action = ( + "Appending to existing" + if resume and os.path.exists(log_filename) + else "Writing to new" + ) + print(f" ๐Ÿ“ Bad configurations will be logged to: {log_filename}") + print(f" ๐Ÿ“ Logging mode: {log_action}") start = time.time() @@ -871,8 +1179,68 @@ def tune_on_gpu( f" ๐Ÿ“Š Testing {len(search_space):,} configurations across {len(batch_sizes)} batch sizes" ) - benchmark_results = [] - for batch_size in batch_sizes: + # Check for existing progress and resume from there (if resume is enabled) + if resume: + completed_batch_sizes, existing_configs = load_progress(N, K, save_path) + else: + completed_batch_sizes, existing_configs = [], {} + + # Filter batch_sizes to only those not yet completed + remaining_batch_sizes = [ + bs for bs in batch_sizes if bs not in completed_batch_sizes + ] + + if completed_batch_sizes and resume: + print(f"\n ๐Ÿ“‚ [GPU {gpu_id}] Found progress for N={N}, K={K}") + print(f" โœ… Already completed batch sizes: {completed_batch_sizes}") + print(f" ๐Ÿ”„ Remaining batch sizes to tune: {remaining_batch_sizes}") + elif not resume: + print( + f"\n ๐Ÿ”„ [GPU {gpu_id}] Starting fresh (resume disabled) for N={N}, K={K}" + ) + elif not remaining_batch_sizes: + print( + f"\n โœ… [GPU {gpu_id}] All batch sizes already completed for N={N}, K={K}" + ) + # Add existing configs to all_configs and continue to next shape + if existing_configs: + all_configs.append(existing_configs) + save_configs(N, K, existing_configs, save_path) + continue + + # Initialize benchmark_results with existing results if any + benchmark_results: List[Dict[str, str | int]] = [] + if existing_configs: + # Reconstruct benchmark_results from existing configs + # We need to map the configs back to their corresponding batch sizes + for i, batch_size in enumerate(batch_sizes): + if batch_size in completed_batch_sizes: + # Find the config for this batch size + config_to_add = None + + # Try to find matching config based on batch size category + if batch_size < 32 and "small" in existing_configs: + config_to_add = existing_configs["small"] + elif batch_size <= 128: + BLK_M = triton.next_power_of_2(batch_size) + if BLK_M == 32 and "medium_M32" in existing_configs: + config_to_add = existing_configs["medium_M32"] + elif BLK_M == 64 and "medium_M64" in existing_configs: + config_to_add = existing_configs["medium_M64"] + elif BLK_M == 128 and "medium_M128" in existing_configs: + config_to_add = existing_configs["medium_M128"] + elif batch_size <= 256 and "large" in existing_configs: + config_to_add = existing_configs["large"] + elif batch_size > 256 and "xlarge" in existing_configs: + config_to_add = existing_configs["xlarge"] + + if config_to_add: + benchmark_results.append(config_to_add) + else: + # If we couldn't find a matching config, we'll need to retune this batch size + remaining_batch_sizes.append(batch_size) + + for batch_size in remaining_batch_sizes: # Check if we were interrupted if INTERRUPTED: break @@ -894,8 +1262,49 @@ def tune_on_gpu( break benchmark_results.append(result) + + # Save incremental progress immediately after each batch size + updated_completed_batch_sizes = completed_batch_sizes + [batch_size] + + # Create configs for different M size categories as expected by the kernel + incremental_configs: Dict[str, Dict[str, int | str]] = {} + for i, (M, config) in enumerate( + zip(batch_sizes[: len(benchmark_results)], benchmark_results) + ): + if i == len(batch_sizes[: len(benchmark_results)]) - 1: + incremental_configs["any"] = config + elif M < 32: + incremental_configs["small"] = config + elif M <= 128: + BLK_M = triton.next_power_of_2(M) + if BLK_M == 32: + incremental_configs["medium_M32"] = config + elif BLK_M == 64: + incremental_configs["medium_M64"] = config + elif BLK_M == 128: + incremental_configs["medium_M128"] = config + elif M <= 256: + incremental_configs["large"] = config + else: + incremental_configs["xlarge"] = config + + # Save the incremental progress + save_configs( + N, + K, + incremental_configs, + save_path, + is_incremental=True, + completed_batch_sizes=updated_completed_batch_sizes, + ) + + print(f" ๐Ÿ’พ [GPU {gpu_id}] Saved progress for batch size {batch_size}") + + # Update completed_batch_sizes for next iteration + completed_batch_sizes = updated_completed_batch_sizes + + # Create final configs for different M size categories as expected by the kernel best_configs: Dict[str, Dict[str, int | str]] = {} - # Create configs for different M size categories as expected by the kernel for i, (M, config) in enumerate(zip(batch_sizes, benchmark_results)): if i == len(batch_sizes) - 1: best_configs["any"] = config @@ -913,20 +1322,40 @@ def tune_on_gpu( best_configs["large"] = config else: best_configs["xlarge"] = config + # Store configs for later analysis all_configs.append(best_configs) + + # Save the final complete config (non-incremental) save_configs(N, K, best_configs, save_path) + # Clean up progress file since we completed successfully + device_name = "R9700" # TODO: Hardcoded, make it dynamic + progress_file = os.path.join( + save_path, + f"{device_name}-GEMM-A8W8_BLOCKSCALE-N={N}-K={K}_progress.json", + ) + if os.path.exists(progress_file): + os.remove(progress_file) + print(f" ๐Ÿงน [GPU {gpu_id}] Cleaned up progress file for N={N}, K={K}") + # Create a default config file (without N,K parameters) by selecting the most common config default_config = create_default_config(all_configs) save_default_config(default_config, save_path) end = time.time() + # Log session end time in GMT+8 + gmt8 = pytz.timezone("Asia/Shanghai") + end_time_gmt8 = datetime.now(gmt8).strftime("%Y-%m-%d %H:%M:%S") + duration = end - start + logger.info(f"=== TUNING SESSION COMPLETED AT {end_time_gmt8} [GMT+8] ===") + logger.info(f"Total duration: {duration:.2f} seconds") + # Log summary of bad configurations log_bad_config_summary(logger, total_tests) - print(f"Tuning on GPU {gpu_id} took {end - start:.2f} seconds") + print(f"Tuning on GPU {gpu_id} took {duration:.2f} seconds") def create_default_config( @@ -1026,12 +1455,41 @@ def main(args): batches_per_gpu[gpu_id], weight_shapes, # Each GPU processes all weight shapes args.input_type, + args.resume, + args.log_filename, ) ) + # Set up signal handler for main process to gracefully terminate workers + def main_sigint_handler(signum, frame): # type: ignore + print("\n" + "=" * 80) + print("๐Ÿ›‘ MAIN PROCESS INTERRUPTED BY USER (Ctrl+C)") + print("๐Ÿ“ก Sending termination signal to worker processes...") + print("โณ Giving workers 3 seconds to log their current state...") + print("=" * 80) + # Set a flag for workers to check and give them time to cleanup + global INTERRUPTED + INTERRUPTED = True + import time + + time.sleep(3) # Give workers time to handle the signal and log + sys.exit(1) + + # Register main process signal handler + if not signal.getsignal(signal.SIGINT) == main_sigint_handler: + signal.signal(signal.SIGINT, main_sigint_handler) + ctx = mp.get_context("spawn") - with ctx.Pool(num_gpus) as pool: - pool.starmap(tune_on_gpu, process_args) + try: + with ctx.Pool(num_gpus) as pool: + pool.starmap(tune_on_gpu, process_args) + except KeyboardInterrupt: + print("\n๐Ÿ›‘ Keyboard interrupt received in main process") + print("๐Ÿ“ก Worker processes terminated") + sys.exit(1) + except Exception as e: + print(f"\nโŒ Error in main process: {e}") + sys.exit(1) print("Multi-GPU tuning completed") @@ -1052,6 +1510,20 @@ def main(args): default="bfloat16", ) parser.add_argument("--batch-size", type=int, required=False) + parser.add_argument( + "--log-filename", + type=str, + default=None, + help="Custom log filename (without .log extension). If not provided, timestamped filename will be used.", + ) + parser.add_argument( + "--no-resume", + action="store_true", + help="Disable resume functionality and start fresh tuning", + ) args = parser.parse_args() + # Convert no_resume flag to resume boolean + args.resume = not args.no_resume + main(args) diff --git a/aiter/ops/triton/tune_a8w8_per_token_scale.py b/aiter/ops/triton/tune_a8w8_per_token_scale.py index ad7419b4ed..c232a0ff94 100644 --- a/aiter/ops/triton/tune_a8w8_per_token_scale.py +++ b/aiter/ops/triton/tune_a8w8_per_token_scale.py @@ -429,60 +429,18 @@ def get_configs_compute_bound() -> List[Dict[str, int | str]]: configs = [] # Explore optimized parameter space (removed large block sizes that cause slowdowns) - # for num_stages in [1, 2, 3, 4]: - # for block_m in [32, 64, 128]: # Removed 256 (causes slowdowns) - # for block_n in [32, 64, 128]: # Removed 256 (causes slowdowns) - # for block_k in [64, 128, 256]: - # for group_size in [1, 8, 16]: - # for num_warps in [2, 4, 8]: - # for num_ksplit in [ - # 1, - # 2, - # 4, - # ]: # Key parameter for K-splitting - # for waves_per_eu in [2, 4, 8]: - # for kpack in [2]: - # for cache_modifier in ["", ".cg"]: - # configs.append( - # { - # "BLOCK_SIZE_M": block_m, - # "BLOCK_SIZE_N": block_n, - # "BLOCK_SIZE_K": block_k, - # "GROUP_SIZE_M": group_size, - # "num_warps": num_warps, - # "num_stages": num_stages, - # "NUM_KSPLIT": num_ksplit, - # "waves_per_eu": waves_per_eu, - # "kpack": kpack, - # "matrix_instr_nonkdim": 16, # Fixed value from atomic kernel - # "cache_modifier": cache_modifier, - # } - # ) - # return configs - for num_stages in [ - 1, - ]: - for block_m in [ - 32, - ]: # Removed 256 (causes slowdowns) - for block_n in [ - 32, - ]: # Removed 256 (causes slowdowns) - for block_k in [ - 64, - ]: - for group_size in [ - 1, - ]: - for num_warps in [ - 2, - ]: + for num_stages in [1, 2, 3, 4]: + for block_m in [32, 64, 128]: # Removed 256 (causes slowdowns) + for block_n in [32, 64, 128]: # Removed 256 (causes slowdowns) + for block_k in [64, 128, 256]: + for group_size in [1, 8, 16]: + for num_warps in [2, 4, 8]: for num_ksplit in [ 1, + 2, + 4, ]: # Key parameter for K-splitting - for waves_per_eu in [ - 2, - ]: + for waves_per_eu in [2, 4, 8]: for kpack in [2]: for cache_modifier in ["", ".cg"]: configs.append( @@ -501,6 +459,48 @@ def get_configs_compute_bound() -> List[Dict[str, int | str]]: } ) return configs + # for num_stages in [ + # 1, + # ]: + # for block_m in [ + # 32, + # ]: # Removed 256 (causes slowdowns) + # for block_n in [ + # 32, + # ]: # Removed 256 (causes slowdowns) + # for block_k in [ + # 64, + # ]: + # for group_size in [ + # 1, + # ]: + # for num_warps in [ + # 2, + # ]: + # for num_ksplit in [ + # 1, + # ]: # Key parameter for K-splitting + # for waves_per_eu in [ + # 2, + # ]: + # for kpack in [2]: + # for cache_modifier in ["", ".cg"]: + # configs.append( + # { + # "BLOCK_SIZE_M": block_m, + # "BLOCK_SIZE_N": block_n, + # "BLOCK_SIZE_K": block_k, + # "GROUP_SIZE_M": group_size, + # "num_warps": num_warps, + # "num_stages": num_stages, + # "NUM_KSPLIT": num_ksplit, + # "waves_per_eu": waves_per_eu, + # "kpack": kpack, + # "matrix_instr_nonkdim": 16, # Fixed value from atomic kernel + # "cache_modifier": cache_modifier, + # } + # ) + # return configs def get_weight_shapes(tp_size: int) -> List[Tuple[int, int]]: @@ -1200,7 +1200,7 @@ def tune_on_gpu( continue # Initialize benchmark_results with existing results if any - benchmark_results :List[Dict[str, str |int]]= [] + benchmark_results: List[Dict[str, str | int]] = [] if existing_configs: # Reconstruct benchmark_results from existing configs # We need to map the configs back to their corresponding batch sizes From f39d95e22d50aab944fe2c5660c6cf45a3c7e2ca Mon Sep 17 00:00:00 2001 From: Amir Balwel Date: Tue, 23 Dec 2025 04:09:29 +0000 Subject: [PATCH 10/10] Create tuning script base and tune moe Co-authored-by: Jeff Aw Signed-off-by: Amir Balwel --- .../configs/moe/gfx1201-MOE-DEFAULT.json | 32 ++ .../configs/moe/gfx1201-MOE-INT4_W4A16.json | 32 ++ .../configs/moe/gfx1201-MOE-INT8_W8A16.json | 32 ++ aiter/ops/triton/gemm_a16w16_atomic.py | 2 - aiter/ops/triton/tune/base.py | 139 ++++++++ aiter/ops/triton/tune/tune_moe_op.py | 314 ++++++++++++++++++ aiter/ops/triton/unified_attention.py | 4 +- op_tests/triton_tests/moe/test_moe.py | 3 - 8 files changed, 551 insertions(+), 7 deletions(-) create mode 100644 aiter/ops/triton/configs/moe/gfx1201-MOE-DEFAULT.json create mode 100644 aiter/ops/triton/configs/moe/gfx1201-MOE-INT4_W4A16.json create mode 100644 aiter/ops/triton/configs/moe/gfx1201-MOE-INT8_W8A16.json create mode 100644 aiter/ops/triton/tune/base.py create mode 100644 aiter/ops/triton/tune/tune_moe_op.py diff --git a/aiter/ops/triton/configs/moe/gfx1201-MOE-DEFAULT.json b/aiter/ops/triton/configs/moe/gfx1201-MOE-DEFAULT.json new file mode 100644 index 0000000000..4e2f6baa27 --- /dev/null +++ b/aiter/ops/triton/configs/moe/gfx1201-MOE-DEFAULT.json @@ -0,0 +1,32 @@ +{ + "small_M": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16 + }, + "medium_M": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16 + }, + "large_M": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16 + } +} diff --git a/aiter/ops/triton/configs/moe/gfx1201-MOE-INT4_W4A16.json b/aiter/ops/triton/configs/moe/gfx1201-MOE-INT4_W4A16.json new file mode 100644 index 0000000000..ceb9bc98a5 --- /dev/null +++ b/aiter/ops/triton/configs/moe/gfx1201-MOE-INT4_W4A16.json @@ -0,0 +1,32 @@ +{ + "small_M": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16 + }, + "medium_M": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16 + }, + "large_M": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16 + } +} diff --git a/aiter/ops/triton/configs/moe/gfx1201-MOE-INT8_W8A16.json b/aiter/ops/triton/configs/moe/gfx1201-MOE-INT8_W8A16.json new file mode 100644 index 0000000000..b6e5cffab1 --- /dev/null +++ b/aiter/ops/triton/configs/moe/gfx1201-MOE-INT8_W8A16.json @@ -0,0 +1,32 @@ +{ + "small_M": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16 + }, + "medium_M": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16 + }, + "large_M": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16 + } +} diff --git a/aiter/ops/triton/gemm_a16w16_atomic.py b/aiter/ops/triton/gemm_a16w16_atomic.py index 38341b3efb..411f9ad672 100644 --- a/aiter/ops/triton/gemm_a16w16_atomic.py +++ b/aiter/ops/triton/gemm_a16w16_atomic.py @@ -4,8 +4,6 @@ from typing import Optional import torch import triton -import triton.language as tl -import aiter.ops.triton.utils._triton.arch_info as arch_info from aiter.ops.triton._triton_kernels.gemm_a16w16_atomic import ( _gemm_a16_w16_atomic_kernel, _get_config, diff --git a/aiter/ops/triton/tune/base.py b/aiter/ops/triton/tune/base.py new file mode 100644 index 0000000000..07f6f2a869 --- /dev/null +++ b/aiter/ops/triton/tune/base.py @@ -0,0 +1,139 @@ +import json +import os +import triton +import torch +import tqdm + +def tune_kernel( + search_space, + make_run_and_gt_fn, + config_callback=None, + num_iters=10, + atol=1e-2, + rtol=1e-2, +): + """ + Args: + search_space: List of config dicts. + make_run_and_gt_fn: Callable(config) -> (run_fn, ground_truth) + config_callback: Optional function to modify config before use. + num_iters: Number of iterations for benchmarking. + atol, rtol: Tolerances for output comparison. + Returns: + The best config dict. + """ + best_config = None + best_time = float("inf") + for config in tqdm.tqdm(search_space): + if config_callback: + config_callback(config) + run, ground_truth = make_run_and_gt_fn(config) + try: + kernel_time = benchmark_config( + run, ground_truth, num_iters=num_iters, atol=atol, rtol=rtol + ) + except triton.runtime.autotuner.OutOfResources: + # print("OutOfResources encountered during tuning.") + # Some configurations may be invalid and fail to compile. + continue + except AssertionError as e: + print(f"AssertionError encountered during tuning: {e}") + continue + except Exception as e: + print(f"Config failed: {e}") + continue + if kernel_time < best_time: + best_time = kernel_time + best_config = config + assert best_config is not None + return best_config + +def benchmark_config(run, ground_truth, num_iters=10, atol=1e-1, rtol=1e-1): + """ + Args: + run: Callable that returns the kernel output when called (no arguments). + ground_truth: The expected output tensor to compare against. + num_iters: Number of iterations to benchmark. + atol: Absolute tolerance for comparison. + rtol: Relative tolerance for comparison. + Returns: + Average latency in microseconds. + """ + torch.cuda.synchronize() + # JIT compilation & warmup + for _ in range(5): + run() + torch.cuda.synchronize() + + start_event = torch.Event(enable_timing=True) + end_event = torch.Event(enable_timing=True) + + latencies: list[float] = [] + for _ in range(num_iters): + torch.cuda.synchronize() + start_event.record() + kernel_out = run() + end_event.record() + end_event.synchronize() + latencies.append(start_event.elapsed_time(end_event)) + torch.testing.assert_close(kernel_out, ground_truth, atol=atol, rtol=rtol) + avg = sum(latencies) / (num_iters * 10) * 1000 # us + return avg + + +def get_search_space(small: bool = False): + """ + Returns the search space for tuning. + Args: + small (bool): If True, returns a small search space for testing. If False, returns the full search space. + """ + configs = [] + if small: + num_stages_list = [2, 3] + block_m_list = [16, 32] + block_k_list = [64] + block_n_list = [32, 64] + num_warps_list = [4] + group_size_list = [1] + waves_per_eu_list = [3] + else: + num_stages_list = [2, 3, 4, 5] + block_m_list = [16, 32, 64, 128, 256] + block_k_list = [64, 128] + block_n_list = [32, 64, 128, 256] + num_warps_list = [4, 8] + group_size_list = [1, 16, 32, 64] + waves_per_eu_list = [2, 3, 4] + + for num_stages in num_stages_list: + for block_m in block_m_list: + for block_k in block_k_list: + for block_n in block_n_list: + for num_warps in num_warps_list: + for group_size in group_size_list: + for waves_per_eu in waves_per_eu_list: + configs.append( + { + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": group_size, + "num_warps": num_warps, + "num_stages": num_stages, + "waves_per_eu": waves_per_eu, + "matrix_instr_nonkdim": 16, + } + ) + return configs + +def save_configs_to_json( + json_file_name: str, + save_path: str, + configs: list[dict], +) -> None: + os.makedirs(save_path, exist_ok=True) + config_file_path = os.path.join(save_path, json_file_name) + + with open(config_file_path, "w") as f: + json.dump(configs, f, indent=4) + f.write("\n") \ No newline at end of file diff --git a/aiter/ops/triton/tune/tune_moe_op.py b/aiter/ops/triton/tune/tune_moe_op.py new file mode 100644 index 0000000000..2f4bbb4cae --- /dev/null +++ b/aiter/ops/triton/tune/tune_moe_op.py @@ -0,0 +1,314 @@ +import torch +import time + +from tqdm import tqdm + +from aiter.ops.triton.tune.base import ( + tune_kernel, + get_search_space, + save_configs_to_json, +) +from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH +from aiter.ops.triton.utils.types import torch_to_triton_dtype +from aiter.ops.triton.utils._triton import arch_info +from aiter.ops.triton.moe_op import fused_moe + +from op_tests.triton_tests.moe.test_moe import ( + quantize_fp8, + quantize_int8, + torch_moe_ref, + torch_moe_align_block_size_ref, +) + + +def input_helper( + M: int, + N: int, + K: int, + top_k: int, + E: int, + routed_weight: bool, + dtype, + fp8_w8a8: bool, + int8_w8a16: bool, + config: dict, +): + assert not (fp8_w8a8 and int8_w8a16) + + a = torch.randn((M, K), dtype=dtype, device="cuda") + b = torch.rand((E, N, K), dtype=dtype, device="cuda") + a_scale = None + b_scale = None + + if fp8_w8a8: + b, _, b_scale = quantize_fp8(b, dim=(0,)) + + if int8_w8a16: + b, _, b_scale = quantize_int8(b, dim=(0,)) + + b_zp = False + + c = torch.zeros((M, top_k, N), dtype=dtype, device="cuda") + c_silu = torch.zeros((M * top_k, N // 2), dtype=dtype, device="cuda") + + values = torch.randn(M, E, dtype=dtype, device="cuda") + + softmax_vals = torch.softmax(values, dim=1) + topk_weights, topk_ids = torch.topk(softmax_vals, k=top_k, dim=1) + + sorted_token_ids, expert_ids, num_tokens_post_padded = ( + torch_moe_align_block_size_ref(topk_ids, config["BLOCK_SIZE_M"], E) + ) + + return ( + a, + b, + c, + c_silu, + b_zp, + a_scale, + b_scale, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + config, + ) + + +def make_run_and_gt_fn_factory( + M, N, K, top_k, E, routed_weight, dtype, fp8_w8a8, int8_w8a16, int4_w4a16 +): + def make_run_and_gt(config): + ( + a, + b, + triton_out, + triton_out_silu, + b_zp, + a_scale, + b_scale, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + _, + ) = input_helper( + M, + N, + K, + top_k, + E, + routed_weight=routed_weight, + dtype=dtype, + fp8_w8a8=fp8_w8a8, + int8_w8a16=int8_w8a16, + config=config, + ) + + def run(): + torch.cuda.empty_cache() + fused_moe( + a, + b, + triton_out, + a_scale, + b_scale, + b_zp, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + routed_weight, + top_k, + torch_to_triton_dtype[dtype], + fp8_w8a8, + int8_w8a16, + False, + config=config, + ) + return triton_out + + torch_out = torch.empty_like(triton_out) + ground_truth = torch_moe_ref( + a, + b, + torch_out, + a_scale, + b_scale, + None, + 0, + topk_ids, + topk_weights, + routed_weight, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + dtype, + fp8_w8a8, + int8_w8a16, + False, + ) + return run, ground_truth + + return make_run_and_gt + + +def tune_fused_moe( + M, + N, + K, + top_k, + E, + routed_weight, + dtype, + fp8_w8a8, + int8_w8a16, + int4_w4a16, + search_space, +): + make_run_and_gt = make_run_and_gt_fn_factory( + M, N, K, top_k, E, routed_weight, dtype, fp8_w8a8, int8_w8a16, int4_w4a16 + ) + + best_config = tune_kernel( + search_space=search_space, + make_run_and_gt_fn=make_run_and_gt, + ) + return best_config + + +def tune_and_save_configs( + batch_sizes, + N, + K, + topk, + E, + routed_weight, + dtype, + fp8_w8a8, + int8_w8a16, + int4_w4a16, + search_space, + save_path, + device_name, + tag, +): + start = time.time() + benchmark_results = [ + tune_fused_moe( + batch_size, + N, + K, + topk, + E, + routed_weight=routed_weight, + dtype=dtype, + fp8_w8a8=fp8_w8a8, + int8_w8a16=int8_w8a16, + int4_w4a16=int4_w4a16, + search_space=search_space, + ) + for batch_size in tqdm(batch_sizes) + ] + best_configs = { + ( + "small_M" + if batch_size <= 256 + else "medium_M" + if batch_size <= 2048 + else "large_M" + ): config + for batch_size, config in zip(batch_sizes, benchmark_results) + } + json_file_name = f"{device_name}-MOE-{tag}.json" + save_configs_to_json(json_file_name, save_path, best_configs) + end = time.time() + print(f"Tuning for {tag} took {end - start:.2f} seconds") + + +def main(): + dev = arch_info.get_arch() + + torch.cuda.init() + + batch_sizes = [256, 2048, 4096] # M + N = 384 + K = 768 + topk = 8 + E = 128 + search_space = get_search_space() + save_path = AITER_TRITON_CONFIGS_PATH + "/moe/" + + # Tune for default (float16) + # tune_and_save_configs( + # batch_sizes=batch_sizes, + # N=N, + # K=K, + # topk=topk, + # E=E, + # routed_weight=False, + # dtype=torch.float16, + # fp8_w8a8=False, + # int8_w8a16=False, + # int4_w4a16=False, + # search_space=search_space, + # save_path=save_path, + # device_name=dev, + # tag="DEFAULT", + # ) + # tune_and_save_configs( + # batch_sizes=batch_sizes, + # N=N, + # K=K, + # topk=topk, + # E=E, + # routed_weight=False, + # dtype=torch.float16, + # fp8_w8a8=True, + # int8_w8a16=False, + # int4_w4a16=False, + # search_space=search_space, + # save_path=save_path, + # device_name=dev, + # tag="FP8_W8A8", + # ) + tune_and_save_configs( + batch_sizes=batch_sizes, + N=N, + K=K, + topk=topk, + E=E, + routed_weight=False, + dtype=torch.float16, + fp8_w8a8=False, + int8_w8a16=True, + int4_w4a16=False, + search_space=search_space, + save_path=save_path, + device_name=dev, + tag="INT8_W8A16", + ) + tune_and_save_configs( + batch_sizes=batch_sizes, + N=N, + K=K, + topk=topk, + E=E, + routed_weight=False, + dtype=torch.float16, + fp8_w8a8=False, + int8_w8a16=False, + int4_w4a16=True, + search_space=search_space, + save_path=save_path, + device_name=dev, + tag="INT4_W4A16", + ) + + +if __name__ == "__main__": + main() diff --git a/aiter/ops/triton/unified_attention.py b/aiter/ops/triton/unified_attention.py index 78a7879db6..1b0ecf3861 100644 --- a/aiter/ops/triton/unified_attention.py +++ b/aiter/ops/triton/unified_attention.py @@ -28,9 +28,9 @@ def select_2d_config( TILE_SIZE = 64 # in case head_size is large max_num_stages_2d = 4 - dev = arch_info.get_device() + dev = arch_info.get_arch() if head_size > 128: - if block_size >=64 and dev == "R9700": + if block_size >= 64 and dev == "gfx1201": max_num_stages_2d = 1 else: max_num_stages_2d = 2 diff --git a/op_tests/triton_tests/moe/test_moe.py b/op_tests/triton_tests/moe/test_moe.py index 240afbaec7..ce2c563a09 100644 --- a/op_tests/triton_tests/moe/test_moe.py +++ b/op_tests/triton_tests/moe/test_moe.py @@ -372,7 +372,6 @@ def quantize_int8( def quantize_int4( tensor: torch.Tensor, group_size: int, has_zp: bool ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - # reshape tensor k, n = tensor.shape tensor = tensor.reshape(-1, group_size, n) @@ -486,7 +485,6 @@ def input_helper_int4_w4a16( group_size: int, has_zp: bool, ): - a = torch.randn((M, K), dtype=dtype, device="cuda") b = torch.rand((E, N, K), dtype=dtype, device="cuda") @@ -775,7 +773,6 @@ def test_fused_moe_int4_w4a16( persistent: bool, silu_fused: bool, ): - torch.cuda.empty_cache() # Helps avoid hangs in large tests if ( M == 1