Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,14 @@ TVM_DLL Pass NarrowDataType(int target_bits);
*/
TVM_DLL Pass PointerValueTypeRewrite();

/*!
* \brief Hoist loop-invariant IfThenElse nodes to
* outside the corresponding loops.
*
* \return The pass.
*/
TVM_DLL Pass HoistIfThenElse();

} // namespace transform
} // namespace tir
} // namespace tvm
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ def lower(sch,
pass_list += [
tvm.tir.transform.VectorizeLoop(not cfg.disable_vectorize),
tvm.tir.transform.InjectVirtualThread(),
tvm.tir.transform.HoistIfThenElse(), # After InjectVirtualThread
# to protect vthread loops
tvm.tir.transform.InjectDoubleBuffer(cfg.double_buffer_split_loop),
tvm.tir.transform.StorageRewrite(),
tvm.tir.transform.UnrollLoop(
Expand Down
10 changes: 10 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,3 +478,13 @@ def VerifyMemory():
The result pass
"""
return _ffi_api.VerifyMemory()

def HoistIfThenElse():
"""Hoist loop-invariant IfThenElse nodes to outside the corresponding loops.

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.HoistIfThenElse()
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/stmt_functor.h>

#include <queue>
Expand Down Expand Up @@ -162,6 +163,83 @@ Stmt update_for(const Stmt& parent_for_stmt, const Stmt& new_if_stmt) {
return IRTransform(parent_for_stmt, nullptr, replace_target_for, Array<String>{"For"});
}

template <class T>
static bool no_intersect(const std::vector<T> &vec, const std::unordered_set<T> &set) {
for (auto &&item : vec) {
if (set.count(item))
return false;
}
return true;
}

// Rename all the Var defined in the else case, to meet the SSA requirement
class Renamer : public StmtExprMutator {
public:
explicit Renamer(const std::string &suffix)
: suffix_(suffix) {}

Stmt Rename(Stmt stmt) {
stmt = operator()(std::move(stmt));
return Substitute(std::move(stmt), var_map_);
}

protected:
Stmt VisitStmt_(const ForNode *op) override {
depth_++;
auto ret = StmtExprMutator::VisitStmt_(op);
depth_--;
if (depth_ >= 1) {
return ret;
}
op = ret.as<ForNode>();
Var new_var(op->loop_var->name_hint + suffix_);
var_map_.Set(op->loop_var, new_var);
return ForNode::make(new_var, op->min, op->extent, op->for_type,
op->device_api, op->body);
}

Stmt VisitStmt_(const AllocateNode *op) override {
auto ret = StmtExprMutator::VisitStmt_(op);
if (depth_ >= 1) {
return ret;
}
op = ret.as<AllocateNode>();
Var new_var(op->buffer_var->name_hint + suffix_);
var_map_.Set(op->buffer_var, new_var);
return AllocateNode::make(new_var, op->dtype, op->extents,
op->condition, op->body);
}

Stmt VisitStmt_(const LetStmtNode *op) override {
auto ret = StmtExprMutator::VisitStmt_(op);
if (depth_ >= 1) {
return ret;
}
op = ret.as<LetStmtNode>();
Var new_var(op->var->name_hint + suffix_);
var_map_.Set(op->var, new_var);
return LetStmtNode::make(new_var, op->value, op->body);
}

PrimExpr VisitExpr_(const LetNode *op) override {
auto ret = StmtExprMutator::VisitExpr_(op);
if (depth_ >= 1) {
return ret;
}
op = ret.as<LetNode>();
Var new_var(op->var->name_hint + suffix_);
var_map_.Set(op->var, new_var);
return LetNode::make(new_var, op->value, op->body);
}

private:
int depth_ = 0; // how may For nodes we are in
// we only rename the out-most loop, because
// Rename is called iteratively
const std::string &suffix_; // name suffix
Map<Var, PrimExpr> var_map_; // old var -> new var
};

// Remove IfThenElse node from a For node.
// A pair of For nodes will be generated.
std::pair<Stmt, Stmt> RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) {
Expand All @@ -186,6 +264,7 @@ std::pair<Stmt, Stmt> RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) {
then_for = IRTransform(for_stmt, nullptr, replace_then_case, Array<String>{"IfThenElse"});
if (if_stmt.as<IfThenElseNode>()->else_case.defined()) {
else_for = IRTransform(for_stmt, nullptr, replace_else_case, Array<String>{"IfThenElse"});
else_for = Renamer(".else").Rename(std::move(else_for));
}

return std::make_pair(then_for, else_for);
Expand All @@ -198,6 +277,8 @@ void IfThenElseHoist::SelectCandidates(const Stmt& stmt) {
if (!for_node) return;

std::queue<Stmt> tracker;
std::vector<const Object*> var_def; // don't hoist thread indices out of their
// definition region
tracker.push(for_node->body);
Stmt for_stmt = Downcast<Stmt, ObjectRef>(node);
for2if_map_.insert({for_stmt.get(), std::vector<Stmt>()});
Expand All @@ -206,19 +287,17 @@ void IfThenElseHoist::SelectCandidates(const Stmt& stmt) {
tracker.pop();
if (head->IsInstance<ForNode>()) {
for (const auto& if_stmt : for2if_map_.at(head.get())) {
for2if_map_[for_stmt.get()].push_back(if_stmt);
if (no_intersect(var_def, cond_var_map_[if_stmt.get()])) {
for2if_map_[for_stmt.get()].push_back(if_stmt);
}
}
} else if (head->IsInstance<AttrStmtNode>()) {
const AttrStmtNode* attr_node = head.as<AttrStmtNode>();
tracker.push(attr_node->body);
} else if (head->IsInstance<IfThenElseNode>()) {
for2if_map_[for_stmt.get()].push_back(head);
const IfThenElseNode* if_node = head.as<IfThenElseNode>();
tracker.push(if_node->then_case);
if (if_node->else_case.defined()) {
tracker.push(if_node->else_case);
} else if (auto attr_node = head.as<AttrStmtNode>()) {
if (attr_node->attr_key == attr::thread_extent) {
IterVar iv = Downcast<IterVar>(attr_node->node);
var_def.push_back(iv->var.get());
}

tracker.push(attr_node->body);
} else if (auto if_node = head.as<IfThenElseNode>()) {
// Record condition variables.
if (!cond_var_map_.count(head.get())) {
std::unordered_set<const Object*> new_var_set;
Expand All @@ -229,6 +308,14 @@ void IfThenElseHoist::SelectCandidates(const Stmt& stmt) {
}
});
}

if (no_intersect(var_def, cond_var_map_[head.get()])) {
for2if_map_[for_stmt.get()].push_back(head);
}
tracker.push(if_node->then_case);
if (if_node->else_case.defined()) {
tracker.push(if_node->else_case);
}
} else {
continue;
}
Expand Down Expand Up @@ -292,7 +379,8 @@ void IfThenElseHoist::LocateTopFor() {
} else {
std::vector<Stmt> actual_if_list;
for (const Stmt& if_stmt : if_list) {
if (if_position_map.count(if_stmt.get())) {
if (if_position_map.count(if_stmt.get()) &&
if_position_map.at(if_stmt.get()).as<ForNode>()->loop_var.get() == top_for_var) {
actual_if_list.push_back(if_stmt);
}
}
Expand Down Expand Up @@ -399,6 +487,21 @@ Stmt IfThenElseHoist::PostOrderMutate(const Stmt& stmt) {

Stmt HoistIfThenElse(Stmt stmt) { return IfThenElseHoist().VisitAndMutate(stmt); }

namespace transform {

Pass HoistIfThenElse() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = HoistIfThenElse(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.HoistIfThenElse", {});
}

TVM_REGISTER_GLOBAL("tir.transform.HoistIfThenElse").set_body_typed(HoistIfThenElse);

} // namespace transform

TVM_REGISTER_GLOBAL("testing.HoistIfThenElse").set_body_typed(HoistIfThenElse);

} // namespace tir
Expand Down
7 changes: 6 additions & 1 deletion tests/python/unittest/test_te_build_lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
import tvm
from tvm import te

def collect_visit(stmt, f):
ret = []
tvm.tir.stmt_functor.post_order_visit(stmt, lambda x : ret.append(f(x)))
return ret

def test_lower_rfactor():
n = te.size_var("n")
m = te.size_var("m")
Expand Down Expand Up @@ -49,7 +54,7 @@ def test_split_uneven_unique_likely():
sch = te.create_schedule(c.op)
xo, xi = sch[c].split(x, 5)
stmt = tvm.lower(sch, [a, b, c])["main"].body
assert isinstance(stmt.body.body.body, tvm.tir.stmt.IfThenElse)
assert(any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))))


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def test_basic():
stmt = ib.get()
new_stmt = tvm.testing.HoistIfThenElse(stmt)
expected_struct = {('For', 'k'): (None,), ('For', 'j'): (('For', 'k'),),
('IfThenElse', ('i',)): (('For', 'j'), ('For', 'j')),
('For', 'k.else'): (None,), ('For', 'j.else'): (('For', 'k.else'),),
('IfThenElse', ('i',)): (('For', 'j'), ('For', 'j.else')),
('For', 'i'): (('IfThenElse', ('i',)),)}
verify_structure(new_stmt, expected_struct)

Expand Down Expand Up @@ -114,7 +115,8 @@ def test_attr_stmt():

stmt = ib.get()
new_stmt = tvm.testing.HoistIfThenElse(stmt)
expected_struct = {('For', 'k'): (None,), ('IfThenElse', ('i', 'j')): (('For', 'k'), ('For', 'k')),
expected_struct = {('For', 'k'): (None,), ('For', 'k.else'): (None,),
('IfThenElse', ('i', 'j')): (('For', 'k'), ('For', 'k.else')),
('For', 'j'): (('IfThenElse', ('i', 'j')),), ('For', 'i'): (('For', 'j'),),
('AttrStmt', 'thread_extent', 64): (('For', 'i'),),
('AttrStmt', 'thread_extent', 32): (('AttrStmt', 'thread_extent', 64),)}
Expand Down Expand Up @@ -177,10 +179,25 @@ def test_if_block():
('IfThenElse', ('n',)): (('For', 'j'), None)}
verify_structure(new_stmt, expected_struct)

def test_multi_if():
ib = tvm.tir.ir_builder.create()
data = ib.pointer("float32", name="data")

with ib.for_range(0, 10, "i") as i:
with ib.for_range(0, 10, "j") as j:
with ib.for_range(0, 10, "k") as k:
with ib.if_scope(i >= 3):
with ib.if_scope(j >= 3):
data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] + 0.5

stmt = ib.get()
new_stmt = tvm.testing.HoistIfThenElse(stmt)


if __name__ == "__main__":
test_basic()
test_no_else()
test_attr_stmt()
test_nested_for()
test_if_block()
test_multi_if()