From d9ce049834d1c716dc70d817d1c002019351986e Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Tue, 15 Dec 2020 13:09:19 -0800 Subject: [PATCH 1/6] convert ops and add tests --- .../contrib/onnx/mx2onnx/_op_translations.py | 97 ++++++++++++++++--- tests/python-pytest/onnx/test_operators.py | 20 +++- 2 files changed, 99 insertions(+), 18 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 5e3749397304..8a3c2a6c34eb 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -53,6 +53,7 @@ """ import re +import sys import logging import numpy as np from .export_onnx import MXNetGraph as mx_op @@ -1631,8 +1632,7 @@ def convert_slice_axis(node, **kwargs): if not ends or ends == 'None': # ONNX doesn't support None for ends. Since ends=None depicts # length of dimension, passing dimension in this case. - in_shape = kwargs['in_shape'][0] - ends = in_shape[axes] + ends = sys.maxsize node = onnx.helper.make_node( "Slice", @@ -1640,7 +1640,7 @@ def convert_slice_axis(node, **kwargs): [name], axes=[axes], starts=[starts], - ends=[int(ends)], + ends=[ends], name=name, ) return [node] @@ -2257,24 +2257,52 @@ def convert_layer_norm(node, **kwargs): """ from onnx.helper import make_node name, input_nodes, attrs = get_inputs(node, kwargs) + axes = int(attrs.get('axis', -1)) + eps = attrs.get('eps', 9.99999975e-06) + + make_tensor([axes], name+"_axes", kwargs["initializer"]) + make_tensor([axes+1], name+"_axes+1", kwargs["initializer"]) + make_tensor([], name+"_void", kwargs["initializer"]) + create_const_scalar_node(name+'_0_s', np.int64(0), kwargs) + create_const_scalar_node(name+'_1_s', np.int64(1), kwargs) + create_const_scalar_node(name+"_2_s", np.int64(2), kwargs) + create_const_scalar_node(name+"_eps", np.float32(eps), kwargs) - in_shape = kwargs['in_shape'] - axes = [-i for i in range(len(in_shape[0]), 0, -1)] - eps = attrs.get('eps') nodes = [ - make_node("ReduceMean", [input_nodes[0]], [name+"_rm0_out"], axes=axes), + make_node("ReduceMean", [input_nodes[0]], [name+"_rm0_out"], axes=[axes]), make_node("Sub", [input_nodes[0], name+"_rm0_out"], [name+"_sub0_out"]), - create_const_scalar_node(name+"_two", np.float32(2.), kwargs), - make_node("Pow", [name+"_sub0_out", name+"_two"], [name+"_pow0_out"]), - make_node("ReduceMean", [name+"_pow0_out"], [name+"_rm1_out"], axes=axes), - create_const_scalar_node(name+"_eps", np.float32(eps), kwargs), + make_node("Pow", [name+"_sub0_out", name+"_2_s"], [name+"_pow0_out"]), + make_node("ReduceMean", [name+"_pow0_out"], [name+"_rm1_out"], axes=[axes]), make_node("Add", [name+"_rm1_out", name+"_eps"], [name+"_add0_out"]), make_node("Sqrt", [name+"_add0_out"], [name+"_sqrt0_out"]), make_node("Div", [name+"_sub0_out", name+"_sqrt0_out"], [name+"_div0_out"]), - make_node("Mul", [name+"_div0_out", input_nodes[1]], [name+"_mul0_out"]), - make_node("Add", [name+"_mul0_out", input_nodes[2]], [name], name) ] + if axes == -1: + nodes += [ + make_node("Mul", [name+"_div0_out", input_nodes[1]], [name+"_mul0_out"]), + make_node("Add", [name+"_mul0_out", input_nodes[2]], [name]) + ] + else: + nodes += [ + make_node("Shape", [input_nodes[0]], [name+"_shape0_out"]), + make_node("Shape", [name+"_shape0_out"], [name+"_in_dim"]), + make_node("Reshape", [name+"_in_dim", name+"_void"], [name+"_in_dim_s"]), + make_node("Range", [name+"_0_s", name+"_in_dim_s", name+"_1_s"], [name+"_range"]), + make_node("Equal", [name+"_range", name+"_axes"], [name+"_equal"]), + make_node("Cast", [name+"_equal"], [name+"_one_hot"], to=int(TensorProto.INT64)), + make_node("Slice", [name+"_shape0_out", name+"_axes", name+"_axes+1"], [name+"_slice_out"]), + make_node("Reshape", [name+"_slice_out", name+"_void"], [name+"_slice_out_s"]), + make_node("Sub", [name+"_slice_out_s", name+"_1_s"], [name+"_sub1_out"]), + make_node("Mul", [name+"_one_hot", name+"_sub1_out"], [name+"_mul0_out"]), + make_node("Add", [name+"_mul0_out", name+"_1_s"], [name+"_add1_out"]), + make_node('Reshape', [input_nodes[1], name+"_add1_out"], [name+"gamma_exp"]), + make_node('Reshape', [input_nodes[2], name+"_add1_out"], [name+"beta_exp"]), + make_node('Expand', [name+"gamma_exp", name+"_shape0_out"], [name+"gamma_exp1"]), + make_node('Expand', [name+"beta_exp", name+"_shape0_out"], [name+"beta_exp1"]), + make_node("Mul", [name+"_div0_out", name+"gamma_exp1"], [name+"_mul1_out"]), + make_node("Add", [name+"_mul1_out", name+"beta_exp1"], [name]) + ] return nodes @@ -2345,6 +2373,49 @@ def convert_matmul_selfatt_qk(node, **kwargs): return nodes +@mx_op.register("_contrib_interleaved_matmul_selfatt_valatt") +def convert_contrib_interleaved_matmul_selfatt_valatt(node, **kwargs): + """Map MXNet's _contrib_interleaved_matmul_selfatt_valatt operator attributes to onnx's operator. + """ + name, input_nodes, attrs = get_inputs(node, kwargs) + qkv, att = input_nodes + num_heads = int(attrs.get('heads')) + + create_tensor([num_heads], name+"_const_num_heads", kwargs["initializer"]) + create_tensor([0], name+"_const_0", kwargs["initializer"]) + create_tensor([1], name+"_const_1", kwargs["initializer"]) + create_tensor([2], name+"_const_2", kwargs["initializer"]) + create_tensor([3], name+"_const_3", kwargs["initializer"]) + create_tensor([4], name+"_const_4", kwargs["initializer"]) + create_tensor([5], name+"_const_5", kwargs["initializer"]) + create_tensor([0, 0, num_heads, 3, -1], name+"_reshape0_shape", kwargs["initializer"]) + create_tensor([0, 0, 0, 2, 0], name+"_slice_start", kwargs["initializer"]) + create_tensor([sys.maxsize, sys.maxsize, sys.maxsize, 3, sys.maxsize], name+"_slice_end", kwargs["initializer"]) + create_tensor([0, 0, 0, -1], name+"_reshape1_shape", kwargs["initializer"]) + create_tensor([0, 0, -1], name+"_reshape4_shape", kwargs["initializer"]) + + nodes = [ + make_node("Shape", [qkv], [name+"_shape_qkv"]), + make_node("Slice", [name+"_shape_qkv", name+"_const_0", name+"_const_1"], [name+"_qkv_d0"]), + make_node("Slice", [name+"_shape_qkv", name+"_const_1", name+"_const_2"], [name+"_qkv_d1"]), + make_node("Slice", [name+"_shape_qkv", name+"_const_2", name+"_const_3"], [name+"_qkv_d2"]), + make_node('Mul', [name+"_qkv_d1", name+'_const_num_heads'], [name+'_mul_out']), + make_node("Reshape", [qkv, name+"_reshape0_shape"], [name+"_reshape0_output"]), + make_node("Shape", [name+"_reshape0_output"], [name+"_shape_reshape0"]), + make_node("Slice", [name+"_shape_reshape0", name+"_const_4", name+"_const_5"], [name+"_d4"]), + make_node("Concat", [name+"_mul_out", name+"_qkv_d0", name+"_d4"], [name+"_reshape2_shape"], axis=0), + make_node("Concat", [name+"_qkv_d1", name+"_const_num_heads", name+"_qkv_d0", name+"_d4"], \ + [name+"_reshape3_shape"], axis=0), + make_node("Slice", [name+"_reshape0_output", name+"_slice_start", name+"_slice_end"], [name+"_slice_output"]), + make_node("Reshape", [name+"_slice_output", name+"_reshape1_shape"], [name+"_reshape1_output"]), + make_node("Transpose", [name+"_reshape1_output"], [name+"_transpose0_output"], perm=[1, 2, 0, 3]), + make_node("Reshape", [name+"_transpose0_output", name+"_reshape2_shape"], [name+"_reshape2_output"]), + make_node("MatMul", [att, name+"_reshape2_output"], [name+"_matmul_output"]), + make_node("Reshape", [name+"_matmul_output", name+"_reshape3_shape"], [name+"_reshape3_output"]), + make_node("Transpose", [name+"_reshape3_output"], [name+"_transpose2_output"], perm=[2, 0, 1, 3]), + make_node("Reshape", [name+"_transpose2_output", name+"_reshape4_shape"], [name]) + ] + return nodes @mx_op.register("broadcast_axis") def convert_broadcast_axis(node, **kwargs): diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 4f6bacac8665..5dc20f567830 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -102,11 +102,14 @@ def test_onnx_export_arange_like(tmp_path, dtype, axis, start, step, test_data): def test_onnx_export_layernorm(tmp_path): - M = def_model('LayerNorm', axis=1) - x = mx.nd.array([[1,3],[2,4]], dtype='float32') - gamma = mx.random.uniform(0, 1, x[0].shape, dtype='float32') - beta = mx.random.uniform(0, 1, x[0].shape, dtype='float32') - op_export_test('LayerNorm', M, [x, gamma, beta], tmp_path) + x = mx.nd.random.uniform(1, 2, (3, 4, 5), dtype=dtype) + axes = list(range(np.shape(np.shape(x))[0])) + axes.append(-1) + for axis in axes: + M = def_model('LayerNorm', axis=axis) + gamma = mx.random.uniform(0, 1, [np.shape(x)[axis]], dtype=dtype) + beta = mx.random.uniform(0, 1, [np.shape(x)[axis]], dtype=dtype) + op_export_test('LayerNorm', M, [x, gamma, beta], tmp_path) @pytest.mark.parametrize('dtype', ['float32', 'float64', 'int32']) @@ -146,6 +149,13 @@ def test_onnx_export_contrib_interleaved_matmul_selfatt_qk(tmp_path, dtype): x2 = mx.nd.random.uniform(0, 1, (7, 5, 4*5*6)) op_export_test('contrib_interleaved_matmul_selfatt_qk_2', M2, [x2], tmp_path) +@pytest.mark.parametrize('dtype', ['float32']) +def test_onnx_export_contrib_interleaved_matmul_selfatt_valatt(tmp_path, dtype): + M = def_model('contrib.interleaved_matmul_selfatt_valatt', heads=6) + x = mx.nd.random.uniform(0, 1, (4, 5, 6*7*3), dtype=dtype) + att = mx.nd.random.uniform(0, 1, (5*6, 4, 4), dtype=dtype) + op_export_test('contrib_interleaved_matmul_selfatt_valatt', M, [x, att], tmp_path) + @pytest.mark.parametrize('dtype', ['float32', 'float64', 'int32', 'int64']) @pytest.mark.parametrize('num_hidden', [1, 5, 10, 20]) From e2bc51e2602c89ce148e7e1492accc5f9de5bf24 Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Tue, 15 Dec 2020 13:20:38 -0800 Subject: [PATCH 2/6] revert slice fix, update function name --- python/mxnet/contrib/onnx/mx2onnx/_op_translations.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 8a3c2a6c34eb..5e71b2e982e4 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -1632,7 +1632,8 @@ def convert_slice_axis(node, **kwargs): if not ends or ends == 'None': # ONNX doesn't support None for ends. Since ends=None depicts # length of dimension, passing dimension in this case. - ends = sys.maxsize + in_shape = kwargs['in_shape'][0] + ends = in_shape[axes] node = onnx.helper.make_node( "Slice", @@ -1640,7 +1641,7 @@ def convert_slice_axis(node, **kwargs): [name], axes=[axes], starts=[starts], - ends=[ends], + ends=[int(ends)], name=name, ) return [node] @@ -2260,9 +2261,9 @@ def convert_layer_norm(node, **kwargs): axes = int(attrs.get('axis', -1)) eps = attrs.get('eps', 9.99999975e-06) - make_tensor([axes], name+"_axes", kwargs["initializer"]) - make_tensor([axes+1], name+"_axes+1", kwargs["initializer"]) - make_tensor([], name+"_void", kwargs["initializer"]) + create_tensor([axes], name+"_axes", kwargs["initializer"]) + create_tensor([axes+1], name+"_axes+1", kwargs["initializer"]) + create_tensor([], name+"_void", kwargs["initializer"]) create_const_scalar_node(name+'_0_s', np.int64(0), kwargs) create_const_scalar_node(name+'_1_s', np.int64(1), kwargs) create_const_scalar_node(name+"_2_s", np.int64(2), kwargs) From 451c066510c6e12438f4a2db94fc13484aa1074d Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Tue, 15 Dec 2020 13:49:43 -0800 Subject: [PATCH 3/6] fix sanity, update test --- .../contrib/onnx/mx2onnx/_op_translations.py | 44 ++++++++++--------- tests/python-pytest/onnx/test_operators.py | 4 +- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 5e71b2e982e4..c8f408c753c0 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -2257,6 +2257,7 @@ def convert_layer_norm(node, **kwargs): """Map MXNet's LayerNorm operator attributes to onnx operators. """ from onnx.helper import make_node + from onnx import TensorProto name, input_nodes, attrs = get_inputs(node, kwargs) axes = int(attrs.get('axis', -1)) eps = attrs.get('eps', 9.99999975e-06) @@ -2281,29 +2282,29 @@ def convert_layer_norm(node, **kwargs): if axes == -1: nodes += [ - make_node("Mul", [name+"_div0_out", input_nodes[1]], [name+"_mul0_out"]), - make_node("Add", [name+"_mul0_out", input_nodes[2]], [name]) - ] + make_node("Mul", [name+"_div0_out", input_nodes[1]], [name+"_mul0_out"]), + make_node("Add", [name+"_mul0_out", input_nodes[2]], [name]) + ] else: nodes += [ - make_node("Shape", [input_nodes[0]], [name+"_shape0_out"]), - make_node("Shape", [name+"_shape0_out"], [name+"_in_dim"]), - make_node("Reshape", [name+"_in_dim", name+"_void"], [name+"_in_dim_s"]), - make_node("Range", [name+"_0_s", name+"_in_dim_s", name+"_1_s"], [name+"_range"]), - make_node("Equal", [name+"_range", name+"_axes"], [name+"_equal"]), - make_node("Cast", [name+"_equal"], [name+"_one_hot"], to=int(TensorProto.INT64)), - make_node("Slice", [name+"_shape0_out", name+"_axes", name+"_axes+1"], [name+"_slice_out"]), - make_node("Reshape", [name+"_slice_out", name+"_void"], [name+"_slice_out_s"]), - make_node("Sub", [name+"_slice_out_s", name+"_1_s"], [name+"_sub1_out"]), - make_node("Mul", [name+"_one_hot", name+"_sub1_out"], [name+"_mul0_out"]), - make_node("Add", [name+"_mul0_out", name+"_1_s"], [name+"_add1_out"]), - make_node('Reshape', [input_nodes[1], name+"_add1_out"], [name+"gamma_exp"]), - make_node('Reshape', [input_nodes[2], name+"_add1_out"], [name+"beta_exp"]), - make_node('Expand', [name+"gamma_exp", name+"_shape0_out"], [name+"gamma_exp1"]), - make_node('Expand', [name+"beta_exp", name+"_shape0_out"], [name+"beta_exp1"]), - make_node("Mul", [name+"_div0_out", name+"gamma_exp1"], [name+"_mul1_out"]), - make_node("Add", [name+"_mul1_out", name+"beta_exp1"], [name]) - ] + make_node("Shape", [input_nodes[0]], [name+"_shape0_out"]), + make_node("Shape", [name+"_shape0_out"], [name+"_in_dim"]), + make_node("Reshape", [name+"_in_dim", name+"_void"], [name+"_in_dim_s"]), + make_node("Range", [name+"_0_s", name+"_in_dim_s", name+"_1_s"], [name+"_range"]), + make_node("Equal", [name+"_range", name+"_axes"], [name+"_equal"]), + make_node("Cast", [name+"_equal"], [name+"_one_hot"], to=int(TensorProto.INT64)), + make_node("Slice", [name+"_shape0_out", name+"_axes", name+"_axes+1"], [name+"_slice_out"]), + make_node("Reshape", [name+"_slice_out", name+"_void"], [name+"_slice_out_s"]), + make_node("Sub", [name+"_slice_out_s", name+"_1_s"], [name+"_sub1_out"]), + make_node("Mul", [name+"_one_hot", name+"_sub1_out"], [name+"_mul0_out"]), + make_node("Add", [name+"_mul0_out", name+"_1_s"], [name+"_add1_out"]), + make_node('Reshape', [input_nodes[1], name+"_add1_out"], [name+"gamma_exp"]), + make_node('Reshape', [input_nodes[2], name+"_add1_out"], [name+"beta_exp"]), + make_node('Expand', [name+"gamma_exp", name+"_shape0_out"], [name+"gamma_exp1"]), + make_node('Expand', [name+"beta_exp", name+"_shape0_out"], [name+"beta_exp1"]), + make_node("Mul", [name+"_div0_out", name+"gamma_exp1"], [name+"_mul1_out"]), + make_node("Add", [name+"_mul1_out", name+"beta_exp1"], [name]) + ] return nodes @@ -2378,6 +2379,7 @@ def convert_matmul_selfatt_qk(node, **kwargs): def convert_contrib_interleaved_matmul_selfatt_valatt(node, **kwargs): """Map MXNet's _contrib_interleaved_matmul_selfatt_valatt operator attributes to onnx's operator. """ + from onnx.helper import make_node name, input_nodes, attrs = get_inputs(node, kwargs) qkv, att = input_nodes num_heads = int(attrs.get('heads')) diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 5dc20f567830..ed4060a66452 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -100,8 +100,8 @@ def test_onnx_export_arange_like(tmp_path, dtype, axis, start, step, test_data): x = mx.nd.array(test_data, dtype=dtype) op_export_test('arange_like', M, [x], tmp_path) - -def test_onnx_export_layernorm(tmp_path): +@pytest.mark.parametrize('dtype', ['float32']) +def test_onnx_export_layernorm(tmp_path, dtype): x = mx.nd.random.uniform(1, 2, (3, 4, 5), dtype=dtype) axes = list(range(np.shape(np.shape(x))[0])) axes.append(-1) From 4fb08caefde469b0bc17e58cd81056feab99c126 Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Tue, 15 Dec 2020 14:29:44 -0800 Subject: [PATCH 4/6] add name to output node --- python/mxnet/contrib/onnx/mx2onnx/_op_translations.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index c8f408c753c0..05952c2b741c 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -2285,7 +2285,7 @@ def convert_layer_norm(node, **kwargs): make_node("Mul", [name+"_div0_out", input_nodes[1]], [name+"_mul0_out"]), make_node("Add", [name+"_mul0_out", input_nodes[2]], [name]) ] - else: + else: nodes += [ make_node("Shape", [input_nodes[0]], [name+"_shape0_out"]), make_node("Shape", [name+"_shape0_out"], [name+"_in_dim"]), @@ -2303,7 +2303,7 @@ def convert_layer_norm(node, **kwargs): make_node('Expand', [name+"gamma_exp", name+"_shape0_out"], [name+"gamma_exp1"]), make_node('Expand', [name+"beta_exp", name+"_shape0_out"], [name+"beta_exp1"]), make_node("Mul", [name+"_div0_out", name+"gamma_exp1"], [name+"_mul1_out"]), - make_node("Add", [name+"_mul1_out", name+"beta_exp1"], [name]) + make_node("Add", [name+"_mul1_out", name+"beta_exp1"], [name], name=name) ] return nodes @@ -2381,7 +2381,8 @@ def convert_contrib_interleaved_matmul_selfatt_valatt(node, **kwargs): """ from onnx.helper import make_node name, input_nodes, attrs = get_inputs(node, kwargs) - qkv, att = input_nodes + qkv = input_nodes[0] + att = input_nodes[1] num_heads = int(attrs.get('heads')) create_tensor([num_heads], name+"_const_num_heads", kwargs["initializer"]) @@ -2416,7 +2417,7 @@ def convert_contrib_interleaved_matmul_selfatt_valatt(node, **kwargs): make_node("MatMul", [att, name+"_reshape2_output"], [name+"_matmul_output"]), make_node("Reshape", [name+"_matmul_output", name+"_reshape3_shape"], [name+"_reshape3_output"]), make_node("Transpose", [name+"_reshape3_output"], [name+"_transpose2_output"], perm=[2, 0, 1, 3]), - make_node("Reshape", [name+"_transpose2_output", name+"_reshape4_shape"], [name]) + make_node("Reshape", [name+"_transpose2_output", name+"_reshape4_shape"], [name], name=name) ] return nodes From 8b1addaa2045d4c3111095313dbd260a2a2f782b Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Tue, 15 Dec 2020 22:35:32 -0800 Subject: [PATCH 5/6] update slice end --- python/mxnet/contrib/onnx/mx2onnx/_op_translations.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 05952c2b741c..d50bd1105300 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -2394,7 +2394,6 @@ def convert_contrib_interleaved_matmul_selfatt_valatt(node, **kwargs): create_tensor([5], name+"_const_5", kwargs["initializer"]) create_tensor([0, 0, num_heads, 3, -1], name+"_reshape0_shape", kwargs["initializer"]) create_tensor([0, 0, 0, 2, 0], name+"_slice_start", kwargs["initializer"]) - create_tensor([sys.maxsize, sys.maxsize, sys.maxsize, 3, sys.maxsize], name+"_slice_end", kwargs["initializer"]) create_tensor([0, 0, 0, -1], name+"_reshape1_shape", kwargs["initializer"]) create_tensor([0, 0, -1], name+"_reshape4_shape", kwargs["initializer"]) @@ -2410,6 +2409,8 @@ def convert_contrib_interleaved_matmul_selfatt_valatt(node, **kwargs): make_node("Concat", [name+"_mul_out", name+"_qkv_d0", name+"_d4"], [name+"_reshape2_shape"], axis=0), make_node("Concat", [name+"_qkv_d1", name+"_const_num_heads", name+"_qkv_d0", name+"_d4"], \ [name+"_reshape3_shape"], axis=0), + make_node("Concat", [name+"_qkv_d0", name+"_qkv_d1", name+"_qkv_d2", name+"_const_3", name+"_d4"], \ + [name+"_slice_end"], axis=0), make_node("Slice", [name+"_reshape0_output", name+"_slice_start", name+"_slice_end"], [name+"_slice_output"]), make_node("Reshape", [name+"_slice_output", name+"_reshape1_shape"], [name+"_reshape1_output"]), make_node("Transpose", [name+"_reshape1_output"], [name+"_transpose0_output"], perm=[1, 2, 0, 3]), From bde51469f87b344847f9e38754f9fa9740f88ed4 Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Wed, 16 Dec 2020 10:05:06 -0800 Subject: [PATCH 6/6] remove sys --- python/mxnet/contrib/onnx/mx2onnx/_op_translations.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index d50bd1105300..01ec2c2e6524 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -53,7 +53,6 @@ """ import re -import sys import logging import numpy as np from .export_onnx import MXNetGraph as mx_op