diff --git a/python/mxnet/onnx/README.md b/python/mxnet/onnx/README.md index b9f6ddaa70aa..118af59372b8 100644 --- a/python/mxnet/onnx/README.md +++ b/python/mxnet/onnx/README.md @@ -148,6 +148,7 @@ mx.onnx.export_model(mx_sym, mx_params, in_shapes, in_dtypes, onnx_file, |_rminus_scalar|1.7 1.8 | |_rnn_param_concat|1.7 1.8 | |_sample_multinomial|1.7 1.8 | +|_split_v2|1.7 1.8 | |_zeros|1.7 1.8 | |abs|1.7 1.8 | |add_n|1.7 1.8 | @@ -173,6 +174,7 @@ mx.onnx.export_model(mx_sym, mx_params, in_shapes, in_dtypes, onnx_file, |broadcast_minimum|1.7 1.8 | |broadcast_mod|1.7 1.8 | |broadcast_mul|1.7 1.8 | +|broadcast_not_equal|1.7 1.8 | |broadcast_power|1.7 1.8 | |broadcast_sub|1.7 1.8 | |broadcast_to|1.7 1.8 | @@ -405,4 +407,4 @@ mx.onnx.export_model(mx_sym, mx_params, in_shapes, in_dtypes, onnx_file, |standard_lstm_lm_200| |standard_lstm_lm_650| |standard_lstm_lm_1500| -|transformer_en_de_512| \ No newline at end of file +|transformer_en_de_512| diff --git a/python/mxnet/onnx/mx2onnx/_export_onnx.py b/python/mxnet/onnx/mx2onnx/_export_onnx.py index e3aa59ec773c..3af870e6a557 100644 --- a/python/mxnet/onnx/mx2onnx/_export_onnx.py +++ b/python/mxnet/onnx/mx2onnx/_export_onnx.py @@ -45,7 +45,7 @@ # coding: utf-8 # pylint: disable=invalid-name,too-many-locals,no-self-use,too-many-arguments, -# pylint: disable=maybe-no-member,too-many-nested-blocks +# pylint: disable=maybe-no-member,too-many-nested-blocks,logging-not-lazy """MXNet to ONNX graph converter functions""" import logging import json @@ -91,15 +91,23 @@ def convert_layer(node, **kwargs): op = str(node["op"]) opset_version = kwargs.get("opset_version", onnx_opset_version()) - # fallback to older opset versions if op is not registered in current version + if opset_version < 12: + logging.warning('Your ONNX op set version is %s, ' % str(opset_version) + + 'which is lower than then lowest tested op set (12), please consider ' + 'updating ONNX') + opset_version = 12 + # Fallback to older opset versions if op is not registered in current version + convert_func = None for op_version in range(opset_version, 11, -1): if op_version not in MXNetGraph.registry_ or op not in MXNetGraph.registry_[op_version]: - if opset_version == 12: - raise AttributeError("No conversion function registered for op type %s yet." % op) continue convert_func = MXNetGraph.registry_[op_version][op] break + # The conversion logic is not implemented + if convert_func is None: + raise AttributeError("No conversion function registered for op type %s yet." % op) + ret = convert_func(node, **kwargs) # in case the conversion function does not specify the returned dtype, we just return None # as the second value