Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit b219ac2

Browse files
authored
[v1.x] Add more ONNX export support to operators (#19625)
1 parent c37d5aa commit b219ac2

File tree

3 files changed

+321
-0
lines changed

3 files changed

+321
-0
lines changed

ci/docker/runtime_functions.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1281,6 +1281,7 @@ integrationtest_ubuntu_cpu_onnx() {
12811281
pytest tests/python-pytest/onnx/mxnet_export_test.py
12821282
pytest tests/python-pytest/onnx/test_models.py
12831283
pytest tests/python-pytest/onnx/test_node.py
1284+
pytest tests/python-pytest/onnx/test_operators.py
12841285
pytest tests/python-pytest/onnx/test_onnxruntime.py
12851286
}
12861287

python/mxnet/contrib/onnx/mx2onnx/_op_translations.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,29 @@ def create_basic_op_node(op_name, node, kwargs):
154154
)
155155
return [node]
156156

157+
def create_const_scalar_node(input_name, value, kwargs):
158+
"""Helper function to create a tensor value node and a
159+
initializer tensor node with constant value."""
160+
from onnx.helper import make_tensor, make_tensor_value_info
161+
initializer = kwargs["initializer"]
162+
input_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[value.dtype]
163+
value_node = make_tensor_value_info(input_name, input_type, ())
164+
tensor_node = make_tensor(input_name, input_type, (), (value,))
165+
initializer.append(tensor_node)
166+
return value_node
167+
168+
def create_const_node(input_name, value, kwargs):
169+
"""Helper function to create a tensor value node and a
170+
initializer tensor node with constant value."""
171+
from onnx.helper import make_tensor, make_tensor_value_info
172+
initializer = kwargs["initializer"]
173+
input_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[value.dtype]
174+
input_shape = value.shape
175+
value_node = make_tensor_value_info(input_name, input_type, input_shape)
176+
tensor_node = make_tensor(input_name, input_type, input_shape, value)
177+
initializer.append(tensor_node)
178+
return value_node
179+
157180
@mx_op.register("null")
158181
def convert_weights_and_inputs(node, **kwargs):
159182
"""Helper function to convert weights and inputs.
@@ -802,6 +825,7 @@ def convert_leakyrelu(node, **kwargs):
802825
"""Map MXNet's LeakyReLU operator attributes to onnx's Elu/LeakyRelu/PRelu operators
803826
based on the input node's attributes and return the created node.
804827
"""
828+
from onnx.helper import make_node
805829
name, input_nodes, attrs = get_inputs(node, kwargs)
806830

807831
act_type = attrs.get("act_type", "leaky")
@@ -816,6 +840,19 @@ def convert_leakyrelu(node, **kwargs):
816840
inputs=input_nodes,
817841
outputs=[name],
818842
name=name)
843+
elif act_type in ('gelu'):
844+
sqrt2 = np.float32(1.4142135623730951)
845+
nodes = [
846+
create_const_scalar_node(name+"_sqrt2", sqrt2, kwargs),
847+
make_node("Div", [input_nodes[0], name+"_sqrt2"], [name+"_div0_out"]),
848+
make_node("Erf", [name+"_div0_out"], [name+"_erf0_out"]),
849+
create_const_scalar_node(name+"_one", np.float32(1.0), kwargs),
850+
create_const_scalar_node(name+"_half", np.float32(0.5), kwargs),
851+
make_node("Add", [name+"_erf0_out", name+"_one"], [name+"_add0_out"]),
852+
make_node("Mul", [input_nodes[0], name+"_add0_out"], [name+"_mul0_out"]),
853+
make_node("Mul", [name+"_mul0_out", name+"_half"], [name])
854+
]
855+
return nodes
819856
else:
820857
node = onnx.helper.make_node(
821858
act_name[act_type],
@@ -2214,3 +2251,152 @@ def convert_take(node, **kwargs):
22142251
name=name,
22152252
)
22162253
return [node]
2254+
2255+
2256+
@mx_op.register("LayerNorm")
2257+
def convert_layer_norm(node, **kwargs):
2258+
"""Map MXNet's LayerNorm operator attributes to onnx operators.
2259+
"""
2260+
from onnx.helper import make_node
2261+
name, input_nodes, attrs = get_inputs(node, kwargs)
2262+
2263+
in_shape = kwargs['in_shape']
2264+
axes = [-i for i in range(len(in_shape[0]), 0, -1)]
2265+
eps = attrs.get('eps')
2266+
nodes = [
2267+
make_node("ReduceMean", [input_nodes[0]], [name+"_rm0_out"], axes=axes),
2268+
make_node("Sub", [input_nodes[0], name+"_rm0_out"], [name+"_sub0_out"]),
2269+
create_const_scalar_node(name+"_two", np.float32(2.), kwargs),
2270+
make_node("Pow", [name+"_sub0_out", name+"_two"], [name+"_pow0_out"]),
2271+
make_node("ReduceMean", [name+"_pow0_out"], [name+"_rm1_out"], axes=axes),
2272+
create_const_scalar_node(name+"_eps", np.float32(eps), kwargs),
2273+
make_node("Add", [name+"_rm1_out", name+"_eps"], [name+"_add0_out"]),
2274+
make_node("Sqrt", [name+"_add0_out"], [name+"_sqrt0_out"]),
2275+
make_node("Div", [name+"_sub0_out", name+"_sqrt0_out"], [name+"_div0_out"]),
2276+
make_node("Mul", [name+"_div0_out", input_nodes[1]], [name+"_mul0_out"]),
2277+
make_node("Add", [name+"_mul0_out", input_nodes[2]], [name], name)
2278+
]
2279+
2280+
return nodes
2281+
2282+
2283+
@mx_op.register("Embedding")
2284+
def convert_embedding(node, **kwargs):
2285+
"""Map MXNet's Embedding operator attributes to onnx's
2286+
Gather operator."""
2287+
name, input_nodes, attrs = get_inputs(node, kwargs)
2288+
axis = int(attrs.get('axis', 0))
2289+
node = onnx.helper.make_node(
2290+
"Gather",
2291+
input_nodes,
2292+
[name],
2293+
axis=axis,
2294+
name=name
2295+
)
2296+
return [node]
2297+
2298+
2299+
@mx_op.register("stack")
2300+
def convert_stack(node, **kwargs):
2301+
"""Map MXNet's stack operator to onnx operators.
2302+
"""
2303+
name, input_nodes, attrs = get_inputs(node, kwargs)
2304+
axis = int(attrs.get('axis', 0))
2305+
idx = 0
2306+
nodes = []
2307+
for input_node in input_nodes:
2308+
nodes.append(onnx.helper.make_node(
2309+
"Unsqueeze",
2310+
inputs=[input_node],
2311+
outputs=[name+"_unsqueeze"+str(idx)],
2312+
axes=[axis]
2313+
))
2314+
idx += 1
2315+
2316+
nodes.append(onnx.helper.make_node(
2317+
"Concat",
2318+
inputs=[name+"_unsqueeze"+str(i) for i in range(len(nodes))],
2319+
outputs=[name],
2320+
name=name,
2321+
axis=axis
2322+
))
2323+
return nodes
2324+
2325+
2326+
@mx_op.register("slice")
2327+
def convert_slice(node, **kwargs):
2328+
"""Map MXNet's slice operator to onnx Slice operator."""
2329+
name, input_nodes, attrs = get_inputs(node, kwargs)
2330+
starts = convert_string_to_list(attrs.get("begin"))
2331+
ends = convert_string_to_list(attrs.get("end"))
2332+
steps = attrs.get("step", [])
2333+
nodes = [
2334+
create_const_node(name+"_begin", np.array(starts), kwargs),
2335+
create_const_node(name+"_end", np.array(ends), kwargs)
2336+
]
2337+
inputs = [input_nodes[0], name+"_begin", name+"_end"]
2338+
if len(steps) > 0:
2339+
nodes.append(create_const_node(name+"_steps", np.array(steps, dtype='int64'), kwargs))
2340+
inputs.append(name+"_steps")
2341+
nodes.append(onnx.helper.make_node("Slice", inputs, [name], name=name))
2342+
return nodes
2343+
2344+
2345+
@mx_op.register("zeros_like")
2346+
def convert_zeros_like(node, **kwargs):
2347+
"""Map MXNet's zeros_like operator attributes to onnx's ConstantOfShape operator.
2348+
"""
2349+
from onnx.helper import make_node, make_tensor
2350+
name, _, _ = get_inputs(node, kwargs)
2351+
2352+
# create tensor with shape of input
2353+
create_const_node(name+"_shape", np.array(kwargs['in_shape'][0], dtype='int64'), kwargs)
2354+
tensor_value = make_tensor(name+"_zero", kwargs['in_type'], [1], [0])
2355+
nodes = [
2356+
make_node("ConstantOfShape", [name+"_shape"], [name], value=tensor_value)
2357+
]
2358+
return nodes
2359+
2360+
2361+
@mx_op.register("_contrib_arange_like")
2362+
def convert_arange_like(node, **kwargs):
2363+
"""Map MXNet's arange_like operator attributes to onnx's Range and Reshape operators.
2364+
"""
2365+
from onnx.helper import make_node
2366+
name, _, attrs = get_inputs(node, kwargs)
2367+
2368+
opset_version = kwargs['opset_version']
2369+
if opset_version < 11:
2370+
raise AttributeError("ONNX opset 11 or greater is required to export this operator")
2371+
2372+
input_type = kwargs['in_type']
2373+
dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[input_type]
2374+
in_shape = kwargs['in_shape']
2375+
axis = attrs.get('axis')
2376+
2377+
if axis is None:
2378+
# output will be same shape as input
2379+
output_shape = in_shape[0]
2380+
else:
2381+
# determine shape of axis
2382+
output_shape = [in_shape[0][int(axis)]]
2383+
2384+
start = np.array([attrs.get('start', 0.)], dtype=dtype)
2385+
step = np.array([attrs.get('step', 1.)], dtype=dtype)
2386+
repeat = np.array([attrs.get('repeat', 1)], dtype=dtype)
2387+
if repeat != 1:
2388+
raise NotImplementedError("arange_like operator with repeat != 1 not yet implemented.")
2389+
2390+
tot_elements = np.prod(output_shape)
2391+
limit = np.array([start + (tot_elements * step)], dtype=dtype)
2392+
2393+
# create constant inputs
2394+
nodes = [
2395+
create_const_scalar_node(name+"_start", start, kwargs),
2396+
create_const_scalar_node(name+"_limit", limit, kwargs),
2397+
create_const_scalar_node(name+"_step", step, kwargs),
2398+
create_const_node(name+"_shape", np.array(output_shape, dtype='int64'), kwargs),
2399+
make_node("Range", [name+"_start", name+"_limit", name+"_step"], [name+"_range0_out"]),
2400+
make_node("Reshape", [name+"_range0_out", name+"_shape"], [name])
2401+
]
2402+
return nodes
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import mxnet as mx
19+
from mxnet.gluon import HybridBlock, nn
20+
import numpy as np
21+
import onnxruntime as rt
22+
from mxnet.test_utils import assert_almost_equal
23+
import pytest
24+
import tempfile
25+
26+
def op_export_test(op_name, Model, inputs, tmp_path):
27+
def export_to_onnx(model, op_name, inputs):
28+
model_path = '{}/{}'.format(tmp_path, op_name)
29+
model.export(model_path, epoch=0)
30+
sym_file = '{}-symbol.json'.format(model_path)
31+
params_file = '{}-0000.params'.format(model_path)
32+
dtype = inputs[0].dtype
33+
onnx_file = '{}/{}.onnx'.format(tmp_path, op_name)
34+
mx.contrib.onnx.export_model(sym_file, params_file, [i.shape for i in inputs],
35+
dtype, onnx_file)
36+
return onnx_file
37+
def onnx_rt(onnx_file, inputs):
38+
sess = rt.InferenceSession(onnx_file)
39+
input_dict = dict((sess.get_inputs()[i].name, inputs[i].asnumpy()) for i in range(len(inputs)))
40+
pred = sess.run(None, input_dict)[0]
41+
return pred
42+
43+
# create a new model
44+
model = Model()
45+
model.initialize(ctx=mx.cpu(0))
46+
model.hybridize()
47+
pred_nat = model(*inputs)
48+
onnx_file = export_to_onnx(model, op_name, inputs)
49+
pred_onx = onnx_rt(onnx_file, inputs)
50+
assert_almost_equal(pred_nat, pred_onx)
51+
52+
53+
def test_onnx_export_abs():
54+
with tempfile.TemporaryDirectory() as tmp_path:
55+
class Model(HybridBlock):
56+
def __init__(self, **kwargs):
57+
super(Model, self).__init__(**kwargs)
58+
def hybrid_forward(self, F, x):
59+
out = F.abs(x)
60+
return out
61+
x = mx.nd.array([[-2, -1], [0, 99]], dtype='float32')
62+
op_export_test('abs', Model, [x], tmp_path)
63+
64+
def test_onnx_export_slice():
65+
with tempfile.TemporaryDirectory() as tmp_path:
66+
class Model(HybridBlock):
67+
def __init__(self, **kwargs):
68+
super(Model, self).__init__(**kwargs)
69+
def hybrid_forward(self, F, x):
70+
out = F.slice(x, begin=(0,1), end=(2,4))
71+
return out
72+
x = mx.nd.array([[1,2,3,4],[5,6,7,8],[9,10,11,12]], dtype='float32')
73+
op_export_test('slice', Model, [x], tmp_path)
74+
75+
def test_onnx_export_stack():
76+
with tempfile.TemporaryDirectory() as tmp_path:
77+
dtype = 'float32'
78+
class Model(HybridBlock):
79+
def __init__(self, **kwargs):
80+
super(Model, self).__init__(**kwargs)
81+
def hybrid_forward(self, F, x, y):
82+
out = F.stack(x, y)
83+
return out
84+
x = mx.nd.array([1, 2], dtype=dtype)
85+
y = mx.nd.array([3, 4], dtype=dtype)
86+
op_export_test('stack', Model, [x, y], tmp_path)
87+
88+
def test_onnx_export_zeros_like():
89+
with tempfile.TemporaryDirectory() as tmp_path:
90+
class Model(HybridBlock):
91+
def __init__(self, **kwargs):
92+
super(Model, self).__init__(**kwargs)
93+
def hybrid_forward(self, F, x):
94+
out = F.zeros_like(x)
95+
return out
96+
x = mx.nd.array([[-2,-1,0],[0,50,99],[4,5,6],[7,8,9]], dtype='float32')
97+
op_export_test('zeros_like', Model, [x], tmp_path)
98+
99+
@pytest.mark.parametrize("dtype", ["float32", "double"])
100+
def test_onnx_export_arange_like(dtype):
101+
with tempfile.TemporaryDirectory() as tmp_path:
102+
class Model(HybridBlock):
103+
def __init__(self, **kwargs):
104+
super(Model, self).__init__(**kwargs)
105+
def hybrid_forward(self, F, x):
106+
out = F.contrib.arange_like(x)
107+
return out
108+
x = mx.nd.array([[-2,-1,0],[0,50,99],[4,5,6],[7,8,9]], dtype=dtype)
109+
op_export_test('arange_like', Model, [x], tmp_path)
110+
111+
def test_onnx_export_layernorm():
112+
with tempfile.TemporaryDirectory() as tmp_path:
113+
dtype = 'float32'
114+
class Model(HybridBlock):
115+
def __init__(self, **kwargs):
116+
super(Model, self).__init__(**kwargs)
117+
def hybrid_forward(self, F, x, gamma, beta):
118+
out = F.LayerNorm(x, gamma, beta, axis=1)
119+
return out
120+
x = mx.nd.array([[1,3],[2,4]], dtype=dtype)
121+
gamma = mx.random.uniform(0, 1, x[0].shape).astype(dtype)
122+
beta = mx.random.uniform(0, 1, x[0].shape).astype(dtype)
123+
op_export_test('LayerNorm', Model, [x, gamma, beta], tmp_path)
124+
125+
126+
if __name__ == '__main__':
127+
test_onnx_export_abs()
128+
test_onnx_export_slice()
129+
test_onnx_export_stack()
130+
test_onnx_export_zeros_like()
131+
test_onnx_export_arange_like('float32')
132+
test_onnx_export_arange_like('double')
133+
test_onnx_export_layernorm()
134+

0 commit comments

Comments
 (0)