From 66814847d125c27bc6b6cd045d608557ee5ba827 Mon Sep 17 00:00:00 2001 From: Yanming Wang Date: Wed, 10 Feb 2021 14:59:43 -0800 Subject: [PATCH 01/16] Initial commit of the unique operator Add unit tests for unique operator --- python/tvm/relay/op/_transform.py | 27 +++++ python/tvm/relay/op/strategy/generic.py | 25 ++++ python/tvm/relay/op/transform.py | 29 +++++ python/tvm/topi/transform.py | 38 ++++++ src/relay/op/algorithm/unique.cc | 147 ++++++++++++++++++++++++ tests/python/relay/test_op_level6.py | 51 ++++++++ 6 files changed, 317 insertions(+) create mode 100644 src/relay/op/algorithm/unique.cc diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 01bcf4a6cf60..cf6e0ff9a267 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -28,6 +28,7 @@ from . import op as _reg from . import strategy from .op import OpPattern +from .op import register_strategy, register_pattern from ._tensor import elemwise_shape_func _reg.register_broadcast_schedule("broadcast_to") @@ -946,3 +947,29 @@ def where_shape_func(attrs, inputs, _): out_shape = _broadcast_shape_tensors(bcast_shape, cond_shape) return [out_shape] + + +register_strategy("unique", strategy.unique_strategy) +register_pattern("unique", OpPattern.OPAQUE) + + +@script +def _unique_shape_1(data_shape): + shape_tensor = output_tensor((1,), "int64") + shape_tensor[0] = int64(data_shape[0]) + return shape_tensor + + +@script +def _unique_shape_2(inputs): + shape_tensor = output_tensor((1,), "int64") + shape_tensor[0] = int64(1) + return shape_tensor + + +@_reg.register_shape_func("unique", False) +def unique_shape_func(attrs, inputs, _): + """ + Shape func for unique operator. + """ + return [_unique_shape_1(inputs[0]), _unique_shape_1(inputs[0]), _unique_shape_1(inputs[0]), _unique_shape_2(inputs[0])] diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index f076176c5d8a..a9d11791afb0 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1432,3 +1432,28 @@ def cumsum_strategy(attrs, inputs, out_type, target): name="cumsum.generic", ) return strategy + + +def wrap_compute_unique(topi_compute): + """Wrap unique topi compute""" + + def _compute_unique(attrs, inputs, _): + return topi_compute(inputs[0]) + + return _compute_unique + + +def wrap_unique_schedule(outs): + return topi.generic.default.default_schedule(outs, False) + + +@override_native_generic_func("unique_strategy") +def unique_strategy(attrs, inputs, out_type, target): + """unique generic strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_unique(topi.unique), + wrap_topi_schedule(wrap_unique_schedule), + name="unique.generic", + ) + return strategy diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index b676fe742544..1f34f8b24b74 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1463,3 +1463,32 @@ def cumsum(data, axis=None, dtype=None, exclusive=None): -> [1, 1, 2, 2, 3, 4, 4] """ return _make.cumsum(data, axis, dtype, exclusive) + + +def unique(data): + """ + Find the unique elements of a tensor + Parameters + ---------- + data : relay.Expr + A 1-D tensor of integers + Returns + ------- + output : relay.Expr + A 1-D tensor containing the unique elements of data tensor + inverse_indices : relay.Expr + A 1-D tensor containing the index of each value of data tensor in output tensor + counts : relay.Expr + A 1-D tensor containing the count of each element of output tensor in data tensor + num_unique_elements : relay.Expr + A 0-D tensor containing the number of unique elements in data tensor + Examples + -------- + .. code-block:: python + [y, idx, counts, n] = unique([1, 1, 2, 4, 4, 4, 7, 8, 8]) + y = [1, 2, 4, 7, 8, ?, ?, ?, ?] + idx = [0, 0, 1, 2, 2, 2, 3, 4, 4] + count = [2, 1, 3, 1, 2, ?, ?, ?, ?] + n = [5] + """ + return TupleWrapper(_make.unique(data), 4) diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 6ddbc73e4666..4a769d0b2030 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -931,3 +931,41 @@ def adv_index(data, indices): Output tensor """ return cpp.adv_index(data, indices) + + +def unique(data): + """ + Find the unique elements of a tensor + Parameters + ---------- + data : tvm.te.Tensor + A 1-D tensor of integers + Returns + ------- + output : tvm.te.Tensor + A 1-D tensor containing the unique elements of data tensor + inverse_indices : rtvm.te.Tensor + A 1-D tensor containing the index of each value of data tensor in output tensor + counts : tvm.te.Tensor + A 1-D tensor containing the count of each element of output tensor in data tensor + num_unique_elements : tvm.te.Tensor + A 0-D tensor containing the number of unique elements in data tensor + Examples + -------- + .. code-block:: python + [y, idx, counts, n] = unique([1, 1, 2, 4, 4, 4, 7, 8, 8]) + y = [1, 2, 4, 7, 8, ?, ?, ?, ?] + idx = [0, 0, 1, 2, 2, 2, 3, 4, 4] + count = [2, 1, 3, 1, 2, ?, ?, ?, ?] + n = [5] + """ + return te.extern( + [data.shape, data.shape, data.shape, (1,)], + [data], + lambda ins, outs: tvm.tir.call_packed( + "tvm.contrib.algorithm.unique", ins[0], *outs + ), + dtype=[data.dtype, "int32", "int32", "int32", "int32"], + name="unique_cpu", + tag="unique_cpu", + ) diff --git a/src/relay/op/algorithm/unique.cc b/src/relay/op/algorithm/unique.cc new file mode 100644 index 000000000000..941e6dac4f53 --- /dev/null +++ b/src/relay/op/algorithm/unique.cc @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file unique.cc + * \brief The unique operator + */ +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { + +bool UniqueRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // types: [data, result] + ICHECK_EQ(types.size(), 2) << "Unique: expect 2 types but " << types.size() << " provided"; + ICHECK_EQ(num_inputs, 1) << "Unique: expect 1 inputs but " << num_inputs << " provided"; + auto data = types[0].as(); + if (data == nullptr) { + ICHECK(types[0].as()) + << "Unique: expect input type to be TensorType but get " << types[0]; + return false; + } + std::vector fields; + fields.push_back(TensorType(data->shape, data->dtype)); + fields.push_back(TensorType(data->shape, DataType::Int(32))); + fields.push_back(TensorType(data->shape, DataType::Int(32))); + fields.push_back(TensorType(Array{1}, DataType::Int(32))); + reporter->Assign(types[1], TupleType(Array(fields))); + return true; +} + +Expr MakeUnique(Expr data) { + static const Op& op = Op::Get("unique"); + return Call(op, {data}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.unique").set_body_typed(MakeUnique); + +RELAY_REGISTER_OP("unique") + .describe( + R"code(This operation returns a tensor **output** containing all of the unique elements of **data** + sorted in the same order that they occur in **data**; **data** does not need to be sorted. + This operation also returns a tensor **inverse_indices** contains the index of each value of **data** in the unique output **output**. + In other words: output[inverse_indices[i]] = data[i] for i in [0, 1,..., len(data) - 1]. + This operation also returns a 0-D tensor **num_unique_elements** contains the number of unique elements in **data**. + Please note **output** and **counts** have the same size of **data** and only items [0, 1,..., num_unique_elements[0]-1] are valid. + + - **data**: A 1-D tensor of integers + + - **output**: A 1-D tensor containing the unique elements of **data** + + - **inverse_indices**: A 1-D tensor containing the index of each value of **data** in **output** + + - **counts**: A 1-D tensor containing the count of each element of **output** in **data** + + - **num_unique_elements**: A 0-D tensor containing the number of unique elements + + Example:: + - [y, idx, counts, n] = unique([1, 1, 2, 4, 4, 4, 7, 8, 8]) + - y = [1, 2, 4, 7, 8, ?, ?, ?, ?] + - idx = [0, 0, 1, 2, 2, 2, 3, 4, 4] + - count = [2, 1, 3, 1, 2, ?, ?, ?, ?] + - n = [5] + )code" TVM_ADD_FILELINE) + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor") + .add_type_rel("unique", UniqueRel) + .set_support_level(6); + +template +void calc_unique(DLTensor* input, DLTensor* output, DLTensor* inverse_indices, DLTensor* counts, + DLTensor* num_unique_elements) { + std::unordered_map + unique_map; // map to record the idx of each unique element in the output tensor + auto input_ptr = static_cast(input->data); + auto output_ptr = static_cast(output->data); + auto inverse_indices_ptr = static_cast(inverse_indices->data); + auto counts_ptr = static_cast(counts->data); + auto num_unique_ptr = static_cast(num_unique_elements->data); + + int unique_counter = 0; + for (int i = 0; i < input->shape[0]; i++) { + if (unique_map.count(input_ptr[i]) == 0) { + unique_map[input_ptr[i]] = unique_counter; + output_ptr[unique_counter] = input_ptr[i]; + counts_ptr[unique_counter] = 0; + unique_counter++; + } + inverse_indices_ptr[i] = unique_map[input_ptr[i]]; + counts_ptr[inverse_indices_ptr[i]]++; + } + + num_unique_ptr[0] = unique_counter; +} + +// The unique operator +TVM_REGISTER_GLOBAL("tvm.contrib.algorithm.unique").set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* input = args[0]; + DLTensor* output = args[1]; + DLTensor* inverse_indices = args[2]; + DLTensor* counts = args[3]; + DLTensor* num_unique_elements = args[4]; + + ICHECK_EQ(input->ndim, 1) << "The input tensor must be 1-D"; + ICHECK((output->ndim) == 1 && (inverse_indices->ndim) == 1 && (counts->ndim == 1) && + (num_unique_elements->ndim == 1)) + << "The output,inverse_indices,counts,num_unique_elements tensors must be 1-D"; + ICHECK((input->shape[0] == output->shape[0]) && (input->shape[0] == inverse_indices->shape[0]) && + (input->shape[0] == counts->shape[0])) + << "The input,output,inverse_indices,counts tensors must have the " + "same size"; + ICHECK_EQ(num_unique_elements->shape[0], 1) << "The num_unique_elements tensor must have size 1"; + + auto data_dtype = tvm::runtime::DLDataType2String(input->dtype); + + if (data_dtype == "int32") { + calc_unique(input, output, inverse_indices, counts, num_unique_elements); + } else if (data_dtype == "int64") { + calc_unique(input, output, inverse_indices, counts, num_unique_elements); + } else { + LOG(FATAL) << "Unsupported input dtype: " << data_dtype; + } +}); + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_op_level6.py b/tests/python/relay/test_op_level6.py index 0dac69e36025..fc7ccaf3aeb5 100644 --- a/tests/python/relay/test_op_level6.py +++ b/tests/python/relay/test_op_level6.py @@ -139,7 +139,58 @@ def verify_topk(k, axis, ret_type, is_ascend, dtype): verify_topk(k, axis, ret_type, False, "float32") +def test_unique(): + def calc_unique(data): + uniq, index, inverse, counts = np.unique( + data, return_index=True, return_inverse=True, return_counts=True) + order = np.argsort(index) + reverse_order = dict(zip(order, np.arange(len(order)))) + uniq = uniq[order].astype(data.dtype) + inverse = np.array([reverse_order[i] for i in inverse]).astype("int32") + counts = counts[order].astype("int32") + num_uniq = np.array([len(uniq)]).astype("int32") + return uniq, inverse, counts, num_uniq + + def verify_unique(len, dtype, is_dyn=False): + if is_dyn: + x = relay.var("x", relay.TensorType([relay.Any()], dtype)) + else: + x = relay.var("x", relay.TensorType([len], dtype)) + outs = relay.unique(x) + outs = outs.astuple() + func = relay.Function([x], outs) + x_data = np.random.randint(100, size=len).astype(dtype) + + if is_dyn: + backends = ["vm", "debug"] + else: + backends = ["graph", "debug"] + for target, ctx in tvm.testing.enabled_targets(): + for kind in backends: + mod = tvm.ir.IRModule.from_expr(func) + intrp = relay.create_executor( + kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(x_data) + ref_res = calc_unique(x_data) + num_uniq = ref_res[3][0] + assert num_uniq == op_res[3].asnumpy()[0] + # output + tvm.testing.assert_allclose( + op_res[0].asnumpy()[:num_uniq], ref_res[0], rtol=1e-5) + # inverse_indices + tvm.testing.assert_allclose( + op_res[1].asnumpy(), ref_res[1], rtol=1e-5) + # count + tvm.testing.assert_allclose( + op_res[2].asnumpy()[:num_uniq], ref_res[2], rtol=1e-5) + for dtype in ["int32", "int64"]: + for is_dyn in [False, True]: + verify_unique((50), dtype, is_dyn=is_dyn) + verify_unique((100), dtype, is_dyn=is_dyn) + + if __name__ == "__main__": test_sort() test_argsort() test_topk() + test_unique() From 9d9fae5b9d4e626aed4d7eefd7271da0c2c5a060 Mon Sep 17 00:00:00 2001 From: Yanming Wang Date: Wed, 10 Feb 2021 17:40:51 -0800 Subject: [PATCH 02/16] Add tensorflow unique op --- python/tvm/relay/frontend/tensorflow.py | 12 ++++++++ python/tvm/relay/op/_transform.py | 7 ++++- python/tvm/topi/transform.py | 4 +-- .../frontend/tensorflow/test_forward.py | 29 +++++++++++++++++++ tests/python/relay/test_op_level6.py | 22 +++++++------- 5 files changed, 58 insertions(+), 16 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 6a29ce266ea6..b29ad67ee233 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -2324,6 +2324,17 @@ def _impl(inputs, attr, params, mod): return _impl +def _unique(): + def _impl(inputs, attr, params, mod): + assert len(inputs) == 1 + x = inputs[0] + [output, indices, counts, num_uniq] = _op.unique(x) + output_sliced = _op.strided_slice(output, begin=[0], end=num_uniq, slice_mode="size") + return [output_sliced, indices] + + return _impl + + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -2502,6 +2513,7 @@ def _impl(inputs, attr, params, mod): "TopKV2": _topk(), "Transpose": _transpose(), "TruncateMod": _elemwise("mod"), + "Unique": _unique(), "Unpack": _unpack(), "UnravelIndex": _unravel_index(), "Where": _where(), diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index cf6e0ff9a267..965ad52515df 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -972,4 +972,9 @@ def unique_shape_func(attrs, inputs, _): """ Shape func for unique operator. """ - return [_unique_shape_1(inputs[0]), _unique_shape_1(inputs[0]), _unique_shape_1(inputs[0]), _unique_shape_2(inputs[0])] + return [ + _unique_shape_1(inputs[0]), + _unique_shape_1(inputs[0]), + _unique_shape_1(inputs[0]), + _unique_shape_2(inputs[0]), + ] diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 4a769d0b2030..4795a1a9b369 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -962,9 +962,7 @@ def unique(data): return te.extern( [data.shape, data.shape, data.shape, (1,)], [data], - lambda ins, outs: tvm.tir.call_packed( - "tvm.contrib.algorithm.unique", ins[0], *outs - ), + lambda ins, outs: tvm.tir.call_packed("tvm.contrib.algorithm.unique", ins[0], *outs), dtype=[data.dtype, "int32", "int32", "int32", "int32"], name="unique_cpu", tag="unique_cpu", diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index f956ea02eb47..8c1c613c79fd 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -4839,5 +4839,34 @@ def lstm_cell(): tvm.testing.assert_allclose(tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5) +####################################################################### +# Unique +# ------------ +def _test_unique(n, dtype, is_dyn): + """ One iteration of a Stridedslice """ + + tf.reset_default_graph() + np_data = np.random.randint(100, size=n).astype(dtype) + with tf.Graph().as_default(): + if is_dyn: + in_data = tf.placeholder(dtype, [n], name="in_data") + else: + in_data = tf.constant(np_data, dtype, name="in_data") + tf.unique(in_data) + if is_dyn: + compare_tf_with_tvm(np_data, "in_data:0", ["Unique:0", "Unique:1"], mode="vm") + else: + compare_tf_with_tvm(None, "", ["Unique:0", "Unique:1"]) + + +def test_forward_unique(): + """test Unique""" + + for dtype in ["int32", "int64"]: + for is_dyn in [False, True]: + _test_unique(50, dtype, is_dyn) + _test_unique(100, dtype, is_dyn) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/relay/test_op_level6.py b/tests/python/relay/test_op_level6.py index fc7ccaf3aeb5..31d1f8ef68fe 100644 --- a/tests/python/relay/test_op_level6.py +++ b/tests/python/relay/test_op_level6.py @@ -142,7 +142,8 @@ def verify_topk(k, axis, ret_type, is_ascend, dtype): def test_unique(): def calc_unique(data): uniq, index, inverse, counts = np.unique( - data, return_index=True, return_inverse=True, return_counts=True) + data, return_index=True, return_inverse=True, return_counts=True + ) order = np.argsort(index) reverse_order = dict(zip(order, np.arange(len(order)))) uniq = uniq[order].astype(data.dtype) @@ -151,15 +152,15 @@ def calc_unique(data): num_uniq = np.array([len(uniq)]).astype("int32") return uniq, inverse, counts, num_uniq - def verify_unique(len, dtype, is_dyn=False): + def verify_unique(n, dtype, is_dyn=False): if is_dyn: x = relay.var("x", relay.TensorType([relay.Any()], dtype)) else: - x = relay.var("x", relay.TensorType([len], dtype)) + x = relay.var("x", relay.TensorType([n], dtype)) outs = relay.unique(x) outs = outs.astuple() func = relay.Function([x], outs) - x_data = np.random.randint(100, size=len).astype(dtype) + x_data = np.random.randint(100, size=n).astype(dtype) if is_dyn: backends = ["vm", "debug"] @@ -168,21 +169,18 @@ def verify_unique(len, dtype, is_dyn=False): for target, ctx in tvm.testing.enabled_targets(): for kind in backends: mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor( - kind, mod=mod, ctx=ctx, target=target) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(x_data) ref_res = calc_unique(x_data) num_uniq = ref_res[3][0] assert num_uniq == op_res[3].asnumpy()[0] # output - tvm.testing.assert_allclose( - op_res[0].asnumpy()[:num_uniq], ref_res[0], rtol=1e-5) + tvm.testing.assert_allclose(op_res[0].asnumpy()[:num_uniq], ref_res[0], rtol=1e-5) # inverse_indices - tvm.testing.assert_allclose( - op_res[1].asnumpy(), ref_res[1], rtol=1e-5) + tvm.testing.assert_allclose(op_res[1].asnumpy(), ref_res[1], rtol=1e-5) # count - tvm.testing.assert_allclose( - op_res[2].asnumpy()[:num_uniq], ref_res[2], rtol=1e-5) + tvm.testing.assert_allclose(op_res[2].asnumpy()[:num_uniq], ref_res[2], rtol=1e-5) + for dtype in ["int32", "int64"]: for is_dyn in [False, True]: verify_unique((50), dtype, is_dyn=is_dyn) From e83259c404a0c9a46a83e7bfb21ebff6b6075810 Mon Sep 17 00:00:00 2001 From: Yanming Wang Date: Fri, 19 Feb 2021 14:57:30 -0800 Subject: [PATCH 03/16] Refactor unique to use sort-based algorithm --- include/tvm/relay/attrs/transform.h | 8 + python/tvm/relay/frontend/tensorflow.py | 11 +- python/tvm/relay/op/_transform.py | 40 +++-- python/tvm/relay/op/strategy/generic.py | 8 +- python/tvm/relay/op/transform.py | 32 ++-- python/tvm/topi/__init__.py | 1 + python/tvm/topi/generic/search.py | 16 ++ python/tvm/topi/transform.py | 36 ----- python/tvm/topi/unique.py | 118 ++++++++++++++ src/relay/op/algorithm/unique.cc | 147 ------------------ src/relay/op/tensor/transform.cc | 42 +++++ .../frontend/tensorflow/test_forward.py | 2 + tests/python/relay/test_op_level3.py | 48 ++++++ tests/python/relay/test_op_level6.py | 49 ------ tests/python/topi/python/test_topi_unique.py | 71 +++++++++ 15 files changed, 351 insertions(+), 278 deletions(-) create mode 100644 python/tvm/topi/unique.py delete mode 100644 src/relay/op/algorithm/unique.cc create mode 100644 tests/python/topi/python/test_topi_unique.py diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 24098b74f3b6..44450a8b7e99 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -452,6 +452,14 @@ struct CumsumAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes used in unique operator */ +struct UniqueAttrs : public tvm::AttrsNode { + bool sorted; + TVM_DECLARE_ATTRS(UniqueAttrs, "relay.attrs.UniqueAttrs") { + TVM_ATTR_FIELD(sorted).describe("Whether the unique elements are sorted").set_default(true); + } +}; // struct UniqueAttrs + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_TRANSFORM_H_ diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index b29ad67ee233..f64727b286d9 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -2327,10 +2327,13 @@ def _impl(inputs, attr, params, mod): def _unique(): def _impl(inputs, attr, params, mod): assert len(inputs) == 1 - x = inputs[0] - [output, indices, counts, num_uniq] = _op.unique(x) - output_sliced = _op.strided_slice(output, begin=[0], end=num_uniq, slice_mode="size") - return [output_sliced, indices] + data = inputs[0] + [unique, indices, num_uniq] = _op.unique(data, is_sorted=False) + unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size") + return _expr.TupleWrapper( + _expr.Tuple([unique_sliced, indices]), + 2, + ) return _impl diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 965ad52515df..7476a7e3c998 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -28,7 +28,6 @@ from . import op as _reg from . import strategy from .op import OpPattern -from .op import register_strategy, register_pattern from ._tensor import elemwise_shape_func _reg.register_broadcast_schedule("broadcast_to") @@ -143,6 +142,15 @@ def compute_cumsum(attrs, inputs, output_type): _reg.register_strategy("cumsum", strategy.cumsum_strategy) _reg.register_shape_func("cumsum", False, elemwise_shape_func) + +@_reg.register_compute("unique") +def compute_unique(attrs, inputs, output_type): + """Compute definition of cumsum""" + return topi.unique(inputs[0], attrs.sorfted) + + +_reg.register_strategy("unique", strategy.unique_strategy) + ##################### # Shape functions # ##################### @@ -949,22 +957,15 @@ def where_shape_func(attrs, inputs, _): return [out_shape] -register_strategy("unique", strategy.unique_strategy) -register_pattern("unique", OpPattern.OPAQUE) - - -@script -def _unique_shape_1(data_shape): - shape_tensor = output_tensor((1,), "int64") - shape_tensor[0] = int64(data_shape[0]) - return shape_tensor - - @script -def _unique_shape_2(inputs): - shape_tensor = output_tensor((1,), "int64") - shape_tensor[0] = int64(1) - return shape_tensor +def _unique_shape(data_shape): + unique_shape = output_tensor((1,), "int64") + indices_shape = output_tensor((1,), "int64") + num_unique_shape = output_tensor((1,), "int64") + unique_shape[0] = data_shape[0] + indices_shape[0] = data_shape[0] + num_unique_shape[0] = int64(1) + return (unique_shape, indices_shape, num_unique_shape) @_reg.register_shape_func("unique", False) @@ -972,9 +973,4 @@ def unique_shape_func(attrs, inputs, _): """ Shape func for unique operator. """ - return [ - _unique_shape_1(inputs[0]), - _unique_shape_1(inputs[0]), - _unique_shape_1(inputs[0]), - _unique_shape_2(inputs[0]), - ] + return _unique_shape(inputs[0]) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index a9d11791afb0..a282803b0d4a 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1438,22 +1438,18 @@ def wrap_compute_unique(topi_compute): """Wrap unique topi compute""" def _compute_unique(attrs, inputs, _): - return topi_compute(inputs[0]) + return topi_compute(inputs[0], attrs.sorted) return _compute_unique -def wrap_unique_schedule(outs): - return topi.generic.default.default_schedule(outs, False) - - @override_native_generic_func("unique_strategy") def unique_strategy(attrs, inputs, out_type, target): """unique generic strategy""" strategy = _op.OpStrategy() strategy.add_implementation( wrap_compute_unique(topi.unique), - wrap_topi_schedule(wrap_unique_schedule), + wrap_topi_schedule(topi.generic.schedule_unique), name="unique.generic", ) return strategy diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 1f34f8b24b74..8adf37d6098d 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1465,30 +1465,34 @@ def cumsum(data, axis=None, dtype=None, exclusive=None): return _make.cumsum(data, axis, dtype, exclusive) -def unique(data): +def unique(data, is_sorted=True): """ Find the unique elements of a tensor Parameters ---------- data : relay.Expr A 1-D tensor of integers + sorted : bool + Whether to sort the unique elements in ascending order before returning as output Returns ------- output : relay.Expr - A 1-D tensor containing the unique elements of data tensor - inverse_indices : relay.Expr - A 1-D tensor containing the index of each value of data tensor in output tensor - counts : relay.Expr - A 1-D tensor containing the count of each element of output tensor in data tensor - num_unique_elements : relay.Expr - A 0-D tensor containing the number of unique elements in data tensor + A 1-D tensor containing the unique elements of the input data tensor + indices : relay.Expr + A 1-D tensor containing the index of each data element in the output tensor + num_unique : relay.Expr + A 0-D tensor containing the number of unique elements in the input data tensor Examples -------- .. code-block:: python - [y, idx, counts, n] = unique([1, 1, 2, 4, 4, 4, 7, 8, 8]) - y = [1, 2, 4, 7, 8, ?, ?, ?, ?] - idx = [0, 0, 1, 2, 2, 2, 3, 4, 4] - count = [2, 1, 3, 1, 2, ?, ?, ?, ?] - n = [5] + [output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], sorted=False) + output = [4, 5, 1, 2, 3, ?, ?, ?] + indices = [0, 1, 2, 3, 4, 4, 0, 1] + num_unique = [5] + + [output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], sorted=True) + output = [1, 2, 3, 4, 5, ?, ?, ?] + indices = [3, 4, 0, 1, 2, 2, 3, 4] + num_unique = [5] """ - return TupleWrapper(_make.unique(data), 4) + return TupleWrapper(_make.unique(data, is_sorted), 3) diff --git a/python/tvm/topi/__init__.py b/python/tvm/topi/__init__.py index 2b17162048e0..63dc4bd4ab83 100644 --- a/python/tvm/topi/__init__.py +++ b/python/tvm/topi/__init__.py @@ -43,6 +43,7 @@ from .argwhere import * from .cumsum import * from .einsum import * +from .unique import * from . import generic from . import nn from . import x86 diff --git a/python/tvm/topi/generic/search.py b/python/tvm/topi/generic/search.py index 5924d35def73..f458ee7bc782 100644 --- a/python/tvm/topi/generic/search.py +++ b/python/tvm/topi/generic/search.py @@ -70,3 +70,19 @@ def schedule_scatter_add(outs): def schedule_sparse_fill_empty_rows(outs): return _default_schedule(outs, False) + + +def schedule_unique(outs): + """Schedule for unique operator. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of unique. + + Returns + ------- + s: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 4795a1a9b369..6ddbc73e4666 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -931,39 +931,3 @@ def adv_index(data, indices): Output tensor """ return cpp.adv_index(data, indices) - - -def unique(data): - """ - Find the unique elements of a tensor - Parameters - ---------- - data : tvm.te.Tensor - A 1-D tensor of integers - Returns - ------- - output : tvm.te.Tensor - A 1-D tensor containing the unique elements of data tensor - inverse_indices : rtvm.te.Tensor - A 1-D tensor containing the index of each value of data tensor in output tensor - counts : tvm.te.Tensor - A 1-D tensor containing the count of each element of output tensor in data tensor - num_unique_elements : tvm.te.Tensor - A 0-D tensor containing the number of unique elements in data tensor - Examples - -------- - .. code-block:: python - [y, idx, counts, n] = unique([1, 1, 2, 4, 4, 4, 7, 8, 8]) - y = [1, 2, 4, 7, 8, ?, ?, ?, ?] - idx = [0, 0, 1, 2, 2, 2, 3, 4, 4] - count = [2, 1, 3, 1, 2, ?, ?, ?, ?] - n = [5] - """ - return te.extern( - [data.shape, data.shape, data.shape, (1,)], - [data], - lambda ins, outs: tvm.tir.call_packed("tvm.contrib.algorithm.unique", ins[0], *outs), - dtype=[data.dtype, "int32", "int32", "int32", "int32"], - name="unique_cpu", - tag="unique_cpu", - ) diff --git a/python/tvm/topi/unique.py b/python/tvm/topi/unique.py new file mode 100644 index 000000000000..e78a6acb6658 --- /dev/null +++ b/python/tvm/topi/unique.py @@ -0,0 +1,118 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Unique operator""" +from ..te import hybrid +from .cumsum import cumsum +from .sort import sort, argsort + + +@hybrid.script +def _calc_adjacent_diff(data): + output = output_tensor(data.shape, "int32") + output[0] = int32(0) + for i in range(1, data.shape[0]): + output[i] = int32(1) if data[i] != data[i - 1] else int32(0) + return output + + +@hybrid.script +def _calc_num_unique(data): + output = output_tensor((1,), "int32") + output[0] = data[data.shape[0] - 1] + 1 + return output + + +@hybrid.script +def _calc_unique_sorted(data, argsorted_indices, inc_scan): + unique_elements = output_tensor(data.shape, data.dtype) + indices = output_tensor(data.shape, "int32") + for i in range(data.shape[0]): + indices[argsorted_indices[i]] = inc_scan[i] + unique_elements[inc_scan[i]] = data[argsorted_indices[i]] + return unique_elements, indices + + +@hybrid.script +def _calc_first_occurence(argsorted_indices, inc_scan): + first_occurence = output_tensor(argsorted_indices.shape, "int32") + for i in range(argsorted_indices.shape[0]): + first_occurence[i] = argsorted_indices.shape[0] + for i in range(argsorted_indices.shape[0]): + if i == 0 or inc_scan[i] != inc_scan[i - 1]: + first_occurence[inc_scan[i]] = argsorted_indices[i] + return first_occurence + + +@hybrid.script +def _calc_unique_unsorted(data, argsorted_indices, inc_scan, index_converter): + unique_elements = output_tensor(data.shape, data.dtype) + indices = output_tensor(data.shape, "int32") + for i in range(data.shape[0]): + new_unique_idx = index_converter[inc_scan[i]] + new_data_idx = argsorted_indices[i] + unique_elements[new_unique_idx] = data[new_data_idx] + indices[new_data_idx] = new_unique_idx + return unique_elements, indices + + +def unique(data, is_sorted=True): + """ + Find the unique elements of a tensor + Parameters + ---------- + data : relay.Expr + A 1-D tensor of integers + sorted : bool + Whether to sort the unique elements in ascending order before returning as output + Returns + ------- + output : relay.Expr + A 1-D tensor containing the unique elements of the input data tensor + indices : relay.Expr + A 1-D tensor containing the index of each data element in the output tensor + num_unique : relay.Expr + A 0-D tensor containing the number of unique elements in the input data tensor + Examples + -------- + .. code-block:: python + [output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], sorted=False) + output = [4, 5, 1, 2, 3, ?, ?, ?] + indices = [0, 1, 2, 3, 4, 4, 0, 1] + num_unique = [5] + + [output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], sorted=True) + output = [1, 2, 3, 4, 5, ?, ?, ?] + indices = [3, 4, 0, 1, 2, 2, 3, 4] + num_unique = [5] + """ + + sorted_data = sort(data) + argsorted_indices = argsort(data, dtype="int32") + adjacent_diff = _calc_adjacent_diff(sorted_data) + inc_scan = cumsum(adjacent_diff, dtype="int32", exclusive=0) + num_unique_elements = _calc_num_unique(inc_scan) + if is_sorted: + unique_elements, inverse_indices = _calc_unique_sorted(data, argsorted_indices, inc_scan) + else: + first_occurence = _calc_first_occurence(argsorted_indices, inc_scan) + argsorted_first_occurence = argsort(first_occurence, dtype="int32") + index_converter = argsort(argsorted_first_occurence, dtype="int32") + unique_elements, inverse_indices = _calc_unique_unsorted( + data, argsorted_indices, inc_scan, index_converter + ) + return [unique_elements, inverse_indices, num_unique_elements] diff --git a/src/relay/op/algorithm/unique.cc b/src/relay/op/algorithm/unique.cc deleted file mode 100644 index 941e6dac4f53..000000000000 --- a/src/relay/op/algorithm/unique.cc +++ /dev/null @@ -1,147 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file unique.cc - * \brief The unique operator - */ -#include -#include -#include -#include -#include - -namespace tvm { -namespace relay { - -bool UniqueRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { - // types: [data, result] - ICHECK_EQ(types.size(), 2) << "Unique: expect 2 types but " << types.size() << " provided"; - ICHECK_EQ(num_inputs, 1) << "Unique: expect 1 inputs but " << num_inputs << " provided"; - auto data = types[0].as(); - if (data == nullptr) { - ICHECK(types[0].as()) - << "Unique: expect input type to be TensorType but get " << types[0]; - return false; - } - std::vector fields; - fields.push_back(TensorType(data->shape, data->dtype)); - fields.push_back(TensorType(data->shape, DataType::Int(32))); - fields.push_back(TensorType(data->shape, DataType::Int(32))); - fields.push_back(TensorType(Array{1}, DataType::Int(32))); - reporter->Assign(types[1], TupleType(Array(fields))); - return true; -} - -Expr MakeUnique(Expr data) { - static const Op& op = Op::Get("unique"); - return Call(op, {data}, Attrs(), {}); -} - -TVM_REGISTER_GLOBAL("relay.op._make.unique").set_body_typed(MakeUnique); - -RELAY_REGISTER_OP("unique") - .describe( - R"code(This operation returns a tensor **output** containing all of the unique elements of **data** - sorted in the same order that they occur in **data**; **data** does not need to be sorted. - This operation also returns a tensor **inverse_indices** contains the index of each value of **data** in the unique output **output**. - In other words: output[inverse_indices[i]] = data[i] for i in [0, 1,..., len(data) - 1]. - This operation also returns a 0-D tensor **num_unique_elements** contains the number of unique elements in **data**. - Please note **output** and **counts** have the same size of **data** and only items [0, 1,..., num_unique_elements[0]-1] are valid. - - - **data**: A 1-D tensor of integers - - - **output**: A 1-D tensor containing the unique elements of **data** - - - **inverse_indices**: A 1-D tensor containing the index of each value of **data** in **output** - - - **counts**: A 1-D tensor containing the count of each element of **output** in **data** - - - **num_unique_elements**: A 0-D tensor containing the number of unique elements - - Example:: - - [y, idx, counts, n] = unique([1, 1, 2, 4, 4, 4, 7, 8, 8]) - - y = [1, 2, 4, 7, 8, ?, ?, ?, ?] - - idx = [0, 0, 1, 2, 2, 2, 3, 4, 4] - - count = [2, 1, 3, 1, 2, ?, ?, ?, ?] - - n = [5] - )code" TVM_ADD_FILELINE) - .set_num_inputs(1) - .add_argument("data", "Tensor", "The input tensor") - .add_type_rel("unique", UniqueRel) - .set_support_level(6); - -template -void calc_unique(DLTensor* input, DLTensor* output, DLTensor* inverse_indices, DLTensor* counts, - DLTensor* num_unique_elements) { - std::unordered_map - unique_map; // map to record the idx of each unique element in the output tensor - auto input_ptr = static_cast(input->data); - auto output_ptr = static_cast(output->data); - auto inverse_indices_ptr = static_cast(inverse_indices->data); - auto counts_ptr = static_cast(counts->data); - auto num_unique_ptr = static_cast(num_unique_elements->data); - - int unique_counter = 0; - for (int i = 0; i < input->shape[0]; i++) { - if (unique_map.count(input_ptr[i]) == 0) { - unique_map[input_ptr[i]] = unique_counter; - output_ptr[unique_counter] = input_ptr[i]; - counts_ptr[unique_counter] = 0; - unique_counter++; - } - inverse_indices_ptr[i] = unique_map[input_ptr[i]]; - counts_ptr[inverse_indices_ptr[i]]++; - } - - num_unique_ptr[0] = unique_counter; -} - -// The unique operator -TVM_REGISTER_GLOBAL("tvm.contrib.algorithm.unique").set_body([](TVMArgs args, TVMRetValue* ret) { - DLTensor* input = args[0]; - DLTensor* output = args[1]; - DLTensor* inverse_indices = args[2]; - DLTensor* counts = args[3]; - DLTensor* num_unique_elements = args[4]; - - ICHECK_EQ(input->ndim, 1) << "The input tensor must be 1-D"; - ICHECK((output->ndim) == 1 && (inverse_indices->ndim) == 1 && (counts->ndim == 1) && - (num_unique_elements->ndim == 1)) - << "The output,inverse_indices,counts,num_unique_elements tensors must be 1-D"; - ICHECK((input->shape[0] == output->shape[0]) && (input->shape[0] == inverse_indices->shape[0]) && - (input->shape[0] == counts->shape[0])) - << "The input,output,inverse_indices,counts tensors must have the " - "same size"; - ICHECK_EQ(num_unique_elements->shape[0], 1) << "The num_unique_elements tensor must have size 1"; - - auto data_dtype = tvm::runtime::DLDataType2String(input->dtype); - - if (data_dtype == "int32") { - calc_unique(input, output, inverse_indices, counts, num_unique_elements); - } else if (data_dtype == "int64") { - calc_unique(input, output, inverse_indices, counts, num_unique_elements); - } else { - LOG(FATAL) << "Unsupported input dtype: " << data_dtype; - } -}); - -} // namespace relay -} // namespace tvm diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 12db859d1ae1..4e8617a59804 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3772,5 +3772,47 @@ RELAY_REGISTER_OP("cumsum") .add_type_rel("Cumsum", CumsumRel) .set_attr("TOpPattern", kOpaque); +TVM_REGISTER_NODE_TYPE(UniqueAttrs); + +bool UniqueRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // types: [data, result] + ICHECK_EQ(types.size(), 2) << "Unique: expect 2 types but " << types.size() << " provided"; + ICHECK_EQ(num_inputs, 1) << "Unique: expect 1 inputs but " << num_inputs << " provided"; + auto data = types[0].as(); + if (data == nullptr) { + ICHECK(types[0].as()) + << "Unique: expect input type to be TensorType but get " << types[0]; + return false; + } + const int ndim = static_cast(data->shape.size()); + ICHECK_EQ(ndim, 1) << "Unique: input must be 1-D tensor"; + ICHECK_EQ(data->dtype.is_int(), true) << "Unique: input must have int32 or int64 dtype"; + std::vector fields; + fields.push_back(TensorType(data->shape, data->dtype)); // unique + fields.push_back(TensorType(data->shape, DataType::Int(32))); // indices + fields.push_back(TensorType(Array{1}, DataType::Int(32))); // num_unique + reporter->Assign(types[1], TupleType(Array(fields))); + return true; +} + +Expr MakeUnique(Expr data, bool sorted) { + auto attrs = make_object(); + attrs->sorted = sorted; + static const Op& op = Op::Get("unique"); + return Call(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.unique").set_body_typed(MakeUnique); + +RELAY_REGISTER_OP("unique") + .describe( + R"code(This operation returns the unique elements and the new index of each item in a given 1-D array. + )code" TVM_ADD_FILELINE) + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor") + .add_type_rel("unique", UniqueRel) + .set_support_level(3) + .set_attr("TOpPattern", kOpaque); } // namespace relay } // namespace tvm diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 8c1c613c79fd..b6cec3f8954b 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -4842,6 +4842,8 @@ def lstm_cell(): ####################################################################### # Unique # ------------ + + def _test_unique(n, dtype, is_dyn): """ One iteration of a Stridedslice """ diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 94fac3ba1264..b6693e6e0252 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1453,5 +1453,53 @@ def verify_scatter_nd_with_stack(data_np, indices_np, shape, ref_res, rtol=1e-5, verify_scatter_nd_with_stack(data, indices, shape, out) +@tvm.testing.uses_gpu +def test_unique(): + def calc_numpy_unique(data, is_sorted=False): + uniq, index, inverse, counts = np.unique( + data, return_index=True, return_inverse=True, return_counts=True + ) + num_uniq = np.array([len(uniq)]).astype("int32") + if not is_sorted: + order = np.argsort(index) + reverse_order = np.argsort(order) + uniq = uniq[order].astype(data.dtype) + inverse = np.array([reverse_order[i] for i in inverse]).astype("int32") + counts = counts[order].astype("int32") + return [uniq.astype(data.dtype), inverse.astype("int32"), counts, num_uniq] + + def verify_unique(n, dtype, is_dyn=False, is_sorted=False): + if is_dyn: + x = relay.var("x", relay.TensorType([relay.Any()], dtype)) + else: + x = relay.var("x", relay.TensorType([n], dtype)) + outs = relay.unique(x, is_sorted) + outs = outs.astuple() + func = relay.Function([x], outs) + x_data = np.random.randint(50, size=n).astype(dtype) + + if is_dyn: + backends = ["vm", "debug"] + else: + backends = ["graph", "debug"] + for target, ctx in tvm.testing.enabled_targets(): + for kind in backends: + mod = tvm.ir.IRModule.from_expr(func) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + tvm_res = intrp.evaluate()(x_data) + np_res = calc_numpy_unique(x_data, is_sorted) + num_unique = np_res[3][0] + assert num_unique == tvm_res[2].asnumpy()[0] + # unique + tvm.testing.assert_allclose(tvm_res[0].asnumpy()[:num_unique], np_res[0], rtol=1e-5) + # inverse_indices + tvm.testing.assert_allclose(tvm_res[1].asnumpy(), np_res[1], rtol=1e-5) + + for dtype in ["int32", "int64"]: + for is_dyn in [True, False]: + verify_unique(1, dtype, is_dyn, True) + verify_unique(50, dtype, is_dyn, False) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/relay/test_op_level6.py b/tests/python/relay/test_op_level6.py index 31d1f8ef68fe..0dac69e36025 100644 --- a/tests/python/relay/test_op_level6.py +++ b/tests/python/relay/test_op_level6.py @@ -139,56 +139,7 @@ def verify_topk(k, axis, ret_type, is_ascend, dtype): verify_topk(k, axis, ret_type, False, "float32") -def test_unique(): - def calc_unique(data): - uniq, index, inverse, counts = np.unique( - data, return_index=True, return_inverse=True, return_counts=True - ) - order = np.argsort(index) - reverse_order = dict(zip(order, np.arange(len(order)))) - uniq = uniq[order].astype(data.dtype) - inverse = np.array([reverse_order[i] for i in inverse]).astype("int32") - counts = counts[order].astype("int32") - num_uniq = np.array([len(uniq)]).astype("int32") - return uniq, inverse, counts, num_uniq - - def verify_unique(n, dtype, is_dyn=False): - if is_dyn: - x = relay.var("x", relay.TensorType([relay.Any()], dtype)) - else: - x = relay.var("x", relay.TensorType([n], dtype)) - outs = relay.unique(x) - outs = outs.astuple() - func = relay.Function([x], outs) - x_data = np.random.randint(100, size=n).astype(dtype) - - if is_dyn: - backends = ["vm", "debug"] - else: - backends = ["graph", "debug"] - for target, ctx in tvm.testing.enabled_targets(): - for kind in backends: - mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) - op_res = intrp.evaluate()(x_data) - ref_res = calc_unique(x_data) - num_uniq = ref_res[3][0] - assert num_uniq == op_res[3].asnumpy()[0] - # output - tvm.testing.assert_allclose(op_res[0].asnumpy()[:num_uniq], ref_res[0], rtol=1e-5) - # inverse_indices - tvm.testing.assert_allclose(op_res[1].asnumpy(), ref_res[1], rtol=1e-5) - # count - tvm.testing.assert_allclose(op_res[2].asnumpy()[:num_uniq], ref_res[2], rtol=1e-5) - - for dtype in ["int32", "int64"]: - for is_dyn in [False, True]: - verify_unique((50), dtype, is_dyn=is_dyn) - verify_unique((100), dtype, is_dyn=is_dyn) - - if __name__ == "__main__": test_sort() test_argsort() test_topk() - test_unique() diff --git a/tests/python/topi/python/test_topi_unique.py b/tests/python/topi/python/test_topi_unique.py new file mode 100644 index 000000000000..87d8c4df0271 --- /dev/null +++ b/tests/python/topi/python/test_topi_unique.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import tvm +import tvm.testing +from tvm import topi +import tvm.topi.testing + + +@tvm.testing.parametrize_targets +def test_unique(ctx, target): + def calc_numpy_unique(data, is_sorted=False): + uniq, index, inverse, counts = np.unique( + data, return_index=True, return_inverse=True, return_counts=True + ) + num_uniq = np.array([len(uniq)]).astype("int32") + if not is_sorted: + order = np.argsort(index) + reverse_order = np.argsort(order) + uniq = uniq[order].astype(data.dtype) + inverse = np.array([reverse_order[i] for i in inverse]).astype("int32") + counts = counts[order].astype("int32") + return [uniq.astype(data.dtype), inverse.astype("int32"), counts, num_uniq] + + def check_unique(data, is_sorted=False): + implementations = { + "generic": (lambda x: topi.unique(x, is_sorted), topi.generic.schedule_unique), + } + fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations) + tvm_data = tvm.nd.array(data, ctx=ctx) + tvm_unique = tvm.nd.array(np.zeros(data.shape).astype(data.dtype), ctx=ctx) + tvm_indices = tvm.nd.array(np.zeros(data.shape).astype("int32"), ctx=ctx) + tvm_num_unique = tvm.nd.array(np.zeros([1]).astype("int32"), ctx=ctx) + with tvm.target.Target(target): + te_input = tvm.te.placeholder(shape=data.shape, dtype=str(data.dtype)) + outs = fcompute(te_input) + s = fschedule(outs) + func = tvm.build(s, [te_input, *outs]) + func(tvm_data, tvm_unique, tvm_indices, tvm_num_unique) + np_unique, np_indices, _, np_num_unique = calc_numpy_unique(data, is_sorted) + num_unique = np_num_unique[0] + assert tvm_num_unique.asnumpy()[0] == np_num_unique + np.testing.assert_allclose( + tvm_unique.asnumpy()[:num_unique], np_unique, atol=1e-5, rtol=1e-5 + ) + np.testing.assert_allclose(tvm_indices.asnumpy(), np_indices, atol=1e-5, rtol=1e-5) + + for in_dtype in ["int32", "int64"]: + for is_sorted in [True, False]: + data = np.random.randint(0, 100, size=(100)).astype(in_dtype) + check_unique(data, is_sorted) + data = np.random.randint(0, 10, size=(100)).astype(in_dtype) + check_unique(data, is_sorted) + + +if __name__ == "__main__": + test_unique(tvm.context("cpu"), tvm.target.Target("llvm")) From 99463dd77bcc53e818110965b91e857c2be60abc Mon Sep 17 00:00:00 2001 From: Yanming Wang Date: Fri, 19 Feb 2021 18:24:14 -0800 Subject: [PATCH 04/16] Change relay.unique test to run only on cpu --- tests/python/relay/test_op_level3.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index b6693e6e0252..b551bd3e4afd 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1453,7 +1453,6 @@ def verify_scatter_nd_with_stack(data_np, indices_np, shape, ref_res, rtol=1e-5, verify_scatter_nd_with_stack(data, indices, shape, out) -@tvm.testing.uses_gpu def test_unique(): def calc_numpy_unique(data, is_sorted=False): uniq, index, inverse, counts = np.unique( @@ -1482,18 +1481,19 @@ def verify_unique(n, dtype, is_dyn=False, is_sorted=False): backends = ["vm", "debug"] else: backends = ["graph", "debug"] - for target, ctx in tvm.testing.enabled_targets(): - for kind in backends: - mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) - tvm_res = intrp.evaluate()(x_data) - np_res = calc_numpy_unique(x_data, is_sorted) - num_unique = np_res[3][0] - assert num_unique == tvm_res[2].asnumpy()[0] - # unique - tvm.testing.assert_allclose(tvm_res[0].asnumpy()[:num_unique], np_res[0], rtol=1e-5) - # inverse_indices - tvm.testing.assert_allclose(tvm_res[1].asnumpy(), np_res[1], rtol=1e-5) + + target, ctx = "llvm", tvm.cpu() + for kind in backends: + mod = tvm.ir.IRModule.from_expr(func) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + tvm_res = intrp.evaluate()(x_data) + np_res = calc_numpy_unique(x_data, is_sorted) + num_unique = np_res[3][0] + assert num_unique == tvm_res[2].asnumpy()[0] + # unique + tvm.testing.assert_allclose(tvm_res[0].asnumpy()[:num_unique], np_res[0], rtol=1e-5) + # inverse_indices + tvm.testing.assert_allclose(tvm_res[1].asnumpy(), np_res[1], rtol=1e-5) for dtype in ["int32", "int64"]: for is_dyn in [True, False]: From 0ffad98e3b9a475a048695c313df9f8ecd0bd228 Mon Sep 17 00:00:00 2001 From: Yanming Wang Date: Fri, 19 Feb 2021 18:35:23 -0800 Subject: [PATCH 05/16] Change topi.unique test to run only on cpu --- tests/python/topi/python/test_topi_unique.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/python/topi/python/test_topi_unique.py b/tests/python/topi/python/test_topi_unique.py index 87d8c4df0271..4f959aa02356 100644 --- a/tests/python/topi/python/test_topi_unique.py +++ b/tests/python/topi/python/test_topi_unique.py @@ -21,8 +21,10 @@ import tvm.topi.testing -@tvm.testing.parametrize_targets -def test_unique(ctx, target): +def test_unique(): + target = "llvm" + ctx = tvm.cpu() + def calc_numpy_unique(data, is_sorted=False): uniq, index, inverse, counts = np.unique( data, return_index=True, return_inverse=True, return_counts=True @@ -68,4 +70,4 @@ def check_unique(data, is_sorted=False): if __name__ == "__main__": - test_unique(tvm.context("cpu"), tvm.target.Target("llvm")) + test_unique() From f31f53ff0da43fb080777e8e85554f49f918ca51 Mon Sep 17 00:00:00 2001 From: Yanming Wang Date: Fri, 19 Feb 2021 18:50:28 -0800 Subject: [PATCH 06/16] Change range to parallel for parallelizable loops --- python/tvm/topi/unique.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/tvm/topi/unique.py b/python/tvm/topi/unique.py index e78a6acb6658..1768bd191747 100644 --- a/python/tvm/topi/unique.py +++ b/python/tvm/topi/unique.py @@ -25,7 +25,7 @@ def _calc_adjacent_diff(data): output = output_tensor(data.shape, "int32") output[0] = int32(0) - for i in range(1, data.shape[0]): + for i in parallel(1, data.shape[0]): output[i] = int32(1) if data[i] != data[i - 1] else int32(0) return output @@ -41,7 +41,7 @@ def _calc_num_unique(data): def _calc_unique_sorted(data, argsorted_indices, inc_scan): unique_elements = output_tensor(data.shape, data.dtype) indices = output_tensor(data.shape, "int32") - for i in range(data.shape[0]): + for i in parallel(data.shape[0]): indices[argsorted_indices[i]] = inc_scan[i] unique_elements[inc_scan[i]] = data[argsorted_indices[i]] return unique_elements, indices @@ -50,9 +50,9 @@ def _calc_unique_sorted(data, argsorted_indices, inc_scan): @hybrid.script def _calc_first_occurence(argsorted_indices, inc_scan): first_occurence = output_tensor(argsorted_indices.shape, "int32") - for i in range(argsorted_indices.shape[0]): + for i in parallel(argsorted_indices.shape[0]): first_occurence[i] = argsorted_indices.shape[0] - for i in range(argsorted_indices.shape[0]): + for i in parallel(argsorted_indices.shape[0]): if i == 0 or inc_scan[i] != inc_scan[i - 1]: first_occurence[inc_scan[i]] = argsorted_indices[i] return first_occurence @@ -62,7 +62,7 @@ def _calc_first_occurence(argsorted_indices, inc_scan): def _calc_unique_unsorted(data, argsorted_indices, inc_scan, index_converter): unique_elements = output_tensor(data.shape, data.dtype) indices = output_tensor(data.shape, "int32") - for i in range(data.shape[0]): + for i in parallel(data.shape[0]): new_unique_idx = index_converter[inc_scan[i]] new_data_idx = argsorted_indices[i] unique_elements[new_unique_idx] = data[new_data_idx] From c403e52762a9f540cba386d938ada21f82953c1c Mon Sep 17 00:00:00 2001 From: Yanming Wang Date: Fri, 19 Feb 2021 23:51:51 -0800 Subject: [PATCH 07/16] Add return_counts option for relay.unique and topi.unique, add pytorch frontend --- include/tvm/relay/attrs/transform.h | 4 + python/tvm/relay/frontend/pytorch.py | 16 ++++ python/tvm/relay/frontend/tensorflow.py | 18 ++++- python/tvm/relay/op/_transform.py | 20 ++++- python/tvm/relay/op/strategy/generic.py | 2 +- python/tvm/relay/op/transform.py | 19 ++++- python/tvm/topi/unique.py | 74 +++++++++++++++++-- src/relay/op/tensor/transform.cc | 7 +- tests/python/frontend/pytorch/test_forward.py | 25 ++++++- .../frontend/tensorflow/test_forward.py | 38 +++++++++- tests/python/relay/test_op_level3.py | 13 ++-- tests/python/topi/python/test_topi_unique.py | 36 ++++++++- 12 files changed, 245 insertions(+), 27 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 44450a8b7e99..ff344f5e1a85 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -455,8 +455,12 @@ struct CumsumAttrs : public tvm::AttrsNode { /*! \brief Attributes used in unique operator */ struct UniqueAttrs : public tvm::AttrsNode { bool sorted; + bool return_counts; TVM_DECLARE_ATTRS(UniqueAttrs, "relay.attrs.UniqueAttrs") { TVM_ATTR_FIELD(sorted).describe("Whether the unique elements are sorted").set_default(true); + TVM_ATTR_FIELD(return_counts) + .describe("Whether to return an additional tensor with counts of each unique elements") + .set_default(false); } }; // struct UniqueAttrs diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 205b2aa779e6..88718a5e673c 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2164,6 +2164,21 @@ def is_floating_point(self, inputs, input_types): is_float = input_type in ["float32", "float64", "float16", "bfloat16"] return _expr.const(is_float) + def unique(self, inputs, input_types): + assert len(inputs) == 4 + [data, is_sorted, return_inverse, return_counts] = inputs + if return_counts: + [unique, indices, num_uniq, counts] = _op.unique( + data, is_sorted=is_sorted, return_counts=True + ) + unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size") + counts_sliced = _op.strided_slice(counts, begin=[0], end=num_uniq, slice_mode="size") + return (unique_sliced, indices, counts_sliced) + else: + [unique, indices, num_uniq] = _op.unique(data, is_sorted=is_sorted, return_counts=False) + unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size") + return (unique_sliced, indices) + # Operator mappings def create_convert_map(self): self.convert_map = { @@ -2370,6 +2385,7 @@ def create_convert_map(self): "aten::masked_select": self.masked_select, "aten::argsort": self.argsort, "aten::sort": self.sort, + "aten::_unique2": self.unique, } def update_convert_map(self, custom_map): diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index f64727b286d9..d58fe24a3206 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -2328,7 +2328,7 @@ def _unique(): def _impl(inputs, attr, params, mod): assert len(inputs) == 1 data = inputs[0] - [unique, indices, num_uniq] = _op.unique(data, is_sorted=False) + [unique, indices, num_uniq] = _op.unique(data, is_sorted=False, return_counts=False) unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size") return _expr.TupleWrapper( _expr.Tuple([unique_sliced, indices]), @@ -2338,6 +2338,21 @@ def _impl(inputs, attr, params, mod): return _impl +def _unique_with_counts(): + def _impl(inputs, attr, params, mod): + assert len(inputs) == 1 + data = inputs[0] + [unique, indices, num_uniq, counts] = _op.unique(data, is_sorted=False, return_counts=True) + unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size") + counts_sliced = _op.strided_slice(counts, begin=[0], end=num_uniq, slice_mode="size") + return _expr.TupleWrapper( + _expr.Tuple([unique_sliced, indices, counts_sliced]), + 3, + ) + + return _impl + + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -2517,6 +2532,7 @@ def _impl(inputs, attr, params, mod): "Transpose": _transpose(), "TruncateMod": _elemwise("mod"), "Unique": _unique(), + "UniqueWithCounts": _unique_with_counts(), "Unpack": _unpack(), "UnravelIndex": _unravel_index(), "Where": _where(), diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 7476a7e3c998..5e6a66a47dae 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -146,7 +146,7 @@ def compute_cumsum(attrs, inputs, output_type): @_reg.register_compute("unique") def compute_unique(attrs, inputs, output_type): """Compute definition of cumsum""" - return topi.unique(inputs[0], attrs.sorfted) + return topi.unique(inputs[0], attrs.sorted, attrs.return_counts) _reg.register_strategy("unique", strategy.unique_strategy) @@ -968,9 +968,25 @@ def _unique_shape(data_shape): return (unique_shape, indices_shape, num_unique_shape) +@script +def _unique_with_counts_shape(data_shape): + unique_shape = output_tensor((1,), "int64") + indices_shape = output_tensor((1,), "int64") + num_unique_shape = output_tensor((1,), "int64") + counts_shape = output_tensor((1,), "int64") + unique_shape[0] = data_shape[0] + indices_shape[0] = data_shape[0] + num_unique_shape[0] = int64(1) + counts_shape[0] = data_shape[0] + return (unique_shape, indices_shape, num_unique_shape, counts_shape) + + @_reg.register_shape_func("unique", False) def unique_shape_func(attrs, inputs, _): """ Shape func for unique operator. """ - return _unique_shape(inputs[0]) + if attrs.return_counts: + return _unique_with_counts_shape(inputs[0]) + else: + return _unique_shape(inputs[0]) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index a282803b0d4a..8a2724dfb614 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1438,7 +1438,7 @@ def wrap_compute_unique(topi_compute): """Wrap unique topi compute""" def _compute_unique(attrs, inputs, _): - return topi_compute(inputs[0], attrs.sorted) + return topi_compute(inputs[0], attrs.sorted, attrs.return_counts) return _compute_unique diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 8adf37d6098d..ae00240b5891 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1465,7 +1465,7 @@ def cumsum(data, axis=None, dtype=None, exclusive=None): return _make.cumsum(data, axis, dtype, exclusive) -def unique(data, is_sorted=True): +def unique(data, is_sorted=True, return_counts=False): """ Find the unique elements of a tensor Parameters @@ -1474,6 +1474,8 @@ def unique(data, is_sorted=True): A 1-D tensor of integers sorted : bool Whether to sort the unique elements in ascending order before returning as output + return_counts : bool + Whether to return the array with count of each unique element Returns ------- output : relay.Expr @@ -1482,17 +1484,28 @@ def unique(data, is_sorted=True): A 1-D tensor containing the index of each data element in the output tensor num_unique : relay.Expr A 0-D tensor containing the number of unique elements in the input data tensor + counts (optional) : relay.Expr + A 1-D tensor containing the count of each unique element in the output Examples -------- .. code-block:: python - [output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], sorted=False) + [output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], sorted=False, return_counts=False) output = [4, 5, 1, 2, 3, ?, ?, ?] indices = [0, 1, 2, 3, 4, 4, 0, 1] num_unique = [5] + [output, indices, num_unique, counts] = unique([4, 5, 1, 2, 3, 3, 4, 5], sorted=False, return_counts=True) + output = [4, 5, 1, 2, 3, ?, ?, ?] + indices = [0, 1, 2, 3, 4, 4, 0, 1] + num_unique = [5] + counts = [2, 2, 1, 1, 2, ?, ?, ?] + [output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], sorted=True) output = [1, 2, 3, 4, 5, ?, ?, ?] indices = [3, 4, 0, 1, 2, 2, 3, 4] num_unique = [5] """ - return TupleWrapper(_make.unique(data, is_sorted), 3) + if return_counts: + return TupleWrapper(_make.unique(data, is_sorted, return_counts), 4) + else: + return TupleWrapper(_make.unique(data, is_sorted, return_counts), 3) diff --git a/python/tvm/topi/unique.py b/python/tvm/topi/unique.py index 1768bd191747..031ca0f3309f 100644 --- a/python/tvm/topi/unique.py +++ b/python/tvm/topi/unique.py @@ -33,7 +33,7 @@ def _calc_adjacent_diff(data): @hybrid.script def _calc_num_unique(data): output = output_tensor((1,), "int32") - output[0] = data[data.shape[0] - 1] + 1 + output[0] = data[data.shape[0] - 1] + int32(1) return output @@ -47,6 +47,21 @@ def _calc_unique_sorted(data, argsorted_indices, inc_scan): return unique_elements, indices +@hybrid.script +def _calc_unique_sorted_with_counts(data, argsorted_indices, inc_scan): + unique_elements = output_tensor(data.shape, data.dtype) + indices = output_tensor(data.shape, "int32") + counts = output_tensor(data.shape, "int32") + for i in parallel(data.shape[0]): + counts[i] = int32(0) + for i in parallel(data.shape[0]): + indices[argsorted_indices[i]] = inc_scan[i] + unique_elements[inc_scan[i]] = data[argsorted_indices[i]] + for i in range(data.shape[0]): + counts[inc_scan[i]] += int32(1) + return unique_elements, indices, counts + + @hybrid.script def _calc_first_occurence(argsorted_indices, inc_scan): first_occurence = output_tensor(argsorted_indices.shape, "int32") @@ -70,7 +85,25 @@ def _calc_unique_unsorted(data, argsorted_indices, inc_scan, index_converter): return unique_elements, indices -def unique(data, is_sorted=True): +@hybrid.script +def _calc_unique_unsorted_with_counts(data, argsorted_indices, inc_scan, index_converter): + unique_elements = output_tensor(data.shape, data.dtype) + indices = output_tensor(data.shape, "int32") + counts = output_tensor(data.shape, "int32") + for i in parallel(data.shape[0]): + counts[i] = int32(0) + for i in parallel(data.shape[0]): + new_unique_idx = index_converter[inc_scan[i]] + new_data_idx = argsorted_indices[i] + unique_elements[new_unique_idx] = data[new_data_idx] + indices[new_data_idx] = new_unique_idx + for i in range(data.shape[0]): + idx = index_converter[inc_scan[i]] + counts[idx] += int32(1) + return unique_elements, indices, counts + + +def unique(data, is_sorted=True, return_counts=False): """ Find the unique elements of a tensor Parameters @@ -79,6 +112,8 @@ def unique(data, is_sorted=True): A 1-D tensor of integers sorted : bool Whether to sort the unique elements in ascending order before returning as output + return_counts : bool + Whether to return the array with count of each unique element Returns ------- output : relay.Expr @@ -87,13 +122,21 @@ def unique(data, is_sorted=True): A 1-D tensor containing the index of each data element in the output tensor num_unique : relay.Expr A 0-D tensor containing the number of unique elements in the input data tensor + counts (optional) : relay.Expr + A 1-D tensor containing the count of each unique element in the output Examples -------- .. code-block:: python - [output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], sorted=False) + [output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], sorted=False, return_counts=False) + output = [4, 5, 1, 2, 3, ?, ?, ?] + indices = [0, 1, 2, 3, 4, 4, 0, 1] + num_unique = [5] + + [output, indices, num_unique, counts] = unique([4, 5, 1, 2, 3, 3, 4, 5], sorted=False, return_counts=True) output = [4, 5, 1, 2, 3, ?, ?, ?] indices = [0, 1, 2, 3, 4, 4, 0, 1] num_unique = [5] + counts = [2, 2, 1, 1, 2, ?, ?, ?] [output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], sorted=True) output = [1, 2, 3, 4, 5, ?, ?, ?] @@ -107,12 +150,27 @@ def unique(data, is_sorted=True): inc_scan = cumsum(adjacent_diff, dtype="int32", exclusive=0) num_unique_elements = _calc_num_unique(inc_scan) if is_sorted: - unique_elements, inverse_indices = _calc_unique_sorted(data, argsorted_indices, inc_scan) + if return_counts: + unique_elements, inverse_indices, counts = _calc_unique_sorted_with_counts( + data, argsorted_indices, inc_scan + ) + return [unique_elements, inverse_indices, num_unique_elements, counts] + else: + unique_elements, inverse_indices = _calc_unique_sorted( + data, argsorted_indices, inc_scan + ) + return [unique_elements, inverse_indices, num_unique_elements] else: first_occurence = _calc_first_occurence(argsorted_indices, inc_scan) argsorted_first_occurence = argsort(first_occurence, dtype="int32") index_converter = argsort(argsorted_first_occurence, dtype="int32") - unique_elements, inverse_indices = _calc_unique_unsorted( - data, argsorted_indices, inc_scan, index_converter - ) - return [unique_elements, inverse_indices, num_unique_elements] + if return_counts: + unique_elements, inverse_indices, counts = _calc_unique_unsorted_with_counts( + data, argsorted_indices, inc_scan, index_converter + ) + return [unique_elements, inverse_indices, num_unique_elements, counts] + else: + unique_elements, inverse_indices = _calc_unique_unsorted( + data, argsorted_indices, inc_scan, index_converter + ) + return [unique_elements, inverse_indices, num_unique_elements] diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 4e8617a59804..eae231fd8d06 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3792,13 +3792,18 @@ bool UniqueRel(const Array& types, int num_inputs, const Attrs& attrs, fields.push_back(TensorType(data->shape, data->dtype)); // unique fields.push_back(TensorType(data->shape, DataType::Int(32))); // indices fields.push_back(TensorType(Array{1}, DataType::Int(32))); // num_unique + const auto* param = attrs.as(); + if (param->return_counts) { + fields.push_back(TensorType(data->shape, DataType::Int(32))); // counts + } reporter->Assign(types[1], TupleType(Array(fields))); return true; } -Expr MakeUnique(Expr data, bool sorted) { +Expr MakeUnique(Expr data, bool sorted, bool return_counts) { auto attrs = make_object(); attrs->sorted = sorted; + attrs->return_counts = return_counts; static const Op& op = Op::Get("unique"); return Call(op, {data}, Attrs(attrs), {}); } diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index aa42b0fb84e4..e589171b0743 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -2064,7 +2064,12 @@ def verify_model_vm(input_model, ishapes, idtype=None, idata=None, targets=["llv pt_result = input_model(*input_data) # Verify the accuracy - if not isinstance(pt_result, torch.Tensor): + if isinstance(pt_result, tuple): + # handle multiple outputs + for i in range(len(pt_result)): + tvm_res = vm_res[i].asnumpy() + tvm.testing.assert_allclose(tvm_res, pt_result[i].numpy(), rtol=1e-5, atol=1e-5) + elif not isinstance(pt_result, torch.Tensor): tvm_res = vm_res.asnumpy().item() assert pt_result == tvm_res else: @@ -3654,6 +3659,23 @@ def test_fn(x, mask): verify_trace_model(test_fn, [x, mask], ["llvm", "cuda", "nvptx"]) +def test_unique(): + def test_fn(is_sorted, return_inverse, return_counts): + return lambda x: torch.unique(x, is_sorted, return_inverse, return_counts) + + in_data = torch.randint(0, 20, (10,), dtype=torch.int32) + targets = ["llvm"] + verify_trace_model(test_fn(True, True, True), [in_data], targets) + verify_trace_model(test_fn(True, False, True), [in_data], targets) + verify_trace_model(test_fn(True, True, False), [in_data], targets) + verify_trace_model(test_fn(True, False, True), [in_data], targets) + in_data = torch.randint(0, 20, (20,), dtype=torch.int64) + verify_trace_model(test_fn(True, True, True), [in_data], targets) + verify_trace_model(test_fn(True, False, True), [in_data], targets) + verify_trace_model(test_fn(True, True, False), [in_data], targets) + verify_trace_model(test_fn(True, False, True), [in_data], targets) + + if __name__ == "__main__": # some structural tests test_forward_traced_function() @@ -3789,6 +3811,7 @@ def test_fn(x, mask): test_argsort() test_logical_and() test_masked_select() + test_unique() # Model tests test_resnet18() diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index b6cec3f8954b..4fd4f98c26e1 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -4845,8 +4845,6 @@ def lstm_cell(): def _test_unique(n, dtype, is_dyn): - """ One iteration of a Stridedslice """ - tf.reset_default_graph() np_data = np.random.randint(100, size=n).astype(dtype) with tf.Graph().as_default(): @@ -4870,5 +4868,41 @@ def test_forward_unique(): _test_unique(100, dtype, is_dyn) +####################################################################### +# Unique with counts +# ------------ + + +def _test_unique_with_counts(n, dtype, is_dyn): + tf.reset_default_graph() + np_data = np.random.randint(100, size=n).astype(dtype) + with tf.Graph().as_default(): + if is_dyn: + in_data = tf.placeholder(dtype, [n], name="in_data") + else: + in_data = tf.constant(np_data, dtype, name="in_data") + tf.unique_with_counts(in_data) + if is_dyn: + compare_tf_with_tvm( + np_data, + "in_data:0", + ["UniqueWithCounts:0", "UniqueWithCounts:1", "UniqueWithCounts:2"], + mode="vm", + ) + else: + compare_tf_with_tvm( + None, "", ["UniqueWithCounts:0", "UniqueWithCounts:1", "UniqueWithCounts:2"] + ) + + +def test_forward_unique_with_counts(): + """test UniqueWithCounts""" + + for dtype in ["int32", "int64"]: + for is_dyn in [False, True]: + _test_unique_with_counts(10, dtype, is_dyn) + _test_unique_with_counts(20, dtype, is_dyn) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index b551bd3e4afd..22d13b1151a0 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1467,12 +1467,12 @@ def calc_numpy_unique(data, is_sorted=False): counts = counts[order].astype("int32") return [uniq.astype(data.dtype), inverse.astype("int32"), counts, num_uniq] - def verify_unique(n, dtype, is_dyn=False, is_sorted=False): + def verify_unique(n, dtype, is_dyn=False, is_sorted=False, return_counts=False): if is_dyn: x = relay.var("x", relay.TensorType([relay.Any()], dtype)) else: x = relay.var("x", relay.TensorType([n], dtype)) - outs = relay.unique(x, is_sorted) + outs = relay.unique(x, is_sorted, return_counts) outs = outs.astuple() func = relay.Function([x], outs) x_data = np.random.randint(50, size=n).astype(dtype) @@ -1494,11 +1494,14 @@ def verify_unique(n, dtype, is_dyn=False, is_sorted=False): tvm.testing.assert_allclose(tvm_res[0].asnumpy()[:num_unique], np_res[0], rtol=1e-5) # inverse_indices tvm.testing.assert_allclose(tvm_res[1].asnumpy(), np_res[1], rtol=1e-5) + # counts + if return_counts: + tvm.testing.assert_allclose(tvm_res[3].asnumpy()[:num_unique], np_res[2], rtol=1e-5) for dtype in ["int32", "int64"]: - for is_dyn in [True, False]: - verify_unique(1, dtype, is_dyn, True) - verify_unique(50, dtype, is_dyn, False) + for i in range(8): + is_dyn, is_sorted, return_counts = bool(i & 1), bool(i & 2), bool(i & 4) + verify_unique(10, dtype, is_dyn, is_sorted, return_counts) if __name__ == "__main__": diff --git a/tests/python/topi/python/test_topi_unique.py b/tests/python/topi/python/test_topi_unique.py index 4f959aa02356..841083d20b14 100644 --- a/tests/python/topi/python/test_topi_unique.py +++ b/tests/python/topi/python/test_topi_unique.py @@ -39,20 +39,45 @@ def calc_numpy_unique(data, is_sorted=False): return [uniq.astype(data.dtype), inverse.astype("int32"), counts, num_uniq] def check_unique(data, is_sorted=False): + # numpy reference + np_unique, np_indices, np_counts, np_num_unique = calc_numpy_unique(data, is_sorted) + num_unique = np_num_unique[0] + implementations = { - "generic": (lambda x: topi.unique(x, is_sorted), topi.generic.schedule_unique), + "generic": ( + lambda x, return_counts: topi.unique(x, is_sorted, return_counts), + topi.generic.schedule_unique, + ), } fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations) tvm_data = tvm.nd.array(data, ctx=ctx) tvm_unique = tvm.nd.array(np.zeros(data.shape).astype(data.dtype), ctx=ctx) tvm_indices = tvm.nd.array(np.zeros(data.shape).astype("int32"), ctx=ctx) tvm_num_unique = tvm.nd.array(np.zeros([1]).astype("int32"), ctx=ctx) + + # without counts with tvm.target.Target(target): te_input = tvm.te.placeholder(shape=data.shape, dtype=str(data.dtype)) - outs = fcompute(te_input) + outs = fcompute(te_input, False) s = fschedule(outs) func = tvm.build(s, [te_input, *outs]) func(tvm_data, tvm_unique, tvm_indices, tvm_num_unique) + + assert tvm_num_unique.asnumpy()[0] == np_num_unique + np.testing.assert_allclose( + tvm_unique.asnumpy()[:num_unique], np_unique, atol=1e-5, rtol=1e-5 + ) + np.testing.assert_allclose(tvm_indices.asnumpy(), np_indices, atol=1e-5, rtol=1e-5) + + # with counts + tvm_counts = tvm.nd.array(np.zeros(data.shape).astype("int32"), ctx=ctx) + with tvm.target.Target(target): + te_input = tvm.te.placeholder(shape=data.shape, dtype=str(data.dtype)) + outs = fcompute(te_input, True) + s = fschedule(outs) + func = tvm.build(s, [te_input, *outs]) + func(tvm_data, tvm_unique, tvm_indices, tvm_num_unique, tvm_counts) + np_unique, np_indices, _, np_num_unique = calc_numpy_unique(data, is_sorted) num_unique = np_num_unique[0] assert tvm_num_unique.asnumpy()[0] == np_num_unique @@ -60,10 +85,15 @@ def check_unique(data, is_sorted=False): tvm_unique.asnumpy()[:num_unique], np_unique, atol=1e-5, rtol=1e-5 ) np.testing.assert_allclose(tvm_indices.asnumpy(), np_indices, atol=1e-5, rtol=1e-5) + np.testing.assert_allclose( + tvm_counts.asnumpy()[:num_unique], np_counts, atol=1e-5, rtol=1e-5 + ) for in_dtype in ["int32", "int64"]: for is_sorted in [True, False]: - data = np.random.randint(0, 100, size=(100)).astype(in_dtype) + data = np.random.randint(0, 100, size=(1)).astype(in_dtype) + check_unique(data, is_sorted) + data = np.random.randint(0, 100, size=(50)).astype(in_dtype) check_unique(data, is_sorted) data = np.random.randint(0, 10, size=(100)).astype(in_dtype) check_unique(data, is_sorted) From 2c2f2c43c362ee09e92054ccb407b46f79f87910 Mon Sep 17 00:00:00 2001 From: Yanming Wang Date: Sat, 20 Feb 2021 00:11:40 -0800 Subject: [PATCH 08/16] Fix pylint --- python/tvm/relay/op/transform.py | 7 +++---- python/tvm/topi/unique.py | 9 +++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index ae00240b5891..cc766bbdf9bd 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1489,12 +1489,12 @@ def unique(data, is_sorted=True, return_counts=False): Examples -------- .. code-block:: python - [output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], sorted=False, return_counts=False) + [output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], False, False) output = [4, 5, 1, 2, 3, ?, ?, ?] indices = [0, 1, 2, 3, 4, 4, 0, 1] num_unique = [5] - [output, indices, num_unique, counts] = unique([4, 5, 1, 2, 3, 3, 4, 5], sorted=False, return_counts=True) + [output, indices, num_unique, counts] = unique([4, 5, 1, 2, 3, 3, 4, 5], False, True) output = [4, 5, 1, 2, 3, ?, ?, ?] indices = [0, 1, 2, 3, 4, 4, 0, 1] num_unique = [5] @@ -1507,5 +1507,4 @@ def unique(data, is_sorted=True, return_counts=False): """ if return_counts: return TupleWrapper(_make.unique(data, is_sorted, return_counts), 4) - else: - return TupleWrapper(_make.unique(data, is_sorted, return_counts), 3) + return TupleWrapper(_make.unique(data, is_sorted, return_counts), 3) diff --git a/python/tvm/topi/unique.py b/python/tvm/topi/unique.py index 031ca0f3309f..b88b1d09e53e 100644 --- a/python/tvm/topi/unique.py +++ b/python/tvm/topi/unique.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name +# pylint: disable=invalid-name, no-else-return """Unique operator""" from ..te import hybrid from .cumsum import cumsum @@ -127,18 +127,18 @@ def unique(data, is_sorted=True, return_counts=False): Examples -------- .. code-block:: python - [output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], sorted=False, return_counts=False) + [output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], False, False) output = [4, 5, 1, 2, 3, ?, ?, ?] indices = [0, 1, 2, 3, 4, 4, 0, 1] num_unique = [5] - [output, indices, num_unique, counts] = unique([4, 5, 1, 2, 3, 3, 4, 5], sorted=False, return_counts=True) + [output, indices, num_unique, counts] = unique([4, 5, 1, 2, 3, 3, 4, 5], False, True) output = [4, 5, 1, 2, 3, ?, ?, ?] indices = [0, 1, 2, 3, 4, 4, 0, 1] num_unique = [5] counts = [2, 2, 1, 1, 2, ?, ?, ?] - [output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], sorted=True) + [output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], True) output = [1, 2, 3, 4, 5, ?, ?, ?] indices = [3, 4, 0, 1, 2, 2, 3, 4] num_unique = [5] @@ -149,6 +149,7 @@ def unique(data, is_sorted=True, return_counts=False): adjacent_diff = _calc_adjacent_diff(sorted_data) inc_scan = cumsum(adjacent_diff, dtype="int32", exclusive=0) num_unique_elements = _calc_num_unique(inc_scan) + if is_sorted: if return_counts: unique_elements, inverse_indices, counts = _calc_unique_sorted_with_counts( From 8c74e19b25d87d053ec9c52fcddf2297da805018 Mon Sep 17 00:00:00 2001 From: Yanming Wang Date: Mon, 22 Feb 2021 15:53:44 -0800 Subject: [PATCH 09/16] Patch pytorch frontend --- python/tvm/relay/frontend/pytorch.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 88718a5e673c..fd30792b7625 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2167,6 +2167,9 @@ def is_floating_point(self, inputs, input_types): def unique(self, inputs, input_types): assert len(inputs) == 4 [data, is_sorted, return_inverse, return_counts] = inputs + if is_sorted == False: + logging.warning("TVM always assumes sorted=True for torch.unique") + is_sorted = True if return_counts: [unique, indices, num_uniq, counts] = _op.unique( data, is_sorted=is_sorted, return_counts=True From 1553d483bd69ba2260b4abc1742887bb53545b17 Mon Sep 17 00:00:00 2001 From: Yanming Wang Date: Tue, 23 Feb 2021 07:24:08 +0000 Subject: [PATCH 10/16] Initial support of topi.cuda.unique --- python/tvm/relay/op/strategy/cuda.py | 12 + python/tvm/topi/cuda/__init__.py | 1 + python/tvm/topi/cuda/unique.py | 287 ++++++++++++++++++ python/tvm/topi/unique.py | 12 +- tests/python/frontend/pytorch/test_forward.py | 2 +- tests/python/relay/test_op_level3.py | 32 +- tests/python/topi/python/test_topi_unique.py | 20 +- 7 files changed, 340 insertions(+), 26 deletions(-) create mode 100644 python/tvm/topi/cuda/unique.py diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index cb4688c4889e..3c2969a3ab58 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -1017,3 +1017,15 @@ def cumsum_strategy_cuda(attrs, inputs, out_type, target): name="cumsum.cuda", ) return strategy + + +@unique_strategy.register(["cuda", "gpu"]) +def unique_strategy_cuda(attrs, inputs, out_type, target): + """unique cuda strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_unique(topi.cuda.unique), + wrap_topi_schedule(topi.cuda.schedule_scan), + name="unique.cuda", + ) + return strategy diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py index bf3582c01d4f..560de1a48386 100644 --- a/python/tvm/topi/cuda/__init__.py +++ b/python/tvm/topi/cuda/__init__.py @@ -58,3 +58,4 @@ from . import tensorcore_alter_op from .argwhere import * from .scan import * +from .unique import * \ No newline at end of file diff --git a/python/tvm/topi/cuda/unique.py b/python/tvm/topi/cuda/unique.py new file mode 100644 index 000000000000..effee77666d3 --- /dev/null +++ b/python/tvm/topi/cuda/unique.py @@ -0,0 +1,287 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, no-else-return +"""Unique operator""" +from ...te import hybrid +from .scan import cumsum +from .sort import sort, argsort +from tvm import te +import tvm +from ..utils import ceil_div +from .nms import atomic_add + + +@hybrid.script +def _calc_adjacent_diff(data): + output = output_tensor(data.shape, "int32") + idx = allocate((1,), "int32", "local") + i_extent = min(data.shape[0], max_num_threads(False)) + j_extent = ceil_div(data.shape[0], i_extent) + for i in bind("threadIdx.x", i_extent): + for j in range(j_extent): + idx[0] = j * i_extent + i + if idx[0] == 0: + output[0] = int32(0) + elif idx[0] < data.shape[0]: + output[idx[0]] = int32(1) if data[idx[0]] != data[idx[0] - 1] else int32(0) + return output + + +@hybrid.script +def _calc_num_unique(data): + output = output_tensor((1,), "int32") + for i in bind("threadIdx.x", 1): + output[0] = data[data.shape[0] - 1] + int32(1) + return output + + +@hybrid.script +def _calc_unique_sorted(data, argsorted_indices, inc_scan): + unique_elements = output_tensor(data.shape, data.dtype) + indices = output_tensor(data.shape, "int32") + idx = allocate((1,), "int32", "local") + i_extent = min(data.shape[0], max_num_threads(False)) + j_extent = ceil_div(data.shape[0], i_extent) + for i in bind("threadIdx.x", i_extent): + for j in range(j_extent): + idx[0] = j * i_extent + i + if idx[0] < data.shape[0]: + indices[argsorted_indices[idx[0]]] = inc_scan[idx[0]] + if idx[0] == 0 or inc_scan[idx[0]] != inc_scan[idx[0] - 1]: + unique_elements[inc_scan[idx[0]]] = data[argsorted_indices[idx[0]]] + return unique_elements, indices + + +def _calc_counts_sorted_ir(inc_scan, counts): + ib = tvm.tir.ir_builder.create() + inc_scan_ptr = ib.buffer_ptr(inc_scan) + counts_ptr = ib.buffer_ptr(counts) + batch_size = inc_scan.shape[0] + max_threads = min(batch_size, tvm.target.Target.current(allow_none=False).max_num_threads) + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(batch_size, max_threads) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + with ib.if_scope(tid < batch_size): + counts_ptr[tid] = 0 + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(batch_size, max_threads) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + atomic_add_return = ib.allocate(counts.dtype, (1,), name="atomic_add_return", scope="local") + with ib.if_scope(tid < batch_size): + index = inc_scan_ptr[tid] + atomic_add_return[0] = tvm.tir.call_intrin( + counts.dtype, + "tir.atomic_add", + tvm.tir.call_intrin("handle", "tir.address_of", counts_ptr[index]), + 1, + ) + return ib.get() + + +@hybrid.script +def _calc_first_occurence(argsorted_indices, inc_scan): + first_occurence = output_tensor(argsorted_indices.shape, "int32") + idx = allocate((1,), "int32", "local") + i_extent = min(argsorted_indices.shape[0], max_num_threads(False)) + j_extent = ceil_div(argsorted_indices.shape[0], i_extent) + for i in bind("threadIdx.x", i_extent): + for j in range(j_extent): + idx[0] = j * i_extent + i + if idx[0] < argsorted_indices.shape[0]: + first_occurence[idx[0]] = argsorted_indices.shape[0] + for i in bind("threadIdx.x", i_extent): + for j in range(j_extent): + idx[0] = j * i_extent + i + if idx[0] < argsorted_indices.shape[0]: + if idx[0] == 0 or inc_scan[idx[0]] != inc_scan[idx[0] - 1]: + first_occurence[inc_scan[idx[0]]] = argsorted_indices[idx[0]] + return first_occurence + + +@hybrid.script +def _calc_unique_unsorted(data, argsorted_indices, inc_scan, index_converter): + unique_elements = output_tensor(data.shape, data.dtype) + indices = output_tensor(data.shape, "int32") + for i in parallel(data.shape[0]): + new_unique_idx = index_converter[inc_scan[i]] + new_data_idx = argsorted_indices[i] + indices[new_data_idx] = new_unique_idx + if i == 0 or inc_scan[i] != inc_scan[i - 1]: + unique_elements[new_unique_idx] = data[new_data_idx] + return unique_elements, indices + + +@hybrid.script +def _calc_unique_unsorted(data, argsorted_indices, inc_scan, index_converter): + unique_elements = output_tensor(data.shape, data.dtype) + indices = output_tensor(data.shape, "int32") + idx = allocate((1,), "int32", "local") + i_extent = min(data.shape[0], max_num_threads(False)) + j_extent = ceil_div(data.shape[0], i_extent) + for i in bind("threadIdx.x", i_extent): + for j in range(j_extent): + idx[0] = j * i_extent + i + if idx[0] < data.shape[0]: + indices[argsorted_indices[idx[0]]] = index_converter[inc_scan[idx[0]]] + if idx[0] == 0 or inc_scan[idx[0]] != inc_scan[idx[0] - 1]: + unique_elements[index_converter[inc_scan[idx[0]]]] = data[ + argsorted_indices[idx[0]] + ] + return unique_elements, indices + + +def _calc_counts_unsorted_ir(inc_scan, index_converter, counts): + ib = tvm.tir.ir_builder.create() + inc_scan_ptr = ib.buffer_ptr(inc_scan) + index_converter_ptr = ib.buffer_ptr(index_converter) + counts_ptr = ib.buffer_ptr(counts) + batch_size = inc_scan.shape[0] + max_threads = min(batch_size, tvm.target.Target.current(allow_none=False).max_num_threads) + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(batch_size, max_threads) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + with ib.if_scope(tid < batch_size): + counts_ptr[tid] = 0 + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(batch_size, max_threads) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + atomic_add_return = ib.allocate(counts.dtype, (1,), name="atomic_add_return", scope="local") + with ib.if_scope(tid < batch_size): + index = index_converter_ptr[inc_scan_ptr[tid]] + atomic_add_return[0] = tvm.tir.call_intrin( + counts.dtype, + "tir.atomic_add", + tvm.tir.call_intrin("handle", "tir.address_of", counts_ptr[index]), + 1, + ) + return ib.get() + + +def unique(data, is_sorted=True, return_counts=False): + """ + Find the unique elements of a tensor + Parameters + ---------- + data : relay.Expr + A 1-D tensor of integers + sorted : bool + Whether to sort the unique elements in ascending order before returning as output + return_counts : bool + Whether to return the array with count of each unique element + Returns + ------- + output : relay.Expr + A 1-D tensor containing the unique elements of the input data tensor + indices : relay.Expr + A 1-D tensor containing the index of each data element in the output tensor + num_unique : relay.Expr + A 0-D tensor containing the number of unique elements in the input data tensor + counts (optional) : relay.Expr + A 1-D tensor containing the count of each unique element in the output + Examples + -------- + .. code-block:: python + [output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], False, False) + output = [4, 5, 1, 2, 3, ?, ?, ?] + indices = [0, 1, 2, 3, 4, 4, 0, 1] + num_unique = [5] + + [output, indices, num_unique, counts] = unique([4, 5, 1, 2, 3, 3, 4, 5], False, True) + output = [4, 5, 1, 2, 3, ?, ?, ?] + indices = [0, 1, 2, 3, 4, 4, 0, 1] + num_unique = [5] + counts = [2, 2, 1, 1, 2, ?, ?, ?] + + [output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], True) + output = [1, 2, 3, 4, 5, ?, ?, ?] + indices = [3, 4, 0, 1, 2, 2, 3, 4] + num_unique = [5] + """ + + sorted_data = sort(data) + argsorted_indices = argsort(data, dtype="int32") + adjacent_diff = _calc_adjacent_diff(sorted_data) + inc_scan = cumsum(adjacent_diff, dtype="int32", exclusive=0) + num_unique_elements = _calc_num_unique(inc_scan) + if is_sorted: + unique_elements, inverse_indices = _calc_unique_sorted(data, argsorted_indices, inc_scan) + if not return_counts: + return [unique_elements, inverse_indices, num_unique_elements] + else: + inc_scan_buf = tvm.tir.decl_buffer( + data.shape, "int32", "inc_scan_buf", data_alignment=8 + ) + counts_buf = tvm.tir.decl_buffer(data.shape, "int32", "counts_buf", data_alignment=8) + counts = te.extern( + [data.shape], + [inc_scan], + lambda ins, outs: _calc_counts_sorted_ir(ins[0], outs[0]), + dtype=["int32"], + in_buffers=[inc_scan_buf], + out_buffers=[counts_buf], + name="calc_counts_sorted", + tag="calc_counts_sorted_gpu", + ) + return [unique_elements, inverse_indices, num_unique_elements, counts] + else: + first_occurence = _calc_first_occurence(argsorted_indices, inc_scan) + argsorted_first_occurence = argsort(first_occurence, dtype="int32") + index_converter = argsort(argsorted_first_occurence, dtype="int32") + unique_elements, inverse_indices = _calc_unique_unsorted( + data, argsorted_indices, inc_scan, index_converter + ) + if not return_counts: + return [unique_elements, inverse_indices, num_unique_elements] + else: + inc_scan_buf = tvm.tir.decl_buffer( + data.shape, "int32", "inc_scan_buf", data_alignment=8 + ) + index_converter_buf = tvm.tir.decl_buffer( + data.shape, "int32", "index_converter_buf", data_alignment=8 + ) + counts_buf = tvm.tir.decl_buffer(data.shape, "int32", "counts_buf", data_alignment=8) + counts = te.extern( + [data.shape], + [inc_scan, index_converter], + lambda ins, outs: _calc_counts_unsorted_ir(ins[0], ins[1], outs[0]), + dtype=["int32"], + in_buffers=[inc_scan_buf, index_converter_buf], + out_buffers=[counts_buf], + name="calc_counts_unsorted", + tag="calc_counts_unsorted_gpu", + ) + return [unique_elements, inverse_indices, num_unique_elements, counts] diff --git a/python/tvm/topi/unique.py b/python/tvm/topi/unique.py index b88b1d09e53e..86dd82a2bcef 100644 --- a/python/tvm/topi/unique.py +++ b/python/tvm/topi/unique.py @@ -43,7 +43,8 @@ def _calc_unique_sorted(data, argsorted_indices, inc_scan): indices = output_tensor(data.shape, "int32") for i in parallel(data.shape[0]): indices[argsorted_indices[i]] = inc_scan[i] - unique_elements[inc_scan[i]] = data[argsorted_indices[i]] + if i == 0 or inc_scan[i] != inc_scan[i - 1]: + unique_elements[inc_scan[i]] = data[argsorted_indices[i]] return unique_elements, indices @@ -56,7 +57,8 @@ def _calc_unique_sorted_with_counts(data, argsorted_indices, inc_scan): counts[i] = int32(0) for i in parallel(data.shape[0]): indices[argsorted_indices[i]] = inc_scan[i] - unique_elements[inc_scan[i]] = data[argsorted_indices[i]] + if i == 0 or inc_scan[i] != inc_scan[i - 1]: + unique_elements[inc_scan[i]] = data[argsorted_indices[i]] for i in range(data.shape[0]): counts[inc_scan[i]] += int32(1) return unique_elements, indices, counts @@ -80,8 +82,9 @@ def _calc_unique_unsorted(data, argsorted_indices, inc_scan, index_converter): for i in parallel(data.shape[0]): new_unique_idx = index_converter[inc_scan[i]] new_data_idx = argsorted_indices[i] - unique_elements[new_unique_idx] = data[new_data_idx] indices[new_data_idx] = new_unique_idx + if i == 0 or inc_scan[i] != inc_scan[i - 1]: + unique_elements[new_unique_idx] = data[new_data_idx] return unique_elements, indices @@ -95,8 +98,9 @@ def _calc_unique_unsorted_with_counts(data, argsorted_indices, inc_scan, index_c for i in parallel(data.shape[0]): new_unique_idx = index_converter[inc_scan[i]] new_data_idx = argsorted_indices[i] - unique_elements[new_unique_idx] = data[new_data_idx] indices[new_data_idx] = new_unique_idx + if i == 0 or inc_scan[i] != inc_scan[i - 1]: + unique_elements[new_unique_idx] = data[new_data_idx] for i in range(data.shape[0]): idx = index_converter[inc_scan[i]] counts[idx] += int32(1) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index e589171b0743..0cf4839c6ebb 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3664,7 +3664,7 @@ def test_fn(is_sorted, return_inverse, return_counts): return lambda x: torch.unique(x, is_sorted, return_inverse, return_counts) in_data = torch.randint(0, 20, (10,), dtype=torch.int32) - targets = ["llvm"] + targets = ["llvm", "cuda", "nvptx"] verify_trace_model(test_fn(True, True, True), [in_data], targets) verify_trace_model(test_fn(True, False, True), [in_data], targets) verify_trace_model(test_fn(True, True, False), [in_data], targets) diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 22d13b1151a0..0fecd368004c 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1482,21 +1482,23 @@ def verify_unique(n, dtype, is_dyn=False, is_sorted=False, return_counts=False): else: backends = ["graph", "debug"] - target, ctx = "llvm", tvm.cpu() - for kind in backends: - mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) - tvm_res = intrp.evaluate()(x_data) - np_res = calc_numpy_unique(x_data, is_sorted) - num_unique = np_res[3][0] - assert num_unique == tvm_res[2].asnumpy()[0] - # unique - tvm.testing.assert_allclose(tvm_res[0].asnumpy()[:num_unique], np_res[0], rtol=1e-5) - # inverse_indices - tvm.testing.assert_allclose(tvm_res[1].asnumpy(), np_res[1], rtol=1e-5) - # counts - if return_counts: - tvm.testing.assert_allclose(tvm_res[3].asnumpy()[:num_unique], np_res[2], rtol=1e-5) + for target, ctx in tvm.testing.enabled_targets(): + for kind in backends: + if is_dyn and ctx.device_type == 2: # skip dynamic shape on GPU + continue + mod = tvm.ir.IRModule.from_expr(func) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + tvm_res = intrp.evaluate()(x_data) + np_res = calc_numpy_unique(x_data, is_sorted) + num_unique = np_res[3][0] + assert num_unique == tvm_res[2].asnumpy()[0] + # unique + tvm.testing.assert_allclose(tvm_res[0].asnumpy()[:num_unique], np_res[0], rtol=1e-5) + # inverse_indices + tvm.testing.assert_allclose(tvm_res[1].asnumpy(), np_res[1], rtol=1e-5) + # counts + if return_counts: + tvm.testing.assert_allclose(tvm_res[3].asnumpy()[:num_unique], np_res[2], rtol=1e-5) for dtype in ["int32", "int64"]: for i in range(8): diff --git a/tests/python/topi/python/test_topi_unique.py b/tests/python/topi/python/test_topi_unique.py index 841083d20b14..f35aed0a60ce 100644 --- a/tests/python/topi/python/test_topi_unique.py +++ b/tests/python/topi/python/test_topi_unique.py @@ -21,10 +21,8 @@ import tvm.topi.testing -def test_unique(): - target = "llvm" - ctx = tvm.cpu() - +@tvm.testing.parametrize_targets +def test_unique(ctx, target): def calc_numpy_unique(data, is_sorted=False): uniq, index, inverse, counts = np.unique( data, return_index=True, return_inverse=True, return_counts=True @@ -48,6 +46,14 @@ def check_unique(data, is_sorted=False): lambda x, return_counts: topi.unique(x, is_sorted, return_counts), topi.generic.schedule_unique, ), + "cuda": ( + lambda x, return_counts: topi.cuda.unique(x, is_sorted, return_counts), + topi.cuda.schedule_scan, + ), + "nvptx": ( + lambda x, return_counts: topi.cuda.unique(x, is_sorted, return_counts), + topi.cuda.schedule_scan, + ), } fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations) tvm_data = tvm.nd.array(data, ctx=ctx) @@ -95,9 +101,11 @@ def check_unique(data, is_sorted=False): check_unique(data, is_sorted) data = np.random.randint(0, 100, size=(50)).astype(in_dtype) check_unique(data, is_sorted) - data = np.random.randint(0, 10, size=(100)).astype(in_dtype) + data = np.random.randint(0, 100, size=(10000)).astype(in_dtype) check_unique(data, is_sorted) if __name__ == "__main__": - test_unique() + test_unique(tvm.context("cpu"), tvm.target.Target("llvm")) + test_unique(tvm.context("cuda"), tvm.target.Target("cuda")) + test_unique(tvm.context("nvptx"), tvm.target.Target("nvptx")) From 14811bfaab6fa0797b960f4449c6bba1f5343ed4 Mon Sep 17 00:00:00 2001 From: Yanming Wang Date: Tue, 23 Feb 2021 20:30:17 +0000 Subject: [PATCH 11/16] Refactor to use ir_builder directly --- python/tvm/relay/frontend/pytorch.py | 2 +- python/tvm/relay/op/_transform.py | 2 +- python/tvm/topi/cuda/__init__.py | 2 +- python/tvm/topi/cuda/unique.py | 287 ++++++++++++++++++--------- tests/python/relay/test_op_level3.py | 6 +- 5 files changed, 198 insertions(+), 101 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index fd30792b7625..6be1b4648fb5 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2167,7 +2167,7 @@ def is_floating_point(self, inputs, input_types): def unique(self, inputs, input_types): assert len(inputs) == 4 [data, is_sorted, return_inverse, return_counts] = inputs - if is_sorted == False: + if not is_sorted: logging.warning("TVM always assumes sorted=True for torch.unique") is_sorted = True if return_counts: diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 5e6a66a47dae..e9cf3d83eaeb 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -145,7 +145,7 @@ def compute_cumsum(attrs, inputs, output_type): @_reg.register_compute("unique") def compute_unique(attrs, inputs, output_type): - """Compute definition of cumsum""" + """Compute definition of unique""" return topi.unique(inputs[0], attrs.sorted, attrs.return_counts) diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py index 560de1a48386..df75c676fad3 100644 --- a/python/tvm/topi/cuda/__init__.py +++ b/python/tvm/topi/cuda/__init__.py @@ -58,4 +58,4 @@ from . import tensorcore_alter_op from .argwhere import * from .scan import * -from .unique import * \ No newline at end of file +from .unique import * diff --git a/python/tvm/topi/cuda/unique.py b/python/tvm/topi/cuda/unique.py index effee77666d3..79a7b939560b 100644 --- a/python/tvm/topi/cuda/unique.py +++ b/python/tvm/topi/cuda/unique.py @@ -16,62 +16,83 @@ # under the License. # pylint: disable=invalid-name, no-else-return """Unique operator""" +from tvm import te, tir +import tvm + from ...te import hybrid from .scan import cumsum from .sort import sort, argsort -from tvm import te -import tvm from ..utils import ceil_div -from .nms import atomic_add -@hybrid.script -def _calc_adjacent_diff(data): - output = output_tensor(data.shape, "int32") - idx = allocate((1,), "int32", "local") - i_extent = min(data.shape[0], max_num_threads(False)) - j_extent = ceil_div(data.shape[0], i_extent) - for i in bind("threadIdx.x", i_extent): - for j in range(j_extent): - idx[0] = j * i_extent + i - if idx[0] == 0: - output[0] = int32(0) - elif idx[0] < data.shape[0]: - output[idx[0]] = int32(1) if data[idx[0]] != data[idx[0] - 1] else int32(0) - return output +def _calc_adjacent_diff_ir(data, adjacent_diff): + ib = tvm.tir.ir_builder.create() + data_ptr = ib.buffer_ptr(data) + adjacent_diff_ptr = ib.buffer_ptr(adjacent_diff) + batch_size = data.shape[0] + max_threads = tir.min(batch_size, tvm.target.Target.current(allow_none=False).max_num_threads) + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(batch_size, max_threads) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + with ib.if_scope(tid < batch_size): + with ib.if_scope(tid == 0): + adjacent_diff_ptr[tid] = 0 + with ib.else_scope(): + with ib.if_scope(data_ptr[tid] != data_ptr[tid - 1]): + adjacent_diff_ptr[tid] = 1 + with ib.else_scope(): + adjacent_diff_ptr[tid] = 0 + return ib.get() @hybrid.script def _calc_num_unique(data): output = output_tensor((1,), "int32") for i in bind("threadIdx.x", 1): - output[0] = data[data.shape[0] - 1] + int32(1) + output[i] = data[data.shape[0] - 1] + int32(1) return output -@hybrid.script -def _calc_unique_sorted(data, argsorted_indices, inc_scan): - unique_elements = output_tensor(data.shape, data.dtype) - indices = output_tensor(data.shape, "int32") - idx = allocate((1,), "int32", "local") - i_extent = min(data.shape[0], max_num_threads(False)) - j_extent = ceil_div(data.shape[0], i_extent) - for i in bind("threadIdx.x", i_extent): - for j in range(j_extent): - idx[0] = j * i_extent + i - if idx[0] < data.shape[0]: - indices[argsorted_indices[idx[0]]] = inc_scan[idx[0]] - if idx[0] == 0 or inc_scan[idx[0]] != inc_scan[idx[0] - 1]: - unique_elements[inc_scan[idx[0]]] = data[argsorted_indices[idx[0]]] - return unique_elements, indices +def _calc_unique_sorted_ir(data, argsorted_indices, inc_scan, unique_elements, indices): + ib = tvm.tir.ir_builder.create() + data_ptr = ib.buffer_ptr(data) + argsorted_indices_ptr = ib.buffer_ptr(argsorted_indices) + inc_scan_ptr = ib.buffer_ptr(inc_scan) + unique_elements_ptr = ib.buffer_ptr(unique_elements) + indices_ptr = ib.buffer_ptr(indices) + + batch_size = data.shape[0] + max_threads = tir.min(batch_size, tvm.target.Target.current(allow_none=False).max_num_threads) + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(batch_size, max_threads) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + with ib.if_scope(tid < batch_size): + indices_ptr[argsorted_indices_ptr[tid]] = inc_scan_ptr[tid] + with ib.if_scope(tid == 0): + unique_elements_ptr[inc_scan_ptr[tid]] = data_ptr[argsorted_indices_ptr[tid]] + with ib.else_scope(): + with ib.if_scope(inc_scan_ptr[tid] != inc_scan_ptr[tid - 1]): + unique_elements_ptr[inc_scan_ptr[tid]] = data_ptr[argsorted_indices_ptr[tid]] + return ib.get() def _calc_counts_sorted_ir(inc_scan, counts): ib = tvm.tir.ir_builder.create() inc_scan_ptr = ib.buffer_ptr(inc_scan) counts_ptr = ib.buffer_ptr(counts) + batch_size = inc_scan.shape[0] - max_threads = min(batch_size, tvm.target.Target.current(allow_none=False).max_num_threads) + max_threads = tir.min(batch_size, tvm.target.Target.current(allow_none=False).max_num_threads) with ib.new_scope(): nthread_tx = max_threads nthread_bx = ceil_div(batch_size, max_threads) @@ -102,56 +123,73 @@ def _calc_counts_sorted_ir(inc_scan, counts): return ib.get() -@hybrid.script -def _calc_first_occurence(argsorted_indices, inc_scan): - first_occurence = output_tensor(argsorted_indices.shape, "int32") - idx = allocate((1,), "int32", "local") - i_extent = min(argsorted_indices.shape[0], max_num_threads(False)) - j_extent = ceil_div(argsorted_indices.shape[0], i_extent) - for i in bind("threadIdx.x", i_extent): - for j in range(j_extent): - idx[0] = j * i_extent + i - if idx[0] < argsorted_indices.shape[0]: - first_occurence[idx[0]] = argsorted_indices.shape[0] - for i in bind("threadIdx.x", i_extent): - for j in range(j_extent): - idx[0] = j * i_extent + i - if idx[0] < argsorted_indices.shape[0]: - if idx[0] == 0 or inc_scan[idx[0]] != inc_scan[idx[0] - 1]: - first_occurence[inc_scan[idx[0]]] = argsorted_indices[idx[0]] - return first_occurence - +def _calc_first_occurence_ir(argsorted_indices, inc_scan, first_occurence): + ib = tvm.tir.ir_builder.create() + argsorted_indices_ptr = ib.buffer_ptr(argsorted_indices) + inc_scan_ptr = ib.buffer_ptr(inc_scan) + first_occurence_ptr = ib.buffer_ptr(first_occurence) + batch_size = argsorted_indices.shape[0] + max_threads = tir.min(batch_size, tvm.target.Target.current(allow_none=False).max_num_threads) + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(batch_size, max_threads) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + with ib.if_scope(tid < batch_size): + first_occurence_ptr[tid] = batch_size + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(batch_size, max_threads) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + with ib.if_scope(tid < batch_size): + with ib.if_scope(tid == 0): + first_occurence_ptr[inc_scan_ptr[tid]] = argsorted_indices_ptr[tid] + with ib.else_scope(): + with ib.if_scope(inc_scan_ptr[tid] != inc_scan_ptr[tid - 1]): + first_occurence_ptr[inc_scan_ptr[tid]] = argsorted_indices_ptr[tid] + return ib.get() -@hybrid.script -def _calc_unique_unsorted(data, argsorted_indices, inc_scan, index_converter): - unique_elements = output_tensor(data.shape, data.dtype) - indices = output_tensor(data.shape, "int32") - for i in parallel(data.shape[0]): - new_unique_idx = index_converter[inc_scan[i]] - new_data_idx = argsorted_indices[i] - indices[new_data_idx] = new_unique_idx - if i == 0 or inc_scan[i] != inc_scan[i - 1]: - unique_elements[new_unique_idx] = data[new_data_idx] - return unique_elements, indices +def _calc_unique_unsorted_ir( + data, argsorted_indices, inc_scan, index_converter, unique_elements, indices +): + ib = tvm.tir.ir_builder.create() + data_ptr = ib.buffer_ptr(data) + argsorted_indices_ptr = ib.buffer_ptr(argsorted_indices) + inc_scan_ptr = ib.buffer_ptr(inc_scan) + index_converter_ptr = ib.buffer_ptr(index_converter) + unique_elements_ptr = ib.buffer_ptr(unique_elements) + indices_ptr = ib.buffer_ptr(indices) -@hybrid.script -def _calc_unique_unsorted(data, argsorted_indices, inc_scan, index_converter): - unique_elements = output_tensor(data.shape, data.dtype) - indices = output_tensor(data.shape, "int32") - idx = allocate((1,), "int32", "local") - i_extent = min(data.shape[0], max_num_threads(False)) - j_extent = ceil_div(data.shape[0], i_extent) - for i in bind("threadIdx.x", i_extent): - for j in range(j_extent): - idx[0] = j * i_extent + i - if idx[0] < data.shape[0]: - indices[argsorted_indices[idx[0]]] = index_converter[inc_scan[idx[0]]] - if idx[0] == 0 or inc_scan[idx[0]] != inc_scan[idx[0] - 1]: - unique_elements[index_converter[inc_scan[idx[0]]]] = data[ - argsorted_indices[idx[0]] + batch_size = data.shape[0] + max_threads = tir.min(batch_size, tvm.target.Target.current(allow_none=False).max_num_threads) + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(batch_size, max_threads) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + with ib.if_scope(tid < batch_size): + indices_ptr[argsorted_indices_ptr[tid]] = index_converter_ptr[inc_scan_ptr[tid]] + with ib.if_scope(tid == 0): + unique_elements_ptr[index_converter_ptr[inc_scan_ptr[tid]]] = data_ptr[ + argsorted_indices_ptr[tid] + ] + with ib.else_scope(): + with ib.if_scope(inc_scan_ptr[tid] != inc_scan_ptr[tid - 1]): + unique_elements_ptr[index_converter_ptr[inc_scan_ptr[tid]]] = data_ptr[ + argsorted_indices_ptr[tid] ] - return unique_elements, indices + return ib.get() def _calc_counts_unsorted_ir(inc_scan, index_converter, counts): @@ -159,8 +197,9 @@ def _calc_counts_unsorted_ir(inc_scan, index_converter, counts): inc_scan_ptr = ib.buffer_ptr(inc_scan) index_converter_ptr = ib.buffer_ptr(index_converter) counts_ptr = ib.buffer_ptr(counts) + batch_size = inc_scan.shape[0] - max_threads = min(batch_size, tvm.target.Target.current(allow_none=False).max_num_threads) + max_threads = tir.min(batch_size, tvm.target.Target.current(allow_none=False).max_num_threads) with ib.new_scope(): nthread_tx = max_threads nthread_bx = ceil_div(batch_size, max_threads) @@ -234,17 +273,55 @@ def unique(data, is_sorted=True, return_counts=False): sorted_data = sort(data) argsorted_indices = argsort(data, dtype="int32") - adjacent_diff = _calc_adjacent_diff(sorted_data) + # calculate adjacent difference + sorted_data_buf = tvm.tir.decl_buffer( + data.shape, data.dtype, "sorted_data_buf", data_alignment=8 + ) + adjacent_diff_buf = tvm.tir.decl_buffer( + data.shape, "int32", "adjacent_diff_buf", data_alignment=8 + ) + adjacent_diff = te.extern( + [data.shape], + [sorted_data], + lambda ins, outs: _calc_adjacent_diff_ir(ins[0], outs[0]), + dtype=["int32"], + in_buffers=[sorted_data_buf], + out_buffers=[adjacent_diff_buf], + name="_calc_adjacent_diff", + tag="_calc_adjacent_diff_gpu", + ) + # calculate inclusive scan inc_scan = cumsum(adjacent_diff, dtype="int32", exclusive=0) + # calculate number of unique elements num_unique_elements = _calc_num_unique(inc_scan) + # declare buffers + data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) + argsorted_indices_buf = tvm.tir.decl_buffer( + data.shape, "int32", "argsorted_indices_buf", data_alignment=8 + ) + inc_scan_buf = tvm.tir.decl_buffer(data.shape, "int32", "inc_scan_buf", data_alignment=8) + unique_elements_buf = tvm.tir.decl_buffer( + data.shape, data.dtype, "unique_elements_buf", data_alignment=8 + ) + inverse_indices_buf = tvm.tir.decl_buffer( + data.shape, "int32", "inverse_indices_buf", data_alignment=8 + ) if is_sorted: - unique_elements, inverse_indices = _calc_unique_sorted(data, argsorted_indices, inc_scan) + # calculate unique elements and inverse indices + unique_elements, inverse_indices = te.extern( + [data.shape, data.shape], + [data, argsorted_indices, inc_scan], + lambda ins, outs: _calc_unique_sorted_ir(*ins, *outs), + dtype=[data.dtype, "int32"], + in_buffers=[data_buf, argsorted_indices_buf, inc_scan_buf], + out_buffers=[unique_elements_buf, inverse_indices_buf], + name="_calc_unique_sorted", + tag="_calc_unique_sorted_gpu", + ) if not return_counts: return [unique_elements, inverse_indices, num_unique_elements] else: - inc_scan_buf = tvm.tir.decl_buffer( - data.shape, "int32", "inc_scan_buf", data_alignment=8 - ) + # calculate counts of unique elements counts_buf = tvm.tir.decl_buffer(data.shape, "int32", "counts_buf", data_alignment=8) counts = te.extern( [data.shape], @@ -258,21 +335,41 @@ def unique(data, is_sorted=True, return_counts=False): ) return [unique_elements, inverse_indices, num_unique_elements, counts] else: - first_occurence = _calc_first_occurence(argsorted_indices, inc_scan) + # calculate first occurence + first_occurence_buf = tvm.tir.decl_buffer( + data.shape, "int32", "first_occurence_buf", data_alignment=8 + ) + first_occurence = te.extern( + [data.shape], + [argsorted_indices, inc_scan], + lambda ins, outs: _calc_first_occurence_ir(ins[0], ins[1], outs[0]), + dtype=["int32"], + in_buffers=[argsorted_indices_buf, inc_scan_buf], + out_buffers=[first_occurence_buf], + name="_calc_first_occurence", + tag="_calc_first_occurence_gpu", + ) + # calculate index converter by sorting unique elements by their first occurence argsorted_first_occurence = argsort(first_occurence, dtype="int32") index_converter = argsort(argsorted_first_occurence, dtype="int32") - unique_elements, inverse_indices = _calc_unique_unsorted( - data, argsorted_indices, inc_scan, index_converter + # calculate unique elements and inverse indices + index_converter_buf = tvm.tir.decl_buffer( + data.shape, "int32", "index_converter_buf", data_alignment=8 + ) + unique_elements, inverse_indices = te.extern( + [data.shape, data.shape], + [data, argsorted_indices, inc_scan, index_converter], + lambda ins, outs: _calc_unique_unsorted_ir(*ins, *outs), + dtype=[data.dtype, "int32"], + in_buffers=[data_buf, argsorted_indices_buf, inc_scan_buf, index_converter_buf], + out_buffers=[unique_elements_buf, inverse_indices_buf], + name="_calc_unique_unsorted", + tag="_calc_unique_unsorted_gpu", ) if not return_counts: return [unique_elements, inverse_indices, num_unique_elements] else: - inc_scan_buf = tvm.tir.decl_buffer( - data.shape, "int32", "inc_scan_buf", data_alignment=8 - ) - index_converter_buf = tvm.tir.decl_buffer( - data.shape, "int32", "index_converter_buf", data_alignment=8 - ) + # calculate counts of unique elements counts_buf = tvm.tir.decl_buffer(data.shape, "int32", "counts_buf", data_alignment=8) counts = te.extern( [data.shape], diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 0fecd368004c..ee55b532218d 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1484,8 +1484,6 @@ def verify_unique(n, dtype, is_dyn=False, is_sorted=False, return_counts=False): for target, ctx in tvm.testing.enabled_targets(): for kind in backends: - if is_dyn and ctx.device_type == 2: # skip dynamic shape on GPU - continue mod = tvm.ir.IRModule.from_expr(func) intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) tvm_res = intrp.evaluate()(x_data) @@ -1498,7 +1496,9 @@ def verify_unique(n, dtype, is_dyn=False, is_sorted=False, return_counts=False): tvm.testing.assert_allclose(tvm_res[1].asnumpy(), np_res[1], rtol=1e-5) # counts if return_counts: - tvm.testing.assert_allclose(tvm_res[3].asnumpy()[:num_unique], np_res[2], rtol=1e-5) + tvm.testing.assert_allclose( + tvm_res[3].asnumpy()[:num_unique], np_res[2], rtol=1e-5 + ) for dtype in ["int32", "int64"]: for i in range(8): From b519bc4ff1fdb1e5266de8c62fd5e7cfd9b1dcbf Mon Sep 17 00:00:00 2001 From: Yanming Wang Date: Wed, 24 Feb 2021 21:21:17 +0000 Subject: [PATCH 12/16] Modularize adjacent difference --- python/tvm/topi/cuda/unique.py | 83 ++++++++++++++++++++++++---------- 1 file changed, 59 insertions(+), 24 deletions(-) diff --git a/python/tvm/topi/cuda/unique.py b/python/tvm/topi/cuda/unique.py index 79a7b939560b..250d8b9e8c58 100644 --- a/python/tvm/topi/cuda/unique.py +++ b/python/tvm/topi/cuda/unique.py @@ -25,10 +25,25 @@ from ..utils import ceil_div -def _calc_adjacent_diff_ir(data, adjacent_diff): +def _calc_adjacent_diff_ir(data, output, binop=tir.Sub): + """Low level IR to calculate adjacent difference in an 1-D array + + Parameters + ---------- + data : Buffer + Input 1-D Buffer. + + output: Buffer + A buffer to store adjacent difference, of the same shape as data. The adjacent difference is defined as: + output[0] = 0, output[i] = binop(data[i], data[i-1]) where i > 0 and i < len(data). + + binop: function, optional + A binary associative op to use for calculating adjacent difference. The function takes two TIR expressions + and produce a new TIR expression. By default it uses tvm.tir.Sub to compute the adjacent difference. + """ ib = tvm.tir.ir_builder.create() data_ptr = ib.buffer_ptr(data) - adjacent_diff_ptr = ib.buffer_ptr(adjacent_diff) + output_ptr = ib.buffer_ptr(output) batch_size = data.shape[0] max_threads = tir.min(batch_size, tvm.target.Target.current(allow_none=False).max_num_threads) with ib.new_scope(): @@ -41,17 +56,52 @@ def _calc_adjacent_diff_ir(data, adjacent_diff): tid = bx * max_threads + tx with ib.if_scope(tid < batch_size): with ib.if_scope(tid == 0): - adjacent_diff_ptr[tid] = 0 + output_ptr[tid] = 0 with ib.else_scope(): - with ib.if_scope(data_ptr[tid] != data_ptr[tid - 1]): - adjacent_diff_ptr[tid] = 1 - with ib.else_scope(): - adjacent_diff_ptr[tid] = 0 + output_ptr[tid] = tir.Cast(output.dtype, binop(data_ptr[tid], data_ptr[tid - 1])) return ib.get() +def _calc_adjacent_diff(data, out_dtype="int32", binop=tir.Sub): + """Function calculate adjacent difference in an 1-D array + + Parameters + ---------- + data : tvm.te.Tensor + Input 1-D tensor. + + output_dtype : str + The output tensor data type. + + binop: function, optional + A binary associative op to use for calculating difference. The function takes two TIR expressions + and produce a new TIR expression. By default it uses tvm.tir.Sub to compute the adjacent difference. + + Returns + ------- + output : tvm.te.Tensor + 1-D tensor storing the adjacent difference of the input tensor. The adjacent difference is defined as: + output[0] = 0, output[i] = binop(data[i], data[i-1]) where i > 0 and i < len(data). + """ + data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "sorted_data_buf", data_alignment=8) + output_buf = tvm.tir.decl_buffer(data.shape, out_dtype, "output_buf", data_alignment=8) + output = te.extern( + [data.shape], + [data], + lambda ins, outs: _calc_adjacent_diff_ir(ins[0], outs[0], binop=binop), + dtype=[out_dtype], + in_buffers=[data_buf], + out_buffers=[output_buf], + name="_calc_adjacent_diff", + tag="_calc_adjacent_diff_gpu", + ) + return output + + @hybrid.script def _calc_num_unique(data): + """Function to get the last element of a 1-D tensor + """ output = output_tensor((1,), "int32") for i in bind("threadIdx.x", 1): output[i] = data[data.shape[0] - 1] + int32(1) @@ -274,25 +324,10 @@ def unique(data, is_sorted=True, return_counts=False): sorted_data = sort(data) argsorted_indices = argsort(data, dtype="int32") # calculate adjacent difference - sorted_data_buf = tvm.tir.decl_buffer( - data.shape, data.dtype, "sorted_data_buf", data_alignment=8 - ) - adjacent_diff_buf = tvm.tir.decl_buffer( - data.shape, "int32", "adjacent_diff_buf", data_alignment=8 - ) - adjacent_diff = te.extern( - [data.shape], - [sorted_data], - lambda ins, outs: _calc_adjacent_diff_ir(ins[0], outs[0]), - dtype=["int32"], - in_buffers=[sorted_data_buf], - out_buffers=[adjacent_diff_buf], - name="_calc_adjacent_diff", - tag="_calc_adjacent_diff_gpu", - ) + adjacent_diff = _calc_adjacent_diff(sorted_data, out_dtype="int32", binop=tir.NE) # calculate inclusive scan inc_scan = cumsum(adjacent_diff, dtype="int32", exclusive=0) - # calculate number of unique elements + # calculate total number of unique elements num_unique_elements = _calc_num_unique(inc_scan) # declare buffers data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) From 822e28e4c2e93fe21e42734c48b59d29053bea00 Mon Sep 17 00:00:00 2001 From: Yanming Wang Date: Thu, 25 Feb 2021 02:27:52 +0000 Subject: [PATCH 13/16] Refactor to simplify --- python/tvm/relay/op/transform.py | 27 +- python/tvm/topi/cuda/unique.py | 397 +++++++++---------- python/tvm/topi/unique.py | 306 +++++++++----- tests/python/topi/python/test_topi_unique.py | 10 +- 4 files changed, 418 insertions(+), 322 deletions(-) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index cc766bbdf9bd..c0a0d31478ef 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1467,25 +1467,34 @@ def cumsum(data, axis=None, dtype=None, exclusive=None): def unique(data, is_sorted=True, return_counts=False): """ - Find the unique elements of a tensor + Find the unique elements of a 1-D tensor. Please note `output` and `counts` are all padded to + have the same length of `data` and element with index >= num_unique[0] has undefined value. + Parameters ---------- data : relay.Expr - A 1-D tensor of integers + A 1-D tensor of integers. + sorted : bool - Whether to sort the unique elements in ascending order before returning as output + Whether to sort the unique elements in ascending order before returning as output. + return_counts : bool - Whether to return the array with count of each unique element + Whether to return the count of each unique element. + Returns ------- output : relay.Expr - A 1-D tensor containing the unique elements of the input data tensor + A 1-D tensor containing the unique elements of the input data tensor. + indices : relay.Expr - A 1-D tensor containing the index of each data element in the output tensor + A 1-D tensor containing the index of each data element in the output tensor. + num_unique : relay.Expr - A 0-D tensor containing the number of unique elements in the input data tensor + A 1-D tensor with size=1 containing the number of unique elements in the input data tensor. + counts (optional) : relay.Expr - A 1-D tensor containing the count of each unique element in the output + A 1-D tensor containing the count of each unique element in the output. + Examples -------- .. code-block:: python @@ -1500,7 +1509,7 @@ def unique(data, is_sorted=True, return_counts=False): num_unique = [5] counts = [2, 2, 1, 1, 2, ?, ?, ?] - [output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], sorted=True) + [output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], True) output = [1, 2, 3, 4, 5, ?, ?, ?] indices = [3, 4, 0, 1, 2, 2, 3, 4] num_unique = [5] diff --git a/python/tvm/topi/cuda/unique.py b/python/tvm/topi/cuda/unique.py index 250d8b9e8c58..b57176162e05 100644 --- a/python/tvm/topi/cuda/unique.py +++ b/python/tvm/topi/cuda/unique.py @@ -14,11 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name, no-else-return +# pylint: disable=invalid-name """Unique operator""" -from tvm import te, tir import tvm - +from tvm import te, tir from ...te import hybrid from .scan import cumsum from .sort import sort, argsort @@ -26,7 +25,7 @@ def _calc_adjacent_diff_ir(data, output, binop=tir.Sub): - """Low level IR to calculate adjacent difference in an 1-D array + """Low level IR to calculate adjacent difference in an 1-D array. Parameters ---------- @@ -34,14 +33,16 @@ def _calc_adjacent_diff_ir(data, output, binop=tir.Sub): Input 1-D Buffer. output: Buffer - A buffer to store adjacent difference, of the same shape as data. The adjacent difference is defined as: - output[0] = 0, output[i] = binop(data[i], data[i-1]) where i > 0 and i < len(data). + A buffer to store adjacent difference, of the same shape as data. The adjacent difference + is defined as: output[0] = 0, output[i] = binop(data[i], data[i-1]) + where i > 0 and i < len(data). binop: function, optional - A binary associative op to use for calculating adjacent difference. The function takes two TIR expressions - and produce a new TIR expression. By default it uses tvm.tir.Sub to compute the adjacent difference. + A binary associative op to use for calculating adjacent difference. The function takes two + TIR expressions and produce a new TIR expression. By default it uses tvm.tir.Sub to + compute the adjacent difference. """ - ib = tvm.tir.ir_builder.create() + ib = tir.ir_builder.create() data_ptr = ib.buffer_ptr(data) output_ptr = ib.buffer_ptr(output) batch_size = data.shape[0] @@ -63,7 +64,7 @@ def _calc_adjacent_diff_ir(data, output, binop=tir.Sub): def _calc_adjacent_diff(data, out_dtype="int32", binop=tir.Sub): - """Function calculate adjacent difference in an 1-D array + """Function calculate adjacent difference in an 1-D array. Parameters ---------- @@ -74,18 +75,20 @@ def _calc_adjacent_diff(data, out_dtype="int32", binop=tir.Sub): The output tensor data type. binop: function, optional - A binary associative op to use for calculating difference. The function takes two TIR expressions - and produce a new TIR expression. By default it uses tvm.tir.Sub to compute the adjacent difference. + A binary associative op to use for calculating difference. The function takes two + TIR expressions and produce a new TIR expression. By default it uses tvm.tir.Sub to + compute the adjacent difference. Returns ------- output : tvm.te.Tensor - 1-D tensor storing the adjacent difference of the input tensor. The adjacent difference is defined as: - output[0] = 0, output[i] = binop(data[i], data[i-1]) where i > 0 and i < len(data). + 1-D tensor storing the adjacent difference of the input tensor. The adjacent difference + is defined as: output[0] = 0, output[i] = binop(data[i], data[i-1]) + where i > 0 and i < len(data). """ - data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "sorted_data_buf", data_alignment=8) - output_buf = tvm.tir.decl_buffer(data.shape, out_dtype, "output_buf", data_alignment=8) - output = te.extern( + data_buf = tir.decl_buffer(data.shape, data.dtype, "sorted_data_buf", data_alignment=8) + output_buf = tir.decl_buffer(data.shape, out_dtype, "output_buf", data_alignment=8) + return te.extern( [data.shape], [data], lambda ins, outs: _calc_adjacent_diff_ir(ins[0], outs[0], binop=binop), @@ -95,29 +98,67 @@ def _calc_adjacent_diff(data, out_dtype="int32", binop=tir.Sub): name="_calc_adjacent_diff", tag="_calc_adjacent_diff_gpu", ) - return output @hybrid.script -def _calc_num_unique(data): - """Function to get the last element of a 1-D tensor - """ +def _calc_num_unique(inc_scan): + """Helper function to get the number of unique elements fron inc_scan tensor""" output = output_tensor((1,), "int32") for i in bind("threadIdx.x", 1): - output[i] = data[data.shape[0] - 1] + int32(1) + output[i] = inc_scan[inc_scan.shape[0] - 1] + int32(1) return output -def _calc_unique_sorted_ir(data, argsorted_indices, inc_scan, unique_elements, indices): - ib = tvm.tir.ir_builder.create() +def _calc_unique_ir( + data, argsorted_indices, inc_scan, index_converter, unique_elements, indices, counts +): + """Low level IR to calculate unique elements, inverse indices, and counts (optional) of + unique elements of 1-D array. + + Parameters + ---------- + data : Buffer + Input 1-D Buffer. + + argsorted_indices : Buffer + A buffer that stores the argsorted indices of the input data. + + inc_scan : Buffer + A buffer that stores the inclusive scan of the binary tir.NE adjacent difference + of the sorted data. + + index_converter (optional) : Buffer + An optional index converter that transforms the unique element index + such that new_idx = index_converter[old_idx]. + + unique_elements : Buffer + A buffer that stores the unique elements. + + indices : Buffer + A buffer that stores the the index of each input data element in the unique element array. + + counts (optional) : Buffer + A buffer that stores the count of each unique element. + """ + ib = tir.ir_builder.create() data_ptr = ib.buffer_ptr(data) argsorted_indices_ptr = ib.buffer_ptr(argsorted_indices) inc_scan_ptr = ib.buffer_ptr(inc_scan) unique_elements_ptr = ib.buffer_ptr(unique_elements) indices_ptr = ib.buffer_ptr(indices) + index_converter_ptr = None + if isinstance(index_converter, tir.Buffer): + index_converter_ptr = ib.buffer_ptr(index_converter) + + if isinstance(counts, tir.Buffer): + counts_ptr = ib.buffer_ptr(counts) + arange_ptr = ib.allocate(counts_ptr.dtype, counts.shape, name="arange_buf", scope="global") + batch_size = data.shape[0] max_threads = tir.min(batch_size, tvm.target.Target.current(allow_none=False).max_num_threads) + + # calculate unique elements and inverse indices with ib.new_scope(): nthread_tx = max_threads nthread_bx = ceil_div(batch_size, max_threads) @@ -127,54 +168,70 @@ def _calc_unique_sorted_ir(data, argsorted_indices, inc_scan, unique_elements, i ib.scope_attr(bx, "thread_extent", nthread_bx) tid = bx * max_threads + tx with ib.if_scope(tid < batch_size): - indices_ptr[argsorted_indices_ptr[tid]] = inc_scan_ptr[tid] + data_idx = argsorted_indices_ptr[tid] + unique_idx = ( + inc_scan_ptr[tid] + if not index_converter_ptr + else index_converter_ptr[inc_scan_ptr[tid]] + ) + indices_ptr[data_idx] = unique_idx with ib.if_scope(tid == 0): - unique_elements_ptr[inc_scan_ptr[tid]] = data_ptr[argsorted_indices_ptr[tid]] + unique_elements_ptr[unique_idx] = data_ptr[data_idx] with ib.else_scope(): with ib.if_scope(inc_scan_ptr[tid] != inc_scan_ptr[tid - 1]): - unique_elements_ptr[inc_scan_ptr[tid]] = data_ptr[argsorted_indices_ptr[tid]] + unique_elements_ptr[unique_idx] = data_ptr[data_idx] + + # if need to return counts + if isinstance(counts, tir.Buffer): + num_unique = inc_scan_ptr[inc_scan.shape[0] - 1] + 1 + num_elements = data.shape[0] + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(batch_size, max_threads) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + with ib.if_scope(tid < batch_size): + with ib.if_scope(tid == 0): + arange_ptr[num_unique - 1] = num_elements + with ib.else_scope(): + with ib.if_scope(inc_scan_ptr[tid] != inc_scan_ptr[tid - 1]): + arange_ptr[inc_scan_ptr[tid] - 1] = tid + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(batch_size, max_threads) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + with ib.if_scope(tid < num_unique): + unique_idx = tid if not index_converter_ptr else index_converter_ptr[tid] + with ib.if_scope(tid == 0): + counts_ptr[unique_idx] = arange_ptr[tid] + with ib.else_scope(): + counts_ptr[unique_idx] = arange_ptr[tid] - arange_ptr[tid - 1] return ib.get() -def _calc_counts_sorted_ir(inc_scan, counts): - ib = tvm.tir.ir_builder.create() - inc_scan_ptr = ib.buffer_ptr(inc_scan) - counts_ptr = ib.buffer_ptr(counts) +def _calc_first_occurence_ir(argsorted_indices, inc_scan, first_occurence): + """Low level IR to calculate the first occurence of each unique element in the input data. - batch_size = inc_scan.shape[0] - max_threads = tir.min(batch_size, tvm.target.Target.current(allow_none=False).max_num_threads) - with ib.new_scope(): - nthread_tx = max_threads - nthread_bx = ceil_div(batch_size, max_threads) - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - tid = bx * max_threads + tx - with ib.if_scope(tid < batch_size): - counts_ptr[tid] = 0 - with ib.new_scope(): - nthread_tx = max_threads - nthread_bx = ceil_div(batch_size, max_threads) - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - tid = bx * max_threads + tx - atomic_add_return = ib.allocate(counts.dtype, (1,), name="atomic_add_return", scope="local") - with ib.if_scope(tid < batch_size): - index = inc_scan_ptr[tid] - atomic_add_return[0] = tvm.tir.call_intrin( - counts.dtype, - "tir.atomic_add", - tvm.tir.call_intrin("handle", "tir.address_of", counts_ptr[index]), - 1, - ) - return ib.get() + Parameters + ---------- + argsorted_indices : Buffer + A buffer that stores the argsorted indices of the input data. + inc_scan : Buffer + A buffer that stores the inclusive scan of the binary tir.NE adjacent difference + of the sorted data. -def _calc_first_occurence_ir(argsorted_indices, inc_scan, first_occurence): - ib = tvm.tir.ir_builder.create() + first_occurence : Buffer + A buffer that stores the first occurence of each unique element in the input data. + """ + ib = tir.ir_builder.create() argsorted_indices_ptr = ib.buffer_ptr(argsorted_indices) inc_scan_ptr = ib.buffer_ptr(inc_scan) first_occurence_ptr = ib.buffer_ptr(first_occurence) @@ -207,100 +264,36 @@ def _calc_first_occurence_ir(argsorted_indices, inc_scan, first_occurence): return ib.get() -def _calc_unique_unsorted_ir( - data, argsorted_indices, inc_scan, index_converter, unique_elements, indices -): - ib = tvm.tir.ir_builder.create() - data_ptr = ib.buffer_ptr(data) - argsorted_indices_ptr = ib.buffer_ptr(argsorted_indices) - inc_scan_ptr = ib.buffer_ptr(inc_scan) - index_converter_ptr = ib.buffer_ptr(index_converter) - unique_elements_ptr = ib.buffer_ptr(unique_elements) - indices_ptr = ib.buffer_ptr(indices) - - batch_size = data.shape[0] - max_threads = tir.min(batch_size, tvm.target.Target.current(allow_none=False).max_num_threads) - with ib.new_scope(): - nthread_tx = max_threads - nthread_bx = ceil_div(batch_size, max_threads) - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - tid = bx * max_threads + tx - with ib.if_scope(tid < batch_size): - indices_ptr[argsorted_indices_ptr[tid]] = index_converter_ptr[inc_scan_ptr[tid]] - with ib.if_scope(tid == 0): - unique_elements_ptr[index_converter_ptr[inc_scan_ptr[tid]]] = data_ptr[ - argsorted_indices_ptr[tid] - ] - with ib.else_scope(): - with ib.if_scope(inc_scan_ptr[tid] != inc_scan_ptr[tid - 1]): - unique_elements_ptr[index_converter_ptr[inc_scan_ptr[tid]]] = data_ptr[ - argsorted_indices_ptr[tid] - ] - return ib.get() - - -def _calc_counts_unsorted_ir(inc_scan, index_converter, counts): - ib = tvm.tir.ir_builder.create() - inc_scan_ptr = ib.buffer_ptr(inc_scan) - index_converter_ptr = ib.buffer_ptr(index_converter) - counts_ptr = ib.buffer_ptr(counts) - - batch_size = inc_scan.shape[0] - max_threads = tir.min(batch_size, tvm.target.Target.current(allow_none=False).max_num_threads) - with ib.new_scope(): - nthread_tx = max_threads - nthread_bx = ceil_div(batch_size, max_threads) - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - tid = bx * max_threads + tx - with ib.if_scope(tid < batch_size): - counts_ptr[tid] = 0 - with ib.new_scope(): - nthread_tx = max_threads - nthread_bx = ceil_div(batch_size, max_threads) - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - tid = bx * max_threads + tx - atomic_add_return = ib.allocate(counts.dtype, (1,), name="atomic_add_return", scope="local") - with ib.if_scope(tid < batch_size): - index = index_converter_ptr[inc_scan_ptr[tid]] - atomic_add_return[0] = tvm.tir.call_intrin( - counts.dtype, - "tir.atomic_add", - tvm.tir.call_intrin("handle", "tir.address_of", counts_ptr[index]), - 1, - ) - return ib.get() - - def unique(data, is_sorted=True, return_counts=False): """ - Find the unique elements of a tensor + Find the unique elements of a 1-D tensor. Please note `output` and `counts` are all padded to + have the same length of `data` and element with index >= num_unique[0] has undefined value. + Parameters ---------- - data : relay.Expr - A 1-D tensor of integers + data : tvm.te.Tensor + A 1-D tensor of integers. + sorted : bool - Whether to sort the unique elements in ascending order before returning as output + Whether to sort the unique elements in ascending order before returning as output. + return_counts : bool - Whether to return the array with count of each unique element + Whether to return the count of each unique element. + Returns ------- - output : relay.Expr - A 1-D tensor containing the unique elements of the input data tensor - indices : relay.Expr - A 1-D tensor containing the index of each data element in the output tensor - num_unique : relay.Expr - A 0-D tensor containing the number of unique elements in the input data tensor - counts (optional) : relay.Expr - A 1-D tensor containing the count of each unique element in the output + output : tvm.te.Tensor + A 1-D tensor containing the unique elements of the input data tensor. + + indices : tvm.te.Tensor + A 1-D tensor containing the index of each data element in the output tensor. + + num_unique : tvm.te.Tensor + A 1-D tensor with size=1 containing the number of unique elements in the input data tensor. + + counts (optional) : tvm.te.Tensor + A 1-D tensor containing the count of each unique element in the output. + Examples -------- .. code-block:: python @@ -320,58 +313,48 @@ def unique(data, is_sorted=True, return_counts=False): indices = [3, 4, 0, 1, 2, 2, 3, 4] num_unique = [5] """ - sorted_data = sort(data) argsorted_indices = argsort(data, dtype="int32") - # calculate adjacent difference + # adjacent difference adjacent_diff = _calc_adjacent_diff(sorted_data, out_dtype="int32", binop=tir.NE) - # calculate inclusive scan + # inclusive scan inc_scan = cumsum(adjacent_diff, dtype="int32", exclusive=0) - # calculate total number of unique elements + # total number of unique elements num_unique_elements = _calc_num_unique(inc_scan) - # declare buffers - data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) - argsorted_indices_buf = tvm.tir.decl_buffer( + # buffers + data_buf = tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) + argsorted_indices_buf = tir.decl_buffer( data.shape, "int32", "argsorted_indices_buf", data_alignment=8 ) inc_scan_buf = tvm.tir.decl_buffer(data.shape, "int32", "inc_scan_buf", data_alignment=8) - unique_elements_buf = tvm.tir.decl_buffer( + unique_elements_buf = tir.decl_buffer( data.shape, data.dtype, "unique_elements_buf", data_alignment=8 ) inverse_indices_buf = tvm.tir.decl_buffer( data.shape, "int32", "inverse_indices_buf", data_alignment=8 ) + # prepare outputs + if return_counts: + counts_buf = tir.decl_buffer(data.shape, "int32", "counts_buf", data_alignment=8) + out_data_shape = [data.shape] * 3 + out_buffers = [unique_elements_buf, inverse_indices_buf, counts_buf] + out_dtypes = [data.dtype, "int32", "int32"] + else: + out_data_shape = [data.shape] * 2 + out_buffers = [unique_elements_buf, inverse_indices_buf] + out_dtypes = [data.dtype, "int32"] + # prepare inputs and fcompute if is_sorted: - # calculate unique elements and inverse indices - unique_elements, inverse_indices = te.extern( - [data.shape, data.shape], - [data, argsorted_indices, inc_scan], - lambda ins, outs: _calc_unique_sorted_ir(*ins, *outs), - dtype=[data.dtype, "int32"], - in_buffers=[data_buf, argsorted_indices_buf, inc_scan_buf], - out_buffers=[unique_elements_buf, inverse_indices_buf], - name="_calc_unique_sorted", - tag="_calc_unique_sorted_gpu", - ) - if not return_counts: - return [unique_elements, inverse_indices, num_unique_elements] + in_data = [data, argsorted_indices, inc_scan] + in_buffers = [data_buf, argsorted_indices_buf, inc_scan_buf] + if return_counts: + fcompute = lambda ins, outs: _calc_unique_ir(*ins, None, *outs) else: - # calculate counts of unique elements - counts_buf = tvm.tir.decl_buffer(data.shape, "int32", "counts_buf", data_alignment=8) - counts = te.extern( - [data.shape], - [inc_scan], - lambda ins, outs: _calc_counts_sorted_ir(ins[0], outs[0]), - dtype=["int32"], - in_buffers=[inc_scan_buf], - out_buffers=[counts_buf], - name="calc_counts_sorted", - tag="calc_counts_sorted_gpu", - ) - return [unique_elements, inverse_indices, num_unique_elements, counts] + fcompute = lambda ins, outs: _calc_unique_ir(*ins, None, *outs, None) else: + # calculate the index converter if the unique elements should not be sorted # calculate first occurence - first_occurence_buf = tvm.tir.decl_buffer( + first_occurence_buf = tir.decl_buffer( data.shape, "int32", "first_occurence_buf", data_alignment=8 ) first_occurence = te.extern( @@ -387,33 +370,25 @@ def unique(data, is_sorted=True, return_counts=False): # calculate index converter by sorting unique elements by their first occurence argsorted_first_occurence = argsort(first_occurence, dtype="int32") index_converter = argsort(argsorted_first_occurence, dtype="int32") - # calculate unique elements and inverse indices - index_converter_buf = tvm.tir.decl_buffer( + index_converter_buf = tir.decl_buffer( data.shape, "int32", "index_converter_buf", data_alignment=8 ) - unique_elements, inverse_indices = te.extern( - [data.shape, data.shape], - [data, argsorted_indices, inc_scan, index_converter], - lambda ins, outs: _calc_unique_unsorted_ir(*ins, *outs), - dtype=[data.dtype, "int32"], - in_buffers=[data_buf, argsorted_indices_buf, inc_scan_buf, index_converter_buf], - out_buffers=[unique_elements_buf, inverse_indices_buf], - name="_calc_unique_unsorted", - tag="_calc_unique_unsorted_gpu", - ) - if not return_counts: - return [unique_elements, inverse_indices, num_unique_elements] + in_data = [data, argsorted_indices, inc_scan, index_converter] + in_buffers = [data_buf, argsorted_indices_buf, inc_scan_buf, index_converter_buf] + if return_counts: + fcompute = lambda ins, outs: _calc_unique_ir(*ins, *outs) else: - # calculate counts of unique elements - counts_buf = tvm.tir.decl_buffer(data.shape, "int32", "counts_buf", data_alignment=8) - counts = te.extern( - [data.shape], - [inc_scan, index_converter], - lambda ins, outs: _calc_counts_unsorted_ir(ins[0], ins[1], outs[0]), - dtype=["int32"], - in_buffers=[inc_scan_buf, index_converter_buf], - out_buffers=[counts_buf], - name="calc_counts_unsorted", - tag="calc_counts_unsorted_gpu", - ) - return [unique_elements, inverse_indices, num_unique_elements, counts] + fcompute = lambda ins, outs: _calc_unique_ir(*ins, *outs, None) + outs = te.extern( + out_data_shape, + in_data, + fcompute, + dtype=out_dtypes, + in_buffers=in_buffers, + out_buffers=out_buffers, + name="_calc_unique", + tag="_calc_unique_gpu", + ) + if return_counts: + return [outs[0], outs[1], num_unique_elements, outs[2]] + return [*outs, num_unique_elements] diff --git a/python/tvm/topi/unique.py b/python/tvm/topi/unique.py index 86dd82a2bcef..5f919f44a370 100644 --- a/python/tvm/topi/unique.py +++ b/python/tvm/topi/unique.py @@ -14,58 +14,180 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name, no-else-return +# pylint: disable=invalid-name """Unique operator""" +from tvm import te, tir from ..te import hybrid from .cumsum import cumsum from .sort import sort, argsort -@hybrid.script -def _calc_adjacent_diff(data): - output = output_tensor(data.shape, "int32") - output[0] = int32(0) - for i in parallel(1, data.shape[0]): - output[i] = int32(1) if data[i] != data[i - 1] else int32(0) - return output +def _calc_adjacent_diff_ir(data, output, binop=tir.Sub): + """Low level IR to calculate adjacent difference in an 1-D array. + + Parameters + ---------- + data : Buffer + Input 1-D Buffer. + + output: Buffer + A buffer to store adjacent difference, of the same shape as data. The adjacent difference + is defined as: output[0] = 0, output[i] = binop(data[i], data[i-1]) + where i > 0 and i < len(data). + + binop: function, optional + A binary associative op to use for calculating adjacent difference. The function takes two + TIR expressions and produce a new TIR expression. By default it uses tvm.tir.Sub to + compute the adjacent difference. + """ + ib = tir.ir_builder.create() + data_ptr = ib.buffer_ptr(data) + output_ptr = ib.buffer_ptr(output) + with ib.for_range(0, data.shape[0], kind="parallel") as i: + with ib.if_scope(i == 0): + output_ptr[0] = 0 + with ib.else_scope(): + output_ptr[i] = tir.Cast(output.dtype, binop(data_ptr[i], data_ptr[i - 1])) + return ib.get() + + +def _calc_adjacent_diff(data, out_dtype="int32", binop=tir.Sub): + """Function calculate adjacent difference in an 1-D array. + + Parameters + ---------- + data : tvm.te.Tensor + Input 1-D tensor. + + output_dtype : str + The output tensor data type. + + binop: function, optional + A binary associative op to use for calculating difference. The function takes two + TIR expressions and produce a new TIR expression. By default it uses tvm.tir.Sub to + compute the adjacent difference. + + Returns + ------- + output : tvm.te.Tensor + 1-D tensor storing the adjacent difference of the input tensor. The adjacent difference + is defined as: output[0] = 0, output[i] = binop(data[i], data[i-1]) + where i > 0 and i < len(data). + """ + return te.extern( + [data.shape], + [data], + lambda ins, outs: _calc_adjacent_diff_ir(ins[0], outs[0], binop=binop), + dtype=[out_dtype], + name="_calc_adjacent_diff", + tag="_calc_adjacent_diff_cpu", + ) @hybrid.script -def _calc_num_unique(data): +def _calc_num_unique(inc_scan): + """Helper function to get the number of unique elements fron inc_scan tensor""" output = output_tensor((1,), "int32") - output[0] = data[data.shape[0] - 1] + int32(1) + output[0] = inc_scan[inc_scan.shape[0] - 1] + int32(1) return output -@hybrid.script -def _calc_unique_sorted(data, argsorted_indices, inc_scan): - unique_elements = output_tensor(data.shape, data.dtype) - indices = output_tensor(data.shape, "int32") - for i in parallel(data.shape[0]): - indices[argsorted_indices[i]] = inc_scan[i] - if i == 0 or inc_scan[i] != inc_scan[i - 1]: - unique_elements[inc_scan[i]] = data[argsorted_indices[i]] - return unique_elements, indices +def _calc_unique_ir( + data, argsorted_indices, inc_scan, index_converter, unique_elements, indices, counts +): + """Low level IR to calculate unique elements, inverse indices, and counts (optional) of + unique elements of 1-D array. + Parameters + ---------- + data : Buffer + Input 1-D Buffer. -@hybrid.script -def _calc_unique_sorted_with_counts(data, argsorted_indices, inc_scan): - unique_elements = output_tensor(data.shape, data.dtype) - indices = output_tensor(data.shape, "int32") - counts = output_tensor(data.shape, "int32") - for i in parallel(data.shape[0]): - counts[i] = int32(0) - for i in parallel(data.shape[0]): - indices[argsorted_indices[i]] = inc_scan[i] - if i == 0 or inc_scan[i] != inc_scan[i - 1]: - unique_elements[inc_scan[i]] = data[argsorted_indices[i]] - for i in range(data.shape[0]): - counts[inc_scan[i]] += int32(1) - return unique_elements, indices, counts + argsorted_indices : Buffer + A buffer that stores the argsorted indices of the input data. + + inc_scan : Buffer + A buffer that stores the inclusive scan of the binary tir.NE adjacent difference + of the sorted data. + + index_converter (optional) : Buffer + An optional index converter that transforms the unique element index + such that new_idx = index_converter[old_idx]. + + unique_elements : Buffer + A buffer that stores the unique elements. + + indices : Buffer + A buffer that stores the the index of each input data element in the unique element array. + + counts (optional) : Buffer + A buffer that stores the count of each unique element. + """ + ib = tir.ir_builder.create() + data_ptr = ib.buffer_ptr(data) + argsorted_indices_ptr = ib.buffer_ptr(argsorted_indices) + inc_scan_ptr = ib.buffer_ptr(inc_scan) + unique_elements_ptr = ib.buffer_ptr(unique_elements) + indices_ptr = ib.buffer_ptr(indices) + + index_converter_ptr = None + if isinstance(index_converter, tir.Buffer): + index_converter_ptr = ib.buffer_ptr(index_converter) + + if isinstance(counts, tir.Buffer): + counts_ptr = ib.buffer_ptr(counts) + arange_ptr = ib.allocate(counts_ptr.dtype, counts.shape, name="arange_buf", scope="local") + + data_length = data.shape[0] + + with ib.new_scope(): + with ib.for_range(0, data_length, kind="parallel") as i: + data_idx = argsorted_indices_ptr[i] + unique_idx = ( + inc_scan_ptr[i] if not index_converter_ptr else index_converter_ptr[inc_scan_ptr[i]] + ) + indices_ptr[data_idx] = unique_idx + with ib.if_scope(i == 0): + unique_elements_ptr[unique_idx] = data_ptr[data_idx] + with ib.else_scope(): + with ib.if_scope(inc_scan_ptr[i] != inc_scan_ptr[i - 1]): + unique_elements_ptr[unique_idx] = data_ptr[data_idx] + + if isinstance(counts, tir.Buffer): + num_unique = inc_scan_ptr[inc_scan.shape[0] - 1] + 1 + num_elements = data.shape[0] + arange_ptr[num_unique - 1] = num_elements + with ib.new_scope(): + with ib.for_range(0, data_length, kind="parallel") as i: + with ib.if_scope(i > 0): + with ib.if_scope(inc_scan_ptr[i] != inc_scan_ptr[i - 1]): + arange_ptr[inc_scan_ptr[i] - 1] = i + with ib.new_scope(): + with ib.for_range(0, num_unique, kind="parallel") as i: + unique_idx = i if not index_converter_ptr else index_converter_ptr[i] + with ib.if_scope(i == 0): + counts_ptr[unique_idx] = arange_ptr[i] + with ib.else_scope(): + counts_ptr[unique_idx] = arange_ptr[i] - arange_ptr[i - 1] + return ib.get() @hybrid.script def _calc_first_occurence(argsorted_indices, inc_scan): + """Hybrid script to calculate the first occurence of each unique element in the input data. + + Parameters + ---------- + argsorted_indices : tvm.te.Tensor + A tensor that stores the argsorted indices of the input data. + + inc_scan : tvm.te.Tensor + A tensor that stores the inclusive scan of the binary tir.NE adjacent difference + of the sorted data. + + first_occurence : tvm.te.Tensor + A tensor that stores the first occurence of each unique element in the input data. + """ first_occurence = output_tensor(argsorted_indices.shape, "int32") for i in parallel(argsorted_indices.shape[0]): first_occurence[i] = argsorted_indices.shape[0] @@ -75,59 +197,36 @@ def _calc_first_occurence(argsorted_indices, inc_scan): return first_occurence -@hybrid.script -def _calc_unique_unsorted(data, argsorted_indices, inc_scan, index_converter): - unique_elements = output_tensor(data.shape, data.dtype) - indices = output_tensor(data.shape, "int32") - for i in parallel(data.shape[0]): - new_unique_idx = index_converter[inc_scan[i]] - new_data_idx = argsorted_indices[i] - indices[new_data_idx] = new_unique_idx - if i == 0 or inc_scan[i] != inc_scan[i - 1]: - unique_elements[new_unique_idx] = data[new_data_idx] - return unique_elements, indices - - -@hybrid.script -def _calc_unique_unsorted_with_counts(data, argsorted_indices, inc_scan, index_converter): - unique_elements = output_tensor(data.shape, data.dtype) - indices = output_tensor(data.shape, "int32") - counts = output_tensor(data.shape, "int32") - for i in parallel(data.shape[0]): - counts[i] = int32(0) - for i in parallel(data.shape[0]): - new_unique_idx = index_converter[inc_scan[i]] - new_data_idx = argsorted_indices[i] - indices[new_data_idx] = new_unique_idx - if i == 0 or inc_scan[i] != inc_scan[i - 1]: - unique_elements[new_unique_idx] = data[new_data_idx] - for i in range(data.shape[0]): - idx = index_converter[inc_scan[i]] - counts[idx] += int32(1) - return unique_elements, indices, counts - - def unique(data, is_sorted=True, return_counts=False): """ - Find the unique elements of a tensor + Find the unique elements of a 1-D tensor. Please note `output` and `counts` are all padded to + have the same length of `data` and element with index >= num_unique[0] has undefined value. + Parameters ---------- - data : relay.Expr - A 1-D tensor of integers + data : tvm.te.Tensor + A 1-D tensor of integers. + sorted : bool - Whether to sort the unique elements in ascending order before returning as output + Whether to sort the unique elements in ascending order before returning as output. + return_counts : bool - Whether to return the array with count of each unique element + Whether to return the count of each unique element. + Returns ------- - output : relay.Expr - A 1-D tensor containing the unique elements of the input data tensor - indices : relay.Expr - A 1-D tensor containing the index of each data element in the output tensor - num_unique : relay.Expr - A 0-D tensor containing the number of unique elements in the input data tensor - counts (optional) : relay.Expr - A 1-D tensor containing the count of each unique element in the output + output : tvm.te.Tensor + A 1-D tensor containing the unique elements of the input data tensor. + + indices : tvm.te.Tensor + A 1-D tensor containing the index of each data element in the output tensor. + + num_unique : tvm.te.Tensor + A 1-D tensor with size=1 containing the number of unique elements in the input data tensor. + + counts (optional) : tvm.te.Tensor + A 1-D tensor containing the count of each unique element in the output. + Examples -------- .. code-block:: python @@ -147,35 +246,48 @@ def unique(data, is_sorted=True, return_counts=False): indices = [3, 4, 0, 1, 2, 2, 3, 4] num_unique = [5] """ - sorted_data = sort(data) argsorted_indices = argsort(data, dtype="int32") - adjacent_diff = _calc_adjacent_diff(sorted_data) + # adjacent difference + adjacent_diff = _calc_adjacent_diff(sorted_data, "int32", tir.NE) + # inclusive scan inc_scan = cumsum(adjacent_diff, dtype="int32", exclusive=0) + # total number of unique elements num_unique_elements = _calc_num_unique(inc_scan) - + # prepare outputs + if return_counts: + out_data_shape = [data.shape] * 3 + out_dtypes = [data.dtype, "int32", "int32"] + else: + out_data_shape = [data.shape] * 2 + out_dtypes = [data.dtype, "int32"] + # prepare inputs and fcompute if is_sorted: + in_data = [data, argsorted_indices, inc_scan] if return_counts: - unique_elements, inverse_indices, counts = _calc_unique_sorted_with_counts( - data, argsorted_indices, inc_scan - ) - return [unique_elements, inverse_indices, num_unique_elements, counts] + fcompute = lambda ins, outs: _calc_unique_ir(*ins, None, *outs) else: - unique_elements, inverse_indices = _calc_unique_sorted( - data, argsorted_indices, inc_scan - ) - return [unique_elements, inverse_indices, num_unique_elements] + fcompute = lambda ins, outs: _calc_unique_ir(*ins, None, *outs, None) else: + # calculate the index converter if the unique elements should not be sorted + # calculate first occurence first_occurence = _calc_first_occurence(argsorted_indices, inc_scan) + # calculate index converter by sorting unique elements by their first occurence argsorted_first_occurence = argsort(first_occurence, dtype="int32") index_converter = argsort(argsorted_first_occurence, dtype="int32") + in_data = [data, argsorted_indices, inc_scan, index_converter] if return_counts: - unique_elements, inverse_indices, counts = _calc_unique_unsorted_with_counts( - data, argsorted_indices, inc_scan, index_converter - ) - return [unique_elements, inverse_indices, num_unique_elements, counts] + fcompute = lambda ins, outs: _calc_unique_ir(*ins, *outs) else: - unique_elements, inverse_indices = _calc_unique_unsorted( - data, argsorted_indices, inc_scan, index_converter - ) - return [unique_elements, inverse_indices, num_unique_elements] + fcompute = lambda ins, outs: _calc_unique_ir(*ins, *outs, None) + outs = te.extern( + out_data_shape, + in_data, + fcompute, + dtype=out_dtypes, + name="_calc_unique", + tag="_calc_unique_cpu", + ) + if return_counts: + return [outs[0], outs[1], num_unique_elements, outs[2]] + return [*outs, num_unique_elements] diff --git a/tests/python/topi/python/test_topi_unique.py b/tests/python/topi/python/test_topi_unique.py index f35aed0a60ce..fd02e622022b 100644 --- a/tests/python/topi/python/test_topi_unique.py +++ b/tests/python/topi/python/test_topi_unique.py @@ -97,12 +97,12 @@ def check_unique(data, is_sorted=False): for in_dtype in ["int32", "int64"]: for is_sorted in [True, False]: - data = np.random.randint(0, 100, size=(1)).astype(in_dtype) - check_unique(data, is_sorted) - data = np.random.randint(0, 100, size=(50)).astype(in_dtype) - check_unique(data, is_sorted) - data = np.random.randint(0, 100, size=(10000)).astype(in_dtype) + # data = np.random.randint(0, 100, size=(1)).astype(in_dtype) + # check_unique(data, is_sorted) + data = np.random.randint(0, 10, size=(10)).astype(in_dtype) check_unique(data, is_sorted) + # data = np.random.randint(0, 100, size=(10000)).astype(in_dtype) + # check_unique(data, is_sorted) if __name__ == "__main__": From 8705397f21947a61753121c7fe1160ae40ec4ef3 Mon Sep 17 00:00:00 2001 From: Yanming Wang Date: Wed, 24 Feb 2021 19:12:46 -0800 Subject: [PATCH 14/16] Fix typo --- tests/python/topi/python/test_topi_unique.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/python/topi/python/test_topi_unique.py b/tests/python/topi/python/test_topi_unique.py index fd02e622022b..d7ee74282922 100644 --- a/tests/python/topi/python/test_topi_unique.py +++ b/tests/python/topi/python/test_topi_unique.py @@ -97,12 +97,12 @@ def check_unique(data, is_sorted=False): for in_dtype in ["int32", "int64"]: for is_sorted in [True, False]: - # data = np.random.randint(0, 100, size=(1)).astype(in_dtype) - # check_unique(data, is_sorted) + data = np.random.randint(0, 100, size=(1)).astype(in_dtype) + check_unique(data, is_sorted) data = np.random.randint(0, 10, size=(10)).astype(in_dtype) check_unique(data, is_sorted) - # data = np.random.randint(0, 100, size=(10000)).astype(in_dtype) - # check_unique(data, is_sorted) + data = np.random.randint(0, 100, size=(10000)).astype(in_dtype) + check_unique(data, is_sorted) if __name__ == "__main__": From 94c0b560033103481f76c559774ad356d4e91aab Mon Sep 17 00:00:00 2001 From: Yanming Wang Date: Wed, 24 Feb 2021 21:27:56 -0800 Subject: [PATCH 15/16] Combine _unique and _unique_with_counts --- python/tvm/relay/frontend/tensorflow.py | 36 ++++++++++++++----------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index d58fe24a3206..8a939f0892eb 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -2324,16 +2324,27 @@ def _impl(inputs, attr, params, mod): return _impl -def _unique(): +def _unique(return_counts=True): def _impl(inputs, attr, params, mod): assert len(inputs) == 1 data = inputs[0] - [unique, indices, num_uniq] = _op.unique(data, is_sorted=False, return_counts=False) - unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size") - return _expr.TupleWrapper( - _expr.Tuple([unique_sliced, indices]), - 2, - ) + if return_counts: + [unique, indices, num_uniq, counts] = _op.unique( + data, is_sorted=False, return_counts=True + ) + unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size") + counts_sliced = _op.strided_slice(counts, begin=[0], end=num_uniq, slice_mode="size") + return _expr.TupleWrapper( + _expr.Tuple([unique_sliced, indices, counts_sliced]), + 3, + ) + else: + [unique, indices, num_uniq] = _op.unique(data, is_sorted=False, return_counts=False) + unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size") + return _expr.TupleWrapper( + _expr.Tuple([unique_sliced, indices]), + 2, + ) return _impl @@ -2342,13 +2353,6 @@ def _unique_with_counts(): def _impl(inputs, attr, params, mod): assert len(inputs) == 1 data = inputs[0] - [unique, indices, num_uniq, counts] = _op.unique(data, is_sorted=False, return_counts=True) - unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size") - counts_sliced = _op.strided_slice(counts, begin=[0], end=num_uniq, slice_mode="size") - return _expr.TupleWrapper( - _expr.Tuple([unique_sliced, indices, counts_sliced]), - 3, - ) return _impl @@ -2531,8 +2535,8 @@ def _impl(inputs, attr, params, mod): "TopKV2": _topk(), "Transpose": _transpose(), "TruncateMod": _elemwise("mod"), - "Unique": _unique(), - "UniqueWithCounts": _unique_with_counts(), + "Unique": _unique(False), + "UniqueWithCounts": _unique(True), "Unpack": _unpack(), "UnravelIndex": _unravel_index(), "Where": _where(), From 74c8fc1bc2fab1fb21c9f0bb1d1875e10f1e0644 Mon Sep 17 00:00:00 2001 From: Yanming Wang Date: Thu, 25 Feb 2021 19:20:32 +0000 Subject: [PATCH 16/16] Reuse indices_ptr to remove arange_ptr --- python/tvm/relay/frontend/tensorflow.py | 21 +++------ python/tvm/topi/cuda/unique.py | 58 +++++++++++++------------ python/tvm/topi/unique.py | 40 +++++++++-------- 3 files changed, 58 insertions(+), 61 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 8a939f0892eb..52c5c8b9cacc 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -2338,21 +2338,12 @@ def _impl(inputs, attr, params, mod): _expr.Tuple([unique_sliced, indices, counts_sliced]), 3, ) - else: - [unique, indices, num_uniq] = _op.unique(data, is_sorted=False, return_counts=False) - unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size") - return _expr.TupleWrapper( - _expr.Tuple([unique_sliced, indices]), - 2, - ) - - return _impl - - -def _unique_with_counts(): - def _impl(inputs, attr, params, mod): - assert len(inputs) == 1 - data = inputs[0] + [unique, indices, num_uniq] = _op.unique(data, is_sorted=False, return_counts=False) + unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size") + return _expr.TupleWrapper( + _expr.Tuple([unique_sliced, indices]), + 2, + ) return _impl diff --git a/python/tvm/topi/cuda/unique.py b/python/tvm/topi/cuda/unique.py index b57176162e05..02a5cf3bc592 100644 --- a/python/tvm/topi/cuda/unique.py +++ b/python/tvm/topi/cuda/unique.py @@ -153,34 +153,12 @@ def _calc_unique_ir( if isinstance(counts, tir.Buffer): counts_ptr = ib.buffer_ptr(counts) - arange_ptr = ib.allocate(counts_ptr.dtype, counts.shape, name="arange_buf", scope="global") + # use indices_ptr as a tmp buffer to store tids with inc_scan[tid] != inc_scan[tid-1] + unique_seq_indices_ptr = ib.buffer_ptr(indices) batch_size = data.shape[0] max_threads = tir.min(batch_size, tvm.target.Target.current(allow_none=False).max_num_threads) - # calculate unique elements and inverse indices - with ib.new_scope(): - nthread_tx = max_threads - nthread_bx = ceil_div(batch_size, max_threads) - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - tid = bx * max_threads + tx - with ib.if_scope(tid < batch_size): - data_idx = argsorted_indices_ptr[tid] - unique_idx = ( - inc_scan_ptr[tid] - if not index_converter_ptr - else index_converter_ptr[inc_scan_ptr[tid]] - ) - indices_ptr[data_idx] = unique_idx - with ib.if_scope(tid == 0): - unique_elements_ptr[unique_idx] = data_ptr[data_idx] - with ib.else_scope(): - with ib.if_scope(inc_scan_ptr[tid] != inc_scan_ptr[tid - 1]): - unique_elements_ptr[unique_idx] = data_ptr[data_idx] - # if need to return counts if isinstance(counts, tir.Buffer): num_unique = inc_scan_ptr[inc_scan.shape[0] - 1] + 1 @@ -195,10 +173,10 @@ def _calc_unique_ir( tid = bx * max_threads + tx with ib.if_scope(tid < batch_size): with ib.if_scope(tid == 0): - arange_ptr[num_unique - 1] = num_elements + unique_seq_indices_ptr[num_unique - 1] = num_elements with ib.else_scope(): with ib.if_scope(inc_scan_ptr[tid] != inc_scan_ptr[tid - 1]): - arange_ptr[inc_scan_ptr[tid] - 1] = tid + unique_seq_indices_ptr[inc_scan_ptr[tid] - 1] = tid with ib.new_scope(): nthread_tx = max_threads nthread_bx = ceil_div(batch_size, max_threads) @@ -210,9 +188,33 @@ def _calc_unique_ir( with ib.if_scope(tid < num_unique): unique_idx = tid if not index_converter_ptr else index_converter_ptr[tid] with ib.if_scope(tid == 0): - counts_ptr[unique_idx] = arange_ptr[tid] + counts_ptr[unique_idx] = unique_seq_indices_ptr[tid] with ib.else_scope(): - counts_ptr[unique_idx] = arange_ptr[tid] - arange_ptr[tid - 1] + counts_ptr[unique_idx] = ( + unique_seq_indices_ptr[tid] - unique_seq_indices_ptr[tid - 1] + ) + # calculate unique elements and inverse indices + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(batch_size, max_threads) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + with ib.if_scope(tid < batch_size): + data_idx = argsorted_indices_ptr[tid] + unique_idx = ( + inc_scan_ptr[tid] + if not index_converter_ptr + else index_converter_ptr[inc_scan_ptr[tid]] + ) + indices_ptr[data_idx] = unique_idx + with ib.if_scope(tid == 0): + unique_elements_ptr[unique_idx] = data_ptr[data_idx] + with ib.else_scope(): + with ib.if_scope(inc_scan_ptr[tid] != inc_scan_ptr[tid - 1]): + unique_elements_ptr[unique_idx] = data_ptr[data_idx] return ib.get() diff --git a/python/tvm/topi/unique.py b/python/tvm/topi/unique.py index 5f919f44a370..b4f27b38f65f 100644 --- a/python/tvm/topi/unique.py +++ b/python/tvm/topi/unique.py @@ -136,39 +136,43 @@ def _calc_unique_ir( if isinstance(counts, tir.Buffer): counts_ptr = ib.buffer_ptr(counts) - arange_ptr = ib.allocate(counts_ptr.dtype, counts.shape, name="arange_buf", scope="local") + # use indices_ptr as a tmp buffer to store tids with inc_scan[tid] != inc_scan[tid-1] + unique_seq_indices_ptr = ib.buffer_ptr(indices) data_length = data.shape[0] - with ib.new_scope(): - with ib.for_range(0, data_length, kind="parallel") as i: - data_idx = argsorted_indices_ptr[i] - unique_idx = ( - inc_scan_ptr[i] if not index_converter_ptr else index_converter_ptr[inc_scan_ptr[i]] - ) - indices_ptr[data_idx] = unique_idx - with ib.if_scope(i == 0): - unique_elements_ptr[unique_idx] = data_ptr[data_idx] - with ib.else_scope(): - with ib.if_scope(inc_scan_ptr[i] != inc_scan_ptr[i - 1]): - unique_elements_ptr[unique_idx] = data_ptr[data_idx] - + # if need to return counts if isinstance(counts, tir.Buffer): num_unique = inc_scan_ptr[inc_scan.shape[0] - 1] + 1 num_elements = data.shape[0] - arange_ptr[num_unique - 1] = num_elements + unique_seq_indices_ptr[num_unique - 1] = num_elements with ib.new_scope(): with ib.for_range(0, data_length, kind="parallel") as i: with ib.if_scope(i > 0): with ib.if_scope(inc_scan_ptr[i] != inc_scan_ptr[i - 1]): - arange_ptr[inc_scan_ptr[i] - 1] = i + unique_seq_indices_ptr[inc_scan_ptr[i] - 1] = i with ib.new_scope(): with ib.for_range(0, num_unique, kind="parallel") as i: unique_idx = i if not index_converter_ptr else index_converter_ptr[i] with ib.if_scope(i == 0): - counts_ptr[unique_idx] = arange_ptr[i] + counts_ptr[unique_idx] = unique_seq_indices_ptr[i] with ib.else_scope(): - counts_ptr[unique_idx] = arange_ptr[i] - arange_ptr[i - 1] + counts_ptr[unique_idx] = ( + unique_seq_indices_ptr[i] - unique_seq_indices_ptr[i - 1] + ) + # calculate unique elements and inverse indices + with ib.new_scope(): + with ib.for_range(0, data_length, kind="parallel") as i: + data_idx = argsorted_indices_ptr[i] + unique_idx = ( + inc_scan_ptr[i] if not index_converter_ptr else index_converter_ptr[inc_scan_ptr[i]] + ) + indices_ptr[data_idx] = unique_idx + with ib.if_scope(i == 0): + unique_elements_ptr[unique_idx] = data_ptr[data_idx] + with ib.else_scope(): + with ib.if_scope(inc_scan_ptr[i] != inc_scan_ptr[i - 1]): + unique_elements_ptr[unique_idx] = data_ptr[data_idx] return ib.get()