diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index d1549cd8326e..cbccd76b284b 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -122,6 +122,7 @@ This level enables additional math and transform operators. tvm.relay.min tvm.relay.mean tvm.relay.prod + tvm.relay.strided_slice **Level 5: Vision/Image Operators** @@ -225,6 +226,7 @@ Level 4 Definitions .. autofunction:: tvm.relay.min .. autofunction:: tvm.relay.mean .. autofunction:: tvm.relay.prod +.. autofunction:: tvm.relay.strided_slice diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index dfad1013701f..afb2030979ef 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -122,6 +122,21 @@ struct SplitAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes for StridedSlice operator */ +struct StridedSliceAttrs : public tvm::AttrsNode { + Array begin; + Array end; + Array strides; + + TVM_DECLARE_ATTRS(StridedSliceAttrs, "relay.attrs.StridedSliceAttrs") { + TVM_ATTR_FIELD(begin) + .describe("Indices for begin of slice, begin index is also inclusive"); + TVM_ATTR_FIELD(end) + .describe("Indices for end of slice, end index is also inclusive"); + TVM_ATTR_FIELD(strides).set_default(Array({})) + .describe("Stride values of the slice"); + } +}; } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_TRANSFORM_H_ diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 9d14463a530c..e737ff9ed950 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -336,3 +336,30 @@ def split(data, indices_or_sections, axis=0): else: ret_size = len(indices_or_sections) + 1 return TupleWrapper(_make.split(data, indices_or_sections, axis), ret_size) + + +def strided_slice(data, begin, end, strides=None): + """Strided slice of an array.. + + Parameters + ---------- + data : relay.Expr + The source array to be sliced. + + begin: list of int + The indices to begin with in the slicing. + + end: list of int + Indicies indicating end of the slice. + + strides: list of int, optional + Specifies the stride values, it can be negative in that case, + the input tensor will be reversed in that particular axis. + + Returns + ------- + ret : relay.Expr + The computed result. + """ + strides = strides or [] + return _make.strided_slice(data, list(begin), list(end), list(strides)) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 5faa0805426a..cbcee320b44d 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -892,6 +892,134 @@ RELAY_REGISTER_OP("broadcast_to_like") .set_support_level(10) .add_type_rel("BroadCastToLike", BroadCastToLikeRel); + +// strided_slice +TVM_REGISTER_NODE_TYPE(StridedSliceAttrs); +bool StridedSliceRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + CHECK(data != nullptr); + if (data->shape.size() == 0) return false; + + const StridedSliceAttrs *param = attrs.as(); + CHECK(param != nullptr); + + auto dshape = data->shape; + auto num_axis = dshape.size(); + + std::vector begin_vec; + for (auto i : param->begin) { + begin_vec.push_back(i); + } + for (auto i = begin_vec.size(); i < num_axis; ++i) { + begin_vec.push_back(0); + } + + std::vector end_vec; + for (auto i : param->end) { + end_vec.push_back(i); + } + for (auto i = end_vec.size(); i < num_axis; ++i) { + end_vec.push_back(dshape[i]); + } + + std::vector stride_vec; + for (auto i : param->strides) { + stride_vec.push_back(i); + } + for (auto i = stride_vec.size(); i < num_axis; ++i) { + stride_vec.push_back(1); + } + std::vector oshape(dshape.size()); + + for (size_t i = 0; i < num_axis; ++i) { + const int64_t* stride_t = as_const_int(stride_vec[i]); + CHECK(stride_t != nullptr) << "Stride cannot be symbolic."; + int64_t stride_v = stride_t[0]; + + const int64_t* begin_t = as_const_int(begin_vec[i]); + CHECK(begin_t != nullptr) << "Begin index cannot be symbolic."; + int64_t begin_v = begin_t[0]; + + const int64_t* end_t = as_const_int(end_vec[i]); + CHECK(end_t != nullptr) << "End index cannot be symbolic."; + int64_t end_v = end_t[0]; + + auto begin_range = make_const(Int(64), (stride_v < 0) ? -1 : 0); + auto end_range = (stride_v < 0) ? dshape[i] - 1 : dshape[i]; + auto begin = (begin_v < 0) ? dshape[i] + begin_vec[i] : begin_vec[i]; + auto end = (end_v < 0) ? dshape[i] + end_vec[i] : end_vec[i]; + + begin = min(max(begin, begin_range), end_range); + end = min(max(end, begin_range), end_range); + auto interval = abs((end - begin)); + auto slice_size = (interval + abs(stride_vec[i]) - 1) / abs(stride_vec[i]); + + CHECK(reporter->Assert(stride_vec[i] < 0) ? + reporter->Assert(end < begin) : reporter->Assert(begin < end)) + << ": Input [Begin=" << begin_vec[i] << ", End=" << end_vec[i] + << "] is invalid for axis=" << i; + oshape[i] = slice_size; + } + + reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); + return true; +} + + +// Positional relay function to create StridedSlice operator used by frontend FFI. +Expr MakeStridedSlice(Expr data, + Array begin, + Array end, + Array strides) { + auto attrs = make_node(); + attrs->begin = std::move(begin); + attrs->end = std::move(end); + attrs->strides = std::move(strides); + static const Op& op = Op::Get("strided_slice"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + + +TVM_REGISTER_API("relay.op._make.strided_slice") + .set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeStridedSlice, args, rv); + }); + + +RELAY_REGISTER_OP("strided_slice") + .describe(R"code(Strided slice of an array. + +Examples:: + + x = [[ 1., 4., 7., 10.], + [ 2., 5., 8., 11.], + [ 3., 6., 9., 12.]] + + strided_slice(x, begin=[0, 1], end=[2, 4], stride=[1, 1]) = [[ 4., 7., 10.], + [ 5., 8., 11.]] + + x = [[[ 1., 2.], + [ 3., 4.]], + + [[ 5., 6.], + [ 7., 8.]]] + + strided_slice(x, begin=[0, 0], end=[2, 2]) = [[[ 1., 2.], + [ 3., 4.]], + + [[ 5., 6.], + [ 7., 8.]]] +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.add_argument("data", "Tensor", "The input tensor.") +.set_support_level(4) +.set_attrs_type_key("relay.attrs.StridedSliceAttrs") +.add_type_rel("StridedSlice", StridedSliceRel); + // Split TVM_REGISTER_NODE_TYPE(SplitAttrs); diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index 2dc643cfd7e4..bb6c28f0cf62 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -91,9 +91,36 @@ def test_reduce_functions(): verify_reduce(func, (128, 24, 128), (0, 1), True, False, (1, 1, 128)) verify_reduce(func, (128, 24, 128), (0, 2), True, False, (1, 24, 1)) + +def verify_strided_slice(data, begin, end, stride, output): + x = relay.var("x", relay.TensorType(data, "float32")) + z = relay.strided_slice(x, begin=begin, end=end, strides=stride) + zz = relay.ir_pass.infer_type(z) + assert "begin=" in z.astext() + assert "end=" in z.astext() + if stride: + assert "strides=" in z.astext() + if output: + assert zz.checked_type == relay.ty.TensorType(output, "float32") + +def test_strided_slice(): + d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4") + verify_strided_slice((d1, d2, d3), [0, 0, 0], [4, -5, 4], [1, -1, 2], None) + verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2], (3, 1, 2)) + verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], (1, 3, 3)) + verify_strided_slice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3)) + verify_strided_slice((3, 4, 3), [1, 0, 0], [2, 2, 3], [1, 1, 2], (1, 2, 2)) + verify_strided_slice((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3)) + verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3)) + verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 1000, 3], None, (2, 3, 3)) + verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4], None, (2, 3, 3)) + verify_strided_slice((3, 4, 3), [1, 1], [4, 4, 3], None, (2, 3, 3)) + + if __name__ == "__main__": test_binary_op() test_cmp_type() test_binary_int_broadcast() test_where() test_reduce_functions() + test_strided_slice()