Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions python/mxnet/contrib/onnx/mx2onnx/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def export_model(sym, params, input_shape, input_type=np.float32,
Path to the params file or params dictionary. (Including both arg_params and aux_params)
input_shape : List of tuple
Input shape of the model e.g [(1,3,224,224)]
input_type : data type
Input data type e.g. np.float32
input_type : data type or list of data types
Input data type e.g. np.float32, or [np.float32, np.int32]
onnx_file_path : str
Path where to save the generated onnx file
verbose : Boolean
Expand Down Expand Up @@ -73,17 +73,19 @@ def export_model(sym, params, input_shape, input_type=np.float32,
# default is to use latest opset version the onnx package supports
opset_version = onnx_opset_version()

data_format = np.dtype(input_type)
if not isinstance(input_type, list):
input_type = [input_type for _ in range(len(input_shape))]
input_dtype = [mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(inp_type)] for inp_type in input_type]
# if input parameters are strings(file paths), load files and create symbol parameter objects
if isinstance(sym, string_types) and isinstance(params, string_types):
logging.info("Converting json and weight file to sym and params")
sym_obj, params_obj = load_module(sym, params)
onnx_graph = converter.create_onnx_graph_proto(sym_obj, params_obj, input_shape,
mapping.NP_TYPE_TO_TENSOR_TYPE[data_format],
input_dtype,
verbose=verbose, opset_version=opset_version)
elif isinstance(sym, symbol.Symbol) and isinstance(params, dict):
onnx_graph = converter.create_onnx_graph_proto(sym, params, input_shape,
mapping.NP_TYPE_TO_TENSOR_TYPE[data_format],
input_dtype,
verbose=verbose, opset_version=opset_version)
else:
raise ValueError("Input sym and params should either be files or objects")
Expand Down
22 changes: 16 additions & 6 deletions python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,19 @@ def get_outputs(sym, params, in_shape, in_label, in_type):

assert len(out_shapes) == len(out_names)

# infer output types
args = {n: mapping.TENSOR_TYPE_TO_NP_TYPE[in_type] for n in sym.list_inputs()}
_, out_type, _ = sym.infer_type(**args)
## Infer output types
# Remove any input listed in params from sym.list_inputs() and bind them to the input types provided
# by user. Also remove in_label
in_dtype = {n: mapping.TENSOR_TYPE_TO_NP_TYPE[t]
for n, t in zip([n for n in sym.list_inputs() if n not in params and n != in_label], in_type)}
# Add params and their types to list of inputs
in_dtype.update({n: v.dtype for n, v in params.items() if n in sym.list_inputs()})
_, out_type, _ = sym.infer_type(**in_dtype)
out_types = [mapping.NP_TYPE_TO_TENSOR_TYPE[o(0).dtype] for o in out_type]

assert len(out_types) == len(out_names)

# bind output shapes with output names
# bind output shapes/types with output names
graph_outputs = {n: {'shape': s, 'dtype': d} for n, s, d in zip(out_names, out_shapes, out_types)}

return graph_outputs
Expand Down Expand Up @@ -256,21 +261,26 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False,
mx_graph=mx_graph,
weights=weights,
in_shape=in_shape[graph_input_idx],
in_type=in_type,
in_type=in_type[graph_input_idx],
proc_nodes=all_processed_nodes,
initializer=initializer,
outputs_lookup=outputs_lookup)
graph_input_idx += 1

else:
# Handle no input case
intype = 1 # Float32 in tensor type
if len(in_type) > 0:
intype = in_type[0]

# Handling graph layers
converted = MXNetGraph.convert_layer(
node,
is_input=False,
mx_graph=mx_graph,
weights=weights,
in_shape=in_shape,
in_type=in_type,
in_type=intype,
proc_nodes=all_processed_nodes,
initializer=initializer,
outputs_lookup=outputs_lookup,
Expand Down
3 changes: 2 additions & 1 deletion python/mxnet/contrib/onnx/onnx2mx/import_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ def get_graph_metadata(self, graph):
for graph_input in graph.input:
if graph_input.name not in _params:
shape = [val.dim_value for val in graph_input.type.tensor_type.shape.dim]
input_data.append((graph_input.name, tuple(shape)))
dtype = graph_input.type.tensor_type.elem_type
input_data.append((graph_input.name, tuple(shape), dtype))

output_data = []
for graph_out in graph.output:
Expand Down
7 changes: 4 additions & 3 deletions tests/python-pytest/onnx/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def set_params(cls, backend, operation):
cls.operation = operation

@staticmethod
def perform_import_export(sym, arg_params, aux_params, input_shape):
def perform_import_export(sym, arg_params, aux_params, input_shape, input_dtype):
""" Import ONNX model to mxnet model and then export to ONNX model
and then import it back to mxnet for verifying the result"""
graph = GraphProto()
Expand All @@ -63,7 +63,7 @@ def perform_import_export(sym, arg_params, aux_params, input_shape):
# exporting to onnx graph proto format
converter = MXNetGraph()
graph_proto = converter.create_onnx_graph_proto(sym, params, in_shape=input_shape,
in_type=mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype('float32')],
in_type=input_dtype,
opset_version=opset_version)

# importing back to MXNET for verifying result.
Expand Down Expand Up @@ -108,8 +108,9 @@ def prepare(cls, model, device='CPU', **kwargs):
metadata = graph.get_graph_metadata(model.graph)
input_data = metadata['input_tensor_data']
input_shape = [data[1] for data in input_data]
input_dtype = [data[2] for data in input_data]
sym, arg_params, aux_params = MXNetBackend.perform_import_export(sym, arg_params, aux_params,
input_shape)
input_shape, input_dtype)

