From 13bf880065d9bdd248785e777787a1ba29dce4e0 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Fri, 7 Apr 2017 17:24:36 +0000 Subject: [PATCH 1/2] [PASS] Support for partition loops with thread_axis --- src/pass/loop_partition.cc | 33 ++++++++++++------- .../unittest/test_pass_loop_partition.py | 19 +++++++++++ 2 files changed, 41 insertions(+), 11 deletions(-) diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc index 34a317824198..9dc7f6bade86 100644 --- a/src/pass/loop_partition.cc +++ b/src/pass/loop_partition.cc @@ -39,27 +39,38 @@ bool ExprUseVars(Expr expr, const std::unordered_set& vars) { class PartitionFinder : public IRVisitor { public: - explicit PartitionFinder(VarExpr loop_var, + explicit PartitionFinder(VarExpr current_var, const std::unordered_map& dom_map) - : target_var_(loop_var), out_vars_(dom_map.size()), hint_map_(dom_map) { + : current_var_(current_var), out_vars_(dom_map.size()), hint_map_(dom_map) { for (const auto& kv : dom_map) out_vars_.insert(kv.first); } void Visit_(const For* op) { if (ExprUseVars(op->min, out_vars_) || ExprUseVars(op->extent, out_vars_)) return; - hint_map_.insert({op->loop_var.get(), - IntSet::interval(op->min, op->min + op->extent - 1)}); - relax_map_.insert({op->loop_var.get(), - IntSet::interval(op->min, op->min + op->extent - 1)}); + const Variable* var = op->loop_var.get(); + hint_map_.insert({var, IntSet::interval(op->min, op->min + op->extent - 1)}); + relax_map_.insert({var, IntSet::interval(op->min, op->min + op->extent - 1)}); IRVisitor::Visit_(op); - relax_map_.erase(op->loop_var.get()); - hint_map_.erase(op->loop_var.get()); + relax_map_.erase(var); + hint_map_.erase(var); + } + + void Visit_(const AttrStmt* op) { + // handle thread_axis + if (const IterVarNode* thread_axis = op->node.as()) { + const Variable* var = thread_axis->var.get(); + hint_map_.insert({var, IntSet::range(thread_axis->dom)}); + relax_map_.insert({var, IntSet::range(thread_axis->dom)}); + IRVisitor::Visit_(op); + relax_map_.erase(var); + hint_map_.erase(var); + } } void Visit_(const IfThenElse* op) { - if (ExprUseVars(op->condition, std::unordered_set({target_var_.get()}))) { - IntSet interval = DeduceBound(target_var_, op->condition, hint_map_, relax_map_); + if (ExprUseVars(op->condition, std::unordered_set({current_var_.get()}))) { + IntSet interval = DeduceBound(current_var_, op->condition, hint_map_, relax_map_); partitions[op->condition.get()] = Partition{op->condition, interval}; } else { IRVisitor::Visit_(op); @@ -69,7 +80,7 @@ class PartitionFinder : public IRVisitor { std::unordered_map partitions; private: - VarExpr target_var_; + VarExpr current_var_; std::unordered_set out_vars_; std::unordered_map hint_map_; std::unordered_map relax_map_; diff --git a/tests/python/unittest/test_pass_loop_partition.py b/tests/python/unittest/test_pass_loop_partition.py index fd0662c8d906..177fb30fecbf 100644 --- a/tests/python/unittest/test_pass_loop_partition.py +++ b/tests/python/unittest/test_pass_loop_partition.py @@ -53,8 +53,27 @@ def test_multi_if(): assert('if' not in str(stmt.body.first)) print(stmt) +def test_thread_axis(): + m = tvm.Var('m') + l = tvm.Var('l') + A = tvm.placeholder((m, l), name='A') + B = tvm.compute((m, l), lambda i, j: A[i, j] + 3, name='B') + + s = tvm.Schedule(B.op) + + s[B].set_scope("shared") + thread_x = tvm.thread_axis((0, 16), "threadIdx.x") + xo, xi = s[B].split(B.op.axis[0], 32) + xi0, xi1 = s[B].split(xi, outer=thread_x) + + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + stmt_ = tvm.ir_pass.LoopPartition(stmt) + assert('if' not in str(stmt_.body.body.body.first)) + print(stmt_) if __name__ == "__main__": test_basic() test_multi_loop() test_multi_if() + test_thread_axis() From 3456ea7afe3d38d6c3bfe37cc6c7af58540f4116 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Sun, 9 Apr 2017 19:25:29 +0000 Subject: [PATCH 2/2] Add check for AttrStmt.attr_key --- src/pass/loop_partition.cc | 11 ++++++++--- tests/python/unittest/test_pass_loop_partition.py | 5 +++-- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc index 9dc7f6bade86..3a8f30e7d46b 100644 --- a/src/pass/loop_partition.cc +++ b/src/pass/loop_partition.cc @@ -58,13 +58,18 @@ class PartitionFinder : public IRVisitor { void Visit_(const AttrStmt* op) { // handle thread_axis - if (const IterVarNode* thread_axis = op->node.as()) { + if (op->attr_key == attr::thread_extent) { + const IterVarNode* thread_axis = op->node.as(); + CHECK(thread_axis); const Variable* var = thread_axis->var.get(); - hint_map_.insert({var, IntSet::range(thread_axis->dom)}); - relax_map_.insert({var, IntSet::range(thread_axis->dom)}); + IntSet dom = IntSet::range(Range(make_zero(op->value.type()), op->value)); + hint_map_.insert({var, dom}); + relax_map_.insert({var, dom}); IRVisitor::Visit_(op); relax_map_.erase(var); hint_map_.erase(var); + } else { + IRVisitor::Visit_(op); } } diff --git a/tests/python/unittest/test_pass_loop_partition.py b/tests/python/unittest/test_pass_loop_partition.py index 177fb30fecbf..9a3c6bbdd82c 100644 --- a/tests/python/unittest/test_pass_loop_partition.py +++ b/tests/python/unittest/test_pass_loop_partition.py @@ -62,9 +62,10 @@ def test_thread_axis(): s = tvm.Schedule(B.op) s[B].set_scope("shared") - thread_x = tvm.thread_axis((0, 16), "threadIdx.x") + num_thread = 16 xo, xi = s[B].split(B.op.axis[0], 32) - xi0, xi1 = s[B].split(xi, outer=thread_x) + xi0, xi1 = s[B].split(xi, nparts=num_thread) + s[B].bind(xi0, tvm.thread_axis("threadIdx.x")) bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds)