Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions aiter/ops/triton/README_tune_atomic.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# GEMM A16W16 Atomic Kernel Tuning

This document explains how to use the tuning script for the atomic GEMM kernel.

## Overview

The atomic GEMM kernel (`gemm_a16w16_atomic`) is a specialized kernel that uses atomic operations for split-K reduction. This allows for better parallelization in certain scenarios, especially for larger matrix dimensions.

## Tuning Script

The tuning script is located at `aiter/ops/triton/tune_a16w16_atomic.py`. It follows a similar pattern to the regular GEMM tuning script but with specific adaptations for the atomic kernel.

### Key Differences from Regular GEMM Tuning

1. **NUM_KSPLIT Parameter**: The atomic kernel includes a `NUM_KSPLIT` parameter that controls how the K dimension is split across multiple thread blocks for parallel reduction.

2. **Configuration Categories**: The atomic kernel uses different configuration categories based on the M dimension:
- `small`: M < 32
- `medium_M32`: M ≤ 128 with BLOCK_SIZE_M = 32
- `medium_M64`: M ≤ 128 with BLOCK_SIZE_M = 64
- `medium_M128`: M ≤ 128 with BLOCK_SIZE_M = 128
- `large`: M ≤ 256
- `xlarge`: M > 256

3. **No Bias Parameter**: The atomic kernel doesn't support a bias parameter.

### Running the Tuning Script

To run the tuning script:

```bash
cd aiter/ops/triton
python tune_a16w16_atomic.py
```

You can also specify specific parameters:

```bash
python tune_a16w16_atomic.py --batch-size 512 --input-type bfloat16 --out-dtype bfloat16
```

### Output

The tuning script will generate configuration files in the `aiter/ops/triton/configs/gemm/` directory with names like:
- `R9700-GEMM-A16W16-ATOMIC-N={N}-K={K}.json` for specific N,K dimensions
- `R9700-GEMM-A16W16-ATOMIC.json` - A default config file (without N,K parameters) that contains the most common optimal configurations across all tested shapes

The default config file is created by analyzing all the specific N,K configurations and selecting the most common optimal configuration for each category (small, medium_M32, etc.). This provides a good general-purpose configuration when a specific N,K configuration is not available.

### Configuration Parameters

The tuning script searches through these parameters:
- `BLOCK_SIZE_M`: [16, 32, 64, 128]
- `BLOCK_SIZE_N`: [32, 64, 128, 256]
- `BLOCK_SIZE_K`: [64, 128]
- `GROUP_SIZE_M`: [1, 8, 16]
- `NUM_KSPLIT`: [1, 2, 4, 8] (atomic kernel specific)
- `num_warps`: [4, 8]
- `num_stages`: [2]
- `waves_per_eu`: [3]

### Batch Sizes Tested

The script tests these batch sizes (M dimensions):
- 16 (small)
- 32 (medium_M32)
- 64 (medium_M64)
- 128 (medium_M128)
- 256 (large)
- 512 (large)
- 2048 (xlarge)
- 4096 (xlarge)

### Weight Shapes Tested

The script tunes for these weight shapes (N, K):
- (1024, 1024)
- (4096, 1024)
- (1024, 2048)
- (6144, 1024)
- (1024, 3072)

## Usage in Code

Once the tuning is complete, the atomic kernel can be used in your code:

```python
from aiter.ops.triton.gemm_a16w16_atomic import gemm_a16w16_atomic

# Basic usage
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
w = torch.randn(N, K, dtype=torch.bfloat16, device="cuda")

# The kernel will automatically use the tuned configurations
y = gemm_a16w16_atomic(x, w, dtype=torch.bfloat16)
```

## Notes

