Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ This level enables additional math and transform operators.
tvm.relay.ones
tvm.relay.ones_like
tvm.relay.reshape
tvm.relay.reshape_like
tvm.relay.copy
tvm.relay.transpose
tvm.relay.floor
Expand Down Expand Up @@ -188,6 +189,7 @@ Level 3 Definitions
.. autofunction:: tvm.relay.abs
.. autofunction:: tvm.relay.negative
.. autofunction:: tvm.relay.reshape
.. autofunction:: tvm.relay.reshape_like
.. autofunction:: tvm.relay.copy
.. autofunction:: tvm.relay.transpose
.. autofunction:: tvm.relay.take
Expand Down
5 changes: 5 additions & 0 deletions include/tvm/relay/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ class TensorTypeNode : public BaseTensorTypeNode {
v->Visit("span", &span);
}

/*! \brief Return product of elements in the shape.
* \return (d1 * d_2 ... * d_n) if shape is (d_1, d_2, ..., d_n) and 1 if shape size is zero.
*/
TVM_DLL IndexExpr Size() const;

TVM_DLL static TensorType make(Array<IndexExpr> shape, DataType dtype);

/*! \brief Construct an scalar containing elements of dtype. */
Expand Down
23 changes: 23 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,29 @@ def reshape(data, newshape):
return _make.reshape(data, list(newshape))


def reshape_like(data, shape_like):
"""Reshapes the input array by the size of another array.
For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes
the input array into an output array with the same shape as the second input array.
.. note::
Sizes for both array should be compatible.

Parameters
----------
data : relay.Expr
The input data to the operator.

shape_like : tuple of int
The new shape. Should be compatible with the original shape.

Returns
-------
ret : relay.Expr
The computed result.
"""
return _make.reshape_like(data, shape_like)


def take(data, indices, axis=None):
"""Take elements from an array along an axis.

Expand Down
12 changes: 12 additions & 0 deletions src/relay/ir/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,18 @@ TensorType TensorTypeNode::Scalar(DataType dtype) {
return TensorTypeNode::make({}, dtype);
}

IndexExpr TensorTypeNode::Size() const {
if (shape.size() == 0) {
return make_const(Int(64), 1);
}

IndexExpr size = shape[0];
for (size_t i = 1; i < shape.size(); ++i) {
size *= shape[i];
}
return size;
}

TVM_REGISTER_NODE_TYPE(TensorTypeNode);

TVM_REGISTER_API("relay._make.TensorType")
Expand Down
56 changes: 56 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,62 @@ Example::
.set_support_level(3)
.add_type_rel("Reshape", ReshapeRel);


/*!
* \brief ReshapeLikeRel User defined type constraint function.
* \param num_inputs Number of input types in the args.
* \param attrs The additional attributes of the operator.
* \param reporter The reporter to report solution to.
* \return False if the relation has not been resolved, it might be resolved later.
* True if this relation has been resolved.
*/
bool ReshapeLikeRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
return false;
}
const auto* reshape_like = types[1].as<TensorTypeNode>();
if (reshape_like == nullptr) {
return false;
}
CHECK(reporter->AssertEQ(data->Size(), reshape_like->Size()))
<< "Reshape inputs size should be compatible.";
reporter->Assign(types[2], TensorTypeNode::make(reshape_like->shape, data->dtype));
return true;
}


Expr MakeReshapeLike(Expr data,
Expr shape_like) {
static const Op& op = Op::Get("reshape_like");
return CallNode::make(op, {data, shape_like}, Attrs(), {});
}


TVM_REGISTER_API("relay.op._make.reshape_like")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeReshapeLike, args, rv);
});


RELAY_REGISTER_OP("reshape_like")
.describe(R"code(Reshapes the input array by the size of another array.
For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes
the input array into an output array with the same shape as the second input array.
.. note::
Sizes for both array should be compatible.
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("shape_like", "Tensor", "Shape tensor.")
.set_support_level(3)
.add_type_rel("ReshapeLike", ReshapeLikeRel);


// Take
TVM_REGISTER_NODE_TYPE(TakeAttrs);

Expand Down
17 changes: 17 additions & 0 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,22 @@ def test_reshape_infer_type():
(n, t, 2000), "float32")


def test_reshape_like():
# concrete shape
x = relay.var("x", relay.TensorType((1, 2, 3), "float32"))
y = relay.var("y", relay.TensorType((1,6), "float32"))
z = relay.reshape_like(x, y)
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.TensorType((1, 6), "float32")

# symbolic shape
n, c, h, w = tvm.var("n"), 2, 3, tvm.var("w")
x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
y = relay.var("y", relay.TensorType((1, 8, 8), "float32"))
z = relay.reshape_like(x, y)
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.TensorType((1, 8, 8), "float32")


def test_take_infer_type():
def verify_take(dshape, indices_shape, oshape, axis=None):
Expand Down Expand Up @@ -155,6 +171,7 @@ def test_infer_type_leaky_relu():
test_clip_type()
test_transpose_infer_type()
test_reshape_infer_type()
test_reshape_like()
test_take_infer_type()
test_full()
test_full_like()
Expand Down