Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 85 additions & 10 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2256,25 +2256,54 @@ def convert_layer_norm(node, **kwargs):
"""Map MXNet's LayerNorm operator attributes to onnx operators.
"""
from onnx.helper import make_node
from onnx import TensorProto
name, input_nodes, attrs = get_inputs(node, kwargs)
axes = int(attrs.get('axis', -1))
eps = attrs.get('eps', 9.99999975e-06)

create_tensor([axes], name+"_axes", kwargs["initializer"])
create_tensor([axes+1], name+"_axes+1", kwargs["initializer"])
create_tensor([], name+"_void", kwargs["initializer"])
create_const_scalar_node(name+'_0_s', np.int64(0), kwargs)
create_const_scalar_node(name+'_1_s', np.int64(1), kwargs)
create_const_scalar_node(name+"_2_s", np.int64(2), kwargs)
create_const_scalar_node(name+"_eps", np.float32(eps), kwargs)

in_shape = kwargs['in_shape']
axes = [-i for i in range(len(in_shape[0]), 0, -1)]
eps = attrs.get('eps')
nodes = [
make_node("ReduceMean", [input_nodes[0]], [name+"_rm0_out"], axes=axes),
make_node("ReduceMean", [input_nodes[0]], [name+"_rm0_out"], axes=[axes]),
make_node("Sub", [input_nodes[0], name+"_rm0_out"], [name+"_sub0_out"]),
create_const_scalar_node(name+"_two", np.float32(2.), kwargs),
make_node("Pow", [name+"_sub0_out", name+"_two"], [name+"_pow0_out"]),
make_node("ReduceMean", [name+"_pow0_out"], [name+"_rm1_out"], axes=axes),
create_const_scalar_node(name+"_eps", np.float32(eps), kwargs),
make_node("Pow", [name+"_sub0_out", name+"_2_s"], [name+"_pow0_out"]),
make_node("ReduceMean", [name+"_pow0_out"], [name+"_rm1_out"], axes=[axes]),
make_node("Add", [name+"_rm1_out", name+"_eps"], [name+"_add0_out"]),
make_node("Sqrt", [name+"_add0_out"], [name+"_sqrt0_out"]),
make_node("Div", [name+"_sub0_out", name+"_sqrt0_out"], [name+"_div0_out"]),
make_node("Mul", [name+"_div0_out", input_nodes[1]], [name+"_mul0_out"]),
make_node("Add", [name+"_mul0_out", input_nodes[2]], [name], name)
]

if axes == -1:
nodes += [
make_node("Mul", [name+"_div0_out", input_nodes[1]], [name+"_mul0_out"]),
make_node("Add", [name+"_mul0_out", input_nodes[2]], [name])
]
else:
nodes += [
make_node("Shape", [input_nodes[0]], [name+"_shape0_out"]),
make_node("Shape", [name+"_shape0_out"], [name+"_in_dim"]),
make_node("Reshape", [name+"_in_dim", name+"_void"], [name+"_in_dim_s"]),
make_node("Range", [name+"_0_s", name+"_in_dim_s", name+"_1_s"], [name+"_range"]),
make_node("Equal", [name+"_range", name+"_axes"], [name+"_equal"]),
make_node("Cast", [name+"_equal"], [name+"_one_hot"], to=int(TensorProto.INT64)),
make_node("Slice", [name+"_shape0_out", name+"_axes", name+"_axes+1"], [name+"_slice_out"]),
make_node("Reshape", [name+"_slice_out", name+"_void"], [name+"_slice_out_s"]),
make_node("Sub", [name+"_slice_out_s", name+"_1_s"], [name+"_sub1_out"]),
make_node("Mul", [name+"_one_hot", name+"_sub1_out"], [name+"_mul0_out"]),
make_node("Add", [name+"_mul0_out", name+"_1_s"], [name+"_add1_out"]),
make_node('Reshape', [input_nodes[1], name+"_add1_out"], [name+"gamma_exp"]),
make_node('Reshape', [input_nodes[2], name+"_add1_out"], [name+"beta_exp"]),
make_node('Expand', [name+"gamma_exp", name+"_shape0_out"], [name+"gamma_exp1"]),
make_node('Expand', [name+"beta_exp", name+"_shape0_out"], [name+"beta_exp1"]),
make_node("Mul", [name+"_div0_out", name+"gamma_exp1"], [name+"_mul1_out"]),
make_node("Add", [name+"_mul1_out", name+"beta_exp1"], [name], name=name)
]
return nodes


Expand Down Expand Up @@ -2345,6 +2374,52 @@ def convert_matmul_selfatt_qk(node, **kwargs):

return nodes

@mx_op.register("_contrib_interleaved_matmul_selfatt_valatt")
def convert_contrib_interleaved_matmul_selfatt_valatt(node, **kwargs):
"""Map MXNet's _contrib_interleaved_matmul_selfatt_valatt operator attributes to onnx's operator.
"""
from onnx.helper import make_node
name, input_nodes, attrs = get_inputs(node, kwargs)
qkv = input_nodes[0]
att = input_nodes[1]
num_heads = int(attrs.get('heads'))

