Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,12 @@ struct OneHotAttrs : public tvm::AttrsNode<OneHotAttrs> {
}
}; // struct OneHotAttrs

/*! \brief Attributes for ArgWhere operator */
struct ArgWhereAttrs : public tvm::AttrsNode<ArgWhereAttrs> {
TVM_DECLARE_ATTRS(ArgWhereAttrs, "relay.attrs.ArgWhereAttrs") {
}
}; // struct ArgWhereAttrs

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
101 changes: 100 additions & 1 deletion python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
# specific language governing permissions and limitations
# under the License.
"""Backend compiler related feature registration"""
# pylint: disable=invalid-name,unused-argument, len-as-condition
# pylint: disable=invalid-name,unused-argument, len-as-condition, too-many-nested-blocks
from __future__ import absolute_import
import tvm
import topi
from topi.util import get_const_int, get_const_tuple
from . import op as _reg
from ._reduce import _schedule_reduce
Expand Down Expand Up @@ -204,3 +206,100 @@ def take_shape_func(attrs, inputs, out_ndims):
axis += data_ndim
assert 0 <= axis < data_ndim
return [_take_with_axis_shape_func(*inputs, convert(axis), out_ndims[0])]

@script
def _argwhere_shape_func_1d(condition):
out = output_tensor((2, ), "int64")
out[0] = int64(0)
out[1] = int64(1)
for i1 in range(condition.shape[0]):
if condition[i1] != 0:
out[0] += int64(1)
return out

@script
def _argwhere_shape_func_2d(condition):
out = output_tensor((2, ), "int64")
out[0] = int64(0)
out[1] = int64(2)
for i1 in range(condition.shape[0]):
for i2 in range(condition.shape[1]):
if condition[i1, i2] != 0:
out[0] += int64(1)
return out

@script
def _argwhere_shape_func_3d(condition):
out = output_tensor((2, ), "int64")
out[0] = int64(0)
out[1] = int64(3)
for i1 in range(condition.shape[0]):
for i2 in range(condition.shape[1]):
for i3 in range(condition.shape[2]):
if condition[i1, i2, i3] != 0:
out[0] += int64(1)
return out

@script
def _argwhere_shape_func_4d(condition):
out = output_tensor((2, ), "int64")
out[0] = int64(0)
out[1] = int64(4)
for i1 in range(condition.shape[0]):
for i2 in range(condition.shape[1]):
for i3 in range(condition.shape[2]):
for i4 in range(condition.shape[3]):
if condition[i1, i2, i3, i4] != 0:
out[0] += int64(1)
return out

@script
def _argwhere_shape_func_5d(condition):
out = output_tensor((2, ), "int64")
out[0] = int64(0)
out[1] = int64(5)
for i1 in range(condition.shape[0]):
for i2 in range(condition.shape[1]):
for i3 in range(condition.shape[2]):
for i4 in range(condition.shape[3]):
for i5 in range(condition.shape[4]):
if condition[i1, i2, i3, i4, i5] != 0:
out[0] += int64(1)
return out

@_reg.register_shape_func("argwhere", True)
def argwhere_shape_func(attrs, inputs, out_ndims):
"""
Shape function for argwhere.
"""
if len(inputs[0].shape) == 1:
return [_argwhere_shape_func_1d(inputs[0])]
elif len(inputs[0].shape) == 2:
return [_argwhere_shape_func_2d(inputs[0])]
elif len(inputs[0].shape) == 3:
return [_argwhere_shape_func_3d(inputs[0])]
elif len(inputs[0].shape) == 4:
return [_argwhere_shape_func_4d(inputs[0])]
elif len(inputs[0].shape) == 5:
return [_argwhere_shape_func_5d(inputs[0])]
return ValueError("Does not support rank higher than 5 in argwhere")

@_reg.register_schedule("argwhere")
def schedule_argwhere(_, outs, target):
"""Schedule definition of argwhere"""
with target:
return topi.generic.schedule_argwhere(outs)


@_reg.register_compute("argwhere")
def compute_argwhere(attrs, inputs, output_type, _):
"""Compute definition of argwhere"""
output_shape = []
for s in output_type.shape:
if hasattr(s, "value"):
output_shape.append(s)
else:
# see Any, replace it with a var
output_shape.append(tvm.var("any_dim", "int32"))
new_output_type = tvm.relay.ty.TensorType(output_shape, "int32")
return [topi.argwhere(new_output_type, inputs[0])]
23 changes: 22 additions & 1 deletion python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ def squeeze(data, axis=None):
"""
return _make.squeeze(data, axis)


def reshape(data, newshape):
"""Reshapes the input array.

