From 8720e1939ad8a391a866cf6f517edd26b34934cc Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Wed, 4 Nov 2020 14:41:46 -0800 Subject: [PATCH 1/5] add MXNet-style reshape_like attrs support --- include/tvm/relay/attrs/transform.h | 22 +++++++++++ python/tvm/relay/op/op_attrs.py | 5 +++ python/tvm/relay/op/transform.py | 4 +- src/relay/op/make_op.h | 3 ++ src/relay/op/tensor/transform.cc | 55 +++++++++++++++++++++++++--- src/relay/transforms/pattern_utils.h | 6 +-- tests/python/relay/test_op_level3.py | 31 +++++++++++++--- 7 files changed, 111 insertions(+), 15 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 683f5a28b4f4..45a95fe4787e 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -93,6 +93,28 @@ struct ReshapeAttrs : public tvm::AttrsNode { } }; // struct ReshapeAttrs +/*! \brief Attributes used in MXNet-style reshape_like operators */ +struct ReshapeLikeAttrs : public tvm::AttrsNode { + int lhs_begin; + Integer lhs_end; // can be None + int rhs_begin; + Integer rhs_end; // can be None + TVM_DECLARE_ATTRS(ReshapeLikeAttrs, "relay.attrs.ReshapeLikeAttrs") { + TVM_ATTR_FIELD(lhs_begin) + .set_default(0) + .describe("The axis of the input where reshaping should begin."); + TVM_ATTR_FIELD(lhs_end) + .set_default(NullValue()) + .describe("The axis of the input where reshaping should end, exclusive."); + TVM_ATTR_FIELD(rhs_begin) + .set_default(0) + .describe("The axis of the shape_like tensor to begin taking dimensions from."); + TVM_ATTR_FIELD(rhs_end) + .set_default(NullValue()) + .describe("The axis of the shape_like tensor to end taking dimensions from, exclusive."); + } +}; // struct ReshapeLikeAttrs + struct ScatterAttrs : public tvm::AttrsNode { Integer axis; diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 5dc2c2402c08..2c5f046bb7e8 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -194,6 +194,11 @@ class ReshapeAttrs(Attrs): """Attributes for transform.reshape""" +@tvm._ffi.register_object("relay.attrs.ReshapeLikeAttrs") +class ReshapeLikeAttrs(Attrs): + """Attributes for transform.reshape_like""" + + @tvm._ffi.register_object("relay.attrs.GatherAttrs") class GatherAttrs(Attrs): """Attributes for transform.gather""" diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 855fd9369c34..bba4d4990d9a 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -308,7 +308,7 @@ def scatter_add(data, indices, updates, axis): return _make.scatter_add(data, indices, updates, axis) -def reshape_like(data, shape_like): +def reshape_like(data, shape_like, lhs_begin=0, lhs_end=None, rhs_begin=0, rhs_end=None): """Reshapes the input array by the size of another array. For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes the input array into an output array with the same shape as the second input array. @@ -329,7 +329,7 @@ def reshape_like(data, shape_like): ret : relay.Expr The computed result. """ - return _make.reshape_like(data, shape_like) + return _make.reshape_like(data, shape_like, lhs_begin, lhs_end, rhs_begin, rhs_end) def take(data, indices, axis=None, mode="clip"): diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index 631ec4c0d2f5..01650e6c45c0 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -62,6 +62,9 @@ Expr MakeRepeat(Expr data, int repeats, int axis); Expr MakeReshape(Expr data, Array newshape); +Expr MakeReshapeLike(Expr lhs, Expr rhs, int64_t lhs_begin, Integer lhs_end, int64_t rhs_begin, + Integer rhs_end); + Expr MakeSplit(Expr data, ObjectRef indices_or_sections, int axis); Expr MakeSqueeze(Expr data, Array axis); diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index c0af0876fccb..9ce6a7aba828 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -453,6 +453,7 @@ RELAY_REGISTER_OP("transpose") /* relay.reshape */ TVM_REGISTER_NODE_TYPE(ReshapeAttrs); +TVM_REGISTER_NODE_TYPE(ReshapeLikeAttrs); Array infer_newshape(const Array& data_shape, const Attrs& attrs) { const auto* param = attrs.as(); @@ -641,11 +642,45 @@ bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, return true; } +Array infer_reshape_like(const Array& lhs_shape, + const Array& rhs_shape, const Attrs& attrs) { + const auto* like_attrs = attrs.as(); + CHECK(!like_attrs->lhs_end.defined() || like_attrs->lhs_end.as()) + << "lhs_end must be a concrete integer or None"; + CHECK(!like_attrs->rhs_end.defined() || like_attrs->rhs_end.as()) + << "rhs_end must be a concrete integer or None"; + int64_t lhs_shape_size = static_cast(lhs_shape.size()); + int64_t rhs_shape_size = static_cast(rhs_shape.size()); + int64_t lhs_begin = like_attrs->lhs_begin; + int64_t lhs_end = like_attrs->lhs_end.defined() ? like_attrs->lhs_end.as()->value + : lhs_shape_size; + int64_t rhs_begin = like_attrs->rhs_begin; + int64_t rhs_end = like_attrs->rhs_end.defined() ? like_attrs->rhs_end.as()->value + : rhs_shape_size; + lhs_begin = lhs_begin < 0 ? lhs_begin + lhs_shape_size : lhs_begin; + lhs_end = lhs_end < 0 ? lhs_end + lhs_shape_size : lhs_end; + rhs_begin = rhs_begin < 0 ? rhs_begin + rhs_shape_size : rhs_begin; + rhs_end = rhs_end < 0 ? rhs_end + rhs_shape_size : rhs_end; + Array shape_like; + for (auto i = 0; i < lhs_begin; i++) { + shape_like.push_back(lhs_shape[i]); + } + for (auto i = rhs_begin; i < rhs_end; i++) { + shape_like.push_back(rhs_shape[i]); + } + for (auto i = lhs_end; i < lhs_shape_size; i++) { + shape_like.push_back(lhs_shape[i]); + } + return shape_like; +} + Array ReshapeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { // Quick path for reshape_like if (!attrs.as()) { - return {topi::reshape(inputs[0], inputs[1]->shape)}; + ICHECK(attrs.as() != nullptr); + auto shape_like = infer_reshape_like(inputs[0]->shape, inputs[1]->shape, attrs); + return {topi::reshape(inputs[0], shape_like)}; } const auto* out_ttype = out_type.as(); @@ -746,6 +781,7 @@ Example:: */ bool ReshapeLikeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { + ICHECK(attrs.as() != nullptr); ICHECK_EQ(types.size(), 3); const auto* data = types[0].as(); if (data == nullptr) { @@ -755,6 +791,7 @@ bool ReshapeLikeRel(const Array& types, int num_inputs, const Attrs& attrs if (reshape_like == nullptr) { return false; } + auto shape_like = infer_reshape_like(data->shape, reshape_like->shape, attrs); // Only check When input data has static shape. bool is_static_shape = true; for (size_t i = 0; i < data->shape.size(); ++i) { @@ -763,17 +800,24 @@ bool ReshapeLikeRel(const Array& types, int num_inputs, const Attrs& attrs break; } } + auto output_type = TensorType(shape_like, data->dtype); if (is_static_shape) { - ICHECK(reporter->AssertEQ(data->Size(), reshape_like->Size())) + ICHECK(reporter->AssertEQ(data->Size(), output_type->Size())) << "Reshape inputs size should be compatible."; } - reporter->Assign(types[2], TensorType(reshape_like->shape, data->dtype)); + reporter->Assign(types[2], output_type); return true; } -Expr MakeReshapeLike(Expr data, Expr shape_like) { +Expr MakeReshapeLike(Expr lhs, Expr rhs, int64_t lhs_begin, Integer lhs_end, + int64_t rhs_begin, Integer rhs_end) { + auto attrs = make_object(); + attrs->lhs_begin = std::move(lhs_begin); + attrs->lhs_end = std::move(lhs_end); + attrs->rhs_begin = std::move(rhs_begin); + attrs->rhs_end = std::move(rhs_end); static const Op& op = Op::Get("reshape_like"); - return Call(op, {data, shape_like}, Attrs(), {}); + return Call(op, {lhs, rhs}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op._make.reshape_like").set_body_typed(MakeReshapeLike); @@ -785,6 +829,7 @@ the input array into an output array with the same shape as the second input arr .. note:: Sizes for both array should be compatible. )code" TVM_ADD_FILELINE) + .set_attrs_type() .set_num_inputs(2) .add_argument("data", "Tensor", "The input tensor.") .add_argument("shape_like", "Tensor", "Shape tensor.") diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h index 555391a27e4b..60867f4faaa0 100644 --- a/src/relay/transforms/pattern_utils.h +++ b/src/relay/transforms/pattern_utils.h @@ -594,9 +594,9 @@ inline Expr LeftShift(Expr x, Expr nbit) { return Call(op, {x, nbit}, Attrs(), {}); } -inline Expr ReshapeLike(Expr lhs, Expr rhs) { - static const Op& op = Op::Get("reshape_like"); - return Call(op, {lhs, rhs}, Attrs(), {}); +inline Expr ReshapeLike(Expr lhs, Expr rhs, int64_t lhs_begin, Integer lhs_end, + int64_t rhs_begin, Integer rhs_end) { + return MakeReshapeLike(lhs, rhs, lhs_begin, lhs_end, rhs_begin, rhs_end); } inline Expr Copy(Expr data) { diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 3ea0777df8ca..d773e1037430 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -316,17 +316,36 @@ def test_reshape_like_infer_type(): zz = run_infer_type(z) assert zz.checked_type == relay.TensorType((1, 8, 8), "float32") + # partial reshaping + x = relay.var("x", relay.TensorType((1, 2, 3, 4), "float32")) + y = relay.var("y", relay.TensorType((1, 6, 5), "float32")) + z = relay.reshape_like(x, y, lhs_begin=1, lhs_end=3, rhs_begin=1, rhs_end=2) + zz = run_infer_type(z) + assert zz.checked_type == relay.TensorType((1, 6, 4), "float32") + + # symbolic partial reshaping + n, c, h, w = te.size_var("n"), 2, 3, te.size_var("w") + x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) + y = relay.var("y", relay.TensorType((5, 6), "float32")) + z = relay.var("z", relay.TensorType((4,), "float32")) + w = relay.reshape_like(x, y, lhs_end=3) + w = relay.reshape_like(w, z, lhs_begin=2) + w = run_infer_type(w) + assert w.checked_type == relay.TensorType((5, 6, 4), "float32") + @tvm.testing.uses_gpu def test_reshape_like(): - def verify_reshape_like(shape, oshape): + def verify_reshape_like(shape, oshape, shape_like=None, reshape_like_kwargs={}): + if shape_like is None: + shape_like = oshape x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32") - y_data = np.random.uniform(low=-1, high=1, size=oshape).astype("float32") - ref_res = np.reshape(x_data, y_data.shape) + y_data = np.random.uniform(low=-1, high=1, size=shape_like).astype("float32") + ref_res = np.reshape(x_data, oshape) x = relay.var("x", relay.TensorType(shape, "float32")) - y = relay.var("x", relay.TensorType(oshape, "float32")) - z = relay.reshape_like(x, y) + y = relay.var("x", relay.TensorType(shape_like, "float32")) + z = relay.reshape_like(x, y, **reshape_like_kwargs) zz = run_infer_type(z) assert zz.checked_type == relay.ty.TensorType(ref_res.shape, "float32") @@ -340,6 +359,8 @@ def verify_reshape_like(shape, oshape): verify_reshape_like((2, 3, 4), (1, 8, 3)) verify_reshape_like((4, 7), (2, 7, 2)) + verify_reshape_like((1, 2, 3, 4), (1, 6, 4), (1, 6, 5), + dict(lhs_begin=1, lhs_end=3, rhs_begin=1, rhs_end=2)) def test_take_infer_type(): From a38bec34e9d31fedd81207130d6be7d690cf65de Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Wed, 4 Nov 2020 14:55:17 -0800 Subject: [PATCH 2/5] lint --- include/tvm/relay/attrs/transform.h | 10 ++++------ src/relay/op/tensor/transform.cc | 18 +++++++++--------- src/relay/transforms/pattern_utils.h | 4 ++-- tests/python/relay/test_op_level3.py | 5 +++-- 4 files changed, 18 insertions(+), 19 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 45a95fe4787e..a7830cf61647 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -100,15 +100,13 @@ struct ReshapeLikeAttrs : public tvm::AttrsNode { int rhs_begin; Integer rhs_end; // can be None TVM_DECLARE_ATTRS(ReshapeLikeAttrs, "relay.attrs.ReshapeLikeAttrs") { - TVM_ATTR_FIELD(lhs_begin) - .set_default(0) - .describe("The axis of the input where reshaping should begin."); + TVM_ATTR_FIELD(lhs_begin).set_default(0).describe( + "The axis of the input where reshaping should begin."); TVM_ATTR_FIELD(lhs_end) .set_default(NullValue()) .describe("The axis of the input where reshaping should end, exclusive."); - TVM_ATTR_FIELD(rhs_begin) - .set_default(0) - .describe("The axis of the shape_like tensor to begin taking dimensions from."); + TVM_ATTR_FIELD(rhs_begin).set_default(0).describe( + "The axis of the shape_like tensor to begin taking dimensions from."); TVM_ATTR_FIELD(rhs_end) .set_default(NullValue()) .describe("The axis of the shape_like tensor to end taking dimensions from, exclusive."); diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 9ce6a7aba828..78f68baf6d09 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -642,21 +642,21 @@ bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, return true; } -Array infer_reshape_like(const Array& lhs_shape, +Array infer_reshape_like(const Array& lhs_shape, const Array& rhs_shape, const Attrs& attrs) { const auto* like_attrs = attrs.as(); CHECK(!like_attrs->lhs_end.defined() || like_attrs->lhs_end.as()) - << "lhs_end must be a concrete integer or None"; + << "lhs_end must be a concrete integer or None"; CHECK(!like_attrs->rhs_end.defined() || like_attrs->rhs_end.as()) - << "rhs_end must be a concrete integer or None"; + << "rhs_end must be a concrete integer or None"; int64_t lhs_shape_size = static_cast(lhs_shape.size()); int64_t rhs_shape_size = static_cast(rhs_shape.size()); int64_t lhs_begin = like_attrs->lhs_begin; - int64_t lhs_end = like_attrs->lhs_end.defined() ? like_attrs->lhs_end.as()->value - : lhs_shape_size; + int64_t lhs_end = + like_attrs->lhs_end.defined() ? like_attrs->lhs_end.as()->value : lhs_shape_size; int64_t rhs_begin = like_attrs->rhs_begin; - int64_t rhs_end = like_attrs->rhs_end.defined() ? like_attrs->rhs_end.as()->value - : rhs_shape_size; + int64_t rhs_end = + like_attrs->rhs_end.defined() ? like_attrs->rhs_end.as()->value : rhs_shape_size; lhs_begin = lhs_begin < 0 ? lhs_begin + lhs_shape_size : lhs_begin; lhs_end = lhs_end < 0 ? lhs_end + lhs_shape_size : lhs_end; rhs_begin = rhs_begin < 0 ? rhs_begin + rhs_shape_size : rhs_begin; @@ -809,8 +809,8 @@ bool ReshapeLikeRel(const Array& types, int num_inputs, const Attrs& attrs return true; } -Expr MakeReshapeLike(Expr lhs, Expr rhs, int64_t lhs_begin, Integer lhs_end, - int64_t rhs_begin, Integer rhs_end) { +Expr MakeReshapeLike(Expr lhs, Expr rhs, int64_t lhs_begin, Integer lhs_end, int64_t rhs_begin, + Integer rhs_end) { auto attrs = make_object(); attrs->lhs_begin = std::move(lhs_begin); attrs->lhs_end = std::move(lhs_end); diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h index 60867f4faaa0..c63ba0ce9326 100644 --- a/src/relay/transforms/pattern_utils.h +++ b/src/relay/transforms/pattern_utils.h @@ -594,8 +594,8 @@ inline Expr LeftShift(Expr x, Expr nbit) { return Call(op, {x, nbit}, Attrs(), {}); } -inline Expr ReshapeLike(Expr lhs, Expr rhs, int64_t lhs_begin, Integer lhs_end, - int64_t rhs_begin, Integer rhs_end) { +inline Expr ReshapeLike(Expr lhs, Expr rhs, int64_t lhs_begin, Integer lhs_end, int64_t rhs_begin, + Integer rhs_end) { return MakeReshapeLike(lhs, rhs, lhs_begin, lhs_end, rhs_begin, rhs_end); } diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index d773e1037430..69fb0abf71b8 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -359,8 +359,9 @@ def verify_reshape_like(shape, oshape, shape_like=None, reshape_like_kwargs={}): verify_reshape_like((2, 3, 4), (1, 8, 3)) verify_reshape_like((4, 7), (2, 7, 2)) - verify_reshape_like((1, 2, 3, 4), (1, 6, 4), (1, 6, 5), - dict(lhs_begin=1, lhs_end=3, rhs_begin=1, rhs_end=2)) + verify_reshape_like( + (1, 2, 3, 4), (1, 6, 4), (1, 6, 5), dict(lhs_begin=1, lhs_end=3, rhs_begin=1, rhs_end=2) + ) def test_take_infer_type(): From e0d39a187577506bf3f50e71ea1458ccad9777b5 Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Thu, 5 Nov 2020 10:58:49 -0800 Subject: [PATCH 3/5] document, switch to int, add more tests, style --- python/tvm/relay/op/transform.py | 29 ++++++++++++++++++++++------ src/relay/op/make_op.h | 2 +- src/relay/op/tensor/transform.cc | 10 +++++++--- src/relay/transforms/pattern_utils.h | 2 +- tests/python/relay/test_op_level3.py | 9 +++++++++ 5 files changed, 41 insertions(+), 11 deletions(-) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index bba4d4990d9a..5d4e51d50127 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -309,20 +309,37 @@ def scatter_add(data, indices, updates, axis): def reshape_like(data, shape_like, lhs_begin=0, lhs_end=None, rhs_begin=0, rhs_end=None): - """Reshapes the input array by the size of another array. - For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes - the input array into an output array with the same shape as the second input array. + """Reshapes the input tensor by the size of another tensor. + For an input tensor with shape ``(d0, d1, ..., d(k-1))``, `reshape_like` operation reshapes + the input tensor into an output tensor with the same shape as the second input tensor, + in particular reshaping the dimensions of `data` in `[lhs_begin, lhs_end)` using the dimensions + from `shape_like` in `[rhs_begin, rhs_end)`. .. note:: - Sizes for both array should be compatible. + Sizes for `data` and the output tensor should be compatible. Parameters ---------- data : relay.Expr The input data to the operator. - shape_like : tuple of int - The new shape. Should be compatible with the original shape. + shape_like : relay.Expr + The tensor to reshape data like. Should be compatible with the original shape on the + reshaped dimensions. + + lhs_begin : int, optional + The axis of data to begin reshaping. Default is 0. + + lhs_end : int or None, optional + The axis of data where reshaping should stop, exclusive. Default is None which reshapes to + the end. + + rhs_begin : int, optional + The axis of shape_like where the target shape begins. Default is 0. + + rhs_end : int or None, optional + The axis of shape_like where the target shape ends, exclusive. Default is None which extends + to the end. Returns ------- diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index 01650e6c45c0..0e1f5c560081 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -62,7 +62,7 @@ Expr MakeRepeat(Expr data, int repeats, int axis); Expr MakeReshape(Expr data, Array newshape); -Expr MakeReshapeLike(Expr lhs, Expr rhs, int64_t lhs_begin, Integer lhs_end, int64_t rhs_begin, +Expr MakeReshapeLike(Expr lhs, Expr rhs, int lhs_begin, Integer lhs_end, int rhs_begin, Integer rhs_end); Expr MakeSplit(Expr data, ObjectRef indices_or_sections, int axis); diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 78f68baf6d09..c3634e9d860d 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -649,18 +649,22 @@ Array infer_reshape_like(const Array& lhs_shape, << "lhs_end must be a concrete integer or None"; CHECK(!like_attrs->rhs_end.defined() || like_attrs->rhs_end.as()) << "rhs_end must be a concrete integer or None"; + int64_t lhs_shape_size = static_cast(lhs_shape.size()); int64_t rhs_shape_size = static_cast(rhs_shape.size()); - int64_t lhs_begin = like_attrs->lhs_begin; + int64_t lhs_begin = static_cast(like_attrs->lhs_begin); int64_t lhs_end = like_attrs->lhs_end.defined() ? like_attrs->lhs_end.as()->value : lhs_shape_size; - int64_t rhs_begin = like_attrs->rhs_begin; + int64_t rhs_begin = static_cast(like_attrs->rhs_begin); int64_t rhs_end = like_attrs->rhs_end.defined() ? like_attrs->rhs_end.as()->value : rhs_shape_size; + + // handle negative axes lhs_begin = lhs_begin < 0 ? lhs_begin + lhs_shape_size : lhs_begin; lhs_end = lhs_end < 0 ? lhs_end + lhs_shape_size : lhs_end; rhs_begin = rhs_begin < 0 ? rhs_begin + rhs_shape_size : rhs_begin; rhs_end = rhs_end < 0 ? rhs_end + rhs_shape_size : rhs_end; + Array shape_like; for (auto i = 0; i < lhs_begin; i++) { shape_like.push_back(lhs_shape[i]); @@ -809,7 +813,7 @@ bool ReshapeLikeRel(const Array& types, int num_inputs, const Attrs& attrs return true; } -Expr MakeReshapeLike(Expr lhs, Expr rhs, int64_t lhs_begin, Integer lhs_end, int64_t rhs_begin, +Expr MakeReshapeLike(Expr lhs, Expr rhs, int lhs_begin, Integer lhs_end, int rhs_begin, Integer rhs_end) { auto attrs = make_object(); attrs->lhs_begin = std::move(lhs_begin); diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h index c63ba0ce9326..8ef86e088193 100644 --- a/src/relay/transforms/pattern_utils.h +++ b/src/relay/transforms/pattern_utils.h @@ -594,7 +594,7 @@ inline Expr LeftShift(Expr x, Expr nbit) { return Call(op, {x, nbit}, Attrs(), {}); } -inline Expr ReshapeLike(Expr lhs, Expr rhs, int64_t lhs_begin, Integer lhs_end, int64_t rhs_begin, +inline Expr ReshapeLike(Expr lhs, Expr rhs, int lhs_begin, Integer lhs_end, int rhs_begin, Integer rhs_end) { return MakeReshapeLike(lhs, rhs, lhs_begin, lhs_end, rhs_begin, rhs_end); } diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 69fb0abf71b8..c677ce354435 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -323,6 +323,15 @@ def test_reshape_like_infer_type(): zz = run_infer_type(z) assert zz.checked_type == relay.TensorType((1, 6, 4), "float32") + x = relay.var("x", relay.TensorType((1, 2, 3, 4), "float32")) + y = relay.var("y", relay.TensorType((2, 3, 4, 1, 6), "float32")) + z = relay.reshape_like(x, y, rhs_end=3) + zz = run_infer_type(z) + assert zz.checked_type == relay.TensorType((2, 3, 4), "float32") + z = relay.reshape_like(x, y, rhs_begin=2) + zz = run_infer_type(z) + assert zz.checked_type == relay.TensorType((4, 1, 6), "float32") + # symbolic partial reshaping n, c, h, w = te.size_var("n"), 2, 3, te.size_var("w") x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) From 81c2162f27644c08004dab1e2ca21ad31382d1a7 Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Thu, 5 Nov 2020 16:05:27 -0800 Subject: [PATCH 4/5] add example usage in documentation --- python/tvm/relay/op/transform.py | 9 +++++++++ src/relay/op/tensor/transform.cc | 7 +++++++ 2 files changed, 16 insertions(+) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 5d4e51d50127..c951c9e77ca5 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -345,6 +345,15 @@ def reshape_like(data, shape_like, lhs_begin=0, lhs_end=None, rhs_begin=0, rhs_e ------- ret : relay.Expr The computed result. + + Examples + -------- + .. code-block:: python + data.shape == (1, 2, 3, 4) + shape_like.shape == (6, 2, 2, 3) + + ret = relay.reshape_like(data, shape_like, lhs_begin=1, rhs_end=3) + ret.shape == (1, 6, 2, 2) """ return _make.reshape_like(data, shape_like, lhs_begin, lhs_end, rhs_begin, rhs_end) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index c3634e9d860d..71f88b2f258e 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -832,6 +832,13 @@ For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation re the input array into an output array with the same shape as the second input array. .. note:: Sizes for both array should be compatible. +Example:: + + data.shape == (1, 2, 3, 4) + shape_like.shape == (6, 2, 2, 3) + + ret = reshape_like(data, shape_like, lhs_begin=1, rhs_end=3) + ret.shape == (1, 6, 2, 2) )code" TVM_ADD_FILELINE) .set_attrs_type() .set_num_inputs(2) From 46d7f64627d1a0399e663cec41ea29db23409919 Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Thu, 5 Nov 2020 19:33:13 -0800 Subject: [PATCH 5/5] fix doc formatting --- python/tvm/relay/op/transform.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index c951c9e77ca5..a3f97392e36e 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -349,6 +349,7 @@ def reshape_like(data, shape_like, lhs_begin=0, lhs_end=None, rhs_begin=0, rhs_e Examples -------- .. code-block:: python + data.shape == (1, 2, 3, 4) shape_like.shape == (6, 2, 2, 3)