diff --git a/src/arith/bound_deducer.cc b/src/arith/bound_deducer.cc index 7cfe8681bea3..d52ae7e6fde3 100644 --- a/src/arith/bound_deducer.cc +++ b/src/arith/bound_deducer.cc @@ -344,6 +344,10 @@ void BoundDeducer::Deduce() { expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_); this->VisitExpr(expr_); + + if (success_) { + result_ = analyzer_.Simplify(result_); + } } void BoundDeducer::Relax() { diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index 11fb041511f9..14c91934d3b2 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -633,6 +633,27 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { */ void SeparateDivisibleParts(const SumExprNode* psum, int64_t coeff, SumExpr* out_divisible, SumExpr* out_non_divisible); + /*! + * \brief Pattern match and check whether lhs is fully divisible by + * rhs using prod pattern simiplification expressions. + * + * The following two relations holds for floordiv/mod and truncdiv/mod + * Note that the relation do not hold for euclidean divide and mod. + * + * This is because the floordiv/mod and truncdiv/mod result can be + * uniquely determined by the value of the realdiv result and the + * relation holds for realdiv. + * + * - div((a0 * a1 * c), (b0 * b1 * c)) = div((a0 * a1), (b0 * b1)) + * - mod((a0 * a1 * c), (b0 * b1 * c)) = mod((a0 * a1), (b0 * b1)) * c + * + * \param lhs The left operand to be updated. + * \param rhs The right operand to be updated. + * \param common_scale The common scale between lhs and rhs. + * \returns The simplified result if it is successful. + * \note This simplification mainly target when rhs is symbolic. + */ + bool ProdDivSimplify(PrimExpr* lhs, PrimExpr* rhs, PrimExpr* common_scale); /*! * \brief Normalize expr to normal expr. * \param expr The input expression. @@ -862,6 +883,66 @@ SplitExpr CanonicalSimplifier::Impl::SplitDivConst(SplitExpr lhs, int64_t cval, return lhs; } +bool CanonicalSimplifier::Impl::ProdDivSimplify(PrimExpr* plhs, PrimExpr* prhs, + PrimExpr* common_scale) { + // the constant rhs case is covered by other simplifier so + // we just skip to save the time + if (prhs->as()) return false; + // collect lhs products and try to eliminate by matching them to prod in rhs + Array> lhs_prods; + PrimExpr new_rhs = make_const(prhs->dtype(), 1); + PrimExpr new_common_scale = make_const(prhs->dtype(), 1); + int64_t lhs_cscale = 1, rhs_cscale = 1; + int num_elimination = 0; + + // collect lhs product and constant scale. + auto fcollect_lhs = [&](PrimExpr value) { + if (auto* intimm = value.as()) { + lhs_cscale *= intimm->value; + } else { + lhs_prods.push_back(value); + } + }; + UnpackReduction(*plhs, fcollect_lhs); + + // collect rhs product and try to eliminate when possible + PEqualChecker deep_equal; + auto fcollect_rhs = [&](PrimExpr value) { + if (auto* intimm = value.as()) { + rhs_cscale *= intimm->value; + } else { + // try eliminate from lhs + for (size_t i = 0; i < lhs_prods.size(); ++i) { + if (lhs_prods[i].defined() && deep_equal(value, lhs_prods[i].value())) { + lhs_prods.Set(i, NullOpt); + ++num_elimination; + new_common_scale = new_common_scale * value; + return; + } + } + // if elimination is not possible then construct the expression. + new_rhs = new_rhs * value; + } + }; + UnpackReduction(*prhs, fcollect_rhs); + // find gcd of const scales. + int64_t cscale_gcd = ZeroAwareGCD(lhs_cscale, rhs_cscale); + lhs_cscale /= cscale_gcd; + rhs_cscale /= cscale_gcd; + // if no elimination is possible + if (num_elimination == 0 && cscale_gcd == 1) return false; + + // construct prod via canonical form + PrimExpr new_lhs = make_const(plhs->dtype(), 1); + for (Optional val : lhs_prods) { + if (val.defined()) new_lhs = new_lhs * val.value(); + } + *plhs = new_lhs * make_const(plhs->dtype(), lhs_cscale); + *prhs = new_rhs * make_const(prhs->dtype(), rhs_cscale); + *common_scale = new_common_scale * make_const(prhs->dtype(), cscale_gcd); + return true; +} + PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const DivNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); @@ -913,6 +994,12 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const DivNode* op) { // normal path a = Normalize(a); b = Normalize(b); + PrimExpr scale; + // note this is the case where b is not constant + if (ProdDivSimplify(&a, &b, &scale)) { + // use operator ver so it can constant fold if b == 1 + return truncdiv(a, b); + } if (op->a.same_as(a) && op->b.same_as(b)) { return GetRef(op); } else { @@ -967,6 +1054,11 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { // normal path a = Normalize(a); b = Normalize(b); + PrimExpr scale; + if (ProdDivSimplify(&a, &b, &scale)) { + // use operator ver so it can const fold. + return floordiv(a, b); + } if (op->a.same_as(a) && op->b.same_as(b)) { return GetRef(op); } else { @@ -1088,6 +1180,13 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ModNode* op) { // normal path a = Normalize(a); b = Normalize(b); + + PrimExpr scale; + if (ProdDivSimplify(&a, &b, &scale)) { + // use operator version here so it can const fold b == 1 + return truncmod(a, b) * scale; + } + if (op->a.same_as(a) && op->b.same_as(b)) { return GetRef(op); } else { @@ -1146,6 +1245,13 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorModNode* op) { // normal path a = Normalize(a); b = Normalize(b); + + PrimExpr scale; + if (ProdDivSimplify(&a, &b, &scale)) { + // use operator version here so it can const fold b == 1 + return floormod(a, b) * scale; + } + if (op->a.same_as(a) && op->b.same_as(b)) { return GetRef(op); } else { diff --git a/src/arith/pattern_match.h b/src/arith/pattern_match.h index 55b51d7a315b..0bb172e56053 100644 --- a/src/arith/pattern_match.h +++ b/src/arith/pattern_match.h @@ -915,6 +915,23 @@ matches_one_of(const TPattern&... patterns) { return PMatchesOneOf(patterns...); } +/*! + * \brief Unpack reduction by calling each leaf via fleaf. + * + * \param value The expression value. + * \tparam TNode the reduction node to match. + * \tparam FLeaf The callback function at leaf. + */ +template +inline void UnpackReduction(const PrimExpr& value, FLeaf fleaf) { + if (const TNode* node = value.as()) { + UnpackReduction(node->a, fleaf); + UnpackReduction(node->b, fleaf); + } else { + fleaf(value); + } +} + } // namespace arith } // namespace tvm #endif // TVM_ARITH_PATTERN_MATCH_H_ diff --git a/tests/python/unittest/test_arith_deduce_bound.py b/tests/python/unittest/test_arith_deduce_bound.py index 45ecb6275549..a36fd214794b 100644 --- a/tests/python/unittest/test_arith_deduce_bound.py +++ b/tests/python/unittest/test_arith_deduce_bound.py @@ -114,12 +114,10 @@ def test_deduce(): assert str(res9.max_value) == "neg_inf" assert str(res9.min_value) == "pos_inf" - # Unsatisfiable Mul in `EQ` - res10 = tvm.arith.deduce_bound( - a, (b * a == b), {b: b_s}, {} - ) # simplifier is not able to prove that (b % b == 0) - assert str(res10.max_value) == "neg_inf" - assert str(res10.min_value) == "pos_inf" + res10 = tvm.arith.deduce_bound(a, (b * a == b), {b: b_s}, {}) + # simplifier is now able to prove symbolic relation (b * a % b == 0) + tvm.testing.assert_prim_expr_equal(res10.max_value, 1) + tvm.testing.assert_prim_expr_equal(res10.min_value, 1) def test_check():