diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 73ca07be76ee..e2aab6b1efa7 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -1545,3 +1545,45 @@ def convert_sum(node, **kwargs): name=name ) return [node] + +@mx_op.register("hard_sigmoid") +def convert_hardsigmoid(node, **kwargs): + """Map MXNet's hard_sigmoid operator attributes to onnx's HardSigmoid operator + and return the created node. + """ + name, input_nodes, attrs = get_inputs(node, kwargs) + + # Converting to float32 + alpha = float(attrs.get("alpha", 0.2)) + beta = float(attrs.get("beta", 0.5)) + + node = onnx.helper.make_node( + 'HardSigmoid', + input_nodes, + [name], + alpha=alpha, + beta=beta, + name=name + ) + return [node] + +@mx_op.register("broadcast_lesser") +def convert_broadcast_lesser(node, **kwargs): + """Map MXNet's broadcast_lesser operator attributes to onnx's Less operator + and return the created node. + """ + return create_basic_op_node('Less', node, kwargs) + +@mx_op.register("broadcast_greater") +def convert_broadcast_greater(node, **kwargs): + """Map MXNet's broadcast_greater operator attributes to onnx's Greater operator + and return the created node. + """ + return create_basic_op_node('Greater', node, kwargs) + +@mx_op.register("broadcast_equal") +def convert_broadcast_equal(node, **kwargs): + """Map MXNet's broadcast_equal operator attributes to onnx's Equal operator + and return the created node. + """ + return create_basic_op_node('Equal', node, kwargs) diff --git a/tests/python-pytest/onnx/export/mxnet_export_test.py b/tests/python-pytest/onnx/export/mxnet_export_test.py index bbff7833fe20..f4144fd6c7fa 100644 --- a/tests/python-pytest/onnx/export/mxnet_export_test.py +++ b/tests/python-pytest/onnx/export/mxnet_export_test.py @@ -241,6 +241,32 @@ def test_square(): npt.assert_almost_equal(result, numpy_op) +@with_seed() +def test_comparison_ops(): + """Test greater, lesser, equal""" + def test_ops(op_name, inputs, input_tensors, numpy_op): + outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, shape=np.shape(inputs[0]))] + nodes = [helper.make_node(op_name, ["input"+str(i+1) for i in range(len(inputs))], ["output"])] + graph = helper.make_graph(nodes, + op_name + "_test", + input_tensors, + outputs) + model = helper.make_model(graph) + bkd_rep = backend.prepare(model) + output = bkd_rep.run(inputs) + npt.assert_almost_equal(output[0], numpy_op) + input_data = [np.random.rand(1, 3, 4, 5).astype("float32"), + np.random.rand(1, 5).astype("float32")] + input_tensor = [] + for idx, ip in enumerate(input_data): + input_tensor.append(helper.make_tensor_value_info("input" + str(idx + 1), + TensorProto.FLOAT, shape=np.shape(ip))) + test_ops("Greater", input_data, input_tensor, + np.greater(input_data[0], input_data[1]).astype(np.float32)) + test_ops("Less", input_data, input_tensor, + np.less(input_data[0], input_data[1]).astype(np.float32)) + test_ops("Equal", input_data, input_tensor, + np.equal(input_data[0], input_data[1]).astype(np.float32)) def _assert_sym_equal(lhs, rhs): assert lhs.list_inputs() == rhs.list_inputs() # input names must be identical diff --git a/tests/python-pytest/onnx/export/onnx_backend_test.py b/tests/python-pytest/onnx/export/onnx_backend_test.py index 678435d92357..ec9ddf23c252 100644 --- a/tests/python-pytest/onnx/export/onnx_backend_test.py +++ b/tests/python-pytest/onnx/export/onnx_backend_test.py @@ -94,7 +94,8 @@ 'test_operator_permute2', 'test_clip' 'test_cast', - 'test_depthtospace' + 'test_depthtospace', + 'test_hardsigmoid' ] BASIC_MODEL_TESTS = [ diff --git a/tests/python-pytest/onnx/import/test_cases.py b/tests/python-pytest/onnx/import/test_cases.py index 7f34247c94e2..aed68ffa114c 100644 --- a/tests/python-pytest/onnx/import/test_cases.py +++ b/tests/python-pytest/onnx/import/test_cases.py @@ -71,7 +71,7 @@ 'test_sin', 'test_tan', 'test_shape', - 'test_hardsigmoid_', + 'test_hardsigmoid', 'test_averagepool_1d', 'test_averagepool_2d_pads_count_include_pad', 'test_averagepool_2d_precomputed_pads_count_include_pad',