From e1116a37576152261e56b1b57a71428108a54db9 Mon Sep 17 00:00:00 2001 From: Xingjian Shi Date: Tue, 20 Oct 2020 03:39:13 -0700 Subject: [PATCH 1/2] Update batch_matmul.py Update batch_matmul.py --- python/tvm/topi/cuda/batch_matmul.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/topi/cuda/batch_matmul.py b/python/tvm/topi/cuda/batch_matmul.py index bb060b3ad8a7..34359bd46f1b 100644 --- a/python/tvm/topi/cuda/batch_matmul.py +++ b/python/tvm/topi/cuda/batch_matmul.py @@ -138,7 +138,7 @@ def _callback(op): return s -def batch_matmul_cublas(x, y): +def batch_matmul_cublas(x, y, _): """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data in batch. @@ -149,6 +149,7 @@ def batch_matmul_cublas(x, y): y : tvm.te.Tensor 3-D with shape [batch, N, K] + _ : None Returns ------- From 5a91e8f7aedb8b0c7a0028e29c1d55c5c3545a5f Mon Sep 17 00:00:00 2001 From: Xingjian Shi Date: Tue, 20 Oct 2020 10:48:41 -0700 Subject: [PATCH 2/2] fix --- python/tvm/topi/cuda/batch_matmul.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/cuda/batch_matmul.py b/python/tvm/topi/cuda/batch_matmul.py index 34359bd46f1b..ee94420066dd 100644 --- a/python/tvm/topi/cuda/batch_matmul.py +++ b/python/tvm/topi/cuda/batch_matmul.py @@ -138,7 +138,7 @@ def _callback(op): return s -def batch_matmul_cublas(x, y, _): +def batch_matmul_cublas(x, y, out_shape=None): """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data in batch. @@ -149,7 +149,9 @@ def batch_matmul_cublas(x, y, _): y : tvm.te.Tensor 3-D with shape [batch, N, K] - _ : None + + out_shape : None + The output shape Returns -------