From 30c0e20c31c87128df9ab511f8bc54e75e53fbc3 Mon Sep 17 00:00:00 2001 From: Dmitriy Smirnov Date: Fri, 27 Nov 2020 13:19:23 +0000 Subject: [PATCH] [TFLite] pack operation extedned with const args pack operation now accepts constant arguments --- python/tvm/relay/frontend/tflite.py | 8 +++--- tests/python/frontend/tflite/test_forward.py | 26 ++++++++++++++------ 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 623aeee358a6..57ae91307a4c 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -2501,9 +2501,6 @@ def convert_pack(self, op): raise ImportError("The tflite package must be installed") input_tensors = self.get_input_tensors(op) - assert len(input_tensors) >= 1, "input tensors should greater than 1" - in_exprs = [self.get_expr(input_tensor.tensor_idx) for input_tensor in input_tensors] - output_tensors = self.get_output_tensors(op) assert len(output_tensors) == 1, "output tensors length should be 1" @@ -2512,8 +2509,11 @@ def convert_pack(self, op): pack_options = PackOptions() pack_options.Init(op_options.Bytes, op_options.Pos) pack_axis = pack_options.Axis() + pack_values_count = pack_options.ValuesCount() + assert len(input_tensors) == pack_values_count, "Discordance in input values count" - in_exprs_reshaped = [_op.expand_dims(i, axis=pack_axis, num_newaxis=1) for i in in_exprs] + in_exprs = [self.get_tensor_expr(_) for _ in input_tensors] + in_exprs_reshaped = [_op.expand_dims(_, axis=pack_axis, num_newaxis=1) for _ in in_exprs] out = _op.concatenate(in_exprs_reshaped, pack_axis) return out diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index b7f3b91f4243..5cfcee82e6b3 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -2673,27 +2673,36 @@ def test_forward_one_hot(): # ---- -def _test_pack(data, axis): +def _test_pack(data, is_var, axis): """ One iteration of pack """ assert len(data) >= 1 + assert len(data) == len(is_var) with tf.Graph().as_default(): in_data = [ - array_ops.placeholder(shape=tensor.shape, dtype=tensor.dtype, name="in_{}".format(idx)) - for idx, tensor in enumerate(data) + array_ops.placeholder(shape=d.shape, dtype=d.dtype, name="in_" + str(idx)) + if is_var[idx] + else constant_op.constant( + d, shape=d.shape, dtype=d.dtype, name="in_constant_" + str(idx) + ) + for idx, d in enumerate(data) ] - out = array_ops.pack(in_data, axis=axis) - name = ["in_{}:0".format(idx) for idx in range(len(data))] - compare_tflite_with_tvm(data, name, in_data, [out]) + out = array_ops.pack(in_data, axis=axis) + name = [_.name for _ in in_data] + compare_tflite_with_tvm(data, name, in_data, [out], experimental_new_converter=True) def test_forward_pack(): """ Pack """ - _test_pack([np.arange(6).reshape((1, 2, 1, 3)), np.arange(6).reshape((1, 2, 1, 3))], 1) + _test_pack([np.int32(1), np.int32(5)], [False, False], 0) + _test_pack([np.array([1, 4]), np.array([2, 5]), np.array([3, 6])], [True, False, False], 0) + _test_pack( + [np.arange(6).reshape((1, 2, 1, 3)), np.arange(6).reshape((1, 2, 1, 3))], [True, True], 1 + ) - _test_pack([np.arange(6).reshape((3, 2)), np.arange(6).reshape((3, 2))], 1) + _test_pack([np.arange(6).reshape((3, 2)), np.arange(6).reshape((3, 2))], [True, True], 1) _test_pack( [ @@ -2701,6 +2710,7 @@ def test_forward_pack(): np.arange(6).reshape((2, 1, 1, 3)), np.arange(6).reshape((2, 1, 1, 3)), ], + [True, True, True], 1, )