From 253e062ee4c86e9e61fed8aa882f1e643c6c256c Mon Sep 17 00:00:00 2001 From: jcf94 Date: Wed, 18 Aug 2021 23:58:52 +0800 Subject: [PATCH 1/2] [FIX] Bug fix for batch_matmul parameters miss match --- python/tvm/topi/cuda/batch_matmul.py | 13 ++++++++++++- python/tvm/topi/cuda/batch_matmul_tensorcore.py | 9 +++++++-- python/tvm/topi/rocm/batch_matmul.py | 7 +++++-- 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/python/tvm/topi/cuda/batch_matmul.py b/python/tvm/topi/cuda/batch_matmul.py index 3fc8a584b557..d7d64f7aa2f8 100644 --- a/python/tvm/topi/cuda/batch_matmul.py +++ b/python/tvm/topi/cuda/batch_matmul.py @@ -237,7 +237,9 @@ def schedule_batch_matmul_cublas(_, outs): @autotvm.register_topi_compute("batch_matmul_int8.cuda") -def batch_matmul_int8(cfg, x, y, out_shape=None, out_dtype=None): +def batch_matmul_int8( + cfg, x, y, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True +): """Batch Matmul operator for int8 on CUDA. Parameters @@ -258,11 +260,20 @@ def batch_matmul_int8(cfg, x, y, out_shape=None, out_dtype=None): out_dtype : Optional[str] Specifies the output data type for mixed precision batch matmul. + transpose_a : Optional[bool] = False + Whether the first tensor is in transposed format. + + transpose_b : Optional[bool] = True + Whether the second tensor is in transposed format. + Returns ------- output : tvm.te.Tensor 3-D with shape [batch, M, N] """ + del out_shape + # TODO(jcf94): Deal with different transpose combinations + assert transpose_a == False and transpose_b == True if out_dtype is None: out_dtype = x.dtype diff --git a/python/tvm/topi/cuda/batch_matmul_tensorcore.py b/python/tvm/topi/cuda/batch_matmul_tensorcore.py index a56d3c36ba33..289f300d1d05 100644 --- a/python/tvm/topi/cuda/batch_matmul_tensorcore.py +++ b/python/tvm/topi/cuda/batch_matmul_tensorcore.py @@ -29,9 +29,14 @@ @autotvm.register_topi_compute("batch_matmul_tensorcore.cuda") -def batch_matmul_tensorcore(cfg, x, y, out_shape=None, out_dtype=None): +def batch_matmul_tensorcore( + cfg, x, y, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True +): """batch matmul tensorcore operator on cuda""" - # todo: deal with out_shape for broadcast, liuxin.ai + # TODO(jcf94): Deal with different transpose combinations + assert transpose_a == False and transpose_b == True + # TODO(liuxin.ai): Deal with out_shape for broadcast + del out_shape return batch_matmul_tensorcore_cuda(x, y, out_dtype) diff --git a/python/tvm/topi/rocm/batch_matmul.py b/python/tvm/topi/rocm/batch_matmul.py index 7f35f4b55620..53b51eedf6d9 100644 --- a/python/tvm/topi/rocm/batch_matmul.py +++ b/python/tvm/topi/rocm/batch_matmul.py @@ -23,7 +23,9 @@ @autotvm.register_topi_compute("batch_matmul_rocblas.rocm") -def batch_matmul_rocblas(cfg, x, y, out_shape=None): +def batch_matmul_rocblas( + cfg, x, y, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True +): """Computes matrix multiplication of `x` and `y` via rocblas when `x` and `y` are batched matrices. @@ -40,12 +42,13 @@ def batch_matmul_rocblas(cfg, x, y, out_shape=None): output : tvm.te.Tensor 3-D with shape [batch, M, N] """ + del out_dtype batch, M, K = get_const_tuple(x.shape) _, N, _ = get_const_tuple(y.shape) if out_shape is not None: assert out_shape[0] == batch, "Input and output batch sizes must match" assert out_shape[1] == M and out_shape[2] == N, "Invalid output shape" - result = rocblas.batch_matmul(x, y, False, True) + result = rocblas.batch_matmul(x, y, transpose_a, transpose_b) cfg.add_flop(batch * M * N * K * 2) return result From e89d8303151eb95f02b65079a3f77b032d1b6bdd Mon Sep 17 00:00:00 2001 From: jcf94 Date: Thu, 19 Aug 2021 00:10:43 +0800 Subject: [PATCH 2/2] Update --- python/tvm/topi/cuda/batch_matmul.py | 2 +- python/tvm/topi/cuda/batch_matmul_tensorcore.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/cuda/batch_matmul.py b/python/tvm/topi/cuda/batch_matmul.py index d7d64f7aa2f8..bd556d2976da 100644 --- a/python/tvm/topi/cuda/batch_matmul.py +++ b/python/tvm/topi/cuda/batch_matmul.py @@ -273,7 +273,7 @@ def batch_matmul_int8( """ del out_shape # TODO(jcf94): Deal with different transpose combinations - assert transpose_a == False and transpose_b == True + assert not transpose_a and transpose_b if out_dtype is None: out_dtype = x.dtype diff --git a/python/tvm/topi/cuda/batch_matmul_tensorcore.py b/python/tvm/topi/cuda/batch_matmul_tensorcore.py index 289f300d1d05..5324302051ba 100644 --- a/python/tvm/topi/cuda/batch_matmul_tensorcore.py +++ b/python/tvm/topi/cuda/batch_matmul_tensorcore.py @@ -34,7 +34,7 @@ def batch_matmul_tensorcore( ): """batch matmul tensorcore operator on cuda""" # TODO(jcf94): Deal with different transpose combinations - assert transpose_a == False and transpose_b == True + assert not transpose_a and transpose_b # TODO(liuxin.ai): Deal with out_shape for broadcast del out_shape return batch_matmul_tensorcore_cuda(x, y, out_dtype)