From 01eee1951e596b1fd73989e0f6581e3e73bef5e0 Mon Sep 17 00:00:00 2001 From: Wang Yao Date: Fri, 30 Aug 2019 14:06:07 -0700 Subject: [PATCH 1/8] Add LiftIfThenElse pass --- include/tvm/ir_pass.h | 7 + src/api/api_pass.cc | 1 + src/pass/lift_if_then_else.cc | 295 +++++++++++++++++++++ tests/python/unittest/test_pass_lift_if.py | 147 ++++++++++ 4 files changed, 450 insertions(+) create mode 100644 src/pass/lift_if_then_else.cc create mode 100644 tests/python/unittest/test_pass_lift_if.py diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 5ac71fdce47b..64e460689f50 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -377,6 +377,13 @@ Stmt LowerStorageAccessInfo(Stmt stmt); */ Stmt DecorateDeviceScope(Stmt stmt); +/*! + * \brief Loop invariant code motion which locates and lifts if statements. + * \param stmt The stmt to do if statement iifting. + * \return Transformed stmt. + */ +Stmt LiftIfThenElse(Stmt stmt); + /*! * \brief Make an user callable API LoweredFunc. * diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index 25cd5838385f..08f0afca317a 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -160,5 +160,6 @@ REGISTER_PASS(VerifyGPUCode); REGISTER_PASS(DecorateDeviceScope); REGISTER_PASS(InstrumentBoundCheckers); REGISTER_PASS(VerifyCompactBuffer); +REGISTER_PASS(LiftIfThenElse); } // namespace ir } // namespace tvm diff --git a/src/pass/lift_if_then_else.cc b/src/pass/lift_if_then_else.cc new file mode 100644 index 000000000000..3ce1861b891f --- /dev/null +++ b/src/pass/lift_if_then_else.cc @@ -0,0 +1,295 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "../arithmetic/int_set.h" +#include "../runtime/thread_storage_scope.h" + +namespace tvm { +namespace ir { + +using LifterMap = std::unordered_map>; +using VarMap = std::unordered_map>; + +class IfThenElseLifter : public IRMutator { + public: + Stmt VisitAndMutate(const Stmt& stmt) { + GenerateInternalData(stmt); + return PostOrderMutate(stmt); + } + + Stmt PostOrderMutate(const Stmt& stmt) { + PackedFunc replace_top_for = PackedFunc( + [&](TVMArgs args, TVMRetValue *ret){ + const NodeRef& current_for = args[0]; + if (current_for.as()) { + const For* for_node = current_for.as(); + if (top_for_map_.count(for_node->loop_var.get())) { + std::vector new_if_list; + for (const Stmt& if_stmt : top_for_map_[for_node->loop_var.get()]) { + new_if_list.emplace_back(LiftIf(if_stmt)); + } + + const IfThenElse* next_if_node; + const IfThenElse* current_if_node = new_if_list.back().as(); + Stmt new_for = Stmt(); + for (size_t i = new_if_list.size() - 1; i > 0; --i) { + const Stmt current_if_stmt = IfThenElse::make(current_if_node->condition, + current_if_node->then_case, + current_if_node->else_case); + next_if_node = new_if_list[i - 1].as(); + new_for = IfThenElse::make(next_if_node->condition, current_if_stmt, next_if_node->else_case); + current_if_node = new_for.as(); + } + + if (!new_for.get()) { + const IfThenElse* first_if_node = new_if_list[0].as(); + new_for = IfThenElse::make(first_if_node->condition, + first_if_node->then_case, + first_if_node->else_case); + } + *ret = new_for; + } + } + }); + return IRTransform(stmt, nullptr, replace_top_for, {Expr("For")}); + } + + private: + void GenerateInternalData(const Stmt& stmt); + size_t GetActualFor(const Stmt& for_stmt, const Stmt& if_stmt); + Stmt LiftIf(const Stmt& if_stmt); + + LifterMap if2for_map_; + LifterMap top_for_map_; + LifterMap for_tracking_map_; + LifterMap for2if_map_; + std::vector ordered_for_list_; + VarMap cond_var_map_; + +}; + +// Check whether a given IfThenElse stmt is the first one appearing in a For stmt. +bool is_first_if(const Stmt& for_stmt, const Stmt& if_stmt) { + std::vector if_hash_list; + + PostOrderVisit(for_stmt.as()->body, [&](const NodeRef& node) { + if (node.as()) { + if_hash_list.push_back(node.hash()); + } + }); + return if_hash_list.empty() ? false : if_stmt.hash() == if_hash_list.back(); +} + +// Update upper level for loop when current for loop is modified. +Stmt update_for(const Stmt& parent_for_stmt, const Stmt& new_if_stmt) { + std::vector for_hash_list; + + PostOrderVisit(parent_for_stmt.as()->body, [&](const NodeRef& node) { + if (node.as()) { + for_hash_list.push_back(node.hash()); + } + }); + + PackedFunc replace_target_for = PackedFunc( + [&](TVMArgs args, TVMRetValue *ret){ + const NodeRef& current_for = args[0]; + if (current_for.hash() == for_hash_list.back()) { + *ret = new_if_stmt; + } + }); + + return IRTransform(parent_for_stmt, nullptr, replace_target_for, {Expr("For")}); +} + +// Remove If statement from a for statement +std::pair RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) { + const For* for_node = for_stmt.as(); + const Stmt make_for = For::make(for_node->loop_var, for_node->min, for_node->extent, + for_node->for_type, for_node->device_api, for_node->body); + + Stmt then_for; + Stmt else_for; + + PackedFunc replace_then_case = PackedFunc( + [&](TVMArgs args, TVMRetValue *ret){ + const NodeRef& node = args[0]; + if (node == if_stmt) { + *ret = node.as()->then_case; + } + }); + + PackedFunc replace_else_case = PackedFunc( + [&](TVMArgs args, TVMRetValue *ret){ + const NodeRef& node = args[0]; + if (node == if_stmt) { + *ret = node.as()->else_case; + } + }); + + then_for = IRTransform(make_for, nullptr, replace_then_case, {Expr("IfThenElse")}); + if (if_stmt.as()->else_case) { + else_for = IRTransform(make_for, nullptr, replace_else_case, {Expr("IfThenElse")}); + } + + return std::make_pair(then_for, else_for); +} + +void IfThenElseLifter::GenerateInternalData(const Stmt& stmt) { + std::unordered_map if_position_map; + std::unordered_set top_for_var_set; + + PostOrderVisit(stmt, [&](const NodeRef& node){ + const For* for_node = node.as(); + if (for_node) { + std::queue tracker; + tracker.push(for_node->body); + Stmt for_stmt = Downcast(node); + for2if_map_.insert({for_stmt.get(), std::vector()}); + while(!tracker.empty()) { + Stmt head = tracker.front(); + tracker.pop(); + if (head->is_type()) { + for (const auto& if_stmt : for2if_map_.at(head.get())) { + for2if_map_[for_stmt.get()].push_back(if_stmt); + } + } else if (head->is_type()) { + const AttrStmt* attr_node = head.as(); + tracker.push(attr_node->body); + } else if (head->is_type()) { + for2if_map_[for_stmt.get()].push_back(head); + const IfThenElse* if_node = head.as(); + tracker.push(if_node->then_case); + if (if_node->else_case) { + tracker.push(if_node->else_case); + } + + if (!cond_var_map_.count(head.get())) { + std::unordered_set new_var_set; + cond_var_map_.insert({head.get(), new_var_set}); + PostOrderVisit(if_node->condition, [&](const NodeRef& var) { + if (var.as()) { + cond_var_map_[head.get()].insert(var.get()); + } + }); + } + } else { + continue; + } + } + ordered_for_list_.emplace_back(Downcast(node)); + } + }); + + + for (const Stmt& for_stmt : ordered_for_list_) { + std::vector if_list = for2if_map_[for_stmt.get()]; + top_for_map_.insert({for_stmt.as()->loop_var.get(), if_list}); + for (const Stmt& if_stmt : if_list) { + if (!if2for_map_.count(if_stmt.get())) { + std::vector new_for_list; + if2for_map_.insert({if_stmt.get(), new_for_list}); + } + if2for_map_[if_stmt.get()].push_back(for_stmt); + } + } + + for (const auto& item : if2for_map_) { + Stmt top_for; + const Node* if_stmt = item.first; + std::vector for_list = item.second; + for (size_t i = 0; i < for_list.size(); ++i) { + const Stmt& for_stmt = for_list.at(i); + std::vector new_for_list{for_stmt}; + for_tracking_map_.insert({for_stmt.get(), new_for_list}); + if (cond_var_map_[if_stmt] + .count(for_stmt.as()->loop_var.get())) { + std::vector updated_for_list(for_list.begin(), for_list.begin() + i); + if2for_map_[if_stmt] = updated_for_list; + break; + } else { + top_for = for_stmt; + } + } + if (top_for.as()) { + if_position_map.insert({if_stmt, top_for}); + } + } + + for ( const auto& item : if_position_map) { + top_for_var_set.insert(item.second.as()->loop_var.get()); + } + + std::vector removed_for_var_list; + for (const auto& item : top_for_map_) { + const Node* top_for_var = item.first; + std::vector if_list = item.second; + if (!top_for_var_set.count(top_for_var)) { + removed_for_var_list.push_back(top_for_var); + } else { + std::vector actual_if_list; + for (const Stmt& if_stmt : if_list) { + if (if_position_map.count(if_stmt.get())) { + actual_if_list.push_back(if_stmt); + } + } + top_for_map_[top_for_var] = actual_if_list; + } + } + for (const Node* top_for_var : removed_for_var_list) { + top_for_map_.erase(top_for_var); + } +} + +size_t IfThenElseLifter::GetActualFor(const Stmt& for_stmt, const Stmt& if_stmt) { + std::vector tracked_for_list = for_tracking_map_[for_stmt.get()]; + size_t actual_idx = 0; + for (size_t i = 0; i < tracked_for_list.size(); ++i) { + const Stmt& current_for = tracked_for_list.at(tracked_for_list.size() - 1 - i); + if (is_first_if(current_for, if_stmt)) { + actual_idx = tracked_for_list.size() - 1 - i; + break; + } + } + return actual_idx; +} + +Stmt IfThenElseLifter::LiftIf(const Stmt& if_stmt) { + Stmt new_if = if_stmt; + + for (size_t i = 0; i < if2for_map_[if_stmt.get()].size(); ++i) { + const Stmt& for_stmt = if2for_map_[if_stmt.get()].at(i); + size_t actual_for_idx = GetActualFor(for_stmt, new_if); + const Stmt& actual_for_node = for_tracking_map_[for_stmt.get()].at(actual_for_idx); + auto generated_for_pair = RemoveIf(actual_for_node, new_if); + const Stmt& then_for = generated_for_pair.first; + const Stmt& else_for = generated_for_pair.second;; + for_tracking_map_[for_stmt.get()].at(actual_for_idx) = then_for; + + if (else_for.get()) { + for_tracking_map_[for_stmt.get()].push_back(else_for); + } + new_if = IfThenElse::make(new_if.as()->condition, then_for, else_for); + if (i < if2for_map_[if_stmt.get()].size() - 1) { + const Stmt& original_next_for = if2for_map_[if_stmt.get()].at(i + 1); + const Stmt& actual_next_for = for_tracking_map_[original_next_for.get()].at(actual_for_idx); + Stmt update_for_stmt = update_for(actual_next_for, new_if); + + for_tracking_map_[original_next_for.get()].at(actual_for_idx) = update_for_stmt; + } + } + return new_if; +} + + +Stmt LiftIfThenElse(Stmt stmt) { + return IfThenElseLifter().VisitAndMutate(stmt); +} + +} // namespace ir +} // namespace tvm \ No newline at end of file diff --git a/tests/python/unittest/test_pass_lift_if.py b/tests/python/unittest/test_pass_lift_if.py new file mode 100644 index 000000000000..e1e56f3d5927 --- /dev/null +++ b/tests/python/unittest/test_pass_lift_if.py @@ -0,0 +1,147 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm + + +var_list = [] + +def verify_structure(stmt, expected_struct): + node_dict = {} + struct = {} + def _extract_vars(op): + global var_list + if isinstance(op, tvm.expr.Var): + var_list.append(op.name) + + def _visit(op): + key = op + if isinstance(op, tvm.stmt.IfThenElse): + global var_list + tvm.ir_pass.PostOrderVisit(op.condition, _extract_vars) + val = [(op.then_case, op.else_case), ("IfThenElse", tuple(var_list))] + var_list.clear() + elif isinstance(op, tvm.stmt.For): + val = [(op.body,), ("For", op.loop_var.name)] + elif isinstance(op, tvm.stmt.AttrStmt): + val = [(op.body,), ("AttrStmt", op.attr_key, int(op.value))] + else: + return + node_dict[key] = val + + tvm.ir_pass.PostOrderVisit(stmt, _visit) + for key, val in node_dict.items(): + struct[val[1]] = tuple(node_dict[child][1] if child in node_dict + else None for child in val[0]) + + assert struct == expected_struct, "Structure mismatch: expect %s but got %s" \ + % (expected_struct, struct) + +def test_basic(): + ib = tvm.ir_builder.create() + l = tvm.var('l') + m = tvm.var('m') + n = tvm.var('n') + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + with ib.for_range(0, n, "k") as k: + with ib.if_scope(ib.likely(i < 2)): + ib.emit(tvm.make.Evaluate(m)) + with ib.else_scope(): + ib.emit(tvm.make.Evaluate(n)) + + stmt = ib.get() + new_stmt = tvm.ir_pass.LiftIfThenElse(stmt) + expected_struct = {('For', 'k'): (None,), ('For', 'j'): (('For', 'k'),), + ('IfThenElse', ('i',)): (('For', 'j'), ('For', 'j')), + ('For', 'i'): (('IfThenElse', ('i',)),)} + verify_structure(new_stmt, expected_struct) + +def test_no_else(): + ib = tvm.ir_builder.create() + l = tvm.var('l') + m = tvm.var('m') + n = tvm.var('n') + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + with ib.for_range(0, n, "k") as k: + with ib.if_scope(ib.likely(i < 2)): + ib.emit(tvm.make.Evaluate(m)) + + stmt = ib.get() + new_stmt = tvm.ir_pass.LiftIfThenElse(stmt) + expected_struct = {('For', 'k'): (None,), ('For', 'j'): (('For', 'k'),), + ('IfThenElse', ('i',)): (('For', 'j'), None), + ('For', 'i'): (('IfThenElse', ('i',)),)} + verify_structure(new_stmt, expected_struct) + +def test_attr_stmt(): + ib = tvm.ir_builder.create() + dshape = (32, 64) + data = ib.pointer("float32", name="data") + l = tvm.var('l') + m = tvm.var('m') + n = tvm.var('n') + + tx = tvm.thread_axis("threadIdx.x") + bx = tvm.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", dshape[0]) + ib.scope_attr(bx, "thread_extent", dshape[1]) + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + with ib.for_range(0, n, "k") as k: + with ib.if_scope(tvm.any(i < 4, j >= 8)): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.5 + with ib.else_scope(): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.0 + + stmt = ib.get() + new_stmt = tvm.ir_pass.LiftIfThenElse(stmt) + expected_struct = {('For', 'k'): (None,), ('IfThenElse', ('i', 'j')): (('For', 'k'), ('For', 'k')), + ('For', 'j'): (('IfThenElse', ('i', 'j')),), ('For', 'i'): (('For', 'j'),), + ('AttrStmt', 'thread_extent', 64): (('For', 'i'),), + ('AttrStmt', 'thread_extent', 32): (('AttrStmt', 'thread_extent', 64),)} + verify_structure(new_stmt, expected_struct) + +def test_nested_for(): + ib = tvm.ir_builder.create() + data = ib.pointer("float32", name="data") + + + with ib.for_range(0, 5, "i") as i: + with ib.for_range(0, 10, "j") as j: + with ib.if_scope(i >= 3): + data[i * 3 + j] = data[i *3 + j] + 0.5 + with ib.for_range(0, 15, "k") as k: + with ib.for_range(0, 20, "l") as l: + with ib.if_scope(tvm.any(i < 4, j >= 8)): + data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 2 + with ib.else_scope(): + data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 1.5 + + stmt = ib.get() + new_stmt = tvm.ir_pass.LiftIfThenElse(stmt) + expected_struct = {('IfThenElse', ('i', 'j')): (None, None), ('For', 'l'): (('IfThenElse', ('i', 'j')),), + ('For', 'k'): (('For', 'l'),), ('For', 'j'): (None,), ('IfThenElse', ('i',)): (('For', 'j'), None), + ('For', 'i'): (('IfThenElse', ('i',)),)} + verify_structure(new_stmt, expected_struct) + + +if __name__ == "__main__": + test_basic() + test_no_else() + test_attr_stmt() + test_nested_for() \ No newline at end of file From 1ad29c4039e98a9d202642b4f0354bac003391e4 Mon Sep 17 00:00:00 2001 From: Wang Yao Date: Fri, 30 Aug 2019 19:36:32 -0700 Subject: [PATCH 2/8] Add more comments --- src/pass/lift_if_then_else.cc | 264 +++++++++++++++------ tests/python/unittest/test_pass_lift_if.py | 40 +++- 2 files changed, 224 insertions(+), 80 deletions(-) diff --git a/src/pass/lift_if_then_else.cc b/src/pass/lift_if_then_else.cc index 3ce1861b891f..f8138d05b18e 100644 --- a/src/pass/lift_if_then_else.cc +++ b/src/pass/lift_if_then_else.cc @@ -1,3 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file lift_if_then_else.cc + */ #include #include #include @@ -16,65 +39,83 @@ namespace ir { using LifterMap = std::unordered_map>; using VarMap = std::unordered_map>; -class IfThenElseLifter : public IRMutator { +/* + * This pass tries to lift IfThenElse stmt out of For loop if condition is loop invariant. + * For example, given the following block: + * for (i = 0; i < 3; i++) + * for (j = 0; j < 4; j++) + * for (k = 0; k < 5; k++) + * if (likely(i*2 < 4)) + * A[3*i+2j+k] = B[7*i+3j+k] + * + * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt. + * Then we lift IfThenElse stmt by one For stmt each step: + * + * Step 1: + * for (i = 0; i < 3; i++) + * for (j = 0; j < 4; j++) + * if (likely(i*2 < 4)) + * for (k = 0; k < 5; k++) + * A[3*i+2j+k] = B[7*i+3j+k] + * + * Step 2: + * for (i = 0; i < 3; i++) + * if (likely(i*2 < 4)) + * for (j = 0; j < 4; j++) + * for (k = 0; k < 5; k++) + * A[3*i+2j+k] = B[7*i+3j+k] + * + * In this pass, we only continue detecting possible lifting chance when visiting For, + * IfThenElse or AttrStmt Node. For example, for the following block: + * for (i = 0; i < 3; i++) + * for (j = 0; j < 4; j++) + * A[i + j] = A[i + j] - 1 + * for (k = 0; k < 5; k++) + * if (likely(i*2 < 4)) + * A[3*i+2j+k] = B[7*i+3j+k] + * + * Only the For with k variable will be considered and the resulting stmt would be: + * for (i = 0; i < 3; i++) + * for (j = 0; j < 4; j++) + * A[i + j] = A[i + j] - 1 + * if (likely(i*2 < 4)) + * for (k = 0; k < 5; k++) + * A[3*i+2j+k] = B[7*i+3j+k] + * + * This pass doesn't do lifting for consecutive IfThenElse stmt. The following + * block won't be optimized: + * for (i = 0; i < 3; i++) + * for (j = 0; j < 4; j++) + * for (k = 0; k < 5; k++) + * if (likely(i*2 < 4)) + * A[3*i+2j+k] = B[7*i+3j+k] + * if (likely(j > 2)) + * A[i+j+k] = B[i+j+k] + * + */ +class IfThenElseLifter { public: Stmt VisitAndMutate(const Stmt& stmt) { GenerateInternalData(stmt); return PostOrderMutate(stmt); } - Stmt PostOrderMutate(const Stmt& stmt) { - PackedFunc replace_top_for = PackedFunc( - [&](TVMArgs args, TVMRetValue *ret){ - const NodeRef& current_for = args[0]; - if (current_for.as()) { - const For* for_node = current_for.as(); - if (top_for_map_.count(for_node->loop_var.get())) { - std::vector new_if_list; - for (const Stmt& if_stmt : top_for_map_[for_node->loop_var.get()]) { - new_if_list.emplace_back(LiftIf(if_stmt)); - } - - const IfThenElse* next_if_node; - const IfThenElse* current_if_node = new_if_list.back().as(); - Stmt new_for = Stmt(); - for (size_t i = new_if_list.size() - 1; i > 0; --i) { - const Stmt current_if_stmt = IfThenElse::make(current_if_node->condition, - current_if_node->then_case, - current_if_node->else_case); - next_if_node = new_if_list[i - 1].as(); - new_for = IfThenElse::make(next_if_node->condition, current_if_stmt, next_if_node->else_case); - current_if_node = new_for.as(); - } - - if (!new_for.get()) { - const IfThenElse* first_if_node = new_if_list[0].as(); - new_for = IfThenElse::make(first_if_node->condition, - first_if_node->then_case, - first_if_node->else_case); - } - *ret = new_for; - } - } - }); - return IRTransform(stmt, nullptr, replace_top_for, {Expr("For")}); - } - private: void GenerateInternalData(const Stmt& stmt); - size_t GetActualFor(const Stmt& for_stmt, const Stmt& if_stmt); + Stmt PostOrderMutate(const Stmt& stmt); + size_t GetUpdatedFor(const Stmt& for_stmt, const Stmt& if_stmt); Stmt LiftIf(const Stmt& if_stmt); LifterMap if2for_map_; - LifterMap top_for_map_; + LifterMap top_for_var_map_; LifterMap for_tracking_map_; LifterMap for2if_map_; - std::vector ordered_for_list_; VarMap cond_var_map_; - + std::vector ordered_for_list_; }; -// Check whether a given IfThenElse stmt is the first one appearing in a For stmt. +// Check whether a given IfThenElse stmt is the first one appearing +// in a For stmt. bool is_first_if(const Stmt& for_stmt, const Stmt& if_stmt) { std::vector if_hash_list; @@ -86,7 +127,9 @@ bool is_first_if(const Stmt& for_stmt, const Stmt& if_stmt) { return if_hash_list.empty() ? false : if_stmt.hash() == if_hash_list.back(); } -// Update upper level for loop when current for loop is modified. +// Update upper level For node when current For node is modified. +// With this function we only need to visit and mutate top level For node +// in the main VisitAndMutate function. Stmt update_for(const Stmt& parent_for_stmt, const Stmt& new_if_stmt) { std::vector for_hash_list; @@ -104,15 +147,13 @@ Stmt update_for(const Stmt& parent_for_stmt, const Stmt& new_if_stmt) { } }); - return IRTransform(parent_for_stmt, nullptr, replace_target_for, {Expr("For")}); + return IRTransform(parent_for_stmt, nullptr, replace_target_for, + {Expr("For")}); } -// Remove If statement from a for statement +// Remove IfThenElse node from a For node. +// A pair of For nodes will be generated. std::pair RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) { - const For* for_node = for_stmt.as(); - const Stmt make_for = For::make(for_node->loop_var, for_node->min, for_node->extent, - for_node->for_type, for_node->device_api, for_node->body); - Stmt then_for; Stmt else_for; @@ -132,18 +173,22 @@ std::pair RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) { } }); - then_for = IRTransform(make_for, nullptr, replace_then_case, {Expr("IfThenElse")}); + then_for = IRTransform(for_stmt, nullptr, replace_then_case, + {Expr("IfThenElse")}); if (if_stmt.as()->else_case) { - else_for = IRTransform(make_for, nullptr, replace_else_case, {Expr("IfThenElse")}); + else_for = IRTransform(for_stmt, nullptr, replace_else_case, + {Expr("IfThenElse")}); } return std::make_pair(then_for, else_for); } +// Generate internal data structures for lifter. void IfThenElseLifter::GenerateInternalData(const Stmt& stmt) { std::unordered_map if_position_map; std::unordered_set top_for_var_set; + // Locate all For nodes and capture child IfThenElse nodes. PostOrderVisit(stmt, [&](const NodeRef& node){ const For* for_node = node.as(); if (for_node) { @@ -151,7 +196,7 @@ void IfThenElseLifter::GenerateInternalData(const Stmt& stmt) { tracker.push(for_node->body); Stmt for_stmt = Downcast(node); for2if_map_.insert({for_stmt.get(), std::vector()}); - while(!tracker.empty()) { + while (!tracker.empty()) { Stmt head = tracker.front(); tracker.pop(); if (head->is_type()) { @@ -169,12 +214,13 @@ void IfThenElseLifter::GenerateInternalData(const Stmt& stmt) { tracker.push(if_node->else_case); } + // Record condition variables. if (!cond_var_map_.count(head.get())) { std::unordered_set new_var_set; cond_var_map_.insert({head.get(), new_var_set}); - PostOrderVisit(if_node->condition, [&](const NodeRef& var) { - if (var.as()) { - cond_var_map_[head.get()].insert(var.get()); + PostOrderVisit(if_node->condition, [&](const NodeRef& cond_node) { + if (cond_node.as()) { + cond_var_map_[head.get()].insert(cond_node.get()); } }); } @@ -187,15 +233,17 @@ void IfThenElseLifter::GenerateInternalData(const Stmt& stmt) { }); + // Create candidate For nodes to be lifted for each IfThenElse node. for (const Stmt& for_stmt : ordered_for_list_) { std::vector if_list = for2if_map_[for_stmt.get()]; - top_for_map_.insert({for_stmt.as()->loop_var.get(), if_list}); + top_for_var_map_.insert({for_stmt.as()->loop_var.get(), if_list}); for (const Stmt& if_stmt : if_list) { - if (!if2for_map_.count(if_stmt.get())) { + const Node* if_node = if_stmt.get(); + if (!if2for_map_.count(if_node)) { std::vector new_for_list; - if2for_map_.insert({if_stmt.get(), new_for_list}); + if2for_map_.insert({if_node, new_for_list}); } - if2for_map_[if_stmt.get()].push_back(for_stmt); + if2for_map_[if_node].push_back(for_stmt); } } @@ -209,7 +257,8 @@ void IfThenElseLifter::GenerateInternalData(const Stmt& stmt) { for_tracking_map_.insert({for_stmt.get(), new_for_list}); if (cond_var_map_[if_stmt] .count(for_stmt.as()->loop_var.get())) { - std::vector updated_for_list(for_list.begin(), for_list.begin() + i); + std::vector updated_for_list(for_list.begin(), + for_list.begin() + i); if2for_map_[if_stmt] = updated_for_list; break; } else { @@ -221,12 +270,14 @@ void IfThenElseLifter::GenerateInternalData(const Stmt& stmt) { } } - for ( const auto& item : if_position_map) { + for (const auto& item : if_position_map) { top_for_var_set.insert(item.second.as()->loop_var.get()); } + // For each IfThenElse node, find the highest For node which + // is loop invariant. std::vector removed_for_var_list; - for (const auto& item : top_for_map_) { + for (const auto& item : top_for_var_map_) { const Node* top_for_var = item.first; std::vector if_list = item.second; if (!top_for_var_set.count(top_for_var)) { @@ -238,58 +289,115 @@ void IfThenElseLifter::GenerateInternalData(const Stmt& stmt) { actual_if_list.push_back(if_stmt); } } - top_for_map_[top_for_var] = actual_if_list; + top_for_var_map_[top_for_var] = actual_if_list; } } for (const Node* top_for_var : removed_for_var_list) { - top_for_map_.erase(top_for_var); + top_for_var_map_.erase(top_for_var); } } -size_t IfThenElseLifter::GetActualFor(const Stmt& for_stmt, const Stmt& if_stmt) { +// When we try to mutate a For node, some child For nodes can have already +// been mutated. This function is to get the updated For node and further +// lifting can be done based on this new node. +// We keep all For nodes tracing in for_tracking_map_. When we get a +// lifted IfThenElse, we matching it with tracing For nodes to pick +// the updated one. +size_t IfThenElseLifter::GetUpdatedFor(const Stmt& for_stmt, + const Stmt& if_stmt) { std::vector tracked_for_list = for_tracking_map_[for_stmt.get()]; - size_t actual_idx = 0; + size_t updated_for_idx = 0; for (size_t i = 0; i < tracked_for_list.size(); ++i) { - const Stmt& current_for = tracked_for_list.at(tracked_for_list.size() - 1 - i); + const Stmt& current_for = + tracked_for_list.at(tracked_for_list.size() - 1 - i); if (is_first_if(current_for, if_stmt)) { - actual_idx = tracked_for_list.size() - 1 - i; + updated_for_idx = tracked_for_list.size() - 1 - i; break; } } - return actual_idx; + return updated_for_idx; } +// Lift a IfThenElse node as high as possible. +// This function iterates on all candidate For nodes. For each For node, +// it first removes IfThenElse nodes. Then it generates a new IfThenElse +// node using mutated For nodes. Stmt IfThenElseLifter::LiftIf(const Stmt& if_stmt) { Stmt new_if = if_stmt; for (size_t i = 0; i < if2for_map_[if_stmt.get()].size(); ++i) { const Stmt& for_stmt = if2for_map_[if_stmt.get()].at(i); - size_t actual_for_idx = GetActualFor(for_stmt, new_if); - const Stmt& actual_for_node = for_tracking_map_[for_stmt.get()].at(actual_for_idx); - auto generated_for_pair = RemoveIf(actual_for_node, new_if); + size_t updated_for_idx = GetUpdatedFor(for_stmt, new_if); + const Stmt& updated_for_node = + for_tracking_map_[for_stmt.get()].at(updated_for_idx); + auto generated_for_pair = RemoveIf(updated_for_node, new_if); const Stmt& then_for = generated_for_pair.first; const Stmt& else_for = generated_for_pair.second;; - for_tracking_map_[for_stmt.get()].at(actual_for_idx) = then_for; + for_tracking_map_[for_stmt.get()].at(updated_for_idx) = then_for; if (else_for.get()) { for_tracking_map_[for_stmt.get()].push_back(else_for); } - new_if = IfThenElse::make(new_if.as()->condition, then_for, else_for); + new_if = IfThenElse::make(new_if.as()->condition, + then_for, else_for); if (i < if2for_map_[if_stmt.get()].size() - 1) { const Stmt& original_next_for = if2for_map_[if_stmt.get()].at(i + 1); - const Stmt& actual_next_for = for_tracking_map_[original_next_for.get()].at(actual_for_idx); + const Stmt& actual_next_for = + for_tracking_map_[original_next_for.get()].at(updated_for_idx); Stmt update_for_stmt = update_for(actual_next_for, new_if); - for_tracking_map_[original_next_for.get()].at(actual_for_idx) = update_for_stmt; + for_tracking_map_[original_next_for.get()]. + at(updated_for_idx) = update_for_stmt; } } return new_if; } +// Mutate For nodes in post order DFS manner. +Stmt IfThenElseLifter::PostOrderMutate(const Stmt& stmt) { + PackedFunc replace_top_for = PackedFunc( + [&](TVMArgs args, TVMRetValue *ret){ + const NodeRef& current_for = args[0]; + if (current_for.as()) { + const For* for_node = current_for.as(); + if (top_for_var_map_.count(for_node->loop_var.get())) { + std::vector new_if_list; + for (const Stmt& if_stmt : + top_for_var_map_[for_node->loop_var.get()]) { + new_if_list.emplace_back(LiftIf(if_stmt)); + } + + const IfThenElse* next_if_node; + const IfThenElse* current_if_node = + new_if_list.back().as(); + Stmt new_for = Stmt(); + for (size_t i = new_if_list.size() - 1; i > 0; --i) { + const Stmt current_if_stmt = + IfThenElse::make(current_if_node->condition, + current_if_node->then_case, + current_if_node->else_case); + next_if_node = new_if_list[i - 1].as(); + new_for = IfThenElse::make(next_if_node->condition, current_if_stmt, + next_if_node->else_case); + current_if_node = new_for.as(); + } + + if (!new_for.get()) { + const IfThenElse* first_if_node = new_if_list[0].as(); + new_for = IfThenElse::make(first_if_node->condition, + first_if_node->then_case, + first_if_node->else_case); + } + *ret = new_for; + } + } + }); + return IRTransform(stmt, nullptr, replace_top_for, {Expr("For")}); +} Stmt LiftIfThenElse(Stmt stmt) { return IfThenElseLifter().VisitAndMutate(stmt); } } // namespace ir -} // namespace tvm \ No newline at end of file +} // namespace tvm diff --git a/tests/python/unittest/test_pass_lift_if.py b/tests/python/unittest/test_pass_lift_if.py index e1e56f3d5927..bf2d7649f729 100644 --- a/tests/python/unittest/test_pass_lift_if.py +++ b/tests/python/unittest/test_pass_lift_if.py @@ -124,7 +124,7 @@ def test_nested_for(): with ib.for_range(0, 5, "i") as i: with ib.for_range(0, 10, "j") as j: with ib.if_scope(i >= 3): - data[i * 3 + j] = data[i *3 + j] + 0.5 + data[i * 3 + j] = data[i * 3 + j] + 0.5 with ib.for_range(0, 15, "k") as k: with ib.for_range(0, 20, "l") as l: with ib.if_scope(tvm.any(i < 4, j >= 8)): @@ -139,9 +139,45 @@ def test_nested_for(): ('For', 'i'): (('IfThenElse', ('i',)),)} verify_structure(new_stmt, expected_struct) +def test_block(): + ib = tvm.ir_builder.create() + data = ib.pointer("float32", name="data") + + n = tvm.var("n") + + + with ib.for_range(0, 5, "i") as i: + with ib.for_range(0, 10, "j") as j: + with ib.if_scope(i >= 3): + data[i * 3 + j] = data[i * 3 + j] + 0.5 + with ib.for_range(0, 15, "k") as k: + with ib.for_range(0, 20, "l") as l: + with ib.if_scope(tvm.any(i < 4, j >= 8)): + data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 2 + with ib.else_scope(): + data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 1.5 + with ib.if_scope(j <5): + data[i * 3 + j + k + l] = data[i * 3 + j + k + l] - 1 + + + with ib.for_range(0, 5, "i") as i: + with ib.for_range(0, 10, "j") as j: + with ib.for_range(0, 15, "k") as k: + with ib.if_scope(n >= 3): + data[i * 3 + j + k] = data[i * 3 + j + k] + 0.6 + + stmt = ib.get() + new_stmt = tvm.ir_pass.LiftIfThenElse(stmt) + expected_struct = {('IfThenElse', ('i', 'j')): (None, None), ('IfThenElse', ('j',)): (None, None), + ('For', 'l'): (None,), ('For', 'k'): (None,), ('For', 'j'): (('For', 'j'),), + ('IfThenElse', ('i',)): (('For', 'j'), None), ('For', 'i'): (('IfThenElse', ('i',)),), + ('IfThenElse', ('n',)): (('For', 'j'), None)} + verify_structure(new_stmt, expected_struct) + if __name__ == "__main__": test_basic() test_no_else() test_attr_stmt() - test_nested_for() \ No newline at end of file + test_nested_for() + test_block() \ No newline at end of file From 443e7bf59ba81ba14d4bd7c7248214521a33c865 Mon Sep 17 00:00:00 2001 From: Wang Yao Date: Tue, 3 Sep 2019 13:24:32 -0700 Subject: [PATCH 3/8] Rename and refactor --- include/tvm/ir_pass.h | 6 +- src/api/api_pass.cc | 2 +- src/pass/lift_if_then_else.cc | 113 ++++++++++-------- ..._pass_lift_if.py => test_pass_hoist_if.py} | 10 +- 4 files changed, 75 insertions(+), 56 deletions(-) rename tests/python/unittest/{test_pass_lift_if.py => test_pass_hoist_if.py} (96%) diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 64e460689f50..03078b8be41f 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -378,11 +378,11 @@ Stmt LowerStorageAccessInfo(Stmt stmt); Stmt DecorateDeviceScope(Stmt stmt); /*! - * \brief Loop invariant code motion which locates and lifts if statements. - * \param stmt The stmt to do if statement iifting. + * \brief Loop invariant code motion which locates and hoists if statements. + * \param stmt The stmt to do if statement hoisting. * \return Transformed stmt. */ -Stmt LiftIfThenElse(Stmt stmt); +Stmt HoistIfThenElse(Stmt stmt); /*! * \brief Make an user callable API LoweredFunc. diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index 08f0afca317a..d2352496c2b4 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -160,6 +160,6 @@ REGISTER_PASS(VerifyGPUCode); REGISTER_PASS(DecorateDeviceScope); REGISTER_PASS(InstrumentBoundCheckers); REGISTER_PASS(VerifyCompactBuffer); -REGISTER_PASS(LiftIfThenElse); +REGISTER_PASS(HoistIfThenElse); } // namespace ir } // namespace tvm diff --git a/src/pass/lift_if_then_else.cc b/src/pass/lift_if_then_else.cc index f8138d05b18e..83ed0ffef106 100644 --- a/src/pass/lift_if_then_else.cc +++ b/src/pass/lift_if_then_else.cc @@ -19,7 +19,7 @@ /*! * Copyright (c) 2019 by Contributors - * \file lift_if_then_else.cc + * \file hoist_if_then_else.cc */ #include #include @@ -36,11 +36,11 @@ namespace tvm { namespace ir { -using LifterMap = std::unordered_map>; +using HoistMap = std::unordered_map>; using VarMap = std::unordered_map>; /* - * This pass tries to lift IfThenElse stmt out of For loop if condition is loop invariant. + * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant. * For example, given the following block: * for (i = 0; i < 3; i++) * for (j = 0; j < 4; j++) @@ -49,7 +49,7 @@ using VarMap = std::unordered_map>; * A[3*i+2j+k] = B[7*i+3j+k] * * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt. - * Then we lift IfThenElse stmt by one For stmt each step: + * Then we hoist IfThenElse stmt by one For stmt each step: * * Step 1: * for (i = 0; i < 3; i++) @@ -65,7 +65,7 @@ using VarMap = std::unordered_map>; * for (k = 0; k < 5; k++) * A[3*i+2j+k] = B[7*i+3j+k] * - * In this pass, we only continue detecting possible lifting chance when visiting For, + * In this pass, we only continue detecting possible hoisting chance when visiting For, * IfThenElse or AttrStmt Node. For example, for the following block: * for (i = 0; i < 3; i++) * for (j = 0; j < 4; j++) @@ -82,7 +82,7 @@ using VarMap = std::unordered_map>; * for (k = 0; k < 5; k++) * A[3*i+2j+k] = B[7*i+3j+k] * - * This pass doesn't do lifting for consecutive IfThenElse stmt. The following + * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following * block won't be optimized: * for (i = 0; i < 3; i++) * for (j = 0; j < 4; j++) @@ -93,23 +93,25 @@ using VarMap = std::unordered_map>; * A[i+j+k] = B[i+j+k] * */ -class IfThenElseLifter { +class IfThenElseLHoist { public: Stmt VisitAndMutate(const Stmt& stmt) { - GenerateInternalData(stmt); + SelectCandidates(stmt); + LocateTopFor(); return PostOrderMutate(stmt); } private: - void GenerateInternalData(const Stmt& stmt); + void SelectCandidates(const Stmt& stmt); + void LocateTopFor(); Stmt PostOrderMutate(const Stmt& stmt); size_t GetUpdatedFor(const Stmt& for_stmt, const Stmt& if_stmt); - Stmt LiftIf(const Stmt& if_stmt); + Stmt HoistIf(const Stmt& if_stmt); - LifterMap if2for_map_; - LifterMap top_for_var_map_; - LifterMap for_tracking_map_; - LifterMap for2if_map_; + HoistMap if2for_map_; + HoistMap top_for_var_map_; + HoistMap for_tracking_map_; + HoistMap for2if_map_; VarMap cond_var_map_; std::vector ordered_for_list_; }; @@ -117,32 +119,38 @@ class IfThenElseLifter { // Check whether a given IfThenElse stmt is the first one appearing // in a For stmt. bool is_first_if(const Stmt& for_stmt, const Stmt& if_stmt) { - std::vector if_hash_list; + std::vector if_node_list; + const For* for_node = for_stmt.as(); + CHECK(for_node); + CHECK(if_stmt.as()); - PostOrderVisit(for_stmt.as()->body, [&](const NodeRef& node) { + PostOrderVisit(for_node->body, [&](const NodeRef& node) { if (node.as()) { - if_hash_list.push_back(node.hash()); + if_node_list.push_back(node.get()); } }); - return if_hash_list.empty() ? false : if_stmt.hash() == if_hash_list.back(); + return if_node_list.empty() ? false : if_stmt.get() == if_node_list.back(); } // Update upper level For node when current For node is modified. // With this function we only need to visit and mutate top level For node // in the main VisitAndMutate function. Stmt update_for(const Stmt& parent_for_stmt, const Stmt& new_if_stmt) { - std::vector for_hash_list; + std::vector for_node_list; + const For* parent_for_node = parent_for_stmt.as(); + CHECK(parent_for_node); + CHECK(new_if_stmt.as()); - PostOrderVisit(parent_for_stmt.as()->body, [&](const NodeRef& node) { + PostOrderVisit(parent_for_node->body, [&](const NodeRef& node) { if (node.as()) { - for_hash_list.push_back(node.hash()); + for_node_list.push_back(node.get()); } }); PackedFunc replace_target_for = PackedFunc( [&](TVMArgs args, TVMRetValue *ret){ const NodeRef& current_for = args[0]; - if (current_for.hash() == for_hash_list.back()) { + if (current_for.get() == for_node_list.back()) { *ret = new_if_stmt; } }); @@ -156,6 +164,7 @@ Stmt update_for(const Stmt& parent_for_stmt, const Stmt& new_if_stmt) { std::pair RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) { Stmt then_for; Stmt else_for; + CHECK(if_stmt.as()); PackedFunc replace_then_case = PackedFunc( [&](TVMArgs args, TVMRetValue *ret){ @@ -183,12 +192,8 @@ std::pair RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) { return std::make_pair(then_for, else_for); } -// Generate internal data structures for lifter. -void IfThenElseLifter::GenerateInternalData(const Stmt& stmt) { - std::unordered_map if_position_map; - std::unordered_set top_for_var_set; - - // Locate all For nodes and capture child IfThenElse nodes. +// Locate all For nodes and capture child IfThenElse nodes. +void IfThenElseLHoist::SelectCandidates(const Stmt& stmt) { PostOrderVisit(stmt, [&](const NodeRef& node){ const For* for_node = node.as(); if (for_node) { @@ -231,12 +236,20 @@ void IfThenElseLifter::GenerateInternalData(const Stmt& stmt) { ordered_for_list_.emplace_back(Downcast(node)); } }); +} +// For each IfThenElse node, find the highest For node which +// meets loop invariant condition. +void IfThenElseLHoist::LocateTopFor() { + std::unordered_map if_position_map; + std::unordered_set top_for_var_set; - // Create candidate For nodes to be lifted for each IfThenElse node. + // Create IfThenElse -> For map. for (const Stmt& for_stmt : ordered_for_list_) { std::vector if_list = for2if_map_[for_stmt.get()]; - top_for_var_map_.insert({for_stmt.as()->loop_var.get(), if_list}); + const For* for_node = for_stmt.as(); + CHECK(for_node); + top_for_var_map_.insert({for_node->loop_var.get(), if_list}); for (const Stmt& if_stmt : if_list) { const Node* if_node = if_stmt.get(); if (!if2for_map_.count(if_node)) { @@ -247,16 +260,19 @@ void IfThenElseLifter::GenerateInternalData(const Stmt& stmt) { } } + // Locate the highest For node which is loop invariant. for (const auto& item : if2for_map_) { Stmt top_for; const Node* if_stmt = item.first; std::vector for_list = item.second; for (size_t i = 0; i < for_list.size(); ++i) { const Stmt& for_stmt = for_list.at(i); + const For* for_node = for_stmt.as(); + CHECK(for_node); std::vector new_for_list{for_stmt}; for_tracking_map_.insert({for_stmt.get(), new_for_list}); if (cond_var_map_[if_stmt] - .count(for_stmt.as()->loop_var.get())) { + .count(for_node->loop_var.get())) { std::vector updated_for_list(for_list.begin(), for_list.begin() + i); if2for_map_[if_stmt] = updated_for_list; @@ -266,7 +282,7 @@ void IfThenElseLifter::GenerateInternalData(const Stmt& stmt) { } } if (top_for.as()) { - if_position_map.insert({if_stmt, top_for}); + if_position_map.insert({if_stmt, top_for}); } } @@ -274,8 +290,6 @@ void IfThenElseLifter::GenerateInternalData(const Stmt& stmt) { top_for_var_set.insert(item.second.as()->loop_var.get()); } - // For each IfThenElse node, find the highest For node which - // is loop invariant. std::vector removed_for_var_list; for (const auto& item : top_for_var_map_) { const Node* top_for_var = item.first; @@ -299,11 +313,11 @@ void IfThenElseLifter::GenerateInternalData(const Stmt& stmt) { // When we try to mutate a For node, some child For nodes can have already // been mutated. This function is to get the updated For node and further -// lifting can be done based on this new node. +// hoisting can be done based on this new node. // We keep all For nodes tracing in for_tracking_map_. When we get a -// lifted IfThenElse, we matching it with tracing For nodes to pick +// hoisted IfThenElse, we matching it with tracing For nodes to pick // the updated one. -size_t IfThenElseLifter::GetUpdatedFor(const Stmt& for_stmt, +size_t IfThenElseLHoist::GetUpdatedFor(const Stmt& for_stmt, const Stmt& if_stmt) { std::vector tracked_for_list = for_tracking_map_[for_stmt.get()]; size_t updated_for_idx = 0; @@ -318,11 +332,11 @@ size_t IfThenElseLifter::GetUpdatedFor(const Stmt& for_stmt, return updated_for_idx; } -// Lift a IfThenElse node as high as possible. +// Hoist an IfThenElse node as high as possible. // This function iterates on all candidate For nodes. For each For node, // it first removes IfThenElse nodes. Then it generates a new IfThenElse // node using mutated For nodes. -Stmt IfThenElseLifter::LiftIf(const Stmt& if_stmt) { +Stmt IfThenElseLHoist::HoistIf(const Stmt& if_stmt) { Stmt new_if = if_stmt; for (size_t i = 0; i < if2for_map_[if_stmt.get()].size(); ++i) { @@ -338,8 +352,10 @@ Stmt IfThenElseLifter::LiftIf(const Stmt& if_stmt) { if (else_for.get()) { for_tracking_map_[for_stmt.get()].push_back(else_for); } - new_if = IfThenElse::make(new_if.as()->condition, - then_for, else_for); + + const IfThenElse* new_if_node = new_if.as(); + CHECK(new_if_node); + new_if = IfThenElse::make(new_if_node->condition, then_for, else_for); if (i < if2for_map_[if_stmt.get()].size() - 1) { const Stmt& original_next_for = if2for_map_[if_stmt.get()].at(i + 1); const Stmt& actual_next_for = @@ -354,17 +370,17 @@ Stmt IfThenElseLifter::LiftIf(const Stmt& if_stmt) { } // Mutate For nodes in post order DFS manner. -Stmt IfThenElseLifter::PostOrderMutate(const Stmt& stmt) { +Stmt IfThenElseLHoist::PostOrderMutate(const Stmt& stmt) { PackedFunc replace_top_for = PackedFunc( [&](TVMArgs args, TVMRetValue *ret){ const NodeRef& current_for = args[0]; - if (current_for.as()) { - const For* for_node = current_for.as(); + const For* for_node = current_for.as(); + if (for_node) { if (top_for_var_map_.count(for_node->loop_var.get())) { std::vector new_if_list; for (const Stmt& if_stmt : top_for_var_map_[for_node->loop_var.get()]) { - new_if_list.emplace_back(LiftIf(if_stmt)); + new_if_list.emplace_back(HoistIf(if_stmt)); } const IfThenElse* next_if_node; @@ -372,11 +388,13 @@ Stmt IfThenElseLifter::PostOrderMutate(const Stmt& stmt) { new_if_list.back().as(); Stmt new_for = Stmt(); for (size_t i = new_if_list.size() - 1; i > 0; --i) { + CHECK(current_if_node); const Stmt current_if_stmt = IfThenElse::make(current_if_node->condition, current_if_node->then_case, current_if_node->else_case); next_if_node = new_if_list[i - 1].as(); + CHECK(next_if_node); new_for = IfThenElse::make(next_if_node->condition, current_if_stmt, next_if_node->else_case); current_if_node = new_for.as(); @@ -384,6 +402,7 @@ Stmt IfThenElseLifter::PostOrderMutate(const Stmt& stmt) { if (!new_for.get()) { const IfThenElse* first_if_node = new_if_list[0].as(); + CHECK(first_if_node); new_for = IfThenElse::make(first_if_node->condition, first_if_node->then_case, first_if_node->else_case); @@ -395,8 +414,8 @@ Stmt IfThenElseLifter::PostOrderMutate(const Stmt& stmt) { return IRTransform(stmt, nullptr, replace_top_for, {Expr("For")}); } -Stmt LiftIfThenElse(Stmt stmt) { - return IfThenElseLifter().VisitAndMutate(stmt); +Stmt HoistIfThenElse(Stmt stmt) { + return IfThenElseLHoist().VisitAndMutate(stmt); } } // namespace ir diff --git a/tests/python/unittest/test_pass_lift_if.py b/tests/python/unittest/test_pass_hoist_if.py similarity index 96% rename from tests/python/unittest/test_pass_lift_if.py rename to tests/python/unittest/test_pass_hoist_if.py index bf2d7649f729..8a37c8be07e4 100644 --- a/tests/python/unittest/test_pass_lift_if.py +++ b/tests/python/unittest/test_pass_hoist_if.py @@ -64,7 +64,7 @@ def test_basic(): ib.emit(tvm.make.Evaluate(n)) stmt = ib.get() - new_stmt = tvm.ir_pass.LiftIfThenElse(stmt) + new_stmt = tvm.ir_pass.HoistIfThenElse(stmt) expected_struct = {('For', 'k'): (None,), ('For', 'j'): (('For', 'k'),), ('IfThenElse', ('i',)): (('For', 'j'), ('For', 'j')), ('For', 'i'): (('IfThenElse', ('i',)),)} @@ -82,7 +82,7 @@ def test_no_else(): ib.emit(tvm.make.Evaluate(m)) stmt = ib.get() - new_stmt = tvm.ir_pass.LiftIfThenElse(stmt) + new_stmt = tvm.ir_pass.HoistIfThenElse(stmt) expected_struct = {('For', 'k'): (None,), ('For', 'j'): (('For', 'k'),), ('IfThenElse', ('i',)): (('For', 'j'), None), ('For', 'i'): (('IfThenElse', ('i',)),)} @@ -109,7 +109,7 @@ def test_attr_stmt(): data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.0 stmt = ib.get() - new_stmt = tvm.ir_pass.LiftIfThenElse(stmt) + new_stmt = tvm.ir_pass.HoistIfThenElse(stmt) expected_struct = {('For', 'k'): (None,), ('IfThenElse', ('i', 'j')): (('For', 'k'), ('For', 'k')), ('For', 'j'): (('IfThenElse', ('i', 'j')),), ('For', 'i'): (('For', 'j'),), ('AttrStmt', 'thread_extent', 64): (('For', 'i'),), @@ -133,7 +133,7 @@ def test_nested_for(): data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 1.5 stmt = ib.get() - new_stmt = tvm.ir_pass.LiftIfThenElse(stmt) + new_stmt = tvm.ir_pass.HoistIfThenElse(stmt) expected_struct = {('IfThenElse', ('i', 'j')): (None, None), ('For', 'l'): (('IfThenElse', ('i', 'j')),), ('For', 'k'): (('For', 'l'),), ('For', 'j'): (None,), ('IfThenElse', ('i',)): (('For', 'j'), None), ('For', 'i'): (('IfThenElse', ('i',)),)} @@ -167,7 +167,7 @@ def test_block(): data[i * 3 + j + k] = data[i * 3 + j + k] + 0.6 stmt = ib.get() - new_stmt = tvm.ir_pass.LiftIfThenElse(stmt) + new_stmt = tvm.ir_pass.HoistIfThenElse(stmt) expected_struct = {('IfThenElse', ('i', 'j')): (None, None), ('IfThenElse', ('j',)): (None, None), ('For', 'l'): (None,), ('For', 'k'): (None,), ('For', 'j'): (('For', 'j'),), ('IfThenElse', ('i',)): (('For', 'j'), None), ('For', 'i'): (('IfThenElse', ('i',)),), From a777398393083359486301c6daab3bc170279afb Mon Sep 17 00:00:00 2001 From: Wang Yao Date: Wed, 4 Sep 2019 12:43:19 -0700 Subject: [PATCH 4/8] Add description for internal data structure --- ..._if_then_else.cc => hoist_if_then_else.cc} | 162 +++++++++--------- 1 file changed, 84 insertions(+), 78 deletions(-) rename src/pass/{lift_if_then_else.cc => hoist_if_then_else.cc} (76%) diff --git a/src/pass/lift_if_then_else.cc b/src/pass/hoist_if_then_else.cc similarity index 76% rename from src/pass/lift_if_then_else.cc rename to src/pass/hoist_if_then_else.cc index 83ed0ffef106..4726063d21e3 100644 --- a/src/pass/lift_if_then_else.cc +++ b/src/pass/hoist_if_then_else.cc @@ -93,7 +93,7 @@ using VarMap = std::unordered_map>; * A[i+j+k] = B[i+j+k] * */ -class IfThenElseLHoist { +class IfThenElseHoist { public: Stmt VisitAndMutate(const Stmt& stmt) { SelectCandidates(stmt); @@ -108,11 +108,17 @@ class IfThenElseLHoist { size_t GetUpdatedFor(const Stmt& for_stmt, const Stmt& if_stmt); Stmt HoistIf(const Stmt& if_stmt); + // Map of all For nodes to all child IfThenElse nodes. + HoistMap for2if_map_; + // Map of all IfThenElse nodes to all For nodes which are loop invariant. HoistMap if2for_map_; + // Map of highest loop invariant For to child IfThenElse. HoistMap top_for_var_map_; + // Map of original For to list of update For nodes. HoistMap for_tracking_map_; - HoistMap for2if_map_; + // Map of all IfThenElse nodes to condition variable nodes. VarMap cond_var_map_; + // List of For nodes added in post order DFS visiting. std::vector ordered_for_list_; }; @@ -193,54 +199,54 @@ std::pair RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) { } // Locate all For nodes and capture child IfThenElse nodes. -void IfThenElseLHoist::SelectCandidates(const Stmt& stmt) { +void IfThenElseHoist::SelectCandidates(const Stmt& stmt) { PostOrderVisit(stmt, [&](const NodeRef& node){ const For* for_node = node.as(); - if (for_node) { - std::queue tracker; - tracker.push(for_node->body); - Stmt for_stmt = Downcast(node); - for2if_map_.insert({for_stmt.get(), std::vector()}); - while (!tracker.empty()) { - Stmt head = tracker.front(); - tracker.pop(); - if (head->is_type()) { - for (const auto& if_stmt : for2if_map_.at(head.get())) { - for2if_map_[for_stmt.get()].push_back(if_stmt); - } - } else if (head->is_type()) { - const AttrStmt* attr_node = head.as(); - tracker.push(attr_node->body); - } else if (head->is_type()) { - for2if_map_[for_stmt.get()].push_back(head); - const IfThenElse* if_node = head.as(); - tracker.push(if_node->then_case); - if (if_node->else_case) { - tracker.push(if_node->else_case); - } - - // Record condition variables. - if (!cond_var_map_.count(head.get())) { - std::unordered_set new_var_set; - cond_var_map_.insert({head.get(), new_var_set}); - PostOrderVisit(if_node->condition, [&](const NodeRef& cond_node) { - if (cond_node.as()) { - cond_var_map_[head.get()].insert(cond_node.get()); - } - }); - } - } else { - continue; + if (!for_node) return; + + std::queue tracker; + tracker.push(for_node->body); + Stmt for_stmt = Downcast(node); + for2if_map_.insert({for_stmt.get(), std::vector()}); + while (!tracker.empty()) { + Stmt head = tracker.front(); + tracker.pop(); + if (head->is_type()) { + for (const auto& if_stmt : for2if_map_.at(head.get())) { + for2if_map_[for_stmt.get()].push_back(if_stmt); + } + } else if (head->is_type()) { + const AttrStmt* attr_node = head.as(); + tracker.push(attr_node->body); + } else if (head->is_type()) { + for2if_map_[for_stmt.get()].push_back(head); + const IfThenElse* if_node = head.as(); + tracker.push(if_node->then_case); + if (if_node->else_case) { + tracker.push(if_node->else_case); } + + // Record condition variables. + if (!cond_var_map_.count(head.get())) { + std::unordered_set new_var_set; + cond_var_map_.insert({head.get(), new_var_set}); + PostOrderVisit(if_node->condition, [&](const NodeRef& cond_node) { + if (cond_node.as()) { + cond_var_map_[head.get()].insert(cond_node.get()); + } + }); + } + } else { + continue; } - ordered_for_list_.emplace_back(Downcast(node)); } + ordered_for_list_.emplace_back(Downcast(node)); }); } // For each IfThenElse node, find the highest For node which // meets loop invariant condition. -void IfThenElseLHoist::LocateTopFor() { +void IfThenElseHoist::LocateTopFor() { std::unordered_map if_position_map; std::unordered_set top_for_var_set; @@ -315,9 +321,9 @@ void IfThenElseLHoist::LocateTopFor() { // been mutated. This function is to get the updated For node and further // hoisting can be done based on this new node. // We keep all For nodes tracing in for_tracking_map_. When we get a -// hoisted IfThenElse, we matching it with tracing For nodes to pick +// hoisted IfThenElse, we match it with tracing For nodes to pick // the updated one. -size_t IfThenElseLHoist::GetUpdatedFor(const Stmt& for_stmt, +size_t IfThenElseHoist::GetUpdatedFor(const Stmt& for_stmt, const Stmt& if_stmt) { std::vector tracked_for_list = for_tracking_map_[for_stmt.get()]; size_t updated_for_idx = 0; @@ -336,7 +342,7 @@ size_t IfThenElseLHoist::GetUpdatedFor(const Stmt& for_stmt, // This function iterates on all candidate For nodes. For each For node, // it first removes IfThenElse nodes. Then it generates a new IfThenElse // node using mutated For nodes. -Stmt IfThenElseLHoist::HoistIf(const Stmt& if_stmt) { +Stmt IfThenElseHoist::HoistIf(const Stmt& if_stmt) { Stmt new_if = if_stmt; for (size_t i = 0; i < if2for_map_[if_stmt.get()].size(); ++i) { @@ -370,52 +376,52 @@ Stmt IfThenElseLHoist::HoistIf(const Stmt& if_stmt) { } // Mutate For nodes in post order DFS manner. -Stmt IfThenElseLHoist::PostOrderMutate(const Stmt& stmt) { +Stmt IfThenElseHoist::PostOrderMutate(const Stmt& stmt) { PackedFunc replace_top_for = PackedFunc( [&](TVMArgs args, TVMRetValue *ret){ const NodeRef& current_for = args[0]; const For* for_node = current_for.as(); - if (for_node) { - if (top_for_var_map_.count(for_node->loop_var.get())) { - std::vector new_if_list; - for (const Stmt& if_stmt : - top_for_var_map_[for_node->loop_var.get()]) { - new_if_list.emplace_back(HoistIf(if_stmt)); - } - - const IfThenElse* next_if_node; - const IfThenElse* current_if_node = - new_if_list.back().as(); - Stmt new_for = Stmt(); - for (size_t i = new_if_list.size() - 1; i > 0; --i) { - CHECK(current_if_node); - const Stmt current_if_stmt = - IfThenElse::make(current_if_node->condition, - current_if_node->then_case, - current_if_node->else_case); - next_if_node = new_if_list[i - 1].as(); - CHECK(next_if_node); - new_for = IfThenElse::make(next_if_node->condition, current_if_stmt, - next_if_node->else_case); - current_if_node = new_for.as(); - } - - if (!new_for.get()) { - const IfThenElse* first_if_node = new_if_list[0].as(); - CHECK(first_if_node); - new_for = IfThenElse::make(first_if_node->condition, - first_if_node->then_case, - first_if_node->else_case); - } - *ret = new_for; + if (!for_node) return; + + if (top_for_var_map_.count(for_node->loop_var.get())) { + std::vector new_if_list; + for (const Stmt& if_stmt : + top_for_var_map_[for_node->loop_var.get()]) { + new_if_list.emplace_back(HoistIf(if_stmt)); + } + + const IfThenElse* next_if_node; + const IfThenElse* current_if_node = + new_if_list.back().as(); + Stmt new_for = Stmt(); + for (size_t i = new_if_list.size() - 1; i > 0; --i) { + CHECK(current_if_node); + const Stmt current_if_stmt = + IfThenElse::make(current_if_node->condition, + current_if_node->then_case, + current_if_node->else_case); + next_if_node = new_if_list[i - 1].as(); + CHECK(next_if_node); + new_for = IfThenElse::make(next_if_node->condition, current_if_stmt, + next_if_node->else_case); + current_if_node = new_for.as(); + } + + if (!new_for.get()) { + const IfThenElse* first_if_node = new_if_list[0].as(); + CHECK(first_if_node); + new_for = IfThenElse::make(first_if_node->condition, + first_if_node->then_case, + first_if_node->else_case); } + *ret = new_for; } }); return IRTransform(stmt, nullptr, replace_top_for, {Expr("For")}); } Stmt HoistIfThenElse(Stmt stmt) { - return IfThenElseLHoist().VisitAndMutate(stmt); + return IfThenElseHoist().VisitAndMutate(stmt); } } // namespace ir From f1b0066fe58ee82b0bfa11659aa9ff601f592ec2 Mon Sep 17 00:00:00 2001 From: Wang Yao Date: Wed, 4 Sep 2019 17:08:29 -0700 Subject: [PATCH 5/8] Rename a test --- tests/python/unittest/test_pass_hoist_if.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_pass_hoist_if.py b/tests/python/unittest/test_pass_hoist_if.py index 8a37c8be07e4..cc7708f2f00b 100644 --- a/tests/python/unittest/test_pass_hoist_if.py +++ b/tests/python/unittest/test_pass_hoist_if.py @@ -139,7 +139,7 @@ def test_nested_for(): ('For', 'i'): (('IfThenElse', ('i',)),)} verify_structure(new_stmt, expected_struct) -def test_block(): +def test_if_block(): ib = tvm.ir_builder.create() data = ib.pointer("float32", name="data") @@ -180,4 +180,4 @@ def test_block(): test_no_else() test_attr_stmt() test_nested_for() - test_block() \ No newline at end of file + test_if_block() \ No newline at end of file From e65484e130af0e89184c6f63e98741dde1fc17e5 Mon Sep 17 00:00:00 2001 From: Wang Yao Date: Fri, 6 Sep 2019 10:42:15 -0700 Subject: [PATCH 6/8] Minor change --- tests/python/unittest/test_pass_hoist_if.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_pass_hoist_if.py b/tests/python/unittest/test_pass_hoist_if.py index cc7708f2f00b..6332d0225e7a 100644 --- a/tests/python/unittest/test_pass_hoist_if.py +++ b/tests/python/unittest/test_pass_hoist_if.py @@ -55,6 +55,7 @@ def test_basic(): l = tvm.var('l') m = tvm.var('m') n = tvm.var('n') + with ib.for_range(0, l, "i") as i: with ib.for_range(0, m, "j") as j: with ib.for_range(0, n, "k") as k: @@ -75,6 +76,7 @@ def test_no_else(): l = tvm.var('l') m = tvm.var('m') n = tvm.var('n') + with ib.for_range(0, l, "i") as i: with ib.for_range(0, m, "j") as j: with ib.for_range(0, n, "k") as k: @@ -142,7 +144,6 @@ def test_nested_for(): def test_if_block(): ib = tvm.ir_builder.create() data = ib.pointer("float32", name="data") - n = tvm.var("n") From 404527e3d079b18ee068b1560f340d161d2ca4d6 Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Thu, 19 Sep 2019 21:38:23 +0000 Subject: [PATCH 7/8] Address comments --- src/pass/hoist_if_then_else.cc | 4 ---- tests/python/unittest/test_pass_hoist_if.py | 3 ++- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/pass/hoist_if_then_else.cc b/src/pass/hoist_if_then_else.cc index 4726063d21e3..5793fe1eaa5c 100644 --- a/src/pass/hoist_if_then_else.cc +++ b/src/pass/hoist_if_then_else.cc @@ -258,10 +258,6 @@ void IfThenElseHoist::LocateTopFor() { top_for_var_map_.insert({for_node->loop_var.get(), if_list}); for (const Stmt& if_stmt : if_list) { const Node* if_node = if_stmt.get(); - if (!if2for_map_.count(if_node)) { - std::vector new_for_list; - if2for_map_.insert({if_node, new_for_list}); - } if2for_map_[if_node].push_back(for_stmt); } } diff --git a/tests/python/unittest/test_pass_hoist_if.py b/tests/python/unittest/test_pass_hoist_if.py index 6332d0225e7a..4a28cf6b318a 100644 --- a/tests/python/unittest/test_pass_hoist_if.py +++ b/tests/python/unittest/test_pass_hoist_if.py @@ -49,6 +49,7 @@ def _visit(op): assert struct == expected_struct, "Structure mismatch: expect %s but got %s" \ % (expected_struct, struct) + var_list.clear() def test_basic(): ib = tvm.ir_builder.create() @@ -181,4 +182,4 @@ def test_if_block(): test_no_else() test_attr_stmt() test_nested_for() - test_if_block() \ No newline at end of file + test_if_block() From fdd94e170d6ed1176b9225b3d262fa52577f7842 Mon Sep 17 00:00:00 2001 From: Wang Yao Date: Tue, 15 Oct 2019 17:44:35 -0700 Subject: [PATCH 8/8] Improve update_for --- src/pass/hoist_if_then_else.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/pass/hoist_if_then_else.cc b/src/pass/hoist_if_then_else.cc index 5793fe1eaa5c..bbdb609e9a08 100644 --- a/src/pass/hoist_if_then_else.cc +++ b/src/pass/hoist_if_then_else.cc @@ -142,21 +142,21 @@ bool is_first_if(const Stmt& for_stmt, const Stmt& if_stmt) { // With this function we only need to visit and mutate top level For node // in the main VisitAndMutate function. Stmt update_for(const Stmt& parent_for_stmt, const Stmt& new_if_stmt) { - std::vector for_node_list; + const Node* top_for_node; const For* parent_for_node = parent_for_stmt.as(); CHECK(parent_for_node); CHECK(new_if_stmt.as()); PostOrderVisit(parent_for_node->body, [&](const NodeRef& node) { if (node.as()) { - for_node_list.push_back(node.get()); + top_for_node = node.get(); } }); PackedFunc replace_target_for = PackedFunc( [&](TVMArgs args, TVMRetValue *ret){ const NodeRef& current_for = args[0]; - if (current_for.get() == for_node_list.back()) { + if (current_for.get() == top_for_node) { *ret = new_if_stmt; } });