From 55e90760acc96bd6d05f024a099c8d1d91cb134d Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Tue, 19 Jan 2021 17:34:01 -0800 Subject: [PATCH 1/3] gather_nd --- .../contrib/onnx/mx2onnx/_op_translations.py | 43 +++++++++++++++++++ tests/python-pytest/onnx/test_operators.py | 11 +++++ 2 files changed, 54 insertions(+) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 59082c961475..8b858cc164d4 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -3177,3 +3177,46 @@ def convert_reshape_like(node, **kwargs): ] return nodes + + +@mx_op.register("gather_nd") +def convert_gather_nd(node, **kwargs): + """Map MXNet's gather_ND operator attributes to onnx's operator. + """ + from onnx.helper import make_node + name, input_nodes, attrs = get_inputs(node, kwargs) + + data = input_nodes[0] + indices = input_nodes[1] + + # Onnx Transpose operator takes perm as a parameter, so we need to 'pad' + # the input to a known dim (10 here) + perm = [9] + [i for i in range(1, 9)] + [0] + + nodes = [ + create_tensor([0], name+'_0', kwargs['initializer']), + create_tensor([1], name+'_1', kwargs['initializer']), + create_tensor([10], name+'_10', kwargs['initializer']), + # Generate 10-d filter + make_node('Shape', [indices], [name+'_indices_shape']), + make_node('Shape', [name+'_indices_shape'], [name+'_indices_dim']), + make_node('Sub', [name+'_10', name+'_indices_dim'], [name+'_sub0_out']), + make_node('Concat', [name+'_0', name+'_sub0_out'], [name+'_concat0_out'], axis=0), + make_node('Pad', [name+'_indices_shape', name+'_concat0_out', name+'_1'], [name+'_shape_10_dim']), + make_node('Reshape', [indices, name+'_shape_10_dim'], [name+'_indices_10_dim']), + make_node('Transpose', [name+'_indices_10_dim'], [name+'_transpose0_output'], perm=perm), + # Reshape filter to acutall dim for GatherND computation + make_node('Sub', [name+'_indices_dim', name+'_1'], [name+'_sub1_out']), + make_node('Slice', [name+'_indices_shape', name+'_0', name+'_sub1_out'], + [name+'_slice0_out']), + make_node('Slice', [name+'_indices_shape', name+'_sub1_out', name+'_indices_dim'], + [name+'_slice1_out']), + make_node('Concat', [name+'_slice1_out', name+'_slice0_out'], [name+'_concat1_out'], axis=0), + make_node('Reshape', [name+'_transpose0_output', name+'_concat1_out'], [name+'_reshape0_out']), + # Cast data type for indicies + make_node('Cast', [name+'_reshape0_out'], [name+'_cast0_out'], to=int(onnx.TensorProto.INT64)), + make_node('GatherND', [data, name+'_cast0_out'], [name], name=name), + ] + + return nodes + diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 1c4629d6f871..e1236d6eec40 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -463,3 +463,14 @@ def test_onnx_export_reshape_like(tmp_path, dtype): M4 = def_model('reshape_like', lhs_begin=0, lhs_end=None, rhs_begin=1, rhs_end=None) op_export_test('reshape_like4', M4, [x, y], tmp_path) + +@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float16', 'float32', 'float64']) +def test_onnx_export_gather_nd(tmp_path, dtype): + x1 = mx.nd.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtype) + y1 = mx.nd.array([[0, 1], [1, 0]], dtype=dtype) + M1 = def_model('gather_nd') + op_export_test('gather_nd1', M1, [x1, y1], tmp_path) + x2 = mx.nd.array([[0, 1], [2, 3]], dtype=dtype) + y2 = mx.nd.array([[1, 1, 0], [0, 1, 0]], dtype=dtype) + M2 = def_model('gather_nd') + op_export_test('gather_nd2', M2, [x2, y2], tmp_path) From 63a42dca62147d879ff71de234e6e746b79cc38c Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Tue, 19 Jan 2021 19:48:23 -0800 Subject: [PATCH 2/3] fix sanity --- python/mxnet/contrib/onnx/mx2onnx/_op_translations.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 8b858cc164d4..df4ddf058da5 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -3184,12 +3184,12 @@ def convert_gather_nd(node, **kwargs): """Map MXNet's gather_ND operator attributes to onnx's operator. """ from onnx.helper import make_node - name, input_nodes, attrs = get_inputs(node, kwargs) + name, input_nodes, _ = get_inputs(node, kwargs) data = input_nodes[0] indices = input_nodes[1] - # Onnx Transpose operator takes perm as a parameter, so we need to 'pad' + # Onnx Transpose operator takes perm as a parameter, so we need to 'pad' # the input to a known dim (10 here) perm = [9] + [i for i in range(1, 9)] + [0] @@ -3197,7 +3197,7 @@ def convert_gather_nd(node, **kwargs): create_tensor([0], name+'_0', kwargs['initializer']), create_tensor([1], name+'_1', kwargs['initializer']), create_tensor([10], name+'_10', kwargs['initializer']), - # Generate 10-d filter + # Generate 10-d filter make_node('Shape', [indices], [name+'_indices_shape']), make_node('Shape', [name+'_indices_shape'], [name+'_indices_dim']), make_node('Sub', [name+'_10', name+'_indices_dim'], [name+'_sub0_out']), @@ -3208,9 +3208,9 @@ def convert_gather_nd(node, **kwargs): # Reshape filter to acutall dim for GatherND computation make_node('Sub', [name+'_indices_dim', name+'_1'], [name+'_sub1_out']), make_node('Slice', [name+'_indices_shape', name+'_0', name+'_sub1_out'], - [name+'_slice0_out']), + [name+'_slice0_out']), make_node('Slice', [name+'_indices_shape', name+'_sub1_out', name+'_indices_dim'], - [name+'_slice1_out']), + [name+'_slice1_out']), make_node('Concat', [name+'_slice1_out', name+'_slice0_out'], [name+'_concat1_out'], axis=0), make_node('Reshape', [name+'_transpose0_output', name+'_concat1_out'], [name+'_reshape0_out']), # Cast data type for indicies @@ -3219,4 +3219,3 @@ def convert_gather_nd(node, **kwargs): ] return nodes - From 225831b012a2d8e045a635998c18852f831b34b4 Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Thu, 21 Jan 2021 15:08:43 -0800 Subject: [PATCH 3/3] update tests --- python/mxnet/contrib/onnx/mx2onnx/_op_translations.py | 7 +++---- tests/python-pytest/onnx/test_operators.py | 11 +++++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index df4ddf058da5..dd81b063252d 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -3206,16 +3206,15 @@ def convert_gather_nd(node, **kwargs): make_node('Reshape', [indices, name+'_shape_10_dim'], [name+'_indices_10_dim']), make_node('Transpose', [name+'_indices_10_dim'], [name+'_transpose0_output'], perm=perm), # Reshape filter to acutall dim for GatherND computation - make_node('Sub', [name+'_indices_dim', name+'_1'], [name+'_sub1_out']), - make_node('Slice', [name+'_indices_shape', name+'_0', name+'_sub1_out'], + make_node('Slice', [name+'_indices_shape', name+'_0', name+'_1'], [name+'_slice0_out']), - make_node('Slice', [name+'_indices_shape', name+'_sub1_out', name+'_indices_dim'], + make_node('Slice', [name+'_indices_shape', name+'_1', name+'_indices_dim'], [name+'_slice1_out']), make_node('Concat', [name+'_slice1_out', name+'_slice0_out'], [name+'_concat1_out'], axis=0), make_node('Reshape', [name+'_transpose0_output', name+'_concat1_out'], [name+'_reshape0_out']), # Cast data type for indicies make_node('Cast', [name+'_reshape0_out'], [name+'_cast0_out'], to=int(onnx.TensorProto.INT64)), - make_node('GatherND', [data, name+'_cast0_out'], [name], name=name), + make_node('GatherND', [data, name+'_cast0_out'], [name], name=name) ] return nodes diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index e1236d6eec40..9b0921be7f8b 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -466,11 +466,14 @@ def test_onnx_export_reshape_like(tmp_path, dtype): @pytest.mark.parametrize('dtype', ['int32', 'int64', 'float16', 'float32', 'float64']) def test_onnx_export_gather_nd(tmp_path, dtype): - x1 = mx.nd.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtype) - y1 = mx.nd.array([[0, 1], [1, 0]], dtype=dtype) + # y[0] == dim(x) + x1 = mx.random.uniform(-100, 100, (4, 5, 6, 7)).astype(dtype) + y1 = mx.random.randint(-4, 4, (4, 4, 4)).astype(dtype) M1 = def_model('gather_nd') op_export_test('gather_nd1', M1, [x1, y1], tmp_path) - x2 = mx.nd.array([[0, 1], [2, 3]], dtype=dtype) - y2 = mx.nd.array([[1, 1, 0], [0, 1, 0]], dtype=dtype) + # y[0] < dim(x) + x2 = mx.random.uniform(-100, 100, (4, 5, 6, 7)).astype(dtype) + y2 = mx.random.randint(-4, 4, (2,3,4)).astype(dtype) M2 = def_model('gather_nd') op_export_test('gather_nd2', M2, [x2, y2], tmp_path) +