Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 27 additions & 11 deletions src/pass/loop_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,27 +39,43 @@ bool ExprUseVars(Expr expr, const std::unordered_set<const Variable*>& vars) {

class PartitionFinder : public IRVisitor {
public:
explicit PartitionFinder(VarExpr loop_var,
explicit PartitionFinder(VarExpr current_var,
const std::unordered_map<const Variable*, IntSet>& dom_map)
: target_var_(loop_var), out_vars_(dom_map.size()), hint_map_(dom_map) {
: current_var_(current_var), out_vars_(dom_map.size()), hint_map_(dom_map) {
for (const auto& kv : dom_map) out_vars_.insert(kv.first);
}

void Visit_(const For* op) {
if (ExprUseVars(op->min, out_vars_) || ExprUseVars(op->extent, out_vars_)) return;

hint_map_.insert({op->loop_var.get(),
IntSet::interval(op->min, op->min + op->extent - 1)});
relax_map_.insert({op->loop_var.get(),
IntSet::interval(op->min, op->min + op->extent - 1)});
const Variable* var = op->loop_var.get();
hint_map_.insert({var, IntSet::interval(op->min, op->min + op->extent - 1)});
relax_map_.insert({var, IntSet::interval(op->min, op->min + op->extent - 1)});
IRVisitor::Visit_(op);
relax_map_.erase(op->loop_var.get());
hint_map_.erase(op->loop_var.get());
relax_map_.erase(var);
hint_map_.erase(var);
}

void Visit_(const AttrStmt* op) {
// handle thread_axis
if (op->attr_key == attr::thread_extent) {
const IterVarNode* thread_axis = op->node.as<IterVarNode>();
CHECK(thread_axis);
const Variable* var = thread_axis->var.get();
IntSet dom = IntSet::range(Range(make_zero(op->value.type()), op->value));
hint_map_.insert({var, dom});
relax_map_.insert({var, dom});
IRVisitor::Visit_(op);
relax_map_.erase(var);
hint_map_.erase(var);
} else {
IRVisitor::Visit_(op);
}
}

void Visit_(const IfThenElse* op) {
if (ExprUseVars(op->condition, std::unordered_set<const Variable*>({target_var_.get()}))) {
IntSet interval = DeduceBound(target_var_, op->condition, hint_map_, relax_map_);
if (ExprUseVars(op->condition, std::unordered_set<const Variable*>({current_var_.get()}))) {
IntSet interval = DeduceBound(current_var_, op->condition, hint_map_, relax_map_);
partitions[op->condition.get()] = Partition{op->condition, interval};
} else {
IRVisitor::Visit_(op);
Expand All @@ -69,7 +85,7 @@ class PartitionFinder : public IRVisitor {
std::unordered_map<const Node*, Partition> partitions;

private:
VarExpr target_var_;
VarExpr current_var_;
std::unordered_set<const Variable*> out_vars_;
std::unordered_map<const Variable*, IntSet> hint_map_;
std::unordered_map<const Variable*, IntSet> relax_map_;
Expand Down
20 changes: 20 additions & 0 deletions tests/python/unittest/test_pass_loop_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,28 @@ def test_multi_if():
assert('if' not in str(stmt.body.first))
print(stmt)

def test_thread_axis():
m = tvm.Var('m')
l = tvm.Var('l')
A = tvm.placeholder((m, l), name='A')
B = tvm.compute((m, l), lambda i, j: A[i, j] + 3, name='B')

s = tvm.Schedule(B.op)

s[B].set_scope("shared")
num_thread = 16
xo, xi = s[B].split(B.op.axis[0], 32)
xi0, xi1 = s[B].split(xi, nparts=num_thread)
s[B].bind(xi0, tvm.thread_axis("threadIdx.x"))

bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt_ = tvm.ir_pass.LoopPartition(stmt)
assert('if' not in str(stmt_.body.body.body.first))
print(stmt_)

if __name__ == "__main__":
test_basic()
test_multi_loop()
test_multi_if()
test_thread_axis()