From 64dd23520f32d087339c6c0852a91b0a0b661033 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 11 May 2021 15:23:24 +0900 Subject: [PATCH 1/6] Supporting dynamic slice on first few axes --- include/tvm/topi/transform.h | 13 +++++++++++-- python/tvm/relay/frontend/onnx.py | 6 +----- python/tvm/relay/op/transform.py | 4 ++-- src/relay/op/dyn/tensor/transform.cc | 21 +++++++++++++++------ 4 files changed, 29 insertions(+), 15 deletions(-) diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 379234a5c65a..dfb6326b567a 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -569,16 +569,25 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b std::string tag = topi::kInjective) { int64_t src_tensor_dim = x->shape.size(); Array out_shape; - for (int64_t i = 0; i < src_tensor_dim; ++i) { + const int64_t num_dynamic_axes = begin->shape[0].as()->value; + for (int64_t i = 0; i < num_dynamic_axes; ++i) { out_shape.push_back(tvm::tir::Var("dim")); } + for (int64_t i = num_dynamic_axes; i < src_tensor_dim; ++i) { + out_shape.push_back(x->shape[i]); + } return te::compute( out_shape, [&](const Array& indices) { Array real_indices; - for (int32_t i = 0; i < src_tensor_dim; ++i) { + // dynamic slicing + for (int32_t i = 0; i < num_dynamic_axes; ++i) { real_indices.push_back(indices[i] * strides(i) + tvm::min(begin(i), x->shape[i] - 1)); } + // keep input dim + for (int64_t i = num_dynamic_axes; i < src_tensor_dim; ++i) { + real_indices.push_back(indices[i]); + } return x(real_indices); }, name, tag); diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 2a57cba53cd2..b9fabdebb330 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2703,11 +2703,7 @@ def conditionally_squeeze_scalar(x): boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold ) - three = _op.const(np.array([3]), dtype="int64") - begin = _op.const(np.array([0, 0]), dtype="int64") - end = _op.concatenate([nms_out[1], three], axis=0) - strides = _op.const(np.array([1, 1]), dtype="int64") - return _op.strided_slice(nms_out[0], begin, end, strides) + return _op.strided_slice(nms_out[0], _op.const([0], dtype="int64"), nms_out[1]) class ATen(OnnxOpConverter): diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 4e406e81ef68..68bcc4870cf0 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -912,10 +912,10 @@ def strided_slice(data, begin, end, strides=None, slice_mode="end"): if isinstance(strides, (tuple, list)): strides = const(list(strides)) begin = _make.where( - begin < cast_like(const(0), begin), begin + cast_like(shape_of(data), begin), begin + begin < cast_like(const(0), begin), begin + cast_like(shape_of(begin), begin), begin ) begin = _make.where( - begin >= cast_like(shape_of(data), begin), cast_like(shape_of(data), begin), begin + begin >= cast_like(shape_of(begin), begin), cast_like(shape_of(begin), begin), begin ) return _dyn_make.strided_slice(data, begin, end, strides, slice_mode) return _make.strided_slice(data, begin, end, strides, slice_mode) diff --git a/src/relay/op/dyn/tensor/transform.cc b/src/relay/op/dyn/tensor/transform.cc index cf8f3689b045..5ed5351b7fb1 100644 --- a/src/relay/op/dyn/tensor/transform.cc +++ b/src/relay/op/dyn/tensor/transform.cc @@ -466,12 +466,20 @@ bool StridedSliceRel(const Array& types, int num_inputs, const Attrs& attr auto dshape = data->shape; int64_t num_axis = dshape.size(); + const auto* begin = types[1].as(); + ICHECK(begin); + // calculate output shape std::vector oshape(num_axis); - for (int64_t i = 0; i < num_axis; ++i) { + int64_t num_dynamic_axes = begin->shape[0].as()->value; + for (size_t i = 0; i < num_dynamic_axes; ++i) { oshape[i] = Any(); } + for (int64_t i = num_dynamic_axes; i < num_axis; ++i) { + oshape[i] = dshape[i]; + } + reporter->Assign(types[4], TensorType(oshape, data->dtype)); return true; } @@ -484,11 +492,12 @@ Array StridedSliceCompute(const Attrs& attrs, const Arrayshape.size(); - ICHECK(begin->shape[0].as()->value == data_rank && - end->shape[0].as()->value == data_rank && - strides->shape[0].as()->value == data_rank) - << "begin, end, and strides are required to have the same length" - << " if they are dynamic variables."; + int64_t num_dynamic_axes = begin->shape[0].as()->value; + ICHECK(end->shape[0].as()->value == num_dynamic_axes && + strides->shape[0].as()->value == num_dynamic_axes) + << "begin, end, strides should have the same length if they are dynamic variables"; + ICHECK(num_dynamic_axes <= data_rank) + << "the number of dynamic axes to slice should be less than or equal to the data rank"; return Array{topi::dynamic_strided_slice(data, begin, end, strides)}; } From 9b36e7678ebcdd4a0c0c89bfda2f39404b7afc7b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 18 May 2021 19:20:44 +0900 Subject: [PATCH 2/6] fix index normalization --- python/tvm/relay/op/transform.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 68bcc4870cf0..8bd61b5b8a84 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -911,11 +911,14 @@ def strided_slice(data, begin, end, strides=None, slice_mode="end"): end = const(list(end)) if isinstance(strides, (tuple, list)): strides = const(list(strides)) + + ishape = cast_like(shape_of(data), begin) + ishape_slice = slice_like(ishape, begin) begin = _make.where( - begin < cast_like(const(0), begin), begin + cast_like(shape_of(begin), begin), begin + begin < cast_like(const(0), begin), begin + ishape_slice, begin ) begin = _make.where( - begin >= cast_like(shape_of(begin), begin), cast_like(shape_of(begin), begin), begin + begin >= ishape_slice, ishape_slice, begin ) return _dyn_make.strided_slice(data, begin, end, strides, slice_mode) return _make.strided_slice(data, begin, end, strides, slice_mode) From bd3972e4fd9559c8dacfa1894e2aa8c255e830a5 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 18 May 2021 19:39:57 +0900 Subject: [PATCH 3/6] update dynamic slice tests --- .../relay/dyn/test_dynamic_op_level4.py | 44 ++++++++++++------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/tests/python/relay/dyn/test_dynamic_op_level4.py b/tests/python/relay/dyn/test_dynamic_op_level4.py index 43e5beba199f..cf5a20e3193f 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level4.py +++ b/tests/python/relay/dyn/test_dynamic_op_level4.py @@ -25,16 +25,19 @@ @tvm.testing.uses_gpu def test_dynamic_strided_slice(): - def verify(dshape, begin, end, strides, output, slice_mode="end", test_ref=True, dtype="int32"): + def verify(dshape, begin, end, strides, slice_mode="end", test_ref=True, dtype="int32"): x = relay.var("x", relay.TensorType(dshape, "float32")) ndim = len(dshape) + slice_dim = len(begin) begin = begin if begin else [0] * ndim - end = end if end else list(dshape) + end = end if end else list(dshape)[:slice_dim] if strides: if len(strides) == 1: - strides = strides * ndim + strides = strides * slice_dim else: - strides = [1] * ndim + strides = [1] * slice_dim + + num_static_axes = len(dshape) - len(begin) # target numpy result x_data = np.random.uniform(size=dshape).astype("float32") @@ -54,7 +57,10 @@ def verify(dshape, begin, end, strides, output, slice_mode="end", test_ref=True, func = relay.Function(inputs, z) func = run_infer_type(func) - text = func.astext() + + if num_static_axes > 0: + oshape = run_infer_type(z).checked_type.shape + assert tuple(oshape[-num_static_axes:]) == dshape[-num_static_axes:] if not test_ref: return @@ -69,22 +75,26 @@ def verify(dshape, begin, end, strides, output, slice_mode="end", test_ref=True, [0, 20, 20, 0], [1, 140, 140, 3], [1, 1, 1, 1], - (1, 120, 120, 3), dtype="int64", ) - verify((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], (1, 3, 3), dtype="int16") - verify((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2], (3, 1, 2)) - verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3)) - verify((3, 4, 3), [1, 1, 0], [4, 1000, 3], None, (2, 3, 3)) - verify((3, 4, 3), [1, 1, 0], [4, 4, 4], None, (2, 3, 3)) - verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3)) - verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3)) - verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3)) - verify((20, 10, 5), [20, 10, 4], [0, 0, 1], [-1, -3, -2], (19, 3, 2)) + verify((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], dtype="int16") + verify((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2]) + verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None) + verify((3, 4, 3), [1, 1, 0], [4, 1000, 3], None) + verify((3, 4, 3), [1, 1, 0], [4, 4, 4], None) + verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None) + verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1]) + verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1]) + verify((20, 10, 5), [20, 10, 4], [0, 0, 1], [-1, -3, -2]) verify( - (3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1], (2, 4, 3), slice_mode="size", test_ref=False + (3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1], slice_mode="size", test_ref=False ) - verify((3, 4, 3), [1, 0, 0], [-1, 2, 3], [1, 1, 1], (2, 2, 3), slice_mode="size", test_ref=True) + verify((3, 4, 3), [1, 0, 0], [-1, 2, 3], [1, 1, 1], slice_mode="size", test_ref=True) + + # Slicing along first few axes, where the rest of axes remain static + verify((3, 4, 3), [0], [2], None) + verify((3, 4, 3), [1], [4], [2]) + verify((3, 4, 3), [1, 0], [4, 2], [2, 1]) if __name__ == "__main__": From 72eb44595505919a5b5e3cffc139c46e7e423381 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 18 May 2021 20:00:00 +0900 Subject: [PATCH 4/6] pylint fix --- python/tvm/relay/op/transform.py | 8 ++------ tests/python/relay/dyn/test_dynamic_op_level4.py | 4 +--- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 8bd61b5b8a84..8744e7b5c6ad 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -914,12 +914,8 @@ def strided_slice(data, begin, end, strides=None, slice_mode="end"): ishape = cast_like(shape_of(data), begin) ishape_slice = slice_like(ishape, begin) - begin = _make.where( - begin < cast_like(const(0), begin), begin + ishape_slice, begin - ) - begin = _make.where( - begin >= ishape_slice, ishape_slice, begin - ) + begin = _make.where(begin < cast_like(const(0), begin), begin + ishape_slice, begin) + begin = _make.where(begin >= ishape_slice, ishape_slice, begin) return _dyn_make.strided_slice(data, begin, end, strides, slice_mode) return _make.strided_slice(data, begin, end, strides, slice_mode) diff --git a/tests/python/relay/dyn/test_dynamic_op_level4.py b/tests/python/relay/dyn/test_dynamic_op_level4.py index cf5a20e3193f..01e5056c72cb 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level4.py +++ b/tests/python/relay/dyn/test_dynamic_op_level4.py @@ -86,9 +86,7 @@ def verify(dshape, begin, end, strides, slice_mode="end", test_ref=True, dtype=" verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1]) verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1]) verify((20, 10, 5), [20, 10, 4], [0, 0, 1], [-1, -3, -2]) - verify( - (3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1], slice_mode="size", test_ref=False - ) + verify((3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1], slice_mode="size", test_ref=False) verify((3, 4, 3), [1, 0, 0], [-1, 2, 3], [1, 1, 1], slice_mode="size", test_ref=True) # Slicing along first few axes, where the rest of axes remain static From 728dac571d42ad3e89906bbfc6c1c3d463b4bef2 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 18 May 2021 20:46:18 +0900 Subject: [PATCH 5/6] fix loop index dtype --- include/tvm/topi/transform.h | 2 +- src/relay/op/dyn/tensor/transform.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index dfb6326b567a..441dd781d03c 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -581,7 +581,7 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b [&](const Array& indices) { Array real_indices; // dynamic slicing - for (int32_t i = 0; i < num_dynamic_axes; ++i) { + for (int64_t i = 0; i < num_dynamic_axes; ++i) { real_indices.push_back(indices[i] * strides(i) + tvm::min(begin(i), x->shape[i] - 1)); } // keep input dim diff --git a/src/relay/op/dyn/tensor/transform.cc b/src/relay/op/dyn/tensor/transform.cc index 5ed5351b7fb1..d8ee1c84a99c 100644 --- a/src/relay/op/dyn/tensor/transform.cc +++ b/src/relay/op/dyn/tensor/transform.cc @@ -472,7 +472,7 @@ bool StridedSliceRel(const Array& types, int num_inputs, const Attrs& attr // calculate output shape std::vector oshape(num_axis); int64_t num_dynamic_axes = begin->shape[0].as()->value; - for (size_t i = 0; i < num_dynamic_axes; ++i) { + for (int64_t i = 0; i < num_dynamic_axes; ++i) { oshape[i] = Any(); } From 4f1742a52cca87b26d99a51927f385249f6c82d1 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 18 May 2021 21:15:38 +0900 Subject: [PATCH 6/6] fix more dtype issue --- include/tvm/topi/transform.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 441dd781d03c..b2132b75fab9 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -581,11 +581,11 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b [&](const Array& indices) { Array real_indices; // dynamic slicing - for (int64_t i = 0; i < num_dynamic_axes; ++i) { + for (int32_t i = 0; i < num_dynamic_axes; ++i) { real_indices.push_back(indices[i] * strides(i) + tvm::min(begin(i), x->shape[i] - 1)); } // keep input dim - for (int64_t i = num_dynamic_axes; i < src_tensor_dim; ++i) { + for (int32_t i = num_dynamic_axes; i < src_tensor_dim; ++i) { real_indices.push_back(indices[i]); } return x(real_indices);