create_tensor([num_heads], name+"_const_num_heads", kwargs["initializer"])
create_tensor([0], name+"_const_0", kwargs["initializer"])
create_tensor([1], name+"_const_1", kwargs["initializer"])
create_tensor([2], name+"_const_2", kwargs["initializer"])
create_tensor([3], name+"_const_3", kwargs["initializer"])
create_tensor([4], name+"_const_4", kwargs["initializer"])
create_tensor([5], name+"_const_5", kwargs["initializer"])
create_tensor([0, 0, num_heads, 3, -1], name+"_reshape0_shape", kwargs["initializer"])
create_tensor([0, 0, 0, 2, 0], name+"_slice_start", kwargs["initializer"])
create_tensor([0, 0, 0, -1], name+"_reshape1_shape", kwargs["initializer"])
create_tensor([0, 0, -1], name+"_reshape4_shape", kwargs["initializer"])

nodes = [
make_node("Shape", [qkv], [name+"_shape_qkv"]),
make_node("Slice", [name+"_shape_qkv", name+"_const_0", name+"_const_1"], [name+"_qkv_d0"]),
make_node("Slice", [name+"_shape_qkv", name+"_const_1", name+"_const_2"], [name+"_qkv_d1"]),
make_node("Slice", [name+"_shape_qkv", name+"_const_2", name+"_const_3"], [name+"_qkv_d2"]),
make_node('Mul', [name+"_qkv_d1", name+'_const_num_heads'], [name+'_mul_out']),
make_node("Reshape", [qkv, name+"_reshape0_shape"], [name+"_reshape0_output"]),
make_node("Shape", [name+"_reshape0_output"], [name+"_shape_reshape0"]),
make_node("Slice", [name+"_shape_reshape0", name+"_const_4", name+"_const_5"], [name+"_d4"]),
make_node("Concat", [name+"_mul_out", name+"_qkv_d0", name+"_d4"], [name+"_reshape2_shape"], axis=0),
make_node("Concat", [name+"_qkv_d1", name+"_const_num_heads", name+"_qkv_d0", name+"_d4"], \
[name+"_reshape3_shape"], axis=0),
make_node("Concat", [name+"_qkv_d0", name+"_qkv_d1", name+"_qkv_d2", name+"_const_3", name+"_d4"], \
[name+"_slice_end"], axis=0),
make_node("Slice", [name+"_reshape0_output", name+"_slice_start", name+"_slice_end"], [name+"_slice_output"]),
make_node("Reshape", [name+"_slice_output", name+"_reshape1_shape"], [name+"_reshape1_output"]),
make_node("Transpose", [name+"_reshape1_output"], [name+"_transpose0_output"], perm=[1, 2, 0, 3]),
make_node("Reshape", [name+"_transpose0_output", name+"_reshape2_shape"], [name+"_reshape2_output"]),
make_node("MatMul", [att, name+"_reshape2_output"], [name+"_matmul_output"]),
make_node("Reshape", [name+"_matmul_output", name+"_reshape3_shape"], [name+"_reshape3_output"]),
make_node("Transpose", [name+"_reshape3_output"], [name+"_transpose2_output"], perm=[2, 0, 1, 3]),
make_node("Reshape", [name+"_transpose2_output", name+"_reshape4_shape"], [name], name=name)
]
return nodes

@mx_op.register("broadcast_axis")
def convert_broadcast_axis(node, **kwargs):
Expand Down
24 changes: 17 additions & 7 deletions tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,16 @@ def test_onnx_export_arange_like(tmp_path, dtype, axis, start, step, test_data):
x = mx.nd.array(test_data, dtype=dtype)
op_export_test('arange_like', M, [x], tmp_path)


def test_onnx_export_layernorm(tmp_path):
M = def_model('LayerNorm', axis=1)
x = mx.nd.array([[1,3],[2,4]], dtype='float32')
gamma = mx.random.uniform(0, 1, x[0].shape, dtype='float32')
beta = mx.random.uniform(0, 1, x[0].shape, dtype='float32')
op_export_test('LayerNorm', M, [x, gamma, beta], tmp_path)
@pytest.mark.parametrize('dtype', ['float32'])
def test_onnx_export_layernorm(tmp_path, dtype):
x = mx.nd.random.uniform(1, 2, (3, 4, 5), dtype=dtype)
axes = list(range(np.shape(np.shape(x))[0]))
axes.append(-1)
for axis in axes:
M = def_model('LayerNorm', axis=axis)
gamma = mx.random.uniform(0, 1, [np.shape(x)[axis]], dtype=dtype)
beta = mx.random.uniform(0, 1, [np.shape(x)[axis]], dtype=dtype)
op_export_test('LayerNorm', M, [x, gamma, beta], tmp_path)


@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int32'])
Expand Down Expand Up @@ -146,6 +149,13 @@ def test_onnx_export_contrib_interleaved_matmul_selfatt_qk(tmp_path, dtype):
x2 = mx.nd.random.uniform(0, 1, (7, 5, 4*5*6))
op_export_test('contrib_interleaved_matmul_selfatt_qk_2', M2, [x2], tmp_path)

@pytest.mark.parametrize('dtype', ['float32'])
def test_onnx_export_contrib_interleaved_matmul_selfatt_valatt(tmp_path, dtype):
M = def_model('contrib.interleaved_matmul_selfatt_valatt', heads=6)
x = mx.nd.random.uniform(0, 1, (4, 5, 6*7*3), dtype=dtype)
att = mx.nd.random.uniform(0, 1, (5*6, 4, 4), dtype=dtype)
op_export_test('contrib_interleaved_matmul_selfatt_valatt', M, [x, att], tmp_path)


@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int32', 'int64'])
@pytest.mark.parametrize('num_hidden', [1, 5, 10, 20])
Expand Down