From 03113d424d1f26fe06f90ea601be86806b060516 Mon Sep 17 00:00:00 2001 From: Hongyi Jin <3231950289@qq.com> Date: Mon, 16 Aug 2021 22:58:21 +0800 Subject: [PATCH 01/18] reorder primitive --- include/tvm/tir/schedule/schedule.h | 13 + python/tvm/tir/schedule/schedule.py | 58 ++++ src/tir/schedule/concrete_schedule.cc | 7 + src/tir/schedule/concrete_schedule.h | 1 + src/tir/schedule/primitive.h | 15 + .../schedule/primitive/loop_transformation.cc | 300 +++++++++++++++++- src/tir/schedule/schedule.cc | 2 + src/tir/schedule/traced_schedule.cc | 10 + src/tir/schedule/traced_schedule.h | 1 + .../unittest/test_tir_schedule_reorder.py | 265 ++++++++++++++++ 10 files changed, 666 insertions(+), 6 deletions(-) create mode 100644 tests/python/unittest/test_tir_schedule_reorder.py diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 1ac3f80ecf39..670f9e5217ba 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -219,6 +219,19 @@ class ScheduleNode : public runtime::Object { * \return The new loops after split */ virtual Array Split(const LoopRV& loop_rv, const Array>& factors) = 0; + /*! + * \brief Reorder a list of loops. It doesn't require the loops to be consecutive. + * It requires: + * 1) The loops are in the same line. That means: the loops can be ordered to [l_1, l_2, ... , + * l_n] where l_i is an ancestor of l_{i+1} and there are only single-branch loops between + * l_1 and l_n (which also indicates they are under the same scope). + * 2) In the new order, an outer loop cannot depend on inner loops. + * 3) The block below the loops have affine bindings and only have data-parallel or reduction + * block iters + * 4) A loop cannot appear multiple times in the input array. + * \param ordered_loop_rvs The loops in the new order + */ + virtual void Reorder(const Array& ordered_loop_rvs) = 0; /******** Schedule: Manipulate ForKind ********/ /*! * \brief Parallelize the input loop. It requires: diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 46e5fd6fddcb..88401a90c96b 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -442,6 +442,64 @@ def after_split(a: ty.handle, b: ty.handle) -> None: # that there is at most one None in `factors` return _ffi_api.ScheduleSplit(self, loop, factors) # type: ignore # pylint: disable=no-member + def reorder(self, *loops: List[LoopRV]) -> None: + """ + Reorder a list of loops. It doesn't require the loops to be consecutive. + It requires: + 1) The loops are in the same line. That means: the loops can be ordered to [l_1, l_2, ... , + l_n] where l_i is an ancestor of l_{i+1} and there are only single-branch loops between + l_1 and l_n (which also indicates they are under the same scope). + 2) In the new order, an outer loop cannot depend on inner loops. + 3) The block below the loops have affine bindings and only have data-parallel or reduction block + iters + 4) A loop cannot appear multiple times in the input array. + + Parameters + ---------- + *loops : List[LoopRV] + The loops in the new order + + Examples + -------- + + Before reorder, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_reorder(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + + Create the schedule and do reorder: + + .. code-block:: python + + sch = tir.Schedule(before_reorder) + i, j = sch.get_loops(sch.get_block("B")) + sch.reorder(j, i) + print(tvm.script.asscript(sch.mod["main"])) + + After applying reorder, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_reorder(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + # Here j and i are reordered + for j, i in tir.grid(128, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + B[vi, vj] = A[vi, vj] * 2.0 + """ + _ffi_api.ScheduleReorder(self, loops) # type: ignore # pylint: disable=no-member + ########## Schedule: Manipulate ForKind ########## def parallel(self, loop: LoopRV) -> None: diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index b18090dd7215..084d0b0eec6a 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -346,6 +346,13 @@ Array ConcreteScheduleNode::Split(const LoopRV& loop_rv, return CreateRV(results); } +void ConcreteScheduleNode::Reorder(const Array& ordered_loop_rvs) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::Reorder(state_, GetSRefs(ordered_loop_rvs)); + TVM_TIR_SCHEDULE_END("reorder", this->error_render_level_); + this->state_->DebugVerify(); +} + /******** Schedule: Manipulate ForKind ********/ void ConcreteScheduleNode::Parallel(const LoopRV& loop_rv) { diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 2af4675ddcca..97819d63edb6 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -81,6 +81,7 @@ class ConcreteScheduleNode : public ScheduleNode { /******** Schedule: Transform loops ********/ LoopRV Fuse(const Array& loop_rvs) override; Array Split(const LoopRV& loop_rv, const Array>& factors) override; + void Reorder(const Array& ordered_loop_rvs) override; /******** Schedule: Manipulate ForKind ********/ void Parallel(const LoopRV& loop_rv) override; void Vectorize(const LoopRV& loop_rv) override; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 04c38f67da7d..93cc937194e8 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -63,6 +63,21 @@ TVM_DLL Array Split(ScheduleState self, const StmtSRef& loop_sref, * \return The sref to the fused loop */ TVM_DLL StmtSRef Fuse(ScheduleState self, const Array& loop_srefs); +/*! + * \brief Reorder a list of loops. It doesn't require the loops to be consecutive. + * It requires: + * 1) The loops are in the same line. That means: the loops can be ordered to [l_1, l_2, ... , + * l_n] where l_i is an ancestor of l_{i+1} and there are only single-branch loops between + * l_1 and l_n (which also indicates they are under the same scope). + * 2) In the new order, an outer loop cannot depend on inner loops. + * 3) The block below the loops have affine bindings and only have data-parallel or reduction block + * iters + * 4) A loop cannot appear multiple times in the input array. + * \param self The state of the schedule + * \param ordered_loop_srefs An array of srefs which indicates the new order of loops + */ +TVM_DLL void Reorder(ScheduleState self, const Array& ordered_loop_srefs); + /******** Schedule: Manipulate ForKind ********/ /*! * \brief Parallelize the input loop. It requires: diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index d1875df61ac7..6f328a477091 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -131,10 +131,32 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator { Map loop_var2extent_; }; +class BlockIterTypeError : public ScheduleError { + public: + explicit BlockIterTypeError(IRModule mod, Block block) + : mod_(std::move(mod)), block_(std::move(block)) {} + + String FastErrorString() const final { + return "ScheduleError: The block under the loops to be reordered have block iter type other " + "than data-parallel or reduction"; + } + + String DetailRenderTemplate() const final { + return "The block {0} under the loops to be reordered have block iter type other than " + "data-parallel or reduction"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + + IRModule mod_; + Block block_; +}; + class HasAnnotationOrThreadBindingError : public ScheduleError { public: explicit HasAnnotationOrThreadBindingError(IRModule mod, For loop) - : mod_(mod), loop_(std::move(loop)) {} + : mod_(std::move(mod)), loop_(std::move(loop)) {} String FastErrorString() const final { return "ScheduleError: The primitive can't be applied because the loop has annotation or " @@ -155,7 +177,7 @@ class HasAnnotationOrThreadBindingError : public ScheduleError { class OuterNotInnerParent : public ScheduleError { public: explicit OuterNotInnerParent(IRModule mod, For outer, For inner) - : mod_(mod), outer_(std::move(outer)), inner_(std::move(inner)) {} + : mod_(std::move(mod)), outer_(std::move(outer)), inner_(std::move(inner)) {} String FastErrorString() const final { return "ScheduleError: The outer loop is not the parent of the inner loop"; @@ -177,7 +199,7 @@ class OuterNotInnerParent : public ScheduleError { class NotOnlyChildError : public ScheduleError { public: explicit NotOnlyChildError(IRModule mod, For outer, For inner) - : mod_(mod), outer_(std::move(outer)), inner_(std::move(inner)) {} + : mod_(std::move(mod)), outer_(std::move(outer)), inner_(std::move(inner)) {} String FastErrorString() const final { return "ScheduleError: The inner loop is not the only child of outer loop"; @@ -198,7 +220,8 @@ class NotOnlyChildError : public ScheduleError { class LoopNotStartWithZeroError : public ScheduleError { public: - explicit LoopNotStartWithZeroError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {} + explicit LoopNotStartWithZeroError(IRModule mod, For loop) + : mod_(std::move(mod)), loop_(std::move(loop)) {} String FastErrorString() const final { return "ScheduleError: The primitive only supports loop starting with 0"; @@ -217,7 +240,7 @@ class LoopNotStartWithZeroError : public ScheduleError { class NotSingleInferFactorError : public ScheduleError { public: - explicit NotSingleInferFactorError(IRModule mod) : mod_(mod) {} + explicit NotSingleInferFactorError(IRModule mod) : mod_(std::move(mod)) {} String FastErrorString() const final { return "ScheduleError: only one factor can be specified as -1 or none"; @@ -235,7 +258,8 @@ class NotSingleInferFactorError : public ScheduleError { class WrongFactorProductError : public ScheduleError { public: - explicit WrongFactorProductError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {} + explicit WrongFactorProductError(IRModule mod, For loop) + : mod_(std::move(mod)), loop_(std::move(loop)) {} String FastErrorString() const final { return "ScheduleError: The product of factors is not larger than or equal to the extent of " @@ -253,6 +277,137 @@ class WrongFactorProductError : public ScheduleError { For loop_; }; +class LoopMultiAppearanceError : public ScheduleError { + public: + explicit LoopMultiAppearanceError(IRModule mod, For loop) + : mod_(std::move(mod)), loop_(std::move(loop)) {} + + String FastErrorString() const final { + return "ScheduleError: Some loop appears in the input array for multiple times."; + } + + String DetailRenderTemplate() const final { + return "Loop {0} appears in the input array for multiple times."; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {loop_}; } + + IRModule mod_; + For loop_; +}; + +class LoopsNotALineError : public ScheduleError { + public: + enum ProblemKind { kNotUnderAScope, kHaveNonSingleBranchStmt }; + + explicit LoopsNotALineError(IRModule mod, Optional problematic_loop, ProblemKind kind) + : mod_(std::move(mod)), problematic_loop_(std::move(problematic_loop)), kind_(kind) {} + + String FastErrorString() const final { return "ScheduleError: the loops are not in a line"; } + + String DetailRenderTemplate() const final { + std::stringstream ss; + ss << "The loops are not in a line because"; + if (kind_ == kNotUnderAScope) { + ss << " they are not under the same scope."; + } else { + ss << " there is a non-single-branch stmt in between. Problematic stmt: {0}"; + } + return ss.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { + if (kind_ == kNotUnderAScope) { + return {}; + } else { + ICHECK(problematic_loop_.defined()); + return {problematic_loop_.value()}; + } + } + + IRModule mod_; + Optional problematic_loop_; + ProblemKind kind_; +}; + +class DependentLoopError : public ScheduleError { + public: + explicit DependentLoopError(IRModule mod, For loop, String inner_var) + : mod_(std::move(mod)), loop_(std::move(loop)), inner_var_(std::move(inner_var)) {} + + String FastErrorString() const final { + return "ScheduleError: An outer loop's `min` or `extent` is dependent on an inner loop " + "in the new order"; + } + + String DetailRenderTemplate() const final { + return "Outer Loop {0}'s `min` or `extent` is dependent on an inner loop " + inner_var_ + + " in the new order"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {loop_}; } + + IRModule mod_; + For loop_; + String inner_var_; +}; + +/*! + * \brief Collect all loops under a specific block scope in the inverse pre-order + * \param self The state of the schedule + * \param root_block_sref the sref to the root of block scope + * \return The array of srefs of all loops under the block scope, in inverse pre-order + */ +std::vector GetLoopsInversePreOrderUnderScope( + const ScheduleState& self, const StmtSRef& root_block_sref) { + std::vector loops; + const BlockNode* root_block = TVM_SREF_TO_BLOCK(root_block, root_block_sref); + // Gather all the loops under parent_block + PreOrderVisit(root_block->body, [&loops, self](const ObjectRef& node) { + // Stops at a new BlockNode + if (node->IsInstance()) { + return false; + } + // Collects every ForNode + if (const auto* loop = node.as()) { + loops.push_back(self->stmt2ref.at(loop).operator->()); + } + return true; + }); + // Reverse to get inverse preorder + std::reverse(loops.begin(), loops.end()); + return loops; +} +/*! + * \brief Check that all the blocks under the specific stmt have affine bindings and only have + * data-parallel or reduction block iters + * \param self The state of the schedule + * \param sref The sref to the specific stmt + */ +void CheckBlockIterTypeAndAffineBinding(const ScheduleState& self, const StmtSRefNode* sref) { + class BlockIterTypeAndAffineBindingChecker : public StmtVisitor { + public: + explicit BlockIterTypeAndAffineBindingChecker(const ScheduleState& state) : state_(state) {} + + private: + void VisitStmt_(const BlockNode* op) final { + for (const IterVar& iter_var : op->iter_vars) { + if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce) { + throw BlockIterTypeError(state_->mod, GetRef(op)); + } + CheckAffineBinding(state_, GetRef(op)); + } + } + const ScheduleState& state_; + }; + + BlockIterTypeAndAffineBindingChecker checker(self); + checker(GetRef(sref->stmt)); +} + Array Split(ScheduleState self, const StmtSRef& loop_sref, const Array& factors) { // Invariance @@ -385,6 +540,108 @@ StmtSRef Fuse(ScheduleState self, const Array& loop_srefs) { return self->stmt2ref.at(new_stmt.get()); } +void Reorder(ScheduleState self, const Array& ordered_loop_srefs) { + std::unordered_set loop_srefs; + loop_srefs.reserve(ordered_loop_srefs.size()); + if (ordered_loop_srefs.empty() || ordered_loop_srefs.size() == 1) { + return; + } + // Step 1. check uniqueness + for (const StmtSRef loop_sref : ordered_loop_srefs) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + // uniqueness check + auto inserted = loop_srefs.insert(loop_sref.get()); + if (!inserted.second) { + throw LoopMultiAppearanceError(self->mod, GetRef(loop)); + } + } + // Step 2. gather loops to be reordered + // The algorithm is to scan the inverse preorder of the whole loop tree in the scope. + // For some Loop x, it is potentially in the reorder range if + // - x is in the reorder list + // - x has only one child which is a loop and is potentially in the reorder range + // After the inverse DFS, we can know the exact reorder range + // `top` and `bottom` denote the boundary of the loop range that need reordering + const StmtSRefNode* top = nullptr; + const StmtSRefNode* bottom = nullptr; + // Maps a parent sref to its child sref + std::unordered_map successor; + int n_loops_not_found = ordered_loop_srefs.size(); + // Gather all the loops under the block scope + std::vector inverse_preorder_loops = GetLoopsInversePreOrderUnderScope( + self, GetScopeRoot(self, ordered_loop_srefs[0], /*require_stage_pipeline=*/true)); + for (const StmtSRefNode* loop : inverse_preorder_loops) { + bool is_in_reorder_list = loop_srefs.count(loop); + bool has_successor_in_reorder_list = successor.count(loop); + if (is_in_reorder_list || has_successor_in_reorder_list) { + const StmtSRefNode* parent = loop->parent; + // If the successor of `parent` exists, then `parent` can't be a single-branch loop + auto inserted = successor.insert({parent, loop}); + if (!inserted.second) { + throw LoopsNotALineError(self->mod, GetRef(parent->stmt), + LoopsNotALineError::kHaveNonSingleBranchStmt); + } + // `bottom` is the first loop encountered + if (bottom == nullptr) { + bottom = loop; + } + // `top` is the last loop encountered + if (is_in_reorder_list) { + top = loop; + --n_loops_not_found; + } + } + } + // Step 3. Check loops are in the same block scope + if (n_loops_not_found != 0) { + throw LoopsNotALineError(self->mod, NullOpt, LoopsNotALineError::kNotUnderAScope); + } + // Step 4. Check that loops are single-branch + const ForNode* outer_loop = TVM_SREF_TO_FOR(outer_loop, GetRef(top)); + for (const StmtSRefNode* loop_sref = top; loop_sref != bottom;) { + loop_sref = successor[loop_sref]; + const ForNode* inner_loop = TVM_SREF_TO_FOR(inner_loop, GetRef(loop_sref)); + if (outer_loop->body.get() != inner_loop) { + throw LoopsNotALineError(self->mod, GetRef(outer_loop), + LoopsNotALineError::kHaveNonSingleBranchStmt); + } + outer_loop = inner_loop; + } + // Step 5. Check the block below has all its block_var to be data-parallel or reduction + CheckBlockIterTypeAndAffineBinding(self, bottom); + // Step 6. Replace the original loops with the reordered loops and check that outer loop is + // not dependent on inner loop + std::unordered_set inner_vars; + std::function f_reorder = + [&bottom, &loop_srefs, &successor, &ordered_loop_srefs, &inner_vars, &self, &f_reorder]( + const StmtSRefNode* loop, int index) -> Stmt { + const ForNode* copy = loop_srefs.count(loop) ? ordered_loop_srefs[index++]->StmtAs() + : loop->StmtAs(); + ObjectPtr n = make_object(*copy); + if (loop == bottom) { + // stop recursion at bottom loop + n->body = loop->StmtAs()->body; + } else { + // reorder recursively + n->body = f_reorder(successor.at(loop), index); + } + const VarNode* used_var; + auto f_contain = [&inner_vars, &used_var](const VarNode* var) { + if (inner_vars.count(var)) { + used_var = var; + return true; + } + return false; + }; + if (UsesVar(copy->min, f_contain) || UsesVar(copy->extent, f_contain)) { + throw DependentLoopError(self->mod, GetRef(copy), used_var->name_hint); + } + inner_vars.insert(copy->loop_var.get()); + return Stmt(std::move(n)); + }; + self->Replace(GetRef(top), f_reorder(top, 0), {}); +} + /******** Instruction Registration ********/ struct SplitTraits : public UnpackedInstTraits { @@ -456,8 +713,39 @@ struct FuseTraits : public UnpackedInstTraits { friend struct ::tvm::tir::UnpackedInstTraits; }; +struct ReorderTraits : public UnpackedInstTraits { + static constexpr const char* kName = "Reorder"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + template + static TVM_ALWAYS_INLINE void _SetInputs(const runtime::TVMArgsSetter& setter, + const Array& inputs) { + setter(delta, inputs); + } + + static void UnpackedApplyToSchedule(Schedule sch, Array loop_rvs) { + return sch->Reorder(loop_rvs); + } + + static String UnpackedAsPython(Array outputs, Array loop_rvs) { + PythonAPICall py("reorder"); + for (const String& loop_rv : loop_rvs) { + py.Input("", loop_rv); + } + return py.Str(); + } + + friend struct UnpackedInstTraits; +}; + TVM_REGISTER_INST_KIND_TRAITS(SplitTraits); TVM_REGISTER_INST_KIND_TRAITS(FuseTraits); +TVM_REGISTER_INST_KIND_TRAITS(ReorderTraits); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index f21a4c370a5b..29681fdf0926 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -125,6 +125,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetLoops") /******** (FFI) Transform loops ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method(&ScheduleNode::Fuse); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSplit").set_body_method(&ScheduleNode::Split); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReorder") + .set_body_method(&ScheduleNode::Reorder); /******** (FFI) Manipulate ForKind ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleParallel") .set_body_method(&ScheduleNode::Parallel); diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index e3f675e8628f..ae6a194b9888 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -99,6 +99,16 @@ Array TracedScheduleNode::Split(const LoopRV& loop_rv, return results; } +void TracedScheduleNode::Reorder(const Array& ordered_loop_rvs) { + ConcreteScheduleNode::Reorder(ordered_loop_rvs); + + static const InstructionKind& kind = InstructionKind::Get("Reorder"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{ordered_loop_rvs.begin(), ordered_loop_rvs.end()}, + /*attrs=*/{}, + /*outputs=*/{})); +} + /******** Schedule: Manipulate ForKind ********/ void TracedScheduleNode::Parallel(const LoopRV& loop_rv) { diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index f5f31abe1556..11128ba32fad 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -54,6 +54,7 @@ class TracedScheduleNode : public ConcreteScheduleNode { /******** Schedule: Transform loops ********/ LoopRV Fuse(const Array& loop_rvs) final; Array Split(const LoopRV& loop_rv, const Array>& factor_rvs) final; + void Reorder(const Array& ordered_loop_rvs) final; /******** Schedule: Manipulate ForKind ********/ void Parallel(const LoopRV& loop_rv) final; void Vectorize(const LoopRV& loop_rv) final; diff --git a/tests/python/unittest/test_tir_schedule_reorder.py b/tests/python/unittest/test_tir_schedule_reorder.py new file mode 100644 index 000000000000..f463ab111d16 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_reorder.py @@ -0,0 +1,265 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import sys + +import pytest +import tvm +from tvm import tir +from tvm.script import ty +from tvm.tir.schedule.testing import verify_trace_roundtrip + +# pylint: disable=no-member,invalid-name,unused-variable + + +@tvm.script.tir +def elementwise(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128, 128)) + with tir.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: + B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 + + +@tvm.script.tir +def elementwise_not_affine(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128, 128)) + for i, j, k, l in tir.grid(128, 128, 128, 8): + with tir.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + tir.bind(vl, l * 16) + B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 + + +@tvm.script.tir +def elementwise_dependent_loop(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128, 128)) + for i in tir.serial(0, 128): + for j, k, l in tir.grid(128, i, 128): + with tir.block([128, 128, i, 128], "B") as [vi, vj, vk, vl]: + B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 + + +@tvm.script.tir +def elementwise_predicate(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128, 128)) + for i, j, k, l in tir.grid(128, 128, 128, 128): + with tir.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: + tir.where(i * 2097152 + j * 16384 + k * 128 + l < 100) + B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 + + +@tvm.script.tir +def elementwise_non_single_branch(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + C = tir.alloc_buffer((128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + for i, j in tir.grid(128, 128): + for k in tir.serial(0, 128): + with tir.block([128, 128, 128], "C") as [vi, vj, vk]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + C[vi, vj, vk] = A[vi, vj, vk] * 2.0 + for k in tir.serial(0, 128): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + B[vi, vj, vk] = C[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_with_loops_not_same_scope(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "A") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + for k in tir.serial(0, 128): + with tir.block([128], "B") as [vk]: + tir.bind(vk, k) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_with_wrong_block_var_type(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + for i, j, k in tir.grid(128, 128, 128): + with tir.block([128, 128, tir.scan_axis(0, 128)], "B") as [vi, vj, vk]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_reordered(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128, 128)) + for l, j, k, i in tir.grid(128, 128, 128, 128): + with tir.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + tir.bind(vl, l) + B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 + + +@tvm.script.tir +def elementwise_reordered_with_predicate(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128, 128)) + for l, j, k, i in tir.grid(128, 128, 128, 128): + with tir.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: + tir.where(i * 2097152 + j * 16384 + k * 128 + l < 100) + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + tir.bind(vl, l) + B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 + + +@tvm.script.tir +def opaque_access(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [16, 16], "float32") + B = tir.match_buffer(b, [16, 16], "float32") + with tir.block([16, 16], "A") as [vi, vj]: + tir.reads([]) + tir.writes([A[0:16, 0:16]]) + tir.store(A.data, vi * 16 + vj, 1) + with tir.block([16, 16], "B") as [vi, vj]: + tir.reads([]) + tir.writes([B[0:16, 0:16]]) + tir.evaluate(tir.tvm_fill_fragment(B.data, 16, 16, 16, 0, vi * 16 + vj, dtype="handle")) + + +@tvm.script.tir +def opaque_access_reorder(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [16, 16], "float32") + B = tir.match_buffer(b, [16, 16], "float32") + for j, i in tir.grid(16, 16): + with tir.block([16, 16], "A") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.reads([]) + tir.writes([A[0:16, 0:16]]) + tir.store(A.data, vi * 16 + vj, 1) + for j, i in tir.grid(16, 16): + with tir.block([16, 16], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.reads([]) + tir.writes([B[0:16, 0:16]]) + tir.evaluate(tir.tvm_fill_fragment(B.data, 16, 16, 16, 0, vi * 16 + vj, dtype="handle")) + + +# pylint: enable=no-member,invalid-name,unused-variable + + +def test_reorder(): + sch = tir.Schedule(elementwise, debug_mask="all") + block_b = sch.get_block("B") + i, j, k, l = sch.get_loops(block_b) + sch.reorder(l, i) + tvm.ir.assert_structural_equal(elementwise_reordered, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise) + + +def test_reorder_with_opaque_access(): + sch = tir.Schedule(opaque_access, debug_mask="all") + block_a = sch.get_block("A") + i, j = sch.get_loops(block_a) + sch.reorder(j, i) + block_b = sch.get_block("B") + i, j = sch.get_loops(block_b) + sch.reorder(j, i) + tvm.ir.assert_structural_equal(opaque_access_reorder, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=opaque_access) + + +def test_reorder_with_predicate(): + sch = tir.Schedule(elementwise_predicate, debug_mask="all") + block_b = sch.get_block("B") + i, j, k, l = sch.get_loops(block_b) + sch.reorder(l, i) + tvm.ir.assert_structural_equal(elementwise_reordered_with_predicate, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise_predicate) + + +def test_reorder_fail_with_multi_appearance_loops(): + sch = tir.Schedule(elementwise, debug_mask="all") + block_b = sch.get_block("B") + i, j, k, l = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError): + sch.reorder(k, i, i) + + +def test_reorder_fail_with_non_single_branch_loop(): + sch = tir.Schedule(elementwise_non_single_branch, debug_mask="all") + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError): + sch.reorder(k, i) + + +def test_reorder_fail_with_loops_not_under_same_scope(): + sch = tir.Schedule(elementwise_with_loops_not_same_scope, debug_mask="all") + block_b = sch.get_block("B") + block_a = sch.get_block("A") + i, j = sch.get_loops(block_a) + k = sch.get_loops(block_b)[0] + with pytest.raises(tvm.tir.ScheduleError): + sch.reorder(k, i) + + +def test_reorder_fail_with_wrong_block_var_type(): + sch = tir.Schedule(elementwise_with_wrong_block_var_type, debug_mask="all") + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError): + sch.reorder(k, i) + + +def test_reorder_fail_with_dependent_loops(): + sch = tir.Schedule(elementwise_dependent_loop, debug_mask="all") + block_b = sch.get_block("B") + i, j, k, l = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError): + sch.reorder(l, i) + + +def test_reorder_fail_not_affine_bindings(): + sch = tir.Schedule(elementwise_not_affine, debug_mask="all") + block_b = sch.get_block("B") + i, j, k, l = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError): + sch.reorder(l, i) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) From b4e2f6917599fc796113046638b1f6ef087d2f92 Mon Sep 17 00:00:00 2001 From: jinhongyi <3231950289@qq.com> Date: Tue, 17 Aug 2021 14:06:19 +0800 Subject: [PATCH 02/18] format --- python/tvm/tir/schedule/schedule.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 88401a90c96b..97076ae0933f 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -450,8 +450,8 @@ def reorder(self, *loops: List[LoopRV]) -> None: l_n] where l_i is an ancestor of l_{i+1} and there are only single-branch loops between l_1 and l_n (which also indicates they are under the same scope). 2) In the new order, an outer loop cannot depend on inner loops. - 3) The block below the loops have affine bindings and only have data-parallel or reduction block - iters + 3) The block below the loops have affine bindings and only have data-parallel or reduction + block iters 4) A loop cannot appear multiple times in the input array. Parameters From c005b10c6e0e96a87b1bdac950663d2f36b5b7e1 Mon Sep 17 00:00:00 2001 From: jinhongyi <3231950289@qq.com> Date: Tue, 17 Aug 2021 17:31:35 +0800 Subject: [PATCH 03/18] fix compilation --- src/tir/schedule/primitive/loop_transformation.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index 6f328a477091..a10e6dc2c9f3 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -740,7 +740,8 @@ struct ReorderTraits : public UnpackedInstTraits { return py.Str(); } - friend struct UnpackedInstTraits; + template + friend struct ::tvm::tir::UnpackedInstTraits; }; TVM_REGISTER_INST_KIND_TRAITS(SplitTraits); From 6808eb774738f80a8856ebdd3e59d7fb8d4aba13 Mon Sep 17 00:00:00 2001 From: jinhongyi <3231950289@qq.com> Date: Sat, 21 Aug 2021 00:15:02 +0800 Subject: [PATCH 04/18] new implementation of gathering reorder range --- .../schedule/primitive/loop_transformation.cc | 74 ++++++++++--------- 1 file changed, 39 insertions(+), 35 deletions(-) diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index a10e6dc2c9f3..c76f74884782 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -547,7 +547,7 @@ void Reorder(ScheduleState self, const Array& ordered_loop_srefs) { return; } // Step 1. check uniqueness - for (const StmtSRef loop_sref : ordered_loop_srefs) { + for (const StmtSRef& loop_sref : ordered_loop_srefs) { const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); // uniqueness check auto inserted = loop_srefs.insert(loop_sref.get()); @@ -556,47 +556,51 @@ void Reorder(ScheduleState self, const Array& ordered_loop_srefs) { } } // Step 2. gather loops to be reordered - // The algorithm is to scan the inverse preorder of the whole loop tree in the scope. - // For some Loop x, it is potentially in the reorder range if - // - x is in the reorder list - // - x has only one child which is a loop and is potentially in the reorder range - // After the inverse DFS, we can know the exact reorder range - // `top` and `bottom` denote the boundary of the loop range that need reordering + // For each loop, traverse upwards along the parent pointer, and stop on either a block, or a + // previously-visited loop + // - the top of the reorder range is the last loop visited in the first traverse which exists in + // the input array + // - the bottom of the reorder range is the last loop in the input array which is not visited in + // the previous traverses const StmtSRefNode* top = nullptr; const StmtSRefNode* bottom = nullptr; // Maps a parent sref to its child sref std::unordered_map successor; - int n_loops_not_found = ordered_loop_srefs.size(); - // Gather all the loops under the block scope - std::vector inverse_preorder_loops = GetLoopsInversePreOrderUnderScope( - self, GetScopeRoot(self, ordered_loop_srefs[0], /*require_stage_pipeline=*/true)); - for (const StmtSRefNode* loop : inverse_preorder_loops) { - bool is_in_reorder_list = loop_srefs.count(loop); - bool has_successor_in_reorder_list = successor.count(loop); - if (is_in_reorder_list || has_successor_in_reorder_list) { - const StmtSRefNode* parent = loop->parent; - // If the successor of `parent` exists, then `parent` can't be a single-branch loop - auto inserted = successor.insert({parent, loop}); - if (!inserted.second) { - throw LoopsNotALineError(self->mod, GetRef(parent->stmt), - LoopsNotALineError::kHaveNonSingleBranchStmt); + for (size_t i = 0; i < ordered_loop_srefs.size(); i++) { + const StmtSRefNode* sref = ordered_loop_srefs[i].get(); + // if sref is not visited before, update `bottom` + if (!successor.count(sref->parent)) { + bottom = sref; + } + while (true) { + // stop at blocknode + if (sref->stmt->IsInstance()) { + if (i != 0) { + throw LoopsNotALineError(self->mod, NullOpt, LoopsNotALineError::kNotUnderAScope); + } else { + break; + } } - // `bottom` is the first loop encountered - if (bottom == nullptr) { - bottom = loop; + const StmtSRefNode* parent_sref = sref->parent; + // stop at previously-visited loop + if (successor.count(parent_sref)) { + if (successor[parent_sref] == sref) { + break; + } else { + throw LoopsNotALineError(self->mod, GetRef(parent_sref->stmt), + LoopsNotALineError::kHaveNonSingleBranchStmt); + } + } else { + successor[parent_sref] = sref; } - // `top` is the last loop encountered - if (is_in_reorder_list) { - top = loop; - --n_loops_not_found; + // if it's the first traverse and the loop is in the input array, update `top` + if (loop_srefs.count(sref) && i == 0) { + top = sref; } + sref = parent_sref; } } - // Step 3. Check loops are in the same block scope - if (n_loops_not_found != 0) { - throw LoopsNotALineError(self->mod, NullOpt, LoopsNotALineError::kNotUnderAScope); - } - // Step 4. Check that loops are single-branch + // Step 3. Check that loops are single-branch const ForNode* outer_loop = TVM_SREF_TO_FOR(outer_loop, GetRef(top)); for (const StmtSRefNode* loop_sref = top; loop_sref != bottom;) { loop_sref = successor[loop_sref]; @@ -607,9 +611,9 @@ void Reorder(ScheduleState self, const Array& ordered_loop_srefs) { } outer_loop = inner_loop; } - // Step 5. Check the block below has all its block_var to be data-parallel or reduction + // Step 4. Check the block below has all its block_var to be data-parallel or reduction CheckBlockIterTypeAndAffineBinding(self, bottom); - // Step 6. Replace the original loops with the reordered loops and check that outer loop is + // Step 5. Replace the original loops with the reordered loops and check that outer loop is // not dependent on inner loop std::unordered_set inner_vars; std::function f_reorder = From 2c4ffd15a20481ce05bdb7441b3fc0ca8c099fa7 Mon Sep 17 00:00:00 2001 From: jinhongyi <3231950289@qq.com> Date: Sat, 21 Aug 2021 00:23:50 +0800 Subject: [PATCH 05/18] address comments --- include/tvm/tir/schedule/schedule.h | 8 +- python/tvm/tir/schedule/schedule.py | 14 +- src/tir/schedule/primitive.h | 8 +- .../schedule/primitive/loop_transformation.cc | 122 +++++++----------- 4 files changed, 62 insertions(+), 90 deletions(-) diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 670f9e5217ba..41906f30bd59 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -225,10 +225,10 @@ class ScheduleNode : public runtime::Object { * 1) The loops are in the same line. That means: the loops can be ordered to [l_1, l_2, ... , * l_n] where l_i is an ancestor of l_{i+1} and there are only single-branch loops between * l_1 and l_n (which also indicates they are under the same scope). - * 2) In the new order, an outer loop cannot depend on inner loops. - * 3) The block below the loops have affine bindings and only have data-parallel or reduction - * block iters - * 4) A loop cannot appear multiple times in the input array. + * 2) After reordering, the domain of an outer loop cannot depend on any of the inner loops + * 3) For every block under the loop nests, its block binding must be affine, and the block + * variables must be either data parallel or reduction. + * 4) No duplicated loops are allowed in the arguments. * \param ordered_loop_rvs The loops in the new order */ virtual void Reorder(const Array& ordered_loop_rvs) = 0; diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 97076ae0933f..5d0eb87eb12f 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -442,21 +442,21 @@ def after_split(a: ty.handle, b: ty.handle) -> None: # that there is at most one None in `factors` return _ffi_api.ScheduleSplit(self, loop, factors) # type: ignore # pylint: disable=no-member - def reorder(self, *loops: List[LoopRV]) -> None: + def reorder(self, *ordered_loops: List[LoopRV]) -> None: """ Reorder a list of loops. It doesn't require the loops to be consecutive. It requires: 1) The loops are in the same line. That means: the loops can be ordered to [l_1, l_2, ... , l_n] where l_i is an ancestor of l_{i+1} and there are only single-branch loops between l_1 and l_n (which also indicates they are under the same scope). - 2) In the new order, an outer loop cannot depend on inner loops. - 3) The block below the loops have affine bindings and only have data-parallel or reduction - block iters - 4) A loop cannot appear multiple times in the input array. + 2) After reordering, the domain of an outer loop cannot depend on any of the inner loops + 3) For every block under the loop nests, its block binding must be affine, and the block + variables must be either data parallel or reduction. + 4) No duplicated loops are allowed in the arguments. Parameters ---------- - *loops : List[LoopRV] + *ordered_loops : List[LoopRV] The loops in the new order Examples @@ -498,7 +498,7 @@ def after_reorder(a: ty.handle, b: ty.handle) -> None: tir.bind(vj, j) B[vi, vj] = A[vi, vj] * 2.0 """ - _ffi_api.ScheduleReorder(self, loops) # type: ignore # pylint: disable=no-member + _ffi_api.ScheduleReorder(self, ordered_loops) # type: ignore # pylint: disable=no-member ########## Schedule: Manipulate ForKind ########## diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 93cc937194e8..0ed455fd3722 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -69,10 +69,10 @@ TVM_DLL StmtSRef Fuse(ScheduleState self, const Array& loop_srefs); * 1) The loops are in the same line. That means: the loops can be ordered to [l_1, l_2, ... , * l_n] where l_i is an ancestor of l_{i+1} and there are only single-branch loops between * l_1 and l_n (which also indicates they are under the same scope). - * 2) In the new order, an outer loop cannot depend on inner loops. - * 3) The block below the loops have affine bindings and only have data-parallel or reduction block - * iters - * 4) A loop cannot appear multiple times in the input array. + * 2) After reordering, the domain of an outer loop cannot depend on any of the inner loops + * 3) For every block under the loop nests, its block binding must be affine, and the block + * variables must be either data parallel or reduction. + * 4) No duplicated loops are allowed in the arguments. * \param self The state of the schedule * \param ordered_loop_srefs An array of srefs which indicates the new order of loops */ diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index c76f74884782..828cd8bb2cb6 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -131,10 +131,37 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator { Map loop_var2extent_; }; -class BlockIterTypeError : public ScheduleError { +class BlockPropertyError : public ScheduleError { public: - explicit BlockIterTypeError(IRModule mod, Block block) - : mod_(std::move(mod)), block_(std::move(block)) {} + /*! + * \brief Check that all the blocks under the specific stmt have affine bindings and only have + * data-parallel or reduction block iters + * \param self The state of the schedule + * \param sref The sref to the specific stmt + */ + static void CheckBlockIterTypeAndAffineBinding(const ScheduleState& self, + const StmtSRefNode* sref) { + class BlockIterTypeAndAffineBindingChecker : public StmtVisitor { + public: + explicit BlockIterTypeAndAffineBindingChecker(const ScheduleState& state) : state_(state) {} + + private: + void VisitStmt_(const BlockNode* op) final { + for (const IterVar& iter_var : op->iter_vars) { + if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce) { + throw BlockPropertyError(state_->mod, GetRef(op)); + } + CheckAffineBinding(state_, GetRef(op)); + } + } + const ScheduleState& state_; + }; + + BlockIterTypeAndAffineBindingChecker checker(self); + checker(GetRef(sref->stmt)); + } + + explicit BlockPropertyError(IRModule mod, Block block) : mod_(mod), block_(std::move(block)) {} String FastErrorString() const final { return "ScheduleError: The block under the loops to be reordered have block iter type other " @@ -156,7 +183,7 @@ class BlockIterTypeError : public ScheduleError { class HasAnnotationOrThreadBindingError : public ScheduleError { public: explicit HasAnnotationOrThreadBindingError(IRModule mod, For loop) - : mod_(std::move(mod)), loop_(std::move(loop)) {} + : mod_(mod), loop_(std::move(loop)) {} String FastErrorString() const final { return "ScheduleError: The primitive can't be applied because the loop has annotation or " @@ -177,7 +204,7 @@ class HasAnnotationOrThreadBindingError : public ScheduleError { class OuterNotInnerParent : public ScheduleError { public: explicit OuterNotInnerParent(IRModule mod, For outer, For inner) - : mod_(std::move(mod)), outer_(std::move(outer)), inner_(std::move(inner)) {} + : mod_(mod), outer_(std::move(outer)), inner_(std::move(inner)) {} String FastErrorString() const final { return "ScheduleError: The outer loop is not the parent of the inner loop"; @@ -199,7 +226,7 @@ class OuterNotInnerParent : public ScheduleError { class NotOnlyChildError : public ScheduleError { public: explicit NotOnlyChildError(IRModule mod, For outer, For inner) - : mod_(std::move(mod)), outer_(std::move(outer)), inner_(std::move(inner)) {} + : mod_(mod), outer_(std::move(outer)), inner_(std::move(inner)) {} String FastErrorString() const final { return "ScheduleError: The inner loop is not the only child of outer loop"; @@ -220,8 +247,7 @@ class NotOnlyChildError : public ScheduleError { class LoopNotStartWithZeroError : public ScheduleError { public: - explicit LoopNotStartWithZeroError(IRModule mod, For loop) - : mod_(std::move(mod)), loop_(std::move(loop)) {} + explicit LoopNotStartWithZeroError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {} String FastErrorString() const final { return "ScheduleError: The primitive only supports loop starting with 0"; @@ -240,7 +266,7 @@ class LoopNotStartWithZeroError : public ScheduleError { class NotSingleInferFactorError : public ScheduleError { public: - explicit NotSingleInferFactorError(IRModule mod) : mod_(std::move(mod)) {} + explicit NotSingleInferFactorError(IRModule mod) : mod_(mod) {} String FastErrorString() const final { return "ScheduleError: only one factor can be specified as -1 or none"; @@ -258,8 +284,7 @@ class NotSingleInferFactorError : public ScheduleError { class WrongFactorProductError : public ScheduleError { public: - explicit WrongFactorProductError(IRModule mod, For loop) - : mod_(std::move(mod)), loop_(std::move(loop)) {} + explicit WrongFactorProductError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {} String FastErrorString() const final { return "ScheduleError: The product of factors is not larger than or equal to the extent of " @@ -279,8 +304,7 @@ class WrongFactorProductError : public ScheduleError { class LoopMultiAppearanceError : public ScheduleError { public: - explicit LoopMultiAppearanceError(IRModule mod, For loop) - : mod_(std::move(mod)), loop_(std::move(loop)) {} + explicit LoopMultiAppearanceError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {} String FastErrorString() const final { return "ScheduleError: Some loop appears in the input array for multiple times."; @@ -299,17 +323,17 @@ class LoopMultiAppearanceError : public ScheduleError { class LoopsNotALineError : public ScheduleError { public: - enum ProblemKind { kNotUnderAScope, kHaveNonSingleBranchStmt }; + enum class ProblemKind { kNotUnderAScope, kHaveNonSingleBranchStmt }; explicit LoopsNotALineError(IRModule mod, Optional problematic_loop, ProblemKind kind) - : mod_(std::move(mod)), problematic_loop_(std::move(problematic_loop)), kind_(kind) {} + : mod_(mod), problematic_loop_(std::move(problematic_loop)), kind_(kind) {} String FastErrorString() const final { return "ScheduleError: the loops are not in a line"; } String DetailRenderTemplate() const final { std::stringstream ss; ss << "The loops are not in a line because"; - if (kind_ == kNotUnderAScope) { + if (kind_ == ProblemKind::kNotUnderAScope) { ss << " they are not under the same scope."; } else { ss << " there is a non-single-branch stmt in between. Problematic stmt: {0}"; @@ -319,7 +343,7 @@ class LoopsNotALineError : public ScheduleError { IRModule mod() const final { return mod_; } Array LocationsOfInterest() const final { - if (kind_ == kNotUnderAScope) { + if (kind_ == ProblemKind::kNotUnderAScope) { return {}; } else { ICHECK(problematic_loop_.defined()); @@ -335,7 +359,7 @@ class LoopsNotALineError : public ScheduleError { class DependentLoopError : public ScheduleError { public: explicit DependentLoopError(IRModule mod, For loop, String inner_var) - : mod_(std::move(mod)), loop_(std::move(loop)), inner_var_(std::move(inner_var)) {} + : mod_(mod), loop_(std::move(loop)), inner_var_(std::move(inner_var)) {} String FastErrorString() const final { return "ScheduleError: An outer loop's `min` or `extent` is dependent on an inner loop " @@ -355,59 +379,6 @@ class DependentLoopError : public ScheduleError { String inner_var_; }; -/*! - * \brief Collect all loops under a specific block scope in the inverse pre-order - * \param self The state of the schedule - * \param root_block_sref the sref to the root of block scope - * \return The array of srefs of all loops under the block scope, in inverse pre-order - */ -std::vector GetLoopsInversePreOrderUnderScope( - const ScheduleState& self, const StmtSRef& root_block_sref) { - std::vector loops; - const BlockNode* root_block = TVM_SREF_TO_BLOCK(root_block, root_block_sref); - // Gather all the loops under parent_block - PreOrderVisit(root_block->body, [&loops, self](const ObjectRef& node) { - // Stops at a new BlockNode - if (node->IsInstance()) { - return false; - } - // Collects every ForNode - if (const auto* loop = node.as()) { - loops.push_back(self->stmt2ref.at(loop).operator->()); - } - return true; - }); - // Reverse to get inverse preorder - std::reverse(loops.begin(), loops.end()); - return loops; -} -/*! - * \brief Check that all the blocks under the specific stmt have affine bindings and only have - * data-parallel or reduction block iters - * \param self The state of the schedule - * \param sref The sref to the specific stmt - */ -void CheckBlockIterTypeAndAffineBinding(const ScheduleState& self, const StmtSRefNode* sref) { - class BlockIterTypeAndAffineBindingChecker : public StmtVisitor { - public: - explicit BlockIterTypeAndAffineBindingChecker(const ScheduleState& state) : state_(state) {} - - private: - void VisitStmt_(const BlockNode* op) final { - for (const IterVar& iter_var : op->iter_vars) { - if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce) { - throw BlockIterTypeError(state_->mod, GetRef(op)); - } - CheckAffineBinding(state_, GetRef(op)); - } - } - const ScheduleState& state_; - }; - - BlockIterTypeAndAffineBindingChecker checker(self); - checker(GetRef(sref->stmt)); -} - Array Split(ScheduleState self, const StmtSRef& loop_sref, const Array& factors) { // Invariance @@ -576,7 +547,8 @@ void Reorder(ScheduleState self, const Array& ordered_loop_srefs) { // stop at blocknode if (sref->stmt->IsInstance()) { if (i != 0) { - throw LoopsNotALineError(self->mod, NullOpt, LoopsNotALineError::kNotUnderAScope); + throw LoopsNotALineError(self->mod, NullOpt, + LoopsNotALineError::ProblemKind::kNotUnderAScope); } else { break; } @@ -588,7 +560,7 @@ void Reorder(ScheduleState self, const Array& ordered_loop_srefs) { break; } else { throw LoopsNotALineError(self->mod, GetRef(parent_sref->stmt), - LoopsNotALineError::kHaveNonSingleBranchStmt); + LoopsNotALineError::ProblemKind::kHaveNonSingleBranchStmt); } } else { successor[parent_sref] = sref; @@ -607,12 +579,12 @@ void Reorder(ScheduleState self, const Array& ordered_loop_srefs) { const ForNode* inner_loop = TVM_SREF_TO_FOR(inner_loop, GetRef(loop_sref)); if (outer_loop->body.get() != inner_loop) { throw LoopsNotALineError(self->mod, GetRef(outer_loop), - LoopsNotALineError::kHaveNonSingleBranchStmt); + LoopsNotALineError::ProblemKind::kHaveNonSingleBranchStmt); } outer_loop = inner_loop; } // Step 4. Check the block below has all its block_var to be data-parallel or reduction - CheckBlockIterTypeAndAffineBinding(self, bottom); + BlockPropertyError::CheckBlockIterTypeAndAffineBinding(self, bottom); // Step 5. Replace the original loops with the reordered loops and check that outer loop is // not dependent on inner loop std::unordered_set inner_vars; From e9074f21dee7dac1424164bea3d29d7be27aa5b7 Mon Sep 17 00:00:00 2001 From: jinhongyi <3231950289@qq.com> Date: Sat, 21 Aug 2021 09:59:42 +0800 Subject: [PATCH 06/18] fix ci --- python/tvm/tir/schedule/schedule.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 5d0eb87eb12f..0617c220a06a 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -447,11 +447,11 @@ def reorder(self, *ordered_loops: List[LoopRV]) -> None: Reorder a list of loops. It doesn't require the loops to be consecutive. It requires: 1) The loops are in the same line. That means: the loops can be ordered to [l_1, l_2, ... , - l_n] where l_i is an ancestor of l_{i+1} and there are only single-branch loops between - l_1 and l_n (which also indicates they are under the same scope). + l_n] where l_i is an ancestor of l_{i+1} and there are only single-branch loops between + l_1 and l_n (which also indicates they are under the same scope). 2) After reordering, the domain of an outer loop cannot depend on any of the inner loops 3) For every block under the loop nests, its block binding must be affine, and the block - variables must be either data parallel or reduction. + variables must be either data parallel or reduction. 4) No duplicated loops are allowed in the arguments. Parameters @@ -497,6 +497,7 @@ def after_reorder(a: ty.handle, b: ty.handle) -> None: tir.bind(vi, i) tir.bind(vj, j) B[vi, vj] = A[vi, vj] * 2.0 + """ _ffi_api.ScheduleReorder(self, ordered_loops) # type: ignore # pylint: disable=no-member From 87dd6c16f2c6e3bca8a07eaea909a8befb4198f0 Mon Sep 17 00:00:00 2001 From: jinhongyi <3231950289@qq.com> Date: Sat, 21 Aug 2021 10:02:41 +0800 Subject: [PATCH 07/18] address comments --- include/tvm/tir/schedule/schedule.h | 2 +- python/tvm/tir/schedule/schedule.py | 2 +- src/tir/schedule/primitive.h | 2 +- .../schedule/primitive/loop_transformation.cc | 20 +++++++++---------- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 41906f30bd59..de8df69362b4 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -222,7 +222,7 @@ class ScheduleNode : public runtime::Object { /*! * \brief Reorder a list of loops. It doesn't require the loops to be consecutive. * It requires: - * 1) The loops are in the same line. That means: the loops can be ordered to [l_1, l_2, ... , + * 1) The loops are in the same chain. That means: the loops can be ordered to [l_1, l_2, ... , * l_n] where l_i is an ancestor of l_{i+1} and there are only single-branch loops between * l_1 and l_n (which also indicates they are under the same scope). * 2) After reordering, the domain of an outer loop cannot depend on any of the inner loops diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 0617c220a06a..b8b731341c5f 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -446,7 +446,7 @@ def reorder(self, *ordered_loops: List[LoopRV]) -> None: """ Reorder a list of loops. It doesn't require the loops to be consecutive. It requires: - 1) The loops are in the same line. That means: the loops can be ordered to [l_1, l_2, ... , + 1) The loops are in the same chain. That means: the loops can be ordered to [l_1, l_2, ... , l_n] where l_i is an ancestor of l_{i+1} and there are only single-branch loops between l_1 and l_n (which also indicates they are under the same scope). 2) After reordering, the domain of an outer loop cannot depend on any of the inner loops diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 0ed455fd3722..c0c06cdc38a2 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -66,7 +66,7 @@ TVM_DLL StmtSRef Fuse(ScheduleState self, const Array& loop_srefs); /*! * \brief Reorder a list of loops. It doesn't require the loops to be consecutive. * It requires: - * 1) The loops are in the same line. That means: the loops can be ordered to [l_1, l_2, ... , + * 1) The loops are in the same chain. That means: the loops can be ordered to [l_1, l_2, ... , * l_n] where l_i is an ancestor of l_{i+1} and there are only single-branch loops between * l_1 and l_n (which also indicates they are under the same scope). * 2) After reordering, the domain of an outer loop cannot depend on any of the inner loops diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index 828cd8bb2cb6..68358ec3fe69 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -321,18 +321,18 @@ class LoopMultiAppearanceError : public ScheduleError { For loop_; }; -class LoopsNotALineError : public ScheduleError { +class LoopsNotAChainError : public ScheduleError { public: enum class ProblemKind { kNotUnderAScope, kHaveNonSingleBranchStmt }; - explicit LoopsNotALineError(IRModule mod, Optional problematic_loop, ProblemKind kind) + explicit LoopsNotAChainError(IRModule mod, Optional problematic_loop, ProblemKind kind) : mod_(mod), problematic_loop_(std::move(problematic_loop)), kind_(kind) {} - String FastErrorString() const final { return "ScheduleError: the loops are not in a line"; } + String FastErrorString() const final { return "ScheduleError: the loops are not in a chain"; } String DetailRenderTemplate() const final { std::stringstream ss; - ss << "The loops are not in a line because"; + ss << "The loops are not in a chain because"; if (kind_ == ProblemKind::kNotUnderAScope) { ss << " they are not under the same scope."; } else { @@ -547,8 +547,8 @@ void Reorder(ScheduleState self, const Array& ordered_loop_srefs) { // stop at blocknode if (sref->stmt->IsInstance()) { if (i != 0) { - throw LoopsNotALineError(self->mod, NullOpt, - LoopsNotALineError::ProblemKind::kNotUnderAScope); + throw LoopsNotAChainError(self->mod, NullOpt, + LoopsNotAChainError::ProblemKind::kNotUnderAScope); } else { break; } @@ -559,8 +559,8 @@ void Reorder(ScheduleState self, const Array& ordered_loop_srefs) { if (successor[parent_sref] == sref) { break; } else { - throw LoopsNotALineError(self->mod, GetRef(parent_sref->stmt), - LoopsNotALineError::ProblemKind::kHaveNonSingleBranchStmt); + throw LoopsNotAChainError(self->mod, GetRef(parent_sref->stmt), + LoopsNotAChainError::ProblemKind::kHaveNonSingleBranchStmt); } } else { successor[parent_sref] = sref; @@ -578,8 +578,8 @@ void Reorder(ScheduleState self, const Array& ordered_loop_srefs) { loop_sref = successor[loop_sref]; const ForNode* inner_loop = TVM_SREF_TO_FOR(inner_loop, GetRef(loop_sref)); if (outer_loop->body.get() != inner_loop) { - throw LoopsNotALineError(self->mod, GetRef(outer_loop), - LoopsNotALineError::ProblemKind::kHaveNonSingleBranchStmt); + throw LoopsNotAChainError(self->mod, GetRef(outer_loop), + LoopsNotAChainError::ProblemKind::kHaveNonSingleBranchStmt); } outer_loop = inner_loop; } From 78e5ca982ff885cad8b93d2f6a9e420b434900dc Mon Sep 17 00:00:00 2001 From: jinhongyi <3231950289@qq.com> Date: Sat, 21 Aug 2021 22:37:11 +0800 Subject: [PATCH 08/18] address comments --- include/tvm/tir/schedule/schedule.h | 2 +- python/tvm/tir/schedule/schedule.py | 2 +- src/tir/schedule/primitive.h | 2 +- .../schedule/primitive/loop_transformation.cc | 119 +++++++++--------- 4 files changed, 64 insertions(+), 61 deletions(-) diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index de8df69362b4..5e223c98d74d 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -225,7 +225,7 @@ class ScheduleNode : public runtime::Object { * 1) The loops are in the same chain. That means: the loops can be ordered to [l_1, l_2, ... , * l_n] where l_i is an ancestor of l_{i+1} and there are only single-branch loops between * l_1 and l_n (which also indicates they are under the same scope). - * 2) After reordering, the domain of an outer loop cannot depend on any of the inner loops + * 2) After reordering, the domain of an outer loop cannot depend on any of the inner loops. * 3) For every block under the loop nests, its block binding must be affine, and the block * variables must be either data parallel or reduction. * 4) No duplicated loops are allowed in the arguments. diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index b8b731341c5f..c9cbf45b9055 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -449,7 +449,7 @@ def reorder(self, *ordered_loops: List[LoopRV]) -> None: 1) The loops are in the same chain. That means: the loops can be ordered to [l_1, l_2, ... , l_n] where l_i is an ancestor of l_{i+1} and there are only single-branch loops between l_1 and l_n (which also indicates they are under the same scope). - 2) After reordering, the domain of an outer loop cannot depend on any of the inner loops + 2) After reordering, the domain of an outer loop cannot depend on any of the inner loops. 3) For every block under the loop nests, its block binding must be affine, and the block variables must be either data parallel or reduction. 4) No duplicated loops are allowed in the arguments. diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index c0c06cdc38a2..2cf59f0b27c0 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -69,7 +69,7 @@ TVM_DLL StmtSRef Fuse(ScheduleState self, const Array& loop_srefs); * 1) The loops are in the same chain. That means: the loops can be ordered to [l_1, l_2, ... , * l_n] where l_i is an ancestor of l_{i+1} and there are only single-branch loops between * l_1 and l_n (which also indicates they are under the same scope). - * 2) After reordering, the domain of an outer loop cannot depend on any of the inner loops + * 2) After reordering, the domain of an outer loop cannot depend on any of the inner loops. * 3) For every block under the loop nests, its block binding must be affine, and the block * variables must be either data parallel or reduction. * 4) No duplicated loops are allowed in the arguments. diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index 68358ec3fe69..e4752dbeb5d9 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -512,94 +512,97 @@ StmtSRef Fuse(ScheduleState self, const Array& loop_srefs) { } void Reorder(ScheduleState self, const Array& ordered_loop_srefs) { - std::unordered_set loop_srefs; - loop_srefs.reserve(ordered_loop_srefs.size()); if (ordered_loop_srefs.empty() || ordered_loop_srefs.size() == 1) { return; } - // Step 1. check uniqueness + std::unordered_set loop_srefs; + loop_srefs.reserve(ordered_loop_srefs.size()); + // Step 1. Check uniqueness. for (const StmtSRef& loop_sref : ordered_loop_srefs) { const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); - // uniqueness check auto inserted = loop_srefs.insert(loop_sref.get()); if (!inserted.second) { throw LoopMultiAppearanceError(self->mod, GetRef(loop)); } } - // Step 2. gather loops to be reordered - // For each loop, traverse upwards along the parent pointer, and stop on either a block, or a - // previously-visited loop - // - the top of the reorder range is the last loop visited in the first traverse which exists in + // Step 2. Gather loops to be reordered + // For each loop sref in the input sref array, traverse upwards along its parent pointer in the + // sref tree, and stop on either a block, or a previously-visited loop + // - the top of the reorder range is the last loop visited in the first traversal which exists in // the input array // - the bottom of the reorder range is the last loop in the input array which is not visited in - // the previous traverses + // the previous traversals const StmtSRefNode* top = nullptr; - const StmtSRefNode* bottom = nullptr; - // Maps a parent sref to its child sref - std::unordered_map successor; + const StmtSRefNode* bottom = ordered_loop_srefs[0].get(); + std::unordered_set visited; + bool scope_block_visited = false; for (size_t i = 0; i < ordered_loop_srefs.size(); i++) { - const StmtSRefNode* sref = ordered_loop_srefs[i].get(); - // if sref is not visited before, update `bottom` - if (!successor.count(sref->parent)) { - bottom = sref; + const StmtSRefNode* loop_sref = ordered_loop_srefs[i].get(); + if (visited.count(loop_sref)) { + continue; } - while (true) { - // stop at blocknode - if (sref->stmt->IsInstance()) { - if (i != 0) { + for (const StmtSRefNode* v = loop_sref;; v = v->parent) { + // Case 1. If `v` corresponds to a block, stop traversal. + if (v->stmt->IsInstance()) { + if (scope_block_visited) { throw LoopsNotAChainError(self->mod, NullOpt, LoopsNotAChainError::ProblemKind::kNotUnderAScope); - } else { - break; } + scope_block_visited = true; + break; } - const StmtSRefNode* parent_sref = sref->parent; - // stop at previously-visited loop - if (successor.count(parent_sref)) { - if (successor[parent_sref] == sref) { - break; - } else { - throw LoopsNotAChainError(self->mod, GetRef(parent_sref->stmt), + // Case 2. If `v` corresponds to a previously-visited loop, stop traversal and update + // `bottom`. + if (visited.count(v)) { + if (v == bottom) { + throw LoopsNotAChainError(self->mod, GetRef(v->stmt), LoopsNotAChainError::ProblemKind::kHaveNonSingleBranchStmt); } - } else { - successor[parent_sref] = sref; + bottom = loop_sref; + break; } - // if it's the first traverse and the loop is in the input array, update `top` - if (loop_srefs.count(sref) && i == 0) { - top = sref; + // Case 3. Add `v` into `visited` + visited.insert(v); + // If it's the first traversal and the loop corresponding to `v` is in the input array, + // update `top`. + if (loop_srefs.count(v) && i == 0) { + top = v; } - sref = parent_sref; } } - // Step 3. Check that loops are single-branch - const ForNode* outer_loop = TVM_SREF_TO_FOR(outer_loop, GetRef(top)); - for (const StmtSRefNode* loop_sref = top; loop_sref != bottom;) { - loop_sref = successor[loop_sref]; - const ForNode* inner_loop = TVM_SREF_TO_FOR(inner_loop, GetRef(loop_sref)); - if (outer_loop->body.get() != inner_loop) { - throw LoopsNotAChainError(self->mod, GetRef(outer_loop), - LoopsNotAChainError::ProblemKind::kHaveNonSingleBranchStmt); + // Step 3. Collect all loops in the chain and check the loops are single-branch + std::vector chain; + for (const StmtSRefNode* loop_sref = bottom;; loop_sref = loop_sref->parent) { + if (!chain.empty()) { + const ForNode* outer_loop = TVM_SREF_TO_FOR(outer_loop, GetRef(loop_sref)); + const ForNode* inner_loop = TVM_SREF_TO_FOR(inner_loop, GetRef(chain.back())); + if (outer_loop->body.get() != inner_loop) { + throw LoopsNotAChainError(self->mod, GetRef(outer_loop), + LoopsNotAChainError::ProblemKind::kHaveNonSingleBranchStmt); + } + } + chain.push_back(loop_sref); + if (loop_sref == top) { + break; } - outer_loop = inner_loop; } - // Step 4. Check the block below has all its block_var to be data-parallel or reduction + // Step 4. Check the block below has all its block_var to be data-parallel or reduction, + // and the block has an affine binding. BlockPropertyError::CheckBlockIterTypeAndAffineBinding(self, bottom); // Step 5. Replace the original loops with the reordered loops and check that outer loop is // not dependent on inner loop std::unordered_set inner_vars; - std::function f_reorder = - [&bottom, &loop_srefs, &successor, &ordered_loop_srefs, &inner_vars, &self, &f_reorder]( - const StmtSRefNode* loop, int index) -> Stmt { - const ForNode* copy = loop_srefs.count(loop) ? ordered_loop_srefs[index++]->StmtAs() - : loop->StmtAs(); + For new_loop; + int index = ordered_loop_srefs.size() - 1; + for (const StmtSRefNode* loop_sref : chain) { + const ForNode* copy = loop_srefs.count(loop_sref) + ? ordered_loop_srefs[index--]->StmtAs() + : loop_sref->StmtAs(); ObjectPtr n = make_object(*copy); - if (loop == bottom) { - // stop recursion at bottom loop - n->body = loop->StmtAs()->body; + if (new_loop.defined()) { + n->body = new_loop; } else { - // reorder recursively - n->body = f_reorder(successor.at(loop), index); + n->body = loop_sref->StmtAs()->body; } const VarNode* used_var; auto f_contain = [&inner_vars, &used_var](const VarNode* var) { @@ -613,9 +616,9 @@ void Reorder(ScheduleState self, const Array& ordered_loop_srefs) { throw DependentLoopError(self->mod, GetRef(copy), used_var->name_hint); } inner_vars.insert(copy->loop_var.get()); - return Stmt(std::move(n)); - }; - self->Replace(GetRef(top), f_reorder(top, 0), {}); + new_loop = For(std::move(n)); + } + self->Replace(GetRef(top), new_loop, {}); } /******** Instruction Registration ********/ From 11a5fc26888fccc4625c6665f87b6bbb49579daf Mon Sep 17 00:00:00 2001 From: jinhongyi <3231950289@qq.com> Date: Sun, 22 Aug 2021 14:29:16 +0800 Subject: [PATCH 09/18] address comments --- .../schedule/primitive/loop_transformation.cc | 46 +++++++++++-------- .../unittest/test_tir_schedule_reorder.py | 29 ++++++++++++ 2 files changed, 55 insertions(+), 20 deletions(-) diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index e4752dbeb5d9..b84bb50304a4 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -519,9 +519,9 @@ void Reorder(ScheduleState self, const Array& ordered_loop_srefs) { loop_srefs.reserve(ordered_loop_srefs.size()); // Step 1. Check uniqueness. for (const StmtSRef& loop_sref : ordered_loop_srefs) { - const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); auto inserted = loop_srefs.insert(loop_sref.get()); if (!inserted.second) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); throw LoopMultiAppearanceError(self->mod, GetRef(loop)); } } @@ -554,7 +554,7 @@ void Reorder(ScheduleState self, const Array& ordered_loop_srefs) { // Case 2. If `v` corresponds to a previously-visited loop, stop traversal and update // `bottom`. if (visited.count(v)) { - if (v == bottom) { + if (v != bottom) { throw LoopsNotAChainError(self->mod, GetRef(v->stmt), LoopsNotAChainError::ProblemKind::kHaveNonSingleBranchStmt); } @@ -565,46 +565,52 @@ void Reorder(ScheduleState self, const Array& ordered_loop_srefs) { visited.insert(v); // If it's the first traversal and the loop corresponding to `v` is in the input array, // update `top`. - if (loop_srefs.count(v) && i == 0) { + if (i == 0 && loop_srefs.count(v)) { top = v; } } } // Step 3. Collect all loops in the chain and check the loops are single-branch std::vector chain; - for (const StmtSRefNode* loop_sref = bottom;; loop_sref = loop_sref->parent) { - if (!chain.empty()) { - const ForNode* outer_loop = TVM_SREF_TO_FOR(outer_loop, GetRef(loop_sref)); - const ForNode* inner_loop = TVM_SREF_TO_FOR(inner_loop, GetRef(chain.back())); - if (outer_loop->body.get() != inner_loop) { - throw LoopsNotAChainError(self->mod, GetRef(outer_loop), - LoopsNotAChainError::ProblemKind::kHaveNonSingleBranchStmt); - } + chain.reserve(visited.size()); + for (const StmtSRefNode* loop_sref = bottom; loop_sref != top;) { + const StmtSRefNode* parent_loop_sref = loop_sref->parent; + const ForNode* outer = parent_loop_sref->StmtAs(); + const ForNode* inner = loop_sref->StmtAs(); + ICHECK(outer != nullptr && inner != nullptr); + if (outer->body.get() != inner) { + throw LoopsNotAChainError(self->mod, GetRef(outer), + LoopsNotAChainError::ProblemKind::kHaveNonSingleBranchStmt); } chain.push_back(loop_sref); - if (loop_sref == top) { - break; - } + loop_sref = parent_loop_sref; } + chain.push_back(top); // Step 4. Check the block below has all its block_var to be data-parallel or reduction, // and the block has an affine binding. BlockPropertyError::CheckBlockIterTypeAndAffineBinding(self, bottom); // Step 5. Replace the original loops with the reordered loops and check that outer loop is // not dependent on inner loop std::unordered_set inner_vars; - For new_loop; - int index = ordered_loop_srefs.size() - 1; + inner_vars.reserve(chain.size()); + For new_loop{nullptr}; + int index = static_cast(ordered_loop_srefs.size()) - 1; for (const StmtSRefNode* loop_sref : chain) { - const ForNode* copy = loop_srefs.count(loop_sref) - ? ordered_loop_srefs[index--]->StmtAs() - : loop_sref->StmtAs(); + const ForNode* copy = nullptr; + if (loop_srefs.count(loop_sref)) { + copy = ordered_loop_srefs[index]->StmtAs(); + --index; + } else { + copy = loop_sref->StmtAs(); + } + ICHECK(copy != nullptr); ObjectPtr n = make_object(*copy); if (new_loop.defined()) { n->body = new_loop; } else { n->body = loop_sref->StmtAs()->body; } - const VarNode* used_var; + const VarNode* used_var = nullptr; auto f_contain = [&inner_vars, &used_var](const VarNode* var) { if (inner_vars.count(var)) { used_var = var; diff --git a/tests/python/unittest/test_tir_schedule_reorder.py b/tests/python/unittest/test_tir_schedule_reorder.py index f463ab111d16..091a77df2030 100644 --- a/tests/python/unittest/test_tir_schedule_reorder.py +++ b/tests/python/unittest/test_tir_schedule_reorder.py @@ -130,6 +130,19 @@ def elementwise_reordered(a: ty.handle, b: ty.handle) -> None: B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 +@tvm.script.tir +def elementwise_reordered2(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128, 128)) + for k, j, i, l in tir.grid(128, 128, 128, 128): + with tir.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + tir.bind(vl, l) + B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 + + @tvm.script.tir def elementwise_reordered_with_predicate(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (128, 128, 128, 128)) @@ -190,6 +203,15 @@ def test_reorder(): verify_trace_roundtrip(sch=sch, mod=elementwise) +def test_reorder2(): + sch = tir.Schedule(elementwise, debug_mask="all") + block_b = sch.get_block("B") + i, j, k, l = sch.get_loops(block_b) + sch.reorder(k, i, l) + tvm.ir.assert_structural_equal(elementwise_reordered2, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise) + + def test_reorder_with_opaque_access(): sch = tir.Schedule(opaque_access, debug_mask="all") block_a = sch.get_block("A") @@ -225,6 +247,13 @@ def test_reorder_fail_with_non_single_branch_loop(): i, j, k = sch.get_loops(block_b) with pytest.raises(tvm.tir.ScheduleError): sch.reorder(k, i) + sch = tir.Schedule(elementwise_non_single_branch, debug_mask="all") + block_b = sch.get_block("B") + block_c = sch.get_block("C") + i, j, k1 = sch.get_loops(block_b) + _, _, k2 = sch.get_loops(block_c) + with pytest.raises(tvm.tir.ScheduleError): + sch.reorder(k1, i, k2) def test_reorder_fail_with_loops_not_under_same_scope(): From 98e951a0c96077c6c18b53cb4eaefed371aa7c6b Mon Sep 17 00:00:00 2001 From: jinhongyi <3231950289@qq.com> Date: Sun, 22 Aug 2021 22:51:49 +0800 Subject: [PATCH 10/18] address comments --- .../schedule/primitive/loop_transformation.cc | 108 ++++++++++++++---- 1 file changed, 84 insertions(+), 24 deletions(-) diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index b84bb50304a4..1046af86c537 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -510,14 +510,17 @@ StmtSRef Fuse(ScheduleState self, const Array& loop_srefs) { self->Replace(loop_srefs[0], new_stmt, opaque_block_reuse); return self->stmt2ref.at(new_stmt.get()); } - -void Reorder(ScheduleState self, const Array& ordered_loop_srefs) { - if (ordered_loop_srefs.empty() || ordered_loop_srefs.size() == 1) { - return; - } +/*! + * \brief collect an array of loop srefs into a set + * \param self The schedule state + * \param ordered_loop_srefs The array of loop srefs + * \return A set containing all loops in the array + * \throws ScheduleError If there are duplicate loops in the array + */ +std::unordered_set CollectLoopsIntoSet( + const ScheduleState& self, const Array& ordered_loop_srefs) { std::unordered_set loop_srefs; loop_srefs.reserve(ordered_loop_srefs.size()); - // Step 1. Check uniqueness. for (const StmtSRef& loop_sref : ordered_loop_srefs) { auto inserted = loop_srefs.insert(loop_sref.get()); if (!inserted.second) { @@ -525,19 +528,24 @@ void Reorder(ScheduleState self, const Array& ordered_loop_srefs) { throw LoopMultiAppearanceError(self->mod, GetRef(loop)); } } - // Step 2. Gather loops to be reordered - // For each loop sref in the input sref array, traverse upwards along its parent pointer in the - // sref tree, and stop on either a block, or a previously-visited loop - // - the top of the reorder range is the last loop visited in the first traversal which exists in - // the input array - // - the bottom of the reorder range is the last loop in the input array which is not visited in - // the previous traversals + return loop_srefs; +} + +/*! + * \brief Get the top and bottom boundary of reorder range (which should be a chain) + * \param self The schedule state + * \param loop_srefs The set containing the srefs to the loops to be reordered + * \return a pair containing the top and bottom boundary of the reorder range + * \throws ScheduleError If the loops to be reordered is not in a chain + */ +std::pair GetBoundaryOfReorderRange( + const ScheduleState& self, const std::unordered_set& loop_srefs) { const StmtSRefNode* top = nullptr; - const StmtSRefNode* bottom = ordered_loop_srefs[0].get(); + const StmtSRefNode* bottom = *loop_srefs.begin(); std::unordered_set visited; bool scope_block_visited = false; - for (size_t i = 0; i < ordered_loop_srefs.size(); i++) { - const StmtSRefNode* loop_sref = ordered_loop_srefs[i].get(); + bool first_traversal = true; + for (const StmtSRefNode* loop_sref : loop_srefs) { if (visited.count(loop_sref)) { continue; } @@ -565,14 +573,27 @@ void Reorder(ScheduleState self, const Array& ordered_loop_srefs) { visited.insert(v); // If it's the first traversal and the loop corresponding to `v` is in the input array, // update `top`. - if (i == 0 && loop_srefs.count(v)) { + if (first_traversal && loop_srefs.count(v)) { top = v; } } + first_traversal = false; } - // Step 3. Collect all loops in the chain and check the loops are single-branch + return std::make_pair(top, bottom); +} + +/*! + * \brief get all the loops in the reorder range + * \param self The schedule state + * \param top The top boundary of the reorder range + * \param bottom The bottom boundary of the reorder range + * \return an array containing all the loops in the reorder range + * \throws ScheduleError If some loop in the reorder range is not single-branch + */ +std::vector GetLoopsInReorderRange(const ScheduleState& self, + const StmtSRefNode* top, + const StmtSRefNode* bottom) { std::vector chain; - chain.reserve(visited.size()); for (const StmtSRefNode* loop_sref = bottom; loop_sref != top;) { const StmtSRefNode* parent_loop_sref = loop_sref->parent; const ForNode* outer = parent_loop_sref->StmtAs(); @@ -586,11 +607,22 @@ void Reorder(ScheduleState self, const Array& ordered_loop_srefs) { loop_sref = parent_loop_sref; } chain.push_back(top); - // Step 4. Check the block below has all its block_var to be data-parallel or reduction, - // and the block has an affine binding. - BlockPropertyError::CheckBlockIterTypeAndAffineBinding(self, bottom); - // Step 5. Replace the original loops with the reordered loops and check that outer loop is - // not dependent on inner loop + return chain; +} + +/*! + * \brief Construct a loop chain in the new order + * \param self The schedule state + * \param chain The loops in the reorder range + * \param ordered_loop_srefs The loop srefs to be reordered + * \param loop_srefs The set containing loop srefs to be reordered + * \return the new loop chain + * \throws ScheduleError If the domain of an outer loop depends on any of the inner loops after + * reordering + */ +For ConstructNewLoopChain(const ScheduleState& self, std::vector chain, + const Array& ordered_loop_srefs, + const std::unordered_set& loop_srefs) { std::unordered_set inner_vars; inner_vars.reserve(chain.size()); For new_loop{nullptr}; @@ -624,6 +656,34 @@ void Reorder(ScheduleState self, const Array& ordered_loop_srefs) { inner_vars.insert(copy->loop_var.get()); new_loop = For(std::move(n)); } + return new_loop; +} + +void Reorder(ScheduleState self, const Array& ordered_loop_srefs) { + if (ordered_loop_srefs.empty() || ordered_loop_srefs.size() == 1) { + return; + } + // Step 1. Check uniqueness.and collect the input loop srefs into a set + std::unordered_set loop_srefs = + CollectLoopsIntoSet(self, ordered_loop_srefs); + // Step 2. Gather loops to be reordered + // For each loop sref in the input sref array, traverse upwards along its parent pointer in the + // sref tree, and stop on either a block, or a previously-visited loop + // - the top of the reorder range is the last loop visited in the first traversal which exists in + // the input array + // - the bottom of the reorder range is the last loop in the input array which is not visited in + // the previous traversals + auto pair = GetBoundaryOfReorderRange(self, loop_srefs); + const StmtSRefNode* top = pair.first; + const StmtSRefNode* bottom = pair.second; + // Step 3. Collect all loops in the chain and check the loops are single-branch + std::vector chain = GetLoopsInReorderRange(self, top, bottom); + // Step 4. Check the block below has all its block_var to be data-parallel or reduction, + // and the block has an affine binding. + BlockPropertyError::CheckBlockIterTypeAndAffineBinding(self, bottom); + // Step 5. Replace the original loops with the reordered loops and check that outer loop is + // not dependent on inner loop + For new_loop = ConstructNewLoopChain(self, std::move(chain), ordered_loop_srefs, loop_srefs); self->Replace(GetRef(top), new_loop, {}); } From 15c99db169d63352f080fc594ef5b970a30a10af Mon Sep 17 00:00:00 2001 From: jinhongyi <3231950289@qq.com> Date: Sun, 22 Aug 2021 22:52:31 +0800 Subject: [PATCH 11/18] address comments --- src/tir/schedule/primitive/loop_transformation.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index 1046af86c537..3051fd68431b 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -660,7 +660,7 @@ For ConstructNewLoopChain(const ScheduleState& self, std::vector& ordered_loop_srefs) { - if (ordered_loop_srefs.empty() || ordered_loop_srefs.size() == 1) { + if (ordered_loop_srefs.size() <= 1) { return; } // Step 1. Check uniqueness.and collect the input loop srefs into a set From 3d11d604ef03937063ed50a0f510da1f02f0615c Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 22 Aug 2021 12:22:57 -0700 Subject: [PATCH 12/18] Update src/tir/schedule/primitive/loop_transformation.cc Co-authored-by: Ruihang Lai --- src/tir/schedule/primitive/loop_transformation.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index 3051fd68431b..55d597f54709 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -511,7 +511,7 @@ StmtSRef Fuse(ScheduleState self, const Array& loop_srefs) { return self->stmt2ref.at(new_stmt.get()); } /*! - * \brief collect an array of loop srefs into a set + * \brief Collect an array of loop srefs into a set * \param self The schedule state * \param ordered_loop_srefs The array of loop srefs * \return A set containing all loops in the array From 9c69dc3968f26b0d58815ca933334d1816bfcd44 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 22 Aug 2021 12:23:05 -0700 Subject: [PATCH 13/18] Update src/tir/schedule/primitive/loop_transformation.cc Co-authored-by: Ruihang Lai --- src/tir/schedule/primitive/loop_transformation.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index 55d597f54709..b99a45ad7425 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -663,7 +663,7 @@ void Reorder(ScheduleState self, const Array& ordered_loop_srefs) { if (ordered_loop_srefs.size() <= 1) { return; } - // Step 1. Check uniqueness.and collect the input loop srefs into a set + // Step 1. Check uniqueness and collect the input loop srefs into a set std::unordered_set loop_srefs = CollectLoopsIntoSet(self, ordered_loop_srefs); // Step 2. Gather loops to be reordered From 59a8afa7ff32ec81f6ebb3e5eb7d0f46fce2fc99 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 22 Aug 2021 12:23:14 -0700 Subject: [PATCH 14/18] Update src/tir/schedule/primitive/loop_transformation.cc Co-authored-by: Ruihang Lai --- src/tir/schedule/primitive/loop_transformation.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index b99a45ad7425..b79e09e4a6eb 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -616,7 +616,7 @@ std::vector GetLoopsInReorderRange(const ScheduleState& sel * \param chain The loops in the reorder range * \param ordered_loop_srefs The loop srefs to be reordered * \param loop_srefs The set containing loop srefs to be reordered - * \return the new loop chain + * \return The new loop chain * \throws ScheduleError If the domain of an outer loop depends on any of the inner loops after * reordering */ From c6727f2987d9745dbbf94de6b1a9531ab9338a19 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 22 Aug 2021 12:23:21 -0700 Subject: [PATCH 15/18] Update src/tir/schedule/primitive/loop_transformation.cc Co-authored-by: Ruihang Lai --- src/tir/schedule/primitive/loop_transformation.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index b79e09e4a6eb..a499aacedc64 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -583,7 +583,7 @@ std::pair GetBoundaryOfReorderRange( } /*! - * \brief get all the loops in the reorder range + * \brief Get all the loops in the reorder range * \param self The schedule state * \param top The top boundary of the reorder range * \param bottom The bottom boundary of the reorder range From ee0f584c9f71ce98a3240e137860d061163209bc Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 22 Aug 2021 12:23:27 -0700 Subject: [PATCH 16/18] Update src/tir/schedule/primitive/loop_transformation.cc Co-authored-by: Ruihang Lai --- src/tir/schedule/primitive/loop_transformation.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index a499aacedc64..e9324c548d3f 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -535,7 +535,7 @@ std::unordered_set CollectLoopsIntoSet( * \brief Get the top and bottom boundary of reorder range (which should be a chain) * \param self The schedule state * \param loop_srefs The set containing the srefs to the loops to be reordered - * \return a pair containing the top and bottom boundary of the reorder range + * \return A pair containing the top and bottom boundary of the reorder range * \throws ScheduleError If the loops to be reordered is not in a chain */ std::pair GetBoundaryOfReorderRange( From 8a6a151bd789ce3dbf3917818332363078694adc Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 22 Aug 2021 12:23:36 -0700 Subject: [PATCH 17/18] Update src/tir/schedule/primitive/loop_transformation.cc Co-authored-by: Ruihang Lai --- src/tir/schedule/primitive/loop_transformation.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index e9324c548d3f..2c6699d45c8d 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -587,7 +587,7 @@ std::pair GetBoundaryOfReorderRange( * \param self The schedule state * \param top The top boundary of the reorder range * \param bottom The bottom boundary of the reorder range - * \return an array containing all the loops in the reorder range + * \return An array containing all the loops in the reorder range * \throws ScheduleError If some loop in the reorder range is not single-branch */ std::vector GetLoopsInReorderRange(const ScheduleState& self, From c6d3c0da27cf7d0ce7f3459284048d8820045b55 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 22 Aug 2021 12:24:01 -0700 Subject: [PATCH 18/18] Update src/tir/schedule/primitive/loop_transformation.cc Co-authored-by: Ruihang Lai --- src/tir/schedule/primitive/loop_transformation.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index 2c6699d45c8d..7c2b61344427 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -673,9 +673,9 @@ void Reorder(ScheduleState self, const Array& ordered_loop_srefs) { // the input array // - the bottom of the reorder range is the last loop in the input array which is not visited in // the previous traversals - auto pair = GetBoundaryOfReorderRange(self, loop_srefs); - const StmtSRefNode* top = pair.first; - const StmtSRefNode* bottom = pair.second; + const StmtSRefNode* top = nullptr; + const StmtSRefNode* bottom = nullptr; + std::tie(top, bottom) = GetBoundaryOfReorderRange(self, loop_srefs); // Step 3. Collect all loops in the chain and check the loops are single-branch std::vector chain = GetLoopsInReorderRange(self, top, bottom); // Step 4. Check the block below has all its block_var to be data-parallel or reduction,