diff --git a/tests/python-pytest/onnx/test_onnxruntime.py b/tests/python-pytest/onnx/test_onnxruntime.py index 5cb4fe7d89c4..3bcf27f7f1c9 100644 --- a/tests/python-pytest/onnx/test_onnxruntime.py +++ b/tests/python-pytest/onnx/test_onnxruntime.py @@ -596,8 +596,8 @@ def test_roberta_inference_onnxruntime(tmp_path, model_name): onnx_file = "%s.onnx" % prefix input_shapes = [(batch, seq_length), (batch,), (batch, num_masked_positions)] converted_model_path = mx.contrib.onnx.export_model(sym_file, params_file, input_shapes, - [np.float32, np.float32, np.int32], - onnx_file, verbose=True) + [np.float32, np.float32, np.int32], + onnx_file, verbose=True) sess_options = onnxruntime.SessionOptions() sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL @@ -626,11 +626,11 @@ def test_bert_inference_onnxruntime(tmp_path, model): name=model, ctx=ctx, dataset_name=dataset, - pretrained=False, + pretrained=True, use_pooler=True, use_decoder=False, use_classifier=False) - model.initialize(ctx=ctx) + model.hybridize(static_alloc=True) batch = 5 @@ -669,3 +669,50 @@ def test_bert_inference_onnxruntime(tmp_path, model): shutil.rmtree(tmp_path) +@with_seed() +@pytest.mark.parametrize('model_name', ['distilbert_6_768_12']) +def test_distilbert_inference_onnxruntime(tmp_path, model_name): + tmp_path = str(tmp_path) + try: + import gluonnlp as nlp + dataset = 'distilbert_book_corpus_wiki_en_uncased' + ctx = mx.cpu(0) + model, _ = nlp.model.get_model( + name=model_name, + ctx=ctx, + pretrained=True, + dataset_name=dataset) + + model.hybridize(static_alloc=True) + + batch = 2 + seq_length = 32 + num_masked_positions = 1 + inputs = mx.nd.random.uniform(0, 30522, shape=(batch, seq_length), dtype='float32', ctx=ctx) + valid_length = mx.nd.array([seq_length] * batch, dtype='float32', ctx=ctx) + + sequence_outputs = model(inputs, valid_length) + + prefix = "%s/distilbert" % tmp_path + model.export(prefix) + sym_file = "%s-symbol.json" % prefix + params_file = "%s-0000.params" % prefix + onnx_file = "%s.onnx" % prefix + + input_shapes = [(batch, seq_length), (batch,)] + converted_model_path = mx.contrib.onnx.export_model(sym_file, params_file, input_shapes, + [np.float32, np.float32], + onnx_file, verbose=True) + sess_options = onnxruntime.SessionOptions() + sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + sess = onnxruntime.InferenceSession(onnx_file, sess_options) + + in_tensors = [inputs, valid_length] + input_dict = dict((sess.get_inputs()[i].name, in_tensors[i].asnumpy()) for i in range(len(in_tensors))) + pred = sess.run(None, input_dict) + + assert_almost_equal(sequence_outputs, pred[0]) + + finally: + shutil.rmtree(tmp_path) +