Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 55 additions & 16 deletions csrc/expr_simplifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,18 @@
#include <unordered_set>
#include <vector>

namespace std {
template <typename X, typename Y>
struct hash<std::pair<X, Y>> {
std::size_t operator()(const std::pair<X, Y>& pair) const {
std::size_t h1 = std::hash<X>()(pair.first);
std::size_t h2 = std::hash<Y>()(pair.second);
nvfuser::hashCombine(h1, h2);
return h1;
}
};
} // namespace std

namespace nvfuser {

namespace debug_print {
Expand Down Expand Up @@ -251,6 +263,9 @@ class Context {
return less_equal_;
}

mutable std::unordered_map<std::pair<Val*, Val*>, bool> less_than_cache_;
mutable std::unordered_map<std::pair<Val*, Val*>, bool> less_equal_cache_;

private:
void assume(Val* a) {
auto def = a->definition();
Expand Down Expand Up @@ -1527,77 +1542,99 @@ bool hasCompatibleSign(Val* x, Val* y, const Context& context) {
return isNonNegative(x, context) && isNonNegative(y, context);
}

#define CACHE_AND_RETURN_LT(value) \
context.less_than_cache_.emplace(std::make_pair(x, y), value); \
return value

bool lessThan(Val* x, Val* y, const Context& context) {
auto cache_it = context.less_than_cache_.find({x, y});
if (cache_it != context.less_than_cache_.end()) {
return cache_it->second;
}

x = foldConstants(x);
y = foldConstants(y);
if (x->value().hasValue() && y->value().hasValue()) {
return x->value() < y->value();
bool result = x->value() < y->value();
CACHE_AND_RETURN_LT(result);
}
x = maybeUnwrapMagicZero(x);
y = maybeUnwrapMagicZero(y);
if (x->isZero() && isPositiveHelper(y, context)) {
return true;
CACHE_AND_RETURN_LT(true);
}
// i1 % i2 < i2
if (auto bop = dynamic_cast<BinaryOp*>(x->definition());
bop != nullptr && bop->getBinaryOpType() == BinaryOpType::Mod) {
auto denominator = bop->rhs();
if (denominator->sameAs(y) && isValidDenominator(denominator, context) &&
isNonNegative(y, context)) {
return true;
CACHE_AND_RETURN_LT(true);
}
}
// x <= a & a < b & b <= y --> x < y
for (const auto& [a, b] : context.getKnownLessThan()) {
if (lessEqual(x, a, context) && lessEqual(b, y, context)) {
return true;
CACHE_AND_RETURN_LT(true);
}
}
return false;
CACHE_AND_RETURN_LT(false);
}

#undef CACHE_AND_RETURN_LT

#define CACHE_AND_RETURN_LE(value) \
context.less_equal_cache_.emplace(std::make_pair(x, y), value); \
return value

bool lessEqual(Val* x, Val* y, const Context& context) {
auto cache_it = context.less_equal_cache_.find({x, y});
if (cache_it != context.less_equal_cache_.end()) {
return cache_it->second;
}

x = foldConstants(x);
y = foldConstants(y);
if (x->value().hasValue() && y->value().hasValue()) {
return x->value() <= y->value();
bool result = x->value() <= y->value();
CACHE_AND_RETURN_LE(result);
}
x = maybeUnwrapMagicZero(x);
y = maybeUnwrapMagicZero(y);
// x == y -> x <= y
if (x->sameAs(y)) {
return true;
CACHE_AND_RETURN_LE(true);
}
if (x->isZero() && isNonNegativeHelper(y, context)) {
return true;
CACHE_AND_RETURN_LE(true);
}
for (const auto& [a, b] : context.getKnownLessThan()) {
// x < y --> x <= y
if (a->sameAs(x) && b->sameAs(y)) {
return true;
CACHE_AND_RETURN_LE(true);
}
}
for (const auto& [a, b] : context.getKnownLessEqual()) {
if (a->sameAs(x) && b->sameAs(y)) {
return true;
CACHE_AND_RETURN_LE(true);
}
}
for (const auto& [a, b] : context.getKnownLessThan()) {
// x < b & b <= y --> x <= y
if (a->sameAs(x) && lessEqual(b, y, context)) {
return true;
CACHE_AND_RETURN_LE(true);
}
}
for (const auto& [a, b] : context.getKnownLessEqual()) {
// x <= b & b <= y --> x <= y
if (a->sameAs(x) && lessEqual(b, y, context)) {
return true;
CACHE_AND_RETURN_LE(true);
}
}
// if i is an integer, i > 0, then i >= 1
if (x->isOneInt() && y->isIntegralScalar()) {
if (isPositiveHelper(y, context)) {
return true;
CACHE_AND_RETURN_LE(true);
}
}
// if a >= 0, b >= 1, then a <= a * b
Expand All @@ -1619,7 +1656,7 @@ bool lessEqual(Val* x, Val* y, const Context& context) {
maybeFlattenedOpOf(BinaryOpType::Mul, std::move(remaining_inputs));
auto one = IrBuilder::create<Val>(1L, *remaining->getDataType());
if (lessEqual(one, remaining, context)) {
return true;
CACHE_AND_RETURN_LE(true);
}
}
}
Expand All @@ -1643,14 +1680,16 @@ bool lessEqual(Val* x, Val* y, const Context& context) {
maybeFlattenedOpOf(BinaryOpType::Mul, std::move(remaining_inputs));
auto one = IrBuilder::create<Val>(1L, *remaining->getDataType());
if (lessEqual(one, remaining, context)) {
return true;
CACHE_AND_RETURN_LE(true);
}
}
}
}
return false;
CACHE_AND_RETURN_LE(false);
}

#undef CACHE_AND_RETURN_LE

} // namespace prove

namespace {
Expand Down