return MXNetBackendRep(sym, arg_params, aux_params, device)
elif backend == 'gluon':
Expand Down
1 change: 1 addition & 0 deletions tests/python-pytest/onnx/mxnet_export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def _check_onnx_export(net, group_outputs=False, shape_type=tuple, extra_params=
sym=net_sym,
params=net_params,
input_shape=[shape_type(data.shape)],
input_type=[data.dtype],
onnx_file_path=onnx_file_path)
assert export_path == onnx_file_path
# Try importing the model to symbol
Expand Down
6 changes: 3 additions & 3 deletions tests/python-pytest/onnx/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def test_import_export(self):

if mxnet_specific:
onnxmodelfile = onnx_mxnet.export_model(test_op, {}, [np.shape(ip) for ip in inputs],
np.float32,
[ip.dtype for ip in inputs],
onnx_name + ".onnx")
onnxmodel = load_model(onnxmodelfile)
else:
Expand Down Expand Up @@ -190,9 +190,9 @@ def test_import_export(self):
onnx_file_path=outsym.name + ".onnx")

sym, arg_params, aux_params = onnx_mxnet.import_model(converted_model)
result = forward_pass(sym, arg_params, aux_params, ['input1'], input1)
result = forward_pass(sym, arg_params, aux_params, ['input1'], input1)

npt.assert_almost_equal(result, forward_op)
npt.assert_almost_equal(result, forward_op)

def test_imports(self):
for test in import_test_cases:
Expand Down
55 changes: 55 additions & 0 deletions tests/python-pytest/onnx/test_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,3 +555,58 @@ def load_video(filepath):
finally:
shutil.rmtree(tmp_path)

@with_seed()
@pytest.mark.parametrize('model', ['bert_12_768_12'])
def test_bert_inference_onnxruntime(tmp_path, model):
tmp_path = str(tmp_path)
try:
import gluonnlp as nlp
dataset = 'book_corpus_wiki_en_uncased'
ctx = mx.cpu(0)
model, vocab = nlp.model.get_model(
name=model,
ctx=ctx,
dataset_name=dataset,
pretrained=False,
use_pooler=True,
use_decoder=False,
use_classifier=False)
model.initialize(ctx=ctx)
model.hybridize(static_alloc=True)

batch = 5
seq_length = 16
# create synthetic test data
inputs = mx.nd.random.uniform(0, 30522, shape=(batch, seq_length), dtype='float32')
token_types = mx.nd.random.uniform(0, 2, shape=(batch, seq_length), dtype='float32')
valid_length = mx.nd.array([seq_length] * batch, dtype='float32')

seq_encoding, cls_encoding = model(inputs, token_types, valid_length)

prefix = "%s/bert" % tmp_path
model.export(prefix)
sym_file = "%s-symbol.json" % prefix
params_file = "%s-0000.params" % prefix
onnx_file = "%s.onnx" % prefix


input_shapes = [(batch, seq_length), (batch, seq_length), (batch,)]
input_types = [np.float32, np.float32, np.float32]
converted_model_path = mx.contrib.onnx.export_model(sym_file, params_file, input_shapes, input_types, onnx_file)


# create onnxruntime session using the generated onnx file
ses_opt = onnxruntime.SessionOptions()
ses_opt.log_severity_level = 3
session = onnxruntime.InferenceSession(onnx_file, ses_opt)
onnx_inputs = [inputs, token_types, valid_length]
input_dict = dict((session.get_inputs()[i].name, onnx_inputs[i].asnumpy()) for i in range(len(onnx_inputs)))
pred_onx, cls_onx = session.run(None, input_dict)

assert_almost_equal(seq_encoding, pred_onx)
assert_almost_equal(cls_encoding, cls_onx)

finally:
shutil.rmtree(tmp_path)


7 changes: 3 additions & 4 deletions tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,15 @@ def export_to_onnx(model, model_name, inputs):
model.export(model_path, epoch=0)
sym_file = '{}-symbol.json'.format(model_path)
params_file = '{}-0000.params'.format(model_path)
dtype = inputs[0].dtype
onnx_file = '{}/{}.onnx'.format(tmp_path, model_name)
mx.contrib.onnx.export_model(sym_file, params_file, [inp.shape for inp in inputs],
dtype, onnx_file)
[inp.dtype for inp in inputs], onnx_file)
return onnx_file

def onnx_rt(onnx_file, inputs):
sess = rt.InferenceSession(onnx_file)
dtype_0 = inputs[0].asnumpy().dtype
input_dict = dict((sess.get_inputs()[i].name, inputs[i].asnumpy().astype(dtype_0)) for i in range(len(inputs)))
input_dict = dict((sess.get_inputs()[i].name, inputs[i].asnumpy()) for i in range(len(inputs)))
pred = sess.run(None, input_dict)
return pred

Expand Down Expand Up @@ -560,7 +559,7 @@ def test_onnx_export_equal_scalar(tmp_path, dtype, scalar):
op_export_test('_internal._equal_scalar', M, [x], tmp_path)


@pytest.mark.parametrize("dtype", ["float16", "float32", "float64", "int32", "int64"])
@pytest.mark.parametrize("dtype", ["float16", "float32", "int32", "int64"])
@pytest.mark.parametrize("shape", [(1,1), (3,3), (10,2), (20,30,40)])
def test_onnx_export_where(tmp_path, dtype, shape):
M = def_model('where')
Expand Down