Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
275 changes: 259 additions & 16 deletions aiter/ops/triton/_triton_kernels/gemm_afp4wfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def _gemm_afp4_wfp4_kernel(

for k in range(pid_k * num_k_iter, (pid_k + 1) * num_k_iter):
a_scales = tl.load(a_scale_ptrs)
b_scales = tl.load(b_scale_ptrs)
b_scales = tl.load(b_scale_ptrs, cache_modifier=cache_modifier)

# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.
Expand Down Expand Up @@ -297,14 +297,36 @@ def _gemm_afp4_wfp4_kernel_preshuffled_scales(
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

for k in range(pid_k * num_k_iter, (pid_k + 1) * num_k_iter):
a_scales = tl.load(a_scale_ptrs)
b_scales = tl.load(b_scale_ptrs, cache_modifier=cache_modifier)
if BLOCK_SIZE_M >= 32:
a_scales = tl.reshape(
a_scales, (BLOCK_SIZE_M, BLOCK_SIZE_K // SCALE_GROUP_SIZE)
if BLOCK_SIZE_M < 32:
a_scales = tl.load(a_scale_ptrs)
else:
a_scales = (
tl.load(a_scale_ptrs)
.reshape(
BLOCK_SIZE_M // 32,
BLOCK_SIZE_K // SCALE_GROUP_SIZE // 8,
4,
16,
2,
2,
1,
)
.permute(0, 5, 3, 1, 4, 2, 6)
.reshape(BLOCK_SIZE_M, BLOCK_SIZE_K // SCALE_GROUP_SIZE)
)
b_scales = (
tl.load(b_scale_ptrs, cache_modifier=cache_modifier)
.reshape(
BLOCK_SIZE_N // 32,
BLOCK_SIZE_K // SCALE_GROUP_SIZE // 8,
4,
16,
2,
2,
1,
)
b_scales = tl.reshape(
b_scales, (BLOCK_SIZE_N, BLOCK_SIZE_K // SCALE_GROUP_SIZE)
.permute(0, 5, 3, 1, 4, 2, 6)
.reshape(BLOCK_SIZE_N, BLOCK_SIZE_K // SCALE_GROUP_SIZE)
)

# Load the next block of A and B, generate a mask by checking the K dimension.
Expand Down Expand Up @@ -346,6 +368,222 @@ def _gemm_afp4_wfp4_kernel_preshuffled_scales(
tl.store(c_ptrs, c, mask=c_mask, cache_modifier=".wt")


@triton.heuristics(
{
"EVEN_K": lambda args: (args["K"] % (args["BLOCK_SIZE_K"] // 2) == 0)
and (args["SPLITK_BLOCK_SIZE"] % args["BLOCK_SIZE_K"] == 0)
and (args["K"] % (args["SPLITK_BLOCK_SIZE"] // 2) == 0),
"GRID_MN": lambda args: triton.cdiv(args["M"], args["BLOCK_SIZE_M"])
* triton.cdiv(args["N"], args["BLOCK_SIZE_N"]),
}
)
@triton.jit
def _gemm_afp4_wfp4_kernel_preshuffled_weight_scales(
a_ptr,
b_ptr,
c_ptr,
a_scales_ptr,
b_scales_ptr,
M,
N,
K,
stride_am,
stride_ak,
stride_bn,
stride_bk,
stride_ck,
stride_cm,
stride_cn,
stride_asm,
stride_ask,
stride_bsn,
stride_bsk,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
NUM_KSPLIT: tl.constexpr,
SPLITK_BLOCK_SIZE: tl.constexpr,
EVEN_K: tl.constexpr,
GRID_MN: tl.constexpr,
cache_modifier: tl.constexpr,
):
"""Kernel for computing the matmul C = A x B.
A and B inputs are in the microscale fp4 (mxfp4) format.
A_scales and B_scales are in e8m0 format.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""

tl.assume(stride_am > 0)
tl.assume(stride_ak > 0)
tl.assume(stride_bk > 0)
tl.assume(stride_bn > 0)
tl.assume(stride_cm > 0)
tl.assume(stride_cn > 0)
tl.assume(stride_asm > 0)
tl.assume(stride_ask > 0)
tl.assume(stride_bsk > 0)
tl.assume(stride_bsn > 0)

# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
pid_unified = tl.program_id(axis=0)
pid_unified = remap_xcd(pid_unified, GRID_MN * NUM_KSPLIT, NUM_XCDS=8)
pid_k = pid_unified % NUM_KSPLIT
pid = pid_unified // NUM_KSPLIT
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)

if NUM_KSPLIT == 1:
pid_m, pid_n = pid_grid(pid, num_pid_m, num_pid_n, GROUP_SIZE_M=GROUP_SIZE_M)
else:
pid_m = pid // num_pid_n
pid_n = pid % num_pid_n

tl.assume(pid_m >= 0)
tl.assume(pid_n >= 0)
# We assume 32 elements along K share the same scale.
SCALE_GROUP_SIZE: tl.constexpr = 32

if (pid_k * SPLITK_BLOCK_SIZE // 2) < K:

num_k_iter = tl.cdiv(SPLITK_BLOCK_SIZE // 2, BLOCK_SIZE_K // 2)

# Create pointers for first block of A and B input matrices
# The BLOCK sizes are of the elements and in fp4 we pack 2 per uint8 container.
offs_k = tl.arange(0, BLOCK_SIZE_K // 2)
offs_k_shuffle_arr = tl.arange(0, (BLOCK_SIZE_K // 2) * 16)
offs_k_split = pid_k * (SPLITK_BLOCK_SIZE // 2) + offs_k
offs_k_shuffle = pid_k * (SPLITK_BLOCK_SIZE // 2) * 16 + offs_k_shuffle_arr

offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * (BLOCK_SIZE_N // 16) + tl.arange(0, BLOCK_SIZE_N // 16)) % N
a_ptrs = a_ptr + (
offs_am[:, None] * stride_am + offs_k_split[None, :] * stride_ak
)
b_ptrs = b_ptr + (
# offs_k_split[:, None] * stride_bk + offs_bn[None, :] * stride_bn
offs_bn[:, None] * stride_bn
+ offs_k_shuffle[None, :] * stride_bk
)
# Create pointers for the first block of A and B scales

offs_asn = (
pid_n * (BLOCK_SIZE_N // 32) + tl.arange(0, (BLOCK_SIZE_N // 32))
) % N
offs_ks = (pid_k * (SPLITK_BLOCK_SIZE // SCALE_GROUP_SIZE) * 32) + tl.arange(
0, BLOCK_SIZE_K // SCALE_GROUP_SIZE * 32
)
# B scales are N x K even though B operand is K x N.
b_scale_ptrs = (
b_scales_ptr
+ offs_asn[:, None] * stride_bsn
+ offs_ks[None, :] * stride_bsk
)

if BLOCK_SIZE_M < 32:
offs_ks_non_shufl = (
pid_k * (SPLITK_BLOCK_SIZE // SCALE_GROUP_SIZE)
) + tl.arange(0, BLOCK_SIZE_K // SCALE_GROUP_SIZE)
a_scale_ptrs = (
a_scales_ptr
+ offs_am[:, None] * stride_asm
+ offs_ks_non_shufl[None, :] * stride_ask
)
else:
offs_asm = (
pid_m * (BLOCK_SIZE_M // 32) + tl.arange(0, (BLOCK_SIZE_M // 32))
) % M
a_scale_ptrs = (
a_scales_ptr
+ offs_asm[:, None] * stride_asm
+ offs_ks[None, :] * stride_ask
)

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

for k in range(pid_k * num_k_iter, (pid_k + 1) * num_k_iter):
if BLOCK_SIZE_M < 32:
a_scales = tl.load(a_scale_ptrs)
else:
a_scales = (
tl.load(a_scale_ptrs)
.reshape(
BLOCK_SIZE_M // 32,
BLOCK_SIZE_K // SCALE_GROUP_SIZE // 8,
4,
16,
2,
2,
1,
)
.permute(0, 5, 3, 1, 4, 2, 6)
.reshape(BLOCK_SIZE_M, BLOCK_SIZE_K // SCALE_GROUP_SIZE)
)

b_scales = (
tl.load(b_scale_ptrs, cache_modifier=cache_modifier)
.reshape(
BLOCK_SIZE_N // 32,
BLOCK_SIZE_K // SCALE_GROUP_SIZE // 8,
4,
16,
2,
2,
1,
)
.permute(0, 5, 3, 1, 4, 2, 6)
.reshape(BLOCK_SIZE_N, BLOCK_SIZE_K // SCALE_GROUP_SIZE)
)

# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.
if EVEN_K:
a = tl.load(a_ptrs)
b = tl.load(b_ptrs, cache_modifier=cache_modifier)

b = (
b.reshape(
1,
BLOCK_SIZE_N // 16,
BLOCK_SIZE_K // 64,
2,
16,
16,
)
.permute(0, 1, 4, 2, 3, 5)
.reshape(BLOCK_SIZE_N, BLOCK_SIZE_K // 2)
.trans(1, 0)
)

accumulator += tl.dot_scaled(a, a_scales, "e2m1", b, b_scales, "e2m1")

# Advance the ptrs to the next K block.
a_ptrs += (BLOCK_SIZE_K // 2) * stride_ak
b_ptrs += (BLOCK_SIZE_K // 2) * 16 * stride_bk
if BLOCK_SIZE_M < 32:
a_scale_ptrs += (BLOCK_SIZE_K // SCALE_GROUP_SIZE) * stride_ask
else:
a_scale_ptrs += BLOCK_SIZE_K * stride_ask
b_scale_ptrs += BLOCK_SIZE_K * stride_bsk

c = accumulator.to(c_ptr.type.element_ty)

# Write back the block of the output matrix C with masks.
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
c_ptrs = (
c_ptr
+ stride_cm * offs_cm[:, None]
+ stride_cn * offs_cn[None, :]
+ pid_k * stride_ck
)
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask, cache_modifier=".wt")


@triton.jit
def _gemm_afp4_wfp4_reduce_kernel(
c_in_ptr,
Expand Down Expand Up @@ -398,29 +636,34 @@ def _get_config(
M: int,
N: int,
K: int,
shuffle: bool = False,
):
if not hasattr(_get_config, "_config_dict"):
shuffle_filename_suffix = "" if not shuffle else "_PRESHUFFLED"
if not hasattr(_get_config, "_config_dict") or not hasattr(
_get_config._config_dict, f"default{shuffle_filename_suffix}"
):
dev = arch_info.get_device()
_get_config._config_dict = {}
fpath = f"{AITER_TRITON_CONFIGS_PATH}/gemm/{dev}-GEMM-AFP4WFP4.json"
fpath = f"{AITER_TRITON_CONFIGS_PATH}/gemm/{dev}-GEMM-AFP4WFP4{shuffle_filename_suffix}.json"
with open(fpath, "r") as file:
config = json.load(file)
_get_config._config_dict["default"] = config
_get_config._config_dict[f"default{shuffle_filename_suffix}"] = config

key = f"{N}_{K}"
key = f"{N}_{K}{shuffle_filename_suffix}"
if key not in _get_config._config_dict.keys():
dev = arch_info.get_device()
fpath = (
f"{AITER_TRITON_CONFIGS_PATH}/gemm/{dev}-GEMM-AFP4WFP4-N={N}-K={2*K}.json"
)
fpath = f"{AITER_TRITON_CONFIGS_PATH}/gemm/{dev}-GEMM-AFP4WFP4{shuffle_filename_suffix}-N={N}-K={2*K}.json"
if os.path.exists(fpath):
with open(fpath, "r") as file:
config = json.load(file)
_get_config._config_dict[key] = config
else:
key = "default" # fall back to default config
key = f"default{shuffle_filename_suffix}" # fall back to default config

if M < 32:
BLK_M = triton.next_power_of_2(M)
if BLK_M >= 16 and "small_M16" in _get_config._config_dict[key]:
return _get_config._config_dict[key]["small_M16"]
return _get_config._config_dict[key]["small"]
elif M <= 128:
BLK_M = triton.next_power_of_2(M)
Expand Down
Loading
Loading