From dafa2c346dc9c825dd8d7f27c03e1ab10e8b60d3 Mon Sep 17 00:00:00 2001 From: root Date: Sat, 18 Apr 2020 03:39:03 -0400 Subject: [PATCH 01/13] Improve floormod. --- include/tvm/arith/analyzer.h | 13 ++++ include/tvm/arith/int_set.h | 8 ++ include/tvm/te/operation.h | 7 ++ src/arith/analyzer.cc | 9 +++ src/arith/const_int_bound.cc | 11 ++- src/arith/int_set.cc | 91 ++++++++++++++++++++-- src/te/operation/compute_op.cc | 5 +- src/te/operation/extern_op.cc | 1 + src/te/operation/hybrid_op.cc | 1 + src/te/operation/placeholder_op.cc | 1 + src/te/operation/scan_op.cc | 1 + src/te/operation/tensor_compute_op.cc | 1 + src/te/operation/tensorize.cc | 3 +- src/te/schedule/bound.cc | 4 +- src/te/schedule/message_passing.cc | 2 +- tests/python/unittest/test_arith_intset.py | 3 + 16 files changed, 143 insertions(+), 18 deletions(-) diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 6ca3ba9cfd55..100c135bc274 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -442,6 +442,19 @@ class TVM_DLL Analyzer { * \note Analyzer will call into sub-analyzers to get the result. */ bool CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound); + /*! + * \brief Whether can we prove expr < val. + + * Non-negative proof is very useful in integer analysis + * to lower divisions and mods given difference in trunc and ceil mode. + * + * \param expr The expression. + * \param upper_bound The upper bound. + * \return Whether we can prove it. + * + * \note Analyzer will call into sub-analyzers to get the result. + */ + bool CanProveLess(const PrimExpr& expr, int64_t upper_bound); /*! * \brief Whether can we prove condition. * diff --git a/include/tvm/arith/int_set.h b/include/tvm/arith/int_set.h index 86ef906fef0a..a8206be05b1b 100644 --- a/include/tvm/arith/int_set.h +++ b/include/tvm/arith/int_set.h @@ -167,10 +167,15 @@ IntSet EvalSet(PrimExpr e, * * \param e The expression to be evaluated. * \param dom_map The domain of each variable. + * \param rmap The range of each variable. * \return An integer set that can cover all the possible values of e. */ IntSet EvalSet(PrimExpr e, const std::unordered_map& dom_map); +IntSet EvalSet(PrimExpr e, + const std::unordered_map& dom_map, const std::unordered_map& rmap); +IntSet EvalSet(PrimExpr e, + const std::unordered_map& dom_map, const Map& rmap); /*! * \brief Find an symbolic integer set that contains is union over @@ -198,10 +203,13 @@ IntSet EvalSet(IntSet s, * * \param r The range to be evaluated. * \param dom_map The domain of each variable. + * \param rmap The range of each variable. * \return An integer set that can cover all the possible values of e. */ IntSet EvalSet(Range r, const std::unordered_map& dom_map); +IntSet EvalSet(Range r, + const std::unordered_map& dom_map, const std::unordered_map& rmap); /*! \brief Map from Expr to IntSet */ using ExprIntSetMap = std::unordered_map; diff --git a/include/tvm/te/operation.h b/include/tvm/te/operation.h index 205589928f01..8abee1529c48 100644 --- a/include/tvm/te/operation.h +++ b/include/tvm/te/operation.h @@ -112,6 +112,7 @@ class OperationNode : public tir::FunctionBaseNode { const Operation& self, arith::Analyzer* analyzer, const std::unordered_map& dom_map, + const std::unordered_map& rmap, std::unordered_map* out_dom_map) const = 0; /*! * \brief Gather the bound from output tensor. @@ -176,6 +177,7 @@ class PlaceholderOpNode : public OperationNode { const Operation& self, arith::Analyzer* analyzer, const std::unordered_map& dom_map, + const std::unordered_map& rmap, std::unordered_map* out_dom_map) const final; void GatherBound( const Operation& self, @@ -254,6 +256,7 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode { const Operation& self, arith::Analyzer* analyzer, const std::unordered_map& dom_map, + const std::unordered_map& rmap, std::unordered_map* out_dom_map) const final; Stmt BuildProvide( const Stage& stage, @@ -307,6 +310,7 @@ class TensorComputeOpNode : public BaseComputeOpNode { const Operation& self, arith::Analyzer* analyzer, const std::unordered_map& dom_map, + const std::unordered_map& rmap, std::unordered_map* out_dom_map) const final; Stmt BuildProvide( const Stage& stage, @@ -382,6 +386,7 @@ class ScanOpNode : public OperationNode { const Operation& self, arith::Analyzer* analyzer, const std::unordered_map& dom_map, + const std::unordered_map& rmap, std::unordered_map* out_dom_map) const final; void GatherBound( const Operation& self, @@ -449,6 +454,7 @@ class ExternOpNode : public OperationNode { const Operation& self, arith::Analyzer* analyzer, const std::unordered_map& dom_map, + const std::unordered_map& rmap, std::unordered_map* out_dom_map) const final; void GatherBound( const Operation& self, @@ -517,6 +523,7 @@ class HybridOpNode : public OperationNode { const Operation& self, arith::Analyzer* analyzer, const std::unordered_map& dom_map, + const std::unordered_map& rmap, std::unordered_map* out_dom_map) const final; void GatherBound( const Operation& self, diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 83dfc64009cf..442b2b5516b3 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -92,6 +92,15 @@ bool Analyzer::CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound) { return false; } +bool Analyzer::CanProveLess(const PrimExpr& expr, int64_t upper_bound) { + if (const auto* ptr = expr.as()) { + return ptr->value < upper_bound; + } + auto bd = this->const_int_bound(this->rewrite_simplify(expr)); + if (bd->max_value < upper_bound) return true; + return false; +} + bool Analyzer::CanProve(const PrimExpr& expr) { if (const auto* ptr = expr.as()) { return ptr->value != 0; diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index 57dfc157fc21..a1199558dc47 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -150,10 +150,13 @@ class ConstIntBoundAnalyzer::Impl : const PrimExprNode* op = expr.as(); auto val = bound_->find(op); if (val != bound_->end()) { - CHECK(val->second->min_value == res.min_value && - val->second->max_value == res.max_value) - << "Detected bound for " << expr - << "conflicts with memorization"; + auto everything = Everything(op->dtype); + CHECK( + (val->second->min_value == res.min_value && + val->second->max_value == res.max_value) || + (val->second->min_value == everything.min_value && + val->second->max_value == everything.max_value)) + << "Detected bound for " << expr << "conflicts with memorization"; } (*bound_)[op] = ConstIntBound(res.min_value, res.max_value); } diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 027259a4d225..27a747a49533 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -311,6 +311,21 @@ inline IntervalSet Combine(Analyzer* analyzer, LOG(FATAL) << "Modular by zero in CombineInterval Mod"; } if (analyzer->CanProveGreaterEqual(divisor, 0)) { + if (const auto* ptr = b->min_value.as()) { + // a % b = a - b * (a/b) if + // (a) a_max - a_min < b, i.e. that before mod, a's range doesn't cover [0, b) + // and (b) a_min % b <= a_max % b, i.e. that a's range is still continuous after mod + auto tmax = a->max_value - b->min_value * floordiv(a->max_value, b->min_value); + tmax = analyzer->Simplify(tmax); + auto tmin = a->min_value - b->min_value * floordiv(a->min_value, b->min_value); + tmin = analyzer->Simplify(tmin); + auto tset = IntervalSet(tmin, tmax); + bool within_range = analyzer->CanProveLess(a->max_value - a->min_value, ptr->value); + bool wrap_around = analyzer->CanProve(tset->max_value < tset->min_value); + if (within_range && !wrap_around) { + return tset; + } + } return IntervalSet(make_zero(divisor.dtype()), divisor - 1); } else { PrimExpr bound = abs(divisor) - 1; @@ -372,7 +387,27 @@ class IntervalSetEvaluator : } IntervalSet Eval(const PrimExpr& val) { - return this->VisitExpr(val); + IntervalSet result = this->VisitExpr(val); + // Use the IterVar range info bound to analyzer to further simplify + // and reduce the interval + auto min_value_expr = analyzer_->Simplify(result->min_value); + auto max_value_expr = analyzer_->Simplify(result->max_value); + auto min_bd = analyzer_->const_int_bound(min_value_expr); + auto max_bd = analyzer_->const_int_bound(max_value_expr); + if (min_bd->max_value == min_bd->min_value && max_bd->max_value == max_bd->min_value) { + const auto* min_ptr = result->min_value.as(); + const auto* max_ptr = result->max_value.as(); + // The following if statement is necessary. When result is a single point of IntImm, such as + // [0, 0], both 0s refer the same ObjectRef. We really don't want to create a new [0, 0] + // IntervalSet and have 0s refer two different ObjectRef. They will confuse APIs, such as + // IntervalSetEvaluator::MatchPoint() and IntervalSetNode::IsSinglePoint(). + if (min_ptr && max_ptr && min_bd->min_value == min_ptr->value && + max_bd->max_value == max_ptr->value) { + return result; + } + return IntervalSet(static_cast(min_bd->min_value), static_cast(max_bd->max_value)); + } + return result; } // evaluate and relax the set IntervalSet Eval(IntervalSet val) { @@ -736,12 +771,30 @@ Map ConvertDomMap( return dmap; } -IntSet EvalSet(PrimExpr e, - const Map& dom_map) { +IntSet EvalSet(PrimExpr e, const Map& dom_map, + const std::unordered_map& rmap) { Analyzer ana; + // Bind ana with rmap + for (auto entry : rmap) { + ana.Bind(entry.first->var, entry.second); + } return IntervalSetEvaluator(&ana, dom_map, false).Eval(e); } +IntSet EvalSet(PrimExpr e, const Map& dom_map, const Map& rmap) { + Analyzer ana; + // Bind ana with rmap + for (auto entry : rmap) { + ana.Bind(entry.first->var, entry.second); + } + return IntervalSetEvaluator(&ana, dom_map, false).Eval(e); +} + +IntSet EvalSet(PrimExpr e, + const Map& dom_map) { + return EvalSet(e, dom_map, std::unordered_map()); +} + IntSet IntSet::vector(PrimExpr x) { Analyzer ana; Map dmap; @@ -753,14 +806,27 @@ IntSet EvalSet(PrimExpr e, return EvalSet(e, ConvertDomMap(dom_map)); } -IntSet EvalSet(PrimExpr e, - const std::unordered_map& dom_map) { +IntSet EvalSet(PrimExpr e, const std::unordered_map& dom_map) { return EvalSet(e, ConvertDomMap(dom_map)); } -IntSet EvalSet(Range r, - const Map& dom_map) { +IntSet EvalSet(PrimExpr e, const std::unordered_map& dom_map, + const std::unordered_map& rmap) { + return EvalSet(e, ConvertDomMap(dom_map), rmap); +} + +IntSet EvalSet(PrimExpr e, const std::unordered_map& dom_map, + const Map& rmap) { + return EvalSet(e, ConvertDomMap(dom_map), rmap); +} + +IntSet EvalSet(Range r, const Map& dom_map, + const std::unordered_map& rmap) { Analyzer ana; + // Bind ana with rmap + for (auto entry : rmap) { + ana.Bind(entry.first->var, entry.second); + } IntervalSetEvaluator m(&ana, dom_map); // Simplifying first can give tighter bounds if r->min and r->extent share variables PrimExpr sum = r->min + r->extent - 1; @@ -769,10 +835,19 @@ IntSet EvalSet(Range r, } IntSet EvalSet(Range r, - const std::unordered_map& dom_map) { + const Map& dom_map) { + return EvalSet(r, dom_map, std::unordered_map()); +} + +IntSet EvalSet(Range r, const std::unordered_map& dom_map) { return EvalSet(r, ConvertDomMap(dom_map)); } +IntSet EvalSet(Range r, const std::unordered_map& dom_map, + const std::unordered_map& rmap) { + return EvalSet(r, ConvertDomMap(dom_map), rmap); +} + IntSet EvalSet(IntSet s, const std::unordered_map& dom_map) { Analyzer ana; diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index 6f703c9ec4e3..25194dd3014a 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -218,9 +218,10 @@ void ComputeOpNode::PropBoundToInputs( const Operation& self, arith::Analyzer* analyzer, const std::unordered_map& dom_map, + const std::unordered_map& rmap, std::unordered_map* out_dom_map) const { CHECK_EQ(self.operator->(), this); - auto fvisit = [&dom_map, out_dom_map, analyzer](const ObjectRef& n) { + auto fvisit = [&dom_map, &rmap, out_dom_map, analyzer](const ObjectRef& n) { auto *call = n.as(); if (call != nullptr && call->func.defined()) { Tensor t = Downcast(call->func).output(call->value_index); @@ -231,7 +232,7 @@ void ComputeOpNode::PropBoundToInputs( // undefined behaviour), so we can intersect the estimated set of the argument with the // range expected by the tensor. However, intersection may result in overly complex // expressions, so we perform a more relaxed form of intersection. - IntSet arg_intset = EvalSet(call->args[i], dom_map); + IntSet arg_intset = EvalSet(call->args[i], dom_map, rmap); const arith::IntervalSetNode* arg_interval = arg_intset.as(); if (arg_interval) { PrimExpr shape_i_min_value = make_zero(t->shape[i].dtype()); diff --git a/src/te/operation/extern_op.cc b/src/te/operation/extern_op.cc index 9d95e329c8f2..9cc6fb0297b7 100644 --- a/src/te/operation/extern_op.cc +++ b/src/te/operation/extern_op.cc @@ -120,6 +120,7 @@ void ExternOpNode::PropBoundToInputs( const Operation& self, arith::Analyzer* analyzer, const std::unordered_map& dom_map, + const std::unordered_map& rmap, std::unordered_map* out_dom_map) const { for (Tensor t : this->inputs) { auto it = out_dom_map->find(t); diff --git a/src/te/operation/hybrid_op.cc b/src/te/operation/hybrid_op.cc index 4da127ea0a85..fc9646275146 100644 --- a/src/te/operation/hybrid_op.cc +++ b/src/te/operation/hybrid_op.cc @@ -136,6 +136,7 @@ void HybridOpNode::PropBoundToInputs( const Operation &self, arith::Analyzer* analyzer, const std::unordered_map &dom_map, + const std::unordered_map& rmap, std::unordered_map* out_dom_map) const { auto curr_inputs = InputTensors(); for (Tensor t : curr_inputs) { diff --git a/src/te/operation/placeholder_op.cc b/src/te/operation/placeholder_op.cc index d48be4c53668..1efb6909209a 100644 --- a/src/te/operation/placeholder_op.cc +++ b/src/te/operation/placeholder_op.cc @@ -87,6 +87,7 @@ void PlaceholderOpNode::PropBoundToInputs( const Operation& self, arith::Analyzer* analyzer, const std::unordered_map& dom_map, + const std::unordered_map& rmap, std::unordered_map* out_dom_map) const { } diff --git a/src/te/operation/scan_op.cc b/src/te/operation/scan_op.cc index 1916b4a4823e..5974af8a2945 100644 --- a/src/te/operation/scan_op.cc +++ b/src/te/operation/scan_op.cc @@ -183,6 +183,7 @@ void ScanOpNode::PropBoundToInputs( const Operation& self, arith::Analyzer* analyzer, const std::unordered_map& dom_map, + const std::unordered_map& rmap, std::unordered_map* out_dom_map) const { CHECK_EQ(self.operator->(), this); for (size_t i = 0, sp_idx = 0; i < this->init.size(); ++i) { diff --git a/src/te/operation/tensor_compute_op.cc b/src/te/operation/tensor_compute_op.cc index 4cdc9e1f8d32..5e41fc7e2354 100644 --- a/src/te/operation/tensor_compute_op.cc +++ b/src/te/operation/tensor_compute_op.cc @@ -116,6 +116,7 @@ void TensorComputeOpNode::PropBoundToInputs( const Operation& self, arith::Analyzer* analyzer, const std::unordered_map& dom_map, + const std::unordered_map& rmap, std::unordered_map* out_dom_map) const { for (size_t i = 0; i < this->inputs.size(); ++i) { Tensor t = this->inputs[i]; diff --git a/src/te/operation/tensorize.cc b/src/te/operation/tensorize.cc index b66406969c76..65514b833b06 100644 --- a/src/te/operation/tensorize.cc +++ b/src/te/operation/tensorize.cc @@ -99,7 +99,8 @@ size_t InferTensorizeRegion( temp_dmap[iv->var.get()] = iset; } // Input domains - self->PropBoundToInputs(stage->op, &analyzer, temp_dmap, &in_dom); + self->PropBoundToInputs(stage->op, &analyzer, temp_dmap, std::unordered_map(), + &in_dom); Range none; for (const auto& kv : in_dom) { Array vec; diff --git a/src/te/schedule/bound.cc b/src/te/schedule/bound.cc index 50cbafd2b654..15023ebbfbcb 100644 --- a/src/te/schedule/bound.cc +++ b/src/te/schedule/bound.cc @@ -199,13 +199,13 @@ void InferRootBound(const Stage& stage, r = iv->dom; } if (relax_set.size() != 0) { - dom_map[iv->var.get()] = EvalSet(r, relax_set); + dom_map[iv->var.get()] = EvalSet(r, relax_set, *rmap); } else { dom_map[iv->var.get()] = IntSet::range(r); } analyzer.Bind(iv->var, r); } - op->PropBoundToInputs(op, &analyzer, dom_map, &tmap); + op->PropBoundToInputs(op, &analyzer, dom_map, *rmap, &tmap); } stage->op->GatherBound(stage->op, tmap, rmap); } diff --git a/src/te/schedule/message_passing.cc b/src/te/schedule/message_passing.cc index 6ed9438ec90f..fc8b2ab628b2 100644 --- a/src/te/schedule/message_passing.cc +++ b/src/te/schedule/message_passing.cc @@ -604,7 +604,7 @@ std::vector MakeBoundCheck( CHECK(iv->dom.defined()); if (!skip_ivar_domain && !IsRangeSame(iv->dom, dom)) { PrimExpr value = value_map.at(iv) - iv->dom->min; - IntSet s = EvalSet(value, iset_dmap); + IntSet s = EvalSet(value, iset_dmap, dom_map); PrimExpr vmin = s.min(); PrimExpr vmax = s.max(); // The range of `value` resides in [vmin, vmax] diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index e57dcef75994..276e38bebd60 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -90,6 +90,9 @@ def test_mod(): flm = tvm.te.floormod ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(-10, 10)}, (0, 9)) + ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(3, 5)}, (3, 5)) + ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(3, 15)}, (0, 9)) + ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(3, 11)}, (0, 9)) def test_max_min(): From 40b76753c6d34a02091e67a2f24a746ba1b33de1 Mon Sep 17 00:00:00 2001 From: root Date: Sat, 18 Apr 2020 15:05:07 -0400 Subject: [PATCH 02/13] Add tests and lint fix --- include/tvm/arith/int_set.h | 12 ++++++------ src/arith/const_int_bound.cc | 5 ++--- tests/python/unittest/test_arith_intset.py | 9 +++++++++ 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/include/tvm/arith/int_set.h b/include/tvm/arith/int_set.h index a8206be05b1b..afd251668c21 100644 --- a/include/tvm/arith/int_set.h +++ b/include/tvm/arith/int_set.h @@ -172,10 +172,10 @@ IntSet EvalSet(PrimExpr e, */ IntSet EvalSet(PrimExpr e, const std::unordered_map& dom_map); -IntSet EvalSet(PrimExpr e, - const std::unordered_map& dom_map, const std::unordered_map& rmap); -IntSet EvalSet(PrimExpr e, - const std::unordered_map& dom_map, const Map& rmap); +IntSet EvalSet(PrimExpr e, const std::unordered_map& dom_map, + const std::unordered_map& rmap); +IntSet EvalSet(PrimExpr e, const std::unordered_map& dom_map, + const Map& rmap); /*! * \brief Find an symbolic integer set that contains is union over @@ -208,8 +208,8 @@ IntSet EvalSet(IntSet s, */ IntSet EvalSet(Range r, const std::unordered_map& dom_map); -IntSet EvalSet(Range r, - const std::unordered_map& dom_map, const std::unordered_map& rmap); +IntSet EvalSet(Range r, const std::unordered_map& dom_map, + const std::unordered_map& rmap); /*! \brief Map from Expr to IntSet */ using ExprIntSetMap = std::unordered_map; diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index a1199558dc47..69aa85e86e25 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -152,10 +152,9 @@ class ConstIntBoundAnalyzer::Impl : if (val != bound_->end()) { auto everything = Everything(op->dtype); CHECK( - (val->second->min_value == res.min_value && - val->second->max_value == res.max_value) || + (val->second->min_value == res.min_value && val->second->max_value == res.max_value) || (val->second->min_value == everything.min_value && - val->second->max_value == everything.max_value)) + val->second->max_value == everything.max_value)) << "Detected bound for " << expr << "conflicts with memorization"; } (*bound_)[op] = ConstIntBound(res.min_value, res.max_value); diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index 276e38bebd60..1ac64edfc1a6 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -94,6 +94,15 @@ def test_mod(): ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(3, 15)}, (0, 9)) ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(3, 11)}, (0, 9)) + floordiv = tvm.te.floordiv + z = te.var("z") + ck.analyzer.bind(x, tvm.ir.Range.make_by_min_extent(0, 3)) + ck.verify(flm(y, 8), {y : tvm.arith.IntervalSet(z*8+x*4, z*8+x*4+3)}, + (x*4-floordiv(x, 2)*8, x*4+3-floordiv(x*4+3, 8)*8)) + ck1 = IntSetChecker() + ck1.analyzer.bind(x, tvm.ir.Range.make_by_min_extent(0, 2)) + ck1.verify(flm(y, 8), {y : tvm.arith.IntervalSet(z*8+x*4, z*8+x*4+3)}, (x*4, x*4+3)) + def test_max_min(): ck = IntSetChecker() From 71f899890dbcab120f1cb2ef959e6b9da659ff57 Mon Sep 17 00:00:00 2001 From: root Date: Sat, 18 Apr 2020 20:08:37 -0400 Subject: [PATCH 03/13] Add a test locking down the improved floormod behavior during bound inference --- ...test_te_schedule_bound_inference_tiling.py | 73 +++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 tests/python/unittest/test_te_schedule_bound_inference_tiling.py diff --git a/tests/python/unittest/test_te_schedule_bound_inference_tiling.py b/tests/python/unittest/test_te_schedule_bound_inference_tiling.py new file mode 100644 index 000000000000..fbadf995d32d --- /dev/null +++ b/tests/python/unittest/test_te_schedule_bound_inference_tiling.py @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import te + +def test_bound_tile_mod(): + def compute(M_tiles, N_tiles, L, factor, dtype): + # Algo + M = M_tiles * factor + N = N_tiles * factor + + A = tvm.te.placeholder((M, L), name='A', dtype=dtype) + C = tvm.te.compute((N, M), lambda n, m: A[n, m], name='C') + s = tvm.te.create_schedule(C.op) + + return s, A, C + + def schedule(s, L, factor, padding, A, C): + switch = True + C_local = s.cache_write(C, "local") + + + n, m = C.op.axis + if switch: + bn, bm, ni, mi = s[C].tile(n, m, factor, factor) + else: + ni, mi = C.op.axis + nio, nii = s[C].split(ni, 2) + n = s[C].fuse(nii, mi) + C_shared = s.cache_write(C, "shared") + if switch: + bn, bm, ni, mi = C_shared.op.axis + else: + ni, mi = C_shared.op.axis + + s[C_shared].storage_align(ni, factor * 2, padding) + + n, m = s[C].op.axis + bn, bm, ni, mi = s[C].tile(n, m, factor, factor) + s[C].set_scope("global") + niio, niii = s[C].split(ni, 32) + + + s[C_shared].compute_at(s[C], niio) + + + return s + + s, A, C = compute(2, 2, 64, 128, "float16") + s = schedule(s, 64, 128, 8, A, C) + + bounds = tvm.te.schedule.InferBound(s) + check = (bounds[s.stages[2].op.axis[2]].extent == 16) + if(not check): + print(tvm.lower(s, [A, C], simple_mode=True)) + assert(check) + +if __name__ == "__main__": + test_bound_tile_mod() From 4a8e173e9814a4435b56eac2f0ae853151887d15 Mon Sep 17 00:00:00 2001 From: root Date: Sat, 18 Apr 2020 20:36:13 -0400 Subject: [PATCH 04/13] Add missing comments --- include/tvm/arith/int_set.h | 26 ++++++++++++++++++++++++-- include/tvm/te/operation.h | 1 + 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/include/tvm/arith/int_set.h b/include/tvm/arith/int_set.h index afd251668c21..450558062c25 100644 --- a/include/tvm/arith/int_set.h +++ b/include/tvm/arith/int_set.h @@ -167,13 +167,28 @@ IntSet EvalSet(PrimExpr e, * * \param e The expression to be evaluated. * \param dom_map The domain of each variable. - * \param rmap The range of each variable. * \return An integer set that can cover all the possible values of e. */ IntSet EvalSet(PrimExpr e, const std::unordered_map& dom_map); +/*! + * \brief Same as EvalSet, but takes unordered_map + * + * \param e The expression to be evaluated. + * \param dom_map The domain of each variable. + * \param rmap The range of each variable. + * \return An integer set that can cover all the possible values of e. + */ IntSet EvalSet(PrimExpr e, const std::unordered_map& dom_map, const std::unordered_map& rmap); +/*! + * \brief Same as EvalSet, but takes unordered_map + * + * \param e The expression to be evaluated. + * \param dom_map The domain of each variable. + * \param rmap The range of each variable. + * \return An integer set that can cover all the possible values of e. + */ IntSet EvalSet(PrimExpr e, const std::unordered_map& dom_map, const Map& rmap); @@ -203,11 +218,18 @@ IntSet EvalSet(IntSet s, * * \param r The range to be evaluated. * \param dom_map The domain of each variable. - * \param rmap The range of each variable. * \return An integer set that can cover all the possible values of e. */ IntSet EvalSet(Range r, const std::unordered_map& dom_map); +/*! + * \brief Same as EvalSet, but takes unordered_map + * + * \param r The range to be evaluated. + * \param dom_map The domain of each variable. + * \param rmap The range of each variable. + * \return An integer set that can cover all the possible values of e. + */ IntSet EvalSet(Range r, const std::unordered_map& dom_map, const std::unordered_map& rmap); diff --git a/include/tvm/te/operation.h b/include/tvm/te/operation.h index 8abee1529c48..7d88216bcd19 100644 --- a/include/tvm/te/operation.h +++ b/include/tvm/te/operation.h @@ -104,6 +104,7 @@ class OperationNode : public tir::FunctionBaseNode { * \param self The reference to self. * \param analyzer The analyzer to be used in the function. * \param dom_map the domain map of Variables(corresponds to root_iter_vars) + * \param rmap The range of variables (not only root_iter_vars) to improve propagation accuracy. * \param out_dom_map The output domain. * The function is only asked to fill the bounds for Tensors that * is already in the out_dom_map From d04348865e50fd8e96329649cd6ba91eaab66267 Mon Sep 17 00:00:00 2001 From: root Date: Sun, 19 Apr 2020 00:57:44 -0400 Subject: [PATCH 05/13] Update tests and comments --- src/arith/int_set.cc | 6 ++--- tests/python/unittest/test_arith_intset.py | 1 + ...test_te_schedule_bound_inference_tiling.py | 27 +++++-------------- 3 files changed, 11 insertions(+), 23 deletions(-) diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 27a747a49533..8eb0b6ac3158 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -312,9 +312,9 @@ inline IntervalSet Combine(Analyzer* analyzer, } if (analyzer->CanProveGreaterEqual(divisor, 0)) { if (const auto* ptr = b->min_value.as()) { - // a % b = a - b * (a/b) if - // (a) a_max - a_min < b, i.e. that before mod, a's range doesn't cover [0, b) - // and (b) a_min % b <= a_max % b, i.e. that a's range is still continuous after mod + // a mod b = a - b * (a/b) if + // (i) a_max - a_min < b, i.e. that before mod, a's range doesn't cover [0, b) + // and (ii) a_min mod b <= a_max mod b, i.e. that a's range is still continuous after mod auto tmax = a->max_value - b->min_value * floordiv(a->max_value, b->min_value); tmax = analyzer->Simplify(tmax); auto tmin = a->min_value - b->min_value * floordiv(a->min_value, b->min_value); diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index 1ac64edfc1a6..00b7d9ca646e 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -91,6 +91,7 @@ def test_mod(): flm = tvm.te.floormod ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(-10, 10)}, (0, 9)) ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(3, 5)}, (3, 5)) + ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(13, 15)}, (3, 5)) ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(3, 15)}, (0, 9)) ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(3, 11)}, (0, 9)) diff --git a/tests/python/unittest/test_te_schedule_bound_inference_tiling.py b/tests/python/unittest/test_te_schedule_bound_inference_tiling.py index fbadf995d32d..3893bb6befda 100644 --- a/tests/python/unittest/test_te_schedule_bound_inference_tiling.py +++ b/tests/python/unittest/test_te_schedule_bound_inference_tiling.py @@ -18,51 +18,38 @@ from tvm import te def test_bound_tile_mod(): - def compute(M_tiles, N_tiles, L, factor, dtype): + def compute(M_tiles, N_tiles, factor, dtype): # Algo M = M_tiles * factor N = N_tiles * factor - A = tvm.te.placeholder((M, L), name='A', dtype=dtype) + A = tvm.te.placeholder((N, M), name='A', dtype=dtype) C = tvm.te.compute((N, M), lambda n, m: A[n, m], name='C') s = tvm.te.create_schedule(C.op) return s, A, C - def schedule(s, L, factor, padding, A, C): - switch = True + def schedule(s, factor, padding, A, C): C_local = s.cache_write(C, "local") - n, m = C.op.axis - if switch: - bn, bm, ni, mi = s[C].tile(n, m, factor, factor) - else: - ni, mi = C.op.axis + bn, bm, ni, mi = s[C].tile(n, m, factor, factor) nio, nii = s[C].split(ni, 2) n = s[C].fuse(nii, mi) C_shared = s.cache_write(C, "shared") - if switch: - bn, bm, ni, mi = C_shared.op.axis - else: - ni, mi = C_shared.op.axis - + bn, bm, ni, mi = C_shared.op.axis s[C_shared].storage_align(ni, factor * 2, padding) n, m = s[C].op.axis bn, bm, ni, mi = s[C].tile(n, m, factor, factor) s[C].set_scope("global") niio, niii = s[C].split(ni, 32) - - s[C_shared].compute_at(s[C], niio) - return s - s, A, C = compute(2, 2, 64, 128, "float16") - s = schedule(s, 64, 128, 8, A, C) - + s, A, C = compute(2, 2, 128, "float16") + s = schedule(s, 128, 8, A, C) bounds = tvm.te.schedule.InferBound(s) check = (bounds[s.stages[2].op.axis[2]].extent == 16) if(not check): From ca883d565cec19ac97bf5bfb68386d91d25e5b3f Mon Sep 17 00:00:00 2001 From: root Date: Wed, 22 Apr 2020 03:58:41 -0400 Subject: [PATCH 06/13] Address review comments: use analyzer->int_set instead of EvalSet. --- include/tvm/arith/analyzer.h | 8 +-- include/tvm/arith/int_set.h | 51 ++++++------------ include/tvm/te/operation.h | 8 --- src/arith/analyzer.cc | 20 +++---- src/arith/const_int_bound.cc | 8 +-- src/arith/int_set.cc | 75 +++++++-------------------- src/te/operation/compute_op.cc | 13 ++--- src/te/operation/extern_op.cc | 1 - src/te/operation/hybrid_op.cc | 1 - src/te/operation/placeholder_op.cc | 1 - src/te/operation/scan_op.cc | 1 - src/te/operation/tensor_compute_op.cc | 1 - src/te/operation/tensorize.cc | 3 +- src/te/schedule/bound.cc | 17 +++--- src/te/schedule/message_passing.cc | 12 +++-- 15 files changed, 78 insertions(+), 142 deletions(-) diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 100c135bc274..258fb47529d5 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -139,7 +139,7 @@ class ConstIntBoundAnalyzer { * \param var The variable. * \param range The range we bind to. */ - TVM_DLL void Bind(const Var& var, const Range& range); + TVM_DLL void Bind(const Var& var, const Range& range, bool override = false); private: friend class Analyzer; @@ -412,7 +412,7 @@ class TVM_DLL Analyzer { * \param var The variable. * \param expr The expression we bind to. */ - void Bind(const Var& var, const PrimExpr& expr); + void Bind(const Var& var, const PrimExpr& expr, bool override = false); /*! * \brief Notify all the sub-analyzers that var * is created and binded to a range. @@ -422,13 +422,13 @@ class TVM_DLL Analyzer { * \param var The variable. * \param range The range we bind to. */ - void Bind(const Var& var, const Range& range); + void Bind(const Var& var, const Range& range, bool override = false); /*! * \brief Bind all the vars in the Map * * \param variables The {variable -> range} map. */ - void Bind(const Map& variables); + void Bind(const Map& variables, bool override = false); /*! * \brief Whether can we prove expr >= val. diff --git a/include/tvm/arith/int_set.h b/include/tvm/arith/int_set.h index 450558062c25..b253ca636f52 100644 --- a/include/tvm/arith/int_set.h +++ b/include/tvm/arith/int_set.h @@ -152,6 +152,22 @@ class IntSet : public ObjectRef { //----------------------------------------------- // Integer set legacy API. //------------------------------------------------ + /*! + * \brief Convert std::unordered_map to Map + * + * \param dom_map The domain map to convert. + * \return The converted map. + */ +Map ConvertDomMap(const std::unordered_map& dom_map); +// /*! +// * \brief Find an symbolic integer set that contains all possible values of +// * e given the domain of each iteration variables. +// * +// * \param e The expression to be evaluated. +// * \param dom_map The domain of each variable. +// * \return An integer set that can cover all the possible values of e. +// */ +// IntSet EvalSet(PrimExpr e, const Map& dom_map); /*! * \brief Find an symbolic integer set that contains all possible values of * e given the domain of each iteration variables. @@ -160,8 +176,7 @@ class IntSet : public ObjectRef { * \param dom_map The domain of each variable. * \return An integer set that can cover all the possible values of e. */ -IntSet EvalSet(PrimExpr e, - const Map& dom_map); +IntSet EvalSet(PrimExpr e, const Map& dom_map); /*! * \brief Same as EvalSet, but takes unordered_map * @@ -171,27 +186,6 @@ IntSet EvalSet(PrimExpr e, */ IntSet EvalSet(PrimExpr e, const std::unordered_map& dom_map); -/*! - * \brief Same as EvalSet, but takes unordered_map - * - * \param e The expression to be evaluated. - * \param dom_map The domain of each variable. - * \param rmap The range of each variable. - * \return An integer set that can cover all the possible values of e. - */ -IntSet EvalSet(PrimExpr e, const std::unordered_map& dom_map, - const std::unordered_map& rmap); -/*! - * \brief Same as EvalSet, but takes unordered_map - * - * \param e The expression to be evaluated. - * \param dom_map The domain of each variable. - * \param rmap The range of each variable. - * \return An integer set that can cover all the possible values of e. - */ -IntSet EvalSet(PrimExpr e, const std::unordered_map& dom_map, - const Map& rmap); - /*! * \brief Find an symbolic integer set that contains is union over * all the possible conditional values in dom_map. @@ -222,17 +216,6 @@ IntSet EvalSet(IntSet s, */ IntSet EvalSet(Range r, const std::unordered_map& dom_map); -/*! - * \brief Same as EvalSet, but takes unordered_map - * - * \param r The range to be evaluated. - * \param dom_map The domain of each variable. - * \param rmap The range of each variable. - * \return An integer set that can cover all the possible values of e. - */ -IntSet EvalSet(Range r, const std::unordered_map& dom_map, - const std::unordered_map& rmap); - /*! \brief Map from Expr to IntSet */ using ExprIntSetMap = std::unordered_map; /*! diff --git a/include/tvm/te/operation.h b/include/tvm/te/operation.h index 7d88216bcd19..205589928f01 100644 --- a/include/tvm/te/operation.h +++ b/include/tvm/te/operation.h @@ -104,7 +104,6 @@ class OperationNode : public tir::FunctionBaseNode { * \param self The reference to self. * \param analyzer The analyzer to be used in the function. * \param dom_map the domain map of Variables(corresponds to root_iter_vars) - * \param rmap The range of variables (not only root_iter_vars) to improve propagation accuracy. * \param out_dom_map The output domain. * The function is only asked to fill the bounds for Tensors that * is already in the out_dom_map @@ -113,7 +112,6 @@ class OperationNode : public tir::FunctionBaseNode { const Operation& self, arith::Analyzer* analyzer, const std::unordered_map& dom_map, - const std::unordered_map& rmap, std::unordered_map* out_dom_map) const = 0; /*! * \brief Gather the bound from output tensor. @@ -178,7 +176,6 @@ class PlaceholderOpNode : public OperationNode { const Operation& self, arith::Analyzer* analyzer, const std::unordered_map& dom_map, - const std::unordered_map& rmap, std::unordered_map* out_dom_map) const final; void GatherBound( const Operation& self, @@ -257,7 +254,6 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode { const Operation& self, arith::Analyzer* analyzer, const std::unordered_map& dom_map, - const std::unordered_map& rmap, std::unordered_map* out_dom_map) const final; Stmt BuildProvide( const Stage& stage, @@ -311,7 +307,6 @@ class TensorComputeOpNode : public BaseComputeOpNode { const Operation& self, arith::Analyzer* analyzer, const std::unordered_map& dom_map, - const std::unordered_map& rmap, std::unordered_map* out_dom_map) const final; Stmt BuildProvide( const Stage& stage, @@ -387,7 +382,6 @@ class ScanOpNode : public OperationNode { const Operation& self, arith::Analyzer* analyzer, const std::unordered_map& dom_map, - const std::unordered_map& rmap, std::unordered_map* out_dom_map) const final; void GatherBound( const Operation& self, @@ -455,7 +449,6 @@ class ExternOpNode : public OperationNode { const Operation& self, arith::Analyzer* analyzer, const std::unordered_map& dom_map, - const std::unordered_map& rmap, std::unordered_map* out_dom_map) const final; void GatherBound( const Operation& self, @@ -524,7 +517,6 @@ class HybridOpNode : public OperationNode { const Operation& self, arith::Analyzer* analyzer, const std::unordered_map& dom_map, - const std::unordered_map& rmap, std::unordered_map* out_dom_map) const final; void GatherBound( const Operation& self, diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 442b2b5516b3..9199bace4997 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -36,31 +36,31 @@ Analyzer::Analyzer() int_set(this) { } -void Analyzer::Bind(const Var& var, const PrimExpr& expr) { +void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool override) { PrimExpr new_expr = expr; new_expr = this->canonical_simplify(new_expr); new_expr = this->rewrite_simplify(new_expr); - this->const_int_bound.Update(var, this->const_int_bound(new_expr)); - this->modular_set.Update(var, this->modular_set(new_expr)); - this->rewrite_simplify.Update(var, new_expr); - this->canonical_simplify.Update(var, new_expr); + this->const_int_bound.Update(var, this->const_int_bound(new_expr), override); + this->modular_set.Update(var, this->modular_set(new_expr), override); + this->rewrite_simplify.Update(var, new_expr, override); + this->canonical_simplify.Update(var, new_expr, override); } -void Analyzer::Bind(const Var& var, const Range& range) { +void Analyzer::Bind(const Var& var, const Range& range, bool override) { CHECK(range.defined()); if (tir::is_one(range->extent)) { - this->Bind(var, range->min); + this->Bind(var, range->min, override); } else { - this->const_int_bound.Bind(var, range); + this->const_int_bound.Bind(var, range, override); } // skip modular_set // skip rewrite simplify } -void Analyzer::Bind(const Map& variables) { +void Analyzer::Bind(const Map& variables, bool override) { for (const auto& iter : variables) { - this->Bind(iter.first, iter.second); + this->Bind(iter.first, iter.second, override); } } diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index 69aa85e86e25..bb7c3dde17e0 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -99,13 +99,13 @@ class ConstIntBoundAnalyzer::Impl : } }; - void Bind(const Var& var, const Range& range) { + void Bind(const Var& var, const Range& range, bool override) { Entry a = VisitExpr(range->min); Entry b = VisitExpr(range->extent); Entry ret; ret.min_value = a.min_value; ret.max_value = InfAwareAdd(a.max_value, InfAwareAdd(b.max_value, -1)); - Update(var, ret, false); + Update(var, ret, override); } void Update(const Var& var, @@ -576,8 +576,8 @@ void ConstIntBoundAnalyzer::Update(const Var& var, impl_->Update(var, info, override); } -void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range) { - impl_->Bind(var, range); +void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range, bool override) { + impl_->Bind(var, range, override); } std::function ConstIntBoundAnalyzer::EnterConstraint(const PrimExpr& constraint) { diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 8eb0b6ac3158..0687c04636d5 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -45,6 +45,17 @@ PrimExpr SymbolicLimits::neg_inf_ = Var("neg_inf", DataType::Handle()); IntervalSet::IntervalSet(PrimExpr min_value, PrimExpr max_value) { auto node = make_object(); + if (!min_value.same_as(max_value) && min_value->IsInstance() && + max_value->IsInstance()) { + const auto* min_ptr = min_value.as(); + const auto* max_ptr = max_value.as(); + if (min_ptr->value == max_ptr->value) { + node->min_value = std::move(min_value); + node->max_value = node->min_value; + data_ = node; + return; + } + } node->min_value = std::move(min_value); node->max_value = std::move(max_value); data_ = std::move(node); @@ -395,16 +406,6 @@ class IntervalSetEvaluator : auto min_bd = analyzer_->const_int_bound(min_value_expr); auto max_bd = analyzer_->const_int_bound(max_value_expr); if (min_bd->max_value == min_bd->min_value && max_bd->max_value == max_bd->min_value) { - const auto* min_ptr = result->min_value.as(); - const auto* max_ptr = result->max_value.as(); - // The following if statement is necessary. When result is a single point of IntImm, such as - // [0, 0], both 0s refer the same ObjectRef. We really don't want to create a new [0, 0] - // IntervalSet and have 0s refer two different ObjectRef. They will confuse APIs, such as - // IntervalSetEvaluator::MatchPoint() and IntervalSetNode::IsSinglePoint(). - if (min_ptr && max_ptr && min_bd->min_value == min_ptr->value && - max_bd->max_value == max_ptr->value) { - return result; - } return IntervalSet(static_cast(min_bd->min_value), static_cast(max_bd->max_value)); } return result; @@ -771,28 +772,10 @@ Map ConvertDomMap( return dmap; } -IntSet EvalSet(PrimExpr e, const Map& dom_map, - const std::unordered_map& rmap) { - Analyzer ana; - // Bind ana with rmap - for (auto entry : rmap) { - ana.Bind(entry.first->var, entry.second); - } - return IntervalSetEvaluator(&ana, dom_map, false).Eval(e); -} - -IntSet EvalSet(PrimExpr e, const Map& dom_map, const Map& rmap) { - Analyzer ana; - // Bind ana with rmap - for (auto entry : rmap) { - ana.Bind(entry.first->var, entry.second); - } - return IntervalSetEvaluator(&ana, dom_map, false).Eval(e); -} - IntSet EvalSet(PrimExpr e, const Map& dom_map) { - return EvalSet(e, dom_map, std::unordered_map()); + Analyzer ana; + return IntervalSetEvaluator(&ana, dom_map, false).Eval(e); } IntSet IntSet::vector(PrimExpr x) { @@ -806,27 +789,14 @@ IntSet EvalSet(PrimExpr e, return EvalSet(e, ConvertDomMap(dom_map)); } -IntSet EvalSet(PrimExpr e, const std::unordered_map& dom_map) { +IntSet EvalSet(PrimExpr e, + const std::unordered_map& dom_map) { return EvalSet(e, ConvertDomMap(dom_map)); } -IntSet EvalSet(PrimExpr e, const std::unordered_map& dom_map, - const std::unordered_map& rmap) { - return EvalSet(e, ConvertDomMap(dom_map), rmap); -} - -IntSet EvalSet(PrimExpr e, const std::unordered_map& dom_map, - const Map& rmap) { - return EvalSet(e, ConvertDomMap(dom_map), rmap); -} - -IntSet EvalSet(Range r, const Map& dom_map, - const std::unordered_map& rmap) { +IntSet EvalSet(Range r, + const Map& dom_map) { Analyzer ana; - // Bind ana with rmap - for (auto entry : rmap) { - ana.Bind(entry.first->var, entry.second); - } IntervalSetEvaluator m(&ana, dom_map); // Simplifying first can give tighter bounds if r->min and r->extent share variables PrimExpr sum = r->min + r->extent - 1; @@ -835,19 +805,10 @@ IntSet EvalSet(Range r, const Map& dom_map, } IntSet EvalSet(Range r, - const Map& dom_map) { - return EvalSet(r, dom_map, std::unordered_map()); -} - -IntSet EvalSet(Range r, const std::unordered_map& dom_map) { + const std::unordered_map& dom_map) { return EvalSet(r, ConvertDomMap(dom_map)); } -IntSet EvalSet(Range r, const std::unordered_map& dom_map, - const std::unordered_map& rmap) { - return EvalSet(r, ConvertDomMap(dom_map), rmap); -} - IntSet EvalSet(IntSet s, const std::unordered_map& dom_map) { Analyzer ana; diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index 25194dd3014a..8ff67e8ee007 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -218,10 +218,9 @@ void ComputeOpNode::PropBoundToInputs( const Operation& self, arith::Analyzer* analyzer, const std::unordered_map& dom_map, - const std::unordered_map& rmap, std::unordered_map* out_dom_map) const { CHECK_EQ(self.operator->(), this); - auto fvisit = [&dom_map, &rmap, out_dom_map, analyzer](const ObjectRef& n) { + auto fvisit = [&dom_map, out_dom_map, analyzer](const ObjectRef& n) { auto *call = n.as(); if (call != nullptr && call->func.defined()) { Tensor t = Downcast(call->func).output(call->value_index); @@ -232,7 +231,7 @@ void ComputeOpNode::PropBoundToInputs( // undefined behaviour), so we can intersect the estimated set of the argument with the // range expected by the tensor. However, intersection may result in overly complex // expressions, so we perform a more relaxed form of intersection. - IntSet arg_intset = EvalSet(call->args[i], dom_map, rmap); + IntSet arg_intset = analyzer->int_set(call->args[i], ConvertDomMap(dom_map)); const arith::IntervalSetNode* arg_interval = arg_intset.as(); if (arg_interval) { PrimExpr shape_i_min_value = make_zero(t->shape[i].dtype()); @@ -240,12 +239,10 @@ void ComputeOpNode::PropBoundToInputs( PrimExpr min_value = arg_interval->min_value; PrimExpr max_value = arg_interval->max_value; // Prefer the shape bounds only when we can prove they are tighter. - if (arith::is_neg_inf(min_value) || - analyzer->CanProve(shape_i_min_value >= min_value)) { + if (arith::is_pos_inf(max_value) || arith::is_neg_inf(min_value) || + (analyzer->CanProve(shape_i_min_value >= min_value) && + analyzer->CanProve(shape_i_max_value <= max_value))) { min_value = shape_i_min_value; - } - if (arith::is_pos_inf(max_value) || - analyzer->CanProve(shape_i_max_value <= max_value)) { max_value = shape_i_max_value; } dom.data[i].push_back(IntSet::interval(min_value, max_value)); diff --git a/src/te/operation/extern_op.cc b/src/te/operation/extern_op.cc index 9cc6fb0297b7..9d95e329c8f2 100644 --- a/src/te/operation/extern_op.cc +++ b/src/te/operation/extern_op.cc @@ -120,7 +120,6 @@ void ExternOpNode::PropBoundToInputs( const Operation& self, arith::Analyzer* analyzer, const std::unordered_map& dom_map, - const std::unordered_map& rmap, std::unordered_map* out_dom_map) const { for (Tensor t : this->inputs) { auto it = out_dom_map->find(t); diff --git a/src/te/operation/hybrid_op.cc b/src/te/operation/hybrid_op.cc index fc9646275146..4da127ea0a85 100644 --- a/src/te/operation/hybrid_op.cc +++ b/src/te/operation/hybrid_op.cc @@ -136,7 +136,6 @@ void HybridOpNode::PropBoundToInputs( const Operation &self, arith::Analyzer* analyzer, const std::unordered_map &dom_map, - const std::unordered_map& rmap, std::unordered_map* out_dom_map) const { auto curr_inputs = InputTensors(); for (Tensor t : curr_inputs) { diff --git a/src/te/operation/placeholder_op.cc b/src/te/operation/placeholder_op.cc index 1efb6909209a..d48be4c53668 100644 --- a/src/te/operation/placeholder_op.cc +++ b/src/te/operation/placeholder_op.cc @@ -87,7 +87,6 @@ void PlaceholderOpNode::PropBoundToInputs( const Operation& self, arith::Analyzer* analyzer, const std::unordered_map& dom_map, - const std::unordered_map& rmap, std::unordered_map* out_dom_map) const { } diff --git a/src/te/operation/scan_op.cc b/src/te/operation/scan_op.cc index 5974af8a2945..1916b4a4823e 100644 --- a/src/te/operation/scan_op.cc +++ b/src/te/operation/scan_op.cc @@ -183,7 +183,6 @@ void ScanOpNode::PropBoundToInputs( const Operation& self, arith::Analyzer* analyzer, const std::unordered_map& dom_map, - const std::unordered_map& rmap, std::unordered_map* out_dom_map) const { CHECK_EQ(self.operator->(), this); for (size_t i = 0, sp_idx = 0; i < this->init.size(); ++i) { diff --git a/src/te/operation/tensor_compute_op.cc b/src/te/operation/tensor_compute_op.cc index 5e41fc7e2354..4cdc9e1f8d32 100644 --- a/src/te/operation/tensor_compute_op.cc +++ b/src/te/operation/tensor_compute_op.cc @@ -116,7 +116,6 @@ void TensorComputeOpNode::PropBoundToInputs( const Operation& self, arith::Analyzer* analyzer, const std::unordered_map& dom_map, - const std::unordered_map& rmap, std::unordered_map* out_dom_map) const { for (size_t i = 0; i < this->inputs.size(); ++i) { Tensor t = this->inputs[i]; diff --git a/src/te/operation/tensorize.cc b/src/te/operation/tensorize.cc index 65514b833b06..b66406969c76 100644 --- a/src/te/operation/tensorize.cc +++ b/src/te/operation/tensorize.cc @@ -99,8 +99,7 @@ size_t InferTensorizeRegion( temp_dmap[iv->var.get()] = iset; } // Input domains - self->PropBoundToInputs(stage->op, &analyzer, temp_dmap, std::unordered_map(), - &in_dom); + self->PropBoundToInputs(stage->op, &analyzer, temp_dmap, &in_dom); Range none; for (const auto& kv : in_dom) { Array vec; diff --git a/src/te/schedule/bound.cc b/src/te/schedule/bound.cc index 15023ebbfbcb..7ccf0942a820 100644 --- a/src/te/schedule/bound.cc +++ b/src/te/schedule/bound.cc @@ -138,7 +138,7 @@ void InferRootBound(const Stage& stage, Array stage_attach = ctx.attach_path.at(stage->op); // The parent set. for (const Operation& op : consumers) { - std::unordered_map relax_set; + Map relax_set; std::unordered_map up_state; bool found_attach = false; CHECK(ctx.op2stage_.count(op.get())); @@ -177,9 +177,9 @@ void InferRootBound(const Stage& stage, << "InferBound requires every leaf iter var's min equals 0, " << "call schedule.normalize to achieve this."; if (NeedRelax(iv, found_attach, ctx.bind_map, scope)) { - relax_set[iv->var.get()] = IntSet::range(vrange); + relax_set.Set(iv->var, IntSet::range(vrange)); if (ctx.bind_map.count(iv)) { - relax_set[ctx.bind_map.at(iv)->var.get()] = IntSet::range(vrange); + relax_set.Set(ctx.bind_map.at(iv)->var, IntSet::range(vrange)); } } } @@ -191,6 +191,9 @@ void InferRootBound(const Stage& stage, // Relax if needed. std::unordered_map dom_map; arith::Analyzer analyzer; + for (auto entry : *rmap) { + analyzer.Bind(entry.first->var, entry.second); + } for (auto iv : op->root_iter_vars()) { Range r; if (up_state.count(iv)) { @@ -199,13 +202,15 @@ void InferRootBound(const Stage& stage, r = iv->dom; } if (relax_set.size() != 0) { - dom_map[iv->var.get()] = EvalSet(r, relax_set, *rmap); + dom_map[iv->var.get()] = IntSet::interval( + analyzer.int_set(r->min, relax_set).min(), + analyzer.int_set(r->min + r->extent - 1, relax_set).max()); } else { dom_map[iv->var.get()] = IntSet::range(r); } - analyzer.Bind(iv->var, r); + analyzer.Bind(iv->var, r, true); } - op->PropBoundToInputs(op, &analyzer, dom_map, *rmap, &tmap); + op->PropBoundToInputs(op, &analyzer, dom_map, &tmap); } stage->op->GatherBound(stage->op, tmap, rmap); } diff --git a/src/te/schedule/message_passing.cc b/src/te/schedule/message_passing.cc index fc8b2ab628b2..d16f84bf12c5 100644 --- a/src/te/schedule/message_passing.cc +++ b/src/te/schedule/message_passing.cc @@ -580,11 +580,15 @@ std::vector MakeBoundCheck( PassUpBoundCheck(stage, dom_map, &bound_state, &analyzer); std::vector preds; - std::unordered_map iset_dmap; + Map iset_dmap; // setup domain map for set analysis for (const auto& kv : dom_map) { - iset_dmap[kv.first->var.get()] = IntSet::range(kv.second); + iset_dmap.Set(kv.first->var, IntSet::range(kv.second)); + } + + for (auto entry : dom_map) { + analyzer.Bind(entry.first->var, entry.second); } for (const IterVar& iv : stage->all_iter_vars) { @@ -592,7 +596,7 @@ std::vector MakeBoundCheck( if (bound_state.at(iv)) { Range dom = dom_map.at(iv); PrimExpr value = value_map.at(iv) - dom->min; - PrimExpr vmax = EvalSet(value, iset_dmap).max(); + PrimExpr vmax = analyzer.int_set(value, iset_dmap).max(); if (vmax.dtype() != value.dtype() || !analyzer.CanProve(vmax < dom->extent)) { preds.emplace_back(value < dom->extent); } @@ -604,7 +608,7 @@ std::vector MakeBoundCheck( CHECK(iv->dom.defined()); if (!skip_ivar_domain && !IsRangeSame(iv->dom, dom)) { PrimExpr value = value_map.at(iv) - iv->dom->min; - IntSet s = EvalSet(value, iset_dmap, dom_map); + IntSet s = analyzer.int_set(value, iset_dmap); PrimExpr vmin = s.min(); PrimExpr vmax = s.max(); // The range of `value` resides in [vmin, vmax] From a8cd3c968dc599e23c0b1fb5a954e782046f3697 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 22 Apr 2020 10:41:13 -0400 Subject: [PATCH 07/13] Update comments. --- include/tvm/arith/analyzer.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 258fb47529d5..4a43c49ca41e 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -138,6 +138,7 @@ class ConstIntBoundAnalyzer { * * \param var The variable. * \param range The range we bind to. + * \param override Whether do we allow override of existing information. */ TVM_DLL void Bind(const Var& var, const Range& range, bool override = false); @@ -411,6 +412,7 @@ class TVM_DLL Analyzer { * * \param var The variable. * \param expr The expression we bind to. + * \param override Whether do we allow override of existing information. */ void Bind(const Var& var, const PrimExpr& expr, bool override = false); /*! @@ -421,12 +423,14 @@ class TVM_DLL Analyzer { * * \param var The variable. * \param range The range we bind to. + * \param override Whether do we allow override of existing information. */ void Bind(const Var& var, const Range& range, bool override = false); /*! * \brief Bind all the vars in the Map * * \param variables The {variable -> range} map. + * \param override Whether do we allow override of existing information. */ void Bind(const Map& variables, bool override = false); /*! From e31ec5d75290d6136524682b39a71e7a67731693 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 22 Apr 2020 12:53:32 -0400 Subject: [PATCH 08/13] Moving IntImm handling to MatchPoint and IsSinglePoint. --- src/arith/int_set.cc | 32 ++++++---------------- src/arith/interval_set.h | 7 ++++- tests/python/unittest/test_arith_intset.py | 1 + 3 files changed, 16 insertions(+), 24 deletions(-) diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 0687c04636d5..b1ac5bca61c7 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -45,17 +45,6 @@ PrimExpr SymbolicLimits::neg_inf_ = Var("neg_inf", DataType::Handle()); IntervalSet::IntervalSet(PrimExpr min_value, PrimExpr max_value) { auto node = make_object(); - if (!min_value.same_as(max_value) && min_value->IsInstance() && - max_value->IsInstance()) { - const auto* min_ptr = min_value.as(); - const auto* max_ptr = max_value.as(); - if (min_ptr->value == max_ptr->value) { - node->min_value = std::move(min_value); - node->max_value = node->min_value; - data_ = node; - return; - } - } node->min_value = std::move(min_value); node->max_value = std::move(max_value); data_ = std::move(node); @@ -398,17 +387,7 @@ class IntervalSetEvaluator : } IntervalSet Eval(const PrimExpr& val) { - IntervalSet result = this->VisitExpr(val); - // Use the IterVar range info bound to analyzer to further simplify - // and reduce the interval - auto min_value_expr = analyzer_->Simplify(result->min_value); - auto max_value_expr = analyzer_->Simplify(result->max_value); - auto min_bd = analyzer_->const_int_bound(min_value_expr); - auto max_bd = analyzer_->const_int_bound(max_value_expr); - if (min_bd->max_value == min_bd->min_value && max_bd->max_value == max_bd->min_value) { - return IntervalSet(static_cast(min_bd->min_value), static_cast(max_bd->max_value)); - } - return result; + return this->VisitExpr(val); } // evaluate and relax the set IntervalSet Eval(IntervalSet val) { @@ -554,7 +533,14 @@ class IntervalSetEvaluator : // whether set is exactly single point that equals value. bool MatchPoint(const IntervalSet& set, const PrimExpr& value) const { - return set->min_value.same_as(value) && set->max_value.same_as(value); + if (set->min_value.same_as(value) && set->max_value.same_as(value)) { + return true; + } + const auto* min_ptr = set->min_value.as(); + const auto* max_ptr = set->max_value.as(); + const auto* value_ptr = value.as(); + return (min_ptr && max_ptr && value_ptr && value_ptr->value == min_ptr->value && + value_ptr->value == max_ptr->value); } template diff --git a/src/arith/interval_set.h b/src/arith/interval_set.h index 51b500adb412..7e82abfa8d7c 100644 --- a/src/arith/interval_set.h +++ b/src/arith/interval_set.h @@ -62,7 +62,12 @@ class IntervalSetNode : public IntSetNode { } /*! \return Whether the interval is a single point. */ bool IsSinglePoint() const { - return min_value.same_as(max_value); + if (min_value.same_as(max_value)) { + return true; + } + const auto* min_ptr = min_value.as(); + const auto* max_ptr = max_value.as(); + return (min_ptr && max_ptr && max_ptr->value == min_ptr->value); } /*! \return whether interval represent nothing */ bool IsEmpty() const { diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index 00b7d9ca646e..f46c8f8962be 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -94,6 +94,7 @@ def test_mod(): ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(13, 15)}, (3, 5)) ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(3, 15)}, (0, 9)) ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(3, 11)}, (0, 9)) + ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(1, 21)}, (0, 9)) floordiv = tvm.te.floordiv z = te.var("z") From da901cf550a3af93d8e194ccbca0ed0882390994 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 22 Apr 2020 15:44:22 -0400 Subject: [PATCH 09/13] Use CanProve to simplify the addtional comparison in MatchPoint. --- src/arith/int_set.cc | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index b1ac5bca61c7..c79b2e599907 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -533,14 +533,9 @@ class IntervalSetEvaluator : // whether set is exactly single point that equals value. bool MatchPoint(const IntervalSet& set, const PrimExpr& value) const { - if (set->min_value.same_as(value) && set->max_value.same_as(value)) { - return true; - } - const auto* min_ptr = set->min_value.as(); - const auto* max_ptr = set->max_value.as(); - const auto* value_ptr = value.as(); - return (min_ptr && max_ptr && value_ptr && value_ptr->value == min_ptr->value && - value_ptr->value == max_ptr->value); + return (set->min_value.same_as(value) && set->max_value.same_as(value)) || + (analyzer_->CanProve(set->min_value == value) && + analyzer_->CanProve(set->max_value == value)); } template From c948f03510a599d70ae6c32a4717142181e159c8 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 22 Apr 2020 17:17:27 -0400 Subject: [PATCH 10/13] Revert the unecessary changes to MatchPoint and IsSinglePoint, as we don't create single point interval set with different min/max. --- src/arith/int_set.cc | 4 +--- src/arith/interval_set.h | 7 +------ 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index c79b2e599907..2b3877ca3c98 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -533,9 +533,7 @@ class IntervalSetEvaluator : // whether set is exactly single point that equals value. bool MatchPoint(const IntervalSet& set, const PrimExpr& value) const { - return (set->min_value.same_as(value) && set->max_value.same_as(value)) || - (analyzer_->CanProve(set->min_value == value) && - analyzer_->CanProve(set->max_value == value)); + return set->min_value.same_as(value) && set->max_value.same_as(value); } template diff --git a/src/arith/interval_set.h b/src/arith/interval_set.h index 7e82abfa8d7c..51b500adb412 100644 --- a/src/arith/interval_set.h +++ b/src/arith/interval_set.h @@ -62,12 +62,7 @@ class IntervalSetNode : public IntSetNode { } /*! \return Whether the interval is a single point. */ bool IsSinglePoint() const { - if (min_value.same_as(max_value)) { - return true; - } - const auto* min_ptr = min_value.as(); - const auto* max_ptr = max_value.as(); - return (min_ptr && max_ptr && max_ptr->value == min_ptr->value); + return min_value.same_as(max_value); } /*! \return whether interval represent nothing */ bool IsEmpty() const { From 84951aa403f5954f9746ac26cc7112fdba03fa6d Mon Sep 17 00:00:00 2001 From: root Date: Thu, 23 Apr 2020 23:08:37 -0400 Subject: [PATCH 11/13] Update floormod with simpler implementation; update corresponding test; clean unecessary comments. --- include/tvm/arith/int_set.h | 21 ++++++--------------- src/arith/int_set.cc | 21 ++++++++------------- tests/python/unittest/test_arith_intset.py | 2 +- 3 files changed, 15 insertions(+), 29 deletions(-) diff --git a/include/tvm/arith/int_set.h b/include/tvm/arith/int_set.h index b253ca636f52..ab73b070f187 100644 --- a/include/tvm/arith/int_set.h +++ b/include/tvm/arith/int_set.h @@ -152,22 +152,13 @@ class IntSet : public ObjectRef { //----------------------------------------------- // Integer set legacy API. //------------------------------------------------ - /*! - * \brief Convert std::unordered_map to Map - * - * \param dom_map The domain map to convert. - * \return The converted map. - */ +/*! + * \brief Convert std::unordered_map to Map + * + * \param dom_map The domain map to convert. + * \return The converted map. + */ Map ConvertDomMap(const std::unordered_map& dom_map); -// /*! -// * \brief Find an symbolic integer set that contains all possible values of -// * e given the domain of each iteration variables. -// * -// * \param e The expression to be evaluated. -// * \param dom_map The domain of each variable. -// * \return An integer set that can cover all the possible values of e. -// */ -// IntSet EvalSet(PrimExpr e, const Map& dom_map); /*! * \brief Find an symbolic integer set that contains all possible values of * e given the domain of each iteration variables. diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 2b3877ca3c98..f4081c7df659 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -311,19 +311,14 @@ inline IntervalSet Combine(Analyzer* analyzer, LOG(FATAL) << "Modular by zero in CombineInterval Mod"; } if (analyzer->CanProveGreaterEqual(divisor, 0)) { - if (const auto* ptr = b->min_value.as()) { - // a mod b = a - b * (a/b) if - // (i) a_max - a_min < b, i.e. that before mod, a's range doesn't cover [0, b) - // and (ii) a_min mod b <= a_max mod b, i.e. that a's range is still continuous after mod - auto tmax = a->max_value - b->min_value * floordiv(a->max_value, b->min_value); - tmax = analyzer->Simplify(tmax); - auto tmin = a->min_value - b->min_value * floordiv(a->min_value, b->min_value); - tmin = analyzer->Simplify(tmin); - auto tset = IntervalSet(tmin, tmax); - bool within_range = analyzer->CanProveLess(a->max_value - a->min_value, ptr->value); - bool wrap_around = analyzer->CanProve(tset->max_value < tset->min_value); - if (within_range && !wrap_around) { - return tset; + if (b->min_value.as()) { + // a mod b = a - (a / b) * b if a_max / b == a_min / b + auto qmax = floordiv(a->max_value, b->min_value); + auto qmin = floordiv(a->min_value, b->min_value); + if (analyzer->CanProve(qmax == qmin)) { + auto tmax = a->max_value - b->min_value * qmin; + auto tmin = a->min_value - b->min_value * qmin; + return IntervalSet(tmin, tmax); } } return IntervalSet(make_zero(divisor.dtype()), divisor - 1); diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index f46c8f8962be..9919c7b96cf1 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -100,7 +100,7 @@ def test_mod(): z = te.var("z") ck.analyzer.bind(x, tvm.ir.Range.make_by_min_extent(0, 3)) ck.verify(flm(y, 8), {y : tvm.arith.IntervalSet(z*8+x*4, z*8+x*4+3)}, - (x*4-floordiv(x, 2)*8, x*4+3-floordiv(x*4+3, 8)*8)) + (0, 7)) ck1 = IntSetChecker() ck1.analyzer.bind(x, tvm.ir.Range.make_by_min_extent(0, 2)) ck1.verify(flm(y, 8), {y : tvm.arith.IntervalSet(z*8+x*4, z*8+x*4+3)}, (x*4, x*4+3)) From 3d12c26bde49c58dafdfba59be2c6f4b8fa9a00d Mon Sep 17 00:00:00 2001 From: root Date: Fri, 24 Apr 2020 11:46:18 -0400 Subject: [PATCH 12/13] Tighten the condition to use shape bounds. --- src/te/operation/compute_op.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index 8ff67e8ee007..0f6fa490d57c 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -239,7 +239,7 @@ void ComputeOpNode::PropBoundToInputs( PrimExpr min_value = arg_interval->min_value; PrimExpr max_value = arg_interval->max_value; // Prefer the shape bounds only when we can prove they are tighter. - if (arith::is_pos_inf(max_value) || arith::is_neg_inf(min_value) || + if ((arith::is_pos_inf(max_value) && arith::is_neg_inf(min_value)) || (analyzer->CanProve(shape_i_min_value >= min_value) && analyzer->CanProve(shape_i_max_value <= max_value))) { min_value = shape_i_min_value; From 6b5fb1830ac26707eb6c1df4e0550799c80bf434 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 24 Apr 2020 21:58:35 -0400 Subject: [PATCH 13/13] Update comments and a variable. --- include/tvm/arith/analyzer.h | 8 ++++---- src/arith/int_set.cc | 10 +++++----- src/te/operation/compute_op.cc | 4 ++++ 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 4a43c49ca41e..c08c0d6d347b 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -138,7 +138,7 @@ class ConstIntBoundAnalyzer { * * \param var The variable. * \param range The range we bind to. - * \param override Whether do we allow override of existing information. + * \param override Whether we allow overriding an existing var's range. */ TVM_DLL void Bind(const Var& var, const Range& range, bool override = false); @@ -412,7 +412,7 @@ class TVM_DLL Analyzer { * * \param var The variable. * \param expr The expression we bind to. - * \param override Whether do we allow override of existing information. + * \param override Whether we allow overriding an existing var's expression. */ void Bind(const Var& var, const PrimExpr& expr, bool override = false); /*! @@ -423,14 +423,14 @@ class TVM_DLL Analyzer { * * \param var The variable. * \param range The range we bind to. - * \param override Whether do we allow override of existing information. + * \param override Whether we allow overriding an existing var's expression. */ void Bind(const Var& var, const Range& range, bool override = false); /*! * \brief Bind all the vars in the Map * * \param variables The {variable -> range} map. - * \param override Whether do we allow override of existing information. + * \param override Whether we allow overriding an existing var's expression. */ void Bind(const Map& variables, bool override = false); /*! diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index f4081c7df659..d2d43d6a537a 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -311,13 +311,13 @@ inline IntervalSet Combine(Analyzer* analyzer, LOG(FATAL) << "Modular by zero in CombineInterval Mod"; } if (analyzer->CanProveGreaterEqual(divisor, 0)) { - if (b->min_value.as()) { + if (divisor.as()) { // a mod b = a - (a / b) * b if a_max / b == a_min / b - auto qmax = floordiv(a->max_value, b->min_value); - auto qmin = floordiv(a->min_value, b->min_value); + auto qmax = floordiv(a->max_value, divisor); + auto qmin = floordiv(a->min_value, divisor); if (analyzer->CanProve(qmax == qmin)) { - auto tmax = a->max_value - b->min_value * qmin; - auto tmin = a->min_value - b->min_value * qmin; + auto tmax = a->max_value - divisor * qmin; + auto tmin = a->min_value - divisor * qmin; return IntervalSet(tmin, tmax); } } diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index 0f6fa490d57c..87d0af344eb8 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -239,6 +239,10 @@ void ComputeOpNode::PropBoundToInputs( PrimExpr min_value = arg_interval->min_value; PrimExpr max_value = arg_interval->max_value; // Prefer the shape bounds only when we can prove they are tighter. + // We must update bound's ends in pairs. Here is an counter example: shape_i is + // [0, 0] and arg_interval is [threadIdx.y, threadIdx.y], where threadIdx.y's range is + // [0, 7]. If we allowed updating one end, the bound would become [threadIdx.y, 0], + // awkward for further analysis. if ((arith::is_pos_inf(max_value) && arith::is_neg_inf(min_value)) || (analyzer->CanProve(shape_i_min_value >= min_value) && analyzer->CanProve(shape_i_max_value <= max_value))) {