From c6e366d0be720e1b44290e877ed90594307a576c Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Wed, 14 Sep 2022 17:35:50 -0700 Subject: [PATCH 1/2] [TVMScript] IRBuilder methods for `For` This PR introduces remaining IRBuilder methods for `For`. Co-authored-by: yongwww --- include/tvm/script/ir_builder/tir/frame.h | 45 +++++ include/tvm/script/ir_builder/tir/ir.h | 53 ++++++ python/tvm/script/ir_builder/tir/frame.py | 9 + python/tvm/script/ir_builder/tir/ir.py | 172 ++++++++++++++++++ src/script/ir_builder/tir/frame.cc | 6 + src/script/ir_builder/tir/ir.cc | 76 ++++++++ .../unittest/test_tvmscript_ir_builder_tir.py | 55 ++++++ 7 files changed, 416 insertions(+) diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tir/frame.h index 15ab77863e5e..c5f724daf1be 100644 --- a/include/tvm/script/ir_builder/tir/frame.h +++ b/include/tvm/script/ir_builder/tir/frame.h @@ -187,6 +187,51 @@ class BlockFrame : public TIRFrame { TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockFrame, TIRFrame, BlockFrameNode); }; +/*! + * \brief A frame that represents the for loop. + * + * \sa BlockInitFrame + */ +class ForFrameNode : public TIRFrameNode { + public: + /*! \brief The for loop generating function type. */ + using FMakeForLoop = + runtime::TypedPackedFunc, Array, tvm::tir::Stmt)>; + /*! \brief The loop variable. */ + Array vars; + /*! \brief The domains of iteration. */ + Array doms; + /*! \brief The for loop generating function. */ + FMakeForLoop f_make_for_loop; + + void VisitAttrs(tvm::AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); + v->Visit("vars", &vars); + v->Visit("doms", &doms); + // `f_make_for_loop` is not visited. + } + + static constexpr const char* _type_key = "script.ir_builder.tir.ForFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(ForFrameNode, TIRFrameNode); + + public: + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ + void ExitWithScope() final; +}; + +/*! + * \brief Managed reference to ForFrameNode. + * + * \sa ForFrameNode + */ +class ForFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ForFrame, TIRFrame, ForFrameNode); +}; + /*! * \brief A frame that represents the assert statement. Proceeds if the condition is true, * otherwise aborts with the message. diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index aaa5442eede3..68948196ff6b 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -141,6 +141,59 @@ void PreflattenedBuffer(Buffer postflattened_buffer, Array shape, */ BlockFrame Block(String name, bool no_realize = false); +/*! + * \brief The serial For statement. + * \param start The minimum value of iteration. + * \param stop The maximum value of iteration. + * \param annotations The optional annotations of the For statement. + * \return The ForFrame. + */ +ForFrame Serial(PrimExpr start, PrimExpr stop, + Optional> annotations = NullOpt); +/*! + * \brief The parallel For statement. + * \param start The minimum value of iteration. + * \param stop The maximum value of iteration. + * \param annotations The optional annotations of the For statement. + * \return The ForFrame. + */ +ForFrame Parallel(PrimExpr start, PrimExpr stop, + Optional> annotations = NullOpt); +/*! + * \brief The vectorized For statement. + * \param start The minimum value of iteration. + * \param stop The maximum value of iteration. + * \param annotations The optional annotations of the For statement. + * \return The ForFrame. + */ +ForFrame Vectorized(PrimExpr start, PrimExpr stop, + Optional> annotations = NullOpt); +/*! + * \brief The unrolled For statement. + * \param start The minimum value of iteration. + * \param stop The maximum value of iteration. + * \param annotations The optional annotations of the For statement. + * \return The ForFrame. + */ +ForFrame Unroll(PrimExpr start, PrimExpr stop, + Optional> annotations = NullOpt); +/*! + * \brief The thread-binding For statement. + * \param start The minimum value of iteration. + * \param stop The maximum value of iteration. + * \param thread The thread for loop variable to bind. + * \param annotations The optional annotations of the For statement. + * \return The ForFrame. + */ +ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread, + Optional> annotations = NullOpt); +/*! + * \brief The grid For statement. + * \param extents The extents of the iteration. + * \return The ForFrame. + */ +ForFrame Grid(Array extents); + /*! * \brief Evaluate the input expression. * \param value The input expression to evaluate. diff --git a/python/tvm/script/ir_builder/tir/frame.py b/python/tvm/script/ir_builder/tir/frame.py index 0e7eb2bb4720..86dea1cb7746 100644 --- a/python/tvm/script/ir_builder/tir/frame.py +++ b/python/tvm/script/ir_builder/tir/frame.py @@ -15,8 +15,10 @@ # specific language governing permissions and limitations # under the License. """IRBuilder for TIR""" +from typing import List from tvm._ffi import register_object as _register_object +from tvm.tir import Var from ..base import IRBuilderFrame @@ -34,3 +36,10 @@ class PrimFuncFrame(TIRFrame): @_register_object("script.ir_builder.tir.BlockFrame") class BlockFrame(TIRFrame): ... + + +@_register_object("script.ir_builder.tir.ForFrame") +class ForFrame(TIRFrame): + def __enter__(self) -> List[Var]: + super().__enter__() + return self.vars if len(self.vars) > 1 else self.vars[0] diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 63fd1291f4bc..a5cdf8a3a105 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -344,6 +344,172 @@ def block(name: str = "", no_realize: bool = False) -> frame.BlockFrame: return _ffi_api.Block(name, no_realize) # pylint: disable=no-member # type: ignore +def serial( + start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None +) -> frame.ForFrame: + """The serial For statement. + + Parameters + ---------- + start : PrimExpr + The minimum value of iteration. + + stop : PrimExpr + The maximum value of iteration. + + annotations : Dict[str, Any] + The optional annotations of the For statement. + + Returns + ------- + res : frame.ForFrame + The ForFrame. + """ + if stop is None: + stop = start + start = 0 + return _ffi_api.Serial(start, stop, annotations) # pylint: disable=no-member # type: ignore + + +def parallel( + start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None +) -> frame.ForFrame: + """The parallel For statement. + + Parameters + ---------- + start : PrimExpr + The minimum value of iteration. + + stop : PrimExpr + The maximum value of iteration. + + annotations : Dict[str, Any] + The optional annotations of the For statement. + + Returns + ------- + res : frame.ForFrame + The ForFrame. + """ + if stop is None: + stop = start + start = 0 + return _ffi_api.Parallel(start, stop, annotations) # pylint: disable=no-member # type: ignore + + +def vectorized( + start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None +) -> frame.ForFrame: + """The vectorized For statement. + + Parameters + ---------- + start : PrimExpr + The minimum value of iteration. + + stop : PrimExpr + The maximum value of iteration. + + annotations : Dict[str, Any] + The optional annotations of the For statement. + + Returns + ------- + res : frame.ForFrame + The ForFrame. + """ + if stop is None: + stop = start + start = 0 + return _ffi_api.Vectorized(start, stop, annotations) # pylint: disable=no-member # type: ignore + + +def unroll( + start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None +) -> frame.ForFrame: + """The unrolled For statement. + + Parameters + ---------- + start : PrimExpr + The minimum value of iteration. + + stop : PrimExpr + The maximum value of iteration. + + annotations : Dict[str, Any] + The optional annotations of the For statement. + + Returns + ------- + res : frame.ForFrame + The ForFrame. + """ + if stop is None: + stop = start + start = 0 + return _ffi_api.Unroll(start, stop, annotations) # pylint: disable=no-member # type: ignore + + +def thread_binding( + start: PrimExpr, + stop: PrimExpr = None, + thread: str = None, + *, + annotations: Dict[str, Any] = None, +) -> frame.ForFrame: + """The thread-binding For statement. + + Parameters + ---------- + start : PrimExpr + The minimum value of iteration. + + stop : PrimExpr + The maximum value of iteration. + + thread : str + The thread for loop variable to bind. + + annotations : Dict[str, Any] + The optional annotations of the For statement. + + Returns + ------- + res : frame.ForFrame + The ForFrame. + """ + if thread is None: + if not isinstance(stop, str): + raise ValueError("Thread cannot be None for thread_binding") + thread = stop + stop = start + start = 0 + elif stop is None: + stop = start + start = 0 + return _ffi_api.ThreadBinding( # pylint: disable=no-member # type: ignore + start, stop, thread, annotations + ) + + +def grid(*extents: PrimExpr) -> frame.ForFrame: + """The grid For statement. + + Parameters + ---------- + extents : PrimExpr + The extents of the iteration. + + Returns + ------- + res : frame.ForFrame + The ForFrame. + """ + return _ffi_api.Grid(extents) # pylint: disable=no-member # type: ignore + + def evaluate(value: PrimExpr) -> None: """Evaluate the input expression. @@ -677,6 +843,12 @@ def var(dtype, name="") -> Var: "match_buffer", "preflattened_buffer", "block", + "serial", + "parallel", + "vectorized", + "unroll", + "thread_binding", + "grid", "evaluate", "int8", "int16", diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc index dd3097e388b7..e54bf75eeff2 100644 --- a/src/script/ir_builder/tir/frame.cc +++ b/src/script/ir_builder/tir/frame.cc @@ -73,9 +73,15 @@ void BlockFrameNode::ExitWithScope() { } } +void ForFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + AddToParent(this->f_make_for_loop(vars, doms, AsStmt(stmts))); +} + TVM_REGISTER_NODE_TYPE(TIRFrameNode); TVM_REGISTER_NODE_TYPE(PrimFuncFrameNode); TVM_REGISTER_NODE_TYPE(BlockFrameNode); +TVM_REGISTER_NODE_TYPE(ForFrameNode); } // namespace tir } // namespace ir_builder diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index e2c1218a7e87..22c7face7084 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -173,6 +173,74 @@ BlockFrame Block(String name, bool no_realize) { return BlockFrame(n); } +#define TVM_TIR_IR_BUILDER_FOR_FRAME(Method, Kind) \ + ForFrame Method(PrimExpr start, PrimExpr stop, Optional> annotations) { \ + PrimExpr min = start; \ + PrimExpr extent = arith::Analyzer().Simplify(stop - start); \ + ObjectPtr n = make_object(); \ + int bits = std::max(min.dtype().bits(), extent.dtype().bits()); \ + n->vars = {Var("v", DataType::Int(bits))}; \ + n->doms = {Range::FromMinExtent(min, extent)}; \ + n->f_make_for_loop = [annotations](Array vars, Array doms, tvm::tir::Stmt body) { \ + ICHECK_EQ(vars.size(), 1); \ + ICHECK_EQ(doms.size(), 1); \ + return tvm::tir::For(vars[0], doms[0]->min, doms[0]->extent, Kind, body, NullOpt, \ + annotations.value_or(Map())); \ + }; \ + return ForFrame(n); \ + } + +TVM_TIR_IR_BUILDER_FOR_FRAME(Serial, tvm::tir::ForKind::kSerial); +TVM_TIR_IR_BUILDER_FOR_FRAME(Parallel, tvm::tir::ForKind::kParallel); +TVM_TIR_IR_BUILDER_FOR_FRAME(Vectorized, tvm::tir::ForKind::kVectorized); +TVM_TIR_IR_BUILDER_FOR_FRAME(Unroll, tvm::tir::ForKind::kUnrolled); + +#undef TVM_TIR_IR_BUILDER_FOR_FRAME + +ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread, + Optional> annotations) { + using namespace tvm::tir; + PrimExpr min = start; + PrimExpr extent = arith::Analyzer().Simplify(stop - start); + ObjectPtr n = make_object(); + int bits = std::max(min.dtype().bits(), extent.dtype().bits()); + n->vars = {Var("v", DataType::Int(bits))}; + n->doms = {Range::FromMinExtent(min, extent)}; + n->f_make_for_loop = [annotations, thread](Array vars, Array doms, Stmt body) -> For { + ICHECK_EQ(vars.size(), 1); + ICHECK_EQ(doms.size(), 1); + IterVar iter_var(Range(nullptr), Var("iter", DataType::Int(32)), IterVarType::kThreadIndex, + thread); + return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kThreadBinding, body, iter_var, + annotations.value_or(Map())); + }; + return ForFrame(n); +} + +ForFrame Grid(Array extents) { + using namespace tvm::tir; + ObjectPtr n = make_object(); + n->vars.reserve(extents.size()); + n->doms.reserve(extents.size()); + for (const auto& extent : extents) { + DataType dtype = extent.dtype(); + n->vars.push_back(Var("v", extent.dtype())); + n->doms.push_back(Range(make_const(dtype, 0), extent)); + } + n->f_make_for_loop = [](Array vars, Array doms, Stmt body) -> Stmt { + ICHECK_EQ(vars.size(), doms.size()); + int n = vars.size(); + for (int i = n - 1; i >= 0; --i) { + Range dom = doms[i]; + Var var = vars[i]; + body = For(var, dom->min, dom->extent, ForKind::kSerial, std::move(body), + /*thread_binding=*/NullOpt, /*annotations=*/{}); + } + return body; + }; + return ForFrame(n); +} + void Evaluate(PrimExpr value) { AddToParent(tvm::tir::Evaluate(value)); } using tvm::script::ir_builder::details::Namer; @@ -235,6 +303,14 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.MatchBuffer").set_body_typed(MatchBuf TVM_REGISTER_GLOBAL("script.ir_builder.tir.PreflattenedBuffer").set_body_typed(PreflattenedBuffer); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Block").set_body_typed(Block); + +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Serial").set_body_typed(Serial); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Parallel").set_body_typed(Parallel); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Vectorized").set_body_typed(Vectorized); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Unroll").set_body_typed(Unroll); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.ThreadBinding").set_body_typed(ThreadBinding); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Grid").set_body_typed(Grid); + TVM_REGISTER_GLOBAL("script.ir_builder.tir.Evaluate").set_body_typed(Evaluate); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int8").set_body_typed(Int8); diff --git a/tests/python/unittest/test_tvmscript_ir_builder_tir.py b/tests/python/unittest/test_tvmscript_ir_builder_tir.py index 5c93e99909d9..9cbfd75e2280 100644 --- a/tests/python/unittest/test_tvmscript_ir_builder_tir.py +++ b/tests/python/unittest/test_tvmscript_ir_builder_tir.py @@ -114,5 +114,60 @@ def test_ir_builder_tir_block(): assert_structural_equal(block_realize_actual, block_realize_expected, map_free_vars=True) +def test_ir_builder_tir_for(): + with IRBuilder() as ib: + with T.serial(128) as a: + with T.parallel(64) as b: + with T.vectorized(32) as c: + with T.unroll(16) as d: + with T.thread_binding(8, thread="threadIdx.x") as e: + T.evaluate(0) + + # the for generated by IRBuilder + for_actual = ib.get() + + # the expected for + thread_binding_expected = tir.For( + loop_var=tir.Var("", "int32"), + min_val=0, + extent=8, + kind=tir.ForKind.THREAD_BINDING, + body=tir.Evaluate(0), + thread_binding=tir.IterVar( + None, tir.Var("", "int32"), tir.IterVar.ThreadIndex, "threadIdx.x" + ), + ) + unroll_expected = tir.For( + loop_var=tir.Var("", "int32"), + min_val=0, + extent=16, + kind=tir.ForKind.UNROLLED, + body=thread_binding_expected, + ) + vectorized_expected = tir.For( + loop_var=tir.Var("", "int32"), + min_val=0, + extent=32, + kind=tir.ForKind.VECTORIZED, + body=unroll_expected, + ) + parallel_expected = tir.For( + loop_var=tir.Var("", "int32"), + min_val=0, + extent=64, + kind=tir.ForKind.PARALLEL, + body=vectorized_expected, + ) + for_expected = tir.For( + loop_var=tir.Var("", "int32"), + min_val=0, + extent=128, + kind=tir.ForKind.SERIAL, + body=parallel_expected, + ) + # Check if the generated ir is expected + assert_structural_equal(for_actual, for_expected, map_free_vars=True) + + if __name__ == "__main__": tvm.testing.main() From e69870593227684fdb582e41e94deafaaa61f06b Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Thu, 15 Sep 2022 11:42:42 -0700 Subject: [PATCH 2/2] apply code review suggestions --- include/tvm/script/ir_builder/tir/frame.h | 14 ++++++++++---- python/tvm/script/ir_builder/tir/frame.py | 4 ++-- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tir/frame.h index c5f724daf1be..2902b982d5a6 100644 --- a/include/tvm/script/ir_builder/tir/frame.h +++ b/include/tvm/script/ir_builder/tir/frame.h @@ -190,13 +190,19 @@ class BlockFrame : public TIRFrame { /*! * \brief A frame that represents the for loop. * - * \sa BlockInitFrame + * \sa ForFrame */ class ForFrameNode : public TIRFrameNode { public: - /*! \brief The for loop generating function type. */ - using FMakeForLoop = - runtime::TypedPackedFunc, Array, tvm::tir::Stmt)>; + /*! + * \brief Functions that generate loop nests. + * \param loop_vars The loop variables, from outer to inner + * \param loop_extents The loop extents that correspond to loop variables + * \param loop_body The loop body + * \return A stmt, the loop nest + */ + using FMakeForLoop = runtime::TypedPackedFunc loop_vars, Array loop_extents, tvm::tir::Stmt loop_body)>; /*! \brief The loop variable. */ Array vars; /*! \brief The domains of iteration. */ diff --git a/python/tvm/script/ir_builder/tir/frame.py b/python/tvm/script/ir_builder/tir/frame.py index 86dea1cb7746..75bb0231aeef 100644 --- a/python/tvm/script/ir_builder/tir/frame.py +++ b/python/tvm/script/ir_builder/tir/frame.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """IRBuilder for TIR""" -from typing import List +from typing import List, Union from tvm._ffi import register_object as _register_object from tvm.tir import Var @@ -40,6 +40,6 @@ class BlockFrame(TIRFrame): @_register_object("script.ir_builder.tir.ForFrame") class ForFrame(TIRFrame): - def __enter__(self) -> List[Var]: + def __enter__(self) -> Union[Var, List[Var]]: super().__enter__() return self.vars if len(self.vars) > 1 else self.vars[0]