-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[RELAY][OP]Strided slice #1891
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RELAY][OP]Strided slice #1891
Conversation
4f8aae0 to
b0688f1
Compare
|
@MarisaKirisame @yzhliu @yuruofeifei @srkreddy1238 @tqchen please review. |
include/tvm/relay/attrs/transform.h
Outdated
| TVM_ATTR_FIELD(begin) | ||
| .describe("Indices for begin of slice"); | ||
| TVM_ATTR_FIELD(end) | ||
| .describe("Indices for end of the slice"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
describe inclusive, exclusive?
| } | ||
|
|
||
| reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); | ||
| return true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
undef min/max somewhere
45efc2f to
32a2e69
Compare
|
please rebase against master after #1934 to make use of the newly introduced API and add test-case to make sure text format works. Thanks! |
15b6a24 to
4e2edf7
Compare
deb1df4 to
41b6e19
Compare
|
@tqchen this can be merged? Anything else need to be done? can you please review once again and let me know. Thanks. |
src/relay/op/tensor/transform.cc
Outdated
| std::vector<IndexExpr> oshape(dshape.size()); | ||
|
|
||
| for (size_t i = 0; i < num_axis; ++i) { | ||
| auto begin_range = reporter->Assert(stride_vec[i] < 0) ? -1 : 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
using Assert here is not reliable, as if stride_vec is symbolic, then assert does not reflect anything.
include/tvm/relay/attrs/transform.h
Outdated
|
|
||
| /*! \brief Attributes for StridedSlice operator */ | ||
| struct StridedSliceAttrs : public tvm::AttrsNode<StridedSliceAttrs> { | ||
| Array<IndexExpr> begin; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let us change Array<IndexExpr> -> Array<Integer> for now, as integer is really what we can do reliably so far
include/tvm/relay/attrs/transform.h
Outdated
| struct StridedSliceAttrs : public tvm::AttrsNode<StridedSliceAttrs> { | ||
| Array<IndexExpr> begin; | ||
| Array<IndexExpr> end; | ||
| Array<IndexExpr> stride; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stride->strides as per https://www.tensorflow.org/api_docs/python/tf/strided_slice
| auto begin = reporter->Assert(begin_vec[i] < 0) ? dshape[i] + begin_vec[i] : begin_vec[i]; | ||
| auto end = reporter->Assert(end_vec[i] < 0) ? dshape[i] + end_vec[i] : end_vec[i]; | ||
|
|
||
| begin = min(max(begin, begin_range), end_range); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use std::min std::max,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
begin and end is derived from dshape which is symbolic, so cannot use std::min/max
All other comments are fixed.
| verify_reduce(func, (128, 24, 128), (0, 2), False, False, (24,)) | ||
| verify_reduce(func, (128, 24, 128), (0, 1), True, False, (1, 1, 128)) | ||
| verify_reduce(func, (128, 24, 128), (0, 2), True, False, (1, 24, 1)) | ||
| def verify_strided_slice(data, begin, end, stride, output): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
newline between functions
c9e48d3 to
3c3f641
Compare
3c3f641 to
ba9864d
Compare
|
Given that there is still some gap and we need this OP in quickly, I am opening a followup #2094 which is based on this PR. |
|
Thanks @siju-samuel @MarisaKirisame |
#1799
Thanks for contributing to TVM! Please refer to guideline https://docs.tvm.ai/contribute/ for useful information and tips. After the pull request is submitted, please request code reviews from others in the community.