Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 196 additions & 0 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,20 @@ def create_const_node(input_name, value, kwargs):
initializer.append(tensor_node)
return value_node

def create_tensor(shape_list, shape_name, initializer, dtype='int64'):
shape_np = np.array(shape_list, dtype=dtype)
data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[shape_np.dtype]
dims = np.shape(shape_np)
initializer.append(
onnx.helper.make_tensor(
name=shape_name,
data_type=data_type,
dims=dims,
vals=shape_list,
raw=False,
)
)

@mx_op.register("null")
def convert_weights_and_inputs(node, **kwargs):
"""Helper function to convert weights and inputs.
Expand Down Expand Up @@ -1556,7 +1570,16 @@ def convert_reshape(node, **kwargs):
"""
name, input_nodes, attrs = get_inputs(node, kwargs)

reverse = attrs.get('reverse', 'False')
output_shape_list = convert_string_to_list(attrs["shape"])
data_shape = list(kwargs['in_shape'][0])
if reverse == 'True':
output_shape_list.reverse()
data_shape.reverse()
for i, dim in enumerate(output_shape_list):
if dim == 0:
output_shape_list[i] = data_shape[i]
output_shape_list.reverse()

initializer = kwargs["initializer"]
output_shape_np = np.array(output_shape_list, dtype='int64')
Expand Down Expand Up @@ -2280,6 +2303,179 @@ def convert_layer_norm(node, **kwargs):
return nodes


@mx_op.register("_contrib_interleaved_matmul_selfatt_qk")
def convert_matmul_selfatt_qk(node, **kwargs):
"""Map MXNet's _contrib_interleaved_matmul_selfatt_qk operator
"""
from onnx.helper import make_node
from onnx import TensorProto
name, input_nodes, attrs = get_inputs(node, kwargs)

heads = int(attrs.get('heads'))

# a, b, c, d, e are seq_len, batch_size, num_heads, 3, head_dim respectively
create_tensor([0], name+"_0", kwargs["initializer"])
create_tensor([1], name+"_1", kwargs["initializer"])
create_tensor([1], name+"_1_f", kwargs["initializer"], dtype='float32')
create_tensor([2], name+"_2", kwargs["initializer"])
create_tensor([3], name+"_3", kwargs["initializer"])
create_tensor([heads], name+"_c", kwargs["initializer"])
create_tensor([3], name+"_d", kwargs["initializer"])

nodes = [
make_node('Shape', [input_nodes[0]], [name+"_data_shape"]),
make_node('Slice', [name+'_data_shape', name+'_0', name+'_1'], [name+"_a"]),
make_node('Slice', [name+'_data_shape', name+'_1', name+'_2'], [name+"_b"]),
make_node('Slice', [name+'_data_shape', name+'_2', name+'_3'], [name+"_cde"]),
make_node('Div', [name+'_cde', name+'_c'], [name+'_de']),
make_node('Div', [name+'_de', name+'_d'], [name+'_e']),
make_node('Cast', [name+'_e'], [name+'_e_f'], to=int(TensorProto.FLOAT)),
make_node('Sqrt', [name+'_e_f'], [name+'_sqrt_e']),
make_node('Div', [name+'_1_f', name+'_sqrt_e'], [name+'_1_over_sqrt_e']),
make_node('Mul', [name+'_b', name+'_c'], [name+'_bc']),

make_node("Concat", [name+'_a', name+'_b', name+'_c', name+'_d', name+'_e'], \
[name+'_shape0'], axis=0),
make_node("Concat", [name+'_0', name+'_0', name+'_0', name+'_0', name+'_0'], \
[name+'_slice_start0'], axis=0),
make_node("Concat", [name+'_a', name+'_b', name+'_c', name+'_1', name+'_e'], \
[name+'_slice_end0'], axis=0),
make_node("Concat", [name+'_a', name+'_b', name+'_c', name+'_e'], \
[name+'_shape1'], axis=0),
make_node("Concat", [name+'_bc', name+'_a', name+'_e'], \
[name+'_shape2'], axis=0),
make_node("Concat", [name+'_0', name+'_0', name+'_0', name+'_1', name+'_0'], \
[name+'_slice_start1'], axis=0),
make_node("Concat", [name+'_a', name+'_b', name+'_c', name+'_2', name+'_e'], \
[name+'_slice_end1'], axis=0),

make_node('Reshape', [input_nodes[0], name+'_shape0'], [name+'_reshape0_out']),
make_node('Slice', [name+'_reshape0_out', name+'_slice_start0', name+'_slice_end0'], \
[name+'_slice0_out']),
make_node('Reshape', [name+'_slice0_out', name+'_shape1'], [name+'_reshape1_out']),
make_node('Transpose', [name+'_reshape1_out'], [name+'_transpose0_out'], \
perm=(1, 2, 0, 3)),
make_node('Reshape', [name+'_transpose0_out', name+'_shape2'], [name+'_reshape2_out']),
make_node('Mul', [name+'_reshape2_out', name+'_1_over_sqrt_e'], [name+'_mul0_out']),
make_node('Slice', [name+'_reshape0_out', name+'_slice_start1', name+'_slice_end1'], \
[name+'_slice1_out']),
make_node('Reshape', [name+'_slice1_out', name+'_shape1'], [name+'_reshape3_out']),
make_node('Transpose', [name+'_reshape3_out'], [name+'_transpose1_out'], \
perm=(1, 2, 0, 3)),
make_node('Reshape', [name+'_transpose1_out', name+'_shape2'], [name+'_reshape4_out']),
make_node('Transpose', [name+'_reshape4_out'], [name+'_transpose2_out'], \
perm=(0, 2, 1)),
make_node('MatMul', [name+'_mul0_out', name+'_transpose2_out'], [name], name=name)
]

