diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 9a6f290cdf80..843d5e2a2873 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -3400,6 +3400,50 @@ def convert_gather_nd(node, **kwargs): return nodes +@mx_op.register('SwapAxis') +def convert_swapaxis(node, **kwargs): + """Map MXNet's SwapAxis operator + """ + from onnx.helper import make_node + name, input_nodes, attrs = get_inputs(node, kwargs) + + dim1 = int(attrs.get('dim1', '0')) + dim2 = int(attrs.get('dim2', '0')) + + if dim1 < 0 or dim2 < 0: + raise NotImplementedError('SwapAxis conversion does not support dim1 < 0\ + or dim2 < 0') + + indices = [[dim1], [dim2]] + vals = [dim2, dim1] + perm = [i for i in range(10)] + perm[dim1], perm[dim2] = dim2, dim1 + + nodes = [ + create_tensor(indices, name+'_ind', kwargs['initializer']), + create_tensor(indices[::-1], name+'_ind_rev', kwargs['initializer']), + create_tensor(vals, name+'_vals', kwargs['initializer']), + create_tensor(perm, name+'_perm', kwargs['initializer']), + create_tensor([0], name+'_0', kwargs['initializer']), + create_tensor([1], name+'_1', kwargs['initializer']), + create_tensor([10], name+'_10', kwargs['initializer']), + make_node('Shape', [input_nodes[0]], [name+'_shape']), + make_node('Shape', [name+'_shape'], [name+'_dim']), + make_node('Sub', [name+'_10', name+'_dim'], [name+'_sub']), + make_node('ScatterND', [name+'_perm', name+'_ind', name+'_vals'], + [name+'_perm_new']), + make_node('GatherND', [name+'_shape', name+'_ind'], [name+'_gather']), + make_node('ScatterND', [name+'_shape', name+'_ind_rev', name+'_gather'], + [name+'_shape_new']), + make_node('Concat', [name+'_0', name+'_sub'], [name+'_pad'], axis=0), + make_node('Pad', [name+'_shape', name+'_pad', name+'_1'], [name+'_shape_padded']), + make_node('Reshape', [input_nodes[0], name+'_shape_padded'], [name+'_data_padded']), + make_node('Transpose', [name+'_data_padded'], [name+'_trans'], perm=perm), + make_node('Reshape', [name+'_trans', name+'_shape_new'], [name]) + ] + + return nodes + @mx_op.register('slice_like') def convert_slice_like(node, **kwargs): diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 74e850be995e..038541abe953 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -595,6 +595,18 @@ def test_onnx_export_gather_nd(tmp_path, dtype): op_export_test('gather_nd2', M2, [x2, y2], tmp_path) +@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float16', 'float32', 'float64']) +@pytest.mark.parametrize('params', [((4, 5, 6), (0, 2)), ((4, 5, 6), (0, 1)), + ((1, 2, 3, 4, 1), (0, 4)), + ((4, 5, 1, 6), (0, 2))]) +def test_onnx_export_swap_axis(tmp_path, dtype, params): + shape = params[0] + dim1, dim2 = params[1] + x = mx.random.uniform(-100, 100, shape).astype(dtype) + M = def_model('SwapAxis', dim1=dim1, dim2=dim2) + op_export_test('SwapAxis', M, [x], tmp_path) + + @pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int32', 'int64']) @pytest.mark.parametrize('axes', [None, (0, 1, 2), (-2, -3), (-2, 0)]) def test_onnx_export_slice_like(tmp_path, dtype, axes):