From 2b6674410d3aca3f03a205c5d49161fc5109ba14 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 12 Mar 2024 01:24:44 -0400 Subject: [PATCH] [Fix][Arith] Fix canonical simplification of LE PR #15471 enhances the simplification for LE, while missed a case where the upper bound `kPosInf` is divisible by a factor. Therefore, prior to this PR, when simplifying `x * 1024 + y < z * 7168`, it will fails with the error message ``` InternalError: Check failed: value < 1LL << (dtype.bits() - 1) (8589934591 vs. 2147483648) : ValueError: Literal value 8589934591 exceeds maximum of int32 ``` This is just because the upper bound 7 here divides `kPosInf` the maximum value of int64, which passes an "if" condition in #15471 unexpectedly. This PR fixes the issue. --- src/arith/canonical_simplify.cc | 1 + tests/python/arith/test_arith_canonical_simplify.py | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index 0d972b491ae6..b11708398fe9 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -1422,6 +1422,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const LTNode* op) { divisible.CopyOnWrite()->DivideBy(gcd); return Rewriter::VisitExpr(divisible->Normalize() < make_zero(dtype)); } else if (extra->args.size() == 1 && + extra->args[0]->upper_factor != ConstIntBoundNode::kPosInf && extra->args[0]->upper_factor % (gcd * extra->args[0]->lower_factor) == 0) { // Case 2. xn == yn % m, where m % d == 0 divisible.CopyOnWrite()->DivideBy(gcd); diff --git a/tests/python/arith/test_arith_canonical_simplify.py b/tests/python/arith/test_arith_canonical_simplify.py index 052d2895bfa0..23321ce823c3 100644 --- a/tests/python/arith/test_arith_canonical_simplify.py +++ b/tests/python/arith/test_arith_canonical_simplify.py @@ -461,6 +461,11 @@ def test_simplify_le(): ) ck.verify(tx // 2 % 8 + vec < 8, tx % 16 // 2 + vec < 8) + # Case 3. No failure + x, y, z = te.var("x"), te.var("y"), te.var("z") + ck.analyzer.bind(y, tvm.ir.Range(0, 1024)) + ck.verify(x * 1024 + y < z * 7168, x - z * 7 < 0) + if __name__ == "__main__": tvm.testing.main()