Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 200 additions & 0 deletions src/arith/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <tvm/tir/expr_functor.h>

#include <algorithm>
#include <optional>

#include "constraint_extract.h"
#include "int_operator.h"
Expand Down Expand Up @@ -81,6 +82,16 @@ struct ConstIntBoundAnalyzer::Entry {
bool operator==(const Entry& other) const {
return min_value == other.min_value && max_value == other.max_value;
}

friend std::ostream& operator<<(std::ostream& os, const Entry& entry) {
os << "Entry[";
PrintBoundValue(os, entry.min_value);
os << ", ";
PrintBoundValue(os, entry.max_value);
os << "]";

return os;
}
};

class ConstIntBoundAnalyzer::Impl
Expand Down Expand Up @@ -228,6 +239,11 @@ class ConstIntBoundAnalyzer::Impl
Entry ret;
ret.min_value = InfAwareAdd(a.min_value, b.min_value);
ret.max_value = InfAwareAdd(a.max_value, b.max_value);

if (auto bound = BoundUsingReciprocal(GetRef<PrimExpr>(op))) {
ret = Intersect(ret, bound.value());
}

return ret;
}

Expand All @@ -237,6 +253,13 @@ class ConstIntBoundAnalyzer::Impl
Entry ret;
ret.min_value = InfAwareAdd(a.min_value, -b.max_value);
ret.max_value = InfAwareAdd(a.max_value, -b.min_value);

if (auto bound = BoundUsingReciprocal(GetRef<Sub>(op))) {
ret = Intersect(ret, bound.value());
}
if (auto bound = BoundUsingReciprocal(Sub(op->b, op->a))) {
ret = Intersect(ret, Negative(bound.value()));
}
return ret;
}

