Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 52 additions & 74 deletions src/arith/int_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,13 @@ using namespace tir;
// We might use better set analysis in the future to replace the intervalset.
class IntervalSetEvaluator : public ExprFunctor<IntervalSet(const PrimExpr&)> {
public:
IntervalSetEvaluator(Analyzer* analyzer, const Map<Var, IntSet>& dom_map, bool eval_vec = false)
: analyzer_(analyzer), dom_map_(dom_map), eval_vec_(eval_vec) {}
IntervalSetEvaluator(Analyzer* analyzer, const Map<Var, IntSet>& dom_map,
const std::vector<std::pair<Var, IntSet>>* dom_constraints = nullptr,
bool eval_vec = false)
: analyzer_(analyzer),
dom_map_(dom_map),
dom_constraints_(dom_constraints),
eval_vec_(eval_vec) {}

IntervalSet Eval(const PrimExpr& val) { return this->VisitExpr(val); }
// evaluate and relax the set
Expand All @@ -383,18 +388,40 @@ class IntervalSetEvaluator : public ExprFunctor<IntervalSet(const PrimExpr&)> {

IntervalSet VisitExpr_(const VarNode* op) final {
Var var = GetRef<Var>(op);

Array<IntSet> values;
if (dom_constraints_) {
for (const auto& constraint : *dom_constraints_) {
if (var.same_as(constraint.first)) {
values.push_back(constraint.second);
}
}
}

auto it = dom_map_.find(var);
if (it != dom_map_.end()) {
IntervalSet res = ToIntervalSet((*it).second);
if (res->min_value.same_as(var) && res->max_value.same_as(var)) {
return res;
}
// recursively evaluate mapped result
// in case the domain contains variables to be relaxed.
return Eval(res);
} else {
values.push_back((*it).second);
}

if (values.empty()) {
return IntervalSet::SinglePoint(var);
}

IntSet intersection = [&]() {
if (values.size() == 1) {
return values.front();
} else {
return Intersect(values);
}
}();

IntervalSet res = ToIntervalSet(intersection);
if (res->min_value.same_as(var) && res->max_value.same_as(var)) {
return res;
}
// recursively evaluate mapped result
// in case the domain contains variables to be relaxed.
return Eval(res);
}

IntervalSet VisitExpr_(const AddNode* op) final { return VisitBinaryExpr_<Add>(op); }
Expand Down Expand Up @@ -517,6 +544,7 @@ class IntervalSetEvaluator : public ExprFunctor<IntervalSet(const PrimExpr&)> {
// analyzer
Analyzer* analyzer_;
const Map<Var, IntSet>& dom_map_;
const std::vector<std::pair<Var, IntSet>>* dom_constraints_;
bool eval_vec_{false};
};

Expand All @@ -529,7 +557,7 @@ class IntSetAnalyzer::Impl {
}

IntSet Eval(const PrimExpr& expr) const {
return IntervalSetEvaluator(analyzer_, GetCurrentBounds(), true).Eval(expr);
return IntervalSetEvaluator(analyzer_, dom_map_, &dom_constraints_, true).Eval(expr);
}

void Bind(const Var& var, const Range& range, bool allow_override) {
Expand All @@ -541,10 +569,6 @@ class IntSetAnalyzer::Impl {
std::function<void()> EnterConstraint(const PrimExpr& constraint);

private:
// Get the current variable bounds, including both global bounds and
// scope-dependent bounds.
Map<Var, IntSet> GetCurrentBounds() const;

// Utility function to split a boolean condition into the domain
// bounds implied by that condition.
static std::vector<std::pair<Var, IntSet>> DetectBoundInfo(const PrimExpr& cond);
Expand All @@ -556,9 +580,11 @@ class IntSetAnalyzer::Impl {
// ranges)
Map<Var, IntSet> dom_map_;

// Map of variables to implicit scope-dependent bounds (e.g. inside
// the body of an if-statement)
Map<Var, IntSet> constraints_;
// List of implicit scope-dependent bounds (e.g. inside the body of
// an if-statement). Maintained as a list of constraints, rather
// than as a `Map<Var,IntSet>`, to avoid computing an Intersection
// until required.
std::vector<std::pair<Var, IntSet>> dom_constraints_;
};

IntSetAnalyzer::IntSetAnalyzer(Analyzer* parent) : impl_(new Impl(parent)) {}
Expand Down Expand Up @@ -603,29 +629,6 @@ void IntSetAnalyzer::Impl::Bind(const Var& var, const PrimExpr& expr, bool can_o
Update(var, Eval(expr), can_override);
}

Map<Var, IntSet> IntSetAnalyzer::Impl::GetCurrentBounds() const {
// If either constraints_ or dom_map_ is empty, return the other to
// avoid constructing a new map.
if (constraints_.empty()) {
return dom_map_;
} else if (dom_map_.empty()) {
return constraints_;
}

// If neither is empty, construct a merged domain map with
// information from both sources.
Map<Var, IntSet> merged = dom_map_;
for (const auto& pair : constraints_) {
auto it = merged.find(pair.first);
if (it == merged.end()) {
merged.Set(pair.first, pair.second);
} else {
merged.Set(pair.first, Intersect({pair.second, (*it).second}));
}
}
return merged;
}

std::vector<std::pair<Var, IntSet>> IntSetAnalyzer::Impl::DetectBoundInfo(
const PrimExpr& constraint) {
PVar<Var> x;
Expand Down Expand Up @@ -665,41 +668,16 @@ std::function<void()> IntSetAnalyzer::EnterConstraint(const PrimExpr& constraint
}

std::function<void()> IntSetAnalyzer::Impl::EnterConstraint(const PrimExpr& constraint) {
Map<Var, IntSet> cached_values;

auto bounds = DetectBoundInfo(constraint);

if (bounds.size() == 0) return nullptr;

// Collect the current values of each var that is changes by this
// constraint.
for (const auto& pair : bounds) {
auto it = constraints_.find(pair.first);
if (it == constraints_.end()) {
cached_values.Set(pair.first, IntSet());
} else {
cached_values.Set(pair.first, (*it).second);
}
}

// Update all constraints
for (const auto& pair : bounds) {
auto it = constraints_.find(pair.first);
if (it == constraints_.end()) {
constraints_.Set(pair.first, pair.second);
} else {
constraints_.Set(pair.first, Intersect({pair.second, (*it).second}));
}
}

auto frecover = [cached_values, this]() {
for (const auto& it : cached_values) {
if (it.second.defined()) {
constraints_.Set(it.first, it.second);
} else {
constraints_.erase(it.first);
}
}
size_t old_size = dom_constraints_.size();
dom_constraints_.insert(dom_constraints_.end(), bounds.begin(), bounds.end());
size_t new_size = dom_constraints_.size();
auto frecover = [old_size, new_size, this]() {
ICHECK_EQ(dom_constraints_.size(), new_size);
dom_constraints_.resize(old_size);
};
return frecover;
}
Expand Down Expand Up @@ -960,13 +938,13 @@ Map<Var, IntSet> ConvertDomMap(const std::unordered_map<const VarNode*, IntSet>&

IntSet EvalSet(PrimExpr e, const Map<Var, IntSet>& dom_map) {
Analyzer ana;
return IntervalSetEvaluator(&ana, dom_map, false).Eval(e);
return IntervalSetEvaluator(&ana, dom_map, {}, false).Eval(e);
}

IntSet IntSet::Vector(PrimExpr x) {
Analyzer ana;
Map<Var, IntSet> dmap;
return IntervalSetEvaluator(&ana, dmap, true).Eval(x);
return IntervalSetEvaluator(&ana, dmap, {}, true).Eval(x);
}

IntSet EvalSet(PrimExpr e, const Map<IterVar, IntSet>& dom_map) {
Expand Down