From 67e8e6fae7f90f6f7654db253faaaf65317c4e44 Mon Sep 17 00:00:00 2001 From: valarLip Date: Thu, 18 Dec 2025 05:24:18 +0000 Subject: [PATCH 1/3] Remove the input parameter "out" in gemm_a4w4 --- aiter/ops/gemm_op_a4w4.py | 28 ++++++++++++++++------------ csrc/py_itfs_cu/asm_gemm_a4w4.cu | 3 ++- op_tests/test_gemm_a4w4.py | 21 +++++++++++---------- 3 files changed, 29 insertions(+), 23 deletions(-) diff --git a/aiter/ops/gemm_op_a4w4.py b/aiter/ops/gemm_op_a4w4.py index bd3759f98c..f85c9b823d 100644 --- a/aiter/ops/gemm_op_a4w4.py +++ b/aiter/ops/gemm_op_a4w4.py @@ -4,20 +4,17 @@ import functools from typing import Optional -from aiter.jit.utils.torch_guard import torch_compile_guard import pandas as pd import torch from torch import Tensor from aiter import logger +from aiter.jit.utils.torch_guard import torch_compile_guard -from ..jit.core import ( - AITER_CONFIGS, - AITER_LOG_TUNED_CONFIG, - compile_ops, -) +from ..jit.core import AITER_CONFIGS, AITER_LOG_TUNED_CONFIG, compile_ops from ..jit.utils.chip_info import get_cu_num, get_gfx from ..ops.gemm_op_common import get_padded_m +from ..utility import dtypes @functools.lru_cache(maxsize=1024) @@ -66,12 +63,15 @@ def gemm_a4w4_fake( B: Tensor, # B:[N, K/2] f4x2 A_scale: Tensor, # A_scale:[M, K/32] e8m0 paded B_scale: Tensor, # B_scale:[N, K/32] e8m0 paded - out: Tensor, # Out:[M, N] bf16 bias: Optional[Tensor] = None, # bias:[1, N] f32 + dtype: torch.dtype = dtypes.bf16, alpha: Optional[float] = 1.0, beta: Optional[float] = 0.0, bpreshuffle: Optional[bool] = True, ) -> torch.Tensor: + m = A.numel() // A.shape[-1] + n = B.shape[0] + out = torch.empty((m, n), dtype=dtype, device=A.device) return out @@ -81,8 +81,8 @@ def gemm_a4w4( B: Tensor, # B:[N, K/2] f4x2 A_scale: Tensor, # A_scale:[M, K/32] e8m0 paded B_scale: Tensor, # B_scale:[N, K/32] e8m0 paded - out: Tensor, # Out:[M, N] bf16 bias: Optional[Tensor] = None, # bias:[1, N] f32 + dtype: torch.dtype = dtypes.bf16, alpha: Optional[float] = 1.0, beta: Optional[float] = 0.0, bpreshuffle: Optional[bool] = True, @@ -93,9 +93,10 @@ def gemm_a4w4( It is used to perform matrix multiplication with 4-bit quantization. """ # Load the A4W4 GEMM kernel - m = A.shape[0] + m = A.numel() // A.shape[-1] n = B.shape[0] k = A.shape[-1] * 2 + out = torch.empty(((m + 31) // 32 * 32, n), dtype=dtype, device=A.device) gfx_arch = get_gfx() if gfx_arch in ["gfx942"]: raise RuntimeError( @@ -114,12 +115,14 @@ def gemm_a4w4( # or bias is None ): splitK = 0 if splitK is None else splitK - return gemm_a4w4_blockscale(A, B, A_scale, B_scale, out, splitK=splitK) + return gemm_a4w4_blockscale( + A.view(-1, A.shape[-1]), B, A_scale, B_scale, out, splitK=splitK + )[:m] assert ( out.shape[0] % 32 == 0 ), "Dim0 of gemm_a4w4_asm output needs to be padded to multiples of 32!" - return gemm_a4w4_asm( - A, + gemm_a4w4_asm( + A.view(-1, A.shape[-1]), B, A_scale, B_scale, @@ -131,6 +134,7 @@ def gemm_a4w4( bpreshuffle, log2_k_split=splitK, ) + return out[:m].view(*A.shape[:-1], n) def gen_gemm_a4w4_asm_fake_tensors( diff --git a/csrc/py_itfs_cu/asm_gemm_a4w4.cu b/csrc/py_itfs_cu/asm_gemm_a4w4.cu index e68a18de27..1a55902607 100644 --- a/csrc/py_itfs_cu/asm_gemm_a4w4.cu +++ b/csrc/py_itfs_cu/asm_gemm_a4w4.cu @@ -113,7 +113,8 @@ std::tuple get_heuristic_kernel(int M, if(cfg.bpreshuffle == bpreshuffle_en && (cfg.splitK >= log2_k_split_en)) { - if((N % cfg.tile_N) == 0) + // tile128x512 may mot support N % cfg.tile_N != 0 + if(cfg.tile_M != 128 || cfg.tile_N != 512 || (N % cfg.tile_N) == 0) { std::vector splitK_list = (log2_k_split.has_value() && cfg.splitK) diff --git a/op_tests/test_gemm_a4w4.py b/op_tests/test_gemm_a4w4.py index 8a1fed83f5..7640ef4b78 100644 --- a/op_tests/test_gemm_a4w4.py +++ b/op_tests/test_gemm_a4w4.py @@ -1,14 +1,16 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +import argparse + +import pandas as pd import torch + import aiter -from aiter.test_common import checkAllclose, benchmark, perftest, run_perftest from aiter import dtypes -from aiter.utility import fp4_utils from aiter.ops.shuffle import shuffle_weight -import argparse -import pandas as pd +from aiter.test_common import benchmark, checkAllclose, perftest, run_perftest +from aiter.utility import fp4_utils torch.set_default_device("cuda") torch.set_printoptions(sci_mode=False) @@ -98,13 +100,11 @@ def test_gemm(dtype, M, N, K): x, x_scales_shuffle = quant_func(x, shuffle=True) w, w_scales_shuffle = quant_func(w, shuffle=True) wshuffle = shuffle_weight(w, layout=(16, 16)) - out1 = torch.empty(M, N, dtype=dtype) - out2 = torch.empty((M + 31) // 32 * 32, N, dtype=dtype) - out3 = torch.empty((M + 31) // 32 * 32, N, dtype=dtype) bias_f32 = None x_scales = x_scales.view(torch.uint8) w_scales = w_scales.view(torch.uint8) a, avg_a = run_torch(x, w, x_scales, w_scales, dtype) + # out1 = torch.empty(M, N, dtype=dtype) # b, avg_b = run_triton(x, w.T, x_scales, w_scales, out1, dtype) # b, avg_b = a, 0 # err_b = checkAllclose(a, b, msg="triton ") @@ -115,7 +115,6 @@ def test_gemm(dtype, M, N, K): wshuffle, x_scales_shuffle, w_scales_shuffle, - out2, bpreshuffle=True, ) err = checkAllclose(a, c[:M], msg="unified api") @@ -124,14 +123,15 @@ def test_gemm(dtype, M, N, K): ret["TB/s"] = (x.nbytes + w.nbytes) / us / 1e6 ret["err"] = err - # kernelName = "" # "_ZN5aiter42f4gemm_bf16_per1x32Fp4_BpreShuffle_128x512E" + # kernelName = "" #"_ZN5aiter42f4gemm_bf16_per1x32Fp4_BpreShuffle_128x512E" # log2_k_split = 1 + # out2 = torch.empty((M + 31) // 32 * 32, N, dtype=dtype) # d, us = run_gemm_asm( # x, # wshuffle, # x_scales_shuffle, # w_scales_shuffle, - # out3, + # out2, # kernelName, # bias_f32, # bpreshuffle=True, @@ -144,6 +144,7 @@ def test_gemm(dtype, M, N, K): # ret[f"TB/s {tag}"] = (x.nbytes + w.nbytes) / us / 1e6 # ret[f"err {tag}"] = err + # out3 = torch.empty((M + 31) // 32 * 32, N, dtype=dtype) # e, us = run_gemm_ck(x, wshuffle, x_scales_shuffle, w_scales_shuffle, out3) # err = checkAllclose(a, e[:M], msg="ck ") # tag = "ck" From a7f30986376b10c1eb02d4dce7e4c85012464c9a Mon Sep 17 00:00:00 2001 From: valarLip Date: Thu, 18 Dec 2025 13:11:58 +0000 Subject: [PATCH 2/3] update --- aiter/ops/gemm_op_a4w4.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aiter/ops/gemm_op_a4w4.py b/aiter/ops/gemm_op_a4w4.py index f85c9b823d..752c724645 100644 --- a/aiter/ops/gemm_op_a4w4.py +++ b/aiter/ops/gemm_op_a4w4.py @@ -116,13 +116,13 @@ def gemm_a4w4( ): splitK = 0 if splitK is None else splitK return gemm_a4w4_blockscale( - A.view(-1, A.shape[-1]), B, A_scale, B_scale, out, splitK=splitK + A.view(m, k // 2), B, A_scale, B_scale, out, splitK=splitK )[:m] assert ( out.shape[0] % 32 == 0 ), "Dim0 of gemm_a4w4_asm output needs to be padded to multiples of 32!" gemm_a4w4_asm( - A.view(-1, A.shape[-1]), + A.view(m, k // 2), B, A_scale, B_scale, From 38f277d205bb0a1b42b8c8922153330b69fc3c8f Mon Sep 17 00:00:00 2001 From: chenjun Date: Fri, 19 Dec 2025 02:37:48 +0000 Subject: [PATCH 3/3] format --- csrc/py_itfs_cu/asm_gemm_a4w4.cu | 2 +- op_tests/test_gemm_a4w4.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/csrc/py_itfs_cu/asm_gemm_a4w4.cu b/csrc/py_itfs_cu/asm_gemm_a4w4.cu index 1a55902607..c76d9107d7 100644 --- a/csrc/py_itfs_cu/asm_gemm_a4w4.cu +++ b/csrc/py_itfs_cu/asm_gemm_a4w4.cu @@ -113,7 +113,7 @@ std::tuple get_heuristic_kernel(int M, if(cfg.bpreshuffle == bpreshuffle_en && (cfg.splitK >= log2_k_split_en)) { - // tile128x512 may mot support N % cfg.tile_N != 0 + // tile128x512 may not support N % cfg.tile_N != 0 if(cfg.tile_M != 128 || cfg.tile_N != 512 || (N % cfg.tile_N) == 0) { std::vector splitK_list = diff --git a/op_tests/test_gemm_a4w4.py b/op_tests/test_gemm_a4w4.py index 7640ef4b78..99f62455ff 100644 --- a/op_tests/test_gemm_a4w4.py +++ b/op_tests/test_gemm_a4w4.py @@ -100,7 +100,6 @@ def test_gemm(dtype, M, N, K): x, x_scales_shuffle = quant_func(x, shuffle=True) w, w_scales_shuffle = quant_func(w, shuffle=True) wshuffle = shuffle_weight(w, layout=(16, 16)) - bias_f32 = None x_scales = x_scales.view(torch.uint8) w_scales = w_scales.view(torch.uint8) a, avg_a = run_torch(x, w, x_scales, w_scales, dtype) @@ -117,13 +116,13 @@ def test_gemm(dtype, M, N, K): w_scales_shuffle, bpreshuffle=True, ) - err = checkAllclose(a, c[:M], msg="unified api") + err = checkAllclose(a, c, msg="unified api") ret["us"] = us ret["TFLOPS"] = M * N * K * 2 / us / 1e6 ret["TB/s"] = (x.nbytes + w.nbytes) / us / 1e6 ret["err"] = err - # kernelName = "" #"_ZN5aiter42f4gemm_bf16_per1x32Fp4_BpreShuffle_128x512E" + # kernelName = "" # "_ZN5aiter42f4gemm_bf16_per1x32Fp4_BpreShuffle_128x512E" # log2_k_split = 1 # out2 = torch.empty((M + 31) // 32 * 32, N, dtype=dtype) # d, us = run_gemm_asm(