diff --git a/python/tvm/contrib/cblas.py b/python/tvm/contrib/cblas.py index 58bf933d44b8..1dfeb801b370 100644 --- a/python/tvm/contrib/cblas.py +++ b/python/tvm/contrib/cblas.py @@ -72,7 +72,7 @@ def batch_matmul(lhs, rhs, transa=False, transb=False, iterative=False, **kwargs C: Tensor The result tensor. """ - b = lhs.shape[0] + b = te.max(lhs.shape[0], rhs.shape[0]) n = lhs.shape[2] if transa else lhs.shape[1] m = rhs.shape[1] if transb else rhs.shape[2] return te.extern( diff --git a/python/tvm/contrib/mkl.py b/python/tvm/contrib/mkl.py index c6e340619ef8..449d660c9027 100644 --- a/python/tvm/contrib/mkl.py +++ b/python/tvm/contrib/mkl.py @@ -105,7 +105,7 @@ def batch_matmul(lhs, rhs, transa=False, transb=False, iterative=False, **kwargs C: Tensor The result tensor. """ - b = lhs.shape[0] + b = te.max(lhs.shape[0], rhs.shape[0]) n = lhs.shape[2] if transa else lhs.shape[1] m = rhs.shape[1] if transb else rhs.shape[2] return te.extern( diff --git a/python/tvm/topi/x86/batch_matmul.py b/python/tvm/topi/x86/batch_matmul.py index df480123375d..37bdd09d6ca6 100644 --- a/python/tvm/topi/x86/batch_matmul.py +++ b/python/tvm/topi/x86/batch_matmul.py @@ -139,7 +139,7 @@ def _default_batch_matmul_config(cfg, M, N, K): def batch_matmul_blas_common(cfg, x, y, out_shape, lib): """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are - data in batch, using one of BLAS libraries. + data in batch, using one of BLAS libraries. Supports broadcasting in batch dimension. Parameters ---------- @@ -162,10 +162,10 @@ def batch_matmul_blas_common(cfg, x, y, out_shape, lib): assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul" XB, M, XK = get_const_tuple(x.shape) YB, N, YK = get_const_tuple(y.shape) - assert XB == YB, "batch dimension doesn't match" + assert (XB == YB) or (YB == 1) or (XB == 1), "batch dimension doesn't match" assert XK == YK, "shapes of x and y is inconsistent" if out_shape is not None: - assert out_shape[0] == XB, "got invalid output shape" + assert out_shape[0] in (XB, YB), "got invalid output shape" assert out_shape[1] == M, "got invalid output shape" assert out_shape[2] == N, "got invalid output shape" cfg.add_flop(XB * M * N * XK * 2) diff --git a/tests/python/contrib/test_cblas.py b/tests/python/contrib/test_cblas.py index b4fc2b283369..2b99879d8227 100644 --- a/tests/python/contrib/test_cblas.py +++ b/tests/python/contrib/test_cblas.py @@ -158,13 +158,14 @@ def test_quantized_matmul_add(): def verify_batch_matmul( - batch, m, l, n, lib, transa=False, transb=False, iterative=False, dtype="float32" + batch_a, batch_b, m, l, n, lib, transa=False, transb=False, iterative=False, dtype="float32" ): - ashape = (batch, l, n) if transa else (batch, n, l) - bshape = (batch, m, l) if transb else (batch, l, m) + batch = max(batch_a, batch_b) + ashape = (batch_a, l, n) if transa else (batch_a, n, l) + bshape = (batch_b, m, l) if transb else (batch_b, l, m) A = te.placeholder(ashape, name="A", dtype=dtype) B = te.placeholder(bshape, name="B", dtype=dtype) - C = cblas.batch_matmul(A, B, transa, transb) + C = lib.batch_matmul(A, B, transa, transb) D = te.compute(C.shape, lambda k, i, j: C[k, i, j], name="D") s = te.create_schedule(D.op) @@ -207,24 +208,32 @@ def verify(target="llvm"): def test_batch_matmul(): - verify_batch_matmul(16, 235, 128, 1024, cblas) - verify_batch_matmul(16, 235, 128, 1024, cblas, True, False) - verify_batch_matmul(16, 235, 128, 1024, cblas, False, True) - verify_batch_matmul(16, 235, 128, 1024, cblas, True, True) - verify_batch_matmul(16, 235, 128, 1024, mkl) - verify_batch_matmul(16, 235, 128, 1024, mkl, True, False) - verify_batch_matmul(16, 235, 128, 1024, mkl, False, True) - verify_batch_matmul(16, 235, 128, 1024, mkl, True, True) - verify_batch_matmul(1, 1, 16, 3, cblas) - verify_batch_matmul(1, 1, 16, 3, cblas, True, False) - verify_batch_matmul(1, 1, 16, 3, cblas, False, False) - verify_batch_matmul(1, 1, 16, 3, cblas, True, True) - verify_batch_matmul(1, 1, 16, 3, cblas, iterative=True) - verify_batch_matmul(1, 1, 16, 3, mkl) - verify_batch_matmul(1, 1, 16, 3, mkl, True, False) - verify_batch_matmul(1, 1, 16, 3, mkl, False, False) - verify_batch_matmul(1, 1, 16, 3, mkl, True, True) - verify_batch_matmul(1, 1, 16, 3, mkl, iterative=True) + verify_batch_matmul(16, 16, 235, 128, 1024, cblas) + verify_batch_matmul(16, 16, 235, 128, 1024, cblas, True, False) + verify_batch_matmul(16, 16, 235, 128, 1024, cblas, False, True) + verify_batch_matmul(16, 16, 235, 128, 1024, cblas, True, True) + verify_batch_matmul(16, 16, 235, 128, 1024, mkl) + verify_batch_matmul(16, 16, 235, 128, 1024, mkl, True, False) + verify_batch_matmul(16, 16, 235, 128, 1024, mkl, False, True) + verify_batch_matmul(16, 16, 235, 128, 1024, mkl, True, True) + verify_batch_matmul(16, 1, 235, 128, 1024, cblas) + verify_batch_matmul(1, 16, 235, 128, 1024, cblas) + verify_batch_matmul(16, 1, 235, 128, 1024, cblas, iterative=True) + verify_batch_matmul(1, 16, 235, 128, 1024, cblas, iterative=True) + verify_batch_matmul(16, 1, 235, 128, 1024, mkl) + verify_batch_matmul(1, 16, 235, 128, 1024, mkl) + verify_batch_matmul(16, 1, 235, 128, 1024, mkl, iterative=True) + verify_batch_matmul(1, 16, 235, 128, 1024, mkl, iterative=True) + verify_batch_matmul(1, 1, 1, 16, 3, cblas) + verify_batch_matmul(1, 1, 1, 16, 3, cblas, True, False) + verify_batch_matmul(1, 1, 1, 16, 3, cblas, False, False) + verify_batch_matmul(1, 1, 1, 16, 3, cblas, True, True) + verify_batch_matmul(1, 1, 1, 16, 3, cblas, iterative=True) + verify_batch_matmul(1, 1, 1, 16, 3, mkl) + verify_batch_matmul(1, 1, 1, 16, 3, mkl, True, False) + verify_batch_matmul(1, 1, 1, 16, 3, mkl, False, False) + verify_batch_matmul(1, 1, 1, 16, 3, mkl, True, True) + verify_batch_matmul(1, 1, 1, 16, 3, mkl, iterative=True) if __name__ == "__main__":