return nodes


@mx_op.register("broadcast_axis")
def convert_broadcast_axis(node, **kwargs):
"""Map MXNet's broadcast_axis
"""
from onnx.helper import make_node
from onnx import TensorProto
name, input_nodes, attrs = get_inputs(node, kwargs)

axis = convert_string_to_list(attrs.get('axis', '()'))
size = convert_string_to_list(attrs.get('size', '()'))
assert len(axis) == len(size)

create_tensor([0], name+'_0', kwargs["initializer"])
create_tensor([1], name+'_1', kwargs["initializer"])
create_tensor([], name+'_void', kwargs["initializer"])
create_const_scalar_node(name+'_0_s', np.int64(0), kwargs)
create_const_scalar_node(name+'_1_s', np.int64(1), kwargs)

shape_name = name+'_shape_0'
nodes = [
make_node('Shape', [input_nodes[0]], [shape_name]),
make_node('Shape', [shape_name], [name+'_in_dim']),
make_node('Reshape', [name+'_in_dim', name+'_void'], [name+'_in_dim_s']),
make_node('Range', [name+'_0_s', name+'_in_dim_s', name+'_1_s'], [name+'_range']),
]

for i, axis in enumerate(axis):
if axis not in (0, 1):
create_tensor([axis], name+'_'+str(axis), kwargs["initializer"])
create_tensor([size[i]-1], name+'_size_'+str(i), kwargs["initializer"])
_ = [
make_node('Equal', [name+'_range', name+'_'+str(axis)], [name+'_equal_'+str(i)]),
make_node('Cast', [name+'_equal_'+str(i)], [name+'_cast_'+str(i)], to=int(TensorProto.INT64)),
make_node('Mul', [name+'_size_'+str(i), name+'_cast_'+str(i)], [name+'_mul_'+str(i)]),
make_node('Add', [name+'_mul_'+str(i), name+'_1'], [name+'_add_'+str(i)]),
make_node('Mul', [name+'_add_'+str(i), shape_name], [name+'_shape_'+str(i+1)])
]
shape_name = name+'_shape_'+str(i+1)
nodes += _

nodes += [make_node('Expand', [input_nodes[0], shape_name], [name], name=name)]

return nodes


@mx_op.register("SequenceMask")
def convert_sequencemask(node, **kwargs):
"""Map MXNet's SequenceMask operator
"""
from onnx.helper import make_node
from onnx import TensorProto

name, input_nodes, attrs = get_inputs(node, kwargs)

use_sequence_length = attrs.get('use_sequence_length', 'False')
mask_val = float(attrs.get('value', '0'))
axis = int(attrs.get('axis', '0'))

if(use_sequence_length == 'False'):
return [make_node('Identity', [input_nodes[0]], [name], name=name)]

create_tensor([], name+'_void', kwargs["initializer"])
create_tensor([0], name+'_0', kwargs["initializer"])
create_tensor([1], name+'_1', kwargs["initializer"])
create_tensor([2], name+'_2', kwargs["initializer"])
create_const_scalar_node(name+'_0_s', np.int64(0), kwargs)
create_const_scalar_node(name+'_1_s', np.int64(1), kwargs)
create_const_scalar_node(name+'_2_s', np.int64(2), kwargs)
create_tensor([mask_val], name+'_mask_val', kwargs["initializer"], dtype='float32')

nodes = [
make_node('Shape', [input_nodes[0]], [name+'_in_shape']),
make_node('Slice', [name+'_in_shape', name+'_0', name+'_1'], [name+'_slice_0']),
make_node('Slice', [name+'_in_shape', name+'_1', name+'_2'], [name+'_slice_1']),
make_node('Concat', [name+'_slice_0', name+'_1'], [name+'_shape_0'], axis=0),
make_node('Shape', [name+'_in_shape'], [name+'_in_dim']),
make_node('Reshape', [name+'_in_dim', name+'_void'], [name+'_in_dim_s']),
make_node('Range', [name+'_0_s', name+'_in_dim_s', name+'_1_s'], [name+'_range_0']),
make_node('Less', [name+'_range_0', name+'_2'], [name+'_less_0']),
make_node('Where', [name+'_less_0', name+'_in_shape', name+'_1'], [name+'_shape_1'])
]

