From 58603324b55c5fa9267c776b7dfd8748aac6fc37 Mon Sep 17 00:00:00 2001 From: Siju Samuel Date: Fri, 12 Oct 2018 10:25:52 +0530 Subject: [PATCH 1/6] [RELAY][OP]Strided slice --- docs/langref/relay_op.rst | 2 + include/tvm/relay/attrs/transform.h | 13 +++ python/tvm/relay/op/transform.py | 27 +++++++ src/relay/op/tensor/transform.cc | 116 +++++++++++++++++++++++++++ tests/python/relay/test_op_level4.py | 21 +++++ 5 files changed, 179 insertions(+) diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index 405f071e3283..e99ac3c97f73 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -123,6 +123,7 @@ This level enables additional math and transform operators. tvm.relay.min tvm.relay.mean tvm.relay.prod + tvm.relay.strided_slice **Level 5: Vision/Image Operators** @@ -227,6 +228,7 @@ Level 4 Definitions .. autofunction:: tvm.relay.min .. autofunction:: tvm.relay.mean .. autofunction:: tvm.relay.prod +.. autofunction:: tvm.relay.strided_slice diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index cb87d358e966..2465777a6dc4 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -120,6 +120,19 @@ struct SplitAttrs : public tvm::AttrsNode { "the entries indicate where along axis the array is split."); TVM_ATTR_FIELD(axis).set_default(0) .describe("the axis to be splitted."); +/*! \brief Attributes for StridedSlice operator */ +struct StridedSliceAttrs : public tvm::AttrsNode { + Array begin; + Array end; + Array stride; + + TVM_DECLARE_ATTRS(StridedSliceAttrs, "relay.attrs.StridedSliceAttrs") { + TVM_ATTR_FIELD(begin) + .describe("Indices for begin of slice"); + TVM_ATTR_FIELD(end) + .describe("Indices for end of the slice"); + TVM_ATTR_FIELD(stride).set_default(Array({})) + .describe("Stride values of the slice"); } }; diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 909b175f08ca..960526acc194 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -334,3 +334,30 @@ def split(data, indices_or_sections, axis=0): else: ret_size = len(indices_or_sections) + 1 return TupleWrapper(_make.split(data, indices_or_sections, axis), ret_size) + +def strided_slice(data, begin, end, stride=None): + """Strided slice of an array.. + + Parameters + ---------- + data : relay.Expr + The source array to be sliced. + + begin: list of int + The indices to begin with in the slicing. + + end: list of int + Indicies indicating end of the slice. + + stride: list of int, optional + Specifies the stride values, it can be negative in that case, + the input tensor will be reversed in that particular axis. + + Returns + ------- + ret : relay.Expr + The computed result. + """ + stride = stride or [] + return _make.strided_slice(data, list(begin), list(end), list(stride)) +>>>>>>> [RELAY][OP]Strided slice diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 20e0e3adbfd3..349399456ebe 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -889,6 +889,122 @@ RELAY_REGISTER_OP("broadcast_to_like") .add_argument("broadcast_type", "Tensor", "Provide the type to broadcast to.") .set_support_level(10) .add_type_rel("BroadCastToLike", BroadCastToLikeRel); +// strided_slice +TVM_REGISTER_NODE_TYPE(StridedSliceAttrs); +bool StridedSliceRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + CHECK(data != nullptr); + if (data->shape.size() == 0) return false; + + const StridedSliceAttrs *param = attrs.as(); + CHECK(param != nullptr); + + auto dshape = data->shape; + auto num_axis = dshape.size(); + + std::vector begin_vec; + for (auto i : param->begin) { + begin_vec.push_back(i); + } + for (auto i = begin_vec.size(); i < num_axis; ++i) { + begin_vec.push_back(0); + } + + std::vector end_vec; + for (auto i : param->end) { + end_vec.push_back(i); + } + for (auto i = end_vec.size(); i < num_axis; ++i) { + end_vec.push_back(dshape[i]); + } + + std::vector stride_vec; + for (auto i : param->stride) { + stride_vec.push_back(i); + } + for (auto i = stride_vec.size(); i < num_axis; ++i) { + stride_vec.push_back(1); + } + std::vector oshape(dshape.size()); + + #define MAX(a, b) (reporter->Assert((a) > (b)) ? (a) : (b)) + #define MIN(a, b) (reporter->Assert((a) < (b)) ? (a) : (b)) + + for (size_t i = 0; i < num_axis; ++i) { + auto begin_range = reporter->Assert(stride_vec[i] < 0) ? -1 : 0; + auto end_range = reporter->Assert(stride_vec[i] < 0) ? dshape[i] - 1 : dshape[i]; + auto begin = reporter->Assert(begin_vec[i] < 0) ? dshape[i] + begin_vec[i] : begin_vec[i]; + auto end = reporter->Assert(end_vec[i] < 0) ? dshape[i] + end_vec[i] : end_vec[i]; + + begin = MIN(MAX(begin, begin_range), end_range); + end = MIN(MAX(end, begin_range), end_range); + auto interval = abs((end - begin)); + auto slice_size = (interval + abs(stride_vec[i]) - 1) / abs(stride_vec[i]); + + CHECK(reporter->Assert(stride_vec[i] < 0) ? + reporter->Assert(end < begin) : reporter->Assert(begin < end)) + << ": Input [Begin=" << begin_vec[i] << ", End=" << end_vec[i] + << "] is invalid for axis=" << i; + oshape[i] = slice_size; + } + + reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); + return true; +} + + +// Positional relay function to create StridedSlice operator used by frontend FFI. +Expr MakeStridedSlice(Expr data, + Array begin, + Array end, + Array stride) { + auto attrs = make_node(); + attrs->begin = std::move(begin); + attrs->end = std::move(end); + attrs->stride = std::move(stride); + static const Op& op = Op::Get("strided_slice"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + + +TVM_REGISTER_API("relay.op._make.strided_slice") + .set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeStridedSlice, args, rv); + }); + + +RELAY_REGISTER_OP("strided_slice") + .describe(R"code(Strided slice of an array. + +Examples:: + + x = [[ 1., 4., 7., 10.], + [ 2., 5., 8., 11.], + [ 3., 6., 9., 12.]] + + strided_slice(x, begin=[0, 1], end=[2, 4], stride=[1, 1]) = [[ 4., 7., 10.], + [ 5., 8., 11.]] + + x = [[[ 1., 2.], + [ 3., 4.]], + + [[ 5., 6.], + [ 7., 8.]]] + + strided_slice(x, begin=[0, 0], end=[2, 2]) = [[[ 1., 2.], + [ 3., 4.]], + + [[ 5., 6.], + [ 7., 8.]]] +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.add_argument("data", "Tensor", "The input tensor.") +.set_support_level(4) +.add_type_rel("StridedSlice", StridedSliceRel); // Split TVM_REGISTER_NODE_TYPE(SplitAttrs); diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index 6fd70c386567..23a6e03b2902 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -141,6 +141,26 @@ def test_reduce_functions(): verify_reduce(func, (128, 24, 128), (0, 2), False, False, (24,)) verify_reduce(func, (128, 24, 128), (0, 1), True, False, (1, 1, 128)) verify_reduce(func, (128, 24, 128), (0, 2), True, False, (1, 24, 1)) +def verify_strided_slice(data, begin, end, stride, output): + x = relay.var("x", relay.TensorType(data, "float32")) + z = relay.strided_slice(x, begin=begin, end=end, stride=stride) + zz = relay.ir_pass.infer_type(z) + assert "begin=" in z.astext() + assert "end=" in z.astext() + if stride: + assert "stride=" in z.astext() + assert zz.checked_type == relay.ty.TensorType(output, "float32") + +def test_strided_slice(): + verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2], (3, 1, 2)) + verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], (1, 3, 3)) + verify_strided_slice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3)) + verify_strided_slice((3, 4, 3), [1, 0, 0], [2, 2, 3], [1, 1, 2], (1, 2, 2)) + verify_strided_slice((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3)) + verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3)) + verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 1000, 3], None, (2, 3, 3)) + verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4], None, (2, 3, 3)) + verify_strided_slice((3, 4, 3), [1, 1], [4, 4, 3], None, (2, 3, 3)) if __name__ == "__main__": test_binary_op() @@ -148,3 +168,4 @@ def test_reduce_functions(): test_binary_int_broadcast() test_where() test_reduce_functions() + test_strided_slice() From 30390e3a928c13a7d8e6c57dca88d8fd0fcda38b Mon Sep 17 00:00:00 2001 From: Siju Samuel Date: Mon, 15 Oct 2018 16:21:13 +0530 Subject: [PATCH 2/6] Review comments --- src/relay/op/tensor/transform.cc | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 349399456ebe..a1a6d27aed38 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -931,17 +931,14 @@ bool StridedSliceRel(const Array& types, } std::vector oshape(dshape.size()); - #define MAX(a, b) (reporter->Assert((a) > (b)) ? (a) : (b)) - #define MIN(a, b) (reporter->Assert((a) < (b)) ? (a) : (b)) - for (size_t i = 0; i < num_axis; ++i) { auto begin_range = reporter->Assert(stride_vec[i] < 0) ? -1 : 0; auto end_range = reporter->Assert(stride_vec[i] < 0) ? dshape[i] - 1 : dshape[i]; auto begin = reporter->Assert(begin_vec[i] < 0) ? dshape[i] + begin_vec[i] : begin_vec[i]; auto end = reporter->Assert(end_vec[i] < 0) ? dshape[i] + end_vec[i] : end_vec[i]; - begin = MIN(MAX(begin, begin_range), end_range); - end = MIN(MAX(end, begin_range), end_range); + begin = min(max(begin, begin_range), end_range); + end = min(max(end, begin_range), end_range); auto interval = abs((end - begin)); auto slice_size = (interval + abs(stride_vec[i]) - 1) / abs(stride_vec[i]); From 315df6222c8d811473dd267f54aebdaf75239ac2 Mon Sep 17 00:00:00 2001 From: Siju Samuel Date: Mon, 15 Oct 2018 17:16:20 +0530 Subject: [PATCH 3/6] Description updated for begin&end strided slice --- include/tvm/relay/attrs/transform.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 2465777a6dc4..55e4d7da2b55 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -128,9 +128,9 @@ struct StridedSliceAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(StridedSliceAttrs, "relay.attrs.StridedSliceAttrs") { TVM_ATTR_FIELD(begin) - .describe("Indices for begin of slice"); + .describe("Indices for begin of slice, begin index is also inclusive"); TVM_ATTR_FIELD(end) - .describe("Indices for end of the slice"); + .describe("Indices for end of slice, end index is also inclusive"); TVM_ATTR_FIELD(stride).set_default(Array({})) .describe("Stride values of the slice"); } From 17c18ab923ec72a03152c33d1ab0723ab7d7f6ca Mon Sep 17 00:00:00 2001 From: Siju Samuel Date: Sat, 20 Oct 2018 08:39:53 +0530 Subject: [PATCH 4/6] set_attrs_type_key and test_format testcase added --- include/tvm/relay/attrs/transform.h | 4 +++- python/tvm/relay/op/transform.py | 2 +- src/relay/op/tensor/transform.cc | 1 + 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 55e4d7da2b55..564b5a9f513e 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -120,6 +120,9 @@ struct SplitAttrs : public tvm::AttrsNode { "the entries indicate where along axis the array is split."); TVM_ATTR_FIELD(axis).set_default(0) .describe("the axis to be splitted."); + } +}; + /*! \brief Attributes for StridedSlice operator */ struct StridedSliceAttrs : public tvm::AttrsNode { Array begin; @@ -135,7 +138,6 @@ struct StridedSliceAttrs : public tvm::AttrsNode { .describe("Stride values of the slice"); } }; - } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_TRANSFORM_H_ diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 960526acc194..2410f4288651 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -335,6 +335,7 @@ def split(data, indices_or_sections, axis=0): ret_size = len(indices_or_sections) + 1 return TupleWrapper(_make.split(data, indices_or_sections, axis), ret_size) + def strided_slice(data, begin, end, stride=None): """Strided slice of an array.. @@ -360,4 +361,3 @@ def strided_slice(data, begin, end, stride=None): """ stride = stride or [] return _make.strided_slice(data, list(begin), list(end), list(stride)) ->>>>>>> [RELAY][OP]Strided slice diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index a1a6d27aed38..dd3f474c675b 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1001,6 +1001,7 @@ Examples:: .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_support_level(4) +.set_attrs_type_key("relay.attrs.StridedSliceAttrs") .add_type_rel("StridedSlice", StridedSliceRel); // Split From 8cbe3a2f57f04031720ab323ee81e91ce89eb4e8 Mon Sep 17 00:00:00 2001 From: Siju Samuel Date: Sun, 28 Oct 2018 21:17:43 +0530 Subject: [PATCH 5/6] Review comment fixed --- include/tvm/relay/attrs/transform.h | 8 +++---- python/tvm/relay/op/transform.py | 8 +++---- src/relay/op/tensor/transform.cc | 36 +++++++++++++++++++--------- tests/python/relay/test_op_level4.py | 12 +++++++--- 4 files changed, 42 insertions(+), 22 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 564b5a9f513e..4d2008628d3a 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -125,16 +125,16 @@ struct SplitAttrs : public tvm::AttrsNode { /*! \brief Attributes for StridedSlice operator */ struct StridedSliceAttrs : public tvm::AttrsNode { - Array begin; - Array end; - Array stride; + Array begin; + Array end; + Array strides; TVM_DECLARE_ATTRS(StridedSliceAttrs, "relay.attrs.StridedSliceAttrs") { TVM_ATTR_FIELD(begin) .describe("Indices for begin of slice, begin index is also inclusive"); TVM_ATTR_FIELD(end) .describe("Indices for end of slice, end index is also inclusive"); - TVM_ATTR_FIELD(stride).set_default(Array({})) + TVM_ATTR_FIELD(strides).set_default(Array({})) .describe("Stride values of the slice"); } }; diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 2410f4288651..e43a4a573e54 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -336,7 +336,7 @@ def split(data, indices_or_sections, axis=0): return TupleWrapper(_make.split(data, indices_or_sections, axis), ret_size) -def strided_slice(data, begin, end, stride=None): +def strided_slice(data, begin, end, strides=None): """Strided slice of an array.. Parameters @@ -350,7 +350,7 @@ def strided_slice(data, begin, end, stride=None): end: list of int Indicies indicating end of the slice. - stride: list of int, optional + strides: list of int, optional Specifies the stride values, it can be negative in that case, the input tensor will be reversed in that particular axis. @@ -359,5 +359,5 @@ def strided_slice(data, begin, end, stride=None): ret : relay.Expr The computed result. """ - stride = stride or [] - return _make.strided_slice(data, list(begin), list(end), list(stride)) + strides = strides or [] + return _make.strided_slice(data, list(begin), list(end), list(strides)) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index dd3f474c675b..31e7d2aaff84 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -889,6 +889,8 @@ RELAY_REGISTER_OP("broadcast_to_like") .add_argument("broadcast_type", "Tensor", "Provide the type to broadcast to.") .set_support_level(10) .add_type_rel("BroadCastToLike", BroadCastToLikeRel); + + // strided_slice TVM_REGISTER_NODE_TYPE(StridedSliceAttrs); bool StridedSliceRel(const Array& types, @@ -906,7 +908,7 @@ bool StridedSliceRel(const Array& types, auto dshape = data->shape; auto num_axis = dshape.size(); - std::vector begin_vec; + std::vector begin_vec; for (auto i : param->begin) { begin_vec.push_back(i); } @@ -922,8 +924,8 @@ bool StridedSliceRel(const Array& types, end_vec.push_back(dshape[i]); } - std::vector stride_vec; - for (auto i : param->stride) { + std::vector stride_vec; + for (auto i : param->strides) { stride_vec.push_back(i); } for (auto i = stride_vec.size(); i < num_axis; ++i) { @@ -932,10 +934,22 @@ bool StridedSliceRel(const Array& types, std::vector oshape(dshape.size()); for (size_t i = 0; i < num_axis; ++i) { - auto begin_range = reporter->Assert(stride_vec[i] < 0) ? -1 : 0; - auto end_range = reporter->Assert(stride_vec[i] < 0) ? dshape[i] - 1 : dshape[i]; - auto begin = reporter->Assert(begin_vec[i] < 0) ? dshape[i] + begin_vec[i] : begin_vec[i]; - auto end = reporter->Assert(end_vec[i] < 0) ? dshape[i] + end_vec[i] : end_vec[i]; + const int64_t* stride_t = as_const_int(stride_vec[i]); + CHECK(stride_t != nullptr) << "Stride cannot be symbolic."; + int64_t stride_v = stride_t[0]; + + const int64_t* begin_t = as_const_int(begin_vec[i]); + CHECK(begin_t != nullptr) << "Begin index cannot be symbolic."; + int64_t begin_v = begin_t[0]; + + const int64_t* end_t = as_const_int(end_vec[i]); + CHECK(end_t != nullptr) << "End index cannot be symbolic."; + int64_t end_v = end_t[0]; + + auto begin_range = make_const(Int(64), (stride_v < 0) ? -1 : 0); + auto end_range = (stride_v < 0) ? dshape[i] - 1 : dshape[i]; + auto begin = (begin_v < 0) ? dshape[i] + begin_vec[i] : begin_vec[i]; + auto end = (end_v < 0) ? dshape[i] + end_vec[i] : end_vec[i]; begin = min(max(begin, begin_range), end_range); end = min(max(end, begin_range), end_range); @@ -956,13 +970,13 @@ bool StridedSliceRel(const Array& types, // Positional relay function to create StridedSlice operator used by frontend FFI. Expr MakeStridedSlice(Expr data, - Array begin, - Array end, - Array stride) { + Array begin, + Array end, + Array strides) { auto attrs = make_node(); attrs->begin = std::move(begin); attrs->end = std::move(end); - attrs->stride = std::move(stride); + attrs->strides = std::move(strides); static const Op& op = Op::Get("strided_slice"); return CallNode::make(op, {data}, Attrs(attrs), {}); } diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index 23a6e03b2902..e98beaf07752 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -141,17 +141,22 @@ def test_reduce_functions(): verify_reduce(func, (128, 24, 128), (0, 2), False, False, (24,)) verify_reduce(func, (128, 24, 128), (0, 1), True, False, (1, 1, 128)) verify_reduce(func, (128, 24, 128), (0, 2), True, False, (1, 24, 1)) + + def verify_strided_slice(data, begin, end, stride, output): x = relay.var("x", relay.TensorType(data, "float32")) - z = relay.strided_slice(x, begin=begin, end=end, stride=stride) + z = relay.strided_slice(x, begin=begin, end=end, strides=stride) zz = relay.ir_pass.infer_type(z) assert "begin=" in z.astext() assert "end=" in z.astext() if stride: - assert "stride=" in z.astext() - assert zz.checked_type == relay.ty.TensorType(output, "float32") + assert "strides=" in z.astext() + if output: + assert zz.checked_type == relay.ty.TensorType(output, "float32") def test_strided_slice(): + d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4") + verify_strided_slice((d1, d2, d3), [0, 0, 0], [4, -5, 4], [1, -1, 2], None) verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2], (3, 1, 2)) verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], (1, 3, 3)) verify_strided_slice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3)) @@ -162,6 +167,7 @@ def test_strided_slice(): verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4], None, (2, 3, 3)) verify_strided_slice((3, 4, 3), [1, 1], [4, 4, 3], None, (2, 3, 3)) + if __name__ == "__main__": test_binary_op() test_cmp_type() From a4fe198775f2d40df79a1bdef9ff8f560f1dee97 Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 13 Nov 2018 11:58:36 -0800 Subject: [PATCH 6/6] [RELAY][OP] Improve strided_slice, based on siju's PR --- nnvm/src/top/tensor/transform.cc | 30 +++- python/tvm/_ffi/node_generic.py | 2 + python/tvm/relay/op/__init__.py | 1 + python/tvm/relay/op/_transform.py | 8 + src/api/api_lang.cc | 6 +- src/relay/ir/text_printer.cc | 6 +- src/relay/op/tensor/transform.cc | 146 +++++++++++------- tests/python/relay/test_op_level4.py | 55 ++++--- topi/include/topi/transform.h | 55 +++++-- topi/python/topi/testing/__init__.py | 1 + .../topi/testing/strided_slice_python.py | 32 ++++ topi/tests/python/test_topi_transform.py | 17 +- 12 files changed, 247 insertions(+), 112 deletions(-) create mode 100644 python/tvm/relay/op/_transform.py create mode 100644 topi/python/topi/testing/strided_slice_python.py diff --git a/nnvm/src/top/tensor/transform.cc b/nnvm/src/top/tensor/transform.cc index 4d08bf761326..2f42727d6083 100644 --- a/nnvm/src/top/tensor/transform.cc +++ b/nnvm/src/top/tensor/transform.cc @@ -980,23 +980,25 @@ Examples:: const Array& inputs, const Array& out_info) { const StridedSliceParam& param = nnvm::get(attrs.parsed); - Array begin; - Array end; - Array stride; + Array begin; + Array end; + Array stride; for (int64_t i : param.begin) { - begin.push_back(tvm::make_const(tvm::Int(32), i)); + begin.push_back(static_cast(i)); } for (int64_t i : param.end) { - end.push_back(tvm::make_const(tvm::Int(32), i)); + end.push_back(static_cast(i)); } for (int64_t i : param.stride) { - stride.push_back(tvm::make_const(tvm::Int(32), i)); + stride.push_back(static_cast(i)); } - return Array{ topi::strided_slice(inputs[0], begin, end, stride) }; + return Array{ + topi::strided_slice(inputs[0], begin, end, stride) + }; }) .set_support_level(1); @@ -1210,6 +1212,15 @@ inline bool SliceLikeShape(const nnvm::NodeAttrs& attrs, return true; } +// Adapter function to make int array. +Array GetIntArray(Array arr) { + for (size_t i = 0; i < arr.size(); ++i) { + CHECK(!arr[i].defined() || arr[i].as()) + << "Expect an int array"; + } + return Array(arr.node_); +} + NNVM_REGISTER_OP(slice_like) .describe(R"code(Slice the first input respect to the second input. )code" NNVM_ADD_FILELINE) @@ -1261,7 +1272,10 @@ NNVM_REGISTER_OP(slice_like) } } return Array{ - topi::strided_slice(inputs[0], begin_idx, end_idx, strides) + topi::strided_slice(inputs[0], + GetIntArray(begin_idx), + GetIntArray(end_idx), + GetIntArray(strides)) }; }) .set_attr("FListInputNames", [](const NodeAttrs& attrs) { diff --git a/python/tvm/_ffi/node_generic.py b/python/tvm/_ffi/node_generic.py index b7230f29da59..e86453499faa 100644 --- a/python/tvm/_ffi/node_generic.py +++ b/python/tvm/_ffi/node_generic.py @@ -56,6 +56,8 @@ def convert_to_node(value): return _api_internal._Map(*vlist) elif isinstance(value, NodeGeneric): return value.asnode() + elif value is None: + return None else: raise ValueError("don't know how to convert type %s to node" % type(value)) diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index 9b581486608b..30aef433d7c6 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -13,6 +13,7 @@ # operator registry from . import _tensor +from . import _transform from ..expr import Expr from ..base import register_relay_node diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py new file mode 100644 index 000000000000..7867336d033f --- /dev/null +++ b/python/tvm/relay/op/_transform.py @@ -0,0 +1,8 @@ +#pylint: disable=invalid-name, unused-argument +"""Backend compiler related feature registration""" +from __future__ import absolute_import +from . import op as _reg +from .op import schedule_injective + +# strided_slice +_reg.register_schedule("strided_slice", schedule_injective) diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index 75365da5bf50..3525e23b8b20 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -47,7 +47,11 @@ TVM_REGISTER_API("_Array") .set_body([](TVMArgs args, TVMRetValue* ret) { std::vector > data; for (int i = 0; i < args.size(); ++i) { - data.push_back(args[i].node_sptr()); + if (args[i].type_code() != kNull) { + data.push_back(args[i].node_sptr()); + } else { + data.push_back(NodePtr(nullptr)); + } } auto node = make_node(); node->data = std::move(data); diff --git a/src/relay/ir/text_printer.cc b/src/relay/ir/text_printer.cc index f28db371706e..5bd381a8b4c1 100644 --- a/src/relay/ir/text_printer.cc +++ b/src/relay/ir/text_printer.cc @@ -401,7 +401,11 @@ class TextPrinter : * \param os The output type. */ void PrintAttr(const NodeRef& value, std::ostream& os) { // NOLINT(*) - this->VisitAttr(value, os); + if (value.defined()) { + this->VisitAttr(value, os); + } else { + os << "None"; + } } //------------------------------------ // Overload of Attr printing functions diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 31e7d2aaff84..98ac1c30b66c 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -7,6 +7,7 @@ #include #include #include +#include #include #include "../op_common.h" @@ -894,13 +895,12 @@ RELAY_REGISTER_OP("broadcast_to_like") // strided_slice TVM_REGISTER_NODE_TYPE(StridedSliceAttrs); bool StridedSliceRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); - CHECK(data != nullptr); - if (data->shape.size() == 0) return false; + if (data == nullptr) return false; const StridedSliceAttrs *param = attrs.as(); CHECK(param != nullptr); @@ -908,61 +908,87 @@ bool StridedSliceRel(const Array& types, auto dshape = data->shape; auto num_axis = dshape.size(); - std::vector begin_vec; - for (auto i : param->begin) { - begin_vec.push_back(i); + std::vector stride_vec; + for (Integer i : param->strides) { + CHECK(i.defined()); + stride_vec.push_back(i->value); } - for (auto i = begin_vec.size(); i < num_axis; ++i) { - begin_vec.push_back(0); + for (size_t i = stride_vec.size(); i < num_axis; ++i) { + stride_vec.push_back(1); } + const int64_t max_range = std::numeric_limits::max(); - std::vector end_vec; - for (auto i : param->end) { - end_vec.push_back(i); + std::vector begin_vec; + for (size_t i = 0; i < param->begin.size(); ++i) { + if (!param->begin[i].defined()) { + // value=None + begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range); + } else { + begin_vec.push_back(param->begin[i]->value); + } } - for (auto i = end_vec.size(); i < num_axis; ++i) { - end_vec.push_back(dshape[i]); + for (size_t i = begin_vec.size(); i < num_axis; ++i) { + begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range); } - std::vector stride_vec; - for (auto i : param->strides) { - stride_vec.push_back(i); + std::vector end_vec; + for (size_t i = 0; i < param->end.size(); ++i) { + // allow end to be None + if (!param->end[i].defined()) { + end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); + } else { + end_vec.push_back(param->end[i]->value); + } } - for (auto i = stride_vec.size(); i < num_axis; ++i) { - stride_vec.push_back(1); + for (size_t i = end_vec.size(); i < num_axis; ++i) { + end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); } - std::vector oshape(dshape.size()); + std::vector oshape(dshape.size()); for (size_t i = 0; i < num_axis; ++i) { - const int64_t* stride_t = as_const_int(stride_vec[i]); - CHECK(stride_t != nullptr) << "Stride cannot be symbolic."; - int64_t stride_v = stride_t[0]; - - const int64_t* begin_t = as_const_int(begin_vec[i]); - CHECK(begin_t != nullptr) << "Begin index cannot be symbolic."; - int64_t begin_v = begin_t[0]; - - const int64_t* end_t = as_const_int(end_vec[i]); - CHECK(end_t != nullptr) << "End index cannot be symbolic."; - int64_t end_v = end_t[0]; - - auto begin_range = make_const(Int(64), (stride_v < 0) ? -1 : 0); - auto end_range = (stride_v < 0) ? dshape[i] - 1 : dshape[i]; - auto begin = (begin_v < 0) ? dshape[i] + begin_vec[i] : begin_vec[i]; - auto end = (end_v < 0) ? dshape[i] + end_vec[i] : end_vec[i]; - - begin = min(max(begin, begin_range), end_range); - end = min(max(end, begin_range), end_range); - auto interval = abs((end - begin)); - auto slice_size = (interval + abs(stride_vec[i]) - 1) / abs(stride_vec[i]); - - CHECK(reporter->Assert(stride_vec[i] < 0) ? - reporter->Assert(end < begin) : reporter->Assert(begin < end)) - << ": Input [Begin=" << begin_vec[i] << ", End=" << end_vec[i] - << "] is invalid for axis=" << i; - oshape[i] = slice_size; + int64_t stride_v = stride_vec[i]; + int64_t begin_v = begin_vec[i]; + int64_t end_v = end_vec[i]; + + if ((stride_v == 1 && + begin_v == 0 && + end_v == max_range) || + (stride_v == -1 && + begin_v == max_range && + end_v == 0)) { + // Quick path, do not slice this dimension. + oshape[i] = dshape[i]; + continue; + } + // Normal path, require the shape to be concrete integer. + // Require concrete integer as symbolic inference of min/max + // can get complicated and not very helpful. + const int64_t* p_dim_size = as_const_int(dshape[i]); + CHECK(p_dim_size) + << "strided_slice requires sliced dimension to be concrete int"; + int64_t dim_size = p_dim_size[0]; + begin_v = (begin_v < 0) ? dim_size + begin_v : begin_v; + end_v = (end_v < 0) ? dim_size + end_v : end_v; + + int64_t slice_range, step; + if (stride_v < 0) { + if (end_v < -1) end_v = -1; + CHECK_LT(end_v, begin_v) + << "strided_slice get empty slice at axis " << i; + begin_v = std::min(dim_size - 1, begin_v); + slice_range = begin_v - end_v; + step = -stride_v; + } else { + if (begin_v < 0) begin_v = 0; + CHECK_GE(stride_v, 0); + CHECK_LT(begin_v, end_v) + << "strided_slice get empty slice at axis " << i; + end_v = std::min(dim_size, end_v); + slice_range = end_v - begin_v; + step = stride_v; + } + oshape[i] = make_const(dshape[i].type(), (slice_range + step - 1) / step); } - reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); return true; } @@ -970,9 +996,9 @@ bool StridedSliceRel(const Array& types, // Positional relay function to create StridedSlice operator used by frontend FFI. Expr MakeStridedSlice(Expr data, - Array begin, - Array end, - Array strides) { + Array begin, + Array end, + Array strides) { auto attrs = make_node(); attrs->begin = std::move(begin); attrs->end = std::move(end); @@ -981,6 +1007,17 @@ Expr MakeStridedSlice(Expr data, return CallNode::make(op, {data}, Attrs(attrs), {}); } +Array StridedSliceCompute(const Attrs& attrs, + const Array& inputs, + const Type& out_type, + const Target& target) { + const StridedSliceAttrs *param = attrs.as(); + CHECK(param != nullptr); + return Array{ + topi::strided_slice(inputs[0], param->begin, param->end, param->strides) + }; +} + TVM_REGISTER_API("relay.op._make.strided_slice") .set_body([](const TVMArgs& args, TVMRetValue* rv) { @@ -1016,7 +1053,10 @@ Examples:: .add_argument("data", "Tensor", "The input tensor.") .set_support_level(4) .set_attrs_type_key("relay.attrs.StridedSliceAttrs") -.add_type_rel("StridedSlice", StridedSliceRel); +.add_type_rel("StridedSlice", StridedSliceRel) +.set_attr("FTVMCompute", StridedSliceCompute) +.set_attr("TOpPattern", kInjective); + // Split TVM_REGISTER_NODE_TYPE(SplitAttrs); diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index e98beaf07752..dd12dc7cff3a 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -2,7 +2,7 @@ import numpy as np from tvm import relay from tvm.relay.testing import ctx_list - +import topi.testing def test_binary_op(): def check_binary_op(opfunc, ref): @@ -143,35 +143,44 @@ def test_reduce_functions(): verify_reduce(func, (128, 24, 128), (0, 2), True, False, (1, 24, 1)) -def verify_strided_slice(data, begin, end, stride, output): - x = relay.var("x", relay.TensorType(data, "float32")) - z = relay.strided_slice(x, begin=begin, end=end, strides=stride) - zz = relay.ir_pass.infer_type(z) - assert "begin=" in z.astext() - assert "end=" in z.astext() - if stride: - assert "strides=" in z.astext() - if output: - assert zz.checked_type == relay.ty.TensorType(output, "float32") - def test_strided_slice(): + def verify(dshape, begin, end, strides, output, test_ref=True): + x = relay.var("x", relay.TensorType(dshape, "float32")) + z = relay.strided_slice(x, begin=begin, end=end, strides=strides) + func = relay.Function([x], z) + func = relay.ir_pass.infer_type(func) + text = func.astext() + assert "begin=" in text + assert "end=" in text + if output: + assert func.body.checked_type == relay.ty.TensorType(output, "float32") + if not test_ref: + return + x_data = np.random.uniform(size=dshape).astype("float32") + ref_res = topi.testing.strided_slice_python( + x_data, begin, end, strides) + for target, ctx in ctx_list(): + intrp = relay.create_executor("graph", ctx=ctx, target=target) + op_res = intrp.evaluate(func)(x_data) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res) + d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4") - verify_strided_slice((d1, d2, d3), [0, 0, 0], [4, -5, 4], [1, -1, 2], None) - verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2], (3, 1, 2)) - verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], (1, 3, 3)) - verify_strided_slice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3)) - verify_strided_slice((3, 4, 3), [1, 0, 0], [2, 2, 3], [1, 1, 2], (1, 2, 2)) - verify_strided_slice((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3)) - verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3)) - verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 1000, 3], None, (2, 3, 3)) - verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4], None, (2, 3, 3)) - verify_strided_slice((3, 4, 3), [1, 1], [4, 4, 3], None, (2, 3, 3)) + verify((d1, d2, 3), [None, None, 1], [None, None, 2], None, (d1, d2, 1), False) + verify((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2], (3, 1, 2)) + verify((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], (1, 3, 3)) + verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3)) + verify((3, 4, 3), [1, 0, 0], [2, 2, 3], [1, 1, 2], (1, 2, 2)) + verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3)) + verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3)) + verify((3, 4, 3), [1, 1, 0], [4, 1000, 3], None, (2, 3, 3)) + verify((3, 4, 3), [1, 1, 0], [4, 4], None, (2, 3, 3)) + verify((3, 4, 3), [1, 1], [4, 4, 3], None, (2, 3, 3)) if __name__ == "__main__": + test_strided_slice() test_binary_op() test_cmp_type() test_binary_int_broadcast() test_where() test_reduce_functions() - test_strided_slice() diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index 7fc408c2c79c..cb09f1cb419e 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -10,6 +10,7 @@ #include #include #include +#include #include "topi/tags.h" #include "topi/detail/ravel_unravel.h" @@ -403,31 +404,51 @@ inline Array split(const Tensor& x, * \return A Tensor whose op member is the split operation */ inline Tensor strided_slice(const Tensor& x, - const Array& begin, - const Array& end, - const Array& strides, + const Array& begin, + const Array& end, + const Array& strides, std::string name = "tensor", std::string tag = kInjective) { size_t src_tensor_dim = static_cast(x->shape.size()); - std::vector begin_vec = GetConstInt64Values(begin, "begin"); - std::vector end_vec = GetConstInt64Values(end, "end"); - std::vector stride_vec = GetConstInt64Values(strides, "strides"); - // in case user has not provided begin indices for all the axes, - // then inflate it with default value = 0 - for (size_t i = begin_vec.size(); i < src_tensor_dim; ++i) { - begin_vec.push_back(0); - } - // in case user has not provided end indices for all the axes, - // then inflate it with default value = input_tensor.shape[axis] - for (size_t i = end_vec.size(); i < src_tensor_dim; ++i) { - end_vec.push_back(GetConstInt(x->shape[i])); + // Setup the ranges. + // NOTE: this code duplicates the shape inference logic relay.op + // Consider to refactor in the future. + std::vector stride_vec; + for (Integer i : strides) { + CHECK(i.defined()); + stride_vec.push_back(i->value); } - // in case user has not provided stride values, - // then inflate it with default value = 1 for (size_t i = stride_vec.size(); i < src_tensor_dim; ++i) { stride_vec.push_back(1); } + const int64_t max_range = std::numeric_limits::max(); + + std::vector begin_vec; + for (size_t i = 0; i < begin.size(); ++i) { + if (!begin[i].defined()) { + // value=None + begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range); + } else { + begin_vec.push_back(begin[i]->value); + } + } + for (size_t i = begin_vec.size(); i < src_tensor_dim; ++i) { + begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range); + } + std::vector end_vec; + for (size_t i = 0; i < end.size(); ++i) { + // allow end to be None + if (!end[i].defined()) { + end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); + } else { + end_vec.push_back(end[i]->value); + } + } + for (size_t i = end_vec.size(); i < src_tensor_dim; ++i) { + end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); + } + // Compute Array out_shape; Array begin_expr; Array strides_expr; diff --git a/topi/python/topi/testing/__init__.py b/topi/python/topi/testing/__init__.py index 8a3269ba83ae..c496e08c1835 100644 --- a/topi/python/topi/testing/__init__.py +++ b/topi/python/topi/testing/__init__.py @@ -19,3 +19,4 @@ from .lrn_python import lrn_python from .l2_normalize_python import l2_normalize_python from .gather_nd_python import gather_nd_python +from .strided_slice_python import strided_slice_python diff --git a/topi/python/topi/testing/strided_slice_python.py b/topi/python/topi/testing/strided_slice_python.py new file mode 100644 index 000000000000..4407b3bec1c7 --- /dev/null +++ b/topi/python/topi/testing/strided_slice_python.py @@ -0,0 +1,32 @@ +"""gather_nd in python""" + +def strided_slice_python(data, begin, end, strides): + """Python version of strided slice operator. + + Parameters + ---------- + data : numpy.ndarray + Input data + + begin : list + Begining of the slices. + + end : list + End of the slices. + + strides : list + The stride of each slice. + + Returns + ------- + result : numpy.ndarray + The sliced result. + """ + strides = [] if strides is None else strides + slices = [] + for i in range(len(data.shape)): + slices.append(slice( + begin[i] if i < len(begin) else None, + end[i] if i < len(end) else None, + strides[i] if i < len(strides) else None)) + return data[tuple(slices)] diff --git a/topi/tests/python/test_topi_transform.py b/topi/tests/python/test_topi_transform.py index 75e4d3b675b0..dc3c3fb70b24 100644 --- a/topi/tests/python/test_topi_transform.py +++ b/topi/tests/python/test_topi_transform.py @@ -249,13 +249,11 @@ def check_device(device): for device in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]: check_device(device) -def verify_strided_slice(in_shape, begin, end, stride=None): - stride = stride if stride else [1, 1, 1] +def verify_strided_slice(in_shape, begin, end, strides=None): A = tvm.placeholder(shape=in_shape, name="A") - B = topi.strided_slice(A, begin, end, stride) + 1 - def test_forward(x, begin, end, stride): - return x[begin[0]:end[0]:stride[0], - begin[1]:end[1]:stride[1], begin[2]:end[2]:stride[2]] + 1 + strides = [1,1,1] if strides is None else strides + B = topi.strided_slice(A, begin, end, strides) + 1 + def check_device(device): ctx = tvm.context(device, 0) if not ctx.exist: @@ -267,7 +265,8 @@ def check_device(device): foo = tvm.build(s, [A, B], device, name="stride_slice") x_np = np.random.uniform(size=in_shape).astype(A.dtype) - out_npy = test_forward(x_np, begin, end, stride) + out_npy = topi.testing.strided_slice_python( + x_np, begin, end, strides) + 1 data_nd = tvm.nd.array(x_np, ctx) out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=A.dtype) foo(data_nd, out_nd) @@ -298,7 +297,7 @@ def check_device(device): shape_size = shape_size * src_shape[i] data_npy = np.arange(shape_size, dtype=src_dtype).reshape((src_shape)) out_npys = topi.testing.gather_nd_python(data_npy, indices_src) - + data_nd = tvm.nd.array(data_npy, ctx) indices_nd = tvm.nd.array(indices_src, ctx) out_nd = tvm.nd.empty(out_npys.shape, ctx=ctx, dtype=src_dtype) @@ -412,6 +411,7 @@ def test_gather_nd(): indices_dtype) if __name__ == "__main__": + test_strided_slice() test_concatenate() test_tranpose() test_expand_dims() @@ -421,5 +421,4 @@ def test_gather_nd(): test_flip() test_expand_like() test_take() - test_strided_slice() test_gather_nd()