diff --git a/tests/python-pytest/onnx/test_onnxruntime.py b/tests/python-pytest/onnx/test_onnxruntime.py index bf32259f53d1..fa45bb63ddc0 100644 --- a/tests/python-pytest/onnx/test_onnxruntime.py +++ b/tests/python-pytest/onnx/test_onnxruntime.py @@ -868,3 +868,61 @@ def test_dynamic_shape_bert_inference_onnxruntime(tmp_path, model): finally: shutil.rmtree(tmp_path) + +@with_seed() +@pytest.mark.parametrize('model_name', ['ernie_12_768_12']) +def test_ernie_inference_onnxruntime(tmp_path, model_name): + tmp_path = str(tmp_path) + try: + import gluonnlp as nlp + dataset = 'baidu_ernie_uncased' + ctx = mx.cpu(0) + model, vocab = nlp.model.get_model( + name=model_name, + ctx=ctx, + dataset_name=dataset, + pretrained=True, + use_pooler=True, + use_decoder=False, + num_layers = 3, + hparam_allow_override = True, + use_classifier=False) + + model.hybridize(static_alloc=True) + + batch = 5 + seq_length = 16 + # create synthetic test data + inputs = mx.nd.random.uniform(0, 17964, shape=(batch, seq_length), dtype='float32') + token_types = mx.nd.random.uniform(0, 2, shape=(batch, seq_length), dtype='float32') + valid_length = mx.nd.array([seq_length] * batch, dtype='float32') + + seq_encoding, cls_encoding = model(inputs, token_types, valid_length) + + prefix = "%s/ernie" % tmp_path + model.export(prefix) + sym_file = "%s-symbol.json" % prefix + params_file = "%s-0000.params" % prefix + onnx_file = "%s.onnx" % prefix + + input_shapes = [(batch, seq_length), (batch, seq_length), (batch,)] + input_types = [np.float32, np.float32, np.float32] + converted_model_path = mx.contrib.onnx.export_model(sym_file, params_file, input_shapes, + input_types, onnx_file) + + # create onnxruntime session using the generated onnx file + ses_opt = onnxruntime.SessionOptions() + ses_opt.log_severity_level = 3 + session = onnxruntime.InferenceSession(onnx_file, ses_opt) + + seq_encoding, cls_encoding = model(inputs, token_types, valid_length) + + onnx_inputs = [inputs, token_types, valid_length] + input_dict = dict((session.get_inputs()[i].name, onnx_inputs[i].asnumpy()) for i in range(len(onnx_inputs))) + pred_onx, cls_onx = session.run(None, input_dict) + + assert_almost_equal(seq_encoding, pred_onx) + assert_almost_equal(cls_encoding, cls_onx) + + finally: + shutil.rmtree(tmp_path)