From 71d1f27c6a323595090a8d7766d2f16fbfcb897d Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Sat, 20 Mar 2021 00:22:40 +0000 Subject: [PATCH 1/4] add gpt support --- .../contrib/onnx/mx2onnx/_op_translations.py | 24 ++++++ tests/python-pytest/onnx/test_onnxruntime.py | 74 +++++++++++++++++++ 2 files changed, 98 insertions(+) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 462564a85071..a4e0bd426b5c 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -1709,6 +1709,30 @@ def convert_reshape(node, **kwargs): ] return nodes + if targ_shape == [0, -4, 12, -1, 0] and reverse != 'True': + create_tensor([-1], name+'_m1', kwargs['initializer']) + create_tensor([12], name+'_12', kwargs['initializer']) + nodes = [ + make_node('Shape', [input_nodes[0]], [name+'_shape']), + make_node('Split', [name+'_shape'], [name+'_dim0', name+'_dim1', name+'_dim2'], axis=0), + make_node('Concat', [name+'_dim0', name+'_12', name+'_m1', name+'_dim2'], + [name+'_shape_new'], axis=0), + make_node('Reshape', [input_nodes[0], name+'_shape_new'], [name], name=name) + ] + return nodes + + if targ_shape == [0, -4, 16, -1, 0] and reverse != 'True': + create_tensor([-1], name+'_m1', kwargs['initializer']) + create_tensor([16], name+'_16', kwargs['initializer']) + nodes = [ + make_node('Shape', [input_nodes[0]], [name+'_shape']), + make_node('Split', [name+'_shape'], [name+'_dim0', name+'_dim1', name+'_dim2'], axis=0), + make_node('Concat', [name+'_dim0', name+'_16', name+'_m1', name+'_dim2'], + [name+'_shape_new'], axis=0), + make_node('Reshape', [input_nodes[0], name+'_shape_new'], [name], name=name) + ] + return nodes + not_supported_shape = [-2, -3, -4] for val in targ_shape: if val in not_supported_shape: diff --git a/tests/python-pytest/onnx/test_onnxruntime.py b/tests/python-pytest/onnx/test_onnxruntime.py index e2a8329dd45d..9ccf602ca16f 100644 --- a/tests/python-pytest/onnx/test_onnxruntime.py +++ b/tests/python-pytest/onnx/test_onnxruntime.py @@ -988,3 +988,77 @@ def test_ernie_inference_onnxruntime(tmp_path, model_name): finally: shutil.rmtree(tmp_path) + + +@with_seed() +@pytest.mark.parametrize('model_params', [('gpt2_117m', 24), ('gpt2_345m', 48)]) +def test_gpt_pretrained_inference_onnxruntime(tmp_path, model_params): + tmp_path = str(tmp_path) + try: + import gluonnlp as nlp + import urllib.request + from zipfile import ZipFile + import importlib.util + import sys + + url = 'https://nlp.gluon.ai/_downloads/77d227fbc8f1613e6802acc7253cc090/text_generation.zip' + urllib.request.urlretrieve(url, tmp_path + 'text_generation.zip') + + with ZipFile(tmp_path + 'text_generation.zip', 'r') as zipObj: + zipObj.extractall(tmp_path) + + # load in the text_generation module, refer to: + # https://github.com/dmlc/gluon-nlp/tree/v0.10.x/scripts/text_generation + spec = importlib.util.spec_from_file_location( + 'text_generation', + tmp_path + '/text_generation/__init__.py') + mod = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = mod + spec.loader.exec_module(mod) + + ctx = mx.cpu(0) + model_name= model_params[0] + dataset= 'openai_webtext' + # get_model() is overridden in here: + # https://github.com/dmlc/gluon-nlp/blob/v0.10.x/scripts/text_generation/model/__init__.py#L23 + model, _ = mod.model.get_model( + name=model_name, + ctx=ctx, + pretrained=True, + dataset_name=dataset) + + model.hybridize() + batch = 4 + seq_length = 64 + inputs = mx.nd.random.uniform(0, 50257, shape=(batch, seq_length), dtype='float32', + ctx=ctx) + + pred = model(inputs) + + prefix = "%s/%s" % (tmp_path, model_name) + model.export(prefix) + sym_file = "%s-symbol.json" % prefix + params_file = "%s-0000.params" % prefix + onnx_file = "%s.onnx" % prefix + + input_shapes = [(batch, seq_length)] + input_types = [np.float32] + converted_model_path = mx.contrib.onnx.export_model(sym_file, params_file, input_shapes, + input_types, onnx_file) + + ses_opt = onnxruntime.SessionOptions() + ses_opt.log_severity_level = 3 + session = onnxruntime.InferenceSession(onnx_file, ses_opt) + onnx_inputs = [inputs] + input_dict = dict((session.get_inputs()[i].name, onnx_inputs[i].asnumpy()) for i in range(len(onnx_inputs))) + pred_onx = session.run(None, input_dict) + + # checkout output + assert_almost_equal(pred[0], pred_onx[0]) + # chckout states + num_states = model_params[1] + for i in range(num_states): + assert_almost_equal(pred[1][i], pred_onx[i+1]) + + finally: + shutil.rmtree(tmp_path) From 8f0277a1a550a80f106d9880586dab1a19a77b07 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Tue, 23 Mar 2021 17:30:55 +0000 Subject: [PATCH 2/4] add reshape tests --- tests/python-pytest/onnx/test_onnxruntime.py | 4 ++++ tests/python-pytest/onnx/test_operators.py | 8 ++++++++ 2 files changed, 12 insertions(+) diff --git a/tests/python-pytest/onnx/test_onnxruntime.py b/tests/python-pytest/onnx/test_onnxruntime.py index 9ebc317e225a..826d00fff899 100644 --- a/tests/python-pytest/onnx/test_onnxruntime.py +++ b/tests/python-pytest/onnx/test_onnxruntime.py @@ -990,6 +990,7 @@ def test_ernie_inference_onnxruntime(tmp_path, model_name): shutil.rmtree(tmp_path) +@with_seed() @pytest.mark.parametrize('model_name', ['transformer_en_de_512']) def test_transformer_pretrained_inference_onnxruntime(tmp_path, model_name): tmp_path = str(tmp_path) @@ -1002,6 +1003,7 @@ def test_transformer_pretrained_inference_onnxruntime(tmp_path, model_name): ctx=ctx, pretrained=True, dataset_name=dataset) + model.hybridize(static_alloc=False) batch = 7 @@ -1193,6 +1195,7 @@ def test_gpt_pretrained_inference_onnxruntime(tmp_path, model_params): dataset_name=dataset) model.hybridize() + batch = 4 seq_length = 64 inputs = mx.nd.random.uniform(0, 50257, shape=(batch, seq_length), dtype='float32', @@ -1227,3 +1230,4 @@ def test_gpt_pretrained_inference_onnxruntime(tmp_path, model_params): finally: shutil.rmtree(tmp_path) + diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 220f259beb46..520ac407e49b 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -276,6 +276,14 @@ def test_onnx_export_reshape_special_cases(tmp_path, dtype): M9 = def_model('reshape', shape=(-4, 1, 1000, 0, 0)) op_export_test('reshape_spec_9', M9, [x7], tmp_path) + x8 = mx.nd.ones((3, 96, 5), dtype=dtype) + M10 = def_model('reshape', shape=(0, -4, 12, -1, 0)) + op_export_test('reshape_spec_10', M10, [x8], tmp_path) + + x9 = mx.nd.ones((3, 96, 5), dtype=dtype) + M11 = def_model('reshape', shape=(0, -4, 16, -1, 0)) + op_export_test('reshape_spec_11', M11, [x9], tmp_path) + @pytest.mark.parametrize('dtype', ['int32', 'int64']) def test_onnx_export_embedding(tmp_path, dtype): From 839359c7c2bfb6b02b89828df75ef99b614afc25 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Tue, 23 Mar 2021 17:33:22 +0000 Subject: [PATCH 3/4] fix comment --- tests/python-pytest/onnx/test_onnxruntime.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python-pytest/onnx/test_onnxruntime.py b/tests/python-pytest/onnx/test_onnxruntime.py index 826d00fff899..6ad0794f875d 100644 --- a/tests/python-pytest/onnx/test_onnxruntime.py +++ b/tests/python-pytest/onnx/test_onnxruntime.py @@ -1221,9 +1221,9 @@ def test_gpt_pretrained_inference_onnxruntime(tmp_path, model_params): input_dict = dict((session.get_inputs()[i].name, onnx_inputs[i].asnumpy()) for i in range(len(onnx_inputs))) pred_onx = session.run(None, input_dict) - # checkout output + # check output assert_almost_equal(pred[0], pred_onx[0]) - # chckout states + # check states num_states = model_params[1] for i in range(num_states): assert_almost_equal(pred[1][i], pred_onx[i+1]) From 40922383e7a74635ac2d42eb5f1267adb38a3b1c Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Tue, 23 Mar 2021 17:55:19 +0000 Subject: [PATCH 4/4] fix sanity --- .../contrib/onnx/mx2onnx/_op_translations.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index feaa0604bec6..22e6282b43a4 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -1645,7 +1645,9 @@ def convert_reshape(node, **kwargs): targ_shape = [-1, 0] reverse = 'True' + special_case = False if targ_shape == [0, 0, -3, -3] and reverse != 'True': + special_case = True nodes = [ make_node('Shape', [input_nodes[0]], [name+'_shape']), make_node('Split', [name+'_shape'], [name+'_dim0', name+'_dim1', name+'_dim2', @@ -1657,9 +1659,9 @@ def convert_reshape(node, **kwargs): [name+'_shape_new'], axis=0), make_node('Reshape', [input_nodes[0], name+'_shape_new'], [name], name=name) ] - return nodes if targ_shape == [0, -4, -1, 4, 0, 0] and reverse != 'True': + special_case = True create_tensor([4], name+'_4', kwargs['initializer']) nodes = [ make_node('Shape', [input_nodes[0]], [name+'_shape']), @@ -1670,9 +1672,9 @@ def convert_reshape(node, **kwargs): [name+'_shape_new'], axis=0), make_node('Reshape', [input_nodes[0], name+'_shape_new'], [name], name=name) ] - return nodes if targ_shape == [0, 0, -4, 2, 2, 0, 0] and reverse != 'True': + special_case = True create_tensor([2], name+'_2', kwargs['initializer']) nodes = [ make_node('Shape', [input_nodes[0]], [name+'_shape']), @@ -1682,9 +1684,9 @@ def convert_reshape(node, **kwargs): name+'_dim3', name+'_dim4'], [name+'_shape_new'], axis=0), make_node('Reshape', [input_nodes[0], name+'_shape_new'], [name], name=name) ] - return nodes if targ_shape == [-4, 1, -1, 0, 0, 0] and reverse != 'True': + special_case = True create_tensor([1], name+'_1', kwargs['initializer']) create_tensor([-1], name+'_m1', kwargs['initializer']) nodes = [ @@ -1695,9 +1697,9 @@ def convert_reshape(node, **kwargs): [name+'_shape_new'], axis=0), make_node('Reshape', [input_nodes[0], name+'_shape_new'], [name], name=name) ] - return nodes if targ_shape == [-4, 1, 1000, 0, 0] and reverse != 'True': + special_case = True create_tensor([1], name+'_1', kwargs['initializer']) create_tensor([1000], name+'_1000', kwargs['initializer']) nodes = [ @@ -1707,9 +1709,9 @@ def convert_reshape(node, **kwargs): [name+'_shape_new'], axis=0), make_node('Reshape', [input_nodes[0], name+'_shape_new'], [name], name=name) ] - return nodes if targ_shape == [0, -4, 12, -1, 0] and reverse != 'True': + special_case = True create_tensor([-1], name+'_m1', kwargs['initializer']) create_tensor([12], name+'_12', kwargs['initializer']) nodes = [ @@ -1719,9 +1721,9 @@ def convert_reshape(node, **kwargs): [name+'_shape_new'], axis=0), make_node('Reshape', [input_nodes[0], name+'_shape_new'], [name], name=name) ] - return nodes if targ_shape == [0, -4, 16, -1, 0] and reverse != 'True': + special_case = True create_tensor([-1], name+'_m1', kwargs['initializer']) create_tensor([16], name+'_16', kwargs['initializer']) nodes = [ @@ -1731,6 +1733,8 @@ def convert_reshape(node, **kwargs): [name+'_shape_new'], axis=0), make_node('Reshape', [input_nodes[0], name+'_shape_new'], [name], name=name) ] + + if special_case: return nodes not_supported_shape = [-2, -3, -4]