diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 5e3749397304..01ec2c2e6524 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -2256,25 +2256,54 @@ def convert_layer_norm(node, **kwargs): """Map MXNet's LayerNorm operator attributes to onnx operators. """ from onnx.helper import make_node + from onnx import TensorProto name, input_nodes, attrs = get_inputs(node, kwargs) + axes = int(attrs.get('axis', -1)) + eps = attrs.get('eps', 9.99999975e-06) + + create_tensor([axes], name+"_axes", kwargs["initializer"]) + create_tensor([axes+1], name+"_axes+1", kwargs["initializer"]) + create_tensor([], name+"_void", kwargs["initializer"]) + create_const_scalar_node(name+'_0_s', np.int64(0), kwargs) + create_const_scalar_node(name+'_1_s', np.int64(1), kwargs) + create_const_scalar_node(name+"_2_s", np.int64(2), kwargs) + create_const_scalar_node(name+"_eps", np.float32(eps), kwargs) - in_shape = kwargs['in_shape'] - axes = [-i for i in range(len(in_shape[0]), 0, -1)] - eps = attrs.get('eps') nodes = [ - make_node("ReduceMean", [input_nodes[0]], [name+"_rm0_out"], axes=axes), + make_node("ReduceMean", [input_nodes[0]], [name+"_rm0_out"], axes=[axes]), make_node("Sub", [input_nodes[0], name+"_rm0_out"], [name+"_sub0_out"]), - create_const_scalar_node(name+"_two", np.float32(2.), kwargs), - make_node("Pow", [name+"_sub0_out", name+"_two"], [name+"_pow0_out"]), - make_node("ReduceMean", [name+"_pow0_out"], [name+"_rm1_out"], axes=axes), - create_const_scalar_node(name+"_eps", np.float32(eps), kwargs), + make_node("Pow", [name+"_sub0_out", name+"_2_s"], [name+"_pow0_out"]), + make_node("ReduceMean", [name+"_pow0_out"], [name+"_rm1_out"], axes=[axes]), make_node("Add", [name+"_rm1_out", name+"_eps"], [name+"_add0_out"]), make_node("Sqrt", [name+"_add0_out"], [name+"_sqrt0_out"]), make_node("Div", [name+"_sub0_out", name+"_sqrt0_out"], [name+"_div0_out"]), - make_node("Mul", [name+"_div0_out", input_nodes[1]], [name+"_mul0_out"]), - make_node("Add", [name+"_mul0_out", input_nodes[2]], [name], name) ] + if axes == -1: + nodes += [ + make_node("Mul", [name+"_div0_out", input_nodes[1]], [name+"_mul0_out"]), + make_node("Add", [name+"_mul0_out", input_nodes[2]], [name]) + ] + else: + nodes += [ + make_node("Shape", [input_nodes[0]], [name+"_shape0_out"]), + make_node("Shape", [name+"_shape0_out"], [name+"_in_dim"]), + make_node("Reshape", [name+"_in_dim", name+"_void"], [name+"_in_dim_s"]), + make_node("Range", [name+"_0_s", name+"_in_dim_s", name+"_1_s"], [name+"_range"]), + make_node("Equal", [name+"_range", name+"_axes"], [name+"_equal"]), + make_node("Cast", [name+"_equal"], [name+"_one_hot"], to=int(TensorProto.INT64)), + make_node("Slice", [name+"_shape0_out", name+"_axes", name+"_axes+1"], [name+"_slice_out"]), + make_node("Reshape", [name+"_slice_out", name+"_void"], [name+"_slice_out_s"]), + make_node("Sub", [name+"_slice_out_s", name+"_1_s"], [name+"_sub1_out"]), + make_node("Mul", [name+"_one_hot", name+"_sub1_out"], [name+"_mul0_out"]), + make_node("Add", [name+"_mul0_out", name+"_1_s"], [name+"_add1_out"]), + make_node('Reshape', [input_nodes[1], name+"_add1_out"], [name+"gamma_exp"]), + make_node('Reshape', [input_nodes[2], name+"_add1_out"], [name+"beta_exp"]), + make_node('Expand', [name+"gamma_exp", name+"_shape0_out"], [name+"gamma_exp1"]), + make_node('Expand', [name+"beta_exp", name+"_shape0_out"], [name+"beta_exp1"]), + make_node("Mul", [name+"_div0_out", name+"gamma_exp1"], [name+"_mul1_out"]), + make_node("Add", [name+"_mul1_out", name+"beta_exp1"], [name], name=name) + ] return nodes @@ -2345,6 +2374,52 @@ def convert_matmul_selfatt_qk(node, **kwargs): return nodes +@mx_op.register("_contrib_interleaved_matmul_selfatt_valatt") +def convert_contrib_interleaved_matmul_selfatt_valatt(node, **kwargs): + """Map MXNet's _contrib_interleaved_matmul_selfatt_valatt operator attributes to onnx's operator. + """ + from onnx.helper import make_node + name, input_nodes, attrs = get_inputs(node, kwargs) + qkv = input_nodes[0] + att = input_nodes[1] + num_heads = int(attrs.get('heads')) + + create_tensor([num_heads], name+"_const_num_heads", kwargs["initializer"]) + create_tensor([0], name+"_const_0", kwargs["initializer"]) + create_tensor([1], name+"_const_1", kwargs["initializer"]) + create_tensor([2], name+"_const_2", kwargs["initializer"]) + create_tensor([3], name+"_const_3", kwargs["initializer"]) + create_tensor([4], name+"_const_4", kwargs["initializer"]) + create_tensor([5], name+"_const_5", kwargs["initializer"]) + create_tensor([0, 0, num_heads, 3, -1], name+"_reshape0_shape", kwargs["initializer"]) + create_tensor([0, 0, 0, 2, 0], name+"_slice_start", kwargs["initializer"]) + create_tensor([0, 0, 0, -1], name+"_reshape1_shape", kwargs["initializer"]) + create_tensor([0, 0, -1], name+"_reshape4_shape", kwargs["initializer"]) + + nodes = [ + make_node("Shape", [qkv], [name+"_shape_qkv"]), + make_node("Slice", [name+"_shape_qkv", name+"_const_0", name+"_const_1"], [name+"_qkv_d0"]), + make_node("Slice", [name+"_shape_qkv", name+"_const_1", name+"_const_2"], [name+"_qkv_d1"]), + make_node("Slice", [name+"_shape_qkv", name+"_const_2", name+"_const_3"], [name+"_qkv_d2"]), + make_node('Mul', [name+"_qkv_d1", name+'_const_num_heads'], [name+'_mul_out']), + make_node("Reshape", [qkv, name+"_reshape0_shape"], [name+"_reshape0_output"]), + make_node("Shape", [name+"_reshape0_output"], [name+"_shape_reshape0"]), + make_node("Slice", [name+"_shape_reshape0", name+"_const_4", name+"_const_5"], [name+"_d4"]), + make_node("Concat", [name+"_mul_out", name+"_qkv_d0", name+"_d4"], [name+"_reshape2_shape"], axis=0), + make_node("Concat", [name+"_qkv_d1", name+"_const_num_heads", name+"_qkv_d0", name+"_d4"], \ + [name+"_reshape3_shape"], axis=0), + make_node("Concat", [name+"_qkv_d0", name+"_qkv_d1", name+"_qkv_d2", name+"_const_3", name+"_d4"], \ + [name+"_slice_end"], axis=0), + make_node("Slice", [name+"_reshape0_output", name+"_slice_start", name+"_slice_end"], [name+"_slice_output"]), + make_node("Reshape", [name+"_slice_output", name+"_reshape1_shape"], [name+"_reshape1_output"]), + make_node("Transpose", [name+"_reshape1_output"], [name+"_transpose0_output"], perm=[1, 2, 0, 3]), + make_node("Reshape", [name+"_transpose0_output", name+"_reshape2_shape"], [name+"_reshape2_output"]), + make_node("MatMul", [att, name+"_reshape2_output"], [name+"_matmul_output"]), + make_node("Reshape", [name+"_matmul_output", name+"_reshape3_shape"], [name+"_reshape3_output"]), + make_node("Transpose", [name+"_reshape3_output"], [name+"_transpose2_output"], perm=[2, 0, 1, 3]), + make_node("Reshape", [name+"_transpose2_output", name+"_reshape4_shape"], [name], name=name) + ] + return nodes @mx_op.register("broadcast_axis") def convert_broadcast_axis(node, **kwargs): diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 4f6bacac8665..ed4060a66452 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -100,13 +100,16 @@ def test_onnx_export_arange_like(tmp_path, dtype, axis, start, step, test_data): x = mx.nd.array(test_data, dtype=dtype) op_export_test('arange_like', M, [x], tmp_path) - -def test_onnx_export_layernorm(tmp_path): - M = def_model('LayerNorm', axis=1) - x = mx.nd.array([[1,3],[2,4]], dtype='float32') - gamma = mx.random.uniform(0, 1, x[0].shape, dtype='float32') - beta = mx.random.uniform(0, 1, x[0].shape, dtype='float32') - op_export_test('LayerNorm', M, [x, gamma, beta], tmp_path) +@pytest.mark.parametrize('dtype', ['float32']) +def test_onnx_export_layernorm(tmp_path, dtype): + x = mx.nd.random.uniform(1, 2, (3, 4, 5), dtype=dtype) + axes = list(range(np.shape(np.shape(x))[0])) + axes.append(-1) + for axis in axes: + M = def_model('LayerNorm', axis=axis) + gamma = mx.random.uniform(0, 1, [np.shape(x)[axis]], dtype=dtype) + beta = mx.random.uniform(0, 1, [np.shape(x)[axis]], dtype=dtype) + op_export_test('LayerNorm', M, [x, gamma, beta], tmp_path) @pytest.mark.parametrize('dtype', ['float32', 'float64', 'int32']) @@ -146,6 +149,13 @@ def test_onnx_export_contrib_interleaved_matmul_selfatt_qk(tmp_path, dtype): x2 = mx.nd.random.uniform(0, 1, (7, 5, 4*5*6)) op_export_test('contrib_interleaved_matmul_selfatt_qk_2', M2, [x2], tmp_path) +@pytest.mark.parametrize('dtype', ['float32']) +def test_onnx_export_contrib_interleaved_matmul_selfatt_valatt(tmp_path, dtype): + M = def_model('contrib.interleaved_matmul_selfatt_valatt', heads=6) + x = mx.nd.random.uniform(0, 1, (4, 5, 6*7*3), dtype=dtype) + att = mx.nd.random.uniform(0, 1, (5*6, 4, 4), dtype=dtype) + op_export_test('contrib_interleaved_matmul_selfatt_valatt', M, [x, att], tmp_path) + @pytest.mark.parametrize('dtype', ['float32', 'float64', 'int32', 'int64']) @pytest.mark.parametrize('num_hidden', [1, 5, 10, 20])