diff --git a/docs/api/python/contrib/onnx.md b/docs/api/python/contrib/onnx.md new file mode 100644 index 000000000000..2a8f1fcea0a6 --- /dev/null +++ b/docs/api/python/contrib/onnx.md @@ -0,0 +1,49 @@ +# ONNX-MXNet API + +## Overview + +[ONNX](https://onnx.ai/) is an open format to represent deep learning models. With ONNX as an intermediate representation, it is easier to move models between state-of-the-art tools and frameworks for training and inference. + +The `mxnet.contrib.onnx` package refers to the APIs and interfaces that implement ONNX model format support for Apache MXNet. + +With ONNX format support for MXNet, developers can build and train models with a [variety of deep learning frameworks](http://onnx.ai/supported-tools), and import these models into MXNet to run them for inference and training using MXNet’s highly optimized engine. + +```eval_rst +.. warning:: This package contains experimental APIs and may change in the near future. +``` + +### Installation Instructions +- To use this module developers need to **install ONNX**, which requires protobuf compiler to be installed separately. Please follow the [instructions to install ONNX and its dependencies](https://github.com/onnx/onnx#installation). Once installed, you can go through the tutorials on how to use this module. + + +This document describes all the ONNX-MXNet APIs. + +```eval_rst +.. autosummary:: + :nosignatures: + + mxnet.contrib.onnx.import_model +``` + +## ONNX Tutorials + +```eval_rst +.. toctree:: + :maxdepth: 1 + + /tutorials/onnx/super_resolution.md + /tutorials/onnx/inference_on_onnx_model.md +``` + +## API Reference + + + +```eval_rst + +.. automodule:: mxnet.contrib.onnx + :members: import_model + +``` + + \ No newline at end of file diff --git a/docs/api/python/index.md b/docs/api/python/index.md index f65d3abfb15f..b097e2045b14 100644 --- a/docs/api/python/index.md +++ b/docs/api/python/index.md @@ -151,4 +151,5 @@ imported by running: contrib/contrib.md contrib/text.md + contrib/onnx.md ``` diff --git a/docs/tutorials/index.md b/docs/tutorials/index.md index 8a597e95bfb7..ff767064d7c9 100644 --- a/docs/tutorials/index.md +++ b/docs/tutorials/index.md @@ -188,6 +188,8 @@ The Gluon and Module tutorials are in Python, but you can also find a variety of - [Text classification (NLP) on Movie Reviews](http://mxnet.incubator.apache.org/tutorials/nlp/cnn.html) +- [Importing an ONNX model into MXNet](http://mxnet.incubator.apache.org/tutorials/onnx/super_resolution.html) + diff --git a/docs/tutorials/onnx/inference_on_onnx_model.md b/docs/tutorials/onnx/inference_on_onnx_model.md index 182a2ae74cde..2b64945e4e9a 100644 --- a/docs/tutorials/onnx/inference_on_onnx_model.md +++ b/docs/tutorials/onnx/inference_on_onnx_model.md @@ -14,7 +14,6 @@ In this tutorial we will: To run the tutorial you will need to have installed the following python modules: - [MXNet](http://mxnet.incubator.apache.org/install/index.html) - [onnx](https://github.com/onnx/onnx) (follow the install guide) -- [onnx-mxnet](https://github.com/onnx/onnx-mxnet) - matplotlib - wget diff --git a/docs/tutorials/onnx/super_resolution.md b/docs/tutorials/onnx/super_resolution.md new file mode 100644 index 000000000000..dc75b6606f20 --- /dev/null +++ b/docs/tutorials/onnx/super_resolution.md @@ -0,0 +1,114 @@ +# Importing an ONNX model into MXNet + +In this tutorial we will: + +- learn how to load a pre-trained ONNX model file into MXNet. +- run inference in MXNet. + +## Prerequisites +This example assumes that the following python packages are installed: +- [mxnet](http://mxnet.incubator.apache.org/install/index.html) +- [onnx](https://github.com/onnx/onnx) (follow the install guide) +- Pillow - A Python Image Processing package and is required for input pre-processing. It can be installed with ```pip install Pillow```. +- matplotlib + + +```python +from PIL import Image +import numpy as np +import mxnet as mx +import mxnet.contrib.onnx as onnx_mxnet +from mxnet.test_utils import download +from matplotlib.pyplot import imshow +``` + +### Fetching the required files + + +```python +img_url = 'https://s3.amazonaws.com/onnx-mxnet/examples/super_res_input.jpg' +download(img_url, 'super_res_input.jpg') +model_url = 'https://s3.amazonaws.com/onnx-mxnet/examples/super_resolution.onnx' +onnx_model_file = download(model_url, 'super_resolution.onnx') +``` + +## Loading the model into MXNet + +To completely describe a pre-trained model in MXNet, we need two elements: a symbolic graph, containing the model's network definition, and a binary file containing the model weights. You can import the ONNX model and get the symbol and parameters objects using ``import_model`` API. The paameter object is split into argument parameters and auxilliary parameters. + + +```python +sym, arg, aux = onnx_mxnet.import_model(onnx_model_file) +``` + +We can now visualize the imported model (graphviz needs to be installed) + + +```python +mx.viz.plot_network(sym, node_attrs={"shape":"oval","fixedsize":"false"}) +``` + + + + +![svg](https://s3.amazonaws.com/onnx-mxnet/examples/super_res_mxnet_model.png) + + + +## Input Pre-processing + +We will transform the previously downloaded input image into an input tensor. + + +```python +img = Image.open('super_res_input.jpg').resize((224, 224)) +img_ycbcr = img.convert("YCbCr") +img_y, img_cb, img_cr = img_ycbcr.split() +test_image = np.array(img_y)[np.newaxis, np.newaxis, :, :] +``` + +## Run Inference using MXNet's Module API + +We will use MXNet's Module API to run the inference. For this we will need to create the module, bind it to the input data and assign the loaded weights from the two parameter objects - argument parameters and auxilliary parameters. + + +```python +mod = mx.mod.Module(symbol=sym, data_names=['input_0'], context=mx.cpu(), label_names=None) +mod.bind(for_training=False, data_shapes=[('input_0',test_image.shape)], label_shapes=None) +mod.set_params(arg_params=arg, aux_params=aux, allow_missing=True, allow_extra=True) +``` + +Module API's forward method requires batch of data as input. We will prepare the data in that format and feed it to the forward method. + + +```python +from collections import namedtuple +Batch = namedtuple('Batch', ['data']) + +# forward on the provided data batch +mod.forward(Batch([mx.nd.array(test_image)])) +``` + +To get the output of previous forward computation, you use ``module.get_outputs()`` method. +It returns an ``ndarray`` that we convert to a ``numpy`` array and then to Pillow's image format + + +```python +output = mod.get_outputs()[0][0][0] +img_out_y = Image.fromarray(np.uint8((output.asnumpy().clip(0, 255)), mode='L')) +result_img = Image.merge( +"YCbCr", [ + img_out_y, + img_cb.resize(img_out_y.size, Image.BICUBIC), + img_cr.resize(img_out_y.size, Image.BICUBIC) +]).convert("RGB") +result_img.save("super_res_output.jpg") +``` + +Here's the input image and the resulting output images compared. As you can see, the model was able to increase the spatial resolution from ``256x256`` to ``672x672``. + +| Input Image | Output Image | +| ----------- | ------------ | +| ![input](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/images/super_res_input.jpg?raw=true) | ![output](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/images/super_res_output.jpg?raw=true) | + + \ No newline at end of file diff --git a/python/mxnet/contrib/onnx/__init__.py b/python/mxnet/contrib/onnx/__init__.py index eff91206298f..169ac673455c 100644 --- a/python/mxnet/contrib/onnx/__init__.py +++ b/python/mxnet/contrib/onnx/__init__.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. -"""Module for importing and exporting ONNX models.""" +"""Module for ONNX model format support for Apache MXNet.""" from ._import.import_model import import_model diff --git a/python/mxnet/contrib/onnx/_import/import_model.py b/python/mxnet/contrib/onnx/_import/import_model.py index d8d32a96a216..1bd4b418bc35 100644 --- a/python/mxnet/contrib/onnx/_import/import_model.py +++ b/python/mxnet/contrib/onnx/_import/import_model.py @@ -22,7 +22,8 @@ from .import_onnx import GraphProto def import_model(model_file): - """Imports the ONNX model file passed as a parameter into MXNet symbol and parameters. + """Imports the ONNX model file, passed as a parameter, into MXNet symbol and parameters. + Operator support and coverage - https://cwiki.apache.org/confluence/display/MXNET/ONNX Parameters ---------- @@ -31,20 +32,23 @@ def import_model(model_file): Returns ------- - Mxnet symbol and parameter objects. + sym : :class:`~mxnet.symbol.Symbol` + MXNet symbol object - sym : mxnet.symbol - Mxnet symbol - params : dict of str to mx.ndarray - Dict of converted parameters stored in mxnet.ndarray format + arg_params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray` + Dict of converted parameters stored in ``mxnet.ndarray.NDArray`` format + + aux_params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray` + Dict of converted parameters stored in ``mxnet.ndarray.NDArray`` format """ graph = GraphProto() - # loads model file and returns ONNX protobuf object try: import onnx except ImportError: - raise ImportError("Onnx and protobuf need to be installed") + raise ImportError("Onnx and protobuf need to be installed. " + + "Instructions to install - https://github.com/onnx/onnx") + # loads model file and returns ONNX protobuf object model_proto = onnx.load(model_file) sym, arg_params, aux_params = graph.from_onnx(model_proto.graph) return sym, arg_params, aux_params diff --git a/python/mxnet/contrib/onnx/_import/import_onnx.py b/python/mxnet/contrib/onnx/_import/import_onnx.py index 037790c80806..92e7cb9c64e8 100644 --- a/python/mxnet/contrib/onnx/_import/import_onnx.py +++ b/python/mxnet/contrib/onnx/_import/import_onnx.py @@ -147,8 +147,9 @@ def _parse_array(self, tensor_proto): """Grab data in TensorProto and convert to numpy array.""" try: from onnx.numpy_helper import to_array - except ImportError as e: - raise ImportError("Unable to import onnx which is required {}".format(e)) + except ImportError: + raise ImportError("Onnx and protobuf need to be installed. " + + "Instructions to install - https://github.com/onnx/onnx") np_array = to_array(tensor_proto).reshape(tuple(tensor_proto.dims)) return nd.array(np_array)