Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ This level enables additional math and transform operators.
tvm.relay.full
tvm.relay.full_like
tvm.relay.cast
tvm.relay.split


**Level 4: Broadcast and Reductions**
Expand Down Expand Up @@ -198,6 +199,7 @@ Level 3 Definitions
.. autofunction:: tvm.relay.full
.. autofunction:: tvm.relay.full_like
.. autofunction:: tvm.relay.cast
.. autofunction:: tvm.relay.split


Level 4 Definitions
Expand Down
16 changes: 16 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,22 @@ struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
}
}; // struct SqueezeAttrs

struct SplitAttrs : public tvm::AttrsNode<SplitAttrs> {
NodeRef indices_or_sections;
int axis;

TVM_DECLARE_ATTRS(SplitAttrs, "relay.attrs.SplitAttrs") {
TVM_ATTR_FIELD(indices_or_sections)
.describe("Indices or sections to split into. Accepts an int or a tuple"
"If indices_or_sections is an integer, the input will be divided equally"
"along given axis. If such a split is not possible, an error is raised."
"If indices_or_sections is a tuple of sorted integers,"
"the entries indicate where along axis the array is split.");
TVM_ATTR_FIELD(axis).set_default(0)
.describe("the axis to be splitted.");
}
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
2 changes: 1 addition & 1 deletion nnvm/src/top/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ along which to split the array.
return Array<Tensor>{ topi::split(inputs[0], indices, param.axis) };
}
})
.set_support_level(1);
.set_support_level(3);

// cast
DMLC_REGISTER_PARAMETER(CastParam);
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as _np
from .base import RelayNode, register_relay_node
from . import _make
from . import _expr
from . import ty as _ty
from .._ffi import base as _base
from .. import nd as _nd
Expand Down Expand Up @@ -284,6 +285,16 @@ def astuple(self):
as an argument to an FFI function."""
return self.tuple_value

def astext(self):
"""Get the text format of the tuple expression.

Returns
-------
text : str
The text format of the tuple expression.
"""
return _expr._text_print(self.tuple_value)

def __getitem__(self, index):
if index >= len(self):
raise IndexError("Tuple index out of range")
Expand Down
35 changes: 34 additions & 1 deletion python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Transform operators."""

from . import _make
from ..expr import TupleWrapper


def expand_dims(data, axis, num_newaxis=1):
Expand Down Expand Up @@ -146,7 +147,7 @@ def take(data, indices, axis=None):

Parameters
----------
a : relay.Expr
data : relay.Expr
The source array.

indices : rely.Expr
Expand Down Expand Up @@ -280,3 +281,35 @@ def collapse_sum_like(data, collapse_type):
The resulting tensor.
"""
return _make.collapse_sum_like(data, collapse_type)


def split(data, indices_or_sections, axis=0):
"""Split input tensor along axis by sections or indices.

If indices_or_sections is an integer, the input will be divided equally
along given axis. If such a split is not possible, an error is raised.

If indices_or_sections is a tuple of sorted integers,
the entries indicate where along axis the array is split.

Parameters
----------
data : relay.Expr
The source array.

indices_or_sections : int or tuple of int
Indices or sections to split into. Accepts an int or a tuple

axis : int, optional
The axis over which to split.

