From a8285a4b8570068e9c0cf44c4b42cd82409aec9e Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Mon, 1 Feb 2021 22:35:57 +0000 Subject: [PATCH 1/3] onnx roi align --- .../contrib/onnx/mx2onnx/_op_translations.py | 39 +++++++++++++++++++ tests/python-pytest/onnx/test_operators.py | 17 ++++++++ 2 files changed, 56 insertions(+) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 39723321923e..07cb8742fa15 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -3555,3 +3555,42 @@ def convert_broadcast_like(node, **kwargs): ] return nodes + + +@mx_op.register('_contrib_ROIAlign') +def convert_contrib_roialign(node, **kwargs): + """Map MXNet's _contrib_ROIAlign + """ + from onnx.helper import make_node + from onnx import TensorProto + name, input_nodes, attrs = get_inputs(node, kwargs) + + pooled_size = convert_string_to_list(str(attrs.get('pooled_size'))) + spatial_scale = float(attrs.get('spatial_scale')) + sample_ratio = int(attrs.get('sample_ratio', '0')) + position_sensitive = attrs.get('position_sensitive', 'False') + aligned = attrs.get('aligned', 'False') + + ''' + if position_sensitive != 'False': + raise NotImplementedError('_contrib_ROIAlign does not currently support \ + position_sensitive!=False') + if aligned != 'False': + raise NotImplementedError('_contrib_ROIAlign does not currently support \ + aligned!=False') + ''' + create_tensor([0], name+'_0', kwargs['initializer']), + create_tensor([1], name+'_1', kwargs['initializer']), + create_tensor([5], name+'_5', kwargs['initializer']), + + nodes = [ + make_node('Slice', [input_nodes[1], name+'_1', name+'_5', name+'_1'], [name+'_rois']), + make_node('Slice', [input_nodes[1], name+'_0', name+'_1', name+'_1'], [name+'_inds__']), + make_node('Squeeze', [name+'_inds__'], [name+'_inds_'], axes=(1,)), + make_node('Cast', [name+'_inds_'], [name+'_inds'], to=int(TensorProto.INT64)), + make_node('RoiAlign', [input_nodes[0], name+'_rois', name+'_inds'], [name], + mode='avg', output_height=pooled_size[0], output_width=pooled_size[1], + sampling_ratio=sample_ratio, spatial_scale=spatial_scale) + ] + + return nodes diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 204fe57f5930..6b01a5dace9e 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -644,3 +644,20 @@ def test_onnx_export_broadcast_like(tmp_path, dtype, lhs_axes, rhs_axes): op_export_test('broadcast_like1', M1, [x, y], tmp_path) M2 = def_model('broadcast_like', lhs_axes=lhs_axes, rhs_axes=rhs_axes) op_export_test('broadcast_like2', M2, [x, y], tmp_path) + + +@pytest.mark.parametrize('dtype', ['float32']) +@pytest.mark.parametrize('pooled_size', [(1, 1), (3, 3), (14, 14), (5, 7)]) +@pytest.mark.parametrize('spatial_scale', [1, 0.5, 0.0625]) +@pytest.mark.parametrize('spatial_ratio', [1, 2, 3, 5]) +def test_onnx_export_contrib_ROIAlign(tmp_path, dtype, pooled_size, spatial_scale, spatial_ratio): + data = mx.random.uniform(0, 1, (5, 3, 128, 128)).astype(dtype) + rois = mx.nd.array([[0, 0, 0, 63, 63], + [1, 34, 52, 25, 85], + [2, 50, 50, 100, 100], + [3, 0, 0, 127, 127], + [4, 12, 84, 22, 94], + [0, 0, 0, 1, 1]]).astype(dtype) + M = def_model('contrib.ROIAlign', pooled_size=pooled_size, spatial_scale=spatial_scale, + sample_ratio=spatial_ratio) + op_export_test('_contrib_ROIAlign', M, [data, rois], tmp_path) From 6653e30cc3830a2ed0ca1ca8d02715af2fce1f37 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Mon, 1 Feb 2021 22:40:25 +0000 Subject: [PATCH 2/3] fix --- python/mxnet/contrib/onnx/mx2onnx/_op_translations.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 07cb8742fa15..30e90ba82a43 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -3571,14 +3571,13 @@ def convert_contrib_roialign(node, **kwargs): position_sensitive = attrs.get('position_sensitive', 'False') aligned = attrs.get('aligned', 'False') - ''' if position_sensitive != 'False': raise NotImplementedError('_contrib_ROIAlign does not currently support \ position_sensitive!=False') if aligned != 'False': raise NotImplementedError('_contrib_ROIAlign does not currently support \ aligned!=False') - ''' + create_tensor([0], name+'_0', kwargs['initializer']), create_tensor([1], name+'_1', kwargs['initializer']), create_tensor([5], name+'_5', kwargs['initializer']), From 5a3c3e7ff37ae37b39562670ef264530f5dd0b6e Mon Sep 17 00:00:00 2001 From: Zhaoqi Zhu Date: Mon, 1 Feb 2021 16:03:39 -0800 Subject: [PATCH 3/3] Update _op_translations.py --- python/mxnet/contrib/onnx/mx2onnx/_op_translations.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 30e90ba82a43..da5a649a7b2e 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -3578,9 +3578,9 @@ def convert_contrib_roialign(node, **kwargs): raise NotImplementedError('_contrib_ROIAlign does not currently support \ aligned!=False') - create_tensor([0], name+'_0', kwargs['initializer']), - create_tensor([1], name+'_1', kwargs['initializer']), - create_tensor([5], name+'_5', kwargs['initializer']), + _ = create_tensor([0], name+'_0', kwargs['initializer']), + _ = create_tensor([1], name+'_1', kwargs['initializer']), + _ = create_tensor([5], name+'_5', kwargs['initializer']), nodes = [ make_node('Slice', [input_nodes[1], name+'_1', name+'_5', name+'_1'], [name+'_rois']),