diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index ecfb0320e593..1c50abe7dc90 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -2946,3 +2946,80 @@ def convert_where(node, **kwargs): make_node("Where", [name+"_bool", input_nodes[1], input_nodes[2]], [name], name=name) ] return nodes + +@mx_op.register("_contrib_box_decode") +def convert_contrib_box_decode(node, **kwargs): + """Map MXNet's _contrib_box_decode operator attributes to onnx's operator. + """ + from onnx.helper import make_node + name, input_nodes, attrs = get_inputs(node, kwargs) + + data = input_nodes[0] + anchors = input_nodes[1] + input_type = kwargs['in_type'] + fmt = attrs.get('format', 'center') + std0 = float(attrs.get('std0', '1.')) + std1 = float(attrs.get('std1', '1.')) + std2 = float(attrs.get('std2', '1.')) + std3 = float(attrs.get('std3', '1.')) + clip = float(attrs.get('clip', '-1.')) + + if fmt not in ['center', 'corner']: + raise NotImplementedError("format must be either corner or center.") + + nodes = [ + create_tensor([0], name+'_0', kwargs["initializer"]), + create_tensor([2], name+'_2', kwargs["initializer"]), + create_tensor([4], name+'_4', kwargs["initializer"]), + create_tensor([2], name+'_2f', kwargs["initializer"], dtype='float32'), + create_tensor([clip], name+'_clip', kwargs["initializer"], dtype='float32'), + create_tensor([std0, std1, std2, std3], name+'_std_1d', kwargs["initializer"], dtype='float32'), + create_tensor([1, 4], name+'_std_shape', kwargs["initializer"]), + make_node("Cast", [data], [name+'_data'], to=int(onnx.TensorProto.FLOAT)), + make_node("Cast", [anchors], [name+'_anchors'], to=int(onnx.TensorProto.FLOAT)), + make_node('Reshape', [name+'_std_1d', name+'_std_shape'], [name+'_std']), + make_node("Mul", [name+'_data', name+'_std'], [name+'_mul0_out']), + make_node('Slice', [name+'_mul0_out', name+'_0', name+'_2', name+'_2'], [name+'_data_xy']), + make_node('Slice', [name+'_mul0_out', name+'_2', name+'_4', name+'_2'], [name+'_data_wh']), + ] + + if fmt == 'corner': + nodes += [ + make_node('Slice', [name+'_anchors', name+'_0', name+'_2', name+'_2'], [name+'_slice0_out']), + make_node('Slice', [name+'_anchors', name+'_2', name+'_4', name+'_2'], [name+'_slice1_out']), + make_node('Sub', [name+'_slice1_out', name+'_slice0_out'], [name+'_anchor_wh']), + make_node('Div', [name+'_anchor_wh', name+'_2f'], [name+'_div0_out']), + make_node("Add", [name+'_slice0_out', name+'_div0_out'], [name+'_anchor_xy']), + ] + else: + nodes += [ + make_node('Slice', [name+'_anchors', name+'_0', name+'_2', name+'_2'], [name+'_anchor_xy']), + make_node('Slice', [name+'_anchors', name+'_2', name+'_4', name+'_2'], [name+'_anchor_wh']), + ] + + nodes += [ + make_node("Mul", [name+'_data_xy', name+'_anchor_wh'], [name+'_mul1_out']), + make_node("Add", [name+'_mul1_out', name+'_anchor_xy'], [name+'_add0_out']), + ] + + if clip > 0.: + nodes += [ + make_node("Less", [name+"_data_wh", name+"_clip"], [name+"_less0_out"]), + make_node('Where', [name+'_less0_out', name+'_data_wh', name+'_clip'], [name+'_where0_out']), + make_node("Exp", [name+'_where0_out'], [name+'_exp0_out']), + ] + else: + nodes += [ + make_node("Exp", [name+'_data_wh'], [name+'_exp0_out']), + ] + + nodes += [ + make_node("Mul", [name+'_exp0_out', name+'_anchor_wh'], [name+'_mul2_out']), + make_node('Div', [name+'_mul2_out', name+'_2f'], [name+'_div1_out']), + make_node('Sub', [name+'_add0_out', name+'_div1_out'], [name+'_sub0_out']), + make_node('Add', [name+'_add0_out', name+'_div1_out'], [name+'_add1_out']), + make_node('Concat', [name+'_sub0_out', name+'_add1_out'], [name+'concat0_out'], axis=2), + make_node("Cast", [name+'concat0_out'], [name], to=input_type, name=name) + ] + + return nodes diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index f9a6abcaf14e..4882321e39a6 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -408,3 +408,16 @@ def test_onnx_export_where(tmp_path, dtype, shape): y = mx.nd.ones(shape, dtype=dtype) cond = mx.nd.random.randint(low=0, high=1, shape=shape, dtype='int32') op_export_test('where', M, [cond, x, y], tmp_path) + +@pytest.mark.parametrize('dtype', ['float16', 'float32']) +@pytest.mark.parametrize('fmt', ['corner', 'center']) +@pytest.mark.parametrize('clip', [-1., 0., .5, 5.]) +def test_onnx_export_contrib_box_decode(tmp_path, dtype, fmt, clip): + # ensure data[0] < data[2] and data[1] < data[3] for corner format + mul = mx.nd.array([-1, -1, 1, 1], dtype=dtype) + data = mx.nd.random.uniform(0, 1, (2, 3, 4), dtype=dtype) * mul + anchors = mx.nd.random.uniform(0, 1, (1, 3, 4), dtype=dtype) * mul + M1 = def_model('contrib.box_decode', format=fmt, clip=clip) + op_export_test('contrib_box_decode', M1, [data, anchors], tmp_path) + M2 = def_model('contrib.box_decode', format=fmt, clip=clip, std0=0.3, std1=1.4, std2=0.5, std3=1.6) + op_export_test('contrib_box_decode', M1, [data, anchors], tmp_path)