diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index 0d972b491ae6..b11708398fe9 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -1422,6 +1422,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const LTNode* op) { divisible.CopyOnWrite()->DivideBy(gcd); return Rewriter::VisitExpr(divisible->Normalize() < make_zero(dtype)); } else if (extra->args.size() == 1 && + extra->args[0]->upper_factor != ConstIntBoundNode::kPosInf && extra->args[0]->upper_factor % (gcd * extra->args[0]->lower_factor) == 0) { // Case 2. xn == yn % m, where m % d == 0 divisible.CopyOnWrite()->DivideBy(gcd); diff --git a/tests/python/arith/test_arith_canonical_simplify.py b/tests/python/arith/test_arith_canonical_simplify.py index 052d2895bfa0..23321ce823c3 100644 --- a/tests/python/arith/test_arith_canonical_simplify.py +++ b/tests/python/arith/test_arith_canonical_simplify.py @@ -461,6 +461,11 @@ def test_simplify_le(): ) ck.verify(tx // 2 % 8 + vec < 8, tx % 16 // 2 + vec < 8) + # Case 3. No failure + x, y, z = te.var("x"), te.var("y"), te.var("z") + ck.analyzer.bind(y, tvm.ir.Range(0, 1024)) + ck.verify(x * 1024 + y < z * 7168, x - z * 7 < 0) + if __name__ == "__main__": tvm.testing.main()