From 2986cf7c30330a3f98eb09a3690d56f418af2954 Mon Sep 17 00:00:00 2001 From: padreofthegame Date: Fri, 2 Sep 2022 08:59:13 +0200 Subject: [PATCH] [Relay] Bug fix in relay.squeeze function. Also added functionality for parameter axis of type int --- python/tvm/relay/op/transform.py | 26 +++++++++++++++++++++----- tests/python/relay/test_op_level3.py | 7 ++++++- 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index e7ae5f7d8315..024da84cbfd8 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -204,23 +204,39 @@ def squeeze(data, axis=None): Parameters ---------- - data : tvm.relay.Expr + data : relay.Expr The input data to the operator. - axis : None or List[int] or Expr + axis : Union[None, int, Tuple[int], List[int]] or Expr The set of axes to remove. - If axis = None, remove all axis of dimensions 1. + If axis = None, remove all axes of dimension 1. If any specified axis has dimension that does not equal 1, it is an error. Returns ------- - result : tvm.relay.Expr + result : relay.Expr The squeezed result. """ if isinstance(axis, Constant): - axis = list(axis.data.numpy()) + if axis.data.shape: + axis = list(axis.data.numpy()) + else: + axis = [axis.data.numpy().item()] if isinstance(axis, Expr): return _dyn_make.squeeze(data, axis) + if isinstance(axis, int): + axis = [axis] + if isinstance(axis, (tuple, list)): + tempaxis = [] + for tmpax in axis: + if isinstance(tmpax, _expr.IntImm): + tempaxis.append(tmpax.value) + else: + try: + tempaxis.append(int(tmpax)) + except ValueError as err: + raise RuntimeError("Unrecognized axis type: %s" % err) + axis = tempaxis return _make.squeeze(data, axis) diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 400f7dcf0b42..0b1d750fb2f6 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -210,13 +210,18 @@ class TestSqueeze: ((1, 3, 2, 5), "float32", None), ((1, 3, 1), "float32", [0]), ((1, 2, 1, 2, 1), "float32", [0, 2]), + ((1, 3, 1), "float32", 2), + ((1, 3, 1), "float32", []), ) def test_squeeze(self, shape, dtype, axis): x = relay.var("x", relay.TensorType(shape, dtype)) squeeze = relay.squeeze(x, axis=axis) - np_axis = tuple(axis) if axis is not None else None + if isinstance(axis, int): + np_axis = (axis,) + else: + np_axis = tuple(axis) if axis is not None else None data = np.random.random_sample(shape).astype(dtype) op_res = create_executor().evaluate(squeeze, {x: relay.const(data)})