diff --git a/include/tvm/relax/attrs/statistical.h b/include/tvm/relax/attrs/statistical.h new file mode 100644 index 000000000000..bb1ab2195d9a --- /dev/null +++ b/include/tvm/relax/attrs/statistical.h @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/attrs/statistical.h + * \brief Attributes for statistical operators. + */ +#ifndef TVM_RELAX_ATTRS_STATISTICAL_H_ +#define TVM_RELAX_ATTRS_STATISTICAL_H_ + +#include + +namespace tvm { +namespace relax { + +/*! \brief Attributes for statistical operators */ +struct StatisticalAttrs : public tvm::AttrsNode { + Optional> axis; + bool keepdims; + + TVM_DECLARE_ATTRS(StatisticalAttrs, "relax.attrs.StatisticalAttrs") { + TVM_ATTR_FIELD(axis).describe("The axis or axes along which to perform the reduction."); + TVM_ATTR_FIELD(keepdims).describe( + "If this is set to `True`, the reduced axes are left in the result as dimension with size " + "one."); + } +}; // struct StatisticalAttrs + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_ATTRS_STATISTICAL_H_ diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 344576fe13b2..68152c2056e1 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -24,6 +24,7 @@ from .index import * from .manipulate import * from .op_attrs import * +from .statistical import * from .set import * from .ternary import * from .unary import * diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index fb64443b7e09..1fb8853040fd 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -34,6 +34,11 @@ class StridedSliceAttrs(Attrs): """Attributes used in strided_slice operator""" +@tvm._ffi.register_object("relax.attrs.StatisticalAttrs") +class StatisticalAttrs(Attrs): + """Attributes used in statistical operator""" + + @tvm._ffi.register_object("relax.attrs.Resize2DAttrs") class Resize2DAttrs(Attrs): """Attributes used in image resize2d operator""" diff --git a/python/tvm/relax/op/statistical.py b/python/tvm/relax/op/statistical.py new file mode 100644 index 000000000000..4669c783adda --- /dev/null +++ b/python/tvm/relax/op/statistical.py @@ -0,0 +1,218 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=redefined-builtin +"""Statistical operators.""" +from typing import List, Optional, Union + +from . import _ffi_api +from ..expr import Expr + + +def max(x: Expr, axis: Optional[Union[int, List[int]]] = None, keepdims: bool = False) -> Expr: + """Computes the max of tensor elements over given axes. + + Parameters + ---------- + x : relax.Expr + The input data tensor + + axis : Optional[Union[int, List[int]]] + Axis or axes along which a max operation is performed. + The default, axis=None, will compute the max of all elements in the input tensor. + Negative indexing is supported. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as dimensions + with size one. + With this option, the result will broadcast correctly against the input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(axis, int): + axis = [axis] + return _ffi_api.max(x, axis, keepdims) # type: ignore + + +def mean(x: Expr, axis: Optional[Union[int, List[int]]] = None, keepdims: bool = False) -> Expr: + """Computes the mean of tensor elements over given axes. + + Parameters + ---------- + x : relax.Expr + The input data tensor + + axis : Optional[Union[int, List[int]]] + Axis or axes along which a mean operation is performed. + The default, axis=None, will compute the mean of all elements in the input tensor. + Negative indexing is supported. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as dimensions + with size one. + With this option, the result will broadcast correctly against the input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(axis, int): + axis = [axis] + return _ffi_api.mean(x, axis, keepdims) # type: ignore + + +def min(x: Expr, axis: Optional[Union[int, List[int]]] = None, keepdims: bool = False) -> Expr: + """Computes the min of tensor elements over given axes. + + Parameters + ---------- + x : relax.Expr + The input data tensor + + axis : Optional[Union[int, List[int]]] + Axis or axes along which a min operation is performed. + The default, axis=None, will compute the min of all elements in the input tensor. + Negative indexing is supported. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as dimensions + with size one. + With this option, the result will broadcast correctly against the input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(axis, int): + axis = [axis] + return _ffi_api.min(x, axis, keepdims) # type: ignore + + +def prod(x: Expr, axis: Optional[Union[int, List[int]]] = None, keepdims: bool = False) -> Expr: + """Computes the product of tensor elements over given axes. + + Parameters + ---------- + x : relax.Expr + The input data tensor + + axis : Optional[Union[int, List[int]]] + Axis or axes along which a product is performed. + The default, axis=None, will compute the product of all elements of the input tensor. + Negative indexing is supported. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as + dimensions with size one. + With this option, the result will broadcast correctly against the input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(axis, int): + axis = [axis] + return _ffi_api.prod(x, axis, keepdims) # type: ignore + + +def std(x: Expr, axis: Optional[Union[int, List[int]]] = None, keepdims: bool = False) -> Expr: + """Computes the standard deviation of tensor elements over given axes. + + Parameters + ---------- + x : relax.Expr + The input data tensor + + axis : Optional[Union[int, List[int]]] + Axis or axes along which a standard deviation is performed. + The default, axis=None, will compute the std of all elements of the input tensor. + Negative indexing is supported. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as + dimensions with size one. + With this option, the result will broadcast correctly against the input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(axis, int): + axis = [axis] + return _ffi_api.std(x, axis, keepdims) # type: ignore + + +def sum(x: Expr, axis: Optional[Union[int, List[int]]] = None, keepdims: bool = False) -> Expr: + """Computes the sum of tensor elements over given axes. + + Parameters + ---------- + x : relax.Expr + The input data tensor + + axis : Optional[Union[int, List[int]]] + Axis or axes along which a sum is performed. + The default, axis=None, will sum all of the elements of the input tensor. + Negative indexing is supported. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as + dimensions with size one. + With this option, the result will broadcast correctly against the input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(axis, int): + axis = [axis] + return _ffi_api.sum(x, axis, keepdims) # type: ignore + + +def variance(x: Expr, axis: Optional[Union[int, List[int]]] = None, keepdims: bool = False) -> Expr: + """Computes the variance of tensor elements over given axes. + + Parameters + ---------- + x : relax.Expr + The input data tensor + + axis : Optional[Union[int, List[int]]] + Axis or axes along which a variance operation is performed. + The default, axis=None, will compute the variance of all elements in the input tensor. + Negative indexing is supported. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as dimensions + with size one. + With this option, the result will broadcast correctly against the input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(axis, int): + axis = [axis] + return _ffi_api.variance(x, axis, keepdims) # type: ignore diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index a5cb574a06f0..47779a602452 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -63,15 +63,24 @@ less_equal, log, make_closure, + max, + mean, memory, + min, multiply, negative, not_equal, null_value, print, + prod, reshape, round, shape_of, + std, + strided_slice, + sum, + take, + variance, sigmoid, sign, sin, @@ -486,7 +495,10 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "less_equal", "log", "make_closure", + "max", + "mean", "memory", + "min", "multiply", "negative", "not_equal", @@ -494,10 +506,15 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "output", "prim_value", "print", + "prod", "reshape", "round", "shape", "shape_of", + "std", + "str", + "strided_slice", + "sum", "sigmoid", "sign", "sin", @@ -511,5 +528,6 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "tan", "tanh", "tuple", + "variance", "unique", ] diff --git a/src/relax/op/tensor/statistical.cc b/src/relax/op/tensor/statistical.cc new file mode 100644 index 000000000000..41b99fbe36c1 --- /dev/null +++ b/src/relax/op/tensor/statistical.cc @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file statistical.cc + * \brief Statistical operators. + */ + +#include "statistical.h" + +#include + +namespace tvm { +namespace relax { + +StructInfo InferStructInfoStatistical(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + const auto* attrs = call->attrs.as(); + + std::vector axes; + if (!data_sinfo->IsUnknownNdim() && attrs->axis.defined()) { + axes = NormalizeAxes(call, ctx, data_sinfo->ndim, attrs->axis.value()); + } + + int out_ndim; + if (attrs->keepdims) { + out_ndim = data_sinfo->ndim; + } else if (!attrs->axis.defined()) { + out_ndim = 0; + } else if (data_sinfo->IsUnknownNdim()) { + out_ndim = kUnknownNDim; + } else { + out_ndim = data_sinfo->ndim - axes.size(); + ICHECK_GE(out_ndim, 0); + } + + // The inference rule for reduction operator output shapes: + // - axes is None, keepdims is false -> return the zero-rank shape; + // - axes is None, keepdims is true -> return the shape whose ndim is the same as input and every + // value is 1. + // - axes is not None, keepdims is false -> the returned shape does not contain the input axes. + // - axes is not None, keepdims is true -> the returned shape has value 1 at the positions of the + // input axes + const auto* data_shape = data_sinfo->shape.as(); + if (data_shape == nullptr) { + if (!attrs->axis.defined() && attrs->keepdims && out_ndim != kUnknownNDim) { + return TensorStructInfo( + ShapeExpr(Array(out_ndim, IntImm(DataType::Int(64), /*value=*/1))), + data_sinfo->dtype); + } else { + return out_ndim == 0 ? TensorStructInfo(ShapeExpr(Array()), data_sinfo->dtype) + : TensorStructInfo(data_sinfo->dtype, out_ndim); + } + } + + Array out_shape; + out_shape.reserve(out_ndim); + for (int i = 0; i < data_sinfo->ndim; ++i) { + if (attrs->axis.defined() && std::find(axes.begin(), axes.end(), i) == axes.end()) { + out_shape.push_back(data_shape->values[i]); + } else if (attrs->keepdims) { + out_shape.push_back(IntImm(DataType::Int(64), /*value=*/1)); + } + } + ICHECK_EQ(static_cast(out_shape.size()), out_ndim); + return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype); +} + +TVM_REGISTER_NODE_TYPE(StatisticalAttrs); + +RELAX_REGISTER_STATISTICAL_OP_INTERFACE(max); +RELAX_REGISTER_STATISTICAL_OP_INTERFACE(mean); +RELAX_REGISTER_STATISTICAL_OP_INTERFACE(min); +RELAX_REGISTER_STATISTICAL_OP_INTERFACE(prod); +RELAX_REGISTER_STATISTICAL_OP_INTERFACE(std); +RELAX_REGISTER_STATISTICAL_OP_INTERFACE(sum); +RELAX_REGISTER_STATISTICAL_OP_INTERFACE(variance); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/tensor/statistical.h b/src/relax/op/tensor/statistical.h new file mode 100644 index 000000000000..7d322d11293c --- /dev/null +++ b/src/relax/op/tensor/statistical.h @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file statistical.h + * \brief The functions to make Relax statistical operator calls. + */ +#ifndef TVM_RELAX_OP_TENSOR_STATISTICAL_H_ +#define TVM_RELAX_OP_TENSOR_STATISTICAL_H_ + +#include + +#include +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! + * \brief Quick helper macro + * - Expose a make function to construct the node. + * - Register op to the registry. + * \param OpName The name of operator to register. The name passed in will + * 1. be prepended with a prefix "relax.op." as the FFI identifier string for the make function, + * 2. be prepended with a prefix "relax." as the identifier string in the operator registry. + */ +#define RELAX_REGISTER_STATISTICAL_OP_INTERFACE(OpName) \ + Expr OpName(Expr x, Optional> axis, bool keepdims) { \ + ObjectPtr attrs = make_object(); \ + attrs->axis = std::move(axis); \ + attrs->keepdims = keepdims; \ + static const Op& op = Op::Get("relax." #OpName); \ + return Call(op, {std::move(x)}, Attrs{attrs}, {}); \ + } \ + TVM_REGISTER_GLOBAL("relax.op." #OpName).set_body_typed(OpName); \ + TVM_REGISTER_OP("relax." #OpName) \ + .set_num_inputs(1) \ + .add_argument("x", "Tensor", "The input data tensor") \ + .set_attr("FInferStructInfo", InferStructInfoStatistical) + +/*! + * \brief Computes the maximum value of tensor elements over given axes. + * \param x The input data tensor + * \param axis Axis or axes along which a max is performed. Being `NullOpt` means to max all the + * elements of the input tensor + * \param keepdims If this is set to True, the axes which are reduced are left in the result as + * dimensions with size one. With this option, the result will broadcast correctly against the + * input tensor. + * \return The result after reduction. + */ +Expr max(Expr x, Optional> axis, bool keepdims); + +/*! \brief Computes the mean of tensor elements over given axes. */ +Expr mean(Expr x, Optional> axis, bool keepdims); + +/*! \brief Computes the min of tensor elements over given axes. */ +Expr min(Expr x, Optional> axis, bool keepdims); + +/*! \brief Computes the product of tensor elements over given axes. */ +Expr prod(Expr x, Optional> axis, bool keepdims); + +/*! \brief Computes the standard deviation of tensor elements over given axes. */ +Expr std(Expr x, Optional> axis, bool keepdims); + +/*! \brief Computes the sum of tensor elements over given axes. */ +Expr sum(Expr x, Optional> axis, bool keepdims); + +/*! \brief Computes the variance of tensor elements over given axes. */ +Expr variance(Expr x, Optional> axis, bool keepdims); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_TENSOR_STATISTICAL_H_ diff --git a/tests/python/relax/test_op_statistical.py b/tests/python/relax/test_op_statistical.py new file mode 100644 index 000000000000..b1bdd8e44d85 --- /dev/null +++ b/tests/python/relax/test_op_statistical.py @@ -0,0 +1,204 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + assert relax.op.max(x).op == Op.get("relax.max") + assert relax.op.mean(x).op == Op.get("relax.mean") + assert relax.op.min(x).op == Op.get("relax.min") + assert relax.op.prod(x).op == Op.get("relax.prod") + assert relax.op.std(x).op == Op.get("relax.std") + assert relax.op.sum(x).op == Op.get("relax.sum") + assert relax.op.variance(x).op == Op.get("relax.variance") + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_statistical_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3, 4, 5))) + + _check_inference(bb, relax.op.sum(x0, axis=[1, 2]), relax.TensorStructInfo((2, 5), "float32")) + _check_inference( + bb, + relax.op.sum(x0, axis=[1, 2], keepdims=True), + relax.TensorStructInfo((2, 1, 1, 5), "float32"), + ) + _check_inference(bb, relax.op.sum(x0, axis=None), relax.TensorStructInfo((), "float32")) + _check_inference( + bb, + relax.op.sum(x0, axis=None, keepdims=True), + relax.TensorStructInfo((1, 1, 1, 1), "float32"), + ) + _check_inference( + bb, relax.op.mean(x1, axis=[1, 2]), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, + relax.op.mean(x1, axis=[1, 2], keepdims=True), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference(bb, relax.op.mean(x1, axis=None), relax.TensorStructInfo((), "float32")) + _check_inference( + bb, + relax.op.mean(x1, axis=None, keepdims=True), + relax.TensorStructInfo((1, 1, 1, 1), "float32"), + ) + _check_inference( + bb, relax.op.variance(x2, axis=[1, 2]), relax.TensorStructInfo(dtype="float32") + ) + _check_inference( + bb, + relax.op.variance(x2, axis=[1, 2], keepdims=True), + relax.TensorStructInfo(dtype="float32"), + ) + _check_inference(bb, relax.op.variance(x2, axis=None), relax.TensorStructInfo((), "float32")) + _check_inference( + bb, + relax.op.variance(x2, axis=None, keepdims=True), + relax.TensorStructInfo(dtype="float32"), + ) + _check_inference(bb, relax.op.max(x3, axis=[1, 2]), relax.TensorStructInfo((2, 5), dtype="")) + _check_inference( + bb, + relax.op.max(x3, axis=[1, 2], keepdims=True), + relax.TensorStructInfo((2, 1, 1, 5), dtype=""), + ) + _check_inference(bb, relax.op.max(x3, axis=None), relax.TensorStructInfo((), dtype="")) + _check_inference( + bb, + relax.op.max(x3, axis=None, keepdims=True), + relax.TensorStructInfo((1, 1, 1, 1), dtype=""), + ) + _check_inference(bb, relax.op.prod(x0, axis=[1, 2]), relax.TensorStructInfo((2, 5), "float32")) + _check_inference( + bb, + relax.op.prod(x0, axis=[1, 2], keepdims=True), + relax.TensorStructInfo((2, 1, 1, 5), "float32"), + ) + _check_inference(bb, relax.op.std(x0, axis=[1, 2]), relax.TensorStructInfo((2, 5), "float32")) + _check_inference( + bb, + relax.op.std(x0, axis=[1, 2], keepdims=True), + relax.TensorStructInfo((2, 1, 1, 5), "float32"), + ) + _check_inference(bb, relax.op.sum(x0, axis=[-1, -4]), relax.TensorStructInfo((3, 4), "float32")) + _check_inference(bb, relax.op.sum(x0, axis=[]), relax.TensorStructInfo((2, 3, 4, 5), "float32")) + + +def test_statistical_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + c = tir.Var("c", "int64") + d = tir.Var("d", "int64") + x = relax.Var("x", R.Tensor((a, b, c, d), "float32")) + + _check_inference(bb, relax.op.min(x, axis=[1, 2]), relax.TensorStructInfo((a, d), "float32")) + _check_inference( + bb, + relax.op.min(x, axis=[1, 2], keepdims=True), + relax.TensorStructInfo((a, 1, 1, d), "float32"), + ) + _check_inference(bb, relax.op.min(x, axis=None), relax.TensorStructInfo((), "float32")) + _check_inference( + bb, + relax.op.min(x, axis=None, keepdims=True), + relax.TensorStructInfo((1, 1, 1, 1), "float32"), + ) + + +def test_statistical_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s1 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + _check_inference(bb, relax.op.max(x0), relax.TensorStructInfo((), dtype="float32")) + _check_inference( + bb, relax.op.max(x0, keepdims=True), relax.TensorStructInfo((1, 1, 1, 1), dtype="float32") + ) + _check_inference( + bb, relax.op.max(x0, axis=[2, 3]), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, + relax.op.max(x0, axis=[2, 3], keepdims=True), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference(bb, relax.op.max(x1), relax.TensorStructInfo((), dtype="float32")) + _check_inference(bb, relax.op.max(x1, keepdims=True), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.max(x1, axis=[2, 3]), relax.TensorStructInfo(dtype="float32")) + _check_inference( + bb, relax.op.max(x1, axis=[2, 3], keepdims=True), relax.TensorStructInfo(dtype="float32") + ) + + +def test_statistical_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 4, 5), "int8")) + + _check_inference(bb, relax.op.sum(x0), relax.TensorStructInfo((), "float16")) + _check_inference(bb, relax.op.sum(x1), relax.TensorStructInfo((), "int8")) + + +def test_statistical_infer_struct_info_axis_out_of_range_repetitive(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.mean(x0, axis=[4])) + with pytest.raises(TVMError): + bb.normalize(relax.op.mean(x1, axis=[3, 3])) + with pytest.raises(TVMError): + bb.normalize(relax.op.mean(x0, axis=[-1, 3])) + with pytest.raises(TVMError): + bb.normalize(relax.op.mean(x1, axis=[-4, -4])) + with pytest.raises(TVMError): + bb.normalize(relax.op.mean(x0, axis=[-5])) + + +def test_statistical_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4, 5), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.variance(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.variance(x1)) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_statistical.py b/tests/python/relax/test_tvmscript_parser_op_statistical.py new file mode 100644 index 000000000000..221d2a17a8b8 --- /dev/null +++ b/tests/python/relax/test_tvmscript_parser_op_statistical.py @@ -0,0 +1,174 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Union + +import tvm +import tvm.script +import tvm.testing +from tvm import IRModule, relax +from tvm.script import relax as R + + +def _check( + parsed: Union[relax.Function, IRModule], + expect: Optional[Union[relax.Function, IRModule]], +): + test = parsed.script(show_meta=True) + roundtrip_mod = tvm.script.from_source(test) + tvm.ir.assert_structural_equal(parsed, roundtrip_mod) + if expect: + tvm.ir.assert_structural_equal(parsed, expect) + + +def test_sum(): + @R.function + def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((1, 3), "float32"): + gv: R.Tensor((1, 3), "float32") = R.sum(x, axis=[1, 3]) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.sum(x, axis=[1, 3])) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_sum_without_specified_axis(): + @R.function + def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((), "float32"): + gv: R.Tensor((), "float32") = R.sum(x) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.sum(x)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_sum_keep_dims(): + @R.function + def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((1, 1, 3, 1), "float32"): + gv: R.Tensor((1, 1, 3, 1), "float32") = R.sum(x, axis=[1, 3], keepdims=True) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.sum(x, axis=[1, 3], keepdims=True)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_mean(): + @R.function + def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((1, 3), "float32"): + gv: R.Tensor((1, 3), "float32") = R.mean(x, axis=[1, 3]) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.mean(x, axis=[1, 3])) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_variance(): + @R.function + def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((1,), "float32"): + gv: R.Tensor((1,), "float32") = R.variance(x, axis=[-1, -2, -3]) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.variance(x, axis=[-1, -2, -3])) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_max(): + @R.function + def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((1, 1, 1, 1), "float32"): + gv: R.Tensor((1, 1, 1, 1), "float32") = R.variance(x, axis=[-1, -2, -3], keepdims=True) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.variance(x, axis=[-1, -2, -3], keepdims=True)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_min(): + @R.function + def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((1, 3, 4), "float32"): + gv: R.Tensor((1, 3, 4), "float32") = R.min(x, axis=1) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.min(x, axis=1)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_prod(): + @R.function + def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((1, 3, 4), "float32"): + gv: R.Tensor((1, 3, 4), "float32") = R.prod(x, axis=1) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.prod(x, axis=1)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_std(): + @R.function + def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((1, 3, 4), "float32"): + gv: R.Tensor((1, 3, 4), "float32") = R.std(x, axis=1) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.std(x, axis=1)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +if __name__ == "__main__": + tvm.testing.main()