diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index e5d2c440e57b..1ac3f80ecf39 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -110,7 +110,7 @@ class ScheduleNode : public runtime::Object { * guaranteeing that * 1) SRef tree is completely reconstructed; * 2) The IRModule being scheduled is not modified; - * 3) All the random variables are valid in the copy, pointing to the correpsonding sref + * 3) All the random variables are valid in the copy, pointing to the corresponding sref * reconstructed */ virtual Schedule Copy() const = 0; @@ -220,6 +220,43 @@ class ScheduleNode : public runtime::Object { */ virtual Array Split(const LoopRV& loop_rv, const Array>& factors) = 0; /******** Schedule: Manipulate ForKind ********/ + /*! + * \brief Parallelize the input loop. It requires: + * 1) The scope block that the loop is in should have stage-pipeline property + * 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine + * bindings + * 3) For each block under the loop, the loop can only be contained in data-parallel block iters' + * bindings + * \param loop_rv The loop to be parallelized + */ + virtual void Parallel(const LoopRV& loop_rv) = 0; + /*! + * \brief Vectorize the input loop. It requires: + * 1) The scope block that the loop is in should have stage-pipeline property + * 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine + * bindings + * 3) For each block under the loop, the loop can only be contained in data-parallel block iters' + * bindings + * \param loop_rv The loop to be vectorized + */ + virtual void Vectorize(const LoopRV& loop_rv) = 0; + /*! + * \brief Bind the input loop to the given thread axis. It requires: + * 1) The scope block that the loop is in should have stage-pipeline property + * 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine + * bindings + * 3) For each block under the loop, if the thread axis starts with "threadIdx`, the loop can only + * be contained in data-parallel block iter and reduction block iters' bindings. Otherwise the + * loop can only be contained in data-parallel block iters' bindings + * \param loop_rv The loop to be bound to the thread axis + * \param thread_axis The thread axis to be bound to the loop + */ + virtual void Bind(const LoopRV& loop_rv, const String& thread_axis) = 0; + /*! + * \brief Unroll the input loop. It requires nothing + * \param loop_rv The loop to be unrolled + */ + virtual void Unroll(const LoopRV& loop_rv) = 0; /******** Schedule: Insert cache stages ********/ /******** Schedule: Compute location ********/ /*! diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index d1308fe0059e..c0fa62d7caf0 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -437,6 +437,18 @@ TVM_DLL Pass LowerMatchBuffer(); */ TVM_DLL Pass FlattenBuffer(); +/*! + * \brief Unify all the thread bindings for "blockIdx.x/y/z", "threadIdx.x/y/z", and + * "vthread.x/y/z". Before the unification, two vars that are bound to a thread axis (e.g., + * "threadIdx.x") use different IterVars and variables in their AttrStmts. After the + * unification, we use a consolidated IterVar and a variable for them. + * \return The pass. + * \note `vthread` is a legacy behavior that will be deprecated, though thread bindings of `vthread` + * are still also unified in this pass. Please use `vthread.x`, `vthread.y` and `vthread.z` + * instead. + */ +TVM_DLL Pass UnifyThreadBinding(); + /*! * A pass to merge multiple TIR-level dynamic shared memory allocations into one */ diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index e8415d2bd522..46e5fd6fddcb 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -170,7 +170,7 @@ def copy(self) -> "Schedule": * guaranteeing that * 1) SRef tree is completely reconstructed; * 2) The IRModule being scheduled is untouched; - * 3) All the random variables are valid in the copy, pointing to the correpsonding sref + * 3) All the random variables are valid in the copy, pointing to the corresponding sref * reconstructed Returns @@ -226,7 +226,7 @@ def get( Returns ------- result : Optional[Union[int, Block, For]] - The correpsonding result + The corresponding result """ if isinstance(rand_var_or_sref, StmtSRef): return rand_var_or_sref.stmt @@ -236,7 +236,7 @@ def get( return result def get_sref(self, rand_var_or_stmt: Union[BlockRV, LoopRV, Block, For]) -> Optional[StmtSRef]: - """Returns the correpsonding sref to the given + """Returns the corresponding sref to the given 1) LoopRV 2) BlockRV 3) Block @@ -250,7 +250,7 @@ def get_sref(self, rand_var_or_stmt: Union[BlockRV, LoopRV, Block, For]) -> Opti Returns ------- result : Optional[StmtSRef] - The correpsonding result + The corresponding result """ return _ffi_api.ScheduleGetSRef( # type: ignore # pylint: disable=no-member self, rand_var_or_stmt @@ -413,7 +413,7 @@ def before_split(a: ty.handle, b: ty.handle) -> None: with tir.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 - Create the schedule and do fuse: + Create the schedule and do split: .. code-block:: python @@ -444,6 +444,234 @@ def after_split(a: ty.handle, b: ty.handle) -> None: ########## Schedule: Manipulate ForKind ########## + def parallel(self, loop: LoopRV) -> None: + """Parallelize the input loop. It requires: + 1) The scope block that the loop is in should have stage-pipeline property + 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine + bindings + 3) For each block under the loop, the loop can only be contained in data-parallel block + iters' bindings + + Parameters + ---------- + loop : LoopRV + The loop to be parallelized + + Examples + -------- + + Before parallel, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_parallel(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + B[vi, vj] = A[vi, vj] * 2.0 + + Create the schedule and do parallel: + + .. code-block:: python + + sch = tir.Schedule(before_parallel) + i, j = sch.get_loops(sch.get_block("B")) + sch.parallel(i) + + After applying parallel, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_parallel(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i in tir.parallel(0, 128): + for j in tir.serial(0, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + B[vi, vj] = A[vi, vj] * 2.0 + + """ + _ffi_api.ScheduleParallel(self, loop) # type: ignore # pylint: disable=no-member + + def vectorize(self, loop: LoopRV) -> None: + """Vectorize the input loop. It requires: + 1) The scope block that the loop is in should have stage-pipeline property + 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine + bindings + 3) For each block under the loop, the loop can only be contained in data-parallel block + iters' bindings + + Parameters + ---------- + loop : LoopRV + The loop to be vectorized + + Examples + -------- + + Before vectorize, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_vectorize(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + B[vi, vj] = A[vi, vj] * 2.0 + + Create the schedule and do vectorize: + + .. code-block:: python + + sch = tir.Schedule(before_vectorize) + i, j = sch.get_loops(sch.get_block("B")) + sch.vectorize(j) + + After applying vectorize, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_vectorize(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i in tir.serial(0, 128): + for j in tir.vectorized(0, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + B[vi, vj] = A[vi, vj] * 2.0 + + """ + _ffi_api.ScheduleVectorize(self, loop) # type: ignore # pylint: disable=no-member + + def bind(self, loop: LoopRV, thread_axis: str) -> None: + """Bind the input loop to the given thread axis. It requires: + 1) The scope block that the loop is in should have stage-pipeline property + 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine + bindings + 3) For each block under the loop, if the thread axis starts with "threadIdx`, the loop can + only be contained in data-parallel block iter and reduction block iters' bindings. Otherwise + the loop can only be contained in data-parallel block iters' bindings + + Parameters + ---------- + loop : LoopRV + The loop to be bound to the thread axis + thread_axis : str + The thread axis to be bound to the loop. Possible candidates: + - blockIdx.x/y/z + - threadIdx.x/y/z + - vthread.x/y/z + - vthread (It is a legacy behavior that will be deprecated. Please use `vthread.x/y/z` + instead.) + + Examples + -------- + + Before bind, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_bind(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + B[vi, vj] = A[vi, vj] * 2.0 + + Create the schedule and do bind: + + .. code-block:: python + + sch = tir.Schedule(before_bind) + i, j = sch.get_loops(sch.get_block("B")) + sch.bind(i, "blockIdx.x") + sch.bind(j, "threadIdx.x") + + After applying bind, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_bind(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i in tir.thread_binding(0, 128, thread = "blockIdx.x"): + for j in tir.thread_binding(0, 128, thread = "threadIdx.x"): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + B[vi, vj] = A[vi, vj] * 2.0 + + """ + _ffi_api.ScheduleBind(self, loop, thread_axis) # type: ignore # pylint: disable=no-member + + def unroll(self, loop: LoopRV) -> None: + """Unroll the input loop. It requires nothing + + Parameters + ---------- + loop : LoopRV + The loop to be unrolled + + Examples + -------- + + Before unroll, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_unroll(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + B[vi, vj] = A[vi, vj] * 2.0 + + Create the schedule and do unroll: + + .. code-block:: python + + sch = tir.Schedule(before_unroll) + i, j = sch.get_loops(sch.get_block("B")) + sch.unroll(i) + + After applying unroll, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_unroll(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i in tir.unroll(0, 128): + for j in tir.serial(0, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + B[vi, vj] = A[vi, vj] * 2.0 + + """ + _ffi_api.ScheduleUnroll(self, loop) # type: ignore # pylint: disable=no-member + ########## Schedule: Insert cache stages ########## ########## Schedule: Compute location ########## @@ -581,7 +809,7 @@ def rfactor(self, loop: LoopRV, factor_axis: int) -> LoopRV: RFactor is a schedule primitive that implements the transformation described above: Given a block that writes to buffer `B`, it factorizes a loop of extent `n`. - For example, the pesudocode below accumulates `B[i] = sum(A[i, : , : ])`: + For example, the pseudocode below accumulates `B[i] = sum(A[i, : , : ])`: .. code-block:: python diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 537499a27fa9..74dafa4157d7 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -668,6 +668,28 @@ def FlattenBuffer(): return _ffi_api.FlattenBuffer() # type: ignore +def UnifyThreadBinding(): + """Unify all the thread bindings for "blockIdx.x/y/z", + "threadIdx.x/y/z", and "vthread.x/y/z". Before the unification, + two vars that are bound to a thread axis (e.g., "threadIdx.x") + use different IterVars and variables in their AttrStmts. After + the unification, we use a consolidated IterVar and a variable + for them. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + + Note + ---- + `vthread` is a legacy behavior that will be deprecated, though + thread bindings of `vthread` are still also unified in this + pass. Please use `vthread.x`, `vthread.y` and `vthread.z` instead. + """ + return _ffi_api.UnifyThreadBinding() # type: ignore + + def MergeDynamicSharedMemoryAllocations(): """This pass merges multiple TIR-level dynamic shared memory allocations into one allocation. diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index d6af9936ca40..ff00e68d91f0 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -222,6 +222,7 @@ Array CreatePassList(bool disable_loop_partition) { pass_list.push_back(tir::transform::CompactBufferAllocation()); pass_list.push_back(tir::transform::LowerMatchBuffer()); pass_list.push_back(tir::transform::FlattenBuffer()); + pass_list.push_back(tir::transform::UnifyThreadBinding()); pass_list.push_back(tir::transform::BF16Legalize()); pass_list.push_back(tir::transform::NarrowDataType(32)); pass_list.push_back(tir::transform::Simplify()); diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index ac8260ffbe39..d577770db1a9 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -163,7 +163,7 @@ struct ThreadScope { */ static ThreadScope Create(const std::string& s) { ThreadScope r; - if (s == "vthread" || s == "cthread") { + if (s.compare(0, 7, "vthread") == 0 || s == "cthread") { // virtual thread at the same level as local r.rank = 1; r.dim_index = -1; diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 370aa01a33c0..3fa0c63b2e2f 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -120,6 +120,20 @@ bool IsReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, void CheckReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_root_sref); +/*! + * \brief Check whether a subtree on SRef tree has compact data flow, and throw an exception if the + * subtree does not have compact data flow + * \details For a given StmtSRef, We say the subtree rooted from the StmtSRef has "compact data + * flow" property if: + * - the scope root of the input subtree root has stage-pipeline property, and + * - all its child blocks on SRef tree are complete blocks or reduction blocks. + * \param self The schedule state + * \param subtree_root_sref The root of the subtree to be checked in the SRef tree + * \throw ScheduleError If the subtree does not have compact data flow + * \sa IsCompleteBlock, IsReductionBlock + */ +void CheckSRefSubtreeCompactDataFlow(const ScheduleState& self, const StmtSRef& subtree_root_sref); + /******** Binding ********/ /*! * \brief Verifies if the block binding in a specific BlockRealize is an affine binding. @@ -132,6 +146,15 @@ void CheckReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, bool IsAffineBinding(const BlockRealize& realize, const Map& loop_var_ranges, arith::Analyzer* analyzer); +/*! + * \brief Check whether a block has an affine binding using the cached flag, and throw an exception + * if the block does not have an affine binding. + * \param self The schedule state + * \param block The block to be checked + * \throw ScheduleError If the input block does not have an affine binding + */ +void CheckAffineBinding(const ScheduleState& self, Block block); + /*! * \brief Extracts the ranges of loop variables in a path of the sref tree * \param low_inclusive The lowest node in the path diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 8d1913fdee86..c9f8ff4c7e75 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -315,6 +315,43 @@ void CheckReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, } } +void CheckSRefSubtreeCompactDataFlow(const ScheduleState& self, const StmtSRef& subtree_root_sref) { + class NotCompactDataFlowError : public ScheduleError { + public: + explicit NotCompactDataFlowError(IRModule mod, Stmt subtree_root, Block violate_block) + : mod_(std::move(mod)), + subtree_root_(std::move(subtree_root)), + violate_block_(std::move(violate_block)) { + ICHECK(subtree_root_->IsInstance() || subtree_root_->IsInstance()); + } + String FastErrorString() const final { + return "ScheduleError: The queried subtree root in SRef tree does not have compact data " + "flow, because some of its child block on SRef tree is neither a complete block nor a " + "reduction block"; + } + String DetailRenderTemplate() const final { + return "The queried subtree root {0} in SRef tree does not have compact data flow, because " + "its child block {1} on SRef tree is neither a complete block nor a reduction block"; + } + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {subtree_root_, violate_block_}; } + + IRModule mod_; + Stmt subtree_root_; + Block violate_block_; + }; + + StmtSRef scope_root = GetScopeRoot(self, subtree_root_sref, /*require_stage_pipeline=*/true); + Array child_blocks = GetChildBlockSRefOnSRefTree(self, scope_root); + for (const StmtSRef& block : child_blocks) { + if (!IsCompleteBlock(self, block, scope_root) && !IsReductionBlock(self, block, scope_root)) { + const BlockNode* violate_block = TVM_SREF_TO_BLOCK(violate_block, block); + throw NotCompactDataFlowError(self->mod, GetRef(subtree_root_sref->stmt), + GetRef(violate_block)); + } + } +} + /******** Binding ********/ bool IsAffineBinding(const BlockRealize& realize, const Map& loop_var_ranges, @@ -340,6 +377,28 @@ bool IsAffineBinding(const BlockRealize& realize, const Map& loop_va return true; } +void CheckAffineBinding(const ScheduleState& self, Block block) { + class NotAffineBindingError : public ScheduleError { + public: + explicit NotAffineBindingError(IRModule mod, Block block) + : mod_(std::move(mod)), block_(std::move(block)) {} + String FastErrorString() const final { + return "ScheduleError: The block is required to have an affine binding"; + } + String DetailRenderTemplate() const final { + return "The block {0} is required to have an affine binding"; + } + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + IRModule mod_; + Block block_; + }; + + if (!self->IsAffineBlockBinding(self->stmt2ref.at(block.get()))) { + throw NotAffineBindingError(self->mod, std::move(block)); + } +} + Map LoopDomainOfSRefTreePath(const StmtSRef& low_inclusive, const Optional& high_exclusive, const runtime::StorageScope& extra_relax_scope) { diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 688ea8059c0e..b18090dd7215 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -207,7 +207,8 @@ Schedule ConcreteScheduleNode::Copy() const { } \ } -/******** Block/Loop relation ********/ +/******** Schedule: Schedule: Sampling ********/ +/******** Schedule: Get blocks & loops ********/ BlockRV ConcreteScheduleNode::GetBlock(const String& name, const String& func_name) { class NotSingleResult : public ScheduleError { @@ -257,7 +258,7 @@ Array ConcreteScheduleNode::GetLoops(const BlockRV& block_rv) { return CreateRV(tir::GetLoops(this->GetSRef(block_rv))); } -/******** Schedule: loops manipulation ********/ +/******** Schedule: Transform loops ********/ LoopRV ConcreteScheduleNode::Fuse(const Array& loop_rvs) { CHECK(!loop_rvs.empty()) << "ValueError: 'fuse' requires at least 1 loop(s)"; @@ -345,7 +346,44 @@ Array ConcreteScheduleNode::Split(const LoopRV& loop_rv, return CreateRV(results); } -/******** Schedule: compute location ********/ +/******** Schedule: Manipulate ForKind ********/ + +void ConcreteScheduleNode::Parallel(const LoopRV& loop_rv) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::Parallel(state_, this->GetSRef(loop_rv)); + this->state_->DebugVerify(); + TVM_TIR_SCHEDULE_END("parallel", this->error_render_level_); +} + +void ConcreteScheduleNode::Vectorize(const LoopRV& loop_rv) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::Vectorize(state_, this->GetSRef(loop_rv)); + this->state_->DebugVerify(); + TVM_TIR_SCHEDULE_END("vectorize", this->error_render_level_); +} + +void ConcreteScheduleNode::Bind(const LoopRV& loop_rv, const String& thread_axis) { + if (thread_axis == "vthread") { + LOG(WARNING) << "`vthread` is legacy behavior and is going to be deprecated. Please use " + "`vthread.x`, `vthread.y` and `vthread.z` instead"; + } + TVM_TIR_SCHEDULE_BEGIN(); + tir::Bind(state_, this->GetSRef(loop_rv), + IterVar(/*dom=*/Range(nullptr), /*var=*/Var(thread_axis), /*iter_type=*/kThreadIndex, + /*thread_tag=*/thread_axis)); + this->state_->DebugVerify(); + TVM_TIR_SCHEDULE_END("bind", this->error_render_level_); +} + +void ConcreteScheduleNode::Unroll(const LoopRV& loop_rv) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::Unroll(state_, this->GetSRef(loop_rv)); + this->state_->DebugVerify(); + TVM_TIR_SCHEDULE_END("unroll", this->error_render_level_); +} + +/******** Schedule: Insert cache stages ********/ +/******** Schedule: Compute location ********/ void ConcreteScheduleNode::ComputeInline(const BlockRV& block_rv) { TVM_TIR_SCHEDULE_BEGIN(); @@ -361,8 +399,7 @@ void ConcreteScheduleNode::ReverseComputeInline(const BlockRV& block_rv) { this->state_->DebugVerify(); } -/******** Schedule: loop binding/annotation ********/ -/******** Schedule: block annotation ********/ +/******** Schedule: Block Annotation ********/ void ConcreteScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor, int offset) { @@ -372,8 +409,7 @@ void ConcreteScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_inde this->state_->DebugVerify(); } -/******** Schedule: cache read/write ********/ -/******** Schedule: reduction ********/ +/******** Schedule: Reduction ********/ BlockRV ConcreteScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) { StmtSRef result{nullptr}; @@ -384,7 +420,9 @@ BlockRV ConcreteScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) { return CreateRV(result); } -/******** Schedule: blockize & tensorize ********/ +/******** Schedule: Blockize & Tensorize ********/ +/******** Schedule: Annotation ********/ +/******** Schedule: Misc ********/ } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index cfdd9c8452f7..2af4675ddcca 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -49,7 +49,7 @@ class ConcreteScheduleNode : public ScheduleNode { // `state_` is not visited // `error_render_level_` is not visited // `symbol_table_` is not visited - // `analyzer_` is not visitied + // `analyzer_` is not visited } virtual ~ConcreteScheduleNode() = default; @@ -82,6 +82,10 @@ class ConcreteScheduleNode : public ScheduleNode { LoopRV Fuse(const Array& loop_rvs) override; Array Split(const LoopRV& loop_rv, const Array>& factors) override; /******** Schedule: Manipulate ForKind ********/ + void Parallel(const LoopRV& loop_rv) override; + void Vectorize(const LoopRV& loop_rv) override; + void Bind(const LoopRV& loop_rv, const String& thread_axis) override; + void Unroll(const LoopRV& loop_rv) override; /******** Schedule: Insert cache stages ********/ /******** Schedule: Compute location ********/ void ComputeInline(const BlockRV& block) override; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 4b9c76947bb1..04c38f67da7d 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -42,7 +42,6 @@ Array GetBlocks(const ScheduleState& self, const String& name, const S */ Array GetLoops(const StmtSRef& block_sref); /******** Schedule: Transform loops ********/ - /*! * Split a loop into a list of consecutive loops. It requires: * 1) The loop can't have annotation or thread binding. @@ -65,6 +64,47 @@ TVM_DLL Array Split(ScheduleState self, const StmtSRef& loop_sref, */ TVM_DLL StmtSRef Fuse(ScheduleState self, const Array& loop_srefs); /******** Schedule: Manipulate ForKind ********/ +/*! + * \brief Parallelize the input loop. It requires: + * 1) The scope block that the loop is in should have stage-pipeline property + * 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine + * bindings + * 3) For each block under the loop, the loop can only be contained in data-parallel block iters' + * bindings + * \param self The state of the schedule + * \param loop_sref The sref of the loop to be parallelized + */ +TVM_DLL void Parallel(ScheduleState self, const StmtSRef& loop_sref); +/*! + * \brief Vectorize the input loop. It requires: + * 1) The scope block that the loop is in should have stage-pipeline property + * 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine + * bindings + * 3) For each block under the loop, the loop can only be contained in data-parallel block iters' + * bindings + * \param self The state of the schedule + * \param loop_sref The sref of the loop to be vectorized + */ +TVM_DLL void Vectorize(ScheduleState self, const StmtSRef& loop_sref); +/*! + * \brief Bind the input loop to the given thread axis. It requires: + * 1) The scope block that the loop is in should have stage-pipeline property + * 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine + * bindings + * 3) For each block under the loop, if the thread axis starts with "threadIdx`, the loop can only + * be contained in data-parallel block iter and reduction block iters' bindings. Otherwise the + * loop can only be contained in data-parallel block iters' bindings + * \param self The state of the schedule + * \param loop_sref The sref of the loop to be bound to the thread axis + * \param thread_axis The thread axis to be bound to the loop + */ +TVM_DLL void Bind(ScheduleState self, const StmtSRef& loop_sref, const IterVar& thread_axis); +/*! + * \brief Unroll the input loop. It requires nothing + * \param self The state of the schedule + * \param loop_sref The loop to be unrolled + */ +TVM_DLL void Unroll(ScheduleState self, const StmtSRef& loop_sref); /******** Schedule: Insert cache stages ********/ /******** Schedule: Compute location ********/ /*! @@ -96,6 +136,7 @@ TVM_DLL void ReverseComputeInline(ScheduleState self, const StmtSRef& block_sref /*! * \brief Factor a reduction block by the specified loop * \details See python/tvm/tir/schedule/schedule.py + * \param self The state of the schedule * \param loop_sref The loop outside block for which we want to do rfactor * \param factor_axis The position where the new dimension is placed in the new introduced rfactor * buffer. Suppose the original reduction block writes to buffer `B` with diff --git a/src/tir/schedule/primitive/for_kind.cc b/src/tir/schedule/primitive/for_kind.cc new file mode 100644 index 000000000000..a6056d607042 --- /dev/null +++ b/src/tir/schedule/primitive/for_kind.cc @@ -0,0 +1,289 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +class WrongBlockIterTypeError : public ScheduleError { + public: + explicit WrongBlockIterTypeError(IRModule mod, ForKind for_kind, Var loop_var, Block block) + : mod_(std::move(mod)), loop_var_(std::move(loop_var)), block_(std::move(block)) { + op_str_ = for_kind == ForKind::kParallel + ? "parallel" + : for_kind == ForKind::kVectorized ? "vectorize" : "bind"; + } + String FastErrorString() const final { + std::ostringstream os; + os << "ScheduleError: The \"" << op_str_ + << "\" cannot be fulfilled with regard to some of its underlying block"; + return os.str(); + } + String DetailRenderTemplate() const final { + std::ostringstream os; + if (op_str_ != "bind") { + os << "The \"" << op_str_ + << "\" cannot be fulfilled with regard to block {0} because some block iter whose block " + "binding contains the loop var is not a data parallel block iter"; + } else { + os << "The \"bind\" cannot be fulfilled with regard to block {0}. This is because some of its" + " block iter whose block binding contains " + << loop_var_ + << " does not meet any of the conditions:\n1) the block iter is data parallel;\n2) the " + "block iter is a reduction block iter, and the thread axis to be bound is " + "\"threadIdx.x/y/z\""; + } + return os.str(); + } + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + IRModule mod_; + std::string op_str_; + Var loop_var_; + Block block_; +}; + +/*! + * \brief Check if a loop can be parallelized/vectorized/bound with regard to a specific block + * \details There are two conditions: + * 1) The block is required to have affine bindings, and + * 2) For each block iter whose binding contains the input loop variable, either + * - the block iter is data parallel, or + * - the block iter is a reduction block iter, and the input `thread_tag` starts with "threadIdx" + * in case of cross-thread reduction. + * \param self The schedule state + * \param for_kind The desired ForKind (only `kParallel`, `kVectorized` and `kThreadBinding` are + * allowed) + * \param loop_var The loop variable of the loop to be checked + * \param block_realize The block-realize of the block to be checked + * \param thread_scope The thread scope of the thread axis to be bound, which is an invalid value if + * the operation is not "bind" + * \throws ScheduleError If the input loop cannot be parallelized/vectorized/bound with regard to + * the input block + */ +void CheckLoopParallelizableInBlock(const ScheduleState& self, ForKind for_kind, + const Var& loop_var, const BlockRealize& block_realize, + runtime::ThreadScope thread_scope) { + const Block& block = block_realize->block; + + // Cond 1. The block is required to have affine bindings. + CheckAffineBinding(self, block); + + // Cond 2. For each block iter whose binding contains `loop_var`, only two cases are allowed. + ICHECK_EQ(block->iter_vars.size(), block_realize->iter_values.size()); + int n_iters = static_cast(block->iter_vars.size()); + for (int i = 0; i < n_iters; ++i) { + const IterVar& iter_var = block->iter_vars[i]; + const PrimExpr& binding = block_realize->iter_values[i]; + + if (!UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) { + continue; + } + // Only two cases are allowed: + // - The block iter is data parallel, or + // - The block iter is a reduction block iter, and the `thread_scope` is "threadIdx.x/y/z" + // in case of cross-thread reduction. + IterVarType iter_type = iter_var->iter_type; + if (!(iter_type == kDataPar || + (iter_type == kCommReduce && thread_scope.rank == 1 && thread_scope.dim_index != -1))) { + throw WrongBlockIterTypeError(self->mod, for_kind, loop_var, block); + } + } +} + +/*! + * \brief For each block (recursive) under the given loop, check whether the input loop can be + * parallelized/vectorized/bound with regard to the block + * \param self The schedule state + * \param loop The loop to be parallelized/vectorized/bound + * \param for_kind The desired ForKind (only `kParallel`, `kVectorized` and `kThreadBinding` are + * allowed) + * \param thread_scope The thread scope of the thread axis to be bound, which is an invalid value if + * the operation is not "bind" + */ +void CheckParallelizability(const ScheduleState& self, const For& loop, ForKind for_kind, + runtime::ThreadScope thread_scope) { + PreOrderVisit(loop, [&](const ObjectRef& node) { + if (const auto* realize = node.as()) { + CheckLoopParallelizableInBlock(self, for_kind, loop->loop_var, GetRef(realize), + thread_scope); + } + return true; + }); +} + +/*! + * \brief The implementation of parallelizing/vectorizing/binding a given loop + * \param self The schedule state + * \param loop_sref The sref of the loop to be parallelized/vectorized/bound + * \param for_kind The type of the operation (only `kParallel`, `kVectorized` and `kThreadBinding` + * are allowed) + * \param thread_axis The thread axis that the input loop is bound to, which is defined only when + * `for_kind` is `kThreadBinding` + */ +void ParallelizeComputation(const ScheduleState& self, const StmtSRef& loop_sref, ForKind for_kind, + Optional thread_axis) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + + /* + * Check: + * - 1. the subtree rooted from the input loop in sref tree has compact data flow + * - 2. all the blocks under the given loop have affine block bindings + * - 3. the input loop can be only bound to data parallel block iters, or the loop can be bound to + * reduction block iter if `thread` is `threadIdx.x/y/z` in case of cross-thread reduction + * When the above conditions are all satisfied, this input loop can be + * parallelized/vectorized/bound. + */ + // Step 1. Check whether the subtree rooted from the `loop` in sref tree has compact data flow. + CheckSRefSubtreeCompactDataFlow(self, loop_sref); + + // Step 2. Check whether the loop can be parallelized/vectorized/bound with regard to each + // underlying block. + CheckParallelizability(self, GetRef(loop), for_kind, + thread_axis.defined() + ? runtime::ThreadScope::Create(thread_axis.value()->thread_tag) + : runtime::ThreadScope{-1, -1}); + + // Step 3. Loop update and IR replacement + ObjectPtr new_loop = make_object(*loop); + new_loop->kind = for_kind; + new_loop->thread_binding = std::move(thread_axis); + self->Replace(loop_sref, For(new_loop), {}); +} + +void Parallel(ScheduleState self, const StmtSRef& loop_sref) { + ParallelizeComputation(self, loop_sref, ForKind::kParallel, NullOpt); +} + +void Vectorize(ScheduleState self, const StmtSRef& loop_sref) { + ParallelizeComputation(self, loop_sref, ForKind::kVectorized, NullOpt); +} + +void Bind(ScheduleState self, const StmtSRef& loop_sref, const IterVar& thread_axis) { + ParallelizeComputation(self, loop_sref, ForKind::kThreadBinding, thread_axis); +} + +void Unroll(ScheduleState self, const StmtSRef& loop_sref) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + ObjectPtr new_loop = make_object(*loop); + new_loop->kind = ForKind::kUnrolled; + new_loop->thread_binding = NullOpt; + self->Replace(loop_sref, For(new_loop), {}); +} + +/******** Instruction Registration ********/ + +struct ParallelTraits : public UnpackedInstTraits { + static constexpr const char* kName = "Parallel"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv) { + return sch->Parallel(loop_rv); + } + + static String UnpackedAsPython(Array outputs, String loop_rv) { + PythonAPICall py("parallel"); + py.Input("loop", loop_rv); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +struct VectorizeTraits : public UnpackedInstTraits { + static constexpr const char* kName = "Vectorize"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv) { + return sch->Vectorize(loop_rv); + } + + static String UnpackedAsPython(Array outputs, String loop_rv) { + PythonAPICall py("vectorize"); + py.Input("loop", loop_rv); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +struct BindTraits : public UnpackedInstTraits { + static constexpr const char* kName = "Bind"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 1; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, String thread) { + return sch->Bind(loop_rv, thread); + } + + static String UnpackedAsPython(Array outputs, String loop_rv, String thread) { + PythonAPICall py("bind"); + py.Input("loop", loop_rv); + py.Input("thread", thread); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +struct UnrollTraits : public UnpackedInstTraits { + static constexpr const char* kName = "Unroll"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv) { return sch->Unroll(loop_rv); } + + static String UnpackedAsPython(Array outputs, String loop_rv) { + PythonAPICall py("unroll"); + py.Input("loop", loop_rv); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(ParallelTraits); +TVM_REGISTER_INST_KIND_TRAITS(VectorizeTraits); +TVM_REGISTER_INST_KIND_TRAITS(BindTraits); +TVM_REGISTER_INST_KIND_TRAITS(UnrollTraits); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc index bf29ceb1ef9f..af77e51e4d83 100644 --- a/src/tir/schedule/primitive/reduction.cc +++ b/src/tir/schedule/primitive/reduction.cc @@ -938,7 +938,9 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax Block new_scope_root_block = BlockReplacer::Replace( old_scope_root_block, rf_body, loops[0], wb_block_creator.new_block_realize_, block_realize, GetRef(rf_loop), reduce_loop_vars, loop_vars2loop, rf_buffer); - self->Replace(scope_root, new_scope_root_block, {{old_scope_root_block, new_scope_root_block}}); + self->Replace( + scope_root, new_scope_root_block, + {{old_scope_root_block, new_scope_root_block}, {block, wb_block_creator.new_block_}}); // Step 2. Update scope information. std::vector new_block_srefs{self->stmt2ref.at(rf_block_creator.new_block_.get()), diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index d6dc0b446e16..f21a4c370a5b 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -126,6 +126,12 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetLoops") TVM_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method(&ScheduleNode::Fuse); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSplit").set_body_method(&ScheduleNode::Split); /******** (FFI) Manipulate ForKind ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleParallel") + .set_body_method(&ScheduleNode::Parallel); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleVectorize") + .set_body_method(&ScheduleNode::Vectorize); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleBind").set_body_method(&ScheduleNode::Bind); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnroll").set_body_method(&ScheduleNode::Unroll); /******** (FFI) Insert cache stages ********/ /******** (FFI) Compute location ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeInline") diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index 6dd09680e987..9a9b97497e04 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -112,7 +112,7 @@ bool ProducerCoversConsumer(const Array& buffer_shape, * \param self The schedule class * \param stmt The statement, or the realize node of the statement whose sref to be set * \param seq_index The seq_index to be set - * \note The method is NOP for statements that are not scheduleable, i.e. not For or Block + * \note The method is NOP for statements that are not schedulable, i.e. not For or Block */ void SetSeqIndex(ScheduleStateNode* self, const Stmt& stmt, int seq_index) { if (const auto* realize = stmt.as()) { @@ -405,7 +405,7 @@ class StateCreator : private StmtVisitor { std::unordered_map block2realize_; /*! \brief The stack frames of blocks in the DFS visit. */ std::vector> block_frames_; - /*! \brief The auxilary analyzer */ + /*! \brief The auxiliary analyzer */ arith::Analyzer analyzer_; }; @@ -565,7 +565,7 @@ class SRefTreePruner : public StmtVisitor { } auto it = self_->stmt2ref.find(op); ICHECK(it != self_->stmt2ref.end()) - << "IndexError: Cannot find correpsonding StmtSRef for the loop:\n" + << "IndexError: Cannot find corresponding StmtSRef for the loop:\n" << GetRef(op); StmtSRef& sref = it->second; // Detect reuse @@ -588,7 +588,7 @@ class SRefTreePruner : public StmtVisitor { } auto it = self_->stmt2ref.find(op); ICHECK(it != self_->stmt2ref.end()) - << "IndexError: Cannot find correpsonding StmtSRef for the block:\n" + << "IndexError: Cannot find corresponding StmtSRef for the block:\n" << GetRef(op); StmtSRef& sref = it->second; // Detect reuse diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index e0ffdc7b019f..e3f675e8628f 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -101,6 +101,46 @@ Array TracedScheduleNode::Split(const LoopRV& loop_rv, /******** Schedule: Manipulate ForKind ********/ +void TracedScheduleNode::Parallel(const LoopRV& loop_rv) { + ConcreteScheduleNode::Parallel(loop_rv); + + static const InstructionKind& kind = InstructionKind::Get("Parallel"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{loop_rv}, + /*attrs=*/{}, + /*outputs=*/{})); +} + +void TracedScheduleNode::Vectorize(const LoopRV& loop_rv) { + ConcreteScheduleNode::Vectorize(loop_rv); + + static const InstructionKind& kind = InstructionKind::Get("Vectorize"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{loop_rv}, + /*attrs=*/{}, + /*outputs=*/{})); +} + +void TracedScheduleNode::Bind(const LoopRV& loop_rv, const String& thread_axis) { + ConcreteScheduleNode::Bind(loop_rv, thread_axis); + + static const InstructionKind& kind = InstructionKind::Get("Bind"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{loop_rv}, + /*attrs=*/{thread_axis}, + /*outputs=*/{})); +} + +void TracedScheduleNode::Unroll(const LoopRV& loop_rv) { + ConcreteScheduleNode::Unroll(loop_rv); + + static const InstructionKind& kind = InstructionKind::Get("Unroll"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{loop_rv}, + /*attrs=*/{}, + /*outputs=*/{})); +} + /******** Schedule: Insert cache stages ********/ /******** Schedule: Compute location ********/ diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 4650c44ba8c3..f5f31abe1556 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -55,6 +55,10 @@ class TracedScheduleNode : public ConcreteScheduleNode { LoopRV Fuse(const Array& loop_rvs) final; Array Split(const LoopRV& loop_rv, const Array>& factor_rvs) final; /******** Schedule: Manipulate ForKind ********/ + void Parallel(const LoopRV& loop_rv) final; + void Vectorize(const LoopRV& loop_rv) final; + void Bind(const LoopRV& loop_rv, const String& thread_axis) final; + void Unroll(const LoopRV& loop_rv) final; /******** Schedule: Insert cache stages ********/ /******** Schedule: Compute location ********/ void ComputeInline(const BlockRV& block_rv) final; diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index 85c412346056..5eb6d5b03921 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -140,7 +140,10 @@ class BufferFlattener : public StmtExprMutator { /*var=*/std::move(var), /*iter_type=*/IterVarType::kThreadIndex, /*thread_tag=*/thread_tag); - String attr_key = thread_tag == "vthread" ? attr::virtual_thread : attr::thread_extent; + String attr_key = (thread_tag == "vthread" || thread_tag == "vthread.x" || + thread_tag == "vthread.y" || thread_tag == "vthread.z") + ? attr::virtual_thread + : attr::thread_extent; return AttrStmt(/*node=*/std::move(iter_var), /*attr_key=*/std::move(attr_key), /*value=*/std::move(extent), diff --git a/src/tir/transforms/inject_virtual_thread.cc b/src/tir/transforms/inject_virtual_thread.cc index 4ef10f326bb0..4964bec0334e 100644 --- a/src/tir/transforms/inject_virtual_thread.cc +++ b/src/tir/transforms/inject_virtual_thread.cc @@ -459,12 +459,12 @@ class VirtualThreadInjector : public StmtMutator { op = stmt.as(); if (op->attr_key == attr::virtual_thread) { IterVar iv = Downcast(op->node); - bool allow_share = iv->thread_tag == "vthread"; + bool allow_share = std::string(iv->thread_tag).substr(0, 7) == "vthread"; int nthread = static_cast(op->value.as()->value); VarTouchedAnalysis vs; auto touched = vs.TouchedVar(op->body, iv->var.get()); - VTInjector injecter(iv->var, nthread, touched, allow_share); - return injecter(op->body); + VTInjector injector(iv->var, nthread, touched, allow_share); + return injector(op->body); } else { return stmt; } @@ -476,11 +476,6 @@ class VirtualThreadInjector : public StmtMutator { } }; -Stmt InjectVirtualThread(Stmt stmt) { - stmt = VirtualThreadInjector()(std::move(stmt)); - return ConvertSSA(std::move(stmt)); -} - namespace transform { Pass InjectVirtualThread() { diff --git a/src/tir/transforms/unify_thread_binding.cc b/src/tir/transforms/unify_thread_binding.cc new file mode 100644 index 000000000000..6a26103e6079 --- /dev/null +++ b/src/tir/transforms/unify_thread_binding.cc @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file unify_thread_binding.cc + */ + +#include +#include +#include +#include + +#include "ir_utils.h" + +namespace tvm { +namespace tir { + +/*! + * \brief A mutator which searches AttrStmts of thread bindings and changes the `node` field IterVar + * of the AttrStmts, so that for one kind of thread binding, all such thread bindings use the same + * IterVar + */ +class ThreadBindingUnifier : public StmtExprMutator { + public: + static Stmt Unify(Stmt stmt) { return ThreadBindingUnifier()(std::move(stmt)); } + + private: + Stmt VisitStmt_(const AttrStmtNode* attr) final { + // If this AttrStmt is not thread binding attribute, return as usual. + if (attr->attr_key != attr::thread_extent && attr->attr_key != attr::virtual_thread) { + return StmtMutator::VisitStmt_(attr); + } + + // Step 1. Fetch the old IterVar and the thread tag. + IterVar old_iter_var = Downcast(attr->node); + IterVar new_iter_var{nullptr}; + const String& thread_tag = old_iter_var->thread_tag; + + // Step 2: Increase `thread_block_depth_` if the thread tag starts with "blockIdx". If the + // thread block depth is 0 before the increasement, it means we are entering a new kernel, and + // therefore we need to make `thread_tag2iter_var_map_` empty, as different kernels can have + // thread axes with different extents. + if (std::string(thread_tag).substr(0, 9) == "blockIdx.") { + if (!thread_block_depth_) { + thread_tag2iter_var_map_.clear(); + } + ++thread_block_depth_; + } + + // Step 3. See if an IterVar for this kind of thread binding was created before. If so, we use + // the created IterVar. Otherwise, we create a new IterVar for this thread binding and store the + // IterVar in mapping `thread_tag2iter_var_map_`. + Map::iterator it = thread_tag2iter_var_map_.find(thread_tag); + if (it != thread_tag2iter_var_map_.end()) { + new_iter_var = (*it).second; + CHECK(ana.CanProveEqual(old_iter_var->dom->extent, (*it).second->dom->extent)) + << "ValueError: All loops that are bound to `" << thread_tag + << "` should have the same extent. However, there are two loops with extent " + << (*it).second->dom->extent << " and " << old_iter_var->dom->extent + << ", which are not equal"; + } else { + ObjectPtr p_new_iter_var = make_object(*old_iter_var.get()); + p_new_iter_var->var = Var(thread_tag); + new_iter_var = IterVar(p_new_iter_var); + thread_tag2iter_var_map_.Set(thread_tag, new_iter_var); + } + + // Step 4. We will substitute the occurrences of the old variable in the old IterVar with the + // new variable in further mutation. Thus, we store the mapping entry. + var_substitution_map_.Set(old_iter_var->var, new_iter_var->var); + + // Step 5. Mutate recursively, update the AttrStmt with the new IterVar, and decrease the depth + // counter if the thread tag starts with "blockIdx". + AttrStmt new_attr = Downcast(StmtMutator::VisitStmt_(attr)); + ObjectPtr p_new_attr = CopyOnWrite(new_attr.get()); + p_new_attr->node = new_iter_var; + if (std::string(thread_tag).substr(0, 9) == "blockIdx.") { + --thread_block_depth_; + } + return Stmt(p_new_attr); + } + + PrimExpr VisitExpr_(const VarNode* var) final { + // If this variable appears as a key in `var_substitution_map_`, we substitute it with its + // corresponding value in the mapping. + Map::iterator it = var_substitution_map_.find(GetRef(var)); + return it != var_substitution_map_.end() ? (*it).second : GetRef(var); + } + + /*! + * \brief A mapping from a thread tag to its corresponding IterVar that is shared by all + * occurrences of the thread tag + * */ + Map thread_tag2iter_var_map_; + /*! \brief A mapping from old variables to new variables, which is used for substitution */ + Map var_substitution_map_; + /*! \brief A integer counter storing the depth of thread bindings of "blockIdx.x/y/z" */ + int thread_block_depth_ = 0; + /*! \brief An analyzer used for equality proof */ + arith::Analyzer ana; +}; + +PrimFunc UnifyThreadBinding(PrimFunc f) { + // Only apply this pass to TIR that is not from TE schedules + if (!IsFromLegacyTESchedule(f)) { + PrimFuncNode* fptr = f.CopyOnWrite(); + fptr->body = ThreadBindingUnifier::Unify(std::move(f->body)); + return f; + } else { + return f; + } +} + +namespace transform { + +Pass UnifyThreadBinding() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + return UnifyThreadBinding(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.UnifyThreadBinding", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.UnifyThreadBinding").set_body_typed(UnifyThreadBinding); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_tir_schedule_for_kind.py b/tests/python/unittest/test_tir_schedule_for_kind.py new file mode 100644 index 000000000000..5649a06bd3b8 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_for_kind.py @@ -0,0 +1,365 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import sys + +import pytest +import tvm +import tvm.testing +from tvm import tir +from tvm.script import ty +from tvm.tir.schedule.testing import verify_trace_roundtrip + +# pylint: disable=no-member,invalid-name,unused-variable + + +@tvm.script.tir +def element_wise(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + + +@tvm.script.tir +def element_wise_parallelized(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i0 in tir.parallel(0, 128): + for i1 in tir.serial(0, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i0) + tir.bind(vj, i1) + B[vi, vj] = A[vi, vj] * 2.0 + + +@tvm.script.tir +def element_wise_i_bound(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i0 in tir.thread_binding(0, 128, thread="threadIdx.x"): + for i1 in tir.serial(0, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i0) + tir.bind(vj, i1) + B[vi, vj] = A[vi, vj] * 2.0 + + +@tvm.script.tir +def element_wise_compute_at_split(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + B = tir.alloc_buffer((128, 128)) + for i in tir.serial(0, 128): + for j0 in tir.serial(0, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j0) + B[vi, vj] = A[vi, vj] * 2.0 + for j1o, j1i in tir.grid(32, 4): + with tir.block([128, 128], "C") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j1o * 4 + j1i) + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def element_wise_compute_at_split_vectorized(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + B = tir.alloc_buffer((128, 128)) + for i in tir.serial(0, 128): + for j0 in tir.serial(0, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j0) + B[vi, vj] = A[vi, vj] * 2.0 + for j1o in tir.serial(0, 32): + for j1i in tir.vectorized(0, 4): + with tir.block([128, 128], "C") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j1o * 4 + j1i) + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def element_wise_split_predicate(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + for i, j_0, j_1 in tir.grid(128, 13, 10): + with tir.block([128, 128], "B") as [vi, vj]: + tir.where(j_0 * 10 + j_1 < 128) + tir.bind(vi, i) + tir.bind(vj, j_0 * 10 + j_1) + B[vi, vj] = A[vi, vj] * 2.0 + + +@tvm.script.tir +def element_wise_split_predicate_parallelized(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + for i in tir.serial(0, 128): + for j_0 in tir.parallel(0, 13): + for j_1 in tir.serial(0, 10): + with tir.block([128, 128], "B") as [vi, vj]: + tir.where(j_0 * 10 + j_1 < 128) + tir.bind(vi, i) + tir.bind(vj, j_0 * 10 + j_1) + B[vi, vj] = A[vi, vj] * 2.0 + + +@tvm.script.tir +def element_wise_split_predicate_vectorized(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + for i in tir.vectorized(0, 128): + for j_0, j_1 in tir.grid(13, 10): + with tir.block([128, 128], "B") as [vi, vj]: + tir.where(j_0 * 10 + j_1 < 128) + tir.bind(vi, i) + tir.bind(vj, j_0 * 10 + j_1) + B[vi, vj] = A[vi, vj] * 2.0 + + +@tvm.script.tir +def element_wise_compute_at_split_j0_j1o_bound(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + B = tir.alloc_buffer((128, 128)) + for i in tir.serial(0, 128): + for j0 in tir.thread_binding(0, 128, thread="threadIdx.x"): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j0) + B[vi, vj] = A[vi, vj] * 2.0 + for j1o in tir.thread_binding(0, 32, thread="threadIdx.x"): + for j1i in tir.serial(0, 4): + with tir.block([128, 128], "C") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j1o * 4 + j1i) + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + + with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +@tvm.script.tir +def rowsum(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128,)) + + with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: + with tir.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vk] + + +@tvm.script.tir +def rowsum_unrolled(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128,)) + for i0 in tir.unroll(0, 128): + for i1 in tir.serial(0, 128): + with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: + tir.bind(vi, i0) + tir.bind(vk, i1) + with tir.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vk] + + +@tvm.script.tir +def rowsum_not_quasi_affine(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128,)) + + for i, k in tir.grid(128, 16): + with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: + tir.bind(vi, i) + tir.bind(vk, tir.floordiv(k * k, 2)) + with tir.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vk] + + +@tvm.script.tir +def rowsum_not_compact_data_flow(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128,)) + + with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: + with tir.init(): + B[vk] = 0.0 + B[vk] = B[vk] + A[vi, vk] + + +@tvm.script.tir +def rowsum_cross_thread_reduction(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128,)) + for i0 in tir.serial(0, 128): + for i1 in tir.thread_binding(0, 128, thread="threadIdx.x"): + with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: + tir.bind(vi, i0) + tir.bind(vk, i1) + with tir.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vk] + + +@tvm.script.tir +def opaque_block(a: ty.handle) -> None: + A = tir.match_buffer(a, (16,)) + for i in tir.serial(0, 15): + with tir.block([], "opaque"): + A[i + 1] = A[i + 1] + A[i] + + +# pylint: enable=no-member,invalid-name,unused-variable + + +def test_parallel(): + s = tir.Schedule(element_wise, debug_mask="all") + i, _ = s.get_loops(s.get_block("B")) + s.parallel(i) + tvm.ir.assert_structural_equal(s.mod["main"], element_wise_parallelized) + verify_trace_roundtrip(s, mod=element_wise) + + +def test_parallel_predicate(): + s = tir.Schedule(element_wise_split_predicate, debug_mask="all") + _, j, _ = s.get_loops(s.get_block("B")) + s.parallel(j) + tvm.ir.assert_structural_equal(s.mod["main"], element_wise_split_predicate_parallelized) + verify_trace_roundtrip(s, mod=element_wise_split_predicate) + + +def test_parallel_reduction_block_iter(): + s = tir.Schedule(matmul, debug_mask="all") + _, _, k = s.get_loops(s.get_block("C")) + with pytest.raises(tvm.tir.ScheduleError): + s.parallel(k) + + +def test_parallel_not_quasi_affine(): + s = tir.Schedule(rowsum_not_quasi_affine, debug_mask="all") + i, _ = s.get_loops(s.get_block("B")) + with pytest.raises(tvm.tir.ScheduleError): + s.parallel(i) + + +def test_parallel_not_compact_data_flow(): + s = tir.Schedule(rowsum_not_compact_data_flow, debug_mask="all") + i, _ = s.get_loops(s.get_block("B")) + with pytest.raises(tvm.tir.ScheduleError): + s.parallel(i) + + +def test_vectorize(): + s = tir.Schedule(element_wise_compute_at_split, debug_mask="all") + _, _, j1i = s.get_loops(s.get_block("C")) + s.vectorize(j1i) + tvm.ir.assert_structural_equal(s.mod["main"], element_wise_compute_at_split_vectorized) + verify_trace_roundtrip(s, mod=element_wise_compute_at_split) + + +def test_vectorize_predicate(): + s = tir.Schedule(element_wise_split_predicate, debug_mask="all") + i, _, _ = s.get_loops(s.get_block("B")) + s.vectorize(i) + tvm.ir.assert_structural_equal(s.mod["main"], element_wise_split_predicate_vectorized) + verify_trace_roundtrip(s, mod=element_wise_split_predicate) + + +def test_vectorize_opaque_block(): + s = tir.Schedule(opaque_block, debug_mask="all") + (i,) = s.get_loops(s.get_block("opaque")) + with pytest.raises(tvm.tir.ScheduleError): + s.vectorize(i) + + +def test_unroll(): + s = tir.Schedule(rowsum, debug_mask="all") + i, _ = s.get_loops(s.get_block("B")) + s.unroll(i) + tvm.ir.assert_structural_equal(s.mod["main"], rowsum_unrolled) + verify_trace_roundtrip(s, mod=rowsum) + + +def test_unroll_after_bind(): + s = tir.Schedule(rowsum, debug_mask="all") + i, _ = s.get_loops(s.get_block("B")) + s.bind(i, "blockIdx.x") + s.unroll(i) + tvm.ir.assert_structural_equal(s.mod["main"], rowsum_unrolled) + verify_trace_roundtrip(s, mod=rowsum) + + +def test_bind1(): + s = tir.Schedule(element_wise, debug_mask="all") + i, _ = s.get_loops(s.get_block("B")) + s.bind(i, "threadIdx.x") + tvm.ir.assert_structural_equal(s.mod["main"], element_wise_i_bound) + verify_trace_roundtrip(s, mod=element_wise) + + +def test_bind2(): + s = tir.Schedule(element_wise_compute_at_split, debug_mask="all") + _, j0 = s.get_loops(s.get_block("B")) + _, j1o, _ = s.get_loops(s.get_block("C")) + s.bind(j0, "threadIdx.x") + s.bind(j1o, "threadIdx.x") + tvm.ir.assert_structural_equal(s.mod["main"], element_wise_compute_at_split_j0_j1o_bound) + verify_trace_roundtrip(s, mod=element_wise_compute_at_split) + + +def test_bind_cross_thread_reduction(): + s = tir.Schedule(rowsum, debug_mask="all") + _, k = s.get_loops(s.get_block("B")) + s.bind(k, "threadIdx.x") + tvm.ir.assert_structural_equal(s.mod["main"], rowsum_cross_thread_reduction) + verify_trace_roundtrip(s, mod=rowsum) + + +def test_bind_not_cross_thread_reduction(): + s = tir.Schedule(rowsum, debug_mask="all") + _, k = s.get_loops(s.get_block("B")) + with pytest.raises(tvm.tir.ScheduleError): + s.bind(k, "blockIdx.x") + + +def test_bind_after_bind(): + s = tir.Schedule(element_wise, debug_mask="all") + i, _ = s.get_loops(s.get_block("B")) + s.bind(i, "blockIdx.x") + s.bind(i, "threadIdx.x") + tvm.ir.assert_structural_equal(s.mod["main"], element_wise_i_bound) + verify_trace_roundtrip(s, mod=element_wise) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_reduction.py b/tests/python/unittest/test_tir_schedule_reduction.py index 6b4ac235039a..067952899c0a 100644 --- a/tests/python/unittest/test_tir_schedule_reduction.py +++ b/tests/python/unittest/test_tir_schedule_reduction.py @@ -459,28 +459,34 @@ def multiple_reduction_blocks_rfactor(a: ty.handle, f: ty.handle) -> None: def test_reduction_rfactor_matmul(): s = tir.Schedule(transformed_matmul, debug_mask="all") - _, _, _, _, kii = s.get_loops(s.get_block("update")) + update = s.get_block("update") + _, _, _, _, kii = s.get_loops(update) rf_block = s.rfactor(kii, 0) tvm.ir.assert_structural_equal(s.mod["main"], matmul_rfactor) assert s.get(rf_block).same_as(s.get(s.get_block("update_rf"))) + assert s.get(update).same_as(s.get(s.get_block("update"))) verify_trace_roundtrip(s, mod=transformed_matmul) def test_reduction_rfactor_square_sum(): s = tir.Schedule(square_sum, debug_mask="all") - _, _, j = s.get_loops(s.get_block("C")) + C = s.get_block("C") + _, _, j = s.get_loops(C) rf_block = s.rfactor(j, 1) tvm.ir.assert_structural_equal(s.mod["main"], square_sum_rfactor) assert s.get(rf_block).same_as(s.get(s.get_block("C_rf"))) + assert s.get(C).same_as(s.get(s.get_block("C"))) verify_trace_roundtrip(s, mod=square_sum) def test_reduction_rfactor_square_sum_square_root(): s = tir.Schedule(transformed_square_sum_square_root, debug_mask="all") - _, _, f_i = s.get_loops(s.get_block("C")) + C = s.get_block("C") + _, _, f_i = s.get_loops(C) rf_block = s.rfactor(f_i, 0) tvm.ir.assert_structural_equal(s.mod["main"], square_sum_square_root_rfactor) assert s.get(rf_block).same_as(s.get(s.get_block("C_rf"))) + assert s.get(C).same_as(s.get(s.get_block("C"))) verify_trace_roundtrip(s, mod=transformed_square_sum_square_root) @@ -544,10 +550,12 @@ def test_reduction_rfactor_factor_axis_range_fail(): def test_reduction_rfactor_factor_axis_range(): s = tir.Schedule(transformed_matmul, debug_mask="all") - _, _, _, _, kii = s.get_loops(s.get_block("update")) + update = s.get_block("update") + _, _, _, _, kii = s.get_loops(update) rf_block = s.rfactor(kii, -3) tvm.ir.assert_structural_equal(s.mod["main"], matmul_rfactor) assert s.get(rf_block).same_as(s.get(s.get_block("update_rf"))) + assert s.get(update).same_as(s.get(s.get_block("update"))) verify_trace_roundtrip(s, mod=transformed_matmul) @@ -581,9 +589,12 @@ def test_reduction_rfactor_wrong_loops2(): def test_reduction_rfactor_zero_dim(): s = tir.Schedule(rowsum_zero_dim, debug_mask="all") - (k,) = s.get_loops(s.get_block("B")) - s.rfactor(k, 0) + B = s.get_block("B") + (k,) = s.get_loops(B) + rf_block = s.rfactor(k, 0) tvm.ir.assert_structural_equal(s.mod["main"], rowsum_zero_dim_rfactor) + assert s.get(rf_block).same_as(s.get(s.get_block("B_rf"))) + assert s.get(B).same_as(s.get(s.get_block("B"))) verify_trace_roundtrip(s, mod=rowsum_zero_dim) @@ -608,9 +619,12 @@ def test_reduction_rfactor_outermost_loop_multiple_children_fail(): # pylint: d def test_reduction_rfactor_outermost_loop_multiple_children(): # pylint: disable=invalid-name s = tir.Schedule(multiple_reduction_blocks, debug_mask="all") - _, _, k1o, _ = s.get_loops(s.get_block("C")) - s.rfactor(k1o, 2) + C = s.get_block("C") + _, _, k1o, _ = s.get_loops(C) + rf_block = s.rfactor(k1o, 2) tvm.ir.assert_structural_equal(s.mod["main"], multiple_reduction_blocks_rfactor) + assert s.get(rf_block).same_as(s.get(s.get_block("C_rf"))) + assert s.get(C).same_as(s.get(s.get_block("C"))) verify_trace_roundtrip(s, mod=multiple_reduction_blocks) diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py b/tests/python/unittest/test_tir_transform_compact_buffer_region.py index 15da022e67d6..cefdb5fd8c6a 100644 --- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py +++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py @@ -438,3 +438,4 @@ def test_storage_align(): test_complex() test_match_buffer() test_storage_align() + test_lower_te() diff --git a/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py b/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py index 708f1af0c064..cfdcc1a65911 100644 --- a/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py +++ b/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py @@ -84,3 +84,4 @@ def test_lower_te(): if __name__ == "__main__": test_elementwise() + test_lower_te() diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py index 3b2b3cf2f55b..c51b5319e85f 100644 --- a/tests/python/unittest/test_tir_transform_flatten_buffer.py +++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py @@ -250,3 +250,4 @@ def test_lower_te(): test_predicate() test_unit_loops() test_multi_alloc() + test_lower_te() diff --git a/tests/python/unittest/test_tir_transform_lower_init_block.py b/tests/python/unittest/test_tir_transform_lower_init_block.py index 8499c9334e46..1f8a4adf7054 100644 --- a/tests/python/unittest/test_tir_transform_lower_init_block.py +++ b/tests/python/unittest/test_tir_transform_lower_init_block.py @@ -97,3 +97,4 @@ def test_lower_te(): if __name__ == "__main__": test_lower_reduction() test_lower_match_buffer() + test_lower_te() diff --git a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py index 72a2f5ebc240..8418e192d060 100644 --- a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py +++ b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py @@ -164,3 +164,4 @@ def test_lower_te(): test_elementwise() test_locate_buffer_allocation() test_match_buffer_allocation() + test_lower_te() diff --git a/tests/python/unittest/test_tir_transform_storage_flatten.py b/tests/python/unittest/test_tir_transform_storage_flatten.py index 1dd4a4852938..b57fa6c417b2 100644 --- a/tests/python/unittest/test_tir_transform_storage_flatten.py +++ b/tests/python/unittest/test_tir_transform_storage_flatten.py @@ -154,3 +154,4 @@ def test_flatten_tir(): test_flatten_storage_align() test_flatten_double_buffer() test_flatten_prefetch() + test_flatten_tir() diff --git a/tests/python/unittest/test_tir_transform_unify_thread_binding.py b/tests/python/unittest/test_tir_transform_unify_thread_binding.py new file mode 100644 index 000000000000..8e0b6dc804aa --- /dev/null +++ b/tests/python/unittest/test_tir_transform_unify_thread_binding.py @@ -0,0 +1,227 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +from tvm import tir, te +from tvm.script import ty + + +def _check(original, transformed): + mod = tvm.IRModule.from_expr(original) + mod = tvm.tir.transform.UnifyThreadBinding()(mod) + mod = tvm.tir.transform.Simplify()(mod) + tvm.ir.assert_structural_equal(mod["main"], transformed, True) + + +def _check_fail(original): + mod = tvm.IRModule.from_expr(original) + with pytest.raises(ValueError): + tvm.tir.transform.UnifyThreadBinding()(mod) + + +@tvm.script.tir +def element_wise_thread_x(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + j1_0 = tir.env_thread("threadIdx.x") + j0_0 = tir.env_thread("threadIdx.x") + i = tir.env_thread("blockIdx.x") + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + tir.launch_thread(i, 128) + with tir.launch_thread(j0_0, 4): + for j0_1 in tir.serial(0, 32): + tir.store( + B.data, + i * 128 + j0_0 * 32 + j0_1, + tir.load("float32", A.data, i * 128 + j0_0 * 32 + j0_1) * 2.0, + True, + ) + tir.launch_thread(j1_0, 4) + for j1_1 in tir.serial(0, 32): + tir.store( + C.data, + i * 128 + j1_0 * 32 + j1_1, + tir.load("float32", A.data, i * 128 + j1_0 * 32 + j1_1) + 1.0, + True, + ) + + +@tvm.script.tir +def unified_element_wise_thread_x(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + thread_x = tir.env_thread("threadIdx.x") + block_x = tir.env_thread("blockIdx.x") + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + tir.launch_thread(block_x, 128) + with tir.launch_thread(thread_x, 4): + for j0_1 in tir.serial(0, 32): + tir.store( + B.data, + block_x * 128 + thread_x * 32 + j0_1, + tir.load("float32", A.data, block_x * 128 + thread_x * 32 + j0_1) * 2.0, + True, + ) + tir.launch_thread(thread_x, 4) + for j1_1 in tir.serial(0, 32): + tir.store( + C.data, + block_x * 128 + thread_x * 32 + j1_1, + tir.load("float32", A.data, block_x * 128 + thread_x * 32 + j1_1) + 1.0, + True, + ) + + +@tvm.script.tir +def element_wise_vthread_x(a: ty.handle, b: ty.handle) -> None: + i_0 = tir.env_thread("vthread.x") + i_1 = tir.env_thread("threadIdx.x") + j_0 = tir.env_thread("vthread.x") + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + tir.launch_thread(i_0, 2) + tir.launch_thread(i_1, 64) + tir.launch_thread(j_0, 2) + for j_1 in tir.serial(0, 64): + tir.store( + B.data, + i_0 * 8192 + i_1 * 128 + j_0 * 64 + j_1, + tir.load("float32", A.data, i_0 * 8192 + i_1 * 128 + j_0 * 64 + j_1) * 2.0, + True, + ) + + +@tvm.script.tir +def unified_element_wise_vthread_x(a: ty.handle, b: ty.handle) -> None: + vthread_x = tir.env_thread("vthread.x") + thread_x = tir.env_thread("threadIdx.x") + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + tir.launch_thread(vthread_x, 2) + tir.launch_thread(thread_x, 64) + tir.launch_thread(vthread_x, 2) + for j_1 in tir.serial(0, 64): + tir.store( + B.data, + vthread_x * 8256 + thread_x * 128 + j_1, + tir.load("float32", A.data, vthread_x * 8256 + thread_x * 128 + j_1) * 2.0, + True, + ) + + +@tvm.script.tir +def element_wise_two_thread_x_in_same_kernel_not_equal( + a: ty.handle, b: ty.handle, c: ty.handle +) -> None: + i = tir.env_thread("blockIdx.x") + j0 = tir.env_thread("threadIdx.x") + j1 = tir.env_thread("threadIdx.x") + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 64]) + tir.launch_thread(i, 128) + with tir.launch_thread(j0, 128): + tir.store(B.data, i * 64 + j0, tir.load("float32", A.data, i * 128 + j0) * 2.0, True) + tir.launch_thread(j1, 64) + tir.store(C.data, i * 64 + j1, tir.load("float32", A.data, i * 128 + j1) + 1.0, True) + + +@tvm.script.tir +def element_wise_kernels_with_different_size( + a: ty.handle, b: ty.handle, c: ty.handle, d: ty.handle +) -> None: + i0 = tir.env_thread("blockIdx.x") + j0 = tir.env_thread("threadIdx.x") + i1 = tir.env_thread("blockIdx.x") + j1 = tir.env_thread("threadIdx.x") + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [256, 256]) + D = tir.match_buffer(d, [256, 256]) + with tir.launch_thread(i0, 128): + tir.launch_thread(j0, 128) + tir.store(B.data, i0 * 128 + j0, tir.load("float32", A.data, i0 * 128 + j0) * 2.0, True) + tir.launch_thread(i1, 256) + tir.launch_thread(j1, 256) + tir.store(D.data, i1 * 256 + j1, tir.load("float32", C.data, i1 * 256 + j1) + 1.0, True) + + +@tvm.script.tir +def unified_element_wise_kernels_with_different_size( + a: ty.handle, b: ty.handle, c: ty.handle, d: ty.handle +) -> None: + block_x = tir.env_thread("blockIdx.x") + thread_x = tir.env_thread("threadIdx.x") + block_x_1 = tir.env_thread("blockIdx.x") + thread_x_1 = tir.env_thread("threadIdx.x") + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [256, 256]) + D = tir.match_buffer(d, [256, 256]) + with tir.launch_thread(block_x, 128): + tir.launch_thread(thread_x, 128) + tir.store( + B.data, + block_x * 128 + thread_x, + tir.load("float32", A.data, block_x * 128 + thread_x) * 2.0, + True, + ) + tir.launch_thread(block_x_1, 256) + tir.launch_thread(thread_x_1, 256) + tir.store( + D.data, + block_x_1 * 256 + thread_x_1, + tir.load("float32", C.data, block_x_1 * 256 + thread_x_1) + 1.0, + True, + ) + + +def test_thread_x(): + _check(element_wise_thread_x, unified_element_wise_thread_x) + + +def test_vthread_x(): + _check(element_wise_vthread_x, unified_element_wise_vthread_x) + + +def test_two_thread_x_in_same_kernel_not_equal(): + _check_fail(element_wise_two_thread_x_in_same_kernel_not_equal) + + +def test_kernels_with_different_size(): + _check( + element_wise_kernels_with_different_size, unified_element_wise_kernels_with_different_size + ) + + +def test_lower_te(): + a = te.placeholder((32, 2, 2)) + b = te.compute((32, 2, 2), lambda i, j, k: a[i, j, k] * 2.0) + s = te.create_schedule(b.op) + s[b].bind(b.op.axis[1], te.thread_axis("threadIdx.x")) + s[b].bind(b.op.axis[2], te.thread_axis("threadIdx.x")) + orig_mod = tvm.driver.build_module.schedule_to_module(s, [a, b]) + mod = tvm.tir.transform.UnifyThreadBinding()(orig_mod) + tvm.ir.assert_structural_equal(mod, orig_mod) # UnifyThreadBinding should do nothing on TE + + +if __name__ == "__main__": + test_thread_x() + test_vthread_x() + test_two_thread_x_in_same_kernel_not_equal() + test_kernels_with_different_size() + test_lower_te()