From e9145f9881540ba8a8f658d9273c3e3bfdc3f003 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 3 Aug 2022 10:39:17 -0500 Subject: [PATCH] [Arith][TIR] IntSetAnalyzer, delay intersection of IntSet until use Follow-up from https://github.com/apache/tvm/pull/11970, to improve performance. In the initial implementation, the `analyzer->int_set` would compute the intersection of all scope-based constraints when entering the scope, even if they weren't actually used. This commit delays the call to `Intersect` until required, following the same behavior as `ConstIntBound`. --- src/arith/int_set.cc | 126 ++++++++++++++++++------------------------- 1 file changed, 52 insertions(+), 74 deletions(-) diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 35b12bb35238..7d601d9a8bae 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -362,8 +362,13 @@ using namespace tir; // We might use better set analysis in the future to replace the intervalset. class IntervalSetEvaluator : public ExprFunctor { public: - IntervalSetEvaluator(Analyzer* analyzer, const Map& dom_map, bool eval_vec = false) - : analyzer_(analyzer), dom_map_(dom_map), eval_vec_(eval_vec) {} + IntervalSetEvaluator(Analyzer* analyzer, const Map& dom_map, + const std::vector>* dom_constraints = nullptr, + bool eval_vec = false) + : analyzer_(analyzer), + dom_map_(dom_map), + dom_constraints_(dom_constraints), + eval_vec_(eval_vec) {} IntervalSet Eval(const PrimExpr& val) { return this->VisitExpr(val); } // evaluate and relax the set @@ -383,18 +388,40 @@ class IntervalSetEvaluator : public ExprFunctor { IntervalSet VisitExpr_(const VarNode* op) final { Var var = GetRef(op); + + Array values; + if (dom_constraints_) { + for (const auto& constraint : *dom_constraints_) { + if (var.same_as(constraint.first)) { + values.push_back(constraint.second); + } + } + } + auto it = dom_map_.find(var); if (it != dom_map_.end()) { - IntervalSet res = ToIntervalSet((*it).second); - if (res->min_value.same_as(var) && res->max_value.same_as(var)) { - return res; - } - // recursively evaluate mapped result - // in case the domain contains variables to be relaxed. - return Eval(res); - } else { + values.push_back((*it).second); + } + + if (values.empty()) { return IntervalSet::SinglePoint(var); } + + IntSet intersection = [&]() { + if (values.size() == 1) { + return values.front(); + } else { + return Intersect(values); + } + }(); + + IntervalSet res = ToIntervalSet(intersection); + if (res->min_value.same_as(var) && res->max_value.same_as(var)) { + return res; + } + // recursively evaluate mapped result + // in case the domain contains variables to be relaxed. + return Eval(res); } IntervalSet VisitExpr_(const AddNode* op) final { return VisitBinaryExpr_(op); } @@ -517,6 +544,7 @@ class IntervalSetEvaluator : public ExprFunctor { // analyzer Analyzer* analyzer_; const Map& dom_map_; + const std::vector>* dom_constraints_; bool eval_vec_{false}; }; @@ -529,7 +557,7 @@ class IntSetAnalyzer::Impl { } IntSet Eval(const PrimExpr& expr) const { - return IntervalSetEvaluator(analyzer_, GetCurrentBounds(), true).Eval(expr); + return IntervalSetEvaluator(analyzer_, dom_map_, &dom_constraints_, true).Eval(expr); } void Bind(const Var& var, const Range& range, bool allow_override) { @@ -541,10 +569,6 @@ class IntSetAnalyzer::Impl { std::function EnterConstraint(const PrimExpr& constraint); private: - // Get the current variable bounds, including both global bounds and - // scope-dependent bounds. - Map GetCurrentBounds() const; - // Utility function to split a boolean condition into the domain // bounds implied by that condition. static std::vector> DetectBoundInfo(const PrimExpr& cond); @@ -556,9 +580,11 @@ class IntSetAnalyzer::Impl { // ranges) Map dom_map_; - // Map of variables to implicit scope-dependent bounds (e.g. inside - // the body of an if-statement) - Map constraints_; + // List of implicit scope-dependent bounds (e.g. inside the body of + // an if-statement). Maintained as a list of constraints, rather + // than as a `Map`, to avoid computing an Intersection + // until required. + std::vector> dom_constraints_; }; IntSetAnalyzer::IntSetAnalyzer(Analyzer* parent) : impl_(new Impl(parent)) {} @@ -603,29 +629,6 @@ void IntSetAnalyzer::Impl::Bind(const Var& var, const PrimExpr& expr, bool can_o Update(var, Eval(expr), can_override); } -Map IntSetAnalyzer::Impl::GetCurrentBounds() const { - // If either constraints_ or dom_map_ is empty, return the other to - // avoid constructing a new map. - if (constraints_.empty()) { - return dom_map_; - } else if (dom_map_.empty()) { - return constraints_; - } - - // If neither is empty, construct a merged domain map with - // information from both sources. - Map merged = dom_map_; - for (const auto& pair : constraints_) { - auto it = merged.find(pair.first); - if (it == merged.end()) { - merged.Set(pair.first, pair.second); - } else { - merged.Set(pair.first, Intersect({pair.second, (*it).second})); - } - } - return merged; -} - std::vector> IntSetAnalyzer::Impl::DetectBoundInfo( const PrimExpr& constraint) { PVar x; @@ -665,41 +668,16 @@ std::function IntSetAnalyzer::EnterConstraint(const PrimExpr& constraint } std::function IntSetAnalyzer::Impl::EnterConstraint(const PrimExpr& constraint) { - Map cached_values; - auto bounds = DetectBoundInfo(constraint); if (bounds.size() == 0) return nullptr; - // Collect the current values of each var that is changes by this - // constraint. - for (const auto& pair : bounds) { - auto it = constraints_.find(pair.first); - if (it == constraints_.end()) { - cached_values.Set(pair.first, IntSet()); - } else { - cached_values.Set(pair.first, (*it).second); - } - } - - // Update all constraints - for (const auto& pair : bounds) { - auto it = constraints_.find(pair.first); - if (it == constraints_.end()) { - constraints_.Set(pair.first, pair.second); - } else { - constraints_.Set(pair.first, Intersect({pair.second, (*it).second})); - } - } - - auto frecover = [cached_values, this]() { - for (const auto& it : cached_values) { - if (it.second.defined()) { - constraints_.Set(it.first, it.second); - } else { - constraints_.erase(it.first); - } - } + size_t old_size = dom_constraints_.size(); + dom_constraints_.insert(dom_constraints_.end(), bounds.begin(), bounds.end()); + size_t new_size = dom_constraints_.size(); + auto frecover = [old_size, new_size, this]() { + ICHECK_EQ(dom_constraints_.size(), new_size); + dom_constraints_.resize(old_size); }; return frecover; } @@ -960,13 +938,13 @@ Map ConvertDomMap(const std::unordered_map& IntSet EvalSet(PrimExpr e, const Map& dom_map) { Analyzer ana; - return IntervalSetEvaluator(&ana, dom_map, false).Eval(e); + return IntervalSetEvaluator(&ana, dom_map, {}, false).Eval(e); } IntSet IntSet::Vector(PrimExpr x) { Analyzer ana; Map dmap; - return IntervalSetEvaluator(&ana, dmap, true).Eval(x); + return IntervalSetEvaluator(&ana, dmap, {}, true).Eval(x); } IntSet EvalSet(PrimExpr e, const Map& dom_map) {