diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 861a73aa2ad8..4952cefaf0e3 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1599,6 +1599,9 @@ def _impl(inputs, attr, params, mod): data_shape = get_const_tuple(in_type.checked_type.shape) data_dim = len(data_shape) stride_dim = len(stride) + if data_dim == 0 and isinstance(inputs[0], _expr.Constant): + new_data = inputs[0].data.asnumpy().reshape(1) + return _expr.const(new_data, inputs[0].data.dtype) # This is a special routine to handle strided_slice after shape_of. # We need this since in some cases we want to do strided_slice on diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 1f9d150f6ec6..6daaf61fbce6 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -170,7 +170,8 @@ def run_tvm_graph( m = graph_runtime.create(graph, lib, ctx) # set inputs for e, i in zip(input_node, input_data): - m.set_input(e, tvm.nd.array(i)) + if e != "": + m.set_input(e, tvm.nd.array(i)) m.set_input(**params) # execute @@ -192,8 +193,10 @@ def run_tf_graph(sess, input_data, input_node, output_node): tensor = [sess.graph.get_tensor_by_name(output_name) for output_name in output_node] input_dict = {e: input_data[i] for i, e in enumerate(input_node)} - - output_data = sess.run(tensor, input_dict) + if len(input_node) == 1 and input_node[0] == "": + output_data = sess.run(tensor) + else: + output_data = sess.run(tensor, input_dict) return output_data @@ -1826,8 +1829,12 @@ def _test_stridedslice( """ One iteration of a Stridedslice """ tf.reset_default_graph() + np_data = np.random.uniform(size=ip_shape).astype(dtype) with tf.Graph().as_default(): - in_data = tf.placeholder(dtype, ip_shape, name="in_data") + if len(ip_shape) == 0: + in_data = tf.constant(np_data, dtype) + else: + in_data = tf.placeholder(dtype, ip_shape, name="in_data") tf.strided_slice( in_data, begin, @@ -1840,56 +1847,58 @@ def _test_stridedslice( ellipsis_mask=ellipsis_mask, name="strided_slice", ) - np_data = np.random.uniform(size=ip_shape).astype(dtype) - - compare_tf_with_tvm(np_data, "in_data:0", "strided_slice:0") + if len(ip_shape) == 0: + compare_tf_with_tvm(None, "", "strided_slice:0") + else: + compare_tf_with_tvm(np_data, "in_data:0", "strided_slice:0") def test_forward_stridedslice(): """test StridedSlice""" - _test_stridedslice((2), [1], [1], [1], "float32", shrink_axis_mask=1) - _test_stridedslice((2, 1), [0], [1], [1], "float32", shrink_axis_mask=1) - _test_stridedslice((2, 3, 4), [0], [1], [1], "float32", shrink_axis_mask=8) - _test_stridedslice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], "float32") - _test_stridedslice((3, 4, 3), [1, 0], [4, 3], [2, 1], "float32", ellipsis_mask=8) - _test_stridedslice((3, 4, 3), [1, 0], [4, 2], [2, 1], "float32", ellipsis_mask=2) - _test_stridedslice((3, 4, 5, 3), [1, 0], [4, 2], [2, 1], "float32", ellipsis_mask=2) - _test_stridedslice((3, 4, 5, 3), [1, 0, 1], [4, 2, 2], [2, 1, 1], "float32", ellipsis_mask=2) - _test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 2], [2, 1, 1], "float32", new_axis_mask=5) + _test_stridedslice([], [0], [0], [1], "float32", new_axis_mask=1) + _test_stridedslice([2], [1], [1], [1], "float32", shrink_axis_mask=1) + _test_stridedslice([2, 1], [0], [1], [1], "float32", shrink_axis_mask=1) + _test_stridedslice([2, 3, 4], [0], [1], [1], "float32", shrink_axis_mask=8) + _test_stridedslice([3, 4, 3], [1, -1, 0], [4, -5, 3], [2, -1, 1], "float32") + _test_stridedslice([3, 4, 3], [1, 0], [4, 3], [2, 1], "float32", ellipsis_mask=8) + _test_stridedslice([3, 4, 3], [1, 0], [4, 2], [2, 1], "float32", ellipsis_mask=2) + _test_stridedslice([3, 4, 5, 3], [1, 0], [4, 2], [2, 1], "float32", ellipsis_mask=2) + _test_stridedslice([3, 4, 5, 3], [1, 0, 1], [4, 2, 2], [2, 1, 1], "float32", ellipsis_mask=2) + _test_stridedslice([3, 4, 3], [1, 1, 0], [4, 4, 2], [2, 1, 1], "float32", new_axis_mask=5) _test_stridedslice( - (3, 4, 3), [1, 1, 1], [4, 4, 1], [2, 1, 1], "float32", ellipsis_mask=2, new_axis_mask=4 + [3, 4, 3], [1, 1, 1], [4, 4, 1], [2, 1, 1], "float32", ellipsis_mask=2, new_axis_mask=4 ) _test_stridedslice( - (6, 4, 5), [1, 1, 1], [6, 3, 4], [2, 1, 1], "float32", ellipsis_mask=2, new_axis_mask=5 + [6, 4, 5], [1, 1, 1], [6, 3, 4], [2, 1, 1], "float32", ellipsis_mask=2, new_axis_mask=5 ) _test_stridedslice( - (3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], "float32", ellipsis_mask=4, new_axis_mask=2 + [3, 4, 3], [1, 1, 2], [4, 4, 3], [2, 1, 1], "float32", ellipsis_mask=4, new_axis_mask=2 ) _test_stridedslice( - (3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], "float32", ellipsis_mask=2, new_axis_mask=3 + [3, 4, 3], [1, 1, 2], [4, 4, 3], [2, 1, 1], "float32", ellipsis_mask=2, new_axis_mask=3 ) _test_stridedslice( - (3, 4, 3), [1, 1, 0], [4, 4, 1], [2, 1, 1], "float32", ellipsis_mask=2, new_axis_mask=3 + [3, 4, 3], [1, 1, 0], [4, 4, 1], [2, 1, 1], "float32", ellipsis_mask=2, new_axis_mask=3 ) _test_stridedslice( - (3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], "float32", ellipsis_mask=2, new_axis_mask=2 + [3, 4, 3], [1, 1, 2], [4, 4, 3], [2, 1, 1], "float32", ellipsis_mask=2, new_axis_mask=2 ) _test_stridedslice((3, 4), [1, 0], [4, 4], [1, 1], "float32", shrink_axis_mask=2) _test_stridedslice( - (3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], "float32", shrink_axis_mask=2, new_axis_mask=2 + [3, 4, 3], [1, 1, 0], [4, 4, 3], [2, 1, 1], "float32", shrink_axis_mask=2, new_axis_mask=2 ) _test_stridedslice( - (3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], "float32", shrink_axis_mask=1, new_axis_mask=2 + [3, 4, 3], [1, 1, 0], [4, 4, 3], [2, 1, 1], "float32", shrink_axis_mask=1, new_axis_mask=2 ) _test_stridedslice( - (3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], "float32", shrink_axis_mask=2, new_axis_mask=1 + [3, 4, 3], [1, 1, 0], [4, 4, 3], [2, 1, 1], "float32", shrink_axis_mask=2, new_axis_mask=1 ) _test_stridedslice( - (3, 4, 5, 4, 5, 6), [0, 0], [2, 3], [1, 1], "float32", shrink_axis_mask=5, new_axis_mask=1 + [3, 4, 5, 4, 5, 6], [0, 0], [2, 3], [1, 1], "float32", shrink_axis_mask=5, new_axis_mask=1 ) _test_stridedslice( - (3, 4, 5, 4, 5, 6), + [3, 4, 5, 4, 5, 6], [0, 0, 1, 2, 1], [2, 3, 4, 5, 3], [1, 1, 2, 2, 1], @@ -1901,7 +1910,7 @@ def test_forward_stridedslice(): end_mask=8, ) _test_stridedslice( - (3, 4, 5, 4, 5, 6), + [3, 4, 5, 4, 5, 6], [0, 0, 1, 2, 1], [2, 3, 4, 5, 3], [1, 1, 2, 2, 1], @@ -1913,7 +1922,7 @@ def test_forward_stridedslice(): end_mask=5, ) _test_stridedslice( - (3, 4, 5, 4, 5, 6), + [3, 4, 5, 4, 5, 6], [0, 0, 1, 2, 1], [2, 3, 4, 5, 3], [1, 1, 2, 2, 1], @@ -1925,7 +1934,7 @@ def test_forward_stridedslice(): end_mask=5, ) _test_stridedslice( - (3, 4, 5, 4, 5, 6), + [3, 4, 5, 4, 5, 6], [1, 2, 0, -3], [4, 5, 3, 3], [2, 2, 1, 1], @@ -1937,7 +1946,7 @@ def test_forward_stridedslice(): end_mask=8, ) _test_stridedslice( - (1, 13, 13, 3, 2), + [1, 13, 13, 3, 2], [0, 0], [1, 1], [1, -1],