From afe2853a35b52315a1eef547537b6ccd94329a05 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Fri, 29 Jan 2021 00:56:40 +0000 Subject: [PATCH 1/9] fix Softmax --- .../mxnet/contrib/onnx/mx2onnx/_op_translations.py | 6 +++--- tests/python-pytest/onnx/test_operators.py | 12 +++++++----- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 0ca1ef50ca2f..fbc21a3d8c9a 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -872,14 +872,14 @@ def convert_softmax(node, **kwargs): data = input_nodes[0] # use op set 11 ONNX Softmax - if axis == -1 and temperature == 'None': + if axis == -1 and temperature == 1.: nodes = [] if use_length == "True": nodes += [ create_const_scalar_node(name+"_0_s", np.int64(0), kwargs), create_const_scalar_node(name+"_1_s", np.int64(1), kwargs), - create_tensor([np.finfo(dtype).min], name+"_mask_val", kwargs["initializer"], - dtype=dtype), + # magic number, this is fp16 min + create_tensor([-65500.0], name+"_mask_val", kwargs["initializer"], dtype=dtype), create_tensor([], name+"_void", kwargs["initializer"]), create_tensor([1], name+"_1", kwargs["initializer"]), make_node("Shape", [data], [name+"_shape"]), diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 71629728d14c..ef649c7fa482 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -342,19 +342,21 @@ def test_onnx_export_cast(tmp_path, src_dtype, dst_dtype, shape): @pytest.mark.parametrize('dtype', ['float16', 'float32']) -@pytest.mark.parametrize('temperature', [.1, 1., 10.]) +@pytest.mark.parametrize('temperature', [None, .1, 1., 10.]) def test_onnx_export_softmax(tmp_path, dtype, temperature): - x = mx.nd.random.uniform(0, 1, (2, 3, 4), dtype=dtype) + x = mx.nd.random.uniform(0, 1, (4, 5, 6), dtype=dtype) M1 = def_model('softmax') op_export_test('softmax_1', M1, [x], tmp_path) M2 = def_model('softmax', use_length=True, axis=0, temperature=temperature) - l2 = mx.nd.array([[2,0,2,1],[1,1,2,1], [0,0,0,1]], dtype=int) + l2 = mx.random.uniform(0, 4, (5, 6)).astype('int32') op_export_test('softmax_2', M2, [x, l2], tmp_path) M3 = def_model('softmax', use_length=True, axis=-1, temperature=temperature) - l3 = mx.nd.array([[2,0,4],[0,0,0]], dtype=int) + # note that the axis==-1 case uses negative value masking + ONNX softmax + # when valid_len==0 the masked values will NOT be 0 + l3 = mx.random.uniform(1, 6, (4, 5)).astype('int32') op_export_test('softmax_3', M3, [x, l3], tmp_path) M4 = def_model('softmax', use_length=True, axis=1, temperature=temperature) - l4 = mx.nd.array([[2,0,3,1],[0,1,0,0]], dtype=int) + l4 = mx.random.uniform(0, 5, (4, 6)).astype('int32') op_export_test('softmax_4', M4, [x, l4], tmp_path) From a0dd82f655a023bfddd20b1b652f36f245f3dfdd Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Fri, 29 Jan 2021 01:10:49 +0000 Subject: [PATCH 2/9] fix compatibility issues with onnx 1.6 --- python/mxnet/contrib/onnx/mx2onnx/_op_translations.py | 6 ++++-- tests/python-pytest/onnx/test_operators.py | 4 +++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index fbc21a3d8c9a..8dbd922f0b9c 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -2333,6 +2333,8 @@ def convert_layer_norm(node, **kwargs): axes = int(attrs.get('axis', -1)) eps = attrs.get('eps', 9.99999975e-06) + input_type = int(kwargs['in_type']) + dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[input_type] nodes = [ create_tensor([axes], name+"_axes", kwargs["initializer"]), @@ -2340,7 +2342,7 @@ def convert_layer_norm(node, **kwargs): create_tensor([], name+"_void", kwargs["initializer"]), create_const_scalar_node(name+'_0_s', np.int64(0), kwargs), create_const_scalar_node(name+'_1_s', np.int64(1), kwargs), - create_const_scalar_node(name+"_2_s", np.int64(2), kwargs), + create_const_scalar_node(name+"_2_s", np.array(2, dtype=dtype), kwargs), create_const_scalar_node(name+"_eps", np.float32(eps), kwargs), make_node("ReduceMean", [input_nodes[0]], [name+"_rm0_out"], axes=[axes]), make_node("Sub", [input_nodes[0], name+"_rm0_out"], [name+"_sub0_out"]), @@ -3025,7 +3027,7 @@ def convert_maximum_scalar(node, **kwargs): input_type = int(kwargs['in_type']) dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[input_type] - + scalar = None if 'float' in str(dtype): scalar = float(attrs.get('scalar', '0')) diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index ef649c7fa482..a9bfb4241902 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -464,7 +464,9 @@ def hybrid_forward(self, F, x): op_export_test('link_op_with_multiple_outputs_case3', Model3, [A], tmp_path) -@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int32', 'int64']) +# opset 8 MAX only supports float types +# opset 12 and up suppots float and int +@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64']) @pytest.mark.parametrize('shape', [(3, 4, 5), (1, 4, 1, 7)]) def test_onnx_maximum_scalar(tmp_path, dtype, shape): x = mx.random.uniform(0, 10, shape).astype(dtype) From 8c9cf9c961cc4b2720449cbc7d59268c121d00f1 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Tue, 2 Feb 2021 20:05:59 +0000 Subject: [PATCH 3/9] fix for when multiple op nodes have the same name --- python/mxnet/contrib/onnx/mx2onnx/export_onnx.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py index ec5ea2d4a273..e6e7d3ca2ceb 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py +++ b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py @@ -253,6 +253,10 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False, graph_input_idx += 1 else: + # If this node is not weight, then we add a prefix to the name to avoid name + # clashing issue in case some op nodes have the same name + if name not in params: + node["name"] = 'op_node_' + str(idx) + '_' + node["name"] # Handling graph layers converted = MXNetGraph.convert_layer( node, From df37bfe4c855ceb18b8e7644d5fe05fe6d38ffaa Mon Sep 17 00:00:00 2001 From: Zhaoqi Zhu Date: Tue, 2 Feb 2021 13:48:45 -0800 Subject: [PATCH 4/9] Update _op_translations.py --- python/mxnet/contrib/onnx/mx2onnx/_op_translations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 8dbd922f0b9c..84b57819baec 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -3027,7 +3027,7 @@ def convert_maximum_scalar(node, **kwargs): input_type = int(kwargs['in_type']) dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[input_type] - + scalar = None if 'float' in str(dtype): scalar = float(attrs.get('scalar', '0')) From d5024a80cbfd0a5e638199549afe83d0e8dfa87f Mon Sep 17 00:00:00 2001 From: Zhaoqi Zhu Date: Tue, 2 Feb 2021 20:04:17 -0800 Subject: [PATCH 5/9] Skip legacy tests --- ci/docker/runtime_functions.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh index 0bbc6d79b9b8..225b0396ce39 100755 --- a/ci/docker/runtime_functions.sh +++ b/ci/docker/runtime_functions.sh @@ -1280,9 +1280,9 @@ integrationtest_ubuntu_cpu_onnx() { export DMLC_LOG_STACK_TRACE_DEPTH=10 #tests/python-pytest/onnx/backend_test.py COV_ARG="--cov=./ --cov-report=xml --cov-append" - pytest $COV_ARG --verbose tests/python-pytest/onnx/mxnet_export_test.py + #pytest $COV_ARG --verbose tests/python-pytest/onnx/mxnet_export_test.py pytest $COV_ARG --verbose tests/python-pytest/onnx/test_models.py - pytest $COV_ARG --verbose tests/python-pytest/onnx/test_node.py + #pytest $COV_ARG --verbose tests/python-pytest/onnx/test_node.py pytest $COV_ARG --verbose tests/python-pytest/onnx/test_operators.py pytest $COV_ARG --verbose tests/python-pytest/onnx/test_onnxruntime.py } From 5918f92c3d7b3096d955966d39faf3153e6b9846 Mon Sep 17 00:00:00 2001 From: Zhaoqi Zhu Date: Wed, 3 Feb 2021 10:49:05 -0800 Subject: [PATCH 6/9] Update runtime_functions.sh --- ci/docker/runtime_functions.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh index 225b0396ce39..b58eb629b824 100755 --- a/ci/docker/runtime_functions.sh +++ b/ci/docker/runtime_functions.sh @@ -1281,7 +1281,7 @@ integrationtest_ubuntu_cpu_onnx() { #tests/python-pytest/onnx/backend_test.py COV_ARG="--cov=./ --cov-report=xml --cov-append" #pytest $COV_ARG --verbose tests/python-pytest/onnx/mxnet_export_test.py - pytest $COV_ARG --verbose tests/python-pytest/onnx/test_models.py + #pytest $COV_ARG --verbose tests/python-pytest/onnx/test_models.py #pytest $COV_ARG --verbose tests/python-pytest/onnx/test_node.py pytest $COV_ARG --verbose tests/python-pytest/onnx/test_operators.py pytest $COV_ARG --verbose tests/python-pytest/onnx/test_onnxruntime.py From b05e0895cd66ef419cfa14865a486776a8233ba5 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Wed, 3 Feb 2021 22:34:30 +0000 Subject: [PATCH 7/9] fix --- python/mxnet/contrib/onnx/mx2onnx/export_onnx.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py index e6e7d3ca2ceb..af6af8b738a7 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py +++ b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py @@ -224,9 +224,15 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False, # Determine output shape graph_outputs = MXNetGraph.get_outputs(sym, params, in_shape, output_label, in_type) + appeared_names = set() graph_input_idx = 0 for idx, node in enumerate(mx_graph): op = node["op"] + # check if the current node has the same name as nodes before + if node["name"] in appeared_names: + node["name"] = 'idx_' + str(idx) + '_' + node["name"] + else: + appeared_names.add(node["name"]) name = node["name"] if verbose: logging.info("Converting idx: %d, op: %s, name: %s", idx, op, name) @@ -253,10 +259,6 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False, graph_input_idx += 1 else: - # If this node is not weight, then we add a prefix to the name to avoid name - # clashing issue in case some op nodes have the same name - if name not in params: - node["name"] = 'op_node_' + str(idx) + '_' + node["name"] # Handling graph layers converted = MXNetGraph.convert_layer( node, From 1b1b5bb8d422feb1796d1cec3ec5aff8a36041e7 Mon Sep 17 00:00:00 2001 From: Zhaoqi Zhu Date: Wed, 3 Feb 2021 15:13:17 -0800 Subject: [PATCH 8/9] Update runtime_functions.sh --- ci/docker/runtime_functions.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh index b58eb629b824..0bbc6d79b9b8 100755 --- a/ci/docker/runtime_functions.sh +++ b/ci/docker/runtime_functions.sh @@ -1280,9 +1280,9 @@ integrationtest_ubuntu_cpu_onnx() { export DMLC_LOG_STACK_TRACE_DEPTH=10 #tests/python-pytest/onnx/backend_test.py COV_ARG="--cov=./ --cov-report=xml --cov-append" - #pytest $COV_ARG --verbose tests/python-pytest/onnx/mxnet_export_test.py - #pytest $COV_ARG --verbose tests/python-pytest/onnx/test_models.py - #pytest $COV_ARG --verbose tests/python-pytest/onnx/test_node.py + pytest $COV_ARG --verbose tests/python-pytest/onnx/mxnet_export_test.py + pytest $COV_ARG --verbose tests/python-pytest/onnx/test_models.py + pytest $COV_ARG --verbose tests/python-pytest/onnx/test_node.py pytest $COV_ARG --verbose tests/python-pytest/onnx/test_operators.py pytest $COV_ARG --verbose tests/python-pytest/onnx/test_onnxruntime.py } From 0f8c5609c11a5a3ad24b45e72dbe9dbfdc34f7c3 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Thu, 4 Feb 2021 22:48:06 +0000 Subject: [PATCH 9/9] move bert test up and make the tolerance larger --- tests/python-pytest/onnx/test_onnxruntime.py | 111 +++++++++---------- 1 file changed, 55 insertions(+), 56 deletions(-) diff --git a/tests/python-pytest/onnx/test_onnxruntime.py b/tests/python-pytest/onnx/test_onnxruntime.py index ad00bddf240f..60e84d047439 100644 --- a/tests/python-pytest/onnx/test_onnxruntime.py +++ b/tests/python-pytest/onnx/test_onnxruntime.py @@ -68,6 +68,61 @@ def download_test_images(image_urls, tmpdir): paths.append(filename) return paths + +@with_seed() +@pytest.mark.parametrize('model', ['bert_12_768_12']) +def test_bert_inference_onnxruntime(tmp_path, model): + tmp_path = str(tmp_path) + try: + import gluonnlp as nlp + dataset = 'book_corpus_wiki_en_uncased' + ctx = mx.cpu(0) + model, vocab = nlp.model.get_model( + name=model, + ctx=ctx, + dataset_name=dataset, + pretrained=False, + use_pooler=True, + use_decoder=False, + use_classifier=False) + model.initialize(ctx=ctx) + model.hybridize(static_alloc=True) + + batch = 5 + seq_length = 16 + # create synthetic test data + inputs = mx.nd.random.uniform(0, 30522, shape=(batch, seq_length), dtype='float32') + token_types = mx.nd.random.uniform(0, 2, shape=(batch, seq_length), dtype='float32') + valid_length = mx.nd.array([seq_length] * batch, dtype='float32') + + seq_encoding, cls_encoding = model(inputs, token_types, valid_length) + + prefix = "%s/bert" % tmp_path + model.export(prefix) + sym_file = "%s-symbol.json" % prefix + params_file = "%s-0000.params" % prefix + onnx_file = "%s.onnx" % prefix + + + input_shapes = [(batch, seq_length), (batch, seq_length), (batch,)] + converted_model_path = mx.contrib.onnx.export_model(sym_file, params_file, input_shapes, np.float32, onnx_file) + + + # create onnxruntime session using the generated onnx file + ses_opt = onnxruntime.SessionOptions() + ses_opt.log_severity_level = 3 + session = onnxruntime.InferenceSession(onnx_file, ses_opt) + onnx_inputs = [inputs, token_types, valid_length] + input_dict = dict((session.get_inputs()[i].name, onnx_inputs[i].asnumpy()) for i in range(len(onnx_inputs))) + pred_onx, cls_onx = session.run(None, input_dict) + + assert_almost_equal(seq_encoding, pred_onx, rtol=0.01, atol=0.01) + assert_almost_equal(cls_encoding, cls_onx, rtol=0.01, atol=0.01) + + finally: + shutil.rmtree(tmp_path) + + @pytest.mark.parametrize('model', [ 'alexnet', 'cifar_resnet20_v1', @@ -399,59 +454,3 @@ def load_video(filepath): finally: shutil.rmtree(tmp_path) - - -@with_seed() -@pytest.mark.parametrize('model', ['bert_12_768_12']) -def test_bert_inference_onnxruntime(tmp_path, model): - tmp_path = str(tmp_path) - try: - import gluonnlp as nlp - dataset = 'book_corpus_wiki_en_uncased' - ctx = mx.cpu(0) - model, vocab = nlp.model.get_model( - name=model, - ctx=ctx, - dataset_name=dataset, - pretrained=False, - use_pooler=True, - use_decoder=False, - use_classifier=False) - model.initialize(ctx=ctx) - model.hybridize(static_alloc=True) - - batch = 5 - seq_length = 16 - # create synthetic test data - inputs = mx.nd.random.uniform(0, 30522, shape=(batch, seq_length), dtype='float32') - token_types = mx.nd.random.uniform(0, 2, shape=(batch, seq_length), dtype='float32') - valid_length = mx.nd.array([seq_length] * batch, dtype='float32') - - seq_encoding, cls_encoding = model(inputs, token_types, valid_length) - - prefix = "%s/bert" % tmp_path - model.export(prefix) - sym_file = "%s-symbol.json" % prefix - params_file = "%s-0000.params" % prefix - onnx_file = "%s.onnx" % prefix - - - input_shapes = [(batch, seq_length), (batch, seq_length), (batch,)] - converted_model_path = mx.contrib.onnx.export_model(sym_file, params_file, input_shapes, np.float32, onnx_file) - - - # create onnxruntime session using the generated onnx file - ses_opt = onnxruntime.SessionOptions() - ses_opt.log_severity_level = 3 - session = onnxruntime.InferenceSession(onnx_file, ses_opt) - onnx_inputs = [inputs, token_types, valid_length] - input_dict = dict((session.get_inputs()[i].name, onnx_inputs[i].asnumpy()) for i in range(len(onnx_inputs))) - pred_onx, cls_onx = session.run(None, input_dict) - - assert_almost_equal(seq_encoding, pred_onx) - assert_almost_equal(cls_encoding, cls_onx) - - finally: - shutil.rmtree(tmp_path) - -