This repository was archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.7k
[MXNET-34] Onnx Module to import onnx models into mxnet #9963
Merged
Merged
Changes from all commits
Commits
Show all changes
36 commits
Select commit
Hold shift + click to select a range
d80f95a
Onnx Module to import onnx models into mxnet
23c409a
Changing the translation and utils file.
2d0c4e4
- Fixed Pylint issues
a350505
Add UTs for reduce ops.
678cd7a
pylint - newline, whitespace.
0c66f32
Added operators:
rajanksin 5f91c6c
Added operators:
rajanksin 9a615b2
- Added Pad operator support.
rajanksin 0be8eae
RandomUniform,Normal,Sub,Mul,Div,Tanh,Relu,Reciprocal,Sqrt operators
5936755
lint fix
rajanksin 96d967f
Add protobuf-compile to CI bash script. Add MatMul and Pow operator.
df61e4f
Max,Min,Sum,Reduce operators.
8e74bd8
BatchNorm,SpatialBN, Split
40d9d13
Slice,Transpose and Squeeze Operators.
a1f3782
Onnx tests in CI integration tests.
65f627b
Addressing Marco's comments
495169e
Floor, LeakyRelu, Elu, PRelu, Softmax, Exp, Log operator.
48b2a7c
Added operators:
rajanksin 51189a8
lint fix
rajanksin 69bf6f8
Rebase fix
rajanksin ee0393a
Added Maxpool operator
rajanksin 0ebd5b0
Adding FullyConnected operator
rajanksin 40cbe11
Adding operator- GlobalPooling - max and avg
rajanksin 71517b5
Adding operator - Gemm
rajanksin 3dd7a2e
Change test Path, LRN and Dropout operator.
11be77b
Add asserts for the super_res example.
d98de71
Fixing conv test failures.
rajanksin 4ef7c36
Update Jenkins job.
e96b41b
Nits: Removing commented out code
rajanksin f979d6c
Rebase after Docker PR
c6c8038
Merge branch 'onnx1' of
1d02490
Fetch test files by version number. Verify the high resolution example.
b4b6f9a
Fix method arguments for Python3.5+
612e2a6
Remove logging configuration from test files.
a9b2f62
Verify result image in example by hash
b24baba
Remove fetching test files by ETag.
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| #!/usr/bin/env bash | ||
|
|
||
| # Licensed to the Apache Software Foundation (ASF) under one | ||
| # or more contributor license agreements. See the NOTICE file | ||
| # distributed with this work for additional information | ||
| # regarding copyright ownership. The ASF licenses this file | ||
| # to you under the Apache License, Version 2.0 (the | ||
| # "License"); you may not use this file except in compliance | ||
| # with the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, | ||
| # software distributed under the License is distributed on an | ||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
|
|
||
| ###################################################################### | ||
| # This script installs ONNX for Python along with all required dependencies | ||
| # on a Ubuntu Machine. | ||
| # Tested on Ubuntu 16.04 distro. | ||
| ###################################################################### | ||
|
|
||
| set -e | ||
| set -x | ||
|
|
||
| echo "Installing libprotobuf-dev and protobuf-compiler ..." | ||
| apt-get install -y libprotobuf-dev protobuf-compiler | ||
|
|
||
| echo "Installing pytest, pytest-cov, protobuf, Pillow, ONNX and tabulate ..." | ||
| pip2 install pytest==3.4.0 pytest-cov==2.5.1 protobuf==3.0.0 onnx==1.0.1 Pillow==5.0.0 tabulate==0.7.5 | ||
| pip3 install pytest==3.4.0 pytest-cov==2.5.1 protobuf==3.0.0 onnx==1.0.1 Pillow==5.0.0 tabulate==0.7.5 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,84 @@ | ||
| # Licensed to the Apache Software Foundation (ASF) under one | ||
| # or more contributor license agreements. See the NOTICE file | ||
| # distributed with this work for additional information | ||
| # regarding copyright ownership. The ASF licenses this file | ||
| # to you under the Apache License, Version 2.0 (the | ||
| # "License"); you may not use this file except in compliance | ||
| # with the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, | ||
| # software distributed under the License is distributed on an | ||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
|
|
||
| """Testing super_resolution model conversion""" | ||
| from __future__ import absolute_import as _abs | ||
| from __future__ import print_function | ||
| from collections import namedtuple | ||
| import logging | ||
| import numpy as np | ||
| from PIL import Image | ||
| import mxnet as mx | ||
| from mxnet.test_utils import download | ||
| import mxnet.contrib.onnx as onnx_mxnet | ||
|
|
||
| # set up logger | ||
| logging.basicConfig() | ||
| LOGGER = logging.getLogger() | ||
| LOGGER.setLevel(logging.INFO) | ||
|
|
||
| def import_onnx(): | ||
| """Import the onnx model into mxnet""" | ||
| model_url = 'https://s3.amazonaws.com/onnx-mxnet/examples/super_resolution.onnx' | ||
| download(model_url, 'super_resolution.onnx') | ||
|
|
||
| LOGGER.info("Converting onnx format to mxnet's symbol and params...") | ||
| sym, params = onnx_mxnet.import_model('super_resolution.onnx') | ||
| LOGGER.info("Successfully Converted onnx format to mxnet's symbol and params...") | ||
| return sym, params | ||
|
|
||
| def get_test_image(): | ||
| """Download and process the test image""" | ||
| # Load test image | ||
| input_image_dim = 224 | ||
| img_url = 'https://s3.amazonaws.com/onnx-mxnet/examples/super_res_input.jpg' | ||
| download(img_url, 'super_res_input.jpg') | ||
| img = Image.open('super_res_input.jpg').resize((input_image_dim, input_image_dim)) | ||
| img_ycbcr = img.convert("YCbCr") | ||
| img_y, img_cb, img_cr = img_ycbcr.split() | ||
| input_image = np.array(img_y)[np.newaxis, np.newaxis, :, :] | ||
| return input_image, img_cb, img_cr | ||
|
|
||
| def perform_inference(sym, params, input_img, img_cb, img_cr): | ||
| """Perform inference on image using mxnet""" | ||
| # create module | ||
| mod = mx.mod.Module(symbol=sym, data_names=['input_0'], label_names=None) | ||
| mod.bind(for_training=False, data_shapes=[('input_0', input_img.shape)]) | ||
| mod.set_params(arg_params=params, aux_params=None) | ||
|
|
||
| # run inference | ||
| batch = namedtuple('Batch', ['data']) | ||
| mod.forward(batch([mx.nd.array(input_img)])) | ||
|
|
||
| # Save the result | ||
| img_out_y = Image.fromarray(np.uint8(mod.get_outputs()[0][0][0]. | ||
| asnumpy().clip(0, 255)), mode='L') | ||
|
|
||
| result_img = Image.merge( | ||
| "YCbCr", [img_out_y, | ||
| img_cb.resize(img_out_y.size, Image.BICUBIC), | ||
| img_cr.resize(img_out_y.size, Image.BICUBIC)]).convert("RGB") | ||
| output_img_dim = 672 | ||
| assert result_img.size == (output_img_dim, output_img_dim) | ||
| LOGGER.info("Super Resolution example success.") | ||
| result_img.save("super_res_output.jpg") | ||
| return result_img | ||
|
|
||
| if __name__ == '__main__': | ||
| MX_SYM, MX_PARAM = import_onnx() | ||
| INPUT_IMG, IMG_CB, IMG_CR = get_test_image() | ||
| perform_inference(MX_SYM, MX_PARAM, INPUT_IMG, IMG_CB, IMG_CR) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,5 +28,5 @@ | |
| from . import tensorboard | ||
|
|
||
| from . import text | ||
|
|
||
| from . import onnx | ||
| from . import io | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| # Licensed to the Apache Software Foundation (ASF) under one | ||
| # or more contributor license agreements. See the NOTICE file | ||
| # distributed with this work for additional information | ||
| # regarding copyright ownership. The ASF licenses this file | ||
| # to you under the Apache License, Version 2.0 (the | ||
| # "License"); you may not use this file except in compliance | ||
| # with the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, | ||
| # software distributed under the License is distributed on an | ||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
|
|
||
| """Module for importing and exporting ONNX models.""" | ||
|
|
||
| from ._import.import_model import import_model |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| # Licensed to the Apache Software Foundation (ASF) under one | ||
| # or more contributor license agreements. See the NOTICE file | ||
| # distributed with this work for additional information | ||
| # regarding copyright ownership. The ASF licenses this file | ||
| # to you under the Apache License, Version 2.0 (the | ||
| # "License"); you may not use this file except in compliance | ||
| # with the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, | ||
| # software distributed under the License is distributed on an | ||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
|
|
||
| # coding: utf-8 | ||
| """ONNX Import module""" | ||
| from . import import_model | ||
| from . import import_onnx |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,105 @@ | ||
| # Licensed to the Apache Software Foundation (ASF) under one | ||
| # or more contributor license agreements. See the NOTICE file | ||
| # distributed with this work for additional information | ||
| # regarding copyright ownership. The ASF licenses this file | ||
| # to you under the Apache License, Version 2.0 (the | ||
| # "License"); you may not use this file except in compliance | ||
| # with the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, | ||
| # software distributed under the License is distributed on an | ||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
|
|
||
| # coding: utf-8 | ||
| # pylint: disable=invalid-name | ||
| """Operator attributes conversion""" | ||
| from .op_translations import identity, random_uniform, random_normal | ||
| from .op_translations import add, subtract, multiply, divide, absolute, negative, add_n | ||
| from .op_translations import tanh | ||
| from .op_translations import ceil, floor | ||
| from .op_translations import concat | ||
| from .op_translations import leaky_relu, _elu, _prelu, softmax, fully_connected | ||
| from .op_translations import global_avgpooling, global_maxpooling, linalg_gemm | ||
| from .op_translations import sigmoid, pad, relu, matrix_multiplication, batch_norm | ||
| from .op_translations import dropout, local_response_norm, conv, deconv | ||
|
anirudhacharya marked this conversation as resolved.
|
||
| from .op_translations import reshape, cast, split, _slice, transpose, squeeze | ||
| from .op_translations import reciprocal, squareroot, power, exponent, _log | ||
| from .op_translations import reduce_max, reduce_mean, reduce_min, reduce_sum | ||
| from .op_translations import reduce_prod, avg_pooling, max_pooling | ||
| from .op_translations import argmax, argmin, maximum, minimum | ||
|
|
||
| # convert_map defines maps of ONNX operator names to converter functor(callable) | ||
|
anirudhacharya marked this conversation as resolved.
|
||
| # defined in the op_translations module. | ||
| _convert_map = { | ||
| # Generator Functions | ||
| 'Constant' : identity, | ||
| 'RandomUniform' : random_uniform, | ||
| 'RandomNormal' : random_normal, | ||
| 'RandomUniformLike' : random_uniform, | ||
| 'RandomNormalLike' : random_normal, | ||
| # Arithmetic Operators | ||
| 'Add' : add, | ||
| 'Sub' : subtract, | ||
| 'Mul' : multiply, | ||
| 'Div' : divide, | ||
| 'Abs' : absolute, | ||
| 'Neg' : negative, | ||
| 'Sum' : add_n, #elemwise sum | ||
| #Hyperbolic functions | ||
| 'Tanh' : tanh, | ||
| # Rounding | ||
| 'Ceil' : ceil, | ||
| 'Floor' : floor, | ||
| # Joining and spliting | ||
| 'Concat' : concat, | ||
| # Basic neural network functions | ||
| 'Sigmoid' : sigmoid, | ||
| 'Relu' : relu, | ||
| 'Pad' : pad, | ||
| 'MatMul' : matrix_multiplication, #linalg_gemm2 | ||
| 'Conv' : conv, | ||
| 'ConvTranspose' : deconv, | ||
| 'BatchNormalization': batch_norm, | ||
| 'SpatialBN' : batch_norm, | ||
| 'LeakyRelu' : leaky_relu, | ||
| 'Elu' : _elu, | ||
| 'PRelu' : _prelu, | ||
| 'Softmax' : softmax, | ||
| 'FC' : fully_connected, | ||
| 'GlobalAveragePool' : global_avgpooling, | ||
| 'GlobalMaxPool' : global_maxpooling, | ||
| 'Gemm' : linalg_gemm, | ||
| 'LRN' : local_response_norm, | ||
| 'Dropout' : dropout, | ||
| # Changing shape and type. | ||
| 'Reshape' : reshape, | ||
| 'Cast' : cast, | ||
| 'Split' : split, | ||
| 'Slice' : _slice, | ||
| 'Transpose' : transpose, | ||
| 'Squeeze' : squeeze, | ||
| #Powers | ||
| 'Reciprocal' : reciprocal, | ||
| 'Sqrt' : squareroot, | ||
| 'Pow' : power, | ||
| 'Exp' : exponent, | ||
| 'Log' : _log, | ||
| # Reduce Functions | ||
| 'ReduceMax' : reduce_max, | ||
| 'ReduceMean' : reduce_mean, | ||
| 'ReduceMin' : reduce_min, | ||
| 'ReduceSum' : reduce_sum, | ||
| 'ReduceProd' : reduce_prod, | ||
| 'AveragePool' : avg_pooling, | ||
| 'MaxPool' : max_pooling, | ||
| # Sorting and Searching | ||
| 'ArgMax' : argmax, | ||
| 'ArgMin' : argmin, | ||
| 'Max' : maximum, #elemwise maximum | ||
| 'Min' : minimum #elemwise minimum | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| # Licensed to the Apache Software Foundation (ASF) under one | ||
| # or more contributor license agreements. See the NOTICE file | ||
| # distributed with this work for additional information | ||
| # regarding copyright ownership. The ASF licenses this file | ||
| # to you under the Apache License, Version 2.0 (the | ||
| # "License"); you may not use this file except in compliance | ||
| # with the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, | ||
| # software distributed under the License is distributed on an | ||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
|
|
||
| # coding: utf-8 | ||
| """import function""" | ||
| # pylint: disable=no-member | ||
|
|
||
| from .import_onnx import GraphProto | ||
|
|
||
| def import_model(model_file): | ||
| """Imports the ONNX model file passed as a parameter into MXNet symbol and parameters. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| model_file : str | ||
| ONNX model file name | ||
|
|
||
| Returns | ||
| ------- | ||
| Mxnet symbol and parameter objects. | ||
|
|
||
| sym : mxnet.symbol | ||
| Mxnet symbol | ||
| params : dict of str to mx.ndarray | ||
| Dict of converted parameters stored in mxnet.ndarray format | ||
| """ | ||
| graph = GraphProto() | ||
|
|
||
| # loads model file and returns ONNX protobuf object | ||
| try: | ||
| import onnx | ||
| except ImportError: | ||
| raise ImportError("Onnx and protobuf need to be installed") | ||
|
anirudhacharya marked this conversation as resolved.
|
||
| model_proto = onnx.load(model_file) | ||
| sym, params = graph.from_onnx(model_proto.graph) | ||
| return sym, params | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.