From 0f0a57e0880db1a689af2f4438d4ea98adb03795 Mon Sep 17 00:00:00 2001 From: monklof Date: Mon, 22 Feb 2021 16:04:44 +0800 Subject: [PATCH 01/13] [TOPI] Dense cuda schedule support dynamic dimension --- python/tvm/topi/cuda/dense.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/python/tvm/topi/cuda/dense.py b/python/tvm/topi/cuda/dense.py index f8abe4d4d799..7f698ecfe7b8 100644 --- a/python/tvm/topi/cuda/dense.py +++ b/python/tvm/topi/cuda/dense.py @@ -42,11 +42,8 @@ def dense_cublas(cfg, data, weight, bias=None, out_dtype=None): batch, in_dim = data.shape out_dim, _ = weight.shape matmul = cublas.matmul(data, weight, False, True) - if isinstance(batch, int): + if isinstance(batch, int) and isinstance(in_dim, int) and isinstance(out_dim, int): cfg.add_flop(batch * in_dim * out_dim * 2) - elif isinstance(batch, tir.IntImm): - cfg.add_flop(batch.value * in_dim * out_dim * 2) - # if we get a te.Var, we cannot add flop counts if bias is not None: matmul = te.compute( (batch, out_dim), lambda i, j: matmul[i, j] + bias[j], tag=tag.BROADCAST From d1ae24fa3be00849e32c102301a654384cef6925 Mon Sep 17 00:00:00 2001 From: monklof Date: Mon, 22 Feb 2021 16:11:04 +0800 Subject: [PATCH 02/13] [TOPI] batch_matmul cublas te computation support dynamism --- python/tvm/topi/cuda/batch_matmul.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/topi/cuda/batch_matmul.py b/python/tvm/topi/cuda/batch_matmul.py index 006b866d6bad..f9f332fdaccb 100644 --- a/python/tvm/topi/cuda/batch_matmul.py +++ b/python/tvm/topi/cuda/batch_matmul.py @@ -161,7 +161,8 @@ def batch_matmul_cublas(cfg, x, y, out_shape=None): """ b, m, k = x.shape b, n, k = y.shape - cfg.add_flop(b * m * k * n * 2) + if isinstance(b, int) and isinstance(m, int) and isinstance(n, int) and isinstance(k, int): + cfg.add_flop(b * m * k * n * 2) return cublas.batch_matmul(x, y, False, True) From 04a7a5c100fcab23506486e44e2d7a7f45f85ae5 Mon Sep 17 00:00:00 2001 From: monklof Date: Mon, 22 Feb 2021 16:20:07 +0800 Subject: [PATCH 03/13] [Frontend] tensorflow frontend: dynamic support for BatchMatmul --- python/tvm/relay/frontend/tensorflow.py | 57 ++++++++++++++++++++++--- 1 file changed, 50 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index ac52ab768066..21bbabd6d3d2 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -44,6 +44,24 @@ __all__ = ["from_tensorflow"] + +def is_symbolic_shape(shape): + return not all([isinstance(dim, (int, tvm.tir.IntImm)) for dim in shape]) + +def list_shape_of(tensor, ndim): + shape_tensor = _op.shape_of(tensor) + return [_op.strided_slice(shape_tensor, begin=[i], end=[i+1], strides=[1]) for i in range(ndim)] + +def concat_dynamic_shape(shape_list): + new_shape = [] + for dim in shape_list: + if isinstance(dim, (int, tvm.tir.IntImm)): + new_shape.append(_op.expand_dims(_op.const(dim, 'int32'), axis=0)) + else: # expected to be tensor[1] + new_shape.append(dim) + + return _op.concatenate(_op.Tuple(new_shape), axis=0) + def _get_pad_pair(input1d, kernel1d, stride1d): if input1d % stride1d == 0: pad = max(kernel1d - stride1d, 0) @@ -919,13 +937,31 @@ def _impl(inputs, attr, params, mod): input_y = inputs[1] orig_shape_x = _infer_shape(input_x, mod) orig_shape_y = _infer_shape(input_y, mod) + ndim = len(orig_shape_x) + + is_static = not is_symbolic_shape(orig_shape_x) + + if len(orig_shape_x) > 3 and not is_static: + shape_of_x = list_shape_of(inputs[0], ndim) + shape_of_y = list_shape_of(inputs[1], ndim) # reshape n-dimensional batch matmul into 3d if len(orig_shape_x) > 3: outer_dims = [orig_shape_x[i] for i in range(0, len(orig_shape_x) - 2)] - num_outer_elts = np.prod(outer_dims) - new_shape_x = (num_outer_elts, orig_shape_x[-2], orig_shape_x[-1]) - new_shape_y = (num_outer_elts, orig_shape_y[-2], orig_shape_y[-1]) + if is_static: + num_outer_elts = np.prod(outer_dims) + new_shape_x = (num_outer_elts, orig_shape_x[-2], orig_shape_x[-1]) + new_shape_y = (num_outer_elts, orig_shape_y[-2], orig_shape_y[-1]) + else: # handle dynamic shape (dyn.reshape op) + # new shape = [prod(shape[:-2]), -2, -1] + new_shape_x = [_op.const(1), shape_of_x[-2], shape_of_x[-1]] + new_shape_y = [_op.const(1), shape_of_y[-2], shape_of_y[-1]] + for i in range(ndim-2): + new_shape_x[0] *= shape_of_x[i] + new_shape_y[0] *= shape_of_y[i] + new_shape_x = _op.concatenate(_op.Tuple(new_shape_x), axis=0) + new_shape_y = _op.concatenate(_op.Tuple(new_shape_y), axis=0) + input_x = _op.reshape(input_x, newshape=new_shape_x) input_y = _op.reshape(input_y, newshape=new_shape_y) @@ -937,11 +973,18 @@ def _impl(inputs, attr, params, mod): # reshape result back to n-dimensional if len(orig_shape_x) > 3: - final_shape = list(orig_shape_x) - final_shape[-2] = orig_shape_x[-1] if adj_x else orig_shape_x[-2] - final_shape[-1] = orig_shape_y[-2] if adj_y else orig_shape_y[-1] - ret = _op.reshape(ret, newshape=final_shape) + if is_static: + final_shape = list(orig_shape_x) + final_shape[-2] = orig_shape_x[-1] if adj_x else orig_shape_x[-2] + final_shape[-1] = orig_shape_y[-2] if adj_y else orig_shape_y[-1] + else: + # calculate the resulting shape = [shape[:-2], 0, 0] + final_shape = list(shape_of_x) + final_shape[-2] = shape_of_x[-1] if adj_x else shape_of_x[-2] + final_shape[-1] = shape_of_y[-2] if adj_y else shape_of_y[-1] + final_shape = _op.concatenate(_op.Tuple(final_shape), axis=0) + ret = _op.reshape(ret, newshape=final_shape) return ret return _impl From 8698f05870d5f8d3a457242ffce44904d932f98f Mon Sep 17 00:00:00 2001 From: monklof Date: Mon, 22 Feb 2021 16:26:42 +0800 Subject: [PATCH 04/13] [TOPI] nn batch_matmul te computation support dynamism --- python/tvm/topi/nn/batch_matmul.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/python/tvm/topi/nn/batch_matmul.py b/python/tvm/topi/nn/batch_matmul.py index 9c5848129397..fb97ef9b58e2 100644 --- a/python/tvm/topi/nn/batch_matmul.py +++ b/python/tvm/topi/nn/batch_matmul.py @@ -18,7 +18,8 @@ # pylint: disable=invalid-name import tvm from tvm import te, auto_scheduler -from ..utils import get_const_tuple +from tvm import tir +from ..util import get_const_tuple def batch_matmul(x, y, oshape=None, auto_scheduler_rewritten_layout=""): @@ -61,9 +62,17 @@ def batch_matmul(x, y, oshape=None, auto_scheduler_rewritten_layout=""): _, M, K = x.shape k = te.reduce_axis((0, K), name="k") if oshape is None: - assert XB == YB or XB == 1 or YB == 1, "batch dimension doesn't match" - assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistant" - batch = te.max(XB, YB) + if isinstance(XB, int) and isinstance(YB, int): + assert XB == YB or XB == 1 or YB == 1, "batch dimension doesn't match" + batch = max(XB, YB) + elif isinstance(XB, tir.expr.Var): + batch = XB + else: + batch = YB + + if isinstance(x_shape[2], int) and isinstance(y_shape[2], int): + assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistant" + N = y.shape[1] oshape = (batch, M, N) From a6048a74660de020caba7134653affab59803924 Mon Sep 17 00:00:00 2001 From: monklof Date: Tue, 23 Feb 2021 11:27:45 +0800 Subject: [PATCH 05/13] fix CI --- python/tvm/relay/frontend/tensorflow.py | 20 ++++++++++++-------- python/tvm/topi/cuda/dense.py | 2 +- python/tvm/topi/nn/batch_matmul.py | 2 +- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 21bbabd6d3d2..2d7383bd42d9 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -44,24 +44,28 @@ __all__ = ["from_tensorflow"] - -def is_symbolic_shape(shape): +def check_symbolic_shape(shape): return not all([isinstance(dim, (int, tvm.tir.IntImm)) for dim in shape]) + def list_shape_of(tensor, ndim): shape_tensor = _op.shape_of(tensor) - return [_op.strided_slice(shape_tensor, begin=[i], end=[i+1], strides=[1]) for i in range(ndim)] + return [ + _op.strided_slice(shape_tensor, begin=[i], end=[i + 1], strides=[1]) for i in range(ndim) + ] + def concat_dynamic_shape(shape_list): new_shape = [] for dim in shape_list: if isinstance(dim, (int, tvm.tir.IntImm)): - new_shape.append(_op.expand_dims(_op.const(dim, 'int32'), axis=0)) - else: # expected to be tensor[1] + new_shape.append(_op.expand_dims(_op.const(dim, "int32"), axis=0)) + else: # expected to be tensor[1] new_shape.append(dim) return _op.concatenate(_op.Tuple(new_shape), axis=0) + def _get_pad_pair(input1d, kernel1d, stride1d): if input1d % stride1d == 0: pad = max(kernel1d - stride1d, 0) @@ -939,7 +943,7 @@ def _impl(inputs, attr, params, mod): orig_shape_y = _infer_shape(input_y, mod) ndim = len(orig_shape_x) - is_static = not is_symbolic_shape(orig_shape_x) + is_static = not check_symbolic_shape(orig_shape_x) if len(orig_shape_x) > 3 and not is_static: shape_of_x = list_shape_of(inputs[0], ndim) @@ -952,11 +956,11 @@ def _impl(inputs, attr, params, mod): num_outer_elts = np.prod(outer_dims) new_shape_x = (num_outer_elts, orig_shape_x[-2], orig_shape_x[-1]) new_shape_y = (num_outer_elts, orig_shape_y[-2], orig_shape_y[-1]) - else: # handle dynamic shape (dyn.reshape op) + else: # handle dynamic shape (dyn.reshape op) # new shape = [prod(shape[:-2]), -2, -1] new_shape_x = [_op.const(1), shape_of_x[-2], shape_of_x[-1]] new_shape_y = [_op.const(1), shape_of_y[-2], shape_of_y[-1]] - for i in range(ndim-2): + for i in range(ndim - 2): new_shape_x[0] *= shape_of_x[i] new_shape_y[0] *= shape_of_y[i] new_shape_x = _op.concatenate(_op.Tuple(new_shape_x), axis=0) diff --git a/python/tvm/topi/cuda/dense.py b/python/tvm/topi/cuda/dense.py index 7f698ecfe7b8..431436307d26 100644 --- a/python/tvm/topi/cuda/dense.py +++ b/python/tvm/topi/cuda/dense.py @@ -17,7 +17,7 @@ # pylint: disable=invalid-name, unused-argument """Schedule for dense operator""" import logging -from tvm import te, tir +from tvm import te import tvm.autotvm as autotvm from tvm.autotvm.task.space import SplitEntity from tvm.contrib import cublas diff --git a/python/tvm/topi/nn/batch_matmul.py b/python/tvm/topi/nn/batch_matmul.py index fb97ef9b58e2..637f2fae75ee 100644 --- a/python/tvm/topi/nn/batch_matmul.py +++ b/python/tvm/topi/nn/batch_matmul.py @@ -19,7 +19,7 @@ import tvm from tvm import te, auto_scheduler from tvm import tir -from ..util import get_const_tuple +from ..utils import get_const_tuple def batch_matmul(x, y, oshape=None, auto_scheduler_rewritten_layout=""): From 42f79a42f22a806f59637afa24fa4016c61f667b Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Fri, 26 Feb 2021 11:04:43 +0800 Subject: [PATCH 06/13] Update python/tvm/topi/nn/batch_matmul.py Co-authored-by: Cody Yu --- python/tvm/topi/nn/batch_matmul.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/topi/nn/batch_matmul.py b/python/tvm/topi/nn/batch_matmul.py index 637f2fae75ee..40197c57f108 100644 --- a/python/tvm/topi/nn/batch_matmul.py +++ b/python/tvm/topi/nn/batch_matmul.py @@ -71,7 +71,7 @@ def batch_matmul(x, y, oshape=None, auto_scheduler_rewritten_layout=""): batch = YB if isinstance(x_shape[2], int) and isinstance(y_shape[2], int): - assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistant" + assert x_shape[2] == y_shape[2], "shapes of x and y are inconsistant" N = y.shape[1] oshape = (batch, M, N) From 8cf2b9fe81fb3093b9c39c3aa3ea6f1d39cd5f60 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Fri, 26 Feb 2021 11:05:28 +0800 Subject: [PATCH 07/13] Update python/tvm/topi/cuda/batch_matmul.py Co-authored-by: Cody Yu --- python/tvm/topi/cuda/batch_matmul.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/topi/cuda/batch_matmul.py b/python/tvm/topi/cuda/batch_matmul.py index f9f332fdaccb..3225a2aa52d3 100644 --- a/python/tvm/topi/cuda/batch_matmul.py +++ b/python/tvm/topi/cuda/batch_matmul.py @@ -161,7 +161,7 @@ def batch_matmul_cublas(cfg, x, y, out_shape=None): """ b, m, k = x.shape b, n, k = y.shape - if isinstance(b, int) and isinstance(m, int) and isinstance(n, int) and isinstance(k, int): + if all([isinstance(s, int) for s in [b, m, n, k]]): cfg.add_flop(b * m * k * n * 2) return cublas.batch_matmul(x, y, False, True) From 1d724d0410cf0822b9fd1a3e9341f80725188902 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Fri, 26 Feb 2021 11:17:39 +0800 Subject: [PATCH 08/13] remove concat_dynamic_shape function --- python/tvm/relay/frontend/tensorflow.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 2d7383bd42d9..339d15516b22 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -55,17 +55,6 @@ def list_shape_of(tensor, ndim): ] -def concat_dynamic_shape(shape_list): - new_shape = [] - for dim in shape_list: - if isinstance(dim, (int, tvm.tir.IntImm)): - new_shape.append(_op.expand_dims(_op.const(dim, "int32"), axis=0)) - else: # expected to be tensor[1] - new_shape.append(dim) - - return _op.concatenate(_op.Tuple(new_shape), axis=0) - - def _get_pad_pair(input1d, kernel1d, stride1d): if input1d % stride1d == 0: pad = max(kernel1d - stride1d, 0) From f8861528cfee8ba627b2defcd5b7d44db887c4b5 Mon Sep 17 00:00:00 2001 From: monklof Date: Tue, 16 Mar 2021 13:47:46 +0800 Subject: [PATCH 09/13] update topi dense op integer checking --- python/tvm/topi/cuda/dense.py | 6 +++--- python/tvm/topi/nn/batch_matmul.py | 14 +++----------- 2 files changed, 6 insertions(+), 14 deletions(-) diff --git a/python/tvm/topi/cuda/dense.py b/python/tvm/topi/cuda/dense.py index 431436307d26..b237c6cd6ceb 100644 --- a/python/tvm/topi/cuda/dense.py +++ b/python/tvm/topi/cuda/dense.py @@ -39,10 +39,10 @@ def dense_cublas(cfg, data, weight, bias=None, out_dtype=None): if out_dtype is None: out_dtype = data.dtype assert out_dtype == data.dtype, "Mixed precision not supported." - batch, in_dim = data.shape - out_dim, _ = weight.shape + batch, in_dim = get_const_tuple(data.shape) + out_dim, _ = get_const_tuple(weight.shape) matmul = cublas.matmul(data, weight, False, True) - if isinstance(batch, int) and isinstance(in_dim, int) and isinstance(out_dim, int): + if all(isinstance(d, int) for d in [batch, in_dim, out_dim]): cfg.add_flop(batch * in_dim * out_dim * 2) if bias is not None: matmul = te.compute( diff --git a/python/tvm/topi/nn/batch_matmul.py b/python/tvm/topi/nn/batch_matmul.py index 40197c57f108..fe91adc710bc 100644 --- a/python/tvm/topi/nn/batch_matmul.py +++ b/python/tvm/topi/nn/batch_matmul.py @@ -62,17 +62,9 @@ def batch_matmul(x, y, oshape=None, auto_scheduler_rewritten_layout=""): _, M, K = x.shape k = te.reduce_axis((0, K), name="k") if oshape is None: - if isinstance(XB, int) and isinstance(YB, int): - assert XB == YB or XB == 1 or YB == 1, "batch dimension doesn't match" - batch = max(XB, YB) - elif isinstance(XB, tir.expr.Var): - batch = XB - else: - batch = YB - - if isinstance(x_shape[2], int) and isinstance(y_shape[2], int): - assert x_shape[2] == y_shape[2], "shapes of x and y are inconsistant" - + assert XB == YB or XB == 1 or YB == 1, "batch dimension doesn't match" + assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistent" + batch = te.max(XB, YB) N = y.shape[1] oshape = (batch, M, N) From 4085cb3dd94308626a5a1b0ba72c05cdcfd12ad9 Mon Sep 17 00:00:00 2001 From: monklof Date: Tue, 16 Mar 2021 16:03:45 +0800 Subject: [PATCH 10/13] fix ci --- python/tvm/topi/nn/batch_matmul.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/topi/nn/batch_matmul.py b/python/tvm/topi/nn/batch_matmul.py index fe91adc710bc..b6ed5a373e81 100644 --- a/python/tvm/topi/nn/batch_matmul.py +++ b/python/tvm/topi/nn/batch_matmul.py @@ -18,7 +18,6 @@ # pylint: disable=invalid-name import tvm from tvm import te, auto_scheduler -from tvm import tir from ..utils import get_const_tuple From 971be64960c2aa43317f04b4a519b29505c09570 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Wed, 17 Mar 2021 11:25:44 +0800 Subject: [PATCH 11/13] Update python/tvm/relay/frontend/tensorflow.py Co-authored-by: Cody Yu --- python/tvm/relay/frontend/tensorflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 339d15516b22..9b4627959cdd 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -934,7 +934,7 @@ def _impl(inputs, attr, params, mod): is_static = not check_symbolic_shape(orig_shape_x) - if len(orig_shape_x) > 3 and not is_static: + if ndim > 3 and not is_static: shape_of_x = list_shape_of(inputs[0], ndim) shape_of_y = list_shape_of(inputs[1], ndim) From d7709c45216a3186dbb3ab00790f785eb71a41a5 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Wed, 17 Mar 2021 11:32:10 +0800 Subject: [PATCH 12/13] Update batch_matmul.py --- python/tvm/topi/cuda/batch_matmul.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/cuda/batch_matmul.py b/python/tvm/topi/cuda/batch_matmul.py index 3225a2aa52d3..04e484f526d2 100644 --- a/python/tvm/topi/cuda/batch_matmul.py +++ b/python/tvm/topi/cuda/batch_matmul.py @@ -159,8 +159,8 @@ def batch_matmul_cublas(cfg, x, y, out_shape=None): output : tvm.te.Tensor 3-D with shape [batch, M, N] """ - b, m, k = x.shape - b, n, k = y.shape + b, m, k = get_const_tuple(x.shape) + b, n, k = get_const_tuple(y.shape) if all([isinstance(s, int) for s in [b, m, n, k]]): cfg.add_flop(b * m * k * n * 2) return cublas.batch_matmul(x, y, False, True) From 69ab8ada78517735d40b3c49042c41b21466cbb3 Mon Sep 17 00:00:00 2001 From: monklof Date: Wed, 17 Mar 2021 14:11:01 +0800 Subject: [PATCH 13/13] [Frontend] add test for batch_matmul in dynamic shaped case --- python/tvm/relay/frontend/tensorflow.py | 4 +- .../frontend/tensorflow/test_forward.py | 52 ++++++++++++++++++- 2 files changed, 53 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 9b4627959cdd..146f5646f30c 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -939,7 +939,7 @@ def _impl(inputs, attr, params, mod): shape_of_y = list_shape_of(inputs[1], ndim) # reshape n-dimensional batch matmul into 3d - if len(orig_shape_x) > 3: + if ndim > 3: outer_dims = [orig_shape_x[i] for i in range(0, len(orig_shape_x) - 2)] if is_static: num_outer_elts = np.prod(outer_dims) @@ -965,7 +965,7 @@ def _impl(inputs, attr, params, mod): ret = get_relay_op("batch_matmul")(input_x, input_y) # reshape result back to n-dimensional - if len(orig_shape_x) > 3: + if ndim > 3: if is_static: final_shape = list(orig_shape_x) final_shape[-2] = orig_shape_x[-1] if adj_x else orig_shape_x[-2] diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index ecf6441bc6b9..11b212dbbadc 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -210,6 +210,7 @@ def compare_tf_with_tvm( mode="graph_runtime", cuda_layout="NCHW", add_shapes_to_graph_def=True, + targets=None, ): """Generic function to generate and compare tensorflow and TVM output""" @@ -233,13 +234,18 @@ def name_without_num(name): tf_output = run_tf_graph(sess, in_data, in_name, out_name) - for device in ["llvm", "cuda"]: + devices = targets if targets else ["llvm", "cuda"] + + for device in devices: ctx = tvm.context(device, 0) if not tvm.testing.device_enabled(device): print("Skip because %s is not enabled" % device) continue if no_gpu and device == "cuda": continue + if "cublas" in device and not tvm.get_global_func("tvm.contrib.cublas.matmul", True): + print("Skip because cublas is not enabled: %s" % device) + continue tvm_output = run_tvm_graph( final_graph_def, @@ -1781,6 +1787,23 @@ def _test_batch_matmul(A_shape, B_shape, dtype, adjoint_a=False, adjoint_b=False compare_tf_with_tvm([A_np, B_np], [A.name, B.name], result.name) +def _test_batch_matmul_dynamic( + A_shape, B_shape, A_np_shape, B_np_shape, dtype, adjoint_a=False, adjoint_b=False +): + with tf.Graph().as_default(): + A = tf.placeholder(shape=A_shape, dtype=dtype, name="A") + B = tf.placeholder(shape=B_shape, dtype=dtype, name="B") + result = tf.matmul(A, B, adjoint_a=adjoint_a, adjoint_b=adjoint_b, name="batchmatmul") + + A_np = np.random.uniform(high=5.0, size=A_np_shape).astype(dtype) + B_np = np.random.uniform(high=5.0, size=B_np_shape).astype(dtype) + # for now, in TOPI, only cublas's implementation support dynamic shape + # TODO add more backends support in TOPI + compare_tf_with_tvm( + [A_np, B_np], [A.name, B.name], result.name, mode="vm", targets=["cuda -libs=cublas"] + ) + + def test_forward_batch_matmul(): """ TF op BatchMatMul, BatchMatMulV2 test""" _test_batch_matmul((3, 5, 4), (3, 4, 5), "int32") @@ -1793,6 +1816,33 @@ def test_forward_batch_matmul(): _test_batch_matmul((2, 3, 4, 2, 3, 4, 5, 6), (2, 3, 4, 2, 3, 4, 5, 6), "float32", False, True) +@tvm.testing.requires_cuda +def test_forward_batch_matmul_dynamic(): + _test_batch_matmul_dynamic((None, 5, 4), (None, 4, 5), (3, 5, 4), (3, 4, 5), "int32") + _test_batch_matmul_dynamic( + (None, 5, 4), (None, 4, 5), (3, 5, 4), (3, 4, 5), "float32", True, True + ) + _test_batch_matmul_dynamic( + (None, 5, 4), (None, 5, 4), (3, 5, 4), (3, 5, 4), "int32", True, False + ) + _test_batch_matmul_dynamic( + (None, 5, 4), (None, 5, 4), (3, 5, 4), (3, 5, 4), "float32", False, True + ) + _test_batch_matmul_dynamic( + (None, 4, 5, 6), (None, 4, 6, 5), (3, 4, 5, 6), (3, 4, 6, 5), "float32" + ) + _test_batch_matmul_dynamic( + (None, None, 5, 6), (None, None, 6, 5), (3, 4, 5, 6), (3, 4, 6, 5), "float32" + ) + _test_batch_matmul_dynamic( + (None, None, None, 5, 6), + (None, None, None, 6, 5), + (2, 3, 4, 5, 6), + (2, 3, 4, 6, 5), + "float32", + ) + + ####################################################################### # SparseTensorDenseMatMul # ----------------------------------