diff --git a/aiter/ops/triton/README_tune_atomic.md b/aiter/ops/triton/README_tune_atomic.md new file mode 100644 index 0000000000..eafa38498d --- /dev/null +++ b/aiter/ops/triton/README_tune_atomic.md @@ -0,0 +1,103 @@ +# GEMM A16W16 Atomic Kernel Tuning + +This document explains how to use the tuning script for the atomic GEMM kernel. + +## Overview + +The atomic GEMM kernel (`gemm_a16w16_atomic`) is a specialized kernel that uses atomic operations for split-K reduction. This allows for better parallelization in certain scenarios, especially for larger matrix dimensions. + +## Tuning Script + +The tuning script is located at `aiter/ops/triton/tune_a16w16_atomic.py`. It follows a similar pattern to the regular GEMM tuning script but with specific adaptations for the atomic kernel. + +### Key Differences from Regular GEMM Tuning + +1. **NUM_KSPLIT Parameter**: The atomic kernel includes a `NUM_KSPLIT` parameter that controls how the K dimension is split across multiple thread blocks for parallel reduction. + +2. **Configuration Categories**: The atomic kernel uses different configuration categories based on the M dimension: + - `small`: M < 32 + - `medium_M32`: M โ‰ค 128 with BLOCK_SIZE_M = 32 + - `medium_M64`: M โ‰ค 128 with BLOCK_SIZE_M = 64 + - `medium_M128`: M โ‰ค 128 with BLOCK_SIZE_M = 128 + - `large`: M โ‰ค 256 + - `xlarge`: M > 256 + +3. **No Bias Parameter**: The atomic kernel doesn't support a bias parameter. + +### Running the Tuning Script + +To run the tuning script: + +```bash +cd aiter/ops/triton +python tune_a16w16_atomic.py +``` + +You can also specify specific parameters: + +```bash +python tune_a16w16_atomic.py --batch-size 512 --input-type bfloat16 --out-dtype bfloat16 +``` + +### Output + +The tuning script will generate configuration files in the `aiter/ops/triton/configs/gemm/` directory with names like: +- `R9700-GEMM-A16W16-ATOMIC-N={N}-K={K}.json` for specific N,K dimensions +- `R9700-GEMM-A16W16-ATOMIC.json` - A default config file (without N,K parameters) that contains the most common optimal configurations across all tested shapes + +The default config file is created by analyzing all the specific N,K configurations and selecting the most common optimal configuration for each category (small, medium_M32, etc.). This provides a good general-purpose configuration when a specific N,K configuration is not available. + +### Configuration Parameters + +The tuning script searches through these parameters: +- `BLOCK_SIZE_M`: [16, 32, 64, 128] +- `BLOCK_SIZE_N`: [32, 64, 128, 256] +- `BLOCK_SIZE_K`: [64, 128] +- `GROUP_SIZE_M`: [1, 8, 16] +- `NUM_KSPLIT`: [1, 2, 4, 8] (atomic kernel specific) +- `num_warps`: [4, 8] +- `num_stages`: [2] +- `waves_per_eu`: [3] + +### Batch Sizes Tested + +The script tests these batch sizes (M dimensions): +- 16 (small) +- 32 (medium_M32) +- 64 (medium_M64) +- 128 (medium_M128) +- 256 (large) +- 512 (large) +- 2048 (xlarge) +- 4096 (xlarge) + +### Weight Shapes Tested + +The script tunes for these weight shapes (N, K): +- (1024, 1024) +- (4096, 1024) +- (1024, 2048) +- (6144, 1024) +- (1024, 3072) + +## Usage in Code + +Once the tuning is complete, the atomic kernel can be used in your code: + +```python +from aiter.ops.triton.gemm_a16w16_atomic import gemm_a16w16_atomic + +# Basic usage +x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") +w = torch.randn(N, K, dtype=torch.bfloat16, device="cuda") + +# The kernel will automatically use the tuned configurations +y = gemm_a16w16_atomic(x, w, dtype=torch.bfloat16) +``` + +## Notes + +1. The atomic kernel is particularly useful for larger matrices where split-K parallelization provides benefits. +2. For smaller matrices, the regular GEMM kernel might be more efficient. +3. The tuning process can take considerable time as it tests many configurations. +4. Make sure you have sufficient GPU memory for the tuning process. \ No newline at end of file diff --git a/aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/utils.py b/aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/utils.py index 104c88d5ba..e91bae25e4 100644 --- a/aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/utils.py +++ b/aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/utils.py @@ -632,7 +632,7 @@ def is_fp8(x) -> bool: def _is_fp8_single(t: torch.Tensor) -> bool: if is_dtype_fp8(t.dtype): arch = get_arch() - if arch not in ("gfx942", "gfx950"): + if arch not in ("gfx942", "gfx950", "gfx1201"): raise RuntimeError( f"{arch} is not in the list of supported architectures for FP8" ) diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=1024.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=1024.json new file mode 100644 index 0000000000..b5c279bfb2 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=1024.json @@ -0,0 +1,93 @@ +{ + "small": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 1024 + }, + "medium_M32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3, + "waves_per_eu": 8, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 1024 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 1024 + }, + "medium_M128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 1024 + }, + "large": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 1024 + }, + "xlarge": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 1024 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 1024 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=2048.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=2048.json new file mode 100644 index 0000000000..3ff0922c5e --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=2048.json @@ -0,0 +1,93 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 2048 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3, + "waves_per_eu": 8, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 2048 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 2048 + }, + "medium_M128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 2048 + }, + "large": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 2048 + }, + "xlarge": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 2048 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 2048 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=3072.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=3072.json new file mode 100644 index 0000000000..9c626d0fac --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=1024-K=3072.json @@ -0,0 +1,93 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 3072 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 3072 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 8, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 3072 + }, + "medium_M128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 3072 + }, + "large": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 3072 + }, + "xlarge": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 3072 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 3072 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=4096-K=1024.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=4096-K=1024.json new file mode 100644 index 0000000000..3ad52f6053 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=4096-K=1024.json @@ -0,0 +1,93 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 1024 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "waves_per_eu": 8, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 1024 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 1024 + }, + "medium_M128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 1024 + }, + "large": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 1024 + }, + "xlarge": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 1024 + }, + "any": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 1024 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=6144-K=1024.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=6144-K=1024.json new file mode 100644 index 0000000000..b0605c7435 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC-N=6144-K=1024.json @@ -0,0 +1,93 @@ +{ + "small": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 1024 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 1024 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 8, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 1024 + }, + "medium_M128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 1024 + }, + "large": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 1024 + }, + "xlarge": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 1024 + }, + "any": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 1024 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC.json new file mode 100644 index 0000000000..7d1c3825bc --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-ATOMIC.json @@ -0,0 +1,93 @@ +{ + "small": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 3072, + "cache_modifier": "", + "matrix_instr_nonkdim": 16, + "num_stages": 3, + "num_warps": 4, + "waves_per_eu": 2 + }, + "medium_M32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 3072, + "cache_modifier": "", + "matrix_instr_nonkdim": 16, + "num_stages": 2, + "num_warps": 8, + "waves_per_eu": 4 + }, + "medium_M64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 8, + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 3072, + "cache_modifier": "", + "matrix_instr_nonkdim": 16, + "num_stages": 2, + "num_warps": 8, + "waves_per_eu": 8 + }, + "medium_M128": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 3072, + "cache_modifier": "", + "matrix_instr_nonkdim": 16, + "num_stages": 2, + "num_warps": 8, + "waves_per_eu": 4 + }, + "large": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 3072, + "cache_modifier": "", + "matrix_instr_nonkdim": 16, + "num_stages": 2, + "num_warps": 8, + "waves_per_eu": 2 + }, + "xlarge": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 8, + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 3072, + "cache_modifier": "", + "matrix_instr_nonkdim": 16, + "num_stages": 2, + "num_warps": 8, + "waves_per_eu": 4 + }, + "any": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 3072, + "cache_modifier": "", + "matrix_instr_nonkdim": 16, + "num_stages": 2, + "num_warps": 8, + "waves_per_eu": 4 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=1024-K=1024.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=1024-K=1024.json new file mode 100644 index 0000000000..3f2af9911b --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=1024-K=1024.json @@ -0,0 +1,86 @@ +{ + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 1, + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 1, + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 1, + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 + }, + "M_LEQ_512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 1, + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 + }, + "M_LEQ_2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 1, + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 1, + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=1024-K=2048.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=1024-K=2048.json new file mode 100644 index 0000000000..0a62925162 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=1024-K=2048.json @@ -0,0 +1,86 @@ +{ + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 1, + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 2048 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 1, + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 2048 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 1, + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 2048 + }, + "M_LEQ_512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 1, + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 2048 + }, + "M_LEQ_2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 1, + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 2048 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 1, + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 2048 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=1024-K=3072.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=1024-K=3072.json new file mode 100644 index 0000000000..79e50ebc8d --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=1024-K=3072.json @@ -0,0 +1,86 @@ +{ + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 1, + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 3072 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 1, + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 3072 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 1, + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 3072 + }, + "M_LEQ_512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 1, + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 3072 + }, + "M_LEQ_2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 1, + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 3072 + }, + "any": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 1, + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 3072 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=4096-K=1024.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=4096-K=1024.json new file mode 100644 index 0000000000..c8d365ec8e --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=4096-K=1024.json @@ -0,0 +1,86 @@ +{ + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 1, + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 1, + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 1, + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 + }, + "M_LEQ_512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 1, + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 + }, + "M_LEQ_2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 1, + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 1, + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=6144-K=1024.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=6144-K=1024.json new file mode 100644 index 0000000000..3b7c3066ec --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16-N=6144-K=1024.json @@ -0,0 +1,86 @@ +{ + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 1, + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 1, + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 1, + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 + }, + "M_LEQ_512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 1, + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 + }, + "M_LEQ_2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 1, + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 1, + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16.json new file mode 100644 index 0000000000..3f2af9911b --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A16W16.json @@ -0,0 +1,86 @@ +{ + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 1, + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 1, + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 1, + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 + }, + "M_LEQ_512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 1, + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 + }, + "M_LEQ_2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 1, + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 1, + "cache_modifier": null, + "SPLITK_BLOCK_SIZE": 1024 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024.json new file mode 100644 index 0000000000..010661e88c --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024.json @@ -0,0 +1,114 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "large": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "xlarge": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16.json new file mode 100644 index 0000000000..a1913185c8 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16.json @@ -0,0 +1,18 @@ +{ + "any": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32.json new file mode 100644 index 0000000000..8fd3a5c3a4 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32.json @@ -0,0 +1,34 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "any": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32_64.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32_64.json new file mode 100644 index 0000000000..6f1641f003 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32_64.json @@ -0,0 +1,50 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "any": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32_64_128.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32_64_128.json new file mode 100644 index 0000000000..a275ac85f7 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32_64_128.json @@ -0,0 +1,66 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32_64_128_256.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32_64_128_256.json new file mode 100644 index 0000000000..ca20e4c6fd --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32_64_128_256.json @@ -0,0 +1,82 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32_64_128_256_512.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32_64_128_256_512.json new file mode 100644 index 0000000000..61c2f0a956 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32_64_128_256_512.json @@ -0,0 +1,98 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "large": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "any": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32_64_128_256_512_2048.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32_64_128_256_512_2048.json new file mode 100644 index 0000000000..538ec8bdb1 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32_64_128_256_512_2048.json @@ -0,0 +1,114 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "large": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "xlarge": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "any": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32_64_128_256_512_2048_4096.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32_64_128_256_512_2048_4096.json new file mode 100644 index 0000000000..010661e88c --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=1024_batch_16_32_64_128_256_512_2048_4096.json @@ -0,0 +1,114 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "large": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "xlarge": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048.json new file mode 100644 index 0000000000..fc17b9ca81 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048.json @@ -0,0 +1,15 @@ +{ + "any": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 2, + "cache_modifier": ".cg" + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048_batch_16.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048_batch_16.json new file mode 100644 index 0000000000..f0401b39de --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048_batch_16.json @@ -0,0 +1,18 @@ +{ + "any": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 2048, + "GROUP_K": 128, + "GROUP_N": 128 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048_batch_16_32.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048_batch_16_32.json new file mode 100644 index 0000000000..c6168c8dfe --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048_batch_16_32.json @@ -0,0 +1,34 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 2048, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "any": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 2048, + "GROUP_K": 128, + "GROUP_N": 128 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048_batch_16_32_64.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048_batch_16_32_64.json new file mode 100644 index 0000000000..852626e62c --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048_batch_16_32_64.json @@ -0,0 +1,50 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 2048, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 2048, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "any": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 2048, + "GROUP_K": 128, + "GROUP_N": 128 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048_batch_16_32_64_128.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048_batch_16_32_64_128.json new file mode 100644 index 0000000000..5cf3dd7af5 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048_batch_16_32_64_128.json @@ -0,0 +1,66 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 2048, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 2048, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 2048, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 2048, + "GROUP_K": 128, + "GROUP_N": 128 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048_progress.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048_progress.json new file mode 100644 index 0000000000..b85ccf2329 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=2048_progress.json @@ -0,0 +1,75 @@ +{ + "completed_batch_sizes": [ + 16, + 32, + 64, + 128 + ], + "configs": { + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 2048, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 2048, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 2048, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 2048, + "GROUP_K": 128, + "GROUP_N": 128 + } + }, + "last_updated": "2025-12-09T16:50:24.402167" +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=3072.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=3072.json new file mode 100644 index 0000000000..5b8b362d82 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=1024-K=3072.json @@ -0,0 +1,15 @@ +{ + "any": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 2, + "cache_modifier": ".cg" + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024.json new file mode 100644 index 0000000000..67408d99fb --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024.json @@ -0,0 +1,114 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "large": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "xlarge": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "any": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16.json new file mode 100644 index 0000000000..530370c6e9 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16.json @@ -0,0 +1,18 @@ +{ + "any": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32.json new file mode 100644 index 0000000000..fc7e680f33 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32.json @@ -0,0 +1,34 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "any": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32_64.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32_64.json new file mode 100644 index 0000000000..07b4ee6f63 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32_64.json @@ -0,0 +1,50 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "any": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32_64_128.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32_64_128.json new file mode 100644 index 0000000000..db03b10ab6 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32_64_128.json @@ -0,0 +1,66 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32_64_128_256.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32_64_128_256.json new file mode 100644 index 0000000000..24d51459eb --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32_64_128_256.json @@ -0,0 +1,82 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32_64_128_256_512.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32_64_128_256_512.json new file mode 100644 index 0000000000..6966dfeab9 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32_64_128_256_512.json @@ -0,0 +1,98 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "large": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "any": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32_64_128_256_512_2048.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32_64_128_256_512_2048.json new file mode 100644 index 0000000000..ac35d393df --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32_64_128_256_512_2048.json @@ -0,0 +1,114 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "large": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "xlarge": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "any": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32_64_128_256_512_2048_4096.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32_64_128_256_512_2048_4096.json new file mode 100644 index 0000000000..67408d99fb --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=4096-K=1024_batch_16_32_64_128_256_512_2048_4096.json @@ -0,0 +1,114 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "medium_M128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "large": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "xlarge": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + }, + "any": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 1, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024, + "GROUP_K": 128, + "GROUP_N": 128 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=6144-K=1024.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=6144-K=1024.json new file mode 100644 index 0000000000..8281184361 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE-N=6144-K=1024.json @@ -0,0 +1,15 @@ +{ + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 2, + "cache_modifier": ".cg" + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE.json new file mode 100644 index 0000000000..de5a5d4300 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_BLOCKSCALE.json @@ -0,0 +1,15 @@ +{ + "any": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NUM_KSPLIT": 1, + "cache_modifier": ".cg", + "kpack": 2, + "matrix_instr_nonkdim": 16, + "num_stages": 2, + "num_warps": 8, + "waves_per_eu": 2 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=1024.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=1024.json new file mode 100644 index 0000000000..e02c3d9496 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=1024.json @@ -0,0 +1,100 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024 + }, + "medium_M64": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024 + }, + "medium_M128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024 + }, + "large": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024 + }, + "xlarge": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=2048.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=2048.json new file mode 100644 index 0000000000..bf4b466417 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=2048.json @@ -0,0 +1,100 @@ +{ + "small": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 2048 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 2048 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 2048 + }, + "medium_M128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "SPLITK_BLOCK_SIZE": 2048 + }, + "large": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 2048 + }, + "xlarge": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 2048 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 2048 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072.json new file mode 100644 index 0000000000..8a704a8fab --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072.json @@ -0,0 +1,100 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "medium_M32": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "medium_M64": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "medium_M128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "large": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "xlarge": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16.json new file mode 100644 index 0000000000..5f96695c9d --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16.json @@ -0,0 +1,16 @@ +{ + "any": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32.json new file mode 100644 index 0000000000..b1a818f586 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32.json @@ -0,0 +1,30 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64.json new file mode 100644 index 0000000000..12f7ea9bdd --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64.json @@ -0,0 +1,44 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "medium_M32": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128.json new file mode 100644 index 0000000000..e99366c468 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128.json @@ -0,0 +1,58 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "medium_M32": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "medium_M64": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256.json new file mode 100644 index 0000000000..adfa60a35f --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256.json @@ -0,0 +1,72 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "medium_M32": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "medium_M64": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "medium_M128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "any": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256_512.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256_512.json new file mode 100644 index 0000000000..ff4a63e228 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256_512.json @@ -0,0 +1,86 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "medium_M32": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "medium_M64": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "medium_M128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "large": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "any": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256_512_2048.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256_512_2048.json new file mode 100644 index 0000000000..1542ef92a6 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256_512_2048.json @@ -0,0 +1,100 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "medium_M32": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "medium_M64": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "medium_M128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "large": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "xlarge": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256_512_2048_4096.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256_512_2048_4096.json new file mode 100644 index 0000000000..8a704a8fab --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=1024-K=3072_batch_16_32_64_128_256_512_2048_4096.json @@ -0,0 +1,100 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "medium_M32": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "medium_M64": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "medium_M128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "large": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "xlarge": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 3072 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=4096-K=1024.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=4096-K=1024.json new file mode 100644 index 0000000000..7010775054 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=4096-K=1024.json @@ -0,0 +1,100 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024 + }, + "medium_M32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024 + }, + "medium_M64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024 + }, + "medium_M128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024 + }, + "large": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024 + }, + "xlarge": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=6144-K=1024.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=6144-K=1024.json new file mode 100644 index 0000000000..fe0fd5eb63 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE-N=6144-K=1024.json @@ -0,0 +1,100 @@ +{ + "small": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 2, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024 + }, + "medium_M128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024 + }, + "large": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024 + }, + "xlarge": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 8, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "NUM_KSPLIT": 1, + "waves_per_eu": 4, + "kpack": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": "", + "SPLITK_BLOCK_SIZE": 1024 + } +} diff --git a/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE.json b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE.json new file mode 100644 index 0000000000..dee783b5dc --- /dev/null +++ b/aiter/ops/triton/configs/gemm/R9700-GEMM-A8W8_PER_TOKEN_SCALE.json @@ -0,0 +1,100 @@ +{ + "small": { + "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 3072, + "cache_modifier": "", + "kpack": 2, + "matrix_instr_nonkdim": 16, + "num_stages": 4, + "num_warps": 8, + "waves_per_eu": 4 + }, + "medium_M32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 3072, + "cache_modifier": "", + "kpack": 2, + "matrix_instr_nonkdim": 16, + "num_stages": 3, + "num_warps": 8, + "waves_per_eu": 4 + }, + "medium_M64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 3072, + "cache_modifier": "", + "kpack": 2, + "matrix_instr_nonkdim": 16, + "num_stages": 4, + "num_warps": 8, + "waves_per_eu": 4 + }, + "medium_M128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 3072, + "cache_modifier": "", + "kpack": 2, + "matrix_instr_nonkdim": 16, + "num_stages": 4, + "num_warps": 8, + "waves_per_eu": 4 + }, + "large": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 3072, + "cache_modifier": "", + "kpack": 2, + "matrix_instr_nonkdim": 16, + "num_stages": 3, + "num_warps": 4, + "waves_per_eu": 8 + }, + "xlarge": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 3072, + "cache_modifier": "", + "kpack": 2, + "matrix_instr_nonkdim": 16, + "num_stages": 2, + "num_warps": 8, + "waves_per_eu": 4 + }, + "any": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 8, + "NUM_KSPLIT": 1, + "SPLITK_BLOCK_SIZE": 3072, + "cache_modifier": "", + "kpack": 2, + "matrix_instr_nonkdim": 16, + "num_stages": 2, + "num_warps": 8, + "waves_per_eu": 4 + } +} diff --git a/aiter/ops/triton/configs/moe/gfx1201-MOE-DEFAULT.json b/aiter/ops/triton/configs/moe/gfx1201-MOE-DEFAULT.json new file mode 100644 index 0000000000..4e2f6baa27 --- /dev/null +++ b/aiter/ops/triton/configs/moe/gfx1201-MOE-DEFAULT.json @@ -0,0 +1,32 @@ +{ + "small_M": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16 + }, + "medium_M": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16 + }, + "large_M": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16 + } +} diff --git a/aiter/ops/triton/configs/moe/gfx1201-MOE-INT4_W4A16.json b/aiter/ops/triton/configs/moe/gfx1201-MOE-INT4_W4A16.json new file mode 100644 index 0000000000..ceb9bc98a5 --- /dev/null +++ b/aiter/ops/triton/configs/moe/gfx1201-MOE-INT4_W4A16.json @@ -0,0 +1,32 @@ +{ + "small_M": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16 + }, + "medium_M": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16 + }, + "large_M": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16 + } +} diff --git a/aiter/ops/triton/configs/moe/gfx1201-MOE-INT8_W8A16.json b/aiter/ops/triton/configs/moe/gfx1201-MOE-INT8_W8A16.json new file mode 100644 index 0000000000..b6e5cffab1 --- /dev/null +++ b/aiter/ops/triton/configs/moe/gfx1201-MOE-INT8_W8A16.json @@ -0,0 +1,32 @@ +{ + "small_M": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16 + }, + "medium_M": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16 + }, + "large_M": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16 + } +} diff --git a/aiter/ops/triton/gemm_a16w16.py b/aiter/ops/triton/gemm_a16w16.py index 549ddd36d8..4a746a57d9 100644 --- a/aiter/ops/triton/gemm_a16w16.py +++ b/aiter/ops/triton/gemm_a16w16.py @@ -57,6 +57,7 @@ def gemm_a16w16( if config is None: config = _get_config(M, N, K) + _LOGGER.info(f"Using GEMM config: {config}") if y is None and (config["NUM_KSPLIT"] == 1 or not skip_reduce): y = torch.empty((M, N), dtype=dtype, device=x.device) diff --git a/aiter/ops/triton/gemm_a16w16_atomic.py b/aiter/ops/triton/gemm_a16w16_atomic.py index 38341b3efb..411f9ad672 100644 --- a/aiter/ops/triton/gemm_a16w16_atomic.py +++ b/aiter/ops/triton/gemm_a16w16_atomic.py @@ -4,8 +4,6 @@ from typing import Optional import torch import triton -import triton.language as tl -import aiter.ops.triton.utils._triton.arch_info as arch_info from aiter.ops.triton._triton_kernels.gemm_a16w16_atomic import ( _gemm_a16_w16_atomic_kernel, _get_config, diff --git a/aiter/ops/triton/gemm_a8w8.py b/aiter/ops/triton/gemm_a8w8.py index 3602ef2ff6..16fed7622e 100644 --- a/aiter/ops/triton/gemm_a8w8.py +++ b/aiter/ops/triton/gemm_a8w8.py @@ -22,7 +22,7 @@ def gemm_a8w8( x_scale: torch.Tensor, w_scale: torch.Tensor, bias: Optional[torch.Tensor] = None, - dtype: Optional[float] = torch.bfloat16, + dtype: Optional[torch.dtype] = torch.bfloat16, y: Optional[torch.Tensor] = None, config: Optional[dict] = None, ): diff --git a/aiter/ops/triton/gemm_a8w8_blockscale.py b/aiter/ops/triton/gemm_a8w8_blockscale.py index ee2d072822..e6db089975 100644 --- a/aiter/ops/triton/gemm_a8w8_blockscale.py +++ b/aiter/ops/triton/gemm_a8w8_blockscale.py @@ -22,7 +22,7 @@ def gemm_a8w8_blockscale( w: torch.Tensor, x_scale: torch.Tensor, w_scale: torch.Tensor, - dtype: Optional[float] = torch.bfloat16, + dtype: Optional[torch.dtype] = torch.bfloat16, y: Optional[torch.Tensor] = None, config: Optional[dict] = None, skip_reduce: Optional[bool] = False, diff --git a/aiter/ops/triton/gemm_a8w8_per_token_scale.py b/aiter/ops/triton/gemm_a8w8_per_token_scale.py index e8032bdbeb..3a29f0a52a 100644 --- a/aiter/ops/triton/gemm_a8w8_per_token_scale.py +++ b/aiter/ops/triton/gemm_a8w8_per_token_scale.py @@ -19,7 +19,7 @@ def gemm_a8w8_per_token_scale( w: torch.Tensor, x_scale: torch.Tensor, w_scale: torch.Tensor, - dtype: Optional[float] = torch.bfloat16, + dtype: Optional[torch.dtype] = torch.bfloat16, y: Optional[torch.Tensor] = None, config=None, ): diff --git a/aiter/ops/triton/tune/base.py b/aiter/ops/triton/tune/base.py new file mode 100644 index 0000000000..07f6f2a869 --- /dev/null +++ b/aiter/ops/triton/tune/base.py @@ -0,0 +1,139 @@ +import json +import os +import triton +import torch +import tqdm + +def tune_kernel( + search_space, + make_run_and_gt_fn, + config_callback=None, + num_iters=10, + atol=1e-2, + rtol=1e-2, +): + """ + Args: + search_space: List of config dicts. + make_run_and_gt_fn: Callable(config) -> (run_fn, ground_truth) + config_callback: Optional function to modify config before use. + num_iters: Number of iterations for benchmarking. + atol, rtol: Tolerances for output comparison. + Returns: + The best config dict. + """ + best_config = None + best_time = float("inf") + for config in tqdm.tqdm(search_space): + if config_callback: + config_callback(config) + run, ground_truth = make_run_and_gt_fn(config) + try: + kernel_time = benchmark_config( + run, ground_truth, num_iters=num_iters, atol=atol, rtol=rtol + ) + except triton.runtime.autotuner.OutOfResources: + # print("OutOfResources encountered during tuning.") + # Some configurations may be invalid and fail to compile. + continue + except AssertionError as e: + print(f"AssertionError encountered during tuning: {e}") + continue + except Exception as e: + print(f"Config failed: {e}") + continue + if kernel_time < best_time: + best_time = kernel_time + best_config = config + assert best_config is not None + return best_config + +def benchmark_config(run, ground_truth, num_iters=10, atol=1e-1, rtol=1e-1): + """ + Args: + run: Callable that returns the kernel output when called (no arguments). + ground_truth: The expected output tensor to compare against. + num_iters: Number of iterations to benchmark. + atol: Absolute tolerance for comparison. + rtol: Relative tolerance for comparison. + Returns: + Average latency in microseconds. + """ + torch.cuda.synchronize() + # JIT compilation & warmup + for _ in range(5): + run() + torch.cuda.synchronize() + + start_event = torch.Event(enable_timing=True) + end_event = torch.Event(enable_timing=True) + + latencies: list[float] = [] + for _ in range(num_iters): + torch.cuda.synchronize() + start_event.record() + kernel_out = run() + end_event.record() + end_event.synchronize() + latencies.append(start_event.elapsed_time(end_event)) + torch.testing.assert_close(kernel_out, ground_truth, atol=atol, rtol=rtol) + avg = sum(latencies) / (num_iters * 10) * 1000 # us + return avg + + +def get_search_space(small: bool = False): + """ + Returns the search space for tuning. + Args: + small (bool): If True, returns a small search space for testing. If False, returns the full search space. + """ + configs = [] + if small: + num_stages_list = [2, 3] + block_m_list = [16, 32] + block_k_list = [64] + block_n_list = [32, 64] + num_warps_list = [4] + group_size_list = [1] + waves_per_eu_list = [3] + else: + num_stages_list = [2, 3, 4, 5] + block_m_list = [16, 32, 64, 128, 256] + block_k_list = [64, 128] + block_n_list = [32, 64, 128, 256] + num_warps_list = [4, 8] + group_size_list = [1, 16, 32, 64] + waves_per_eu_list = [2, 3, 4] + + for num_stages in num_stages_list: + for block_m in block_m_list: + for block_k in block_k_list: + for block_n in block_n_list: + for num_warps in num_warps_list: + for group_size in group_size_list: + for waves_per_eu in waves_per_eu_list: + configs.append( + { + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": group_size, + "num_warps": num_warps, + "num_stages": num_stages, + "waves_per_eu": waves_per_eu, + "matrix_instr_nonkdim": 16, + } + ) + return configs + +def save_configs_to_json( + json_file_name: str, + save_path: str, + configs: list[dict], +) -> None: + os.makedirs(save_path, exist_ok=True) + config_file_path = os.path.join(save_path, json_file_name) + + with open(config_file_path, "w") as f: + json.dump(configs, f, indent=4) + f.write("\n") \ No newline at end of file diff --git a/aiter/ops/triton/tune/tune_moe_op.py b/aiter/ops/triton/tune/tune_moe_op.py new file mode 100644 index 0000000000..2f4bbb4cae --- /dev/null +++ b/aiter/ops/triton/tune/tune_moe_op.py @@ -0,0 +1,314 @@ +import torch +import time + +from tqdm import tqdm + +from aiter.ops.triton.tune.base import ( + tune_kernel, + get_search_space, + save_configs_to_json, +) +from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH +from aiter.ops.triton.utils.types import torch_to_triton_dtype +from aiter.ops.triton.utils._triton import arch_info +from aiter.ops.triton.moe_op import fused_moe + +from op_tests.triton_tests.moe.test_moe import ( + quantize_fp8, + quantize_int8, + torch_moe_ref, + torch_moe_align_block_size_ref, +) + + +def input_helper( + M: int, + N: int, + K: int, + top_k: int, + E: int, + routed_weight: bool, + dtype, + fp8_w8a8: bool, + int8_w8a16: bool, + config: dict, +): + assert not (fp8_w8a8 and int8_w8a16) + + a = torch.randn((M, K), dtype=dtype, device="cuda") + b = torch.rand((E, N, K), dtype=dtype, device="cuda") + a_scale = None + b_scale = None + + if fp8_w8a8: + b, _, b_scale = quantize_fp8(b, dim=(0,)) + + if int8_w8a16: + b, _, b_scale = quantize_int8(b, dim=(0,)) + + b_zp = False + + c = torch.zeros((M, top_k, N), dtype=dtype, device="cuda") + c_silu = torch.zeros((M * top_k, N // 2), dtype=dtype, device="cuda") + + values = torch.randn(M, E, dtype=dtype, device="cuda") + + softmax_vals = torch.softmax(values, dim=1) + topk_weights, topk_ids = torch.topk(softmax_vals, k=top_k, dim=1) + + sorted_token_ids, expert_ids, num_tokens_post_padded = ( + torch_moe_align_block_size_ref(topk_ids, config["BLOCK_SIZE_M"], E) + ) + + return ( + a, + b, + c, + c_silu, + b_zp, + a_scale, + b_scale, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + config, + ) + + +def make_run_and_gt_fn_factory( + M, N, K, top_k, E, routed_weight, dtype, fp8_w8a8, int8_w8a16, int4_w4a16 +): + def make_run_and_gt(config): + ( + a, + b, + triton_out, + triton_out_silu, + b_zp, + a_scale, + b_scale, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + _, + ) = input_helper( + M, + N, + K, + top_k, + E, + routed_weight=routed_weight, + dtype=dtype, + fp8_w8a8=fp8_w8a8, + int8_w8a16=int8_w8a16, + config=config, + ) + + def run(): + torch.cuda.empty_cache() + fused_moe( + a, + b, + triton_out, + a_scale, + b_scale, + b_zp, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + routed_weight, + top_k, + torch_to_triton_dtype[dtype], + fp8_w8a8, + int8_w8a16, + False, + config=config, + ) + return triton_out + + torch_out = torch.empty_like(triton_out) + ground_truth = torch_moe_ref( + a, + b, + torch_out, + a_scale, + b_scale, + None, + 0, + topk_ids, + topk_weights, + routed_weight, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + dtype, + fp8_w8a8, + int8_w8a16, + False, + ) + return run, ground_truth + + return make_run_and_gt + + +def tune_fused_moe( + M, + N, + K, + top_k, + E, + routed_weight, + dtype, + fp8_w8a8, + int8_w8a16, + int4_w4a16, + search_space, +): + make_run_and_gt = make_run_and_gt_fn_factory( + M, N, K, top_k, E, routed_weight, dtype, fp8_w8a8, int8_w8a16, int4_w4a16 + ) + + best_config = tune_kernel( + search_space=search_space, + make_run_and_gt_fn=make_run_and_gt, + ) + return best_config + + +def tune_and_save_configs( + batch_sizes, + N, + K, + topk, + E, + routed_weight, + dtype, + fp8_w8a8, + int8_w8a16, + int4_w4a16, + search_space, + save_path, + device_name, + tag, +): + start = time.time() + benchmark_results = [ + tune_fused_moe( + batch_size, + N, + K, + topk, + E, + routed_weight=routed_weight, + dtype=dtype, + fp8_w8a8=fp8_w8a8, + int8_w8a16=int8_w8a16, + int4_w4a16=int4_w4a16, + search_space=search_space, + ) + for batch_size in tqdm(batch_sizes) + ] + best_configs = { + ( + "small_M" + if batch_size <= 256 + else "medium_M" + if batch_size <= 2048 + else "large_M" + ): config + for batch_size, config in zip(batch_sizes, benchmark_results) + } + json_file_name = f"{device_name}-MOE-{tag}.json" + save_configs_to_json(json_file_name, save_path, best_configs) + end = time.time() + print(f"Tuning for {tag} took {end - start:.2f} seconds") + + +def main(): + dev = arch_info.get_arch() + + torch.cuda.init() + + batch_sizes = [256, 2048, 4096] # M + N = 384 + K = 768 + topk = 8 + E = 128 + search_space = get_search_space() + save_path = AITER_TRITON_CONFIGS_PATH + "/moe/" + + # Tune for default (float16) + # tune_and_save_configs( + # batch_sizes=batch_sizes, + # N=N, + # K=K, + # topk=topk, + # E=E, + # routed_weight=False, + # dtype=torch.float16, + # fp8_w8a8=False, + # int8_w8a16=False, + # int4_w4a16=False, + # search_space=search_space, + # save_path=save_path, + # device_name=dev, + # tag="DEFAULT", + # ) + # tune_and_save_configs( + # batch_sizes=batch_sizes, + # N=N, + # K=K, + # topk=topk, + # E=E, + # routed_weight=False, + # dtype=torch.float16, + # fp8_w8a8=True, + # int8_w8a16=False, + # int4_w4a16=False, + # search_space=search_space, + # save_path=save_path, + # device_name=dev, + # tag="FP8_W8A8", + # ) + tune_and_save_configs( + batch_sizes=batch_sizes, + N=N, + K=K, + topk=topk, + E=E, + routed_weight=False, + dtype=torch.float16, + fp8_w8a8=False, + int8_w8a16=True, + int4_w4a16=False, + search_space=search_space, + save_path=save_path, + device_name=dev, + tag="INT8_W8A16", + ) + tune_and_save_configs( + batch_sizes=batch_sizes, + N=N, + K=K, + topk=topk, + E=E, + routed_weight=False, + dtype=torch.float16, + fp8_w8a8=False, + int8_w8a16=False, + int4_w4a16=True, + search_space=search_space, + save_path=save_path, + device_name=dev, + tag="INT4_W4A16", + ) + + +if __name__ == "__main__": + main() diff --git a/aiter/ops/triton/tune_a16w16.py b/aiter/ops/triton/tune_a16w16.py new file mode 100644 index 0000000000..08bb254c62 --- /dev/null +++ b/aiter/ops/triton/tune_a16w16.py @@ -0,0 +1,262 @@ +import argparse +import json +import os +import time +import triton +from datetime import datetime + +import torch +from tqdm import tqdm + +from aiter.ops.triton.gemm_a16w16 import gemm_a16w16 + +from op_tests.triton_tests.utils.types import str_to_torch_dtype + +from utils.core import AITER_TRITON_CONFIGS_PATH # type: ignore + + +def generate_gemm_a16w16_inputs(M, N, K, dtype, layout="TN", output=True, bias=False): + if isinstance(dtype, str): + dtype = str_to_torch_dtype[dtype] + + # TN is default layout + if layout[0] == "T": + print(M, K) + print(dtype) + x = torch.randn((M, K), dtype=dtype, device="cuda") + else: + x = torch.randn((K, M), dtype=dtype, device="cuda").T + + if layout[1] == "T": + weight = torch.randn((K, N), dtype=dtype, device="cuda").T + else: + weight = torch.randn((N, K), dtype=dtype, device="cuda") + + bias_tensor = None + if bias: + bias_tensor = torch.empty((N), dtype=dtype, device="cuda") + + y = None + if output: + y = torch.empty((M, N), dtype=dtype, device="cuda") + out_dtype = (None,) + else: + out_dtype = dtype + + return x, weight, bias_tensor, out_dtype, y + + +def get_configs_compute_bound(): + configs = [] + for num_stages in [2, 3, 4, 5]: + for block_m in [16, 32, 64, 128, 256]: + for block_k in [64, 128]: + for block_n in [32, 64, 128, 256]: + for num_warps in [4, 8]: + for group_size in [1, 16, 32, 64]: + for waves_per_eu in [2, 3, 4]: + configs.append( + { + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": group_size, + "num_warps": num_warps, + "num_stages": num_stages, + "waves_per_eu": waves_per_eu, + "matrix_instr_nonkdim": 16, + "NUM_KSPLIT": 1, + "kpack": 1, + "cache_modifier": None, + } + ) + return configs + + +def get_weight_shapes(): + total = [ + (1024, 1024), + (4096, 1024), + (1024, 2048), + (6144, 1024), + (1024, 3072), + ] + + weight_shapes = [] + for t in total: + weight_shapes.append(t) + + return weight_shapes + + +def benchmark_config(x, w, bias, dtype, y, config, activation, num_iters=10): + torch_out = torch.nn.functional.linear(x, w, bias=bias) # Ground truth + + def run(): + return gemm_a16w16( + x, w, bias=bias, dtype=dtype, y=y, config=config, activation=None + ) + + torch.cuda.synchronize() + # JIT complication & warmup + for _ in range(5): + run() + torch.cuda.synchronize() + + start_event = torch.Event(enable_timing=True) + end_event = torch.Event(enable_timing=True) + + latencies: list[float] = [] + for i in range(num_iters): + torch.cuda.synchronize() + start_event.record() + triton_out = run() + end_event.record() + end_event.synchronize() + latencies.append(start_event.elapsed_time(end_event)) + torch.testing.assert_close(triton_out, torch_out, atol=1e-1, rtol=1e-1) + avg = sum(latencies) / (num_iters * 10) * 1000 # us + return avg + + +def tune(M, N, K, dtype, search_space): + x, w, bias, out_dtype, y = generate_gemm_a16w16_inputs( + M, N, K, dtype, output=True, bias=True + ) + + best_config = None + best_time = float("inf") + for config in tqdm(search_space): + config["SPLITK_BLOCK_SIZE"]= triton.cdiv(K, config["NUM_KSPLIT"]) + try: + kernel_time = benchmark_config( + x=x, + w=w, + bias=bias, + dtype=out_dtype, + y=y, + config=config, + activation=None, + num_iters=10, + ) + except triton.runtime.autotuner.OutOfResources: + print("OutOfResources encountered during tuning.") + # Some configurations may be invalid and fail to compile. + continue + except AssertionError: + print("AssertionError encountered during tuning.") + continue + + if kernel_time < best_time: + best_time = kernel_time + best_config = config + now = datetime.now() + print(f"{now.ctime()}] Completed tuning for batch_size={M}") + assert best_config is not None + return best_config + + +def save_configs( + N, + K, + configs, + save_path, +) -> None: + os.makedirs(save_path, exist_ok=True) + device_name = "R9700" # TODO: Hardcoded, make it dynamic + json_file_name = f"{device_name}-GEMM-A16W16-N={N}-K={K}.json" + + config_file_path = os.path.join(save_path, json_file_name) + print(f"Writing best config to {config_file_path}...") + + with open(config_file_path, "w") as f: + json.dump(configs, f, indent=4) + f.write("\n") + + +def tune_on_gpu(args_dict): + """Run tuning on a specific GPU.""" + gpu_id = args_dict["gpu_id"] + batch_sizes = args_dict["batch_sizes"] + weight_shapes = args_dict["weight_shapes"] + args = args_dict["args"] + + torch.cuda.set_device(gpu_id) + print(f"Starting tuning on GPU {gpu_id} with batch sizes {batch_sizes}") + + dtype = str_to_torch_dtype[args.dtype] + save_path = AITER_TRITON_CONFIGS_PATH + "/gemm/" + + search_space = get_configs_compute_bound() + + start = time.time() + for shape in tqdm(weight_shapes, desc=f"GPU {gpu_id} - Shapes"): + N, K = shape[0], shape[1] + print(f"[GPU {gpu_id}] Tune for weight shape of `N: {N}, K: {K}`") + benchmark_results = [ + tune( + batch_size, + N, + K, + dtype, + search_space, + ) + for batch_size in tqdm(batch_sizes, desc=f"GPU {gpu_id} - Batch sizes") + ] + best_configs = { + ("any" if i == len(batch_sizes) - 1 else f"M_LEQ_{M}"): config + for i, (M, config) in enumerate(zip(batch_sizes, benchmark_results)) + } + save_configs(N, K, best_configs, save_path) + + end = time.time() + print(f"Tuning on GPU {gpu_id} took {end - start:.2f} seconds") + + +def main(args): + print(args) + num_gpus = torch.cuda.device_count() + if num_gpus == 0: + raise RuntimeError("No GPU available for tuning") + print(f"Found {num_gpus} GPUs for tuning") + + torch.cuda.init() + + batch_sizes = [ + 64, + 128, + 256, + 512, + 2048, + 4096, + ] + + weight_shapes = get_weight_shapes() + + # Run tuning sequentially on GPU 0 + tune_on_gpu( + { + "gpu_id": 0, + "batch_sizes": batch_sizes, + "weight_shapes": weight_shapes, + "args": args, + } + ) + + print("Tuning completed") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + formatter_class=argparse.RawTextHelpFormatter, + ) + + parser.add_argument( + "--dtype", + type=str, + choices=["float32", "float16", "bfloat16", "half"], + default="bfloat16", + ) + args = parser.parse_args() + + main(args) diff --git a/aiter/ops/triton/tune_a16w16_atomic.py b/aiter/ops/triton/tune_a16w16_atomic.py new file mode 100644 index 0000000000..0b57b6e596 --- /dev/null +++ b/aiter/ops/triton/tune_a16w16_atomic.py @@ -0,0 +1,421 @@ +import argparse +import json +import multiprocessing as mp +import os +import time +import triton +from datetime import datetime +from typing import List, Dict, Union, Tuple, Optional +import torch +from tqdm import tqdm + + +from gemm_a16w16_atomic import gemm_a16w16_atomic # type: ignore +from utils.core import AITER_TRITON_CONFIGS_PATH # type: ignore + +mp.set_start_method("spawn", force=True) + + +DTYPE_MAP = { + "float32": torch.float32, + "float16": torch.float16, + "half": torch.half, + "bfloat16": torch.bfloat16, +} + + +def generate_gemm_a16w16_inputs(M, N, K, dtype, layout="TN", output=True, bias=False): + if isinstance(dtype, str): + dtype = DTYPE_MAP[dtype] + + # TN is default layout + if layout[0] == "T": + x = torch.randn((M, K), dtype=dtype, device="cuda") + else: + x = torch.randn((K, M), dtype=dtype, device="cuda").T + + if layout[1] == "T": + weight = torch.randn((K, N), dtype=dtype, device="cuda").T + else: + weight = torch.randn((N, K), dtype=dtype, device="cuda") + + bias_tensor = None + if bias: + bias_tensor = torch.empty((N), dtype=dtype, device="cuda") + + y = None + if output: + y = torch.empty((M, N), dtype=dtype, device="cuda") + out_dtype = (None,) + else: + out_dtype = dtype + + return x, weight, bias_tensor, out_dtype, y + + +def get_configs_compute_bound() -> List[Dict[str, int | str]]: + configs = [] + # Optimize parameters based on kernel analysis + # Only test parameters that are actually used in the kernel + # Based on the generated configs, we'll use the optimal values + for num_stages in [2, 3, 4, 5]: # Only 1 stage is used + for block_m in [16, 32, 64, 128, 256]: # Fixed to 64 as in current configs + for block_n in [32, 64, 128, 256]: # Fixed to 32 as in current configs + for block_k in [64, 128]: # Fixed to 64 as in current configs + for group_size in [1, 8, 16, 32, 64]: + for num_warps in [4, 8]: + for num_ksplit in [ + 1, + ]: # Only test 1 since higher values may cause issues + for waves_per_eu in [ + 2, + 4, + 8, + ]: # Fixed to 3 as in current configs + configs.append( + { + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": group_size, + "num_warps": num_warps, + "num_stages": num_stages, + "waves_per_eu": waves_per_eu, + "matrix_instr_nonkdim": 16, # Fixed value used in kernel + "cache_modifier": "", # Empty string for atomic kernel + "NUM_KSPLIT": num_ksplit, # Atomic kernel specific + # "kpack": 1, # Fixed value used in kernel + # "SPLITK_BLOCK_SIZE": 1, # Will be set dynamically + } + ) + return configs + + +def get_weight_shapes(tp_size: int) -> List[Tuple[int, int]]: + total = [ + # (1024, 1024), + # (4096, 1024), + # (1024, 2048), + (6144, 1024), + (1024, 3072), + ] + + weight_shapes: List[Tuple[int, int]] = [] + for t in total: + weight_shapes.append(t) + + return weight_shapes + + +def benchmark_config( + x: torch.Tensor, + w: torch.Tensor, + dtype: torch.dtype, + config: Dict[str, Union[str, int]], + y: Optional[torch.Tensor] = None, + num_iters=10, +) -> float: + """ + Benchmark the performance of a GEMM operation with a specific configuration. + + This function measures the execution time of the gemm_a16w16_atomic kernel by running + it multiple times with synchronization points to ensure accurate timing. It performs + warmup runs before the actual benchmarking to account for JIT compilation overhead. + + Args: + x (torch.Tensor): Input tensor of shape (M, K) representing the first matrix operand. + w (torch.Tensor): Weight tensor of shape (N, K) representing the second matrix operand. + dtype (torch.dtype): Data type for the computation (e.g., torch.bfloat16). + config (Dict[str, Union[str, int]]): Configuration dictionary containing kernel + parameters such as block sizes, number of warps, etc. + y (Optional[torch.Tensor], optional): Output tensor to store the result. If None, + a new tensor will be allocated. Defaults to None. + num_iters (int, optional): Number of benchmark iterations to run. Defaults to 10. + + Returns: + float: Average execution time in microseconds (us) per iteration. + + Note: + The function performs 5 warmup iterations before benchmarking to account for + JIT compilation and GPU warmup effects. The timing is measured using CUDA events + for accurate GPU kernel timing. + """ + torch_out = torch.nn.functional.linear(x, w, bias=None) + + # run the kernel + def run(): + return gemm_a16w16_atomic(x, w, dtype, y, config) + + torch.cuda.synchronize() + # JIT complication & warmup + for _ in range(5): + run() + torch.cuda.synchronize() + + start_event = torch.Event(enable_timing=True) + end_event = torch.Event(enable_timing=True) + + latencies: list[float] = [] + for i in range(num_iters): + torch.cuda.synchronize() + start_event.record() + triton_out_raw = run() + # Convert to the same dtype as the reference for comparison + triton_out = triton_out_raw.to(torch_out.dtype) + end_event.record() + end_event.synchronize() + latencies.append(start_event.elapsed_time(end_event)) + torch.testing.assert_close(triton_out, torch_out, atol=1e-1, rtol=1e-1) + avg = sum(latencies) / (num_iters * 10) * 1000 # us + return avg + + +def tune( + M: int, N: int, K: int, search_space: List[Dict[str, int | str]], input_type: str +): + if input_type == "bfloat16": + # Use the same input generation as test file + x, weight, x_scale, w_scale, y = generate_gemm_a16w16_inputs( + M, N, K, torch.bfloat16, output=True + ) + else: + raise RuntimeError("Currently, only support tune w16a16 block fp16 kernel.") + + best_config = None + best_time = float("inf") + for config in tqdm(search_space): + try: + kernel_time = benchmark_config( + x=x, + w=weight, + dtype=torch.float32, + y=None, + config=config, + num_iters=10, + ) + except triton.runtime.autotuner.OutOfResources as e: + # print("config failed!", config) + # print("error: ", e) + # Some configurations may be invalid and fail to compile. + continue + except AssertionError as e: + print("Assert error:", e) + continue + + if kernel_time < best_time: + best_time = kernel_time + best_config = config + now = datetime.now() + print(f"{now.ctime()}] Completed tuning for batch_size={M}") + assert best_config is not None + return best_config + + +def save_configs( + N, + K, + configs, + save_path, +) -> None: + os.makedirs(save_path, exist_ok=True) + device_name = "R9700" # TODO: Hardcoded, make it dynamic + json_file_name = f"{device_name}-GEMM-A16W16-ATOMIC-N={N}-K={K}.json" + + config_file_path = os.path.join(save_path, json_file_name) + print(f"Writing best config to {config_file_path}...") + + with open(config_file_path, "w") as f: + json.dump(configs, f, indent=4) + f.write("\n") + + +def tune_on_gpu( + gpu_id: int, + batch_sizes: List[int], + weight_shapes: List[Tuple[int, int]], + input_type: str, +) -> None: + """Run tuning on a specific GPU.""" + torch.cuda.set_device(gpu_id) + print(f"Starting tuning on GPU {gpu_id} with batch sizes {batch_sizes}") + + save_path = AITER_TRITON_CONFIGS_PATH + "/gemm/" + + search_space = get_configs_compute_bound() + + start = time.time() + + # Collect all configs to determine best overall config + all_configs: List[Dict[str, Dict[str, int | str]]] = [] + + for shape in tqdm(weight_shapes, desc=f"GPU {gpu_id} - Shapes"): + N, K = shape[0], shape[1] + print(f"[GPU {gpu_id}] Tune for weight shape of `N: {N}, K: {K}`") + benchmark_results = [ + tune( + batch_size, + N, + K, + search_space, + input_type, + ) + for batch_size in tqdm(batch_sizes, desc=f"GPU {gpu_id} - Batch sizes") + ] + best_configs: Dict[str, Dict[str, int | str]] = {} + # Create configs for different M size categories as expected by the atomic kernel + for i, (M, config) in enumerate(zip(batch_sizes, benchmark_results)): + if i == len(batch_sizes) - 1: + best_configs["any"] = config + elif M < 32: + best_configs["small"] = config + elif M <= 128: + BLK_M = triton.next_power_of_2(M) + if BLK_M == 32: + best_configs["medium_M32"] = config + elif BLK_M == 64: + best_configs["medium_M64"] = config + elif BLK_M == 128: + best_configs["medium_M128"] = config + elif M <= 256: + best_configs["large"] = config + else: + best_configs["xlarge"] = config + # Store configs for later analysis + all_configs.append(best_configs) + save_configs(N, K, best_configs, save_path) + + # Create a default config file (without N,K parameters) by selecting most common config + default_config = create_default_config(all_configs) + save_default_config(default_config, save_path) + + end = time.time() + print(f"Tuning on GPU {gpu_id} took {end - start:.2f} seconds") + + +def create_default_config( + all_configs: List[Dict[str, Dict[str, Union[int, str]]]], +) -> Dict[str, Dict[str, Union[int, str]]]: + """Create a default config by selecting the most common config across all shapes.""" + from collections import Counter + + # Collect all configs for each category + category_configs = { + "small": [], + "medium_M32": [], + "medium_M64": [], + "medium_M128": [], + "large": [], + "xlarge": [], + "any": [], + } + + for config in all_configs: + for category, params in config.items(): + if category in category_configs: + # Convert config to a hashable tuple for counting + config_tuple = tuple(sorted(params.items())) + category_configs[category].append(config_tuple) + + # Find the most common config for each category + default_config: Dict[str, Dict[str, Union[int, str]]] = {} + for category, configs in category_configs.items(): + if configs: + most_common = Counter(configs).most_common(1)[0][0] + default_config[category] = dict(most_common) + + return default_config + + +def save_default_config( + config: Dict[str, Dict[str, Union[int, str]]], save_path: str +) -> None: + """Save the default config file (without N,K parameters).""" + os.makedirs(save_path, exist_ok=True) + device_name = "R9700" # TODO: Hardcoded, make it dynamic + json_file_name = f"{device_name}-GEMM-A16W16-ATOMIC.json" + + config_file_path = os.path.join(save_path, json_file_name) + print(f"Writing default config to {config_file_path}...") + + with open(config_file_path, "w") as f: + json.dump(config, f, indent=4) + f.write("\n") + + +def distribute_batch_sizes(batch_sizes: List[int], num_gpus: int) -> List[List[int]]: + """Distribute batch sizes across available GPUs.""" + batches_per_gpu: List[List[int]] = [] + for i in range(num_gpus): + start_idx = i * len(batch_sizes) // num_gpus + end_idx = (i + 1) * len(batch_sizes) // num_gpus + batches_per_gpu.append(batch_sizes[start_idx:end_idx]) + return batches_per_gpu + + +def main(args): + print(args) + num_gpus = torch.cuda.device_count() + if num_gpus == 0: + raise RuntimeError("No GPU available for tuning") + print(f"Found {num_gpus} GPUs for parallel tuning") + + torch.cuda.init() + + if args.batch_size is None: + batch_sizes = [ + 16, # For small config + 32, # For medium_M32 config + 64, # For medium_M64 config + 128, # For medium_M128 config + 256, # For large config + 512, # For large config + 2048, # For xlarge config + 4096, # For xlarge config + ] + else: + batch_sizes = [args.batch_size] + num_gpus = 1 # If only one batch size, use only one GPU + + weight_shapes = get_weight_shapes(args.tp_size) + + batches_per_gpu = distribute_batch_sizes(batch_sizes, 1) + + # Prepare arguments for each GPU process + process_args = [] + for gpu_id in range(1): + process_args.append( + ( + gpu_id, + batches_per_gpu[gpu_id], + weight_shapes, # Each GPU processes all weight shapes + args.input_type, + ) + ) + + ctx = mp.get_context("spawn") + with ctx.Pool(1) as pool: + pool.starmap(tune_on_gpu, process_args) + + print("Multi-GPU tuning completed") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + formatter_class=argparse.RawTextHelpFormatter, + ) + + parser.add_argument("--tp-size", "-tp", type=int, default=1) + parser.add_argument( + "--input-type", type=str, choices=["bfloat16"], default="bfloat16" + ) + parser.add_argument( + "--out-dtype", + type=str, + choices=["float32", "float16", "bfloat16", "half"], + default="bfloat16", + ) + parser.add_argument("--batch-size", type=int, required=False) + args = parser.parse_args() + + main(args) +# diff --git a/aiter/ops/triton/tune_a8w8_blockscale.py b/aiter/ops/triton/tune_a8w8_blockscale.py new file mode 100644 index 0000000000..8106fd8199 --- /dev/null +++ b/aiter/ops/triton/tune_a8w8_blockscale.py @@ -0,0 +1,1529 @@ +import argparse +import json +import logging +import multiprocessing as mp +import os +import signal +import sys +import time +import triton +from datetime import datetime +from typing import List, Dict, Union, Tuple, Optional, Any +import torch +import pytz +from rich.console import Console +from rich.table import Table +from rich.panel import Panel +from rich.columns import Columns +from rich.live import Live +from rich.layout import Layout +from rich.progress import Progress, BarColumn, TextColumn + + +from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale # type: ignore +from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH # type: ignore +from aiter.ops.triton.utils.types import get_fp8_dtypes + +mp.set_start_method("spawn", force=True) + +# Get FP8 data types +e5m2_type, e4m3_type = get_fp8_dtypes() + + +class TimeoutError(Exception): + """Custom exception for timeout errors.""" + + pass + + +# Global variables to track bad configurations and current state +BAD_CONFIGS = { + "timeouts": [], + "out_of_resources": [], + "assert_errors": [], + "other_errors": [], +} + +# Global variables to track current state for SIGINT handling +CURRENT_CONFIG: Dict[str, Any] = { + "M": None, + "N": None, + "K": None, + "config": None, + "config_index": None, + "total_configs": None, + "batch_size": None, + "weight_shape_index": None, + "total_weight_shapes": None, + "gpu_id": None, +} + +INTERRUPTED = False + + +class GMT8Formatter(logging.Formatter): + """Custom formatter that uses GMT+8 timezone.""" + + def __init__(self, fmt=None, datefmt=None): + super().__init__(fmt, datefmt) + self.gmt8 = pytz.timezone("Asia/Shanghai") # GMT+8 + + def formatTime(self, record, datefmt=None): + # Convert timestamp to GMT+8 + dt = datetime.fromtimestamp(record.created, tz=self.gmt8) + if datefmt: + return dt.strftime(datefmt) + else: + return dt.strftime("%Y-%m-%d %H:%M:%S") + + def format(self, record): + # Add timezone info to the formatted message + original = super().format(record) + return original.replace("[GMT+8]", "") # Remove any existing timezone tag + + +def get_timestamped_filename(base_name: str, extension: str = ".log") -> str: + """Generate a filename with timestamp in GMT+8 timezone.""" + gmt8 = pytz.timezone("Asia/Shanghai") + timestamp = datetime.now(gmt8).strftime("%Y%m%d_%H%M%S") + return f"{base_name}_{timestamp}{extension}" + + +def sigint_handler(signum, frame): + """Handle SIGINT (Ctrl+C) gracefully by logging the current configuration.""" + global INTERRUPTED + global CURRENT_CONFIG + INTERRUPTED = True + + print("\n" + "=" * 80) + print("๐Ÿ›‘ TUNING INTERRUPTED BY USER (Ctrl+C)") + print("=" * 80) + + if CURRENT_CONFIG["M"] is not None: + print("๐Ÿ“ Last configuration being processed:") + print(f" ๐ŸŽฏ GPU: {CURRENT_CONFIG['gpu_id']}") + print( + f" ๐Ÿ“Š Matrix: M={CURRENT_CONFIG['M']} N={CURRENT_CONFIG['N']} K={CURRENT_CONFIG['K']}" + ) + print(f" ๐Ÿ“ฆ Batch Size: {CURRENT_CONFIG['batch_size']}") + print( + f" ๐Ÿ”„ Progress: Config {CURRENT_CONFIG['config_index'] + 1}/{CURRENT_CONFIG['total_configs']}" + ) + print( + f" ๐Ÿ—๏ธ Weight Shape: {CURRENT_CONFIG['weight_shape_index'] + 1}/{CURRENT_CONFIG['total_weight_shapes']}" + ) + + if CURRENT_CONFIG["config"]: + config = CURRENT_CONFIG["config"] + print(" โš™๏ธ Parameters:") + print(f" BLOCK_SIZE_M: {config.get('BLOCK_SIZE_M', 'N/A')}") + print(f" BLOCK_SIZE_N: {config.get('BLOCK_SIZE_N', 'N/A')}") + print(f" BLOCK_SIZE_K: {config.get('BLOCK_SIZE_K', 'N/A')}") + print(f" num_warps: {config.get('num_warps', 'N/A')}") + print(f" num_stages: {config.get('num_stages', 'N/A')}") + print(f" NUM_KSPLIT: {config.get('NUM_KSPLIT', 'N/A')}") + print(f" waves_per_eu: {config.get('waves_per_eu', 'N/A')}") + print(f" kpack: {config.get('kpack', 'N/A')}") + print(f" cache_modifier: {config.get('cache_modifier', 'N/A')}") + print(f" GROUP_SIZE_M: {config.get('GROUP_SIZE_M', 'N/A')}") + print(f" GROUP_K: {config.get('GROUP_K', 'N/A')}") + + # Show config in same format as console output for consistency + config_num = CURRENT_CONFIG["config_index"] + 1 + console_format = f" ๐Ÿ’ป Config {config_num} (INTERRUPTED): | BM:{config.get('BLOCK_SIZE_M')} BN:{config.get('BLOCK_SIZE_N')} BK:{config.get('BLOCK_SIZE_K')} W:{config.get('num_warps')} S:{config.get('num_stages')} KS:{config.get('NUM_KSPLIT')} kpack:{config.get('kpack')} cache:{config.get('cache_modifier')} GSM:{config.get('GROUP_SIZE_M')}" + print(console_format) + + # Log the interruption to the file if logger is available + try: + logger = logging.getLogger("gemm_a8w8_blockscale_tuning") + if logger.handlers: + # Use GMT+8 timestamp for consistency + gmt8 = pytz.timezone("Asia/Shanghai") + + # Create detailed log entry + detailed_log_entry = { + "timestamp": datetime.now(gmt8).isoformat(), + "event_type": "user_interrupt", + "gpu_id": CURRENT_CONFIG.get("gpu_id", "N/A"), + "batch_size": CURRENT_CONFIG["batch_size"], + "matrix_dims": f"M={CURRENT_CONFIG['M']} N={CURRENT_CONFIG['N']} K={CURRENT_CONFIG['K']}", + "config": CURRENT_CONFIG["config"], + "progress": f"Config {CURRENT_CONFIG['config_index'] + 1}/{CURRENT_CONFIG['total_configs']}", + "weight_shape_progress": f"Shape {CURRENT_CONFIG['weight_shape_index'] + 1}/{CURRENT_CONFIG['total_weight_shapes']}", + } + + # Log detailed interruption info + logger.info(f"=== USER INTERRUPT ===") + logger.info( + f"Interrupted while testing: Config {CURRENT_CONFIG['config_index'] + 1}/{CURRENT_CONFIG['total_configs']}" + ) + logger.info(f"GPU: {CURRENT_CONFIG.get('gpu_id', 'N/A')}") + logger.info( + f"Matrix: M={CURRENT_CONFIG['M']} N={CURRENT_CONFIG['N']} K={CURRENT_CONFIG['K']}" + ) + logger.info( + f"Weight Shape Progress: {CURRENT_CONFIG['weight_shape_index'] + 1}/{CURRENT_CONFIG['total_weight_shapes']}" + ) + + # Log config details in same format as console output for consistency + if CURRENT_CONFIG["config"]: + config = CURRENT_CONFIG["config"] + config_num = CURRENT_CONFIG["config_index"] + 1 + config_str = f"Config {config_num} (INTERRUPTED): | BM:{config.get('BLOCK_SIZE_M')} BN:{config.get('BLOCK_SIZE_N')} BK:{config.get('BLOCK_SIZE_K')} W:{config.get('num_warps')} S:{config.get('num_stages')} KS:{config.get('NUM_KSPLIT')} kpack:{config.get('kpack')} cache:{config.get('cache_modifier')} GROUP_SIZE_M:{config.get('GROUP_SIZE_M')} GROUP_K:{config.get('GROUP_K')}" + logger.info(f"CONFIG_DETAILS: {config_str}") + + logger.info(f"DETAILED_ENTRY: {detailed_log_entry}") + logger.info(f"=== END USER INTERRUPT ===") + + # Force flush to write immediately + for handler in logger.handlers: + if hasattr(handler, "stream"): + handler.stream.flush() + + print(" ๐Ÿ“ Interruption logged to tuning log file") + except Exception as e: + print(f" โš ๏ธ Could not log interruption: {e}") + + print("\n๐Ÿ’ก You can use this information to:") + print(" โ€ข Skip this problematic configuration in future runs") + print(" โ€ข Analyze why this specific config might be causing issues") + print(" โ€ข Adjust the search space to avoid similar parameter combinations") + print("=" * 80) + + # Exit gracefully + import sys + + sys.exit(1) + + +def setup_logger(log_file_path: str, mode: str = "a") -> logging.Logger: + """ + Setup logger for recording bad configurations during tuning. + + Args: + log_file_path: Path to the log file + mode: File write mode - 'a' to append to existing logs, 'w' to overwrite + + Returns: + Configured logger instance + """ + logger = logging.getLogger("gemm_a8w8_blockscale_tuning") + logger.setLevel(logging.INFO) + + # Clear existing handlers + logger.handlers.clear() + + # Create file handler with live writing (immediate flush) + # Default to append mode to preserve logs across resume sessions + file_handler = logging.FileHandler(log_file_path, mode=mode) + file_handler.setLevel(logging.INFO) + + # Create custom formatter that flushes immediately + file_handler.flush = lambda: file_handler.stream.flush() # type: ignore + + # Create console handler + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.WARNING) + + # Create GMT+8 formatter + formatter = GMT8Formatter( + "%(asctime)s [GMT+8] - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S" + ) + file_handler.setFormatter(formatter) + console_handler.setFormatter(formatter) + + # Add handlers to logger + logger.addHandler(file_handler) + logger.addHandler(console_handler) + + return logger + + +def log_bad_config( + logger: logging.Logger, + error_type: str, + M: int, + N: int, + K: int, + config: Dict[str, Union[str, int]], + error_msg: str = "", +): + """ + Log a bad configuration that failed during tuning. + + Args: + logger: Logger instance + error_type: Type of error ('timeout', 'out_of_resources', 'assert_error', 'other_error') + M, N, K: Matrix dimensions + config: Configuration that failed + error_msg: Additional error message + """ + # Use GMT+8 timestamp for consistency + gmt8 = pytz.timezone("Asia/Shanghai") + log_entry = { + "timestamp": datetime.now(gmt8).isoformat(), + "error_type": error_type, + "batch_size": M, + "matrix_dims": f"M={M} N={N} K={K}", + "config": config, + "error_msg": str(error_msg), + } + + # Log to file + logger.info(f"BAD_CONFIG_{error_type.upper()}: {log_entry}") + + # Force flush to write immediately + for handler in logger.handlers: + if hasattr(handler, "stream"): + handler.stream.flush() + + # Store in global list for summary + if error_type == "timeout": + BAD_CONFIGS["timeouts"].append(log_entry) + elif error_type == "out_of_resources": + BAD_CONFIGS["out_of_resources"].append(log_entry) + elif error_type == "assert_error": + BAD_CONFIGS["assert_errors"].append(log_entry) + else: + BAD_CONFIGS["other_errors"].append(log_entry) + + +def log_bad_config_summary(logger: logging.Logger, total_configs_tested: int): + """ + Log a summary of all bad configurations encountered during tuning. + + Args: + logger: Logger instance + total_configs_tested: Total number of configurations tested + """ + total_bad = ( + len(BAD_CONFIGS["timeouts"]) + + len(BAD_CONFIGS["out_of_resources"]) + + len(BAD_CONFIGS["assert_errors"]) + + len(BAD_CONFIGS["other_errors"]) + ) + success_rate = ( + ((total_configs_tested - total_bad) / total_configs_tested * 100) + if total_configs_tested > 0 + else 0 + ) + + logger.info("=" * 80) + logger.info("BAD CONFIGURATION SUMMARY") + logger.info("=" * 80) + logger.info(f"Total configurations tested: {total_configs_tested}") + logger.info(f"Successful configurations: {total_configs_tested - total_bad}") + logger.info(f"Failed configurations: {total_bad}") + logger.info(f"Success rate: {success_rate:.2f}%") + logger.info("") + + logger.info(f"Timeouts: {len(BAD_CONFIGS['timeouts'])}") + logger.info(f"Out of Resources: {len(BAD_CONFIGS['out_of_resources'])}") + logger.info(f"Assert Errors: {len(BAD_CONFIGS['assert_errors'])}") + logger.info(f"Other Errors: {len(BAD_CONFIGS['other_errors'])}") + logger.info("") + + if BAD_CONFIGS["timeouts"]: + logger.info("TIMEOUT CONFIGS (most problematic):") + for entry in BAD_CONFIGS["timeouts"]: + config = entry["config"] + logger.info( + f" - Batch {entry['batch_size']} | {entry['matrix_dims']} | BM:{config.get('BLOCK_SIZE_M', 'N/A')}, BN:{config.get('BLOCK_SIZE_N', 'N/A')}, BK:{config.get('BLOCK_SIZE_K', 'N/A')}, W:{config.get('num_warps', 'N/A')}, S:{config.get('num_stages', 'N/A')}, KS:{config.get('NUM_KSPLIT', 'N/A')}" + ) + + if BAD_CONFIGS["out_of_resources"]: + logger.info("OUT OF RESOURCE CONFIGS:") + for entry in BAD_CONFIGS["out_of_resources"]: + config = entry["config"] + logger.info( + f" - Batch {entry['batch_size']} | {entry['matrix_dims']} | BM:{config.get('BLOCK_SIZE_M', 'N/A')}, BN:{config.get('BLOCK_SIZE_N', 'N/A')}, BK:{config.get('BLOCK_SIZE_K', 'N/A')}, W:{config.get('num_warps', 'N/A')}, S:{config.get('num_stages', 'N/A')}, KS:{config.get('NUM_KSPLIT', 'N/A')}" + ) + + logger.info("=" * 80) + + # Print summary to console as well + print("\n๐Ÿ“Š Bad Configuration Summary:") + print(f" Total tested: {total_configs_tested}") + print(f" โœ… Successful: {total_configs_tested - total_bad}") + print( + f" โŒ Failed: {total_bad} ({len(BAD_CONFIGS['timeouts'])} timeouts, {len(BAD_CONFIGS['out_of_resources'])} OOM, {len(BAD_CONFIGS['assert_errors'])} assert, {len(BAD_CONFIGS['other_errors'])} other)" + ) + print(f" ๐Ÿ“ˆ Success rate: {success_rate:.1f}%") + + +def timeout_handler(signum, frame): + """Signal handler for timeout.""" + raise TimeoutError("Kernel execution timed out") + + +def run_with_timeout(func, timeout_seconds=3, *args, **kwargs): + """ + Run a function with a timeout limit. + + Args: + func: Function to execute + timeout_seconds: Timeout in seconds (default: 3) + *args: Arguments to pass to the function + **kwargs: Keyword arguments to pass to the function + + Returns: + Result of the function call + + Raises: + TimeoutError: If function execution exceeds timeout + """ + # Set the signal handler + old_handler = signal.signal(signal.SIGALRM, timeout_handler) + signal.alarm(timeout_seconds) + + try: + result = func(*args, **kwargs) + return result + finally: + # Cancel the alarm and restore the old handler + signal.alarm(0) + signal.signal(signal.SIGALRM, old_handler) + + +def generate_gemm_a8w8_blockscale_inputs( + M: int, + N: int, + K: int, + block_shape_n: int, + block_shape_k: int, + dtype=torch.bfloat16, + layout: str = "TN", + output=False, +): + """ + The GEMM kernel expects: + - x: (M, K) -> row-major format + - w: (N, K) -> column-major format + """ + scale_n = (N + block_shape_n - 1) // block_shape_n + scale_k = (K + block_shape_k - 1) // block_shape_k + + if layout[0] == "T": + x = (torch.rand((M, K), dtype=torch.float16, device="cuda") / 10).to(e4m3_type) + else: + x = ( + (torch.rand((K, M), dtype=torch.float16, device="cuda") / 10) + .to(e4m3_type) + .T + ) + + if layout[1] == "N": + weight = (torch.rand((N, K), dtype=torch.float16, device="cuda") / 10).to( + e4m3_type + ) + else: + weight = ( + (torch.rand((K, N), dtype=torch.float16, device="cuda") / 10) + .to(e4m3_type) + .T + ) + + x_scale = torch.rand([M, scale_k], dtype=torch.float32, device="cuda") + w_scale = torch.rand([scale_n, scale_k], dtype=torch.float32, device="cuda") + + y = None + if output: + y = torch.empty((M, N), dtype=dtype, device="cuda").cuda() + + return x, weight, x_scale, w_scale, y + + +def get_configs_compute_bound() -> List[Dict[str, int | str]]: + """ + Generate configuration space for tuning the gemm_a8w8_blockscale kernel. + Based on the sample config file, we'll tune around those values. + Note: With (128, 128) quantization blocks, GROUP_K will be computed as 128, + so we must only use BLOCK_SIZE_K = 128 to satisfy the constraint GROUP_K == BLOCK_SIZE_K. + """ + configs = [] + + # For blockscale kernel with (128, 128) quantization blocks: + # - scale_k = ceil(K / 128) = 8 for K=1024 + # - GROUP_K = next_power_of_2(K / scale_k) = next_power_of_2(128) = 128 + # - BLOCK_SIZE_K must equal GROUP_K, so BLOCK_SIZE_K must be 128 + block_k = 128 # Fixed to match computed GROUP_K + + # Explore optimized parameter space for blockscale kernel + for num_stages in [1, 2, 3, 4]: + for block_m in [32, 64, 128]: # Removed 256 (causes slowdowns) + for block_n in [32, 64, 128]: # Removed 256 (causes slowdowns) + for group_size_m in [1, 8, 16]: + for num_warps in [2, 4, 8]: + for num_ksplit in [ + 1, + 2, + 4, + ]: # Key parameter for K-splitting + for waves_per_eu in [2, 4, 8]: + for kpack in [2]: + for cache_modifier in ["", ".cg"]: + # Note: GROUP_K and GROUP_N are computed by the kernel function + # from the scale tensor shapes, not hardcoded in config + configs.append( + { + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, # Must be 128 to match computed GROUP_K + "GROUP_SIZE_M": group_size_m, + "num_warps": num_warps, + "num_stages": num_stages, + "NUM_KSPLIT": num_ksplit, + "waves_per_eu": waves_per_eu, + "kpack": kpack, + "matrix_instr_nonkdim": 16, # Fixed value from atomic kernel + "cache_modifier": cache_modifier, + } + ) + return configs + + +def get_weight_shapes(tp_size: int = 1) -> List[Tuple[int, int]]: + """Get weight shapes to test during tuning.""" + total = [ + (1024, 1024), + (4096, 1024), + (1024, 2048), + (6144, 1024), + (1024, 3072), + ] + + weight_shapes: List[Tuple[int, int]] = [] + for t in total: + weight_shapes.append(t) + + return weight_shapes + + +def run_torch_reference( + x: torch.Tensor, + w: torch.Tensor, + x_scale: torch.Tensor, + w_scale: torch.Tensor, + block_shape: Tuple[int, int], + dtype=torch.bfloat16, +): + """ + Run reference implementation using PyTorch for blockscale kernel. + This is used for correctness verification. + """ + block_shape_n, block_shape_k = block_shape + m, k = x.shape + n = w.shape[0] + + # Expand scales to match the block sizes (same as original) + x_scale_expanded = x_scale.repeat_interleave(block_shape_k, dim=1) + x_dequant = x.to(x_scale_expanded.dtype) * x_scale_expanded[:m, :k] + + # Expand weight scales: first repeat along N dimension, then along K dimension + w_scale_expanded = w_scale.repeat_interleave(block_shape_n, dim=0) + w_scale_expanded = w_scale_expanded.repeat_interleave(block_shape_k, dim=1) + w_scale_expanded = w_scale_expanded[:n, :k] + weight_dequant = w.to(w_scale_expanded.dtype) * w_scale_expanded + + out = torch.nn.functional.linear( + x_dequant.to(torch.float32), weight_dequant.to(torch.float32) + ) + + return out.to(dtype) + + +# Global variable to store console output +console_output = [] + + +def create_live_display( + M: int, + N: int, + K: int, + current_config: Dict[str, Union[str, int]], + best_config: Dict[str, Union[str, int]], + best_time: float, + config_index: int, + total_configs: int, + console_messages: Optional[List[str]] = None, +) -> Layout: + """Create a live display layout with current and best configuration tables.""" + + layout = Layout() + + # Use global console_output if none provided + if console_messages is None: + global console_output + console_messages = console_output + + # Create progress bar + progress = Progress( + TextColumn("[bold blue]{task.description}"), + BarColumn(bar_width=40), + TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), + TextColumn("({task.completed}/{task.total})"), + ) + task = progress.add_task( + f"๐Ÿ”ง Tuning M={M} N={N} K={K} | Batch Size={M}", + total=total_configs, + completed=config_index, + ) + + # Create status information + status_text = "" + if best_time != float("inf"): + status_text = f"๐Ÿ† Best Performance: {best_time:.1f}ฮผs" + else: + status_text = "๐Ÿ” Searching for best configuration..." + + # Create console area + console_text = ( + "\n".join(console_messages[-10:]) + if console_messages + else "Waiting for results..." + ) + console_table = Table(show_header=False, box=None, padding=0) + console_table.add_column("Output", style="white") + console_table.add_row(console_text) + + # Create current config table + config_table = Table( + title="Current Configuration", + show_header=True, + header_style="bold magenta", + ) + config_table.add_column("Parameter", style="cyan", width=15) + config_table.add_column("Value", style="green", width=10) + + # Add matrix dimensions and batch size first + config_table.add_row("[bold yellow]Matrix M[/bold yellow]", str(M), style="yellow") + config_table.add_row("[bold yellow]Matrix N[/bold yellow]", str(N), style="yellow") + config_table.add_row("[bold yellow]Matrix K[/bold yellow]", str(K), style="yellow") + config_table.add_row( + "[bold yellow]Batch Size[/bold yellow]", str(M), style="yellow" + ) + config_table.add_row("", "") # Separator + config_table.add_row("BLOCK_SIZE_M", str(current_config.get("BLOCK_SIZE_M", "N/A"))) + config_table.add_row("BLOCK_SIZE_N", str(current_config.get("BLOCK_SIZE_N", "N/A"))) + config_table.add_row("BLOCK_SIZE_K", str(current_config.get("BLOCK_SIZE_K", "N/A"))) + config_table.add_row("num_warps", str(current_config.get("num_warps", "N/A"))) + config_table.add_row("num_stages", str(current_config.get("num_stages", "N/A"))) + config_table.add_row("NUM_KSPLIT", str(current_config.get("NUM_KSPLIT", "N/A"))) + config_table.add_row("waves_per_eu", str(current_config.get("waves_per_eu", "N/A"))) + config_table.add_row("kpack", str(current_config.get("kpack", "N/A"))) + config_table.add_row( + "cache_modifier", str(current_config.get("cache_modifier", "N/A")) + ) + config_table.add_row("GROUP_SIZE_M", str(current_config.get("GROUP_SIZE_M", "N/A"))) + config_table.add_row("GROUP_K", str(current_config.get("GROUP_K", "N/A"))) + + # Create best config table if we have a best configuration + best_config_table = None + if best_time != float("inf"): + best_config_table = Table( + title="๐Ÿ† Best Configuration So Far", + show_header=True, + header_style="bold green", + ) + best_config_table.add_column("Parameter", style="cyan", width=15) + best_config_table.add_column("Value", style="green", width=10) + + # Add performance and matrix dimensions + best_config_table.add_row( + "[bold green]Performance[/bold green]", f"{best_time:.1f}ฮผs", style="green" + ) + best_config_table.add_row("", "") # Separator + best_config_table.add_row( + "[bold yellow]BLOCK_SIZE_M[/bold yellow]", + str(best_config.get("BLOCK_SIZE_M", "N/A")), + style="yellow", + ) + best_config_table.add_row( + "[bold yellow]BLOCK_SIZE_N[/bold yellow]", + str(best_config.get("BLOCK_SIZE_N", "N/A")), + style="yellow", + ) + best_config_table.add_row( + "[bold yellow]BLOCK_SIZE_K[/bold yellow]", + str(best_config.get("BLOCK_SIZE_K", "N/A")), + style="yellow", + ) + best_config_table.add_row("num_warps", str(best_config.get("num_warps", "N/A"))) + best_config_table.add_row( + "num_stages", str(best_config.get("num_stages", "N/A")) + ) + best_config_table.add_row( + "NUM_KSPLIT", str(best_config.get("NUM_KSPLIT", "N/A")) + ) + best_config_table.add_row( + "waves_per_eu", str(best_config.get("waves_per_eu", "N/A")) + ) + best_config_table.add_row("kpack", str(best_config.get("kpack", "N/A"))) + best_config_table.add_row( + "cache_modifier", str(best_config.get("cache_modifier", "N/A")) + ) + best_config_table.add_row( + "GROUP_SIZE_M", str(best_config.get("GROUP_SIZE_M", "N/A")) + ) + best_config_table.add_row("GROUP_K", str(best_config.get("GROUP_K", "N/A"))) + + # Create combined layout + if best_config_table: + # Display tables side by side + tables = Columns([config_table, best_config_table], equal=True, expand=True) + layout.split_column( + Layout(Panel(progress, title="Progress", border_style="blue"), size=5), + Layout(Panel(status_text, title="Status", border_style="green"), size=3), + Layout(tables), + Layout( + Panel(console_table, title="Console Output", border_style="cyan"), + size=10, + ), + ) + else: + # Display only current config + layout.split_column( + Layout(Panel(progress, title="Progress", border_style="blue"), size=5), + Layout(Panel(status_text, title="Status", border_style="green"), size=3), + Layout(config_table), + Layout( + Panel(console_table, title="Console Output", border_style="cyan"), + size=10, + ), + ) + + return layout + + +def benchmark_config( + x: torch.Tensor, + w: torch.Tensor, + x_scale: torch.Tensor, + w_scale: torch.Tensor, + dtype: torch.dtype, + config: Dict[str, Union[str, int]], + y: Optional[torch.Tensor] = None, + num_iters=10, +) -> float: + """ + Benchmark the performance of a GEMM operation with a specific configuration. + + This function measures the execution time of the gemm_a8w8_blockscale kernel by running + it multiple times with synchronization points to ensure accurate timing. It performs + warmup runs before the actual benchmarking to account for JIT compilation overhead. + + Args: + x (torch.Tensor): Input tensor of shape (M, K) representing the first matrix operand. + w (torch.Tensor): Weight tensor of shape (N, K) representing the second matrix operand. + x_scale (torch.Tensor): Scale tensor for x with shape (M, scale_k). + w_scale (torch.Tensor): Scale tensor for w with shape (scale_n, scale_k). + dtype (torch.dtype): Data type for the computation (e.g., torch.bfloat16). + config (Dict[str, Union[str, int]]): Configuration dictionary containing kernel + parameters such as block sizes, number of warps, etc. + y (Optional[torch.Tensor], optional): Output tensor to store the result. If None, + a new tensor will be allocated. Defaults to None. + num_iters (int, optional): Number of benchmark iterations to run. Defaults to 10. + + Returns: + float: Average execution time in microseconds (us) per iteration. + + Note: + The function performs 5 warmup iterations before benchmarking to account for + JIT compilation and GPU warmup effects. The timing is measured using CUDA events + for accurate GPU kernel timing. + """ + # Get reference output for correctness verification + torch_out = run_torch_reference( + x, + w, + x_scale, + w_scale, + (128, 128), + dtype, # follow test using (128,128) + ) + + # Add SPLITK_BLOCK_SIZE computation as done in the kernel function + _, K = x.shape + _, K = w.shape + num_ksplit = int(config["NUM_KSPLIT"]) + block_k = int(config["BLOCK_SIZE_K"]) + splitk_block_size = triton.cdiv(K, num_ksplit) + + config["SPLITK_BLOCK_SIZE"] = splitk_block_size + if block_k > splitk_block_size: + block_k = triton.next_power_of_2(splitk_block_size) + if block_k > splitk_block_size: + block_k = block_k // 4 + block_k = max(block_k, 16) + config["BLOCK_SIZE_K"] = block_k + + # Run kernel + def run(): + return gemm_a8w8_blockscale( + x, w, x_scale, w_scale, dtype, y, config, skip_reduce=False + ) + + torch.cuda.synchronize() + + # JIT compilation & warmup with timeout for entire warmup phase + def run_warmup(): + for i in range(5): + run() + + try: + run_with_timeout(run_warmup, timeout_seconds=3) + except TimeoutError: + # If warmup times out, this config is likely bad, skip it + raise TimeoutError("Warmup phase timed out after 3 seconds") + torch.cuda.synchronize() + + start_event = torch.Event(enable_timing=True) + end_event = torch.Event(enable_timing=True) + + latencies: list[float] = [] + for i in range(num_iters): + torch.cuda.synchronize() + start_event.record() + try: + triton_out_raw = run_with_timeout(run, timeout_seconds=3) + except TimeoutError: + # If benchmark iteration times out, skip this config + raise TimeoutError(f"Benchmark iteration {i + 1} timed out after 3 seconds") + + # Convert to the same dtype as the reference for comparison + # Handle the case where triton_out_raw might be None + if triton_out_raw is not None: + triton_out = triton_out_raw.to(torch_out.dtype) + else: + triton_out = torch_out # Fallback to reference output + end_event.record() + end_event.synchronize() + latencies.append(start_event.elapsed_time(end_event)) + torch.testing.assert_close(triton_out, torch_out, atol=1e-1, rtol=1e-1) + avg = sum(latencies) / (num_iters * 10) * 1000 # us + + return avg + + +def tune( + M: int, + N: int, + K: int, + search_space: List[Dict[str, int | str]], + input_type: str, + logger: logging.Logger, +): + """Tune the kernel for specific matrix dimensions.""" + # Register SIGINT handler if not already registered + if not signal.getsignal(signal.SIGINT) == sigint_handler: + signal.signal(signal.SIGINT, sigint_handler) + + if input_type == "bfloat16": + # Use the same input generation as test file + # IMPORTANT: Use fixed quantization block sizes (128, 128), not kernel BLOCK_SIZE_N/BLOCK_SIZE_K + # These are completely different concepts - quantization blocks vs kernel tiling + quant_block_size_n, quant_block_size_k = 128, 128 + x, w, x_scale, w_scale, _ = generate_gemm_a8w8_blockscale_inputs( + M, N, K, quant_block_size_n, quant_block_size_k, torch.bfloat16 + ) + else: + raise RuntimeError( + "Currently, only support tune a8w8 blockscale kernel with bfloat16 output." + ) + + best_config: Dict[str, Union[int, str]] = {} + best_time = float("inf") + slow_config_threshold = ( + 1000 # microseconds - configs slower than this get highlighted + ) + + # Clear console output for fresh start + global console_output + console_output = [] + + # Initialize Rich console for better formatting + console = Console() + + # Create initial live display + initial_layout = create_live_display( + M, + N, + K, + search_space[0] if search_space else {}, + best_config, + best_time, + 0, + len(search_space), + console_output, + ) + + # Create progress display with Rich and Live + with Live(initial_layout, refresh_per_second=4, console=console) as live: + for i, config in enumerate(search_space): + # Check if we were interrupted + if INTERRUPTED: + break + + # Update global state for SIGINT handling + CURRENT_CONFIG.update( + { + "M": M, + "N": N, + "K": K, + "config": config, + "config_index": i, + "total_configs": len(search_space), + "batch_size": M, + } + ) + + # Update live display with current configuration + layout = create_live_display( + M, + N, + K, + config, + best_config, + best_time, + i + 1, + len(search_space), + console_output, + ) + live.update(layout) + + try: + kernel_time = benchmark_config( + x=x, + w=w, + x_scale=x_scale, + w_scale=w_scale, + dtype=torch.bfloat16, + y=None, + config=config, + num_iters=10, + ) + + # Add kernel time to console output + if kernel_time > slow_config_threshold: + console_msg = f"[yellow]โš ๏ธ Config {i + 1} (SLOW): {kernel_time:.1f}ฮผs | BM:{config.get('BLOCK_SIZE_M')} BN:{config.get('BLOCK_SIZE_N')} BK:{config.get('BLOCK_SIZE_K')} W:{config.get('num_warps')} S:{config.get('num_stages')} KS:{config.get('NUM_KSPLIT')} GSM:{config.get('GROUP_SIZE_M')}[/yellow]" + live.console.print(console_msg) + console_output.append(console_msg) + else: + console_msg = f"[green]โœ… Config {i + 1}: {kernel_time:.1f}ฮผs | BM:{config.get('BLOCK_SIZE_M')} BN:{config.get('BLOCK_SIZE_N')} BK:{config.get('BLOCK_SIZE_K')} W:{config.get('num_warps')} S:{config.get('num_stages')} KS:{config.get('NUM_KSPLIT')} GSM:{config.get('GROUP_SIZE_M')}[/green]" + console_output.append(console_msg) + + # Update best time and config + if kernel_time < best_time: + best_time = kernel_time + best_config = config + best_msg = ( + f"[bold green]๐Ÿ† NEW BEST: {kernel_time:.1f}ฮผs![/bold green]" + ) + console_output.append(best_msg) + + # Update live display with current configuration and console output + layout = create_live_display( + M, + N, + K, + config, + best_config, + best_time, + i + 1, + len(search_space), + console_output, + ) + live.update(layout) + + except triton.runtime.autotuner.OutOfResources as e: + # Log and skip out of resources configurations + log_bad_config(logger, "out_of_resources", M, N, K, config, str(e)) + error_msg = f"[red]โŒ Config {i + 1} (OOM): Out of resources | BM:{config.get('BLOCK_SIZE_M')} BN:{config.get('BLOCK_SIZE_N')} BK:{config.get('BLOCK_SIZE_K')} W:{config.get('num_warps')} S:{config.get('num_stages')} KS:{config.get('NUM_KSPLIT')} GSM:{config.get('GROUP_SIZE_M')}[/red]" + console_output.append(error_msg) + live.console.print(error_msg) + continue + except AssertionError as e: + # Log and skip assert error configurations + log_bad_config(logger, "assert_error", M, N, K, config, str(e)) + error_msg = f"[red]โŒ Config {i + 1} (ASSERT): Assert error | BM:{config.get('BLOCK_SIZE_M')} BN:{config.get('BLOCK_SIZE_N')} BK:{config.get('BLOCK_SIZE_K')} W:{config.get('num_warps')} S:{config.get('num_stages')} KS:{config.get('NUM_KSPLIT')} GSM:{config.get('GROUP_SIZE_M')} | {e}[/red]" + console_output.append(error_msg) + live.console.print(error_msg) + continue + except TimeoutError as e: + # Log and skip timeout configurations + log_bad_config(logger, "timeout", M, N, K, config, str(e)) + error_msg = f"[orange1]โฑ๏ธ Config {i + 1} (TIMEOUT): {e} | BM:{config.get('BLOCK_SIZE_M')} BN:{config.get('BLOCK_SIZE_N')} BK:{config.get('BLOCK_SIZE_K')} W:{config.get('num_warps')} S:{config.get('num_stages')} KS:{config.get('NUM_KSPLIT')} GSM:{config.get('GROUP_SIZE_M')}[/orange1]" + console_output.append(error_msg) + live.console.print(error_msg) + continue + except Exception as e: + # Log and skip other error configurations + log_bad_config(logger, "other_error", M, N, K, config, str(e)) + error_msg = f"[red]๐Ÿ’ฅ Config {i + 1} (ERROR): {e} | BM:{config.get('BLOCK_SIZE_M')} BN:{config.get('BLOCK_SIZE_N')} BK:{config.get('BLOCK_SIZE_K')} W:{config.get('num_warps')} S:{config.get('num_stages')} KS:{config.get('NUM_KSPLIT')} GSM:{config.get('GROUP_SIZE_M')}[/red]" + console_output.append(error_msg) + live.console.print(error_msg) + continue + + # Show final completion message with Rich + print("\n" + "=" * 70) + + # Create best config table with matrix dimensions + best_table = Table( + title="๐Ÿ† Best Configuration Found", show_header=True, header_style="bold green" + ) + best_table.add_column("Parameter", style="cyan", width=15) + best_table.add_column("Value", style="green", width=10) + + # Add matrix dimensions and batch size first + best_table.add_row( + "[bold yellow]Matrix M (Batch)[/bold yellow]", str(M), style="yellow" + ) + best_table.add_row("[bold yellow]Matrix N[/bold yellow]", str(N), style="yellow") + best_table.add_row("[bold yellow]Matrix K[/bold yellow]", str(K), style="yellow") + best_table.add_row("", "") # Separator + best_table.add_row( + "[bold green]Performance[/bold green]", f"{best_time:.1f}ฮผs", style="green" + ) + best_table.add_row("", "") # Separator + best_table.add_row( + "[bold yellow]BLOCK_SIZE_M[/bold yellow]", + str(best_config.get("BLOCK_SIZE_M", "N/A")), + style="yellow", + ) + best_table.add_row( + "[bold yellow]BLOCK_SIZE_N[/bold yellow]", + str(best_config.get("BLOCK_SIZE_N", "N/A")), + style="yellow", + ) + best_table.add_row( + "[bold yellow]BLOCK_SIZE_K[/bold yellow]", + str(best_config.get("BLOCK_SIZE_K", "N/A")), + style="yellow", + ) + best_table.add_row("num_warps", str(best_config.get("num_warps", "N/A"))) + best_table.add_row("num_stages", str(best_config.get("num_stages", "N/A"))) + best_table.add_row("NUM_KSPLIT", str(best_config.get("NUM_KSPLIT", "N/A"))) + best_table.add_row("waves_per_eu", str(best_config.get("waves_per_eu", "N/A"))) + best_table.add_row("kpack", str(best_config.get("kpack", "N/A"))) + best_table.add_row("cache_modifier", str(best_config.get("cache_modifier", "N/A"))) + best_table.add_row("GROUP_SIZE_M", str(best_config.get("GROUP_SIZE_M", "N/A"))) + best_table.add_row("GROUP_K", str(best_config.get("GROUP_K", "N/A"))) + best_table.add_row( + "matrix_instr_nonkdim", str(best_config.get("matrix_instr_nonkdim", "N/A")) + ) + + completion_panel = Panel( + best_table, + title=f"[bold green]โœ… Completed Tuning for M={M} N={N} K={K} (Batch Size={M})[/bold green]", + border_style="green", + ) + console.print(completion_panel) + print("=" * 70) + + assert best_config is not None + return best_config + + +def save_configs( + N, + K, + configs, + save_path, + is_incremental=False, + completed_batch_sizes=None, +) -> None: + """Save the best configurations to a JSON file.""" + os.makedirs(save_path, exist_ok=True) + device_name = "R9700" # TODO: Hardcoded, make it dynamic + + if is_incremental: + # Save incremental progress with batch size info in filename + batch_sizes_str = ( + "_".join(map(str, completed_batch_sizes)) + if completed_batch_sizes + else "partial" + ) + json_file_name = f"{device_name}-GEMM-A8W8_BLOCKSCALE-N={N}-K={K}_batch_{batch_sizes_str}.json" + progress_file = os.path.join( + save_path, + f"{device_name}-GEMM-A8W8_BLOCKSCALE-N={N}-K={K}_progress.json", + ) + + # Save progress info + progress_info = { + "completed_batch_sizes": completed_batch_sizes or [], + "configs": configs, + "last_updated": datetime.now().isoformat(), + } + + with open(progress_file, "w") as f: + json.dump(progress_info, f, indent=4) + f.write("\n") + else: + json_file_name = f"{device_name}-GEMM-A8W8_BLOCKSCALE-N={N}-K={K}.json" + + config_file_path = os.path.join(save_path, json_file_name) + + # Add incremental flag to filename + action = "Updating incremental" if is_incremental else "Writing" + print(f"{action} config to {config_file_path}...") + + with open(config_file_path, "w") as f: + json.dump(configs, f, indent=4) + f.write("\n") + + +def load_progress(N, K, save_path): + """Load previously saved progress for a given N,K configuration.""" + device_name = "R9700" # TODO: Hardcoded, make it dynamic + progress_file = os.path.join( + save_path, f"{device_name}-GEMM-A8W8_BLOCKSCALE-N={N}-K={K}_progress.json" + ) + + if os.path.exists(progress_file): + try: + with open(progress_file, "r") as f: + progress_info = json.load(f) + return progress_info.get("completed_batch_sizes", []), progress_info.get( + "configs", {} + ) + except Exception as e: + print(f"Warning: Could not load progress file {progress_file}: {e}") + return [], {} + return [], {} + + +def tune_on_gpu( + gpu_id: int, + batch_sizes: List[int], + weight_shapes: List[Tuple[int, int]], + input_type: str, + resume: bool = True, + log_filename: Optional[str] = None, +) -> None: + """Run tuning on a specific GPU.""" + # Register SIGINT handler and set GPU ID in global state + signal.signal(signal.SIGINT, sigint_handler) + CURRENT_CONFIG["gpu_id"] = gpu_id + + torch.cuda.set_device(gpu_id) + print(f"๐Ÿš€ Starting tuning on GPU {gpu_id} with batch sizes {batch_sizes}") + + save_path = AITER_TRITON_CONFIGS_PATH + "/gemm/" + + # Setup logger for this GPU with custom or timestamped filename + if log_filename: + # Use custom filename, ensure it has .log extension + if not log_filename.endswith(".log"): + log_filename += ".log" + # If no path separator, assume it's just a filename + if "/" not in log_filename and "\\" not in log_filename: + log_filename = os.path.join(save_path, log_filename) + else: + log_filename = log_filename # Use full path as provided + else: + # Fall back to timestamped filename + log_filename = os.path.join( + save_path, + get_timestamped_filename(f"tune_a8w8_blockscale_bad_configs_gpu{gpu_id}"), + ) + + # Choose appropriate logging mode: append for resume, overwrite for fresh start + log_mode = "a" if resume else "w" + logger = setup_logger(log_filename, mode=log_mode) + + # Log the start time in GMT+8 + gmt8 = pytz.timezone("Asia/Shanghai") + start_time_gmt8 = datetime.now(gmt8).strftime("%Y-%m-%d %H:%M:%S") + logger.info(f"=== TUNING SESSION STARTED AT {start_time_gmt8} [GMT+8] ===") + logger.info(f"GPU: {gpu_id}") + logger.info(f"Batch sizes: {batch_sizes}") + + search_space = get_configs_compute_bound() + total_configs = len(search_space) + total_tests = total_configs * len(batch_sizes) * len(weight_shapes) + + print(f" ๐Ÿ“Š Search space: {total_configs:,} configurations") + print(f" ๐ŸŽฏ Total tests to run: {total_tests:,}") + print( + f" โšก Estimated tests per weight shape: {total_configs * len(batch_sizes):,}" + ) + log_action = ( + "Appending to existing" + if resume and os.path.exists(log_filename) + else "Writing to new" + ) + print(f" ๐Ÿ“ Bad configurations will be logged to: {log_filename}") + print(f" ๐Ÿ“ Logging mode: {log_action}") + + start = time.time() + + # Collect all configs to determine the best overall config + all_configs: List[Dict[str, Dict[str, int | str]]] = [] + + for i, shape in enumerate(weight_shapes): + # Check if we were interrupted + if INTERRUPTED: + break + + # Update weight shape tracking + CURRENT_CONFIG.update( + {"weight_shape_index": i, "total_weight_shapes": len(weight_shapes)} + ) + + N, K = shape[0], shape[1] + print( + f"\n๐Ÿš€ [GPU {gpu_id}] Shape {i + 1}/{len(weight_shapes)}: Starting tuning for N:{N}, K:{K}" + ) + print( + f" ๐Ÿ“Š Testing {len(search_space):,} configurations across {len(batch_sizes)} batch sizes" + ) + + # Check for existing progress and resume from there (if resume is enabled) + if resume: + completed_batch_sizes, existing_configs = load_progress(N, K, save_path) + else: + completed_batch_sizes, existing_configs = [], {} + + # Filter batch_sizes to only those not yet completed + remaining_batch_sizes = [ + bs for bs in batch_sizes if bs not in completed_batch_sizes + ] + + if completed_batch_sizes and resume: + print(f"\n ๐Ÿ“‚ [GPU {gpu_id}] Found progress for N={N}, K={K}") + print(f" โœ… Already completed batch sizes: {completed_batch_sizes}") + print(f" ๐Ÿ”„ Remaining batch sizes to tune: {remaining_batch_sizes}") + elif not resume: + print( + f"\n ๐Ÿ”„ [GPU {gpu_id}] Starting fresh (resume disabled) for N={N}, K={K}" + ) + elif not remaining_batch_sizes: + print( + f"\n โœ… [GPU {gpu_id}] All batch sizes already completed for N={N}, K={K}" + ) + # Add existing configs to all_configs and continue to next shape + if existing_configs: + all_configs.append(existing_configs) + save_configs(N, K, existing_configs, save_path) + continue + + # Initialize benchmark_results with existing results if any + benchmark_results: List[Dict[str, str | int]] = [] + if existing_configs: + # Reconstruct benchmark_results from existing configs + # We need to map the configs back to their corresponding batch sizes + for i, batch_size in enumerate(batch_sizes): + if batch_size in completed_batch_sizes: + # Find the config for this batch size + config_to_add = None + + # Try to find matching config based on batch size category + if batch_size < 32 and "small" in existing_configs: + config_to_add = existing_configs["small"] + elif batch_size <= 128: + BLK_M = triton.next_power_of_2(batch_size) + if BLK_M == 32 and "medium_M32" in existing_configs: + config_to_add = existing_configs["medium_M32"] + elif BLK_M == 64 and "medium_M64" in existing_configs: + config_to_add = existing_configs["medium_M64"] + elif BLK_M == 128 and "medium_M128" in existing_configs: + config_to_add = existing_configs["medium_M128"] + elif batch_size <= 256 and "large" in existing_configs: + config_to_add = existing_configs["large"] + elif batch_size > 256 and "xlarge" in existing_configs: + config_to_add = existing_configs["xlarge"] + + if config_to_add: + benchmark_results.append(config_to_add) + else: + # If we couldn't find a matching config, we'll need to retune this batch size + remaining_batch_sizes.append(batch_size) + + for batch_size in remaining_batch_sizes: + # Check if we were interrupted + if INTERRUPTED: + break + + print( + f"\n ๐Ÿ” [GPU {gpu_id}] Testing batch size M={batch_size} for N={N}, K={K}" + ) + result = tune( + batch_size, + N, + K, + search_space, + input_type, + logger, + ) + + # Check if tune() was interrupted + if INTERRUPTED: + break + + benchmark_results.append(result) + + # Save incremental progress immediately after each batch size + updated_completed_batch_sizes = completed_batch_sizes + [batch_size] + + # Create configs for different M size categories as expected by the kernel + incremental_configs: Dict[str, Dict[str, int | str]] = {} + for i, (M, config) in enumerate( + zip(batch_sizes[: len(benchmark_results)], benchmark_results) + ): + if i == len(batch_sizes[: len(benchmark_results)]) - 1: + incremental_configs["any"] = config + elif M < 32: + incremental_configs["small"] = config + elif M <= 128: + BLK_M = triton.next_power_of_2(M) + if BLK_M == 32: + incremental_configs["medium_M32"] = config + elif BLK_M == 64: + incremental_configs["medium_M64"] = config + elif BLK_M == 128: + incremental_configs["medium_M128"] = config + elif M <= 256: + incremental_configs["large"] = config + else: + incremental_configs["xlarge"] = config + + # Save the incremental progress + save_configs( + N, + K, + incremental_configs, + save_path, + is_incremental=True, + completed_batch_sizes=updated_completed_batch_sizes, + ) + + print(f" ๐Ÿ’พ [GPU {gpu_id}] Saved progress for batch size {batch_size}") + + # Update completed_batch_sizes for next iteration + completed_batch_sizes = updated_completed_batch_sizes + + # Create final configs for different M size categories as expected by the kernel + best_configs: Dict[str, Dict[str, int | str]] = {} + for i, (M, config) in enumerate(zip(batch_sizes, benchmark_results)): + if i == len(batch_sizes) - 1: + best_configs["any"] = config + elif M < 32: + best_configs["small"] = config + elif M <= 128: + BLK_M = triton.next_power_of_2(M) + if BLK_M == 32: + best_configs["medium_M32"] = config + elif BLK_M == 64: + best_configs["medium_M64"] = config + elif BLK_M == 128: + best_configs["medium_M128"] = config + elif M <= 256: + best_configs["large"] = config + else: + best_configs["xlarge"] = config + + # Store configs for later analysis + all_configs.append(best_configs) + + # Save the final complete config (non-incremental) + save_configs(N, K, best_configs, save_path) + + # Clean up progress file since we completed successfully + device_name = "R9700" # TODO: Hardcoded, make it dynamic + progress_file = os.path.join( + save_path, + f"{device_name}-GEMM-A8W8_BLOCKSCALE-N={N}-K={K}_progress.json", + ) + if os.path.exists(progress_file): + os.remove(progress_file) + print(f" ๐Ÿงน [GPU {gpu_id}] Cleaned up progress file for N={N}, K={K}") + + # Create a default config file (without N,K parameters) by selecting the most common config + default_config = create_default_config(all_configs) + save_default_config(default_config, save_path) + + end = time.time() + + # Log session end time in GMT+8 + gmt8 = pytz.timezone("Asia/Shanghai") + end_time_gmt8 = datetime.now(gmt8).strftime("%Y-%m-%d %H:%M:%S") + duration = end - start + logger.info(f"=== TUNING SESSION COMPLETED AT {end_time_gmt8} [GMT+8] ===") + logger.info(f"Total duration: {duration:.2f} seconds") + + # Log summary of bad configurations + log_bad_config_summary(logger, total_tests) + + print(f"Tuning on GPU {gpu_id} took {duration:.2f} seconds") + + +def create_default_config( + all_configs: List[Dict[str, Dict[str, Union[int, str]]]], +) -> Dict[str, Dict[str, Union[int, str]]]: + """Create a default config by selecting the most common config across all shapes.""" + from collections import Counter + + # Collect all configs for each category + category_configs = { + "small": [], + "medium_M32": [], + "medium_M64": [], + "medium_M128": [], + "large": [], + "xlarge": [], + "any": [], + } + + for config in all_configs: + for category, params in config.items(): + if category in category_configs: + # Convert config to a hashable tuple for counting + config_tuple = tuple(sorted(params.items())) + category_configs[category].append(config_tuple) + + # Find the most common config for each category + default_config: Dict[str, Dict[str, Union[int, str]]] = {} + for category, configs in category_configs.items(): + if configs: + most_common = Counter(configs).most_common(1)[0][0] + default_config[category] = dict(most_common) + + return default_config + + +def save_default_config( + config: Dict[str, Dict[str, Union[int, str]]], save_path: str +) -> None: + """Save the default config file (without N,K parameters).""" + os.makedirs(save_path, exist_ok=True) + device_name = "R9700" # TODO: Hardcoded, make it dynamic + json_file_name = f"{device_name}-GEMM-A8W8_BLOCKSCALE.json" + + config_file_path = os.path.join(save_path, json_file_name) + print(f"Writing default config to {config_file_path}...") + + with open(config_file_path, "w") as f: + json.dump(config, f, indent=4) + f.write("\n") + + +def distribute_batch_sizes(batch_sizes: List[int], num_gpus: int) -> List[List[int]]: + """Distribute batch sizes across available GPUs.""" + batches_per_gpu: List[List[int]] = [] + for i in range(num_gpus): + start_idx = i * len(batch_sizes) // num_gpus + end_idx = (i + 1) * len(batch_sizes) // num_gpus + batches_per_gpu.append(batch_sizes[start_idx:end_idx]) + return batches_per_gpu + + +def main(args): + print(args) + num_gpus = torch.cuda.device_count() + if num_gpus == 0: + raise RuntimeError("No GPU available for tuning") + print(f"Found {num_gpus} GPUs for parallel tuning") + + torch.cuda.init() + + if args.batch_size is None: + batch_sizes = [ + 16, # For small config + 32, # For medium_M32 config + 64, # For medium_M64 config + 128, # For medium_M128 config + 256, # For large config + 512, # For large config + 2048, # For xlarge config + 4096, # For xlarge config + ] + else: + batch_sizes = [args.batch_size] + num_gpus = 1 # If only one batch size, use only one GPU + + weight_shapes = get_weight_shapes(args.tp_size) + + batches_per_gpu = distribute_batch_sizes(batch_sizes, num_gpus) + + # Prepare arguments for each GPU process + process_args = [] + for gpu_id in range(num_gpus): + process_args.append( + ( + gpu_id, + batches_per_gpu[gpu_id], + weight_shapes, # Each GPU processes all weight shapes + args.input_type, + args.resume, + args.log_filename, + ) + ) + + # Set up signal handler for main process to gracefully terminate workers + def main_sigint_handler(signum, frame): # type: ignore + print("\n" + "=" * 80) + print("๐Ÿ›‘ MAIN PROCESS INTERRUPTED BY USER (Ctrl+C)") + print("๐Ÿ“ก Sending termination signal to worker processes...") + print("โณ Giving workers 3 seconds to log their current state...") + print("=" * 80) + # Set a flag for workers to check and give them time to cleanup + global INTERRUPTED + INTERRUPTED = True + import time + + time.sleep(3) # Give workers time to handle the signal and log + sys.exit(1) + + # Register main process signal handler + if not signal.getsignal(signal.SIGINT) == main_sigint_handler: + signal.signal(signal.SIGINT, main_sigint_handler) + + ctx = mp.get_context("spawn") + try: + with ctx.Pool(num_gpus) as pool: + pool.starmap(tune_on_gpu, process_args) + except KeyboardInterrupt: + print("\n๐Ÿ›‘ Keyboard interrupt received in main process") + print("๐Ÿ“ก Worker processes terminated") + sys.exit(1) + except Exception as e: + print(f"\nโŒ Error in main process: {e}") + sys.exit(1) + + print("Multi-GPU tuning completed") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + formatter_class=argparse.RawTextHelpFormatter, + ) + + parser.add_argument("--tp-size", "-tp", type=int, default=1) + parser.add_argument( + "--input-type", type=str, choices=["bfloat16"], default="bfloat16" + ) + parser.add_argument( + "--out-dtype", + type=str, + choices=["float32", "float16", "bfloat16", "half"], + default="bfloat16", + ) + parser.add_argument("--batch-size", type=int, required=False) + parser.add_argument( + "--log-filename", + type=str, + default=None, + help="Custom log filename (without .log extension). If not provided, timestamped filename will be used.", + ) + parser.add_argument( + "--no-resume", + action="store_true", + help="Disable resume functionality and start fresh tuning", + ) + args = parser.parse_args() + + # Convert no_resume flag to resume boolean + args.resume = not args.no_resume + + main(args) diff --git a/aiter/ops/triton/tune_a8w8_per_token_scale.py b/aiter/ops/triton/tune_a8w8_per_token_scale.py new file mode 100644 index 0000000000..c232a0ff94 --- /dev/null +++ b/aiter/ops/triton/tune_a8w8_per_token_scale.py @@ -0,0 +1,1520 @@ +import argparse +import json +import logging +import multiprocessing as mp +import os +import signal +import sys +import time +import triton +from datetime import datetime +from typing import List, Dict, Union, Tuple, Optional, Any +import torch +import pytz +from rich.console import Console +from rich.table import Table +from rich.panel import Panel +from rich.columns import Columns +from rich.live import Live +from rich.layout import Layout +from rich.progress import Progress, BarColumn, TextColumn + + +from aiter.ops.triton.gemm_a8w8_per_token_scale import gemm_a8w8_per_token_scale # type: ignore +from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH # type: ignore +from aiter.ops.triton.utils.types import get_fp8_dtypes + +mp.set_start_method("spawn", force=True) + +# Get FP8 data types +e5m2_type, e4m3_type = get_fp8_dtypes() + + +class TimeoutError(Exception): + """Custom exception for timeout errors.""" + + pass + + +# Global variables to track bad configurations and current state +BAD_CONFIGS = { + "timeouts": [], + "out_of_resources": [], + "assert_errors": [], + "other_errors": [], +} + +# Global variables to track current state for SIGINT handling +CURRENT_CONFIG: Dict[str, Any] = { + "M": None, + "N": None, + "K": None, + "config": None, + "config_index": None, + "total_configs": None, + "batch_size": None, + "weight_shape_index": None, + "total_weight_shapes": None, + "gpu_id": None, +} + +INTERRUPTED = False + + +class GMT8Formatter(logging.Formatter): + """Custom formatter that uses GMT+8 timezone.""" + + def __init__(self, fmt=None, datefmt=None): + super().__init__(fmt, datefmt) + self.gmt8 = pytz.timezone("Asia/Shanghai") # GMT+8 + + def formatTime(self, record, datefmt=None): + # Convert timestamp to GMT+8 + dt = datetime.fromtimestamp(record.created, tz=self.gmt8) + if datefmt: + return dt.strftime(datefmt) + else: + return dt.strftime("%Y-%m-%d %H:%M:%S") + + def format(self, record): + # Add timezone info to the formatted message + original = super().format(record) + return original.replace("[GMT+8]", "") # Remove any existing timezone tag + + +def get_timestamped_filename(base_name: str, extension: str = ".log") -> str: + """Generate a filename with timestamp in GMT+8 timezone.""" + gmt8 = pytz.timezone("Asia/Shanghai") + timestamp = datetime.now(gmt8).strftime("%Y%m%d_%H%M%S") + return f"{base_name}_{timestamp}{extension}" + + +def sigint_handler(signum, frame): + """Handle SIGINT (Ctrl+C) gracefully by logging the current configuration.""" + global INTERRUPTED + global CURRENT_CONFIG + INTERRUPTED = True + + print("\n" + "=" * 80) + print("๐Ÿ›‘ TUNING INTERRUPTED BY USER (Ctrl+C)") + print("=" * 80) + + if CURRENT_CONFIG["M"] is not None: + print("๐Ÿ“ Last configuration being processed:") + print(f" ๐ŸŽฏ GPU: {CURRENT_CONFIG['gpu_id']}") + print( + f" ๐Ÿ“Š Matrix: M={CURRENT_CONFIG['M']} N={CURRENT_CONFIG['N']} K={CURRENT_CONFIG['K']}" + ) + print(f" ๐Ÿ“ฆ Batch Size: {CURRENT_CONFIG['batch_size']}") + print( + f" ๐Ÿ”„ Progress: Config {CURRENT_CONFIG['config_index'] + 1}/{CURRENT_CONFIG['total_configs']}" + ) + print( + f" ๐Ÿ—๏ธ Weight Shape: {CURRENT_CONFIG['weight_shape_index'] + 1}/{CURRENT_CONFIG['total_weight_shapes']}" + ) + + if CURRENT_CONFIG["config"]: + config = CURRENT_CONFIG["config"] + print(" โš™๏ธ Parameters:") + print(f" BLOCK_SIZE_M: {config.get('BLOCK_SIZE_M', 'N/A')}") + print(f" BLOCK_SIZE_N: {config.get('BLOCK_SIZE_N', 'N/A')}") + print(f" BLOCK_SIZE_K: {config.get('BLOCK_SIZE_K', 'N/A')}") + print(f" num_warps: {config.get('num_warps', 'N/A')}") + print(f" num_stages: {config.get('num_stages', 'N/A')}") + print(f" NUM_KSPLIT: {config.get('NUM_KSPLIT', 'N/A')}") + print(f" waves_per_eu: {config.get('waves_per_eu', 'N/A')}") + print(f" kpack: {config.get('kpack', 'N/A')}") + print(f" cache_modifier: {config.get('cache_modifier', 'N/A')}") + print(f" GROUP_SIZE_M: {config.get('GROUP_SIZE_M', 'N/A')}") + + # Show config in same format as console output for consistency + config_num = CURRENT_CONFIG["config_index"] + 1 + console_format = f" ๐Ÿ’ป Config {config_num} (INTERRUPTED): | BM:{config.get('BLOCK_SIZE_M')} BN:{config.get('BLOCK_SIZE_N')} BK:{config.get('BLOCK_SIZE_K')} W:{config.get('num_warps')} S:{config.get('num_stages')} KS:{config.get('NUM_KSPLIT')} kpack:{config.get('kpack')} cache:{config.get('cache_modifier')}" + print(console_format) + + # Log the interruption to the file if logger is available + try: + logger = logging.getLogger("gemm_a8w8_per_token_scale_tuning") + if logger.handlers: + # Use GMT+8 timestamp for consistency + gmt8 = pytz.timezone("Asia/Shanghai") + + # Create detailed log entry + detailed_log_entry = { + "timestamp": datetime.now(gmt8).isoformat(), + "event_type": "user_interrupt", + "gpu_id": CURRENT_CONFIG.get("gpu_id", "N/A"), + "batch_size": CURRENT_CONFIG["batch_size"], + "matrix_dims": f"M={CURRENT_CONFIG['M']} N={CURRENT_CONFIG['N']} K={CURRENT_CONFIG['K']}", + "config": CURRENT_CONFIG["config"], + "progress": f"Config {CURRENT_CONFIG['config_index'] + 1}/{CURRENT_CONFIG['total_configs']}", + "weight_shape_progress": f"Shape {CURRENT_CONFIG['weight_shape_index'] + 1}/{CURRENT_CONFIG['total_weight_shapes']}", + } + + # Log detailed interruption info + logger.info(f"=== USER INTERRUPT ===") + logger.info( + f"Interrupted while testing: Config {CURRENT_CONFIG['config_index'] + 1}/{CURRENT_CONFIG['total_configs']}" + ) + logger.info(f"GPU: {CURRENT_CONFIG.get('gpu_id', 'N/A')}") + logger.info( + f"Matrix: M={CURRENT_CONFIG['M']} N={CURRENT_CONFIG['N']} K={CURRENT_CONFIG['K']}" + ) + logger.info( + f"Weight Shape Progress: {CURRENT_CONFIG['weight_shape_index'] + 1}/{CURRENT_CONFIG['total_weight_shapes']}" + ) + + # Log config details in same format as console output for consistency + if CURRENT_CONFIG["config"]: + config = CURRENT_CONFIG["config"] + config_num = CURRENT_CONFIG["config_index"] + 1 + config_str = f"Config {config_num} (INTERRUPTED): | BM:{config.get('BLOCK_SIZE_M')} BN:{config.get('BLOCK_SIZE_N')} BK:{config.get('BLOCK_SIZE_K')} W:{config.get('num_warps')} S:{config.get('num_stages')} KS:{config.get('NUM_KSPLIT')} kpack:{config.get('kpack')} cache:{config.get('cache_modifier')} GROUP_SIZE_M:{config.get('GROUP_SIZE_M')}" + logger.info(f"CONFIG_DETAILS: {config_str}") + + logger.info(f"DETAILED_ENTRY: {detailed_log_entry}") + logger.info(f"=== END USER INTERRUPT ===") + + # Force flush to write immediately + for handler in logger.handlers: + if hasattr(handler, "stream"): + handler.stream.flush() + + print(" ๐Ÿ“ Interruption logged to tuning log file") + except Exception as e: + print(f" โš ๏ธ Could not log interruption: {e}") + + print("\n๐Ÿ’ก You can use this information to:") + print(" โ€ข Skip this problematic configuration in future runs") + print(" โ€ข Analyze why this specific config might be causing issues") + print(" โ€ข Adjust the search space to avoid similar parameter combinations") + print("=" * 80) + + # Exit gracefully + import sys + + sys.exit(1) + + +def setup_logger(log_file_path: str, mode: str = "a") -> logging.Logger: + """ + Setup logger for recording bad configurations during tuning. + + Args: + log_file_path: Path to the log file + mode: File write mode - 'a' to append to existing logs, 'w' to overwrite + + Returns: + Configured logger instance + """ + logger = logging.getLogger("gemm_a8w8_per_token_scale_tuning") + logger.setLevel(logging.INFO) + + # Clear existing handlers + logger.handlers.clear() + + # Create file handler with live writing (immediate flush) + # Default to append mode to preserve logs across resume sessions + file_handler = logging.FileHandler(log_file_path, mode=mode) + file_handler.setLevel(logging.INFO) + + # Create custom formatter that flushes immediately + file_handler.flush = lambda: file_handler.stream.flush() # type: ignore + + # Create console handler + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.WARNING) + + # Create GMT+8 formatter + formatter = GMT8Formatter( + "%(asctime)s [GMT+8] - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S" + ) + file_handler.setFormatter(formatter) + console_handler.setFormatter(formatter) + + # Add handlers to logger + logger.addHandler(file_handler) + logger.addHandler(console_handler) + + return logger + + +def log_bad_config( + logger: logging.Logger, + error_type: str, + M: int, + N: int, + K: int, + config: Dict[str, Union[str, int]], + error_msg: str = "", +): + """ + Log a bad configuration that failed during tuning. + + Args: + logger: Logger instance + error_type: Type of error ('timeout', 'out_of_resources', 'assert_error', 'other_error') + M, N, K: Matrix dimensions + config: Configuration that failed + error_msg: Additional error message + """ + # Use GMT+8 timestamp for consistency + gmt8 = pytz.timezone("Asia/Shanghai") + log_entry = { + "timestamp": datetime.now(gmt8).isoformat(), + "error_type": error_type, + "batch_size": M, + "matrix_dims": f"M={M} N={N} K={K}", + "config": config, + "error_msg": str(error_msg), + } + + # Log to file + logger.info(f"BAD_CONFIG_{error_type.upper()}: {log_entry}") + + # Force flush to write immediately + for handler in logger.handlers: + if hasattr(handler, "stream"): + handler.stream.flush() + + # Store in global list for summary + if error_type == "timeout": + BAD_CONFIGS["timeouts"].append(log_entry) + elif error_type == "out_of_resources": + BAD_CONFIGS["out_of_resources"].append(log_entry) + elif error_type == "assert_error": + BAD_CONFIGS["assert_errors"].append(log_entry) + else: + BAD_CONFIGS["other_errors"].append(log_entry) + + +def log_bad_config_summary(logger: logging.Logger, total_configs_tested: int): + """ + Log a summary of all bad configurations encountered during tuning. + + Args: + logger: Logger instance + total_configs_tested: Total number of configurations tested + """ + total_bad = ( + len(BAD_CONFIGS["timeouts"]) + + len(BAD_CONFIGS["out_of_resources"]) + + len(BAD_CONFIGS["assert_errors"]) + + len(BAD_CONFIGS["other_errors"]) + ) + success_rate = ( + ((total_configs_tested - total_bad) / total_configs_tested * 100) + if total_configs_tested > 0 + else 0 + ) + + logger.info("=" * 80) + logger.info("BAD CONFIGURATION SUMMARY") + logger.info("=" * 80) + logger.info(f"Total configurations tested: {total_configs_tested}") + logger.info(f"Successful configurations: {total_configs_tested - total_bad}") + logger.info(f"Failed configurations: {total_bad}") + logger.info(f"Success rate: {success_rate:.2f}%") + logger.info("") + + logger.info(f"Timeouts: {len(BAD_CONFIGS['timeouts'])}") + logger.info(f"Out of Resources: {len(BAD_CONFIGS['out_of_resources'])}") + logger.info(f"Assert Errors: {len(BAD_CONFIGS['assert_errors'])}") + logger.info(f"Other Errors: {len(BAD_CONFIGS['other_errors'])}") + logger.info("") + + if BAD_CONFIGS["timeouts"]: + logger.info("TIMEOUT CONFIGS (most problematic):") + for entry in BAD_CONFIGS["timeouts"]: + config = entry["config"] + logger.info( + f" - Batch {entry['batch_size']} | {entry['matrix_dims']} | BM:{config.get('BLOCK_SIZE_M', 'N/A')}, BN:{config.get('BLOCK_SIZE_N', 'N/A')}, BK:{config.get('BLOCK_SIZE_K', 'N/A')}, W:{config.get('num_warps', 'N/A')}, S:{config.get('num_stages', 'N/A')}, KS:{config.get('NUM_KSPLIT', 'N/A')}" + ) + + if BAD_CONFIGS["out_of_resources"]: + logger.info("OUT OF RESOURCE CONFIGS:") + for entry in BAD_CONFIGS["out_of_resources"]: + config = entry["config"] + logger.info( + f" - Batch {entry['batch_size']} | {entry['matrix_dims']} | BM:{config.get('BLOCK_SIZE_M', 'N/A')}, BN:{config.get('BLOCK_SIZE_N', 'N/A')}, BK:{config.get('BLOCK_SIZE_K', 'N/A')}, W:{config.get('num_warps', 'N/A')}, S:{config.get('num_stages', 'N/A')}, KS:{config.get('NUM_KSPLIT', 'N/A')}" + ) + + logger.info("=" * 80) + + # Print summary to console as well + print("\n๐Ÿ“Š Bad Configuration Summary:") + print(f" Total tested: {total_configs_tested}") + print(f" โœ… Successful: {total_configs_tested - total_bad}") + print( + f" โŒ Failed: {total_bad} ({len(BAD_CONFIGS['timeouts'])} timeouts, {len(BAD_CONFIGS['out_of_resources'])} OOM, {len(BAD_CONFIGS['assert_errors'])} assert, {len(BAD_CONFIGS['other_errors'])} other)" + ) + print(f" ๐Ÿ“ˆ Success rate: {success_rate:.1f}%") + + +def timeout_handler(signum, frame): + """Signal handler for timeout.""" + raise TimeoutError("Kernel execution timed out") + + +def run_with_timeout(func, timeout_seconds=3, *args, **kwargs): + """ + Run a function with a timeout limit. + + Args: + func: Function to execute + timeout_seconds: Timeout in seconds (default: 3) + *args: Arguments to pass to the function + **kwargs: Keyword arguments to pass to the function + + Returns: + Result of the function call + + Raises: + TimeoutError: If function execution exceeds timeout + """ + # Set the signal handler + old_handler = signal.signal(signal.SIGALRM, timeout_handler) + signal.alarm(timeout_seconds) + + try: + result = func(*args, **kwargs) + return result + finally: + # Cancel the alarm and restore the old handler + signal.alarm(0) + signal.signal(signal.SIGALRM, old_handler) + + +def generate_gemm_a8w8_per_token_scale_inputs(M, N, K, dtype, output=True, bias=False): + """ + Generate inputs for gemm_a8w8_per_token_scale kernel. + + Args: + M, N, K: Matrix dimensions + dtype: Output data type + output: Whether to generate output tensor + bias: Whether to generate bias tensor + + Returns: + Tuple of (x, w, x_scale, w_scale, bias, y) + """ + # Generate input matrix x (M, K) - FP8 E4M3 (matching test file pattern) + x = (torch.rand((M, K), dtype=torch.float16, device="cuda") / 10).to(e4m3_type) + + # Generate weight matrix w (N, K) - FP8 E4M3 (matching test file pattern) + w = (torch.rand((N, K), dtype=torch.float16, device="cuda") / 10).to(e4m3_type) + + # Generate per-token and per-channel scale tensors - 2D [M,1] and [N,1] + x_scale = torch.rand([M, 1], dtype=torch.float32, device="cuda") + w_scale = torch.rand([N, 1], dtype=torch.float32, device="cuda") + + # Generate bias tensor if needed + bias_tensor = None + if bias: + bias_tensor = torch.empty((N), dtype=dtype, device="cuda") + + # Generate output tensor if needed + y = None + if output: + y = torch.empty((M, N), dtype=dtype, device="cuda") + + return x, w, x_scale, w_scale, bias_tensor, y + + +def get_configs_compute_bound() -> List[Dict[str, int | str]]: + """ + Generate configuration space for tuning the gemm_a8w8_per_token_scale kernel. + Focus on parameters that affect performance for this specific kernel. + Comprehensive search space matching atomic kernel patterns. + """ + configs = [] + + # Explore optimized parameter space (removed large block sizes that cause slowdowns) + for num_stages in [1, 2, 3, 4]: + for block_m in [32, 64, 128]: # Removed 256 (causes slowdowns) + for block_n in [32, 64, 128]: # Removed 256 (causes slowdowns) + for block_k in [64, 128, 256]: + for group_size in [1, 8, 16]: + for num_warps in [2, 4, 8]: + for num_ksplit in [ + 1, + 2, + 4, + ]: # Key parameter for K-splitting + for waves_per_eu in [2, 4, 8]: + for kpack in [2]: + for cache_modifier in ["", ".cg"]: + configs.append( + { + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": group_size, + "num_warps": num_warps, + "num_stages": num_stages, + "NUM_KSPLIT": num_ksplit, + "waves_per_eu": waves_per_eu, + "kpack": kpack, + "matrix_instr_nonkdim": 16, # Fixed value from atomic kernel + "cache_modifier": cache_modifier, + } + ) + return configs + # for num_stages in [ + # 1, + # ]: + # for block_m in [ + # 32, + # ]: # Removed 256 (causes slowdowns) + # for block_n in [ + # 32, + # ]: # Removed 256 (causes slowdowns) + # for block_k in [ + # 64, + # ]: + # for group_size in [ + # 1, + # ]: + # for num_warps in [ + # 2, + # ]: + # for num_ksplit in [ + # 1, + # ]: # Key parameter for K-splitting + # for waves_per_eu in [ + # 2, + # ]: + # for kpack in [2]: + # for cache_modifier in ["", ".cg"]: + # configs.append( + # { + # "BLOCK_SIZE_M": block_m, + # "BLOCK_SIZE_N": block_n, + # "BLOCK_SIZE_K": block_k, + # "GROUP_SIZE_M": group_size, + # "num_warps": num_warps, + # "num_stages": num_stages, + # "NUM_KSPLIT": num_ksplit, + # "waves_per_eu": waves_per_eu, + # "kpack": kpack, + # "matrix_instr_nonkdim": 16, # Fixed value from atomic kernel + # "cache_modifier": cache_modifier, + # } + # ) + # return configs + + +def get_weight_shapes(tp_size: int) -> List[Tuple[int, int]]: + """Get weight shapes to test during tuning.""" + total = [ + # (1024, 1024), + # (4096, 1024), + # (1024, 2048), + # (6144, 1024), + (1024, 3072), + ] + + weight_shapes: List[Tuple[int, int]] = [] + for t in total: + weight_shapes.append(t) + + return weight_shapes + + +def run_torch_reference(x, w, x_scale, w_scale, bias, dtype=torch.bfloat16): + """ + Run reference implementation using PyTorch. + This is used for correctness verification. + """ + # Apply scaling as in test file: convert to scale dtype, multiply, then compute + x = x.to(x_scale.dtype) * x_scale + w = w.to(w_scale.dtype) * w_scale + + # Compute the matrix multiplication - note weights need to be transposed for torch linear + out = torch.nn.functional.linear(x.to(torch.float32), w.to(torch.float32)) + + return out.to(dtype) + + +def benchmark_config( + x: torch.Tensor, + w: torch.Tensor, + x_scale: torch.Tensor, + w_scale: torch.Tensor, + bias: Optional[torch.Tensor], + dtype: torch.dtype, + config: Dict[str, Union[str, int]], + y: Optional[torch.Tensor] = None, + num_iters=10, +) -> float: + """ + Benchmark the performance of a GEMM operation with a specific configuration. + + This function measures the execution time of the gemm_a8w8_per_token_scale kernel by running + it multiple times with synchronization points to ensure accurate timing. It performs + warmup runs before the actual benchmarking to account for JIT compilation overhead. + + Args: + x (torch.Tensor): Input tensor of shape (M, K) representing the first matrix operand. + w (torch.Tensor): Weight tensor of shape (N, K) representing the second matrix operand. + x_scale (torch.Tensor): Per-token scale tensor for x with shape (M, 1). + w_scale (torch.Tensor): Per-output-channel scale tensor for w with shape (N, 1). + dtype (torch.dtype): Data type for the computation (e.g., torch.bfloat16). + config (Dict[str, Union[str, int]]): Configuration dictionary containing kernel + parameters such as block sizes, number of warps, etc. + y (Optional[torch.Tensor], optional): Output tensor to store the result. If None, + a new tensor will be allocated. Defaults to None. + num_iters (int, optional): Number of benchmark iterations to run. Defaults to 10. + + Returns: + float: Average execution time in microseconds (us) per iteration. + + Note: + The function performs 5 warmup iterations before benchmarking to account for + JIT compilation and GPU warmup effects. The timing is measured using CUDA events + for accurate GPU kernel timing. + """ + # Get reference output for correctness verification + torch_out = run_torch_reference(x, w, x_scale, w_scale, bias, dtype) + + # Add SPLITK_BLOCK_SIZE computation as done in the kernel function + _, K = x.shape + _, K = w.shape + num_ksplit = int(config["NUM_KSPLIT"]) + block_k = int(config["BLOCK_SIZE_K"]) + splitk_block_size = triton.cdiv(K, num_ksplit) + + config["SPLITK_BLOCK_SIZE"] = splitk_block_size + if block_k > splitk_block_size: + block_k = triton.next_power_of_2(splitk_block_size) + if block_k > splitk_block_size: + block_k = block_k // 4 + block_k = max(block_k, 16) + config["BLOCK_SIZE_K"] = block_k + + # Run kernel + def run(): + return gemm_a8w8_per_token_scale(x, w, x_scale, w_scale, dtype, y, config) + + torch.cuda.synchronize() + + # JIT compilation & warmup with timeout for entire warmup phase + def run_warmup(): + for i in range(5): + run() + + try: + run_with_timeout(run_warmup, timeout_seconds=3) + except TimeoutError: + # If warmup times out, this config is likely bad, skip it + raise TimeoutError("Warmup phase timed out after 3 seconds") + torch.cuda.synchronize() + + start_event = torch.Event(enable_timing=True) + end_event = torch.Event(enable_timing=True) + + latencies: list[float] = [] + for i in range(num_iters): + torch.cuda.synchronize() + start_event.record() + try: + triton_out_raw = run_with_timeout(run, timeout_seconds=3) + except TimeoutError: + # If benchmark iteration times out, skip this config + raise TimeoutError(f"Benchmark iteration {i + 1} timed out after 3 seconds") + + # Convert to the same dtype as the reference for comparison + # Handle the case where triton_out_raw might be None + if triton_out_raw is not None: + triton_out = triton_out_raw.to(torch_out.dtype) + else: + triton_out = torch_out # Fallback to reference output + end_event.record() + end_event.synchronize() + latencies.append(start_event.elapsed_time(end_event)) + torch.testing.assert_close(triton_out, torch_out, atol=1e-1, rtol=1e-1) + avg = sum(latencies) / (num_iters * 10) * 1000 # us + return avg + + +# Global variable to store console output +console_output = [] + + +def create_live_display( + M: int, + N: int, + K: int, + current_config: Dict[str, Union[str, int]], + best_config: Dict[str, Union[str, int]], + best_time: float, + config_index: int, + total_configs: int, + console_messages: Optional[List[str]] = None, +) -> Layout: + """Create a live display layout with current and best configuration tables.""" + + layout = Layout() + + # Use global console_output if none provided + if console_messages is None: + global console_output + console_messages = console_output + + # Create progress bar + progress = Progress( + TextColumn("[bold blue]{task.description}"), + BarColumn(bar_width=40), + TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), + TextColumn("({task.completed}/{task.total})"), + ) + task = progress.add_task( + f"๐Ÿ”ง Tuning M={M} N={N} K={K} | Batch Size={M}", + total=total_configs, + completed=config_index, + ) + + # Create status information + status_text = "" + if best_time != float("inf"): + status_text = f"๐Ÿ† Best Performance: {best_time:.1f}ฮผs" + else: + status_text = "๐Ÿ” Searching for best configuration..." + + # Create console area + console_text = ( + "\n".join(console_messages[-10:]) + if console_messages + else "Waiting for results..." + ) + console_table = Table(show_header=False, box=None, padding=0) + console_table.add_column("Output", style="white") + console_table.add_row(console_text) + + # Create current config table + config_table = Table( + title="Current Configuration", + show_header=True, + header_style="bold magenta", + ) + config_table.add_column("Parameter", style="cyan", width=15) + config_table.add_column("Value", style="green", width=10) + + # Add matrix dimensions and batch size first + config_table.add_row("[bold yellow]Matrix M[/bold yellow]", str(M), style="yellow") + config_table.add_row("[bold yellow]Matrix N[/bold yellow]", str(N), style="yellow") + config_table.add_row("[bold yellow]Matrix K[/bold yellow]", str(K), style="yellow") + config_table.add_row( + "[bold yellow]Batch Size[/bold yellow]", str(M), style="yellow" + ) + config_table.add_row("", "") # Separator + config_table.add_row("BLOCK_SIZE_M", str(current_config.get("BLOCK_SIZE_M", "N/A"))) + config_table.add_row("BLOCK_SIZE_N", str(current_config.get("BLOCK_SIZE_N", "N/A"))) + config_table.add_row("BLOCK_SIZE_K", str(current_config.get("BLOCK_SIZE_K", "N/A"))) + config_table.add_row("num_warps", str(current_config.get("num_warps", "N/A"))) + config_table.add_row("num_stages", str(current_config.get("num_stages", "N/A"))) + config_table.add_row("NUM_KSPLIT", str(current_config.get("NUM_KSPLIT", "N/A"))) + config_table.add_row("waves_per_eu", str(current_config.get("waves_per_eu", "N/A"))) + config_table.add_row("kpack", str(current_config.get("kpack", "N/A"))) + config_table.add_row( + "cache_modifier", str(current_config.get("cache_modifier", "N/A")) + ) + config_table.add_row("GROUP_SIZE_M", str(current_config.get("GROUP_SIZE_M", "N/A"))) + + # Create best config table if we have a best configuration + best_config_table = None + if best_time != float("inf"): + best_config_table = Table( + title="๐Ÿ† Best Configuration So Far", + show_header=True, + header_style="bold green", + ) + best_config_table.add_column("Parameter", style="cyan", width=15) + best_config_table.add_column("Value", style="green", width=10) + + # Add performance and matrix dimensions + best_config_table.add_row( + "[bold green]Performance[/bold green]", f"{best_time:.1f}ฮผs", style="green" + ) + best_config_table.add_row("", "") # Separator + best_config_table.add_row( + "[bold yellow]BLOCK_SIZE_M[/bold yellow]", + str(best_config.get("BLOCK_SIZE_M", "N/A")), + style="yellow", + ) + best_config_table.add_row( + "[bold yellow]BLOCK_SIZE_N[/bold yellow]", + str(best_config.get("BLOCK_SIZE_N", "N/A")), + style="yellow", + ) + best_config_table.add_row( + "[bold yellow]BLOCK_SIZE_K[/bold yellow]", + str(best_config.get("BLOCK_SIZE_K", "N/A")), + style="yellow", + ) + best_config_table.add_row("num_warps", str(best_config.get("num_warps", "N/A"))) + best_config_table.add_row( + "num_stages", str(best_config.get("num_stages", "N/A")) + ) + best_config_table.add_row( + "NUM_KSPLIT", str(best_config.get("NUM_KSPLIT", "N/A")) + ) + best_config_table.add_row( + "waves_per_eu", str(best_config.get("waves_per_eu", "N/A")) + ) + best_config_table.add_row("kpack", str(best_config.get("kpack", "N/A"))) + best_config_table.add_row( + "cache_modifier", str(best_config.get("cache_modifier", "N/A")) + ) + best_config_table.add_row( + "GROUP_SIZE_M", str(best_config.get("GROUP_SIZE_M", "N/A")) + ) + + # Create combined layout + if best_config_table: + # Display tables side by side + tables = Columns([config_table, best_config_table], equal=True, expand=True) + layout.split_column( + Layout(Panel(progress, title="Progress", border_style="blue"), size=5), + Layout(Panel(status_text, title="Status", border_style="green"), size=3), + Layout(tables), + Layout( + Panel(console_table, title="Console Output", border_style="cyan"), + size=10, + ), + ) + else: + # Display only current config + layout.split_column( + Layout(Panel(progress, title="Progress", border_style="blue"), size=5), + Layout(Panel(status_text, title="Status", border_style="green"), size=3), + Layout(config_table), + Layout( + Panel(console_table, title="Console Output", border_style="cyan"), + size=10, + ), + ) + + return layout + + +def tune( + M: int, + N: int, + K: int, + search_space: List[Dict[str, int | str]], + input_type: str, + logger: logging.Logger, +): + """Tune the kernel for specific matrix dimensions.""" + # Register SIGINT handler if not already registered + if not signal.getsignal(signal.SIGINT) == sigint_handler: + signal.signal(signal.SIGINT, sigint_handler) + + if input_type == "bfloat16": + # Use the same input generation as test file + x, w, x_scale, w_scale, bias, y = generate_gemm_a8w8_per_token_scale_inputs( + M, N, K, torch.bfloat16, bias=True + ) + else: + raise RuntimeError( + "Currently, only support tune a8w8 per-token scale kernel with bfloat16 output." + ) + + best_config: Dict[str, Union[int, str]] = {} + best_time = float("inf") + slow_config_threshold = ( + 1000 # microseconds - configs slower than this get highlighted + ) + + # Clear console output for fresh start + global console_output + console_output = [] + + # Initialize Rich console for better formatting + console = Console() + + # Create initial live display + initial_layout = create_live_display( + M, + N, + K, + search_space[0] if search_space else {}, + best_config, + best_time, + 0, + len(search_space), + console_output, + ) + + # Create progress display with Rich and Live + with Live(initial_layout, refresh_per_second=4, console=console) as live: + for i, config in enumerate(search_space): + # Check if we were interrupted + if INTERRUPTED: + break + + # Update global state for SIGINT handling + CURRENT_CONFIG.update( + { + "M": M, + "N": N, + "K": K, + "config": config, + "config_index": i, + "total_configs": len(search_space), + "batch_size": M, + } + ) + + # Update live display with current configuration + layout = create_live_display( + M, + N, + K, + config, + best_config, + best_time, + i + 1, + len(search_space), + console_output, + ) + live.update(layout) + + try: + kernel_time = benchmark_config( + x=x, + w=w, + x_scale=x_scale, + w_scale=w_scale, + bias=bias, + dtype=torch.bfloat16, + y=None, + config=config, + num_iters=10, + ) + + # Add kernel time to console output + if kernel_time > slow_config_threshold: + console_msg = f"[yellow]โš ๏ธ Config {i + 1} (SLOW): {kernel_time:.1f}ฮผs | BM:{config.get('BLOCK_SIZE_M')} BN:{config.get('BLOCK_SIZE_N')} BK:{config.get('BLOCK_SIZE_K')} W:{config.get('num_warps')} S:{config.get('num_stages')} KS:{config.get('NUM_KSPLIT')}[/yellow]" + live.console.print(console_msg) + console_output.append(console_msg) + else: + console_msg = f"[green]โœ… Config {i + 1}: {kernel_time:.1f}ฮผs | BM:{config.get('BLOCK_SIZE_M')} BN:{config.get('BLOCK_SIZE_N')} BK:{config.get('BLOCK_SIZE_K')} W:{config.get('num_warps')} S:{config.get('num_stages')} KS:{config.get('NUM_KSPLIT')}[/green]" + console_output.append(console_msg) + + # Update best time and config + if kernel_time < best_time: + best_time = kernel_time + best_config = config + best_msg = ( + f"[bold green]๐Ÿ† NEW BEST: {kernel_time:.1f}ฮผs![/bold green]" + ) + console_output.append(best_msg) + + # Update live display with current configuration and console output + layout = create_live_display( + M, + N, + K, + config, + best_config, + best_time, + i + 1, + len(search_space), + console_output, + ) + live.update(layout) + + except triton.runtime.autotuner.OutOfResources as e: + # Log and skip out of resources configurations + log_bad_config(logger, "out_of_resources", M, N, K, config, str(e)) + error_msg = f"[red]โŒ Config {i + 1} (OOM): Out of resources | BM:{config.get('BLOCK_SIZE_M')} BN:{config.get('BLOCK_SIZE_N')} BK:{config.get('BLOCK_SIZE_K')} W:{config.get('num_warps')} S:{config.get('num_stages')} KS:{config.get('NUM_KSPLIT')}[/red]" + console_output.append(error_msg) + live.console.print(error_msg) + continue + except AssertionError as e: + # Log and skip assert error configurations + log_bad_config(logger, "assert_error", M, N, K, config, str(e)) + error_msg = f"[red]โŒ Config {i + 1} (ASSERT): Assert error | BM:{config.get('BLOCK_SIZE_M')} BN:{config.get('BLOCK_SIZE_N')} BK:{config.get('BLOCK_SIZE_K')} W:{config.get('num_warps')} S:{config.get('num_stages')} KS:{config.get('NUM_KSPLIT')} | {e}[/red]" + console_output.append(error_msg) + live.console.print(error_msg) + continue + except TimeoutError as e: + # Log and skip timeout configurations + log_bad_config(logger, "timeout", M, N, K, config, str(e)) + error_msg = f"[orange1]โฑ๏ธ Config {i + 1} (TIMEOUT): {e} | BM:{config.get('BLOCK_SIZE_M')} BN:{config.get('BLOCK_SIZE_N')} BK:{config.get('BLOCK_SIZE_K')} W:{config.get('num_warps')} S:{config.get('num_stages')} KS:{config.get('NUM_KSPLIT')}[/orange1]" + console_output.append(error_msg) + live.console.print(error_msg) + continue + except Exception as e: + # Log and skip other error configurations + log_bad_config(logger, "other_error", M, N, K, config, str(e)) + error_msg = f"[red]๐Ÿ’ฅ Config {i + 1} (ERROR): {e} | BM:{config.get('BLOCK_SIZE_M')} BN:{config.get('BLOCK_SIZE_N')} BK:{config.get('BLOCK_SIZE_K')} W:{config.get('num_warps')} S:{config.get('num_stages')} KS:{config.get('NUM_KSPLIT')}[/red]" + console_output.append(error_msg) + live.console.print(error_msg) + continue + + # Show final completion message with Rich + print("\n" + "=" * 70) + + # Create best config table with matrix dimensions + best_table = Table( + title="๐Ÿ† Best Configuration Found", show_header=True, header_style="bold green" + ) + best_table.add_column("Parameter", style="cyan", width=15) + best_table.add_column("Value", style="green", width=10) + + # Add matrix dimensions and batch size first + best_table.add_row( + "[bold yellow]Matrix M (Batch)[/bold yellow]", str(M), style="yellow" + ) + best_table.add_row("[bold yellow]Matrix N[/bold yellow]", str(N), style="yellow") + best_table.add_row("[bold yellow]Matrix K[/bold yellow]", str(K), style="yellow") + best_table.add_row("", "") # Separator + best_table.add_row( + "[bold green]Performance[/bold green]", f"{best_time:.1f}ฮผs", style="green" + ) + best_table.add_row("", "") # Separator + best_table.add_row( + "[bold yellow]BLOCK_SIZE_M[/bold yellow]", + str(best_config.get("BLOCK_SIZE_M", "N/A")), + style="yellow", + ) + best_table.add_row( + "[bold yellow]BLOCK_SIZE_N[/bold yellow]", + str(best_config.get("BLOCK_SIZE_N", "N/A")), + style="yellow", + ) + best_table.add_row( + "[bold yellow]BLOCK_SIZE_K[/bold yellow]", + str(best_config.get("BLOCK_SIZE_K", "N/A")), + style="yellow", + ) + best_table.add_row("num_warps", str(best_config.get("num_warps", "N/A"))) + best_table.add_row("num_stages", str(best_config.get("num_stages", "N/A"))) + best_table.add_row("NUM_KSPLIT", str(best_config.get("NUM_KSPLIT", "N/A"))) + best_table.add_row("waves_per_eu", str(best_config.get("waves_per_eu", "N/A"))) + best_table.add_row("kpack", str(best_config.get("kpack", "N/A"))) + best_table.add_row("cache_modifier", str(best_config.get("cache_modifier", "N/A"))) + best_table.add_row("GROUP_SIZE_M", str(best_config.get("GROUP_SIZE_M", "N/A"))) + best_table.add_row( + "matrix_instr_nonkdim", str(best_config.get("matrix_instr_nonkdim", "N/A")) + ) + + completion_panel = Panel( + best_table, + title=f"[bold green]โœ… Completed Tuning for M={M} N={N} K={K} (Batch Size={M})[/bold green]", + border_style="green", + ) + console.print(completion_panel) + print("=" * 70) + + assert best_config is not None + return best_config + + +def save_configs( + N, + K, + configs, + save_path, + is_incremental=False, + completed_batch_sizes=None, +) -> None: + """Save the best configurations to a JSON file.""" + os.makedirs(save_path, exist_ok=True) + device_name = "R9700" # TODO: Hardcoded, make it dynamic + + if is_incremental: + # Save incremental progress with batch size info in filename + batch_sizes_str = ( + "_".join(map(str, completed_batch_sizes)) + if completed_batch_sizes + else "partial" + ) + json_file_name = f"{device_name}-GEMM-A8W8_PER_TOKEN_SCALE-N={N}-K={K}_batch_{batch_sizes_str}.json" + progress_file = os.path.join( + save_path, + f"{device_name}-GEMM-A8W8_PER_TOKEN_SCALE-N={N}-K={K}_progress.json", + ) + + # Save progress info + progress_info = { + "completed_batch_sizes": completed_batch_sizes or [], + "configs": configs, + "last_updated": datetime.now().isoformat(), + } + + with open(progress_file, "w") as f: + json.dump(progress_info, f, indent=4) + f.write("\n") + else: + json_file_name = f"{device_name}-GEMM-A8W8_PER_TOKEN_SCALE-N={N}-K={K}.json" + + config_file_path = os.path.join(save_path, json_file_name) + + # Add incremental flag to filename + action = "Updating incremental" if is_incremental else "Writing" + print(f"{action} config to {config_file_path}...") + + with open(config_file_path, "w") as f: + json.dump(configs, f, indent=4) + f.write("\n") + + +def load_progress(N, K, save_path): + """Load previously saved progress for a given N,K configuration.""" + device_name = "R9700" # TODO: Hardcoded, make it dynamic + progress_file = os.path.join( + save_path, f"{device_name}-GEMM-A8W8_PER_TOKEN_SCALE-N={N}-K={K}_progress.json" + ) + + if os.path.exists(progress_file): + try: + with open(progress_file, "r") as f: + progress_info = json.load(f) + return progress_info.get("completed_batch_sizes", []), progress_info.get( + "configs", {} + ) + except Exception as e: + print(f"Warning: Could not load progress file {progress_file}: {e}") + return [], {} + return [], {} + + +def tune_on_gpu( + gpu_id: int, + batch_sizes: List[int], + weight_shapes: List[Tuple[int, int]], + input_type: str, + resume: bool = True, + log_filename: Optional[str] = None, +) -> None: + """Run tuning on a specific GPU.""" + # Register SIGINT handler and set GPU ID in global state + signal.signal(signal.SIGINT, sigint_handler) + CURRENT_CONFIG["gpu_id"] = gpu_id + + torch.cuda.set_device(gpu_id) + print(f"๐Ÿš€ Starting tuning on GPU {gpu_id} with batch sizes {batch_sizes}") + + save_path = AITER_TRITON_CONFIGS_PATH + "/gemm/" + + # Setup logger for this GPU with custom or timestamped filename + if log_filename: + # Use custom filename, ensure it has .log extension + if not log_filename.endswith(".log"): + log_filename += ".log" + # If no path separator, assume it's just a filename + if "/" not in log_filename and "\\" not in log_filename: + log_filename = os.path.join(save_path, log_filename) + else: + log_filename = log_filename # Use full path as provided + else: + # Fall back to timestamped filename + log_filename = os.path.join( + save_path, + get_timestamped_filename( + f"tune_a8w8_per_token_scale_bad_configs_gpu{gpu_id}" + ), + ) + + # Choose appropriate logging mode: append for resume, overwrite for fresh start + log_mode = "a" if resume else "w" + logger = setup_logger(log_filename, mode=log_mode) + + # Log the start time in GMT+8 + gmt8 = pytz.timezone("Asia/Shanghai") + start_time_gmt8 = datetime.now(gmt8).strftime("%Y-%m-%d %H:%M:%S") + logger.info(f"=== TUNING SESSION STARTED AT {start_time_gmt8} [GMT+8] ===") + logger.info(f"GPU: {gpu_id}") + logger.info(f"Batch sizes: {batch_sizes}") + + search_space = get_configs_compute_bound() + total_configs = len(search_space) + total_tests = total_configs * len(batch_sizes) * len(weight_shapes) + + print(f" ๐Ÿ“Š Search space: {total_configs:,} configurations") + print(f" ๐ŸŽฏ Total tests to run: {total_tests:,}") + print( + f" โšก Estimated tests per weight shape: {total_configs * len(batch_sizes):,}" + ) + log_action = ( + "Appending to existing" + if resume and os.path.exists(log_filename) + else "Writing to new" + ) + print(f" ๐Ÿ“ Bad configurations will be logged to: {log_filename}") + print(f" ๐Ÿ“ Logging mode: {log_action}") + + start = time.time() + + # Collect all configs to determine the best overall config + all_configs: List[Dict[str, Dict[str, int | str]]] = [] + + for i, shape in enumerate(weight_shapes): + # Check if we were interrupted + if INTERRUPTED: + break + + # Update weight shape tracking + CURRENT_CONFIG.update( + {"weight_shape_index": i, "total_weight_shapes": len(weight_shapes)} + ) + + N, K = shape[0], shape[1] + print( + f"\n๐Ÿš€ [GPU {gpu_id}] Shape {i + 1}/{len(weight_shapes)}: Starting tuning for N:{N}, K:{K}" + ) + print( + f" ๐Ÿ“Š Testing {len(search_space):,} configurations across {len(batch_sizes)} batch sizes" + ) + + # Check for existing progress and resume from there (if resume is enabled) + if resume: + completed_batch_sizes, existing_configs = load_progress(N, K, save_path) + else: + completed_batch_sizes, existing_configs = [], {} + + # Filter batch_sizes to only those not yet completed + remaining_batch_sizes = [ + bs for bs in batch_sizes if bs not in completed_batch_sizes + ] + + if completed_batch_sizes and resume: + print(f"\n ๐Ÿ“‚ [GPU {gpu_id}] Found progress for N={N}, K={K}") + print(f" โœ… Already completed batch sizes: {completed_batch_sizes}") + print(f" ๐Ÿ”„ Remaining batch sizes to tune: {remaining_batch_sizes}") + elif not resume: + print( + f"\n ๐Ÿ”„ [GPU {gpu_id}] Starting fresh (resume disabled) for N={N}, K={K}" + ) + elif not remaining_batch_sizes: + print( + f"\n โœ… [GPU {gpu_id}] All batch sizes already completed for N={N}, K={K}" + ) + # Add existing configs to all_configs and continue to next shape + if existing_configs: + all_configs.append(existing_configs) + save_configs(N, K, existing_configs, save_path) + continue + + # Initialize benchmark_results with existing results if any + benchmark_results: List[Dict[str, str | int]] = [] + if existing_configs: + # Reconstruct benchmark_results from existing configs + # We need to map the configs back to their corresponding batch sizes + for i, batch_size in enumerate(batch_sizes): + if batch_size in completed_batch_sizes: + # Find the config for this batch size + config_to_add = None + + # Try to find matching config based on batch size category + if batch_size < 32 and "small" in existing_configs: + config_to_add = existing_configs["small"] + elif batch_size <= 128: + BLK_M = triton.next_power_of_2(batch_size) + if BLK_M == 32 and "medium_M32" in existing_configs: + config_to_add = existing_configs["medium_M32"] + elif BLK_M == 64 and "medium_M64" in existing_configs: + config_to_add = existing_configs["medium_M64"] + elif BLK_M == 128 and "medium_M128" in existing_configs: + config_to_add = existing_configs["medium_M128"] + elif batch_size <= 256 and "large" in existing_configs: + config_to_add = existing_configs["large"] + elif batch_size > 256 and "xlarge" in existing_configs: + config_to_add = existing_configs["xlarge"] + + if config_to_add: + benchmark_results.append(config_to_add) + else: + # If we couldn't find a matching config, we'll need to retune this batch size + remaining_batch_sizes.append(batch_size) + + for batch_size in remaining_batch_sizes: + # Check if we were interrupted + if INTERRUPTED: + break + + print( + f"\n ๐Ÿ” [GPU {gpu_id}] Testing batch size M={batch_size} for N={N}, K={K}" + ) + result = tune( + batch_size, + N, + K, + search_space, + input_type, + logger, + ) + + # Check if tune() was interrupted + if INTERRUPTED: + break + + benchmark_results.append(result) + + # Save incremental progress immediately after each batch size + updated_completed_batch_sizes = completed_batch_sizes + [batch_size] + + # Create configs for different M size categories as expected by the kernel + incremental_configs: Dict[str, Dict[str, int | str]] = {} + for i, (M, config) in enumerate( + zip(batch_sizes[: len(benchmark_results)], benchmark_results) + ): + if i == len(batch_sizes[: len(benchmark_results)]) - 1: + incremental_configs["any"] = config + elif M < 32: + incremental_configs["small"] = config + elif M <= 128: + BLK_M = triton.next_power_of_2(M) + if BLK_M == 32: + incremental_configs["medium_M32"] = config + elif BLK_M == 64: + incremental_configs["medium_M64"] = config + elif BLK_M == 128: + incremental_configs["medium_M128"] = config + elif M <= 256: + incremental_configs["large"] = config + else: + incremental_configs["xlarge"] = config + + # Save the incremental progress + save_configs( + N, + K, + incremental_configs, + save_path, + is_incremental=True, + completed_batch_sizes=updated_completed_batch_sizes, + ) + + print(f" ๐Ÿ’พ [GPU {gpu_id}] Saved progress for batch size {batch_size}") + + # Update completed_batch_sizes for next iteration + completed_batch_sizes = updated_completed_batch_sizes + + # Create final configs for different M size categories as expected by the kernel + best_configs: Dict[str, Dict[str, int | str]] = {} + for i, (M, config) in enumerate(zip(batch_sizes, benchmark_results)): + if i == len(batch_sizes) - 1: + best_configs["any"] = config + elif M < 32: + best_configs["small"] = config + elif M <= 128: + BLK_M = triton.next_power_of_2(M) + if BLK_M == 32: + best_configs["medium_M32"] = config + elif BLK_M == 64: + best_configs["medium_M64"] = config + elif BLK_M == 128: + best_configs["medium_M128"] = config + elif M <= 256: + best_configs["large"] = config + else: + best_configs["xlarge"] = config + + # Store configs for later analysis + all_configs.append(best_configs) + + # Save the final complete config (non-incremental) + save_configs(N, K, best_configs, save_path) + + # Clean up progress file since we completed successfully + device_name = "R9700" # TODO: Hardcoded, make it dynamic + progress_file = os.path.join( + save_path, + f"{device_name}-GEMM-A8W8_PER_TOKEN_SCALE-N={N}-K={K}_progress.json", + ) + if os.path.exists(progress_file): + os.remove(progress_file) + print(f" ๐Ÿงน [GPU {gpu_id}] Cleaned up progress file for N={N}, K={K}") + + # Create a default config file (without N,K parameters) by selecting the most common config + default_config = create_default_config(all_configs) + save_default_config(default_config, save_path) + + end = time.time() + + # Log session end time in GMT+8 + gmt8 = pytz.timezone("Asia/Shanghai") + end_time_gmt8 = datetime.now(gmt8).strftime("%Y-%m-%d %H:%M:%S") + duration = end - start + logger.info(f"=== TUNING SESSION COMPLETED AT {end_time_gmt8} [GMT+8] ===") + logger.info(f"Total duration: {duration:.2f} seconds") + + # Log summary of bad configurations + log_bad_config_summary(logger, total_tests) + + print(f"Tuning on GPU {gpu_id} took {duration:.2f} seconds") + + +def create_default_config( + all_configs: List[Dict[str, Dict[str, Union[int, str]]]], +) -> Dict[str, Dict[str, Union[int, str]]]: + """Create a default config by selecting the most common config across all shapes.""" + from collections import Counter + + # Collect all configs for each category + category_configs = { + "small": [], + "medium_M32": [], + "medium_M64": [], + "medium_M128": [], + "large": [], + "xlarge": [], + "any": [], + } + + for config in all_configs: + for category, params in config.items(): + if category in category_configs: + # Convert config to a hashable tuple for counting + config_tuple = tuple(sorted(params.items())) + category_configs[category].append(config_tuple) + + # Find the most common config for each category + default_config: Dict[str, Dict[str, Union[int, str]]] = {} + for category, configs in category_configs.items(): + if configs: + most_common = Counter(configs).most_common(1)[0][0] + default_config[category] = dict(most_common) + + return default_config + + +def save_default_config( + config: Dict[str, Dict[str, Union[int, str]]], save_path: str +) -> None: + """Save the default config file (without N,K parameters).""" + os.makedirs(save_path, exist_ok=True) + device_name = "R9700" # TODO: Hardcoded, make it dynamic + json_file_name = f"{device_name}-GEMM-A8W8_PER_TOKEN_SCALE.json" + + config_file_path = os.path.join(save_path, json_file_name) + print(f"Writing default config to {config_file_path}...") + + with open(config_file_path, "w") as f: + json.dump(config, f, indent=4) + f.write("\n") + + +def distribute_batch_sizes(batch_sizes: List[int], num_gpus: int) -> List[List[int]]: + """Distribute batch sizes across available GPUs.""" + batches_per_gpu: List[List[int]] = [] + for i in range(num_gpus): + start_idx = i * len(batch_sizes) // num_gpus + end_idx = (i + 1) * len(batch_sizes) // num_gpus + batches_per_gpu.append(batch_sizes[start_idx:end_idx]) + return batches_per_gpu + + +def main(args): + print(args) + num_gpus = torch.cuda.device_count() + if num_gpus == 0: + raise RuntimeError("No GPU available for tuning") + print(f"Found {num_gpus} GPUs for parallel tuning") + + torch.cuda.init() + + if args.batch_size is None: + batch_sizes = [ + 16, # For small config + 32, # For medium_M32 config + 64, # For medium_M64 config + 128, # For medium_M128 config + 256, # For large config + 512, # For large config + 2048, # For xlarge config + 4096, # For xlarge config + ] + else: + batch_sizes = [args.batch_size] + num_gpus = 1 # If only one batch size, use only one GPU + + weight_shapes = get_weight_shapes(args.tp_size) + + batches_per_gpu = distribute_batch_sizes(batch_sizes, num_gpus) + + # Prepare arguments for each GPU process + process_args = [] + for gpu_id in range(num_gpus): + process_args.append( + ( + gpu_id, + batches_per_gpu[gpu_id], + weight_shapes, # Each GPU processes all weight shapes + args.input_type, + args.resume, + args.log_filename, + ) + ) + + # Set up signal handler for main process to gracefully terminate workers + def main_sigint_handler(signum, frame): # type: ignore + print("\n" + "=" * 80) + print("๐Ÿ›‘ MAIN PROCESS INTERRUPTED BY USER (Ctrl+C)") + print("๐Ÿ“ก Sending termination signal to worker processes...") + print("โณ Giving workers 3 seconds to log their current state...") + print("=" * 80) + # Set a flag for workers to check and give them time to cleanup + global INTERRUPTED + INTERRUPTED = True + import time + + time.sleep(3) # Give workers time to handle the signal and log + sys.exit(1) + + # Register main process signal handler + if not signal.getsignal(signal.SIGINT) == main_sigint_handler: + signal.signal(signal.SIGINT, main_sigint_handler) + + ctx = mp.get_context("spawn") + try: + with ctx.Pool(num_gpus) as pool: + pool.starmap(tune_on_gpu, process_args) + except KeyboardInterrupt: + print("\n๐Ÿ›‘ Keyboard interrupt received in main process") + print("๐Ÿ“ก Worker processes terminated") + sys.exit(1) + except Exception as e: + print(f"\nโŒ Error in main process: {e}") + sys.exit(1) + + print("Multi-GPU tuning completed") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + formatter_class=argparse.RawTextHelpFormatter, + ) + + parser.add_argument("--tp-size", "-tp", type=int, default=1) + parser.add_argument( + "--input-type", type=str, choices=["bfloat16"], default="bfloat16" + ) + parser.add_argument( + "--out-dtype", + type=str, + choices=["float32", "float16", "bfloat16", "half"], + default="bfloat16", + ) + parser.add_argument("--batch-size", type=int, required=False) + parser.add_argument( + "--log-filename", + type=str, + default=None, + help="Custom log filename (without .log extension). If not provided, timestamped filename will be used.", + ) + parser.add_argument( + "--no-resume", + action="store_true", + help="Disable resume functionality and start fresh tuning", + ) + args = parser.parse_args() + + # Convert no_resume flag to resume boolean + args.resume = not args.no_resume + + main(args) diff --git a/aiter/ops/triton/unified_attention.py b/aiter/ops/triton/unified_attention.py index b2231ee563..1b0ecf3861 100644 --- a/aiter/ops/triton/unified_attention.py +++ b/aiter/ops/triton/unified_attention.py @@ -1,5 +1,6 @@ # The kernels in this file are adapted from vLLM: # https://github.com/vllm-project/vllm/blob/main/vllm/attention/ops/triton_unified_attention.py +from aiter.ops.triton.utils._triton import arch_info import triton import torch from aiter.ops.triton.utils.device_info import get_num_sms @@ -27,8 +28,12 @@ def select_2d_config( TILE_SIZE = 64 # in case head_size is large max_num_stages_2d = 4 + dev = arch_info.get_arch() if head_size > 128: - max_num_stages_2d = 2 + if block_size >= 64 and dev == "gfx1201": + max_num_stages_2d = 1 + else: + max_num_stages_2d = 2 if all_decode == False: num_stages_2d = 1 num_warps = 2 diff --git a/aiter/ops/triton/utils/_triton/arch_info.py b/aiter/ops/triton/utils/_triton/arch_info.py index d709d7c8f1..cec9af54bb 100644 --- a/aiter/ops/triton/utils/_triton/arch_info.py +++ b/aiter/ops/triton/utils/_triton/arch_info.py @@ -22,4 +22,4 @@ def is_fp4_avail(): def is_fp8_avail(): - return get_arch() in ("gfx942", "gfx950") + return get_arch() in ("gfx942", "gfx950", "gfx1201") diff --git a/aiter/ops/triton/utils/types.py b/aiter/ops/triton/utils/types.py index 92d18c6eed..6a07c63753 100644 --- a/aiter/ops/triton/utils/types.py +++ b/aiter/ops/triton/utils/types.py @@ -11,7 +11,7 @@ def get_dtype_max(dtype): def get_fp8_dtypes(): - if arch_info.get_arch() in ("gfx950"): + if arch_info.get_arch() in ("gfx950", "gfx1201"): e5m2_dtype = torch.float8_e5m2 e4m3_dtype = torch.float8_e4m3fn else: @@ -22,7 +22,7 @@ def get_fp8_dtypes(): def get_fp8_e4m3_dtype(): - if arch_info.get_arch() in ("gfx950"): + if arch_info.get_arch() in ("gfx950", "gfx1201"): e4m3_dtype = torch.float8_e4m3fn else: e4m3_dtype = torch.float8_e4m3fnuz diff --git a/aiter/utility/dtypes.py b/aiter/utility/dtypes.py index 5252f5f8b3..e1eb66a153 100644 --- a/aiter/utility/dtypes.py +++ b/aiter/utility/dtypes.py @@ -7,6 +7,7 @@ defaultDtypes = { "gfx942": {"fp8": torch.float8_e4m3fnuz}, "gfx950": {"fp8": torch.float8_e4m3fn}, + "gfx1201": {"fp8": torch.float8_e4m3fn}, } _8bit_fallback = torch.uint8 diff --git a/op_tests/triton_tests/README.md b/op_tests/triton_tests/README.md index a10e3275b5..7a9ef98b1a 100644 --- a/op_tests/triton_tests/README.md +++ b/op_tests/triton_tests/README.md @@ -6,4 +6,6 @@ Triton's implementation uses numpy under the hood. Switch back to Triton if you're experiencing OOM issues. 2. When possible, generate test inputs directly on the GPU (e.g with `torch.randn((M, K), device="cuda")` as opposed to `torch.randn((M, K)).cuda()`). -It's ~2 orders of magnitude faster for large test cases. \ No newline at end of file +It's ~2 orders of magnitude faster for large test cases. + +** Run the tests with ```AITER_TRITON_LOG_LEVEL=INFO``` to get logging \ No newline at end of file diff --git a/op_tests/triton_tests/gemm/basic/test_gemm_a16w16.py b/op_tests/triton_tests/gemm/basic/test_gemm_a16w16.py index 66f64b6119..bc63eeb4ff 100644 --- a/op_tests/triton_tests/gemm/basic/test_gemm_a16w16.py +++ b/op_tests/triton_tests/gemm/basic/test_gemm_a16w16.py @@ -4,10 +4,14 @@ import torch import torch.nn.functional as F import pytest +import logging from aiter.ops.triton.gemm_a16w16 import gemm_a16w16 from aiter.ops.triton.gemm_a16w16_atomic import gemm_a16w16_atomic from op_tests.triton_tests.utils.types import str_to_torch_dtype +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + def generate_gemm_a16w16_inputs(M, N, K, dtype, layout="TN", output=True, bias=False): if isinstance(dtype, str): @@ -80,6 +84,7 @@ def get_x_vals(): @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("output", [True, False]) def test_gemm_a16_w16_activation(M: int, N: int, K: int, dtype, output, activation): + logger.info(f"Running test_gemm_a16_w16_activation with M={M}, N={N}, K={K}, dtype={dtype}, output={output}, activation={activation}") x, w, _, out_dtype, y = generate_gemm_a16w16_inputs( M, N, @@ -123,6 +128,7 @@ def test_gemm_a16_w16_activation(M: int, N: int, K: int, dtype, output, activati @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("output", [True, False]) def test_gemm_a16_w16(M: int, N: int, K: int, dtype, output): + logger.info(f"Running test_gemm_a16_w16 with M={M}, N={N}, K={K}, dtype={dtype}, output={output}") torch.cuda.empty_cache() # Helps avoid hangs in large tests x, w, bias, out_dtype, y = generate_gemm_a16w16_inputs( @@ -144,6 +150,7 @@ def test_gemm_a16_w16(M: int, N: int, K: int, dtype, output): @pytest.mark.parametrize("layout", ["TT", "NN", "NT"]) @pytest.mark.parametrize("output", [True, False]) def test_gemm_a16_w16_layout(M: int, N: int, K: int, dtype, layout, output): + logger.info(f"Running test_gemm_a16_w16_layout with M={M}, N={N}, K={K}, dtype={dtype}, layout={layout}, output={output}") torch.cuda.empty_cache() # Helps avoid hangs in large tests x, w, _, out_dtype, y = generate_gemm_a16w16_inputs( @@ -164,6 +171,7 @@ def test_gemm_a16_w16_layout(M: int, N: int, K: int, dtype, layout, output): @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("output", [True, False]) def test_gemm_a16_w16_atomic(M: int, N: int, K: int, dtype, output): + logger.info(f"Running test_gemm_a16_w16_atomic with M={M}, N={N}, K={K}, dtype={dtype}, output={output}") torch.cuda.empty_cache() # Helps avoid hangs in large tests x, w, _, out_dtype, y = generate_gemm_a16w16_inputs(M, N, K, dtype, output=output) @@ -185,6 +193,7 @@ def test_gemm_a16_w16_atomic(M: int, N: int, K: int, dtype, output): @pytest.mark.parametrize("layout", ["TT", "NN", "NT"]) @pytest.mark.parametrize("output", [True, False]) def test_gemm_a16_w16_atomic_layout(M: int, N: int, K: int, dtype, layout, output): + logger.info(f"Running test_gemm_a16_w16_atomic_layout with M={M}, N={N}, K={K}, dtype={dtype}, layout={layout}, output={output}") torch.cuda.empty_cache() # Helps avoid hangs in large tests x, w, _, out_dtype, y = generate_gemm_a16w16_inputs( diff --git a/op_tests/triton_tests/moe/test_moe.py b/op_tests/triton_tests/moe/test_moe.py index 28a7eccd42..ce2c563a09 100644 --- a/op_tests/triton_tests/moe/test_moe.py +++ b/op_tests/triton_tests/moe/test_moe.py @@ -324,7 +324,7 @@ def quantize_fp8( tensor: torch.Tensor, dim=() ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: dev = arch_info.get_arch() - if dev == "gfx950": + if dev in ["gfx950", "gfx1201"]: fp8_type = torch.float8_e4m3fn else: fp8_type = torch.float8_e4m3fnuz @@ -372,7 +372,6 @@ def quantize_int8( def quantize_int4( tensor: torch.Tensor, group_size: int, has_zp: bool ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - # reshape tensor k, n = tensor.shape tensor = tensor.reshape(-1, group_size, n) @@ -486,7 +485,6 @@ def input_helper_int4_w4a16( group_size: int, has_zp: bool, ): - a = torch.randn((M, K), dtype=dtype, device="cuda") b = torch.rand((E, N, K), dtype=dtype, device="cuda") @@ -775,7 +773,6 @@ def test_fused_moe_int4_w4a16( persistent: bool, silu_fused: bool, ): - torch.cuda.empty_cache() # Helps avoid hangs in large tests if ( M == 1 diff --git a/op_tests/triton_tests/moe/test_moe_routing.py b/op_tests/triton_tests/moe/test_moe_routing.py index eae12dd40f..6b24bdfb7c 100644 --- a/op_tests/triton_tests/moe/test_moe_routing.py +++ b/op_tests/triton_tests/moe/test_moe_routing.py @@ -97,7 +97,7 @@ def init_data(n_tokens, n_expts_tot, dtype=torch.float16, device="cuda"): @pytest.mark.parametrize("use_expt_indx", [False, True]) @pytest.mark.parametrize("sm_first", [True, False]) def test_op(n_tokens, n_expts_tot, n_expts_act, sm_first, use_expt_indx): - if get_arch() != "gfx950": + if get_arch() not in ["gfx950", "gfx1201"]: pytest.skip("MOE stack not fully implemented on non-CDNA4 arch yet.") device = "cuda" diff --git a/op_tests/triton_tests/test_gemm_a16w16_qwen3.py b/op_tests/triton_tests/test_gemm_a16w16_qwen3.py new file mode 100644 index 0000000000..9681e7f5fa --- /dev/null +++ b/op_tests/triton_tests/test_gemm_a16w16_qwen3.py @@ -0,0 +1,179 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +import torch +import torch.nn.functional as F +import pytest +from aiter.ops.triton.gemm_a16w16 import gemm_a16w16 +from aiter.ops.triton.gemm_a16w16_atomic import gemm_a16w16_atomic +from op_tests.triton_tests.utils.types import str_to_torch_dtype + + +def generate_gemm_a16w16_inputs(M, N, K, dtype, layout="TN", output=True, bias=False): + if isinstance(dtype, str): + dtype = str_to_torch_dtype[dtype] + + # TN is default layout + if layout[0] == "T": + x = torch.randn((M, K), dtype=dtype, device="cuda") + else: + x = torch.randn((K, M), dtype=dtype, device="cuda").T + + if layout[1] == "T": + weight = torch.randn((K, N), dtype=dtype, device="cuda").T + else: + weight = torch.randn((N, K), dtype=dtype, device="cuda") + + bias_tensor = None + if bias: + bias_tensor = torch.empty((N), dtype=dtype, device="cuda") + + y = None + if output: + y = torch.empty((M, N), dtype=dtype, device="cuda") + out_dtype = (None,) + else: + out_dtype = dtype + + return x, weight, bias_tensor, out_dtype, y + + +def get_x_vals(): + x_vals = [] # minimal case + for batch in [64,128,256,512,2048,4096]: + x_vals += [ + (batch, 1024, 1024), + (batch, 4096, 1024), + (batch, 1024, 2048), + (batch, 6144, 1024), + (batch, 1024, 3072), + ] + return x_vals + + +@pytest.mark.parametrize("activation", ["gelu", "gelu_tanh", "silu", "silu_exp2"]) +@pytest.mark.parametrize("M, N, K", get_x_vals()) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("output", [True, False]) +def test_gemm_a16_w16_activation(M: int, N: int, K: int, dtype, output, activation): + x, w, _, out_dtype, y = generate_gemm_a16w16_inputs( + M, + N, + K, + dtype, + output=output, + ) + + torch_out = F.linear(x, w, bias=None) + if activation == "gelu": + torch_out = F.gelu(torch_out) + elif activation == "gelu_tanh": + torch_out = F.gelu(torch_out, approximate="tanh") + elif activation == "silu": + torch_out = F.silu(torch_out) + elif activation == "silu_exp2": + torch_out = F.silu(torch_out) + + if output: + triton_out = gemm_a16w16( + x, + w, + None, + out_dtype, + y, + activation=activation, + ) + else: + triton_out = gemm_a16w16( + x, + w, + None, + out_dtype, + activation=activation, + ) + + torch.testing.assert_close(triton_out, torch_out, atol=1e-1, rtol=1e-2) + + +@pytest.mark.parametrize("M, N, K", get_x_vals()) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("output", [True, False]) +def test_gemm_a16_w16(M: int, N: int, K: int, dtype, output): + torch.cuda.empty_cache() # Helps avoid hangs in large tests + + x, w, bias, out_dtype, y = generate_gemm_a16w16_inputs( + M, N, K, dtype, output=output, bias=True + ) + + torch_out = F.linear(x, w, bias=bias) + + if output: + triton_out = gemm_a16w16(x, w, bias, out_dtype, y) + else: + triton_out = gemm_a16w16(x, w, bias, out_dtype) + + torch.testing.assert_close(triton_out, torch_out, atol=1e-1, rtol=1e-1) + + +@pytest.mark.parametrize("M, N, K", get_x_vals()) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("layout", ["TT", "NN", "NT"]) +@pytest.mark.parametrize("output", [True, False]) +def test_gemm_a16_w16_layout(M: int, N: int, K: int, dtype, layout, output): + torch.cuda.empty_cache() # Helps avoid hangs in large tests + + x, w, _, out_dtype, y = generate_gemm_a16w16_inputs( + M, N, K, dtype, layout=layout, output=output + ) + + torch_out = F.linear(x, w, bias=None) + + if output: + triton_out = gemm_a16w16(x, w, None, out_dtype, y) + else: + triton_out = gemm_a16w16(x, w, None, out_dtype) + + torch.testing.assert_close(triton_out, torch_out, atol=1e-1, rtol=1e-1) + + +@pytest.mark.parametrize("M, N, K", get_x_vals()) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("output", [True, False]) +def test_gemm_a16_w16_atomic(M: int, N: int, K: int, dtype, output): + torch.cuda.empty_cache() # Helps avoid hangs in large tests + + x, w, _, out_dtype, y = generate_gemm_a16w16_inputs(M, N, K, dtype, output=output) + + torch_out = F.linear(x, w, bias=None) + + # Accumulation in bf16/fp16 leads to precision loss, cast y to fp32 to prevent that + if output: + y = y.to(torch.float32).zero_() + triton_out = gemm_a16w16_atomic(x, w, torch.float32, y).to(dtype) + else: + triton_out = gemm_a16w16_atomic(x, w, dtype=torch.float32).to(dtype) + + torch.testing.assert_close(triton_out, torch_out, atol=1e-1, rtol=1e-1) + + +@pytest.mark.parametrize("M, N, K", get_x_vals()) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("layout", ["TT", "NN", "NT"]) +@pytest.mark.parametrize("output", [True, False]) +def test_gemm_a16_w16_atomic_layout(M: int, N: int, K: int, dtype, layout, output): + torch.cuda.empty_cache() # Helps avoid hangs in large tests + + x, w, _, out_dtype, y = generate_gemm_a16w16_inputs( + M, N, K, dtype, layout=layout, output=output + ) + + torch_out = F.linear(x, w, bias=None) + + # Accumulation in bf16/fp16 leads to precision loss, cast y to fp32 to prevent that + if output: + y = y.to(torch.float32).zero_() + triton_out = gemm_a16w16_atomic(x, w, torch.float32, y).to(dtype) + else: + triton_out = gemm_a16w16_atomic(x, w, dtype=torch.float32).to(dtype) + + torch.testing.assert_close(triton_out, torch_out, atol=1e-1, rtol=1e-1) diff --git a/pyproject.toml b/pyproject.toml index 36ada796c1..e39cd81668 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,3 +14,7 @@ requires = [ write_to = "aiter/_version.py" write_to_template = "__version__ = '{version}'\n" fallback_version = "0.1.0" + +[tool.pytest.ini_options] +log_cli = true +log_cli_level = "INFO" \ No newline at end of file