From 879b585d2818604b28a2c6247cef8d99306bcb4c Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Fri, 18 Dec 2020 23:35:49 +0000 Subject: [PATCH 1/2] fix slice --- .../contrib/onnx/mx2onnx/_op_translations.py | 61 +++++++++++++------ tests/python-pytest/onnx/test_operators.py | 10 ++- 2 files changed, 49 insertions(+), 22 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 0a9d9aac71e4..c7636f4dda96 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -113,7 +113,9 @@ def convert_string_to_list(string_val): val = val.replace("L", "") val = val.replace("[", "") val = val.replace("]", "") - if val not in ("", "None"): + if val == "None": + result_list.append(None) + elif val != "": result_list.append(int(val)) return result_list @@ -2516,17 +2518,17 @@ def convert_sequencemask(node, **kwargs): def convert_embedding(node, **kwargs): """Map MXNet's Embedding operator attributes to onnx's Gather operator.""" + from onnx.helper import make_node + from onnx import TensorProto + name, input_nodes, attrs = get_inputs(node, kwargs) axis = int(attrs.get('axis', 0)) - node = onnx.helper.make_node( - "Gather", - [input_nodes[1], input_nodes[0]], - [name], - axis=axis, - name=name - ) - return [node] + nodes = [ + make_node('Cast', [input_nodes[0]], [name+'_indices_casted'], to=int(TensorProto.INT64)), + make_node('Gather', [input_nodes[1], name+'_indices_casted'], [name], axis=axis, name=name) + ] + return nodes @mx_op.register("stack") def convert_stack(node, **kwargs): @@ -2558,19 +2560,40 @@ def convert_stack(node, **kwargs): @mx_op.register("slice") def convert_slice(node, **kwargs): """Map MXNet's slice operator to onnx Slice operator.""" + from onnx.helper import make_node + name, input_nodes, attrs = get_inputs(node, kwargs) - starts = convert_string_to_list(attrs.get("begin")) - ends = convert_string_to_list(attrs.get("end")) - steps = attrs.get("step", []) + + starts = convert_string_to_list(attrs.get('begin')) + ends = convert_string_to_list(attrs.get('end')) + steps = convert_string_to_list(attrs.get('step','[]')) + + assert len(starts) == len(ends) + if len(steps) == 0 or (len(steps) == 1 and steps[0] == None): + steps = [1 for x in starts] + else: + assert len(steps) == len(starts) + steps = [1 if x is None else x for x in steps] + for i, s in enumerate(steps): + if s < 0: + raise NotImplementedError('slice operator does not support negative steps yet') + if starts[i] is None: + starts[i] = 0 + if ends[i] is None: + ends[i] = 2**63-1 + nodes = [ - create_const_node(name+"_begin", np.array(starts), kwargs), - create_const_node(name+"_end", np.array(ends), kwargs) + create_const_scalar_node(name+'_0_s', np.int64(0), kwargs), + create_const_scalar_node(name+'_1_s', np.int64(1), kwargs), + create_const_scalar_node(name+'_len_s', np.int64(len(starts)), kwargs), + make_node('Range', [name+'_0_s', name+'_len_s', name+'_1_s'], [name+'_axes']), + create_tensor(starts, name+'_starts', kwargs['initializer']), + create_tensor(ends, name+'_ends', kwargs['initializer']), + create_tensor(steps, name+'_steps', kwargs['initializer']), + make_node("Slice", [input_nodes[0], name+'_starts', name+'_ends', name+'_axes', + name+'_steps'], [name], name=name) ] - inputs = [input_nodes[0], name+"_begin", name+"_end"] - if len(steps) > 0: - nodes.append(create_const_node(name+"_steps", np.array(steps, dtype='int64'), kwargs)) - inputs.append(name+"_steps") - nodes.append(onnx.helper.make_node("Slice", inputs, [name], name=name)) + return nodes diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 87554ed795a8..8d1b40a66e84 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -71,9 +71,13 @@ def test_onnx_export_abs(tmp_path): op_export_test('abs', M, [x], tmp_path) -def test_onnx_export_slice(tmp_path): - M = def_model('slice', begin=(0,1), end=(2,4)) - x = mx.nd.array([[1,2,3,4],[5,6,7,8],[9,10,11,12]], dtype='float32') +@pytest.mark.parametrize('dtype', ['float32', 'float64', 'float16', 'int32', 'int64']) +@pytest.mark.parametrize('params', [[(0, 1), (2,3), (1, 1)], + [(None, 1), (2, None), None], + [(0, 0, 0), (None, 4, 5), (None, 1, 2)]]) +def test_onnx_export_slice(tmp_path, dtype, params): + M = def_model('slice', begin=params[0], end=params[1], step=params[2]) + x = mx.nd.arange(start=0, stop=60, dtype=dtype).reshape((3, 4, 5)) op_export_test('slice', M, [x], tmp_path) From 51c653c230b4cc8ad7c870e88b49fe5277a97632 Mon Sep 17 00:00:00 2001 From: Zhaoqi Zhu Date: Fri, 18 Dec 2020 20:51:36 -0800 Subject: [PATCH 2/2] Update _op_translations.py --- python/mxnet/contrib/onnx/mx2onnx/_op_translations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index c7636f4dda96..8ef820b7989e 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -2566,10 +2566,10 @@ def convert_slice(node, **kwargs): starts = convert_string_to_list(attrs.get('begin')) ends = convert_string_to_list(attrs.get('end')) - steps = convert_string_to_list(attrs.get('step','[]')) + steps = convert_string_to_list(attrs.get('step', '[]')) assert len(starts) == len(ends) - if len(steps) == 0 or (len(steps) == 1 and steps[0] == None): + if len(steps) == 0 or (len(steps) == 1 and steps[0] is None): steps = [1 for x in starts] else: assert len(steps) == len(starts)