Expand Down Expand Up @@ -628,6 +651,25 @@ class ConstIntBoundAnalyzer::Impl
ret.max_value = std::min(a.max_value, b.max_value);
return ret;
}
/*!
* \brief Flip the sign of a set.
* \param entry The set of values
*/
static Entry Negative(Entry entry) {
Entry ret;
if (entry.max_value == kPosInf) {
ret.min_value = kNegInf;
} else {
ret.min_value = -entry.max_value;
}
if (entry.min_value == kNegInf) {
ret.max_value = kPosInf;
} else {
ret.max_value = -entry.min_value;
}

return ret;
}
/*!
* \brief return everything dtype can represent.
* \param dtype The data type.
Expand Down Expand Up @@ -733,6 +775,164 @@ class ConstIntBoundAnalyzer::Impl
std::ceil(std::log2(arg_bounds.max_value)));
}
}

std::optional<Entry> BoundUsingReciprocal(PrimExpr expr) {
// Match expressions of the form `(A+B)*C - (A*B)*D`. Depending on
// previous simplifications, the exact form of the expression may vary.
auto opt_special_case = [&]() -> std::optional<std::tuple<Entry, Entry, Entry, Entry>> {
PVar<PrimExpr> A, B, C, D;

if (PMatchesOneOf{
(A + B) * C - (A * B) * D,
(A + B) * C - (B * A) * D,
}
.Match(expr)) {
return std::tuple{VisitExpr(A.Eval()), VisitExpr(B.Eval()), VisitExpr(C.Eval()),
VisitExpr(D.Eval())};
} else if (PMatchesOneOf{
(A + B) * C - A * B,
(A + B) * C - B * A,
}
.Match(expr)) {
return std::tuple{VisitExpr(A.Eval()), VisitExpr(B.Eval()), VisitExpr(C.Eval()),
MakeBound(1, 1)};
} else if (PMatchesOneOf{
(A * B) * D - (A + B) * C,
(B * A) * D - (A + B) * C,
}
.Match(expr)) {
return std::tuple{Negative(VisitExpr(A.Eval())), Negative(VisitExpr(B.Eval())),
Negative(VisitExpr(C.Eval())), Negative(VisitExpr(D.Eval()))};
} else if (PMatchesOneOf{
A * B - (A + B) * C,
B * A - (A + B) * C,
}
.Match(expr)) {
return std::tuple{Negative(VisitExpr(A.Eval())), Negative(VisitExpr(B.Eval())),
Negative(VisitExpr(C.Eval())), MakeBound(-1, -1)};
} else if (PMatchesOneOf{
(A * B) * D + (A + B) * C,
(B * A) * D + (A + B) * C,
(A + B) * C + (A * B) * D,
(A + B) * C + (B * A) * D,
}
.Match(expr)) {
return std::tuple{Negative(VisitExpr(A.Eval())), Negative(VisitExpr(B.Eval())),
VisitExpr(C.Eval()), Negative(VisitExpr(D.Eval()))};
} else if (PMatchesOneOf{
(A * B) + (A + B) * C,
(B * A) + (A + B) * C,
(A + B) * C + (A * B),
(A + B) * C + (B * A),
}
.Match(expr)) {
return std::tuple{Negative(VisitExpr(A.Eval())), Negative(VisitExpr(B.Eval())),
VisitExpr(C.Eval()), MakeBound(-1, -1)};
} else {
return std::nullopt;
}
}();

if (!opt_special_case.has_value()) {
return std::nullopt;
}
// Unpacking the tuple would be cleaner with a structured binding.
// However, until C++20, structured bindings cannot be captured for
// use in a lambda function.
auto A_bound = std::get<0>(*opt_special_case);
auto B_bound = std::get<1>(*opt_special_case);
auto C_bound = std::get<2>(*opt_special_case);
auto D_bound = std::get<3>(*opt_special_case);

// If C and D have different signs, flip the signs of A/B/C so
// that C will match the sign of D.
if ((D_bound.max_value < 0 && C_bound.min_value > 0) ||
(D_bound.min_value > 0 && C_bound.max_value < 0)) {
A_bound = Negative(A_bound);
B_bound = Negative(B_bound);
C_bound = Negative(C_bound);
}

// If all terms are negative, then we'll be providing an upper bound
// rather than a lower bound. To avoid code duplication, flip all the
// signs here, find a lower bound, then flip the sign to produce the
// upper bound of the original expression.
bool all_terms_negative = (A_bound.max_value < 0 && B_bound.max_value < 0 &&
C_bound.max_value < 0 && D_bound.max_value < 0);
if (all_terms_negative) {
A_bound = Negative(A_bound);
B_bound = Negative(B_bound);
C_bound = Negative(C_bound);
D_bound = Negative(D_bound);
}

bool all_terms_positive = (A_bound.min_value > 0 && B_bound.min_value > 0 &&
C_bound.min_value > 0 && D_bound.min_value > 0);
if (!all_terms_positive) {
return std::nullopt;
}

// (A + B) * C - (A * B) * D
// (A*B*C*D) * ( (A+B)/(A*B*D) - 1/C )
// (A*B*C*D) * ( (1/A + 1/B)/D - 1/C )
// (A*B*C*D) * (1/(A*D) + 1/(B*D) - 1/C)
//
// The constant (A*B*C*D) is positive, and its minimum value is the
// product of the minimum values of A, B, C, and D. If the reciprocal
// term (1/(A*D) + 1/(B*D) - 1/C) is positive, then this constant can
// be used to provide a lower bound on the expression.

bool reciprocal_term_is_positive = [&]() {
if (D_bound.max_value == ConstIntBound::kPosInf) {
// If D can grow without bound, the `1/(A*D)` and `1/(B*D)`
// terms will approach zero, at which point the `-1/C` term
// will determine the sign the sign.
return false;
}

if (std::min(A_bound.max_value, B_bound.max_value) * D_bound.max_value <= C_bound.min_value) {
// 1/(A*D) + 1/(B*D) - 1/C is positive if 1/C < 1/(A*D) + 1/(B*D).
// Since each term is positive, this condition can hold if either
// A*D <= C or B*D <= C.
return true;
}
if (A_bound.max_value != ConstIntBound::kPosInf &&
B_bound.max_value != ConstIntBound::kPosInf) {
// Even if neither term is sufficient on its own, if both A and B
// have known upper bounds, the inequality 1/C < 1/(A*D) + 1/(B*D)
// may still be provable.
//
// The maximum value of the LHS is found when C is minimized. The
// minimum value of the RHS is found when A, B, and D are
// maximized. If the condition holds in this case, then it holds
// in all cases.
//
// 1/C_min < 1/(A_max * D_max) + 1/(B_max*D_max)
// A_max*B_max*D_max < C_min*B_max + C_min*A_max
// A_max*B_max*D_max < C_min*(A_max + B_max)
//
if (A_bound.max_value * B_bound.max_value * D_bound.max_value <
C_bound.min_value * (A_bound.max_value + B_bound.max_value)) {
return true;
}
}
return false;
}();

if (!reciprocal_term_is_positive) {
return std::nullopt;
}

auto ret = Everything(expr->dtype);
ret.min_value = A_bound.min_value * B_bound.min_value * C_bound.min_value * D_bound.min_value;

// If we flipped the sign of the original expression, flip the sign of
// the resulting set of possible values.
if (all_terms_negative) {
ret = Negative(ret);
}
return ret;
}
};

ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr) const {
Expand Down
11 changes: 11 additions & 0 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1768,6 +1768,17 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(LT ret) {
if (merge_constants) {
return RecursiveRewrite(merge_constants.value());
}

auto common_factor = [&]() -> int64_t {
auto modular_a = analyzer_->modular_set(ret->a);
auto modular_b = analyzer_->modular_set(ret->b);
auto gcd_lhs = ZeroAwareGCD(modular_a->base, modular_a->coeff);
auto gcd_rhs = ZeroAwareGCD(modular_b->base, modular_b->coeff);
return ZeroAwareGCD(gcd_lhs, gcd_rhs);
}();
if (common_factor > 1) {
return RecursiveRewrite(floordiv(ret->a, common_factor) < floordiv(ret->b, common_factor));
}
}
return std::move(ret);
}
Expand Down
Loading