Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,14 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False)
# Deriving the output_label name.
output_label = sym.get_internals()[len(sym.get_internals()) - 1].name + "_label"

# Determine output shape
output_shape = MXNetGraph.infer_output_shape(sym, params, in_shape, output_label)
# Determine outputs shapes
input_names = [n for n in sym.list_inputs() if n not in params]
input_pairs = {n: in_shape[i] for i, n in enumerate(input_names)}
_, output_shapes, _ = sym.get_internals().infer_shape(**input_pairs)

output_suffix = '_output'
output_names = [
o[:-len(output_suffix)] for o in sym.list_outputs() if o.endswith(output_suffix)]

weights = MXNetGraph.convert_weights_to_numpy(params)

Expand Down Expand Up @@ -265,6 +271,7 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False)
weights=weights,
in_shape=in_shape[graph_input_idx],
in_type=in_type,
out_shape=output_shapes[idx],
proc_nodes=all_processed_nodes,
initializer=initializer,
index_lookup=index_lookup)
Expand All @@ -279,6 +286,7 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False)
weights=weights,
in_shape=in_shape,
in_type=in_type,
out_shape=output_shapes[idx],
proc_nodes=all_processed_nodes,
initializer=initializer,
index_lookup=index_lookup,
Expand All @@ -294,22 +302,22 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False)
# If converted node is NodeProto, add it in processed nodes list
elif isinstance(converted_node, NodeProto):
onnx_processed_nodes.append(converted_node)
if idx == (len(mx_graph) - 1):
if idx == (len(mx_graph) - 1) or converted_node.name in output_names:
# If converted node doesnt have name, use it from output field
if not converted_node.name:
onnx_processed_outputs.append(
make_tensor_value_info(
name=converted_node.output[0],
elem_type=in_type,
shape=output_shape
shape=output_shapes[idx]
)
)
else:
onnx_processed_outputs.append(
make_tensor_value_info(
name=converted_node.name,
elem_type=in_type,
shape=output_shape
shape=output_shapes[idx]
)
)
if verbose:
Expand Down
27 changes: 17 additions & 10 deletions tests/python-pytest/onnx/export/mxnet_export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,10 @@ def forward_pass(sym, arg, aux, data_names, input_data):
batch = namedtuple('Batch', ['data'])
mod.forward(batch([mx.nd.array(input_data)]), is_train=False)

return mod.get_outputs()[0].asnumpy()
return [output.asnumpy() for output in mod.get_outputs()]


def test_models(model_name, input_shape, output_shape):
def test_models(model_name, input_shape, output_shape, test_extra_output=False):
""" Tests Googlenet model for both onnx import and export"""
model_path, inputs, outputs = get_test_files(model_name)
logging.info("Translating model from ONNX model zoo to Mxnet")
Expand All @@ -117,6 +117,12 @@ def test_models(model_name, input_shape, output_shape):
new_model_name = "exported_" + model_name + ".onnx"
onnx_file = os.path.join(dir_path, new_model_name)

if test_extra_output:
logging.info("Adding extra output to model")
sym_output = sym.get_internals()[sym.list_outputs()[0]]
id_output = mx.sym.identity(data=sym_output)
sym = mx.symbol.Group([sym_output, id_output])

logging.info("Translating converted model from mxnet to ONNX")
converted_model_path = onnx_mxnet.export_model(sym, params, [input_shape], np.float32, onnx_file)

Expand All @@ -133,11 +139,11 @@ def test_models(model_name, input_shape, output_shape):
logging.info("Running inference on onnx re-import model in mxnet")
# run test for each test file
for input_data, output_data in zip(inputs, outputs):
result = forward_pass(sym, arg_params, aux_params, data_names, input_data)

# verify the results
npt.assert_equal(result.shape, output_data.shape)
npt.assert_almost_equal(output_data, result, decimal=3)
results = forward_pass(sym, arg_params, aux_params, data_names, input_data)
for result in results:
# verify the results
npt.assert_equal(result.shape, output_data.shape)
npt.assert_almost_equal(output_data, result, decimal=3)
logging.info(model_name + " conversion successful")


Expand All @@ -153,7 +159,7 @@ def test_model_accuracy(model_name, input_shape):

expected_result= []
for input_data, output_data in zip(inputs, outputs):
result = forward_pass(sym, arg_params, aux_params, data_names, input_data)
result = forward_pass(sym, arg_params, aux_params, data_names, input_data)[0]
expected_result.append(result)

params = {}
Expand All @@ -175,7 +181,7 @@ def test_model_accuracy(model_name, input_shape):

actual_result = []
for input_data, output_data in zip(inputs, outputs):
result = forward_pass(sym, arg_params, aux_params, data_names, input_data)
result = forward_pass(sym, arg_params, aux_params, data_names, input_data)[0]
actual_result.append(result)

# verify the results
Expand Down Expand Up @@ -232,7 +238,7 @@ def test_square():
converted_model = onnx_mxnet.export_model(square, params, [np.shape(input1)], np.float32, "square.onnx")

sym, arg_params, aux_params = onnx_mxnet.import_model(converted_model)
result = forward_pass(sym, arg_params, aux_params, ['input1'], input1)
result = forward_pass(sym, arg_params, aux_params, ['input1'], input1)[0]

numpy_op = np.square(input1)

Expand All @@ -242,6 +248,7 @@ def test_square():
test_models("bvlc_googlenet", (1, 3, 224, 224), (1, 1000))
test_models("bvlc_reference_caffenet", (1, 3, 224, 224), (1, 1000))
test_models("bvlc_reference_rcnn_ilsvrc13", (1, 3, 224, 224), (1, 200))
test_models("bvlc_googlenet", (1, 3, 224, 224), (1, 1000), test_extra_output=True)

# Comparing MXNet inference result, since MXNet results don't match
# ONNX expected results due to AveragePool issue github issue(#10194)
Expand Down