diff --git a/csrc/py_itfs_cu/asm_gemm_a16w16.cu b/csrc/py_itfs_cu/asm_gemm_a16w16.cu index 4d6b723e4a..051fde35fb 100644 --- a/csrc/py_itfs_cu/asm_gemm_a16w16.cu +++ b/csrc/py_itfs_cu/asm_gemm_a16w16.cu @@ -254,7 +254,7 @@ torch::Tensor gemm_a16w16_asm(torch::Tensor& A, int gdy = (Mdim + SUBM - 1) / SUBM; int gdz = selectedksplit; - TORCH_CHECK(gdx <= 16, __func__, " gdx (", gdx, ") must be <= 16"); // 16 = 512/32 + TORCH_CHECK(gdy <= 16, __func__, " gdy (", gdy, ") must be <= 16"); // 16 = 512/32 // semaphore.fill_(selectedksplit); args.ptr_semaphore = (void*)semaphore.data_ptr(); diff --git a/gradlib/gradlib/GemmTuner.py b/gradlib/gradlib/GemmTuner.py index 8b83ca742b..2809aab50c 100644 --- a/gradlib/gradlib/GemmTuner.py +++ b/gradlib/gradlib/GemmTuner.py @@ -24,13 +24,13 @@ import torch.nn.functional as F import aiter -from aiter import dtypes, logger +from aiter import dtypes, get_semaphore_workspace, logger from aiter.jit.core import AITER_CONFIG_GEMM_BF16, get_asm_dir from aiter.jit.utils.chip_info import get_cu_num, get_gfx from aiter.ops.shuffle import shuffle_weight +from aiter.ops.triton.gemm_a16w16 import gemm_a16w16 as triton_gemm_a16w16 from aiter.utility.base_tuner import GemmCommonTuner from aiter.utility.mp_tuner import mp_tuner -from aiter.ops.triton.gemm_a16w16 import gemm_a16w16 as triton_gemm_a16w16 aiter.hipb_create_extension() @@ -59,10 +59,12 @@ def call_hipb_mm( def run_gemm_bf16_asm( inp, w, out, bias=None, splitK=None, kernelName=None, bpreshuffle=False ): + sema = get_semaphore_workspace(inp.device) return aiter.gemm_a16w16_asm( inp, w, out, + sema, bias=bias, splitK=splitK, kernelName=kernelName,