diff --git a/include/tvm/relax/attrs/datatype.h b/include/tvm/relax/attrs/datatype.h new file mode 100644 index 000000000000..79cb345688c9 --- /dev/null +++ b/include/tvm/relax/attrs/datatype.h @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/attrs/datatype.h + * \brief Attributes for datatype operators. + */ +#ifndef TVM_RELAX_ATTRS_DATATYPE_H_ +#define TVM_RELAX_ATTRS_DATATYPE_H_ + +#include + +namespace tvm { +namespace relax { + +/*! \brief Attributes used in astype operator */ +struct AstypeAttrs : public tvm::AttrsNode { + DataType dtype; + + TVM_DECLARE_ATTRS(AstypeAttrs, "relax.attrs.AstypeAttrs") { + TVM_ATTR_FIELD(dtype).describe("Target data type"); + } +}; // struct AstypeAttrs. + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_ATTRS_DATATYPE_H_ diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 3393a5dcae67..f3ab9085b87e 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -20,6 +20,7 @@ # Operators from .base import * from .binary import * +from .datatype import * from .index import * from .manipulate import * from .op_attrs import * diff --git a/python/tvm/relax/op/datatype.py b/python/tvm/relax/op/datatype.py new file mode 100644 index 000000000000..5c02776dd7ee --- /dev/null +++ b/python/tvm/relax/op/datatype.py @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Datatype operators.""" +from typing import Union + +from tvm import DataType + +from . import _ffi_api +from ..expr import Expr + + +def astype(x: Expr, dtype: Union[str, DataType]) -> Expr: + """Cast input tensor to the given data type. + + Parameters + ---------- + x : relax.Expr + The input data to the operator. + + dtype: Union[str, DataType] + The target data type + + Returns + ------- + result : relax.Expr + The casted result. + """ + return _ffi_api.astype(x, dtype) # type: ignore diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index 44cb2cf3a5b4..cb3336394407 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -19,6 +19,11 @@ import tvm._ffi +@tvm._ffi.register_object("relax.attrs.AstypeAttrs") +class AstypeAttrs(Attrs): + """Attributes used in astype operator""" + + @tvm._ffi.register_object("relax.attrs.TakeAttrs") class TakeAttrs(Attrs): """Attributes used in take operator""" diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 75a00ea04985..aaee0f4e2f89 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -31,6 +31,7 @@ from tvm.relax.op import ( add, assert_op, + astype, builtin, call_builtin_with_ctx, call_tir, @@ -403,6 +404,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "add", "arg", "assert_op", + "astype", "builtin", "call_packed", "call_tir", diff --git a/src/relax/op/tensor/datatype.cc b/src/relax/op/tensor/datatype.cc new file mode 100644 index 000000000000..0c647aa866be --- /dev/null +++ b/src/relax/op/tensor/datatype.cc @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file datatype.cc + * \brief Datatype operators. + */ + +#include "datatype.h" + +#include + +namespace tvm { +namespace relax { + +/* relax.astype */ +TVM_REGISTER_NODE_TYPE(AstypeAttrs); + +Expr astype(Expr x, DataType dtype) { + ObjectPtr attrs = make_object(); + attrs->dtype = dtype; + + static const Op& op = Op::Get("relax.astype"); + return Call(op, {std::move(x)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.astype").set_body_typed(astype); + +StructInfo InferStructInfoAstype(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo sinfo = GetUnaryInputTensorStructInfo(call, ctx); + const auto* attrs = call->attrs.as(); + ObjectPtr new_sinfo = make_object(*sinfo.get()); + new_sinfo->dtype = attrs->dtype; + return TensorStructInfo(new_sinfo); +} + +TVM_REGISTER_OP("relax.astype") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("x", "Tensor", "The input tensor") + .set_attr("FInferStructInfo", InferStructInfoAstype); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/tensor/datatype.h b/src/relax/op/tensor/datatype.h new file mode 100644 index 000000000000..6afa7a50d462 --- /dev/null +++ b/src/relax/op/tensor/datatype.h @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file datatype.h + * \brief The functions to make Relax datatype operator calls. + */ +#ifndef TVM_RELAX_OP_TENSOR_DATATYPE_H_ +#define TVM_RELAX_OP_TENSOR_DATATYPE_H_ + +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! + * \brief Cast input tensor to the given data type. + * \param x The input data to the operator. + * \param dtype The target data type + * \return The casted result. + */ +Expr astype(Expr x, DataType dtype); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_TENSOR_DATATYPE_H_ diff --git a/tests/python/relax/test_op_datatype.py b/tests/python/relax/test_op_datatype.py new file mode 100644 index 000000000000..56bbe464cf20 --- /dev/null +++ b/tests/python/relax/test_op_datatype.py @@ -0,0 +1,105 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + x = relax.Var("x", R.Tensor((2, 3), "float32")) + assert relax.op.astype(x, "float16").op == Op.get("relax.astype") + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_astype_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=2)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3))) + x4 = relax.Var("x", R.Tensor(ndim=2)) + x5 = relax.Var("x", R.Tensor()) + + _check_inference(bb, relax.op.astype(x0, "float16"), relax.TensorStructInfo((2, 3), "float16")) + _check_inference( + bb, relax.op.astype(x1, "float16"), relax.TensorStructInfo(dtype="float16", ndim=2) + ) + _check_inference(bb, relax.op.astype(x2, "float16"), relax.TensorStructInfo(dtype="float16")) + _check_inference(bb, relax.op.astype(x3, "float16"), relax.TensorStructInfo((2, 3), "float16")) + _check_inference( + bb, relax.op.astype(x4, "float16"), relax.TensorStructInfo(dtype="float16", ndim=2) + ) + _check_inference(bb, relax.op.astype(x5, "float16"), relax.TensorStructInfo(dtype="float16")) + + +def test_astype_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x0 = relax.Var("x", R.Tensor((m, n), "float32")) + x1 = relax.Var("x", R.Tensor((m, n))) + + _check_inference(bb, relax.op.astype(x0, "float16"), relax.TensorStructInfo((m, n), "float16")) + _check_inference(bb, relax.op.astype(x1, "float16"), relax.TensorStructInfo((m, n), "float16")) + + +def test_astype_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 3))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference(bb, relax.op.astype(x0, "float16"), relax.TensorStructInfo(s0, "float16")) + _check_inference(bb, relax.op.astype(x1, "float16"), relax.TensorStructInfo(s1, "float16")) + _check_inference(bb, relax.op.astype(x2, "float16"), relax.TensorStructInfo(s2, "float16")) + + +def test_astype_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3), "int32")) + + _check_inference(bb, relax.op.astype(x0, "float32"), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.astype(x1, "int32"), relax.TensorStructInfo((2, 3), "int32")) + _check_inference(bb, relax.op.astype(x2, "int8"), relax.TensorStructInfo((2, 3), "int8")) + + +def test_astype_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.astype(x0, "float16")) + with pytest.raises(TVMError): + bb.normalize(relax.op.astype(x1, "float16")) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_datatype.py b/tests/python/relax/test_tvmscript_parser_op_datatype.py new file mode 100644 index 000000000000..ec71e868d45b --- /dev/null +++ b/tests/python/relax/test_tvmscript_parser_op_datatype.py @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Union + +import tvm +import tvm.script +import tvm.testing +from tvm import IRModule, relax +from tvm.script import relax as R + + +def _check( + parsed: Union[relax.Function, IRModule], + expect: Optional[Union[relax.Function, IRModule]], +): + test = parsed.script(show_meta=True) + roundtrip_mod = tvm.script.from_source(test) + tvm.ir.assert_structural_equal(parsed, roundtrip_mod) + if expect: + tvm.ir.assert_structural_equal(parsed, expect) + + +def test_astype(): + @R.function + def expected(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 3, 4), "float16"): + gv: R.Tensor((2, 3, 4), "float16") = R.astype(x, "float16") + return gv + + x = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("main", [x]): + gv = bb.emit(relax.op.astype(x, "float16")) + bb.emit_func_output(gv) + + _check(expected, bb.get()["main"]) + + +if __name__ == "__main__": + tvm.testing.main()