Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/py_itfs_cu/asm_gemm_a16w16.cu
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ torch::Tensor gemm_a16w16_asm(torch::Tensor& A,
int gdy = (Mdim + SUBM - 1) / SUBM;
int gdz = selectedksplit;

TORCH_CHECK(gdx <= 16, __func__, " gdx (", gdx, ") must be <= 16"); // 16 = 512/32
TORCH_CHECK(gdy <= 16, __func__, " gdy (", gdy, ") must be <= 16"); // 16 = 512/32

// semaphore.fill_(selectedksplit);
args.ptr_semaphore = (void*)semaphore.data_ptr<uint32_t>();
Expand Down
6 changes: 4 additions & 2 deletions gradlib/gradlib/GemmTuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
import torch.nn.functional as F

import aiter
from aiter import dtypes, logger
from aiter import dtypes, get_semaphore_workspace, logger
from aiter.jit.core import AITER_CONFIG_GEMM_BF16, get_asm_dir
from aiter.jit.utils.chip_info import get_cu_num, get_gfx
from aiter.ops.shuffle import shuffle_weight
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16 as triton_gemm_a16w16
from aiter.utility.base_tuner import GemmCommonTuner
from aiter.utility.mp_tuner import mp_tuner
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16 as triton_gemm_a16w16

aiter.hipb_create_extension()

Expand Down Expand Up @@ -59,10 +59,12 @@ def call_hipb_mm(
def run_gemm_bf16_asm(
inp, w, out, bias=None, splitK=None, kernelName=None, bpreshuffle=False
):
sema = get_semaphore_workspace(inp.device)
return aiter.gemm_a16w16_asm(
inp,
w,
out,
sema,
bias=bias,
splitK=splitK,
kernelName=kernelName,
Expand Down