Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/arith/bound_deducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,10 @@ void BoundDeducer::Deduce() {
expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_);

this->VisitExpr(expr_);

if (success_) {
result_ = analyzer_.Simplify(result_);
}
}

void BoundDeducer::Relax() {
Expand Down
106 changes: 106 additions & 0 deletions src/arith/canonical_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,27 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
*/
void SeparateDivisibleParts(const SumExprNode* psum, int64_t coeff, SumExpr* out_divisible,
SumExpr* out_non_divisible);
/*!
* \brief Pattern match and check whether lhs is fully divisible by
* rhs using prod pattern simiplification expressions.
*
* The following two relations holds for floordiv/mod and truncdiv/mod
* Note that the relation do not hold for euclidean divide and mod.
*
* This is because the floordiv/mod and truncdiv/mod result can be
* uniquely determined by the value of the realdiv result and the
* relation holds for realdiv.
*
* - div((a0 * a1 * c), (b0 * b1 * c)) = div((a0 * a1), (b0 * b1))
* - mod((a0 * a1 * c), (b0 * b1 * c)) = mod((a0 * a1), (b0 * b1)) * c
*
* \param lhs The left operand to be updated.
* \param rhs The right operand to be updated.
* \param common_scale The common scale between lhs and rhs.
* \returns The simplified result if it is successful.
* \note This simplification mainly target when rhs is symbolic.
*/
bool ProdDivSimplify(PrimExpr* lhs, PrimExpr* rhs, PrimExpr* common_scale);
/*!
* \brief Normalize expr to normal expr.
* \param expr The input expression.
Expand Down Expand Up @@ -862,6 +883,66 @@ SplitExpr CanonicalSimplifier::Impl::SplitDivConst(SplitExpr lhs, int64_t cval,
return lhs;
}

bool CanonicalSimplifier::Impl::ProdDivSimplify(PrimExpr* plhs, PrimExpr* prhs,
PrimExpr* common_scale) {
// the constant rhs case is covered by other simplifier so
// we just skip to save the time
if (prhs->as<IntImmNode>()) return false;
// collect lhs products and try to eliminate by matching them to prod in rhs
Array<Optional<PrimExpr>> lhs_prods;
PrimExpr new_rhs = make_const(prhs->dtype(), 1);
PrimExpr new_common_scale = make_const(prhs->dtype(), 1);
int64_t lhs_cscale = 1, rhs_cscale = 1;
int num_elimination = 0;

// collect lhs product and constant scale.
auto fcollect_lhs = [&](PrimExpr value) {
if (auto* intimm = value.as<tir::IntImmNode>()) {
lhs_cscale *= intimm->value;
} else {
lhs_prods.push_back(value);
}
};
UnpackReduction<tir::MulNode>(*plhs, fcollect_lhs);

// collect rhs product and try to eliminate when possible
PEqualChecker<PrimExpr> deep_equal;
auto fcollect_rhs = [&](PrimExpr value) {
if (auto* intimm = value.as<tir::IntImmNode>()) {
rhs_cscale *= intimm->value;
} else {
// try eliminate from lhs
for (size_t i = 0; i < lhs_prods.size(); ++i) {
if (lhs_prods[i].defined() && deep_equal(value, lhs_prods[i].value())) {
lhs_prods.Set(i, NullOpt);
++num_elimination;
new_common_scale = new_common_scale * value;
return;
}
}
// if elimination is not possible then construct the expression.
new_rhs = new_rhs * value;
}
};
UnpackReduction<tir::MulNode>(*prhs, fcollect_rhs);
// find gcd of const scales.
int64_t cscale_gcd = ZeroAwareGCD(lhs_cscale, rhs_cscale);
lhs_cscale /= cscale_gcd;
rhs_cscale /= cscale_gcd;
// if no elimination is possible
if (num_elimination == 0 && cscale_gcd == 1) return false;

// construct prod via canonical form
PrimExpr new_lhs = make_const(plhs->dtype(), 1);
for (Optional<PrimExpr> val : lhs_prods) {
if (val.defined()) new_lhs = new_lhs * val.value();
}
*plhs = new_lhs * make_const(plhs->dtype(), lhs_cscale);
*prhs = new_rhs * make_const(prhs->dtype(), rhs_cscale);
*common_scale = new_common_scale * make_const(prhs->dtype(), cscale_gcd);
return true;
}

PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const DivNode* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::VisitExpr_(op);
Expand Down Expand Up @@ -913,6 +994,12 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const DivNode* op) {
// normal path
a = Normalize(a);
b = Normalize(b);
PrimExpr scale;
// note this is the case where b is not constant
if (ProdDivSimplify(&a, &b, &scale)) {
// use operator ver so it can constant fold if b == 1
return truncdiv(a, b);
}
if (op->a.same_as(a) && op->b.same_as(b)) {
return GetRef<PrimExpr>(op);
} else {
Expand Down Expand Up @@ -967,6 +1054,11 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
// normal path
a = Normalize(a);
b = Normalize(b);
PrimExpr scale;
if (ProdDivSimplify(&a, &b, &scale)) {
// use operator ver so it can const fold.
return floordiv(a, b);
}
if (op->a.same_as(a) && op->b.same_as(b)) {
return GetRef<PrimExpr>(op);
} else {
Expand Down Expand Up @@ -1088,6 +1180,13 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ModNode* op) {
// normal path
a = Normalize(a);
b = Normalize(b);

PrimExpr scale;
if (ProdDivSimplify(&a, &b, &scale)) {
// use operator version here so it can const fold b == 1
return truncmod(a, b) * scale;
}

if (op->a.same_as(a) && op->b.same_as(b)) {
return GetRef<PrimExpr>(op);
} else {
Expand Down Expand Up @@ -1146,6 +1245,13 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorModNode* op) {
// normal path
a = Normalize(a);
b = Normalize(b);

PrimExpr scale;
if (ProdDivSimplify(&a, &b, &scale)) {
// use operator version here so it can const fold b == 1
return floormod(a, b) * scale;
}

if (op->a.same_as(a) && op->b.same_as(b)) {
return GetRef<PrimExpr>(op);
} else {
Expand Down
17 changes: 17 additions & 0 deletions src/arith/pattern_match.h
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,23 @@ matches_one_of(const TPattern&... patterns) {
return PMatchesOneOf<TPattern...>(patterns...);
}

/*!
* \brief Unpack reduction by calling each leaf via fleaf.
*
* \param value The expression value.
* \tparam TNode the reduction node to match.
* \tparam FLeaf The callback function at leaf.
*/
template <typename TNode, typename FLeaf>
inline void UnpackReduction(const PrimExpr& value, FLeaf fleaf) {
if (const TNode* node = value.as<TNode>()) {
UnpackReduction<TNode, FLeaf>(node->a, fleaf);
UnpackReduction<TNode, FLeaf>(node->b, fleaf);
} else {
fleaf(value);
}
}

} // namespace arith
} // namespace tvm
#endif // TVM_ARITH_PATTERN_MATCH_H_
10 changes: 4 additions & 6 deletions tests/python/unittest/test_arith_deduce_bound.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,10 @@ def test_deduce():
assert str(res9.max_value) == "neg_inf"
assert str(res9.min_value) == "pos_inf"

# Unsatisfiable Mul in `EQ`
res10 = tvm.arith.deduce_bound(
a, (b * a == b), {b: b_s}, {}
) # simplifier is not able to prove that (b % b == 0)
assert str(res10.max_value) == "neg_inf"
assert str(res10.min_value) == "pos_inf"
res10 = tvm.arith.deduce_bound(a, (b * a == b), {b: b_s}, {})
# simplifier is now able to prove symbolic relation (b * a % b == 0)
tvm.testing.assert_prim_expr_equal(res10.max_value, 1)
tvm.testing.assert_prim_expr_equal(res10.min_value, 1)


def test_check():
Expand Down