Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tvm/contrib/cblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def batch_matmul(lhs, rhs, transa=False, transb=False, iterative=False, **kwargs
C: Tensor
The result tensor.
"""
b = lhs.shape[0]
b = te.max(lhs.shape[0], rhs.shape[0])
n = lhs.shape[2] if transa else lhs.shape[1]
m = rhs.shape[1] if transb else rhs.shape[2]
return te.extern(
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/contrib/mkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def batch_matmul(lhs, rhs, transa=False, transb=False, iterative=False, **kwargs
C: Tensor
The result tensor.
"""
b = lhs.shape[0]
b = te.max(lhs.shape[0], rhs.shape[0])
n = lhs.shape[2] if transa else lhs.shape[1]
m = rhs.shape[1] if transb else rhs.shape[2]
return te.extern(
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/topi/x86/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def _default_batch_matmul_config(cfg, M, N, K):

def batch_matmul_blas_common(cfg, x, y, out_shape, lib):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
data in batch, using one of BLAS libraries.
data in batch, using one of BLAS libraries. Supports broadcasting in batch dimension.

Parameters
----------
Expand All @@ -162,10 +162,10 @@ def batch_matmul_blas_common(cfg, x, y, out_shape, lib):
assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul"
XB, M, XK = get_const_tuple(x.shape)
YB, N, YK = get_const_tuple(y.shape)
assert XB == YB, "batch dimension doesn't match"
assert (XB == YB) or (YB == 1) or (XB == 1), "batch dimension doesn't match"
assert XK == YK, "shapes of x and y is inconsistent"
if out_shape is not None:
assert out_shape[0] == XB, "got invalid output shape"
assert out_shape[0] in (XB, YB), "got invalid output shape"
assert out_shape[1] == M, "got invalid output shape"
assert out_shape[2] == N, "got invalid output shape"
cfg.add_flop(XB * M * N * XK * 2)
Expand Down
53 changes: 31 additions & 22 deletions tests/python/contrib/test_cblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,14 @@ def test_quantized_matmul_add():


def verify_batch_matmul(
batch, m, l, n, lib, transa=False, transb=False, iterative=False, dtype="float32"
batch_a, batch_b, m, l, n, lib, transa=False, transb=False, iterative=False, dtype="float32"
):
ashape = (batch, l, n) if transa else (batch, n, l)
bshape = (batch, m, l) if transb else (batch, l, m)
batch = max(batch_a, batch_b)
ashape = (batch_a, l, n) if transa else (batch_a, n, l)
bshape = (batch_b, m, l) if transb else (batch_b, l, m)
A = te.placeholder(ashape, name="A", dtype=dtype)
B = te.placeholder(bshape, name="B", dtype=dtype)
C = cblas.batch_matmul(A, B, transa, transb)
C = lib.batch_matmul(A, B, transa, transb)
D = te.compute(C.shape, lambda k, i, j: C[k, i, j], name="D")
s = te.create_schedule(D.op)

Expand Down Expand Up @@ -207,24 +208,32 @@ def verify(target="llvm"):


def test_batch_matmul():
verify_batch_matmul(16, 235, 128, 1024, cblas)
verify_batch_matmul(16, 235, 128, 1024, cblas, True, False)
verify_batch_matmul(16, 235, 128, 1024, cblas, False, True)
verify_batch_matmul(16, 235, 128, 1024, cblas, True, True)
verify_batch_matmul(16, 235, 128, 1024, mkl)
verify_batch_matmul(16, 235, 128, 1024, mkl, True, False)
verify_batch_matmul(16, 235, 128, 1024, mkl, False, True)
verify_batch_matmul(16, 235, 128, 1024, mkl, True, True)
verify_batch_matmul(1, 1, 16, 3, cblas)
verify_batch_matmul(1, 1, 16, 3, cblas, True, False)
verify_batch_matmul(1, 1, 16, 3, cblas, False, False)
verify_batch_matmul(1, 1, 16, 3, cblas, True, True)
verify_batch_matmul(1, 1, 16, 3, cblas, iterative=True)
verify_batch_matmul(1, 1, 16, 3, mkl)
verify_batch_matmul(1, 1, 16, 3, mkl, True, False)
verify_batch_matmul(1, 1, 16, 3, mkl, False, False)
verify_batch_matmul(1, 1, 16, 3, mkl, True, True)
verify_batch_matmul(1, 1, 16, 3, mkl, iterative=True)
verify_batch_matmul(16, 16, 235, 128, 1024, cblas)
verify_batch_matmul(16, 16, 235, 128, 1024, cblas, True, False)
verify_batch_matmul(16, 16, 235, 128, 1024, cblas, False, True)
verify_batch_matmul(16, 16, 235, 128, 1024, cblas, True, True)
verify_batch_matmul(16, 16, 235, 128, 1024, mkl)
verify_batch_matmul(16, 16, 235, 128, 1024, mkl, True, False)
verify_batch_matmul(16, 16, 235, 128, 1024, mkl, False, True)
verify_batch_matmul(16, 16, 235, 128, 1024, mkl, True, True)
verify_batch_matmul(16, 1, 235, 128, 1024, cblas)
verify_batch_matmul(1, 16, 235, 128, 1024, cblas)
verify_batch_matmul(16, 1, 235, 128, 1024, cblas, iterative=True)
verify_batch_matmul(1, 16, 235, 128, 1024, cblas, iterative=True)
verify_batch_matmul(16, 1, 235, 128, 1024, mkl)
verify_batch_matmul(1, 16, 235, 128, 1024, mkl)
verify_batch_matmul(16, 1, 235, 128, 1024, mkl, iterative=True)
verify_batch_matmul(1, 16, 235, 128, 1024, mkl, iterative=True)
verify_batch_matmul(1, 1, 1, 16, 3, cblas)
verify_batch_matmul(1, 1, 1, 16, 3, cblas, True, False)
verify_batch_matmul(1, 1, 1, 16, 3, cblas, False, False)
verify_batch_matmul(1, 1, 1, 16, 3, cblas, True, True)
verify_batch_matmul(1, 1, 1, 16, 3, cblas, iterative=True)
verify_batch_matmul(1, 1, 1, 16, 3, mkl)
verify_batch_matmul(1, 1, 1, 16, 3, mkl, True, False)
verify_batch_matmul(1, 1, 1, 16, 3, mkl, False, False)
verify_batch_matmul(1, 1, 1, 16, 3, mkl, True, True)
verify_batch_matmul(1, 1, 1, 16, 3, mkl, iterative=True)


if __name__ == "__main__":
Expand Down