diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 4b2f990eaa27..39a645ffea54 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -27,6 +27,7 @@ from .manipulate import * from .op_attrs import * from .statistical import * +from .search import * from .set import * from .ternary import * from .unary import * diff --git a/python/tvm/relax/op/search.py b/python/tvm/relax/op/search.py new file mode 100644 index 000000000000..8252b0e1d851 --- /dev/null +++ b/python/tvm/relax/op/search.py @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Search operators.""" +from . import _ffi_api +from ..expr import Expr + + +def where(condition: Expr, x1: Expr, x2: Expr) -> Expr: + """Selecting elements from either the input tensors depending on the value of the + condition. + + For a given position, return the corresponding value in `x1` if `condition` is True, + and return the corresponding value in `x2` otherwise. + + Parameters + ---------- + condition : relax.Expr + When True, yield `x1`; otherwise, yield `x2`. + Must be broadcasting compatible with `x1` and `x2`. + Must have boolean dtype. + + x1 : relax.Expr + The first input tensor. + Must be broadcasting compatible with `condition` and `x2`. + + x2 : relax.Expr + The second input tensor. + Must be broadcasting compatible with `condition` and `x1`. + + Returns + ------- + result : relax.Expr + The result tensor. + """ + return _ffi_api.where(condition, x1, x2) # type: ignore diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 9f5fe03decfb..b779bdac9c13 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -101,6 +101,7 @@ tril, triu, unique, + where, zeros, zeros_like, nn, @@ -547,8 +548,9 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "tril", "triu", "tuple", - "variance", "unique", + "variance", + "where", "zeros", "zeros_like", "nn", diff --git a/src/relax/op/tensor/search.cc b/src/relax/op/tensor/search.cc new file mode 100644 index 000000000000..5191017ea17f --- /dev/null +++ b/src/relax/op/tensor/search.cc @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file search.cc + * \brief Searching operators. + */ + +#include "search.h" + +#include +#include + +namespace tvm { +namespace relax { + +/* relax.where */ +Expr where(Expr condition, Expr x1, Expr x2) { + static const Op& op = Op::Get("relax.where"); + return Call(op, {std::move(condition), std::move(x1), std::move(x2)}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.where").set_body_typed(where); + +StructInfo InferStructInfoWhere(const Call& call, const BlockBuilder& ctx) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + TensorStructInfo cond_sinfo = input_sinfo[0]; + TensorStructInfo x1_sinfo = input_sinfo[1]; + TensorStructInfo x2_sinfo = input_sinfo[2]; + + if (!cond_sinfo->dtype.is_bool()) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Where requires the input condition tensor to have boolean dtype. However, " + "the given condition dtype is " + << cond_sinfo->dtype); + } + DataType output_dtype = InferBinaryArithOpOutDtype(call, ctx, x1_sinfo, x2_sinfo); + + int output_ndim; + if (cond_sinfo->IsUnknownNdim() || x1_sinfo->IsUnknownNdim() || x2_sinfo->IsUnknownNdim()) { + output_ndim = kUnknownNDim; + } else { + output_ndim = std::max(cond_sinfo->ndim, std::max(x1_sinfo->ndim, x2_sinfo->ndim)); + } + + const auto* cond_shape = cond_sinfo->shape.as(); + const auto* x1_shape = x1_sinfo->shape.as(); + const auto* x2_shape = x2_sinfo->shape.as(); + if (cond_shape && x1_shape && x2_shape) { + // Step 1. Compute the broadcasted shape of x1's and x2's + Optional> broadcasted_shape = + InferBinaryBroadcastShape(call, ctx, x1_shape->values, x2_shape->values); + if (!broadcasted_shape.defined()) { + return TensorStructInfo(output_dtype, output_ndim); + } + // Step 2. Compute the broadcasted shape of cond's and the previous broadcasted shape. + broadcasted_shape = + InferBinaryBroadcastShape(call, ctx, cond_shape->values, broadcasted_shape.value()); + if (!broadcasted_shape.defined()) { + return TensorStructInfo(output_dtype, output_ndim); + } + ICHECK_EQ(static_cast(broadcasted_shape.value().size()), output_ndim); + return TensorStructInfo(ShapeExpr(broadcasted_shape.value()), output_dtype); + } else if (cond_sinfo->shape.defined() && // + x1_sinfo->shape.defined() && // + x2_sinfo->shape.defined() && // + cond_sinfo->shape.same_as(x1_sinfo->shape) && // + cond_sinfo->shape.same_as(x2_sinfo->shape)) { + return TensorStructInfo(cond_sinfo->shape.value(), output_dtype); + } else { + return TensorStructInfo(output_dtype, output_ndim); + } +} + +TVM_REGISTER_OP("relax.where") + .set_num_inputs(3) + .add_argument("condition", "Tensor", "When True, yield `x1`; otherwise, yield `x2`.") + .add_argument("x1", "Tensor", "The first input tensor.") + .add_argument("x2", "Tensor", "The second input tensor.") + .set_attr("FInferStructInfo", InferStructInfoWhere); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/tensor/search.h b/src/relax/op/tensor/search.h new file mode 100644 index 000000000000..aeae4a7157b3 --- /dev/null +++ b/src/relax/op/tensor/search.h @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file search.h + * \brief The functions to make Relax searching operator calls. + */ +#ifndef TVM_RELAX_OP_TENSOR_SEARCH_H_ +#define TVM_RELAX_OP_TENSOR_SEARCH_H_ + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! + * \brief Selecting elements from either the input tensors depending on the value of the + * condition. + */ +Expr where(Expr condition, Expr x1, Expr x2); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_TENSOR_SEARCH_H_ diff --git a/tests/python/relax/test_op_search.py b/tests/python/relax/test_op_search.py new file mode 100644 index 000000000000..a2f271671ba6 --- /dev/null +++ b/tests/python/relax/test_op_search.py @@ -0,0 +1,278 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + cond = relax.Var("cond", R.Tensor((2, 3), "bool")) + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y = relax.Var("x", R.Tensor((2, 3), "float32")) + assert relax.op.where(cond, x, y).op == Op.get("relax.where") + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_where_infer_struct_info(): + bb = relax.BlockBuilder() + cond0 = relax.Var("cond", R.Tensor((6, 5, 1, 3, 1), "bool")) + cond1 = relax.Var("cond", R.Tensor("bool", ndim=5)) + cond2 = relax.Var("cond", R.Tensor("bool")) + x0 = relax.Var("x", R.Tensor((5, 1, 3, 2), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((5, 1, 3, 2))) + x4 = relax.Var("x", R.Tensor(ndim=4)) + x5 = relax.Var("x", R.Tensor()) + y0 = relax.Var("y", R.Tensor((4, 3, 1), "float32")) + y1 = relax.Var("y", R.Tensor("float32", ndim=3)) + y2 = relax.Var("y", R.Tensor("float32")) + y3 = relax.Var("y", R.Tensor((4, 3, 1))) + y4 = relax.Var("y", R.Tensor(ndim=3)) + y5 = relax.Var("y", R.Tensor()) + + _check_inference( + bb, relax.op.where(cond0, x0, y0), relax.TensorStructInfo((6, 5, 4, 3, 2), "float32") + ) + _check_inference( + bb, relax.op.where(cond0, x1, y0), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference(bb, relax.op.where(cond0, x2, y0), relax.TensorStructInfo(dtype="float32")) + _check_inference( + bb, relax.op.where(cond0, x3, y0), relax.TensorStructInfo((6, 5, 4, 3, 2), dtype="") + ) + _check_inference(bb, relax.op.where(cond0, x4, y0), relax.TensorStructInfo(dtype="", ndim=5)) + _check_inference(bb, relax.op.where(cond0, x5, y0), relax.TensorStructInfo(dtype="")) + _check_inference( + bb, relax.op.where(cond0, x1, y1), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference(bb, relax.op.where(cond0, x2, y1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.where(cond0, x3, y1), relax.TensorStructInfo(dtype="", ndim=5)) + _check_inference(bb, relax.op.where(cond0, x4, y1), relax.TensorStructInfo(dtype="", ndim=5)) + _check_inference(bb, relax.op.where(cond0, x5, y1), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.where(cond0, x2, y2), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.where(cond0, x3, y2), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.where(cond0, x4, y2), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.where(cond0, x5, y2), relax.TensorStructInfo(dtype="")) + _check_inference( + bb, relax.op.where(cond0, x3, y3), relax.TensorStructInfo((6, 5, 4, 3, 2), dtype="") + ) + _check_inference(bb, relax.op.where(cond0, x4, y3), relax.TensorStructInfo(dtype="", ndim=5)) + _check_inference(bb, relax.op.where(cond0, x5, y3), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.where(cond0, x4, y4), relax.TensorStructInfo(dtype="", ndim=5)) + _check_inference(bb, relax.op.where(cond0, x5, y4), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.where(cond0, x5, y5), relax.TensorStructInfo(dtype="")) + _check_inference( + bb, relax.op.where(cond1, x0, y0), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference(bb, relax.op.where(cond1, x2, y0), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.where(cond2, x0, y0), relax.TensorStructInfo(dtype="float32")) + + +def test_where_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + c = tir.Var("c", "int64") + d0 = tir.Var("d", "int64") + d1 = tir.Var("d", "int64") + e = tir.Var("e", "int64") + cond = relax.Var("cond", R.Tensor((a, b, 1, d0, 1), "bool")) + x0 = relax.Var("x", R.Tensor((b, 1, d0, e), "float32")) + x1 = relax.Var("x", R.Tensor((b, 1, d1, e), "float32")) + x2 = relax.Var("x", R.Tensor((b, 1, d0, e))) + y0 = relax.Var("y", R.Tensor((c, d0, 1), "float32")) + y1 = relax.Var("y", R.Tensor((c, d0, 1))) + + _check_inference( + bb, relax.op.where(cond, x0, y0), relax.TensorStructInfo((a, b, c, d0, e), "float32") + ) + _check_inference( + bb, relax.op.where(cond, x1, y0), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference( + bb, relax.op.where(cond, x2, y0), relax.TensorStructInfo((a, b, c, d0, e), dtype="") + ) + _check_inference( + bb, relax.op.where(cond, x0, y1), relax.TensorStructInfo((a, b, c, d0, e), dtype="") + ) + _check_inference(bb, relax.op.where(cond, x1, y1), relax.TensorStructInfo(dtype="", ndim=5)) + _check_inference( + bb, relax.op.where(cond, x2, y1), relax.TensorStructInfo((a, b, c, d0, e), dtype="") + ) + + +def test_where_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + scond0 = relax.Var("scond", relax.ShapeStructInfo((6, 5, 1, 3, 1))) + scond1 = relax.Var("scond", relax.ShapeStructInfo(ndim=5)) + scond2 = relax.Var("scond", relax.ShapeStructInfo()) + sx0 = relax.Var("sx", relax.ShapeStructInfo((5, 1, 3, 2))) + sx1 = relax.Var("sx", relax.ShapeStructInfo(ndim=4)) + sx2 = relax.Var("sx", relax.ShapeStructInfo()) + sy0 = relax.Var("sy", relax.ShapeStructInfo((4, 3, 1))) + sy1 = relax.Var("sy", relax.ShapeStructInfo(ndim=3)) + sy2 = relax.Var("sy", relax.ShapeStructInfo()) + s0 = relax.Var("s", relax.ShapeStructInfo((6, 5, 4, 3, 2))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=5)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + cond0 = relax.Var("cond", relax.TensorStructInfo(scond0, "bool")) + cond1 = relax.Var("cond", relax.TensorStructInfo(scond1, "bool")) + cond2 = relax.Var("cond", relax.TensorStructInfo(scond2, "bool")) + cond3 = relax.Var("cond", relax.TensorStructInfo(s0, "bool")) + cond4 = relax.Var("cond", relax.TensorStructInfo(s1, "bool")) + cond5 = relax.Var("cond", relax.TensorStructInfo(s2, "bool")) + x0 = relax.Var("x", relax.TensorStructInfo(sx0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(sx1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(sx2, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x4 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x5 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + y0 = relax.Var("y", relax.TensorStructInfo(sy0, "float32")) + y1 = relax.Var("y", relax.TensorStructInfo(sy1, "float32")) + y2 = relax.Var("y", relax.TensorStructInfo(sy2, "float32")) + y3 = relax.Var("y", relax.TensorStructInfo(s0, "float32")) + y4 = relax.Var("y", relax.TensorStructInfo(s1, "float32")) + y5 = relax.Var("y", relax.TensorStructInfo(s2, "float32")) + + _check_inference( + bb, relax.op.where(cond0, x0, y0), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference( + bb, relax.op.where(cond0, x0, y1), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference(bb, relax.op.where(cond0, x0, y2), relax.TensorStructInfo(dtype="float32")) + _check_inference( + bb, relax.op.where(cond0, x1, y1), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference(bb, relax.op.where(cond0, x1, y2), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.where(cond0, x2, y2), relax.TensorStructInfo(dtype="float32")) + _check_inference( + bb, relax.op.where(cond1, x1, y1), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference(bb, relax.op.where(cond1, x1, y2), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.where(cond1, x2, y2), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.where(cond2, x2, y2), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.where(cond3, x3, y3), relax.TensorStructInfo(s0, "float32")) + _check_inference( + bb, relax.op.where(cond3, x3, y4), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference( + bb, relax.op.where(cond3, x4, y3), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference( + bb, relax.op.where(cond4, x3, y3), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference(bb, relax.op.where(cond4, x4, y4), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.where(cond4, x4, y5), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.where(cond4, x5, y4), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.where(cond5, x4, y4), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.where(cond5, x5, y5), relax.TensorStructInfo(s2, "float32")) + + +def test_where_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + cond = relax.Var("cond", R.Tensor((6, 5, 1, 3, 1), "bool")) + x0 = relax.Var("x", R.Tensor((5, 1, 3, 2), "float16")) + y0 = relax.Var("y", R.Tensor((4, 3, 1), "float16")) + x1 = relax.Var("x", R.Tensor((5, 1, 3, 2), "int8")) + y1 = relax.Var("y", R.Tensor((4, 3, 1), "int8")) + x2 = relax.Var("x", R.Tensor((5, 1, 3, 2), "int32")) + y2 = relax.Var("y", R.Tensor((4, 3, 1), "int32")) + + _check_inference( + bb, relax.op.where(cond, x0, y0), relax.TensorStructInfo((6, 5, 4, 3, 2), "float16") + ) + _check_inference( + bb, relax.op.where(cond, x1, y1), relax.TensorStructInfo((6, 5, 4, 3, 2), "int8") + ) + _check_inference( + bb, relax.op.where(cond, x2, y2), relax.TensorStructInfo((6, 5, 4, 3, 2), "int32") + ) + + +def test_where_infer_struct_info_cond_not_boolean(): + bb = relax.BlockBuilder() + cond0 = relax.Var("cond", R.Tensor((2, 3), "float32")) + cond1 = relax.Var("cond", R.Tensor((2, 3))) + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y = relax.Var("y", R.Tensor((2, 3), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.where(cond0, x, y)) + with pytest.raises(TVMError): + bb.normalize(relax.op.where(cond1, x, y)) + + +def test_where_infer_struct_info_shape_unequal_const_int(): + bb = relax.BlockBuilder() + cond0 = relax.Var("cond", R.Tensor((6, 5, 1, 4, 1), "bool")) + cond1 = relax.Var("cond", R.Tensor((6, 5, 1, 3, 1), "bool")) + x0 = relax.Var("x", R.Tensor((5, 1, 4, 2), "float32")) + x1 = relax.Var("x", R.Tensor((5, 1, 3, 2), "float32")) + y0 = relax.Var("y", R.Tensor((4, 4, 1), "float32")) + y1 = relax.Var("y", R.Tensor((4, 3, 1), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.where(cond0, x1, y1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.where(cond1, x0, y1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.where(cond1, x1, y0)) + + +def test_where_infer_struct_info_dtype_mismatch(): + bb = relax.BlockBuilder() + cond = relax.Var("cond", R.Tensor((2, 3), "bool")) + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + y0 = relax.Var("y", R.Tensor((2, 3), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3), "int8")) + y1 = relax.Var("y", R.Tensor((2, 3), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.where(cond, x0, y0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.where(cond, x1, y1)) + + +def test_where_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + cond0 = relax.Var("cond", relax.ShapeStructInfo((2, 3))) + cond1 = relax.Var("cond", R.Tensor((2, 3), "bool")) + x0 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + x1 = relax.Var("x", R.Tensor((2, 3), "float32")) + y0 = relax.Var("y", relax.TupleStructInfo([R.Tensor((2, 3), "float32")])) + y1 = relax.Var("y", R.Tensor((2, 3), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.where(cond0, x1, y1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.where(cond1, x0, y1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.where(cond1, x1, y0)) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_search.py b/tests/python/relax/test_tvmscript_parser_op_search.py new file mode 100644 index 000000000000..a8eaa814aa2e --- /dev/null +++ b/tests/python/relax/test_tvmscript_parser_op_search.py @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Union + +import tvm +import tvm.script +import tvm.testing +from tvm import IRModule, relax +from tvm.script import relax as R + + +def _check( + parsed: Union[relax.Function, IRModule], + expect: Optional[Union[relax.Function, IRModule]], +): + test = parsed.script(show_meta=True) + roundtrip_mod = tvm.script.from_source(test) + tvm.ir.assert_structural_equal(parsed, roundtrip_mod) + if expect: + tvm.ir.assert_structural_equal(parsed, expect) + + +def test_where(): + @R.function + def foo( + condition: R.Tensor((2, 1), "bool"), + x: R.Tensor((2, 3), "float32"), + y: R.Tensor((1, 3), "float32"), + ) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.where(condition, x, y) + return gv + + bb = relax.BlockBuilder() + condition = relax.Var("condition", R.Tensor((2, 1), "bool")) + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y = relax.Var("y", R.Tensor((1, 3), "float32")) + with bb.function("foo", [condition, x, y]): + gv = bb.emit(relax.op.where(condition, x, y)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +if __name__ == "__main__": + tvm.testing.main()