Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,18 @@ struct CumsumAttrs : public tvm::AttrsNode<CumsumAttrs> {
}
};

/*! \brief Attributes used in unique operator */
struct UniqueAttrs : public tvm::AttrsNode<UniqueAttrs> {
bool sorted;
bool return_counts;
TVM_DECLARE_ATTRS(UniqueAttrs, "relay.attrs.UniqueAttrs") {
TVM_ATTR_FIELD(sorted).describe("Whether the unique elements are sorted").set_default(true);
TVM_ATTR_FIELD(return_counts)
.describe("Whether to return an additional tensor with counts of each unique elements")
.set_default(false);
}
}; // struct UniqueAttrs

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
19 changes: 19 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2164,6 +2164,24 @@ def is_floating_point(self, inputs, input_types):
is_float = input_type in ["float32", "float64", "float16", "bfloat16"]
return _expr.const(is_float)

def unique(self, inputs, input_types):
assert len(inputs) == 4
[data, is_sorted, return_inverse, return_counts] = inputs
if not is_sorted:
logging.warning("TVM always assumes sorted=True for torch.unique")
is_sorted = True
if return_counts:
[unique, indices, num_uniq, counts] = _op.unique(
data, is_sorted=is_sorted, return_counts=True
)
unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size")
counts_sliced = _op.strided_slice(counts, begin=[0], end=num_uniq, slice_mode="size")
return (unique_sliced, indices, counts_sliced)
else:
[unique, indices, num_uniq] = _op.unique(data, is_sorted=is_sorted, return_counts=False)
unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size")
return (unique_sliced, indices)

# Operator mappings
def create_convert_map(self):
self.convert_map = {
Expand Down Expand Up @@ -2370,6 +2388,7 @@ def create_convert_map(self):
"aten::masked_select": self.masked_select,
"aten::argsort": self.argsort,
"aten::sort": self.sort,
"aten::_unique2": self.unique,
}

def update_convert_map(self, custom_map):
Expand Down
26 changes: 26 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2324,6 +2324,30 @@ def _impl(inputs, attr, params, mod):
return _impl


def _unique(return_counts=True):
def _impl(inputs, attr, params, mod):
assert len(inputs) == 1
data = inputs[0]
if return_counts:
[unique, indices, num_uniq, counts] = _op.unique(
data, is_sorted=False, return_counts=True
)
unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size")
counts_sliced = _op.strided_slice(counts, begin=[0], end=num_uniq, slice_mode="size")
return _expr.TupleWrapper(
_expr.Tuple([unique_sliced, indices, counts_sliced]),
3,
)
[unique, indices, num_uniq] = _op.unique(data, is_sorted=False, return_counts=False)
unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size")
return _expr.TupleWrapper(
_expr.Tuple([unique_sliced, indices]),
2,
)

return _impl


# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -2502,6 +2526,8 @@ def _impl(inputs, attr, params, mod):
"TopKV2": _topk(),
"Transpose": _transpose(),
"TruncateMod": _elemwise("mod"),
"Unique": _unique(False),
"UniqueWithCounts": _unique(True),
"Unpack": _unpack(),
"UnravelIndex": _unravel_index(),
"Where": _where(),
Expand Down
44 changes: 44 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,15 @@ def compute_cumsum(attrs, inputs, output_type):
_reg.register_strategy("cumsum", strategy.cumsum_strategy)
_reg.register_shape_func("cumsum", False, elemwise_shape_func)


@_reg.register_compute("unique")
def compute_unique(attrs, inputs, output_type):
"""Compute definition of unique"""
return topi.unique(inputs[0], attrs.sorted, attrs.return_counts)


_reg.register_strategy("unique", strategy.unique_strategy)

#####################
# Shape functions #
#####################
Expand Down Expand Up @@ -946,3 +955,38 @@ def where_shape_func(attrs, inputs, _):
out_shape = _broadcast_shape_tensors(bcast_shape, cond_shape)

return [out_shape]


@script
def _unique_shape(data_shape):
unique_shape = output_tensor((1,), "int64")
indices_shape = output_tensor((1,), "int64")
num_unique_shape = output_tensor((1,), "int64")
unique_shape[0] = data_shape[0]
indices_shape[0] = data_shape[0]
num_unique_shape[0] = int64(1)
return (unique_shape, indices_shape, num_unique_shape)


