From bf740d595619f97fdde2e482f485225f748184a6 Mon Sep 17 00:00:00 2001 From: wrongtest Date: Tue, 10 Jan 2023 18:12:11 +0800 Subject: [PATCH 1/2] Support eq in detect_clip_bound --- src/arith/detect_linear_equation.cc | 10 ++++++++++ .../python/unittest/test_arith_detect_clip_bound.py | 13 +++++++++++++ 2 files changed, 23 insertions(+) diff --git a/src/arith/detect_linear_equation.cc b/src/arith/detect_linear_equation.cc index 8ea8f168b6ee..da9864f921e3 100644 --- a/src/arith/detect_linear_equation.cc +++ b/src/arith/detect_linear_equation.cc @@ -189,6 +189,7 @@ bool DetectClipBound(const PrimExpr& cond, PostOrderVisit(cond, fvisit); if (flag != 1) return false; // canonical form: exp >= 0 + bool is_eq = false; PrimExpr canonical; if (const LTNode* op = cond.as()) { if (!op->a.dtype().is_int()) return false; @@ -202,6 +203,9 @@ bool DetectClipBound(const PrimExpr& cond, } else if (const GENode* op = cond.as()) { if (!op->a.dtype().is_int()) return false; canonical = op->a - op->b; + } else if (const EQNode* op = cond.as()) { + canonical = op->a - op->b; + is_eq = true; } else { return false; } @@ -217,6 +221,9 @@ bool DetectClipBound(const PrimExpr& cond, } else { p.min_value = -ret.base; } + if (is_eq) { + p.max_value = p.min_value; + } return true; } if (is_const_int(ret.coeff, -1)) { @@ -226,6 +233,9 @@ bool DetectClipBound(const PrimExpr& cond, } else { p.max_value = ret.base; } + if (is_eq) { + p.min_value = p.max_value; + } return true; } return false; diff --git a/tests/python/unittest/test_arith_detect_clip_bound.py b/tests/python/unittest/test_arith_detect_clip_bound.py index 0a9d75fcea54..03fff11f77e5 100644 --- a/tests/python/unittest/test_arith_detect_clip_bound.py +++ b/tests/python/unittest/test_arith_detect_clip_bound.py @@ -39,5 +39,18 @@ def test_basic(): tvm.testing.assert_prim_expr_equal(m[2], 4) +def test_trivial_eq(): + a = te.var("a") + b = te.var("b") + m = tvm.arith.detect_clip_bound(b == 3, [a, b]) + tvm.testing.assert_prim_expr_equal(m[2], 3) + tvm.testing.assert_prim_expr_equal(m[3], 3) + m = tvm.arith.detect_clip_bound(tvm.tir.all(a == 4, b == 3), [a, b]) + tvm.testing.assert_prim_expr_equal(m[0], 4) + tvm.testing.assert_prim_expr_equal(m[1], 4) + tvm.testing.assert_prim_expr_equal(m[2], 3) + tvm.testing.assert_prim_expr_equal(m[3], 3) + + if __name__ == "__main__": test_basic() From db4cd4aeecbc84a256f8b1c6e1f728345d2f5e3f Mon Sep 17 00:00:00 2001 From: wrongtest Date: Fri, 20 Jan 2023 09:28:27 +0800 Subject: [PATCH 2/2] follow review suggestion --- src/arith/detect_linear_equation.cc | 40 ++++++++++++++++++----------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/src/arith/detect_linear_equation.cc b/src/arith/detect_linear_equation.cc index da9864f921e3..576ac1716e69 100644 --- a/src/arith/detect_linear_equation.cc +++ b/src/arith/detect_linear_equation.cc @@ -204,6 +204,7 @@ bool DetectClipBound(const PrimExpr& cond, if (!op->a.dtype().is_int()) return false; canonical = op->a - op->b; } else if (const EQNode* op = cond.as()) { + if (!op->a.dtype().is_int()) return false; canonical = op->a - op->b; is_eq = true; } else { @@ -214,31 +215,40 @@ bool DetectClipBound(const PrimExpr& cond, if (!LinearEqDetector(var).Detect(canonical, &ret)) return false; ret.coeff = analyzer.Simplify(ret.coeff); IntervalEntry& p = (*bmap)[var.get()]; + + Optional min_value; + Optional max_value; if (is_const_int(ret.coeff, 1)) { // var + shift >=0 -> var >= -shift - if (p.min_value.defined()) { - p.min_value = max(p.min_value, -ret.base); - } else { - p.min_value = -ret.base; + min_value = -ret.base; + if (is_eq) { + max_value = min_value; } + } else if (is_const_int(ret.coeff, -1)) { + // -var + shift >=0 -> var <= shift + max_value = ret.base; if (is_eq) { - p.max_value = p.min_value; + min_value = max_value; } - return true; } - if (is_const_int(ret.coeff, -1)) { - // -var + shift >=0 -> var <= shift - if (p.max_value.defined()) { - p.max_value = min(p.max_value, ret.base); + if (!min_value.defined() && !max_value.defined()) { + return false; + } + if (min_value.defined()) { + if (p.min_value.defined()) { + p.min_value = max(p.min_value, min_value.value()); } else { - p.max_value = ret.base; + p.min_value = min_value.value(); } - if (is_eq) { - p.min_value = p.max_value; + } + if (max_value.defined()) { + if (p.max_value.defined()) { + p.max_value = min(p.max_value, max_value.value()); + } else { + p.max_value = max_value.value(); } - return true; } - return false; + return true; } template