Expand Down Expand Up @@ -214,6 +213,28 @@ def reshape(data, newshape):
newshape = [newshape]
return _make.reshape(data, list(newshape))

def argwhere(condition):
"""Find the indices of elements of a tensor that are
non-zero.

Parameters
----------
condition : relay.Expr
The input condition tensor.

Returns
-------
out : relay.Expr
Tensor with the indices of elements that are non-zero.

Examples
--------
.. code-block:: python

condition = [[True, False], [False, True]]
relay.argwhere(condition) = [[0, 0], [1, 1]]
"""
return _make.argwhere(condition)

def reshape_like(data, shape_like):
"""Reshapes the input array by the size of another array.
Expand Down
34 changes: 34 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,40 @@ the input array into an output array with the same shape as the second input arr
.set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

// ArgWhere
bool ArgWhereRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(num_inputs, 1);
auto tt = types[0].as<TensorTypeNode>();
CHECK(tt != nullptr);
const auto& input_shape = tt->shape;
const auto& input_rank = input_shape.size();
std::vector<IndexExpr> result_shape;
result_shape.push_back(Any::make());
result_shape.push_back(IntImm::make(Int(32), input_rank));
reporter->Assign(types[1], TensorTypeNode::make(result_shape, Int(32)));
return true;
}

TVM_REGISTER_API("relay.op._make.argwhere")
.set_body_typed<Expr(Expr)>([](Expr data) {
static const Op& op = Op::Get("argwhere");
auto attrs = make_node<ArgWhereAttrs>();
return CallNode::make(op, {data}, Attrs(attrs), {});
});

RELAY_REGISTER_OP("argwhere")
.describe(R"doc(Find the indices of elements of a tensor that are
non-zero)doc" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_attrs_type_key("relay.attrs.ArgWhereAttrs")
.add_argument("condition", "Tensor", "The input condition tensor.")
.add_type_rel("ArgWhere", ArgWhereRel)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_support_level(10);

// Take
TVM_REGISTER_NODE_TYPE(TakeAttrs);
Expand Down
30 changes: 30 additions & 0 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,36 @@ def test_any_reshape():
verify_any_reshape(any_dims(3), (-4, 2, -1, -2), (6, 3, 4), (2, 3, 3, 4))
verify_any_reshape(any_dims(3), (-4, -1, 2, -3), (6, 3, 4), (3, 2, 12))

def verify_any_argwhere(x_shape, x_np_shape, dtype="bool"):
x = relay.var('x', shape=x_shape, dtype=dtype)
y = relay.argwhere(x)
mod = relay.module.Module()
mod["main"] = relay.Function([x], y)
data = np.random.choice([0, 1, 2, 3], size=x_np_shape).astype(dtype)
for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(data).asnumpy()
expected = np.argwhere(data)
assert result.shape == expected.shape
tvm.testing.assert_allclose(result.flatten(), expected.flatten())

def test_any_argwhere():
verify_any_argwhere(any_dims(1), (5,))
verify_any_argwhere(any_dims(2), (5, 5))
verify_any_argwhere(any_dims(3), (5, 5, 5))
verify_any_argwhere(any_dims(4), (5, 5, 5, 5))
verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5))
verify_any_argwhere(any_dims(1), (5,), "int32")
verify_any_argwhere(any_dims(2), (5, 5), "int32")
verify_any_argwhere(any_dims(3), (5, 5, 5), "int32")
verify_any_argwhere(any_dims(4), (5, 5, 5, 5), "int32")
verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5), "int32")
verify_any_argwhere(any_dims(1), (5,), "int8")
verify_any_argwhere(any_dims(2), (5, 5), "int8")
verify_any_argwhere(any_dims(3), (5, 5, 5), "int8")
verify_any_argwhere(any_dims(4), (5, 5, 5, 5), "int8")
verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5), "int8")

def verify_any_take(data_shape, indices_shape, axis, data_np_shape, indices_np_shape):
mod = relay.Module()
data = relay.var('data', shape=data_shape, dtype='float32')
Expand Down
1 change: 1 addition & 0 deletions topi/python/topi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .transform import *
from .broadcast import *
from .sort import *
from .argwhere import *
from . import nn
from . import x86
from . import cuda
Expand Down
Loading