From 8bc712d0d076383471986377d798d8718ff59eae Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 29 Jun 2022 14:37:35 -0500 Subject: [PATCH 01/14] [Arith] Allow binding of Var in IntSetAnalyzer The other four subanalyzers in `arith::Analyzer` can each be provided with variable bindings/constraints that are remembered internally. This adds the same capability to `IntSetAnalyzer`, rather than requiring users to independently track and maintain a `Map` containing the domain of each variable, and applies bindings/constraints alongside the other subanalyzers. --- include/tvm/arith/analyzer.h | 38 ++++++++-- src/arith/analyzer.cc | 6 +- src/arith/int_set.cc | 131 +++++++++++++++++++++++++++++++++++ 3 files changed, 170 insertions(+), 5 deletions(-) diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 3704eff33ec2..828ef22ca1b8 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -135,7 +135,7 @@ class ConstIntBoundAnalyzer { * * \param var The variable of interest. * \param info The bound information. - * \param allow_override Whether do we allow override of existing information. + * \param allow_override whether we allow override of existing information. */ TVM_DLL void Update(const Var& var, const ConstIntBound& info, bool allow_override = false); /*! @@ -224,7 +224,7 @@ class ModularSetAnalyzer { * * \param var The variable of interest. * \param info The bound information. - * \param allow_override Whether do we allow override of existing information. + * \param allow_override whether we allow override of existing information. */ TVM_DLL void Update(const Var& var, const ModularSet& info, bool allow_override = false); @@ -263,7 +263,7 @@ class RewriteSimplifier { * * \param var The variable of interest. * \param new_expr - * \param allow_override Whether do we allow override of existing information. + * \param allow_override Whether we allow override of existing information. */ TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool allow_override = false); @@ -297,7 +297,7 @@ class CanonicalSimplifier { * * \param var The variable of interest. * \param new_expr - * \param allow_override Whether do we allow override of existing information. + * \param allow_override whether we allow override of existing information. */ TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool allow_override = false); @@ -365,6 +365,36 @@ class IntSetAnalyzer { */ TVM_DLL IntSet operator()(const PrimExpr& expr, const Map& dom_map); + /*! + * \brief Find a symbolic integer set that contains all possible + * values of expr given the domain of each variables, using + * the domain map defined by bound variables. + * + * \param expr The expression of interest. + * \return the result of the analysis. + */ + TVM_DLL IntSet operator()(const PrimExpr& expr); + + /*! + * \brief Update binding of var to a new expression. + * + * \param var The variable of interest. + * \param new_interval_set The set of allowed values for this var. + * \param allow_override whether we allow override of existing information. + */ + TVM_DLL void Update(const Var& var, const IntSet& new_interval_set, bool allow_override = false); + + /*! + * \brief Update binding of var to a new expression. + * + * \param var The variable of interest. + * \param new_range The range of allowed values for this var. + * \param allow_override whether we allow override of existing information. + */ + TVM_DLL void Update(const Var& var, const Range& new_range, bool allow_override = false); + + std::function EnterConstraint(const PrimExpr& constraint); + private: friend class Analyzer; explicit IntSetAnalyzer(Analyzer* parent); diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index b922138057e9..7158ed6c657a 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -44,6 +44,7 @@ void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) { this->modular_set.Update(var, this->modular_set(new_expr), allow_override); this->rewrite_simplify.Update(var, new_expr, allow_override); this->canonical_simplify.Update(var, new_expr, allow_override); + this->int_set.Update(var, this->int_set(new_expr), allow_override); } void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) { @@ -52,6 +53,7 @@ void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) { this->Bind(var, range->min, allow_override); } else { this->const_int_bound.Bind(var, range, allow_override); + this->int_set.Update(var, range, allow_override); } // skip modular_set // skip rewrite simplify @@ -69,8 +71,10 @@ void ConstraintContext::EnterWithScope() { auto f0 = analyzer_->const_int_bound.EnterConstraint(constraint_); auto f1 = analyzer_->modular_set.EnterConstraint(constraint_); auto f2 = analyzer_->rewrite_simplify.EnterConstraint(constraint_); + auto f3 = analyzer_->int_set.EnterConstraint(constraint_); // recovery function. - exit_ = [f0, f1, f2]() { + exit_ = [f0, f1, f2, f3]() { + if (f3 != nullptr) f3(); if (f2 != nullptr) f2(); if (f1 != nullptr) f1(); if (f0 != nullptr) f0(); diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 48fae479b042..1ccc5fbff097 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -31,6 +31,7 @@ #include #include +#include "constraint_extract.h" #include "interval_set.h" #include "pattern_match.h" @@ -509,8 +510,26 @@ class IntSetAnalyzer::Impl { return IntervalSetEvaluator(analyzer_, dom_map).Eval(expr); } + IntSet Eval(const PrimExpr& expr) const { + return IntervalSetEvaluator(analyzer_, dom_map_).Eval(expr); + } + + void Bind(const Var& var, const Range& range, bool allow_override) { + IntSet min = Eval(range->min); + IntSet extent = Eval(range->extent); + + Bind(var, IntervalSet(min.min(), min.max() + extent.max() - 1), allow_override); + } + + void Bind(const Var& var, const IntSet& info, bool override_info); + void Bind(const Var& var, const PrimExpr& expr, bool override_info); + std::function EnterConstraint(const PrimExpr& constraint); + private: + static std::vector> DetectBoundInfo(const PrimExpr& cond); + Analyzer* analyzer_; + Map dom_map_; }; IntSetAnalyzer::IntSetAnalyzer(Analyzer* parent) : impl_(new Impl(parent)) {} @@ -521,6 +540,118 @@ IntSet IntSetAnalyzer::operator()(const PrimExpr& expr, const Map& return impl_->Eval(expr, dom_map); } +IntSet IntSetAnalyzer::operator()(const PrimExpr& expr) { return impl_->Eval(expr); } + +void IntSetAnalyzer::Update(const Var& var, const IntSet& info, bool allow_override) { + impl_->Bind(var, info, allow_override); +} + +void IntSetAnalyzer::Update(const Var& var, const Range& range, bool allow_override) { + impl_->Bind(var, range, allow_override); +} + +void IntSetAnalyzer::Impl::Bind(const Var& var, const IntSet& info, bool can_override) { + if (!can_override) { + auto it = dom_map_.find(var); + if (it != dom_map_.end()) { + const IntSet& old_info = (*it).second; + + ICHECK(ExprDeepEqual()(old_info.min(), info.min())) + << "Trying to update var \'" << var << "\'" + << " with a different minimum value: " + << "original=" << old_info.min() << ", new=" << info.min(); + + ICHECK(ExprDeepEqual()(old_info.max(), info.max())) + << "Trying to update var \'" << var << "\'" + << " with a different maximum value: " + << "original=" << old_info.max() << ", new=" << info.max(); + } + } + dom_map_.Set(var, info); +} + +void IntSetAnalyzer::Impl::Bind(const Var& var, const PrimExpr& expr, bool can_override) { + Bind(var, Eval(expr), can_override); +} + +std::vector> IntSetAnalyzer::Impl::DetectBoundInfo( + const PrimExpr& constraint) { + PVar x; + PVar limit; + + std::vector> bounds; + for (const PrimExpr& subconstraint : ExtractConstraints(constraint)) { + if ((x <= limit).Match(subconstraint)) { + bounds.push_back({x.Eval(), IntSet::Interval(SymbolicLimits::neg_inf_, limit.Eval())}); + } else if ((x < limit).Match(subconstraint)) { + bounds.push_back({x.Eval(), IntSet::Interval(SymbolicLimits::neg_inf_, limit.Eval() - 1)}); + } else if ((x >= limit).Match(subconstraint)) { + bounds.push_back({x.Eval(), IntSet::Interval(limit.Eval(), SymbolicLimits::pos_inf_)}); + } else if ((x > limit).Match(subconstraint)) { + bounds.push_back({x.Eval(), IntSet::Interval(limit.Eval() + 1, SymbolicLimits::pos_inf_)}); + } else if ((x == limit).Match(subconstraint)) { + bounds.push_back({x.Eval(), IntSet::SinglePoint(limit.Eval())}); + } + + if ((limit >= x).Match(subconstraint)) { + bounds.push_back({x.Eval(), IntSet::Interval(SymbolicLimits::neg_inf_, limit.Eval())}); + } else if ((limit > x).Match(subconstraint)) { + bounds.push_back({x.Eval(), IntSet::Interval(SymbolicLimits::neg_inf_, limit.Eval() - 1)}); + } else if ((limit <= x).Match(subconstraint)) { + bounds.push_back({x.Eval(), IntSet::Interval(limit.Eval(), SymbolicLimits::pos_inf_)}); + } else if ((limit < x).Match(subconstraint)) { + bounds.push_back({x.Eval(), IntSet::Interval(limit.Eval() + 1, SymbolicLimits::pos_inf_)}); + } else if ((limit == x).Match(subconstraint)) { + bounds.push_back({x.Eval(), IntSet::SinglePoint(limit.Eval())}); + } + } + return bounds; +} + +std::function IntSetAnalyzer::EnterConstraint(const PrimExpr& constraint) { + return impl_->EnterConstraint(constraint); +} + +std::function IntSetAnalyzer::Impl::EnterConstraint(const PrimExpr& constraint) { + Map cached_values; + + auto bounds = DetectBoundInfo(constraint); + + if (bounds.size() == 0) return nullptr; + + // Collect the current values of each var that is changes by this + // constraint. + for (const auto& pair : bounds) { + auto it = dom_map_.find(pair.first); + if (it == dom_map_.end()) { + cached_values.Set(pair.first, IntSet()); + } else { + cached_values.Set(pair.first, (*it).second); + } + } + + // Update all constraints + for (const auto& pair : bounds) { + auto it = dom_map_.find(pair.first); + if (it == dom_map_.end()) { + dom_map_.Set(pair.first, pair.second); + } else { + dom_map_.Set(pair.first, Intersect({pair.second, (*it).second})); + } + } + + auto frecover = [cached_values, this]() { + for (const auto& it : cached_values) { + if (it.second.defined()) { + dom_map_.Set(it.first, it.second); + } else { + dom_map_.erase(it.first); + } + } + }; + return frecover; +} + // Quickly adapt to IntSet interface // TODO(tqchen): revisit IntSet interface as well. Range IntSet::CoverRange(Range max_range) const { From 42a7474fcfc0fee1d8332b4a340159259fc91b5f Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 29 Jun 2022 14:52:10 -0500 Subject: [PATCH 02/14] [Arith] Updated IRVisitorWithAnalyzer to mimic IRMutatorWithAnalyzer Previously, `IRVisitorWithAnalyzer` did not allow subclassing, and could only be used to collect bounds of variables along an entire statement, and could not be used to perform scope-dependent analysis. This commit removes `final` from `IRVisitorWithAnalyzer` and provides the same scope-based constraints/bindings during iteration as are provided by `IRMutatorWithAnalyzer`. --- src/arith/ir_visitor_with_analyzer.cc | 140 ++++++++++++++++++++++++++ src/arith/ir_visitor_with_analyzer.h | 40 +++----- 2 files changed, 156 insertions(+), 24 deletions(-) create mode 100644 src/arith/ir_visitor_with_analyzer.cc diff --git a/src/arith/ir_visitor_with_analyzer.cc b/src/arith/ir_visitor_with_analyzer.cc new file mode 100644 index 000000000000..fb5d15c0d55c --- /dev/null +++ b/src/arith/ir_visitor_with_analyzer.cc @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/arith/ir_visitor_with_analyzer.cc + */ +#include "ir_visitor_with_analyzer.h" + +#include +#include +#include + +namespace tvm { +namespace tir { + +using namespace arith; + +void IRVisitorWithAnalyzer::VisitStmt_(const ForNode* op) { + analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); + StmtExprVisitor::VisitStmt_(op); +} + +void IRVisitorWithAnalyzer::VisitStmt_(const BlockNode* op) { + for (const auto& iter_var : op->iter_vars) { + analyzer_.Bind(iter_var->var, iter_var->dom); + } + StmtExprVisitor::VisitStmt_(op); +} + +void IRVisitorWithAnalyzer::VisitStmt_(const LetStmtNode* op) { + this->VisitExpr(op->value); + analyzer_.Bind(op->var, op->value); + this->VisitStmt(op->body); +} + +void IRVisitorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) { + this->VisitExpr(op->condition); + + PrimExpr real_condition = ExtractRealCondition(op->condition); + + { + With constraint(&analyzer_, real_condition); + this->VisitStmt(op->then_case); + } + if (op->else_case.defined()) { + With constraint(&analyzer_, analyzer_.rewrite_simplify(Not(real_condition))); + this->VisitStmt(op->else_case); + } +} + +void IRVisitorWithAnalyzer::VisitStmt_(const AttrStmtNode* op) { + if (op->attr_key == tir::attr::thread_extent || op->attr_key == tir::attr::virtual_thread) { + IterVar iv = Downcast(op->node); + ICHECK_NE(iv->thread_tag.length(), 0U); + analyzer_.Bind(iv->var, Range::FromMinExtent(0, op->value)); + } + StmtExprVisitor::VisitStmt_(op); +} + +void IRVisitorWithAnalyzer::VisitStmt_(const AssertStmtNode* op) { + this->VisitExpr(op->condition); + this->VisitExpr(op->message); + With constraint(&analyzer_, op->condition); + this->VisitStmt(op->body); +} + +void IRVisitorWithAnalyzer::VisitExpr_(const CallNode* op) { + // add condition context to if_then_else + static auto op_if_then_else = Op::Get("tir.if_then_else"); + if (op->op.same_as(op_if_then_else)) { + PrimExpr cond = op->args[0]; + this->VisitExpr(op->args[0]); + { + With constraint(&analyzer_, cond); + this->VisitExpr(op->args[1]); + } + { + With constraint(&analyzer_, analyzer_.rewrite_simplify(Not(cond))); + this->VisitExpr(op->args[2]); + } + } else { + StmtExprVisitor::VisitExpr_(op); + } +} + +void IRVisitorWithAnalyzer::VisitExpr_(const LetNode* op) { + this->VisitExpr(op->value); + analyzer_.Bind(op->var, op->value); + this->VisitExpr(op->body); +} + +void IRVisitorWithAnalyzer::VisitExpr_(const SelectNode* op) { + this->VisitExpr(op->condition); + + auto real_condition = ExtractRealCondition(op->condition); + { + With constraint(&analyzer_, real_condition); + VisitExpr(op->true_value); + } + { + With constraint(&analyzer_, analyzer_.rewrite_simplify(Not(real_condition))); + VisitExpr(op->false_value); + } +} + +void IRVisitorWithAnalyzer::VisitExpr_(const ReduceNode* op) { + for (const IterVar& iv : op->axis) { + analyzer_.Bind(iv->var, iv->dom); + } + StmtExprVisitor::VisitExpr_(op); +} + +PrimExpr IRVisitorWithAnalyzer::ExtractRealCondition(PrimExpr condition) const { + if (auto call = condition.as()) { + if (call->op.same_as(builtin::likely())) { + return call->args[0]; + } + } + + return condition; +} + +} // namespace tir +} // namespace tvm diff --git a/src/arith/ir_visitor_with_analyzer.h b/src/arith/ir_visitor_with_analyzer.h index 058abc8c7d20..458c7c48c1fd 100644 --- a/src/arith/ir_visitor_with_analyzer.h +++ b/src/arith/ir_visitor_with_analyzer.h @@ -32,38 +32,30 @@ namespace tvm { namespace tir { -class IRVisitorWithAnalyzer final : public StmtExprVisitor { +class IRVisitorWithAnalyzer : public StmtExprVisitor { public: PrimExpr Simplify(const PrimExpr& expr) { return analyzer_.Simplify(expr); } - void VisitStmt_(const ForNode* op) { - analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); - return StmtExprVisitor::VisitStmt_(op); - } + using StmtExprVisitor::VisitExpr_; + using StmtExprVisitor::VisitStmt_; - void VisitStmt_(const AttrStmtNode* op) { - if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { - IterVar iv = Downcast(op->node); - ICHECK_NE(iv->thread_tag.length(), 0U); - analyzer_.Bind(iv->var, Range::FromMinExtent(0, op->value)); - StmtExprVisitor::VisitStmt_(op); - } else { - StmtExprVisitor::VisitStmt_(op); - } - } - - void VisitExpr_(const ReduceNode* op) { - // Setup the domain information before simplification. - for (const IterVar& iv : op->axis) { - analyzer_.Bind(iv->var, iv->dom); - } - // Recursively call simplification when necessary. - StmtExprVisitor::VisitExpr_(op); - } + void VisitStmt_(const ForNode* op); + void VisitStmt_(const BlockNode* op); + void VisitStmt_(const LetStmtNode* op); + void VisitStmt_(const IfThenElseNode* op); + void VisitStmt_(const AttrStmtNode* op); + void VisitStmt_(const AssertStmtNode* op); + void VisitExpr_(const CallNode* op); + void VisitExpr_(const LetNode* op); + void VisitExpr_(const SelectNode* op); + void VisitExpr_(const ReduceNode* op); protected: /*! \brief internal analyzer field. */ arith::Analyzer analyzer_; + + private: + PrimExpr ExtractRealCondition(PrimExpr condition) const; }; } // namespace tir From 9d3740badde6c9279afb2325a0494aee18878f22 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 29 Jun 2022 15:02:12 -0500 Subject: [PATCH 03/14] [Arith] Moved IRVisitorWithAnalyzer to tvm::arith namespace Changing for consistency, since `IRVisitorWithAnalyzer` it is part of the `src/arith` directory and the analogous `IRMutatorWithAnalyzer` is already part of the `arith` namespace. --- src/arith/ir_visitor_with_analyzer.cc | 6 +++--- src/arith/ir_visitor_with_analyzer.h | 26 +++++++++++++------------- src/tir/transforms/storage_flatten.cc | 1 + src/tir/transforms/texture_flatten.cc | 1 + 4 files changed, 18 insertions(+), 16 deletions(-) diff --git a/src/arith/ir_visitor_with_analyzer.cc b/src/arith/ir_visitor_with_analyzer.cc index fb5d15c0d55c..429e461562c0 100644 --- a/src/arith/ir_visitor_with_analyzer.cc +++ b/src/arith/ir_visitor_with_analyzer.cc @@ -27,9 +27,9 @@ #include namespace tvm { -namespace tir { +namespace arith { -using namespace arith; +using namespace tir; void IRVisitorWithAnalyzer::VisitStmt_(const ForNode* op) { analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); @@ -136,5 +136,5 @@ PrimExpr IRVisitorWithAnalyzer::ExtractRealCondition(PrimExpr condition) const { return condition; } -} // namespace tir +} // namespace arith } // namespace tvm diff --git a/src/arith/ir_visitor_with_analyzer.h b/src/arith/ir_visitor_with_analyzer.h index 458c7c48c1fd..d57944dd6dc4 100644 --- a/src/arith/ir_visitor_with_analyzer.h +++ b/src/arith/ir_visitor_with_analyzer.h @@ -30,25 +30,25 @@ #include namespace tvm { -namespace tir { +namespace arith { -class IRVisitorWithAnalyzer : public StmtExprVisitor { +class IRVisitorWithAnalyzer : public tir::StmtExprVisitor { public: PrimExpr Simplify(const PrimExpr& expr) { return analyzer_.Simplify(expr); } using StmtExprVisitor::VisitExpr_; using StmtExprVisitor::VisitStmt_; - void VisitStmt_(const ForNode* op); - void VisitStmt_(const BlockNode* op); - void VisitStmt_(const LetStmtNode* op); - void VisitStmt_(const IfThenElseNode* op); - void VisitStmt_(const AttrStmtNode* op); - void VisitStmt_(const AssertStmtNode* op); - void VisitExpr_(const CallNode* op); - void VisitExpr_(const LetNode* op); - void VisitExpr_(const SelectNode* op); - void VisitExpr_(const ReduceNode* op); + void VisitStmt_(const tir::ForNode* op); + void VisitStmt_(const tir::BlockNode* op); + void VisitStmt_(const tir::LetStmtNode* op); + void VisitStmt_(const tir::IfThenElseNode* op); + void VisitStmt_(const tir::AttrStmtNode* op); + void VisitStmt_(const tir::AssertStmtNode* op); + void VisitExpr_(const tir::CallNode* op); + void VisitExpr_(const tir::LetNode* op); + void VisitExpr_(const tir::SelectNode* op); + void VisitExpr_(const tir::ReduceNode* op); protected: /*! \brief internal analyzer field. */ @@ -58,6 +58,6 @@ class IRVisitorWithAnalyzer : public StmtExprVisitor { PrimExpr ExtractRealCondition(PrimExpr condition) const; }; -} // namespace tir +} // namespace arith } // namespace tvm #endif // TVM_ARITH_IR_VISITOR_WITH_ANALYZER_H_ diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index f2d9aba4fba8..dd236537e9c2 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -47,6 +47,7 @@ namespace tvm { namespace tir { +using arith::IRVisitorWithAnalyzer; using runtime::StorageRank; using runtime::StorageScope; using runtime::ThreadScope; diff --git a/src/tir/transforms/texture_flatten.cc b/src/tir/transforms/texture_flatten.cc index a607e5914b39..3c35b73bc8d7 100644 --- a/src/tir/transforms/texture_flatten.cc +++ b/src/tir/transforms/texture_flatten.cc @@ -38,6 +38,7 @@ namespace tvm { namespace tir { +using arith::IRVisitorWithAnalyzer; using runtime::ApplyTexture2DFlattening; using runtime::DefaultTextureLayoutSeparator; using runtime::IsTextureStorage; From 56235e83b92717131b93490bd3cb2c65bc7ceae8 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 29 Jun 2022 15:05:24 -0500 Subject: [PATCH 04/14] [Arith] Updated BufferDomainTouched to use IRVisitorWithAnalyzer This used the earlier changes to allow subclasses of `IRVisitorWithAnalyzer`, and to expose binding/constraints to `IntSetAnalyzer`. --- src/arith/domain_touched.cc | 43 ++++++++++--------------------------- 1 file changed, 11 insertions(+), 32 deletions(-) diff --git a/src/arith/domain_touched.cc b/src/arith/domain_touched.cc index 403ea47f4e61..d2c5d79a0960 100644 --- a/src/arith/domain_touched.cc +++ b/src/arith/domain_touched.cc @@ -30,6 +30,8 @@ #include #include +#include "ir_visitor_with_analyzer.h" + namespace tvm { namespace arith { @@ -56,7 +58,7 @@ using BufferDomainAccess = std::tuple; } // namespace // Find Read region of the tensor in the stmt. -class BufferTouchedDomain final : public StmtExprVisitor { +class BufferTouchedDomain final : public IRVisitorWithAnalyzer { public: BufferTouchedDomain(const Stmt& stmt) { operator()(stmt); } @@ -90,39 +92,17 @@ class BufferTouchedDomain final : public StmtExprVisitor { return ret; } - void VisitStmt_(const ForNode* op) final { - const VarNode* var = op->loop_var.get(); - dom_map_[var] = IntSet::FromRange(Range::FromMinExtent(op->min, op->extent)); - StmtExprVisitor::VisitStmt_(op); - dom_map_.erase(var); - } - - void VisitStmt_(const LetStmtNode* op) final { - dom_map_[op->var.get()] = arith::EvalSet(op->value, dom_map_); - StmtExprVisitor::VisitStmt_(op); - dom_map_.erase(op->var.get()); - } - - /* TODO: Thread extent unitest not generated.*/ - void VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == tir::attr::thread_extent) { - const IterVarNode* thread_axis = op->node.as(); - ICHECK(thread_axis); - const VarNode* var = thread_axis->var.get(); - dom_map_[var] = IntSet::FromRange(Range(make_zero(op->value.dtype()), op->value)); - StmtExprVisitor::VisitStmt_(op); - dom_map_.erase(var); - } else { - StmtExprVisitor::VisitStmt_(op); - } - } + private: + using Parent = IRVisitorWithAnalyzer; + using Parent::VisitExpr_; + using Parent::VisitStmt_; void VisitExpr_(const BufferLoadNode* op) final { // Record load-exclusive buffer access Touch(&std::get(buffer_access_map_[op->buffer.get()]).set, op->indices); // Record load-store inclusive buffer access Touch(&std::get(buffer_access_map_[op->buffer.get()]).set, op->indices); - StmtExprVisitor::VisitExpr_(op); + Parent::VisitExpr_(op); } void VisitStmt_(const BufferStoreNode* op) final { @@ -130,11 +110,11 @@ class BufferTouchedDomain final : public StmtExprVisitor { Touch(&std::get(buffer_access_map_[op->buffer.get()]).set, op->indices); // Record load-store inclusive buffer access Touch(&std::get(buffer_access_map_[op->buffer.get()]).set, op->indices); - StmtExprVisitor::VisitStmt_(op); + Parent::VisitStmt_(op); } private: - void Touch(BufferTouches* bounds, const Array& args) const { + void Touch(BufferTouches* bounds, const Array& args) { if (args.size() > bounds->size()) { bounds->resize(args.size()); } @@ -142,13 +122,12 @@ class BufferTouchedDomain final : public StmtExprVisitor { if (args[i].as()) { (*bounds)[i].emplace_back(IntSet::Vector(args[i])); } else { - (*bounds)[i].emplace_back(EvalSet(args[i], dom_map_)); + (*bounds)[i].emplace_back(analyzer_.int_set(args[i])); } } } std::unordered_map buffer_access_map_; - std::unordered_map dom_map_; }; Region DomainTouched(const Stmt& stmt, const Buffer& buffer, bool consider_loads, From 5e70b850fa37e9d0df1fe5794fb805f7f26aa914 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 1 Jul 2022 08:26:02 -0500 Subject: [PATCH 05/14] Avoid accidental Bind with dynamic Range --- src/arith/int_set.cc | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 1ccc5fbff097..dea7abb32502 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -515,10 +515,7 @@ class IntSetAnalyzer::Impl { } void Bind(const Var& var, const Range& range, bool allow_override) { - IntSet min = Eval(range->min); - IntSet extent = Eval(range->extent); - - Bind(var, IntervalSet(min.min(), min.max() + extent.max() - 1), allow_override); + Bind(var, IntSet::FromRange(range), allow_override); } void Bind(const Var& var, const IntSet& info, bool override_info); From 981cf3c0b4b7c8be8865fc45012bb8fe88db3769 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 1 Jul 2022 09:35:44 -0500 Subject: [PATCH 06/14] [Arith] Do not visit SelectNode in IRVisitorWithAnalyzer Because both sides of a `Select` node are visited regardless of the condition, the `SelectNode::condition` should not be treated as a known value. --- src/arith/ir_visitor_with_analyzer.cc | 14 -------------- src/arith/ir_visitor_with_analyzer.h | 5 ++++- 2 files changed, 4 insertions(+), 15 deletions(-) diff --git a/src/arith/ir_visitor_with_analyzer.cc b/src/arith/ir_visitor_with_analyzer.cc index 429e461562c0..75ae22ef9915 100644 --- a/src/arith/ir_visitor_with_analyzer.cc +++ b/src/arith/ir_visitor_with_analyzer.cc @@ -105,20 +105,6 @@ void IRVisitorWithAnalyzer::VisitExpr_(const LetNode* op) { this->VisitExpr(op->body); } -void IRVisitorWithAnalyzer::VisitExpr_(const SelectNode* op) { - this->VisitExpr(op->condition); - - auto real_condition = ExtractRealCondition(op->condition); - { - With constraint(&analyzer_, real_condition); - VisitExpr(op->true_value); - } - { - With constraint(&analyzer_, analyzer_.rewrite_simplify(Not(real_condition))); - VisitExpr(op->false_value); - } -} - void IRVisitorWithAnalyzer::VisitExpr_(const ReduceNode* op) { for (const IterVar& iv : op->axis) { analyzer_.Bind(iv->var, iv->dom); diff --git a/src/arith/ir_visitor_with_analyzer.h b/src/arith/ir_visitor_with_analyzer.h index d57944dd6dc4..f41a628f3cc6 100644 --- a/src/arith/ir_visitor_with_analyzer.h +++ b/src/arith/ir_visitor_with_analyzer.h @@ -47,9 +47,12 @@ class IRVisitorWithAnalyzer : public tir::StmtExprVisitor { void VisitStmt_(const tir::AssertStmtNode* op); void VisitExpr_(const tir::CallNode* op); void VisitExpr_(const tir::LetNode* op); - void VisitExpr_(const tir::SelectNode* op); void VisitExpr_(const tir::ReduceNode* op); + // IRVisitorWithAnalyzer deliberately does not handle Select nodes, + // because both sides of a Select node are visited regardless of the + // condition. + protected: /*! \brief internal analyzer field. */ arith::Analyzer analyzer_; From 5d1948bd4936968c13879385ae80f2336108d462 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 5 Jul 2022 15:33:07 -0500 Subject: [PATCH 07/14] [Arith][IntSet] Track global and scope-dependent bounds separately Resolves a bug that was found in CI, where an earlier scope-dependent constraint was treated as a conflict by a later global bound. --- src/arith/int_set.cc | 55 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 46 insertions(+), 9 deletions(-) diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index dea7abb32502..39e0a770bd4d 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -511,7 +511,7 @@ class IntSetAnalyzer::Impl { } IntSet Eval(const PrimExpr& expr) const { - return IntervalSetEvaluator(analyzer_, dom_map_).Eval(expr); + return IntervalSetEvaluator(analyzer_, GetCurrentBounds()).Eval(expr); } void Bind(const Var& var, const Range& range, bool allow_override) { @@ -523,10 +523,24 @@ class IntSetAnalyzer::Impl { std::function EnterConstraint(const PrimExpr& constraint); private: + // Get the current variable bounds, including both global bounds and + // scope-dependent bounds. + Map GetCurrentBounds() const; + + // Utility function to split a boolean condition into the domain + // bounds implied by that condition. static std::vector> DetectBoundInfo(const PrimExpr& cond); + // The parent arith::Analyzer Analyzer* analyzer_; + + // Map of variables to global variable bounds (e.g. loop iterator + // ranges) Map dom_map_; + + // Map of variables to implicit scope-dependent bounds (e.g. inside + // the body of an if-statement) + Map constraints_; }; IntSetAnalyzer::IntSetAnalyzer(Analyzer* parent) : impl_(new Impl(parent)) {} @@ -571,6 +585,29 @@ void IntSetAnalyzer::Impl::Bind(const Var& var, const PrimExpr& expr, bool can_o Bind(var, Eval(expr), can_override); } +Map IntSetAnalyzer::Impl::GetCurrentBounds() const { + // If either constraints_ or dom_map_ is empty, return the other to + // avoid constructing a new map. + if (constraints_.empty()) { + return dom_map_; + } else if (dom_map_.empty()) { + return constraints_; + } + + // If neither is empty, construct a merged domain map with + // information from both sources. + Map merged = dom_map_; + for (const auto& pair : constraints_) { + auto it = merged.find(pair.first); + if (it == merged.end()) { + merged.Set(pair.first, pair.second); + } else { + merged.Set(pair.first, Intersect({pair.second, (*it).second})); + } + } + return merged; +} + std::vector> IntSetAnalyzer::Impl::DetectBoundInfo( const PrimExpr& constraint) { PVar x; @@ -619,8 +656,8 @@ std::function IntSetAnalyzer::Impl::EnterConstraint(const PrimExpr& cons // Collect the current values of each var that is changes by this // constraint. for (const auto& pair : bounds) { - auto it = dom_map_.find(pair.first); - if (it == dom_map_.end()) { + auto it = constraints_.find(pair.first); + if (it == constraints_.end()) { cached_values.Set(pair.first, IntSet()); } else { cached_values.Set(pair.first, (*it).second); @@ -629,20 +666,20 @@ std::function IntSetAnalyzer::Impl::EnterConstraint(const PrimExpr& cons // Update all constraints for (const auto& pair : bounds) { - auto it = dom_map_.find(pair.first); - if (it == dom_map_.end()) { - dom_map_.Set(pair.first, pair.second); + auto it = constraints_.find(pair.first); + if (it == constraints_.end()) { + constraints_.Set(pair.first, pair.second); } else { - dom_map_.Set(pair.first, Intersect({pair.second, (*it).second})); + constraints_.Set(pair.first, Intersect({pair.second, (*it).second})); } } auto frecover = [cached_values, this]() { for (const auto& it : cached_values) { if (it.second.defined()) { - dom_map_.Set(it.first, it.second); + constraints_.Set(it.first, it.second); } else { - dom_map_.erase(it.first); + constraints_.erase(it.first); } } }; From bd6f84f77ced4a4c3599f928d33fda2a4d8780a6 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 6 Jul 2022 10:14:09 -0500 Subject: [PATCH 08/14] [Arith] Recovery function for each subanalyzer This way, if a subanalyzer throws an exception during `EnterConstraint`, the other subanalyzers are still appropriately backed out of the constraint. --- include/tvm/arith/analyzer.h | 2 +- src/arith/analyzer.cc | 26 ++++++++++++-------------- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 828ef22ca1b8..6b50921c5907 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -347,7 +347,7 @@ class ConstraintContext { /*! \brief The constraint */ PrimExpr constraint_; /*! \brief function to be called in recovery */ - std::function exit_; + std::vector> recovery_functions_; }; /*! diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 7158ed6c657a..8b9c32651efd 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -66,24 +66,22 @@ void Analyzer::Bind(const Map& variables, bool allow_override) { } void ConstraintContext::EnterWithScope() { - ICHECK(exit_ == nullptr); + ICHECK(recovery_functions_.size() == 0); // entering the scope. - auto f0 = analyzer_->const_int_bound.EnterConstraint(constraint_); - auto f1 = analyzer_->modular_set.EnterConstraint(constraint_); - auto f2 = analyzer_->rewrite_simplify.EnterConstraint(constraint_); - auto f3 = analyzer_->int_set.EnterConstraint(constraint_); - // recovery function. - exit_ = [f0, f1, f2, f3]() { - if (f3 != nullptr) f3(); - if (f2 != nullptr) f2(); - if (f1 != nullptr) f1(); - if (f0 != nullptr) f0(); - }; + recovery_functions_.push_back(analyzer_->const_int_bound.EnterConstraint(constraint_)); + recovery_functions_.push_back(analyzer_->modular_set.EnterConstraint(constraint_)); + recovery_functions_.push_back(analyzer_->rewrite_simplify.EnterConstraint(constraint_)); + recovery_functions_.push_back(analyzer_->int_set.EnterConstraint(constraint_)); } void ConstraintContext::ExitWithScope() { - ICHECK(exit_ != nullptr); - exit_(); + while (recovery_functions_.size()) { + auto& func = recovery_functions_.back(); + if (func) { + func(); + } + recovery_functions_.pop_back(); + } } bool Analyzer::CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound) { From 254e1bd42ce739e0650bba86546df695395b5939 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 6 Jul 2022 10:15:52 -0500 Subject: [PATCH 09/14] [Arith][IntSet] Use CanProve instead of CanProveGreaterEqual The `min_value - max_value` in the `CanProveGreaterEqual` argument can result in an exception being thrown for unsigned integers where subtraction would wrap. --- src/arith/int_set.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 39e0a770bd4d..84e5c2f840ce 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -64,7 +64,7 @@ IntervalSet Intersect(Analyzer* analyzer, IntervalSet a, IntervalSet b) { PrimExpr min_value = max(a->min_value, b->min_value); if ((max_value.dtype().is_int() || max_value.dtype().is_uint()) && (min_value.dtype().is_int() || min_value.dtype().is_uint()) && - analyzer->CanProveGreaterEqual(min_value - max_value, 1)) { + analyzer->CanProve(max_value < min_value)) { return IntervalSet::Empty(); } else { return IntervalSet(min_value, max_value); From 445a94aa27bb0cd808891bf6b33dab85e7e9fb35 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 7 Jul 2022 07:23:13 -0500 Subject: [PATCH 10/14] [Arith] Allow vector expressions in IntSet::operator(PrimExpr) Since these are tracked when lowering expressions, should allow post-vectorization expressions. To maintain previous behavior, this only applies when using the automatically tracked `Map dom_map_`. If an explicit domain map is passed, the previous behavior of raising an error for vectorized expressions still occurs. --- src/arith/int_set.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 84e5c2f840ce..41b434b08346 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -511,7 +511,7 @@ class IntSetAnalyzer::Impl { } IntSet Eval(const PrimExpr& expr) const { - return IntervalSetEvaluator(analyzer_, GetCurrentBounds()).Eval(expr); + return IntervalSetEvaluator(analyzer_, GetCurrentBounds(), true).Eval(expr); } void Bind(const Var& var, const Range& range, bool allow_override) { From e41aaed7e2444c98838f07557ec254bf98ab48aa Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 8 Jul 2022 12:27:24 -0500 Subject: [PATCH 11/14] Avoid comparisons between integer and handle datatypes --- src/arith/int_set.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 41b434b08346..012fd4aa11f0 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -295,7 +295,10 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Int // a mod b = a - (a / b) * b if a_max / b == a_min / b auto qmax = a->HasUpperBound() ? floordiv(a->max_value, divisor) : pos_inf(); auto qmin = a->HasLowerBound() ? floordiv(a->min_value, divisor) : neg_inf(); - if (analyzer->CanProve(qmax == qmin)) { + // We can compare +/- inf against each other, but cannot use + // operator== between the symbolic limits and an integer. + bool compatible_dtypes = !(qmin.dtype().is_handle() ^ qmax.dtype().is_handle()); + if (compatible_dtypes && analyzer->CanProve(qmax == qmin)) { auto tmax = a->max_value - divisor * qmin; auto tmin = a->min_value - divisor * qmin; return IntervalSet(tmin, tmax); From 47f2ad8801ec318bbb5f549ddd3a3f9c762b9351 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 11 Jul 2022 08:33:36 -0500 Subject: [PATCH 12/14] [Arith] IntSet, Combine() extension Previously, the Combine() method didn't handle values without a known lower bound, for boolean operators. --- src/arith/int_set.cc | 39 +++++++++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 012fd4aa11f0..8bc9808dd0d3 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -106,14 +106,14 @@ TVM_DECLARE_LOGICAL_OP(Not); * \note this can possibly relax the set. */ template -inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, DataType dtype) { if (a->IsSinglePoint() && b->IsSinglePoint()) { PrimExpr res = TryConstFold(a->min_value, b->min_value); if (!res.defined()) res = Op(a->min_value, b->min_value); return IntervalSet::SinglePoint(res); } if (is_logical_op::value) { - return IntervalSet(make_const(a->min_value.dtype(), 0), make_const(a->min_value.dtype(), 1)); + return IntervalSet(make_const(dtype, 0), make_const(dtype, 1)); } if (a->IsEmpty()) return a; if (b->IsEmpty()) return b; @@ -123,7 +123,8 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { } template <> -inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalSet b) { +inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalSet b, + DataType /* dtype */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value + b->min_value); } @@ -137,7 +138,8 @@ inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalS } template <> -inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalSet b) { +inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalSet b, + DataType /* dtype */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value - b->min_value); } @@ -151,7 +153,8 @@ inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalS } template <> -inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, + DataType /* dtype */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value * b->min_value); } @@ -184,7 +187,8 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Interval } template <> -inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, + DataType /* dtype */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value / b->min_value); } @@ -217,7 +221,8 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Interval } template <> -inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, + DataType /* dtype */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(truncmod(a->min_value, b->min_value)); } @@ -245,7 +250,8 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Interval } template <> -inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, + DataType /* dtype */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(floordiv(a->min_value, b->min_value)); } @@ -278,7 +284,8 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Int } template <> -inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, + DataType /* dtype */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(floormod(a->min_value, b->min_value)); } @@ -315,7 +322,8 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Int } template <> -inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, IntervalSet b) { +inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, IntervalSet b, + DataType /* dtype */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(max(a->min_value, b->min_value)); } @@ -325,7 +333,8 @@ inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, Interval } template <> -inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, IntervalSet b) { +inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, IntervalSet b, + DataType /* dtype */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(min(a->min_value, b->min_value)); } @@ -427,10 +436,12 @@ class IntervalSetEvaluator : public ExprFunctor { int64_t vstride = stride.Eval()->value; if (vstride > 0) { return Combine(analyzer_, base, - IntervalSet(make_zero(t), make_const(t, vstride * op->lanes - 1))); + IntervalSet(make_zero(t), make_const(t, vstride * op->lanes - 1)), + op->dtype); } else { return Combine(analyzer_, base, - IntervalSet(make_const(t, vstride * op->lanes + 1), make_zero(t))); + IntervalSet(make_const(t, vstride * op->lanes + 1), make_zero(t)), + op->dtype); } } DLOG(WARNING) << "cannot evaluate set on expression " << GetRef(op); @@ -494,7 +505,7 @@ class IntervalSetEvaluator : public ExprFunctor { if (MatchPoint(a, op->a) && MatchPoint(b, op->b)) { return IntervalSet::SinglePoint(GetRef(op)); } - return Combine(analyzer_, a, b); + return Combine(analyzer_, a, b, op->dtype); } // recursive depth From a4c2e58e1da5a7caf9a03d3f8e9d0d8f355be80a Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 13 Jul 2022 08:09:40 -0500 Subject: [PATCH 13/14] Added docstring --- include/tvm/arith/analyzer.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 6b50921c5907..d1c9e945e603 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -267,6 +267,12 @@ class RewriteSimplifier { */ TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool allow_override = false); + /*! + * \brief Update the internal state to enter constraint. + * \param constraint A constraint expression. + * + * \return an exit function that must be called to cleanup the constraint can be nullptr. + */ std::function EnterConstraint(const PrimExpr& constraint); private: From 348061163fdf536b0891976334ab5560b6810ca4 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 13 Jul 2022 08:17:22 -0500 Subject: [PATCH 14/14] Naming consistency of `IntSetAnalyzer` methods. To be consistent with other subanalyzers, using "Update" when providing the analyzer with the same data structure as is used internally, and "Bind" used when providing it with something that must be converted to the internal data structure. --- include/tvm/arith/analyzer.h | 2 +- src/arith/analyzer.cc | 2 +- src/arith/int_set.cc | 12 ++++++------ 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index d1c9e945e603..ceb9f574f2c9 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -397,7 +397,7 @@ class IntSetAnalyzer { * \param new_range The range of allowed values for this var. * \param allow_override whether we allow override of existing information. */ - TVM_DLL void Update(const Var& var, const Range& new_range, bool allow_override = false); + TVM_DLL void Bind(const Var& var, const Range& new_range, bool allow_override = false); std::function EnterConstraint(const PrimExpr& constraint); diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 8b9c32651efd..f32c9b2ff4cf 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -53,7 +53,7 @@ void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) { this->Bind(var, range->min, allow_override); } else { this->const_int_bound.Bind(var, range, allow_override); - this->int_set.Update(var, range, allow_override); + this->int_set.Bind(var, range, allow_override); } // skip modular_set // skip rewrite simplify diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 8bc9808dd0d3..6d48ad1ed151 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -529,10 +529,10 @@ class IntSetAnalyzer::Impl { } void Bind(const Var& var, const Range& range, bool allow_override) { - Bind(var, IntSet::FromRange(range), allow_override); + Update(var, IntSet::FromRange(range), allow_override); } - void Bind(const Var& var, const IntSet& info, bool override_info); + void Update(const Var& var, const IntSet& info, bool override_info); void Bind(const Var& var, const PrimExpr& expr, bool override_info); std::function EnterConstraint(const PrimExpr& constraint); @@ -568,14 +568,14 @@ IntSet IntSetAnalyzer::operator()(const PrimExpr& expr, const Map& IntSet IntSetAnalyzer::operator()(const PrimExpr& expr) { return impl_->Eval(expr); } void IntSetAnalyzer::Update(const Var& var, const IntSet& info, bool allow_override) { - impl_->Bind(var, info, allow_override); + impl_->Update(var, info, allow_override); } -void IntSetAnalyzer::Update(const Var& var, const Range& range, bool allow_override) { +void IntSetAnalyzer::Bind(const Var& var, const Range& range, bool allow_override) { impl_->Bind(var, range, allow_override); } -void IntSetAnalyzer::Impl::Bind(const Var& var, const IntSet& info, bool can_override) { +void IntSetAnalyzer::Impl::Update(const Var& var, const IntSet& info, bool can_override) { if (!can_override) { auto it = dom_map_.find(var); if (it != dom_map_.end()) { @@ -596,7 +596,7 @@ void IntSetAnalyzer::Impl::Bind(const Var& var, const IntSet& info, bool can_ove } void IntSetAnalyzer::Impl::Bind(const Var& var, const PrimExpr& expr, bool can_override) { - Bind(var, Eval(expr), can_override); + Update(var, Eval(expr), can_override); } Map IntSetAnalyzer::Impl::GetCurrentBounds() const {