diff --git a/python/tvm/topi/nn/batch_matmul.py b/python/tvm/topi/nn/batch_matmul.py index 9c5848129397..fb97ef9b58e2 100644 --- a/python/tvm/topi/nn/batch_matmul.py +++ b/python/tvm/topi/nn/batch_matmul.py @@ -18,7 +18,8 @@ # pylint: disable=invalid-name import tvm from tvm import te, auto_scheduler -from ..utils import get_const_tuple +from tvm import tir +from ..util import get_const_tuple def batch_matmul(x, y, oshape=None, auto_scheduler_rewritten_layout=""): @@ -61,9 +62,17 @@ def batch_matmul(x, y, oshape=None, auto_scheduler_rewritten_layout=""): _, M, K = x.shape k = te.reduce_axis((0, K), name="k") if oshape is None: - assert XB == YB or XB == 1 or YB == 1, "batch dimension doesn't match" - assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistant" - batch = te.max(XB, YB) + if isinstance(XB, int) and isinstance(YB, int): + assert XB == YB or XB == 1 or YB == 1, "batch dimension doesn't match" + batch = max(XB, YB) + elif isinstance(XB, tir.expr.Var): + batch = XB + else: + batch = YB + + if isinstance(x_shape[2], int) and isinstance(y_shape[2], int): + assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistant" + N = y.shape[1] oshape = (batch, M, N)