diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 3c7574562676..694001f612e7 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -1003,16 +1003,26 @@ struct DenseAttrs : public tvm::AttrsNode { } }; -/*! \brief Attributes for batch matmul operator */ +/*! \brief Attributes for batch matmul operator. */ struct BatchMatmulAttrs : public tvm::AttrsNode { - tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite DataType out_dtype; + bool transpose_a; + bool transpose_b; + tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite TVM_DECLARE_ATTRS(BatchMatmulAttrs, "relay.attrs.BatchMatmulAttrs") { // use 0 bits to indicate none. TVM_ATTR_FIELD(out_dtype) .set_default(NullValue()) .describe("Output data type, set to explicit type under mixed precision setting"); + + TVM_ATTR_FIELD(transpose_a) + .set_default(false) + .describe("Whether the first input tensor is in transposed format."); + + TVM_ATTR_FIELD(transpose_b) + .set_default(false) + .describe("Whether the second input tensor is in transposed format."); } }; diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index c6857e7773d4..d35e0e1c203d 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -52,6 +52,11 @@ # However, please note that `nn.matmul` is in experimental so it may have some performance # issues. "use_dense": True, + # By default, TVM converts `tf.batch_matmul` to `transpose(weight) + nn.batch_matmul_NT`. + # Change this flag to False to directly convert to `nn.batch_matmul`. + # Note that `nn.batch_matmul` with format other than NT is in experimental, it may have some + # performance issues. + "use_nt_batch_matmul": True, } # compatible operators that do NOT require any conversion. @@ -1214,7 +1219,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): return func, self._params -def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None, use_dense_op=True): +def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None, convert_config=None): """Load tensorflow graph which is a python tensorflow graph object into relay. The companion parameters will be handled automatically. @@ -1232,10 +1237,15 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None, use_dense_op outputs : List of output tensor names (Optional) if not specified then the last node is assumed as graph output. - use_dense_op : bool (Optional) = True - Ture to convert `tf.matmul` to `nn.dense`, else to `nn.matmul`. - The `nn.dense` op requires the data tensor to be non-transposed and weight tensor to be - transposed, may insert extra `transpose` to the original graph. + convert_config : Optional[Dict[str, Any]] + Default config: + use_dense : bool = True + Ture to convert `tf.matmul` to `nn.dense`, else to `nn.matmul`. + The `nn.dense` op requires the data tensor to be non-transposed and weight tensor + to be transposed, may insert extra `transpose` to the original graph. + use_nt_batch_matmul : bool = True + True to convert `tf.batch_matmul` to `nn.batch_matmul` strict to NT format + (transpose_a=False, transpose_b=True). Returns ------- @@ -1246,7 +1256,8 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None, use_dense_op Dict of converted parameters stored in tvm.nd.NDArray format """ global TF_DEFAULT_CONFIGS - TF_DEFAULT_CONFIGS["use_dense"] = use_dense_op + if convert_config is not None: + TF_DEFAULT_CONFIGS.update(convert_config) g = GraphProto() mod, params = g.from_tensorflow(graph, layout, shape, outputs) diff --git a/python/tvm/relay/frontend/tensorflow_ops.py b/python/tvm/relay/frontend/tensorflow_ops.py index ba0fcca0197d..2d8188202a66 100644 --- a/python/tvm/relay/frontend/tensorflow_ops.py +++ b/python/tvm/relay/frontend/tensorflow_ops.py @@ -1137,6 +1137,8 @@ def _impl(inputs, attr, params, mod): def _batch_matmul(): def _impl(inputs, attr, params, mod): + from .tensorflow import TF_DEFAULT_CONFIGS + input_x = inputs[0] input_y = inputs[1] orig_shape_x = _infer_shape(input_x, mod) @@ -1173,9 +1175,16 @@ def _impl(inputs, attr, params, mod): input_y = _op.reshape(input_y, (1, orig_shape_y[-2], orig_shape_y[-1])) adj_x = attr["adj_x"] adj_y = attr["adj_y"] - input_x = _op.transpose(input_x, axes=[0, 2, 1]) if adj_x else input_x - input_y = _op.transpose(input_y, axes=[0, 2, 1]) if not adj_y else input_y - ret = get_relay_op("batch_matmul")(input_x, input_y) + + if TF_DEFAULT_CONFIGS["use_nt_batch_matmul"]: + # Strictly convert all batch_matmul to NT format + input_x = _op.transpose(input_x, axes=[0, 2, 1]) if adj_x else input_x + input_y = _op.transpose(input_y, axes=[0, 2, 1]) if not adj_y else input_y + ret = get_relay_op("batch_matmul")(input_x, input_y) + else: + ret = get_relay_op("batch_matmul")( + input_x, input_y, transpose_a=adj_x, transpose_b=adj_y + ) # reshape result back to n-dimensional if ndim > 3: diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index fa2772c1299f..3793f947c5cc 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -590,11 +590,59 @@ def batch_matmul_grad(orig, grad): GRAD_OUT_bij,LHS_bik->GRAD_IN_RHS_bjk """ lhs, rhs = orig.args + if (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (True, True): + # ki, jk -> ij + # jk, ij -> ki + # ij, ki -> jk + return [ + collapse_sum_like(_nn.batch_matmul(rhs, grad, transpose_a=True, transpose_b=True), lhs), + collapse_sum_like(_nn.batch_matmul(grad, lhs, transpose_a=True, transpose_b=True), rhs), + ] + if (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (True, False): + # ki, kj -> ij + # kj, ij -> ki + # ki, ij -> kj + return [ + collapse_sum_like( + _nn.batch_matmul(rhs, grad, transpose_a=False, transpose_b=True), lhs + ), + collapse_sum_like( + _nn.batch_matmul(lhs, grad, transpose_a=False, transpose_b=False), rhs + ), + ] + if (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (False, True): + # ik, jk -> ij + # ij, jk -> ik + # ij, ik -> jk + # Keep using NT format batch_matmul here for not involving extra ops + # TODO(jcf94): Merge all to normal batch_matmul when it is finally ready + return [ + collapse_sum_like( + _nn.batch_matmul( + grad, + transpose(rhs, [0, 2, 1]), + transpose_a=False, + transpose_b=True, + ), + lhs, + ), + collapse_sum_like( + _nn.batch_matmul( + transpose(grad, [0, 2, 1]), + transpose(lhs, [0, 2, 1]), + transpose_a=False, + transpose_b=True, + ), + rhs, + ), + ] + # (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (False, False) + # ik, kj -> ij + # ij, kj -> ik + # ik, ij -> kj return [ - collapse_sum_like(_nn.batch_matmul(grad, transpose(rhs, [0, 2, 1])), lhs), - collapse_sum_like( - _nn.batch_matmul(transpose(grad, [0, 2, 1]), transpose(lhs, [0, 2, 1])), rhs - ), + collapse_sum_like(_nn.batch_matmul(grad, rhs, transpose_a=False, transpose_b=True), lhs), + collapse_sum_like(_nn.batch_matmul(lhs, grad, transpose_a=True, transpose_b=False), rhs), ] diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 6e757ea6fa74..96cef8bc3588 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -1276,14 +1276,11 @@ def dense_pack_shape_func(attrs, inputs, _): @script -def _batch_matmul_shape_func(data_shape, weight_shape): - out = output_tensor((data_shape.shape[0],), "int64") - for i in const_range(out.shape[0] - 1): - if i == 0: - out[i] = max(data_shape[i], weight_shape[i]) - else: - out[i] = data_shape[i] - out[out.shape[0] - 1] = weight_shape[weight_shape.shape[0] - 2] +def _batch_matmul_shape_func(tensor_a_shape, tensor_b_shape, transpose_a, transpose_b): + out = output_tensor((tensor_a_shape.shape[0],), "int64") + out[0] = max(tensor_a_shape[0], tensor_b_shape[0]) + out[1] = tensor_a_shape[2] if transpose_a else tensor_a_shape[1] + out[2] = tensor_b_shape[1] if transpose_b else tensor_b_shape[2] return out @@ -1291,9 +1288,16 @@ def _batch_matmul_shape_func(data_shape, weight_shape): @reg.register_shape_func("nn.batch_matmul", False) def batch_matmul_shape_func(attrs, inputs, _): """ - Shape function for dense op. + Shape function for batch matmul op. """ - ret = [_batch_matmul_shape_func(inputs[0], inputs[1])] + ret = [ + _batch_matmul_shape_func( + inputs[0], + inputs[1], + expr.IntImm("bool", attrs.transpose_a), + expr.IntImm("bool", attrs.transpose_b), + ) + ] return ret diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 4c94102275bb..64b397a4d4f9 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -2137,32 +2137,40 @@ def group_norm(data, gamma, beta, num_groups, axis=1, epsilon=1e-5, center=True, return _make.group_norm(data, gamma, beta, num_groups, axis, epsilon, center, scale) -def batch_matmul(x, y, out_dtype=""): +def batch_matmul(tensor_a, tensor_b, out_dtype="", transpose_a=False, transpose_b=True): r""" - Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data - in batch. + Compute batch matrix multiplication of `tensor_a` and `tensor_b`. + + Both `tensor_a` and `tensor_b` can be transposed. For legacy reason, we use NT format + (transpose_a=False, transpose_b=True) by default. .. math:: - \mbox{batch_matmul}(x, y)[i, :, :] = \mbox{matmul}(x[i, :, :], y[i, :, :]^T) + \mbox{batch_matmul}(A, B)[i, :, :] = \mbox{matmul}(A[i, :, :], B[i, :, :]) Parameters ---------- - x : tvm.relay.Expr + tensor_a : tvm.relay.Expr The first input. - y : tvm.relay.Expr + tensor_b : tvm.relay.Expr The second input. - out_dtype : str, optional - Specifies the output data type for mixed precision batch matmul + out_dtype : Optional[str] + Specifies the output data type for mixed precision batch matmul. + + transpose_a : Optional[bool] = False + Whether the first tensor is in transposed format. + + transpose_b : Optional[bool] = True + Whether the second tensor is in transposed format. Returns ------- result: tvm.relay.Expr The computed result. """ - return _make.batch_matmul(x, y, out_dtype) + return _make.batch_matmul(tensor_a, tensor_b, out_dtype, transpose_a, transpose_b) # pylint: disable=no-else-return,inconsistent-return-statements diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 2d185bcee798..507dd9371a97 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -74,6 +74,11 @@ class DenseAttrs(Attrs): """Attributes for nn.dense""" +@tvm._ffi.register_object("relay.attrs.BatchMatmulAttrs") +class BatchMatmulAttrs(Attrs): + """Attributes for nn.batch_matmul""" + + @tvm._ffi.register_object("relay.attrs.SoftmaxAttrs") class SoftmaxAttrs(Attrs): """Attributes for nn.softmax""" diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 1f999a810164..ba47ae7bc4f1 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -819,7 +819,13 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target): """batch_matmul cuda strategy""" strategy = _op.OpStrategy() x, y = inputs - if x.dtype == "int8" and y.dtype == "int8" and out_type.dtype == "int32": + if ( + x.dtype == "int8" + and y.dtype == "int8" + and out_type.dtype == "int32" + and not attrs["transpose_a"] + and attrs["transpose_b"] + ): strategy.add_implementation( wrap_compute_batch_matmul(topi.cuda.batch_matmul_int8, need_out_dtype=True), wrap_topi_schedule(topi.cuda.schedule_batch_matmul_int8), @@ -840,7 +846,12 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target): name="batch_matmul_cublas.cuda", plevel=15, ) - if target.kind.name == "cuda" and nvcc.have_tensorcore(target=target): + if ( + target.kind.name == "cuda" + and nvcc.have_tensorcore(target=target) + and not attrs["transpose_a"] + and attrs["transpose_b"] + ): x, y = inputs _, M, K = get_const_tuple(x.shape) _, N, K = get_const_tuple(y.shape) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 3348d8033904..9c756f201721 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -799,10 +799,11 @@ def wrap_compute_batch_matmul(topi_compute, need_auto_scheduler_layout=False, ne def _compute_batch_matmul(attrs, inputs, out_type): args = [inputs[0], inputs[1], out_type.shape] + args.append(out_type.dtype if need_out_dtype else None) + args.append(attrs.transpose_a) + args.append(attrs.transpose_b) if need_auto_scheduler_layout: args.append(get_auto_scheduler_rewritten_layout(attrs)) - if need_out_dtype: - args.append(out_type.dtype) return [topi_compute(*args)] return _compute_batch_matmul diff --git a/python/tvm/topi/cuda/batch_matmul.py b/python/tvm/topi/cuda/batch_matmul.py index fb91912f29a0..3fc8a584b557 100644 --- a/python/tvm/topi/cuda/batch_matmul.py +++ b/python/tvm/topi/cuda/batch_matmul.py @@ -27,9 +27,49 @@ @autotvm.register_topi_compute("batch_matmul.cuda") -def batch_matmul(cfg, x, y, out_shape=None): - """Compute conv2d with NCHW layout""" - return nn.batch_matmul(x, y) +def batch_matmul(cfg, x, y, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True): + """Compute batch matrix multiplication of `tensor_a` and `tensor_b`. + + Both `tensor_a` and `tensor_b` can be transposed. For legacy reason, we use NT format + (transpose_a=False, transpose_b=True) by default. + + Parameters + ---------- + cfg : ConfigSpace + Autotvm tuning space config file. + + tensor_a : tvm.te.Tensor + 3-D with shape [batch, M, K] or [batch, K, M]. + + tensor_b : tvm.te.Tensor + 3-D with shape [batch, K, N] or [batch, N, K]. + + out_shape : List[Optional] + Explicit intended output shape of the computation. Can be useful in cases + with dynamic input shapes. + + out_dtype : Optional[str] + Specifies the output data type for mixed precision batch matmul. + + transpose_a : Optional[bool] = False + Whether the first tensor is in transposed format. + + transpose_b : Optional[bool] = True + Whether the second tensor is in transposed format. + + Returns + ------- + output : tvm.te.Tensor + 3-D with shape [batch, M, N] + """ + return nn.batch_matmul( + x, + y, + oshape=out_shape, + out_dtype=out_dtype, + transpose_a=transpose_a, + transpose_b=transpose_b, + ) @autotvm.register_topi_schedule("batch_matmul.cuda") @@ -140,31 +180,54 @@ def _callback(op): @autotvm.register_topi_compute("batch_matmul_cublas.cuda") -def batch_matmul_cublas(cfg, x, y, out_shape=None): - """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are - data in batch. +def batch_matmul_cublas( + cfg, x, y, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True +): + """Compute batch matrix multiplication of `x` and `y`. + + Both `x` and `y` can be transposed. For legacy reason, we use NT format + (transpose_a=False, transpose_b=True) by default. Parameters ---------- + cfg : ConfigSpace + Autotvm tuning space config file. + x : tvm.te.Tensor - 3-D with shape [batch, M, K] + 3-D with shape [batch, M, K] or [batch, K, M]. y : tvm.te.Tensor - 3-D with shape [batch, N, K] + 3-D with shape [batch, K, N] or [batch, N, K]. - out_shape : None - The output shape + out_shape : List[Optional] + Explicit intended output shape of the computation. Can be useful in cases + with dynamic input shapes. + + out_dtype : Optional[str] + Specifies the output data type for mixed precision batch matmul. + + transpose_a : Optional[bool] = False + Whether the first tensor is in transposed format. + + transpose_b : Optional[bool] = True + Whether the second tensor is in transposed format. Returns ------- output : tvm.te.Tensor 3-D with shape [batch, M, N] """ - b, m, k = get_const_tuple(x.shape) - b, n, k = get_const_tuple(y.shape) + if transpose_a: + b, k, m = get_const_tuple(x.shape) + else: + b, m, k = get_const_tuple(x.shape) + if transpose_b: + b, n, k = get_const_tuple(y.shape) + else: + b, k, n = get_const_tuple(y.shape) if all([isinstance(s, int) for s in [b, m, n, k]]): cfg.add_flop(b * m * k * n * 2) - return cublas.batch_matmul(x, y, False, True) + return cublas.batch_matmul(x, y, transa=transpose_a, transb=transpose_b) @autotvm.register_topi_schedule("batch_matmul_cublas.cuda") @@ -175,7 +238,31 @@ def schedule_batch_matmul_cublas(_, outs): @autotvm.register_topi_compute("batch_matmul_int8.cuda") def batch_matmul_int8(cfg, x, y, out_shape=None, out_dtype=None): - """Batch Matmul operator for int8 on CUDA""" + """Batch Matmul operator for int8 on CUDA. + + Parameters + ---------- + cfg : ConfigSpace + Autotvm tuning space config file. + + x : tvm.te.Tensor + 3-D with shape [batch, M, K] or [batch, K, M]. + + y : tvm.te.Tensor + 3-D with shape [batch, K, N] or [batch, N, K]. + + out_shape : List[Optional] + Explicit intended output shape of the computation. Can be useful in cases + with dynamic input shapes. + + out_dtype : Optional[str] + Specifies the output data type for mixed precision batch matmul. + + Returns + ------- + output : tvm.te.Tensor + 3-D with shape [batch, M, N] + """ if out_dtype is None: out_dtype = x.dtype diff --git a/python/tvm/topi/cuda/tensorcore_alter_op.py b/python/tvm/topi/cuda/tensorcore_alter_op.py index fffb0d6d48fc..50bcafd9f9a7 100644 --- a/python/tvm/topi/cuda/tensorcore_alter_op.py +++ b/python/tvm/topi/cuda/tensorcore_alter_op.py @@ -19,7 +19,7 @@ import logging import math -from tvm import relay +from tvm import relay, tir from .. import nn @@ -56,6 +56,15 @@ def _batch_matmul_legalize(attrs, inputs, arg_types): B, M, K = x_tensor.shape B, N, K = y_tensor.shape + if ( + isinstance(B, tir.expr.Any) + or isinstance(M, tir.expr.Any) + or isinstance(K, tir.expr.Any) + or isinstance(N, tir.expr.Any) + ): + # Dynamic shape do not support alter op layout now + return None + M = M.value K = K.value N = N.value diff --git a/python/tvm/topi/nn/batch_matmul.py b/python/tvm/topi/nn/batch_matmul.py index a1212668affa..26d45feb0387 100644 --- a/python/tvm/topi/nn/batch_matmul.py +++ b/python/tvm/topi/nn/batch_matmul.py @@ -16,28 +16,50 @@ # under the License. """Batch matrix multiplication""" # pylint: disable=invalid-name +import logging import tvm from tvm import te, auto_scheduler from ..utils import get_const_tuple +logger = logging.getLogger("topi") -def batch_matmul(x, y, oshape=None, auto_scheduler_rewritten_layout="", out_dtype=None): - """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are - data in batch. Supports broadcasting for batch dimension. + +def batch_matmul( + tensor_a, + tensor_b, + oshape=None, + out_dtype=None, + transpose_a=False, + transpose_b=True, + auto_scheduler_rewritten_layout="", +): + """Compute batch matrix multiplication of `tensor_a` and `tensor_b`. + + Both `tensor_a` and `tensor_b` can be transposed. For legacy reason, we use NT format + (transpose_a=False, transpose_b=True) by default. Parameters ---------- - x : tvm.te.Tensor - 3-D with shape [batch, M, K] + tensor_a : tvm.te.Tensor + 3-D with shape [batch, M, K] or [batch, K, M]. - y : tvm.te.Tensor - 3-D with shape [batch, N, K] + tensor_b : tvm.te.Tensor + 3-D with shape [batch, K, N] or [batch, N, K]. oshape : List[Optional] Explicit intended output shape of the computation. Can be useful in cases with dynamic input shapes. - auto_scheduler_rewritten_layout: str = "" + out_dtype : Optional[str] + Specifies the output data type for mixed precision batch matmul. + + transpose_a : Optional[bool] = False + Whether the first tensor is in transposed format. + + transpose_b : Optional[bool] = True + Whether the second tensor is in transposed format. + + auto_scheduler_rewritten_layout: Optional[str] = "" The layout after auto-scheduler's layout rewrite pass. Returns @@ -45,49 +67,79 @@ def batch_matmul(x, y, oshape=None, auto_scheduler_rewritten_layout="", out_dtyp output : tvm.te.Tensor 3-D with shape [batch, M, N] """ - x_shape = get_const_tuple(x.shape) + assert len(tensor_a.shape) == 3, "tensor_a only support 3-dim" + if transpose_a: + XB, XK, XI = get_const_tuple(tensor_a.shape) + else: + XB, XI, XK = get_const_tuple(tensor_a.shape) if auto_scheduler_rewritten_layout: # Infer shape for the rewritten layout - y_shape = auto_scheduler.get_shape_from_rewritten_layout( - auto_scheduler_rewritten_layout, ["b", "j", "k"] + YB, YK, YJ = auto_scheduler.get_shape_from_rewritten_layout( + auto_scheduler_rewritten_layout, ["b", "k", "j"] ) - auto_scheduler.remove_index_check(y) + auto_scheduler.remove_index_check(tensor_b) else: - y_shape = get_const_tuple(y.shape) - assert len(x_shape) == 3 and len(y_shape) == 3, "only support 3-dim batch_matmul" + assert len(tensor_b.shape) == 3, "tensor_b only support 3-dim" + if transpose_b: + YB, YJ, YK = get_const_tuple(tensor_b.shape) + else: + YB, YK, YJ = get_const_tuple(tensor_b.shape) - XB = x_shape[0] - YB = y_shape[0] - _, M, K = x.shape - k = te.reduce_axis((0, K), name="k") + assert XK == YK or isinstance(YK, tvm.tir.expr.Var), "shapes of x and y are inconsistent" + k = te.reduce_axis((0, XK), name="k") if oshape is None: assert XB == YB or XB == 1 or YB == 1, "batch dimension doesn't match" - assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistent" - batch = te.max(XB, YB) - N = y.shape[1] - oshape = (batch, M, N) - - if out_dtype is None or out_dtype == x.dtype: - output = te.compute( - oshape, - lambda b, i, j: te.sum( - x[b if XB != 1 else 0, i, k] * y[b if YB != 1 else 0, j, k], axis=k - ), - tag="batch_matmul", - attrs={"layout_free_placeholders": [y]}, + batch = ( + tvm.tir.expr.SizeVar("batch", "int32") + if isinstance(XB, tvm.tir.expr.Var) or isinstance(YB, tvm.tir.expr.Var) + else te.max(XB, YB) ) - else: - output = te.compute( - oshape, - lambda b, i, j: te.sum( - x[b if XB != 1 else 0, i, k].astype(out_dtype) - * y[b if YB != 1 else 0, j, k].astype(out_dtype), - axis=k, - ), - tag="batch_matmul", - attrs={"layout_free_placeholders": [y]}, + oshape = (batch, XI, YJ) + if out_dtype is None: + out_dtype = tensor_a.dtype + if tensor_a.dtype != tensor_b.dtype: + logger.warning( + "tensor_a has different data type with tensor_b: %s, %s", + tensor_a.dtype, + tensor_b.dtype, + ) + + if (transpose_a, transpose_b) == (True, True): + compute_lambda = lambda b, i, j: te.sum( + tensor_a[b if XB != 1 else 0, k, i].astype(out_dtype) + * tensor_b[b if YB != 1 else 0, j, k].astype(out_dtype), + axis=k, + ) + compute_name = "T_batch_matmul_TT" + elif (transpose_a, transpose_b) == (True, False): + compute_lambda = lambda b, i, j: te.sum( + tensor_a[b if XB != 1 else 0, k, i].astype(out_dtype) + * tensor_b[b if YB != 1 else 0, k, j].astype(out_dtype), + axis=k, + ) + compute_name = "T_batch_matmul_TN" + elif (transpose_a, transpose_b) == (False, True): + compute_lambda = lambda b, i, j: te.sum( + tensor_a[b if XB != 1 else 0, i, k].astype(out_dtype) + * tensor_b[b if YB != 1 else 0, j, k].astype(out_dtype), + axis=k, + ) + compute_name = "T_batch_matmul_NT" + else: # (transpose_a, transpose_b) == (False, False): + compute_lambda = lambda b, i, j: te.sum( + tensor_a[b if XB != 1 else 0, i, k].astype(out_dtype) + * tensor_b[b if YB != 1 else 0, k, j].astype(out_dtype), + axis=k, ) + compute_name = "T_batch_matmul_NN" + output = te.compute( + oshape, + compute_lambda, + name=compute_name, + tag="batch_matmul", + attrs={"layout_free_placeholders": [tensor_b]}, + ) if auto_scheduler_rewritten_layout: output = auto_scheduler.rewrite_compute_body(output, auto_scheduler_rewritten_layout) diff --git a/python/tvm/topi/testing/batch_matmul.py b/python/tvm/topi/testing/batch_matmul.py index 96d1fcbb5bc3..18fa7e8c4b33 100644 --- a/python/tvm/topi/testing/batch_matmul.py +++ b/python/tvm/topi/testing/batch_matmul.py @@ -19,7 +19,7 @@ import numpy as np -def batch_matmul(x, y, out_dtype=None): +def batch_matmul(x, y, out_dtype=None, trans_x=False, trans_y=True): """batch_matmul operator implemented in numpy. Parameters @@ -38,13 +38,22 @@ def batch_matmul(x, y, out_dtype=None): out : numpy.ndarray 3-D with shape [batch, M, N] """ - XB, M, _ = x.shape - YB, N, _ = y.shape + if trans_x: + XB, _, M = x.shape + else: + XB, M, _ = x.shape + if trans_y: + YB, N, _ = y.shape + else: + YB, _, N = y.shape batch = max(XB, YB) dtype = x.dtype if out_dtype is None else out_dtype out = np.zeros((batch, M, N)).astype(dtype) for i in range(batch): + xx = x[i if XB != 1 else 0].astype(dtype) + yy = y[i if YB != 1 else 0].astype(dtype) out[i] = np.dot( - x[i if XB != 1 else 0].astype(dtype), y[i if YB != 1 else 0].T.astype(dtype) + xx.T if trans_x else xx, + yy.T if trans_y else yy, ) return out diff --git a/python/tvm/topi/x86/batch_matmul.py b/python/tvm/topi/x86/batch_matmul.py index 35f4a9aba456..13ca851f0e38 100644 --- a/python/tvm/topi/x86/batch_matmul.py +++ b/python/tvm/topi/x86/batch_matmul.py @@ -20,65 +20,66 @@ from tvm import autotvm from tvm.autotvm.task.space import SplitEntity from tvm.contrib import cblas, mkl -from .. import generic +from .. import generic, nn from ..utils import traverse_inline, get_const_tuple, get_max_power2_factor @autotvm.register_topi_compute("batch_matmul.x86") -def batch_matmul(cfg, x, y, out_shape=None, out_dtype=None): - """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are - data in batch. Supports broadcasting in batch dimension. +def batch_matmul( + cfg, tensor_a, tensor_b, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True +): + """Compute batch matrix multiplication of `tensor_a` and `tensor_b`. + + Both `tensor_a` and `tensor_b` can be transposed. For legacy reason, we use NT format + (transpose_a=False, transpose_b=True) by default. Parameters ---------- cfg : ConfigSpace - Autotvm tuning space config file - x : tvm.te.Tensor - 3-D with shape [batch, M, K] - y : tvm.te.Tensor - 3-D with shape [batch, N, K] - out_shape : tuple or None - Shape of the outputs + Autotvm tuning space config file. + + tensor_a : tvm.te.Tensor + 3-D with shape [batch, M, K] or [batch, K, M]. + + tensor_b : tvm.te.Tensor + 3-D with shape [batch, K, N] or [batch, N, K]. + + out_shape : List[Optional] + Explicit intended output shape of the computation. Can be useful in cases + with dynamic input shapes. + + out_dtype : Optional[str] + Specifies the output data type for mixed precision batch matmul. + + transpose_a : Optional[bool] = False + Whether the first tensor is in transposed format. + + transpose_b : Optional[bool] = True + Whether the second tensor is in transposed format. Returns ------- output : tvm.te.Tensor 3-D with shape [batch, M, N] """ - assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul" - XB, M, XK = get_const_tuple(x.shape) - YB, N, YK = get_const_tuple(y.shape) - assert (XB == YB) or (YB == 1) or (XB == 1), "batch dimension doesn't match" - assert XK == YK, "shapes of x and y is inconsistent" - B = te.max(XB, YB) - K = XK - if out_shape is not None: - assert out_shape[0] == B, "got invalid output shape" - assert out_shape[1] == M, "got invalid output shape" - assert out_shape[2] == N, "got invalid output shape" if cfg.is_fallback: + if transpose_a: + _, K, M = get_const_tuple(tensor_a.shape) + else: + _, M, K = get_const_tuple(tensor_a.shape) + if transpose_b: + _, N, _ = get_const_tuple(tensor_b.shape) + else: + _, _, N = get_const_tuple(tensor_b.shape) _default_batch_matmul_config(cfg, M, N, K) - - k = te.reduce_axis((0, K), name="k") - if out_dtype is None or out_dtype == x.dtype: - C = te.compute( - (B, M, N), - lambda b, i, j: te.sum( - x[b if XB != 1 else 0, i, k] * y[b if YB != 1 else 0, j, k], axis=k - ), - tag="batch_matmul", - ) - else: - C = te.compute( - (B, M, N), - lambda b, i, j: te.sum( - x[b if XB != 1 else 0, i, k].astype(out_dtype) - * y[b if YB != 1 else 0, j, k].astype(out_dtype), - axis=k, - ), - tag="batch_matmul", - ) - return C + return nn.batch_matmul( + tensor_a, + tensor_b, + out_shape, + out_dtype, + transpose_a, + transpose_b, + ) @autotvm.register_topi_schedule("batch_matmul.x86") @@ -150,20 +151,32 @@ def _default_batch_matmul_config(cfg, M, N, K): cfg["tile_y"] = SplitEntity([M // y_bn, y_bn]) -def batch_matmul_blas_common(cfg, x, y, out_shape, lib): - """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are - data in batch, using one of BLAS libraries. Supports broadcasting in batch dimension. +def batch_matmul_blas_common(cfg, tensor_a, tensor_b, out_shape, trans_a, trans_b, lib): + """Computes batch matrix multiplication of `tensor_a` and `tensor_b` when `tensor_a` and + `tensor_b` are data in batch, using one of BLAS libraries. Supports broadcasting in batch + dimension. Parameters ---------- cfg : ConfigSpace Autotvm tuning space config file - x : tvm.te.Tensor - 3-D with shape [batch, M, K] - y : tvm.te.Tensor - 3-D with shape [batch, N, K] - out_shape : tuple or None - Shape of the output + + tensor_a : tvm.te.Tensor + 3-D with shape [batch, M, K] or [batch, K, M]. + + tensor_b : tvm.te.Tensor + 3-D with shape [batch, K, N] or [batch, N, K]. + + out_shape : List[Optional] + Explicit intended output shape of the computation. Can be useful in cases + with dynamic input shapes. + + trans_a : Optional[bool] = False + Whether the first tensor is in transposed format. + + trans_b : Optional[bool] = True + Whether the second tensor is in transposed format. + lib : A contrib module which implements batch_matmul function cblas and mkl are supported @@ -172,9 +185,15 @@ def batch_matmul_blas_common(cfg, x, y, out_shape, lib): output : tvm.te.Tensor 3-D with shape [batch, M, N] """ - assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul" - XB, M, XK = get_const_tuple(x.shape) - YB, N, YK = get_const_tuple(y.shape) + assert len(tensor_a.shape) == 3 and len(tensor_b.shape) == 3, "only support 3-dim batch_matmul" + if trans_a: + XB, XK, M = get_const_tuple(tensor_a.shape) + else: + XB, M, XK = get_const_tuple(tensor_a.shape) + if trans_b: + YB, N, YK = get_const_tuple(tensor_b.shape) + else: + YB, YK, N = get_const_tuple(tensor_a.shape) assert (XB == YB) or (YB == 1) or (XB == 1), "batch dimension doesn't match" assert XK == YK, "shapes of x and y is inconsistent" if out_shape is not None: @@ -182,13 +201,18 @@ def batch_matmul_blas_common(cfg, x, y, out_shape, lib): assert out_shape[1] == M, "got invalid output shape" assert out_shape[2] == N, "got invalid output shape" cfg.add_flop(XB * M * N * XK * 2) - return lib.batch_matmul(x, y, False, True) + return lib.batch_matmul(tensor_a, tensor_b, trans_a, trans_b) @autotvm.register_topi_compute("batch_matmul_cblas.x86") -def batch_matmul_cblas(cfg, x, y, out_shape=None): +def batch_matmul_cblas( + cfg, tensor_a, tensor_b, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True +): """Compute batch_matmul using cblas""" - return batch_matmul_blas_common(cfg, x, y, out_shape, cblas) + del out_dtype # Unused argument + return batch_matmul_blas_common( + cfg, tensor_a, tensor_b, out_shape, transpose_a, transpose_b, cblas + ) @autotvm.register_topi_schedule("batch_matmul_cblas.x86") @@ -198,9 +222,14 @@ def schedule_batch_matmul_cblas(_, outs): @autotvm.register_topi_compute("batch_matmul_mkl.x86") -def batch_matmul_mkl(cfg, x, y, out_shape=None): +def batch_matmul_mkl( + cfg, tensor_a, tensor_b, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True +): """Compute batch_matmul using mkl""" - return batch_matmul_blas_common(cfg, x, y, out_shape, mkl) + del out_dtype # Unused argument + return batch_matmul_blas_common( + cfg, tensor_a, tensor_b, out_shape, transpose_a, transpose_b, mkl + ) @autotvm.register_topi_schedule("batch_matmul_mkl.x86") diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index 1a47193bb91a..43ce6656cdc0 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -49,7 +49,7 @@ Expr MakeMatmul(Expr tensor_a, Expr tensor_b, IndexExpr units, DataType out_dtyp Expr MakeDense(Expr data, Expr weight, IndexExpr units, DataType out_dtype); -Expr MakeBatchMatmul(Expr lhs, Expr rhs, DataType out_dtype); +Expr MakeBatchMatmul(Expr lhs, Expr rhs, DataType out_dtype, bool transpose_a, bool transpose_b); Expr MakeExpandDims(Expr data, int axis, int num_newaxis); diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 76a12e27c361..a96f167df2bb 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -932,37 +932,43 @@ If the input has size k on axis 1, then both gamma and beta have shape (k,). .set_support_level(1) .add_type_rel("GroupNorm", GroupNormRel); -// relay.nn.batch_matmul +// ------------------- relay.nn.batch_matmul TVM_REGISTER_NODE_TYPE(BatchMatmulAttrs); // Positional relay function to create batch_matmul operator used by frontend FFI. -Expr MakeBatchMatmul(Expr x, Expr y, DataType out_dtype) { +Expr MakeBatchMatmul(Expr tensor_a, Expr tensor_b, DataType out_dtype, bool transpose_a, + bool transpose_b) { auto attrs = make_object(); attrs->out_dtype = out_dtype; + attrs->transpose_a = transpose_a; + attrs->transpose_b = transpose_b; static const Op& op = Op::Get("nn.batch_matmul"); - return Call(op, {x, y}, Attrs(attrs), {}); + return Call(op, {tensor_a, tensor_b}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_matmul").set_body_typed(MakeBatchMatmul); RELAY_REGISTER_OP("nn.batch_matmul") - .describe(R"code(Computes matrix multiplication of `x` and `y` when `x` and `y` -are data in batch. + .describe(R"code(Compute batch matrix multiplication of `tensor_a` and `tensor_b`. + +Both `tensor_a` and `tensor_b` can be transposed. For legacy reason, we use NT format +(transpose_a=False, transpose_b=True) by default. .. math:: - batch\_matmul(x, y)[i, :, :] = matmul(x[i, :, :], y[i, :, :]^T) + batch\_matmul(A, B)[i, :, :] = matmul(A[i, :, :], B[i, :, :]^T) -- **x**: `(b, m, k)` -- **y**: `(b, n, k)` +- **tensor_a**: `(b, m, k)` or `(b, k, m)` +- **tensor_b**: `(b, k, n)` or `(b, n, k)` - **out**: `(b, m, n)`. )code" TVM_ADD_FILELINE) .set_num_inputs(2) - .add_argument("x", "3D Tensor", "First input.") - .add_argument("y", "3D Tensor", "Second input.") + .add_argument("tensor_a", "3D Tensor", "The first input.") + .add_argument("tensor_b", "3D Tensor", "The second input.") .set_support_level(10) .add_type_rel("BatchMatmul", BatchMatmulRel); +// ------------------- relay.nn.batch_matmul // relay.nn.cross_entropy bool CrossEntropyRel(const Array& types, int num_inputs, const Attrs& attrs, diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h index cf2ec84d1a6e..3dc63b31a205 100644 --- a/src/relay/op/nn/nn.h +++ b/src/relay/op/nn/nn.h @@ -148,46 +148,47 @@ bool BatchMatmulRel(const Array& types, int num_inputs, const Attrs& attrs if (x == nullptr || y == nullptr) return false; const AttrType* param = attrs.as(); - Array y_shape; - if (param->auto_scheduler_rewritten_layout.size() == 0) { - y_shape = y->shape; - } else { - y_shape = auto_scheduler::GetShapeFromRewrittenLayout(param->auto_scheduler_rewritten_layout, - {"b", "j", "k"}); - } - + ICHECK(param != nullptr); + bool transpose_a = param->transpose_a; + bool transpose_b = param->transpose_b; + const Array& y_shape = + param->auto_scheduler_rewritten_layout.size() == 0 + ? y->shape + : auto_scheduler::GetShapeFromRewrittenLayout( + param->auto_scheduler_rewritten_layout, + transpose_b ? tvm::runtime::Array({"b", "j", "k"}) + : tvm::runtime::Array({"b", "k", "j"})); ICHECK(x->shape.size() == 3 && y_shape.size() == 3); + const PrimExpr& xb = x->shape[0]; + const PrimExpr& xi = x->shape[transpose_a ? 2 : 1]; + const PrimExpr& xk = x->shape[transpose_a ? 1 : 2]; + const PrimExpr& yb = y_shape[0]; + const PrimExpr& yk = y_shape[transpose_b ? 2 : 1]; + const PrimExpr& yj = y_shape[transpose_b ? 1 : 2]; + bool is_dyn = false; - Array oshape; for (size_t i = 0; i < 3; ++i) { if (x->shape[i].as() != nullptr || y_shape[i].as() != nullptr) { is_dyn = true; - oshape.push_back(Any()); - } else { - if (i == 0) { - oshape.push_back(max(x->shape[i], y_shape[i])); - } else { - oshape.push_back(x->shape[i]); - } + break; } } if (!is_dyn) { - ICHECK(reporter->AssertEQ(x->shape[0], y_shape[0]) || reporter->AssertEQ(x->shape[0], 1) || - reporter->AssertEQ(y_shape[0], 1)) + ICHECK(reporter->AssertEQ(xb, yb) || reporter->AssertEQ(xb, 1) || reporter->AssertEQ(yb, 1)) << "BatchDot: batch dimensions don't match, " << " x shape=" << x->shape << ", y shape=" << y_shape; - ICHECK(reporter->AssertEQ(x->shape[2], y_shape[2])) - << "BatchDot: shapes of x and y is inconsistent, " - << " x shape=" << x->shape << ", y shape=" << y_shape; + ICHECK(reporter->AssertEQ(xk, yk)) << "BatchDot: shapes of x and y is inconsistent, " + << " x shape=" << x->shape << ", y shape=" << y_shape; } - oshape.Set(2, y_shape[1]); DataType out_dtype = param->out_dtype; if (out_dtype.bits() == 0) { out_dtype = x->dtype; } // assign output type - reporter->Assign(types[2], TensorType(oshape, out_dtype)); + const auto& out_b = + xb->IsInstance() || yb->IsInstance() ? tir::Any() : max(xb, yb); + reporter->Assign(types[2], TensorType(Array({out_b, xi, yj}), out_dtype)); return true; } diff --git a/src/relay/qnn/op/batch_matmul.cc b/src/relay/qnn/op/batch_matmul.cc index bb2b73141afc..4b0bcacacaa1 100644 --- a/src/relay/qnn/op/batch_matmul.cc +++ b/src/relay/qnn/op/batch_matmul.cc @@ -78,13 +78,21 @@ Expr MakeQuantizedBatchMatmul(Expr x, Expr y, Expr x_zero_point, Expr y_zero_poi Expr y_scale, DataType out_dtype) { auto attrs = make_object(); attrs->out_dtype = out_dtype; + // For legacy reason, currently `qnn.batch_matmul` only supports + // (transpose_a=false, transpose_b=true) + // TODO(jcf94): extent to support all tensor format + attrs->transpose_a = false; + attrs->transpose_b = true; static const Op& op = Op::Get("qnn.batch_matmul"); return Call(op, {x, y, x_zero_point, y_zero_point, x_scale, y_scale}, Attrs(attrs), {}); } Expr BatchMatmulFirstTerm(const Expr& quantized_x, const Expr& quantized_y, const BatchMatmulAttrs* attrs) { - return MakeBatchMatmul(quantized_x, quantized_y, attrs->out_dtype); + ICHECK(attrs->transpose_a == false && attrs->transpose_b == true) + << "Currently qnn.batch_matmul only supports (transpose_a=false, transpose_b=true)."; + return MakeBatchMatmul(quantized_x, quantized_y, attrs->out_dtype, attrs->transpose_a, + attrs->transpose_b); } Expr BatchMatmulSecondTerm(const Expr& x_quantized_data, const Expr& y_zero_point) { diff --git a/src/relay/transforms/combine_parallel_batch_matmul.cc b/src/relay/transforms/combine_parallel_batch_matmul.cc index f8c46d93c675..ddab87a4893e 100644 --- a/src/relay/transforms/combine_parallel_batch_matmul.cc +++ b/src/relay/transforms/combine_parallel_batch_matmul.cc @@ -68,6 +68,16 @@ class ParallelBatchMatmulCombiner : public ParallelOpCombiner { // shape[2] is the contraction axis and automatically consistent // if it were valid batch_matmul ops + // TODO(jcf94): Add full support of layout format + if (!(attrs_a->transpose_a == false && attrs_a->transpose_b == true && + attrs_b->transpose_a == false && attrs_b->transpose_b == true)) { + LOG(WARNING) << "For legacy reason, this pass only supports" + << " (transpose_a=false, transpose_b=true) now, skip combining these two with:" + << " batch_matmul_a: " << attrs_a->transpose_a << ", " << attrs_a->transpose_b + << " batch_matmul_b: " << attrs_b->transpose_a << ", " << attrs_b->transpose_b; + return false; + } + auto res = eq(rhs_a->dtype, rhs_b->dtype) && eq(restype_a->dtype, restype_b->dtype) && (rhs_a->shape.size() == 3) && (rhs_b->shape.size() == 3) && eq(rhs_a->shape[0], rhs_b->shape[0]) && eq(attrs_a->out_dtype, attrs_b->out_dtype); @@ -86,7 +96,8 @@ class ParallelBatchMatmulCombiner : public ParallelOpCombiner { const auto* origin_attrs = branches[0][0]->attrs.as(); ICHECK(origin_attrs); - return Downcast(MakeBatchMatmul(data, new_weight, origin_attrs->out_dtype)); + return Downcast(MakeBatchMatmul(data, new_weight, origin_attrs->out_dtype, + origin_attrs->transpose_a, origin_attrs->transpose_b)); } bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) { return true; } diff --git a/src/relay/transforms/combine_parallel_dense.cc b/src/relay/transforms/combine_parallel_dense.cc index 3cd9cca4fec4..d5404ba30f90 100644 --- a/src/relay/transforms/combine_parallel_dense.cc +++ b/src/relay/transforms/combine_parallel_dense.cc @@ -72,7 +72,8 @@ class ParallelDenseToBatchCombiner : public ParallelOpBatchCombiner { CHECK_EQ(num_args, 2); const auto* origin_attrs = branches[0][0]->attrs.as(); ICHECK(origin_attrs); - return Downcast(MakeBatchMatmul(new_args[0], new_args[1], origin_attrs->out_dtype)); + return Downcast( + MakeBatchMatmul(new_args[0], new_args[1], origin_attrs->out_dtype, false, true)); } virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b) { diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index a57c402d212c..6733b326c395 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -124,7 +124,7 @@ def run_tvm_graph( disabled_pass=None, ignore_in_shape=False, serialize=False, - use_dense_op=True, + convert_config=None, ): """Generic function to compile on relay and execute on tvm""" input_data = convert_to_list(input_data) @@ -143,7 +143,7 @@ def run_tvm_graph( layout=layout, shape=shape_dict, outputs=out_names, - use_dense_op=use_dense_op, + convert_config=convert_config, ) dev = tvm.device(target, 0) if mode == "debug": @@ -225,7 +225,7 @@ def compare_tf_with_tvm( add_shapes_to_graph_def=True, targets=None, ignore_in_shape=False, - use_dense_op=True, + convert_config=None, ): """Generic function to generate and compare tensorflow and TVM output""" @@ -273,7 +273,7 @@ def name_without_num(name): mode=mode, cuda_layout=cuda_layout, ignore_in_shape=ignore_in_shape, - use_dense_op=use_dense_op, + convert_config=convert_config, ) # since the names from tensorflow and relay runs are not exactly same, # first len(tf_output) will be compared @@ -1811,8 +1811,12 @@ def _test_matmul(i, j, k, dtype, outer=None): A_np = np.random.uniform(high=5.0, size=A_shape).astype(dtype) B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype) - compare_tf_with_tvm([A_np, B_np], [A.name, B.name], result.name, use_dense_op=True) - compare_tf_with_tvm([A_np, B_np], [A.name, B.name], result.name, use_dense_op=False) + compare_tf_with_tvm( + [A_np, B_np], [A.name, B.name], result.name, convert_config={"use_dense": True} + ) + compare_tf_with_tvm( + [A_np, B_np], [A.name, B.name], result.name, convert_config={"use_dense": False} + ) def test_forward_matmul(): @@ -1830,7 +1834,18 @@ def _test_batch_matmul(A_shape, B_shape, dtype, adjoint_a=False, adjoint_b=False A_np = np.random.uniform(high=5.0, size=A_shape).astype(dtype) B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype) - compare_tf_with_tvm([A_np, B_np], [A.name, B.name], result.name) + compare_tf_with_tvm( + [A_np, B_np], + [A.name, B.name], + result.name, + convert_config={"use_nt_batch_matmul": True}, + ) + compare_tf_with_tvm( + [A_np, B_np], + [A.name, B.name], + result.name, + convert_config={"use_nt_batch_matmul": False}, + ) def _test_batch_matmul_dynamic( @@ -1843,10 +1858,23 @@ def _test_batch_matmul_dynamic( A_np = np.random.uniform(high=5.0, size=A_np_shape).astype(dtype) B_np = np.random.uniform(high=5.0, size=B_np_shape).astype(dtype) - # for now, in TOPI, only cublas's implementation support dynamic shape + # for now, in TOPI, only llvm & cublas's implementation support dynamic shape # TODO add more backends support in TOPI compare_tf_with_tvm( - [A_np, B_np], [A.name, B.name], result.name, mode="vm", targets=["cuda -libs=cublas"] + [A_np, B_np], + [A.name, B.name], + result.name, + mode="vm", + targets=["llvm", "cuda -libs=cublas"], + convert_config={"use_nt_batch_matmul": True}, + ) + compare_tf_with_tvm( + [A_np, B_np], + [A.name, B.name], + result.name, + mode="vm", + targets=["llvm", "cuda -libs=cublas"], + convert_config={"use_nt_batch_matmul": False}, ) @@ -1865,7 +1893,6 @@ def test_forward_batch_matmul(): _test_batch_matmul((1, 8, 64), (64, 1), "float32", False, False) -@tvm.testing.requires_cuda def test_forward_batch_matmul_dynamic(): _test_batch_matmul_dynamic((None, 5, 4), (None, 4, 5), (3, 5, 4), (3, 4, 5), "int32") _test_batch_matmul_dynamic( diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 7b1d416d8cef..3f53c11fa36a 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -920,6 +920,98 @@ def test_any_dense_dynamic_batch(): verify_any_dense((relay.Any(), 40), (50, 40), 50, (4, 40), (50, 40), (4, 50), use_cublas=True) +def verify_any_batch_matmul( + x_shape, + y_shape, + out_shape, + x_var_shape, + y_var_shape, + dtype="float32", + trans_x=False, + trans_y=True, +): + x = relay.var("x", relay.TensorType(x_var_shape, dtype)) + y = relay.var("y", relay.TensorType(y_var_shape, dtype)) + z = relay.nn.batch_matmul(x, y, transpose_a=trans_x, transpose_b=trans_y) + + func = relay.Function([x, y], z) + x_np = np.random.uniform(size=x_shape).astype(dtype) + y_np = np.random.uniform(size=y_shape).astype(dtype) + z_np = tvm.topi.testing.batch_matmul(x_np, y_np, trans_x=trans_x, trans_y=trans_y) + + for target, dev in tvm.testing.enabled_targets(): + for kind in ["vm", "debug"]: + mod = tvm.ir.IRModule.from_expr(func) + intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) + z = intrp.evaluate()(x_np, y_np) + tvm.testing.assert_allclose(z.numpy(), z_np, rtol=1e-5) + + +# TODO(mbrookhart): enable once VM supports heterogenous execution +# @tvm.testing.uses_gpu +def test_any_batch_matmul(): + verify_any_batch_matmul((1, 16, 32), (1, 16, 32), (1, 16, 16), (1, 16, 32), (relay.Any(),) * 3) + verify_any_batch_matmul((5, 16, 32), (5, 16, 32), (5, 16, 16), (5, 16, 32), (relay.Any(),) * 3) + verify_any_batch_matmul((5, 16, 32), (5, 20, 32), (5, 16, 20), (5, 16, 32), (relay.Any(),) * 3) + verify_any_batch_matmul( + (30, 16, 32), (30, 20, 32), (30, 16, 20), (30, 16, 32), (relay.Any(),) * 3 + ) + + verify_any_batch_matmul( + (1, 16, 32), (1, 16, 32), (1, 16, 16), (relay.Any(), 16, 32), (relay.Any(), 16, 32) + ) + verify_any_batch_matmul( + (5, 16, 32), (5, 16, 32), (5, 16, 16), (relay.Any(), 16, 32), (relay.Any(), 16, 32) + ) + verify_any_batch_matmul( + (5, 16, 32), (5, 20, 32), (5, 16, 20), (relay.Any(), 16, 32), (relay.Any(), 20, 32) + ) + verify_any_batch_matmul( + (30, 16, 32), (30, 20, 32), (30, 16, 20), (relay.Any(), 16, 32), (relay.Any(), 20, 32) + ) + + verify_any_batch_matmul( + (1, 32, 16), (1, 16, 32), (1, 16, 16), (1, 32, 16), (relay.Any(),) * 3, trans_x=True + ) + verify_any_batch_matmul( + (5, 16, 32), (5, 32, 16), (5, 16, 16), (5, 16, 32), (relay.Any(),) * 3, trans_y=False + ) + verify_any_batch_matmul( + (5, 32, 16), + (5, 32, 20), + (5, 16, 20), + (5, 32, 16), + (relay.Any(),) * 3, + trans_x=True, + trans_y=False, + ) + verify_any_batch_matmul( + (1, 32, 16), + (1, 16, 32), + (1, 16, 16), + (relay.Any(), 32, 16), + (relay.Any(), 16, 32), + trans_x=True, + ) + verify_any_batch_matmul( + (5, 16, 32), + (5, 32, 16), + (5, 16, 16), + (relay.Any(), 16, 32), + (relay.Any(), 32, 16), + trans_y=False, + ) + verify_any_batch_matmul( + (5, 32, 16), + (5, 32, 20), + (5, 16, 20), + (relay.Any(), 32, 16), + (relay.Any(), 32, 20), + trans_x=True, + trans_y=False, + ) + + @tvm.testing.uses_gpu def verify_any_pad(data_shape, pad_width, static_data_shape): mod = tvm.IRModule() diff --git a/tests/python/relay/test_op_grad_level10.py b/tests/python/relay/test_op_grad_level10.py index e2145f77b366..8d961eb60b18 100644 --- a/tests/python/relay/test_op_grad_level10.py +++ b/tests/python/relay/test_op_grad_level10.py @@ -62,10 +62,24 @@ def test_checkpoint(): check_grad(relay.Function(inputs, out_single)) +def verify_batch_matmul_grad(a_shape, b_shape, transpose_a, transpose_b): + tensor_a = relay.var("tensor_a", relay.TensorType(a_shape, "float32")) + tensor_b = relay.var("tensor_b", relay.TensorType(b_shape, "float32")) + check_grad( + relay.Function( + [tensor_a, tensor_b], + relay.op.nn.batch_matmul( + tensor_a, tensor_b, transpose_a=transpose_a, transpose_b=transpose_b + ), + ) + ) + + def test_batch_matmul_grad(): - x = relay.var("x", shape=(2, 3, 5), dtype="float64") - y = relay.var("y", shape=(2, 4, 5), dtype="float64") - check_grad(relay.Function([x, y], relay.op.nn.batch_matmul(x, y))) + verify_batch_matmul_grad((2, 3, 5), (2, 5, 4), False, False) + verify_batch_matmul_grad((2, 3, 5), (2, 4, 5), False, True) + verify_batch_matmul_grad((2, 5, 3), (2, 5, 4), True, False) + verify_batch_matmul_grad((2, 5, 3), (2, 4, 5), True, True) def test_reverse_reshape_grad(): diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index 71598e61f694..eda7eac1b025 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -17,6 +17,7 @@ """ Support level10 operator test cases. """ import numpy as np +import pytest import tvm import tvm.testing import tvm.topi.testing @@ -325,17 +326,17 @@ def verify_reverse_reshape(shape, newshape, oshape): verify_reverse_reshape((2, 3, 4), (0, -3), (2, 12)) -def verify_batch_matmul(x_shape, y_shape, out_shape, dtype="float32"): +def verify_batch_matmul(x_shape, y_shape, out_shape, dtype="float32", trans_x=False, trans_y=True): x = relay.var("x", relay.TensorType(x_shape, dtype)) y = relay.var("y", relay.TensorType(y_shape, dtype)) - z = relay.nn.batch_matmul(x, y) + z = relay.nn.batch_matmul(x, y, transpose_a=trans_x, transpose_b=trans_y) zz = run_infer_type(z) assert zz.checked_type == relay.ty.TensorType(out_shape, dtype) func = relay.Function([x, y], z) x_np = np.random.uniform(size=x_shape).astype(dtype) y_np = np.random.uniform(size=y_shape).astype(dtype) - z_np = tvm.topi.testing.batch_matmul(x_np, y_np) + z_np = tvm.topi.testing.batch_matmul(x_np, y_np, trans_x=trans_x, trans_y=trans_y) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: @@ -353,60 +354,13 @@ def test_batch_matmul(): zz = run_infer_type(z) assert zz.checked_type == relay.TensorType((b, m, n), "float32") - verify_batch_matmul((1, 16, 32), (1, 16, 32), (1, 16, 16)) - verify_batch_matmul((5, 16, 32), (5, 16, 32), (5, 16, 16)) - verify_batch_matmul((5, 16, 32), (5, 20, 32), (5, 16, 20)) - verify_batch_matmul((30, 16, 32), (30, 20, 32), (30, 16, 20)) - - -def verify_dynamic_batch_matmul( - x_shape, y_shape, out_shape, x_var_shape, y_var_shape, dtype="float32" -): - x = relay.var("x", relay.TensorType(x_var_shape, dtype)) - y = relay.var("y", relay.TensorType(y_var_shape, dtype)) - z = relay.nn.batch_matmul(x, y) - - func = relay.Function([x, y], z) - x_np = np.random.uniform(size=x_shape).astype(dtype) - y_np = np.random.uniform(size=y_shape).astype(dtype) - z_np = tvm.topi.testing.batch_matmul(x_np, y_np) - - for target, dev in tvm.testing.enabled_targets(): - for kind in ["vm", "debug"]: - mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - z = intrp.evaluate()(x_np, y_np) - tvm.testing.assert_allclose(z.numpy(), z_np, rtol=1e-5) - - -# TODO(mbrookhart): enable once VM supports heterogenous execution -# @tvm.testing.uses_gpu -def test_dynamic_batch_matmul(): - verify_dynamic_batch_matmul( - (1, 16, 32), (1, 16, 32), (1, 16, 16), (1, 16, 32), (relay.Any(),) * 3 - ) - verify_dynamic_batch_matmul( - (5, 16, 32), (5, 16, 32), (5, 16, 16), (5, 16, 32), (relay.Any(),) * 3 - ) - verify_dynamic_batch_matmul( - (5, 16, 32), (5, 20, 32), (5, 16, 20), (5, 16, 32), (relay.Any(),) * 3 - ) - verify_dynamic_batch_matmul( - (30, 16, 32), (30, 20, 32), (30, 16, 20), (30, 16, 32), (relay.Any(),) * 3 - ) - - verify_dynamic_batch_matmul( - (1, 16, 32), (1, 16, 32), (1, 16, 16), (relay.Any(), 16, 32), (relay.Any(), 16, 32) - ) - verify_dynamic_batch_matmul( - (5, 16, 32), (5, 16, 32), (5, 16, 16), (relay.Any(), 16, 32), (relay.Any(), 16, 32) - ) - verify_dynamic_batch_matmul( - (5, 16, 32), (5, 20, 32), (5, 16, 20), (relay.Any(), 16, 32), (relay.Any(), 20, 32) - ) - verify_dynamic_batch_matmul( - (30, 16, 32), (30, 20, 32), (30, 16, 20), (relay.Any(), 16, 32), (relay.Any(), 20, 32) - ) + verify_batch_matmul((1, 16, 32), (1, 16, 32), (1, 16, 16), trans_x=False, trans_y=True) + verify_batch_matmul((5, 16, 32), (5, 16, 32), (5, 16, 16), trans_x=False, trans_y=True) + verify_batch_matmul((5, 16, 32), (5, 20, 32), (5, 16, 20), trans_x=False, trans_y=True) + verify_batch_matmul((30, 16, 32), (30, 20, 32), (30, 16, 20), trans_x=False, trans_y=True) + verify_batch_matmul((1, 32, 16), (1, 16, 32), (1, 16, 16), trans_x=True, trans_y=True) + verify_batch_matmul((5, 16, 32), (5, 32, 16), (5, 16, 16), trans_x=False, trans_y=False) + verify_batch_matmul((5, 32, 16), (5, 32, 20), (5, 16, 20), trans_x=True, trans_y=False) @tvm.testing.uses_gpu @@ -639,15 +593,4 @@ def _verify(prediction_shape, reduction="mean", ignore_index=-100, dtype="float3 if __name__ == "__main__": - test_adaptive_pool() - test_collapse_sum_like() - test_broadcast_to() - test_broadcast_to_like() - test_slice_like() - test_reverse_reshape() - test_batch_matmul() - test_shape_of() - test_sequence_mask() - test_one_hot() - test_ndarray_size() - test_matrix_set_diag() + pytest.main([__file__])