From 9d5b020103555714128683c48704f6fd7a2ea3f0 Mon Sep 17 00:00:00 2001 From: Siju Samuel Date: Mon, 22 Oct 2018 11:34:19 +0530 Subject: [PATCH 1/4] [RELAY]reshape_like --- docs/langref/relay_op.rst | 2 + include/tvm/relay/type.h | 3 ++ python/tvm/relay/op/transform.py | 23 ++++++++++++ src/relay/ir/type.cc | 8 ++++ src/relay/op/tensor/transform.cc | 55 ++++++++++++++++++++++++++++ tests/python/relay/test_op_level3.py | 17 +++++++++ 6 files changed, 108 insertions(+) diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index 42883f5f77da..91fdaef6f14c 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -78,6 +78,7 @@ This level enables additional math and transform operators. tvm.relay.ones tvm.relay.ones_like tvm.relay.reshape + tvm.relay.reshape_like tvm.relay.copy tvm.relay.transpose tvm.relay.floor @@ -188,6 +189,7 @@ Level 3 Definitions .. autofunction:: tvm.relay.abs .. autofunction:: tvm.relay.negative .. autofunction:: tvm.relay.reshape +.. autofunction:: tvm.relay.reshape_like .. autofunction:: tvm.relay.copy .. autofunction:: tvm.relay.transpose .. autofunction:: tvm.relay.take diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 2bb9b3070270..6612cfaea88a 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -82,6 +82,9 @@ class TensorTypeNode : public BaseTensorTypeNode { v->Visit("span", &span); } + /*! \brief Return product of elements in the shape */ + TVM_DLL IndexExpr Size() const; + TVM_DLL static TensorType make(Array shape, DataType dtype); /*! \brief Construct an scalar containing elements of dtype. */ diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 84e2398f0a9e..1eb3d45ba95a 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -141,6 +141,29 @@ def reshape(data, newshape): return _make.reshape(data, list(newshape)) +def reshape_like(data, shape_like): + """Reshapes the input array by the size of another array. + For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes + the input array into an output array with the same shape as the second input array. + .. note:: + Sizes for both array should be compatible. + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + shape_like : tuple of int + The new shape. Should be compatible with the original shape. + + Returns + ------- + ret : relay.Expr + The computed result. + """ + return _make.reshape_like(data, shape_like) + + def take(data, indices, axis=None): """Take elements from an array along an axis. diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc index d6fc2e85b2d8..ec5ca713d70a 100644 --- a/src/relay/ir/type.cc +++ b/src/relay/ir/type.cc @@ -22,6 +22,14 @@ TensorType TensorTypeNode::Scalar(DataType dtype) { return TensorTypeNode::make({}, dtype); } +IndexExpr TensorTypeNode::Size() const { + IndexExpr size = make_const(Int(64), 1); + for (IndexExpr i : shape) { + size *= i; + } + return size; +} + TVM_REGISTER_NODE_TYPE(TensorTypeNode); TVM_REGISTER_API("relay._make.TensorType") diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index bab875fd190e..f9aff99e523f 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -376,6 +376,61 @@ Example:: .set_support_level(3) .add_type_rel("Reshape", ReshapeRel); + +/*! +* \brief ReshapeLikeRel Output type and shape relation evaluation function. +* \param num_inputs Number of input types in the args. +* \param attrs The additional attributes of the operator. +* \param reporter The reporter to report solution to. +* \return false if This relation cannot be resolved. true if this relation has been resolved. +*/ +bool ReshapeLikeRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + if (data == nullptr) { + return false; + } + const auto* reshape_like = types[1].as(); + if (reshape_like == nullptr) { + return false; + } + CHECK(reporter->AssertEQ(data->Size(), reshape_like->Size())) + << "Reshape inputs size should be compatible"; + reporter->Assign(types[2], TensorTypeNode::make(reshape_like->shape, data->dtype)); + return true; +} + + +Expr MakeReshapeLike(Expr data, + Expr shape_like) { + static const Op& op = Op::Get("reshape_like"); + return CallNode::make(op, {data, shape_like}, Attrs(), {}); +} + + +TVM_REGISTER_API("relay.op._make.reshape_like") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeReshapeLike, args, rv); +}); + + +RELAY_REGISTER_OP("reshape_like") +.describe(R"code(Reshapes the input array by the size of another array. +For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes +the input array into an output array with the same shape as the second input array. +.. note:: + Sizes for both array should be compatible. +)code" TVM_ADD_FILELINE) +.set_num_inputs(2) +.add_argument("data", "Tensor", "The input tensor.") +.add_argument("shape_like", "Tensor", "Shape tensor.") +.set_support_level(3) +.add_type_rel("ReshapeLike", ReshapeLikeRel); + + // Take TVM_REGISTER_NODE_TYPE(TakeAttrs); diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 8ab3c41c079d..0307eb782c7b 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -88,6 +88,22 @@ def test_reshape_infer_type(): (n, t, 2000), "float32") +def test_reshape_like(): + # concrete shape + x = relay.var("x", relay.TensorType((1, 2, 3), "float32")) + y = relay.var("y", relay.TensorType((1,6), "float32")) + z = relay.reshape_like(x, y) + zz = relay.ir_pass.infer_type(z) + assert zz.checked_type == relay.TensorType((1, 6), "float32") + + # symbolic shape + n, c, h, w = tvm.var("n"), 2, 3, tvm.var("w") + x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) + y = relay.var("y", relay.TensorType((1, 8, 8), "float32")) + z = relay.reshape_like(x, y) + zz = relay.ir_pass.infer_type(z) + assert zz.checked_type == relay.TensorType((1, 8, 8), "float32") + def test_take_infer_type(): def verify_take(dshape, indices_shape, oshape, axis=None): @@ -155,6 +171,7 @@ def test_infer_type_leaky_relu(): test_clip_type() test_transpose_infer_type() test_reshape_infer_type() + test_reshape_like() test_take_infer_type() test_full() test_full_like() From dd8ddb14d4b5c914ee0175f33f66d986ff8219c7 Mon Sep 17 00:00:00 2001 From: Siju Samuel Date: Tue, 23 Oct 2018 19:38:58 +0530 Subject: [PATCH 2/4] Review comment fix --- src/relay/ir/type.cc | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc index ec5ca713d70a..bbe6472609df 100644 --- a/src/relay/ir/type.cc +++ b/src/relay/ir/type.cc @@ -23,9 +23,13 @@ TensorType TensorTypeNode::Scalar(DataType dtype) { } IndexExpr TensorTypeNode::Size() const { - IndexExpr size = make_const(Int(64), 1); - for (IndexExpr i : shape) { - size *= i; + if (shape.size() == 0) { + return make_const(Int(64), 1); + } + + IndexExpr size = shape[0]; + for (size_t i = 1; i < shape.size(); ++i) { + size *= shape[i]; } return size; } From 5a87c45d0d921a3c1d17405a9d336858670e0d92 Mon Sep 17 00:00:00 2001 From: Siju Samuel Date: Wed, 24 Oct 2018 08:43:43 +0530 Subject: [PATCH 3/4] Review comments --- src/relay/op/tensor/transform.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index f9aff99e523f..c821104f693a 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -378,11 +378,12 @@ Example:: /*! -* \brief ReshapeLikeRel Output type and shape relation evaluation function. +* \brief ReshapeLikeRel User defined type constraint function. * \param num_inputs Number of input types in the args. * \param attrs The additional attributes of the operator. * \param reporter The reporter to report solution to. -* \return false if This relation cannot be resolved. true if this relation has been resolved. +* \return False if the relation has not been resolved, it might be resolved later. +* True if this relation has been resolved. */ bool ReshapeLikeRel(const Array& types, int num_inputs, @@ -398,7 +399,7 @@ bool ReshapeLikeRel(const Array& types, return false; } CHECK(reporter->AssertEQ(data->Size(), reshape_like->Size())) - << "Reshape inputs size should be compatible"; + << "Reshape inputs size should be compatible."; reporter->Assign(types[2], TensorTypeNode::make(reshape_like->shape, data->dtype)); return true; } From 57bfe43277d97b23b90c94b6403b646fb4b37fe7 Mon Sep 17 00:00:00 2001 From: Siju Samuel Date: Thu, 25 Oct 2018 08:07:32 +0530 Subject: [PATCH 4/4] Review comment, added docstring --- include/tvm/relay/type.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 6612cfaea88a..e2a0c2a2a7ed 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -82,7 +82,9 @@ class TensorTypeNode : public BaseTensorTypeNode { v->Visit("span", &span); } - /*! \brief Return product of elements in the shape */ + /*! \brief Return product of elements in the shape. + * \return (d1 * d_2 ... * d_n) if shape is (d_1, d_2, ..., d_n) and 1 if shape size is zero. + */ TVM_DLL IndexExpr Size() const; TVM_DLL static TensorType make(Array shape, DataType dtype);