From e3861760476eadf896acd0a2952e4243131278a9 Mon Sep 17 00:00:00 2001 From: kun-zh Date: Mon, 25 Dec 2017 23:21:28 +0800 Subject: [PATCH 01/18] when there is no intrin func, using body for initialization. For issue 714. --- src/op/op_util.cc | 26 +++++++++++++++++++++++ src/op/op_util.h | 14 +++++++++++++ src/op/tensorize.cc | 50 ++++++++++++++++++++++++++++++--------------- 3 files changed, 73 insertions(+), 17 deletions(-) diff --git a/src/op/op_util.cc b/src/op/op_util.cc index 78e092ca844e..79544cc89c97 100644 --- a/src/op/op_util.cc +++ b/src/op/op_util.cc @@ -208,5 +208,31 @@ Stmt Substitute(Stmt s, return ir::Substitute(s, init); } +Stmt TransformUpdate(const Stage& stage, + const std::unordered_map& dom_map, + Stmt body, + Stmt update) +{ + Expr condition = (make_zero(Int(32)) == make_zero(Int(32))); //will try Bool(1) + for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) { + IterVar iv = stage->leaf_iter_vars[i]; + auto iit = stage->iter_var_attrs.find(iv); + if (iit != stage->iter_var_attrs.end()) { + const IterVarAttr& attr = (*iit).second; + if (attr->iter_type == kTensorized) { + break; + } + } + if (iv->iter_type == kCommReduce) { + auto vit = dom_map.find(iv); + CHECK(vit != dom_map.end()); + const Range& vrange = vit->second; + Expr newcond = ( iv->var == vrange->min); + condition = condition && newcond; + } + } + return IfThenElse::make(condition, body, update); +} + } // namespace op } // namespace tvm diff --git a/src/op/op_util.h b/src/op/op_util.h index 783fbb989422..fccc9a98de30 100644 --- a/src/op/op_util.h +++ b/src/op/op_util.h @@ -70,6 +70,20 @@ Expr ReplaceTensor(Expr expr, Stmt Substitute(Stmt stmt, const std::unordered_map& value_map); + +/*! + * \brief Transform the update part when there is no init func in tensorizing + * \param stage The stage for tensorizing. + * \param dom_map The range of each iter var. + * \param body The body func in tensorize intrin + * \param update The update func in tensorize intrin + * \return Transformed result. + */ +Stmt TransformUpdate(const Stage& stage, + const std::unordered_map& dom_map, + Stmt body, + Stmt update); + } // namespace op } // namespace tvm #endif // TVM_OP_OP_UTIL_H_ diff --git a/src/op/tensorize.cc b/src/op/tensorize.cc index b4527f76e808..cb72bdd86ddf 100644 --- a/src/op/tensorize.cc +++ b/src/op/tensorize.cc @@ -416,32 +416,48 @@ Stmt MakeTensorize(const ComputeOpNode* self, return MergeNest(nest, body); } else { // Need to split reduction - CHECK(intrin->reduce_init.defined()) - << "Reduction init op for intrin " << intrin << " is not defined"; + // Comment out the following check for the case when there is no init func + //CHECK(intrin->reduce_init.defined()) + // << "Reduction init op for intrin " << intrin << " is not defined"; CHECK(intrin->reduce_update.defined()) << "Reduction update op for intrin " << intrin << " is not defined"; // Need init and update steps CHECK_NE(self->reduce_axis.size(), 0U); std::vector > common( n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1); - // init nest - std::vector > init_nest( - n.init_nest.begin(), n.init_nest.begin() + tloc + 1); - init_nest.emplace_back(op::MakeIfNest(n.init_predicates)); - Stmt init = MergeNest(output_bind_nest, intrin->reduce_init); - init = Substitute(init, n.init_vmap); - init = MergeNest(init_nest, init); - // The update std::vector > update_nest( n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.begin() + tloc + 1); update_nest.emplace_back(op::MakeIfNest(n.main_predicates)); - Stmt update = MergeNest(output_bind_nest, intrin->reduce_update); - update = MergeNest(input_bind_nest, update); - update = Substitute(update, vmap); - update = MergeNest(binder.asserts(), update); - update = Substitute(update, n.main_vmap); - update = MergeNest(update_nest, update); - return MergeNest(common, Block::make(init, update)); + + if (intrin->reduce_init.defined()) { + // init nest + std::vector > init_nest( + n.init_nest.begin(), n.init_nest.begin() + tloc + 1); + init_nest.emplace_back(op::MakeIfNest(n.init_predicates)); + Stmt init = MergeNest(output_bind_nest, intrin->reduce_init); + init = Substitute(init, n.init_vmap); + init = MergeNest(init_nest, init); + // The update + Stmt update = MergeNest(output_bind_nest, intrin->reduce_update); + update = MergeNest(input_bind_nest, update); + update = Substitute(update, vmap); + update = MergeNest(binder.asserts(), update); + update = Substitute(update, n.main_vmap); + update = MergeNest(update_nest, update); + return MergeNest(common, Block::make(init, update)); + } else { + // The update + Stmt update = TransformUpdate (stage, dom_map, + intrin->body, + intrin->reduce_update); + update = MergeNest(output_bind_nest, update); + update = MergeNest(input_bind_nest, update); + update = Substitute(update, vmap); + update = MergeNest(binder.asserts(), update); + update = Substitute(update, n.main_vmap); + update = MergeNest(update_nest, update); + return MergeNest(common, update); + } } } From 784a0fdaab3e541453650f6d39c5081ab5cb6416 Mon Sep 17 00:00:00 2001 From: kun-zh Date: Wed, 27 Dec 2017 00:15:34 +0800 Subject: [PATCH 02/18] Refine code per review comments, and add a test case. --- src/op/op_util.cc | 26 ------ src/op/op_util.h | 14 --- src/op/tensorize.cc | 56 ++++++++++-- .../test_schedule_tensorize_init_none.py | 90 +++++++++++++++++++ 4 files changed, 141 insertions(+), 45 deletions(-) create mode 100644 tests/python/unittest/test_schedule_tensorize_init_none.py diff --git a/src/op/op_util.cc b/src/op/op_util.cc index 79544cc89c97..78e092ca844e 100644 --- a/src/op/op_util.cc +++ b/src/op/op_util.cc @@ -208,31 +208,5 @@ Stmt Substitute(Stmt s, return ir::Substitute(s, init); } -Stmt TransformUpdate(const Stage& stage, - const std::unordered_map& dom_map, - Stmt body, - Stmt update) -{ - Expr condition = (make_zero(Int(32)) == make_zero(Int(32))); //will try Bool(1) - for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) { - IterVar iv = stage->leaf_iter_vars[i]; - auto iit = stage->iter_var_attrs.find(iv); - if (iit != stage->iter_var_attrs.end()) { - const IterVarAttr& attr = (*iit).second; - if (attr->iter_type == kTensorized) { - break; - } - } - if (iv->iter_type == kCommReduce) { - auto vit = dom_map.find(iv); - CHECK(vit != dom_map.end()); - const Range& vrange = vit->second; - Expr newcond = ( iv->var == vrange->min); - condition = condition && newcond; - } - } - return IfThenElse::make(condition, body, update); -} - } // namespace op } // namespace tvm diff --git a/src/op/op_util.h b/src/op/op_util.h index fccc9a98de30..783fbb989422 100644 --- a/src/op/op_util.h +++ b/src/op/op_util.h @@ -70,20 +70,6 @@ Expr ReplaceTensor(Expr expr, Stmt Substitute(Stmt stmt, const std::unordered_map& value_map); - -/*! - * \brief Transform the update part when there is no init func in tensorizing - * \param stage The stage for tensorizing. - * \param dom_map The range of each iter var. - * \param body The body func in tensorize intrin - * \param update The update func in tensorize intrin - * \return Transformed result. - */ -Stmt TransformUpdate(const Stage& stage, - const std::unordered_map& dom_map, - Stmt body, - Stmt update); - } // namespace op } // namespace tvm #endif // TVM_OP_OP_UTIL_H_ diff --git a/src/op/tensorize.cc b/src/op/tensorize.cc index cb72bdd86ddf..9bf8b4979740 100644 --- a/src/op/tensorize.cc +++ b/src/op/tensorize.cc @@ -10,6 +10,7 @@ #include "./op_util.h" #include "./compute_op.h" #include "../schedule/message_passing.h" +#include "../arithmetic/compute_expr.h" namespace tvm { @@ -322,6 +323,52 @@ void VerifyTensorizeBody( } } +/*! + * \brief Transform the update part when there is no init func in tensorizing + * \param stage The stage for tensorizing. + * \param dom_map The range of each iter var. + * \param n The loop nest structured used in compute. + * \param body The body func in tensorize intrin + * \param update The update func in tensorize intrin + * \return Transformed result. + */ +Stmt TransformUpdate(const Stage& stage, + const std::unordered_map& dom_map, + const ComputeLoopNest& n, + Stmt body, + Stmt update) +{ + Array conds; + std::unordered_set banned; + for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) { + IterVar iv = stage->leaf_iter_vars[i]; + auto iit = stage->iter_var_attrs.find(iv); + if (iit != stage->iter_var_attrs.end()) { + const IterVarAttr& attr = (*iit).second; + if (attr->iter_type == kTensorized) { + break; + } + } + if (iv->iter_type == kCommReduce) { + auto vit = dom_map.find(iv); + CHECK(vit != dom_map.end()); + const Range& vrange = vit->second; + conds.push_back(likely(iv->var > vrange->min)); + banned.insert(iv->var.get()); + } + } + + for (const Expr& pred : n.main_predicates) { + if (ir::ExprUseVar(pred, banned)) { + LOG(FATAL) << "Tensorize update transform failed, the condition " + << pred << " has a conflict with the reset condition"; + } + } + + return IfThenElse::make(arith::ComputeReduce(conds, const_true(1)), + update, body); +} + Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage, const std::unordered_map& dom_map) { @@ -416,9 +463,6 @@ Stmt MakeTensorize(const ComputeOpNode* self, return MergeNest(nest, body); } else { // Need to split reduction - // Comment out the following check for the case when there is no init func - //CHECK(intrin->reduce_init.defined()) - // << "Reduction init op for intrin " << intrin << " is not defined"; CHECK(intrin->reduce_update.defined()) << "Reduction update op for intrin " << intrin << " is not defined"; // Need init and update steps @@ -446,8 +490,10 @@ Stmt MakeTensorize(const ComputeOpNode* self, update = MergeNest(update_nest, update); return MergeNest(common, Block::make(init, update)); } else { - // The update - Stmt update = TransformUpdate (stage, dom_map, + // When init op is not available, use body op for reset in the first iter. + CHECK(intrin->body.defined()) + << "Normal body op for intrin " << intrin << " is not defined"; + Stmt update = TransformUpdate (stage, dom_map, n, intrin->body, intrin->reduce_update); update = MergeNest(output_bind_nest, update); diff --git a/tests/python/unittest/test_schedule_tensorize_init_none.py b/tests/python/unittest/test_schedule_tensorize_init_none.py new file mode 100644 index 000000000000..ce1d5633173a --- /dev/null +++ b/tests/python/unittest/test_schedule_tensorize_init_none.py @@ -0,0 +1,90 @@ +import tvm + +def intrin_gemv(m, n): + w = tvm.placeholder((m, n), name='w') + x = tvm.placeholder((n,), name='x') + k = tvm.reduce_axis((0, n), name='k') + z = tvm.compute((m,), lambda i: + tvm.sum(w[i, k] * x[k], axis=k), name='z') + Wb = tvm.decl_buffer(w.shape, w.dtype, + name="W", + offset_factor=16, + strides=[tvm.var('ldw'), 1]) + def intrin_func(ins, outs): + ww, xx = ins + zz = outs[0] + ww_ptr = ww.access_ptr("r") + xx_ptr = xx.access_ptr("r") + zz_ptr = zz.access_ptr("w") + body = tvm.call_packed( + "gemv", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0]) + update = tvm.call_packed( + "gemv_add", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0]) + return body, None, update + + with tvm.build_config(data_alignment=16, + offset_factor=16): + return tvm.decl_tensor_intrin(z.op, intrin_func, + binds={w: Wb}) + + +def test_tensorize_matmul(): + n = 1024 + m = n + l = n + A = tvm.placeholder((n, l), name='A') + B = tvm.placeholder((m, l), name='B') + k = tvm.reduce_axis((0, l), name='k') + C = tvm.compute((n, m), lambda i, j: + tvm.sum(B[j, k] * A[i, k], axis=k), name='C') + + def check(factor): + s = tvm.create_schedule(C.op) + x, y = C.op.axis + yo, yi = s[C].split(y, factor=factor) + gemv = intrin_gemv(factor, l) + s[C].tensorize(yi, gemv) + s = s.normalize() + dom_map = tvm.schedule.InferBound(s) + finfer = tvm.get_global_func("test.op.InferTensorizeRegion") + out_dom, in_dom = finfer(s[C], dom_map) + assert tvm.ir_pass.Equal(out_dom[x].extent, 1) + assert tvm.ir_pass.Equal(out_dom[y].extent, factor) + assert tvm.ir_pass.Equal(out_dom[y].min, yo * factor) + fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") + body = fmatch(s[C], out_dom, in_dom, gemv) + assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(body[0]), + tvm.ir_pass.CanonicalSimplify(gemv.op.body[0])) + stmt = tvm.schedule.ScheduleOps(s, dom_map) + tvm.lower(s, [A, B, C]) + + + def check_rfactor(factor, rfactor): + s = tvm.create_schedule(C.op) + x, y = C.op.axis + rk = C.op.reduce_axis[0] + yo, yi = s[C].split(y, factor=factor) + ro, ri = s[C].split(rk, factor=rfactor) + s[C].reorder(yo, ro, yi, ri) + gemv = intrin_gemv(factor, rfactor) + s[C].tensorize(yi, gemv) + s = s.normalize() + dom_map = tvm.schedule.InferBound(s) + finfer = tvm.get_global_func("test.op.InferTensorizeRegion") + out_dom, in_dom = finfer(s[C], dom_map) + assert tvm.ir_pass.Equal(out_dom[x].extent, 1) + assert tvm.ir_pass.Equal(out_dom[y].extent, factor) + assert tvm.ir_pass.Equal(out_dom[y].min, yo * factor) + fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") + body = fmatch(s[C], out_dom, in_dom, gemv) + assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(body[0]), + tvm.ir_pass.CanonicalSimplify(gemv.op.body[0])) + stmt = tvm.schedule.ScheduleOps(s, dom_map) + tvm.lower(s, [A, B, C]) + + check(16) + check_rfactor(16, 16) + + +if __name__ == "__main__": + test_tensorize_matmul() From ab49e9ffc91a355a54220149a3223630ca00b041 Mon Sep 17 00:00:00 2001 From: kun-zh Date: Wed, 27 Dec 2017 13:44:41 +0800 Subject: [PATCH 03/18] Fix lint issues. --- src/op/tensorize.cc | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/src/op/tensorize.cc b/src/op/tensorize.cc index 9bf8b4979740..6fa5459829fc 100644 --- a/src/op/tensorize.cc +++ b/src/op/tensorize.cc @@ -336,11 +336,10 @@ Stmt TransformUpdate(const Stage& stage, const std::unordered_map& dom_map, const ComputeLoopNest& n, Stmt body, - Stmt update) -{ - Array conds; - std::unordered_set banned; - for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) { + Stmt update) { + Array conds; + std::unordered_set banned; + for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) { IterVar iv = stage->leaf_iter_vars[i]; auto iit = stage->iter_var_attrs.find(iv); if (iit != stage->iter_var_attrs.end()) { @@ -357,7 +356,6 @@ Stmt TransformUpdate(const Stage& stage, banned.insert(iv->var.get()); } } - for (const Expr& pred : n.main_predicates) { if (ir::ExprUseVar(pred, banned)) { LOG(FATAL) << "Tensorize update transform failed, the condition " @@ -493,16 +491,16 @@ Stmt MakeTensorize(const ComputeOpNode* self, // When init op is not available, use body op for reset in the first iter. CHECK(intrin->body.defined()) << "Normal body op for intrin " << intrin << " is not defined"; - Stmt update = TransformUpdate (stage, dom_map, n, - intrin->body, - intrin->reduce_update); + Stmt update = TransformUpdate(stage, dom_map, n, + intrin->body, + intrin->reduce_update); update = MergeNest(output_bind_nest, update); update = MergeNest(input_bind_nest, update); update = Substitute(update, vmap); update = MergeNest(binder.asserts(), update); update = Substitute(update, n.main_vmap); update = MergeNest(update_nest, update); - return MergeNest(common, update); + return MergeNest(common, update); } } } From dd7aefd7b339c51089403e70f39d1bffefedc8c4 Mon Sep 17 00:00:00 2001 From: kun-zh Date: Wed, 27 Dec 2017 23:23:21 +0800 Subject: [PATCH 04/18] Re-organize the tensorize test cases, and add a new case for none-reset mode. --- .../unittest/test_schedule_tensorize.py | 76 ++++++++++++++++ .../test_schedule_tensorize_init_none.py | 90 ------------------- 2 files changed, 76 insertions(+), 90 deletions(-) delete mode 100644 tests/python/unittest/test_schedule_tensorize_init_none.py diff --git a/tests/python/unittest/test_schedule_tensorize.py b/tests/python/unittest/test_schedule_tensorize.py index 71ae493e51ae..47b135e3c7c5 100644 --- a/tests/python/unittest/test_schedule_tensorize.py +++ b/tests/python/unittest/test_schedule_tensorize.py @@ -40,6 +40,33 @@ def intrin_func(ins, outs): return tvm.decl_tensor_intrin(z.op, intrin_func, binds={w: Wb}) +def intrin_gemv_no_reset(m, n): + w = tvm.placeholder((m, n), name='w') + x = tvm.placeholder((n,), name='x') + k = tvm.reduce_axis((0, n), name='k') + z = tvm.compute((m,), lambda i: + tvm.sum(w[i, k] * x[k], axis=k), name='z') + Wb = tvm.decl_buffer(w.shape, w.dtype, + name="W", + offset_factor=16, + strides=[tvm.var('ldw'), 1]) + def intrin_func(ins, outs): + ww, xx = ins + zz = outs[0] + ww_ptr = ww.access_ptr("r") + xx_ptr = xx.access_ptr("r") + zz_ptr = zz.access_ptr("w") + body = tvm.call_packed( + "gemv", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0]) + update = tvm.call_packed( + "gemv_add", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0]) + return body, None, update + + with tvm.build_config(data_alignment=16, + offset_factor=16): + return tvm.decl_tensor_intrin(z.op, intrin_func, + binds={w: Wb}) + def test_tensorize_vadd(): m = 128 @@ -123,8 +150,57 @@ def check_rfactor(factor, rfactor): stmt = tvm.schedule.ScheduleOps(s, dom_map) tvm.lower(s, [A, B, C]) + def check_rfactor_no_reset(factor, rfactor): + s = tvm.create_schedule(C.op) + x, y = C.op.axis + rk = C.op.reduce_axis[0] + yo, yi = s[C].split(y, factor=factor) + ro, ri = s[C].split(rk, factor=rfactor) + s[C].reorder(yo, ro, yi, ri) + gemv = intrin_gemv_no_reset(factor, rfactor) + s[C].tensorize(yi, gemv) + s = s.normalize() + dom_map = tvm.schedule.InferBound(s) + finfer = tvm.get_global_func("test.op.InferTensorizeRegion") + out_dom, in_dom = finfer(s[C], dom_map) + assert tvm.ir_pass.Equal(out_dom[x].extent, 1) + assert tvm.ir_pass.Equal(out_dom[y].extent, factor) + assert tvm.ir_pass.Equal(out_dom[y].min, yo * factor) + fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") + body = fmatch(s[C], out_dom, in_dom, gemv) + assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(body[0]), + tvm.ir_pass.CanonicalSimplify(gemv.op.body[0])) + stmt = tvm.schedule.ScheduleOps(s, dom_map) + tvm.lower(s, [A, B, C]) + + def check_rfactor_no_reset_multi_reduction(factor, rfactor): + s = tvm.create_schedule(C.op) + x, y = C.op.axis + rk = C.op.reduce_axis[0] + yo, yi = s[C].split(y, factor=factor) + ro, ri = s[C].split(rk, factor=rfactor) + roo, roi = s[C].split(ro, factor=2) + s[C].reorder(yo, roo, roi, yi, ri) + gemv = intrin_gemv_no_reset(factor, rfactor) + s[C].tensorize(yi, gemv) + s = s.normalize() + dom_map = tvm.schedule.InferBound(s) + finfer = tvm.get_global_func("test.op.InferTensorizeRegion") + out_dom, in_dom = finfer(s[C], dom_map) + assert tvm.ir_pass.Equal(out_dom[x].extent, 1) + assert tvm.ir_pass.Equal(out_dom[y].extent, factor) + assert tvm.ir_pass.Equal(out_dom[y].min, yo * factor) + fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") + body = fmatch(s[C], out_dom, in_dom, gemv) + assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(body[0]), + tvm.ir_pass.CanonicalSimplify(gemv.op.body[0])) + stmt = tvm.schedule.ScheduleOps(s, dom_map) + tvm.lower(s, [A, B, C]) + check(16) check_rfactor(16, 16) + check_rfactor_no_reset(16, 16) + check_rfactor_no_reset_multi_reduction(16, 16) # This tests whether algorithm and intrinsics expressions are simplified # as much as possible first and then checked for equality. See Issue #696 diff --git a/tests/python/unittest/test_schedule_tensorize_init_none.py b/tests/python/unittest/test_schedule_tensorize_init_none.py deleted file mode 100644 index ce1d5633173a..000000000000 --- a/tests/python/unittest/test_schedule_tensorize_init_none.py +++ /dev/null @@ -1,90 +0,0 @@ -import tvm - -def intrin_gemv(m, n): - w = tvm.placeholder((m, n), name='w') - x = tvm.placeholder((n,), name='x') - k = tvm.reduce_axis((0, n), name='k') - z = tvm.compute((m,), lambda i: - tvm.sum(w[i, k] * x[k], axis=k), name='z') - Wb = tvm.decl_buffer(w.shape, w.dtype, - name="W", - offset_factor=16, - strides=[tvm.var('ldw'), 1]) - def intrin_func(ins, outs): - ww, xx = ins - zz = outs[0] - ww_ptr = ww.access_ptr("r") - xx_ptr = xx.access_ptr("r") - zz_ptr = zz.access_ptr("w") - body = tvm.call_packed( - "gemv", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0]) - update = tvm.call_packed( - "gemv_add", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0]) - return body, None, update - - with tvm.build_config(data_alignment=16, - offset_factor=16): - return tvm.decl_tensor_intrin(z.op, intrin_func, - binds={w: Wb}) - - -def test_tensorize_matmul(): - n = 1024 - m = n - l = n - A = tvm.placeholder((n, l), name='A') - B = tvm.placeholder((m, l), name='B') - k = tvm.reduce_axis((0, l), name='k') - C = tvm.compute((n, m), lambda i, j: - tvm.sum(B[j, k] * A[i, k], axis=k), name='C') - - def check(factor): - s = tvm.create_schedule(C.op) - x, y = C.op.axis - yo, yi = s[C].split(y, factor=factor) - gemv = intrin_gemv(factor, l) - s[C].tensorize(yi, gemv) - s = s.normalize() - dom_map = tvm.schedule.InferBound(s) - finfer = tvm.get_global_func("test.op.InferTensorizeRegion") - out_dom, in_dom = finfer(s[C], dom_map) - assert tvm.ir_pass.Equal(out_dom[x].extent, 1) - assert tvm.ir_pass.Equal(out_dom[y].extent, factor) - assert tvm.ir_pass.Equal(out_dom[y].min, yo * factor) - fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") - body = fmatch(s[C], out_dom, in_dom, gemv) - assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(body[0]), - tvm.ir_pass.CanonicalSimplify(gemv.op.body[0])) - stmt = tvm.schedule.ScheduleOps(s, dom_map) - tvm.lower(s, [A, B, C]) - - - def check_rfactor(factor, rfactor): - s = tvm.create_schedule(C.op) - x, y = C.op.axis - rk = C.op.reduce_axis[0] - yo, yi = s[C].split(y, factor=factor) - ro, ri = s[C].split(rk, factor=rfactor) - s[C].reorder(yo, ro, yi, ri) - gemv = intrin_gemv(factor, rfactor) - s[C].tensorize(yi, gemv) - s = s.normalize() - dom_map = tvm.schedule.InferBound(s) - finfer = tvm.get_global_func("test.op.InferTensorizeRegion") - out_dom, in_dom = finfer(s[C], dom_map) - assert tvm.ir_pass.Equal(out_dom[x].extent, 1) - assert tvm.ir_pass.Equal(out_dom[y].extent, factor) - assert tvm.ir_pass.Equal(out_dom[y].min, yo * factor) - fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") - body = fmatch(s[C], out_dom, in_dom, gemv) - assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(body[0]), - tvm.ir_pass.CanonicalSimplify(gemv.op.body[0])) - stmt = tvm.schedule.ScheduleOps(s, dom_map) - tvm.lower(s, [A, B, C]) - - check(16) - check_rfactor(16, 16) - - -if __name__ == "__main__": - test_tensorize_matmul() From e17ae7fc41b08ee13d376699fe4c7ca20c214a4e Mon Sep 17 00:00:00 2001 From: kun-zh Date: Thu, 28 Dec 2017 21:53:54 +0800 Subject: [PATCH 05/18] Fix a typo. --- tests/python/unittest/test_schedule_tensorize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_schedule_tensorize.py b/tests/python/unittest/test_schedule_tensorize.py index 47b135e3c7c5..ca5836143ef3 100644 --- a/tests/python/unittest/test_schedule_tensorize.py +++ b/tests/python/unittest/test_schedule_tensorize.py @@ -173,7 +173,7 @@ def check_rfactor_no_reset(factor, rfactor): stmt = tvm.schedule.ScheduleOps(s, dom_map) tvm.lower(s, [A, B, C]) - def check_rfactor_no_reset_multi_reduction(factor, rfactor): + def check_rfactor_no_reset_multi_reduction(factor, rfactor): s = tvm.create_schedule(C.op) x, y = C.op.axis rk = C.op.reduce_axis[0] From e59bdddce4cd8394bfe5a5712a80c31e4b612bc5 Mon Sep 17 00:00:00 2001 From: kun-zh Date: Thu, 28 Dec 2017 22:41:11 +0800 Subject: [PATCH 06/18] Delete the unit case because merged it into test_schedule_tensorize.py already. --- .../test_schedule_tensorize_init_none.py | 90 ------------------- 1 file changed, 90 deletions(-) delete mode 100644 tests/python/unittest/test_schedule_tensorize_init_none.py diff --git a/tests/python/unittest/test_schedule_tensorize_init_none.py b/tests/python/unittest/test_schedule_tensorize_init_none.py deleted file mode 100644 index ce1d5633173a..000000000000 --- a/tests/python/unittest/test_schedule_tensorize_init_none.py +++ /dev/null @@ -1,90 +0,0 @@ -import tvm - -def intrin_gemv(m, n): - w = tvm.placeholder((m, n), name='w') - x = tvm.placeholder((n,), name='x') - k = tvm.reduce_axis((0, n), name='k') - z = tvm.compute((m,), lambda i: - tvm.sum(w[i, k] * x[k], axis=k), name='z') - Wb = tvm.decl_buffer(w.shape, w.dtype, - name="W", - offset_factor=16, - strides=[tvm.var('ldw'), 1]) - def intrin_func(ins, outs): - ww, xx = ins - zz = outs[0] - ww_ptr = ww.access_ptr("r") - xx_ptr = xx.access_ptr("r") - zz_ptr = zz.access_ptr("w") - body = tvm.call_packed( - "gemv", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0]) - update = tvm.call_packed( - "gemv_add", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0]) - return body, None, update - - with tvm.build_config(data_alignment=16, - offset_factor=16): - return tvm.decl_tensor_intrin(z.op, intrin_func, - binds={w: Wb}) - - -def test_tensorize_matmul(): - n = 1024 - m = n - l = n - A = tvm.placeholder((n, l), name='A') - B = tvm.placeholder((m, l), name='B') - k = tvm.reduce_axis((0, l), name='k') - C = tvm.compute((n, m), lambda i, j: - tvm.sum(B[j, k] * A[i, k], axis=k), name='C') - - def check(factor): - s = tvm.create_schedule(C.op) - x, y = C.op.axis - yo, yi = s[C].split(y, factor=factor) - gemv = intrin_gemv(factor, l) - s[C].tensorize(yi, gemv) - s = s.normalize() - dom_map = tvm.schedule.InferBound(s) - finfer = tvm.get_global_func("test.op.InferTensorizeRegion") - out_dom, in_dom = finfer(s[C], dom_map) - assert tvm.ir_pass.Equal(out_dom[x].extent, 1) - assert tvm.ir_pass.Equal(out_dom[y].extent, factor) - assert tvm.ir_pass.Equal(out_dom[y].min, yo * factor) - fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") - body = fmatch(s[C], out_dom, in_dom, gemv) - assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(body[0]), - tvm.ir_pass.CanonicalSimplify(gemv.op.body[0])) - stmt = tvm.schedule.ScheduleOps(s, dom_map) - tvm.lower(s, [A, B, C]) - - - def check_rfactor(factor, rfactor): - s = tvm.create_schedule(C.op) - x, y = C.op.axis - rk = C.op.reduce_axis[0] - yo, yi = s[C].split(y, factor=factor) - ro, ri = s[C].split(rk, factor=rfactor) - s[C].reorder(yo, ro, yi, ri) - gemv = intrin_gemv(factor, rfactor) - s[C].tensorize(yi, gemv) - s = s.normalize() - dom_map = tvm.schedule.InferBound(s) - finfer = tvm.get_global_func("test.op.InferTensorizeRegion") - out_dom, in_dom = finfer(s[C], dom_map) - assert tvm.ir_pass.Equal(out_dom[x].extent, 1) - assert tvm.ir_pass.Equal(out_dom[y].extent, factor) - assert tvm.ir_pass.Equal(out_dom[y].min, yo * factor) - fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") - body = fmatch(s[C], out_dom, in_dom, gemv) - assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(body[0]), - tvm.ir_pass.CanonicalSimplify(gemv.op.body[0])) - stmt = tvm.schedule.ScheduleOps(s, dom_map) - tvm.lower(s, [A, B, C]) - - check(16) - check_rfactor(16, 16) - - -if __name__ == "__main__": - test_tensorize_matmul() From 42adaf05c464587c1f82665d3c9aa8882f2933d6 Mon Sep 17 00:00:00 2001 From: kun-zh Date: Wed, 3 Jan 2018 00:27:50 +0800 Subject: [PATCH 07/18] always use new tensor in its stage when rewrite for cache read --- src/schedule/schedule_dataflow_rewrite.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index d1a69ecf0203..1f2357ae10bc 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -86,7 +86,8 @@ Tensor Schedule::cache_read(const Tensor& tensor, return tensor(Array(i.begin(), i.end())); }, os.str()); std::unordered_map vsub; - vsub[tensor] = cache; + Stage tensor_stage = operator[](tensor->op); + vsub[tensor_stage->op.output(0)] = cache; std::unordered_map vmap; for (Operation op : readers) { From edb22fac4cdb8ff3a72117b59b00043ab534858e Mon Sep 17 00:00:00 2001 From: kun-zh Date: Fri, 5 Jan 2018 22:58:47 +0800 Subject: [PATCH 08/18] revert previous changes to sync up with master --- src/schedule/schedule_dataflow_rewrite.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index 1f2357ae10bc..d1a69ecf0203 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -86,8 +86,7 @@ Tensor Schedule::cache_read(const Tensor& tensor, return tensor(Array(i.begin(), i.end())); }, os.str()); std::unordered_map vsub; - Stage tensor_stage = operator[](tensor->op); - vsub[tensor_stage->op.output(0)] = cache; + vsub[tensor] = cache; std::unordered_map vmap; for (Operation op : readers) { From 1a860fea4be3b6e01d97d0137db0ca301f039b68 Mon Sep 17 00:00:00 2001 From: kun-zh Date: Sat, 27 Jan 2018 00:04:56 +0800 Subject: [PATCH 09/18] support using the ptr with an original offset --- include/tvm/buffer.h | 3 ++- python/tvm/schedule.py | 7 +++++-- src/api/api_lang.cc | 2 +- src/lang/buffer.cc | 4 ++-- tests/python/unittest/test_lang_buffer.py | 8 ++++++++ 5 files changed, 18 insertions(+), 6 deletions(-) diff --git a/include/tvm/buffer.h b/include/tvm/buffer.h index f2790f6df7d1..d737341e1c0e 100644 --- a/include/tvm/buffer.h +++ b/include/tvm/buffer.h @@ -52,9 +52,10 @@ class Buffer : public NodeRef { * \param access_mask The access mask * \param ptr_type The type of the pointer. * \param content_lanes The number of lanes for the (data) type. + * \param offset The offset of ptr. */ TVM_DLL Expr access_ptr(int access_mask, Type ptr_type = Handle(), - int content_lanes = 1) const; + int content_lanes = 1, int offset = 0) const; /*! * \brief Create an Expr that does a vector load at begin index. * \param begin The beginning index diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py index 0fc6692d950e..9b8866559125 100644 --- a/python/tvm/schedule.py +++ b/python/tvm/schedule.py @@ -25,7 +25,7 @@ class Buffer(NodeBase): READ = 1 WRITE = 2 - def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1): + def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0): """Get an access pointer to the head of buffer. This is the recommended method to get buffer data @@ -45,6 +45,9 @@ def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1): The number of lanes for the data type. This value is greater than one for vector types. + offset: int, optional + The offset of pointer. + Examples -------- .. code-block:: python @@ -68,7 +71,7 @@ def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1): raise ValueError("Unknown access_mask %s" % access_mask) access_mask = mask return _api_internal._BufferAccessPtr(self, access_mask, ptr_type, - content_lanes) + content_lanes, offset) def vload(self, begin, dtype=None): """Generate an Expr that loads dtype from begin index. diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index 37a21cedf3db..3b5916ea5fec 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -159,7 +159,7 @@ TVM_REGISTER_API("_Buffer") TVM_REGISTER_API("_BufferAccessPtr") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = args[0].operator Buffer() - .access_ptr(args[1], args[2], args[3]); + .access_ptr(args[1], args[2], args[3], args[4]); }); TVM_REGISTER_API("_BufferVLoad") diff --git a/src/lang/buffer.cc b/src/lang/buffer.cc index af76dcc94f71..07e455e25384 100644 --- a/src/lang/buffer.cc +++ b/src/lang/buffer.cc @@ -335,7 +335,7 @@ Buffer Buffer::MakeSlice(Array begins, Array extents) const { 0); } -Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes) const { +Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, int offset) const { const BufferNode* self = operator->(); Expr e_dtype; Expr extent; @@ -348,7 +348,7 @@ Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes) const } else { extent = arith::ComputeReduce(self->shape, Expr()); } - Expr elem_offset = self->elem_offset; + Expr elem_offset = self->elem_offset + offset; if (content_lanes > 1) { e_dtype = make_zero(self->dtype.with_lanes(content_lanes)); extent = extent / make_const(self->elem_offset.type(), content_lanes); diff --git a/tests/python/unittest/test_lang_buffer.py b/tests/python/unittest/test_lang_buffer.py index c3f00ac2f166..9c47a7377e2e 100644 --- a/tests/python/unittest/test_lang_buffer.py +++ b/tests/python/unittest/test_lang_buffer.py @@ -23,6 +23,14 @@ def test_buffer_access_ptr(): aptr = Ab.access_ptr("w") assert aptr.args[4].value == Buffer.WRITE +def test_buffer_access_ptr_offset(): + m = tvm.var('m') + n = tvm.var('n') + Ab = tvm.decl_buffer((m, n), tvm.float32) + aptr = Ab.access_ptr("rw", handle(), 1, 100) + assert tvm.ir_pass.Equal(aptr.args[2], 100) + assert aptr.args[4].value == Buffer.READ | Buffer.WRITE + def test_buffer_index_merge_mult_mod(): m = tvm.var('m') n = tvm.var('n') From 0ff0d89e5c447de189ac773f18984750681b2ad6 Mon Sep 17 00:00:00 2001 From: kun-zh Date: Sat, 27 Jan 2018 16:22:30 +0800 Subject: [PATCH 10/18] update test case and fix CI error --- python/tvm/schedule.py | 3 ++- tests/python/unittest/test_lang_buffer.py | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py index 9b8866559125..dda5f67d1b89 100644 --- a/python/tvm/schedule.py +++ b/python/tvm/schedule.py @@ -46,7 +46,8 @@ def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0): is greater than one for vector types. offset: int, optional - The offset of pointer. + The offset of pointer. We can use it to offset by + the number of elements from the address of ptr. Examples -------- diff --git a/tests/python/unittest/test_lang_buffer.py b/tests/python/unittest/test_lang_buffer.py index 9c47a7377e2e..fe0f1f0b759c 100644 --- a/tests/python/unittest/test_lang_buffer.py +++ b/tests/python/unittest/test_lang_buffer.py @@ -27,8 +27,9 @@ def test_buffer_access_ptr_offset(): m = tvm.var('m') n = tvm.var('n') Ab = tvm.decl_buffer((m, n), tvm.float32) - aptr = Ab.access_ptr("rw", handle(), 1, 100) - assert tvm.ir_pass.Equal(aptr.args[2], 100) + aptr = Ab.access_ptr("rw", offset=100) + offset = tvm.ir_pass.Simplify(aptr.args[2]) + assert tvm.ir_pass.Equal(offset, 100) assert aptr.args[4].value == Buffer.READ | Buffer.WRITE def test_buffer_index_merge_mult_mod(): @@ -65,4 +66,5 @@ def assert_simplified_equal(index_simplified, index_direct): if __name__ == "__main__": test_buffer() test_buffer_access_ptr() + test_buffer_access_ptr_offset() test_buffer_index_merge_mult_mod() From 9909c9887f3fe9ab5654714a8664f885a00b9d1d Mon Sep 17 00:00:00 2001 From: kun-zh Date: Thu, 29 Mar 2018 01:25:46 +0800 Subject: [PATCH 11/18] fix a bug in ReplaceDataFlow for issue 1043 --- src/schedule/schedule_dataflow_rewrite.cc | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index 562eff417dd2..7b280bb02e63 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -58,12 +58,18 @@ Expr InjectPredicate(const Array& predicates, // Replace data flow appears in all stages given the tensor change. // Also update vmap if subsequent dataflow need to be replaced. void ReplaceDataFlow(const Array& stages, - std::unordered_map* vmap) { + std::unordered_map* vmap, + std::unordered_map* rvmap) { for (Stage s : stages) { Operation op = s->op->ReplaceInputs(s->op, *vmap); if (!op.same_as(s->op)) { for (int i = 0; i < op->num_outputs(); ++i) { - (*vmap)[s->op.output(i)] = op.output(i); + if((*rvmap).find(s->op.output(i)) != (*rvmap).end()) { + (*vmap)[((*rvmap)[s->op.output(i)])] = op.output(i); + } else { + (*vmap)[s->op.output(i)] = op.output(i); + (*rvmap)[op.output(i)] = s->op.output(i); + } } s->op = op; } @@ -91,6 +97,7 @@ Tensor Schedule::cache_read(const Tensor& tensor, vsub[sugar_tensor] = cache; std::unordered_map vmap; + std::unordered_map rvmap; for (Operation op : readers) { Stage s = operator[](op); Operation repl_op = s->op->ReplaceInputs(s->op, vsub); @@ -98,9 +105,10 @@ Tensor Schedule::cache_read(const Tensor& tensor, << "Cannot find " << tensor << " in the inputs of " << s->op; vmap[s->op.output(0)] = repl_op.output(0); + rvmap[repl_op.output(0)] = s->op.output(0); s->op = repl_op; } - ReplaceDataFlow((*this)->stages, &vmap); + ReplaceDataFlow((*this)->stages, &vmap, &rvmap); ArrayNode* stages = (*this)->stages.CopyOnWrite(); Stage op_stage = operator[](tensor->op); size_t pos = FindNodeRef(stages, op_stage); @@ -197,8 +205,10 @@ Tensor CacheWriteWithReLayout(Schedule sch, {cache_tensor(args)}); // The replace of the dataflow std::unordered_map vmap; + std::unordered_map rvmap; vmap[orig_stage->op.output(0)] = orig_new_op.output(0); - ReplaceDataFlow(sch->stages, &vmap); + rvmap[orig_new_op.output(0)] = orig_stage->op.output(0); + ReplaceDataFlow(sch->stages, &vmap, &rvmap); // mutate orig stage orig_stage->op = orig_new_op; orig_stage->all_iter_vars = orig_stage->op->root_iter_vars(); @@ -583,10 +593,12 @@ Array Schedule::rfactor(const Tensor& tensor, }, reduce_stage->op->name + ".repl"); std::unordered_map vmap; + std::unordered_map rvmap; for (int idx = 0; idx < size; ++idx) { vmap[old_tensors[idx]] = repl_tensors[idx]; + rvmap[repl_tensors[idx]] = old_tensors[idx]; } - ReplaceDataFlow((*this)->stages, &vmap); + ReplaceDataFlow((*this)->stages, &vmap, &rvmap); // revamp the reduction stage. reduce_stage->op = repl_tensors[0]->op; reduce_stage->all_iter_vars = repl_tensors[0]->op->root_iter_vars(); From 740a6300ccbe06f1750fbce38c4867c304127736 Mon Sep 17 00:00:00 2001 From: kun-zh Date: Thu, 29 Mar 2018 01:38:08 +0800 Subject: [PATCH 12/18] fix lint error --- src/schedule/schedule_dataflow_rewrite.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index 7b280bb02e63..d85e7119b511 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -64,7 +64,7 @@ void ReplaceDataFlow(const Array& stages, Operation op = s->op->ReplaceInputs(s->op, *vmap); if (!op.same_as(s->op)) { for (int i = 0; i < op->num_outputs(); ++i) { - if((*rvmap).find(s->op.output(i)) != (*rvmap).end()) { + if ((*rvmap).find(s->op.output(i)) != (*rvmap).end()) { (*vmap)[((*rvmap)[s->op.output(i)])] = op.output(i); } else { (*vmap)[s->op.output(i)] = op.output(i); From 6ae29f49733d466575a20f245f1adb5cc8eb6eed Mon Sep 17 00:00:00 2001 From: kun-zh Date: Thu, 29 Mar 2018 22:14:13 +0800 Subject: [PATCH 13/18] add the test case in issue 1043 --- tests/python/unittest/test_pass_storage_rewrite.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/python/unittest/test_pass_storage_rewrite.py b/tests/python/unittest/test_pass_storage_rewrite.py index 994db4ce5f15..2bb02998982f 100644 --- a/tests/python/unittest/test_pass_storage_rewrite.py +++ b/tests/python/unittest/test_pass_storage_rewrite.py @@ -442,6 +442,19 @@ def verify(n): tvm.ir_pass.PostOrderVisit(body, verify) assert num_alloc[0] == 1 +def test_replace_dataflow(): + shape = (255,) + A = tvm.placeholder(shape, name = "A") + B = tvm.compute(shape, lambda i: A[i] + A[i], name = "B") + C = tvm.compute(shape, lambda i: A[i] + B[i], name = "C") + D = tvm.compute(shape, lambda i: A[i] + C[i], name = "D") + E = tvm.compute(shape, lambda i: A[i] + D[i], name = "E") + + s = tvm.create_schedule(E.op) + s.cache_read(A, "local", [B, C, D, E]) + bounds = tvm.schedule.InferBound(s) + assert isinstance(bounds, tvm.container.Map) + if __name__ == "__main__": test_alloc_seq() @@ -456,3 +469,4 @@ def verify(n): test_alloc_seq_type() test_alloc_seq_type2() test_reuse_small_buffer() + test_replace_dataflow() From d8fb1dcab76cf9527f32728564465f9b13f4f3f5 Mon Sep 17 00:00:00 2001 From: kun-zh Date: Sat, 31 Mar 2018 18:08:46 +0800 Subject: [PATCH 14/18] refine code per review suggestions --- src/schedule/schedule_dataflow_rewrite.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index d85e7119b511..1d9dfed4e5e1 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -57,6 +57,7 @@ Expr InjectPredicate(const Array& predicates, // Replace data flow appears in all stages given the tensor change. // Also update vmap if subsequent dataflow need to be replaced. +// Need to keep an update to the date transitive closure property on the vmap by a reverse map. void ReplaceDataFlow(const Array& stages, std::unordered_map* vmap, std::unordered_map* rvmap) { @@ -64,8 +65,9 @@ void ReplaceDataFlow(const Array& stages, Operation op = s->op->ReplaceInputs(s->op, *vmap); if (!op.same_as(s->op)) { for (int i = 0; i < op->num_outputs(); ++i) { - if ((*rvmap).find(s->op.output(i)) != (*rvmap).end()) { - (*vmap)[((*rvmap)[s->op.output(i)])] = op.output(i); + auto it = rvmap->find(s->op.output(i)); + if (it != rvmap->end()) { + (*vmap)[it->second] = op.output(i); } else { (*vmap)[s->op.output(i)] = op.output(i); (*rvmap)[op.output(i)] = s->op.output(i); From 5f4191d7997e6a22325650d77309c45c568cb12e Mon Sep 17 00:00:00 2001 From: kun-zh Date: Sun, 8 Apr 2018 22:37:37 +0800 Subject: [PATCH 15/18] Generate Lower Bound Conditions to fix issue 1014 --- src/schedule/message_passing.cc | 4 ++++ tests/python/unittest/test_schedule_schedule_ops.py | 13 +++++++++++++ 2 files changed, 17 insertions(+) diff --git a/src/schedule/message_passing.cc b/src/schedule/message_passing.cc index a144e7fc40d1..2d1102391e62 100644 --- a/src/schedule/message_passing.cc +++ b/src/schedule/message_passing.cc @@ -477,7 +477,11 @@ std::vector MakeBoundCheck( CHECK(iv->dom.defined()); if (!skip_ivar_domain && !iv->dom.same_as(dom)) { Expr value = ComputeExpr(value_map.at(iv), iv->dom->min); + Expr vmin = EvalSet(value, iset_dmap).min(); Expr vmax = EvalSet(value, iset_dmap).max(); + if (vmin.type() != value.type() || !can_prove(vmin > (iv->dom->min - 1))) { + preds.emplace_back(value > (iv->dom->min - 1)); + } if (vmax.type() != value.type() || !can_prove(vmax < iv->dom->extent)) { preds.emplace_back(value < iv->dom->extent); } diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py index 03b8dbf48c8c..dd57c35c39b6 100644 --- a/tests/python/unittest/test_schedule_schedule_ops.py +++ b/tests/python/unittest/test_schedule_schedule_ops.py @@ -249,6 +249,18 @@ def test_schedule_cache_relayout3(): stmt = tvm.schedule.ScheduleOps(s, bounds) +def test_schedule_bound_condition(): + A = tvm.placeholder((64,), name='A', dtype="float32") + Apad = tvm.compute((66,), lambda i: tvm.select(tvm.all(i>0, i < 65), A[i-1], tvm.const(0.)), name='Apad') + Apad2 = tvm.compute((66,), lambda i: Apad[i]*2, name='Apad2') + s = tvm.create_schedule(Apad2.op) + AL1 = s.cache_read(A,"local",[Apad]) + s = s.normalize() + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + stmt = tvm.ir_pass.Simplify(stmt) + assert (isinstance(stmt.body.body.first.body.body.then_case, tvm.stmt.IfThenElse)) + if __name__ == "__main__": test_schedule_middle_cache() test_inline_multi_reduce() @@ -265,3 +277,4 @@ def test_schedule_cache_relayout3(): test_schedule1() test_schedule2() test_schedule_cache() + test_schedule_bound_condition() From e9de7372542a1db6e0569166308eeeeddc765e5b Mon Sep 17 00:00:00 2001 From: kun-zh Date: Sun, 8 Apr 2018 23:26:38 +0800 Subject: [PATCH 16/18] update test case to fix regression issue --- tests/python/unittest/test_pass_inject_copy_intrin.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/unittest/test_pass_inject_copy_intrin.py b/tests/python/unittest/test_pass_inject_copy_intrin.py index c6ed19d65b69..a44f3899c282 100644 --- a/tests/python/unittest/test_pass_inject_copy_intrin.py +++ b/tests/python/unittest/test_pass_inject_copy_intrin.py @@ -82,6 +82,7 @@ def test_copy_pad_split(): Ab = tvm.decl_buffer(A.shape, A.dtype, name='A') Bb = tvm.decl_buffer(B.shape, B.dtype, name='B') stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) + stmt = tvm.ir_pass.Simplify(stmt) stmt = tvm.ir_pass.CanonicalSimplify(stmt) def cb(src, dst, pad_before, pad_after, pad_value): assert(dst.elem_offset.value == 0) From f1afda9187fa11ac2155851c44a39f0e651c1889 Mon Sep 17 00:00:00 2001 From: kun-zh Date: Wed, 11 Apr 2018 21:23:04 +0800 Subject: [PATCH 17/18] fix regression failure --- src/schedule/message_passing.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/schedule/message_passing.cc b/src/schedule/message_passing.cc index 2d1102391e62..9e6b15cef22f 100644 --- a/src/schedule/message_passing.cc +++ b/src/schedule/message_passing.cc @@ -479,11 +479,11 @@ std::vector MakeBoundCheck( Expr value = ComputeExpr(value_map.at(iv), iv->dom->min); Expr vmin = EvalSet(value, iset_dmap).min(); Expr vmax = EvalSet(value, iset_dmap).max(); - if (vmin.type() != value.type() || !can_prove(vmin > (iv->dom->min - 1))) { - preds.emplace_back(value > (iv->dom->min - 1)); + if (vmin.type() != value.type() || !can_prove(vmin >= iv->dom->min)) { + preds.emplace_back(value >= 0); } if (vmax.type() != value.type() || !can_prove(vmax < iv->dom->extent)) { - preds.emplace_back(value < iv->dom->extent); + preds.emplace_back(value < (iv->dom->extent - iv->dom->min)); } } } From d72912750db5d24f62d4c76e66f16aa00134a7f6 Mon Sep 17 00:00:00 2001 From: kun-zh Date: Thu, 12 Apr 2018 08:03:50 +0800 Subject: [PATCH 18/18] refine code --- src/schedule/message_passing.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/schedule/message_passing.cc b/src/schedule/message_passing.cc index 9e6b15cef22f..3cea560318d8 100644 --- a/src/schedule/message_passing.cc +++ b/src/schedule/message_passing.cc @@ -477,8 +477,9 @@ std::vector MakeBoundCheck( CHECK(iv->dom.defined()); if (!skip_ivar_domain && !iv->dom.same_as(dom)) { Expr value = ComputeExpr(value_map.at(iv), iv->dom->min); - Expr vmin = EvalSet(value, iset_dmap).min(); - Expr vmax = EvalSet(value, iset_dmap).max(); + IntSet s = EvalSet(value, iset_dmap); + Expr vmin = s.min(); + Expr vmax = s.max(); if (vmin.type() != value.type() || !can_prove(vmin >= iv->dom->min)) { preds.emplace_back(value >= 0); }