diff --git a/aiter/ops/triton/batched_gemm_a8w8.py b/aiter/ops/triton/batched_gemm_a8w8.py index 1c99e8d7e5..446b835264 100644 --- a/aiter/ops/triton/batched_gemm_a8w8.py +++ b/aiter/ops/triton/batched_gemm_a8w8.py @@ -1,13 +1,15 @@ -# SPDX-License-Identifier: MIT -# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. - # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +from typing import Optional +import functools +import json import torch import triton import triton.language as tl from typing import Optional +import aiter.ops.triton.utils.arch_info as arch_info +from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH @triton.heuristics( @@ -181,6 +183,26 @@ def _batched_gemm_a8w8_kernel( tl.store(c_ptrs, c, mask=c_mask) +@functools.lru_cache(maxsize=1024) +def _get_config( + M: int, + N: int, + K: int, +): + if not hasattr(_get_config, "_config_dict"): + dev = arch_info.get_device() + fpath = f"{AITER_TRITON_CONFIGS_PATH}/gemm/{dev}-BATCHED_GEMM-A8W8.json" + print(f"fpath={fpath}") + with open(fpath, "r") as file: + config = json.load(file) + _get_config._config_dict = config + + if M + N >= 4096: + return _get_config._config_dict["large"] + else: + return _get_config._config_dict["small"] + + def batched_gemm_a8w8( XQ: torch.Tensor, WQ: torch.Tensor, @@ -190,6 +212,7 @@ def batched_gemm_a8w8( dtype: Optional[torch.dtype] = torch.bfloat16, splitK: Optional[int] = None, YQ: Optional[torch.Tensor] = None, + config: Optional[dict] = None, ): """ Computes the matmul YQ[i] = XQ[i] x WQ[i]T and applies a conversion scale for every i in a given batch. @@ -235,25 +258,12 @@ def batched_gemm_a8w8( if YQ is None: YQ = torch.empty((B, M, N), dtype=dtype, device=XQ.device) - if (M + N) >= 4096: - BLOCK_SIZE_M = 256 - BLOCK_SIZE_N = 256 - BLOCK_SIZE_K = 64 - GROUP_SIZE_M = 4 - else: - BLOCK_SIZE_M = 128 - BLOCK_SIZE_N = 128 - BLOCK_SIZE_K = 32 - GROUP_SIZE_M = 1 + if config is None: + config = _get_config(M, N, K) - waves_per_eu = 2 - matrix_instr_nonkdim = 16 - num_warps = 8 - num_stages = 2 - - grid = ( + grid = lambda META: ( # noqa: E731 B, - triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N), + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), ) _batched_gemm_a8w8_kernel[grid]( @@ -279,13 +289,7 @@ def batched_gemm_a8w8( w_scale.stride(0), bias.stride(0) if has_bias else 0, has_bias, - BLOCK_SIZE_M, - BLOCK_SIZE_N, - BLOCK_SIZE_K, - GROUP_SIZE_M, - waves_per_eu=waves_per_eu, - matrix_instr_nonkdim=matrix_instr_nonkdim, - num_warps=num_warps, - num_stages=num_stages, + **config, ) + return YQ diff --git a/aiter/ops/triton/batched_gemm_afp4wfp4.py b/aiter/ops/triton/batched_gemm_afp4wfp4.py index 30570d81ab..65d2b4fa17 100644 --- a/aiter/ops/triton/batched_gemm_afp4wfp4.py +++ b/aiter/ops/triton/batched_gemm_afp4wfp4.py @@ -1,9 +1,25 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + from typing import Optional +import functools +import json import os import torch import triton import triton.language as tl from aiter.ops.triton.utils.pid_preprocessing import pid_grid, remap_xcd +import aiter.ops.triton.utils.arch_info as arch_info +from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH + + +global _USE_GEMM_SPLITK_BF16 +_USE_GEMM_SPLITK_BF16 = False + + +def set_use_gemm_splitk_bf16(value: bool): + global _USE_GEMM_SPLITK_BF16 + _USE_GEMM_SPLITK_BF16 = value @triton.heuristics( @@ -259,11 +275,49 @@ def get_splitk(K: int, BLOCK_SIZE_K: int, NUM_KSPLIT: int): triton.cdiv((2 * triton.cdiv(K, NUM_KSPLIT)), BLOCK_SIZE_K) * BLOCK_SIZE_K ) - # print(K, SPLITK_BLOCK_SIZE // 2, BLOCK_SIZE_K // 2, NUM_KSPLIT) - # print(K % (SPLITK_BLOCK_SIZE // 2) == 0, SPLITK_BLOCK_SIZE % BLOCK_SIZE_K == 0, K % (BLOCK_SIZE_K // 2) == 0) return SPLITK_BLOCK_SIZE, BLOCK_SIZE_K, NUM_KSPLIT +@functools.lru_cache(maxsize=1024) +def _get_config( + M: int, + N: int, + K: int, +): + if not hasattr(_get_config, "_config_dict"): + dev = arch_info.get_device() + fpath = f"{AITER_TRITON_CONFIGS_PATH}/gemm/{dev}-BATCHED_GEMM-AFP4WFP4.json" + with open(fpath, "r") as file: + config = json.load(file) + _get_config._config_dict = config + + if M < 32: + config = _get_config._config_dict["M_LT_32"] + elif M < 64: + config = _get_config._config_dict["M_LT_64"] + elif M < 128: + config = _get_config._config_dict["M_LT_128"] + elif M == 128: + config = _get_config._config_dict["M_EQ_128"] + elif M <= 256: + config = _get_config._config_dict["M_LTE_256"] + else: + config = _get_config._config_dict["default"] + + if M <= 128: + SPLITK_BLOCK_SIZE, BLOCK_SIZE_K, NUM_KSPLIT = get_splitk( + K, config["BLOCK_SIZE_K"], config["NUM_KSPLIT"] + ) + config["SPLITK_BLOCK_SIZE"] = SPLITK_BLOCK_SIZE + config["BLOCK_SIZE_K"] = BLOCK_SIZE_K + config["NUM_KSPLIT"] = NUM_KSPLIT + + else: + config["SPLITK_BLOCK_SIZE"] = 2 * K + + return config + + def batched_gemm_afp4wfp4( x: torch.Tensor, w: torch.Tensor, @@ -271,6 +325,7 @@ def batched_gemm_afp4wfp4( x_scales: torch.Tensor, w_scales: torch.Tensor, dtype: Optional[float] = torch.bfloat16, + config: Optional[dict] = None, ): """ Computes the matmul Y = X x W @@ -289,97 +344,35 @@ def batched_gemm_afp4wfp4( - Y: The output matrix with shape (M, N). """ + assert arch_info.is_fp4_avail(), "MXFP4 is not available on your device" + Bx, M, K = x.shape Bw, K, N = w.shape By, _, _ = y.shape assert Bx == Bw == By Batch = Bx - if M < 32: - BLOCK_SIZE_M = 16 - BLOCK_SIZE_N = 64 - BLOCK_SIZE_K = 256 - GROUP_SIZE_M = 1 - waves_per_eu = 6 - kpack = 1 - num_warps = 4 - num_stages = 2 - matrix_instr_nonkdim = 16 - cache_modifier = ".cg" - - NUM_KSPLIT = 4 - SPLITK_BLOCK_SIZE, BLOCK_SIZE_K, NUM_KSPLIT = get_splitk( - K, BLOCK_SIZE_K, NUM_KSPLIT - ) - - if os.getenv("VLLM_TRITON_FP4_GEMM_SPLITK_USE_BF16") == "1": - y_pp = torch.empty( - (Batch, NUM_KSPLIT, M, N), dtype=y.dtype, device=y.device - ) - else: - y_pp = torch.empty( - (Batch, NUM_KSPLIT, M, N), dtype=torch.float32, device=y.device - ) - elif M <= 128: - BLOCK_SIZE_M = triton.next_power_of_2(M) - BLOCK_SIZE_N = 128 - BLOCK_SIZE_K = 256 - GROUP_SIZE_M = 1 - waves_per_eu = 4 - kpack = 1 - num_warps = 4 - num_stages = 2 - matrix_instr_nonkdim = 16 - cache_modifier = ".cg" - - NUM_KSPLIT = 4 - SPLITK_BLOCK_SIZE, BLOCK_SIZE_K, NUM_KSPLIT = get_splitk( - K, BLOCK_SIZE_K, NUM_KSPLIT - ) + if config is None: + config = _get_config(M, N, K) - if os.getenv("VLLM_TRITON_FP4_GEMM_SPLITK_USE_BF16") == "1": + if M <= 128: + if _USE_GEMM_SPLITK_BF16: y_pp = torch.empty( - (Batch, NUM_KSPLIT, M, N), dtype=y.dtype, device=y.device + (Batch, config["NUM_KSPLIT"], M, N), dtype=y.dtype, device=y.device ) else: y_pp = torch.empty( - (Batch, NUM_KSPLIT, M, N), dtype=torch.float32, device=y.device + (Batch, config["NUM_KSPLIT"], M, N), + dtype=torch.float32, + device=y.device, ) - elif M <= 256: - BLOCK_SIZE_M = 128 - BLOCK_SIZE_N = 128 - BLOCK_SIZE_K = 256 - GROUP_SIZE_M = 2 - waves_per_eu = 4 - kpack = 1 - num_warps = 4 - num_stages = 2 - matrix_instr_nonkdim = 16 - cache_modifier = ".cg" - - NUM_KSPLIT = 1 - SPLITK_BLOCK_SIZE = 2 * K - y_pp = None else: - BLOCK_SIZE_M = 256 - BLOCK_SIZE_N = 256 - BLOCK_SIZE_K = 256 - GROUP_SIZE_M = 64 - waves_per_eu = 1 - kpack = 1 - num_warps = 8 - num_stages = 2 - matrix_instr_nonkdim = 16 - cache_modifier = None - - NUM_KSPLIT = 1 - SPLITK_BLOCK_SIZE = 2 * K y_pp = None grid = lambda META: ( # noqa: E731 Batch, ( - NUM_KSPLIT + config["NUM_KSPLIT"] * triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]) ), @@ -387,7 +380,7 @@ def batched_gemm_afp4wfp4( _batched_gemm_afp4_wfp4_kernel[grid]( x, w, - y if NUM_KSPLIT == 1 else y_pp, + y if config["NUM_KSPLIT"] == 1 else y_pp, x_scales, w_scales, M, @@ -399,39 +392,26 @@ def batched_gemm_afp4wfp4( w.stride(0), w.stride(1), w.stride(2), - y.stride(0) if NUM_KSPLIT == 1 else y_pp.stride(0), - 0 if NUM_KSPLIT == 1 else y_pp.stride(1), - y.stride(1) if NUM_KSPLIT == 1 else y_pp.stride(2), - y.stride(2) if NUM_KSPLIT == 1 else y_pp.stride(3), + y.stride(0) if config["NUM_KSPLIT"] == 1 else y_pp.stride(0), + 0 if config["NUM_KSPLIT"] == 1 else y_pp.stride(1), + y.stride(1) if config["NUM_KSPLIT"] == 1 else y_pp.stride(2), + y.stride(2) if config["NUM_KSPLIT"] == 1 else y_pp.stride(3), x_scales.stride(0), x_scales.stride(1), x_scales.stride(2), w_scales.stride(0), w_scales.stride(1), w_scales.stride(2), - BLOCK_SIZE_M, - BLOCK_SIZE_N, - BLOCK_SIZE_K, - GROUP_SIZE_M, - NUM_KSPLIT, - SPLITK_BLOCK_SIZE, - cache_modifier=cache_modifier, - waves_per_eu=waves_per_eu, - kpack=kpack, - num_warps=num_warps, - num_stages=num_stages, - matrix_instr_nonkdim=matrix_instr_nonkdim, + **config, ) - if NUM_KSPLIT > 1: + if config["NUM_KSPLIT"] > 1: REDUCE_BLOCK_SIZE_M = 16 # TODO: Need to debug - REDUCE_BLOCK_SIZE_N=128 with fp32 partials fails # NOTE: REDUCE_BLOCK_SIZE_N=16 gives best perf with fp32 partials and # REDUCE_BLOCK_SIZE_N=128 gives best perf with bf16 partials - REDUCE_BLOCK_SIZE_N = ( - 128 if os.getenv("VLLM_TRITON_FP4_GEMM_SPLITK_USE_BF16") == "1" else 64 - ) - ACTUAL_KSPLIT = triton.cdiv(K, (SPLITK_BLOCK_SIZE // 2)) + REDUCE_BLOCK_SIZE_N = 128 if _USE_GEMM_SPLITK_BF16 else 64 + ACTUAL_KSPLIT = triton.cdiv(K, (config["SPLITK_BLOCK_SIZE"] // 2)) grid_reduce = ( Batch, @@ -453,5 +433,5 @@ def batched_gemm_afp4wfp4( REDUCE_BLOCK_SIZE_M, REDUCE_BLOCK_SIZE_N, ACTUAL_KSPLIT, - NUM_KSPLIT, + config["NUM_KSPLIT"], ) diff --git a/aiter/ops/triton/batched_gemm_bf16.py b/aiter/ops/triton/batched_gemm_bf16.py index efa18465e5..417a2f4d70 100644 --- a/aiter/ops/triton/batched_gemm_bf16.py +++ b/aiter/ops/triton/batched_gemm_bf16.py @@ -1,13 +1,14 @@ -# SPDX-License-Identifier: MIT -# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. - # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +from typing import Optional +import functools +import json import torch import triton import triton.language as tl -from typing import Optional +import aiter.ops.triton.utils.arch_info as arch_info +from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH @triton.heuristics( @@ -160,6 +161,26 @@ def _batched_gemm_bf16_kernel( tl.store(c_ptrs, c, mask=c_mask) +@functools.lru_cache(maxsize=1024) +def _get_config( + M: int, + N: int, + K: int, +): + if not hasattr(_get_config, "_config_dict"): + dev = arch_info.get_device() + fpath = f"{AITER_TRITON_CONFIGS_PATH}/gemm/{dev}-BATCHED_GEMM-A16W16.json" + print(f"fpath={fpath}") + with open(fpath, "r") as file: + config = json.load(file) + _get_config._config_dict = config + + if M + N >= 4096: + return _get_config._config_dict["large"] + else: + return _get_config._config_dict["small"] + + def batched_gemm_bf16( XQ: torch.Tensor, WQ: torch.Tensor, @@ -167,6 +188,7 @@ def batched_gemm_bf16( dtype: Optional[torch.dtype] = torch.bfloat16, splitK: Optional[int] = None, YQ: Optional[torch.Tensor] = None, + config: Optional[dict] = None, ): """ Computes the matmul YQ[i] = XQ[i] x WQ[i]T for every i in a given batch and optionally adds a bias to each result. @@ -206,25 +228,12 @@ def batched_gemm_bf16( if YQ is None: YQ = torch.empty((B, M, N), dtype=dtype, device=XQ.device) - if (M + N) >= 4096: - BLOCK_SIZE_M = 256 - BLOCK_SIZE_N = 256 - BLOCK_SIZE_K = 64 - GROUP_SIZE_M = 4 - else: - BLOCK_SIZE_M = 128 - BLOCK_SIZE_N = 128 - BLOCK_SIZE_K = 32 - GROUP_SIZE_M = 1 + if config is None: + config = _get_config(M, N, K) - waves_per_eu = 2 - matrix_instr_nonkdim = 16 - num_warps = 8 - num_stages = 2 - - grid = ( + grid = lambda META: ( # noqa: E731 B, - triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N), + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), ) _batched_gemm_bf16_kernel[grid]( @@ -246,13 +255,7 @@ def batched_gemm_bf16( YQ.stride(2), bias.stride(0) if has_bias else 0, has_bias, - BLOCK_SIZE_M, - BLOCK_SIZE_N, - BLOCK_SIZE_K, - GROUP_SIZE_M, - waves_per_eu=waves_per_eu, - matrix_instr_nonkdim=matrix_instr_nonkdim, - num_warps=num_warps, - num_stages=num_stages, + **config, ) + return YQ diff --git a/aiter/ops/triton/configs/gemm/MI300X-BATCHED_GEMM-A16W16.json b/aiter/ops/triton/configs/gemm/MI300X-BATCHED_GEMM-A16W16.json new file mode 100644 index 0000000000..947fe026c7 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/MI300X-BATCHED_GEMM-A16W16.json @@ -0,0 +1,24 @@ +{ + "large": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16 + }, + "small" : { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16 + } +} + + \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/MI300X-BATCHED_GEMM-A8W8.json b/aiter/ops/triton/configs/gemm/MI300X-BATCHED_GEMM-A8W8.json new file mode 100644 index 0000000000..947fe026c7 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/MI300X-BATCHED_GEMM-A8W8.json @@ -0,0 +1,24 @@ +{ + "large": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16 + }, + "small" : { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16 + } +} + + \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/MI300X-GEMM_BLOCKSCALE-A8W8.json b/aiter/ops/triton/configs/gemm/MI300X-GEMM_BLOCKSCALE-A8W8.json new file mode 100644 index 0000000000..2c731811f4 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/MI300X-GEMM_BLOCKSCALE-A8W8.json @@ -0,0 +1,13 @@ +{ + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/MI350X-BATCHED_GEMM-A16W16.json b/aiter/ops/triton/configs/gemm/MI350X-BATCHED_GEMM-A16W16.json new file mode 100644 index 0000000000..947fe026c7 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/MI350X-BATCHED_GEMM-A16W16.json @@ -0,0 +1,24 @@ +{ + "large": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16 + }, + "small" : { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16 + } +} + + \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/MI350X-BATCHED_GEMM-A8W8.json b/aiter/ops/triton/configs/gemm/MI350X-BATCHED_GEMM-A8W8.json new file mode 100644 index 0000000000..947fe026c7 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/MI350X-BATCHED_GEMM-A8W8.json @@ -0,0 +1,24 @@ +{ + "large": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16 + }, + "small" : { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16 + } +} + + \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/MI350X-BATCHED_GEMM-AFP4WFP4.json b/aiter/ops/triton/configs/gemm/MI350X-BATCHED_GEMM-AFP4WFP4.json new file mode 100644 index 0000000000..cdc837fb22 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/MI350X-BATCHED_GEMM-AFP4WFP4.json @@ -0,0 +1,82 @@ +{ + "M_LT_32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 6, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "M_LT_64" : { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "M_LT_128" : { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "M_EQ_128" : { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 2, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "M_LTE_256" : { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 2, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "default": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} + + \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/MI300X-GEMM-AFP4WFP4.json b/aiter/ops/triton/configs/gemm/MI350X-GEMM-A8WFP4.json similarity index 84% rename from aiter/ops/triton/configs/gemm/MI300X-GEMM-AFP4WFP4.json rename to aiter/ops/triton/configs/gemm/MI350X-GEMM-A8WFP4.json index eaf0465f22..03f8117ec3 100644 --- a/aiter/ops/triton/configs/gemm/MI300X-GEMM-AFP4WFP4.json +++ b/aiter/ops/triton/configs/gemm/MI350X-GEMM-A8WFP4.json @@ -1,7 +1,7 @@ { - "small": { + "M_LT_32": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, @@ -12,7 +12,7 @@ "cache_modifier": ".cg", "NUM_KSPLIT": 4 }, - "medium_M32": { + "M_EQ_32" : { "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, @@ -25,7 +25,7 @@ "cache_modifier": ".cg", "NUM_KSPLIT": 4 }, - "medium_M64": { + "M_33_64" : { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, @@ -38,20 +38,20 @@ "cache_modifier": ".cg", "NUM_KSPLIT": 4 }, - "medium_M128": { + "M_65_128" : { "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 3, - "waves_per_eu": 1, + "num_stages": 2, + "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 1, "cache_modifier": ".cg", - "NUM_KSPLIT": 1 + "NUM_KSPLIT": 4 }, - "large": { + "M_129_256" : { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, @@ -63,8 +63,8 @@ "kpack": 1, "cache_modifier": ".cg", "NUM_KSPLIT": 1 - }, - "xlarge": { + }, + "default": { "BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 256, @@ -72,10 +72,11 @@ "num_warps": 8, "num_stages": 2, "waves_per_eu": 1, - "matrix_instr_nonkdim": 16, + "matrix_instr_nonkdim": 32, "kpack": 1, - "cache_modifier": null, + "cache_modifier": null, "NUM_KSPLIT": 1 } - } + + \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/MI350X-GEMM_BLOCKSCALE-A8W8.json b/aiter/ops/triton/configs/gemm/MI350X-GEMM_BLOCKSCALE-A8W8.json new file mode 100644 index 0000000000..5c24b1495a --- /dev/null +++ b/aiter/ops/triton/configs/gemm/MI350X-GEMM_BLOCKSCALE-A8W8.json @@ -0,0 +1,13 @@ +{ + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 1 + } +} \ No newline at end of file diff --git a/aiter/ops/triton/gemm_a16w16.py b/aiter/ops/triton/gemm_a16w16.py index 309d879a14..90fb20859b 100644 --- a/aiter/ops/triton/gemm_a16w16.py +++ b/aiter/ops/triton/gemm_a16w16.py @@ -118,7 +118,6 @@ def _get_config( return _get_config._config_dict["any"] -# Wrapper for gemm kernel. def gemm_a16w16( x, w, diff --git a/aiter/ops/triton/gemm_a8w8.py b/aiter/ops/triton/gemm_a8w8.py index a2f7f29833..d84c964af9 100644 --- a/aiter/ops/triton/gemm_a8w8.py +++ b/aiter/ops/triton/gemm_a8w8.py @@ -1,6 +1,3 @@ -# SPDX-License-Identifier: MIT -# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. - # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. @@ -190,7 +187,6 @@ def _get_config( config = json.load(file) _get_config._config_dict = config - # TODO: Update this logic return _get_config._config_dict["any"] diff --git a/aiter/ops/triton/gemm_a8w8_blockscale.py b/aiter/ops/triton/gemm_a8w8_blockscale.py index 04a97932d3..a6611b4f48 100644 --- a/aiter/ops/triton/gemm_a8w8_blockscale.py +++ b/aiter/ops/triton/gemm_a8w8_blockscale.py @@ -1,18 +1,15 @@ -# SPDX-License-Identifier: MIT -# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. - # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +from typing import Optional +import functools +import json import torch import triton import triton.language as tl from typing import Optional - - -# TODO Move this to a common folder. Will need to add future arch list -def get_arch(): - return triton.runtime.driver.active.get_current_target().arch +import aiter.ops.triton.utils.arch_info as arch_info +from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH @triton.heuristics( @@ -189,6 +186,23 @@ def _gemm_a8w8_blockscale_kernel( tl.store(c_ptrs, c, mask=c_mask) +@functools.lru_cache(maxsize=1024) +def _get_config( + M: int, + N: int, + K: int, +): + if not hasattr(_get_config, "_config_dict"): + dev = arch_info.get_device() + fpath = f"{AITER_TRITON_CONFIGS_PATH}/gemm/{dev}-GEMM_BLOCKSCALE-A8W8.json" + print(f"fpath={fpath}") + with open(fpath, "r") as file: + config = json.load(file) + _get_config._config_dict = config + + return _get_config._config_dict["any"] + + def gemm_a8w8_blockscale( x: torch.Tensor, w: torch.Tensor, @@ -196,6 +210,7 @@ def gemm_a8w8_blockscale( w_scale: torch.Tensor, dtype: Optional[float] = torch.bfloat16, y: Optional[torch.Tensor] = None, + config: Optional[dict] = None, ): """ Computes the 8 bit matmul Y = X x WT using the block-scale quantization approach. @@ -221,28 +236,26 @@ def gemm_a8w8_blockscale( M, K = x.shape K, N = w.shape - # Scale block sizes - # TODO: need a better way to pass scale block sizes around - GROUP_K = triton.next_power_of_2(triton.cdiv(K, w_scale.shape[0])) - GROUP_N = triton.next_power_of_2(triton.cdiv(N, w_scale.shape[1])) - # Check constraints. assert x.shape[1] == w.shape[0], "Incompatible dimensions!!!" if y is None: y = torch.empty((M, N), dtype=dtype, device=x.device) - BLOCK_SIZE_M = 128 - BLOCK_SIZE_N = 128 - BLOCK_SIZE_K = 128 - GROUP_SIZE_M = 4 - waves_per_eu = 2 - kpack = 1 if get_arch() in ("gfx950") else 2 - matrix_instr_nonkdim = 16 - num_warps = 4 - num_stages = 2 - - grid = (triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N),) + if config is None: + config = _get_config(M, N, K) + + # Scale block sizes + # TODO: need a better way to pass scale block sizes around + config["GROUP_K"] = triton.next_power_of_2(triton.cdiv(K, w_scale.shape[0])) + config["GROUP_N"] = triton.next_power_of_2(triton.cdiv(N, w_scale.shape[1])) + + grid = lambda META: ( # noqa: E731 + ( + triton.cdiv(M, config["BLOCK_SIZE_M"]) + * triton.cdiv(N, config["BLOCK_SIZE_N"]), + ) + ) _gemm_a8w8_blockscale_kernel[grid]( x, w, @@ -262,17 +275,7 @@ def gemm_a8w8_blockscale( x_scale.stride(1), w_scale.stride(0), w_scale.stride(1), - GROUP_K, - GROUP_N, - BLOCK_SIZE_M, - BLOCK_SIZE_N, - BLOCK_SIZE_K, - GROUP_SIZE_M, - waves_per_eu=waves_per_eu, - kpack=kpack, - matrix_instr_nonkdim=matrix_instr_nonkdim, - num_warps=num_warps, - num_stages=num_stages, + **config, ) return y diff --git a/aiter/ops/triton/gemm_a8wfp4.py b/aiter/ops/triton/gemm_a8wfp4.py index 8e0cdd2c5d..d43190b0ff 100644 --- a/aiter/ops/triton/gemm_a8wfp4.py +++ b/aiter/ops/triton/gemm_a8wfp4.py @@ -1,14 +1,24 @@ # SPDX-License-Identifier: MIT -# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. from typing import Optional +import functools +import json import os import torch import triton import triton.language as tl from aiter.ops.triton.utils.pid_preprocessing import pid_grid, remap_xcd +import aiter.ops.triton.utils.arch_info as arch_info +from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH -DEBUG = False +global _USE_GEMM_SPLITK_BF16 +_USE_GEMM_SPLITK_BF16 = False + + +def set_use_gemm_splitk_bf16(value: bool): + global _USE_GEMM_SPLITK_BF16 + _USE_GEMM_SPLITK_BF16 = value @triton.heuristics( @@ -239,9 +249,6 @@ def get_splitk(K: int, BLOCK_SIZE_K: int, NUM_KSPLIT: int): triton.cdiv((2 * triton.cdiv(K, NUM_KSPLIT)), BLOCK_SIZE_K) * BLOCK_SIZE_K ) while NUM_KSPLIT > 1 and BLOCK_SIZE_K > 16: - # print(K, SPLITK_BLOCK_SIZE, BLOCK_SIZE_K, NUM_KSPLIT) - # print(K % (SPLITK_BLOCK_SIZE // 2) == 0, SPLITK_BLOCK_SIZE % BLOCK_SIZE_K == 0, K % (BLOCK_SIZE_K // 2) == 0) - if ( K % (SPLITK_BLOCK_SIZE // 2) == 0 and SPLITK_BLOCK_SIZE % BLOCK_SIZE_K == 0 @@ -264,12 +271,50 @@ def get_splitk(K: int, BLOCK_SIZE_K: int, NUM_KSPLIT: int): triton.cdiv((2 * triton.cdiv(K, NUM_KSPLIT)), BLOCK_SIZE_K) * BLOCK_SIZE_K ) - # print(K, SPLITK_BLOCK_SIZE // 2, BLOCK_SIZE_K // 2, NUM_KSPLIT) - # print(K % (SPLITK_BLOCK_SIZE // 2) == 0, SPLITK_BLOCK_SIZE % BLOCK_SIZE_K == 0, K % (BLOCK_SIZE_K // 2) == 0) return SPLITK_BLOCK_SIZE, BLOCK_SIZE_K, NUM_KSPLIT -# Wrapper for gemm kernel. +@functools.lru_cache(maxsize=1024) +def _get_config( + M: int, + N: int, + K: int, +): + if not hasattr(_get_config, "_config_dict"): + dev = arch_info.get_device() + fpath = f"{AITER_TRITON_CONFIGS_PATH}/gemm/{dev}-GEMM-A8WFP4.json" + + with open(fpath, "r") as file: + config = json.load(file) + _get_config._config_dict = config + + if M < 32: + config = _get_config._config_dict["M_LT_32"] + elif M == 32: + config = _get_config._config_dict["M_EQ_32"] + elif M <= 64: + config = _get_config._config_dict["M_33_64"] + elif M <= 128: + config = _get_config._config_dict["M_65_128"] + elif M <= 256: + config = _get_config._config_dict["M_129_256"] + else: + config = _get_config._config_dict["default"] + + if M <= 128: + SPLITK_BLOCK_SIZE, BLOCK_SIZE_K, NUM_KSPLIT = get_splitk( + K, config["BLOCK_SIZE_K"], config["NUM_KSPLIT"] + ) + config["SPLITK_BLOCK_SIZE"] = SPLITK_BLOCK_SIZE + config["BLOCK_SIZE_K"] = BLOCK_SIZE_K + config["NUM_KSPLIT"] = NUM_KSPLIT + + else: + config["SPLITK_BLOCK_SIZE"] = 2 * K + + return config + + def gemm_a8wfp4( x, w, @@ -277,6 +322,7 @@ def gemm_a8wfp4( x_scales, w_scales, dtype: Optional[float] = torch.bfloat16, + config: Optional[dict] = None, ): """ Computes the matmul Y = X @ W.T (where W.T is the logical transpose of unpacked W) @@ -303,9 +349,11 @@ def gemm_a8wfp4( - Every 32 consecutive elements in the K dimension of W share one e8m0 scale - X uses per-row scaling (not per-group scaling) """ - M, K = x.shape K_packed, N = w.shape + + assert arch_info.is_fp4_avail(), "MXFP4 is not available on your device" + assert ( K_packed == K // 2 ), f"Inconsistent shapes: x has K={K} but w has K_packed={K_packed}, expected {K//2}" @@ -314,127 +362,39 @@ def gemm_a8wfp4( K // 32, ), f"Scale shapes incorrect: x_scales should be ({M}, 1), got {x_scales.shape}; w_scales should be ({N}, {K//32}), got {w_scales.shape}" - if M < 32: - BLOCK_SIZE_M = 16 - BLOCK_SIZE_N = 64 - BLOCK_SIZE_K = 256 - GROUP_SIZE_M = 1 - waves_per_eu = 6 - kpack = 1 - num_warps = 4 - num_stages = 2 - matrix_instr_nonkdim = 16 - cache_modifier = ".cg" - - NUM_KSPLIT = 4 - SPLITK_BLOCK_SIZE, BLOCK_SIZE_K, NUM_KSPLIT = get_splitk( - K, BLOCK_SIZE_K, NUM_KSPLIT - ) - - if os.getenv("VLLM_TRITON_FP4_GEMM_SPLITK_USE_BF16") == "1": - y_pp = torch.empty((NUM_KSPLIT, M, N), dtype=y.dtype, device=y.device) - else: - y_pp = torch.empty((NUM_KSPLIT, M, N), dtype=torch.float32, device=y.device) - elif M <= 128: - BLOCK_SIZE_M = triton.next_power_of_2(M) - BLOCK_SIZE_N = 128 - BLOCK_SIZE_K = 256 - GROUP_SIZE_M = 1 - waves_per_eu = 4 - kpack = 1 - num_warps = 4 - num_stages = 2 - matrix_instr_nonkdim = 16 - cache_modifier = ".cg" - - NUM_KSPLIT = 4 - SPLITK_BLOCK_SIZE, BLOCK_SIZE_K, NUM_KSPLIT = get_splitk( - K, BLOCK_SIZE_K, NUM_KSPLIT - ) + if config is None: + config = _get_config(M, N, K) - if os.getenv("VLLM_TRITON_FP4_GEMM_SPLITK_USE_BF16") == "1": - y_pp = torch.empty((NUM_KSPLIT, M, N), dtype=y.dtype, device=y.device) + if M <= 128: + if _USE_GEMM_SPLITK_BF16: + y_pp = torch.empty( + (config["NUM_KSPLIT"], M, N), dtype=y.dtype, device=y.device + ) else: - y_pp = torch.empty((NUM_KSPLIT, M, N), dtype=torch.float32, device=y.device) - elif M <= 256: - BLOCK_SIZE_M = 128 - BLOCK_SIZE_N = 128 - BLOCK_SIZE_K = 256 - GROUP_SIZE_M = 2 - waves_per_eu = 4 - kpack = 1 - num_warps = 4 - num_stages = 2 - matrix_instr_nonkdim = 16 - cache_modifier = ".cg" - - NUM_KSPLIT = 1 - SPLITK_BLOCK_SIZE = 2 * K - y_pp = None + y_pp = torch.empty( + (config["NUM_KSPLIT"], M, N), dtype=torch.float32, device=y.device + ) else: - BLOCK_SIZE_M = 256 - BLOCK_SIZE_N = 256 - BLOCK_SIZE_K = 256 - GROUP_SIZE_M = 32 - waves_per_eu = 1 - kpack = 1 - num_warps = 8 - num_stages = 2 - matrix_instr_nonkdim = 32 - cache_modifier = None - - NUM_KSPLIT = 1 SPLITK_BLOCK_SIZE = 2 * K y_pp = None grid = lambda META: ( # noqa: E731 ( - NUM_KSPLIT + config["NUM_KSPLIT"] * triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]) ), ) - y_final = y if NUM_KSPLIT == 1 else y_pp + y_final = y if config["NUM_KSPLIT"] == 1 else y_pp stride_am, stride_ak = x.stride() stride_bk, stride_bn = w.stride() stride_ck, stride_cm, stride_cn = ( - (0, y.stride(0), y.stride(1)) if NUM_KSPLIT == 1 else y_pp.stride() + (0, y.stride(0), y.stride(1)) if config["NUM_KSPLIT"] == 1 else y_pp.stride() ) stride_asm, stride_ask = x_scales.stride() stride_bsn, stride_bsk = w_scales.stride() - if DEBUG: - print( - "grid:", grid({"BLOCK_SIZE_M": BLOCK_SIZE_M, "BLOCK_SIZE_N": BLOCK_SIZE_N}) - ) - print("x:", x) - print("w:", w) - print("y_final:", y_final) - print("x_scales", x_scales) - print("w_scales:", w_scales) - print("M:", M) - print("N:", N) - print("K:", K) - print("stride_am:", stride_am) - print("stride_ak:", stride_ak) - print("stride_bk:", stride_bk) - print("stride_bn:", stride_bn) - print("stride_ck:", stride_ck) - print("stride_cm:", stride_cm) - print("stride_cn:", stride_cn) - print("stride_asm:", stride_asm) - print("stride_ask:", stride_ask) - print("stride_bsn:", stride_bsn) - print("stride_bsk:", stride_bsk) - print("BLOCK_SIZE_M:", BLOCK_SIZE_M) - print("BLOCK_SIZE_N:", BLOCK_SIZE_N) - print("BLOCK_SIZE_K:", BLOCK_SIZE_K) - print("GROUP_SIZE_M:", GROUP_SIZE_M) - print("NUM_KSPLIT:", NUM_KSPLIT) - print("SPLITK_BLOCK_SIZE:", SPLITK_BLOCK_SIZE) - print("cache_modifier:", cache_modifier) - _gemm_a8wfp4_kernel[grid]( x, w, @@ -455,30 +415,17 @@ def gemm_a8wfp4( stride_ask, stride_bsn, stride_bsk, - BLOCK_SIZE_M, - BLOCK_SIZE_N, - BLOCK_SIZE_K, - GROUP_SIZE_M, - NUM_KSPLIT, - SPLITK_BLOCK_SIZE, RAW_MASKED_LOADS=True, - cache_modifier=cache_modifier, - waves_per_eu=waves_per_eu, - kpack=kpack, - num_warps=num_warps, - num_stages=num_stages, - matrix_instr_nonkdim=matrix_instr_nonkdim, + **config, ) - if NUM_KSPLIT > 1: + if config["NUM_KSPLIT"] > 1: REDUCE_BLOCK_SIZE_M = 16 # TODO: Need to debug - REDUCE_BLOCK_SIZE_N=128 with fp32 partials fails # NOTE: REDUCE_BLOCK_SIZE_N=16 gives best perf with fp32 partials and # REDUCE_BLOCK_SIZE_N=128 gives best perf with bf16 partials - REDUCE_BLOCK_SIZE_N = ( - 128 if os.getenv("VLLM_TRITON_FP4_GEMM_SPLITK_USE_BF16") == "1" else 64 - ) - ACTUAL_KSPLIT = triton.cdiv(K, (SPLITK_BLOCK_SIZE)) + REDUCE_BLOCK_SIZE_N = 128 if _USE_GEMM_SPLITK_BF16 else 64 + ACTUAL_KSPLIT = triton.cdiv(K, (config["SPLITK_BLOCK_SIZE"])) grid_reduce = ( triton.cdiv(M, REDUCE_BLOCK_SIZE_M), @@ -497,5 +444,5 @@ def gemm_a8wfp4( REDUCE_BLOCK_SIZE_M, REDUCE_BLOCK_SIZE_N, ACTUAL_KSPLIT, - NUM_KSPLIT, + config["NUM_KSPLIT"], ) diff --git a/aiter/ops/triton/gemm_afp4wfp4.py b/aiter/ops/triton/gemm_afp4wfp4.py index a8855f5588..e47c06e1ba 100644 --- a/aiter/ops/triton/gemm_afp4wfp4.py +++ b/aiter/ops/triton/gemm_afp4wfp4.py @@ -126,8 +126,7 @@ def _gemm_afp4_wfp4_kernel( for k in range(pid_k * num_k_iter, (pid_k + 1) * num_k_iter): a_scales = tl.load(a_scale_ptrs) b_scales = tl.load(b_scale_ptrs) - # a_scales = tl.full((BLOCK_SIZE_M, BLOCK_SIZE_K//SCALE_GROUP_SIZE), 127, dtype=tl.uint8) - # b_scales = tl.full((BLOCK_SIZE_N, BLOCK_SIZE_K//SCALE_GROUP_SIZE), 127, dtype=tl.uint8) + # Load the next block of A and B, generate a mask by checking the K dimension. # If it is out of bounds, set it to 0. if EVEN_K: @@ -307,8 +306,6 @@ def _gemm_afp4_wfp4_kernel_preshuffled_scales( b_scales, (BLOCK_SIZE_N, BLOCK_SIZE_K // SCALE_GROUP_SIZE) ) - # a_scales = tl.full((BLOCK_SIZE_M, BLOCK_SIZE_K//SCALE_GROUP_SIZE), 127, dtype=tl.uint8) - # b_scales = tl.full((BLOCK_SIZE_N, BLOCK_SIZE_K//SCALE_GROUP_SIZE), 127, dtype=tl.uint8) # Load the next block of A and B, generate a mask by checking the K dimension. # If it is out of bounds, set it to 0. if EVEN_K: @@ -403,9 +400,6 @@ def get_splitk(K: int, BLOCK_SIZE_K: int, NUM_KSPLIT: int): triton.cdiv((2 * triton.cdiv(K, NUM_KSPLIT)), BLOCK_SIZE_K) * BLOCK_SIZE_K ) while NUM_KSPLIT > 1 and BLOCK_SIZE_K > 16: - # print(K, SPLITK_BLOCK_SIZE, BLOCK_SIZE_K, NUM_KSPLIT) - # print(K % (SPLITK_BLOCK_SIZE // 2) == 0, SPLITK_BLOCK_SIZE % BLOCK_SIZE_K == 0, K % (BLOCK_SIZE_K // 2) == 0) - if ( K % (SPLITK_BLOCK_SIZE // 2) == 0 and SPLITK_BLOCK_SIZE % BLOCK_SIZE_K == 0 @@ -462,7 +456,6 @@ def _get_config( return _get_config._config_dict["xlarge"] -# Wrapper for gemm kernel. def gemm_afp4wfp4( x, w, @@ -497,7 +490,7 @@ def gemm_afp4wfp4( if config is None: config = _get_config(M, N, K) - # print(f"AFP4WFP4_config={config}") + if config["NUM_KSPLIT"] > 1: SPLITK_BLOCK_SIZE, BLOCK_SIZE_K, NUM_KSPLIT = get_splitk( K, config["BLOCK_SIZE_K"], config["NUM_KSPLIT"] @@ -582,7 +575,6 @@ def gemm_afp4wfp4( return y -# Wrapper for gemm kernel. def gemm_afp4wfp4_preshuffled_scales( x, w, @@ -609,6 +601,8 @@ def gemm_afp4wfp4_preshuffled_scales( - Y: The output matrix with shape (M, N). """ + assert arch_info.is_fp4_avail(), "MXFP4 is not available on your device" + M, K = x.shape K, N = w.shape @@ -617,7 +611,7 @@ def gemm_afp4wfp4_preshuffled_scales( if config is None: config = _get_config(M, N, K) - # print(f"AFP4WFP4_config={config}") + if config["NUM_KSPLIT"] > 1: SPLITK_BLOCK_SIZE, BLOCK_SIZE_K, NUM_KSPLIT = get_splitk( K, config["BLOCK_SIZE_K"], config["NUM_KSPLIT"] diff --git a/aiter/ops/triton/utils/types.py b/aiter/ops/triton/utils/types.py index b3aa479874..ddaeb0a059 100644 --- a/aiter/ops/triton/utils/types.py +++ b/aiter/ops/triton/utils/types.py @@ -15,6 +15,8 @@ "float8_e4m3fn": e4m3_dtype, "e5m2fnuz": e5m2_dtype, "e4m3fnuz": e4m3_dtype, + "fp8e4m3": e4m3_dtype, + "fp8e5m2": e5m2_dtype, "int64": torch.int64, "int32": torch.int32, "int16": torch.int16, diff --git a/op_tests/op_benchmarks/triton/bench_gemm_a16w16.py b/op_tests/op_benchmarks/triton/bench_gemm_a16w16.py index 8c6b1657d2..7e142e286e 100644 --- a/op_tests/op_benchmarks/triton/bench_gemm_a16w16.py +++ b/op_tests/op_benchmarks/triton/bench_gemm_a16w16.py @@ -87,7 +87,9 @@ def run_benchmark(args): def bench_gemm_a16w16(M, N, K, layout, metric, provider): # NOTE: Assume bias and output has the same dtype c_dtype = torch.bfloat16 - x, w, y = generate_gemm_a16w16_inputs(M, N, K, c_dtype, layout, output=True) + x, w, out_dtype, y = generate_gemm_a16w16_inputs( + M, N, K, c_dtype, layout, output=True + ) # flops flops = 2.0 * M * N * K # memory transfer diff --git a/op_tests/op_benchmarks/triton/bench_gemm_a8w8.py b/op_tests/op_benchmarks/triton/bench_gemm_a8w8.py index f8e529bae7..194c9d99be 100644 --- a/op_tests/op_benchmarks/triton/bench_gemm_a8w8.py +++ b/op_tests/op_benchmarks/triton/bench_gemm_a8w8.py @@ -3,9 +3,9 @@ import torch import triton from aiter.ops.triton.gemm_a8w8 import gemm_a8w8 +from aiter.ops.triton.utils.types import str_to_torch_dtype from op_tests.triton_tests.test_gemm_a8w8 import ( generate_gemm_a8w8_inputs, - name_to_torch_types, ) from utils.benchmark_utils import get_model_configs, get_available_models @@ -83,9 +83,9 @@ def run_benchmark(args): @triton.testing.perf_report([benchmark]) def bench_gemm_a8w8(M, N, K, metric, provider): # NOTE: Assume bias and output has the same dtype - c_dtype = name_to_torch_types["bf16"] + c_dtype = str_to_torch_dtype["bf16"] x, weight, x_scale, w_scale, bias, y = generate_gemm_a8w8_inputs( - M, N, K, name_to_torch_types["fp8e4"], c_dtype, output=True + M, N, K, str_to_torch_dtype["fp8e4m3"], c_dtype, output=True ) # flops flops = 2.0 * M * N * K diff --git a/op_tests/triton_tests/test_batched_gemm_a8w8.py b/op_tests/triton_tests/test_batched_gemm_a8w8.py index 1a637e9e0f..6764e7ef02 100644 --- a/op_tests/triton_tests/test_batched_gemm_a8w8.py +++ b/op_tests/triton_tests/test_batched_gemm_a8w8.py @@ -5,6 +5,8 @@ import triton import pytest from aiter.ops.triton.batched_gemm_a8w8 import batched_gemm_a8w8 +from aiter.ops.triton.utils.arch_info import get_fp8_dtypes +from aiter.ops.triton.utils.types import str_to_torch_dtype import torch.nn.functional as F @@ -27,22 +29,7 @@ def run_triton(x, weight, x_scale, w_scale, bias=None, dtype=torch.bfloat16, y=N return batched_gemm_a8w8(x, weight, x_scale, w_scale, bias, dtype, YQ=y) -def is_cdna4(): - return triton.runtime.driver.active.get_current_target().arch == "gfx950" - - -e5m2_type = torch.float8_e5m2 if is_cdna4() else torch.float8_e5m2fnuz -e4m3_type = torch.float8_e4m3fn if is_cdna4() else torch.float8_e4m3fnuz - -name_to_torch_types = { - "int8": torch.int8, - "int32": torch.int32, - "fp16": torch.float16, - "fp32": torch.float32, - "bf16": torch.bfloat16, - "fp8e5": e5m2_type, - "fp8e4": e4m3_type, -} +e5m2_type, e4m3_type = get_fp8_dtypes() def get_x_vals(): @@ -92,7 +79,7 @@ def get_x_vals(): ) def test_batched_gemm_a8w8(dtype, b, m, n, k, output): - dtype = name_to_torch_types[dtype] + dtype = str_to_torch_dtype[dtype] x = torch.randint(-20, 20, (b, m, k), dtype=torch.int8).cuda() weight = torch.randint(-20, 20, (b, n, k), dtype=torch.int8).cuda() diff --git a/op_tests/triton_tests/test_batched_gemm_afp4wfp4.py b/op_tests/triton_tests/test_batched_gemm_afp4wfp4.py index 7668198a88..cc5d7d6d73 100644 --- a/op_tests/triton_tests/test_batched_gemm_afp4wfp4.py +++ b/op_tests/triton_tests/test_batched_gemm_afp4wfp4.py @@ -2,6 +2,7 @@ import triton import pytest from aiter.ops.triton.batched_gemm_afp4wfp4 import batched_gemm_afp4wfp4 +import aiter.ops.triton.utils.arch_info as arch_info # Note this is specified by the HW and cannot be changed. SCALE_GROUP_SIZE = 32 @@ -128,7 +129,7 @@ def run_torch(x, w, x_scales, w_scales, dtype): @pytest.mark.parametrize("B, M, N, K", get_x_vals()) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) def test_batched_gemm_afp4_wfp4(B: int, M: int, N: int, K: int, dtype): - if triton.runtime.driver.active.get_current_target().arch not in ("gfx950"): + if not (arch_info.is_fp4_avail()): pytest.skip("MXFP4 not supported on this architecture") x, w, x_scales, w_scales = generate_batched_gemm_afp4wfp4_inputs(B, M, N, K) diff --git a/op_tests/triton_tests/test_batched_gemm_bf16.py b/op_tests/triton_tests/test_batched_gemm_bf16.py index ee769932c6..86fd74f572 100644 --- a/op_tests/triton_tests/test_batched_gemm_bf16.py +++ b/op_tests/triton_tests/test_batched_gemm_bf16.py @@ -5,6 +5,8 @@ import triton import pytest from aiter.ops.triton.batched_gemm_bf16 import batched_gemm_bf16 +from aiter.ops.triton.utils.arch_info import get_fp8_dtypes +from aiter.ops.triton.utils.types import str_to_torch_dtype import torch.nn.functional as F @@ -27,22 +29,7 @@ def run_triton(x, weight, bias=None, dtype=torch.bfloat16, y=None): return batched_gemm_bf16(x, weight, bias, dtype, YQ=y) -def is_cdna4(): - return triton.runtime.driver.active.get_current_target().arch == "gfx950" - - -e5m2_type = torch.float8_e5m2 if is_cdna4() else torch.float8_e5m2fnuz -e4m3_type = torch.float8_e4m3fn if is_cdna4() else torch.float8_e4m3fnuz - -name_to_torch_types = { - "int8": torch.int8, - "int32": torch.int32, - "fp16": torch.float16, - "fp32": torch.float32, - "bf16": torch.bfloat16, - "fp8e5": e5m2_type, - "fp8e4": e4m3_type, -} +e5m2_type, e4m3_type = get_fp8_dtypes() def get_x_vals(): @@ -92,7 +79,7 @@ def get_x_vals(): ) def test_batched_gemm_bf16(dtype, b, m, n, k, output): - dtype = name_to_torch_types[dtype] + dtype = str_to_torch_dtype[dtype] x = torch.randint(-20, 20, (b, m, k), dtype=dtype).cuda() weight = torch.randint(-20, 20, (b, n, k), dtype=dtype).cuda() diff --git a/op_tests/triton_tests/test_gemm_a8w8.py b/op_tests/triton_tests/test_gemm_a8w8.py index 7d6e52affa..fa0fe18363 100644 --- a/op_tests/triton_tests/test_gemm_a8w8.py +++ b/op_tests/triton_tests/test_gemm_a8w8.py @@ -6,6 +6,8 @@ import pytest import torch.nn.functional as F from aiter.ops.triton.gemm_a8w8 import gemm_a8w8 +from aiter.ops.triton.utils.arch_info import get_fp8_dtypes +from aiter.ops.triton.utils.types import str_to_torch_dtype def run_torch(x, weight, x_scale, w_scale, bias=None, dtype=torch.bfloat16): @@ -21,22 +23,7 @@ def run_triton(x, weight, x_scale, w_scale, bias=None, dtype=torch.bfloat16, y=N return gemm_a8w8(x, weight, x_scale, w_scale, bias, dtype, y) -def is_cdna4(): - return triton.runtime.driver.active.get_current_target().arch == "gfx950" - - -e5m2_type = torch.float8_e5m2 if is_cdna4() else torch.float8_e5m2fnuz -e4m3_type = torch.float8_e4m3fn if is_cdna4() else torch.float8_e4m3fnuz - -name_to_torch_types = { - "int8": torch.int8, - "int32": torch.int32, - "fp16": torch.float16, - "fp32": torch.float32, - "bf16": torch.bfloat16, - "fp8e5": e5m2_type, - "fp8e4": e4m3_type, -} +e5m2_type, e4m3_type = get_fp8_dtypes() dtype_max = { @@ -111,15 +98,15 @@ def generate_gemm_a8w8_inputs(M, N, K, in_dtype, out_dtype, output=False): "in_dtype, out_dtype, m, n, k, output", [ (in_dtype, out_dtype, *shape, output) - for in_dtype in ["fp8e4", "fp8e5", "int8"] + for in_dtype in ["fp8e4m3", "fp8e5m2", "int8"] for out_dtype in ["bf16"] for shape in get_x_vals() for output in [True, False] ], ) def test_gemm(in_dtype, out_dtype, m, n, k, output): - in_dtype = name_to_torch_types[in_dtype] - out_dtype = name_to_torch_types[out_dtype] + in_dtype = str_to_torch_dtype[in_dtype] + out_dtype = str_to_torch_dtype[out_dtype] x, weight, x_scale, w_scale, bias, y = generate_gemm_a8w8_inputs( m, n, k, in_dtype, out_dtype, output ) diff --git a/op_tests/triton_tests/test_gemm_a8w8_blockscale.py b/op_tests/triton_tests/test_gemm_a8w8_blockscale.py index e74c30a8a9..04817232b1 100644 --- a/op_tests/triton_tests/test_gemm_a8w8_blockscale.py +++ b/op_tests/triton_tests/test_gemm_a8w8_blockscale.py @@ -5,6 +5,8 @@ import triton import pytest from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale +from aiter.ops.triton.utils.arch_info import get_fp8_dtypes +from aiter.ops.triton.utils.types import str_to_torch_dtype import torch.nn.functional as F @@ -34,22 +36,7 @@ def run_triton(x, weight, x_scale, w_scale, dtype=torch.bfloat16, y=None): return gemm_a8w8_blockscale(x, weight, x_scale, w_scale, dtype, y) -def is_cdna4(): - return triton.runtime.driver.active.get_current_target().arch == "gfx950" - - -e5m2_type = torch.float8_e5m2 if is_cdna4() else torch.float8_e5m2fnuz -e4m3_type = torch.float8_e4m3fn if is_cdna4() else torch.float8_e4m3fnuz - -name_to_torch_types = { - "int8": torch.int8, - "int32": torch.int32, - "fp16": torch.float16, - "fp32": torch.float32, - "bf16": torch.bfloat16, - "fp8e5": e5m2_type, - "fp8e4": e4m3_type, -} +e5m2_type, e4m3_type = get_fp8_dtypes() def get_x_vals(): @@ -121,7 +108,7 @@ def generate_gemm_a8w8_blockscale_inputs( def test_gemm(dtype, M, N, K, output): block_shape_n, block_shape_k = block_shape - dtype = name_to_torch_types[dtype] + dtype = str_to_torch_dtype[dtype] x, weight, x_scale, w_scale, y = generate_gemm_a8w8_blockscale_inputs( M, N, diff --git a/op_tests/triton_tests/test_gemm_a8wfp4.py b/op_tests/triton_tests/test_gemm_a8wfp4.py index d25733dd42..ea78f24e41 100644 --- a/op_tests/triton_tests/test_gemm_a8wfp4.py +++ b/op_tests/triton_tests/test_gemm_a8wfp4.py @@ -6,6 +6,7 @@ import pytest from enum import Enum from aiter.ops.triton.gemm_a8wfp4 import gemm_a8wfp4 +import aiter.ops.triton.utils.arch_info as arch_info # Debug DEBUG = False @@ -291,12 +292,7 @@ def run_torch_emulation(x, w, x_scales, w_scales, dtype): return torch.mm(x_f32, w_f32.T).to(dtype) -def is_cdna4(): - return triton.runtime.driver.active.get_current_target().arch == "gfx950" - - -e5m2_type = torch.float8_e5m2 if is_cdna4() else torch.float8_e5m2fnuz -e4m3_type = torch.float8_e4m3fn if is_cdna4() else torch.float8_e4m3fnuz +e5m2_type, e4m3_type = arch_info.get_fp8_dtypes() @pytest.mark.parametrize("M, N, K", get_x_vals()) @@ -317,7 +313,7 @@ def is_cdna4(): @pytest.mark.parametrize("out_dtype", [torch.float16]) def test_gemm_a8wfp4(M: int, N: int, K: int, a_dtype, out_dtype, CLEAR_GPUS=True): torch.manual_seed(42) # for reproducibility - if not is_cdna4(): + if not (arch_info.is_fp4_avail()): pytest.skip("MXFP4 not supported on this architecture") # clean up to avoid hangs in large tests diff --git a/op_tests/triton_tests/test_gemm_afp4wfp4.py b/op_tests/triton_tests/test_gemm_afp4wfp4.py index cf912f96ce..8522a24aea 100644 --- a/op_tests/triton_tests/test_gemm_afp4wfp4.py +++ b/op_tests/triton_tests/test_gemm_afp4wfp4.py @@ -8,7 +8,8 @@ gemm_afp4wfp4, gemm_afp4wfp4_preshuffled_scales, ) -from op_tests.triton_tests.utils.types import str_to_torch_dtype +import aiter.ops.triton.utils.arch_info as arch_info +from aiter.ops.triton.utils.types import str_to_torch_dtype TRITON_HIP_PRESHUFFLE_SCALES = ( os.environ.get("TRITON_HIP_PRESHUFFLE_SCALES", "0") == "1" @@ -178,7 +179,7 @@ def run_torch(x, w, x_scales, w_scales, dtype): @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("output", [True, False]) def test_gemm_afp4_wfp4(M: int, N: int, K: int, dtype, output): - if triton.runtime.driver.active.get_current_target().arch not in ("gfx950"): + if not (arch_info.is_fp4_avail()): pytest.skip("MXFP4 not supported on this architecture") if TRITON_HIP_PRESHUFFLE_SCALES: