diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 13e1e2510e29..45ef1ae93461 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -340,6 +340,14 @@ TVM_DLL Pass NarrowDataType(int target_bits); */ TVM_DLL Pass PointerValueTypeRewrite(); +/*! + * \brief Hoist loop-invariant IfThenElse nodes to + * outside the corresponding loops. + * + * \return The pass. + */ +TVM_DLL Pass HoistIfThenElse(); + } // namespace transform } // namespace tir } // namespace tvm diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 216cad992d98..345c472e774a 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -187,6 +187,8 @@ def lower(sch, pass_list += [ tvm.tir.transform.VectorizeLoop(not cfg.disable_vectorize), tvm.tir.transform.InjectVirtualThread(), + tvm.tir.transform.HoistIfThenElse(), # After InjectVirtualThread + # to protect vthread loops tvm.tir.transform.InjectDoubleBuffer(cfg.double_buffer_split_loop), tvm.tir.transform.StorageRewrite(), tvm.tir.transform.UnrollLoop( diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 6d797f8772ec..2efe714a661e 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -478,3 +478,13 @@ def VerifyMemory(): The result pass """ return _ffi_api.VerifyMemory() + +def HoistIfThenElse(): + """Hoist loop-invariant IfThenElse nodes to outside the corresponding loops. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.HoistIfThenElse() diff --git a/src/tir/pass/hoist_if_then_else.cc b/src/tir/transforms/hoist_if_then_else.cc similarity index 79% rename from src/tir/pass/hoist_if_then_else.cc rename to src/tir/transforms/hoist_if_then_else.cc index 67a88f5d922e..020b7be532de 100644 --- a/src/tir/pass/hoist_if_then_else.cc +++ b/src/tir/transforms/hoist_if_then_else.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -162,6 +163,83 @@ Stmt update_for(const Stmt& parent_for_stmt, const Stmt& new_if_stmt) { return IRTransform(parent_for_stmt, nullptr, replace_target_for, Array{"For"}); } +template +static bool no_intersect(const std::vector &vec, const std::unordered_set &set) { + for (auto &&item : vec) { + if (set.count(item)) + return false; + } + return true; +} + +// Rename all the Var defined in the else case, to meet the SSA requirement +class Renamer : public StmtExprMutator { + public: + explicit Renamer(const std::string &suffix) + : suffix_(suffix) {} + + Stmt Rename(Stmt stmt) { + stmt = operator()(std::move(stmt)); + return Substitute(std::move(stmt), var_map_); + } + + protected: + Stmt VisitStmt_(const ForNode *op) override { + depth_++; + auto ret = StmtExprMutator::VisitStmt_(op); + depth_--; + if (depth_ >= 1) { + return ret; + } + op = ret.as(); + Var new_var(op->loop_var->name_hint + suffix_); + var_map_.Set(op->loop_var, new_var); + return ForNode::make(new_var, op->min, op->extent, op->for_type, + op->device_api, op->body); + } + + Stmt VisitStmt_(const AllocateNode *op) override { + auto ret = StmtExprMutator::VisitStmt_(op); + if (depth_ >= 1) { + return ret; + } + op = ret.as(); + Var new_var(op->buffer_var->name_hint + suffix_); + var_map_.Set(op->buffer_var, new_var); + return AllocateNode::make(new_var, op->dtype, op->extents, + op->condition, op->body); + } + + Stmt VisitStmt_(const LetStmtNode *op) override { + auto ret = StmtExprMutator::VisitStmt_(op); + if (depth_ >= 1) { + return ret; + } + op = ret.as(); + Var new_var(op->var->name_hint + suffix_); + var_map_.Set(op->var, new_var); + return LetStmtNode::make(new_var, op->value, op->body); + } + + PrimExpr VisitExpr_(const LetNode *op) override { + auto ret = StmtExprMutator::VisitExpr_(op); + if (depth_ >= 1) { + return ret; + } + op = ret.as(); + Var new_var(op->var->name_hint + suffix_); + var_map_.Set(op->var, new_var); + return LetNode::make(new_var, op->value, op->body); + } + + private: + int depth_ = 0; // how may For nodes we are in + // we only rename the out-most loop, because + // Rename is called iteratively + const std::string &suffix_; // name suffix + Map var_map_; // old var -> new var +}; + // Remove IfThenElse node from a For node. // A pair of For nodes will be generated. std::pair RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) { @@ -186,6 +264,7 @@ std::pair RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) { then_for = IRTransform(for_stmt, nullptr, replace_then_case, Array{"IfThenElse"}); if (if_stmt.as()->else_case.defined()) { else_for = IRTransform(for_stmt, nullptr, replace_else_case, Array{"IfThenElse"}); + else_for = Renamer(".else").Rename(std::move(else_for)); } return std::make_pair(then_for, else_for); @@ -198,6 +277,8 @@ void IfThenElseHoist::SelectCandidates(const Stmt& stmt) { if (!for_node) return; std::queue tracker; + std::vector var_def; // don't hoist thread indices out of their + // definition region tracker.push(for_node->body); Stmt for_stmt = Downcast(node); for2if_map_.insert({for_stmt.get(), std::vector()}); @@ -206,19 +287,17 @@ void IfThenElseHoist::SelectCandidates(const Stmt& stmt) { tracker.pop(); if (head->IsInstance()) { for (const auto& if_stmt : for2if_map_.at(head.get())) { - for2if_map_[for_stmt.get()].push_back(if_stmt); + if (no_intersect(var_def, cond_var_map_[if_stmt.get()])) { + for2if_map_[for_stmt.get()].push_back(if_stmt); + } } - } else if (head->IsInstance()) { - const AttrStmtNode* attr_node = head.as(); - tracker.push(attr_node->body); - } else if (head->IsInstance()) { - for2if_map_[for_stmt.get()].push_back(head); - const IfThenElseNode* if_node = head.as(); - tracker.push(if_node->then_case); - if (if_node->else_case.defined()) { - tracker.push(if_node->else_case); + } else if (auto attr_node = head.as()) { + if (attr_node->attr_key == attr::thread_extent) { + IterVar iv = Downcast(attr_node->node); + var_def.push_back(iv->var.get()); } - + tracker.push(attr_node->body); + } else if (auto if_node = head.as()) { // Record condition variables. if (!cond_var_map_.count(head.get())) { std::unordered_set new_var_set; @@ -229,6 +308,14 @@ void IfThenElseHoist::SelectCandidates(const Stmt& stmt) { } }); } + + if (no_intersect(var_def, cond_var_map_[head.get()])) { + for2if_map_[for_stmt.get()].push_back(head); + } + tracker.push(if_node->then_case); + if (if_node->else_case.defined()) { + tracker.push(if_node->else_case); + } } else { continue; } @@ -292,7 +379,8 @@ void IfThenElseHoist::LocateTopFor() { } else { std::vector actual_if_list; for (const Stmt& if_stmt : if_list) { - if (if_position_map.count(if_stmt.get())) { + if (if_position_map.count(if_stmt.get()) && + if_position_map.at(if_stmt.get()).as()->loop_var.get() == top_for_var) { actual_if_list.push_back(if_stmt); } } @@ -399,6 +487,21 @@ Stmt IfThenElseHoist::PostOrderMutate(const Stmt& stmt) { Stmt HoistIfThenElse(Stmt stmt) { return IfThenElseHoist().VisitAndMutate(stmt); } +namespace transform { + +Pass HoistIfThenElse() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = HoistIfThenElse(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.HoistIfThenElse", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.HoistIfThenElse").set_body_typed(HoistIfThenElse); + +} // namespace transform + TVM_REGISTER_GLOBAL("testing.HoistIfThenElse").set_body_typed(HoistIfThenElse); } // namespace tir diff --git a/tests/python/unittest/test_te_build_lower.py b/tests/python/unittest/test_te_build_lower.py index b1d754605a46..3880e01400bd 100644 --- a/tests/python/unittest/test_te_build_lower.py +++ b/tests/python/unittest/test_te_build_lower.py @@ -17,6 +17,11 @@ import tvm from tvm import te +def collect_visit(stmt, f): + ret = [] + tvm.tir.stmt_functor.post_order_visit(stmt, lambda x : ret.append(f(x))) + return ret + def test_lower_rfactor(): n = te.size_var("n") m = te.size_var("m") @@ -49,7 +54,7 @@ def test_split_uneven_unique_likely(): sch = te.create_schedule(c.op) xo, xi = sch[c].split(x, 5) stmt = tvm.lower(sch, [a, b, c])["main"].body - assert isinstance(stmt.body.body.body, tvm.tir.stmt.IfThenElse) + assert(any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse)))) if __name__ == "__main__": diff --git a/tests/python/unittest/test_tir_pass_hoist_if.py b/tests/python/unittest/test_tir_transform_hoist_if.py similarity index 90% rename from tests/python/unittest/test_tir_pass_hoist_if.py rename to tests/python/unittest/test_tir_transform_hoist_if.py index 346239d302cf..526fba0fd530 100644 --- a/tests/python/unittest/test_tir_pass_hoist_if.py +++ b/tests/python/unittest/test_tir_transform_hoist_if.py @@ -69,7 +69,8 @@ def test_basic(): stmt = ib.get() new_stmt = tvm.testing.HoistIfThenElse(stmt) expected_struct = {('For', 'k'): (None,), ('For', 'j'): (('For', 'k'),), - ('IfThenElse', ('i',)): (('For', 'j'), ('For', 'j')), + ('For', 'k.else'): (None,), ('For', 'j.else'): (('For', 'k.else'),), + ('IfThenElse', ('i',)): (('For', 'j'), ('For', 'j.else')), ('For', 'i'): (('IfThenElse', ('i',)),)} verify_structure(new_stmt, expected_struct) @@ -114,7 +115,8 @@ def test_attr_stmt(): stmt = ib.get() new_stmt = tvm.testing.HoistIfThenElse(stmt) - expected_struct = {('For', 'k'): (None,), ('IfThenElse', ('i', 'j')): (('For', 'k'), ('For', 'k')), + expected_struct = {('For', 'k'): (None,), ('For', 'k.else'): (None,), + ('IfThenElse', ('i', 'j')): (('For', 'k'), ('For', 'k.else')), ('For', 'j'): (('IfThenElse', ('i', 'j')),), ('For', 'i'): (('For', 'j'),), ('AttrStmt', 'thread_extent', 64): (('For', 'i'),), ('AttrStmt', 'thread_extent', 32): (('AttrStmt', 'thread_extent', 64),)} @@ -177,6 +179,20 @@ def test_if_block(): ('IfThenElse', ('n',)): (('For', 'j'), None)} verify_structure(new_stmt, expected_struct) +def test_multi_if(): + ib = tvm.tir.ir_builder.create() + data = ib.pointer("float32", name="data") + + with ib.for_range(0, 10, "i") as i: + with ib.for_range(0, 10, "j") as j: + with ib.for_range(0, 10, "k") as k: + with ib.if_scope(i >= 3): + with ib.if_scope(j >= 3): + data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] + 0.5 + + stmt = ib.get() + new_stmt = tvm.testing.HoistIfThenElse(stmt) + if __name__ == "__main__": test_basic() @@ -184,3 +200,4 @@ def test_if_block(): test_attr_stmt() test_nested_for() test_if_block() + test_multi_if()