diff --git a/src/arith/detect_linear_equation.cc b/src/arith/detect_linear_equation.cc index 8ea8f168b6ee..576ac1716e69 100644 --- a/src/arith/detect_linear_equation.cc +++ b/src/arith/detect_linear_equation.cc @@ -189,6 +189,7 @@ bool DetectClipBound(const PrimExpr& cond, PostOrderVisit(cond, fvisit); if (flag != 1) return false; // canonical form: exp >= 0 + bool is_eq = false; PrimExpr canonical; if (const LTNode* op = cond.as()) { if (!op->a.dtype().is_int()) return false; @@ -202,6 +203,10 @@ bool DetectClipBound(const PrimExpr& cond, } else if (const GENode* op = cond.as()) { if (!op->a.dtype().is_int()) return false; canonical = op->a - op->b; + } else if (const EQNode* op = cond.as()) { + if (!op->a.dtype().is_int()) return false; + canonical = op->a - op->b; + is_eq = true; } else { return false; } @@ -210,25 +215,40 @@ bool DetectClipBound(const PrimExpr& cond, if (!LinearEqDetector(var).Detect(canonical, &ret)) return false; ret.coeff = analyzer.Simplify(ret.coeff); IntervalEntry& p = (*bmap)[var.get()]; + + Optional min_value; + Optional max_value; if (is_const_int(ret.coeff, 1)) { // var + shift >=0 -> var >= -shift + min_value = -ret.base; + if (is_eq) { + max_value = min_value; + } + } else if (is_const_int(ret.coeff, -1)) { + // -var + shift >=0 -> var <= shift + max_value = ret.base; + if (is_eq) { + min_value = max_value; + } + } + if (!min_value.defined() && !max_value.defined()) { + return false; + } + if (min_value.defined()) { if (p.min_value.defined()) { - p.min_value = max(p.min_value, -ret.base); + p.min_value = max(p.min_value, min_value.value()); } else { - p.min_value = -ret.base; + p.min_value = min_value.value(); } - return true; } - if (is_const_int(ret.coeff, -1)) { - // -var + shift >=0 -> var <= shift + if (max_value.defined()) { if (p.max_value.defined()) { - p.max_value = min(p.max_value, ret.base); + p.max_value = min(p.max_value, max_value.value()); } else { - p.max_value = ret.base; + p.max_value = max_value.value(); } - return true; } - return false; + return true; } template diff --git a/tests/python/unittest/test_arith_detect_clip_bound.py b/tests/python/unittest/test_arith_detect_clip_bound.py index 0a9d75fcea54..03fff11f77e5 100644 --- a/tests/python/unittest/test_arith_detect_clip_bound.py +++ b/tests/python/unittest/test_arith_detect_clip_bound.py @@ -39,5 +39,18 @@ def test_basic(): tvm.testing.assert_prim_expr_equal(m[2], 4) +def test_trivial_eq(): + a = te.var("a") + b = te.var("b") + m = tvm.arith.detect_clip_bound(b == 3, [a, b]) + tvm.testing.assert_prim_expr_equal(m[2], 3) + tvm.testing.assert_prim_expr_equal(m[3], 3) + m = tvm.arith.detect_clip_bound(tvm.tir.all(a == 4, b == 3), [a, b]) + tvm.testing.assert_prim_expr_equal(m[0], 4) + tvm.testing.assert_prim_expr_equal(m[1], 4) + tvm.testing.assert_prim_expr_equal(m[2], 3) + tvm.testing.assert_prim_expr_equal(m[3], 3) + + if __name__ == "__main__": test_basic()