diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 59082c961475..dd81b063252d 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -3177,3 +3177,44 @@ def convert_reshape_like(node, **kwargs): ] return nodes + + +@mx_op.register("gather_nd") +def convert_gather_nd(node, **kwargs): + """Map MXNet's gather_ND operator attributes to onnx's operator. + """ + from onnx.helper import make_node + name, input_nodes, _ = get_inputs(node, kwargs) + + data = input_nodes[0] + indices = input_nodes[1] + + # Onnx Transpose operator takes perm as a parameter, so we need to 'pad' + # the input to a known dim (10 here) + perm = [9] + [i for i in range(1, 9)] + [0] + + nodes = [ + create_tensor([0], name+'_0', kwargs['initializer']), + create_tensor([1], name+'_1', kwargs['initializer']), + create_tensor([10], name+'_10', kwargs['initializer']), + # Generate 10-d filter + make_node('Shape', [indices], [name+'_indices_shape']), + make_node('Shape', [name+'_indices_shape'], [name+'_indices_dim']), + make_node('Sub', [name+'_10', name+'_indices_dim'], [name+'_sub0_out']), + make_node('Concat', [name+'_0', name+'_sub0_out'], [name+'_concat0_out'], axis=0), + make_node('Pad', [name+'_indices_shape', name+'_concat0_out', name+'_1'], [name+'_shape_10_dim']), + make_node('Reshape', [indices, name+'_shape_10_dim'], [name+'_indices_10_dim']), + make_node('Transpose', [name+'_indices_10_dim'], [name+'_transpose0_output'], perm=perm), + # Reshape filter to acutall dim for GatherND computation + make_node('Slice', [name+'_indices_shape', name+'_0', name+'_1'], + [name+'_slice0_out']), + make_node('Slice', [name+'_indices_shape', name+'_1', name+'_indices_dim'], + [name+'_slice1_out']), + make_node('Concat', [name+'_slice1_out', name+'_slice0_out'], [name+'_concat1_out'], axis=0), + make_node('Reshape', [name+'_transpose0_output', name+'_concat1_out'], [name+'_reshape0_out']), + # Cast data type for indicies + make_node('Cast', [name+'_reshape0_out'], [name+'_cast0_out'], to=int(onnx.TensorProto.INT64)), + make_node('GatherND', [data, name+'_cast0_out'], [name], name=name) + ] + + return nodes diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 1c4629d6f871..9b0921be7f8b 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -463,3 +463,17 @@ def test_onnx_export_reshape_like(tmp_path, dtype): M4 = def_model('reshape_like', lhs_begin=0, lhs_end=None, rhs_begin=1, rhs_end=None) op_export_test('reshape_like4', M4, [x, y], tmp_path) + +@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float16', 'float32', 'float64']) +def test_onnx_export_gather_nd(tmp_path, dtype): + # y[0] == dim(x) + x1 = mx.random.uniform(-100, 100, (4, 5, 6, 7)).astype(dtype) + y1 = mx.random.randint(-4, 4, (4, 4, 4)).astype(dtype) + M1 = def_model('gather_nd') + op_export_test('gather_nd1', M1, [x1, y1], tmp_path) + # y[0] < dim(x) + x2 = mx.random.uniform(-100, 100, (4, 5, 6, 7)).astype(dtype) + y2 = mx.random.randint(-4, 4, (2,3,4)).astype(dtype) + M2 = def_model('gather_nd') + op_export_test('gather_nd2', M2, [x2, y2], tmp_path) +