From 961cea5b4c8c100ad470ca0b9439bed16e17a3f2 Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Thu, 19 Sep 2019 18:38:03 -0700 Subject: [PATCH 01/14] Add op argwhere --- include/tvm/relay/attrs/algorithm.h | 6 + python/tvm/relay/op/_algorithm.py | 22 ++++ python/tvm/relay/op/_transform.py | 67 ++++++++++- python/tvm/relay/op/algorithm.py | 16 +++ python/tvm/relay/op/transform.py | 1 - src/relay/op/algorithm/argwhere.cc | 68 +++++++++++ src/relay/op/tensor/unary.cc | 1 - tests/python/relay/test_any.py | 19 +++ topi/python/topi/__init__.py | 1 + topi/python/topi/argwhere.py | 166 +++++++++++++++++++++++++++ topi/python/topi/generic/__init__.py | 1 + topi/python/topi/generic/where.py | 37 ++++++ topi/python/topi/transform.py | 1 + 13 files changed, 403 insertions(+), 3 deletions(-) create mode 100644 src/relay/op/algorithm/argwhere.cc create mode 100644 topi/python/topi/argwhere.py create mode 100644 topi/python/topi/generic/where.py diff --git a/include/tvm/relay/attrs/algorithm.h b/include/tvm/relay/attrs/algorithm.h index ce14a6a2d535..781ff79a6a41 100644 --- a/include/tvm/relay/attrs/algorithm.h +++ b/include/tvm/relay/attrs/algorithm.h @@ -31,6 +31,12 @@ namespace tvm { namespace relay { +/*! \brief Attributes for ArgWhere operator */ +struct ArgWhereAttrs : public tvm::AttrsNode { + TVM_DECLARE_ATTRS(ArgWhereAttrs, "relay.attrs.ArgWhereAttrs") { + } +}; + /*! \brief Attributes used in argsort operators */ struct ArgsortAttrs : public tvm::AttrsNode { int axis; diff --git a/python/tvm/relay/op/_algorithm.py b/python/tvm/relay/op/_algorithm.py index 09746be13e30..5617a1666f76 100644 --- a/python/tvm/relay/op/_algorithm.py +++ b/python/tvm/relay/op/_algorithm.py @@ -17,6 +17,8 @@ "Definition of classic algorithms" # pylint: disable=invalid-name,unused-argument from __future__ import absolute_import +import tvm +from tvm.relay.ty import TensorType import topi from topi.util import get_const_int @@ -41,6 +43,26 @@ def compute_argsort(attrs, inputs, _, target): register_pattern("argsort", OpPattern.OPAQUE) +# argwhere +@register_schedule("argwhere") +def schedule_argwhere(_, outs, target): + """Schedule definition of argwhere""" + with target: + return topi.generic.schedule_argwhere(outs) + + +@register_compute("argwhere") +def compute_argwhere(attrs, inputs, output_type, _): + """Compute definition of argwhere""" + output_shape = [] + for s in output_type.shape: + if hasattr(s, "value"): + output_shape.append(s) + else: + # see Any, replace it with a var + output_shape.append(tvm.var("any_dim", "int32")) + new_output_type = TensorType(output_shape, "int32") + return [topi.argwhere(new_output_type, inputs[0])] @register_schedule("topk") def schedule_topk(_, outs, target): diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 5dddfc6f88d6..e6082669e0f7 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Backend compiler related feature registration""" -# pylint: disable=invalid-name,unused-argument, len-as-condition +# pylint: disable=invalid-name,unused-argument, len-as-condition, too-many-nested-blocks from __future__ import absolute_import from topi.util import get_const_int, get_const_tuple from . import op as _reg @@ -204,3 +204,68 @@ def take_shape_func(attrs, inputs, out_ndims): axis += data_ndim assert 0 <= axis < data_ndim return [_take_with_axis_shape_func(*inputs, convert(axis), out_ndims[0])] + +@script +def _argwhere_shape_func_2d(condition): + out = output_tensor((2, ), "int64") + out[0] = int64(0) + out[1] = int64(2) + for i1 in range(condition.shape[0]): + for i2 in range(condition.shape[1]): + if condition[i1, i2]: + out[0] += int64(1) + return out + +@script +def _argwhere_shape_func_3d(condition): + out = output_tensor((2, ), "int64") + out[0] = int64(0) + out[1] = int64(3) + for i1 in range(condition.shape[0]): + for i2 in range(condition.shape[1]): + for i3 in range(condition.shape[2]): + if condition[i1, i2, i3]: + out[0] += int64(1) + return out + +@script +def _argwhere_shape_func_4d(condition): + out = output_tensor((2, ), "int64") + out[0] = int64(0) + out[1] = int64(4) + for i1 in range(condition.shape[0]): + for i2 in range(condition.shape[1]): + for i3 in range(condition.shape[2]): + for i4 in range(condition.shape[3]): + if condition[i1, i2, i3, i4]: + out[0] += int64(1) + return out + +@script +def _argwhere_shape_func_5d(condition): + out = output_tensor((2, ), "int64") + out[0] = int64(0) + out[1] = int64(5) + for i1 in range(condition.shape[0]): + for i2 in range(condition.shape[1]): + for i3 in range(condition.shape[2]): + for i4 in range(condition.shape[3]): + for i5 in range(condition.shape[4]): + if condition[i1, i2, i3, i4, i5]: + out[0] += int64(1) + return out + +@_reg.register_shape_func("argwhere", True) +def argwhere_shape_func(attrs, inputs, out_ndims): + """ + Shape function for argwhere. + """ + if len(inputs[0].shape) == 2: + return [_argwhere_shape_func_2d(inputs[0])] + elif len(inputs[0].shape) == 3: + return [_argwhere_shape_func_3d(inputs[0])] + elif len(inputs[0].shape) == 4: + return [_argwhere_shape_func_4d(inputs[0])] + elif len(inputs[0].shape) == 5: + return [_argwhere_shape_func_5d(inputs[0])] + return [] diff --git a/python/tvm/relay/op/algorithm.py b/python/tvm/relay/op/algorithm.py index 6f875919df4c..20d4f9c6a412 100644 --- a/python/tvm/relay/op/algorithm.py +++ b/python/tvm/relay/op/algorithm.py @@ -19,6 +19,22 @@ from . import _make from ..expr import TupleWrapper +def argwhere(condition): + """Find the indices of elements of a tensor that are + non-zero. + + Parameters + ---------- + condition : relay.Expr + The input condition tensor. + + Returns + ------- + out : relay.Expr + Tensor with the indices of elements that are non-zero. + """ + return _make.argwhere(condition) + def argsort(data, axis=-1, is_ascend=1, dtype="int32"): """Performs sorting along the given axis and returns an array of indicies having same shape as an input array that index data in sorted order. diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 7f921d03a62f..c730c7047bd0 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -144,7 +144,6 @@ def squeeze(data, axis=None): """ return _make.squeeze(data, axis) - def reshape(data, newshape): """Reshapes the input array. diff --git a/src/relay/op/algorithm/argwhere.cc b/src/relay/op/algorithm/argwhere.cc new file mode 100644 index 000000000000..87a2c7b17fd9 --- /dev/null +++ b/src/relay/op/algorithm/argwhere.cc @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file argwhere.cc + * \brief Argwhere operators + */ +#include +#include +#include + +namespace tvm { +namespace relay { + +// ArgWhere +bool ArgWhereRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(num_inputs, 1); + auto tt = types[0].as(); + CHECK(tt != nullptr); + const auto& input_shape = tt->shape; + const auto& input_rank = input_shape.size(); + std::vector result_shape; + result_shape.push_back(Any::make()); + result_shape.push_back(IntImm::make(Int(32), input_rank)); + reporter->Assign(types[1], TensorTypeNode::make(result_shape, Int(32))); + return true; +} + +TVM_REGISTER_API("relay.op._make.argwhere") +.set_body_typed([](Expr data) { + static const Op& op = Op::Get("argwhere"); + auto attrs = make_node(); + return CallNode::make(op, {data}, Attrs(attrs), {}); +}); + +RELAY_REGISTER_OP("argwhere") +.describe(R"doc(Find the indices of elements of a tensor that are +non-zero)doc" TVM_ADD_FILELINE) +.set_num_inputs(1) +.set_attrs_type_key("relay.attrs.ArgWhereAttrs") +.add_argument("condition", "Tensor", "The input condition tensor.") +.add_type_rel("ArgWhere", ArgWhereRel) +.set_attr("TOpIsStateful", false) +.set_attr("TOpPattern", kOpaque) +.set_support_level(10); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc index 0e3e539cc928..3a99c2656173 100644 --- a/src/relay/op/tensor/unary.cc +++ b/src/relay/op/tensor/unary.cc @@ -326,7 +326,6 @@ RELAY_REGISTER_OP("shape_of") .set_support_level(10) .set_attr("FTVMCompute", ShapeOfCompute); - TVM_REGISTER_NODE_TYPE(NdarraySizeAttrs); bool NdarraySizeRel(const Array& types, diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 214b88fa1850..df6d8d149c00 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -92,6 +92,25 @@ def test_any_reshape(): verify_any_reshape(any_dims(3), (-4, 2, -1, -2), (6, 3, 4), (2, 3, 3, 4)) verify_any_reshape(any_dims(3), (-4, -1, 2, -3), (6, 3, 4), (3, 2, 12)) +def verify_any_argwhere(x_shape, x_np_shape, out_shape): + x = relay.var('x', shape=x_shape, dtype="bool") + y = relay.argwhere(x) + mod = relay.module.Module() + mod["main"] = relay.Function([x], y) + data = np.random.choice([True, False], size=x_np_shape) + for kind in ["debug", "vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(data).asnumpy() + expected = np.argwhere(data) + assert result.shape == expected.shape + tvm.testing.assert_allclose(result.flatten(), expected.flatten()) + +def test_any_argwhere(): + verify_any_argwhere(any_dims(2), (5, 5), None) + verify_any_argwhere(any_dims(3), (5, 5, 5), None) + verify_any_argwhere(any_dims(4), (5, 5, 5, 5), None) + verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5), None) + def verify_any_take(data_shape, indices_shape, axis, data_np_shape, indices_np_shape): mod = relay.Module() data = relay.var('data', shape=data_shape, dtype='float32') diff --git a/topi/python/topi/__init__.py b/topi/python/topi/__init__.py index ac855d144aad..fd293a09b9e7 100644 --- a/topi/python/topi/__init__.py +++ b/topi/python/topi/__init__.py @@ -22,6 +22,7 @@ from .transform import * from .broadcast import * from .sort import * +from .argwhere import * from . import nn from . import x86 from . import cuda diff --git a/topi/python/topi/argwhere.py b/topi/python/topi/argwhere.py new file mode 100644 index 000000000000..ff9cbbf56927 --- /dev/null +++ b/topi/python/topi/argwhere.py @@ -0,0 +1,166 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks +"""Argwhere operator""" +import tvm +from tvm import hybrid + +@hybrid.script +def hybrid_argwhere_2d(output_shape, condition): + """Find the indices of elements of a 2-D tensor that are non-zero. + + Parameters + ---------- + condition : tvm.Tensor + 2-D tensor with boolean values. + + Returns + ------- + out : tvm.Tensor + Indices of non-zero elements. + """ + a = output_tensor(output_shape, "int32") + a1 = condition.shape[0] + a2 = condition.shape[1] + valid_index = 0 + for i1 in range(a1): + for i2 in range(a2): + if condition[i1, i2]: + a[valid_index, 0] = i1 + a[valid_index, 1] = i2 + valid_index += 1 + return a + +@hybrid.script +def hybrid_argwhere_3d(output_shape, condition): + """Find the indices of elements of a 3-D tensor that are non-zero. + + Parameters + ---------- + condition : tvm.Tensor + 3-D tensor with boolean values. + + Returns + ------- + out : tvm.Tensor + Indices of non-zero elements. + """ + a = output_tensor(output_shape, "int32") + a1 = condition.shape[0] + a2 = condition.shape[1] + a3 = condition.shape[2] + valid_index = 0 + for i1 in range(a1): + for i2 in range(a2): + for i3 in range(a3): + if condition[i1, i2, i3]: + a[valid_index, 0] = i1 + a[valid_index, 1] = i2 + a[valid_index, 2] = i3 + valid_index += 1 + return a + +@hybrid.script +def hybrid_argwhere_4d(output_shape, condition): + """Find the indices of elements of a 4-D tensor that are non-zero. + + Parameters + ---------- + condition : tvm.Tensor + 4-D tensor with boolean values. + + Returns + ------- + out : tvm.Tensor + Indices of non-zero elements. + """ + a = output_tensor(output_shape, "int32") + a1 = condition.shape[0] + a2 = condition.shape[1] + a3 = condition.shape[2] + a4 = condition.shape[3] + valid_index = 0 + for i1 in range(a1): + for i2 in range(a2): + for i3 in range(a3): + for i4 in range(a4): + if condition[i1, i2, i3, i4]: + a[valid_index, 0] = i1 + a[valid_index, 1] = i2 + a[valid_index, 2] = i3 + a[valid_index, 3] = i4 + valid_index += 1 + return a + +@hybrid.script +def hybrid_argwhere_5d(output_shape, condition): + """Find the indices of elements of a 5-D tensor that are non-zero. + + Parameters + ---------- + condition : tvm.Tensor + 5-D tensor with boolean values. + + Returns + ------- + out : tvm.Tensor + Indices of non-zero elements. + """ + a = output_tensor(output_shape, "int32") + a1 = condition.shape[0] + a2 = condition.shape[1] + a3 = condition.shape[2] + a4 = condition.shape[3] + a5 = condition.shape[4] + valid_index = 0 + for i1 in range(a1): + for i2 in range(a2): + for i3 in range(a3): + for i4 in range(a4): + for i5 in range(a5): + if condition[i1, i2, i3, i4, i5]: + a[valid_index, 0] = i1 + a[valid_index, 1] = i2 + a[valid_index, 2] = i3 + a[valid_index, 3] = i4 + a[valid_index, 4] = i5 + valid_index += 1 + return a + +@tvm.target.generic_func +def argwhere(output_shape, condition): + """Find the indices of elements of a tensor that are non-zero. + + Parameters + ---------- + condition : tvm.Tensor + Tensor with boolean values. + + Returns + ------- + out : tvm.Tensor + Indices of non-zero elements. + """ + if len(condition.shape) == 2: + return hybrid_argwhere_2d(output_shape.shape, condition) + if len(condition.shape) == 3: + return hybrid_argwhere_3d(output_shape.shape, condition) + if len(condition.shape) == 4: + return hybrid_argwhere_4d(output_shape.shape, condition) + if len(condition.shape) == 5: + return hybrid_argwhere_5d(output_shape.shape, condition) + return [] diff --git a/topi/python/topi/generic/__init__.py b/topi/python/topi/generic/__init__.py index 6bf5f3a053c9..e3317a3c11f8 100644 --- a/topi/python/topi/generic/__init__.py +++ b/topi/python/topi/generic/__init__.py @@ -20,3 +20,4 @@ from .extern import * from .vision import * from .sort import * +from .where import * diff --git a/topi/python/topi/generic/where.py b/topi/python/topi/generic/where.py new file mode 100644 index 000000000000..8d3d137f7249 --- /dev/null +++ b/topi/python/topi/generic/where.py @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, no-member +"""Generic vision operators""" +from __future__ import absolute_import as _abs +import tvm +from .vision import _default_schedule + +@tvm.target.generic_func +def schedule_argwhere(outs): + """Schedule for argwhere operator. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of argwhere. + + Returns + ------- + s: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) diff --git a/topi/python/topi/transform.py b/topi/python/topi/transform.py index 3c7fc9c0dffb..aeb826b3a1c2 100644 --- a/topi/python/topi/transform.py +++ b/topi/python/topi/transform.py @@ -562,3 +562,4 @@ def one_hot(indices, on_value, off_value, depth, axis, dtype): [0, 0, 1]] """ return cpp.one_hot(indices, on_value, off_value, depth, axis, dtype) + \ No newline at end of file From 30500280f2cf14c4132e18be9ff73d79b614d421 Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Wed, 25 Sep 2019 10:17:29 -0700 Subject: [PATCH 02/14] Move shape func to _algorithm.py --- python/tvm/relay/op/_algorithm.py | 84 +++++++++++++++++++++++++++---- python/tvm/relay/op/_transform.py | 65 ------------------------ 2 files changed, 75 insertions(+), 74 deletions(-) diff --git a/python/tvm/relay/op/_algorithm.py b/python/tvm/relay/op/_algorithm.py index 5617a1666f76..a22307e50ce7 100644 --- a/python/tvm/relay/op/_algorithm.py +++ b/python/tvm/relay/op/_algorithm.py @@ -22,17 +22,18 @@ import topi from topi.util import get_const_int -from ..op import OpPattern, register_compute, register_schedule, register_pattern +from . import op as _reg +from ...hybrid import script -@register_schedule("argsort") +@_reg.register_schedule("argsort") def schedule_argsort(_, outs, target): """Schedule definition of argsort""" with target: return topi.generic.schedule_argsort(outs) -@register_compute("argsort") +@_reg.register_compute("argsort") def compute_argsort(attrs, inputs, _, target): """Compute definition of argsort""" axis = get_const_int(attrs.axis) @@ -41,17 +42,17 @@ def compute_argsort(attrs, inputs, _, target): return [topi.argsort(inputs[0], axis=axis, is_ascend=is_ascend, dtype=dtype)] -register_pattern("argsort", OpPattern.OPAQUE) +_reg.register_pattern("argsort", _reg.OpPattern.OPAQUE) # argwhere -@register_schedule("argwhere") +@_reg.register_schedule("argwhere") def schedule_argwhere(_, outs, target): """Schedule definition of argwhere""" with target: return topi.generic.schedule_argwhere(outs) -@register_compute("argwhere") +@_reg.register_compute("argwhere") def compute_argwhere(attrs, inputs, output_type, _): """Compute definition of argwhere""" output_shape = [] @@ -64,14 +65,14 @@ def compute_argwhere(attrs, inputs, output_type, _): new_output_type = TensorType(output_shape, "int32") return [topi.argwhere(new_output_type, inputs[0])] -@register_schedule("topk") +@_reg.register_schedule("topk") def schedule_topk(_, outs, target): """Schedule definition of argsort""" with target: return topi.generic.schedule_topk(outs) -@register_compute("topk") +@_reg.register_compute("topk") def compute_topk(attrs, inputs, _, target): """Compute definition of argsort""" k = get_const_int(attrs.k) @@ -84,4 +85,69 @@ def compute_topk(attrs, inputs, _, target): return out -register_pattern("topk", OpPattern.OPAQUE) +_reg.register_pattern("topk", _reg.OpPattern.OPAQUE) + +@script +def _argwhere_shape_func_2d(condition): + out = output_tensor((2, ), "int64") + out[0] = int64(0) + out[1] = int64(2) + for i1 in range(condition.shape[0]): + for i2 in range(condition.shape[1]): + if condition[i1, i2]: + out[0] += int64(1) + return out + +@script +def _argwhere_shape_func_3d(condition): + out = output_tensor((2, ), "int64") + out[0] = int64(0) + out[1] = int64(3) + for i1 in range(condition.shape[0]): + for i2 in range(condition.shape[1]): + for i3 in range(condition.shape[2]): + if condition[i1, i2, i3]: + out[0] += int64(1) + return out + +@script +def _argwhere_shape_func_4d(condition): + out = output_tensor((2, ), "int64") + out[0] = int64(0) + out[1] = int64(4) + for i1 in range(condition.shape[0]): + for i2 in range(condition.shape[1]): + for i3 in range(condition.shape[2]): + for i4 in range(condition.shape[3]): + if condition[i1, i2, i3, i4]: + out[0] += int64(1) + return out + +@script +def _argwhere_shape_func_5d(condition): + out = output_tensor((2, ), "int64") + out[0] = int64(0) + out[1] = int64(5) + for i1 in range(condition.shape[0]): + for i2 in range(condition.shape[1]): + for i3 in range(condition.shape[2]): + for i4 in range(condition.shape[3]): + for i5 in range(condition.shape[4]): + if condition[i1, i2, i3, i4, i5]: + out[0] += int64(1) + return out + +@_reg.register_shape_func("argwhere", True) +def argwhere_shape_func(attrs, inputs, out_ndims): + """ + Shape function for argwhere. + """ + if len(inputs[0].shape) == 2: + return [_argwhere_shape_func_2d(inputs[0])] + elif len(inputs[0].shape) == 3: + return [_argwhere_shape_func_3d(inputs[0])] + elif len(inputs[0].shape) == 4: + return [_argwhere_shape_func_4d(inputs[0])] + elif len(inputs[0].shape) == 5: + return [_argwhere_shape_func_5d(inputs[0])] + return [] diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index e6082669e0f7..80b0491d4cd3 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -204,68 +204,3 @@ def take_shape_func(attrs, inputs, out_ndims): axis += data_ndim assert 0 <= axis < data_ndim return [_take_with_axis_shape_func(*inputs, convert(axis), out_ndims[0])] - -@script -def _argwhere_shape_func_2d(condition): - out = output_tensor((2, ), "int64") - out[0] = int64(0) - out[1] = int64(2) - for i1 in range(condition.shape[0]): - for i2 in range(condition.shape[1]): - if condition[i1, i2]: - out[0] += int64(1) - return out - -@script -def _argwhere_shape_func_3d(condition): - out = output_tensor((2, ), "int64") - out[0] = int64(0) - out[1] = int64(3) - for i1 in range(condition.shape[0]): - for i2 in range(condition.shape[1]): - for i3 in range(condition.shape[2]): - if condition[i1, i2, i3]: - out[0] += int64(1) - return out - -@script -def _argwhere_shape_func_4d(condition): - out = output_tensor((2, ), "int64") - out[0] = int64(0) - out[1] = int64(4) - for i1 in range(condition.shape[0]): - for i2 in range(condition.shape[1]): - for i3 in range(condition.shape[2]): - for i4 in range(condition.shape[3]): - if condition[i1, i2, i3, i4]: - out[0] += int64(1) - return out - -@script -def _argwhere_shape_func_5d(condition): - out = output_tensor((2, ), "int64") - out[0] = int64(0) - out[1] = int64(5) - for i1 in range(condition.shape[0]): - for i2 in range(condition.shape[1]): - for i3 in range(condition.shape[2]): - for i4 in range(condition.shape[3]): - for i5 in range(condition.shape[4]): - if condition[i1, i2, i3, i4, i5]: - out[0] += int64(1) - return out - -@_reg.register_shape_func("argwhere", True) -def argwhere_shape_func(attrs, inputs, out_ndims): - """ - Shape function for argwhere. - """ - if len(inputs[0].shape) == 2: - return [_argwhere_shape_func_2d(inputs[0])] - elif len(inputs[0].shape) == 3: - return [_argwhere_shape_func_3d(inputs[0])] - elif len(inputs[0].shape) == 4: - return [_argwhere_shape_func_4d(inputs[0])] - elif len(inputs[0].shape) == 5: - return [_argwhere_shape_func_5d(inputs[0])] - return [] From 0f94dfa894ec49a1e2e13ca9c6aad4fe5a698a77 Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Wed, 25 Sep 2019 10:47:05 -0700 Subject: [PATCH 03/14] Add lint rule --- python/tvm/relay/op/_algorithm.py | 2 +- python/tvm/relay/op/_transform.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/op/_algorithm.py b/python/tvm/relay/op/_algorithm.py index a22307e50ce7..3d86d9a408c8 100644 --- a/python/tvm/relay/op/_algorithm.py +++ b/python/tvm/relay/op/_algorithm.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. "Definition of classic algorithms" -# pylint: disable=invalid-name,unused-argument +# pylint: disable=invalid-name,unused-argument, too-many-nested-blocks from __future__ import absolute_import import tvm from tvm.relay.ty import TensorType diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 80b0491d4cd3..5dddfc6f88d6 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Backend compiler related feature registration""" -# pylint: disable=invalid-name,unused-argument, len-as-condition, too-many-nested-blocks +# pylint: disable=invalid-name,unused-argument, len-as-condition from __future__ import absolute_import from topi.util import get_const_int, get_const_tuple from . import op as _reg From 73c110c888cdf4989c6d072a8b6289cad38e9965 Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Wed, 25 Sep 2019 12:40:52 -0700 Subject: [PATCH 04/14] Raise exception if rank is not supportted --- topi/python/topi/argwhere.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/topi/python/topi/argwhere.py b/topi/python/topi/argwhere.py index ff9cbbf56927..6d1e611e96ae 100644 --- a/topi/python/topi/argwhere.py +++ b/topi/python/topi/argwhere.py @@ -163,4 +163,4 @@ def argwhere(output_shape, condition): return hybrid_argwhere_4d(output_shape.shape, condition) if len(condition.shape) == 5: return hybrid_argwhere_5d(output_shape.shape, condition) - return [] + raise ValueError("Does not support rank higher than 5") From d147e127d9a894d621ffd99b82857a1922589397 Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Wed, 25 Sep 2019 13:30:10 -0700 Subject: [PATCH 05/14] move argwhere to transform --- include/tvm/relay/attrs/algorithm.h | 6 -- include/tvm/relay/attrs/transform.h | 6 ++ python/tvm/relay/op/_algorithm.py | 104 +++------------------------- python/tvm/relay/op/_transform.py | 88 +++++++++++++++++++++++ python/tvm/relay/op/algorithm.py | 16 ----- python/tvm/relay/op/transform.py | 15 ++++ src/relay/op/algorithm/argwhere.cc | 68 ------------------ src/relay/op/tensor/transform.cc | 34 +++++++++ src/relay/op/tensor/unary.cc | 1 + 9 files changed, 152 insertions(+), 186 deletions(-) delete mode 100644 src/relay/op/algorithm/argwhere.cc diff --git a/include/tvm/relay/attrs/algorithm.h b/include/tvm/relay/attrs/algorithm.h index 781ff79a6a41..ce14a6a2d535 100644 --- a/include/tvm/relay/attrs/algorithm.h +++ b/include/tvm/relay/attrs/algorithm.h @@ -31,12 +31,6 @@ namespace tvm { namespace relay { -/*! \brief Attributes for ArgWhere operator */ -struct ArgWhereAttrs : public tvm::AttrsNode { - TVM_DECLARE_ATTRS(ArgWhereAttrs, "relay.attrs.ArgWhereAttrs") { - } -}; - /*! \brief Attributes used in argsort operators */ struct ArgsortAttrs : public tvm::AttrsNode { int axis; diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 52656872ad10..ccdc871e8a78 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -314,6 +314,12 @@ struct OneHotAttrs : public tvm::AttrsNode { } }; // struct OneHotAttrs +/*! \brief Attributes for ArgWhere operator */ +struct ArgWhereAttrs : public tvm::AttrsNode { + TVM_DECLARE_ATTRS(ArgWhereAttrs, "relay.attrs.ArgWhereAttrs") { + } +}; // struct ArgWhereAttrs + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_TRANSFORM_H_ diff --git a/python/tvm/relay/op/_algorithm.py b/python/tvm/relay/op/_algorithm.py index 3d86d9a408c8..09746be13e30 100644 --- a/python/tvm/relay/op/_algorithm.py +++ b/python/tvm/relay/op/_algorithm.py @@ -15,25 +15,22 @@ # specific language governing permissions and limitations # under the License. "Definition of classic algorithms" -# pylint: disable=invalid-name,unused-argument, too-many-nested-blocks +# pylint: disable=invalid-name,unused-argument from __future__ import absolute_import -import tvm -from tvm.relay.ty import TensorType import topi from topi.util import get_const_int -from . import op as _reg -from ...hybrid import script +from ..op import OpPattern, register_compute, register_schedule, register_pattern -@_reg.register_schedule("argsort") +@register_schedule("argsort") def schedule_argsort(_, outs, target): """Schedule definition of argsort""" with target: return topi.generic.schedule_argsort(outs) -@_reg.register_compute("argsort") +@register_compute("argsort") def compute_argsort(attrs, inputs, _, target): """Compute definition of argsort""" axis = get_const_int(attrs.axis) @@ -42,37 +39,17 @@ def compute_argsort(attrs, inputs, _, target): return [topi.argsort(inputs[0], axis=axis, is_ascend=is_ascend, dtype=dtype)] -_reg.register_pattern("argsort", _reg.OpPattern.OPAQUE) +register_pattern("argsort", OpPattern.OPAQUE) -# argwhere -@_reg.register_schedule("argwhere") -def schedule_argwhere(_, outs, target): - """Schedule definition of argwhere""" - with target: - return topi.generic.schedule_argwhere(outs) - - -@_reg.register_compute("argwhere") -def compute_argwhere(attrs, inputs, output_type, _): - """Compute definition of argwhere""" - output_shape = [] - for s in output_type.shape: - if hasattr(s, "value"): - output_shape.append(s) - else: - # see Any, replace it with a var - output_shape.append(tvm.var("any_dim", "int32")) - new_output_type = TensorType(output_shape, "int32") - return [topi.argwhere(new_output_type, inputs[0])] -@_reg.register_schedule("topk") +@register_schedule("topk") def schedule_topk(_, outs, target): """Schedule definition of argsort""" with target: return topi.generic.schedule_topk(outs) -@_reg.register_compute("topk") +@register_compute("topk") def compute_topk(attrs, inputs, _, target): """Compute definition of argsort""" k = get_const_int(attrs.k) @@ -85,69 +62,4 @@ def compute_topk(attrs, inputs, _, target): return out -_reg.register_pattern("topk", _reg.OpPattern.OPAQUE) - -@script -def _argwhere_shape_func_2d(condition): - out = output_tensor((2, ), "int64") - out[0] = int64(0) - out[1] = int64(2) - for i1 in range(condition.shape[0]): - for i2 in range(condition.shape[1]): - if condition[i1, i2]: - out[0] += int64(1) - return out - -@script -def _argwhere_shape_func_3d(condition): - out = output_tensor((2, ), "int64") - out[0] = int64(0) - out[1] = int64(3) - for i1 in range(condition.shape[0]): - for i2 in range(condition.shape[1]): - for i3 in range(condition.shape[2]): - if condition[i1, i2, i3]: - out[0] += int64(1) - return out - -@script -def _argwhere_shape_func_4d(condition): - out = output_tensor((2, ), "int64") - out[0] = int64(0) - out[1] = int64(4) - for i1 in range(condition.shape[0]): - for i2 in range(condition.shape[1]): - for i3 in range(condition.shape[2]): - for i4 in range(condition.shape[3]): - if condition[i1, i2, i3, i4]: - out[0] += int64(1) - return out - -@script -def _argwhere_shape_func_5d(condition): - out = output_tensor((2, ), "int64") - out[0] = int64(0) - out[1] = int64(5) - for i1 in range(condition.shape[0]): - for i2 in range(condition.shape[1]): - for i3 in range(condition.shape[2]): - for i4 in range(condition.shape[3]): - for i5 in range(condition.shape[4]): - if condition[i1, i2, i3, i4, i5]: - out[0] += int64(1) - return out - -@_reg.register_shape_func("argwhere", True) -def argwhere_shape_func(attrs, inputs, out_ndims): - """ - Shape function for argwhere. - """ - if len(inputs[0].shape) == 2: - return [_argwhere_shape_func_2d(inputs[0])] - elif len(inputs[0].shape) == 3: - return [_argwhere_shape_func_3d(inputs[0])] - elif len(inputs[0].shape) == 4: - return [_argwhere_shape_func_4d(inputs[0])] - elif len(inputs[0].shape) == 5: - return [_argwhere_shape_func_5d(inputs[0])] - return [] +register_pattern("topk", OpPattern.OPAQUE) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 5dddfc6f88d6..489904420155 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -17,12 +17,15 @@ """Backend compiler related feature registration""" # pylint: disable=invalid-name,unused-argument, len-as-condition from __future__ import absolute_import +import tvm +import topi from topi.util import get_const_int, get_const_tuple from . import op as _reg from ._reduce import _schedule_reduce from .op import OpPattern from ...hybrid import script from ...api import convert +from tvm.relay.ty import TensorType schedule_injective = _reg.schedule_injective schedule_broadcast = _reg.schedule_injective @@ -204,3 +207,88 @@ def take_shape_func(attrs, inputs, out_ndims): axis += data_ndim assert 0 <= axis < data_ndim return [_take_with_axis_shape_func(*inputs, convert(axis), out_ndims[0])] + +@script +def _argwhere_shape_func_2d(condition): + out = output_tensor((2, ), "int64") + out[0] = int64(0) + out[1] = int64(2) + for i1 in range(condition.shape[0]): + for i2 in range(condition.shape[1]): + if condition[i1, i2]: + out[0] += int64(1) + return out + +@script +def _argwhere_shape_func_3d(condition): + out = output_tensor((2, ), "int64") + out[0] = int64(0) + out[1] = int64(3) + for i1 in range(condition.shape[0]): + for i2 in range(condition.shape[1]): + for i3 in range(condition.shape[2]): + if condition[i1, i2, i3]: + out[0] += int64(1) + return out + +@script +def _argwhere_shape_func_4d(condition): + out = output_tensor((2, ), "int64") + out[0] = int64(0) + out[1] = int64(4) + for i1 in range(condition.shape[0]): + for i2 in range(condition.shape[1]): + for i3 in range(condition.shape[2]): + for i4 in range(condition.shape[3]): + if condition[i1, i2, i3, i4]: + out[0] += int64(1) + return out + +@script +def _argwhere_shape_func_5d(condition): + out = output_tensor((2, ), "int64") + out[0] = int64(0) + out[1] = int64(5) + for i1 in range(condition.shape[0]): + for i2 in range(condition.shape[1]): + for i3 in range(condition.shape[2]): + for i4 in range(condition.shape[3]): + for i5 in range(condition.shape[4]): + if condition[i1, i2, i3, i4, i5]: + out[0] += int64(1) + return out + +@_reg.register_shape_func("argwhere", True) +def argwhere_shape_func(attrs, inputs, out_ndims): + """ + Shape function for argwhere. + """ + if len(inputs[0].shape) == 2: + return [_argwhere_shape_func_2d(inputs[0])] + elif len(inputs[0].shape) == 3: + return [_argwhere_shape_func_3d(inputs[0])] + elif len(inputs[0].shape) == 4: + return [_argwhere_shape_func_4d(inputs[0])] + elif len(inputs[0].shape) == 5: + return [_argwhere_shape_func_5d(inputs[0])] + return [] + +@_reg.register_schedule("argwhere") +def schedule_argwhere(_, outs, target): + """Schedule definition of argwhere""" + with target: + return topi.generic.schedule_argwhere(outs) + + +@_reg.register_compute("argwhere") +def compute_argwhere(attrs, inputs, output_type, _): + """Compute definition of argwhere""" + output_shape = [] + for s in output_type.shape: + if hasattr(s, "value"): + output_shape.append(s) + else: + # see Any, replace it with a var + output_shape.append(tvm.var("any_dim", "int32")) + new_output_type = TensorType(output_shape, "int32") + return [topi.argwhere(new_output_type, inputs[0])] diff --git a/python/tvm/relay/op/algorithm.py b/python/tvm/relay/op/algorithm.py index 20d4f9c6a412..6f875919df4c 100644 --- a/python/tvm/relay/op/algorithm.py +++ b/python/tvm/relay/op/algorithm.py @@ -19,22 +19,6 @@ from . import _make from ..expr import TupleWrapper -def argwhere(condition): - """Find the indices of elements of a tensor that are - non-zero. - - Parameters - ---------- - condition : relay.Expr - The input condition tensor. - - Returns - ------- - out : relay.Expr - Tensor with the indices of elements that are non-zero. - """ - return _make.argwhere(condition) - def argsort(data, axis=-1, is_ascend=1, dtype="int32"): """Performs sorting along the given axis and returns an array of indicies having same shape as an input array that index data in sorted order. diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index c730c7047bd0..81dacf8700e5 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -213,6 +213,21 @@ def reshape(data, newshape): newshape = [newshape] return _make.reshape(data, list(newshape)) +def argwhere(condition): + """Find the indices of elements of a tensor that are + non-zero. + + Parameters + ---------- + condition : relay.Expr + The input condition tensor. + + Returns + ------- + out : relay.Expr + Tensor with the indices of elements that are non-zero. + """ + return _make.argwhere(condition) def reshape_like(data, shape_like): """Reshapes the input array by the size of another array. diff --git a/src/relay/op/algorithm/argwhere.cc b/src/relay/op/algorithm/argwhere.cc deleted file mode 100644 index 87a2c7b17fd9..000000000000 --- a/src/relay/op/algorithm/argwhere.cc +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2019 by Contributors - * \file argwhere.cc - * \brief Argwhere operators - */ -#include -#include -#include - -namespace tvm { -namespace relay { - -// ArgWhere -bool ArgWhereRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - CHECK_EQ(num_inputs, 1); - auto tt = types[0].as(); - CHECK(tt != nullptr); - const auto& input_shape = tt->shape; - const auto& input_rank = input_shape.size(); - std::vector result_shape; - result_shape.push_back(Any::make()); - result_shape.push_back(IntImm::make(Int(32), input_rank)); - reporter->Assign(types[1], TensorTypeNode::make(result_shape, Int(32))); - return true; -} - -TVM_REGISTER_API("relay.op._make.argwhere") -.set_body_typed([](Expr data) { - static const Op& op = Op::Get("argwhere"); - auto attrs = make_node(); - return CallNode::make(op, {data}, Attrs(attrs), {}); -}); - -RELAY_REGISTER_OP("argwhere") -.describe(R"doc(Find the indices of elements of a tensor that are -non-zero)doc" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type_key("relay.attrs.ArgWhereAttrs") -.add_argument("condition", "Tensor", "The input condition tensor.") -.add_type_rel("ArgWhere", ArgWhereRel) -.set_attr("TOpIsStateful", false) -.set_attr("TOpPattern", kOpaque) -.set_support_level(10); - -} // namespace relay -} // namespace tvm diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 9cd436bbac7a..5f5ae23f0b03 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -817,6 +817,40 @@ the input array into an output array with the same shape as the second input arr .set_attr("FTVMCompute", ReshapeCompute) .set_attr("TOpPattern", kInjective); +// ArgWhere +bool ArgWhereRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(num_inputs, 1); + auto tt = types[0].as(); + CHECK(tt != nullptr); + const auto& input_shape = tt->shape; + const auto& input_rank = input_shape.size(); + std::vector result_shape; + result_shape.push_back(Any::make()); + result_shape.push_back(IntImm::make(Int(32), input_rank)); + reporter->Assign(types[1], TensorTypeNode::make(result_shape, Int(32))); + return true; +} + +TVM_REGISTER_API("relay.op._make.argwhere") +.set_body_typed([](Expr data) { + static const Op& op = Op::Get("argwhere"); + auto attrs = make_node(); + return CallNode::make(op, {data}, Attrs(attrs), {}); +}); + +RELAY_REGISTER_OP("argwhere") +.describe(R"doc(Find the indices of elements of a tensor that are +non-zero)doc" TVM_ADD_FILELINE) +.set_num_inputs(1) +.set_attrs_type_key("relay.attrs.ArgWhereAttrs") +.add_argument("condition", "Tensor", "The input condition tensor.") +.add_type_rel("ArgWhere", ArgWhereRel) +.set_attr("TOpIsStateful", false) +.set_attr("TOpPattern", kOpaque) +.set_support_level(10); // Take TVM_REGISTER_NODE_TYPE(TakeAttrs); diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc index 3a99c2656173..0e3e539cc928 100644 --- a/src/relay/op/tensor/unary.cc +++ b/src/relay/op/tensor/unary.cc @@ -326,6 +326,7 @@ RELAY_REGISTER_OP("shape_of") .set_support_level(10) .set_attr("FTVMCompute", ShapeOfCompute); + TVM_REGISTER_NODE_TYPE(NdarraySizeAttrs); bool NdarraySizeRel(const Array& types, From 8aa6729ce59385196f2ad21263ab1653c7853683 Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Wed, 25 Sep 2019 14:02:25 -0700 Subject: [PATCH 06/14] Add argwhere example --- python/tvm/relay/op/transform.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 81dacf8700e5..88d7a448005c 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -226,6 +226,13 @@ def argwhere(condition): ------- out : relay.Expr Tensor with the indices of elements that are non-zero. + + Examples + -------- + .. code-block:: python + + condition = [[True, False], [False, True]] + relay.argwhere(condition) = [[0, 0], [1, 1]] """ return _make.argwhere(condition) From 235f8b4605e73491081ba22648d9ae61360acf6c Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Wed, 25 Sep 2019 14:04:10 -0700 Subject: [PATCH 07/14] Fix lint --- python/tvm/relay/op/_transform.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 489904420155..d2a74ec92402 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -15,17 +15,17 @@ # specific language governing permissions and limitations # under the License. """Backend compiler related feature registration""" -# pylint: disable=invalid-name,unused-argument, len-as-condition +# pylint: disable=invalid-name,unused-argument, len-as-condition, too-many-nested-blocks from __future__ import absolute_import import tvm import topi from topi.util import get_const_int, get_const_tuple +from tvm.relay.ty import TensorType from . import op as _reg from ._reduce import _schedule_reduce from .op import OpPattern from ...hybrid import script from ...api import convert -from tvm.relay.ty import TensorType schedule_injective = _reg.schedule_injective schedule_broadcast = _reg.schedule_injective From 5210cc9f63ab64d328f2b8949539d26292ba4612 Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Wed, 25 Sep 2019 14:11:23 -0700 Subject: [PATCH 08/14] Add 1-d support --- python/tvm/relay/op/_transform.py | 14 +++++++++++++- tests/python/relay/test_any.py | 1 + topi/python/topi/argwhere.py | 25 +++++++++++++++++++++++++ 3 files changed, 39 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index d2a74ec92402..cc96f19606d0 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -208,6 +208,16 @@ def take_shape_func(attrs, inputs, out_ndims): assert 0 <= axis < data_ndim return [_take_with_axis_shape_func(*inputs, convert(axis), out_ndims[0])] +@script +def _argwhere_shape_func_1d(condition): + out = output_tensor((2, ), "int64") + out[0] = int64(0) + out[1] = int64(1) + for i1 in range(condition.shape[0]): + if condition[i1]: + out[0] += int64(1) + return out + @script def _argwhere_shape_func_2d(condition): out = output_tensor((2, ), "int64") @@ -263,7 +273,9 @@ def argwhere_shape_func(attrs, inputs, out_ndims): """ Shape function for argwhere. """ - if len(inputs[0].shape) == 2: + if len(inputs[0].shape) == 1: + return [_argwhere_shape_func_1d(inputs[0])] + elif len(inputs[0].shape) == 2: return [_argwhere_shape_func_2d(inputs[0])] elif len(inputs[0].shape) == 3: return [_argwhere_shape_func_3d(inputs[0])] diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index df6d8d149c00..802e4f342ebf 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -106,6 +106,7 @@ def verify_any_argwhere(x_shape, x_np_shape, out_shape): tvm.testing.assert_allclose(result.flatten(), expected.flatten()) def test_any_argwhere(): + verify_any_argwhere(any_dims(1), (5,), None) verify_any_argwhere(any_dims(2), (5, 5), None) verify_any_argwhere(any_dims(3), (5, 5, 5), None) verify_any_argwhere(any_dims(4), (5, 5, 5, 5), None) diff --git a/topi/python/topi/argwhere.py b/topi/python/topi/argwhere.py index 6d1e611e96ae..659c075686c6 100644 --- a/topi/python/topi/argwhere.py +++ b/topi/python/topi/argwhere.py @@ -19,6 +19,29 @@ import tvm from tvm import hybrid +@hybrid.script +def hybrid_argwhere_1d(output_shape, condition): + """Find the indices of elements of a 1-D tensor that are non-zero. + + Parameters + ---------- + condition : tvm.Tensor + 1-D tensor with boolean values. + + Returns + ------- + out : tvm.Tensor + Indices of non-zero elements. + """ + a = output_tensor(output_shape, "int32") + a1 = condition.shape[0] + valid_index = 0 + for i1 in range(a1): + if condition[i1]: + a[valid_index, 0] = i1 + valid_index += 1 + return a + @hybrid.script def hybrid_argwhere_2d(output_shape, condition): """Find the indices of elements of a 2-D tensor that are non-zero. @@ -155,6 +178,8 @@ def argwhere(output_shape, condition): out : tvm.Tensor Indices of non-zero elements. """ + if len(condition.shape) == 1: + return hybrid_argwhere_1d(output_shape.shape, condition) if len(condition.shape) == 2: return hybrid_argwhere_2d(output_shape.shape, condition) if len(condition.shape) == 3: From 968aa7c0581f3c47e8ce42295f35653ef8fb7d16 Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Wed, 25 Sep 2019 14:12:01 -0700 Subject: [PATCH 09/14] cleanup --- python/tvm/relay/op/_transform.py | 3 +-- topi/python/topi/transform.py | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index cc96f19606d0..a11af63d1bd4 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -20,7 +20,6 @@ import tvm import topi from topi.util import get_const_int, get_const_tuple -from tvm.relay.ty import TensorType from . import op as _reg from ._reduce import _schedule_reduce from .op import OpPattern @@ -302,5 +301,5 @@ def compute_argwhere(attrs, inputs, output_type, _): else: # see Any, replace it with a var output_shape.append(tvm.var("any_dim", "int32")) - new_output_type = TensorType(output_shape, "int32") + new_output_type = tvm.relay.ty.TensorType(output_shape, "int32") return [topi.argwhere(new_output_type, inputs[0])] diff --git a/topi/python/topi/transform.py b/topi/python/topi/transform.py index aeb826b3a1c2..3c7fc9c0dffb 100644 --- a/topi/python/topi/transform.py +++ b/topi/python/topi/transform.py @@ -562,4 +562,3 @@ def one_hot(indices, on_value, off_value, depth, axis, dtype): [0, 0, 1]] """ return cpp.one_hot(indices, on_value, off_value, depth, axis, dtype) - \ No newline at end of file From e7f9496b2142a0c1f1407a88c53eac74e2ab52b6 Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Fri, 27 Sep 2019 15:38:25 -0700 Subject: [PATCH 10/14] Add more dtype support --- python/tvm/relay/op/_transform.py | 10 +++++----- tests/python/relay/test_any.py | 16 +++++++++++++--- topi/python/topi/argwhere.py | 10 +++++----- 3 files changed, 23 insertions(+), 13 deletions(-) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index a11af63d1bd4..ce98ff996bee 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -213,7 +213,7 @@ def _argwhere_shape_func_1d(condition): out[0] = int64(0) out[1] = int64(1) for i1 in range(condition.shape[0]): - if condition[i1]: + if condition[i1] != 0: out[0] += int64(1) return out @@ -224,7 +224,7 @@ def _argwhere_shape_func_2d(condition): out[1] = int64(2) for i1 in range(condition.shape[0]): for i2 in range(condition.shape[1]): - if condition[i1, i2]: + if condition[i1, i2] != 0: out[0] += int64(1) return out @@ -236,7 +236,7 @@ def _argwhere_shape_func_3d(condition): for i1 in range(condition.shape[0]): for i2 in range(condition.shape[1]): for i3 in range(condition.shape[2]): - if condition[i1, i2, i3]: + if condition[i1, i2, i3] != 0: out[0] += int64(1) return out @@ -249,7 +249,7 @@ def _argwhere_shape_func_4d(condition): for i2 in range(condition.shape[1]): for i3 in range(condition.shape[2]): for i4 in range(condition.shape[3]): - if condition[i1, i2, i3, i4]: + if condition[i1, i2, i3, i4] != 0: out[0] += int64(1) return out @@ -263,7 +263,7 @@ def _argwhere_shape_func_5d(condition): for i3 in range(condition.shape[2]): for i4 in range(condition.shape[3]): for i5 in range(condition.shape[4]): - if condition[i1, i2, i3, i4, i5]: + if condition[i1, i2, i3, i4, i5] != 0: out[0] += int64(1) return out diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 802e4f342ebf..d51b35ecf911 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -92,12 +92,12 @@ def test_any_reshape(): verify_any_reshape(any_dims(3), (-4, 2, -1, -2), (6, 3, 4), (2, 3, 3, 4)) verify_any_reshape(any_dims(3), (-4, -1, 2, -3), (6, 3, 4), (3, 2, 12)) -def verify_any_argwhere(x_shape, x_np_shape, out_shape): - x = relay.var('x', shape=x_shape, dtype="bool") +def verify_any_argwhere(x_shape, x_np_shape, out_shape, dtype="bool"): + x = relay.var('x', shape=x_shape, dtype=dtype) y = relay.argwhere(x) mod = relay.module.Module() mod["main"] = relay.Function([x], y) - data = np.random.choice([True, False], size=x_np_shape) + data = np.random.choice([0, 1, 2, 3], size=x_np_shape).astype(dtype) for kind in ["debug", "vm"]: ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") result = ex.evaluate()(data).asnumpy() @@ -111,6 +111,16 @@ def test_any_argwhere(): verify_any_argwhere(any_dims(3), (5, 5, 5), None) verify_any_argwhere(any_dims(4), (5, 5, 5, 5), None) verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5), None) + verify_any_argwhere(any_dims(1), (5,), None, "int32") + verify_any_argwhere(any_dims(2), (5, 5), None, "int32") + verify_any_argwhere(any_dims(3), (5, 5, 5), None, "int32") + verify_any_argwhere(any_dims(4), (5, 5, 5, 5), None, "int32") + verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5), None, "int32") + verify_any_argwhere(any_dims(1), (5,), None, "int8") + verify_any_argwhere(any_dims(2), (5, 5), None, "int8") + verify_any_argwhere(any_dims(3), (5, 5, 5), None, "int8") + verify_any_argwhere(any_dims(4), (5, 5, 5, 5), None, "int8") + verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5), None, "int8") def verify_any_take(data_shape, indices_shape, axis, data_np_shape, indices_np_shape): mod = relay.Module() diff --git a/topi/python/topi/argwhere.py b/topi/python/topi/argwhere.py index 659c075686c6..34a2039950f2 100644 --- a/topi/python/topi/argwhere.py +++ b/topi/python/topi/argwhere.py @@ -37,7 +37,7 @@ def hybrid_argwhere_1d(output_shape, condition): a1 = condition.shape[0] valid_index = 0 for i1 in range(a1): - if condition[i1]: + if condition[i1] != 0: a[valid_index, 0] = i1 valid_index += 1 return a @@ -62,7 +62,7 @@ def hybrid_argwhere_2d(output_shape, condition): valid_index = 0 for i1 in range(a1): for i2 in range(a2): - if condition[i1, i2]: + if condition[i1, i2] != 0: a[valid_index, 0] = i1 a[valid_index, 1] = i2 valid_index += 1 @@ -90,7 +90,7 @@ def hybrid_argwhere_3d(output_shape, condition): for i1 in range(a1): for i2 in range(a2): for i3 in range(a3): - if condition[i1, i2, i3]: + if condition[i1, i2, i3] != 0: a[valid_index, 0] = i1 a[valid_index, 1] = i2 a[valid_index, 2] = i3 @@ -121,7 +121,7 @@ def hybrid_argwhere_4d(output_shape, condition): for i2 in range(a2): for i3 in range(a3): for i4 in range(a4): - if condition[i1, i2, i3, i4]: + if condition[i1, i2, i3, i4] != 0: a[valid_index, 0] = i1 a[valid_index, 1] = i2 a[valid_index, 2] = i3 @@ -155,7 +155,7 @@ def hybrid_argwhere_5d(output_shape, condition): for i3 in range(a3): for i4 in range(a4): for i5 in range(a5): - if condition[i1, i2, i3, i4, i5]: + if condition[i1, i2, i3, i4, i5] != 0: a[valid_index, 0] = i1 a[valid_index, 1] = i2 a[valid_index, 2] = i3 From 1866edf019b4e3b618374bf08a223c1cfc24a4a2 Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Sun, 29 Sep 2019 17:49:43 -0700 Subject: [PATCH 11/14] CR comment --- src/relay/op/tensor/transform.cc | 6 +++--- tests/python/relay/test_any.py | 32 ++++++++++++++++---------------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 5f5ae23f0b03..5411be2b5fe8 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -819,9 +819,9 @@ the input array into an output array with the same shape as the second input arr // ArgWhere bool ArgWhereRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { CHECK_EQ(num_inputs, 1); auto tt = types[0].as(); CHECK(tt != nullptr); diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index d51b35ecf911..d02dcd0b73dd 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -92,7 +92,7 @@ def test_any_reshape(): verify_any_reshape(any_dims(3), (-4, 2, -1, -2), (6, 3, 4), (2, 3, 3, 4)) verify_any_reshape(any_dims(3), (-4, -1, 2, -3), (6, 3, 4), (3, 2, 12)) -def verify_any_argwhere(x_shape, x_np_shape, out_shape, dtype="bool"): +def verify_any_argwhere(x_shape, x_np_shape, dtype="bool"): x = relay.var('x', shape=x_shape, dtype=dtype) y = relay.argwhere(x) mod = relay.module.Module() @@ -106,21 +106,21 @@ def verify_any_argwhere(x_shape, x_np_shape, out_shape, dtype="bool"): tvm.testing.assert_allclose(result.flatten(), expected.flatten()) def test_any_argwhere(): - verify_any_argwhere(any_dims(1), (5,), None) - verify_any_argwhere(any_dims(2), (5, 5), None) - verify_any_argwhere(any_dims(3), (5, 5, 5), None) - verify_any_argwhere(any_dims(4), (5, 5, 5, 5), None) - verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5), None) - verify_any_argwhere(any_dims(1), (5,), None, "int32") - verify_any_argwhere(any_dims(2), (5, 5), None, "int32") - verify_any_argwhere(any_dims(3), (5, 5, 5), None, "int32") - verify_any_argwhere(any_dims(4), (5, 5, 5, 5), None, "int32") - verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5), None, "int32") - verify_any_argwhere(any_dims(1), (5,), None, "int8") - verify_any_argwhere(any_dims(2), (5, 5), None, "int8") - verify_any_argwhere(any_dims(3), (5, 5, 5), None, "int8") - verify_any_argwhere(any_dims(4), (5, 5, 5, 5), None, "int8") - verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5), None, "int8") + verify_any_argwhere(any_dims(1), (5,)) + verify_any_argwhere(any_dims(2), (5, 5)) + verify_any_argwhere(any_dims(3), (5, 5, 5)) + verify_any_argwhere(any_dims(4), (5, 5, 5, 5)) + verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5)) + verify_any_argwhere(any_dims(1), (5,), "int32") + verify_any_argwhere(any_dims(2), (5, 5), "int32") + verify_any_argwhere(any_dims(3), (5, 5, 5), "int32") + verify_any_argwhere(any_dims(4), (5, 5, 5, 5), "int32") + verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5), "int32") + verify_any_argwhere(any_dims(1), (5,), "int8") + verify_any_argwhere(any_dims(2), (5, 5), "int8") + verify_any_argwhere(any_dims(3), (5, 5, 5), "int8") + verify_any_argwhere(any_dims(4), (5, 5, 5, 5), "int8") + verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5), "int8") def verify_any_take(data_shape, indices_shape, axis, data_np_shape, indices_np_shape): mod = relay.Module() From b5f30245a4da36cec68c06011cd0ff862aeb388e Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Sun, 29 Sep 2019 17:51:19 -0700 Subject: [PATCH 12/14] Improve error message --- topi/python/topi/argwhere.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/topi/python/topi/argwhere.py b/topi/python/topi/argwhere.py index 34a2039950f2..32f4e8718c46 100644 --- a/topi/python/topi/argwhere.py +++ b/topi/python/topi/argwhere.py @@ -188,4 +188,4 @@ def argwhere(output_shape, condition): return hybrid_argwhere_4d(output_shape.shape, condition) if len(condition.shape) == 5: return hybrid_argwhere_5d(output_shape.shape, condition) - raise ValueError("Does not support rank higher than 5") + raise ValueError("Does not support rank higher than 5 in argwhere") From 0e2408421f2bdbbb919254381ceb0fa9a29bcf7d Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Sun, 29 Sep 2019 17:56:19 -0700 Subject: [PATCH 13/14] Docs --- topi/python/topi/generic/__init__.py | 2 +- topi/python/topi/generic/{where.py => search.py} | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) rename topi/python/topi/generic/{where.py => search.py} (97%) diff --git a/topi/python/topi/generic/__init__.py b/topi/python/topi/generic/__init__.py index e3317a3c11f8..18af0e328471 100644 --- a/topi/python/topi/generic/__init__.py +++ b/topi/python/topi/generic/__init__.py @@ -20,4 +20,4 @@ from .extern import * from .vision import * from .sort import * -from .where import * +from .search import * diff --git a/topi/python/topi/generic/where.py b/topi/python/topi/generic/search.py similarity index 97% rename from topi/python/topi/generic/where.py rename to topi/python/topi/generic/search.py index 8d3d137f7249..41045e492e53 100644 --- a/topi/python/topi/generic/where.py +++ b/topi/python/topi/generic/search.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name, no-member -"""Generic vision operators""" +"""Generic search operators""" from __future__ import absolute_import as _abs import tvm from .vision import _default_schedule From 6b40ff91825f1c7c1d0487f2869bf561979a595e Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Mon, 30 Sep 2019 11:35:48 -0700 Subject: [PATCH 14/14] raise exception --- python/tvm/relay/op/_transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index ce98ff996bee..3197b81289b2 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -282,7 +282,7 @@ def argwhere_shape_func(attrs, inputs, out_ndims): return [_argwhere_shape_func_4d(inputs[0])] elif len(inputs[0].shape) == 5: return [_argwhere_shape_func_5d(inputs[0])] - return [] + return ValueError("Does not support rank higher than 5 in argwhere") @_reg.register_schedule("argwhere") def schedule_argwhere(_, outs, target):