diff --git a/python/tvm/topi/cuda/batch_matmul.py b/python/tvm/topi/cuda/batch_matmul.py index 3fc8a584b557..bd556d2976da 100644 --- a/python/tvm/topi/cuda/batch_matmul.py +++ b/python/tvm/topi/cuda/batch_matmul.py @@ -237,7 +237,9 @@ def schedule_batch_matmul_cublas(_, outs): @autotvm.register_topi_compute("batch_matmul_int8.cuda") -def batch_matmul_int8(cfg, x, y, out_shape=None, out_dtype=None): +def batch_matmul_int8( + cfg, x, y, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True +): """Batch Matmul operator for int8 on CUDA. Parameters @@ -258,11 +260,20 @@ def batch_matmul_int8(cfg, x, y, out_shape=None, out_dtype=None): out_dtype : Optional[str] Specifies the output data type for mixed precision batch matmul. + transpose_a : Optional[bool] = False + Whether the first tensor is in transposed format. + + transpose_b : Optional[bool] = True + Whether the second tensor is in transposed format. + Returns ------- output : tvm.te.Tensor 3-D with shape [batch, M, N] """ + del out_shape + # TODO(jcf94): Deal with different transpose combinations + assert not transpose_a and transpose_b if out_dtype is None: out_dtype = x.dtype diff --git a/python/tvm/topi/cuda/batch_matmul_tensorcore.py b/python/tvm/topi/cuda/batch_matmul_tensorcore.py index a56d3c36ba33..5324302051ba 100644 --- a/python/tvm/topi/cuda/batch_matmul_tensorcore.py +++ b/python/tvm/topi/cuda/batch_matmul_tensorcore.py @@ -29,9 +29,14 @@ @autotvm.register_topi_compute("batch_matmul_tensorcore.cuda") -def batch_matmul_tensorcore(cfg, x, y, out_shape=None, out_dtype=None): +def batch_matmul_tensorcore( + cfg, x, y, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True +): """batch matmul tensorcore operator on cuda""" - # todo: deal with out_shape for broadcast, liuxin.ai + # TODO(jcf94): Deal with different transpose combinations + assert not transpose_a and transpose_b + # TODO(liuxin.ai): Deal with out_shape for broadcast + del out_shape return batch_matmul_tensorcore_cuda(x, y, out_dtype) diff --git a/python/tvm/topi/rocm/batch_matmul.py b/python/tvm/topi/rocm/batch_matmul.py index 7f35f4b55620..53b51eedf6d9 100644 --- a/python/tvm/topi/rocm/batch_matmul.py +++ b/python/tvm/topi/rocm/batch_matmul.py @@ -23,7 +23,9 @@ @autotvm.register_topi_compute("batch_matmul_rocblas.rocm") -def batch_matmul_rocblas(cfg, x, y, out_shape=None): +def batch_matmul_rocblas( + cfg, x, y, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True +): """Computes matrix multiplication of `x` and `y` via rocblas when `x` and `y` are batched matrices. @@ -40,12 +42,13 @@ def batch_matmul_rocblas(cfg, x, y, out_shape=None): output : tvm.te.Tensor 3-D with shape [batch, M, N] """ + del out_dtype batch, M, K = get_const_tuple(x.shape) _, N, _ = get_const_tuple(y.shape) if out_shape is not None: assert out_shape[0] == batch, "Input and output batch sizes must match" assert out_shape[1] == M and out_shape[2] == N, "Invalid output shape" - result = rocblas.batch_matmul(x, y, False, True) + result = rocblas.batch_matmul(x, y, transpose_a, transpose_b) cfg.add_flop(batch * M * N * K * 2) return result