diff --git a/src/arithmetic/detect_linear_equation.cc b/src/arithmetic/detect_linear_equation.cc index 4e6d8caf3772..6f4d3cfb53bb 100644 --- a/src/arithmetic/detect_linear_equation.cc +++ b/src/arithmetic/detect_linear_equation.cc @@ -111,8 +111,9 @@ class LinearEqDetector return ComputeExpr(a, b); } Expr SubCombine(Expr a, Expr b) { - if (!a.defined()) return -b; + // Check b first in case they are both undefined if (!b.defined()) return a; + if (!a.defined()) return -b; return ComputeExpr(a, b); } Expr MulCombine(Expr a, Expr b) { diff --git a/tests/python/unittest/test_arith_detect_linear_equation.py b/tests/python/unittest/test_arith_detect_linear_equation.py index 9d875c910d1c..2b0f327b65b2 100644 --- a/tests/python/unittest/test_arith_detect_linear_equation.py +++ b/tests/python/unittest/test_arith_detect_linear_equation.py @@ -38,6 +38,10 @@ def test_multivariate(): assert(m[2].value == 2) assert(m[len(m)-1].value == 2) + m = tvm.arith.DetectLinearEquation((v[0] - v[1]), [v[2]]) + assert(m[0].value == 0) + assert(tvm.ir_pass.Simplify(m[1] - (v[0] - v[1])).value == 0) + if __name__ == "__main__": test_basic() test_multivariate()