From 62eecf80c1f63e3025c78a4128f5aca3d55f56d6 Mon Sep 17 00:00:00 2001 From: Lucien0 <16538059+Lucien0@users.noreply.github.com> Date: Tue, 15 Aug 2023 15:21:22 +0800 Subject: [PATCH] fix detect linear equation with uint var --- src/arith/detect_linear_equation.cc | 3 ++- tests/python/unittest/test_arith_detect_linear_equation.py | 4 ++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/arith/detect_linear_equation.cc b/src/arith/detect_linear_equation.cc index 576ac1716e69..4d3164cbd382 100644 --- a/src/arith/detect_linear_equation.cc +++ b/src/arith/detect_linear_equation.cc @@ -100,7 +100,8 @@ class LinearEqDetector : public ExprFunctordtype, 1); + auto dtype = op->dtype; + ret.coeff = make_const(DataType::Int(dtype.bits(), dtype.lanes()), 1); } else { ret.base = e; } diff --git a/tests/python/unittest/test_arith_detect_linear_equation.py b/tests/python/unittest/test_arith_detect_linear_equation.py index cedb55782989..829b101af341 100644 --- a/tests/python/unittest/test_arith_detect_linear_equation.py +++ b/tests/python/unittest/test_arith_detect_linear_equation.py @@ -43,6 +43,10 @@ def test_basic(): assert len(m) == 1 tvm.testing.assert_prim_expr_equal(m[0], b * 7) + c = te.var("c", "uint32") + m = tvm.arith.detect_linear_equation(128 - c, [c]) + assert m[0].value == -1 + def test_multivariate(): v = [te.var("v%d" % i) for i in range(4)]