Returns
-------
ret : relay.Tuple([relay.Expr, relay.Expr])
The computed result.
"""
if isinstance(indices_or_sections, int):
ret_size = indices_or_sections
else:
ret_size = len(indices_or_sections) + 1
return TupleWrapper(_make.split(data, indices_or_sections, axis), ret_size)
4 changes: 4 additions & 0 deletions src/lang/attr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class AttrFunctor<R(const NodeRef& n, Args...)> {
virtual R VisitAttr_(const ir::Add* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Sub* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Mul* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Div* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Mod* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Min* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::Max* op, Args... args) ATTR_FUNCTOR_DEFAULT;
Expand Down Expand Up @@ -96,6 +97,7 @@ class AttrFunctor<R(const NodeRef& n, Args...)> {
ATTR_FUNCTOR_DISPATCH(Add);
ATTR_FUNCTOR_DISPATCH(Sub);
ATTR_FUNCTOR_DISPATCH(Mul);
ATTR_FUNCTOR_DISPATCH(Div);
ATTR_FUNCTOR_DISPATCH(Min);
ATTR_FUNCTOR_DISPATCH(Max);
ATTR_FUNCTOR_DISPATCH(GE);
Expand Down Expand Up @@ -135,6 +137,7 @@ class AttrsEqualHandler :
bool VisitAttr_(const ir::Add* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Sub* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Mul* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Div* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Mod* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Min* lhs, const NodeRef& other) final;
bool VisitAttr_(const ir::Max* lhs, const NodeRef& other) final;
Expand Down Expand Up @@ -174,6 +177,7 @@ class AttrsHashHandler :
size_t VisitAttr_(const ir::Add* op) final;
size_t VisitAttr_(const ir::Sub* op) final;
size_t VisitAttr_(const ir::Mul* op) final;
size_t VisitAttr_(const ir::Div* op) final;
size_t VisitAttr_(const ir::Mod* op) final;
size_t VisitAttr_(const ir::Min* op) final;
size_t VisitAttr_(const ir::Max* op) final;
Expand Down
2 changes: 2 additions & 0 deletions src/lang/attrs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const NodeRef& other)
TVM_DEFINE_ATTRS_BINOP_EQUAL(Add);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Sub);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Mul);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Div);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Mod);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Max);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Min);
Expand Down Expand Up @@ -243,6 +244,7 @@ size_t AttrsHashHandler::VisitAttr_(const StrMapNode* lhs) {
TVM_DEFINE_ATTRS_BINOP_HASH(Add);
TVM_DEFINE_ATTRS_BINOP_HASH(Sub);
TVM_DEFINE_ATTRS_BINOP_HASH(Mul);
TVM_DEFINE_ATTRS_BINOP_HASH(Div);
TVM_DEFINE_ATTRS_BINOP_HASH(Mod);
TVM_DEFINE_ATTRS_BINOP_HASH(Max);
TVM_DEFINE_ATTRS_BINOP_HASH(Min);
Expand Down
97 changes: 97 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/ir_operator.h>
#include <tvm/ir.h>
#include <vector>
#include "../op_common.h"


namespace tvm {
namespace relay {
using ir::IntImm;

// relay.cast
TVM_REGISTER_NODE_TYPE(CastAttrs);
Expand Down Expand Up @@ -834,5 +836,100 @@ RELAY_REGISTER_OP("broadcast_to_like")
.set_support_level(10)
.add_type_rel("BroadCastToLike", BroadCastToLikeRel);

// Split
TVM_REGISTER_NODE_TYPE(SplitAttrs);

bool SplitRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [data, result]
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
CHECK(data != nullptr);
CHECK_NE(data->shape.size(), 0) << "Input shape cannot be empty";
const auto param = attrs.as<SplitAttrs>();
CHECK(param != nullptr);
auto axis = param->axis;
if (axis < 0) {
axis += data->shape.size();
}
CHECK_LT(axis, data->shape.size())
<< "axis should be within the input dimension range.";
CHECK_GT(axis, 0)
<< "axis should be within the input dimension range.";

if (const IntImm* sections = param->indices_or_sections.as<IntImm>()) {
CHECK(reporter->Assert(data->shape[axis] %
sections->value == make_zero(Int(64))))
<< "indices_or_sections need to be able to divide input.shape[axis]";
std::vector<Type> fields;
for (int i = 0; i < sections->value; ++i) {
std::vector<IndexExpr>&& oshape = AsVector(data->shape);
oshape[axis] /= int32_t(sections->value);
auto vec_type = TensorTypeNode::make(oshape, data->dtype);
fields.push_back(vec_type);
}
reporter->Assign(types[1], TupleTypeNode::make(Array<Type>(fields)));
} else {
auto indices = param->indices_or_sections.as<ArrayNode>()->data;
auto begin = IndexExpr(make_zero(Int(32)));
std::vector<Type> fields;
for (uint i = 0; i < indices.size(); ++i) {
CHECK(reporter->Assert(IndexExpr(indices[i]) > begin))
<< "indices_or_sections need to be a sorted ascending list";
std::vector<IndexExpr>&& oshape = AsVector(data->shape);
oshape[axis] = IndexExpr(indices[i]) - begin;
begin = IndexExpr(indices[i]);
auto vec_type = TensorTypeNode::make(oshape, data->dtype);
fields.push_back(vec_type);
}
CHECK(reporter->Assert(begin < data->shape[axis]))
<< "The sum of sections must match the input.shape[axis]";
std::vector<IndexExpr>&& oshape = AsVector(data->shape);
oshape[axis] = data->shape[axis] - begin;
auto vec_type = TensorTypeNode::make(oshape, data->dtype);
fields.push_back(vec_type);
reporter->Assign(types[1], TupleTypeNode::make(Array<Type>(fields)));
}
return true;
}

Expr MakeSplit(Expr data,
NodeRef indices_or_sections,
int axis) {
auto attrs = make_node<SplitAttrs>();
attrs->axis = axis;
attrs->indices_or_sections = std::move(indices_or_sections);
static const Op& op = Op::Get("split");
return CallNode::make(op, {data}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay.op._make.split")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
if (args.type_codes[1] == kDLInt) {
*rv = MakeSplit(args[0], make_const(Int(64), int64_t(args[1])), args[2]);
} else {
*rv = MakeSplit(args[0], args[1], args[2]);
}
});

RELAY_REGISTER_OP("split")
.describe(R"code(Splits an array along a particular axis into multiple sub-arrays.

Indices or sections to split into. Accepts an int or a tuple
If indices_or_sections is an integer, the input will be divided equally
along given axis. If such a split is not possible, an error is raised.

If indices_or_sections is a tuple of sorted integers,
the entries indicate where along axis the array is split.

)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.SplitAttrs")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(3)
.add_type_rel("Split", SplitRel);

} // namespace relay
} // namespace tvm
33 changes: 33 additions & 0 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,38 @@ def verify_take(dshape, indices_shape, oshape, axis=None):
verify_take((d1, d2), (d3, d4, d5), (d1, d3, d4, d5), 1)
verify_take((d1, d2, d3, d4), (d5, d6), (d1, d2, d5, d6, d4), -2)

def test_split_infer_type():
def verify_split(dshape, indices_or_sections, ret_type, axis=None):
x = relay.var("x", relay.ty.TensorType(dshape, "float32"))
y = relay.split(x, indices_or_sections, axis=axis)
y.astext()
yy = relay.ir_pass.infer_type(y.astuple())
assert yy.checked_type == ret_type

d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4")
axis = tvm.var("axis")
verify_split((5, 5, 2, 2), 5,
relay.ty.TupleType(tvm.convert([
relay.ty.TensorType((5, 1, 2, 2), "float32"),
relay.ty.TensorType((5, 1, 2, 2), "float32"),
relay.ty.TensorType((5, 1, 2, 2), "float32"),
relay.ty.TensorType((5, 1, 2, 2), "float32"),
relay.ty.TensorType((5, 1, 2, 2), "float32")])),
axis=1)
verify_split((d1, d2, d3, d4), 4,
relay.ty.TupleType(tvm.convert([
relay.ty.TensorType((d1, d2, d3/4, d4), "float32"),
relay.ty.TensorType((d1, d2, d3/4, d4), "float32"),
relay.ty.TensorType((d1, d2, d3/4, d4), "float32"),
relay.ty.TensorType((d1, d2, d3/4, d4), "float32")])),
axis=2)
verify_split((d1, d2, d3, d4), (2, 4, 7),
relay.ty.TupleType(tvm.convert([
relay.ty.TensorType((d1, 2, d3, d4), "float32"),
relay.ty.TensorType((d1, 2, d3, d4), "float32"),
relay.ty.TensorType((d1, 3, d3, d4), "float32"),
relay.ty.TensorType((d1, (d2-7), d3, d4), "float32")])),
axis=1)

def test_full():
# default settings: match input dtype
Expand Down Expand Up @@ -161,3 +193,4 @@ def test_infer_type_leaky_relu():
test_infer_type_leaky_relu()
test_squeeze_infer_type()
test_squeeze_bad_axes_infer_type()
test_split_infer_type()