From 84c9ae519b8da7628c5bfce57a6d9dc38f8fdf43 Mon Sep 17 00:00:00 2001 From: Matthew Date: Tue, 7 Sep 2021 19:16:46 -0600 Subject: [PATCH] support slicing with out of order axes --- python/tvm/relay/frontend/onnx.py | 32 ++++++++-------------- tests/python/frontend/onnx/test_forward.py | 2 ++ 2 files changed, 13 insertions(+), 21 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 29221884702c..1a3bf09b164f 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1382,20 +1382,13 @@ class Slice(OnnxOpConverter): @classmethod def _common(cls, starts, ends, axes): - new_axes = [] - new_starts = [] - new_ends = [] - pop_index = 0 - for i in range(max(axes) + 1): - if i in axes: - new_axes.append(i) - new_starts.append(starts[pop_index]) - new_ends.append(ends[pop_index]) - pop_index += 1 - else: - new_axes.append(i) - new_starts.append(0) - new_ends.append(np.iinfo(np.int32).max) + N = max(axes) + 1 + new_axes = list(range(N)) + new_starts = [0] * N + new_ends = [np.iinfo(np.int32).max] * N + for i, axis in enumerate(axes): + new_starts[axis] = starts[i] + new_ends[axis] = ends[i] return new_starts, new_ends, new_axes @classmethod @@ -1408,13 +1401,10 @@ def _impl_v1(cls, inputs, attr, params): # Update the starts and ends according to axes if required. if isinstance(attr["axes"], int): attr["axes"] = (attr["axes"],) - if (max(attr["axes"]) + 1) != len(attr["axes"]): - new_starts, new_ends, new_axes = cls._common( - attr["starts"], attr["ends"], attr["axes"] - ) - attr["axes"] = new_axes - attr["starts"] = new_starts - attr["ends"] = new_ends + new_starts, new_ends, new_axes = cls._common(attr["starts"], attr["ends"], attr["axes"]) + attr["axes"] = new_axes + attr["starts"] = new_starts + attr["ends"] = new_ends except KeyError: pass begin = list(attr["starts"]) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 9e3b48de8764..687dbeaefb4b 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -885,10 +885,12 @@ def add_noop_to_input_attr(attr_name, attr): x = np.random.randn(20, 10, 5).astype(np.float32) _test_slice_iteration_v1(x, x[0:3, 0:10], starts=(0, 0), ends=(3, 10), axes=(0, 1)) + _test_slice_iteration_v1(x, x[0:3, 0:10], starts=(0, 0), ends=(10, 3), axes=(1, 0)) _test_slice_iteration_v1(x, x[:, :, 3:4], starts=(0, 0, 3), ends=(20, 10, 4)) _test_slice_iteration_v1(x, x[:, 1:1000], starts=(1,), ends=(1000,), axes=(1,)) _test_slice_iteration_v1(x, x[:, 0:-1], starts=(0,), ends=(-1,), axes=(1,)) _test_slice_iteration_v10(x, x[0:3, 0:10], starts=(0, 0), ends=(3, 10), axes=(0, 1)) + _test_slice_iteration_v10(x, x[0:3, 0:10], starts=(0, 0), ends=(10, 3), axes=(1, 0)) _test_slice_iteration_v10(x, x[:, :, 3:4], starts=(0, 0, 3), ends=(20, 10, 4)) _test_slice_iteration_v10(x, x[:, 1:1000], starts=(1,), ends=(1000,), axes=(1,)) _test_slice_iteration_v10(x, x[:, 0:-1], starts=(0,), ends=(-1,), axes=(1,))