diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 52656872ad10..ccdc871e8a78 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -314,6 +314,12 @@ struct OneHotAttrs : public tvm::AttrsNode { } }; // struct OneHotAttrs +/*! \brief Attributes for ArgWhere operator */ +struct ArgWhereAttrs : public tvm::AttrsNode { + TVM_DECLARE_ATTRS(ArgWhereAttrs, "relay.attrs.ArgWhereAttrs") { + } +}; // struct ArgWhereAttrs + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_TRANSFORM_H_ diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 5dddfc6f88d6..3197b81289b2 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -15,8 +15,10 @@ # specific language governing permissions and limitations # under the License. """Backend compiler related feature registration""" -# pylint: disable=invalid-name,unused-argument, len-as-condition +# pylint: disable=invalid-name,unused-argument, len-as-condition, too-many-nested-blocks from __future__ import absolute_import +import tvm +import topi from topi.util import get_const_int, get_const_tuple from . import op as _reg from ._reduce import _schedule_reduce @@ -204,3 +206,100 @@ def take_shape_func(attrs, inputs, out_ndims): axis += data_ndim assert 0 <= axis < data_ndim return [_take_with_axis_shape_func(*inputs, convert(axis), out_ndims[0])] + +@script +def _argwhere_shape_func_1d(condition): + out = output_tensor((2, ), "int64") + out[0] = int64(0) + out[1] = int64(1) + for i1 in range(condition.shape[0]): + if condition[i1] != 0: + out[0] += int64(1) + return out + +@script +def _argwhere_shape_func_2d(condition): + out = output_tensor((2, ), "int64") + out[0] = int64(0) + out[1] = int64(2) + for i1 in range(condition.shape[0]): + for i2 in range(condition.shape[1]): + if condition[i1, i2] != 0: + out[0] += int64(1) + return out + +@script +def _argwhere_shape_func_3d(condition): + out = output_tensor((2, ), "int64") + out[0] = int64(0) + out[1] = int64(3) + for i1 in range(condition.shape[0]): + for i2 in range(condition.shape[1]): + for i3 in range(condition.shape[2]): + if condition[i1, i2, i3] != 0: + out[0] += int64(1) + return out + +@script +def _argwhere_shape_func_4d(condition): + out = output_tensor((2, ), "int64") + out[0] = int64(0) + out[1] = int64(4) + for i1 in range(condition.shape[0]): + for i2 in range(condition.shape[1]): + for i3 in range(condition.shape[2]): + for i4 in range(condition.shape[3]): + if condition[i1, i2, i3, i4] != 0: + out[0] += int64(1) + return out + +@script +def _argwhere_shape_func_5d(condition): + out = output_tensor((2, ), "int64") + out[0] = int64(0) + out[1] = int64(5) + for i1 in range(condition.shape[0]): + for i2 in range(condition.shape[1]): + for i3 in range(condition.shape[2]): + for i4 in range(condition.shape[3]): + for i5 in range(condition.shape[4]): + if condition[i1, i2, i3, i4, i5] != 0: + out[0] += int64(1) + return out + +@_reg.register_shape_func("argwhere", True) +def argwhere_shape_func(attrs, inputs, out_ndims): + """ + Shape function for argwhere. + """ + if len(inputs[0].shape) == 1: + return [_argwhere_shape_func_1d(inputs[0])] + elif len(inputs[0].shape) == 2: + return [_argwhere_shape_func_2d(inputs[0])] + elif len(inputs[0].shape) == 3: + return [_argwhere_shape_func_3d(inputs[0])] + elif len(inputs[0].shape) == 4: + return [_argwhere_shape_func_4d(inputs[0])] + elif len(inputs[0].shape) == 5: + return [_argwhere_shape_func_5d(inputs[0])] + return ValueError("Does not support rank higher than 5 in argwhere") + +@_reg.register_schedule("argwhere") +def schedule_argwhere(_, outs, target): + """Schedule definition of argwhere""" + with target: + return topi.generic.schedule_argwhere(outs) + + +@_reg.register_compute("argwhere") +def compute_argwhere(attrs, inputs, output_type, _): + """Compute definition of argwhere""" + output_shape = [] + for s in output_type.shape: + if hasattr(s, "value"): + output_shape.append(s) + else: + # see Any, replace it with a var + output_shape.append(tvm.var("any_dim", "int32")) + new_output_type = tvm.relay.ty.TensorType(output_shape, "int32") + return [topi.argwhere(new_output_type, inputs[0])] diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 7f921d03a62f..88d7a448005c 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -144,7 +144,6 @@ def squeeze(data, axis=None): """ return _make.squeeze(data, axis) - def reshape(data, newshape): """Reshapes the input array. @@ -214,6 +213,28 @@ def reshape(data, newshape): newshape = [newshape] return _make.reshape(data, list(newshape)) +def argwhere(condition): + """Find the indices of elements of a tensor that are + non-zero. + + Parameters + ---------- + condition : relay.Expr + The input condition tensor. + + Returns + ------- + out : relay.Expr + Tensor with the indices of elements that are non-zero. + + Examples + -------- + .. code-block:: python + + condition = [[True, False], [False, True]] + relay.argwhere(condition) = [[0, 0], [1, 1]] + """ + return _make.argwhere(condition) def reshape_like(data, shape_like): """Reshapes the input array by the size of another array. diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 9cd436bbac7a..5411be2b5fe8 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -817,6 +817,40 @@ the input array into an output array with the same shape as the second input arr .set_attr("FTVMCompute", ReshapeCompute) .set_attr("TOpPattern", kInjective); +// ArgWhere +bool ArgWhereRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(num_inputs, 1); + auto tt = types[0].as(); + CHECK(tt != nullptr); + const auto& input_shape = tt->shape; + const auto& input_rank = input_shape.size(); + std::vector result_shape; + result_shape.push_back(Any::make()); + result_shape.push_back(IntImm::make(Int(32), input_rank)); + reporter->Assign(types[1], TensorTypeNode::make(result_shape, Int(32))); + return true; +} + +TVM_REGISTER_API("relay.op._make.argwhere") +.set_body_typed([](Expr data) { + static const Op& op = Op::Get("argwhere"); + auto attrs = make_node(); + return CallNode::make(op, {data}, Attrs(attrs), {}); +}); + +RELAY_REGISTER_OP("argwhere") +.describe(R"doc(Find the indices of elements of a tensor that are +non-zero)doc" TVM_ADD_FILELINE) +.set_num_inputs(1) +.set_attrs_type_key("relay.attrs.ArgWhereAttrs") +.add_argument("condition", "Tensor", "The input condition tensor.") +.add_type_rel("ArgWhere", ArgWhereRel) +.set_attr("TOpIsStateful", false) +.set_attr("TOpPattern", kOpaque) +.set_support_level(10); // Take TVM_REGISTER_NODE_TYPE(TakeAttrs); diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 214b88fa1850..d02dcd0b73dd 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -92,6 +92,36 @@ def test_any_reshape(): verify_any_reshape(any_dims(3), (-4, 2, -1, -2), (6, 3, 4), (2, 3, 3, 4)) verify_any_reshape(any_dims(3), (-4, -1, 2, -3), (6, 3, 4), (3, 2, 12)) +def verify_any_argwhere(x_shape, x_np_shape, dtype="bool"): + x = relay.var('x', shape=x_shape, dtype=dtype) + y = relay.argwhere(x) + mod = relay.module.Module() + mod["main"] = relay.Function([x], y) + data = np.random.choice([0, 1, 2, 3], size=x_np_shape).astype(dtype) + for kind in ["debug", "vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(data).asnumpy() + expected = np.argwhere(data) + assert result.shape == expected.shape + tvm.testing.assert_allclose(result.flatten(), expected.flatten()) + +def test_any_argwhere(): + verify_any_argwhere(any_dims(1), (5,)) + verify_any_argwhere(any_dims(2), (5, 5)) + verify_any_argwhere(any_dims(3), (5, 5, 5)) + verify_any_argwhere(any_dims(4), (5, 5, 5, 5)) + verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5)) + verify_any_argwhere(any_dims(1), (5,), "int32") + verify_any_argwhere(any_dims(2), (5, 5), "int32") + verify_any_argwhere(any_dims(3), (5, 5, 5), "int32") + verify_any_argwhere(any_dims(4), (5, 5, 5, 5), "int32") + verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5), "int32") + verify_any_argwhere(any_dims(1), (5,), "int8") + verify_any_argwhere(any_dims(2), (5, 5), "int8") + verify_any_argwhere(any_dims(3), (5, 5, 5), "int8") + verify_any_argwhere(any_dims(4), (5, 5, 5, 5), "int8") + verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5), "int8") + def verify_any_take(data_shape, indices_shape, axis, data_np_shape, indices_np_shape): mod = relay.Module() data = relay.var('data', shape=data_shape, dtype='float32') diff --git a/topi/python/topi/__init__.py b/topi/python/topi/__init__.py index ac855d144aad..fd293a09b9e7 100644 --- a/topi/python/topi/__init__.py +++ b/topi/python/topi/__init__.py @@ -22,6 +22,7 @@ from .transform import * from .broadcast import * from .sort import * +from .argwhere import * from . import nn from . import x86 from . import cuda diff --git a/topi/python/topi/argwhere.py b/topi/python/topi/argwhere.py new file mode 100644 index 000000000000..32f4e8718c46 --- /dev/null +++ b/topi/python/topi/argwhere.py @@ -0,0 +1,191 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks +"""Argwhere operator""" +import tvm +from tvm import hybrid + +@hybrid.script +def hybrid_argwhere_1d(output_shape, condition): + """Find the indices of elements of a 1-D tensor that are non-zero. + + Parameters + ---------- + condition : tvm.Tensor + 1-D tensor with boolean values. + + Returns + ------- + out : tvm.Tensor + Indices of non-zero elements. + """ + a = output_tensor(output_shape, "int32") + a1 = condition.shape[0] + valid_index = 0 + for i1 in range(a1): + if condition[i1] != 0: + a[valid_index, 0] = i1 + valid_index += 1 + return a + +@hybrid.script +def hybrid_argwhere_2d(output_shape, condition): + """Find the indices of elements of a 2-D tensor that are non-zero. + + Parameters + ---------- + condition : tvm.Tensor + 2-D tensor with boolean values. + + Returns + ------- + out : tvm.Tensor + Indices of non-zero elements. + """ + a = output_tensor(output_shape, "int32") + a1 = condition.shape[0] + a2 = condition.shape[1] + valid_index = 0 + for i1 in range(a1): + for i2 in range(a2): + if condition[i1, i2] != 0: + a[valid_index, 0] = i1 + a[valid_index, 1] = i2 + valid_index += 1 + return a + +@hybrid.script +def hybrid_argwhere_3d(output_shape, condition): + """Find the indices of elements of a 3-D tensor that are non-zero. + + Parameters + ---------- + condition : tvm.Tensor + 3-D tensor with boolean values. + + Returns + ------- + out : tvm.Tensor + Indices of non-zero elements. + """ + a = output_tensor(output_shape, "int32") + a1 = condition.shape[0] + a2 = condition.shape[1] + a3 = condition.shape[2] + valid_index = 0 + for i1 in range(a1): + for i2 in range(a2): + for i3 in range(a3): + if condition[i1, i2, i3] != 0: + a[valid_index, 0] = i1 + a[valid_index, 1] = i2 + a[valid_index, 2] = i3 + valid_index += 1 + return a + +@hybrid.script +def hybrid_argwhere_4d(output_shape, condition): + """Find the indices of elements of a 4-D tensor that are non-zero. + + Parameters + ---------- + condition : tvm.Tensor + 4-D tensor with boolean values. + + Returns + ------- + out : tvm.Tensor + Indices of non-zero elements. + """ + a = output_tensor(output_shape, "int32") + a1 = condition.shape[0] + a2 = condition.shape[1] + a3 = condition.shape[2] + a4 = condition.shape[3] + valid_index = 0 + for i1 in range(a1): + for i2 in range(a2): + for i3 in range(a3): + for i4 in range(a4): + if condition[i1, i2, i3, i4] != 0: + a[valid_index, 0] = i1 + a[valid_index, 1] = i2 + a[valid_index, 2] = i3 + a[valid_index, 3] = i4 + valid_index += 1 + return a + +@hybrid.script +def hybrid_argwhere_5d(output_shape, condition): + """Find the indices of elements of a 5-D tensor that are non-zero. + + Parameters + ---------- + condition : tvm.Tensor + 5-D tensor with boolean values. + + Returns + ------- + out : tvm.Tensor + Indices of non-zero elements. + """ + a = output_tensor(output_shape, "int32") + a1 = condition.shape[0] + a2 = condition.shape[1] + a3 = condition.shape[2] + a4 = condition.shape[3] + a5 = condition.shape[4] + valid_index = 0 + for i1 in range(a1): + for i2 in range(a2): + for i3 in range(a3): + for i4 in range(a4): + for i5 in range(a5): + if condition[i1, i2, i3, i4, i5] != 0: + a[valid_index, 0] = i1 + a[valid_index, 1] = i2 + a[valid_index, 2] = i3 + a[valid_index, 3] = i4 + a[valid_index, 4] = i5 + valid_index += 1 + return a + +@tvm.target.generic_func +def argwhere(output_shape, condition): + """Find the indices of elements of a tensor that are non-zero. + + Parameters + ---------- + condition : tvm.Tensor + Tensor with boolean values. + + Returns + ------- + out : tvm.Tensor + Indices of non-zero elements. + """ + if len(condition.shape) == 1: + return hybrid_argwhere_1d(output_shape.shape, condition) + if len(condition.shape) == 2: + return hybrid_argwhere_2d(output_shape.shape, condition) + if len(condition.shape) == 3: + return hybrid_argwhere_3d(output_shape.shape, condition) + if len(condition.shape) == 4: + return hybrid_argwhere_4d(output_shape.shape, condition) + if len(condition.shape) == 5: + return hybrid_argwhere_5d(output_shape.shape, condition) + raise ValueError("Does not support rank higher than 5 in argwhere") diff --git a/topi/python/topi/generic/__init__.py b/topi/python/topi/generic/__init__.py index 6bf5f3a053c9..18af0e328471 100644 --- a/topi/python/topi/generic/__init__.py +++ b/topi/python/topi/generic/__init__.py @@ -20,3 +20,4 @@ from .extern import * from .vision import * from .sort import * +from .search import * diff --git a/topi/python/topi/generic/search.py b/topi/python/topi/generic/search.py new file mode 100644 index 000000000000..41045e492e53 --- /dev/null +++ b/topi/python/topi/generic/search.py @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, no-member +"""Generic search operators""" +from __future__ import absolute_import as _abs +import tvm +from .vision import _default_schedule + +@tvm.target.generic_func +def schedule_argwhere(outs): + """Schedule for argwhere operator. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of argwhere. + + Returns + ------- + s: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False)