Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions docs/api/python/contrib/onnx.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# ONNX-MXNet API

## Overview

[ONNX](https://onnx.ai/) is an open format to represent deep learning models. With ONNX as an intermediate representation, it is easier to move models between state-of-the-art tools and frameworks for training and inference.

The `mxnet.contrib.onnx` package refers to the APIs and interfaces that implement ONNX model format support for Apache MXNet.

With ONNX format support for MXNet, developers can build and train models with a [variety of deep learning frameworks](http://onnx.ai/supported-tools), and import these models into MXNet to run them for inference and training using MXNet’s highly optimized engine.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also for re training and transfer learning usecases right?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes


```eval_rst
.. warning:: This package contains experimental APIs and may change in the near future.
```

### Installation Instructions
- To use this module developers need to **install ONNX**, which requires protobuf compiler to be installed separately. Please follow the [instructions to install ONNX and its dependencies](https://github.com/onnx/onnx#installation). Once installed, you can go through the tutorials on how to use this module.


This document describes all the ONNX-MXNet APIs.

```eval_rst
.. autosummary::
:nosignatures:

mxnet.contrib.onnx.import_model
```

## ONNX Tutorials

```eval_rst
.. toctree::
:maxdepth: 1

/tutorials/onnx/super_resolution.md
/tutorials/onnx/inference_on_onnx_model.md
```

## API Reference

<script type="text/javascript" src='../../_static/js/auto_module_index.js'></script>

```eval_rst

.. automodule:: mxnet.contrib.onnx
:members: import_model

```

<script>auto_index("api-reference");</script>
1 change: 1 addition & 0 deletions docs/api/python/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,4 +151,5 @@ imported by running:

contrib/contrib.md
contrib/text.md
contrib/onnx.md
```
2 changes: 2 additions & 0 deletions docs/tutorials/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ The Gluon and Module tutorials are in Python, but you can also find a variety of

- [Text classification (NLP) on Movie Reviews](http://mxnet.incubator.apache.org/tutorials/nlp/cnn.html)

- [Importing an ONNX model into MXNet](http://mxnet.incubator.apache.org/tutorials/onnx/super_resolution.html)

</div> <!--end of applications-->

</div> <!--end of module-->
Expand Down
1 change: 0 additions & 1 deletion docs/tutorials/onnx/inference_on_onnx_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ In this tutorial we will:
To run the tutorial you will need to have installed the following python modules:
- [MXNet](http://mxnet.incubator.apache.org/install/index.html)
- [onnx](https://github.com/onnx/onnx) (follow the install guide)
- [onnx-mxnet](https://github.com/onnx/onnx-mxnet)
- matplotlib
- wget

Expand Down
114 changes: 114 additions & 0 deletions docs/tutorials/onnx/super_resolution.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Importing an ONNX model into MXNet

In this tutorial we will:

- learn how to load a pre-trained ONNX model file into MXNet.
- run inference in MXNet.

## Prerequisites
This example assumes that the following python packages are installed:
- [mxnet](http://mxnet.incubator.apache.org/install/index.html)
- [onnx](https://github.com/onnx/onnx) (follow the install guide)
- Pillow - A Python Image Processing package and is required for input pre-processing. It can be installed with ```pip install Pillow```.
- matplotlib


```python
from PIL import Image
import numpy as np
import mxnet as mx
import mxnet.contrib.onnx as onnx_mxnet
from mxnet.test_utils import download
from matplotlib.pyplot import imshow
```

### Fetching the required files


```python
img_url = 'https://s3.amazonaws.com/onnx-mxnet/examples/super_res_input.jpg'
download(img_url, 'super_res_input.jpg')
model_url = 'https://s3.amazonaws.com/onnx-mxnet/examples/super_resolution.onnx'
onnx_model_file = download(model_url, 'super_resolution.onnx')
```

## Loading the model into MXNet

To completely describe a pre-trained model in MXNet, we need two elements: a symbolic graph, containing the model's network definition, and a binary file containing the model weights. You can import the ONNX model and get the symbol and parameters objects using ``import_model`` API. The paameter object is split into argument parameters and auxilliary parameters.


```python
sym, arg, aux = onnx_mxnet.import_model(onnx_model_file)
```

We can now visualize the imported model (graphviz needs to be installed)


```python
mx.viz.plot_network(sym, node_attrs={"shape":"oval","fixedsize":"false"})
```




![svg](https://s3.amazonaws.com/onnx-mxnet/examples/super_res_mxnet_model.png)



## Input Pre-processing

We will transform the previously downloaded input image into an input tensor.


```python
img = Image.open('super_res_input.jpg').resize((224, 224))
img_ycbcr = img.convert("YCbCr")
img_y, img_cb, img_cr = img_ycbcr.split()
test_image = np.array(img_y)[np.newaxis, np.newaxis, :, :]
```

## Run Inference using MXNet's Module API

We will use MXNet's Module API to run the inference. For this we will need to create the module, bind it to the input data and assign the loaded weights from the two parameter objects - argument parameters and auxilliary parameters.


```python
mod = mx.mod.Module(symbol=sym, data_names=['input_0'], context=mx.cpu(), label_names=None)
mod.bind(for_training=False, data_shapes=[('input_0',test_image.shape)], label_shapes=None)
mod.set_params(arg_params=arg, aux_params=aux, allow_missing=True, allow_extra=True)
```

Module API's forward method requires batch of data as input. We will prepare the data in that format and feed it to the forward method.


```python
from collections import namedtuple
Batch = namedtuple('Batch', ['data'])

# forward on the provided data batch
mod.forward(Batch([mx.nd.array(test_image)]))
```

To get the output of previous forward computation, you use ``module.get_outputs()`` method.
It returns an ``ndarray`` that we convert to a ``numpy`` array and then to Pillow's image format


```python
output = mod.get_outputs()[0][0][0]
img_out_y = Image.fromarray(np.uint8((output.asnumpy().clip(0, 255)), mode='L'))
result_img = Image.merge(
"YCbCr", [
img_out_y,
img_cb.resize(img_out_y.size, Image.BICUBIC),
img_cr.resize(img_out_y.size, Image.BICUBIC)
]).convert("RGB")
result_img.save("super_res_output.jpg")
```

Here's the input image and the resulting output images compared. As you can see, the model was able to increase the spatial resolution from ``256x256`` to ``672x672``.

| Input Image | Output Image |
| ----------- | ------------ |
| ![input](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/images/super_res_input.jpg?raw=true) | ![output](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/images/super_res_output.jpg?raw=true) |

<!-- INSERT SOURCE DOWNLOAD BUTTONS -->
2 changes: 1 addition & 1 deletion python/mxnet/contrib/onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@
# specific language governing permissions and limitations
# under the License.

"""Module for importing and exporting ONNX models."""
"""Module for ONNX model format support for Apache MXNet."""

from ._import.import_model import import_model
20 changes: 12 additions & 8 deletions python/mxnet/contrib/onnx/_import/import_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from .import_onnx import GraphProto

def import_model(model_file):
"""Imports the ONNX model file passed as a parameter into MXNet symbol and parameters.
"""Imports the ONNX model file, passed as a parameter, into MXNet symbol and parameters.
Operator support and coverage - https://cwiki.apache.org/confluence/display/MXNET/ONNX

Parameters
----------
Expand All @@ -31,20 +32,23 @@ def import_model(model_file):

Returns
-------
Mxnet symbol and parameter objects.
sym : :class:`~mxnet.symbol.Symbol`
MXNet symbol object

sym : mxnet.symbol
Mxnet symbol
params : dict of str to mx.ndarray
Dict of converted parameters stored in mxnet.ndarray format
arg_params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray`
Dict of converted parameters stored in ``mxnet.ndarray.NDArray`` format

aux_params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray`
Dict of converted parameters stored in ``mxnet.ndarray.NDArray`` format
"""
graph = GraphProto()

# loads model file and returns ONNX protobuf object
try:
import onnx
except ImportError:
raise ImportError("Onnx and protobuf need to be installed")
raise ImportError("Onnx and protobuf need to be installed. "
+ "Instructions to install - https://github.com/onnx/onnx")
# loads model file and returns ONNX protobuf object
model_proto = onnx.load(model_file)
sym, arg_params, aux_params = graph.from_onnx(model_proto.graph)
return sym, arg_params, aux_params
5 changes: 3 additions & 2 deletions python/mxnet/contrib/onnx/_import/import_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,9 @@ def _parse_array(self, tensor_proto):
"""Grab data in TensorProto and convert to numpy array."""
try:
from onnx.numpy_helper import to_array
except ImportError as e:
raise ImportError("Unable to import onnx which is required {}".format(e))
except ImportError:
raise ImportError("Onnx and protobuf need to be installed. "
+ "Instructions to install - https://github.com/onnx/onnx")
np_array = to_array(tensor_proto).reshape(tuple(tensor_proto.dims))
return nd.array(np_array)

Expand Down