Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions python/mxnet/onnx/mx2onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,4 @@
"""ONNX Export module"""

from ._export_model import export_model, get_operator_support
from ._op_translations import _op_translations_opset12
from ._op_translations import _op_translations_opset13
from ._op_translations import *
15 changes: 12 additions & 3 deletions python/mxnet/onnx/mx2onnx/_export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def get_operator_support(opset_version=None):
def export_model(sym, params, in_shapes=None, in_types=np.float32,
onnx_file_path='model.onnx', verbose=False, dynamic=False,
dynamic_input_shapes=None, run_shape_inference=False, input_type=None,
input_shape=None):
input_shape=None, model_specific_logics=None, cheat_sheet=None):
"""Exports the MXNet model file, passed as a parameter, into ONNX model.
Accepts both symbol,parameter objects as well as json and params filepaths as input.
Operator support and coverage -
Expand Down Expand Up @@ -83,6 +83,11 @@ def export_model(sym, params, in_shapes=None, in_types=np.float32,
This is the old name of in_types. We keep this parameter name for backward compatibility
in_shapes : List of tuple
This is the old name of in_shapes. We keep this parameter name for backward compatibility
model_specific_logics : str
Specifies if model-specific conversion logic should be used. Refer to ./_op_translations/
cheat_sheet : dict of str to str
This is a dict that stors some hyperparameters values or additional info about the model that
would be used in model-specific conversion functions
Comment on lines +86 to +90
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these options are semantically unclear and hard to maintain


Returns
-------
Expand Down Expand Up @@ -122,12 +127,16 @@ def export_model(sym, params, in_shapes=None, in_types=np.float32,
onnx_graph = converter.create_onnx_graph_proto(sym_obj, params_obj, in_shapes,
in_types_t,
verbose=verbose, opset_version=opset_version,
dynamic=dynamic, dynamic_input_shapes=dynamic_input_shapes)
dynamic=dynamic, dynamic_input_shapes=dynamic_input_shapes,
model_specific_logics=model_specific_logics,
cheat_sheet=cheat_sheet)
elif isinstance(sym, symbol.Symbol) and isinstance(params, dict):
onnx_graph = converter.create_onnx_graph_proto(sym, params, in_shapes,
in_types_t,
verbose=verbose, opset_version=opset_version,
dynamic=dynamic, dynamic_input_shapes=dynamic_input_shapes)
dynamic=dynamic, dynamic_input_shapes=dynamic_input_shapes,
model_specific_logics=model_specific_logics,
cheat_sheet=cheat_sheet)
elif isinstance(sym, symbol.Symbol) and isinstance(params, list) and len(params) == 2:
# when params contains arg_params and aux_params
p = {}
Expand Down
32 changes: 26 additions & 6 deletions python/mxnet/onnx/mx2onnx/_export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,25 @@ def convert_layer(node, **kwargs):
op = str(node["op"])
opset_version = kwargs.get("opset_version", onnx_opset_version())
# fallback to older opset versions if op is not registered in current version
convert_func = None
for op_version in range(opset_version, 11, -1):
if op_version not in MXNetGraph.registry_ or op not in MXNetGraph.registry_[op_version]:
if opset_version == 12:
raise AttributeError("No conversion function registered for op type %s yet." % op)
continue
convert_func = MXNetGraph.registry_[op_version][op]
break

model_specific_logics = kwargs.get("model_specific_logics", None)
if model_specific_logics:
assert model_specific_logics in MXNetGraph.registry_,\
"Model specific converion logics for %s is not found" % model_specific_logics
if op in MXNetGraph.registry_[model_specific_logics]:
logging.info("Found model-specific %s conversion logic for model %s",
op, model_specific_logics)
convert_func = MXNetGraph.registry_[model_specific_logics][op]

if convert_func is None:
raise AttributeError("No conversion function registered for op type %s yet." % op)

ret = convert_func(node, **kwargs)
# in case the conversion function does not specify the returned dtype, we just return None
# as the second value
Expand Down Expand Up @@ -238,8 +249,9 @@ def convert_weights_to_numpy(weights_dict):
return dict([(k.replace("arg:", "").replace("aux:", ""), v.asnumpy())
for k, v in weights_dict.items()])

def create_onnx_graph_proto(self, sym, params, in_shapes, in_types, verbose=False, opset_version=None,
dynamic=True, dynamic_input_shapes=None):
def create_onnx_graph_proto(self, sym, params, in_shapes, in_types, verbose=False,
opset_version=None, dynamic=True, dynamic_input_shapes=None,
model_specific_logics=None, cheat_sheet=None):
"""Convert MXNet graph to ONNX graph

Parameters
Expand All @@ -260,6 +272,11 @@ def create_onnx_graph_proto(self, sym, params, in_shapes, in_types, verbose=Fals
If True will allow for dynamic input shapes to the model
dynamic_input_shapes: list of tuple
Specifies the dynamic input_shapes. If None then all dimensions are set to None
model_specific_logics: str
Specifies if model-specific conversion logic should be used. Refer to ./_op_translations/
cheat_sheet : dict of str to str
This is a dict that stors some hyperparameters values or additional info about the model that
would be used in model-specific conversion functions

Returns
-------
Expand Down Expand Up @@ -335,7 +352,8 @@ def __init__(self, name, dtype):
in_type=in_types[graph_input_idx],
proc_nodes=all_processed_nodes,
initializer=initializer,
outputs_lookup=outputs_lookup)
outputs_lookup=outputs_lookup,
)
graph_input_idx += 1
else:
# Handle graph layers
Expand All @@ -348,7 +366,9 @@ def __init__(self, name, dtype):
initializer=initializer,
outputs_lookup=outputs_lookup,
idx=idx,
opset_version=opset_version
opset_version=opset_version,
model_specific_logics=model_specific_logics,
cheat_sheet=cheat_sheet
)
if isinstance(converted, list):
# Collect all the node's output names
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/onnx/mx2onnx/_op_translations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@
# coding: utf-8
"""ONNX export op translation"""

from . import _gluonnlp_bert_uninterleaved
from . import _op_translations_opset12
from . import _op_translations_opset13
Loading