diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tir/frame.h index 15ab77863e5e..2902b982d5a6 100644 --- a/include/tvm/script/ir_builder/tir/frame.h +++ b/include/tvm/script/ir_builder/tir/frame.h @@ -187,6 +187,57 @@ class BlockFrame : public TIRFrame { TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockFrame, TIRFrame, BlockFrameNode); }; +/*! + * \brief A frame that represents the for loop. + * + * \sa ForFrame + */ +class ForFrameNode : public TIRFrameNode { + public: + /*! + * \brief Functions that generate loop nests. + * \param loop_vars The loop variables, from outer to inner + * \param loop_extents The loop extents that correspond to loop variables + * \param loop_body The loop body + * \return A stmt, the loop nest + */ + using FMakeForLoop = runtime::TypedPackedFunc loop_vars, Array loop_extents, tvm::tir::Stmt loop_body)>; + /*! \brief The loop variable. */ + Array vars; + /*! \brief The domains of iteration. */ + Array doms; + /*! \brief The for loop generating function. */ + FMakeForLoop f_make_for_loop; + + void VisitAttrs(tvm::AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); + v->Visit("vars", &vars); + v->Visit("doms", &doms); + // `f_make_for_loop` is not visited. + } + + static constexpr const char* _type_key = "script.ir_builder.tir.ForFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(ForFrameNode, TIRFrameNode); + + public: + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ + void ExitWithScope() final; +}; + +/*! + * \brief Managed reference to ForFrameNode. + * + * \sa ForFrameNode + */ +class ForFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ForFrame, TIRFrame, ForFrameNode); +}; + /*! * \brief A frame that represents the assert statement. Proceeds if the condition is true, * otherwise aborts with the message. diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index aaa5442eede3..68948196ff6b 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -141,6 +141,59 @@ void PreflattenedBuffer(Buffer postflattened_buffer, Array shape, */ BlockFrame Block(String name, bool no_realize = false); +/*! + * \brief The serial For statement. + * \param start The minimum value of iteration. + * \param stop The maximum value of iteration. + * \param annotations The optional annotations of the For statement. + * \return The ForFrame. + */ +ForFrame Serial(PrimExpr start, PrimExpr stop, + Optional> annotations = NullOpt); +/*! + * \brief The parallel For statement. + * \param start The minimum value of iteration. + * \param stop The maximum value of iteration. + * \param annotations The optional annotations of the For statement. + * \return The ForFrame. + */ +ForFrame Parallel(PrimExpr start, PrimExpr stop, + Optional> annotations = NullOpt); +/*! + * \brief The vectorized For statement. + * \param start The minimum value of iteration. + * \param stop The maximum value of iteration. + * \param annotations The optional annotations of the For statement. + * \return The ForFrame. + */ +ForFrame Vectorized(PrimExpr start, PrimExpr stop, + Optional> annotations = NullOpt); +/*! + * \brief The unrolled For statement. + * \param start The minimum value of iteration. + * \param stop The maximum value of iteration. + * \param annotations The optional annotations of the For statement. + * \return The ForFrame. + */ +ForFrame Unroll(PrimExpr start, PrimExpr stop, + Optional> annotations = NullOpt); +/*! + * \brief The thread-binding For statement. + * \param start The minimum value of iteration. + * \param stop The maximum value of iteration. + * \param thread The thread for loop variable to bind. + * \param annotations The optional annotations of the For statement. + * \return The ForFrame. + */ +ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread, + Optional> annotations = NullOpt); +/*! + * \brief The grid For statement. + * \param extents The extents of the iteration. + * \return The ForFrame. + */ +ForFrame Grid(Array extents); + /*! * \brief Evaluate the input expression. * \param value The input expression to evaluate. diff --git a/python/tvm/script/ir_builder/tir/frame.py b/python/tvm/script/ir_builder/tir/frame.py index 0e7eb2bb4720..75bb0231aeef 100644 --- a/python/tvm/script/ir_builder/tir/frame.py +++ b/python/tvm/script/ir_builder/tir/frame.py @@ -15,8 +15,10 @@ # specific language governing permissions and limitations # under the License. """IRBuilder for TIR""" +from typing import List, Union from tvm._ffi import register_object as _register_object +from tvm.tir import Var from ..base import IRBuilderFrame @@ -34,3 +36,10 @@ class PrimFuncFrame(TIRFrame): @_register_object("script.ir_builder.tir.BlockFrame") class BlockFrame(TIRFrame): ... + + +@_register_object("script.ir_builder.tir.ForFrame") +class ForFrame(TIRFrame): + def __enter__(self) -> Union[Var, List[Var]]: + super().__enter__() + return self.vars if len(self.vars) > 1 else self.vars[0] diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 63fd1291f4bc..a5cdf8a3a105 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -344,6 +344,172 @@ def block(name: str = "", no_realize: bool = False) -> frame.BlockFrame: return _ffi_api.Block(name, no_realize) # pylint: disable=no-member # type: ignore +def serial( + start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None +) -> frame.ForFrame: + """The serial For statement. + + Parameters + ---------- + start : PrimExpr + The minimum value of iteration. + + stop : PrimExpr + The maximum value of iteration. + + annotations : Dict[str, Any] + The optional annotations of the For statement. + + Returns + ------- + res : frame.ForFrame + The ForFrame. + """ + if stop is None: + stop = start + start = 0 + return _ffi_api.Serial(start, stop, annotations) # pylint: disable=no-member # type: ignore + + +def parallel( + start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None +) -> frame.ForFrame: + """The parallel For statement. + + Parameters + ---------- + start : PrimExpr + The minimum value of iteration. + + stop : PrimExpr + The maximum value of iteration. + + annotations : Dict[str, Any] + The optional annotations of the For statement. + + Returns + ------- + res : frame.ForFrame + The ForFrame. + """ + if stop is None: + stop = start + start = 0 + return _ffi_api.Parallel(start, stop, annotations) # pylint: disable=no-member # type: ignore + + +def vectorized( + start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None +) -> frame.ForFrame: + """The vectorized For statement. + + Parameters + ---------- + start : PrimExpr + The minimum value of iteration. + + stop : PrimExpr + The maximum value of iteration. + + annotations : Dict[str, Any] + The optional annotations of the For statement. + + Returns + ------- + res : frame.ForFrame + The ForFrame. + """ + if stop is None: + stop = start + start = 0 + return _ffi_api.Vectorized(start, stop, annotations) # pylint: disable=no-member # type: ignore + + +def unroll( + start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None +) -> frame.ForFrame: + """The unrolled For statement. + + Parameters + ---------- + start : PrimExpr + The minimum value of iteration. + + stop : PrimExpr + The maximum value of iteration. + + annotations : Dict[str, Any] + The optional annotations of the For statement. + + Returns + ------- + res : frame.ForFrame + The ForFrame. + """ + if stop is None: + stop = start + start = 0 + return _ffi_api.Unroll(start, stop, annotations) # pylint: disable=no-member # type: ignore + + +def thread_binding( + start: PrimExpr, + stop: PrimExpr = None, + thread: str = None, + *, + annotations: Dict[str, Any] = None, +) -> frame.ForFrame: + """The thread-binding For statement. + + Parameters + ---------- + start : PrimExpr + The minimum value of iteration. + + stop : PrimExpr + The maximum value of iteration. + + thread : str + The thread for loop variable to bind. + + annotations : Dict[str, Any] + The optional annotations of the For statement. + + Returns + ------- + res : frame.ForFrame + The ForFrame. + """ + if thread is None: + if not isinstance(stop, str): + raise ValueError("Thread cannot be None for thread_binding") + thread = stop + stop = start + start = 0 + elif stop is None: + stop = start + start = 0 + return _ffi_api.ThreadBinding( # pylint: disable=no-member # type: ignore + start, stop, thread, annotations + ) + + +def grid(*extents: PrimExpr) -> frame.ForFrame: + """The grid For statement. + + Parameters + ---------- + extents : PrimExpr + The extents of the iteration. + + Returns + ------- + res : frame.ForFrame + The ForFrame. + """ + return _ffi_api.Grid(extents) # pylint: disable=no-member # type: ignore + + def evaluate(value: PrimExpr) -> None: """Evaluate the input expression. @@ -677,6 +843,12 @@ def var(dtype, name="") -> Var: "match_buffer", "preflattened_buffer", "block", + "serial", + "parallel", + "vectorized", + "unroll", + "thread_binding", + "grid", "evaluate", "int8", "int16", diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc index dd3097e388b7..e54bf75eeff2 100644 --- a/src/script/ir_builder/tir/frame.cc +++ b/src/script/ir_builder/tir/frame.cc @@ -73,9 +73,15 @@ void BlockFrameNode::ExitWithScope() { } } +void ForFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + AddToParent(this->f_make_for_loop(vars, doms, AsStmt(stmts))); +} + TVM_REGISTER_NODE_TYPE(TIRFrameNode); TVM_REGISTER_NODE_TYPE(PrimFuncFrameNode); TVM_REGISTER_NODE_TYPE(BlockFrameNode); +TVM_REGISTER_NODE_TYPE(ForFrameNode); } // namespace tir } // namespace ir_builder diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index e2c1218a7e87..22c7face7084 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -173,6 +173,74 @@ BlockFrame Block(String name, bool no_realize) { return BlockFrame(n); } +#define TVM_TIR_IR_BUILDER_FOR_FRAME(Method, Kind) \ + ForFrame Method(PrimExpr start, PrimExpr stop, Optional> annotations) { \ + PrimExpr min = start; \ + PrimExpr extent = arith::Analyzer().Simplify(stop - start); \ + ObjectPtr n = make_object(); \ + int bits = std::max(min.dtype().bits(), extent.dtype().bits()); \ + n->vars = {Var("v", DataType::Int(bits))}; \ + n->doms = {Range::FromMinExtent(min, extent)}; \ + n->f_make_for_loop = [annotations](Array vars, Array doms, tvm::tir::Stmt body) { \ + ICHECK_EQ(vars.size(), 1); \ + ICHECK_EQ(doms.size(), 1); \ + return tvm::tir::For(vars[0], doms[0]->min, doms[0]->extent, Kind, body, NullOpt, \ + annotations.value_or(Map())); \ + }; \ + return ForFrame(n); \ + } + +TVM_TIR_IR_BUILDER_FOR_FRAME(Serial, tvm::tir::ForKind::kSerial); +TVM_TIR_IR_BUILDER_FOR_FRAME(Parallel, tvm::tir::ForKind::kParallel); +TVM_TIR_IR_BUILDER_FOR_FRAME(Vectorized, tvm::tir::ForKind::kVectorized); +TVM_TIR_IR_BUILDER_FOR_FRAME(Unroll, tvm::tir::ForKind::kUnrolled); + +#undef TVM_TIR_IR_BUILDER_FOR_FRAME + +ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread, + Optional> annotations) { + using namespace tvm::tir; + PrimExpr min = start; + PrimExpr extent = arith::Analyzer().Simplify(stop - start); + ObjectPtr n = make_object(); + int bits = std::max(min.dtype().bits(), extent.dtype().bits()); + n->vars = {Var("v", DataType::Int(bits))}; + n->doms = {Range::FromMinExtent(min, extent)}; + n->f_make_for_loop = [annotations, thread](Array vars, Array doms, Stmt body) -> For { + ICHECK_EQ(vars.size(), 1); + ICHECK_EQ(doms.size(), 1); + IterVar iter_var(Range(nullptr), Var("iter", DataType::Int(32)), IterVarType::kThreadIndex, + thread); + return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kThreadBinding, body, iter_var, + annotations.value_or(Map())); + }; + return ForFrame(n); +} + +ForFrame Grid(Array extents) { + using namespace tvm::tir; + ObjectPtr n = make_object(); + n->vars.reserve(extents.size()); + n->doms.reserve(extents.size()); + for (const auto& extent : extents) { + DataType dtype = extent.dtype(); + n->vars.push_back(Var("v", extent.dtype())); + n->doms.push_back(Range(make_const(dtype, 0), extent)); + } + n->f_make_for_loop = [](Array vars, Array doms, Stmt body) -> Stmt { + ICHECK_EQ(vars.size(), doms.size()); + int n = vars.size(); + for (int i = n - 1; i >= 0; --i) { + Range dom = doms[i]; + Var var = vars[i]; + body = For(var, dom->min, dom->extent, ForKind::kSerial, std::move(body), + /*thread_binding=*/NullOpt, /*annotations=*/{}); + } + return body; + }; + return ForFrame(n); +} + void Evaluate(PrimExpr value) { AddToParent(tvm::tir::Evaluate(value)); } using tvm::script::ir_builder::details::Namer; @@ -235,6 +303,14 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.MatchBuffer").set_body_typed(MatchBuf TVM_REGISTER_GLOBAL("script.ir_builder.tir.PreflattenedBuffer").set_body_typed(PreflattenedBuffer); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Block").set_body_typed(Block); + +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Serial").set_body_typed(Serial); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Parallel").set_body_typed(Parallel); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Vectorized").set_body_typed(Vectorized); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Unroll").set_body_typed(Unroll); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.ThreadBinding").set_body_typed(ThreadBinding); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Grid").set_body_typed(Grid); + TVM_REGISTER_GLOBAL("script.ir_builder.tir.Evaluate").set_body_typed(Evaluate); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int8").set_body_typed(Int8); diff --git a/tests/python/unittest/test_tvmscript_ir_builder_tir.py b/tests/python/unittest/test_tvmscript_ir_builder_tir.py index 5c93e99909d9..9cbfd75e2280 100644 --- a/tests/python/unittest/test_tvmscript_ir_builder_tir.py +++ b/tests/python/unittest/test_tvmscript_ir_builder_tir.py @@ -114,5 +114,60 @@ def test_ir_builder_tir_block(): assert_structural_equal(block_realize_actual, block_realize_expected, map_free_vars=True) +def test_ir_builder_tir_for(): + with IRBuilder() as ib: + with T.serial(128) as a: + with T.parallel(64) as b: + with T.vectorized(32) as c: + with T.unroll(16) as d: + with T.thread_binding(8, thread="threadIdx.x") as e: + T.evaluate(0) + + # the for generated by IRBuilder + for_actual = ib.get() + + # the expected for + thread_binding_expected = tir.For( + loop_var=tir.Var("", "int32"), + min_val=0, + extent=8, + kind=tir.ForKind.THREAD_BINDING, + body=tir.Evaluate(0), + thread_binding=tir.IterVar( + None, tir.Var("", "int32"), tir.IterVar.ThreadIndex, "threadIdx.x" + ), + ) + unroll_expected = tir.For( + loop_var=tir.Var("", "int32"), + min_val=0, + extent=16, + kind=tir.ForKind.UNROLLED, + body=thread_binding_expected, + ) + vectorized_expected = tir.For( + loop_var=tir.Var("", "int32"), + min_val=0, + extent=32, + kind=tir.ForKind.VECTORIZED, + body=unroll_expected, + ) + parallel_expected = tir.For( + loop_var=tir.Var("", "int32"), + min_val=0, + extent=64, + kind=tir.ForKind.PARALLEL, + body=vectorized_expected, + ) + for_expected = tir.For( + loop_var=tir.Var("", "int32"), + min_val=0, + extent=128, + kind=tir.ForKind.SERIAL, + body=parallel_expected, + ) + # Check if the generated ir is expected + assert_structural_equal(for_actual, for_expected, map_free_vars=True) + + if __name__ == "__main__": tvm.testing.main()