From 3066267c39fc5ae57a1af7ca96288d8aedd076c1 Mon Sep 17 00:00:00 2001 From: Joe Evans Date: Tue, 29 Dec 2020 22:50:40 +0000 Subject: [PATCH 01/13] Add onnx export support for ones_like operator. --- .../contrib/onnx/mx2onnx/_op_translations.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index d30197560175..4927541a74fa 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -2703,6 +2703,22 @@ def convert_zeros_like(node, **kwargs): return nodes +@mx_op.register("ones_like") +def convert_ones_like(node, **kwargs): + """Map MXNet's ones_like operator attributes to onnx's ConstantOfShape operator. + """ + from onnx.helper import make_node, make_tensor + name, input_nodes, _ = get_inputs(node, kwargs) + + # create tensor with shape of input + tensor_value = make_tensor(name+"_one", kwargs['in_type'], [1], [1]) + nodes = [ + make_node("Shape", [input_nodes[0]], [name+"_shape"]), + make_node("ConstantOfShape", [name+"_shape"], [name], name=name, value=tensor_value) + ] + return nodes + + @mx_op.register("_contrib_arange_like") def convert_arange_like(node, **kwargs): """Map MXNet's arange_like operator attributes to onnx's Range and Reshape operators. From c5dbc5dc572df58691756984e8814adbd382d84b Mon Sep 17 00:00:00 2001 From: Joe Evans Date: Wed, 30 Dec 2020 00:30:32 +0000 Subject: [PATCH 02/13] Clean up dropout, clip and topk export functions. --- .../contrib/onnx/mx2onnx/_op_translations.py | 113 ++++-------------- 1 file changed, 26 insertions(+), 87 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 4927541a74fa..ddbe024abe1d 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -1103,6 +1103,7 @@ def convert_dropout(node, **kwargs): """Map MXNet's Dropout operator attributes to onnx's Dropout operator and return the created node. """ + from onnx.helper import make_node name, input_nodes, attrs = get_inputs(node, kwargs) opset_version = kwargs["opset_version"] @@ -1110,29 +1111,13 @@ def convert_dropout(node, **kwargs): if opset_version >= 12: # opset >= 12 requires the ratio to be an input - initializer = kwargs["initializer"] - ratio_input_name = name + "_ratio" - value_node = onnx.helper.make_tensor_value_info(ratio_input_name, - onnx.TensorProto.FLOAT, ()) - tensor_node = onnx.helper.make_tensor(ratio_input_name, onnx.TensorProto.FLOAT, - (), [probability]) - initializer.append(tensor_node) - dropout_node = onnx.helper.make_node( - "Dropout", - [input_nodes[0], ratio_input_name], - [name], - name=name - ) - return [value_node, dropout_node] + create_const_scalar_node(name+"_ratio0", np.float32(probability), kwargs) + nodes = [ + make_node("Dropout", [input_nodes[0], name+"_ratio0"], [name], name=name) + ] + return nodes else: - dropout_node = onnx.helper.make_node( - "Dropout", - input_nodes, - [name], - ratio=probability, - name=name - ) - return [dropout_node] + return [make_node("Dropout", input_nodes, [name], ratio=probability, name=name)] @mx_op.register("Flatten") @@ -1147,6 +1132,7 @@ def convert_clip(node, **kwargs): """Map MXNet's Clip operator attributes to onnx's Clip operator and return the created node. """ + from onnx.helper import make_node name, input_nodes, attrs = get_inputs(node, kwargs) opset_version = kwargs["opset_version"] @@ -1155,39 +1141,16 @@ def convert_clip(node, **kwargs): if opset_version >= 11: # opset >= 11 requires min/max to be inputs - initializer = kwargs["initializer"] - min_input_name = name + "_min" - max_input_name = name + "_max" - min_value_node = onnx.helper.make_tensor_value_info(min_input_name, - onnx.TensorProto.FLOAT, ()) - max_value_node = onnx.helper.make_tensor_value_info(max_input_name, - onnx.TensorProto.FLOAT, ()) - min_tensor_node = onnx.helper.make_tensor(min_input_name, onnx.TensorProto.FLOAT, - (), [a_min]) - max_tensor_node = onnx.helper.make_tensor(max_input_name, onnx.TensorProto.FLOAT, - (), [a_max]) - initializer.append(min_tensor_node) - initializer.append(max_tensor_node) - input_nodes.append(min_input_name) - input_nodes.append(max_input_name) - clip_node = onnx.helper.make_node( - "Clip", - input_nodes, - [name], - name=name - ) - return [min_value_node, max_value_node, clip_node] - + create_const_scalar_node(name+"_min", np.float32(a_min), kwargs) + create_const_scalar_node(name+"_max", np.float32(a_max), kwargs) + nodes = [ + make_node("Clip", [input_nodes[0], name+"_min", name+"_max"], [name], name=name) + ] else: - clip_node = onnx.helper.make_node( - "Clip", - input_nodes, - [name], - name=name, - min=a_min, - max=a_max - ) - return [clip_node] + nodes = [ + make_node("Clip", input_nodes, [name], name=name, min=a_min, max=a_max) + ] + return nodes def scalar_op_helper(node, op_name, **kwargs): @@ -2267,52 +2230,28 @@ def convert_topk(node, **kwargs): """Map MXNet's topk operator attributes to onnx's TopK operator and return the created node. """ + from onnx.helper import make_node name, input_nodes, attrs = get_inputs(node, kwargs) axis = int(attrs.get('axis', '-1')) k = int(attrs.get('k', '1')) ret_type = attrs.get('ret_typ') - dtype = attrs.get('dtype') - outputs = [name + '_output0'] + outputs = [name] if ret_type and ret_type == 'both': - if dtype and dtype == 'int64': - outputs.append(name + '_output1') - else: - raise NotImplementedError("ONNX expects indices to be of type int64") + outputs.append(name + '_output1') else: raise NotImplementedError("ONNX expects both value and indices as output") opset_version = kwargs['opset_version'] if opset_version >= 10: - from onnx.helper import make_tensor, make_tensor_value_info - initializer = kwargs["initializer"] - k_input_name = name + "_k" - k_input_type = onnx.TensorProto.INT64 - k_value_node = make_tensor_value_info(k_input_name, k_input_type, ()) - k_tensor_node = make_tensor(k_input_name, k_input_type, (), k) - initializer.append(k_tensor_node) - input_nodes.append(k_input_name) - - topk_node = onnx.helper.make_node( - "TopK", - input_nodes, - outputs, - axis=axis, - name=name - ) - return [k_value_node, topk_node] + create_const_scalar_node(name+"_k", np.int64(k), kwargs) + nodes = [ + make_node("TopK", [input_nodes[0], name+"_k"], outputs, axis=axis, name=name) + ] + return nodes else: - topk_node = onnx.helper.make_node( - "TopK", - input_nodes, - outputs, - axis=axis, - k=k, - name=name - ) - - return [topk_node] + return [make_node("TopK", input_nodes, outputs, axis=axis, k=k, name=name)] @mx_op.register("take") From 88d31834583f2ecc519abafe88695ee35f6e8c9e Mon Sep 17 00:00:00 2001 From: Joe Evans Date: Wed, 30 Dec 2020 00:58:59 +0000 Subject: [PATCH 03/13] Clean up pad export function. --- .../contrib/onnx/mx2onnx/_op_translations.py | 43 +++++-------------- 1 file changed, 10 insertions(+), 33 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index ddbe024abe1d..0f8017862c25 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -515,40 +515,18 @@ def convert_pad(node, **kwargs): if opset_version >= 11: # starting with opset 11, pads and constant_value are inputs instead of attributes - from onnx.helper import make_tensor, make_tensor_value_info - initializer = kwargs["initializer"] - pads_input_name = name + "_pads" - pads_input_type = onnx.TensorProto.INT64 - pads_input_shape = np.shape(np.array(onnx_pad_width)) - pads_value_node = make_tensor_value_info(pads_input_name, pads_input_type, pads_input_shape) - pads_tensor_node = make_tensor(pads_input_name, pads_input_type, pads_input_shape, onnx_pad_width) - initializer.append(pads_tensor_node) - input_nodes.append(pads_input_name) + create_const_scalar_node(name+"_pads", np.int64(onnx_pad_width), kwargs) if pad_mode == "constant": - const_input_name = name + "_constant" - const_input_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[pad_value.dtype] - const_value_node = make_tensor_value_info(const_input_name, const_input_type, ()) - const_tensor_node = make_tensor(const_input_name, const_input_type, (), [pad_value]) - initializer.append(const_tensor_node) - input_nodes.append(const_input_name) - pad_node = onnx.helper.make_node( - "Pad", - input_nodes, - [name], - mode=pad_mode, - name=name - ) - return [pads_value_node, const_value_node, pad_node] + create_const_scalar_node(name+"_const", pad_value, kwargs) + nodes = [ + make_node("Pad", [input_nodes[0], name+"_pads", name+"_const"], [name], mode=pad_mode, name=name) + ] else: - pad_node = onnx.helper.make_node( - "Pad", - input_nodes, - [name], - mode=pad_mode, - name=name - ) - return [pads_value_node, pad_node] + nodes = [ + make_node("Pad", [input_nodes[0], name+"_pads"], [name], mode=pad_mode, name=name) + ] + return nodes else: if pad_mode == "constant": node = onnx.helper.make_node( @@ -560,7 +538,6 @@ def convert_pad(node, **kwargs): pads=onnx_pad_width, name=name ) - return [node] else: node = onnx.helper.make_node( 'Pad', @@ -570,7 +547,7 @@ def convert_pad(node, **kwargs): pads=onnx_pad_width, name=name ) - return [node] + return [node] def create_helper_trans_node(node_name, input_node): From 097b21c1f8913b7d3271e3533d74b5bd64839d99 Mon Sep 17 00:00:00 2001 From: Joe Evans Date: Tue, 5 Jan 2021 20:36:50 +0000 Subject: [PATCH 04/13] Add unit test for ones_like onnx export. --- tests/python-pytest/onnx/test_operators.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 057a279880a4..91b967f03db7 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -94,6 +94,10 @@ def test_onnx_export_zeros_like(tmp_path): x = mx.nd.array([[-2,-1,0],[0,50,99],[4,5,6],[7,8,9]], dtype='float32') op_export_test('zeros_like', M, [x], tmp_path) +def test_onnx_export_ones_like(tmp_path): + M = def_model('ones_like') + x = mx.nd.array([[-2,-1,0],[0,50,99],[4,5,6],[7,8,9]], dtype='float32') + op_export_test('ones_like', M, [x], tmp_path) @pytest.mark.parametrize("dtype", ["float32", "float64"]) @pytest.mark.parametrize("axis", [None,0,1]) From cc20adb1ea8981397a50388f90d52660c6159e7d Mon Sep 17 00:00:00 2001 From: Joe Evans Date: Wed, 6 Jan 2021 00:05:56 +0000 Subject: [PATCH 05/13] Add onnx export function for arange operator. --- .../contrib/onnx/mx2onnx/_op_translations.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 0f8017862c25..bf644326848f 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -2691,3 +2691,32 @@ def convert_arange_like(node, **kwargs): ] return nodes + +@mx_op.register("_arange") +def convert_arange(node, **kwargs): + """Map MXNet's arange operator attributes to onnx's Range operator. + """ + from onnx.helper import make_node + name, input_nodes, attrs = get_inputs(node, kwargs) + + opset_version = kwargs['opset_version'] + if opset_version < 11: + raise AttributeError("ONNX opset 11 or greater is required to export this operator") + + input_type = kwargs['in_type'] + dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[input_type] + start = attrs.get('start', 0.) + stop = attrs.get('stop') + step = attrs.get('step', 1.) + repeat = int(attrs.get('repeat', 1)) + if repeat != 1: + raise NotImplementedError("arange operator with repeat != 1 not yet implemented.") + + nodes = [ + create_const_scalar_node(name+"_start", np.array([start], dtype=dtype), kwargs), + create_const_scalar_node(name+"_stop", np.array([stop], dtype=dtype), kwargs), + create_const_scalar_node(name+"_step", np.array([step], dtype=dtype), kwargs), + make_node("Range", [name+"_start", name+"_stop", name+"_step"], [name]) + ] + + return nodes From 9aaabb4ec8d62e6365f4816eb3498073ae97e663 Mon Sep 17 00:00:00 2001 From: Joe Evans Date: Wed, 6 Jan 2021 00:32:56 +0000 Subject: [PATCH 06/13] Fix lint. --- python/mxnet/contrib/onnx/mx2onnx/_op_translations.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index bf644326848f..b53aa4a31785 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -504,6 +504,7 @@ def convert_pad(node, **kwargs): """Map MXNet's pad operator attributes to onnx's Pad operator and return the created node. """ + from onnx.helper import make_node opset_version = kwargs["opset_version"] name, input_nodes, attrs = get_inputs(node, kwargs) @@ -2697,7 +2698,7 @@ def convert_arange(node, **kwargs): """Map MXNet's arange operator attributes to onnx's Range operator. """ from onnx.helper import make_node - name, input_nodes, attrs = get_inputs(node, kwargs) + name, _, attrs = get_inputs(node, kwargs) opset_version = kwargs['opset_version'] if opset_version < 11: From 084faa535da8fcd7556361bc4e5365a8263f2077 Mon Sep 17 00:00:00 2001 From: Joe Evans Date: Wed, 6 Jan 2021 03:47:04 +0000 Subject: [PATCH 07/13] Make sure to return all nodes created. --- .../contrib/onnx/mx2onnx/_op_translations.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index b53aa4a31785..d0dc54fa6b3f 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -516,15 +516,17 @@ def convert_pad(node, **kwargs): if opset_version >= 11: # starting with opset 11, pads and constant_value are inputs instead of attributes - create_const_scalar_node(name+"_pads", np.int64(onnx_pad_width), kwargs) + nodes = [ + create_const_node(name+"_pads", np.array(onnx_pad_width, dtype='int64'), kwargs) + ] if pad_mode == "constant": - create_const_scalar_node(name+"_const", pad_value, kwargs) - nodes = [ + nodes += [ + create_const_scalar_node(name+"_const", pad_value, kwargs), make_node("Pad", [input_nodes[0], name+"_pads", name+"_const"], [name], mode=pad_mode, name=name) ] else: - nodes = [ + nodes += [ make_node("Pad", [input_nodes[0], name+"_pads"], [name], mode=pad_mode, name=name) ] return nodes @@ -1089,8 +1091,8 @@ def convert_dropout(node, **kwargs): if opset_version >= 12: # opset >= 12 requires the ratio to be an input - create_const_scalar_node(name+"_ratio0", np.float32(probability), kwargs) nodes = [ + create_const_scalar_node(name+"_ratio0", np.float32(probability), kwargs), make_node("Dropout", [input_nodes[0], name+"_ratio0"], [name], name=name) ] return nodes @@ -2713,10 +2715,10 @@ def convert_arange(node, **kwargs): if repeat != 1: raise NotImplementedError("arange operator with repeat != 1 not yet implemented.") + create_const_scalar_node(name+"_start", np.array([start], dtype=dtype), kwargs) + create_const_scalar_node(name+"_stop", np.array([stop], dtype=dtype), kwargs) + create_const_scalar_node(name+"_step", np.array([step], dtype=dtype), kwargs) nodes = [ - create_const_scalar_node(name+"_start", np.array([start], dtype=dtype), kwargs), - create_const_scalar_node(name+"_stop", np.array([stop], dtype=dtype), kwargs), - create_const_scalar_node(name+"_step", np.array([step], dtype=dtype), kwargs), make_node("Range", [name+"_start", name+"_stop", name+"_step"], [name]) ] From e1f150da7a07f3e66c8201720d50e4f9e63b1d5e Mon Sep 17 00:00:00 2001 From: Joe Evans Date: Wed, 6 Jan 2021 03:48:41 +0000 Subject: [PATCH 08/13] Extend operator test to work with no inputs, add unit test for arange. --- tests/python-pytest/onnx/test_operators.py | 23 ++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 91b967f03db7..d2f276f7e781 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -23,7 +23,7 @@ import pytest import tempfile -def def_model(op_name, **params): +def def_model(op_name, dummy_input=False, **params): class Model(HybridBlock): def __init__(self, **kwargs): super(Model, self).__init__(**kwargs) @@ -33,11 +33,13 @@ def hybrid_forward(self, F, *inputs): func = F for name in names: func = getattr(func, name) - out = func(*inputs, **params) - return out + if dummy_input: + return func(**params), inputs[0] + else: + return func(*inputs, **params) return Model -def op_export_test(model_name, Model, inputs, tmp_path): +def op_export_test(model_name, Model, inputs, tmp_path, dummy_input=False): def export_to_onnx(model, model_name, inputs): model_path = '{}/{}'.format(tmp_path, model_name) model.export(model_path, epoch=0) @@ -63,6 +65,8 @@ def onnx_rt(onnx_file, inputs): pred_nat = model(*inputs) onnx_file = export_to_onnx(model, model_name, inputs) pred_onx = onnx_rt(onnx_file, inputs) + if dummy_input: + pred_nat = pred_nat[0] assert_almost_equal(pred_nat, pred_onx) @@ -109,6 +113,17 @@ def test_onnx_export_arange_like(tmp_path, dtype, axis, start, step, test_data): x = mx.nd.array(test_data, dtype=dtype) op_export_test('arange_like', M, [x], tmp_path) + +@pytest.mark.parametrize("stop", [2, 50, 5000]) +@pytest.mark.parametrize("step", [0.25, 0.5, 1, 5]) +@pytest.mark.parametrize("start", [0., 1.]) +@pytest.mark.parametrize("dtype", ["float32", "float64"]) +def test_onnx_export_arange(tmp_path, dtype, start, stop, step): + M = def_model('arange', dummy_input=True, start=start, stop=stop, step=step, dtype=dtype) + x = mx.nd.array([1], dtype=dtype) + op_export_test('arange', M, [x], tmp_path, dummy_input=True) + + @pytest.mark.parametrize('dtype', ['float32']) def test_onnx_export_layernorm(tmp_path, dtype): x = mx.nd.random.uniform(1, 2, (3, 4, 5), dtype=dtype) From 363273ec8875fc3c2cddf049b324590053dea507 Mon Sep 17 00:00:00 2001 From: Joe Evans Date: Wed, 6 Jan 2021 03:53:09 +0000 Subject: [PATCH 09/13] Extent arange test to also test int32 and int64 dtypes. --- tests/python-pytest/onnx/test_operators.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index d2f276f7e781..096a449bd5e9 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -117,8 +117,14 @@ def test_onnx_export_arange_like(tmp_path, dtype, axis, start, step, test_data): @pytest.mark.parametrize("stop", [2, 50, 5000]) @pytest.mark.parametrize("step", [0.25, 0.5, 1, 5]) @pytest.mark.parametrize("start", [0., 1.]) -@pytest.mark.parametrize("dtype", ["float32", "float64"]) +@pytest.mark.parametrize("dtype", ["float32", "float64", "int32", "int64"]) def test_onnx_export_arange(tmp_path, dtype, start, stop, step): + if "int" in dtype: + start = int(start) + stop = int(stop) + step = int(step) + if step == 0: + step = 1 M = def_model('arange', dummy_input=True, start=start, stop=stop, step=step, dtype=dtype) x = mx.nd.array([1], dtype=dtype) op_export_test('arange', M, [x], tmp_path, dummy_input=True) From 5e1d583b2e581a155be731925fd3bcb423d6f325 Mon Sep 17 00:00:00 2001 From: Joe Evans Date: Wed, 6 Jan 2021 04:00:31 +0000 Subject: [PATCH 10/13] Return scalar nodes in clip conversion function. --- python/mxnet/contrib/onnx/mx2onnx/_op_translations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index d0dc54fa6b3f..6ff2800321e6 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -1121,9 +1121,9 @@ def convert_clip(node, **kwargs): if opset_version >= 11: # opset >= 11 requires min/max to be inputs - create_const_scalar_node(name+"_min", np.float32(a_min), kwargs) - create_const_scalar_node(name+"_max", np.float32(a_max), kwargs) nodes = [ + create_const_scalar_node(name+"_min", np.float32(a_min), kwargs), + create_const_scalar_node(name+"_max", np.float32(a_max), kwargs), make_node("Clip", [input_nodes[0], name+"_min", name+"_max"], [name], name=name) ] else: From 4890786b881814d506bacf9ea4da6a04d4672ab0 Mon Sep 17 00:00:00 2001 From: Joe Evans Date: Wed, 6 Jan 2021 04:12:31 +0000 Subject: [PATCH 11/13] Make sure to return all graph nodes created in export ops. --- .../contrib/onnx/mx2onnx/_op_translations.py | 29 ++++++++++--------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 6ff2800321e6..176080c3b2e9 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -1648,25 +1648,28 @@ def convert_slice_axis(node, **kwargs): begin = int(attrs.get("begin")) end = attrs.get("end", None) - nodes = [] - create_tensor([axis], name+'_axis', kwargs["initializer"]) - create_tensor([begin], name+'_begin', kwargs["initializer"]) + nodes = [ + create_tensor([axis], name+'_axis', kwargs["initializer"]), + create_tensor([begin], name+'_begin', kwargs["initializer"]) + ] if not end or end == 'None': # ONNX doesn't support None for ends. Since ends=None depicts # length of dimension, passing dimension in this case. - create_tensor([axis+1], name+"_axis_plus_1", kwargs["initializer"]) nodes += [ + create_tensor([axis+1], name+"_axis_plus_1", kwargs["initializer"]), make_node('Shape', [input_nodes[0]], [name+"_data_shape"]), make_node('Slice', [name+'_data_shape', name+'_axis', name+'_axis_plus_1'], [name+"_end"]) ] else: - create_tensor([int(end)], name+'_end', kwargs["initializer"]) + nodes += [ + create_tensor([int(end)], name+'_end', kwargs["initializer"]) + ] nodes += [ make_node('Slice', [input_nodes[0], name+'_begin', name+'_end', name+'_axis'], [name], name=name) - ] + ] return nodes @@ -2225,8 +2228,8 @@ def convert_topk(node, **kwargs): opset_version = kwargs['opset_version'] if opset_version >= 10: - create_const_scalar_node(name+"_k", np.int64(k), kwargs) nodes = [ + create_const_scalar_node(name+"_k", np.int64(k), kwargs), make_node("TopK", [input_nodes[0], name+"_k"], outputs, axis=axis, name=name) ] return nodes @@ -2444,7 +2447,7 @@ def convert_broadcast_axis(node, **kwargs): make_node('Shape', [shape_name], [name+'_in_dim']), make_node('Reshape', [name+'_in_dim', name+'_void'], [name+'_in_dim_s']), make_node('Range', [name+'_0_s', name+'_in_dim_s', name+'_1_s'], [name+'_range']), - ] + ] for i, axis in enumerate(axis): if axis not in (0, 1): @@ -2456,7 +2459,7 @@ def convert_broadcast_axis(node, **kwargs): make_node('Mul', [name+'_size_'+str(i), name+'_cast_'+str(i)], [name+'_mul_'+str(i)]), make_node('Add', [name+'_mul_'+str(i), name+'_1'], [name+'_add_'+str(i)]), make_node('Mul', [name+'_add_'+str(i), shape_name], [name+'_shape_'+str(i+1)]) - ] + ] shape_name = name+'_shape_'+str(i+1) nodes += [make_node('Expand', [input_nodes[0], shape_name], [name], name=name)] @@ -2498,7 +2501,7 @@ def convert_sequencemask(node, **kwargs): make_node('Range', [name+'_0_s', name+'_in_dim_s', name+'_1_s'], [name+'_range_0']), make_node('Less', [name+'_range_0', name+'_2'], [name+'_less_0']), make_node('Where', [name+'_less_0', name+'_in_shape', name+'_1'], [name+'_shape_1']) - ] + ] if(axis == 0): nodes += [ @@ -2715,10 +2718,10 @@ def convert_arange(node, **kwargs): if repeat != 1: raise NotImplementedError("arange operator with repeat != 1 not yet implemented.") - create_const_scalar_node(name+"_start", np.array([start], dtype=dtype), kwargs) - create_const_scalar_node(name+"_stop", np.array([stop], dtype=dtype), kwargs) - create_const_scalar_node(name+"_step", np.array([step], dtype=dtype), kwargs) nodes = [ + create_const_scalar_node(name+"_start", np.array([start], dtype=dtype), kwargs), + create_const_scalar_node(name+"_stop", np.array([stop], dtype=dtype), kwargs), + create_const_scalar_node(name+"_step", np.array([step], dtype=dtype), kwargs), make_node("Range", [name+"_start", name+"_stop", name+"_step"], [name]) ] From 51f18397092feafaaa2c93da358923bca9a89a1e Mon Sep 17 00:00:00 2001 From: Joe Evans Date: Wed, 6 Jan 2021 04:41:39 +0000 Subject: [PATCH 12/13] Properly obey dtype attribute instead of using input type for arange. --- python/mxnet/contrib/onnx/mx2onnx/_op_translations.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 176080c3b2e9..07537a3d99f5 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -2709,11 +2709,10 @@ def convert_arange(node, **kwargs): if opset_version < 11: raise AttributeError("ONNX opset 11 or greater is required to export this operator") - input_type = kwargs['in_type'] - dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[input_type] start = attrs.get('start', 0.) stop = attrs.get('stop') step = attrs.get('step', 1.) + dtype = attrs.get('dtype', 'float32') repeat = int(attrs.get('repeat', 1)) if repeat != 1: raise NotImplementedError("arange operator with repeat != 1 not yet implemented.") From a19f4c326e32a90cb76c1c497618ca315ec7a9f3 Mon Sep 17 00:00:00 2001 From: Joe Evans Date: Wed, 6 Jan 2021 04:42:58 +0000 Subject: [PATCH 13/13] Use static dtype for parameter to catch errors when dtype != input type. --- tests/python-pytest/onnx/test_operators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 096a449bd5e9..34838a081996 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -126,7 +126,7 @@ def test_onnx_export_arange(tmp_path, dtype, start, stop, step): if step == 0: step = 1 M = def_model('arange', dummy_input=True, start=start, stop=stop, step=step, dtype=dtype) - x = mx.nd.array([1], dtype=dtype) + x = mx.nd.array([1], dtype='float32') op_export_test('arange', M, [x], tmp_path, dummy_input=True)