1. The atomic kernel is particularly useful for larger matrices where split-K parallelization provides benefits.
2. For smaller matrices, the regular GEMM kernel might be more efficient.
3. The tuning process can take considerable time as it tests many configurations.
4. Make sure you have sufficient GPU memory for the tuning process.
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ def is_fp8(x) -> bool:
def _is_fp8_single(t: torch.Tensor) -> bool:
if is_dtype_fp8(t.dtype):
arch = get_arch()
if arch not in ("gfx942", "gfx950"):
if arch not in ("gfx942", "gfx950", "gfx1201"):
raise RuntimeError(
f"{arch} is not in the list of supported architectures for FP8"
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
{
"small": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 4,
"matrix_instr_nonkdim": 16,
"cache_modifier": "",
"NUM_KSPLIT": 1,
"SPLITK_BLOCK_SIZE": 1024
},
"medium_M32": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3,
"waves_per_eu": 8,
"matrix_instr_nonkdim": 16,
"cache_modifier": "",
"NUM_KSPLIT": 1,
"SPLITK_BLOCK_SIZE": 1024
},
"medium_M64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3,
"waves_per_eu": 4,
"matrix_instr_nonkdim": 16,
"cache_modifier": "",
"NUM_KSPLIT": 1,
"SPLITK_BLOCK_SIZE": 1024
},
"medium_M128": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 2,
"matrix_instr_nonkdim": 16,
"cache_modifier": "",
"NUM_KSPLIT": 1,
"SPLITK_BLOCK_SIZE": 1024
},
"large": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 4,
"matrix_instr_nonkdim": 16,
"cache_modifier": "",
"NUM_KSPLIT": 1,
"SPLITK_BLOCK_SIZE": 1024
},
"xlarge": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 4,
"matrix_instr_nonkdim": 16,
"cache_modifier": "",
"NUM_KSPLIT": 1,
"SPLITK_BLOCK_SIZE": 1024
},
"any": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 4,
"matrix_instr_nonkdim": 16,
"cache_modifier": "",
"NUM_KSPLIT": 1,
"SPLITK_BLOCK_SIZE": 1024
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
{
"small": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5,
"waves_per_eu": 4,
"matrix_instr_nonkdim": 16,
"cache_modifier": "",
"NUM_KSPLIT": 1,
"SPLITK_BLOCK_SIZE": 2048
},
"medium_M32": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3,
"waves_per_eu": 8,
"matrix_instr_nonkdim": 16,
"cache_modifier": "",
"NUM_KSPLIT": 1,
"SPLITK_BLOCK_SIZE": 2048
},
"medium_M64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 2,
"matrix_instr_nonkdim": 16,
"cache_modifier": "",
"NUM_KSPLIT": 1,
"SPLITK_BLOCK_SIZE": 2048
},
"medium_M128": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 2,
"matrix_instr_nonkdim": 16,
"cache_modifier": "",
"NUM_KSPLIT": 1,
"SPLITK_BLOCK_SIZE": 2048
},
"large": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 2,
"matrix_instr_nonkdim": 16,
"cache_modifier": "",
"NUM_KSPLIT": 1,
"SPLITK_BLOCK_SIZE": 2048
},
"xlarge": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 4,
"matrix_instr_nonkdim": 16,
"cache_modifier": "",
"NUM_KSPLIT": 1,
"SPLITK_BLOCK_SIZE": 2048
},
"any": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 4,
"matrix_instr_nonkdim": 16,
"cache_modifier": "",
"NUM_KSPLIT": 1,
"SPLITK_BLOCK_SIZE": 2048
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
{
"small": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 4,
"matrix_instr_nonkdim": 16,
"cache_modifier": "",
"NUM_KSPLIT": 1,
"SPLITK_BLOCK_SIZE": 3072
},
"medium_M32": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 2,
"matrix_instr_nonkdim": 16,
"cache_modifier": "",
"NUM_KSPLIT": 1,
"SPLITK_BLOCK_SIZE": 3072
},
"medium_M64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 8,
"matrix_instr_nonkdim": 16,
"cache_modifier": "",
"NUM_KSPLIT": 1,
"SPLITK_BLOCK_SIZE": 3072
},
"medium_M128": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 2,
"matrix_instr_nonkdim": 16,
"cache_modifier": "",
"NUM_KSPLIT": 1,
"SPLITK_BLOCK_SIZE": 3072
},
"large": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 2,
"matrix_instr_nonkdim": 16,
"cache_modifier": "",
"NUM_KSPLIT": 1,
"SPLITK_BLOCK_SIZE": 3072
},
"xlarge": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 4,
"matrix_instr_nonkdim": 16,
"cache_modifier": "",
"NUM_KSPLIT": 1,
"SPLITK_BLOCK_SIZE": 3072
},
"any": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 2,
"matrix_instr_nonkdim": 16,
"cache_modifier": "",
"NUM_KSPLIT": 1,
"SPLITK_BLOCK_SIZE": 3072
}
}
Loading