From 4fc444c936e46b88a85ac461bf730b849c9e9a95 Mon Sep 17 00:00:00 2001 From: tkclimb Date: Tue, 3 Dec 2019 15:21:50 +0900 Subject: [PATCH 1/5] Add Expand to onnx.py --- python/tvm/relay/frontend/onnx.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 07e69564dde9..7e05e5905d00 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1080,6 +1080,18 @@ class Or(Elemwise): def _impl_v7(cls, inputs, attr, params): return _op.logical_or(inputs[0], inputs[1]) +class Expand(OnnxOpConverter): + """ Operator converter for Expand. + """ + + @classmethod + def _impl_v1(cls, inputs, attr, params): + assert len(inputs) == 2, "ONNX Expand must have 2 inputs." + assert isinstance(inputs[1], tvm.relay.expr.Var), "2nd operand of ONNX Expand must be a constant." + shape_data = params[inputs[1].name_hint] + shape = tuple(shape_data.asnumpy()) + return _op.broadcast_to(inputs[0], shape=shape) + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -1187,6 +1199,7 @@ def _get_convert_map(opset): # defs/tensor 'Cast': Cast.get_converter(opset), 'Reshape': Reshape.get_converter(opset), + 'Expand': Expand.get_converter(opset), 'Concat': Concat.get_converter(opset), 'Split': Split.get_converter(opset), 'Slice': Slice.get_converter(opset), From 15a556d985299603342ab05c95f1f6f11a426a39 Mon Sep 17 00:00:00 2001 From: tkclimb Date: Mon, 9 Dec 2019 17:28:33 +0900 Subject: [PATCH 2/5] add test function for expand --- tests/python/frontend/onnx/test_forward.py | 31 ++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index e074bac90f2a..64b01199c9c6 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -142,6 +142,37 @@ def test_reshape(): tvm.testing.assert_allclose(ref_shape, tvm_out.shape) +def test_expand(): + in_shape = (1, 3, 1, 4) + shape = (4, 3, 3, 4) + ref_shape = (4, 3, 3, 4) + + shape_array = np.array(shape) + ref_node = onnx.helper.make_node('Constant', + inputs=[], + outputs=['shape'], + value=onnx.helper.make_tensor(name = 'const_tensor', + data_type = onnx.TensorProto.INT32, + dims = shape_array.shape, + vals = shape_array.flatten().astype(int))) + expand_node = helper.make_node("Expand", ["in", "shape"], ["out"]) + + graph = helper.make_graph([ref_node, expand_node], + "expand_node", + inputs = [helper.make_tensor_value_info("in", + TensorProto.FLOAT, list(in_shape))], + outputs = [helper.make_tensor_value_info("out", + TensorProto.FLOAT, list(ref_shape))]) + + model = helper.make_model(graph, producer_name='expand_test') + + for target, ctx in ctx_list(): + x = np.random.uniform(size=in_shape).astype('int32') + tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, 'float32') + + tvm.testing.assert_allclose(ref_shape, tvm_out.shape) + + def verify_depth_to_space(inshape, outshape, mode, blockSize): node = onnx.helper.make_node('DepthToSpace', inputs=['x'], From 239cfd8e1123f40e5bd40836f5ff32881c9e8ee9 Mon Sep 17 00:00:00 2001 From: tkclimb Date: Mon, 16 Dec 2019 16:00:59 +0900 Subject: [PATCH 3/5] Fix a onnx frontend test --- python/tvm/relay/frontend/onnx.py | 4 +- tests/python/frontend/onnx/test_forward.py | 57 ++++++++++++---------- 2 files changed, 33 insertions(+), 28 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 7e05e5905d00..8df722f9ebfa 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1085,9 +1085,7 @@ class Expand(OnnxOpConverter): """ @classmethod - def _impl_v1(cls, inputs, attr, params): - assert len(inputs) == 2, "ONNX Expand must have 2 inputs." - assert isinstance(inputs[1], tvm.relay.expr.Var), "2nd operand of ONNX Expand must be a constant." + def _impl_v8(cls, inputs, attr, params): shape_data = params[inputs[1].name_hint] shape = tuple(shape_data.asnumpy()) return _op.broadcast_to(inputs[0], shape=shape) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 64b01199c9c6..d26aed68a86e 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -143,34 +143,40 @@ def test_reshape(): def test_expand(): - in_shape = (1, 3, 1, 4) - shape = (4, 3, 3, 4) - ref_shape = (4, 3, 3, 4) - shape_array = np.array(shape) - ref_node = onnx.helper.make_node('Constant', - inputs=[], - outputs=['shape'], - value=onnx.helper.make_tensor(name = 'const_tensor', - data_type = onnx.TensorProto.INT32, - dims = shape_array.shape, - vals = shape_array.flatten().astype(int))) - expand_node = helper.make_node("Expand", ["in", "shape"], ["out"]) - - graph = helper.make_graph([ref_node, expand_node], - "expand_node", - inputs = [helper.make_tensor_value_info("in", - TensorProto.FLOAT, list(in_shape))], - outputs = [helper.make_tensor_value_info("out", - TensorProto.FLOAT, list(ref_shape))]) - - model = helper.make_model(graph, producer_name='expand_test') + def _test(name, in_shap, shape, ref_shape): + shape_array = np.array(shape) + ref_node = onnx.helper.make_node('Constant', + inputs=[], + outputs=['shape'], + value=onnx.helper.make_tensor(name = 'const_tensor', + data_type = onnx.TensorProto.INT32, + dims = shape_array.shape, + vals = shape_array.flatten().astype(int))) + expand_node = helper.make_node("Expand", ["in", "shape"], ["out"]) + + graph = helper.make_graph([ref_node, expand_node], + "expand_node", + inputs = [helper.make_tensor_value_info("in", + TensorProto.FLOAT, list(in_shape))], + outputs = [helper.make_tensor_value_info("out", + TensorProto.FLOAT, list(ref_shape))]) + + model = helper.make_model(graph, producer_name=name) - for target, ctx in ctx_list(): - x = np.random.uniform(size=in_shape).astype('int32') - tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, 'float32') + for target, ctx in ctx_list(): + x = np.random.uniform(size=in_shape).astype('int32') + tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, 'float32') - tvm.testing.assert_allclose(ref_shape, tvm_out.shape) + tvm.testing.assert_allclose(ref_shape, tvm_out.shape) + + in_shape = (1, 3, 1, 4) + shape = (4, 3, 3, 4) + _test('expand_with_dim_unchanged_test', in_shape, shape, shape) + + in_shape = (3, 2, 1) + shape = (4, 3, 2, 4) + _test('expand_with_dim_changed_test', in_shape, shape, shape) def verify_depth_to_space(inshape, outshape, mode, blockSize): @@ -1741,6 +1747,7 @@ def test_or(): test_flatten() test_reshape() test_shape() + test_expand() test_power() test_squeeze() test_unsqueeze() From 2e49fc1cb49dd6d25720557423f41a43468a3f3e Mon Sep 17 00:00:00 2001 From: tkclimb Date: Tue, 17 Dec 2019 15:58:33 +0900 Subject: [PATCH 4/5] Add tests for the value itself instead of shape only on test_expand --- python/tvm/relay/frontend/onnx.py | 50 +++++++- tests/python/frontend/onnx/test_forward.py | 141 +++++++++++---------- 2 files changed, 115 insertions(+), 76 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 8df722f9ebfa..40112a7169b4 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1083,13 +1083,49 @@ def _impl_v7(cls, inputs, attr, params): class Expand(OnnxOpConverter): """ Operator converter for Expand. """ - @classmethod def _impl_v8(cls, inputs, attr, params): - shape_data = params[inputs[1].name_hint] - shape = tuple(shape_data.asnumpy()) - return _op.broadcast_to(inputs[0], shape=shape) + in_shape = np.array(infer_shape(inputs[0])).astype('int32') + if get_name(inputs[1]) in params: + shape = params[inputs[1].name_hint].asnumpy().astype('int32') + else: + shape = infer_value_simulated(inputs[1], params).asnumpy().astype('int32') + + # Currently 'op.broadcast_to' expect the rank of the given 'shape' + # (the 2nd input) is always higher than that of the given 'input' (the 1st input) + # However, ONNX Expand supports multi-directional broadcasting, which allows + # above pattern and also some extent of 'shape' can be smaller than the corresponding + # extent of 'input'. In this case, the extent of 'shape' must be 1. + # https://github.com/onnx/onnx/blob/master/docs/Broadcasting.md + # In above cases, we cannot directorly apply 'op.broadcast_to' instead of 'expand' + # so, here we solved this problem by expanding the given 'shape' itself. + def expand_shape(in_shape, shape): + """ A function expands the shape when the rank is lower than that of the given + intput. Also it replaces the extent of the shape with the corresponding extent + of the intput when it is 1. + """ + + # here we flip the shapes because this can be more simply written + # when the innermost dimension is located at the index 0. + in_shape = np.flip(in_shape, axis=0) + shape = np.flip(shape, axis=0) + + if in_shape.size < shape.size: + for i in range(shape.size): + if i < in_shape.size and in_shape[i] > shape[i]: + shape[i] = in_shape[i] + else: + for i in range(in_shape.size): + if i >= shape.size: + np.append(shape, in_shape[i]) + elif shape[i] == 1: + shape[i] = in_shape[i] + + new_shape = np.flip(shape, axis=0) + return new_shape + shape = expand_shape(in_shape, shape) + return _op.broadcast_to(inputs[0], shape=tuple(shape)) # compatible operators that do NOT require any conversion. _identity_list = [] @@ -1239,7 +1275,7 @@ def __init__(self, shape, dtype): self._renames = {} self._num_input = 0 self._num_param = 0 - self._shape = shape if shape else {} + self.shape_list = shape if shape else {} self._dtype = dtype def from_onnx(self, graph, opset): @@ -1289,8 +1325,8 @@ def from_onnx(self, graph, opset): dtype=self._params[i_name].dtype) else: self._num_input += 1 - if i_name in self._shape: - tshape = self._shape[i_name] + if i_name in self.shape_list: + tshape = self.shape_list[i_name] else: raise ValueError("Must provide an input shape for `{0}`.".format(i_name)) if isinstance(self._dtype, dict): diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index d26aed68a86e..6bf23c52c0e4 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -144,39 +144,42 @@ def test_reshape(): def test_expand(): - def _test(name, in_shap, shape, ref_shape): + def _test_expand(name, data, shape, ref_data): shape_array = np.array(shape) - ref_node = onnx.helper.make_node('Constant', + shape_node = onnx.helper.make_node('Constant', inputs=[], outputs=['shape'], value=onnx.helper.make_tensor(name = 'const_tensor', - data_type = onnx.TensorProto.INT32, - dims = shape_array.shape, - vals = shape_array.flatten().astype(int))) + data_type = onnx.TensorProto.INT32, + dims = shape_array.shape, + vals = shape_array.flatten().astype('int32'))) expand_node = helper.make_node("Expand", ["in", "shape"], ["out"]) - graph = helper.make_graph([ref_node, expand_node], - "expand_node", + graph = helper.make_graph([shape_node, expand_node], + "expand_test", inputs = [helper.make_tensor_value_info("in", - TensorProto.FLOAT, list(in_shape))], + TensorProto.FLOAT, list(data.shape))], outputs = [helper.make_tensor_value_info("out", - TensorProto.FLOAT, list(ref_shape))]) + TensorProto.FLOAT, list(ref_data.shape))]) model = helper.make_model(graph, producer_name=name) for target, ctx in ctx_list(): - x = np.random.uniform(size=in_shape).astype('int32') - tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, 'float32') + tvm_out = get_tvm_output(model, data, target, ctx, ref_data.shape, 'float32') - tvm.testing.assert_allclose(ref_shape, tvm_out.shape) + tvm.testing.assert_allclose(ref_data, tvm_out) - in_shape = (1, 3, 1, 4) - shape = (4, 3, 3, 4) - _test('expand_with_dim_unchanged_test', in_shape, shape, shape) + in_shape = (3, 1) + shape = (3, 4) + data = np.random.uniform(size=in_shape).astype(np.float32) + ref_data = np.tile(data, 4) + _test_expand('expand_with_dim_unchanged_test', data, shape, ref_data) - in_shape = (3, 2, 1) - shape = (4, 3, 2, 4) - _test('expand_with_dim_changed_test', in_shape, shape, shape) + in_shape = (3, 1) + shape = (2, 1, 6) + data = np.random.uniform(size=in_shape).astype(np.float32) + ref_data = data * np.ones(shape, dtype=np.float32) + _test_expand('expand_with_dim_changed_test', data, shape, ref_data) def verify_depth_to_space(inshape, outshape, mode, blockSize): @@ -1745,55 +1748,55 @@ def test_or(): if __name__ == '__main__': test_flatten() - test_reshape() - test_shape() + # test_reshape() + # test_shape() test_expand() - test_power() - test_squeeze() - test_unsqueeze() - test_slice() - test_floor() - test_ceil() - test_clip() - test_onehot() - test_matmul() - test_batch_matmul() - test_gather() - test_lrn() - test_instance_norm() - test_upsample() - test_forward_min() - test_forward_max() - test_forward_mean() - test_forward_hardsigmoid() - test_forward_arg_min_max() - test_softmax() - test_constantofshape() - test_reduce_max() - test_reduce_min() - test_reduce_sum() - test_reduce_mean() - test_pad() - test_split() - test_binary_ops() - test_single_ops() - test_leaky_relu() - test_elu() - test_selu() - test_ThresholdedRelu() - test_ScaledTanh() - test_ParametricSoftplus() - test_Scale() - test_LogSoftmax() - test_resnet() - test_inception() - test_densenet() - test_sign() - test_not() - test_and() - test_tile() - test_erf() - test_where() - test_or() - test_depth_to_space() - test_space_to_depth() + # test_power() + # test_squeeze() + # test_unsqueeze() + # test_slice() + # test_floor() + # test_ceil() + # test_clip() + # test_onehot() + # test_matmul() + # test_batch_matmul() + # test_gather() + # test_lrn() + # test_instance_norm() + # test_upsample() + # test_forward_min() + # test_forward_max() + # test_forward_mean() + # test_forward_hardsigmoid() + # test_forward_arg_min_max() + # test_softmax() + # test_constantofshape() + # test_reduce_max() + # test_reduce_min() + # test_reduce_sum() + # test_reduce_mean() + # test_pad() + # test_split() + # test_binary_ops() + # test_single_ops() + # test_leaky_relu() + # test_elu() + # test_selu() + # test_ThresholdedRelu() + # test_ScaledTanh() + # test_ParametricSoftplus() + # test_Scale() + # test_LogSoftmax() + # test_resnet() + # test_inception() + # test_densenet() + # test_sign() + # test_not() + # test_and() + # test_tile() + # test_erf() + # test_where() + # test_or() + # test_depth_to_space() + # test_space_to_depth() From 453f3122827b986655d1b73003d5461a897c294b Mon Sep 17 00:00:00 2001 From: tkclimb Date: Wed, 18 Dec 2019 14:35:02 +0900 Subject: [PATCH 5/5] Cleaned up some unnecessary modifications. --- python/tvm/relay/frontend/onnx.py | 6 +- tests/python/frontend/onnx/test_forward.py | 102 ++++++++++----------- 2 files changed, 54 insertions(+), 54 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 40112a7169b4..2dec7e2e1ede 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1275,7 +1275,7 @@ def __init__(self, shape, dtype): self._renames = {} self._num_input = 0 self._num_param = 0 - self.shape_list = shape if shape else {} + self._shape = shape if shape else {} self._dtype = dtype def from_onnx(self, graph, opset): @@ -1325,8 +1325,8 @@ def from_onnx(self, graph, opset): dtype=self._params[i_name].dtype) else: self._num_input += 1 - if i_name in self.shape_list: - tshape = self.shape_list[i_name] + if i_name in self._shape: + tshape = self._shape[i_name] else: raise ValueError("Must provide an input shape for `{0}`.".format(i_name)) if isinstance(self._dtype, dict): diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 6bf23c52c0e4..f7029bb58079 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1748,55 +1748,55 @@ def test_or(): if __name__ == '__main__': test_flatten() - # test_reshape() - # test_shape() + test_reshape() + test_shape() test_expand() - # test_power() - # test_squeeze() - # test_unsqueeze() - # test_slice() - # test_floor() - # test_ceil() - # test_clip() - # test_onehot() - # test_matmul() - # test_batch_matmul() - # test_gather() - # test_lrn() - # test_instance_norm() - # test_upsample() - # test_forward_min() - # test_forward_max() - # test_forward_mean() - # test_forward_hardsigmoid() - # test_forward_arg_min_max() - # test_softmax() - # test_constantofshape() - # test_reduce_max() - # test_reduce_min() - # test_reduce_sum() - # test_reduce_mean() - # test_pad() - # test_split() - # test_binary_ops() - # test_single_ops() - # test_leaky_relu() - # test_elu() - # test_selu() - # test_ThresholdedRelu() - # test_ScaledTanh() - # test_ParametricSoftplus() - # test_Scale() - # test_LogSoftmax() - # test_resnet() - # test_inception() - # test_densenet() - # test_sign() - # test_not() - # test_and() - # test_tile() - # test_erf() - # test_where() - # test_or() - # test_depth_to_space() - # test_space_to_depth() + test_power() + test_squeeze() + test_unsqueeze() + test_slice() + test_floor() + test_ceil() + test_clip() + test_onehot() + test_matmul() + test_batch_matmul() + test_gather() + test_lrn() + test_instance_norm() + test_upsample() + test_forward_min() + test_forward_max() + test_forward_mean() + test_forward_hardsigmoid() + test_forward_arg_min_max() + test_softmax() + test_constantofshape() + test_reduce_max() + test_reduce_min() + test_reduce_sum() + test_reduce_mean() + test_pad() + test_split() + test_binary_ops() + test_single_ops() + test_leaky_relu() + test_elu() + test_selu() + test_ThresholdedRelu() + test_ScaledTanh() + test_ParametricSoftplus() + test_Scale() + test_LogSoftmax() + test_resnet() + test_inception() + test_densenet() + test_sign() + test_not() + test_and() + test_tile() + test_erf() + test_where() + test_or() + test_depth_to_space() + test_space_to_depth()