@script
def _unique_with_counts_shape(data_shape):
unique_shape = output_tensor((1,), "int64")
indices_shape = output_tensor((1,), "int64")
num_unique_shape = output_tensor((1,), "int64")
counts_shape = output_tensor((1,), "int64")
unique_shape[0] = data_shape[0]
indices_shape[0] = data_shape[0]
num_unique_shape[0] = int64(1)
counts_shape[0] = data_shape[0]
return (unique_shape, indices_shape, num_unique_shape, counts_shape)


@_reg.register_shape_func("unique", False)
def unique_shape_func(attrs, inputs, _):
"""
Shape func for unique operator.
"""
if attrs.return_counts:
return _unique_with_counts_shape(inputs[0])
else:
return _unique_shape(inputs[0])
12 changes: 12 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,3 +1017,15 @@ def cumsum_strategy_cuda(attrs, inputs, out_type, target):
name="cumsum.cuda",
)
return strategy


@unique_strategy.register(["cuda", "gpu"])
def unique_strategy_cuda(attrs, inputs, out_type, target):
"""unique cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_unique(topi.cuda.unique),
wrap_topi_schedule(topi.cuda.schedule_scan),
name="unique.cuda",
)
return strategy
21 changes: 21 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1432,3 +1432,24 @@ def cumsum_strategy(attrs, inputs, out_type, target):
name="cumsum.generic",
)
return strategy


def wrap_compute_unique(topi_compute):
"""Wrap unique topi compute"""

def _compute_unique(attrs, inputs, _):
return topi_compute(inputs[0], attrs.sorted, attrs.return_counts)

return _compute_unique


@override_native_generic_func("unique_strategy")
def unique_strategy(attrs, inputs, out_type, target):
"""unique generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_unique(topi.unique),
wrap_topi_schedule(topi.generic.schedule_unique),
name="unique.generic",
)
return strategy
54 changes: 54 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1463,3 +1463,57 @@ def cumsum(data, axis=None, dtype=None, exclusive=None):
-> [1, 1, 2, 2, 3, 4, 4]
"""
return _make.cumsum(data, axis, dtype, exclusive)


def unique(data, is_sorted=True, return_counts=False):
"""
Find the unique elements of a 1-D tensor. Please note `output` and `counts` are all padded to
have the same length of `data` and element with index >= num_unique[0] has undefined value.

Parameters
----------
data : relay.Expr
A 1-D tensor of integers.

sorted : bool
Whether to sort the unique elements in ascending order before returning as output.

return_counts : bool
Whether to return the count of each unique element.

Returns
-------
output : relay.Expr
A 1-D tensor containing the unique elements of the input data tensor.

indices : relay.Expr
A 1-D tensor containing the index of each data element in the output tensor.

num_unique : relay.Expr
A 1-D tensor with size=1 containing the number of unique elements in the input data tensor.

counts (optional) : relay.Expr
A 1-D tensor containing the count of each unique element in the output.

Examples
--------
.. code-block:: python
[output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], False, False)
output = [4, 5, 1, 2, 3, ?, ?, ?]
indices = [0, 1, 2, 3, 4, 4, 0, 1]
num_unique = [5]

[output, indices, num_unique, counts] = unique([4, 5, 1, 2, 3, 3, 4, 5], False, True)
output = [4, 5, 1, 2, 3, ?, ?, ?]
indices = [0, 1, 2, 3, 4, 4, 0, 1]
num_unique = [5]
counts = [2, 2, 1, 1, 2, ?, ?, ?]

[output, indices, num_unique] = unique([4, 5, 1, 2, 3, 3, 4, 5], True)
output = [1, 2, 3, 4, 5, ?, ?, ?]
indices = [3, 4, 0, 1, 2, 2, 3, 4]
num_unique = [5]
"""
if return_counts:
return TupleWrapper(_make.unique(data, is_sorted, return_counts), 4)
return TupleWrapper(_make.unique(data, is_sorted, return_counts), 3)
1 change: 1 addition & 0 deletions python/tvm/topi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from .argwhere import *
from .cumsum import *
from .einsum import *
from .unique import *
from . import generic
from . import nn
from . import x86
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,4 @@
from . import tensorcore_alter_op
from .argwhere import *
from .scan import *
from .unique import *
Loading