From f0b824e2f65ee9eb338ac955959fb89def12e46f Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Tue, 26 Jan 2021 20:13:35 -0800 Subject: [PATCH 01/17] add test --- python/mxnet/contrib/onnx/mx2onnx/_op_translations.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 77179bd46966..a1dcd7d454c9 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -4035,6 +4035,7 @@ def convert_argsort(node, **kwargs): to=onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)]) ] + return nodes From d10ebd574599701b1d53c34a99fe0ff932279220 Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Tue, 26 Jan 2021 20:29:13 -0800 Subject: [PATCH 02/17] support multiple input nodes --- python/mxnet/contrib/onnx/mx2onnx/export_model.py | 8 ++++---- python/mxnet/contrib/onnx/mx2onnx/export_onnx.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_model.py b/python/mxnet/contrib/onnx/mx2onnx/export_model.py index 2fc77604b9b6..0697832d6803 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/export_model.py +++ b/python/mxnet/contrib/onnx/mx2onnx/export_model.py @@ -28,7 +28,7 @@ from ._export_helper import load_module -def export_model(sym, params, input_shape, input_type=np.float32, +def export_model(sym, params, input_shape, input_type=[np.float32], onnx_file_path='model.onnx', verbose=False, opset_version=None): """Exports the MXNet model file, passed as a parameter, into ONNX model. Accepts both symbol,parameter objects as well as json and params filepaths as input. @@ -73,17 +73,17 @@ def export_model(sym, params, input_shape, input_type=np.float32, # default is to use latest opset version the onnx package supports opset_version = onnx_opset_version() - data_format = np.dtype(input_type) + data_types = [mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(inp_type)] for inp_type in input_type] # if input parameters are strings(file paths), load files and create symbol parameter objects if isinstance(sym, string_types) and isinstance(params, string_types): logging.info("Converting json and weight file to sym and params") sym_obj, params_obj = load_module(sym, params) onnx_graph = converter.create_onnx_graph_proto(sym_obj, params_obj, input_shape, - mapping.NP_TYPE_TO_TENSOR_TYPE[data_format], + data_types, verbose=verbose, opset_version=opset_version) elif isinstance(sym, symbol.Symbol) and isinstance(params, dict): onnx_graph = converter.create_onnx_graph_proto(sym, params, input_shape, - mapping.NP_TYPE_TO_TENSOR_TYPE[data_format], + data_types, verbose=verbose, opset_version=opset_version) else: raise ValueError("Input sym and params should either be files or objects") diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py index 89f061d4c161..4fda0bdb1314 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py +++ b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py @@ -156,7 +156,7 @@ def get_outputs(sym, params, in_shape, in_label, in_type): assert len(out_shapes) == len(out_names) # infer output types - args = {n: mapping.TENSOR_TYPE_TO_NP_TYPE[in_type] for n in sym.list_inputs()} + args = {n: mapping.TENSOR_TYPE_TO_NP_TYPE[in_type[i]] for i, n in enumerate(sym.list_inputs())} _, out_type, _ = sym.infer_type(**args) out_types = [mapping.NP_TYPE_TO_TENSOR_TYPE[o(0).dtype] for o in out_type] @@ -256,7 +256,7 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False, mx_graph=mx_graph, weights=weights, in_shape=in_shape[graph_input_idx], - in_type=in_type, + in_type=in_type[graph_input_idx], proc_nodes=all_processed_nodes, initializer=initializer, outputs_lookup=outputs_lookup) @@ -270,7 +270,7 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False, mx_graph=mx_graph, weights=weights, in_shape=in_shape, - in_type=in_type, + in_type=in_type[0], proc_nodes=all_processed_nodes, initializer=initializer, outputs_lookup=outputs_lookup, From 8cf81ecf58658a0ee14d88d412543353e6a96d30 Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Tue, 26 Jan 2021 21:57:18 -0800 Subject: [PATCH 03/17] fix sanity --- python/mxnet/contrib/onnx/mx2onnx/export_model.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_model.py b/python/mxnet/contrib/onnx/mx2onnx/export_model.py index 0697832d6803..086918fd56c9 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/export_model.py +++ b/python/mxnet/contrib/onnx/mx2onnx/export_model.py @@ -28,7 +28,7 @@ from ._export_helper import load_module -def export_model(sym, params, input_shape, input_type=[np.float32], +def export_model(sym, params, input_shape, input_type, onnx_file_path='model.onnx', verbose=False, opset_version=None): """Exports the MXNet model file, passed as a parameter, into ONNX model. Accepts both symbol,parameter objects as well as json and params filepaths as input. @@ -43,8 +43,8 @@ def export_model(sym, params, input_shape, input_type=[np.float32], Path to the params file or params dictionary. (Including both arg_params and aux_params) input_shape : List of tuple Input shape of the model e.g [(1,3,224,224)] - input_type : data type - Input data type e.g. np.float32 + input_type : List of dtype + Input data type e.g. [np.float32] onnx_file_path : str Path where to save the generated onnx file verbose : Boolean @@ -73,17 +73,17 @@ def export_model(sym, params, input_shape, input_type=[np.float32], # default is to use latest opset version the onnx package supports opset_version = onnx_opset_version() - data_types = [mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(inp_type)] for inp_type in input_type] + input_dtype = [mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(inp_type)] for inp_type in input_type] # if input parameters are strings(file paths), load files and create symbol parameter objects if isinstance(sym, string_types) and isinstance(params, string_types): logging.info("Converting json and weight file to sym and params") sym_obj, params_obj = load_module(sym, params) onnx_graph = converter.create_onnx_graph_proto(sym_obj, params_obj, input_shape, - data_types, + input_dtype, verbose=verbose, opset_version=opset_version) elif isinstance(sym, symbol.Symbol) and isinstance(params, dict): onnx_graph = converter.create_onnx_graph_proto(sym, params, input_shape, - data_types, + input_dtype, verbose=verbose, opset_version=opset_version) else: raise ValueError("Input sym and params should either be files or objects") From 7c4d1ea00348a13df466fbcb05b94a9344a12fc0 Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Wed, 27 Jan 2021 00:25:24 -0800 Subject: [PATCH 04/17] update input dtype --- tests/python-pytest/onnx/mxnet_export_test.py | 1 + tests/python-pytest/onnx/test_models.py | 2 +- tests/python-pytest/onnx/test_node.py | 2 +- tests/python-pytest/onnx/test_operators.py | 3 +-- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/python-pytest/onnx/mxnet_export_test.py b/tests/python-pytest/onnx/mxnet_export_test.py index 947fa2f6bf97..07d2fc4e954a 100644 --- a/tests/python-pytest/onnx/mxnet_export_test.py +++ b/tests/python-pytest/onnx/mxnet_export_test.py @@ -62,6 +62,7 @@ def _check_onnx_export(net, group_outputs=False, shape_type=tuple, extra_params= sym=net_sym, params=net_params, input_shape=[shape_type(data.shape)], + input_type=[data.dtype] onnx_file_path=onnx_file_path) assert export_path == onnx_file_path # Try importing the model to symbol diff --git a/tests/python-pytest/onnx/test_models.py b/tests/python-pytest/onnx/test_models.py index f85786141d6e..18730eb48ea8 100644 --- a/tests/python-pytest/onnx/test_models.py +++ b/tests/python-pytest/onnx/test_models.py @@ -136,7 +136,7 @@ def get_model_results(modelpath): onnx_file = os.path.join(dir_path, new_model_name) logging.info("Translating converted model from mxnet to ONNX") - converted_model_path = onnx_mxnet.export_model(sym, params, [input_shape], np.float32, onnx_file) + converted_model_path = onnx_mxnet.export_model(sym, params, [input_shape], [np.float32], onnx_file) sym, arg_params, aux_params, actual_result, metadata = get_model_results(converted_model_path) diff --git a/tests/python-pytest/onnx/test_node.py b/tests/python-pytest/onnx/test_node.py index 0b7fd9c7a970..28a5fe22bdc0 100644 --- a/tests/python-pytest/onnx/test_node.py +++ b/tests/python-pytest/onnx/test_node.py @@ -154,7 +154,7 @@ def test_import_export(self): if mxnet_specific: onnxmodelfile = onnx_mxnet.export_model(test_op, {}, [np.shape(ip) for ip in inputs], - np.float32, + [ip.dtype for ip in inputs], onnx_name + ".onnx") onnxmodel = load_model(onnxmodelfile) else: diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 8ebfbab2bb47..3ddcb6240b1e 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -45,10 +45,9 @@ def export_to_onnx(model, model_name, inputs): model.export(model_path, epoch=0) sym_file = '{}-symbol.json'.format(model_path) params_file = '{}-0000.params'.format(model_path) - dtype = inputs[0].dtype onnx_file = '{}/{}.onnx'.format(tmp_path, model_name) mx.contrib.onnx.export_model(sym_file, params_file, [inp.shape for inp in inputs], - dtype, onnx_file) + [inp.dtype for inp in inputs], onnx_file) return onnx_file def onnx_rt(onnx_file, inputs): From acd3beca19894c82880c4630cd2bbc45c64aeca6 Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Wed, 27 Jan 2021 10:22:01 -0800 Subject: [PATCH 05/17] fix typo --- tests/python-pytest/onnx/mxnet_export_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python-pytest/onnx/mxnet_export_test.py b/tests/python-pytest/onnx/mxnet_export_test.py index 07d2fc4e954a..82e628b156c7 100644 --- a/tests/python-pytest/onnx/mxnet_export_test.py +++ b/tests/python-pytest/onnx/mxnet_export_test.py @@ -62,7 +62,7 @@ def _check_onnx_export(net, group_outputs=False, shape_type=tuple, extra_params= sym=net_sym, params=net_params, input_shape=[shape_type(data.shape)], - input_type=[data.dtype] + input_type=[data.dtype], onnx_file_path=onnx_file_path) assert export_path == onnx_file_path # Try importing the model to symbol From 9118314a0e540d92654e1dd31d5bb359a3d196ef Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Wed, 27 Jan 2021 13:39:45 -0800 Subject: [PATCH 06/17] update export_onnx --- python/mxnet/contrib/onnx/mx2onnx/export_onnx.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py index 4fda0bdb1314..4306cca960d2 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py +++ b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py @@ -155,9 +155,14 @@ def get_outputs(sym, params, in_shape, in_label, in_type): assert len(out_shapes) == len(out_names) - # infer output types - args = {n: mapping.TENSOR_TYPE_TO_NP_TYPE[in_type[i]] for i, n in enumerate(sym.list_inputs())} - _, out_type, _ = sym.infer_type(**args) + ## Infer output types + # Remove any input listed in params from sym.list_inputs() and bind them to the input types provided + # by user. Also remove in_label + in_dtype = {n: mapping.TENSOR_TYPE_TO_NP_TYPE[t] for n, t in zip([n for n in sym.list_inputs() + if n not in params and n != in_label], in_type)} + # Add params and their shape to list of inputs + in_dtype.update({n: v.dtype for n, v in params.items() if n in sym.list_inputs()}) + _, out_type, _ = sym.infer_type(**in_dtype) out_types = [mapping.NP_TYPE_TO_TENSOR_TYPE[o(0).dtype] for o in out_type] assert len(out_types) == len(out_names) From 066f0610c524cb2c367188838885787d2c5f0c3f Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Wed, 27 Jan 2021 15:54:50 -0800 Subject: [PATCH 07/17] fix sanity --- python/mxnet/contrib/onnx/mx2onnx/export_onnx.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py index 4306cca960d2..9c1ae058a4b4 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py +++ b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py @@ -158,8 +158,8 @@ def get_outputs(sym, params, in_shape, in_label, in_type): ## Infer output types # Remove any input listed in params from sym.list_inputs() and bind them to the input types provided # by user. Also remove in_label - in_dtype = {n: mapping.TENSOR_TYPE_TO_NP_TYPE[t] for n, t in zip([n for n in sym.list_inputs() - if n not in params and n != in_label], in_type)} + in_dtype = {n: mapping.TENSOR_TYPE_TO_NP_TYPE[t] + for n, t in zip([n for n in sym.list_inputs() if n not in params and n != in_label], in_type)} # Add params and their shape to list of inputs in_dtype.update({n: v.dtype for n, v in params.items() if n in sym.list_inputs()}) _, out_type, _ = sym.infer_type(**in_dtype) From 2653a129955d2c2b4ebc194f3510df6ed5ffef6a Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Wed, 27 Jan 2021 21:57:42 -0800 Subject: [PATCH 08/17] fix space --- python/mxnet/contrib/onnx/mx2onnx/export_onnx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py index 9c1ae058a4b4..a71c0c7d3633 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py +++ b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py @@ -158,7 +158,7 @@ def get_outputs(sym, params, in_shape, in_label, in_type): ## Infer output types # Remove any input listed in params from sym.list_inputs() and bind them to the input types provided # by user. Also remove in_label - in_dtype = {n: mapping.TENSOR_TYPE_TO_NP_TYPE[t] + in_dtype = {n: mapping.TENSOR_TYPE_TO_NP_TYPE[t] for n, t in zip([n for n in sym.list_inputs() if n not in params and n != in_label], in_type)} # Add params and their shape to list of inputs in_dtype.update({n: v.dtype for n, v in params.items() if n in sym.list_inputs()}) From de04aaabb868298d418af30a4e6d85057e69d2a3 Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Tue, 2 Feb 2021 19:34:00 -0800 Subject: [PATCH 09/17] update import --- python/mxnet/contrib/onnx/mx2onnx/export_onnx.py | 7 ++++++- python/mxnet/contrib/onnx/onnx2mx/import_onnx.py | 3 ++- tests/python-pytest/onnx/backend.py | 7 ++++--- tests/python-pytest/onnx/test_node.py | 8 ++++---- tests/python-pytest/onnx/test_onnxruntime.py | 2 +- tests/python-pytest/onnx/test_operators.py | 2 +- 6 files changed, 18 insertions(+), 11 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py index a71c0c7d3633..4560533fcdfc 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py +++ b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py @@ -268,6 +268,11 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False, graph_input_idx += 1 else: + # Handle no input case + intype = 1 # Float32 in tensor type + if len(in_type) > 0: + intype = in_type[0] + # Handling graph layers converted = MXNetGraph.convert_layer( node, @@ -275,7 +280,7 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False, mx_graph=mx_graph, weights=weights, in_shape=in_shape, - in_type=in_type[0], + in_type=intype, proc_nodes=all_processed_nodes, initializer=initializer, outputs_lookup=outputs_lookup, diff --git a/python/mxnet/contrib/onnx/onnx2mx/import_onnx.py b/python/mxnet/contrib/onnx/onnx2mx/import_onnx.py index c2be83d8f12e..d51c51cc70f3 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/import_onnx.py +++ b/python/mxnet/contrib/onnx/onnx2mx/import_onnx.py @@ -147,7 +147,8 @@ def get_graph_metadata(self, graph): for graph_input in graph.input: if graph_input.name not in _params: shape = [val.dim_value for val in graph_input.type.tensor_type.shape.dim] - input_data.append((graph_input.name, tuple(shape))) + dtype = graph_input.type.tensor_type.elem_type + input_data.append((graph_input.name, tuple(shape), dtype)) output_data = [] for graph_out in graph.output: diff --git a/tests/python-pytest/onnx/backend.py b/tests/python-pytest/onnx/backend.py index eb803f790332..6d8b1af6baff 100644 --- a/tests/python-pytest/onnx/backend.py +++ b/tests/python-pytest/onnx/backend.py @@ -50,7 +50,7 @@ def set_params(cls, backend, operation): cls.operation = operation @staticmethod - def perform_import_export(sym, arg_params, aux_params, input_shape): + def perform_import_export(sym, arg_params, aux_params, input_shape, input_dtype): """ Import ONNX model to mxnet model and then export to ONNX model and then import it back to mxnet for verifying the result""" graph = GraphProto() @@ -63,7 +63,7 @@ def perform_import_export(sym, arg_params, aux_params, input_shape): # exporting to onnx graph proto format converter = MXNetGraph() graph_proto = converter.create_onnx_graph_proto(sym, params, in_shape=input_shape, - in_type=mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype('float32')], + in_type=input_dtype, opset_version=opset_version) # importing back to MXNET for verifying result. @@ -108,8 +108,9 @@ def prepare(cls, model, device='CPU', **kwargs): metadata = graph.get_graph_metadata(model.graph) input_data = metadata['input_tensor_data'] input_shape = [data[1] for data in input_data] + input_dtype = [data[2] for data in input_data] sym, arg_params, aux_params = MXNetBackend.perform_import_export(sym, arg_params, aux_params, - input_shape) + input_shape, input_dtype) return MXNetBackendRep(sym, arg_params, aux_params, device) elif backend == 'gluon': diff --git a/tests/python-pytest/onnx/test_node.py b/tests/python-pytest/onnx/test_node.py index 28a5fe22bdc0..325ae35e2a04 100644 --- a/tests/python-pytest/onnx/test_node.py +++ b/tests/python-pytest/onnx/test_node.py @@ -186,13 +186,13 @@ def test_import_export(self): if test == "Pow": outsym = ipsym ** 2 forward_op = forward_pass(outsym, None, None, ['input1'], input1) - converted_model = onnx_mxnet.export_model(outsym, {}, [np.shape(input1)], np.float32, + converted_model = onnx_mxnet.export_model(outsym, {}, [np.shape(input1)], [np.float32], onnx_file_path=outsym.name + ".onnx") sym, arg_params, aux_params = onnx_mxnet.import_model(converted_model) - result = forward_pass(sym, arg_params, aux_params, ['input1'], input1) + result = forward_pass(sym, arg_params, aux_params, ['input1'], input1) - npt.assert_almost_equal(result, forward_op) + npt.assert_almost_equal(result, forward_op) def test_imports(self): for test in import_test_cases: @@ -212,7 +212,7 @@ def test_exports(self): test_name, onnx_name, mx_op, attrs = test input_sym = mx.sym.var('data') outsym = mx_op(input_sym, **attrs) - converted_model = onnx_mxnet.export_model(outsym, {}, [input_shape], np.float32, + converted_model = onnx_mxnet.export_model(outsym, {}, [input_shape], [np.float32], onnx_file_path=outsym.name + ".onnx") model = load_model(converted_model) checker.check_model(model) diff --git a/tests/python-pytest/onnx/test_onnxruntime.py b/tests/python-pytest/onnx/test_onnxruntime.py index 86e19fa121af..ad9115eb36b1 100644 --- a/tests/python-pytest/onnx/test_onnxruntime.py +++ b/tests/python-pytest/onnx/test_onnxruntime.py @@ -52,7 +52,7 @@ def export(self): def export_onnx(self): onnx_file = self.modelpath + ".onnx" mx.contrib.onnx.export_model(self.modelpath + "-symbol.json", self.modelpath + "-0000.params", - [self.input_shape], self.input_dtype, onnx_file) + [self.input_shape], [self.input_dtype], onnx_file) return onnx_file def predict(self, data): diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 3ddcb6240b1e..fce5781ad92a 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -53,7 +53,7 @@ def export_to_onnx(model, model_name, inputs): def onnx_rt(onnx_file, inputs): sess = rt.InferenceSession(onnx_file) dtype_0 = inputs[0].asnumpy().dtype - input_dict = dict((sess.get_inputs()[i].name, inputs[i].asnumpy().astype(dtype_0)) for i in range(len(inputs))) + input_dict = dict((sess.get_inputs()[i].name, inputs[i].asnumpy()) for i in range(len(inputs))) pred = sess.run(None, input_dict) return pred From 0444194a546bf78c15b0306b24bb03275507f3f0 Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Tue, 2 Feb 2021 20:03:10 -0800 Subject: [PATCH 10/17] fix sanity --- python/mxnet/contrib/onnx/mx2onnx/export_onnx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py index 4560533fcdfc..e6ed4df435af 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py +++ b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py @@ -270,7 +270,7 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False, else: # Handle no input case intype = 1 # Float32 in tensor type - if len(in_type) > 0: + if len(in_type) > 0: intype = in_type[0] # Handling graph layers From 2b9e48f89b40cc9ed06b18dcb01a2739105912f9 Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Wed, 3 Feb 2021 14:58:17 -0800 Subject: [PATCH 11/17] remove float64 from test_where --- tests/python-pytest/onnx/test_operators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index fce5781ad92a..3bf3cf691333 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -559,7 +559,7 @@ def test_onnx_export_equal_scalar(tmp_path, dtype, scalar): op_export_test('_internal._equal_scalar', M, [x], tmp_path) -@pytest.mark.parametrize("dtype", ["float16", "float32", "float64", "int32", "int64"]) +@pytest.mark.parametrize("dtype", ["float16", "float32", "int32", "int64"]) @pytest.mark.parametrize("shape", [(1,1), (3,3), (10,2), (20,30,40)]) def test_onnx_export_where(tmp_path, dtype, shape): M = def_model('where') From bb12126e56fc6a2c56d2fa0e474d1e803ceeb2e1 Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Wed, 3 Feb 2021 15:23:14 -0800 Subject: [PATCH 12/17] update test --- python/mxnet/contrib/onnx/mx2onnx/export_onnx.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py index e6ed4df435af..59f9e201b32f 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py +++ b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py @@ -160,14 +160,14 @@ def get_outputs(sym, params, in_shape, in_label, in_type): # by user. Also remove in_label in_dtype = {n: mapping.TENSOR_TYPE_TO_NP_TYPE[t] for n, t in zip([n for n in sym.list_inputs() if n not in params and n != in_label], in_type)} - # Add params and their shape to list of inputs + # Add params and their types to list of inputs in_dtype.update({n: v.dtype for n, v in params.items() if n in sym.list_inputs()}) _, out_type, _ = sym.infer_type(**in_dtype) out_types = [mapping.NP_TYPE_TO_TENSOR_TYPE[o(0).dtype] for o in out_type] assert len(out_types) == len(out_names) - # bind output shapes with output names + # bind output shapes/types with output names graph_outputs = {n: {'shape': s, 'dtype': d} for n, s, d in zip(out_names, out_shapes, out_types)} return graph_outputs From f2da05bf13fa22809f14d9dfb0b7cea1e64920c1 Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Thu, 4 Feb 2021 14:44:19 -0800 Subject: [PATCH 13/17] fix bert test input type --- tests/python-pytest/onnx/test_onnxruntime.py | 57 ++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/tests/python-pytest/onnx/test_onnxruntime.py b/tests/python-pytest/onnx/test_onnxruntime.py index ad9115eb36b1..8bdbd4255adb 100644 --- a/tests/python-pytest/onnx/test_onnxruntime.py +++ b/tests/python-pytest/onnx/test_onnxruntime.py @@ -555,3 +555,60 @@ def load_video(filepath): finally: shutil.rmtree(tmp_path) + + +@with_seed() +@pytest.mark.parametrize('model', ['bert_12_768_12']) +def test_bert_inference_onnxruntime(tmp_path, model): + tmp_path = str(tmp_path) + try: + import gluonnlp as nlp + dataset = 'book_corpus_wiki_en_uncased' + ctx = mx.cpu(0) + model, vocab = nlp.model.get_model( + name=model, + ctx=ctx, + dataset_name=dataset, + pretrained=False, + use_pooler=True, + use_decoder=False, + use_classifier=False) + model.initialize(ctx=ctx) + model.hybridize(static_alloc=True) + + batch = 5 + seq_length = 16 + # create synthetic test data + inputs = mx.nd.random.uniform(0, 30522, shape=(batch, seq_length), dtype='float32') + token_types = mx.nd.random.uniform(0, 2, shape=(batch, seq_length), dtype='float32') + valid_length = mx.nd.array([seq_length] * batch, dtype='float32') + + seq_encoding, cls_encoding = model(inputs, token_types, valid_length) + + prefix = "%s/bert" % tmp_path + model.export(prefix) + sym_file = "%s-symbol.json" % prefix + params_file = "%s-0000.params" % prefix + onnx_file = "%s.onnx" % prefix + + + input_shapes = [(batch, seq_length), (batch, seq_length), (batch,)] + input_types = [np.float32, np.float32, np.float32] + converted_model_path = mx.contrib.onnx.export_model(sym_file, params_file, input_shapes, input_types, onnx_file) + + + # create onnxruntime session using the generated onnx file + ses_opt = onnxruntime.SessionOptions() + ses_opt.log_severity_level = 3 + session = onnxruntime.InferenceSession(onnx_file, ses_opt) + onnx_inputs = [inputs, token_types, valid_length] + input_dict = dict((session.get_inputs()[i].name, onnx_inputs[i].asnumpy()) for i in range(len(onnx_inputs))) + pred_onx, cls_onx = session.run(None, input_dict) + + assert_almost_equal(seq_encoding, pred_onx) + assert_almost_equal(cls_encoding, cls_onx) + + finally: + shutil.rmtree(tmp_path) + + From 00bc04bd7395fe206373608575c5b753b2d85381 Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Thu, 4 Feb 2021 16:59:54 -0800 Subject: [PATCH 14/17] enable defalut input_type --- python/mxnet/contrib/onnx/mx2onnx/export_model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_model.py b/python/mxnet/contrib/onnx/mx2onnx/export_model.py index 086918fd56c9..aa68e639d3b7 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/export_model.py +++ b/python/mxnet/contrib/onnx/mx2onnx/export_model.py @@ -28,7 +28,7 @@ from ._export_helper import load_module -def export_model(sym, params, input_shape, input_type, +def export_model(sym, params, input_shape, input_type=np.float32, onnx_file_path='model.onnx', verbose=False, opset_version=None): """Exports the MXNet model file, passed as a parameter, into ONNX model. Accepts both symbol,parameter objects as well as json and params filepaths as input. @@ -73,6 +73,8 @@ def export_model(sym, params, input_shape, input_type, # default is to use latest opset version the onnx package supports opset_version = onnx_opset_version() + if not isinstance(input_type, list) then: + input_type = [input_type for _ in range(len(input_shapes))] input_dtype = [mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(inp_type)] for inp_type in input_type] # if input parameters are strings(file paths), load files and create symbol parameter objects if isinstance(sym, string_types) and isinstance(params, string_types): From c4608ca1532a6569099279d909d8f6857097999c Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Thu, 4 Feb 2021 19:37:32 -0800 Subject: [PATCH 15/17] more default fix --- python/mxnet/contrib/onnx/mx2onnx/export_model.py | 6 +++--- tests/python-pytest/onnx/test_models.py | 2 +- tests/python-pytest/onnx/test_node.py | 4 ++-- tests/python-pytest/onnx/test_onnxruntime.py | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_model.py b/python/mxnet/contrib/onnx/mx2onnx/export_model.py index aa68e639d3b7..852737a3398c 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/export_model.py +++ b/python/mxnet/contrib/onnx/mx2onnx/export_model.py @@ -43,8 +43,8 @@ def export_model(sym, params, input_shape, input_type=np.float32, Path to the params file or params dictionary. (Including both arg_params and aux_params) input_shape : List of tuple Input shape of the model e.g [(1,3,224,224)] - input_type : List of dtype - Input data type e.g. [np.float32] + input_type : data type or list of data types + Input data type e.g. np.float32, or [np.float32, np.int32] onnx_file_path : str Path where to save the generated onnx file verbose : Boolean @@ -73,7 +73,7 @@ def export_model(sym, params, input_shape, input_type=np.float32, # default is to use latest opset version the onnx package supports opset_version = onnx_opset_version() - if not isinstance(input_type, list) then: + if not isinstance(input_type, list): input_type = [input_type for _ in range(len(input_shapes))] input_dtype = [mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(inp_type)] for inp_type in input_type] # if input parameters are strings(file paths), load files and create symbol parameter objects diff --git a/tests/python-pytest/onnx/test_models.py b/tests/python-pytest/onnx/test_models.py index 18730eb48ea8..f85786141d6e 100644 --- a/tests/python-pytest/onnx/test_models.py +++ b/tests/python-pytest/onnx/test_models.py @@ -136,7 +136,7 @@ def get_model_results(modelpath): onnx_file = os.path.join(dir_path, new_model_name) logging.info("Translating converted model from mxnet to ONNX") - converted_model_path = onnx_mxnet.export_model(sym, params, [input_shape], [np.float32], onnx_file) + converted_model_path = onnx_mxnet.export_model(sym, params, [input_shape], np.float32, onnx_file) sym, arg_params, aux_params, actual_result, metadata = get_model_results(converted_model_path) diff --git a/tests/python-pytest/onnx/test_node.py b/tests/python-pytest/onnx/test_node.py index 325ae35e2a04..686ea4c485b9 100644 --- a/tests/python-pytest/onnx/test_node.py +++ b/tests/python-pytest/onnx/test_node.py @@ -186,7 +186,7 @@ def test_import_export(self): if test == "Pow": outsym = ipsym ** 2 forward_op = forward_pass(outsym, None, None, ['input1'], input1) - converted_model = onnx_mxnet.export_model(outsym, {}, [np.shape(input1)], [np.float32], + converted_model = onnx_mxnet.export_model(outsym, {}, [np.shape(input1)], np.float32, onnx_file_path=outsym.name + ".onnx") sym, arg_params, aux_params = onnx_mxnet.import_model(converted_model) @@ -212,7 +212,7 @@ def test_exports(self): test_name, onnx_name, mx_op, attrs = test input_sym = mx.sym.var('data') outsym = mx_op(input_sym, **attrs) - converted_model = onnx_mxnet.export_model(outsym, {}, [input_shape], [np.float32], + converted_model = onnx_mxnet.export_model(outsym, {}, [input_shape], np.float32, onnx_file_path=outsym.name + ".onnx") model = load_model(converted_model) checker.check_model(model) diff --git a/tests/python-pytest/onnx/test_onnxruntime.py b/tests/python-pytest/onnx/test_onnxruntime.py index 8bdbd4255adb..fa2799e89a84 100644 --- a/tests/python-pytest/onnx/test_onnxruntime.py +++ b/tests/python-pytest/onnx/test_onnxruntime.py @@ -52,7 +52,7 @@ def export(self): def export_onnx(self): onnx_file = self.modelpath + ".onnx" mx.contrib.onnx.export_model(self.modelpath + "-symbol.json", self.modelpath + "-0000.params", - [self.input_shape], [self.input_dtype], onnx_file) + [self.input_shape], self.input_dtype, onnx_file) return onnx_file def predict(self, data): From 691c70272397d940b96056aab13d9400beac1a4f Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Fri, 5 Feb 2021 10:18:52 -0800 Subject: [PATCH 16/17] fix typo --- python/mxnet/contrib/onnx/mx2onnx/export_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_model.py b/python/mxnet/contrib/onnx/mx2onnx/export_model.py index 852737a3398c..f4b0c909b439 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/export_model.py +++ b/python/mxnet/contrib/onnx/mx2onnx/export_model.py @@ -74,7 +74,7 @@ def export_model(sym, params, input_shape, input_type=np.float32, opset_version = onnx_opset_version() if not isinstance(input_type, list): - input_type = [input_type for _ in range(len(input_shapes))] + input_type = [input_type for _ in range(len(input_shape))] input_dtype = [mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(inp_type)] for inp_type in input_type] # if input parameters are strings(file paths), load files and create symbol parameter objects if isinstance(sym, string_types) and isinstance(params, string_types): From 4bd8aa195391063e4f20b3f14e3fba8a4eb67416 Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Mon, 15 Feb 2021 19:11:14 -0800 Subject: [PATCH 17/17] fix empty lines --- python/mxnet/contrib/onnx/mx2onnx/_op_translations.py | 1 - tests/python-pytest/onnx/test_onnxruntime.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index a1dcd7d454c9..77179bd46966 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -4035,7 +4035,6 @@ def convert_argsort(node, **kwargs): to=onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)]) ] - return nodes diff --git a/tests/python-pytest/onnx/test_onnxruntime.py b/tests/python-pytest/onnx/test_onnxruntime.py index fa2799e89a84..48a8a436384b 100644 --- a/tests/python-pytest/onnx/test_onnxruntime.py +++ b/tests/python-pytest/onnx/test_onnxruntime.py @@ -555,8 +555,6 @@ def load_video(filepath): finally: shutil.rmtree(tmp_path) - - @with_seed() @pytest.mark.parametrize('model', ['bert_12_768_12']) def test_bert_inference_onnxruntime(tmp_path, model):