diff --git a/include/tvm/script/ir_builder/ir/frame.h b/include/tvm/script/ir_builder/ir/frame.h index 181774bc53bc..887981ccffc8 100644 --- a/include/tvm/script/ir_builder/ir/frame.h +++ b/include/tvm/script/ir_builder/ir/frame.h @@ -29,6 +29,7 @@ namespace tvm { namespace script { namespace ir_builder { +namespace ir { /*! * \brief A frame that represents the IRModule frame with functions and global variables. @@ -64,6 +65,7 @@ class IRModuleFrame : public IRBuilderFrame { IRModuleFrameNode); }; +} // namespace ir } // namespace ir_builder } // namespace script } // namespace tvm diff --git a/include/tvm/script/ir_builder/ir/ir.h b/include/tvm/script/ir_builder/ir/ir.h index 0bd5473c7eaf..f0e7cc6f5c2f 100644 --- a/include/tvm/script/ir_builder/ir/ir.h +++ b/include/tvm/script/ir_builder/ir/ir.h @@ -29,6 +29,7 @@ namespace tvm { namespace script { namespace ir_builder { +namespace ir { /*! * \brief The IRModule declaration statement. @@ -36,6 +37,7 @@ namespace ir_builder { */ TVM_DLL IRModuleFrame IRModule(); +} // namespace ir } // namespace ir_builder } // namespace script } // namespace tvm diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tir/frame.h new file mode 100644 index 000000000000..4bfd022af27a --- /dev/null +++ b/include/tvm/script/ir_builder/tir/frame.h @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_IR_BUILDER_TIR_FRAME_H_ +#define TVM_SCRIPT_IR_BUILDER_TIR_FRAME_H_ + +#include +#include +#include + +namespace tvm { +namespace script { +namespace ir_builder { +namespace tir { + +/*! + * \brief A base frame that represents the TIR fame with body of statements. + * + * \sa TIRFrame + */ +class TIRFrameNode : public IRBuilderFrameNode { + public: + /*! \brief The Stmt within in this frame. */ + Array stmts; + + void VisitAttrs(tvm::AttrVisitor* v) { + IRBuilderFrameNode::VisitAttrs(v); + v->Visit("stmts", &stmts); + } + + static constexpr const char* _type_key = "script.ir_builder.tir.TIRFrame"; + TVM_DECLARE_BASE_OBJECT_INFO(TIRFrameNode, IRBuilderFrameNode); +}; + +/*! + * \brief Managed reference to TIRFrameNode. + * + * \sa TIRFrameNode + */ +class TIRFrame : public IRBuilderFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TIRFrame, IRBuilderFrame, TIRFrameNode); + + protected: + TIRFrame() = default; +}; + +/*! + * \brief A frame that represents the PrimFunc containing TIR statements. + * + * \sa PrimFuncFrame + */ +class PrimFuncFrameNode : public TIRFrameNode { + public: + /*! \brief The name of the block. */ + Optional name; + /*! \brief Function parameters. */ + Array args; + /*! \brief The return type of the function. */ + Optional ret_type; + /*! \brief Maps some parameters to specific Buffer data structures. */ + Map buffer_map; + /*! \brief The buffer map prior to flattening. */ + Map preflattened_buffer_map; + /*! \brief Additional attributes storing the meta-data */ + Optional> attrs; + /*! \brief The variable map bound to thread env. */ + Map env_threads; + /*! \brief The buffer allocated in root block. */ + Array root_alloc_buffers; + + void VisitAttrs(tvm::AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); + v->Visit("name", &name); + v->Visit("args", &args); + v->Visit("ret_type", &ret_type); + v->Visit("buffer_map", &buffer_map); + v->Visit("preflattened_buffer_map", &preflattened_buffer_map); + v->Visit("attrs", &attrs); + v->Visit("env_threads", &env_threads); + v->Visit("root_alloc_buffers", &root_alloc_buffers); + } + + static constexpr const char* _type_key = "script.ir_builder.tir.PrimFuncFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncFrameNode, TIRFrameNode); + + public: + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ + void ExitWithScope() final; +}; + +/*! + * \brief Managed reference to PrimFuncFrameNode. + * + * \sa PrimFuncFrameNode + */ +class PrimFuncFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PrimFuncFrame, TIRFrame, PrimFuncFrameNode); +}; + +/*! + * \brief A frame that represents the assert statement. Proceeds if the condition is true, + * otherwise aborts with the message. + * + * \sa AssertFrame + */ +class AssertFrameNode : public TIRFrameNode { + public: + /*! \brief The PrimExpr to test. */ + PrimExpr condition; + /*! \brief The output error message when the assertion failed. */ + PrimExpr message; + + void VisitAttrs(tvm::AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); + v->Visit("condition", &condition); + v->Visit("message", &message); + } + + static constexpr const char* _type_key = "script.ir_builder.tir.AssertFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(AssertFrameNode, TIRFrameNode); + + public: + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ + void ExitWithScope() final; +}; + +} // namespace tir +} // namespace ir_builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_IR_BUILDER_TIR_FRAME_H_ diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h new file mode 100644 index 000000000000..cee60ad4f827 --- /dev/null +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_IR_BUILDER_TIR_IR_H_ +#define TVM_SCRIPT_IR_BUILDER_TIR_IR_H_ + +#include +#include +#include + +namespace tvm { +namespace script { +namespace ir_builder { +namespace tir { + +/*! + * \brief The primitive function statement. + * \return The PrimFuncFrame. + */ +PrimFuncFrame PrimFunc(); + +/*! + * \brief Evaluate the input expression. + * \param value The input expression to evaluate. + */ +void Evaluate(PrimExpr value); + +} // namespace tir +} // namespace ir_builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_IR_BUILDER_TIR_IR_H_ diff --git a/python/tvm/script/ir_builder/tir/__init__.py b/python/tvm/script/ir_builder/tir/__init__.py new file mode 100644 index 000000000000..1e43d1af3498 --- /dev/null +++ b/python/tvm/script/ir_builder/tir/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Package tvm.script.ir_builder.tir""" +from .ir import * # pylint: disable=wildcard-import,redefined-builtin diff --git a/python/tvm/script/ir_builder/tir/_ffi_api.py b/python/tvm/script/ir_builder/tir/_ffi_api.py new file mode 100644 index 000000000000..876f5f3a35a0 --- /dev/null +++ b/python/tvm/script/ir_builder/tir/_ffi_api.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs""" +import tvm._ffi + +tvm._ffi._init_api("script.ir_builder.tir", __name__) # pylint: disable=protected-access diff --git a/python/tvm/script/ir_builder/tir/frame.py b/python/tvm/script/ir_builder/tir/frame.py new file mode 100644 index 000000000000..61418e0b2aa6 --- /dev/null +++ b/python/tvm/script/ir_builder/tir/frame.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""IRBuilder for TIR""" + +from tvm._ffi import register_object as _register_object + +from ..base import IRBuilderFrame + + +@_register_object("script.ir_builder.tir.TIRFrame") +class TIRFrame(IRBuilderFrame): + ... + + +@_register_object("script.ir_builder.tir.PrimFuncFrame") +class PrimFuncFrame(TIRFrame): + ... diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py new file mode 100644 index 000000000000..ae5d5b260f65 --- /dev/null +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +"""IRBuilder for TIR""" + +from tvm.tir import PrimExpr, StringImm + +from . import _ffi_api, frame + + +def prim_func() -> frame.PrimFuncFrame: + """The primitive function statement. + + Returns + ------- + res : frame.PrimFuncFrame + The PrimFuncFrame. + """ + return _ffi_api.PrimFunc() # pylint: disable=no-member # type: ignore + + +def evaluate(value: PrimExpr) -> None: + """Evaluate the input expression. + + Parameters + ---------- + value: PrimExpr + The input expression to evaluate. + """ + if isinstance(value, str): + value = StringImm(value) + return _ffi_api.Evaluate(value) # pylint: disable=no-member # type: ignore + + +# pylint: enable=invalid-name + + +__all__ = [ + "evaluate", + "prim_func", +] diff --git a/src/script/ir_builder/ir/frame.cc b/src/script/ir_builder/ir/frame.cc index c85e30544aca..a81c56922dff 100644 --- a/src/script/ir_builder/ir/frame.cc +++ b/src/script/ir_builder/ir/frame.cc @@ -23,6 +23,7 @@ namespace tvm { namespace script { namespace ir_builder { +namespace ir { void IRModuleFrameNode::ExitWithScope() { ICHECK_EQ(functions.size(), global_vars.size()); @@ -38,6 +39,7 @@ void IRModuleFrameNode::ExitWithScope() { TVM_REGISTER_NODE_TYPE(IRModuleFrameNode); +} // namespace ir } // namespace ir_builder } // namespace script } // namespace tvm diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc index bcd21de144bb..a8cc452e4f0c 100644 --- a/src/script/ir_builder/ir/ir.cc +++ b/src/script/ir_builder/ir/ir.cc @@ -23,6 +23,7 @@ namespace tvm { namespace script { namespace ir_builder { +namespace ir { IRModuleFrame IRModule() { ObjectPtr n = make_object(); @@ -33,6 +34,7 @@ IRModuleFrame IRModule() { TVM_REGISTER_GLOBAL("script.ir_builder.ir.IRModule").set_body_typed(IRModule); +} // namespace ir } // namespace ir_builder } // namespace script } // namespace tvm diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc new file mode 100644 index 000000000000..139c8193b0ba --- /dev/null +++ b/src/script/ir_builder/tir/frame.cc @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include + +#include "../../../tir/ir/script/script_complete.h" +#include "./utils.h" + +namespace tvm { +namespace script { +namespace ir_builder { +namespace tir { + +void PrimFuncFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + tvm::tir::PrimFunc func( + /*params=*/args, + /*body=*/AsStmt(stmts), + /*ret_type=*/ret_type.value_or(TupleType::Empty()), + /*buffer_map=*/buffer_map, + /*preflattened_buffer_map=*/preflattened_buffer_map, + /*attrs=*/attrs.defined() ? DictAttrs(attrs.value()) : NullValue()); + func = tvm::tir::ScriptComplete(func, root_alloc_buffers); + IRBuilder builder = IRBuilder::Current(); + if (builder->frames.empty()) { + ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; + builder->result = func; + } else if (Optional opt_frame = builder->FindFrame()) { + ir::IRModuleFrame frame = opt_frame.value(); + frame->global_vars.push_back(GlobalVar(name.value_or(""))); + frame->functions.push_back(func); + } else { + LOG(FATAL) << "ValueError: Cannot find where to insert PrimFunc"; + } +} + +TVM_REGISTER_NODE_TYPE(TIRFrameNode); +TVM_REGISTER_NODE_TYPE(PrimFuncFrameNode); + +} // namespace tir +} // namespace ir_builder +} // namespace script +} // namespace tvm diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc new file mode 100644 index 000000000000..5f994d71ca0a --- /dev/null +++ b/src/script/ir_builder/tir/ir.cc @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include + +#include "./utils.h" + +namespace tvm { +namespace script { +namespace ir_builder { +namespace tir { + +using tvm::tir::IterVar; + +PrimFuncFrame PrimFunc() { + ObjectPtr n = make_object(); + n->name = NullOpt; + n->args.clear(); + n->ret_type = NullOpt; + n->buffer_map.clear(); + n->preflattened_buffer_map.clear(); + n->attrs = NullOpt; + n->env_threads.clear(); + n->root_alloc_buffers.clear(); + return PrimFuncFrame(n); +} + +void Evaluate(PrimExpr value) { AddToParent(tvm::tir::Evaluate(value)); } +TVM_REGISTER_GLOBAL("script.ir_builder.tir.PrimFunc").set_body_typed(PrimFunc); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Evaluate").set_body_typed(Evaluate); +} // namespace tir +} // namespace ir_builder +} // namespace script +} // namespace tvm diff --git a/src/script/ir_builder/tir/utils.h b/src/script/ir_builder/tir/utils.h new file mode 100644 index 000000000000..47557917cca5 --- /dev/null +++ b/src/script/ir_builder/tir/utils.h @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_IR_BUILDER_TIR_UTILS_H_ +#define TVM_SCRIPT_IR_BUILDER_TIR_UTILS_H_ + +#include +#include +#include + +namespace tvm { +namespace script { +namespace ir_builder { +namespace tir { + +inline void AddToParent(tvm::tir::Stmt stmt) { + IRBuilder builder = IRBuilder::Current(); + if (builder->frames.empty()) { + ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; + builder->result = stmt; + } else if (const auto* tir_frame = builder->frames.back().as()) { + GetRef(tir_frame)->stmts.push_back(stmt); + } else { + LOG(FATAL) << "TypeError: Unsupported frame type: " << builder->frames.back(); + } +} + +inline tvm::tir::Stmt AsStmt(const Array& stmt) { + using namespace tvm::tir; + if (stmt.empty()) { + return tvm::tir::Evaluate(0); + } else if (stmt.size() == 1) { + return stmt[0]; + } else { + return SeqStmt(stmt); + } +} + +inline PrimFuncFrame FindPrimFuncFrame(const String& method) { + if (Optional frame = IRBuilder::Current()->GetLastFrame()) { + return frame.value(); + } + LOG(FATAL) << "ValueError: PrimFunc frame not find. Please ensure '" << method + << "' is called under T.prim_func()"; + throw; +} + +} // namespace tir +} // namespace ir_builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_IR_BUILDER_TIR_UTILS_H_ diff --git a/tests/python/unittest/test_tvmscript_ir_builder_tir.py b/tests/python/unittest/test_tvmscript_ir_builder_tir.py new file mode 100644 index 000000000000..70a8f3565d03 --- /dev/null +++ b/tests/python/unittest/test_tvmscript_ir_builder_tir.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, missing-docstring +"""Unittests for tvm.script.ir_builder.tir""" +import pytest +import tvm.testing +import tvm +from tvm import tir +from tvm.script.ir_builder import tir as T +from tvm.script.ir_builder import IRBuilder +from tvm.ir.base import assert_structural_equal + + +def test_ir_builder_tir_primfunc(): + with IRBuilder() as ib: + with T.prim_func(): + T.evaluate(0) + # the prim_func generated by IRBuilder + prim_func_actual = ib.get() + + # the expected prim_func + prim_func_expected = tir.PrimFunc( + params=[], + body=tir.Evaluate(0), + ret_type=None, + buffer_map=None, + preflattened_buffer_map=None, + attrs=None, + ) + # Check if the generated ir is expected + assert_structural_equal(prim_func_actual, prim_func_expected, map_free_vars=True) + + +if __name__ == "__main__": + tvm.testing.main()