From 16220254d976d76b93234a91bf823dceda605041 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 29 Sep 2021 20:52:44 +0900 Subject: [PATCH 01/26] Add relay definition --- include/tvm/relay/attrs/algorithm.h | 13 +++++ src/relay/op/algorithm/searchsorted.cc | 66 ++++++++++++++++++++++++++ 2 files changed, 79 insertions(+) create mode 100644 src/relay/op/algorithm/searchsorted.cc diff --git a/include/tvm/relay/attrs/algorithm.h b/include/tvm/relay/attrs/algorithm.h index 83b4ddaead43..b15c2e687fdc 100644 --- a/include/tvm/relay/attrs/algorithm.h +++ b/include/tvm/relay/attrs/algorithm.h @@ -76,6 +76,19 @@ struct TopKAttrs : public tvm::AttrsNode { } }; +struct SearchSortedAttrs : public tvm::AttrsNode { + std::string side; + DataType dtype; + + TVM_DECLARE_ATTRS(SearchSortedAttrs, "relay.attrs.SearchSortedAttrs") { + TVM_ATTR_FIELD(side).set_default("left").describe( + "Controls which index is returned if a value lands exactly on one of sorted values."); + TVM_ATTR_FIELD(dtype) + .set_default(DataType::Int(64)) + .describe("Data type of the output indices."); + } +}; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_ALGORITHM_H_ diff --git a/src/relay/op/algorithm/searchsorted.cc b/src/relay/op/algorithm/searchsorted.cc new file mode 100644 index 000000000000..0bf964e344ac --- /dev/null +++ b/src/relay/op/algorithm/searchsorted.cc @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file topk.cc + * \brief TopK operators + */ +#include +#include +#include + +namespace tvm { +namespace relay { + +TVM_REGISTER_NODE_TYPE(SearchSortedAttrs); + +bool SearchSortedRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + const SearchSortedAttrs* param = attrs.as(); + ICHECK_EQ(types.size(), 3); + const auto* sorted_sequence = types[0].as(); + const auto* values = types[1].as(); + ICHECK(sorted_sequence) << "Expects TensorType in the first input"; + ICHECK(values) << "Expects TensorType in the second input"; + + return true; +} + +Expr MakeSearchSorted(Expr sorted_sequence, Expr values, String side, DataType dtype) { + auto attrs = make_object(); + static const Op& op = Op::Get("searchsorted"); + return Call(op, {sorted_sequence, values}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.searchsorted").set_body_typed(MakeSearchSorted); + +RELAY_REGISTER_OP("searchsorted") + .describe(R"doc(Find indices where elements should be inserted to maintain order. +)doc" TVM_ADD_FILELINE) + .set_num_inputs(2) + .set_attrs_type() + .add_argument("sorted_sequence", "Tensor", + "Monotonically increasing sequence on the innermost dimension.") + .add_argument("values", "Tensor", "Values to search for.") + .set_support_level(6) + .add_type_rel("SearchSorted", SearchSortedRel); + + +} // namespace relay +} // namespace tvm From a43e0c84d2904c6fee5e69c92017ceaeebc9f770 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 30 Sep 2021 21:12:27 +0900 Subject: [PATCH 02/26] 1D cpu test working --- python/tvm/topi/__init__.py | 1 + python/tvm/topi/searchsorted.py | 67 +++++++++++++++++++ .../topi/python/test_topi_searchsorted.py | 64 ++++++++++++++++++ 3 files changed, 132 insertions(+) create mode 100644 python/tvm/topi/searchsorted.py create mode 100644 tests/python/topi/python/test_topi_searchsorted.py diff --git a/python/tvm/topi/__init__.py b/python/tvm/topi/__init__.py index 6b22cf13f5b9..e243d6ee3bc7 100644 --- a/python/tvm/topi/__init__.py +++ b/python/tvm/topi/__init__.py @@ -45,6 +45,7 @@ from .scan import * from .einsum import * from .unique import * +from .searchsorted import * from . import generic from . import nn from . import x86 diff --git a/python/tvm/topi/searchsorted.py b/python/tvm/topi/searchsorted.py new file mode 100644 index 000000000000..6706930c271b --- /dev/null +++ b/python/tvm/topi/searchsorted.py @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""searchsorted operator""" +from . import utils +from . import te +from ..tir import ir_builder +from .math import cast + + +def searchsorted(sorted_sequence, values, side="left", out_dtype="int64"): + def binary_search(ib, sequence_offset, search_range, sorted_sequence, i, values, out_indices): + lo = ib.allocate(out_dtype, (1,), name="lo", scope="local") + hi = ib.allocate(out_dtype, (1,), name="hi", scope="local") + + v = values[i] + lo[0] = cast(0, out_dtype) + hi[0] = cast(search_range, out_dtype) + + with ib.while_loop(lo[0] < hi[0]): + mid = lo[0] + (hi[0] - lo[0] >> 1) + with ib.if_scope(sorted_sequence[sequence_offset + mid] < v): + lo[0] = mid + 1 + with ib.else_scope(): + hi[0] = mid + + out_indices[i] = lo[0] + + def ir(sorted_sequence, values, indices): + ib = ir_builder.create() + sorted_sequence_shape = sorted_sequence.shape + values_shape = values.shape + num_search = utils.prod(values_shape) + num_sequence = utils.prod(sorted_sequence_shape[:-1]) + search_range = sorted_sequence_shape[-1] + + sorted_sequence = ib.buffer_ptr(sorted_sequence) + values = ib.buffer_ptr(values) + indices = ib.buffer_ptr(indices) + + with ib.for_range(0, num_search, name="i", kind="parallel") as i: + sequence_id = i // values_shape[-1] + sequence_offset = sequence_id * search_range + binary_search(ib, sequence_offset, search_range, sorted_sequence, i, values, indices) + + return ib.get() + + return te.extern( + values.shape, + [sorted_sequence, values], + lambda ins, outs: ir(ins[0], ins[1], outs[0]), + name="searchsorted_ir", + dtype=out_dtype, + ) diff --git a/tests/python/topi/python/test_topi_searchsorted.py b/tests/python/topi/python/test_topi_searchsorted.py new file mode 100644 index 000000000000..60b2e4501e59 --- /dev/null +++ b/tests/python/topi/python/test_topi_searchsorted.py @@ -0,0 +1,64 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import tvm +import tvm.testing +import tvm.topi.testing +from tvm import te, topi + +topi_funcs = {"searchsorted": {"generic": topi.searchsorted}} + + +def get_implementations(name, axis, dtype, exclusive): + topi_func_generic = topi_funcs[name]["generic"] + # topi_func_cuda = topi_funcs[name]["cuda"] + + return { + "generic": ( + lambda x: topi_func_generic(x, axis, dtype, exclusive=exclusive), + topi.generic.schedule_extern, + ), + } + + +@tvm.testing.parametrize_targets +def test_cumsum(dev, target): + n = 1024 + A = te.placeholder((n,), name="A", dtype="float32") + B = te.placeholder((n,), name="B", dtype="float32") + C = topi.searchsorted(A, B) + s = te.create_schedule(C.op) + + with tvm.transform.PassContext(opt_level=3): + func = tvm.build(s, [A, B, C], target) + + dev = tvm.device(target, 0) + a_np = np.random.uniform(size=n).astype(A.dtype) + b_np = np.random.uniform(size=n).astype(B.dtype) + a_np = np.sort(a_np) + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(b_np, dev) + c = tvm.nd.array(np.zeros(n, dtype=C.dtype), dev) + func(a, b, c) + ref = np.searchsorted(a_np, b_np) + tvm.testing.assert_allclose(c.numpy(), ref) + print("ok") + + +if __name__ == "__main__": + target = "llvm" + test_cumsum(tvm.device(target, 0), target) From 36cb3bee3c949fcdac98a7c94dfa9f88a7ef9bbe Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 30 Sep 2021 21:29:44 +0900 Subject: [PATCH 03/26] multi dim working --- .../topi/python/test_topi_searchsorted.py | 36 +++++++++++++------ 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/tests/python/topi/python/test_topi_searchsorted.py b/tests/python/topi/python/test_topi_searchsorted.py index 60b2e4501e59..d3da73616db3 100644 --- a/tests/python/topi/python/test_topi_searchsorted.py +++ b/tests/python/topi/python/test_topi_searchsorted.py @@ -35,28 +35,42 @@ def get_implementations(name, axis, dtype, exclusive): } +def searchsorted_ref(sorted_sequence, values): + sorted_sequence_2d = np.reshape(sorted_sequence, (-1, sorted_sequence.shape[-1])) + values_2d = np.reshape(values, (-1, values.shape[-1])) + indices = np.zeros(values_2d.shape) + + for i in range(indices.shape[0]): + indices[i] = np.searchsorted(sorted_sequence_2d[i], values_2d[i]) + + return np.reshape(indices, values.shape) + + @tvm.testing.parametrize_targets def test_cumsum(dev, target): - n = 1024 - A = te.placeholder((n,), name="A", dtype="float32") - B = te.placeholder((n,), name="B", dtype="float32") + sequence_len = 1024 + num_search = 1000 + outer_axes = (10, 5, 3) + sorted_sequence_shape = outer_axes + (sequence_len,) + values_shape = outer_axes + (num_search,) + A = te.placeholder(sorted_sequence_shape, name="A", dtype="float32") + B = te.placeholder(values_shape, name="B", dtype="float32") C = topi.searchsorted(A, B) s = te.create_schedule(C.op) with tvm.transform.PassContext(opt_level=3): - func = tvm.build(s, [A, B, C], target) + func = tvm.build(s, [A, B, C], target=target) dev = tvm.device(target, 0) - a_np = np.random.uniform(size=n).astype(A.dtype) - b_np = np.random.uniform(size=n).astype(B.dtype) - a_np = np.sort(a_np) + a_np = np.random.randn(*sorted_sequence_shape).astype(A.dtype) + b_np = np.random.randn(*values_shape).astype(B.dtype) + a_np = np.sort(a_np, axis=-1) a = tvm.nd.array(a_np, dev) b = tvm.nd.array(b_np, dev) - c = tvm.nd.array(np.zeros(n, dtype=C.dtype), dev) + c = tvm.nd.array(np.zeros(values_shape, dtype=C.dtype), dev) func(a, b, c) - ref = np.searchsorted(a_np, b_np) - tvm.testing.assert_allclose(c.numpy(), ref) - print("ok") + ref = searchsorted_ref(a_np, b_np) + np.testing.assert_equal(c.numpy(), ref) if __name__ == "__main__": From 8f1f0109ed03b461939bad9d3c1030a4f510ce0e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 30 Sep 2021 22:05:39 +0900 Subject: [PATCH 04/26] gpu version working --- python/tvm/topi/cuda/__init__.py | 1 + python/tvm/topi/cuda/searchsorted.py | 57 +++++++++++++++++++ python/tvm/topi/searchsorted.py | 32 +++++------ .../topi/python/test_topi_searchsorted.py | 29 ++++++---- 4 files changed, 93 insertions(+), 26 deletions(-) create mode 100644 python/tvm/topi/cuda/searchsorted.py diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py index 21ddf57ca1d0..88d306761310 100644 --- a/python/tvm/topi/cuda/__init__.py +++ b/python/tvm/topi/cuda/__init__.py @@ -59,3 +59,4 @@ from .sparse_reshape import * from .transform import * from .unique import * +from .searchsorted import * diff --git a/python/tvm/topi/cuda/searchsorted.py b/python/tvm/topi/cuda/searchsorted.py new file mode 100644 index 000000000000..2a27d8c2f1ef --- /dev/null +++ b/python/tvm/topi/cuda/searchsorted.py @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""searchsorted operator for GPU""" +import tvm +from tvm import te +from .. import utils +from ..searchsorted import binary_search + + +def searchsorted(sorted_sequence, values, side="left", out_dtype="int64"): + def ir(sorted_sequence, values, indices): + ib = tvm.tir.ir_builder.create() + sorted_sequence_shape = sorted_sequence.shape + values_shape = values.shape + num_search = utils.prod(values_shape) + search_range = sorted_sequence_shape[-1] + + sorted_sequence = ib.buffer_ptr(sorted_sequence) + values = ib.buffer_ptr(values) + indices = ib.buffer_ptr(indices) + + max_threads = 256 + bx = te.thread_axis("blockIdx.x") + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(bx, "thread_extent", tvm.tir.indexdiv(num_search + max_threads - 1, max_threads)) + ib.scope_attr(tx, "thread_extent", max_threads) + tid = bx * max_threads + tx + + with ib.new_scope(): + with ib.if_scope(tid < num_search): + sequence_id = tid // values_shape[-1] + sequence_offset = sequence_id * search_range + binary_search(ib, sequence_offset, search_range, sorted_sequence, tid, values, indices, out_dtype) + + return ib.get() + + return te.extern( + values.shape, + [sorted_sequence, values], + lambda ins, outs: ir(ins[0], ins[1], outs[0]), + name="searchsorted_ir", + dtype=out_dtype, + ) diff --git a/python/tvm/topi/searchsorted.py b/python/tvm/topi/searchsorted.py index 6706930c271b..0ce4edf043e4 100644 --- a/python/tvm/topi/searchsorted.py +++ b/python/tvm/topi/searchsorted.py @@ -21,30 +21,30 @@ from .math import cast -def searchsorted(sorted_sequence, values, side="left", out_dtype="int64"): - def binary_search(ib, sequence_offset, search_range, sorted_sequence, i, values, out_indices): - lo = ib.allocate(out_dtype, (1,), name="lo", scope="local") - hi = ib.allocate(out_dtype, (1,), name="hi", scope="local") +def binary_search(ib, sequence_offset, search_range, sorted_sequence, i, values, out_indices, out_dtype): + lo = ib.allocate(out_dtype, (1,), name="lo", scope="local") + hi = ib.allocate(out_dtype, (1,), name="hi", scope="local") + + v = values[i] + lo[0] = cast(0, out_dtype) + hi[0] = cast(search_range, out_dtype) - v = values[i] - lo[0] = cast(0, out_dtype) - hi[0] = cast(search_range, out_dtype) + with ib.while_loop(lo[0] < hi[0]): + mid = lo[0] + (hi[0] - lo[0] >> 1) + with ib.if_scope(sorted_sequence[sequence_offset + mid] < v): + lo[0] = mid + 1 + with ib.else_scope(): + hi[0] = mid - with ib.while_loop(lo[0] < hi[0]): - mid = lo[0] + (hi[0] - lo[0] >> 1) - with ib.if_scope(sorted_sequence[sequence_offset + mid] < v): - lo[0] = mid + 1 - with ib.else_scope(): - hi[0] = mid + out_indices[i] = lo[0] - out_indices[i] = lo[0] +def searchsorted(sorted_sequence, values, side="left", out_dtype="int64"): def ir(sorted_sequence, values, indices): ib = ir_builder.create() sorted_sequence_shape = sorted_sequence.shape values_shape = values.shape num_search = utils.prod(values_shape) - num_sequence = utils.prod(sorted_sequence_shape[:-1]) search_range = sorted_sequence_shape[-1] sorted_sequence = ib.buffer_ptr(sorted_sequence) @@ -54,7 +54,7 @@ def ir(sorted_sequence, values, indices): with ib.for_range(0, num_search, name="i", kind="parallel") as i: sequence_id = i // values_shape[-1] sequence_offset = sequence_id * search_range - binary_search(ib, sequence_offset, search_range, sorted_sequence, i, values, indices) + binary_search(ib, sequence_offset, search_range, sorted_sequence, i, values, indices, out_dtype) return ib.get() diff --git a/tests/python/topi/python/test_topi_searchsorted.py b/tests/python/topi/python/test_topi_searchsorted.py index d3da73616db3..a9b4223450c0 100644 --- a/tests/python/topi/python/test_topi_searchsorted.py +++ b/tests/python/topi/python/test_topi_searchsorted.py @@ -20,18 +20,22 @@ import tvm.topi.testing from tvm import te, topi -topi_funcs = {"searchsorted": {"generic": topi.searchsorted}} +topi_funcs = {"generic": topi.searchsorted, "cuda": topi.cuda.searchsorted} -def get_implementations(name, axis, dtype, exclusive): - topi_func_generic = topi_funcs[name]["generic"] - # topi_func_cuda = topi_funcs[name]["cuda"] +def get_implementations(): + topi_func_generic = topi_funcs["generic"] + topi_func_cuda = topi_funcs["cuda"] return { "generic": ( - lambda x: topi_func_generic(x, axis, dtype, exclusive=exclusive), + lambda x, y: topi_func_generic(x, y), topi.generic.schedule_extern, ), + "vulkan": ( + lambda x, y: topi_func_cuda(x, y), + topi.cuda.schedule_extern, + ), } @@ -47,7 +51,7 @@ def searchsorted_ref(sorted_sequence, values): @tvm.testing.parametrize_targets -def test_cumsum(dev, target): +def test_searchsorted(dev, target): sequence_len = 1024 num_search = 1000 outer_axes = (10, 5, 3) @@ -55,8 +59,12 @@ def test_cumsum(dev, target): values_shape = outer_axes + (num_search,) A = te.placeholder(sorted_sequence_shape, name="A", dtype="float32") B = te.placeholder(values_shape, name="B", dtype="float32") - C = topi.searchsorted(A, B) - s = te.create_schedule(C.op) + + implementations = get_implementations() + fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations) + + C = fcompute(A, B) + s = fschedule([C]) with tvm.transform.PassContext(opt_level=3): func = tvm.build(s, [A, B, C], target=target) @@ -71,8 +79,9 @@ def test_cumsum(dev, target): func(a, b, c) ref = searchsorted_ref(a_np, b_np) np.testing.assert_equal(c.numpy(), ref) + print("ok") if __name__ == "__main__": - target = "llvm" - test_cumsum(tvm.device(target, 0), target) + target = "vulkan -from_device=0" + test_searchsorted(tvm.device(target, 0), target) From 445852c69c0fecdc3b60f5c2603bc910beb30447 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 1 Oct 2021 12:09:22 +0900 Subject: [PATCH 05/26] check shape in type rel --- src/relay/op/algorithm/searchsorted.cc | 9 ++++++++- tests/python/topi/python/test_topi_searchsorted.py | 5 ++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/relay/op/algorithm/searchsorted.cc b/src/relay/op/algorithm/searchsorted.cc index 0bf964e344ac..b5a791ed8865 100644 --- a/src/relay/op/algorithm/searchsorted.cc +++ b/src/relay/op/algorithm/searchsorted.cc @@ -39,6 +39,14 @@ bool SearchSortedRel(const Array& types, int num_inputs, const Attrs& attr ICHECK(sorted_sequence) << "Expects TensorType in the first input"; ICHECK(values) << "Expects TensorType in the second input"; + ICHECK_EQ(sorted_sequence->shape.size(), values->shape.size()) + << "Ranks of sorted sequence and values must be the same"; + for (size_t i = 0; i < values->shape.size() - 1; ++i) { + ICHECK_EQ(sorted_sequence->shape[i], values->shape[i]) + << "sorted sequence and values do not have the same shape along outer axes"; + } + + reporter->Assign(types[2], TensorType(values->shape, param->dtype)); return true; } @@ -61,6 +69,5 @@ RELAY_REGISTER_OP("searchsorted") .set_support_level(6) .add_type_rel("SearchSorted", SearchSortedRel); - } // namespace relay } // namespace tvm diff --git a/tests/python/topi/python/test_topi_searchsorted.py b/tests/python/topi/python/test_topi_searchsorted.py index a9b4223450c0..d99bbe88005a 100644 --- a/tests/python/topi/python/test_topi_searchsorted.py +++ b/tests/python/topi/python/test_topi_searchsorted.py @@ -79,9 +79,8 @@ def test_searchsorted(dev, target): func(a, b, c) ref = searchsorted_ref(a_np, b_np) np.testing.assert_equal(c.numpy(), ref) - print("ok") if __name__ == "__main__": - target = "vulkan -from_device=0" - test_searchsorted(tvm.device(target, 0), target) + for target in ["llvm", "vulkan -from_device=0"]: + test_searchsorted(tvm.device(target, 0), target) From 6b1adca59cd4d3f01242360493adccc940464eca Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 1 Oct 2021 19:24:31 +0900 Subject: [PATCH 06/26] support side --- python/tvm/topi/cuda/searchsorted.py | 27 ++++++++--- python/tvm/topi/searchsorted.py | 29 ++++++++++-- .../topi/python/test_topi_searchsorted.py | 45 ++++++++++--------- 3 files changed, 69 insertions(+), 32 deletions(-) diff --git a/python/tvm/topi/cuda/searchsorted.py b/python/tvm/topi/cuda/searchsorted.py index 2a27d8c2f1ef..868c530cfb83 100644 --- a/python/tvm/topi/cuda/searchsorted.py +++ b/python/tvm/topi/cuda/searchsorted.py @@ -22,6 +22,8 @@ def searchsorted(sorted_sequence, values, side="left", out_dtype="int64"): + """TODO""" + def ir(sorted_sequence, values, indices): ib = tvm.tir.ir_builder.create() sorted_sequence_shape = sorted_sequence.shape @@ -36,15 +38,26 @@ def ir(sorted_sequence, values, indices): max_threads = 256 bx = te.thread_axis("blockIdx.x") tx = te.thread_axis("threadIdx.x") - ib.scope_attr(bx, "thread_extent", tvm.tir.indexdiv(num_search + max_threads - 1, max_threads)) + ib.scope_attr( + bx, "thread_extent", tvm.tir.indexdiv(num_search + max_threads - 1, max_threads) + ) ib.scope_attr(tx, "thread_extent", max_threads) tid = bx * max_threads + tx - with ib.new_scope(): - with ib.if_scope(tid < num_search): - sequence_id = tid // values_shape[-1] - sequence_offset = sequence_id * search_range - binary_search(ib, sequence_offset, search_range, sorted_sequence, tid, values, indices, out_dtype) + with ib.if_scope(tid < num_search): + sequence_id = tid // values_shape[-1] + sequence_offset = sequence_id * search_range + binary_search( + ib, + sequence_offset, + search_range, + sorted_sequence, + tid, + values, + indices, + side, + out_dtype, + ) return ib.get() @@ -52,6 +65,6 @@ def ir(sorted_sequence, values, indices): values.shape, [sorted_sequence, values], lambda ins, outs: ir(ins[0], ins[1], outs[0]), - name="searchsorted_ir", + name="searchsorted", dtype=out_dtype, ) diff --git a/python/tvm/topi/searchsorted.py b/python/tvm/topi/searchsorted.py index 0ce4edf043e4..53e4bc48c2c3 100644 --- a/python/tvm/topi/searchsorted.py +++ b/python/tvm/topi/searchsorted.py @@ -21,7 +21,10 @@ from .math import cast -def binary_search(ib, sequence_offset, search_range, sorted_sequence, i, values, out_indices, out_dtype): +def binary_search( + ib, sequence_offset, search_range, sorted_sequence, i, values, out_indices, side, out_dtype +): + """TODO""" lo = ib.allocate(out_dtype, (1,), name="lo", scope="local") hi = ib.allocate(out_dtype, (1,), name="hi", scope="local") @@ -29,9 +32,15 @@ def binary_search(ib, sequence_offset, search_range, sorted_sequence, i, values, lo[0] = cast(0, out_dtype) hi[0] = cast(search_range, out_dtype) + # Reference: pytorch/aten/src/ATen/native/cuda/Bucketization.cu + def condition(current_val, target_val): + if side == "left": + return current_val < target_val + return current_val <= target_val + with ib.while_loop(lo[0] < hi[0]): mid = lo[0] + (hi[0] - lo[0] >> 1) - with ib.if_scope(sorted_sequence[sequence_offset + mid] < v): + with ib.if_scope(condition(sorted_sequence[sequence_offset + mid], v)): lo[0] = mid + 1 with ib.else_scope(): hi[0] = mid @@ -40,6 +49,8 @@ def binary_search(ib, sequence_offset, search_range, sorted_sequence, i, values, def searchsorted(sorted_sequence, values, side="left", out_dtype="int64"): + """TODO""" + def ir(sorted_sequence, values, indices): ib = ir_builder.create() sorted_sequence_shape = sorted_sequence.shape @@ -54,7 +65,17 @@ def ir(sorted_sequence, values, indices): with ib.for_range(0, num_search, name="i", kind="parallel") as i: sequence_id = i // values_shape[-1] sequence_offset = sequence_id * search_range - binary_search(ib, sequence_offset, search_range, sorted_sequence, i, values, indices, out_dtype) + binary_search( + ib, + sequence_offset, + search_range, + sorted_sequence, + i, + values, + indices, + side, + out_dtype, + ) return ib.get() @@ -62,6 +83,6 @@ def ir(sorted_sequence, values, indices): values.shape, [sorted_sequence, values], lambda ins, outs: ir(ins[0], ins[1], outs[0]), - name="searchsorted_ir", + name="searchsorted", dtype=out_dtype, ) diff --git a/tests/python/topi/python/test_topi_searchsorted.py b/tests/python/topi/python/test_topi_searchsorted.py index d99bbe88005a..cd5cd03bf28d 100644 --- a/tests/python/topi/python/test_topi_searchsorted.py +++ b/tests/python/topi/python/test_topi_searchsorted.py @@ -29,23 +29,23 @@ def get_implementations(): return { "generic": ( - lambda x, y: topi_func_generic(x, y), + lambda x, y, side, out_dtype: topi_func_generic(x, y, side, out_dtype), topi.generic.schedule_extern, ), "vulkan": ( - lambda x, y: topi_func_cuda(x, y), + lambda x, y, side, out_dtype: topi_func_cuda(x, y, side, out_dtype), topi.cuda.schedule_extern, ), } -def searchsorted_ref(sorted_sequence, values): +def searchsorted_ref(sorted_sequence, values, side, out_dtype): sorted_sequence_2d = np.reshape(sorted_sequence, (-1, sorted_sequence.shape[-1])) values_2d = np.reshape(values, (-1, values.shape[-1])) - indices = np.zeros(values_2d.shape) + indices = np.zeros(values_2d.shape, dtype=out_dtype) for i in range(indices.shape[0]): - indices[i] = np.searchsorted(sorted_sequence_2d[i], values_2d[i]) + indices[i] = np.searchsorted(sorted_sequence_2d[i], values_2d[i], side=side) return np.reshape(indices, values.shape) @@ -59,28 +59,31 @@ def test_searchsorted(dev, target): values_shape = outer_axes + (num_search,) A = te.placeholder(sorted_sequence_shape, name="A", dtype="float32") B = te.placeholder(values_shape, name="B", dtype="float32") - + out_dtype = "int32" implementations = get_implementations() fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations) - C = fcompute(A, B) - s = fschedule([C]) + for side in ["left", "right"]: + C = fcompute(A, B, side, out_dtype) + s = fschedule([C]) + + with tvm.transform.PassContext(opt_level=3): + func = tvm.build(s, [A, B, C], target=target) - with tvm.transform.PassContext(opt_level=3): - func = tvm.build(s, [A, B, C], target=target) + dev = tvm.device(target, 0) - dev = tvm.device(target, 0) - a_np = np.random.randn(*sorted_sequence_shape).astype(A.dtype) - b_np = np.random.randn(*values_shape).astype(B.dtype) - a_np = np.sort(a_np, axis=-1) - a = tvm.nd.array(a_np, dev) - b = tvm.nd.array(b_np, dev) - c = tvm.nd.array(np.zeros(values_shape, dtype=C.dtype), dev) - func(a, b, c) - ref = searchsorted_ref(a_np, b_np) - np.testing.assert_equal(c.numpy(), ref) + for i in range(100): + a_np = np.random.randn(*sorted_sequence_shape).astype(A.dtype) + b_np = np.random.randn(*values_shape).astype(B.dtype) + a_np = np.sort(a_np, axis=-1) + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(b_np, dev) + c = tvm.nd.array(np.zeros(values_shape, dtype=C.dtype), dev) + func(a, b, c) + ref = searchsorted_ref(a_np, b_np, side, out_dtype) + np.testing.assert_equal(c.numpy(), ref) if __name__ == "__main__": - for target in ["llvm", "vulkan -from_device=0"]: + for target in ["vulkan -from_device=0"]: test_searchsorted(tvm.device(target, 0), target) From 98b88fc3f42d7d15f9815e080a9ef89038525fa5 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 1 Oct 2021 19:29:56 +0900 Subject: [PATCH 07/26] use target specfic max threads --- python/tvm/topi/cuda/searchsorted.py | 2 +- .../topi/python/test_topi_searchsorted.py | 21 +++++++++---------- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/python/tvm/topi/cuda/searchsorted.py b/python/tvm/topi/cuda/searchsorted.py index 868c530cfb83..28a2381b566c 100644 --- a/python/tvm/topi/cuda/searchsorted.py +++ b/python/tvm/topi/cuda/searchsorted.py @@ -35,7 +35,7 @@ def ir(sorted_sequence, values, indices): values = ib.buffer_ptr(values) indices = ib.buffer_ptr(indices) - max_threads = 256 + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) bx = te.thread_axis("blockIdx.x") tx = te.thread_axis("threadIdx.x") ib.scope_attr( diff --git a/tests/python/topi/python/test_topi_searchsorted.py b/tests/python/topi/python/test_topi_searchsorted.py index cd5cd03bf28d..2fd65a430714 100644 --- a/tests/python/topi/python/test_topi_searchsorted.py +++ b/tests/python/topi/python/test_topi_searchsorted.py @@ -57,33 +57,32 @@ def test_searchsorted(dev, target): outer_axes = (10, 5, 3) sorted_sequence_shape = outer_axes + (sequence_len,) values_shape = outer_axes + (num_search,) - A = te.placeholder(sorted_sequence_shape, name="A", dtype="float32") - B = te.placeholder(values_shape, name="B", dtype="float32") + sorted_sequence = te.placeholder(sorted_sequence_shape, name="A", dtype="float32") + values = te.placeholder(values_shape, name="B", dtype="float32") out_dtype = "int32" implementations = get_implementations() fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations) for side in ["left", "right"]: - C = fcompute(A, B, side, out_dtype) - s = fschedule([C]) - - with tvm.transform.PassContext(opt_level=3): - func = tvm.build(s, [A, B, C], target=target) + with tvm.target.Target(target): + indices = fcompute(sorted_sequence, values, side, out_dtype) + s = fschedule([indices]) + func = tvm.build(s, [sorted_sequence, values, indices], target=target) dev = tvm.device(target, 0) for i in range(100): - a_np = np.random.randn(*sorted_sequence_shape).astype(A.dtype) - b_np = np.random.randn(*values_shape).astype(B.dtype) + a_np = np.random.randn(*sorted_sequence_shape).astype(sorted_sequence.dtype) + b_np = np.random.randn(*values_shape).astype(values.dtype) a_np = np.sort(a_np, axis=-1) a = tvm.nd.array(a_np, dev) b = tvm.nd.array(b_np, dev) - c = tvm.nd.array(np.zeros(values_shape, dtype=C.dtype), dev) + c = tvm.nd.array(np.zeros(values_shape, dtype=indices.dtype), dev) func(a, b, c) ref = searchsorted_ref(a_np, b_np, side, out_dtype) np.testing.assert_equal(c.numpy(), ref) if __name__ == "__main__": - for target in ["vulkan -from_device=0"]: + for target in ["llvm", "vulkan -from_device=0"]: test_searchsorted(tvm.device(target, 0), target) From 5ca105cff9a97a2d6a8cf1cd4da462639262eba9 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 1 Oct 2021 19:52:14 +0900 Subject: [PATCH 08/26] add relay boilerplate --- include/tvm/relay/attrs/algorithm.h | 2 +- python/tvm/relay/op/_algorithm.py | 3 +++ python/tvm/relay/op/algorithm.py | 17 +++++++++++++++++ python/tvm/relay/op/op_attrs.py | 3 +++ python/tvm/relay/op/strategy/cuda.py | 12 ++++++++++++ python/tvm/relay/op/strategy/generic.py | 25 +++++++++++++++++++++++++ src/relay/op/algorithm/searchsorted.cc | 10 ++++++++-- 7 files changed, 69 insertions(+), 3 deletions(-) diff --git a/include/tvm/relay/attrs/algorithm.h b/include/tvm/relay/attrs/algorithm.h index b15c2e687fdc..19f88d9bffc6 100644 --- a/include/tvm/relay/attrs/algorithm.h +++ b/include/tvm/relay/attrs/algorithm.h @@ -84,7 +84,7 @@ struct SearchSortedAttrs : public tvm::AttrsNode { TVM_ATTR_FIELD(side).set_default("left").describe( "Controls which index is returned if a value lands exactly on one of sorted values."); TVM_ATTR_FIELD(dtype) - .set_default(DataType::Int(64)) + .set_default(DataType::Int(32)) .describe("Data type of the output indices."); } }; diff --git a/python/tvm/relay/op/_algorithm.py b/python/tvm/relay/op/_algorithm.py index 817f96b696df..ace3de5fff06 100644 --- a/python/tvm/relay/op/_algorithm.py +++ b/python/tvm/relay/op/_algorithm.py @@ -41,6 +41,9 @@ register_strategy("topk", strategy.topk_strategy) register_pattern("topk", OpPattern.OPAQUE) +# searchsorted +register_strategy("searchsorted", strategy.searchsorted_strategy) +register_pattern("searchsorted", OpPattern.OPAQUE) @script def _topk_shape_func_input_shape(data_shape, k, axis): diff --git a/python/tvm/relay/op/algorithm.py b/python/tvm/relay/op/algorithm.py index 119936f632f8..9639a6800433 100644 --- a/python/tvm/relay/op/algorithm.py +++ b/python/tvm/relay/op/algorithm.py @@ -115,3 +115,20 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int32"): if ret_type == "both": return TupleWrapper(out, 2) return out + + +def searchsorted(sorted_sequence, values, side="left", dtype="int32"): + """TODO + + Parameters + ---------- + dtype : string, optional + The data type of the output indices. + + Returns + ------- + indices : relay.Expr + Tensor with same shape as values, representing the indices of + elements of values if they are inserted in sorted_sequence. + """ + return _make.searchsorted(sorted_sequence, values, side, dtype) diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 8fd46817b817..f9a4051f6aeb 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -563,6 +563,9 @@ class SparseConv2DAttrs(Attrs): class TopkAttrs(Attrs): """Attributes used in topk operators""" +@tvm._ffi.register_object("relay.attrs.SearchSortedAttrs") +class SearchSortedAttrs(Attrs): + """Attributes used in searchsorted operators""" @tvm._ffi.register_object("relay.attrs.TupleGetItemAttrs") class TupleGetItemAttrs(Attrs): diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index da7cbd5cec10..5f24dbda9d35 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -1022,6 +1022,18 @@ def topk_strategy_cuda(attrs, inputs, out_type, target): return strategy +@searchsorted_strategy.register(["cuda", "gpu"]) +def searchsorted_strategy_cuda(attrs, inputs, out_type, target): + """searchsorted cuda strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_searchsorted(topi.cuda.searchsorted), + wrap_topi_schedule(topi.cuda.schedule_extern), + name="searchsorted.cuda", + ) + return strategy + + @multibox_prior_strategy.register(["cuda", "gpu"]) def multibox_prior_strategy_cuda(attrs, inputs, out_type, target): """multibox_prior cuda strategy""" diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index d021b5d9d84d..fc6ba8fda86d 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1002,6 +1002,31 @@ def topk_strategy(attrs, inputs, out_type, target): return strategy +# searchsorted +def wrap_compute_searchsorted(topi_compute): + """Wrap searchsorted compute""" + + def _compute_searchsorted(attrs, inputs, out_type): + side = attrs.side + dtype = attrs.dtype + return [topi_compute(inputs[0], inputs[1], side, dtype)] + + return _compute_searchsorted + + +# searchsorted_strategy +@override_native_generic_func("searchsorted_strategy") +def searchsorted_strategy(attrs, inputs, out_type, target): + """searchsorted generic strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_searchsorted(topi.searchsorted), + wrap_topi_schedule(topi.generic.schedule_extern), + name="searchsorted.generic", + ) + return strategy + + # multibox_prior def wrap_compute_multibox_prior(topi_compute): """Wrap multibox_prior compute""" diff --git a/src/relay/op/algorithm/searchsorted.cc b/src/relay/op/algorithm/searchsorted.cc index b5a791ed8865..91e6bbc8b549 100644 --- a/src/relay/op/algorithm/searchsorted.cc +++ b/src/relay/op/algorithm/searchsorted.cc @@ -38,12 +38,18 @@ bool SearchSortedRel(const Array& types, int num_inputs, const Attrs& attr const auto* values = types[1].as(); ICHECK(sorted_sequence) << "Expects TensorType in the first input"; ICHECK(values) << "Expects TensorType in the second input"; + ICHECK(param->side == "left" || param->side == "right") + << "'side' parameter must be either 'left' or 'right'"; ICHECK_EQ(sorted_sequence->shape.size(), values->shape.size()) << "Ranks of sorted sequence and values must be the same"; + for (size_t i = 0; i < values->shape.size() - 1; ++i) { - ICHECK_EQ(sorted_sequence->shape[i], values->shape[i]) - << "sorted sequence and values do not have the same shape along outer axes"; + if (sorted_sequence->shape[i].as() && values->shape[i].as()) { + ICHECK_EQ(sorted_sequence->shape[i].as()->value, + values->shape[i].as()->value) + << "sorted sequence and values do not have the same shape along outer axes"; + } } reporter->Assign(types[2], TensorType(values->shape, param->dtype)); From a055dd3e5f772a4d471cd2203a25cf1811187116 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 1 Oct 2021 20:20:17 +0900 Subject: [PATCH 09/26] relay test working --- python/tvm/topi/testing/__init__.py | 1 + python/tvm/topi/testing/searchsorted.py | 12 ++++++++++++ src/relay/op/algorithm/searchsorted.cc | 6 ++++-- tests/python/relay/test_op_level6.py | 24 ++++++++++++++++++++++++ 4 files changed, 41 insertions(+), 2 deletions(-) create mode 100644 python/tvm/topi/testing/searchsorted.py diff --git a/python/tvm/topi/testing/__init__.py b/python/tvm/topi/testing/__init__.py index d10c49f5c084..2d7d0a4b9e11 100644 --- a/python/tvm/topi/testing/__init__.py +++ b/python/tvm/topi/testing/__init__.py @@ -73,3 +73,4 @@ from .batch_to_space_nd import batch_to_space_nd_python from .nll_loss import nll_loss from .dense import dense +from .searchsorted import searchsorted_ref diff --git a/python/tvm/topi/testing/searchsorted.py b/python/tvm/topi/testing/searchsorted.py new file mode 100644 index 000000000000..045e00df7eeb --- /dev/null +++ b/python/tvm/topi/testing/searchsorted.py @@ -0,0 +1,12 @@ +import numpy as np + + +def searchsorted_ref(sorted_sequence, values, side, out_dtype): + sorted_sequence_2d = np.reshape(sorted_sequence, (-1, sorted_sequence.shape[-1])) + values_2d = np.reshape(values, (-1, values.shape[-1])) + indices = np.zeros(values_2d.shape, dtype=out_dtype) + + for i in range(indices.shape[0]): + indices[i] = np.searchsorted(sorted_sequence_2d[i], values_2d[i], side=side) + + return np.reshape(indices, values.shape) diff --git a/src/relay/op/algorithm/searchsorted.cc b/src/relay/op/algorithm/searchsorted.cc index 91e6bbc8b549..99566ef43330 100644 --- a/src/relay/op/algorithm/searchsorted.cc +++ b/src/relay/op/algorithm/searchsorted.cc @@ -18,8 +18,8 @@ */ /*! - * \file topk.cc - * \brief TopK operators + * \file searchsorted.cc + * \brief SearchSorted operators */ #include #include @@ -59,6 +59,8 @@ bool SearchSortedRel(const Array& types, int num_inputs, const Attrs& attr Expr MakeSearchSorted(Expr sorted_sequence, Expr values, String side, DataType dtype) { auto attrs = make_object(); static const Op& op = Op::Get("searchsorted"); + attrs->dtype = dtype; + attrs->side = side; return Call(op, {sorted_sequence, values}, Attrs(attrs), {}); } diff --git a/tests/python/relay/test_op_level6.py b/tests/python/relay/test_op_level6.py index ea640c62dfeb..f1666ecd005d 100644 --- a/tests/python/relay/test_op_level6.py +++ b/tests/python/relay/test_op_level6.py @@ -20,6 +20,7 @@ import numpy as np import tvm from tvm import relay +from tvm.topi.testing import searchsorted_ref import tvm.testing @@ -149,5 +150,28 @@ def verify_topk(k, axis, ret_type, is_ascend, dtype, in_dtype="float32"): verify_topk(k, axis, ret_type, False, "int64", "float16") +@tvm.testing.uses_gpu +def test_searchsorted(): + def verify_searchsorted(side, dtype): + shape = (10, 20, 100) + values_shape = shape[:-1] + (50,) + sorted_sequence = relay.var("sorted_sequence", relay.TensorType(shape, "float32")) + values = relay.var("sorted_sequence", relay.TensorType(values_shape, "float32")) + out = relay.searchsorted(sorted_sequence, values, side, dtype) + func = relay.Function([sorted_sequence, values], out) + sorted_sequence_np = np.sort(np.random.randn(*shape).astype("float32"), axis=-1) + values_np = np.random.randn(*values_shape).astype("float32") + np_indices = searchsorted_ref(sorted_sequence_np, values_np, side, dtype) + + for target, dev in tvm.testing.enabled_targets(): + op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + sorted_sequence_np, values_np + ) + np.testing.assert_equal(op_res.numpy(), np_indices) + + verify_searchsorted("left", "int32") + verify_searchsorted("right", "int64") + + if __name__ == "__main__": pytest.main([__file__]) From 2584f45f06ea4272cd2fe5493faed8dc4853babb Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 1 Oct 2021 20:29:23 +0900 Subject: [PATCH 10/26] cleanup topi test --- .../topi/python/test_topi_searchsorted.py | 63 ++++++++----------- 1 file changed, 27 insertions(+), 36 deletions(-) diff --git a/tests/python/topi/python/test_topi_searchsorted.py b/tests/python/topi/python/test_topi_searchsorted.py index 2fd65a430714..acef6f071001 100644 --- a/tests/python/topi/python/test_topi_searchsorted.py +++ b/tests/python/topi/python/test_topi_searchsorted.py @@ -18,6 +18,7 @@ import tvm import tvm.testing import tvm.topi.testing +from tvm.topi.testing import searchsorted_ref from tvm import te, topi topi_funcs = {"generic": topi.searchsorted, "cuda": topi.cuda.searchsorted} @@ -32,6 +33,10 @@ def get_implementations(): lambda x, y, side, out_dtype: topi_func_generic(x, y, side, out_dtype), topi.generic.schedule_extern, ), + "cuda": ( + lambda x, y, side, out_dtype: topi_func_cuda(x, y, side, out_dtype), + topi.cuda.schedule_extern, + ), "vulkan": ( lambda x, y, side, out_dtype: topi_func_cuda(x, y, side, out_dtype), topi.cuda.schedule_extern, @@ -39,31 +44,17 @@ def get_implementations(): } -def searchsorted_ref(sorted_sequence, values, side, out_dtype): - sorted_sequence_2d = np.reshape(sorted_sequence, (-1, sorted_sequence.shape[-1])) - values_2d = np.reshape(values, (-1, values.shape[-1])) - indices = np.zeros(values_2d.shape, dtype=out_dtype) - - for i in range(indices.shape[0]): - indices[i] = np.searchsorted(sorted_sequence_2d[i], values_2d[i], side=side) - - return np.reshape(indices, values.shape) - - @tvm.testing.parametrize_targets def test_searchsorted(dev, target): - sequence_len = 1024 - num_search = 1000 - outer_axes = (10, 5, 3) - sorted_sequence_shape = outer_axes + (sequence_len,) - values_shape = outer_axes + (num_search,) - sorted_sequence = te.placeholder(sorted_sequence_shape, name="A", dtype="float32") - values = te.placeholder(values_shape, name="B", dtype="float32") - out_dtype = "int32" - implementations = get_implementations() - fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations) + def verify(sequence_len, num_search, outer_axes, side): + sorted_sequence_shape = outer_axes + (sequence_len,) + values_shape = outer_axes + (num_search,) + sorted_sequence = te.placeholder(sorted_sequence_shape, dtype="float32") + values = te.placeholder(values_shape, dtype="float32") + out_dtype = "int32" + implementations = get_implementations() + fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations) - for side in ["left", "right"]: with tvm.target.Target(target): indices = fcompute(sorted_sequence, values, side, out_dtype) s = fschedule([indices]) @@ -71,18 +62,18 @@ def test_searchsorted(dev, target): func = tvm.build(s, [sorted_sequence, values, indices], target=target) dev = tvm.device(target, 0) - for i in range(100): - a_np = np.random.randn(*sorted_sequence_shape).astype(sorted_sequence.dtype) - b_np = np.random.randn(*values_shape).astype(values.dtype) - a_np = np.sort(a_np, axis=-1) - a = tvm.nd.array(a_np, dev) - b = tvm.nd.array(b_np, dev) - c = tvm.nd.array(np.zeros(values_shape, dtype=indices.dtype), dev) - func(a, b, c) - ref = searchsorted_ref(a_np, b_np, side, out_dtype) - np.testing.assert_equal(c.numpy(), ref) - + a_np = np.random.randn(*sorted_sequence_shape).astype(sorted_sequence.dtype) + b_np = np.random.randn(*values_shape).astype(values.dtype) + a_np = np.sort(a_np, axis=-1) + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(b_np, dev) + c = tvm.nd.array(np.zeros(values_shape, dtype=indices.dtype), dev) + func(a, b, c) + ref = searchsorted_ref(a_np, b_np, side, out_dtype) + np.testing.assert_equal(c.numpy(), ref) -if __name__ == "__main__": - for target in ["llvm", "vulkan -from_device=0"]: - test_searchsorted(tvm.device(target, 0), target) + # The first argument is the range of binary search + verify(1024, 1000, (10, 5, 3), "left") + verify(999, 2000, (10, 5, 3), "right") + verify(1000, 1000, (), "left") + verify(2001, 100, (500), "right") From 686e2222d7637a52728d00cf379d05f48368b7fe Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 2 Oct 2021 17:51:06 +0900 Subject: [PATCH 11/26] fix test --- src/te/operation/create_primfunc.cc | 2 +- tests/python/relay/test_op_level5.py | 1 - tests/python/topi/python/test_topi_searchsorted.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index a47556bac101..1885de0848cf 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -48,7 +48,7 @@ class ProducerToBufferTransformer : public StmtExprMutator { const std::unordered_map& tensor2buffers_; }; -/*! \brief Helper data structural to store informations. */ +/*! \brief Helper data structure to store information. */ struct CreateFuncInfo { /*! \brief The Tensor arg_list. */ Array arg_list; diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index eb4eee379b08..c968c5a7f19f 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -773,7 +773,6 @@ def verify_roi_align( mode=mode, ) for target, dev in tvm.testing.enabled_targets(): - print("test on", target) op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( np_data, np_rois ) diff --git a/tests/python/topi/python/test_topi_searchsorted.py b/tests/python/topi/python/test_topi_searchsorted.py index acef6f071001..7592354e9bb8 100644 --- a/tests/python/topi/python/test_topi_searchsorted.py +++ b/tests/python/topi/python/test_topi_searchsorted.py @@ -76,4 +76,4 @@ def verify(sequence_len, num_search, outer_axes, side): verify(1024, 1000, (10, 5, 3), "left") verify(999, 2000, (10, 5, 3), "right") verify(1000, 1000, (), "left") - verify(2001, 100, (500), "right") + verify(2001, 100, (500,), "right") From 1c7a0ffc7b6d3bbddcbf8e4d5183ac769c081ceb Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 2 Oct 2021 20:31:26 +0900 Subject: [PATCH 12/26] add torch converter --- python/tvm/relay/frontend/pytorch.py | 8 ++++++++ tests/python/frontend/pytorch/test_forward.py | 12 ++++++++++++ 2 files changed, 20 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 76cd0455661b..da9ae270294e 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2774,6 +2774,13 @@ def all_any_common(self, op, inputs, input_types): inp = inputs[0] return op(inp, axis=dim, keepdims=keepdim) + def searchsorted(self, inputs, input_types): + out_int32 = inputs[2] + right = inputs[3] + dtype = "int32" if out_int32 else "int64" + side = "right" if right else "left" + return _op.searchsorted(inputs[0], inputs[1], side=side, dtype=dtype) + # Operator mappings def create_convert_map(self): self.convert_map = { @@ -2999,6 +3006,7 @@ def create_convert_map(self): "aten::lstm": self.lstm, "aten::all": functools.partial(self.all_any_common, _op.all), "aten::any": functools.partial(self.all_any_common, _op.any), + "aten::searchsorted": self.searchsorted, } def update_convert_map(self, custom_map): diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 3a3889d5cfb7..4a4aaec91b9c 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3962,5 +3962,17 @@ def test_fn(f, dim=None, keepdim=False): verify_model(test_fn(f, 0, keepdim=True), [torch.rand(4, 2).bool()]) +@tvm.testing.uses_gpu +def test_searchsorted(): + def test_fn(out_int32=False, right=False): + return lambda x, y: torch.searchsorted(x, y, out_int32=out_int32, right=right) + + sorted_sequence = torch.tensor([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]]) + values = torch.tensor([[3, 6, 9], [3, 6, 9]]) + verify_model(test_fn(), [sorted_sequence, values]) + verify_model(test_fn(out_int32=True), [sorted_sequence[0], values[0]]) + verify_model(test_fn(right=True), [sorted_sequence, values]) + + if __name__ == "__main__": pytest.main([__file__]) From a57b0814df5718a0d9cd3e7dbf139a5be7ad5128 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 2 Oct 2021 20:57:32 +0900 Subject: [PATCH 13/26] handle other cases --- python/tvm/relay/frontend/pytorch.py | 13 +++++++++++- python/tvm/topi/cuda/searchsorted.py | 8 ++++++-- python/tvm/topi/searchsorted.py | 8 ++++++-- src/relay/op/algorithm/searchsorted.cc | 20 +++++++++++-------- tests/python/frontend/pytorch/test_forward.py | 6 ++++++ 5 files changed, 42 insertions(+), 13 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index da9ae270294e..997ef6a077af 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2775,11 +2775,22 @@ def all_any_common(self, op, inputs, input_types): return op(inp, axis=dim, keepdims=keepdim) def searchsorted(self, inputs, input_types): + values = inputs[1] out_int32 = inputs[2] right = inputs[3] dtype = "int32" if out_int32 else "int64" side = "right" if right else "left" - return _op.searchsorted(inputs[0], inputs[1], side=side, dtype=dtype) + values_shape = _infer_shape(values) + + if len(values_shape) == 0: + values = _op.expand_dims(values, 0) + + out = _op.searchsorted(inputs[0], values, side=side, dtype=dtype) + + if len(values_shape) == 0: + return _op.squeeze(out) + + return out # Operator mappings def create_convert_map(self): diff --git a/python/tvm/topi/cuda/searchsorted.py b/python/tvm/topi/cuda/searchsorted.py index 28a2381b566c..3de35cf129e5 100644 --- a/python/tvm/topi/cuda/searchsorted.py +++ b/python/tvm/topi/cuda/searchsorted.py @@ -45,8 +45,12 @@ def ir(sorted_sequence, values, indices): tid = bx * max_threads + tx with ib.if_scope(tid < num_search): - sequence_id = tid // values_shape[-1] - sequence_offset = sequence_id * search_range + if len(sorted_sequence_shape) == 1: + sequence_offset = 0 + else: + sequence_id = tid // values_shape[-1] + sequence_offset = sequence_id * search_range + binary_search( ib, sequence_offset, diff --git a/python/tvm/topi/searchsorted.py b/python/tvm/topi/searchsorted.py index 53e4bc48c2c3..bdd3ec370094 100644 --- a/python/tvm/topi/searchsorted.py +++ b/python/tvm/topi/searchsorted.py @@ -63,8 +63,12 @@ def ir(sorted_sequence, values, indices): indices = ib.buffer_ptr(indices) with ib.for_range(0, num_search, name="i", kind="parallel") as i: - sequence_id = i // values_shape[-1] - sequence_offset = sequence_id * search_range + if len(sorted_sequence_shape) == 1: + sequence_offset = 0 + else: + sequence_id = i // values_shape[-1] + sequence_offset = sequence_id * search_range + binary_search( ib, sequence_offset, diff --git a/src/relay/op/algorithm/searchsorted.cc b/src/relay/op/algorithm/searchsorted.cc index 99566ef43330..90a1be06c89a 100644 --- a/src/relay/op/algorithm/searchsorted.cc +++ b/src/relay/op/algorithm/searchsorted.cc @@ -38,17 +38,21 @@ bool SearchSortedRel(const Array& types, int num_inputs, const Attrs& attr const auto* values = types[1].as(); ICHECK(sorted_sequence) << "Expects TensorType in the first input"; ICHECK(values) << "Expects TensorType in the second input"; + ICHECK_GT(values->shape.size(), 0) << "The rank of `values` must be greater than one"; ICHECK(param->side == "left" || param->side == "right") - << "'side' parameter must be either 'left' or 'right'"; + << "`side` parameter must be either `left` or `right`"; - ICHECK_EQ(sorted_sequence->shape.size(), values->shape.size()) - << "Ranks of sorted sequence and values must be the same"; + if (sorted_sequence->shape.size() > 1) { + ICHECK_EQ(sorted_sequence->shape.size(), values->shape.size()) + << "Ranks of `sorted_sequence` and values must be the same if `sorted_sequence` is " + "multi-dimensional."; - for (size_t i = 0; i < values->shape.size() - 1; ++i) { - if (sorted_sequence->shape[i].as() && values->shape[i].as()) { - ICHECK_EQ(sorted_sequence->shape[i].as()->value, - values->shape[i].as()->value) - << "sorted sequence and values do not have the same shape along outer axes"; + for (size_t i = 0; i < values->shape.size() - 1; ++i) { + if (sorted_sequence->shape[i].as() && values->shape[i].as()) { + ICHECK_EQ(sorted_sequence->shape[i].as()->value, + values->shape[i].as()->value) + << "`sorted_sequence and `values` do not have the same shape along outer axes"; + } } } diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 4a4aaec91b9c..84511dd66ccb 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3973,6 +3973,12 @@ def test_fn(out_int32=False, right=False): verify_model(test_fn(out_int32=True), [sorted_sequence[0], values[0]]) verify_model(test_fn(right=True), [sorted_sequence, values]) + sorted_sequence_1d = torch.tensor([1, 3, 5, 7, 9]) + values = torch.tensor([[3, 6, 9], [4, 2, 7]]) + verify_model(test_fn(), [sorted_sequence_1d, values]) + + verify_model(test_fn(), [sorted_sequence_1d, torch.tensor(6)]) + if __name__ == "__main__": pytest.main([__file__]) From cda495737c20c518a7df5c5be13e5c1750538312 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 2 Oct 2021 21:00:55 +0900 Subject: [PATCH 14/26] more topi test --- tests/python/topi/python/test_topi_searchsorted.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/python/topi/python/test_topi_searchsorted.py b/tests/python/topi/python/test_topi_searchsorted.py index 7592354e9bb8..71a3a9a859ca 100644 --- a/tests/python/topi/python/test_topi_searchsorted.py +++ b/tests/python/topi/python/test_topi_searchsorted.py @@ -46,8 +46,11 @@ def get_implementations(): @tvm.testing.parametrize_targets def test_searchsorted(dev, target): - def verify(sequence_len, num_search, outer_axes, side): - sorted_sequence_shape = outer_axes + (sequence_len,) + def verify(sequence_len, num_search, outer_axes, side, sorted_sequence_1d=False): + if sorted_sequence_1d: + sorted_sequence_shape = (sequence_len,) + else: + sorted_sequence_shape = outer_axes + (sequence_len,) values_shape = outer_axes + (num_search,) sorted_sequence = te.placeholder(sorted_sequence_shape, dtype="float32") values = te.placeholder(values_shape, dtype="float32") @@ -77,3 +80,4 @@ def verify(sequence_len, num_search, outer_axes, side): verify(999, 2000, (10, 5, 3), "right") verify(1000, 1000, (), "left") verify(2001, 100, (500,), "right") + verify(2001, 100, (500,), "left", sorted_sequence_1d=True) From 16ef46924a4d33603ed3604642376333664f5c77 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 2 Oct 2021 21:15:35 +0900 Subject: [PATCH 15/26] support torch bucketize --- python/tvm/relay/frontend/pytorch.py | 15 ++++++++++----- tests/python/frontend/pytorch/test_forward.py | 12 ++++++++++++ 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 997ef6a077af..80c5a78c72a2 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2774,10 +2774,8 @@ def all_any_common(self, op, inputs, input_types): inp = inputs[0] return op(inp, axis=dim, keepdims=keepdim) - def searchsorted(self, inputs, input_types): - values = inputs[1] - out_int32 = inputs[2] - right = inputs[3] + def searchsorted_common(self, sorted_sequence, values, out_int32, right): + dtype = "int32" if out_int32 else "int64" side = "right" if right else "left" values_shape = _infer_shape(values) @@ -2785,13 +2783,19 @@ def searchsorted(self, inputs, input_types): if len(values_shape) == 0: values = _op.expand_dims(values, 0) - out = _op.searchsorted(inputs[0], values, side=side, dtype=dtype) + out = _op.searchsorted(sorted_sequence, values, side=side, dtype=dtype) if len(values_shape) == 0: return _op.squeeze(out) return out + def searchsorted(self, inputs, input_types): + return self.searchsorted_common(*inputs) + + def bucketize(self, inputs, input_types): + return self.searchsorted_common(inputs[1], inputs[0], inputs[2], inputs[3]) + # Operator mappings def create_convert_map(self): self.convert_map = { @@ -3018,6 +3022,7 @@ def create_convert_map(self): "aten::all": functools.partial(self.all_any_common, _op.all), "aten::any": functools.partial(self.all_any_common, _op.any), "aten::searchsorted": self.searchsorted, + "aten::bucketize": self.bucketize, } def update_convert_map(self, custom_map): diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 84511dd66ccb..0031f4143fab 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3980,5 +3980,17 @@ def test_fn(out_int32=False, right=False): verify_model(test_fn(), [sorted_sequence_1d, torch.tensor(6)]) +@tvm.testing.uses_gpu +def test_bucketize(): + def test_fn(out_int32=False, right=False): + return lambda x, y: torch.bucketize(x, y, out_int32=out_int32, right=right) + + boundaries = torch.tensor([1, 3, 5, 7, 9]) + values = torch.tensor([3, 6, 9]) + + verify_model(test_fn(), [values, boundaries]) + verify_model(test_fn(out_int32=True, right=True), [values, boundaries]) + + if __name__ == "__main__": pytest.main([__file__]) From 6fe38ad154c39a9fb3e7ce87ec20d9fa21ffaf45 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 2 Oct 2021 21:55:50 +0900 Subject: [PATCH 16/26] update doc --- python/tvm/relay/op/algorithm.py | 20 +++++++++++-- python/tvm/topi/cuda/searchsorted.py | 32 +++++++++++++++++++-- python/tvm/topi/searchsorted.py | 39 ++++++++++++++++++++++---- src/relay/op/algorithm/searchsorted.cc | 5 +++- 4 files changed, 86 insertions(+), 10 deletions(-) diff --git a/python/tvm/relay/op/algorithm.py b/python/tvm/relay/op/algorithm.py index 9639a6800433..7b9883f421c5 100644 --- a/python/tvm/relay/op/algorithm.py +++ b/python/tvm/relay/op/algorithm.py @@ -118,10 +118,26 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int32"): def searchsorted(sorted_sequence, values, side="left", dtype="int32"): - """TODO + """Find indices where elements should be inserted to maintain order. + If `sorted_sequence` is N-dimensional, the innermost dimension of + `values` are searched in the corresponding dimension of `sorted_sequence`. Parameters ---------- + sorted_sequence : relay.Expr + N-D or 1-D Tensor, containing monotonically increasing sequence + on the innermost dimension. + + values : relay.Expr + N-D Tensor containing the search values. When `sorted_sequence` is 1-D, + the shape of `values` can be arbitrary. Otherwise, ranks of `sorted_sequence` + and `values` must be the same, and outer N-1 axes must have the same size. + + side : string, optional + It can be `left` or `right`. If `left`, gets the lower bound index for each value + in `values` on the corresponding innermost dimension of the `sorted_sequence`. + If `right`, gets the upper bound index instead. + dtype : string, optional The data type of the output indices. @@ -129,6 +145,6 @@ def searchsorted(sorted_sequence, values, side="left", dtype="int32"): ------- indices : relay.Expr Tensor with same shape as values, representing the indices of - elements of values if they are inserted in sorted_sequence. + elements of `values` if they are inserted in `sorted_sequence`. """ return _make.searchsorted(sorted_sequence, values, side, dtype) diff --git a/python/tvm/topi/cuda/searchsorted.py b/python/tvm/topi/cuda/searchsorted.py index 3de35cf129e5..d84a1941369f 100644 --- a/python/tvm/topi/cuda/searchsorted.py +++ b/python/tvm/topi/cuda/searchsorted.py @@ -22,7 +22,35 @@ def searchsorted(sorted_sequence, values, side="left", out_dtype="int64"): - """TODO""" + """Find indices where elements should be inserted to maintain order. + If `sorted_sequence` is N-dimensional, the innermost dimension of + `values` are searched in the corresponding dimension of `sorted_sequence`. + + Parameters + ---------- + sorted_sequence : te.Tensor + N-D or 1-D Tensor, containing monotonically increasing sequence + on the innermost dimension. + + values : te.Tensor + N-D Tensor containing the search values. When `sorted_sequence` is 1-D, + the shape of `values` can be arbitrary. Otherwise, ranks of `sorted_sequence` + and `values` must be the same, and outer N-1 axes must have the same size. + + side : string, optional + It can be `left` or `right`. If `left`, gets the lower bound index for each value + in `values` on the corresponding innermost dimension of the `sorted_sequence`. + If `right`, gets the upper bound index instead. + + dtype : string, optional + The data type of the output indices. + + Returns + ------- + indices : te.Tensor + Tensor with same shape as values, representing the indices of + elements of `values` if they are inserted in `sorted_sequence`. + """ def ir(sorted_sequence, values, indices): ib = tvm.tir.ir_builder.create() @@ -55,8 +83,8 @@ def ir(sorted_sequence, values, indices): ib, sequence_offset, search_range, - sorted_sequence, tid, + sorted_sequence, values, indices, side, diff --git a/python/tvm/topi/searchsorted.py b/python/tvm/topi/searchsorted.py index bdd3ec370094..d70d0842ce9d 100644 --- a/python/tvm/topi/searchsorted.py +++ b/python/tvm/topi/searchsorted.py @@ -22,9 +22,10 @@ def binary_search( - ib, sequence_offset, search_range, sorted_sequence, i, values, out_indices, side, out_dtype + ib, sequence_offset, search_range, dst_index, sorted_sequence, values, out_indices, + side, out_dtype ): - """TODO""" + """Common IR generator for CPU and GPU searchsorted.""" lo = ib.allocate(out_dtype, (1,), name="lo", scope="local") hi = ib.allocate(out_dtype, (1,), name="hi", scope="local") @@ -45,11 +46,39 @@ def condition(current_val, target_val): with ib.else_scope(): hi[0] = mid - out_indices[i] = lo[0] + out_indices[dst_index] = lo[0] def searchsorted(sorted_sequence, values, side="left", out_dtype="int64"): - """TODO""" + """Find indices where elements should be inserted to maintain order. + If `sorted_sequence` is N-dimensional, the innermost dimension of + `values` are searched in the corresponding dimension of `sorted_sequence`. + + Parameters + ---------- + sorted_sequence : te.Tensor + N-D or 1-D Tensor, containing monotonically increasing sequence + on the innermost dimension. + + values : te.Tensor + N-D Tensor containing the search values. When `sorted_sequence` is 1-D, + the shape of `values` can be arbitrary. Otherwise, ranks of `sorted_sequence` + and `values` must be the same, and outer N-1 axes must have the same size. + + side : string, optional + It can be `left` or `right`. If `left`, gets the lower bound index for each value + in `values` on the corresponding innermost dimension of the `sorted_sequence`. + If `right`, gets the upper bound index instead. + + dtype : string, optional + The data type of the output indices. + + Returns + ------- + indices : te.Tensor + Tensor with same shape as values, representing the indices of + elements of `values` if they are inserted in `sorted_sequence`. + """ def ir(sorted_sequence, values, indices): ib = ir_builder.create() @@ -73,8 +102,8 @@ def ir(sorted_sequence, values, indices): ib, sequence_offset, search_range, - sorted_sequence, i, + sorted_sequence, values, indices, side, diff --git a/src/relay/op/algorithm/searchsorted.cc b/src/relay/op/algorithm/searchsorted.cc index 90a1be06c89a..bed6eab4e52b 100644 --- a/src/relay/op/algorithm/searchsorted.cc +++ b/src/relay/op/algorithm/searchsorted.cc @@ -71,7 +71,10 @@ Expr MakeSearchSorted(Expr sorted_sequence, Expr values, String side, DataType d TVM_REGISTER_GLOBAL("relay.op._make.searchsorted").set_body_typed(MakeSearchSorted); RELAY_REGISTER_OP("searchsorted") - .describe(R"doc(Find indices where elements should be inserted to maintain order. + .describe( + R"doc(Find indices where elements should be inserted to maintain order. +If `sorted_sequence` is N-dimensional, the innermost dimension of +`values` are searched in the corresponding dimension of `sorted_sequence`. )doc" TVM_ADD_FILELINE) .set_num_inputs(2) .set_attrs_type() From ce02cef184a7debd0a11d316c9d13e7de0c3ce30 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 2 Oct 2021 22:04:29 +0900 Subject: [PATCH 17/26] fix tests --- python/tvm/topi/searchsorted.py | 6 +++--- python/tvm/topi/testing/searchsorted.py | 6 +++++- tests/python/topi/python/test_topi_searchsorted.py | 1 - 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/python/tvm/topi/searchsorted.py b/python/tvm/topi/searchsorted.py index d70d0842ce9d..3a984a60ecb5 100644 --- a/python/tvm/topi/searchsorted.py +++ b/python/tvm/topi/searchsorted.py @@ -22,14 +22,14 @@ def binary_search( - ib, sequence_offset, search_range, dst_index, sorted_sequence, values, out_indices, + ib, sequence_offset, search_range, index, sorted_sequence, values, out_indices, side, out_dtype ): """Common IR generator for CPU and GPU searchsorted.""" lo = ib.allocate(out_dtype, (1,), name="lo", scope="local") hi = ib.allocate(out_dtype, (1,), name="hi", scope="local") - v = values[i] + v = values[index] lo[0] = cast(0, out_dtype) hi[0] = cast(search_range, out_dtype) @@ -46,7 +46,7 @@ def condition(current_val, target_val): with ib.else_scope(): hi[0] = mid - out_indices[dst_index] = lo[0] + out_indices[index] = lo[0] def searchsorted(sorted_sequence, values, side="left", out_dtype="int64"): diff --git a/python/tvm/topi/testing/searchsorted.py b/python/tvm/topi/testing/searchsorted.py index 045e00df7eeb..e939fc8a0f7d 100644 --- a/python/tvm/topi/testing/searchsorted.py +++ b/python/tvm/topi/testing/searchsorted.py @@ -2,7 +2,11 @@ def searchsorted_ref(sorted_sequence, values, side, out_dtype): - sorted_sequence_2d = np.reshape(sorted_sequence, (-1, sorted_sequence.shape[-1])) + if len(sorted_sequence.shape) == 1 and len(values.shape) > 1: + sorted_sequence_2d = np.tile(sorted_sequence, (np.prod(values.shape[:-1]), 1)) + else: + sorted_sequence_2d = np.reshape(sorted_sequence, (-1, sorted_sequence.shape[-1])) + values_2d = np.reshape(values, (-1, values.shape[-1])) indices = np.zeros(values_2d.shape, dtype=out_dtype) diff --git a/tests/python/topi/python/test_topi_searchsorted.py b/tests/python/topi/python/test_topi_searchsorted.py index 71a3a9a859ca..0dfb57caee1a 100644 --- a/tests/python/topi/python/test_topi_searchsorted.py +++ b/tests/python/topi/python/test_topi_searchsorted.py @@ -75,7 +75,6 @@ def verify(sequence_len, num_search, outer_axes, side, sorted_sequence_1d=False) ref = searchsorted_ref(a_np, b_np, side, out_dtype) np.testing.assert_equal(c.numpy(), ref) - # The first argument is the range of binary search verify(1024, 1000, (10, 5, 3), "left") verify(999, 2000, (10, 5, 3), "right") verify(1000, 1000, (), "left") From fe01efeb8754a641c59b9ce01ad0b0917604cc4d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 4 Oct 2021 16:13:52 +0900 Subject: [PATCH 18/26] fix lint --- python/tvm/relay/op/_algorithm.py | 1 + python/tvm/relay/op/op_attrs.py | 2 ++ python/tvm/topi/cuda/searchsorted.py | 1 + python/tvm/topi/searchsorted.py | 4 ++-- python/tvm/topi/testing/searchsorted.py | 18 ++++++++++++++++++ 5 files changed, 24 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/op/_algorithm.py b/python/tvm/relay/op/_algorithm.py index ace3de5fff06..19162a108395 100644 --- a/python/tvm/relay/op/_algorithm.py +++ b/python/tvm/relay/op/_algorithm.py @@ -45,6 +45,7 @@ register_strategy("searchsorted", strategy.searchsorted_strategy) register_pattern("searchsorted", OpPattern.OPAQUE) + @script def _topk_shape_func_input_shape(data_shape, k, axis): ndim = data_shape.shape[0] diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index f9a4051f6aeb..dba40b2f6f34 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -563,10 +563,12 @@ class SparseConv2DAttrs(Attrs): class TopkAttrs(Attrs): """Attributes used in topk operators""" + @tvm._ffi.register_object("relay.attrs.SearchSortedAttrs") class SearchSortedAttrs(Attrs): """Attributes used in searchsorted operators""" + @tvm._ffi.register_object("relay.attrs.TupleGetItemAttrs") class TupleGetItemAttrs(Attrs): """Attributes used in tuple item access operators""" diff --git a/python/tvm/topi/cuda/searchsorted.py b/python/tvm/topi/cuda/searchsorted.py index d84a1941369f..286112ed1776 100644 --- a/python/tvm/topi/cuda/searchsorted.py +++ b/python/tvm/topi/cuda/searchsorted.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=invalid-name """searchsorted operator for GPU""" import tvm from tvm import te diff --git a/python/tvm/topi/searchsorted.py b/python/tvm/topi/searchsorted.py index 3a984a60ecb5..02b7d2362109 100644 --- a/python/tvm/topi/searchsorted.py +++ b/python/tvm/topi/searchsorted.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=invalid-name """searchsorted operator""" from . import utils from . import te @@ -22,8 +23,7 @@ def binary_search( - ib, sequence_offset, search_range, index, sorted_sequence, values, out_indices, - side, out_dtype + ib, sequence_offset, search_range, index, sorted_sequence, values, out_indices, side, out_dtype ): """Common IR generator for CPU and GPU searchsorted.""" lo = ib.allocate(out_dtype, (1,), name="lo", scope="local") diff --git a/python/tvm/topi/testing/searchsorted.py b/python/tvm/topi/testing/searchsorted.py index e939fc8a0f7d..1f4300b83411 100644 --- a/python/tvm/topi/testing/searchsorted.py +++ b/python/tvm/topi/testing/searchsorted.py @@ -1,7 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The reference implementation of searchsorted in Numpy.""" import numpy as np def searchsorted_ref(sorted_sequence, values, side, out_dtype): + """Run Numpy searchsorted on 1-D or N-D sorted_sequence.""" if len(sorted_sequence.shape) == 1 and len(values.shape) > 1: sorted_sequence_2d = np.tile(sorted_sequence, (np.prod(values.shape[:-1]), 1)) else: From 3b18a324b5504431c50d2d4e03da21b2e463128f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 5 Oct 2021 06:25:21 +0900 Subject: [PATCH 19/26] rebase fix --- python/tvm/relay/frontend/pytorch.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 80c5a78c72a2..a78bee6aeb2d 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2775,7 +2775,6 @@ def all_any_common(self, op, inputs, input_types): return op(inp, axis=dim, keepdims=keepdim) def searchsorted_common(self, sorted_sequence, values, out_int32, right): - dtype = "int32" if out_int32 else "int64" side = "right" if right else "left" values_shape = _infer_shape(values) From bac6dc5d00a834458e3672e016c9c74ae27f48a8 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 13 Oct 2021 05:35:11 +0900 Subject: [PATCH 20/26] make the test case smaller --- tests/python/relay/test_op_level6.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/relay/test_op_level6.py b/tests/python/relay/test_op_level6.py index f1666ecd005d..b36efb43262d 100644 --- a/tests/python/relay/test_op_level6.py +++ b/tests/python/relay/test_op_level6.py @@ -153,8 +153,8 @@ def verify_topk(k, axis, ret_type, is_ascend, dtype, in_dtype="float32"): @tvm.testing.uses_gpu def test_searchsorted(): def verify_searchsorted(side, dtype): - shape = (10, 20, 100) - values_shape = shape[:-1] + (50,) + shape = (8, 9, 10) + values_shape = shape[:-1] + (10,) sorted_sequence = relay.var("sorted_sequence", relay.TensorType(shape, "float32")) values = relay.var("sorted_sequence", relay.TensorType(values_shape, "float32")) out = relay.searchsorted(sorted_sequence, values, side, dtype) From 5eb0c159e15d15ed3061a04b0afb988f68ef2fda Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 13 Oct 2021 06:23:06 +0900 Subject: [PATCH 21/26] add tests for edge cases --- .../topi/python/test_topi_searchsorted.py | 41 ++++++++++++------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/tests/python/topi/python/test_topi_searchsorted.py b/tests/python/topi/python/test_topi_searchsorted.py index 0dfb57caee1a..aef00663251f 100644 --- a/tests/python/topi/python/test_topi_searchsorted.py +++ b/tests/python/topi/python/test_topi_searchsorted.py @@ -46,14 +46,9 @@ def get_implementations(): @tvm.testing.parametrize_targets def test_searchsorted(dev, target): - def verify(sequence_len, num_search, outer_axes, side, sorted_sequence_1d=False): - if sorted_sequence_1d: - sorted_sequence_shape = (sequence_len,) - else: - sorted_sequence_shape = outer_axes + (sequence_len,) - values_shape = outer_axes + (num_search,) - sorted_sequence = te.placeholder(sorted_sequence_shape, dtype="float32") - values = te.placeholder(values_shape, dtype="float32") + def verify_with_input(sorted_sequence_np, values_np, side): + sorted_sequence = te.placeholder(sorted_sequence_np.shape, dtype="float32") + values = te.placeholder(values_np.shape, dtype="float32") out_dtype = "int32" implementations = get_implementations() fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations) @@ -65,18 +60,34 @@ def verify(sequence_len, num_search, outer_axes, side, sorted_sequence_1d=False) func = tvm.build(s, [sorted_sequence, values, indices], target=target) dev = tvm.device(target, 0) - a_np = np.random.randn(*sorted_sequence_shape).astype(sorted_sequence.dtype) - b_np = np.random.randn(*values_shape).astype(values.dtype) - a_np = np.sort(a_np, axis=-1) - a = tvm.nd.array(a_np, dev) - b = tvm.nd.array(b_np, dev) - c = tvm.nd.array(np.zeros(values_shape, dtype=indices.dtype), dev) + a = tvm.nd.array(sorted_sequence_np, dev) + b = tvm.nd.array(values_np, dev) + c = tvm.nd.array(np.zeros(values_np.shape, dtype=indices.dtype), dev) func(a, b, c) - ref = searchsorted_ref(a_np, b_np, side, out_dtype) + ref = searchsorted_ref(sorted_sequence_np, values_np, side, out_dtype) np.testing.assert_equal(c.numpy(), ref) + def verify(sequence_len, num_search, outer_axes, side, sorted_sequence_1d=False): + if sorted_sequence_1d: + sorted_sequence_shape = (sequence_len,) + else: + sorted_sequence_shape = outer_axes + (sequence_len,) + values_shape = outer_axes + (num_search,) + + verify_with_input( + np.sort(np.random.randn(*sorted_sequence_shape).astype("float32"), axis=-1), + np.random.randn(*values_shape).astype("float32"), + side, + ) + verify(1024, 1000, (10, 5, 3), "left") verify(999, 2000, (10, 5, 3), "right") verify(1000, 1000, (), "left") verify(2001, 100, (500,), "right") verify(2001, 100, (500,), "left", sorted_sequence_1d=True) + + # Check edge cases + for side in ["left", "right"]: + sorted_sequence = np.array([1, 2, 3, 4, 5], dtype="float32") + verify_with_input(sorted_sequence, np.array([6], dtype="float32"), side) + verify_with_input(sorted_sequence, np.array([0], dtype="float32"), side) From 5fc1bbb395eccf6fd842abb2b698bb60eb92ceb9 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 13 Oct 2021 06:41:36 +0900 Subject: [PATCH 22/26] replace "side" attribute with boolean "right" --- include/tvm/relay/attrs/algorithm.h | 9 ++++--- python/tvm/relay/frontend/pytorch.py | 3 +-- python/tvm/relay/op/algorithm.py | 13 +++++----- python/tvm/relay/op/strategy/generic.py | 4 +-- python/tvm/topi/cuda/searchsorted.py | 13 +++++----- python/tvm/topi/searchsorted.py | 21 ++++++++------- python/tvm/topi/testing/searchsorted.py | 3 ++- src/relay/op/algorithm/searchsorted.cc | 6 ++--- tests/python/relay/test_op_level6.py | 10 +++---- .../topi/python/test_topi_searchsorted.py | 26 +++++++++---------- 10 files changed, 56 insertions(+), 52 deletions(-) diff --git a/include/tvm/relay/attrs/algorithm.h b/include/tvm/relay/attrs/algorithm.h index 19f88d9bffc6..3652a09e9168 100644 --- a/include/tvm/relay/attrs/algorithm.h +++ b/include/tvm/relay/attrs/algorithm.h @@ -77,12 +77,15 @@ struct TopKAttrs : public tvm::AttrsNode { }; struct SearchSortedAttrs : public tvm::AttrsNode { - std::string side; + bool right; DataType dtype; TVM_DECLARE_ATTRS(SearchSortedAttrs, "relay.attrs.SearchSortedAttrs") { - TVM_ATTR_FIELD(side).set_default("left").describe( - "Controls which index is returned if a value lands exactly on one of sorted values."); + TVM_ATTR_FIELD(right).set_default(false).describe( + "Controls which index is returned if a value lands exactly on one of sorted values. If " + " false, the index of the first suitable location found is given. If true, return the " + "last such index. If there is no suitable index, return either 0 or N (where N is the " + "size of the innermost dimension)."); TVM_ATTR_FIELD(dtype) .set_default(DataType::Int(32)) .describe("Data type of the output indices."); diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index a78bee6aeb2d..3fc202a7cc91 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2776,13 +2776,12 @@ def all_any_common(self, op, inputs, input_types): def searchsorted_common(self, sorted_sequence, values, out_int32, right): dtype = "int32" if out_int32 else "int64" - side = "right" if right else "left" values_shape = _infer_shape(values) if len(values_shape) == 0: values = _op.expand_dims(values, 0) - out = _op.searchsorted(sorted_sequence, values, side=side, dtype=dtype) + out = _op.searchsorted(sorted_sequence, values, right=right, dtype=dtype) if len(values_shape) == 0: return _op.squeeze(out) diff --git a/python/tvm/relay/op/algorithm.py b/python/tvm/relay/op/algorithm.py index 7b9883f421c5..809a9061ade0 100644 --- a/python/tvm/relay/op/algorithm.py +++ b/python/tvm/relay/op/algorithm.py @@ -117,7 +117,7 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int32"): return out -def searchsorted(sorted_sequence, values, side="left", dtype="int32"): +def searchsorted(sorted_sequence, values, right=False, dtype="int32"): """Find indices where elements should be inserted to maintain order. If `sorted_sequence` is N-dimensional, the innermost dimension of `values` are searched in the corresponding dimension of `sorted_sequence`. @@ -133,10 +133,11 @@ def searchsorted(sorted_sequence, values, side="left", dtype="int32"): the shape of `values` can be arbitrary. Otherwise, ranks of `sorted_sequence` and `values` must be the same, and outer N-1 axes must have the same size. - side : string, optional - It can be `left` or `right`. If `left`, gets the lower bound index for each value - in `values` on the corresponding innermost dimension of the `sorted_sequence`. - If `right`, gets the upper bound index instead. + right : bool, optional + Controls which index is returned if a value lands exactly on one of sorted values. If + False, the index of the first suitable location found is given. If true, return the + last such index. If there is no suitable index, return either 0 or N (where N is the + size of the innermost dimension). dtype : string, optional The data type of the output indices. @@ -147,4 +148,4 @@ def searchsorted(sorted_sequence, values, side="left", dtype="int32"): Tensor with same shape as values, representing the indices of elements of `values` if they are inserted in `sorted_sequence`. """ - return _make.searchsorted(sorted_sequence, values, side, dtype) + return _make.searchsorted(sorted_sequence, values, right, dtype) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index fc6ba8fda86d..777f17ba6084 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1007,9 +1007,9 @@ def wrap_compute_searchsorted(topi_compute): """Wrap searchsorted compute""" def _compute_searchsorted(attrs, inputs, out_type): - side = attrs.side + right = attrs.right dtype = attrs.dtype - return [topi_compute(inputs[0], inputs[1], side, dtype)] + return [topi_compute(inputs[0], inputs[1], right, dtype)] return _compute_searchsorted diff --git a/python/tvm/topi/cuda/searchsorted.py b/python/tvm/topi/cuda/searchsorted.py index 286112ed1776..46f182b0ced3 100644 --- a/python/tvm/topi/cuda/searchsorted.py +++ b/python/tvm/topi/cuda/searchsorted.py @@ -22,7 +22,7 @@ from ..searchsorted import binary_search -def searchsorted(sorted_sequence, values, side="left", out_dtype="int64"): +def searchsorted(sorted_sequence, values, right, out_dtype="int64"): """Find indices where elements should be inserted to maintain order. If `sorted_sequence` is N-dimensional, the innermost dimension of `values` are searched in the corresponding dimension of `sorted_sequence`. @@ -38,10 +38,11 @@ def searchsorted(sorted_sequence, values, side="left", out_dtype="int64"): the shape of `values` can be arbitrary. Otherwise, ranks of `sorted_sequence` and `values` must be the same, and outer N-1 axes must have the same size. - side : string, optional - It can be `left` or `right`. If `left`, gets the lower bound index for each value - in `values` on the corresponding innermost dimension of the `sorted_sequence`. - If `right`, gets the upper bound index instead. + right : bool, optional + Controls which index is returned if a value lands exactly on one of sorted values. If + False, the index of the first suitable location found is given. If true, return the + last such index. If there is no suitable index, return either 0 or N (where N is the + size of the innermost dimension). dtype : string, optional The data type of the output indices. @@ -88,7 +89,7 @@ def ir(sorted_sequence, values, indices): sorted_sequence, values, indices, - side, + right, out_dtype, ) diff --git a/python/tvm/topi/searchsorted.py b/python/tvm/topi/searchsorted.py index 02b7d2362109..56d016aee775 100644 --- a/python/tvm/topi/searchsorted.py +++ b/python/tvm/topi/searchsorted.py @@ -23,7 +23,7 @@ def binary_search( - ib, sequence_offset, search_range, index, sorted_sequence, values, out_indices, side, out_dtype + ib, sequence_offset, search_range, index, sorted_sequence, values, out_indices, right, out_dtype ): """Common IR generator for CPU and GPU searchsorted.""" lo = ib.allocate(out_dtype, (1,), name="lo", scope="local") @@ -35,9 +35,9 @@ def binary_search( # Reference: pytorch/aten/src/ATen/native/cuda/Bucketization.cu def condition(current_val, target_val): - if side == "left": - return current_val < target_val - return current_val <= target_val + if right: + return current_val <= target_val + return current_val < target_val with ib.while_loop(lo[0] < hi[0]): mid = lo[0] + (hi[0] - lo[0] >> 1) @@ -49,7 +49,7 @@ def condition(current_val, target_val): out_indices[index] = lo[0] -def searchsorted(sorted_sequence, values, side="left", out_dtype="int64"): +def searchsorted(sorted_sequence, values, right=False, out_dtype="int64"): """Find indices where elements should be inserted to maintain order. If `sorted_sequence` is N-dimensional, the innermost dimension of `values` are searched in the corresponding dimension of `sorted_sequence`. @@ -65,10 +65,11 @@ def searchsorted(sorted_sequence, values, side="left", out_dtype="int64"): the shape of `values` can be arbitrary. Otherwise, ranks of `sorted_sequence` and `values` must be the same, and outer N-1 axes must have the same size. - side : string, optional - It can be `left` or `right`. If `left`, gets the lower bound index for each value - in `values` on the corresponding innermost dimension of the `sorted_sequence`. - If `right`, gets the upper bound index instead. + right : bool, optional + Controls which index is returned if a value lands exactly on one of sorted values. If + False, the index of the first suitable location found is given. If true, return the + last such index. If there is no suitable index, return either 0 or N (where N is the + size of the innermost dimension). dtype : string, optional The data type of the output indices. @@ -106,7 +107,7 @@ def ir(sorted_sequence, values, indices): sorted_sequence, values, indices, - side, + right, out_dtype, ) diff --git a/python/tvm/topi/testing/searchsorted.py b/python/tvm/topi/testing/searchsorted.py index 1f4300b83411..10762600992d 100644 --- a/python/tvm/topi/testing/searchsorted.py +++ b/python/tvm/topi/testing/searchsorted.py @@ -18,8 +18,9 @@ import numpy as np -def searchsorted_ref(sorted_sequence, values, side, out_dtype): +def searchsorted_ref(sorted_sequence, values, right, out_dtype): """Run Numpy searchsorted on 1-D or N-D sorted_sequence.""" + side = "right" if right else "left" if len(sorted_sequence.shape) == 1 and len(values.shape) > 1: sorted_sequence_2d = np.tile(sorted_sequence, (np.prod(values.shape[:-1]), 1)) else: diff --git a/src/relay/op/algorithm/searchsorted.cc b/src/relay/op/algorithm/searchsorted.cc index bed6eab4e52b..be5921311660 100644 --- a/src/relay/op/algorithm/searchsorted.cc +++ b/src/relay/op/algorithm/searchsorted.cc @@ -39,8 +39,6 @@ bool SearchSortedRel(const Array& types, int num_inputs, const Attrs& attr ICHECK(sorted_sequence) << "Expects TensorType in the first input"; ICHECK(values) << "Expects TensorType in the second input"; ICHECK_GT(values->shape.size(), 0) << "The rank of `values` must be greater than one"; - ICHECK(param->side == "left" || param->side == "right") - << "`side` parameter must be either `left` or `right`"; if (sorted_sequence->shape.size() > 1) { ICHECK_EQ(sorted_sequence->shape.size(), values->shape.size()) @@ -60,11 +58,11 @@ bool SearchSortedRel(const Array& types, int num_inputs, const Attrs& attr return true; } -Expr MakeSearchSorted(Expr sorted_sequence, Expr values, String side, DataType dtype) { +Expr MakeSearchSorted(Expr sorted_sequence, Expr values, Bool right, DataType dtype) { auto attrs = make_object(); static const Op& op = Op::Get("searchsorted"); attrs->dtype = dtype; - attrs->side = side; + attrs->right = right; return Call(op, {sorted_sequence, values}, Attrs(attrs), {}); } diff --git a/tests/python/relay/test_op_level6.py b/tests/python/relay/test_op_level6.py index b36efb43262d..48c58dc2dc33 100644 --- a/tests/python/relay/test_op_level6.py +++ b/tests/python/relay/test_op_level6.py @@ -152,16 +152,16 @@ def verify_topk(k, axis, ret_type, is_ascend, dtype, in_dtype="float32"): @tvm.testing.uses_gpu def test_searchsorted(): - def verify_searchsorted(side, dtype): + def verify_searchsorted(right, dtype): shape = (8, 9, 10) values_shape = shape[:-1] + (10,) sorted_sequence = relay.var("sorted_sequence", relay.TensorType(shape, "float32")) values = relay.var("sorted_sequence", relay.TensorType(values_shape, "float32")) - out = relay.searchsorted(sorted_sequence, values, side, dtype) + out = relay.searchsorted(sorted_sequence, values, right, dtype) func = relay.Function([sorted_sequence, values], out) sorted_sequence_np = np.sort(np.random.randn(*shape).astype("float32"), axis=-1) values_np = np.random.randn(*values_shape).astype("float32") - np_indices = searchsorted_ref(sorted_sequence_np, values_np, side, dtype) + np_indices = searchsorted_ref(sorted_sequence_np, values_np, right, dtype) for target, dev in tvm.testing.enabled_targets(): op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)( @@ -169,8 +169,8 @@ def verify_searchsorted(side, dtype): ) np.testing.assert_equal(op_res.numpy(), np_indices) - verify_searchsorted("left", "int32") - verify_searchsorted("right", "int64") + verify_searchsorted(False, "int32") + verify_searchsorted(True, "int64") if __name__ == "__main__": diff --git a/tests/python/topi/python/test_topi_searchsorted.py b/tests/python/topi/python/test_topi_searchsorted.py index aef00663251f..7b3976b7eb74 100644 --- a/tests/python/topi/python/test_topi_searchsorted.py +++ b/tests/python/topi/python/test_topi_searchsorted.py @@ -46,7 +46,7 @@ def get_implementations(): @tvm.testing.parametrize_targets def test_searchsorted(dev, target): - def verify_with_input(sorted_sequence_np, values_np, side): + def verify_with_input(sorted_sequence_np, values_np, right): sorted_sequence = te.placeholder(sorted_sequence_np.shape, dtype="float32") values = te.placeholder(values_np.shape, dtype="float32") out_dtype = "int32" @@ -54,7 +54,7 @@ def verify_with_input(sorted_sequence_np, values_np, side): fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations) with tvm.target.Target(target): - indices = fcompute(sorted_sequence, values, side, out_dtype) + indices = fcompute(sorted_sequence, values, right, out_dtype) s = fschedule([indices]) func = tvm.build(s, [sorted_sequence, values, indices], target=target) @@ -64,10 +64,10 @@ def verify_with_input(sorted_sequence_np, values_np, side): b = tvm.nd.array(values_np, dev) c = tvm.nd.array(np.zeros(values_np.shape, dtype=indices.dtype), dev) func(a, b, c) - ref = searchsorted_ref(sorted_sequence_np, values_np, side, out_dtype) + ref = searchsorted_ref(sorted_sequence_np, values_np, right, out_dtype) np.testing.assert_equal(c.numpy(), ref) - def verify(sequence_len, num_search, outer_axes, side, sorted_sequence_1d=False): + def verify(sequence_len, num_search, outer_axes, right, sorted_sequence_1d=False): if sorted_sequence_1d: sorted_sequence_shape = (sequence_len,) else: @@ -77,17 +77,17 @@ def verify(sequence_len, num_search, outer_axes, side, sorted_sequence_1d=False) verify_with_input( np.sort(np.random.randn(*sorted_sequence_shape).astype("float32"), axis=-1), np.random.randn(*values_shape).astype("float32"), - side, + right, ) - verify(1024, 1000, (10, 5, 3), "left") - verify(999, 2000, (10, 5, 3), "right") - verify(1000, 1000, (), "left") - verify(2001, 100, (500,), "right") - verify(2001, 100, (500,), "left", sorted_sequence_1d=True) + verify(1024, 1000, (10, 5, 3), False) + verify(999, 2000, (10, 5, 3), True) + verify(1000, 1000, (), False) + verify(2001, 100, (500,), True) + verify(2001, 100, (500,), False, sorted_sequence_1d=True) # Check edge cases - for side in ["left", "right"]: + for right in [True, False]: sorted_sequence = np.array([1, 2, 3, 4, 5], dtype="float32") - verify_with_input(sorted_sequence, np.array([6], dtype="float32"), side) - verify_with_input(sorted_sequence, np.array([0], dtype="float32"), side) + verify_with_input(sorted_sequence, np.array([6], dtype="float32"), right) + verify_with_input(sorted_sequence, np.array([0], dtype="float32"), right) From 4775b72c474dc91c0fabbb71ef1d3eaffa456440 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 13 Oct 2021 07:12:03 +0900 Subject: [PATCH 23/26] add more descrition to binear_search IR gen params --- python/tvm/topi/searchsorted.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/python/tvm/topi/searchsorted.py b/python/tvm/topi/searchsorted.py index 56d016aee775..81535b77c718 100644 --- a/python/tvm/topi/searchsorted.py +++ b/python/tvm/topi/searchsorted.py @@ -25,7 +25,19 @@ def binary_search( ib, sequence_offset, search_range, index, sorted_sequence, values, out_indices, right, out_dtype ): - """Common IR generator for CPU and GPU searchsorted.""" + """Common IR generator for binary search used by CPU and GPU backends. + + `sorted_sequence` is a N-D Buffer whose innermost dimension we want to search, + and `search_range` is the size of the innermost dimension. + + `index` is the index of the current value in `values` being searched. `sequence_offset` is + a 1-D linearlized offset specifying which of innermost sequences to search for `values[index]`. + + So the search for `values[index]` is performed over + `sorted_sequence[sequence_offset:(sequence_offset + search_range)]`. + Note that we index N-D Buffer by 1-D linearlized indices. + + """ lo = ib.allocate(out_dtype, (1,), name="lo", scope="local") hi = ib.allocate(out_dtype, (1,), name="hi", scope="local") From c3eace86f41cfbcbef40f8199760a9e7ac23551b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 13 Oct 2021 08:39:21 +0900 Subject: [PATCH 24/26] return index from binary_search rather than update inplace --- python/tvm/topi/cuda/searchsorted.py | 5 ++--- python/tvm/topi/searchsorted.py | 22 +++++++++------------- 2 files changed, 11 insertions(+), 16 deletions(-) diff --git a/python/tvm/topi/cuda/searchsorted.py b/python/tvm/topi/cuda/searchsorted.py index 46f182b0ced3..b896de57ae73 100644 --- a/python/tvm/topi/cuda/searchsorted.py +++ b/python/tvm/topi/cuda/searchsorted.py @@ -81,13 +81,12 @@ def ir(sorted_sequence, values, indices): sequence_id = tid // values_shape[-1] sequence_offset = sequence_id * search_range - binary_search( + indices[tid] = binary_search( ib, sequence_offset, search_range, - tid, sorted_sequence, - values, + values[tid], indices, right, out_dtype, diff --git a/python/tvm/topi/searchsorted.py b/python/tvm/topi/searchsorted.py index 81535b77c718..cfe645590670 100644 --- a/python/tvm/topi/searchsorted.py +++ b/python/tvm/topi/searchsorted.py @@ -23,17 +23,15 @@ def binary_search( - ib, sequence_offset, search_range, index, sorted_sequence, values, out_indices, right, out_dtype + ib, sequence_offset, search_range, sorted_sequence, value, out_indices, right, out_dtype ): """Common IR generator for binary search used by CPU and GPU backends. - `sorted_sequence` is a N-D Buffer whose innermost dimension we want to search, - and `search_range` is the size of the innermost dimension. + `sorted_sequence` is a N-D Buffer whose innermost dimension we want to search for `value`, + and `search_range` is the size of the innermost dimension. `sequence_offset` is + a 1-D linearlized offset specifying which of innermost sequences to search. - `index` is the index of the current value in `values` being searched. `sequence_offset` is - a 1-D linearlized offset specifying which of innermost sequences to search for `values[index]`. - - So the search for `values[index]` is performed over + So the search for `value` is performed over `sorted_sequence[sequence_offset:(sequence_offset + search_range)]`. Note that we index N-D Buffer by 1-D linearlized indices. @@ -41,7 +39,6 @@ def binary_search( lo = ib.allocate(out_dtype, (1,), name="lo", scope="local") hi = ib.allocate(out_dtype, (1,), name="hi", scope="local") - v = values[index] lo[0] = cast(0, out_dtype) hi[0] = cast(search_range, out_dtype) @@ -53,12 +50,12 @@ def condition(current_val, target_val): with ib.while_loop(lo[0] < hi[0]): mid = lo[0] + (hi[0] - lo[0] >> 1) - with ib.if_scope(condition(sorted_sequence[sequence_offset + mid], v)): + with ib.if_scope(condition(sorted_sequence[sequence_offset + mid], value)): lo[0] = mid + 1 with ib.else_scope(): hi[0] = mid - out_indices[index] = lo[0] + return lo[0] def searchsorted(sorted_sequence, values, right=False, out_dtype="int64"): @@ -111,13 +108,12 @@ def ir(sorted_sequence, values, indices): sequence_id = i // values_shape[-1] sequence_offset = sequence_id * search_range - binary_search( + indices[i] = binary_search( ib, sequence_offset, search_range, - i, sorted_sequence, - values, + values[i], indices, right, out_dtype, From 169088ced5092c0f5092aab2fd8858cfcc055496 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 13 Oct 2021 08:44:26 +0900 Subject: [PATCH 25/26] remove unused argument --- python/tvm/topi/cuda/searchsorted.py | 1 - python/tvm/topi/searchsorted.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/python/tvm/topi/cuda/searchsorted.py b/python/tvm/topi/cuda/searchsorted.py index b896de57ae73..1c39ccaa8632 100644 --- a/python/tvm/topi/cuda/searchsorted.py +++ b/python/tvm/topi/cuda/searchsorted.py @@ -87,7 +87,6 @@ def ir(sorted_sequence, values, indices): search_range, sorted_sequence, values[tid], - indices, right, out_dtype, ) diff --git a/python/tvm/topi/searchsorted.py b/python/tvm/topi/searchsorted.py index cfe645590670..500b78d7b0e8 100644 --- a/python/tvm/topi/searchsorted.py +++ b/python/tvm/topi/searchsorted.py @@ -23,7 +23,7 @@ def binary_search( - ib, sequence_offset, search_range, sorted_sequence, value, out_indices, right, out_dtype + ib, sequence_offset, search_range, sorted_sequence, value, right, out_dtype ): """Common IR generator for binary search used by CPU and GPU backends. @@ -114,7 +114,6 @@ def ir(sorted_sequence, values, indices): search_range, sorted_sequence, values[i], - indices, right, out_dtype, ) From 431db6b3f450e6ff61056e632d2413903d7411e0 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 13 Oct 2021 08:48:20 +0900 Subject: [PATCH 26/26] format fix --- python/tvm/topi/searchsorted.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/tvm/topi/searchsorted.py b/python/tvm/topi/searchsorted.py index 500b78d7b0e8..28ffd170c955 100644 --- a/python/tvm/topi/searchsorted.py +++ b/python/tvm/topi/searchsorted.py @@ -22,9 +22,7 @@ from .math import cast -def binary_search( - ib, sequence_offset, search_range, sorted_sequence, value, right, out_dtype -): +def binary_search(ib, sequence_offset, search_range, sorted_sequence, value, right, out_dtype): """Common IR generator for binary search used by CPU and GPU backends. `sorted_sequence` is a N-D Buffer whose innermost dimension we want to search for `value`,