diff --git a/ci/docker/install/ubuntu_onnx.sh b/ci/docker/install/ubuntu_onnx.sh index 3fbc1672a2f7..a6bca56518ff 100755 --- a/ci/docker/install/ubuntu_onnx.sh +++ b/ci/docker/install/ubuntu_onnx.sh @@ -31,4 +31,4 @@ apt-get update || true apt-get install -y libprotobuf-dev protobuf-compiler echo "Installing pytest, pytest-cov, protobuf, Pillow, ONNX, tabulate and onnxruntime..." -pip3 install pytest pytest-cov protobuf==3.5.2 onnx==1.7.0 Pillow==5.0.0 tabulate==0.7.5 onnxruntime==1.4.0 +pip3 install pytest pytest-cov protobuf==3.5.2 onnx==1.7.0 Pillow==5.0.0 tabulate==0.7.5 onnxruntime==1.4.0 gluonnlp diff --git a/tests/python-pytest/onnx/common.py b/tests/python-pytest/onnx/common.py new file mode 100644 index 000000000000..3f9e2642dbe2 --- /dev/null +++ b/tests/python-pytest/onnx/common.py @@ -0,0 +1,100 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import functools +import logging +import os +import random + +import mxnet as mx +import numpy as np + + +def with_seed(seed=None): + """ + A decorator for test functions that manages rng seeds. + + Parameters + ---------- + + seed : the seed to pass to np.random and mx.random + + + This tests decorator sets the np, mx and python random seeds identically + prior to each test, then outputs those seeds if the test fails or + if the test requires a fixed seed (as a reminder to make the test + more robust against random data). + + @with_seed() + def test_ok_with_random_data(): + ... + + @with_seed(1234) + def test_not_ok_with_random_data(): + ... + + Use of the @with_seed() decorator for all tests creates + tests isolation and reproducability of failures. When a + test fails, the decorator outputs the seed used. The user + can then set the environment variable MXNET_TEST_SEED to + the value reported, then rerun the test with: + + pytest --verbose --capture=no :: + + To run a test repeatedly, set MXNET_TEST_COUNT= in the environment. + To see the seeds of even the passing tests, add '--log-level=DEBUG' to pytest. + """ + def test_helper(orig_test): + @functools.wraps(orig_test) + def test_new(*args, **kwargs): + test_count = int(os.getenv('MXNET_TEST_COUNT', '1')) + env_seed_str = os.getenv('MXNET_TEST_SEED') + for i in range(test_count): + if seed is not None: + this_test_seed = seed + log_level = logging.INFO + elif env_seed_str is not None: + this_test_seed = int(env_seed_str) + log_level = logging.INFO + else: + this_test_seed = np.random.randint(0, np.iinfo(np.int32).max) + log_level = logging.DEBUG + post_test_state = np.random.get_state() + np.random.seed(this_test_seed) + mx.random.seed(this_test_seed) + random.seed(this_test_seed) + # 'pytest --logging-level=DEBUG' shows this msg even with an ensuing core dump. + test_count_msg = '{} of {}: '.format(i+1,test_count) if test_count > 1 else '' + pre_test_msg = ('{}Setting test np/mx/python random seeds, use MXNET_TEST_SEED={}' + ' to reproduce.').format(test_count_msg, this_test_seed) + on_err_test_msg = ('{}Error seen with seeded test, use MXNET_TEST_SEED={}' + ' to reproduce.').format(test_count_msg, this_test_seed) + logging.log(log_level, pre_test_msg) + try: + orig_test(*args, **kwargs) + except: + # With exceptions, repeat test_msg at WARNING level to be sure it's seen. + if log_level < logging.WARNING: + logging.warning(on_err_test_msg) + raise + finally: + # Provide test-isolation for any test having this decorator + mx.nd.waitall() + np.random.set_state(post_test_state) + return test_new + return test_helper + diff --git a/tests/python-pytest/onnx/test_onnxruntime.py b/tests/python-pytest/onnx/test_onnxruntime.py index 358d87658b05..2cb2b772861b 100644 --- a/tests/python-pytest/onnx/test_onnxruntime.py +++ b/tests/python-pytest/onnx/test_onnxruntime.py @@ -19,6 +19,9 @@ import numpy as np import onnxruntime +from mxnet.test_utils import assert_almost_equal +from common import with_seed + import json import os import pytest @@ -30,10 +33,11 @@ ['apron.jpg', [411,578,638,639,689,775]], ['dolphin.jpg', [2,3,4,146,147,148,395]], ['hammerheadshark.jpg', [3,4]], - ['lotus.jpg', [723,738,985]] + ['lotus.jpg', [716,723,738,985]] ] test_models = [ + 'alexnet', 'densenet121', 'densenet161', 'densenet169', 'densenet201', 'mobilenet1.0', 'mobilenet0.75', 'mobilenet0.5', 'mobilenet0.25', 'mobilenetv2_1.0', 'mobilenetv2_0.75', 'mobilenetv2_0.5', 'mobilenetv2_0.25', 'resnet18_v1', 'resnet18_v2', 'resnet34_v1', 'resnet34_v2', 'resnet50_v1', 'resnet50_v2', @@ -42,6 +46,7 @@ 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn' ] +@with_seed() @pytest.mark.parametrize('model', test_models) def test_cv_model_inference_onnxruntime(tmp_path, model): def get_gluon_cv_model(model_name, tmp): @@ -97,25 +102,80 @@ def download_test_images(tmpdir): tmp_path = str(tmp_path) - #labels = load_imgnet_labels(tmp_path) - test_images = download_test_images(tmp_path) - sym_file, params_file = get_gluon_cv_model(model, tmp_path) - onnx_file = export_model_to_onnx(sym_file, params_file) - - # create onnxruntime session using the generated onnx file - ses_opt = onnxruntime.SessionOptions() - ses_opt.log_severity_level = 3 - session = onnxruntime.InferenceSession(onnx_file, ses_opt) - input_name = session.get_inputs()[0].name - - for img, accepted_ids in test_images: - img_data = normalize_image(os.path.join(tmp_path,img)) - raw_result = session.run([], {input_name: img_data}) - res = softmax(np.array(raw_result)).tolist() - class_idx = np.argmax(res) - assert(class_idx in accepted_ids) - - shutil.rmtree(tmp_path) - + try: + #labels = load_imgnet_labels(tmp_path) + test_images = download_test_images(tmp_path) + sym_file, params_file = get_gluon_cv_model(model, tmp_path) + onnx_file = export_model_to_onnx(sym_file, params_file) + + # create onnxruntime session using the generated onnx file + ses_opt = onnxruntime.SessionOptions() + ses_opt.log_severity_level = 3 + session = onnxruntime.InferenceSession(onnx_file, ses_opt) + input_name = session.get_inputs()[0].name + + for img, accepted_ids in test_images: + img_data = normalize_image(os.path.join(tmp_path,img)) + raw_result = session.run([], {input_name: img_data}) + res = softmax(np.array(raw_result)).tolist() + class_idx = np.argmax(res) + assert(class_idx in accepted_ids) + + finally: + shutil.rmtree(tmp_path) + + +@with_seed() +@pytest.mark.parametrize('model', ['bert_12_768_12']) +def test_bert_inference_onnxruntime(tmp_path, model): + tmp_path = str(tmp_path) + try: + import gluonnlp as nlp + dataset = 'book_corpus_wiki_en_uncased' + ctx = mx.cpu(0) + model, vocab = nlp.model.get_model( + name=model, + ctx=ctx, + dataset_name=dataset, + pretrained=False, + use_pooler=True, + use_decoder=False, + use_classifier=False) + model.initialize(ctx=ctx) + model.hybridize(static_alloc=True) + + batch = 5 + seq_length = 16 + # create synthetic test data + inputs = mx.nd.random.uniform(0, 30522, shape=(batch, seq_length), dtype='float32') + token_types = mx.nd.random.uniform(0, 2, shape=(batch, seq_length), dtype='float32') + valid_length = mx.nd.array([seq_length] * batch, dtype='float32') + + seq_encoding, cls_encoding = model(inputs, token_types, valid_length) + + prefix = "%s/bert" % tmp_path + model.export(prefix) + sym_file = "%s-symbol.json" % prefix + params_file = "%s-0000.params" % prefix + onnx_file = "%s.onnx" % prefix + + + input_shapes = [(batch, seq_length), (batch, seq_length), (batch,)] + converted_model_path = mx.contrib.onnx.export_model(sym_file, params_file, input_shapes, np.float32, onnx_file) + + + # create onnxruntime session using the generated onnx file + ses_opt = onnxruntime.SessionOptions() + ses_opt.log_severity_level = 3 + session = onnxruntime.InferenceSession(onnx_file, ses_opt) + onnx_inputs = [inputs, token_types, valid_length] + input_dict = dict((session.get_inputs()[i].name, onnx_inputs[i].asnumpy()) for i in range(len(onnx_inputs))) + pred_onx, cls_onx = session.run(None, input_dict) + + assert_almost_equal(seq_encoding, pred_onx) + assert_almost_equal(cls_encoding, cls_onx) + + finally: + shutil.rmtree(tmp_path)