From f26560163a1918c5951a4f8c8326de80305e3d64 Mon Sep 17 00:00:00 2001 From: Siju Samuel Date: Tue, 31 Jul 2018 15:59:25 +0530 Subject: [PATCH] Onnx Gather operator added --- nnvm/python/nnvm/frontend/onnx.py | 16 ++++++++-- .../python/frontend/onnx/test_forward.py | 29 +++++++++++++++++++ 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/nnvm/python/nnvm/frontend/onnx.py b/nnvm/python/nnvm/frontend/onnx.py index ee7fcc3c57fd..cfef11d6a106 100644 --- a/nnvm/python/nnvm/frontend/onnx.py +++ b/nnvm/python/nnvm/frontend/onnx.py @@ -446,7 +446,6 @@ def _impl_v1(cls, inputs, attr, params): inputs[0] = _sym.expand_dims(inputs[0], axis=axes, num_newaxis=1) return inputs[0] - class Slice(OnnxOpConverter): """ Operator converter for Slice. """ @@ -487,6 +486,19 @@ def _impl_v1(cls, inputs, attr, params): 'ends': 'end'}, ignores=['axes'])(inputs, attr) +class Gather(OnnxOpConverter): + """ Operator converter for Gather. + """ + + @classmethod + def _impl_v1(cls, inputs, attr, params): + axis = attr['axis'] + indices = np.array(attr['indices'], dtype='int32') + name = 'gather_indices' + gather_indices = _sym.Variable(name=name, init=indices) + params[name] = indices + return _sym.take(inputs[0], gather_indices, axis=axis) + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -593,7 +605,7 @@ def _get_convert_map(opset): 'Split': AttrCvt('split', {'split': 'indices_or_sections'}), 'Slice': Slice.get_converter(opset), 'Transpose': AttrCvt('transpose', {'perm': 'axes'}), - # 'Gather' + 'Gather': Gather.get_converter(opset), 'Squeeze': Renamer('squeeze'), 'Unsqueeze': Unsqueeze.get_converter(opset), 'Pad': Pad.get_converter(opset), diff --git a/nnvm/tests/python/frontend/onnx/test_forward.py b/nnvm/tests/python/frontend/onnx/test_forward.py index 56f1e006265a..bddf4a87009c 100644 --- a/nnvm/tests/python/frontend/onnx/test_forward.py +++ b/nnvm/tests/python/frontend/onnx/test_forward.py @@ -189,6 +189,34 @@ def test_unsqueeze(): np.testing.assert_allclose(out_shape, tvm_out.shape) +def verify_gather(in_shape, indices, axis=0): + indices_src = np.array(indices, dtype="int32") + + x = np.random.uniform(size=in_shape) + out_np = np.take(x, indices_src, axis=axis) + + y = helper.make_node("Gather", ['in'], ['out'], indices=indices, axis=axis) + + graph = helper.make_graph([y], + 'gather_test', + inputs = [helper.make_tensor_value_info("in", + TensorProto.FLOAT, list(in_shape))], + outputs = [helper.make_tensor_value_info("out", + TensorProto.FLOAT, list(out_np.shape))]) + + model = helper.make_model(graph, producer_name='gather_test') + + for target, ctx in ctx_list(): + tvm_out = get_tvm_output(model, x, target, ctx, out_np.shape, 'float32') + + np.testing.assert_allclose(out_np, tvm_out) + +def test_gather(): + verify_gather((4,), [1]) + verify_gather((4,), [0, 1, 2, 3]) + verify_gather((4, 2), [1], 1) + verify_gather((4, 3, 5, 6), [2, 1, 0, 0], -2) + def _test_slice_iteration(indata, outdata, starts, ends, axes=None): if axes: y = helper.make_node("Slice", ['in'], ['out'], axes=axes, starts=starts, ends=ends) @@ -299,3 +327,4 @@ def test_matmul(): test_ceil() test_clip() test_matmul() + test_gather()