Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
10 changes: 6 additions & 4 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,14 +872,14 @@ def convert_softmax(node, **kwargs):
data = input_nodes[0]

# use op set 11 ONNX Softmax
if axis == -1 and temperature == 'None':
if axis == -1 and temperature == 1.:
nodes = []
if use_length == "True":
nodes += [
create_const_scalar_node(name+"_0_s", np.int64(0), kwargs),
create_const_scalar_node(name+"_1_s", np.int64(1), kwargs),
create_tensor([np.finfo(dtype).min], name+"_mask_val", kwargs["initializer"],
dtype=dtype),
# magic number, this is fp16 min
create_tensor([-65500.0], name+"_mask_val", kwargs["initializer"], dtype=dtype),
create_tensor([], name+"_void", kwargs["initializer"]),
create_tensor([1], name+"_1", kwargs["initializer"]),
make_node("Shape", [data], [name+"_shape"]),
Expand Down Expand Up @@ -2380,14 +2380,16 @@ def convert_layer_norm(node, **kwargs):
axes = int(attrs.get('axis', -1))
eps = attrs.get('eps', 9.99999975e-06)

input_type = int(kwargs['in_type'])
dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[input_type]

nodes = [
create_tensor([axes], name+"_axes", kwargs["initializer"]),
create_tensor([axes+1], name+"_axes+1", kwargs["initializer"]),
create_tensor([], name+"_void", kwargs["initializer"]),
create_const_scalar_node(name+'_0_s', np.int64(0), kwargs),
create_const_scalar_node(name+'_1_s', np.int64(1), kwargs),
create_const_scalar_node(name+"_2_s", np.int64(2), kwargs),
create_const_scalar_node(name+"_2_s", np.array(2, dtype=dtype), kwargs),
create_const_scalar_node(name+"_eps", np.float32(eps), kwargs),
make_node("ReduceMean", [input_nodes[0]], [name+"_rm0_out"], axes=[axes]),
make_node("Sub", [input_nodes[0], name+"_rm0_out"], [name+"_sub0_out"]),
Expand Down
6 changes: 6 additions & 0 deletions python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,15 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False,
# Determine output shape
graph_outputs = MXNetGraph.get_outputs(sym, params, in_shape, output_label, in_type)

appeared_names = set()
graph_input_idx = 0
for idx, node in enumerate(mx_graph):
op = node["op"]
# check if the current node has the same name as nodes before
if node["name"] in appeared_names:
node["name"] = 'idx_' + str(idx) + '_' + node["name"]
else:
appeared_names.add(node["name"])
name = node["name"]
if verbose:
logging.info("Converting idx: %d, op: %s, name: %s", idx, op, name)
Expand Down
111 changes: 55 additions & 56 deletions tests/python-pytest/onnx/test_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,61 @@ def download_test_images(image_urls, tmpdir):
paths.append(filename)
return paths


@with_seed()
@pytest.mark.parametrize('model', ['bert_12_768_12'])
def test_bert_inference_onnxruntime(tmp_path, model):
tmp_path = str(tmp_path)
try:
import gluonnlp as nlp
dataset = 'book_corpus_wiki_en_uncased'
ctx = mx.cpu(0)
model, vocab = nlp.model.get_model(
name=model,
ctx=ctx,
dataset_name=dataset,
pretrained=False,
use_pooler=True,
use_decoder=False,
use_classifier=False)
model.initialize(ctx=ctx)
model.hybridize(static_alloc=True)

batch = 5
seq_length = 16
# create synthetic test data
inputs = mx.nd.random.uniform(0, 30522, shape=(batch, seq_length), dtype='float32')
token_types = mx.nd.random.uniform(0, 2, shape=(batch, seq_length), dtype='float32')
valid_length = mx.nd.array([seq_length] * batch, dtype='float32')

seq_encoding, cls_encoding = model(inputs, token_types, valid_length)

prefix = "%s/bert" % tmp_path
model.export(prefix)
sym_file = "%s-symbol.json" % prefix
params_file = "%s-0000.params" % prefix
onnx_file = "%s.onnx" % prefix


input_shapes = [(batch, seq_length), (batch, seq_length), (batch,)]
converted_model_path = mx.contrib.onnx.export_model(sym_file, params_file, input_shapes, np.float32, onnx_file)


# create onnxruntime session using the generated onnx file
ses_opt = onnxruntime.SessionOptions()
ses_opt.log_severity_level = 3
session = onnxruntime.InferenceSession(onnx_file, ses_opt)
onnx_inputs = [inputs, token_types, valid_length]
input_dict = dict((session.get_inputs()[i].name, onnx_inputs[i].asnumpy()) for i in range(len(onnx_inputs)))
pred_onx, cls_onx = session.run(None, input_dict)

assert_almost_equal(seq_encoding, pred_onx, rtol=0.01, atol=0.01)
assert_almost_equal(cls_encoding, cls_onx, rtol=0.01, atol=0.01)

finally:
shutil.rmtree(tmp_path)


@pytest.mark.parametrize('model', [
'alexnet',
'cifar_resnet20_v1',
Expand Down Expand Up @@ -399,59 +454,3 @@ def load_video(filepath):
finally:
shutil.rmtree(tmp_path)



@with_seed()
@pytest.mark.parametrize('model', ['bert_12_768_12'])
def test_bert_inference_onnxruntime(tmp_path, model):
tmp_path = str(tmp_path)
try:
import gluonnlp as nlp
dataset = 'book_corpus_wiki_en_uncased'
ctx = mx.cpu(0)
model, vocab = nlp.model.get_model(
name=model,
ctx=ctx,
dataset_name=dataset,
pretrained=False,
use_pooler=True,
use_decoder=False,
use_classifier=False)
model.initialize(ctx=ctx)
model.hybridize(static_alloc=True)

batch = 5
seq_length = 16
# create synthetic test data
inputs = mx.nd.random.uniform(0, 30522, shape=(batch, seq_length), dtype='float32')
token_types = mx.nd.random.uniform(0, 2, shape=(batch, seq_length), dtype='float32')
valid_length = mx.nd.array([seq_length] * batch, dtype='float32')

seq_encoding, cls_encoding = model(inputs, token_types, valid_length)

prefix = "%s/bert" % tmp_path
model.export(prefix)
sym_file = "%s-symbol.json" % prefix
params_file = "%s-0000.params" % prefix
onnx_file = "%s.onnx" % prefix


input_shapes = [(batch, seq_length), (batch, seq_length), (batch,)]
converted_model_path = mx.contrib.onnx.export_model(sym_file, params_file, input_shapes, np.float32, onnx_file)


# create onnxruntime session using the generated onnx file
ses_opt = onnxruntime.SessionOptions()
ses_opt.log_severity_level = 3
session = onnxruntime.InferenceSession(onnx_file, ses_opt)
onnx_inputs = [inputs, token_types, valid_length]
input_dict = dict((session.get_inputs()[i].name, onnx_inputs[i].asnumpy()) for i in range(len(onnx_inputs)))
pred_onx, cls_onx = session.run(None, input_dict)

assert_almost_equal(seq_encoding, pred_onx)
assert_almost_equal(cls_encoding, cls_onx)

finally:
shutil.rmtree(tmp_path)


16 changes: 10 additions & 6 deletions tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,19 +381,21 @@ def test_onnx_export_cast(tmp_path, src_dtype, dst_dtype, shape):


@pytest.mark.parametrize('dtype', ['float16', 'float32'])
@pytest.mark.parametrize('temperature', [.1, 1., 10.])
@pytest.mark.parametrize('temperature', [None, .1, 1., 10.])
def test_onnx_export_softmax(tmp_path, dtype, temperature):
x = mx.nd.random.uniform(0, 1, (2, 3, 4), dtype=dtype)
x = mx.nd.random.uniform(0, 1, (4, 5, 6), dtype=dtype)
M1 = def_model('softmax')
op_export_test('softmax_1', M1, [x], tmp_path)
M2 = def_model('softmax', use_length=True, axis=0, temperature=temperature)
l2 = mx.nd.array([[2,0,2,1],[1,1,2,1], [0,0,0,1]], dtype=int)
l2 = mx.random.uniform(0, 4, (5, 6)).astype('int32')
op_export_test('softmax_2', M2, [x, l2], tmp_path)
M3 = def_model('softmax', use_length=True, axis=-1, temperature=temperature)
l3 = mx.nd.array([[2,0,4],[0,0,0]], dtype=int)
# note that the axis==-1 case uses negative value masking + ONNX softmax
# when valid_len==0 the masked values will NOT be 0
l3 = mx.random.uniform(1, 6, (4, 5)).astype('int32')
op_export_test('softmax_3', M3, [x, l3], tmp_path)
M4 = def_model('softmax', use_length=True, axis=1, temperature=temperature)
l4 = mx.nd.array([[2,0,3,1],[0,1,0,0]], dtype=int)
l4 = mx.random.uniform(0, 5, (4, 6)).astype('int32')
op_export_test('softmax_4', M4, [x, l4], tmp_path)


Expand Down Expand Up @@ -580,7 +582,9 @@ def hybrid_forward(self, F, x):
op_export_test('link_op_with_multiple_outputs_case3', Model3, [A], tmp_path)


@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int32', 'int64'])
# opset 8 MAX only supports float types
# opset 12 and up suppots float and int
@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64'])
@pytest.mark.parametrize('shape', [(3, 4, 5), (1, 4, 1, 7)])
def test_onnx_maximum_scalar(tmp_path, dtype, shape):
x = mx.random.uniform(0, 10, shape).astype(dtype)
Expand Down