Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 32 additions & 28 deletions aiter/ops/triton/batched_gemm_a8w8.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.

# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

from typing import Optional
import functools
import json
import torch
import triton
import triton.language as tl
from typing import Optional
import aiter.ops.triton.utils.arch_info as arch_info
from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH


@triton.heuristics(
Expand Down Expand Up @@ -181,6 +183,26 @@ def _batched_gemm_a8w8_kernel(
tl.store(c_ptrs, c, mask=c_mask)


@functools.lru_cache(maxsize=1024)
def _get_config(
M: int,
N: int,
K: int,
):
if not hasattr(_get_config, "_config_dict"):
dev = arch_info.get_device()
fpath = f"{AITER_TRITON_CONFIGS_PATH}/gemm/{dev}-BATCHED_GEMM-A8W8.json"
print(f"fpath={fpath}")
with open(fpath, "r") as file:
config = json.load(file)
_get_config._config_dict = config

if M + N >= 4096:
return _get_config._config_dict["large"]
else:
return _get_config._config_dict["small"]


def batched_gemm_a8w8(
XQ: torch.Tensor,
WQ: torch.Tensor,
Expand All @@ -190,6 +212,7 @@ def batched_gemm_a8w8(
dtype: Optional[torch.dtype] = torch.bfloat16,
splitK: Optional[int] = None,
YQ: Optional[torch.Tensor] = None,
config: Optional[dict] = None,
):
"""
Computes the matmul YQ[i] = XQ[i] x WQ[i]T and applies a conversion scale for every i in a given batch.
Expand Down Expand Up @@ -235,25 +258,12 @@ def batched_gemm_a8w8(
if YQ is None:
YQ = torch.empty((B, M, N), dtype=dtype, device=XQ.device)

if (M + N) >= 4096:
BLOCK_SIZE_M = 256
BLOCK_SIZE_N = 256
BLOCK_SIZE_K = 64
GROUP_SIZE_M = 4
else:
BLOCK_SIZE_M = 128
BLOCK_SIZE_N = 128
BLOCK_SIZE_K = 32
GROUP_SIZE_M = 1
if config is None:
config = _get_config(M, N, K)

waves_per_eu = 2
matrix_instr_nonkdim = 16
num_warps = 8
num_stages = 2

grid = (
grid = lambda META: ( # noqa: E731
B,
triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N),
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
)

_batched_gemm_a8w8_kernel[grid](
Expand All @@ -279,13 +289,7 @@ def batched_gemm_a8w8(
w_scale.stride(0),
bias.stride(0) if has_bias else 0,
has_bias,
BLOCK_SIZE_M,
BLOCK_SIZE_N,
BLOCK_SIZE_K,
GROUP_SIZE_M,
waves_per_eu=waves_per_eu,
matrix_instr_nonkdim=matrix_instr_nonkdim,
num_warps=num_warps,
num_stages=num_stages,
**config,
)

return YQ
176 changes: 78 additions & 98 deletions aiter/ops/triton/batched_gemm_afp4wfp4.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,25 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

from typing import Optional
import functools
import json
import os
import torch
import triton
import triton.language as tl
from aiter.ops.triton.utils.pid_preprocessing import pid_grid, remap_xcd
import aiter.ops.triton.utils.arch_info as arch_info
from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH


global _USE_GEMM_SPLITK_BF16
_USE_GEMM_SPLITK_BF16 = False


def set_use_gemm_splitk_bf16(value: bool):
global _USE_GEMM_SPLITK_BF16
_USE_GEMM_SPLITK_BF16 = value


@triton.heuristics(
Expand Down Expand Up @@ -259,18 +275,57 @@ def get_splitk(K: int, BLOCK_SIZE_K: int, NUM_KSPLIT: int):
triton.cdiv((2 * triton.cdiv(K, NUM_KSPLIT)), BLOCK_SIZE_K) * BLOCK_SIZE_K
)

# print(K, SPLITK_BLOCK_SIZE // 2, BLOCK_SIZE_K // 2, NUM_KSPLIT)
# print(K % (SPLITK_BLOCK_SIZE // 2) == 0, SPLITK_BLOCK_SIZE % BLOCK_SIZE_K == 0, K % (BLOCK_SIZE_K // 2) == 0)
return SPLITK_BLOCK_SIZE, BLOCK_SIZE_K, NUM_KSPLIT


@functools.lru_cache(maxsize=1024)
def _get_config(
M: int,
N: int,
K: int,
):
if not hasattr(_get_config, "_config_dict"):
dev = arch_info.get_device()
fpath = f"{AITER_TRITON_CONFIGS_PATH}/gemm/{dev}-BATCHED_GEMM-AFP4WFP4.json"
with open(fpath, "r") as file:
config = json.load(file)
_get_config._config_dict = config

if M < 32:
config = _get_config._config_dict["M_LT_32"]
elif M < 64:
config = _get_config._config_dict["M_LT_64"]
elif M < 128:
config = _get_config._config_dict["M_LT_128"]
elif M == 128:
config = _get_config._config_dict["M_EQ_128"]
elif M <= 256:
config = _get_config._config_dict["M_LTE_256"]
else:
config = _get_config._config_dict["default"]

if M <= 128:
SPLITK_BLOCK_SIZE, BLOCK_SIZE_K, NUM_KSPLIT = get_splitk(
K, config["BLOCK_SIZE_K"], config["NUM_KSPLIT"]
)
config["SPLITK_BLOCK_SIZE"] = SPLITK_BLOCK_SIZE
config["BLOCK_SIZE_K"] = BLOCK_SIZE_K
config["NUM_KSPLIT"] = NUM_KSPLIT

else:
config["SPLITK_BLOCK_SIZE"] = 2 * K

return config


def batched_gemm_afp4wfp4(
x: torch.Tensor,
w: torch.Tensor,
y: torch.Tensor,
x_scales: torch.Tensor,
w_scales: torch.Tensor,
dtype: Optional[float] = torch.bfloat16,
config: Optional[dict] = None,
):
"""
Computes the matmul Y = X x W
Expand All @@ -289,105 +344,43 @@ def batched_gemm_afp4wfp4(
- Y: The output matrix with shape (M, N).
"""

assert arch_info.is_fp4_avail(), "MXFP4 is not available on your device"

Bx, M, K = x.shape
Bw, K, N = w.shape
By, _, _ = y.shape
assert Bx == Bw == By
Batch = Bx

if M < 32:
BLOCK_SIZE_M = 16
BLOCK_SIZE_N = 64
BLOCK_SIZE_K = 256
GROUP_SIZE_M = 1
waves_per_eu = 6
kpack = 1
num_warps = 4
num_stages = 2
matrix_instr_nonkdim = 16
cache_modifier = ".cg"

NUM_KSPLIT = 4
SPLITK_BLOCK_SIZE, BLOCK_SIZE_K, NUM_KSPLIT = get_splitk(
K, BLOCK_SIZE_K, NUM_KSPLIT
)

if os.getenv("VLLM_TRITON_FP4_GEMM_SPLITK_USE_BF16") == "1":
y_pp = torch.empty(
(Batch, NUM_KSPLIT, M, N), dtype=y.dtype, device=y.device
)
else:
y_pp = torch.empty(
(Batch, NUM_KSPLIT, M, N), dtype=torch.float32, device=y.device
)
elif M <= 128:
BLOCK_SIZE_M = triton.next_power_of_2(M)
BLOCK_SIZE_N = 128
BLOCK_SIZE_K = 256
GROUP_SIZE_M = 1
waves_per_eu = 4
kpack = 1
num_warps = 4
num_stages = 2
matrix_instr_nonkdim = 16
cache_modifier = ".cg"

NUM_KSPLIT = 4
SPLITK_BLOCK_SIZE, BLOCK_SIZE_K, NUM_KSPLIT = get_splitk(
K, BLOCK_SIZE_K, NUM_KSPLIT
)
if config is None:
config = _get_config(M, N, K)

if os.getenv("VLLM_TRITON_FP4_GEMM_SPLITK_USE_BF16") == "1":
if M <= 128:
if _USE_GEMM_SPLITK_BF16:
y_pp = torch.empty(
(Batch, NUM_KSPLIT, M, N), dtype=y.dtype, device=y.device
(Batch, config["NUM_KSPLIT"], M, N), dtype=y.dtype, device=y.device
)
else:
y_pp = torch.empty(
(Batch, NUM_KSPLIT, M, N), dtype=torch.float32, device=y.device
(Batch, config["NUM_KSPLIT"], M, N),
dtype=torch.float32,
device=y.device,
)
elif M <= 256:
BLOCK_SIZE_M = 128
BLOCK_SIZE_N = 128
BLOCK_SIZE_K = 256
GROUP_SIZE_M = 2
waves_per_eu = 4
kpack = 1
num_warps = 4
num_stages = 2
matrix_instr_nonkdim = 16
cache_modifier = ".cg"

NUM_KSPLIT = 1
SPLITK_BLOCK_SIZE = 2 * K
y_pp = None
else:
BLOCK_SIZE_M = 256
BLOCK_SIZE_N = 256
BLOCK_SIZE_K = 256
GROUP_SIZE_M = 64
waves_per_eu = 1
kpack = 1
num_warps = 8
num_stages = 2
matrix_instr_nonkdim = 16
cache_modifier = None

NUM_KSPLIT = 1
SPLITK_BLOCK_SIZE = 2 * K
y_pp = None

grid = lambda META: ( # noqa: E731
Batch,
(
NUM_KSPLIT
config["NUM_KSPLIT"]
* triton.cdiv(M, META["BLOCK_SIZE_M"])
* triton.cdiv(N, META["BLOCK_SIZE_N"])
),
)
_batched_gemm_afp4_wfp4_kernel[grid](
x,
w,
y if NUM_KSPLIT == 1 else y_pp,
y if config["NUM_KSPLIT"] == 1 else y_pp,
x_scales,
w_scales,
M,
Expand All @@ -399,39 +392,26 @@ def batched_gemm_afp4wfp4(
w.stride(0),
w.stride(1),
w.stride(2),
y.stride(0) if NUM_KSPLIT == 1 else y_pp.stride(0),
0 if NUM_KSPLIT == 1 else y_pp.stride(1),
y.stride(1) if NUM_KSPLIT == 1 else y_pp.stride(2),
y.stride(2) if NUM_KSPLIT == 1 else y_pp.stride(3),
y.stride(0) if config["NUM_KSPLIT"] == 1 else y_pp.stride(0),
0 if config["NUM_KSPLIT"] == 1 else y_pp.stride(1),
y.stride(1) if config["NUM_KSPLIT"] == 1 else y_pp.stride(2),
y.stride(2) if config["NUM_KSPLIT"] == 1 else y_pp.stride(3),
x_scales.stride(0),
x_scales.stride(1),
x_scales.stride(2),
w_scales.stride(0),
w_scales.stride(1),
w_scales.stride(2),
BLOCK_SIZE_M,
BLOCK_SIZE_N,
BLOCK_SIZE_K,
GROUP_SIZE_M,
NUM_KSPLIT,
SPLITK_BLOCK_SIZE,
cache_modifier=cache_modifier,
waves_per_eu=waves_per_eu,
kpack=kpack,
num_warps=num_warps,
num_stages=num_stages,
matrix_instr_nonkdim=matrix_instr_nonkdim,
**config,
)

if NUM_KSPLIT > 1:
if config["NUM_KSPLIT"] > 1:
REDUCE_BLOCK_SIZE_M = 16
# TODO: Need to debug - REDUCE_BLOCK_SIZE_N=128 with fp32 partials fails
# NOTE: REDUCE_BLOCK_SIZE_N=16 gives best perf with fp32 partials and
# REDUCE_BLOCK_SIZE_N=128 gives best perf with bf16 partials
REDUCE_BLOCK_SIZE_N = (
128 if os.getenv("VLLM_TRITON_FP4_GEMM_SPLITK_USE_BF16") == "1" else 64
)
ACTUAL_KSPLIT = triton.cdiv(K, (SPLITK_BLOCK_SIZE // 2))
REDUCE_BLOCK_SIZE_N = 128 if _USE_GEMM_SPLITK_BF16 else 64
ACTUAL_KSPLIT = triton.cdiv(K, (config["SPLITK_BLOCK_SIZE"] // 2))

grid_reduce = (
Batch,
Expand All @@ -453,5 +433,5 @@ def batched_gemm_afp4wfp4(
REDUCE_BLOCK_SIZE_M,
REDUCE_BLOCK_SIZE_N,
ACTUAL_KSPLIT,
NUM_KSPLIT,
config["NUM_KSPLIT"],
)
Loading