Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions src/arith/solve_linear_equation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -427,11 +427,10 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol

// We have to transform ranges of the old variables into relations over new variables because
// new ranges are not enough usually.
for (const auto& p : system_to_solve->ranges) {
const Var& old_var = p.first;
const Range& old_range = p.second;
if (old_to_new_map.count(old_var)) {
PrimExpr express_by_new_vars = old_to_new_map[old_var];
for (const auto& old_var : system_to_solve->variables) {
if (system_to_solve->ranges.find(old_var) != system_to_solve->ranges.end()) {
const Range& old_range = system_to_solve->ranges.at(old_var);
PrimExpr express_by_new_vars = old_to_new_map.at(old_var);
PrimExpr lower_cond = analyzer_solution.Simplify(old_range->min <= express_by_new_vars);
PrimExpr upper_cond =
analyzer_solution.Simplify(express_by_new_vars < old_range->min + old_range->extent);
Expand Down
54 changes: 27 additions & 27 deletions src/arith/solve_linear_inequality.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,10 @@ struct ExprLess {
}
};

void DebugPrint(
const std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>& current_ineq_set,
const std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>& next_ineq_set,
const std::vector<PrimExpr>& rest, const std::vector<std::pair<int64_t, PrimExpr>>& coef_pos,
const std::vector<std::pair<int64_t, PrimExpr>>& coef_neg) {
void DebugPrint(const std::vector<PrimExpr>& current_ineq_set,
const std::vector<PrimExpr>& next_ineq_set, const std::vector<PrimExpr>& rest,
const std::vector<std::pair<int64_t, PrimExpr>>& coef_pos,
const std::vector<std::pair<int64_t, PrimExpr>>& coef_neg) {
std::cout << "Current ineq set:\n[";
for (auto& ineq : current_ineq_set) {
std::cout << ineq << ", ";
Expand Down Expand Up @@ -148,9 +147,12 @@ class NormalizeComparisons : public ExprMutator {
arith::Analyzer analyzer_;
};

void AddInequality(std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>* inequality_set,
const PrimExpr& new_ineq, Analyzer* analyzer) {
if (analyzer->CanProve(new_ineq) || inequality_set->find(new_ineq) != inequality_set->end()) {
void AddInequality(std::vector<PrimExpr>* inequality_set, const PrimExpr& new_ineq,
Analyzer* analyzer) {
if (analyzer->CanProve(new_ineq) ||
std::find_if(inequality_set->begin(), inequality_set->end(), [&](const PrimExpr& e) {
return StructuralEqual()(e, new_ineq);
}) != inequality_set->end()) {
// redundant: follows from the vranges
// or has already been added
return;
Expand All @@ -168,15 +170,13 @@ void AddInequality(std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>
}
}

inequality_set->insert(new_ineq);
inequality_set->push_back(new_ineq);
}

void ClassifyByPolarity(
const Var& var,
const std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>& current_ineq_set,
std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>* next_ineq_set,
std::vector<PrimExpr>* rest, std::vector<std::pair<int64_t, PrimExpr>>* coef_pos,
std::vector<std::pair<int64_t, PrimExpr>>* coef_neg, Analyzer* analyzer) {
void ClassifyByPolarity(const Var& var, const std::vector<PrimExpr>& current_ineq_set,
std::vector<PrimExpr>* next_ineq_set, std::vector<PrimExpr>* rest,
std::vector<std::pair<int64_t, PrimExpr>>* coef_pos,
std::vector<std::pair<int64_t, PrimExpr>>* coef_neg, Analyzer* analyzer) {
// Take formulas from current_ineq_set and classify them according to polarity wrt var
// and store to coef_pos and coef_neg respectively.
for (const PrimExpr& ineq : current_ineq_set) {
Expand Down Expand Up @@ -218,14 +218,14 @@ void ClassifyByPolarity(
}
}

void MoveEquality(std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>* upper_bounds,
std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>* lower_bounds,
std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>* equalities) {
void MoveEquality(std::vector<PrimExpr>* upper_bounds, std::vector<PrimExpr>* lower_bounds,
std::vector<PrimExpr>* equalities) {
// those exist in both upper & lower bounds will be moved to equalities
for (auto ub = upper_bounds->begin(); ub != upper_bounds->end();) {
auto lb = lower_bounds->find(*ub);
auto lb = std::find_if(lower_bounds->begin(), lower_bounds->end(),
[&](const PrimExpr& e) { return StructuralEqual()(e, *ub); });
if (lb != lower_bounds->end()) {
equalities->insert(*lb);
equalities->push_back(*lb);
lower_bounds->erase(lb);
ub = upper_bounds->erase(ub);
} else {
Expand All @@ -249,8 +249,8 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t
// and move to the next variable.

// normalized inequality
std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> current_ineq_set_to_solve;
std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> next_ineq_set_to_solve;
std::vector<PrimExpr> current_ineq_set_to_solve;
std::vector<PrimExpr> next_ineq_set_to_solve;
// A vector of pairs (c, e), c > 0, representing formulas of the form c*v + e <= 0
std::vector<std::pair<int64_t, PrimExpr>> coef_pos;
// A vector of pairs (c, e), c < 0, representing formulas of the form c*v + e <= 0
Expand Down Expand Up @@ -321,8 +321,8 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t
}

// The resulting lower and upper bounds
std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> upper_bounds;
std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> lower_bounds;
std::vector<PrimExpr> upper_bounds;
std::vector<PrimExpr> lower_bounds;
upper_bounds.reserve(coef_pos.size());
lower_bounds.reserve(coef_neg.size());

Expand All @@ -345,7 +345,7 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t
}
}
// Add the upper bound
upper_bounds.insert(bound);
upper_bounds.push_back(bound);
}
for (const auto& neg : coef_neg) {
PrimExpr bound = make_const(v.dtype(), -coef_lcm / neg.first) * neg.second;
Expand All @@ -366,10 +366,10 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t
}
}
// Add the lower bound
lower_bounds.insert(bound);
lower_bounds.push_back(bound);
}

std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> equal;
std::vector<PrimExpr> equal;
equal.reserve(std::min(upper_bounds.size(), lower_bounds.size()));
MoveEquality(&upper_bounds, &lower_bounds, &equal);
std::vector<PrimExpr> equal_list(equal.begin(), equal.end());
Expand Down
26 changes: 15 additions & 11 deletions src/te/autodiff/ad_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -413,15 +413,17 @@ class FactorOutAtomicFormulasFunctor
auto res_b = VisitExpr(op->b);

// For the And case we return the union of the sets of atomic formulas
std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set;
res_set.reserve(res_a.atomic_formulas.size() + res_b.atomic_formulas.size());
std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_a_set;
res_a_set.reserve(res_a.atomic_formulas.size());
std::copy(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(),
std::inserter(res_set, res_set.end()));
std::copy(res_b.atomic_formulas.begin(), res_b.atomic_formulas.end(),
std::inserter(res_set, res_set.end()));

std::vector<PrimExpr> res{res_set.begin(), res_set.end()};
std::inserter(res_a_set, res_a_set.end()));

std::vector<PrimExpr> res = res_a.atomic_formulas;
for (const auto& e : res_b.atomic_formulas) {
if (res_a_set.find(e) == res_a_set.end()) {
res.emplace_back(e);
}
}
// And the residuals are combined with &&
return {res, res_a.rest && res_b.rest};
}
Expand All @@ -443,32 +445,34 @@ class FactorOutAtomicFormulasFunctor

// For the Or case we intersect the sets of atomic formulas
std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set;
std::vector<PrimExpr> res;
res_set.reserve(std::min(res_a.atomic_formulas.size(), res_b.atomic_formulas.size()));
for (const auto& res_b_formula : res_b_set) {
res.reserve(std::min(res_a.atomic_formulas.size(), res_b.atomic_formulas.size()));
for (const auto& res_b_formula : res_b.atomic_formulas) {
if (res_a_set.count(res_b_formula)) {
res_set.insert(res_b_formula);
res.push_back(res_b_formula);
}
}

// Computing the residual is more complex: we have to compute the sets of atomic formulas
// which are left behind, and then combine them with the residuals into the new residual.
std::vector<PrimExpr> new_cond_a;
new_cond_a.reserve(res_a.atomic_formulas.size() - res_set.size());
for (const auto& formula : res_a_set) {
for (const auto& formula : res_a.atomic_formulas) {
if (!res_set.count(formula)) new_cond_a.emplace_back(formula);
}

std::vector<PrimExpr> new_cond_b;
new_cond_b.reserve(res_b.atomic_formulas.size() - res_set.size());
for (const auto& formula : res_b_set) {
for (const auto& formula : res_b.atomic_formulas) {
if (!res_set.count(formula)) new_cond_b.emplace_back(formula);
}

res_a.atomic_formulas = std::move(new_cond_a);
res_b.atomic_formulas = std::move(new_cond_b);

PrimExpr new_rest = res_a.to_expr() || res_b.to_expr();
std::vector<PrimExpr> res{res_set.begin(), res_set.end()};

return {res, new_rest};
}
Expand Down