From 2e85fc56b30a7f88ac67f590b5852210f07bb76c Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 20 Mar 2024 14:25:32 -0700 Subject: [PATCH 1/2] Add naive cache for provers --- csrc/expr_simplifier.cpp | 71 +++++++++++++++++++++++++++++++--------- 1 file changed, 55 insertions(+), 16 deletions(-) diff --git a/csrc/expr_simplifier.cpp b/csrc/expr_simplifier.cpp index c037ff365d7..a05de9bdbd1 100644 --- a/csrc/expr_simplifier.cpp +++ b/csrc/expr_simplifier.cpp @@ -28,6 +28,18 @@ #include #include +namespace std { +template +struct hash> { + std::size_t operator()(const std::pair& pair) const { + std::size_t h1 = std::hash()(pair.first); + std::size_t h2 = std::hash()(pair.second); + nvfuser::hashCombine(h1, h2); + return h1; + } +}; +} // namespace std + namespace nvfuser { namespace debug_print { @@ -251,6 +263,9 @@ class Context { return less_equal_; } + mutable std::unordered_map, bool> less_than_cache_; + mutable std::unordered_map, bool> less_equal_cache_; + private: void assume(Val* a) { auto def = a->definition(); @@ -1527,16 +1542,26 @@ bool hasCompatibleSign(Val* x, Val* y, const Context& context) { return isNonNegative(x, context) && isNonNegative(y, context); } +#define CACHE_AND_RETURN(value) \ + context.less_than_cache_.emplace(std::make_pair(x, y), value); \ + return value + bool lessThan(Val* x, Val* y, const Context& context) { + auto cache_it = context.less_than_cache_.find({x, y}); + if (cache_it != context.less_than_cache_.end()) { + return cache_it->second; + } + x = foldConstants(x); y = foldConstants(y); if (x->value().hasValue() && y->value().hasValue()) { - return x->value() < y->value(); + bool result = x->value() < y->value(); + CACHE_AND_RETURN(result); } x = maybeUnwrapMagicZero(x); y = maybeUnwrapMagicZero(y); if (x->isZero() && isPositiveHelper(y, context)) { - return true; + CACHE_AND_RETURN(true); } // i1 % i2 < i2 if (auto bop = dynamic_cast(x->definition()); @@ -1544,60 +1569,72 @@ bool lessThan(Val* x, Val* y, const Context& context) { auto denominator = bop->rhs(); if (denominator->sameAs(y) && isValidDenominator(denominator, context) && isNonNegative(y, context)) { - return true; + CACHE_AND_RETURN(true); } } // x <= a & a < b & b <= y --> x < y for (const auto& [a, b] : context.getKnownLessThan()) { if (lessEqual(x, a, context) && lessEqual(b, y, context)) { - return true; + CACHE_AND_RETURN(true); } } - return false; + CACHE_AND_RETURN(false); } +#undef CACHE_AND_RETURN + +#define CACHE_AND_RETURN(value) \ + context.less_equal_cache_.emplace(std::make_pair(x, y), value); \ + return value + bool lessEqual(Val* x, Val* y, const Context& context) { + auto cache_it = context.less_equal_cache_.find({x, y}); + if (cache_it != context.less_equal_cache_.end()) { + return cache_it->second; + } + x = foldConstants(x); y = foldConstants(y); if (x->value().hasValue() && y->value().hasValue()) { - return x->value() <= y->value(); + bool result = x->value() <= y->value(); + CACHE_AND_RETURN(result); } x = maybeUnwrapMagicZero(x); y = maybeUnwrapMagicZero(y); // x == y -> x <= y if (x->sameAs(y)) { - return true; + CACHE_AND_RETURN(true); } if (x->isZero() && isNonNegativeHelper(y, context)) { - return true; + CACHE_AND_RETURN(true); } for (const auto& [a, b] : context.getKnownLessThan()) { // x < y --> x <= y if (a->sameAs(x) && b->sameAs(y)) { - return true; + CACHE_AND_RETURN(true); } } for (const auto& [a, b] : context.getKnownLessEqual()) { if (a->sameAs(x) && b->sameAs(y)) { - return true; + CACHE_AND_RETURN(true); } } for (const auto& [a, b] : context.getKnownLessThan()) { // x < b & b <= y --> x <= y if (a->sameAs(x) && lessEqual(b, y, context)) { - return true; + CACHE_AND_RETURN(true); } } for (const auto& [a, b] : context.getKnownLessEqual()) { // x <= b & b <= y --> x <= y if (a->sameAs(x) && lessEqual(b, y, context)) { - return true; + CACHE_AND_RETURN(true); } } // if i is an integer, i > 0, then i >= 1 if (x->isOneInt() && y->isIntegralScalar()) { if (isPositiveHelper(y, context)) { - return true; + CACHE_AND_RETURN(true); } } // if a >= 0, b >= 1, then a <= a * b @@ -1619,7 +1656,7 @@ bool lessEqual(Val* x, Val* y, const Context& context) { maybeFlattenedOpOf(BinaryOpType::Mul, std::move(remaining_inputs)); auto one = IrBuilder::create(1L, *remaining->getDataType()); if (lessEqual(one, remaining, context)) { - return true; + CACHE_AND_RETURN(true); } } } @@ -1643,14 +1680,16 @@ bool lessEqual(Val* x, Val* y, const Context& context) { maybeFlattenedOpOf(BinaryOpType::Mul, std::move(remaining_inputs)); auto one = IrBuilder::create(1L, *remaining->getDataType()); if (lessEqual(one, remaining, context)) { - return true; + CACHE_AND_RETURN(true); } } } } - return false; + CACHE_AND_RETURN(false); } +#undef CACHE_AND_RETURN + } // namespace prove namespace { From 922475aa8634f3b3a837193d44af35a1edd48298 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 20 Mar 2024 17:15:40 -0700 Subject: [PATCH 2/2] rename macro --- csrc/expr_simplifier.cpp | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/csrc/expr_simplifier.cpp b/csrc/expr_simplifier.cpp index a05de9bdbd1..765dc9c08ba 100644 --- a/csrc/expr_simplifier.cpp +++ b/csrc/expr_simplifier.cpp @@ -1542,7 +1542,7 @@ bool hasCompatibleSign(Val* x, Val* y, const Context& context) { return isNonNegative(x, context) && isNonNegative(y, context); } -#define CACHE_AND_RETURN(value) \ +#define CACHE_AND_RETURN_LT(value) \ context.less_than_cache_.emplace(std::make_pair(x, y), value); \ return value @@ -1556,12 +1556,12 @@ bool lessThan(Val* x, Val* y, const Context& context) { y = foldConstants(y); if (x->value().hasValue() && y->value().hasValue()) { bool result = x->value() < y->value(); - CACHE_AND_RETURN(result); + CACHE_AND_RETURN_LT(result); } x = maybeUnwrapMagicZero(x); y = maybeUnwrapMagicZero(y); if (x->isZero() && isPositiveHelper(y, context)) { - CACHE_AND_RETURN(true); + CACHE_AND_RETURN_LT(true); } // i1 % i2 < i2 if (auto bop = dynamic_cast(x->definition()); @@ -1569,21 +1569,21 @@ bool lessThan(Val* x, Val* y, const Context& context) { auto denominator = bop->rhs(); if (denominator->sameAs(y) && isValidDenominator(denominator, context) && isNonNegative(y, context)) { - CACHE_AND_RETURN(true); + CACHE_AND_RETURN_LT(true); } } // x <= a & a < b & b <= y --> x < y for (const auto& [a, b] : context.getKnownLessThan()) { if (lessEqual(x, a, context) && lessEqual(b, y, context)) { - CACHE_AND_RETURN(true); + CACHE_AND_RETURN_LT(true); } } - CACHE_AND_RETURN(false); + CACHE_AND_RETURN_LT(false); } -#undef CACHE_AND_RETURN +#undef CACHE_AND_RETURN_LT -#define CACHE_AND_RETURN(value) \ +#define CACHE_AND_RETURN_LE(value) \ context.less_equal_cache_.emplace(std::make_pair(x, y), value); \ return value @@ -1597,44 +1597,44 @@ bool lessEqual(Val* x, Val* y, const Context& context) { y = foldConstants(y); if (x->value().hasValue() && y->value().hasValue()) { bool result = x->value() <= y->value(); - CACHE_AND_RETURN(result); + CACHE_AND_RETURN_LE(result); } x = maybeUnwrapMagicZero(x); y = maybeUnwrapMagicZero(y); // x == y -> x <= y if (x->sameAs(y)) { - CACHE_AND_RETURN(true); + CACHE_AND_RETURN_LE(true); } if (x->isZero() && isNonNegativeHelper(y, context)) { - CACHE_AND_RETURN(true); + CACHE_AND_RETURN_LE(true); } for (const auto& [a, b] : context.getKnownLessThan()) { // x < y --> x <= y if (a->sameAs(x) && b->sameAs(y)) { - CACHE_AND_RETURN(true); + CACHE_AND_RETURN_LE(true); } } for (const auto& [a, b] : context.getKnownLessEqual()) { if (a->sameAs(x) && b->sameAs(y)) { - CACHE_AND_RETURN(true); + CACHE_AND_RETURN_LE(true); } } for (const auto& [a, b] : context.getKnownLessThan()) { // x < b & b <= y --> x <= y if (a->sameAs(x) && lessEqual(b, y, context)) { - CACHE_AND_RETURN(true); + CACHE_AND_RETURN_LE(true); } } for (const auto& [a, b] : context.getKnownLessEqual()) { // x <= b & b <= y --> x <= y if (a->sameAs(x) && lessEqual(b, y, context)) { - CACHE_AND_RETURN(true); + CACHE_AND_RETURN_LE(true); } } // if i is an integer, i > 0, then i >= 1 if (x->isOneInt() && y->isIntegralScalar()) { if (isPositiveHelper(y, context)) { - CACHE_AND_RETURN(true); + CACHE_AND_RETURN_LE(true); } } // if a >= 0, b >= 1, then a <= a * b @@ -1656,7 +1656,7 @@ bool lessEqual(Val* x, Val* y, const Context& context) { maybeFlattenedOpOf(BinaryOpType::Mul, std::move(remaining_inputs)); auto one = IrBuilder::create(1L, *remaining->getDataType()); if (lessEqual(one, remaining, context)) { - CACHE_AND_RETURN(true); + CACHE_AND_RETURN_LE(true); } } } @@ -1680,15 +1680,15 @@ bool lessEqual(Val* x, Val* y, const Context& context) { maybeFlattenedOpOf(BinaryOpType::Mul, std::move(remaining_inputs)); auto one = IrBuilder::create(1L, *remaining->getDataType()); if (lessEqual(one, remaining, context)) { - CACHE_AND_RETURN(true); + CACHE_AND_RETURN_LE(true); } } } } - CACHE_AND_RETURN(false); + CACHE_AND_RETURN_LE(false); } -#undef CACHE_AND_RETURN +#undef CACHE_AND_RETURN_LE } // namespace prove