From 14c3e68c527a61cdf1fe84c5a19f6eed4e77ae96 Mon Sep 17 00:00:00 2001 From: SRK Reddy Date: Wed, 10 Oct 2018 14:28:47 +0530 Subject: [PATCH 01/11] [RELAY][OP] Split --- include/tvm/relay/attrs/transform.h | 15 +++++ nnvm/src/top/tensor/transform.cc | 2 +- python/tvm/relay/op/transform.py | 32 +++++++++- src/relay/op/tensor/transform.cc | 89 ++++++++++++++++++++++++++++ tests/python/relay/test_op_level3.py | 37 ++++++++++++ 5 files changed, 173 insertions(+), 2 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index b0150c4ac3d9..4493050b3ba9 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -106,6 +106,21 @@ struct SqueezeAttrs : public tvm::AttrsNode { } }; // struct SqueezeAttrs +struct SplitAttrs : public tvm::AttrsNode { + Array indices_or_sections; + int axis; + bool equal_split; + + TVM_DECLARE_ATTRS(SplitAttrs, "relay.attrs.SplitAttrs") { + TVM_ATTR_FIELD(indices_or_sections) + .describe("Number of outputs to be splitted"); + TVM_ATTR_FIELD(axis).set_lower_bound(0).set_default(1) + .describe("the axis to be splitted."); + TVM_ATTR_FIELD(equal_split).set_default(false) + .describe("Is it equal split of input"); + } +}; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_TRANSFORM_H_ diff --git a/nnvm/src/top/tensor/transform.cc b/nnvm/src/top/tensor/transform.cc index a8159b539410..8e35039a8085 100644 --- a/nnvm/src/top/tensor/transform.cc +++ b/nnvm/src/top/tensor/transform.cc @@ -427,7 +427,7 @@ along which to split the array. return Array{ topi::split(inputs[0], indices, param.axis) }; } }) -.set_support_level(1); +.set_support_level(3); // cast DMLC_REGISTER_PARAMETER(CastParam); diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 84e2398f0a9e..198d22350cd0 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -146,7 +146,7 @@ def take(data, indices, axis=None): Parameters ---------- - a : relay.Expr + data : relay.Expr The source array. indices : rely.Expr @@ -280,3 +280,33 @@ def collapse_sum_like(data, collapse_type): The resulting tensor. """ return _make.collapse_sum_like(data, collapse_type) + + +def split(data, indices_or_sections, axis=0): + """Split input tensor along axis by sections or indices. + + If indices_or_sections is an integer, the input will be divided equally + along given axis. If such a split is not possible, an error is raised. + + If indices_or_sections is a tuple of sorted integers, + the entries indicate where along axis the array is split. + + Parameters + ---------- + data : relay.Expr + The source array. + + indices_or_sections : int or tuple of int + Indices or sections to split into. Accepts an int or a tuple + + axis : int, optional + The axis over which to split. + + Returns + ------- + ret : relay.Tuple([relay.Expr, relay.Expr]) + The computed result. + """ + ret_size = indices_or_sections if isinstance(indices_or_sections, int) + else len(indices_or_sections)+1 + return TupleWrapper(_make.split(data, indices_or_sections, axis), ret_size) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 29dff1e4ba27..e9eac331d028 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -834,5 +834,94 @@ RELAY_REGISTER_OP("broadcast_to_like") .set_support_level(10) .add_type_rel("BroadCastToLike", BroadCastToLikeRel); +// Split +TVM_REGISTER_NODE_TYPE(SplitAttrs); + +bool SplitRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + // `types` contains: [data, result] + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + CHECK(data != nullptr); + CHECK_NE(data->shape.size(), 0); + const auto param = attrs.as(); + CHECK(param != nullptr); + + if (param->equal_split) { + const auto num_outputs = as_const_int(param->indices_or_sections[0]); + CHECK_LT(param->axis, data->shape.size()); + // CHECK(reporter->Assert(data->shape[param->axis] % + // param->indices_or_sections[0] == make_zero(Int(64)))) + // << "indices_or_sections need to be able to divide input.shape[axis]"; + + std::vector fields; + for (int i = 0; i < *num_outputs; ++i) { + std::vector&& oshape = AsVector(data->shape); + oshape[param->axis] /= param->indices_or_sections[0]; + auto vec_type = TensorTypeNode::make(oshape, data->dtype); + fields.push_back(vec_type); + } + reporter->Assign(types[1], TupleTypeNode::make(Array(fields))); + } else { + const auto num_outputs = param->indices_or_sections.size() + 1; + CHECK_LT(param->axis, data->shape.size()); + auto begin = make_zero(Int(32)); + std::vector fields; + for (uint i = 0; i < num_outputs - 1; ++i) { + // CHECK(reporter->Assert(param->indices_or_sections[i] > begin)) + // << "indices_or_sections need to be a sorted ascending list"; + std::vector&& oshape = AsVector(data->shape); + oshape[param->axis] = param->indices_or_sections[i] - begin; + begin = param->indices_or_sections[i]; + auto vec_type = TensorTypeNode::make(oshape, data->dtype); + fields.push_back(vec_type); + } + // CHECK(reporter->Assert(begin < data->shape[param->axis])) + // << "The sum of sections must match the input.shape[axis]"; + std::vector&& oshape = AsVector(data->shape); + oshape[param->axis] = data->shape[param->axis] - begin; + auto vec_type = TensorTypeNode::make(oshape, data->dtype); + fields.push_back(vec_type); + + reporter->Assign(types[1], TupleTypeNode::make(Array(fields))); + } + return true; +} + +Expr MakeSplit(Expr data, + Array indices_or_sections, + int axis, + bool equal_split) { + auto attrs = make_node(); + attrs->axis = axis; + attrs->indices_or_sections = std::move(indices_or_sections); + attrs->equal_split = equal_split; + static const Op& op = Op::Get("split"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay.op._make.split") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeSplit, args, rv); +}); + +RELAY_REGISTER_OP("split") +.describe(R"code(Splits an array along a particular axis into multiple sub-arrays. + +While equal_split is true `indices_or_sections` should be of size 1 and it indicates +number of sections to solit into and the dimension along given axis should be a +multiple of indices_or_section[0]. + +With equal_split being false indices_or_section ia an ascending ordered list with in 0 +and dimention of given axis. Here the input is split at the given indices. + +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.add_argument("data", "Tensor", "The input tensor.") +.set_support_level(3) +.add_type_rel("Split", SplitRel); + } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 8ab3c41c079d..b77e5be19a80 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -107,6 +107,42 @@ def verify_take(dshape, indices_shape, oshape, axis=None): verify_take((d1, d2), (d3, d4, d5), (d1, d3, d4, d5), 1) verify_take((d1, d2, d3, d4), (d5, d6), (d1, d2, d5, d6, d4), -2) +def test_split_infer_type(): + def verify_split(dshape, indices_or_sections, ret_type, axis=None, equal_split=True): + ib = relay.ir_builder.IRBuilder() + x = ib.param("x", relay.ty.TensorType(dshape, "float32")) + with ib.function(x) as func: + ib.ret(relay.split(x, indices_or_sections, axis=axis, equal_split=equal_split)) + ib.ret(func) + func = relay.ir_pass.infer_type(ib.env, func.to_func()) + ftype = func.checked_type + assert ftype.ret_type == ret_type + + d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4") + axis = tvm.var("axis") + verify_split((5, 4, 2, 2), (4,), + relay.ty.TupleType(tvm.convert([ + relay.ty.TensorType((5, 1, 2, 2), "float32"), + relay.ty.TensorType((5, 1, 2, 2), "float32"), + relay.ty.TensorType((5, 1, 2, 2), "float32"), + relay.ty.TensorType((5, 1, 2, 2), "float32")])), + axis=1, equal_split=True) + + verify_split((d1, d2, d3, d4), (4,), + relay.ty.TupleType(tvm.convert([ + relay.ty.TensorType((d1, d2, d3/4, d4), "float32"), + relay.ty.TensorType((d1, d2, d3/4, d4), "float32"), + relay.ty.TensorType((d1, d2, d3/4, d4), "float32"), + relay.ty.TensorType((d1, d2, d3/4, d4), "float32")])), + axis=2, equal_split=True) + + verify_split((d1, d2, d3, d4), (2, 4, 7), + relay.ty.TupleType(tvm.convert([ + relay.ty.TensorType((d1, 2, d3, d4), "float32"), + relay.ty.TensorType((d1, 2, d3, d4), "float32"), + relay.ty.TensorType((d1, 3, d3, d4), "float32"), + relay.ty.TensorType((d1, (d2-7), d3, d4), "float32")])), + axis=1, equal_split=False) def test_full(): # default settings: match input dtype @@ -161,3 +197,4 @@ def test_infer_type_leaky_relu(): test_infer_type_leaky_relu() test_squeeze_infer_type() test_squeeze_bad_axes_infer_type() + test_split_infer_type() From 584c326563b2d54903ea2acf5519a37286b6e31d Mon Sep 17 00:00:00 2001 From: SRK Reddy Date: Thu, 11 Oct 2018 10:40:57 +0530 Subject: [PATCH 02/11] * Review comments. --- include/tvm/relay/attrs/transform.h | 2 +- src/relay/op/tensor/transform.cc | 24 +++++++++++++++--------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 4493050b3ba9..da009d7ce5c2 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -114,7 +114,7 @@ struct SplitAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(SplitAttrs, "relay.attrs.SplitAttrs") { TVM_ATTR_FIELD(indices_or_sections) .describe("Number of outputs to be splitted"); - TVM_ATTR_FIELD(axis).set_lower_bound(0).set_default(1) + TVM_ATTR_FIELD(axis).set_default(0) .describe("the axis to be splitted."); TVM_ATTR_FIELD(equal_split).set_default(false) .describe("Is it equal split of input"); diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index e9eac331d028..639056dbce4f 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -845,43 +845,49 @@ bool SplitRel(const Array& types, CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); CHECK(data != nullptr); - CHECK_NE(data->shape.size(), 0); + CHECK_NE(data->shape.size(), 0) << "Input shape cannot be empty"; const auto param = attrs.as(); CHECK(param != nullptr); + auto axis = param->axis; + if (axis < 0) { + axis += data->shape.size(); + } + CHECK_LT(axis, data->shape.size()) + << "axis should be within the input dimension range."; + CHECK_GT(axis, 0) + << "axis should be within the input dimension range."; + if (param->equal_split) { const auto num_outputs = as_const_int(param->indices_or_sections[0]); - CHECK_LT(param->axis, data->shape.size()); - // CHECK(reporter->Assert(data->shape[param->axis] % + // CHECK(reporter->Assert(data->shape[axis] % // param->indices_or_sections[0] == make_zero(Int(64)))) // << "indices_or_sections need to be able to divide input.shape[axis]"; - std::vector fields; for (int i = 0; i < *num_outputs; ++i) { std::vector&& oshape = AsVector(data->shape); - oshape[param->axis] /= param->indices_or_sections[0]; + oshape[axis] /= param->indices_or_sections[0]; auto vec_type = TensorTypeNode::make(oshape, data->dtype); fields.push_back(vec_type); } reporter->Assign(types[1], TupleTypeNode::make(Array(fields))); } else { const auto num_outputs = param->indices_or_sections.size() + 1; - CHECK_LT(param->axis, data->shape.size()); auto begin = make_zero(Int(32)); std::vector fields; for (uint i = 0; i < num_outputs - 1; ++i) { // CHECK(reporter->Assert(param->indices_or_sections[i] > begin)) // << "indices_or_sections need to be a sorted ascending list"; std::vector&& oshape = AsVector(data->shape); - oshape[param->axis] = param->indices_or_sections[i] - begin; + oshape[axis] = param->indices_or_sections[i] - begin; begin = param->indices_or_sections[i]; auto vec_type = TensorTypeNode::make(oshape, data->dtype); fields.push_back(vec_type); } - // CHECK(reporter->Assert(begin < data->shape[param->axis])) + // CHECK(reporter->Assert(begin < data->shape[axis])) // << "The sum of sections must match the input.shape[axis]"; std::vector&& oshape = AsVector(data->shape); - oshape[param->axis] = data->shape[param->axis] - begin; + oshape[axis] = data->shape[axis] - begin; auto vec_type = TensorTypeNode::make(oshape, data->dtype); fields.push_back(vec_type); From 60a8e364823374c7733992fbab5f498ac12f4629 Mon Sep 17 00:00:00 2001 From: SRK Reddy Date: Fri, 12 Oct 2018 10:32:15 +0530 Subject: [PATCH 03/11] * Rebase & review comments addressed. --- src/relay/op/tensor/transform.cc | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 639056dbce4f..9a151f9d6649 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -860,9 +860,9 @@ bool SplitRel(const Array& types, if (param->equal_split) { const auto num_outputs = as_const_int(param->indices_or_sections[0]); - // CHECK(reporter->Assert(data->shape[axis] % - // param->indices_or_sections[0] == make_zero(Int(64)))) - // << "indices_or_sections need to be able to divide input.shape[axis]"; + CHECK(reporter->Assert(data->shape[axis] % + param->indices_or_sections[0] == make_zero(Int(64)))) + << "indices_or_sections need to be able to divide input.shape[axis]"; std::vector fields; for (int i = 0; i < *num_outputs; ++i) { std::vector&& oshape = AsVector(data->shape); @@ -876,16 +876,16 @@ bool SplitRel(const Array& types, auto begin = make_zero(Int(32)); std::vector fields; for (uint i = 0; i < num_outputs - 1; ++i) { - // CHECK(reporter->Assert(param->indices_or_sections[i] > begin)) - // << "indices_or_sections need to be a sorted ascending list"; + CHECK(reporter->Assert(param->indices_or_sections[i] > begin)) + << "indices_or_sections need to be a sorted ascending list"; std::vector&& oshape = AsVector(data->shape); oshape[axis] = param->indices_or_sections[i] - begin; begin = param->indices_or_sections[i]; auto vec_type = TensorTypeNode::make(oshape, data->dtype); fields.push_back(vec_type); } - // CHECK(reporter->Assert(begin < data->shape[axis])) - // << "The sum of sections must match the input.shape[axis]"; + CHECK(reporter->Assert(begin < data->shape[axis])) + << "The sum of sections must match the input.shape[axis]"; std::vector&& oshape = AsVector(data->shape); oshape[axis] = data->shape[axis] - begin; auto vec_type = TensorTypeNode::make(oshape, data->dtype); From a72312e28a9aa48c26eccf6d59cc6d2a284132ae Mon Sep 17 00:00:00 2001 From: SRK Reddy Date: Fri, 19 Oct 2018 18:52:21 +0530 Subject: [PATCH 04/11] * Review comments --- python/tvm/relay/op/transform.py | 1 + tests/python/relay/test_op_level3.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 198d22350cd0..02de715cec1d 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1,6 +1,7 @@ """Transform operators.""" from . import _make +from ..expr import TupleWrapper def expand_dims(data, axis, num_newaxis=1): diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index b77e5be19a80..5db817df51b2 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -120,11 +120,12 @@ def verify_split(dshape, indices_or_sections, ret_type, axis=None, equal_split=T d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4") axis = tvm.var("axis") - verify_split((5, 4, 2, 2), (4,), + verify_split((5, 5, 2, 2), (5,), relay.ty.TupleType(tvm.convert([ relay.ty.TensorType((5, 1, 2, 2), "float32"), relay.ty.TensorType((5, 1, 2, 2), "float32"), relay.ty.TensorType((5, 1, 2, 2), "float32"), + relay.ty.TensorType((5, 1, 2, 2), "float32"), relay.ty.TensorType((5, 1, 2, 2), "float32")])), axis=1, equal_split=True) From c2fdf065ba83c0327857af9e5fd3d0b3c981f45c Mon Sep 17 00:00:00 2001 From: SRK Reddy Date: Sun, 21 Oct 2018 13:06:22 +0530 Subject: [PATCH 05/11] * Rebase. --- src/relay/op/tensor/transform.cc | 1 + tests/python/relay/test_op_level3.py | 12 ++++-------- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 9a151f9d6649..d1c9a0f3b538 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -924,6 +924,7 @@ With equal_split being false indices_or_section ia an ascending ordered list wit and dimention of given axis. Here the input is split at the given indices. )code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.SplitAttrs") .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_support_level(3) diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 5db817df51b2..1d8612288c0c 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -109,14 +109,10 @@ def verify_take(dshape, indices_shape, oshape, axis=None): def test_split_infer_type(): def verify_split(dshape, indices_or_sections, ret_type, axis=None, equal_split=True): - ib = relay.ir_builder.IRBuilder() - x = ib.param("x", relay.ty.TensorType(dshape, "float32")) - with ib.function(x) as func: - ib.ret(relay.split(x, indices_or_sections, axis=axis, equal_split=equal_split)) - ib.ret(func) - func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type - assert ftype.ret_type == ret_type + x = relay.var("x", relay.ty.TensorType(dshape, "float32")) + y = relay.split(x, indices_or_sections, axis=axis, equal_split=equal_split) + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == ret_type d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4") axis = tvm.var("axis") From 9fbb9ee6e165f45ed9f7084305a8cdb76b35536f Mon Sep 17 00:00:00 2001 From: SRK Reddy Date: Mon, 22 Oct 2018 17:37:41 +0530 Subject: [PATCH 06/11] * review comments. Use NodeRef and keep numpy consistency. --- include/tvm/relay/attrs/transform.h | 11 ++++--- python/tvm/relay/op/transform.py | 8 +++-- src/relay/op/tensor/transform.cc | 48 +++++++++++++++------------- tests/python/relay/test_op_level3.py | 16 ++++------ 4 files changed, 44 insertions(+), 39 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index da009d7ce5c2..dfad1013701f 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -107,17 +107,18 @@ struct SqueezeAttrs : public tvm::AttrsNode { }; // struct SqueezeAttrs struct SplitAttrs : public tvm::AttrsNode { - Array indices_or_sections; + NodeRef indices_or_sections; int axis; - bool equal_split; TVM_DECLARE_ATTRS(SplitAttrs, "relay.attrs.SplitAttrs") { TVM_ATTR_FIELD(indices_or_sections) - .describe("Number of outputs to be splitted"); + .describe("Indices or sections to split into. Accepts an int or a tuple" + "If indices_or_sections is an integer, the input will be divided equally" + "along given axis. If such a split is not possible, an error is raised." + "If indices_or_sections is a tuple of sorted integers," + "the entries indicate where along axis the array is split."); TVM_ATTR_FIELD(axis).set_default(0) .describe("the axis to be splitted."); - TVM_ATTR_FIELD(equal_split).set_default(false) - .describe("Is it equal split of input"); } }; diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 02de715cec1d..3cf139c7dd86 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -300,7 +300,7 @@ def split(data, indices_or_sections, axis=0): indices_or_sections : int or tuple of int Indices or sections to split into. Accepts an int or a tuple - axis : int, optional + axis : int, optional The axis over which to split. Returns @@ -308,6 +308,8 @@ def split(data, indices_or_sections, axis=0): ret : relay.Tuple([relay.Expr, relay.Expr]) The computed result. """ - ret_size = indices_or_sections if isinstance(indices_or_sections, int) - else len(indices_or_sections)+1 + if isinstance(indices_or_sections, int): + ret_size = indices_or_sections + else: + ret_size = len(indices_or_sections) + 1 return TupleWrapper(_make.split(data, indices_or_sections, axis), ret_size) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index d1c9a0f3b538..256798065a42 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -6,12 +6,14 @@ #include #include #include +#include #include #include "../op_common.h" namespace tvm { namespace relay { +using ir::IntImm; // relay.cast TVM_REGISTER_NODE_TYPE(CastAttrs); @@ -848,7 +850,6 @@ bool SplitRel(const Array& types, CHECK_NE(data->shape.size(), 0) << "Input shape cannot be empty"; const auto param = attrs.as(); CHECK(param != nullptr); - auto axis = param->axis; if (axis < 0) { axis += data->shape.size(); @@ -858,29 +859,31 @@ bool SplitRel(const Array& types, CHECK_GT(axis, 0) << "axis should be within the input dimension range."; - if (param->equal_split) { - const auto num_outputs = as_const_int(param->indices_or_sections[0]); + if (param->indices_or_sections.as()) { + const auto sections = make_const(Int(32), + param->indices_or_sections.as()->value); CHECK(reporter->Assert(data->shape[axis] % - param->indices_or_sections[0] == make_zero(Int(64)))) + sections == make_zero(Int(64)))) << "indices_or_sections need to be able to divide input.shape[axis]"; std::vector fields; - for (int i = 0; i < *num_outputs; ++i) { + for (int i = 0; i < *as_const_int(sections); ++i) { std::vector&& oshape = AsVector(data->shape); - oshape[axis] /= param->indices_or_sections[0]; + oshape[axis] /= sections; auto vec_type = TensorTypeNode::make(oshape, data->dtype); fields.push_back(vec_type); } reporter->Assign(types[1], TupleTypeNode::make(Array(fields))); } else { - const auto num_outputs = param->indices_or_sections.size() + 1; - auto begin = make_zero(Int(32)); + auto indices = param->indices_or_sections.as()->data; + const auto num_outputs = indices.size() + 1; + auto begin = IndexExpr(make_zero(Int(32))); std::vector fields; for (uint i = 0; i < num_outputs - 1; ++i) { - CHECK(reporter->Assert(param->indices_or_sections[i] > begin)) + CHECK(reporter->Assert(IndexExpr(indices[i]) > begin)) << "indices_or_sections need to be a sorted ascending list"; std::vector&& oshape = AsVector(data->shape); - oshape[axis] = param->indices_or_sections[i] - begin; - begin = param->indices_or_sections[i]; + oshape[axis] = IndexExpr(indices[i]) - begin; + begin = IndexExpr(indices[i]); auto vec_type = TensorTypeNode::make(oshape, data->dtype); fields.push_back(vec_type); } @@ -890,38 +893,39 @@ bool SplitRel(const Array& types, oshape[axis] = data->shape[axis] - begin; auto vec_type = TensorTypeNode::make(oshape, data->dtype); fields.push_back(vec_type); - reporter->Assign(types[1], TupleTypeNode::make(Array(fields))); } return true; } Expr MakeSplit(Expr data, - Array indices_or_sections, - int axis, - bool equal_split) { + NodeRef indices_or_sections, + int axis) { auto attrs = make_node(); attrs->axis = axis; attrs->indices_or_sections = std::move(indices_or_sections); - attrs->equal_split = equal_split; static const Op& op = Op::Get("split"); return CallNode::make(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_API("relay.op._make.split") .set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeSplit, args, rv); + if (args.type_codes[1] == kDLInt) { + *rv = MakeSplit(args[0], make_const(Int(64), int64_t(args[1])), args[2]); + } else { + *rv = MakeSplit(args[0], args[1], args[2]); + } }); RELAY_REGISTER_OP("split") .describe(R"code(Splits an array along a particular axis into multiple sub-arrays. -While equal_split is true `indices_or_sections` should be of size 1 and it indicates -number of sections to solit into and the dimension along given axis should be a -multiple of indices_or_section[0]. +Indices or sections to split into. Accepts an int or a tuple +If indices_or_sections is an integer, the input will be divided equally +along given axis. If such a split is not possible, an error is raised. -With equal_split being false indices_or_section ia an ascending ordered list with in 0 -and dimention of given axis. Here the input is split at the given indices. +If indices_or_sections is a tuple of sorted integers, +the entries indicate where along axis the array is split. )code" TVM_ADD_FILELINE) .set_attrs_type_key("relay.attrs.SplitAttrs") diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 1d8612288c0c..b5d90b9b24e2 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -108,38 +108,36 @@ def verify_take(dshape, indices_shape, oshape, axis=None): verify_take((d1, d2, d3, d4), (d5, d6), (d1, d2, d5, d6, d4), -2) def test_split_infer_type(): - def verify_split(dshape, indices_or_sections, ret_type, axis=None, equal_split=True): + def verify_split(dshape, indices_or_sections, ret_type, axis=None): x = relay.var("x", relay.ty.TensorType(dshape, "float32")) - y = relay.split(x, indices_or_sections, axis=axis, equal_split=equal_split) + y = relay.split(x, indices_or_sections, axis=axis) yy = relay.ir_pass.infer_type(y) assert yy.checked_type == ret_type d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4") axis = tvm.var("axis") - verify_split((5, 5, 2, 2), (5,), + verify_split((5, 5, 2, 2), 5, relay.ty.TupleType(tvm.convert([ relay.ty.TensorType((5, 1, 2, 2), "float32"), relay.ty.TensorType((5, 1, 2, 2), "float32"), relay.ty.TensorType((5, 1, 2, 2), "float32"), relay.ty.TensorType((5, 1, 2, 2), "float32"), relay.ty.TensorType((5, 1, 2, 2), "float32")])), - axis=1, equal_split=True) - - verify_split((d1, d2, d3, d4), (4,), + axis=1) + verify_split((d1, d2, d3, d4), 4, relay.ty.TupleType(tvm.convert([ relay.ty.TensorType((d1, d2, d3/4, d4), "float32"), relay.ty.TensorType((d1, d2, d3/4, d4), "float32"), relay.ty.TensorType((d1, d2, d3/4, d4), "float32"), relay.ty.TensorType((d1, d2, d3/4, d4), "float32")])), - axis=2, equal_split=True) - + axis=2) verify_split((d1, d2, d3, d4), (2, 4, 7), relay.ty.TupleType(tvm.convert([ relay.ty.TensorType((d1, 2, d3, d4), "float32"), relay.ty.TensorType((d1, 2, d3, d4), "float32"), relay.ty.TensorType((d1, 3, d3, d4), "float32"), relay.ty.TensorType((d1, (d2-7), d3, d4), "float32")])), - axis=1, equal_split=False) + axis=1) def test_full(): # default settings: match input dtype From 5ac468a5b5a1908ee2e38682d261d1828ed5d28c Mon Sep 17 00:00:00 2001 From: SRK Reddy Date: Tue, 23 Oct 2018 11:17:05 +0530 Subject: [PATCH 07/11] * Review comments. --- src/relay/op/tensor/transform.cc | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 256798065a42..d7b4980f80b2 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -859,26 +859,23 @@ bool SplitRel(const Array& types, CHECK_GT(axis, 0) << "axis should be within the input dimension range."; - if (param->indices_or_sections.as()) { - const auto sections = make_const(Int(32), - param->indices_or_sections.as()->value); + if (const IntImm* sections = param->indices_or_sections.as()) { CHECK(reporter->Assert(data->shape[axis] % - sections == make_zero(Int(64)))) + sections->value == make_zero(Int(64)))) << "indices_or_sections need to be able to divide input.shape[axis]"; std::vector fields; - for (int i = 0; i < *as_const_int(sections); ++i) { + for (int i = 0; i < sections->value; ++i) { std::vector&& oshape = AsVector(data->shape); - oshape[axis] /= sections; + oshape[axis] /= int32_t(sections->value); auto vec_type = TensorTypeNode::make(oshape, data->dtype); fields.push_back(vec_type); } reporter->Assign(types[1], TupleTypeNode::make(Array(fields))); } else { auto indices = param->indices_or_sections.as()->data; - const auto num_outputs = indices.size() + 1; auto begin = IndexExpr(make_zero(Int(32))); std::vector fields; - for (uint i = 0; i < num_outputs - 1; ++i) { + for (uint i = 0; i < indices.size(); ++i) { CHECK(reporter->Assert(IndexExpr(indices[i]) > begin)) << "indices_or_sections need to be a sorted ascending list"; std::vector&& oshape = AsVector(data->shape); From 3bd3f8e9cbca428bd1bac8da578ab58cd9842f85 Mon Sep 17 00:00:00 2001 From: SRK Reddy Date: Wed, 24 Oct 2018 06:52:04 +0530 Subject: [PATCH 08/11] * Text printer testcase added. --- python/tvm/relay/expr.py | 11 +++++++++++ tests/python/relay/test_op_level3.py | 1 + 2 files changed, 12 insertions(+) diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 655379066c74..0650a493d9a6 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -5,6 +5,7 @@ import numpy as _np from .base import RelayNode, register_relay_node from . import _make +from . import _expr from . import ty as _ty from .._ffi import base as _base from .. import nd as _nd @@ -284,6 +285,16 @@ def astuple(self): as an argument to an FFI function.""" return self.tuple_value + def astext(self): + """Get the text format of the tuple expression. + + Returns + ------- + text : str + The text format of the tuple expression. + """ + return _expr._text_print(self.tuple_value) + def __getitem__(self, index): if index >= len(self): raise IndexError("Tuple index out of range") diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index b5d90b9b24e2..518eb39c47f6 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -111,6 +111,7 @@ def test_split_infer_type(): def verify_split(dshape, indices_or_sections, ret_type, axis=None): x = relay.var("x", relay.ty.TensorType(dshape, "float32")) y = relay.split(x, indices_or_sections, axis=axis) + y.astext() yy = relay.ir_pass.infer_type(y) assert yy.checked_type == ret_type From 8cd8c0632f21b960750e01d39074736553e81bc4 Mon Sep 17 00:00:00 2001 From: SRK Reddy Date: Wed, 24 Oct 2018 11:31:14 +0530 Subject: [PATCH 09/11] * Alpha equal for Div. --- src/lang/attr_functor.h | 4 ++++ src/lang/attrs.cc | 2 ++ 2 files changed, 6 insertions(+) diff --git a/src/lang/attr_functor.h b/src/lang/attr_functor.h index ef1d061015c3..9257ad3b5490 100644 --- a/src/lang/attr_functor.h +++ b/src/lang/attr_functor.h @@ -64,6 +64,7 @@ class AttrFunctor { virtual R VisitAttr_(const ir::Add* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::Sub* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::Mul* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const ir::Div* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::Mod* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::Min* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::Max* op, Args... args) ATTR_FUNCTOR_DEFAULT; @@ -96,6 +97,7 @@ class AttrFunctor { ATTR_FUNCTOR_DISPATCH(Add); ATTR_FUNCTOR_DISPATCH(Sub); ATTR_FUNCTOR_DISPATCH(Mul); + ATTR_FUNCTOR_DISPATCH(Div); ATTR_FUNCTOR_DISPATCH(Min); ATTR_FUNCTOR_DISPATCH(Max); ATTR_FUNCTOR_DISPATCH(GE); @@ -135,6 +137,7 @@ class AttrsEqualHandler : bool VisitAttr_(const ir::Add* lhs, const NodeRef& other) final; bool VisitAttr_(const ir::Sub* lhs, const NodeRef& other) final; bool VisitAttr_(const ir::Mul* lhs, const NodeRef& other) final; + bool VisitAttr_(const ir::Div* lhs, const NodeRef& other) final; bool VisitAttr_(const ir::Mod* lhs, const NodeRef& other) final; bool VisitAttr_(const ir::Min* lhs, const NodeRef& other) final; bool VisitAttr_(const ir::Max* lhs, const NodeRef& other) final; @@ -174,6 +177,7 @@ class AttrsHashHandler : size_t VisitAttr_(const ir::Add* op) final; size_t VisitAttr_(const ir::Sub* op) final; size_t VisitAttr_(const ir::Mul* op) final; + size_t VisitAttr_(const ir::Div* op) final; size_t VisitAttr_(const ir::Mod* op) final; size_t VisitAttr_(const ir::Min* op) final; size_t VisitAttr_(const ir::Max* op) final; diff --git a/src/lang/attrs.cc b/src/lang/attrs.cc index 9aa067c09679..3b273f4939ef 100644 --- a/src/lang/attrs.cc +++ b/src/lang/attrs.cc @@ -132,6 +132,7 @@ bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const NodeRef& other) TVM_DEFINE_ATTRS_BINOP_EQUAL(Add); TVM_DEFINE_ATTRS_BINOP_EQUAL(Sub); TVM_DEFINE_ATTRS_BINOP_EQUAL(Mul); +TVM_DEFINE_ATTRS_BINOP_EQUAL(Div); TVM_DEFINE_ATTRS_BINOP_EQUAL(Mod); TVM_DEFINE_ATTRS_BINOP_EQUAL(Max); TVM_DEFINE_ATTRS_BINOP_EQUAL(Min); @@ -243,6 +244,7 @@ size_t AttrsHashHandler::VisitAttr_(const StrMapNode* lhs) { TVM_DEFINE_ATTRS_BINOP_HASH(Add); TVM_DEFINE_ATTRS_BINOP_HASH(Sub); TVM_DEFINE_ATTRS_BINOP_HASH(Mul); +TVM_DEFINE_ATTRS_BINOP_HASH(Div); TVM_DEFINE_ATTRS_BINOP_HASH(Mod); TVM_DEFINE_ATTRS_BINOP_HASH(Max); TVM_DEFINE_ATTRS_BINOP_HASH(Min); From d8889f63c2f3942257bd0de53dddf6bb9fadd862 Mon Sep 17 00:00:00 2001 From: SRK Reddy Date: Wed, 24 Oct 2018 22:34:34 +0530 Subject: [PATCH 10/11] * docs. --- docs/langref/relay_op.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index 42883f5f77da..11fb282abac5 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -94,6 +94,7 @@ This level enables additional math and transform operators. tvm.relay.full tvm.relay.full_like tvm.relay.cast + tvm.relay.split **Level 4: Broadcast and Reductions** @@ -198,6 +199,7 @@ Level 3 Definitions .. autofunction:: tvm.relay.full .. autofunction:: tvm.relay.full_like .. autofunction:: tvm.relay.cast +.. autofunction:: tvm.relay.split Level 4 Definitions From 09cbe730faf05390a17cd7d287bc378c3cd6468f Mon Sep 17 00:00:00 2001 From: SRK Reddy Date: Fri, 26 Oct 2018 06:41:05 +0530 Subject: [PATCH 11/11] * CI error --- tests/python/relay/test_op_level3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 518eb39c47f6..804d3c46ca36 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -112,7 +112,7 @@ def verify_split(dshape, indices_or_sections, ret_type, axis=None): x = relay.var("x", relay.ty.TensorType(dshape, "float32")) y = relay.split(x, indices_or_sections, axis=axis) y.astext() - yy = relay.ir_pass.infer_type(y) + yy = relay.ir_pass.infer_type(y.astuple()) assert yy.checked_type == ret_type d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4")