Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,17 @@ This level enables additional math and transform operators.
tvm.relay.image.resize


**Level 10: Temporary Operators**

This level support backpropagation of broadcast operators. It is temporary.

.. autosummary::
:nosignatures:

tvm.relay.broadcast_to_like
tvm.relay.collapse_sum_like


Level 1 Definitions
-------------------
.. autofunction:: tvm.relay.log
Expand Down Expand Up @@ -199,6 +210,13 @@ Level 4 Definitions
.. autofunction:: tvm.relay.prod



Level 5 Definitions
-------------------
.. autofunction:: tvm.relay.image.resize


Level 10 Definitions
--------------------
.. autofunction:: tvm.relay.broadcast_to_like
.. autofunction:: tvm.relay.collapse_sum_like
38 changes: 38 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,3 +242,41 @@ def where(condition, x, y):
Note that the shape of condition, x, and y needs to be the same.
"""
return _make.where(condition, x, y)


def broadcast_to_like(data, broadcast_type):
"""Return an scalar value array with the same shape and type as the input array.

Parameters
----------
data : relay.Expr
The input tensor.

broadcast_type : relay.Expr
Provide the type to broadcast to.

Returns
-------
result : relay.Expr
The resulting tensor.
"""
return _make.broadcast_to_like(data, broadcast_type)


def collapse_sum_like(data, collapse_type):
"""Return an scalar value array with the same shape and type as the input array.

Parameters
----------
data : relay.Expr
The input tensor.

collapse_type : relay.Expr
Provide the type to collapse to.

Returns
-------
result : relay.Expr
The resulting tensor.
"""
return _make.collapse_sum_like(data, collapse_type)
61 changes: 61 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -718,5 +718,66 @@ RELAY_REGISTER_OP("squeeze")
.set_support_level(3)
.add_type_rel("Squeeze", SqueezeRel);

// Have no idea how to assert the constraint.
// CollapseSumLike: <A, B> -> B where BroadCast(A, B) = A
bool CollapseSumLikeRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
reporter->Assign(types[2], types[1]);
Copy link
Contributor

@slyubomirsky slyubomirsky Oct 11, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe see ReduceShapeImpl in https://github.com/dmlc/tvm/blob/master/nnvm/src/top/tensor/reduce.cc for guidance on how the axis params are supposed to work? (Strangely, collapse_sum itself doesn't use the reduction.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am confused as well - I originally thought I need reduce_shape (which was at another pr), so at the mean time I can get this up.

return true;
}

Expr MakeCollapseSumLike(Expr data,
Expr collapse_type) {
static const Op& op = Op::Get("collapse_sum_like");
return CallNode::make(op, {data, collapse_type}, Attrs(), {});
}

TVM_REGISTER_API("relay.op._make.collapse_sum_like")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeCollapseSumLike, args, rv);
});

RELAY_REGISTER_OP("collapse_sum_like")
.describe(R"code(Collapse the first input to match the shape of the second input.
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("collapse_type", "Tensor", "Provide the type to collapse to.")
.set_support_level(10)
.add_type_rel("CollapseSumLike", CollapseSumLikeRel);

// BroadCastToLike: <A, B> -> B where BroadCast(A, B) = B
bool BroadCastToLikeRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
reporter->Assign(types[2], types[1]);
return true;
}

Expr MakeBroadCastToLike(Expr data,
Expr broadcast_type) {
static const Op& op = Op::Get("broadcast_to_like");
return CallNode::make(op, {data, broadcast_type}, Attrs(), {});
}

TVM_REGISTER_API("relay.op._make.broadcast_to_like")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeBroadCastToLike, args, rv);
});

RELAY_REGISTER_OP("broadcast_to_like")
.describe(R"code(Broadcast the first input to match the shape of the second input.
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("broadcast_type", "Tensor", "Provide the type to broadcast to.")
.set_support_level(10)
.add_type_rel("BroadCastToLike", BroadCastToLikeRel);

} // namespace relay
} // namespace tvm
23 changes: 23 additions & 0 deletions tests/python/relay/test_op_level10.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
""" Support level10 operator test cases.
"""
import tvm
from tvm import relay

def test_collapse_sum_like():
x = relay.Var("x", relay.ty.TensorType((3, 4, 5, 6), "int8"))
y = relay.Var("y", relay.ty.TensorType((4, 1, 6), "int8"))
z = relay.collapse_sum_like(x, y)
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.ty.TensorType((4, 1, 6), "int8")


def test_broadcast_to_like():
x = relay.Var("x", relay.ty.TensorType((3, 4, 5, 6), "int8"))
y = relay.Var("y", relay.ty.TensorType((4, 1, 6), "int8"))
z = relay.broadcast_to_like(y, x)
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.ty.TensorType((3, 4, 5, 6), "int8")

if __name__ == "__main__":
test_collapse_sum_like()
test_broadcast_to_like()
1 change: 1 addition & 0 deletions tests/python/relay/test_pass_alpha_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,3 +461,4 @@ def test_op_alpha_equal():
test_let_alpha_equal()
test_if_alpha_equal()
test_op_alpha_equal()
test_var_alpha_equal()