if(axis == 0):
nodes += [
make_node('Reshape', [name+'_slice_0', name+'_void'], [name+'_max_len']),
make_node('Range', [name+'_0_s', name+'_max_len', name+'_1_s'], [name+'_range_1']),
make_node('Reshape', [name+'_range_1', name+'_shape_0'], [name+"_reshape_0"]),
make_node('Cast', [input_nodes[1]], [name+'_cast'], to=int(TensorProto.INT64)),
make_node('Less', [name+'_reshape_0', name+'_cast'], [name+'_less_1']),
make_node('Reshape', [name+'_less_1', name+'_shape_1'], [name+"_reshape_1"]),
make_node('Where', [name+'_reshape_1', input_nodes[0], name+'_mask_val'], [name], name=name),
]
else:
nodes += [
make_node('Reshape', [name+'_slice_1', name+'_void'], [name+'_max_len']),
make_node('Range', [name+'_0_s', name+'_max_len', name+'_1_s'], [name+'_range_1']),
make_node('Reshape', [input_nodes[1], name+'_shape_0'], [name+"_reshape_0"]),
make_node('Cast', [name+"_reshape_0"], [name+'_cast'], to=int(TensorProto.INT64)),
make_node('Less', [name+'_range_1', name+'_cast'], [name+'_less_1']),
make_node('Reshape', [name+'_less_1', name+'_shape_1'], [name+"_reshape_1"]),
make_node('Where', [name+'_reshape_1', input_nodes[0], name+'_mask_val'], [name], name=name),
]
return nodes


@mx_op.register("Embedding")
def convert_embedding(node, **kwargs):
"""Map MXNet's Embedding operator attributes to onnx's
Expand Down
39 changes: 38 additions & 1 deletion tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def test_onnx_export_zeros_like(tmp_path):
op_export_test('zeros_like', M, [x], tmp_path)


@pytest.mark.parametrize("dtype", ["float32", "double"])
@pytest.mark.parametrize("dtype", ["float32", "float64"])
def test_onnx_export_arange_like(tmp_path, dtype):
M = def_model('contrib.arange_like')
x = mx.nd.array([[-2,-1,0],[0,50,99],[4,5,6],[7,8,9]], dtype=dtype)
Expand All @@ -104,3 +104,40 @@ def test_onnx_export_layernorm(tmp_path):
beta = mx.random.uniform(0, 1, x[0].shape, dtype='float32')
op_export_test('LayerNorm', M, [x, gamma, beta], tmp_path)


@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int32'])
def test_onnx_export_broadcast_axis(tmp_path, dtype):
M1 = def_model('broadcast_axis', axis=(0, 2), size=(3, 4))
M2 = def_model('broadcast_axis', axis=(0, 2), size=(1, 5))
x1 = mx.nd.array([[[1], [2]]], dtype=dtype)
op_export_test('broadcast_axis_1', M1, [x1], tmp_path)
op_export_test('broadcast_axis_2', M2, [x1], tmp_path)
M3 = def_model('broadcast_axis', axis=(1, 4), size=(3, 5))
x2 = mx.nd.ones((1, 1, 3, 1, 1, 1), dtype=dtype)
op_export_test('broadcast_axis_3', M3, [x2], tmp_path)


@pytest.mark.parametrize('dtype', ['float32'])
def test_onnx_export_SequenceMask(tmp_path, dtype):
M1 = def_model('SequenceMask', use_sequence_length=True, axis=1, value=-5)
M2 = def_model('SequenceMask', use_sequence_length=True, axis=0, value=-99)
x = mx.nd.array([[[[ 1., 2., 3., 3.5]],
[[ 4., 5., 6., 6.5]]],
[[[ 7., 8., 9., 9.5]],
[[ 10., 11., 12., 12.5]]],
[[[ 13., 14., 15., 15.5]],
[[ 16., 17., 18., 18.5]]]], dtype=dtype)
seq_len1 = mx.nd.array([1, 2, 1], dtype=dtype)
seq_len2 = mx.nd.array([1, 2], dtype=dtype)
op_export_test('SequenceMask_1', M1, [x, seq_len1], tmp_path)
op_export_test('SequenceMask_2', M2, [x, seq_len2], tmp_path)


@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int32'])
def test_onnx_export_contrib_interleaved_matmul_selfatt_qk(tmp_path, dtype):
M1 = def_model('contrib.interleaved_matmul_selfatt_qk', heads=3)
x1 = mx.nd.random.uniform(0, 1, (3, 3, 3*3*3))
op_export_test('contrib_interleaved_matmul_selfatt_qk_1', M1, [x1], tmp_path)
M2 = def_model('contrib.interleaved_matmul_selfatt_qk', heads=5)
x2 = mx.nd.random.uniform(0, 1, (7, 5, 4*5*6))
op_export_test('contrib_interleaved_matmul_selfatt_qk_2', M2, [x2], tmp_path)