From 34d4a4667ad1037f4ea2e3269ae503dde2490b7b Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Tue, 26 Jan 2021 00:49:55 +0000 Subject: [PATCH 1/2] swap axis --- .../contrib/onnx/mx2onnx/_op_translations.py | 49 ++++++++++++++++++- tests/python-pytest/onnx/test_operators.py | 11 +++++ 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index a1d5320d9f94..0445ed0640d4 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -186,13 +186,13 @@ def create_tensor(tensor_list, tensor_name, initializer, dtype='int64'): dims = np.shape(tensor_np) tensor_node = onnx.helper.make_tensor_value_info(tensor_name, data_type, dims) if dtype == np.float16: - tensor_list = tensor_np.view(dtype=np.uint16).flatten().tolist() + tensor_np = tensor_np.view(dtype=np.uint16) initializer.append( onnx.helper.make_tensor( name=tensor_name, data_type=data_type, dims=dims, - vals=tensor_list, + vals=tensor_np.flatten().tolist(), raw=False ) ) @@ -3275,3 +3275,48 @@ def convert_gather_nd(node, **kwargs): ] return nodes + + +@mx_op.register('SwapAxis') +def convert_reshape_like(node, **kwargs): + """Map MXNet's SwapAxis operator + """ + from onnx.helper import make_node + name, input_nodes, attrs = get_inputs(node, kwargs) + + dim1 = int(attrs.get('dim1', '0')) + dim2 = int(attrs.get('dim2', '0')) + + if dim1 < 0 or dim2 < 0: + raise NotImplementedError('SwapAxis conversion does not support dim1 < 0\ + or dim2 < 0') + + indices = [[dim1], [dim2]] + vals = [dim2, dim1] + perm = [i for i in range(10)] + perm[dim1], perm[dim2] = dim2, dim1 + + nodes = [ + create_tensor(indices, name+'_ind', kwargs['initializer']), + create_tensor(indices[::-1], name+'_ind_rev', kwargs['initializer']), + create_tensor(vals, name+'_vals', kwargs['initializer']), + create_tensor(perm, name+'_perm', kwargs['initializer']), + create_tensor([0], name+'_0', kwargs['initializer']), + create_tensor([1], name+'_1', kwargs['initializer']), + create_tensor([10], name+'_10', kwargs['initializer']), + make_node('Shape', [input_nodes[0]], [name+'_shape']), + make_node('Shape', [name+'_shape'], [name+'_dim']), + make_node('Sub', [name+'_10', name+'_dim'], [name+'_sub']), + make_node('ScatterND', [name+'_perm', name+'_ind', name+'_vals'], + [name+'_perm_new']), + make_node('GatherND', [name+'_shape', name+'_ind'], [name+'_gather']), + make_node('ScatterND', [name+'_shape', name+'_ind_rev', name+'_gather'], + [name+'_shape_new']), + make_node('Concat', [name+'_0', name+'_sub'], [name+'_pad'], axis=0), + make_node('Pad', [name+'_shape', name+'_pad', name+'_1'], [name+'_shape_padded']), + make_node('Reshape', [input_nodes[0], name+'_shape_padded'], [name+'_data_padded']), + make_node('Transpose', [name+'_data_padded'], [name+'_trans'], perm=perm), + make_node('Reshape', [name+'_trans', name+'_shape_new'], [name]) + ] + + return nodes diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 2c363a96ba04..81deb01fb13e 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -539,3 +539,14 @@ def test_onnx_export_gather_nd(tmp_path, dtype): M2 = def_model('gather_nd') op_export_test('gather_nd2', M2, [x2, y2], tmp_path) + +@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float16', 'float32', 'float64']) +@pytest.mark.parametrize('params', [((4, 5, 6), (0, 2)), ((4, 5, 6), (0, 1)), + ((1, 2, 3, 4, 1), (0, 4)), + ((4, 5, 1, 6), (0, 2))]) +def test_onnx_export_swap_axis(tmp_path, dtype, params): + shape = params[0] + dim1, dim2 = params[1] + x = mx.random.uniform(-100, 100, shape).astype(dtype) + M = def_model('SwapAxis', dim1=dim1, dim2=dim2) + op_export_test('SwapAxis', M, [x], tmp_path) From c6796276523fb6691db3068d1bb91fdbd204d0a8 Mon Sep 17 00:00:00 2001 From: Zhaoqi Zhu Date: Tue, 26 Jan 2021 16:40:20 -0800 Subject: [PATCH 2/2] Update _op_translations.py --- python/mxnet/contrib/onnx/mx2onnx/_op_translations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 0445ed0640d4..bd9ba35c7475 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -3278,7 +3278,7 @@ def convert_gather_nd(node, **kwargs): @mx_op.register('SwapAxis') -def convert_reshape_like(node, **kwargs): +def convert_swapaxis(node, **kwargs): """Map MXNet's SwapAxis operator """ from onnx.helper import make_node