Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions aiter/ops/gemm_op_a4w4.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,17 @@
import functools
from typing import Optional

from aiter.jit.utils.torch_guard import torch_compile_guard
import pandas as pd
import torch
from torch import Tensor

from aiter import logger
from aiter.jit.utils.torch_guard import torch_compile_guard

from ..jit.core import (
AITER_CONFIGS,
AITER_LOG_TUNED_CONFIG,
compile_ops,
)
from ..jit.core import AITER_CONFIGS, AITER_LOG_TUNED_CONFIG, compile_ops
from ..jit.utils.chip_info import get_cu_num, get_gfx
from ..ops.gemm_op_common import get_padded_m
from ..utility import dtypes


@functools.lru_cache(maxsize=1024)
Expand Down Expand Up @@ -66,12 +63,15 @@ def gemm_a4w4_fake(
B: Tensor, # B:[N, K/2] f4x2
A_scale: Tensor, # A_scale:[M, K/32] e8m0 paded
B_scale: Tensor, # B_scale:[N, K/32] e8m0 paded
out: Tensor, # Out:[M, N] bf16
bias: Optional[Tensor] = None, # bias:[1, N] f32
dtype: torch.dtype = dtypes.bf16,
alpha: Optional[float] = 1.0,
beta: Optional[float] = 0.0,
bpreshuffle: Optional[bool] = True,
) -> torch.Tensor:
m = A.numel() // A.shape[-1]
n = B.shape[0]
out = torch.empty((m, n), dtype=dtype, device=A.device)
return out


Expand All @@ -81,8 +81,8 @@ def gemm_a4w4(
B: Tensor, # B:[N, K/2] f4x2
A_scale: Tensor, # A_scale:[M, K/32] e8m0 paded
B_scale: Tensor, # B_scale:[N, K/32] e8m0 paded
out: Tensor, # Out:[M, N] bf16
bias: Optional[Tensor] = None, # bias:[1, N] f32
dtype: torch.dtype = dtypes.bf16,
alpha: Optional[float] = 1.0,
beta: Optional[float] = 0.0,
bpreshuffle: Optional[bool] = True,
Expand All @@ -93,9 +93,10 @@ def gemm_a4w4(
It is used to perform matrix multiplication with 4-bit quantization.
"""
# Load the A4W4 GEMM kernel
m = A.shape[0]
m = A.numel() // A.shape[-1]
n = B.shape[0]
k = A.shape[-1] * 2
out = torch.empty(((m + 31) // 32 * 32, n), dtype=dtype, device=A.device)
gfx_arch = get_gfx()
if gfx_arch in ["gfx942"]:
raise RuntimeError(
Expand All @@ -114,12 +115,14 @@ def gemm_a4w4(
# or bias is None
):
splitK = 0 if splitK is None else splitK
return gemm_a4w4_blockscale(A, B, A_scale, B_scale, out, splitK=splitK)
return gemm_a4w4_blockscale(
A.view(m, k // 2), B, A_scale, B_scale, out, splitK=splitK
)[:m]
assert (
out.shape[0] % 32 == 0
), "Dim0 of gemm_a4w4_asm output needs to be padded to multiples of 32!"
return gemm_a4w4_asm(
A,
gemm_a4w4_asm(
A.view(m, k // 2),
B,
A_scale,
B_scale,
Expand All @@ -131,6 +134,7 @@ def gemm_a4w4(
bpreshuffle,
log2_k_split=splitK,
)
return out[:m].view(*A.shape[:-1], n)


def gen_gemm_a4w4_asm_fake_tensors(
Expand Down
3 changes: 2 additions & 1 deletion csrc/py_itfs_cu/asm_gemm_a4w4.cu
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ std::tuple<std::string, int> get_heuristic_kernel(int M,
if(cfg.bpreshuffle == bpreshuffle_en &&
(cfg.splitK >= log2_k_split_en))
{
if((N % cfg.tile_N) == 0)
// tile128x512 may not support N % cfg.tile_N != 0
if(cfg.tile_M != 128 || cfg.tile_N != 512 || (N % cfg.tile_N) == 0)
{
std::vector<int> splitK_list =
(log2_k_split.has_value() && cfg.splitK)
Expand Down
24 changes: 12 additions & 12 deletions op_tests/test_gemm_a4w4.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

import argparse

import pandas as pd
import torch

import aiter
from aiter.test_common import checkAllclose, benchmark, perftest, run_perftest
from aiter import dtypes
from aiter.utility import fp4_utils
from aiter.ops.shuffle import shuffle_weight
import argparse
import pandas as pd
from aiter.test_common import benchmark, checkAllclose, perftest, run_perftest
from aiter.utility import fp4_utils

torch.set_default_device("cuda")
torch.set_printoptions(sci_mode=False)
Expand Down Expand Up @@ -98,13 +100,10 @@ def test_gemm(dtype, M, N, K):
x, x_scales_shuffle = quant_func(x, shuffle=True)
w, w_scales_shuffle = quant_func(w, shuffle=True)
wshuffle = shuffle_weight(w, layout=(16, 16))
out1 = torch.empty(M, N, dtype=dtype)
out2 = torch.empty((M + 31) // 32 * 32, N, dtype=dtype)
out3 = torch.empty((M + 31) // 32 * 32, N, dtype=dtype)
bias_f32 = None
x_scales = x_scales.view(torch.uint8)
w_scales = w_scales.view(torch.uint8)
a, avg_a = run_torch(x, w, x_scales, w_scales, dtype)
# out1 = torch.empty(M, N, dtype=dtype)
# b, avg_b = run_triton(x, w.T, x_scales, w_scales, out1, dtype)
# b, avg_b = a, 0
# err_b = checkAllclose(a, b, msg="triton ")
Expand All @@ -115,23 +114,23 @@ def test_gemm(dtype, M, N, K):
wshuffle,
x_scales_shuffle,
w_scales_shuffle,
out2,
bpreshuffle=True,
)
err = checkAllclose(a, c[:M], msg="unified api")
err = checkAllclose(a, c, msg="unified api")
ret["us"] = us
ret["TFLOPS"] = M * N * K * 2 / us / 1e6
ret["TB/s"] = (x.nbytes + w.nbytes) / us / 1e6
ret["err"] = err

# kernelName = "" # "_ZN5aiter42f4gemm_bf16_per1x32Fp4_BpreShuffle_128x512E"
# kernelName = "" # "_ZN5aiter42f4gemm_bf16_per1x32Fp4_BpreShuffle_128x512E"
# log2_k_split = 1
# out2 = torch.empty((M + 31) // 32 * 32, N, dtype=dtype)
# d, us = run_gemm_asm(
# x,
# wshuffle,
# x_scales_shuffle,
# w_scales_shuffle,
# out3,
# out2,
# kernelName,
# bias_f32,
# bpreshuffle=True,
Expand All @@ -144,6 +143,7 @@ def test_gemm(dtype, M, N, K):
# ret[f"TB/s {tag}"] = (x.nbytes + w.nbytes) / us / 1e6
# ret[f"err {tag}"] = err

# out3 = torch.empty((M + 31) // 32 * 32, N, dtype=dtype)
# e, us = run_gemm_ck(x, wshuffle, x_scales_shuffle, w_scales_shuffle, out3)
# err = checkAllclose(a, e[:M], msg="ck ")
# tag = "ck"
Expand Down