From 7dbadf666a29d213de54296fc6f726c28a5730fd Mon Sep 17 00:00:00 2001 From: Yongfeng Gu Date: Mon, 24 Feb 2020 12:21:46 -0500 Subject: [PATCH 1/9] Set split node's range to minimum of ext and split factor or split nparts, but only when PassDownDomain is called with allow_missing == false, i.e. by InferBound. Add a helper PassUpThreadBinding() to get a map telling whether an IterVar has at least one leaf IterVar deriving from it binding to a thread. Add two unit tests. --- src/te/schedule/message_passing.cc | 58 ++++++++++++++++++- .../unittest/test_schedule_bound_inference.py | 26 +++++++++ 2 files changed, 81 insertions(+), 3 deletions(-) diff --git a/src/te/schedule/message_passing.cc b/src/te/schedule/message_passing.cc index 5b6fa861895a..315bca4ec8cb 100644 --- a/src/te/schedule/message_passing.cc +++ b/src/te/schedule/message_passing.cc @@ -51,6 +51,38 @@ void Update(std::unordered_map* p_state, } } +void PassUpThreadBinding(const Stage& stage, std::unordered_map* p_state) { + auto bound_to_thread = [stage](const IterVar& iv) { + bool bound = false; + auto it = stage->iter_var_attrs.find(iv); + if (it != stage->iter_var_attrs.end()) { + bound = (*it).second->bind_thread.defined(); + } + return bound; + }; + + auto& state = *p_state; + // Fill p_state with leaf itervars + for (IterVar iv : stage->leaf_iter_vars) { + state[iv] = bound_to_thread(iv); + } + + for (size_t i = stage->relations.size(); i != 0; --i) { + IterVarRelation rel = stage->relations[i - 1]; + if (const SplitNode* s = rel.as()) { + state[s->parent] = state[s->inner] || state[s->outer]; + } else if (const FuseNode* s = rel.as()) { + state[s->inner] = state[s->fused]; + state[s->outer] = state[s->fused]; + } else if (const RebaseNode* s = rel.as()) { + state[s->parent] = state[s->rebased]; + } else if (rel.as()) { + } else { + LOG(FATAL) << "unknown relation type"; + } + } +} + void PassDownDomain(const Stage& stage, std::unordered_map* p_state, arith::Analyzer* actx, @@ -62,6 +94,17 @@ void PassDownDomain(const Stage& stage, return actx->Simplify(indexdiv(a + (b - 1), b)); }; + auto minimum_or_later = [actx](PrimExpr a, PrimExpr b) { + if (actx->CanProve(a < b)) { + return actx->Simplify(a); + } + return actx->Simplify(b); + }; + + // Construct a map: IterVar -> whether dominating a leaf iterVar binding to a thread + std::unordered_map dominating_thread; + PassUpThreadBinding(stage, &dominating_thread); + auto& state = *p_state; // forwar iteration on relations for (IterVarRelation rel : stage->relations) { @@ -73,13 +116,22 @@ void PassDownDomain(const Stage& stage, CHECK(!state.count(r->inner)); const Range& range_parent = state.at(r->parent); if (r->factor.defined()) { - Update(p_state, r->inner, - Range::make_by_min_extent(0, r->factor), actx); + Update( + p_state, r->inner, + Range::make_by_min_extent(0, dominating_thread[r->inner] || allow_missing + ? r->factor + : minimum_or_later(range_parent->extent, r->factor)), + actx); Update(p_state, r->outer, Range::make_by_min_extent( 0, ceil_div(range_parent->extent, r->factor)), actx); } else { - Update(p_state, r->outer, Range::make_by_min_extent(0, r->nparts), actx); + Update( + p_state, r->outer, + Range::make_by_min_extent(0, dominating_thread[r->outer] || allow_missing + ? r->nparts + : minimum_or_later(range_parent->extent, r->nparts)), + actx); Update(p_state, r->inner, Range::make_by_min_extent( 0, ceil_div(range_parent->extent, r->nparts)), actx); diff --git a/tests/python/unittest/test_schedule_bound_inference.py b/tests/python/unittest/test_schedule_bound_inference.py index 484aa503e066..03ec136756b3 100644 --- a/tests/python/unittest/test_schedule_bound_inference.py +++ b/tests/python/unittest/test_schedule_bound_inference.py @@ -70,6 +70,32 @@ def test_bound3(): assert(bounds[A1.op.axis[0]].extent.value==32) assert(bounds[A1.op.axis[1]].extent.value==16) +def test_bound_split_ext_less_than_factor(): + m = 8 + I = tvm.placeholder((m,), name='I') + EF = tvm.compute((m,), lambda i: I[i] * 2, name = "EF") + E = tvm.compute((m,), lambda i: EF[i] * 2, name = "E") + s = tvm.create_schedule([E.op]) + xo, xi = s[E].split(s[E].op.axis[0], factor = 32) + s[EF].compute_at(s[E], xo) + + bounds = tvm.schedule.InferBound(s) + assert isinstance(bounds, tvm.container.Map) + assert bounds[xi].extent.value == m + +def test_bound_split_ext_less_than_naprts(): + m = 8 + I = tvm.placeholder((m,), name='I') + EF = tvm.compute((m,), lambda i: I[i] * 2, name = "EF") + E = tvm.compute((m,), lambda i: EF[i] * 2, name = "E") + s = tvm.create_schedule([E.op]) + xo, xi = s[E].split(s[E].op.axis[0], nparts = 32) + s[EF].compute_at(s[E], xo) + + bounds = tvm.schedule.InferBound(s) + assert isinstance(bounds, tvm.container.Map) + assert bounds[xo].extent.value == m + def test_bound_split_divisible(): m = te.var('m') l = te.var('l') From 9dc54f76861177a426817f9973939b7e43979fa9 Mon Sep 17 00:00:00 2001 From: Yongfeng Gu Date: Mon, 24 Feb 2020 17:01:04 -0500 Subject: [PATCH 2/9] Enhance LoopVectorizer for vectorizing by 0. Found at least one case from testtopi/tests/python/test_topi_transform.py::test_tile. --- src/tir/pass/vectorize_loop.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/tir/pass/vectorize_loop.cc b/src/tir/pass/vectorize_loop.cc index d62bd1f2584e..9102fc33dc64 100644 --- a/src/tir/pass/vectorize_loop.cc +++ b/src/tir/pass/vectorize_loop.cc @@ -524,9 +524,13 @@ class LoopVectorizer : public StmtMutator { CHECK(is_zero(op->min)); int lanes = 0; bool succ = arith::GetConstInt(op->extent, &lanes); - if (!succ || lanes < 1) { + if (!succ || lanes < 0) { LOG(FATAL) << "Failed to vectorize loop with extent " << op->extent; } + if (lanes == 0) { + // Nothing to run. This may happen when a tensor has 0-sized dimension. + return Stmt(); + } return Vectorizer(op->loop_var, lanes)(op->body); } else { return StmtMutator::VisitStmt_(op); From 96ab20b67de6201994607d232665781b7a2d7f91 Mon Sep 17 00:00:00 2001 From: Yongfeng Gu Date: Mon, 24 Feb 2020 21:51:43 -0500 Subject: [PATCH 3/9] Revert changes vectorize_loop.cc; when parent's ext is zero, set split's range to the factor or nparts. --- src/te/schedule/message_passing.cc | 24 ++++++++++++------------ src/tir/pass/vectorize_loop.cc | 6 +----- 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/src/te/schedule/message_passing.cc b/src/te/schedule/message_passing.cc index 315bca4ec8cb..5c2f8ef94183 100644 --- a/src/te/schedule/message_passing.cc +++ b/src/te/schedule/message_passing.cc @@ -116,22 +116,22 @@ void PassDownDomain(const Stage& stage, CHECK(!state.count(r->inner)); const Range& range_parent = state.at(r->parent); if (r->factor.defined()) { - Update( - p_state, r->inner, - Range::make_by_min_extent(0, dominating_thread[r->inner] || allow_missing - ? r->factor - : minimum_or_later(range_parent->extent, r->factor)), - actx); + Update(p_state, r->inner, + Range::make_by_min_extent( + 0, dominating_thread[r->inner] || allow_missing || is_zero(range_parent->extent) + ? r->factor + : minimum_or_later(range_parent->extent, r->factor)), + actx); Update(p_state, r->outer, Range::make_by_min_extent( 0, ceil_div(range_parent->extent, r->factor)), actx); } else { - Update( - p_state, r->outer, - Range::make_by_min_extent(0, dominating_thread[r->outer] || allow_missing - ? r->nparts - : minimum_or_later(range_parent->extent, r->nparts)), - actx); + Update(p_state, r->outer, + Range::make_by_min_extent( + 0, dominating_thread[r->outer] || allow_missing || is_zero(range_parent->extent) + ? r->nparts + : minimum_or_later(range_parent->extent, r->nparts)), + actx); Update(p_state, r->inner, Range::make_by_min_extent( 0, ceil_div(range_parent->extent, r->nparts)), actx); diff --git a/src/tir/pass/vectorize_loop.cc b/src/tir/pass/vectorize_loop.cc index 9102fc33dc64..d62bd1f2584e 100644 --- a/src/tir/pass/vectorize_loop.cc +++ b/src/tir/pass/vectorize_loop.cc @@ -524,13 +524,9 @@ class LoopVectorizer : public StmtMutator { CHECK(is_zero(op->min)); int lanes = 0; bool succ = arith::GetConstInt(op->extent, &lanes); - if (!succ || lanes < 0) { + if (!succ || lanes < 1) { LOG(FATAL) << "Failed to vectorize loop with extent " << op->extent; } - if (lanes == 0) { - // Nothing to run. This may happen when a tensor has 0-sized dimension. - return Stmt(); - } return Vectorizer(op->loop_var, lanes)(op->body); } else { return StmtMutator::VisitStmt_(op); From 85a4d1ef63e77ba13003603e4b7505fc0ef3e626 Mon Sep 17 00:00:00 2001 From: Yongfeng Gu Date: Tue, 25 Feb 2020 02:28:14 -0500 Subject: [PATCH 4/9] Update with comments. --- src/te/schedule/message_passing.cc | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/src/te/schedule/message_passing.cc b/src/te/schedule/message_passing.cc index 5c2f8ef94183..ef236580a60c 100644 --- a/src/te/schedule/message_passing.cc +++ b/src/te/schedule/message_passing.cc @@ -51,6 +51,13 @@ void Update(std::unordered_map* p_state, } } +/*! + * \param Upward propagating whether an IterVar derives at least one leaf IterVar that binds to + * a thread. + * + * \param stage The stage to operate on. + * \param p_state The propagation result of each IterVar. + */ void PassUpThreadBinding(const Stage& stage, std::unordered_map* p_state) { auto bound_to_thread = [stage](const IterVar& iv) { bool bound = false; @@ -66,7 +73,7 @@ void PassUpThreadBinding(const Stage& stage, std::unordered_map* for (IterVar iv : stage->leaf_iter_vars) { state[iv] = bound_to_thread(iv); } - + // Traverse the graph bottom-up to propagate thread binding information for (size_t i = stage->relations.size(); i != 0; --i) { IterVarRelation rel = stage->relations[i - 1]; if (const SplitNode* s = rel.as()) { @@ -101,7 +108,6 @@ void PassDownDomain(const Stage& stage, return actx->Simplify(b); }; - // Construct a map: IterVar -> whether dominating a leaf iterVar binding to a thread std::unordered_map dominating_thread; PassUpThreadBinding(stage, &dominating_thread); @@ -116,6 +122,17 @@ void PassDownDomain(const Stage& stage, CHECK(!state.count(r->inner)); const Range& range_parent = state.at(r->parent); if (r->factor.defined()) { + // Tighten r->inner's range to min(range_parent->extent, r->factor), only if all of the + // following conditions are met. Same reason for r->out in the split with nparts mode. + // 1. no leaf IterVar derived from r->inner binds to any thread. People may use split + // to force an IterVar extent to match the number of allocated threads to fuse stages + // that require different number of threads. + // 2. allow_missing is false, i.e. that PassDownDomain is called by the final InferBound, + // rather than by an early compiler phase, such as rfactor(). We don't want an IterVar + // to be tightened in an early phase, but bind to a thread later. + // 3. range_parent's extent is not 0. At lest one Topi test has a case where a tensor has one + // zero-sized dimension. Split creates r->inner with a positive extent to avoid zero-extent + // IterVar. We don't touch it. Update(p_state, r->inner, Range::make_by_min_extent( 0, dominating_thread[r->inner] || allow_missing || is_zero(range_parent->extent) From 0a3a2eab2da7f00afbf517c94a82d7ad47e9709f Mon Sep 17 00:00:00 2001 From: Yongfeng Gu Date: Tue, 25 Feb 2020 11:58:10 -0500 Subject: [PATCH 5/9] Refactor the ext tightening predicate. --- src/te/schedule/message_passing.cc | 35 +++++++++++++++--------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/src/te/schedule/message_passing.cc b/src/te/schedule/message_passing.cc index ef236580a60c..38747535e962 100644 --- a/src/te/schedule/message_passing.cc +++ b/src/te/schedule/message_passing.cc @@ -121,23 +121,26 @@ void PassDownDomain(const Stage& stage, } CHECK(!state.count(r->inner)); const Range& range_parent = state.at(r->parent); + // Tighten iv's extent to min(parent_extent, factor_or_nparts), only if all of the + // following conditions are met: + // 1. no leaf IterVar derived from iv binds to any thread. People may use split + // to force an IterVar extent to match the number of allocated threads to fuse stages + // that require different number of threads. We don't want to change these extents. + // 2. allow_missing is false, i.e. that PassDownDomain is called by the final InferBound, + // rather than by an early compiler phase, such as rfactor(). We don't want to tighten an + // IterVar in an early phase allowing missing IterVars, because it may bind to a thread later. + // 3. range_parent's extent is not 0. At lest one Topi test has a case where a tensor has one + // zero-sized dimension. Split creates iv with a positive extent to avoid zero-extent + // IterVar. We don't touch it. + auto resolve_min_extent_for_split = [&](IterVar iv, PrimExpr factor_or_nparts) { + return dominating_thread[iv] || allow_missing || is_zero(range_parent->extent) + ? factor_or_nparts + : minimum_or_later(range_parent->extent, factor_or_nparts); + }; if (r->factor.defined()) { - // Tighten r->inner's range to min(range_parent->extent, r->factor), only if all of the - // following conditions are met. Same reason for r->out in the split with nparts mode. - // 1. no leaf IterVar derived from r->inner binds to any thread. People may use split - // to force an IterVar extent to match the number of allocated threads to fuse stages - // that require different number of threads. - // 2. allow_missing is false, i.e. that PassDownDomain is called by the final InferBound, - // rather than by an early compiler phase, such as rfactor(). We don't want an IterVar - // to be tightened in an early phase, but bind to a thread later. - // 3. range_parent's extent is not 0. At lest one Topi test has a case where a tensor has one - // zero-sized dimension. Split creates r->inner with a positive extent to avoid zero-extent - // IterVar. We don't touch it. Update(p_state, r->inner, Range::make_by_min_extent( - 0, dominating_thread[r->inner] || allow_missing || is_zero(range_parent->extent) - ? r->factor - : minimum_or_later(range_parent->extent, r->factor)), + 0, resolve_min_extent_for_split(r->inner, r->factor)), actx); Update(p_state, r->outer, Range::make_by_min_extent( @@ -145,9 +148,7 @@ void PassDownDomain(const Stage& stage, } else { Update(p_state, r->outer, Range::make_by_min_extent( - 0, dominating_thread[r->outer] || allow_missing || is_zero(range_parent->extent) - ? r->nparts - : minimum_or_later(range_parent->extent, r->nparts)), + 0, resolve_min_extent_for_split(r->outer, r->nparts)), actx); Update(p_state, r->inner, Range::make_by_min_extent( From 4deda0d058f2c699ac5ac9d4b7f170afdecfa3bd Mon Sep 17 00:00:00 2001 From: root Date: Fri, 28 Feb 2020 11:23:28 -0500 Subject: [PATCH 6/9] Fix reference types. --- src/te/schedule/message_passing.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/te/schedule/message_passing.cc b/src/te/schedule/message_passing.cc index 38747535e962..110486114fa3 100644 --- a/src/te/schedule/message_passing.cc +++ b/src/te/schedule/message_passing.cc @@ -59,7 +59,7 @@ void Update(std::unordered_map* p_state, * \param p_state The propagation result of each IterVar. */ void PassUpThreadBinding(const Stage& stage, std::unordered_map* p_state) { - auto bound_to_thread = [stage](const IterVar& iv) { + auto bound_to_thread = [&stage](const IterVar& iv) { bool bound = false; auto it = stage->iter_var_attrs.find(iv); if (it != stage->iter_var_attrs.end()) { @@ -70,7 +70,7 @@ void PassUpThreadBinding(const Stage& stage, std::unordered_map* auto& state = *p_state; // Fill p_state with leaf itervars - for (IterVar iv : stage->leaf_iter_vars) { + for (const IterVar& iv : stage->leaf_iter_vars) { state[iv] = bound_to_thread(iv); } // Traverse the graph bottom-up to propagate thread binding information @@ -94,14 +94,14 @@ void PassDownDomain(const Stage& stage, std::unordered_map* p_state, arith::Analyzer* actx, bool allow_missing) { - auto ceil_div = [actx](PrimExpr a, PrimExpr b) { + auto ceil_div = [actx](const PrimExpr& a, const PrimExpr& b) { if (actx->CanProve(indexmod(a, b) == 0)) { return actx->Simplify(indexdiv(a, b)); } return actx->Simplify(indexdiv(a + (b - 1), b)); }; - auto minimum_or_later = [actx](PrimExpr a, PrimExpr b) { + auto minimum_or_later = [actx](const PrimExpr& a, const PrimExpr& b) { if (actx->CanProve(a < b)) { return actx->Simplify(a); } @@ -132,7 +132,7 @@ void PassDownDomain(const Stage& stage, // 3. range_parent's extent is not 0. At lest one Topi test has a case where a tensor has one // zero-sized dimension. Split creates iv with a positive extent to avoid zero-extent // IterVar. We don't touch it. - auto resolve_min_extent_for_split = [&](IterVar iv, PrimExpr factor_or_nparts) { + auto resolve_min_extent_for_split = [&](const IterVar& iv, const PrimExpr& factor_or_nparts) { return dominating_thread[iv] || allow_missing || is_zero(range_parent->extent) ? factor_or_nparts : minimum_or_later(range_parent->extent, factor_or_nparts); From 34b568c247346b4ab8948f3b1c8563a45587c6c0 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 28 Feb 2020 11:53:57 -0500 Subject: [PATCH 7/9] Integrate tvm.te changes. --- .../unittest/test_schedule_bound_inference.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/python/unittest/test_schedule_bound_inference.py b/tests/python/unittest/test_schedule_bound_inference.py index 03ec136756b3..edae527c0183 100644 --- a/tests/python/unittest/test_schedule_bound_inference.py +++ b/tests/python/unittest/test_schedule_bound_inference.py @@ -72,27 +72,27 @@ def test_bound3(): def test_bound_split_ext_less_than_factor(): m = 8 - I = tvm.placeholder((m,), name='I') - EF = tvm.compute((m,), lambda i: I[i] * 2, name = "EF") - E = tvm.compute((m,), lambda i: EF[i] * 2, name = "E") - s = tvm.create_schedule([E.op]) + I = te.placeholder((m,), name='I') + EF = te.compute((m,), lambda i: I[i] * 2, name = "EF") + E = te.compute((m,), lambda i: EF[i] * 2, name = "E") + s = te.create_schedule([E.op]) xo, xi = s[E].split(s[E].op.axis[0], factor = 32) s[EF].compute_at(s[E], xo) - bounds = tvm.schedule.InferBound(s) + bounds = tvm.te.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) assert bounds[xi].extent.value == m def test_bound_split_ext_less_than_naprts(): m = 8 - I = tvm.placeholder((m,), name='I') - EF = tvm.compute((m,), lambda i: I[i] * 2, name = "EF") - E = tvm.compute((m,), lambda i: EF[i] * 2, name = "E") - s = tvm.create_schedule([E.op]) + I = te.placeholder((m,), name='I') + EF = te.compute((m,), lambda i: I[i] * 2, name = "EF") + E = te.compute((m,), lambda i: EF[i] * 2, name = "E") + s = te.create_schedule([E.op]) xo, xi = s[E].split(s[E].op.axis[0], nparts = 32) s[EF].compute_at(s[E], xo) - bounds = tvm.schedule.InferBound(s) + bounds = tvm.te.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) assert bounds[xo].extent.value == m From c95f7e6f182bba65daeed7c68c3c359980a9d154 Mon Sep 17 00:00:00 2001 From: root Date: Sat, 29 Feb 2020 12:08:16 -0500 Subject: [PATCH 8/9] Trivial comment change to trigger CI. --- src/te/schedule/message_passing.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/te/schedule/message_passing.cc b/src/te/schedule/message_passing.cc index 110486114fa3..10769dff67ae 100644 --- a/src/te/schedule/message_passing.cc +++ b/src/te/schedule/message_passing.cc @@ -123,7 +123,7 @@ void PassDownDomain(const Stage& stage, const Range& range_parent = state.at(r->parent); // Tighten iv's extent to min(parent_extent, factor_or_nparts), only if all of the // following conditions are met: - // 1. no leaf IterVar derived from iv binds to any thread. People may use split + // 1. No leaf IterVar is derived from iv binds to any thread. People may use split // to force an IterVar extent to match the number of allocated threads to fuse stages // that require different number of threads. We don't want to change these extents. // 2. allow_missing is false, i.e. that PassDownDomain is called by the final InferBound, From 40ecacc296fa9ee4b8e79df4fbceda83fe030536 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 4 Mar 2020 10:21:01 -0500 Subject: [PATCH 9/9] Trivial comment correction to trigger testing. --- src/te/schedule/message_passing.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/te/schedule/message_passing.cc b/src/te/schedule/message_passing.cc index 10769dff67ae..a7b248285c4d 100644 --- a/src/te/schedule/message_passing.cc +++ b/src/te/schedule/message_passing.cc @@ -123,7 +123,7 @@ void PassDownDomain(const Stage& stage, const Range& range_parent = state.at(r->parent); // Tighten iv's extent to min(parent_extent, factor_or_nparts), only if all of the // following conditions are met: - // 1. No leaf IterVar is derived from iv binds to any thread. People may use split + // 1. No leaf IterVar derived from iv binds to any thread. People may use split // to force an IterVar extent to match the number of allocated threads to fuse stages // that require different number of threads. We don't want to change these extents. // 2. allow_missing is false, i.e. that PassDownDomain is called by the final InferBound,