From e504b78e8d6ce7f39439abd910e9b64fc9640021 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 14 Feb 2023 14:22:59 -0500 Subject: [PATCH] [Unity] Relax op: set This PR is about the high-level tensor computation operators in Relax. This PR includes the set operators. Co-authored-by: Prakalp Srivastava --- include/tvm/relax/attrs/set.h | 62 ++ python/tvm/relax/op/__init__.py | 1 + python/tvm/relax/op/op_attrs.py | 5 + python/tvm/relax/op/set.py | 101 ++ python/tvm/script/ir_builder/relax/ir.py | 2 + src/relax/op/tensor/set.cc | 103 +++ src/relax/op/tensor/set.h | 40 + tests/python/relax/test_op_set.py | 862 ++++++++++++++++++ .../relax/test_tvmscript_parser_op_set.py | 68 ++ 9 files changed, 1244 insertions(+) create mode 100644 include/tvm/relax/attrs/set.h create mode 100644 python/tvm/relax/op/set.py create mode 100644 src/relax/op/tensor/set.cc create mode 100644 src/relax/op/tensor/set.h create mode 100644 tests/python/relax/test_op_set.py create mode 100644 tests/python/relax/test_tvmscript_parser_op_set.py diff --git a/include/tvm/relax/attrs/set.h b/include/tvm/relax/attrs/set.h new file mode 100644 index 000000000000..3fae7646ff8e --- /dev/null +++ b/include/tvm/relax/attrs/set.h @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/attrs/set.h + * \brief Attributes for set operators. + */ +#ifndef TVM_RELAX_ATTRS_SET_H_ +#define TVM_RELAX_ATTRS_SET_H_ + +#include + +namespace tvm { +namespace relax { + +/*! \brief Attributes used in unique operator */ +struct UniqueAttrs : public tvm::AttrsNode { + bool sorted; + bool return_index; + bool return_inverse; + bool return_counts; + Optional axis; + + TVM_DECLARE_ATTRS(UniqueAttrs, "relax.attrs.UniqueAttrs") { + TVM_ATTR_FIELD(sorted).describe( + "Whether to sort the unique elements in ascending order before returning as output."); + TVM_ATTR_FIELD(return_index) + .describe( + "Whether to return an additional tensor with indices for where elements in the unique " + "tensor come from the original input."); + TVM_ATTR_FIELD(return_inverse) + .describe( + "Whether to return an additional tensor with indices for where elements in the " + "original input ended up in the returned unique list."); + TVM_ATTR_FIELD(return_counts) + .describe("Whether to return an additional tensor with counts of each unique elements"); + TVM_ATTR_FIELD(axis).describe( + "The dimension to apply unique. If it is NullOpt, the unique values of the flattened input " + "is are returned."); + } +}; // struct UniqueAttrs + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_ATTRS_SET_H_ diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 3393a5dcae67..9a9c1754bae5 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -23,5 +23,6 @@ from .index import * from .manipulate import * from .op_attrs import * +from .set import * from . import builtin from . import memory diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index 44cb2cf3a5b4..176020da63ee 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -27,3 +27,8 @@ class TakeAttrs(Attrs): @tvm._ffi.register_object("relax.attrs.StridedSliceAttrs") class StridedSliceAttrs(Attrs): """Attributes used in strided_slice operator""" + + +@tvm._ffi.register_object("relax.attrs.UniqueAttrs") +class UniqueAttrs(Attrs): + """Attributes used for the unique operator""" diff --git a/python/tvm/relax/op/set.py b/python/tvm/relax/op/set.py new file mode 100644 index 000000000000..b7ee0f381169 --- /dev/null +++ b/python/tvm/relax/op/set.py @@ -0,0 +1,101 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=import-outside-toplevel, redefined-builtin, unused-argument +"""Set operators.""" +from typing import Optional + +import numpy as np # type: ignore +import tvm + +from . import _ffi_api +from ..expr import Expr + + +def unique( + x: Expr, + sorted: bool = True, + return_index: bool = False, + return_inverse: bool = False, + return_counts: bool = False, + axis: Optional[int] = None, +) -> Expr: + """Find the unique elements in a given tensor. + In addition, it optionally returns + - the indices of the input tensor that give the unique values; + - the indices of the unique tensor that reconstruct the input tensor; + - the number of times each unique value comes up in the input tensor. + + Parameters + ---------- + x : relax.Expr + The input tensor. + + sorted : bool + Whether to sort the unique elements in ascending order before + returning as output. + + return_index : bool + Whether to return an additional tensor with indices for where elements in + the unique tensor come from the original input. + + return_inverse : bool + Whether to return an additional tensor with indices for where elements in + the original input ended up in the returned unique list. + + return_counts : bool + Whether to return an additional tensor with counts of each unique elements. + + axis : Optional + The dimension to apply unique. + If not specified, the unique values of the flattened input are returned. + + Returns + ------- + ret : relax.Expr + The created relax call with + """ + + return _ffi_api.unique( # type: ignore + x, sorted, return_index, return_inverse, return_counts, axis + ) + + +@tvm.register_func("relax.run.unique") +def numpy_unique( + x: tvm.nd.array, + sorted: int, + return_index: int, + return_inverse: int, + return_counts: int, + axis: Optional[int], +) -> tvm.nd.array: + """Returns the unique elements of the input tensor. + + Uses numpy.unique to compute unique elements. + """ + import builtins + + # TODO(prakalp): add support for returning a tuple when return_inverse or return_counts is True + if bool(return_index) or bool(return_inverse) or bool(return_counts): + raise NotImplementedError("missing support return_inverse or return_counts set to true") + x_numpy = x.numpy() + # TODO(prakalp): use torch.unique instead of numpy when torch is installed in ci. + output_sorted_numpy, indices = np.unique(x_numpy, return_index=True) + if sorted: + return tvm.nd.array(output_sorted_numpy) + output_numpy = [x_numpy.flatten()[index] for index in builtins.sorted(indices, reverse=True)] + return tvm.nd.array(output_numpy) diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 75a00ea04985..e283df18f6f7 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -44,6 +44,7 @@ shape_of, strided_slice, take, + unique, ) from tvm.relax.struct_info import StructInfo from tvm.relax.utils import args_converter @@ -432,4 +433,5 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "strided_slice", "take", "tuple", + "unique", ] diff --git a/src/relax/op/tensor/set.cc b/src/relax/op/tensor/set.cc new file mode 100644 index 000000000000..4d5a274e17fa --- /dev/null +++ b/src/relax/op/tensor/set.cc @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file set.cc + * \brief Relax set operators. + */ + +#include "set.h" + +#include +#include + +namespace tvm { +namespace relax { + +/* relax.unique */ +TVM_REGISTER_NODE_TYPE(UniqueAttrs); + +Expr unique(Expr x, bool sorted, bool return_index, bool return_inverse, bool return_counts, + Optional axis) { + ObjectPtr attrs = make_object(); + attrs->sorted = sorted; + attrs->return_index = return_index; + attrs->return_inverse = return_inverse; + attrs->return_counts = return_counts; + attrs->axis = std::move(axis); + + static const Op& op = Op::Get("relax.unique"); + return Call(op, {std::move(x)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.unique").set_body_typed(unique); + +StructInfo InferStructInfoUnique(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + const auto* attrs = call->attrs.as(); + if (!data_sinfo->IsUnknownNdim() && attrs->axis.defined()) { + // Normalize the axis for sanity check purpose. + NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis.value()->value); + } + + int n_int_return = static_cast(attrs->return_index) + + static_cast(attrs->return_inverse) + + static_cast(attrs->return_counts); + + std::vector output_sinfo; + output_sinfo.reserve(1 + n_int_return); + + // unique values + if (data_sinfo->ndim == 0) { + output_sinfo.push_back( + TensorStructInfo(ShapeExpr({IntImm(DataType::Int(64), /*value=*/1)}), data_sinfo->dtype)); + } else if (attrs->axis.defined()) { + output_sinfo.push_back(TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim)); + } else { + output_sinfo.push_back(TensorStructInfo(data_sinfo->dtype, /*ndim=*/1)); + } + + // index, reverse and counts + TensorStructInfo int_return{nullptr}; + if (data_sinfo->ndim == 0) { + int_return = + TensorStructInfo(ShapeExpr({IntImm(DataType::Int(64), /*value=*/1)}), DataType::Int(64)); + } else { + int_return = TensorStructInfo(DataType::Int(64), /*ndim=*/1); + } + for (int i = 0; i < n_int_return; ++i) { + output_sinfo.push_back(int_return); + } + + if (output_sinfo.size() == 1) { + return output_sinfo[0]; + } else { + return TupleStructInfo(output_sinfo); + } +} + +TVM_REGISTER_OP("relax.unique") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("x", "Tensor", "The input tensor") + .set_attr("FInferStructInfo", InferStructInfoUnique) + .set_attr("FCallPacked", "relax.run.unique"); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/tensor/set.h b/src/relax/op/tensor/set.h new file mode 100644 index 000000000000..83d2619e4d2c --- /dev/null +++ b/src/relax/op/tensor/set.h @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. Sex The NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. Sex The License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file set.h + * \brief The functions to make Relax set operator calls. + */ +#ifndef TVM_RELAX_OP_TENSOR_SET_H_ +#define TVM_RELAX_OP_TENSOR_SET_H_ + +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +Expr unique(Expr x, bool sorted, bool return_index, bool return_inverse, bool return_counts, + Optional axis); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_TENSOR_SET_H_ diff --git a/tests/python/relax/test_op_set.py b/tests/python/relax/test_op_set.py new file mode 100644 index 000000000000..755d5e8f870c --- /dev/null +++ b/tests/python/relax/test_op_set.py @@ -0,0 +1,862 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + assert relax.op.unique(x).op == Op.get("relax.unique") + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_unique_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3, 4))) + + _check_inference( + bb, + relax.op.unique( + x0, return_index=False, return_inverse=False, return_counts=False, axis=None + ), + relax.TensorStructInfo(dtype="float32", ndim=1), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=False, return_inverse=False, return_counts=False, axis=1), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + _check_inference( + bb, + relax.op.unique( + x0, return_index=False, return_inverse=False, return_counts=True, axis=None + ), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=False, return_inverse=False, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique( + x0, return_index=False, return_inverse=True, return_counts=False, axis=None + ), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=False, return_inverse=True, return_counts=False, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=False, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=False, return_inverse=True, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique( + x0, return_index=True, return_inverse=False, return_counts=False, axis=None + ), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=True, return_inverse=False, return_counts=False, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=True, return_inverse=False, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=True, return_inverse=False, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=True, return_inverse=True, return_counts=False, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=True, return_inverse=True, return_counts=False, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=True, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=True, return_inverse=True, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=True, return_inverse=True, return_counts=True, axis=-2), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique( + x0, sorted=True, return_index=True, return_inverse=True, return_counts=True, axis=None + ), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique( + x0, sorted=True, return_index=True, return_inverse=True, return_counts=True, axis=1 + ), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique( + x1, return_index=False, return_inverse=False, return_counts=False, axis=None + ), + relax.TensorStructInfo(dtype="float32", ndim=1), + ) + _check_inference( + bb, + relax.op.unique(x1, return_index=False, return_inverse=False, return_counts=False, axis=1), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + _check_inference( + bb, + relax.op.unique( + x1, return_index=False, return_inverse=True, return_counts=False, axis=None + ), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x1, return_index=False, return_inverse=True, return_counts=False, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x1, return_index=True, return_inverse=False, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x1, return_index=True, return_inverse=False, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x1, return_index=True, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x1, return_index=True, return_inverse=True, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique( + x2, return_index=False, return_inverse=False, return_counts=False, axis=None + ), + relax.TensorStructInfo(dtype="float32", ndim=1), + ) + _check_inference( + bb, + relax.op.unique(x2, return_index=False, return_inverse=False, return_counts=False, axis=1), + relax.TensorStructInfo(dtype="float32"), + ) + _check_inference( + bb, + relax.op.unique( + x2, return_index=True, return_inverse=False, return_counts=False, axis=None + ), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x2, return_index=True, return_inverse=False, return_counts=False, axis=1), + relax.TupleStructInfo( + [relax.TensorStructInfo(dtype="float32"), relax.TensorStructInfo(dtype="int64", ndim=1)] + ), + ) + _check_inference( + bb, + relax.op.unique(x2, return_index=True, return_inverse=True, return_counts=False, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x2, return_index=True, return_inverse=True, return_counts=False, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32"), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x2, return_index=True, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x2, return_index=True, return_inverse=True, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32"), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique( + x3, return_index=False, return_inverse=False, return_counts=False, axis=None + ), + relax.TensorStructInfo(dtype="", ndim=1), + ) + _check_inference( + bb, + relax.op.unique(x3, return_index=False, return_inverse=False, return_counts=False, axis=1), + relax.TensorStructInfo(dtype="", ndim=3), + ) + _check_inference( + bb, + relax.op.unique( + x3, return_index=False, return_inverse=False, return_counts=True, axis=None + ), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x3, return_index=False, return_inverse=False, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x3, return_index=False, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x3, return_index=False, return_inverse=True, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x3, return_index=True, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x3, return_index=True, return_inverse=True, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + + +def test_unique_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + c = tir.Var("c", "int64") + x = relax.Var("x", R.Tensor((a, b, c), "float32")) + + _check_inference( + bb, + relax.op.unique( + x, return_index=False, return_inverse=False, return_counts=False, axis=None + ), + relax.TensorStructInfo(dtype="float32", ndim=1), + ) + _check_inference( + bb, + relax.op.unique(x, return_index=False, return_inverse=False, return_counts=False, axis=1), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + _check_inference( + bb, + relax.op.unique(x, return_index=False, return_inverse=False, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x, return_index=False, return_inverse=False, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x, return_index=False, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x, return_index=False, return_inverse=True, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x, return_index=True, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x, return_index=True, return_inverse=True, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + + +def test_unique_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + _check_inference( + bb, + relax.op.unique( + x0, return_index=False, return_inverse=False, return_counts=False, axis=None + ), + relax.TensorStructInfo(dtype="float32", ndim=1), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=False, return_inverse=False, return_counts=False, axis=1), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + _check_inference( + bb, + relax.op.unique( + x0, return_index=False, return_inverse=False, return_counts=True, axis=None + ), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=False, return_inverse=False, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=False, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=False, return_inverse=True, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=True, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x0, return_index=True, return_inverse=True, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique( + x1, return_index=False, return_inverse=False, return_counts=False, axis=None + ), + relax.TensorStructInfo(dtype="float32", ndim=1), + ) + _check_inference( + bb, + relax.op.unique(x1, return_index=False, return_inverse=False, return_counts=False, axis=1), + relax.TensorStructInfo(dtype="float32"), + ) + _check_inference( + bb, + relax.op.unique( + x1, return_index=False, return_inverse=False, return_counts=True, axis=None + ), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x1, return_index=False, return_inverse=False, return_counts=True, axis=1), + relax.TupleStructInfo( + [relax.TensorStructInfo(dtype="float32"), relax.TensorStructInfo(dtype="int64", ndim=1)] + ), + ) + _check_inference( + bb, + relax.op.unique(x1, return_index=False, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x1, return_index=False, return_inverse=True, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32"), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x1, return_index=True, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x1, return_index=True, return_inverse=True, return_counts=True, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32"), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + + +def test_unique_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3, 4), "int32")) + + _check_inference( + bb, + relax.op.unique(x0, return_index=True, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float16", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x1, return_index=True, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="int8", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x2, return_index=True, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="int32", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + relax.TensorStructInfo(dtype="int64", ndim=1), + ] + ), + ) + + +def test_unique_infer_struct_info_input_zero_rank(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(())) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=0)) + x0 = relax.Var("x", R.Tensor((), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=0)) + x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + _check_inference( + bb, + relax.op.unique(x0, return_index=True, return_inverse=True, return_counts=True, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((1,), "float32"), + relax.TensorStructInfo((1,), "int64"), + relax.TensorStructInfo((1,), "int64"), + relax.TensorStructInfo((1,), "int64"), + ] + ), + ) + _check_inference( + bb, + relax.op.unique(x1, return_index=True, return_inverse=True, return_counts=False, axis=None), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((1,), "float32"), + relax.TensorStructInfo((1,), "int64"), + relax.TensorStructInfo((1,), "int64"), + ] + ), + ) + _check_inference( + bb, + relax.op.unique( + x2, return_index=True, return_inverse=False, return_counts=False, axis=None + ), + relax.TupleStructInfo( + [relax.TensorStructInfo((1,), "float32"), relax.TensorStructInfo((1,), "int64")] + ), + ) + _check_inference( + bb, + relax.op.unique( + x3, return_index=False, return_inverse=False, return_counts=False, axis=None + ), + relax.TensorStructInfo((1,), "float32"), + ) + + +def test_unique_infer_struct_info_axis_out_of_range(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor((), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.unique(x0, axis=3)) + with pytest.raises(TVMError): + bb.normalize(relax.op.unique(x0, axis=-4)) + with pytest.raises(TVMError): + bb.normalize(relax.op.unique(x1, axis=0)) + + +def test_unique_infer_struct_info_wrong_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.unique(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.unique(x1)) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_set.py b/tests/python/relax/test_tvmscript_parser_op_set.py new file mode 100644 index 000000000000..8e01fa6f6215 --- /dev/null +++ b/tests/python/relax/test_tvmscript_parser_op_set.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Union + +import tvm +import tvm.script +import tvm.testing +from tvm import IRModule, relax +from tvm.script import relax as R + + +def _check( + parsed: Union[relax.Function, IRModule], + expect: Optional[Union[relax.Function, IRModule]], +): + test = parsed.script(show_meta=True) + roundtrip_mod = tvm.script.from_source(test) + tvm.ir.assert_structural_equal(parsed, roundtrip_mod) + if expect: + tvm.ir.assert_structural_equal(parsed, expect) + + +def test_unique(): + @R.function + def foo( + x: R.Tensor((2, 3, 4), dtype="float32") + ) -> R.Tuple( + R.Tensor(dtype="float32", ndim=3), + R.Tensor(dtype="int64", ndim=1), + R.Tensor(dtype="int64", ndim=1), + ): + gv: R.Tuple( + R.Tensor(dtype="float32", ndim=3), + R.Tensor(dtype="int64", ndim=1), + R.Tensor(dtype="int64", ndim=1), + ) = R.unique( + x, sorted=True, return_index=False, return_inverse=True, return_counts=True, axis=1 + ) + return gv + + x = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit( + relax.op.unique(x, sorted=True, return_inverse=True, return_counts=True, axis=1) + ) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +if __name__ == "__main__": + tvm.testing.main()