From 8ae1a5dafe26dcbadfa3917e416ccfba0263e043 Mon Sep 17 00:00:00 2001 From: masa Date: Fri, 11 Sep 2020 05:23:49 +0900 Subject: [PATCH 1/3] support onnx GatherElements --- python/tvm/relay/frontend/onnx.py | 12 +++++++ tests/python/frontend/onnx/test_forward.py | 40 ++++++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 8269ebfb198d..f852e0fdeb2b 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1050,6 +1050,17 @@ def _impl_v1(cls, inputs, attr, params): return AttrCvt("take", extras={"axis": axis})(inputs, {}) +class GatherElements(OnnxOpConverter): + """ Operator converter for GatherElements. + """ + @classmethod + def _impl_v1(cls, inputs, attr, params): + data = inputs[0] + indices = inputs[1] + axis = attr.get('axis', 0) + return _op.gather(data, axis, indices) + + class GatherND(OnnxOpConverter): """Operator converter for GatherND.""" @@ -2014,6 +2025,7 @@ def _get_convert_map(opset): "DepthToSpace": DepthToSpace.get_converter(opset), "SpaceToDepth": SpaceToDepth.get_converter(opset), "Gather": Gather.get_converter(opset), + "GatherElements": GatherElements.get_converter(opset), "GatherND": GatherND.get_converter(opset), "Scatter": Scatter.get_converter(opset), "ScatterElements": Scatter.get_converter(opset), diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 6f63cbf0413a..3c69b8b95812 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -435,6 +435,45 @@ def test_gather(): verify_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, "float32") +def verify_gatherelements(in_shape, indices, axis): + x = np.random.uniform(size=in_shape).astype("float32") + indices = np.array(indices, dtype="int32") + print(x.shape) + print(indices.shape) + + y = helper.make_node("GatherElements", ['data', 'indices'], ['output'], axis=axis) + + graph = helper.make_graph([y], + 'gather_elements_test', + inputs=[helper.make_tensor_value_info("data", + TensorProto.FLOAT, list(in_shape)), + helper.make_tensor_value_info("indices", + TensorProto.INT32, list(indices.shape))], + outputs=[helper.make_tensor_value_info("output", + TensorProto.FLOAT, list(in_shape))]) + model = helper.make_model(graph, producer_name='gather_elements_test') + onnx_out = get_onnxruntime_output(model, [x, indices]) + + for target, ctx in tvm.testing.enabled_targets(): + tvm_out = get_tvm_output( + model, [x, indices], target, ctx, onnx_out[0].shape) + tvm.testing.assert_allclose(onnx_out[0], tvm_out) + + +@tvm.testing.uses_gpu +def test_gatherelements(): + verify_gatherelements((4,), [3, 0, 2, 1], 0) + verify_gatherelements((2, 2), [[1, 0], [0, 1]], 0) + verify_gatherelements((2, 2), [[0, 0], [1, 0]], 1) + verify_gatherelements((2, 2), [[1, 0], [0, 1]], 1) + + indices = [[[1, 0, 0], [1, 0, 1], [0, 1, 1]], + [[1, 1, 1], [1, 2, 1], [1, 0, 1]], + [[1, 2, 1], [1, 2, 1], [1, 2, 1]]] + + verify_gatherelements((3, 3, 3), indices, 2) + + def verify_scatter(in_shape, indices, axis): x = np.random.uniform(size=in_shape).astype("float32") indices = np.array(indices, dtype="int32") @@ -3384,6 +3423,7 @@ def verify_roi_align( test_matmul() test_batch_matmul() test_gather() + test_gatherelements() test_gather_nd() test_scatter() test_lrn() From 0b5e22fa5aaa1ff1e4e3cac9718734698354c9d3 Mon Sep 17 00:00:00 2001 From: masa Date: Fri, 11 Sep 2020 12:02:12 +0900 Subject: [PATCH 2/3] remove print --- tests/python/frontend/onnx/test_forward.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 3c69b8b95812..342f038d7873 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -438,11 +438,8 @@ def test_gather(): def verify_gatherelements(in_shape, indices, axis): x = np.random.uniform(size=in_shape).astype("float32") indices = np.array(indices, dtype="int32") - print(x.shape) - print(indices.shape) y = helper.make_node("GatherElements", ['data', 'indices'], ['output'], axis=axis) - graph = helper.make_graph([y], 'gather_elements_test', inputs=[helper.make_tensor_value_info("data", From 37fb5c4553ae427aa559f892e1bf3e6b1c84217c Mon Sep 17 00:00:00 2001 From: masa Date: Sat, 12 Sep 2020 05:30:57 +0900 Subject: [PATCH 3/3] run black --- python/tvm/relay/frontend/onnx.py | 6 ++-- tests/python/frontend/onnx/test_forward.py | 32 ++++++++++++---------- 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index f852e0fdeb2b..5f31724ec3ea 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1051,13 +1051,13 @@ def _impl_v1(cls, inputs, attr, params): class GatherElements(OnnxOpConverter): - """ Operator converter for GatherElements. - """ + """Operator converter for GatherElements.""" + @classmethod def _impl_v1(cls, inputs, attr, params): data = inputs[0] indices = inputs[1] - axis = attr.get('axis', 0) + axis = attr.get("axis", 0) return _op.gather(data, axis, indices) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 342f038d7873..894a6b6d40ce 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -439,21 +439,21 @@ def verify_gatherelements(in_shape, indices, axis): x = np.random.uniform(size=in_shape).astype("float32") indices = np.array(indices, dtype="int32") - y = helper.make_node("GatherElements", ['data', 'indices'], ['output'], axis=axis) - graph = helper.make_graph([y], - 'gather_elements_test', - inputs=[helper.make_tensor_value_info("data", - TensorProto.FLOAT, list(in_shape)), - helper.make_tensor_value_info("indices", - TensorProto.INT32, list(indices.shape))], - outputs=[helper.make_tensor_value_info("output", - TensorProto.FLOAT, list(in_shape))]) - model = helper.make_model(graph, producer_name='gather_elements_test') + y = helper.make_node("GatherElements", ["data", "indices"], ["output"], axis=axis) + graph = helper.make_graph( + [y], + "gather_elements_test", + inputs=[ + helper.make_tensor_value_info("data", TensorProto.FLOAT, list(in_shape)), + helper.make_tensor_value_info("indices", TensorProto.INT32, list(indices.shape)), + ], + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, list(in_shape))], + ) + model = helper.make_model(graph, producer_name="gather_elements_test") onnx_out = get_onnxruntime_output(model, [x, indices]) for target, ctx in tvm.testing.enabled_targets(): - tvm_out = get_tvm_output( - model, [x, indices], target, ctx, onnx_out[0].shape) + tvm_out = get_tvm_output(model, [x, indices], target, ctx, onnx_out[0].shape) tvm.testing.assert_allclose(onnx_out[0], tvm_out) @@ -464,9 +464,11 @@ def test_gatherelements(): verify_gatherelements((2, 2), [[0, 0], [1, 0]], 1) verify_gatherelements((2, 2), [[1, 0], [0, 1]], 1) - indices = [[[1, 0, 0], [1, 0, 1], [0, 1, 1]], - [[1, 1, 1], [1, 2, 1], [1, 0, 1]], - [[1, 2, 1], [1, 2, 1], [1, 2, 1]]] + indices = [ + [[1, 0, 0], [1, 0, 1], [0, 1, 1]], + [[1, 1, 1], [1, 2, 1], [1, 0, 1]], + [[1, 2, 1], [1, 2, 1], [1, 2, 1]], + ] verify_gatherelements((3, 3, 3), indices, 2)