From e35663f8a07d807acb12378580ea9d4e1c155f1a Mon Sep 17 00:00:00 2001 From: syang-ng Date: Wed, 8 Sep 2021 09:08:31 +0000 Subject: [PATCH 1/3] fix div zero error in rewrite_simplify --- src/arith/rewrite_simplify.cc | 2 ++ .../unittest/test_arith_rewrite_simplify.py | 27 ++++++++----------- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 1d3475b13dad..1cdd6f0b4cb5 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -474,6 +474,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) { if ((div(ramp(b1, c1, lanes), broadcast(c2, lanes))).Match(ret)) { int64_t c1val = c1.Eval()->value; int64_t c2val = c2.Eval()->value; + // If divisor is equal to zero + ICHECK(c2val != 0) << "division by zero"; if (c1val % c2val == 0) { return ramp(div(b1, c2), div(c1, c2), lanes).Eval(); } diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 641eed51d5cf..ba42b7cb920c 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import py +import pytest import tvm from tvm import te @@ -931,20 +933,13 @@ def test_shift_left_simplify(): ck.verify(z, tvm.tir.const(1 << 10, "int32")) +def test_div_zero_simplify(): + ck = RewriteChecker() + + with pytest.raises(tvm.error.TVMError) as cm: + ck.analyzer.rewrite_simplify(tvm.tir.Div(tvm.tir.Ramp(1,1,2), tvm.tir.Broadcast(0, 2))) + assert "division by zero" in str(cm.execption) + + if __name__ == "__main__": - test_floordiv_index_simplify() - test_floormod_index_simplify() - test_cmp_simplify() - test_vector_simplify() - test_add_index_simplify() - test_sub_index_simplify() - test_mul_index_simplify() - test_div_index_simplify() - test_max_index_simplify() - test_min_index_simplify() - test_mod_index_simplify() - test_select_simplify() - test_logical_simplify() - test_let_simplify() - test_cast_simplify() - test_shift_left_simplify() + pytest.main([__file__]) From 17a517470ae9300fff6afc2b8424074bf58a87a3 Mon Sep 17 00:00:00 2001 From: syang-ng Date: Wed, 8 Sep 2021 10:55:13 +0000 Subject: [PATCH 2/3] update the style to fix ci error --- tests/python/unittest/test_arith_rewrite_simplify.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index ba42b7cb920c..95d24d68ee7f 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -937,7 +937,7 @@ def test_div_zero_simplify(): ck = RewriteChecker() with pytest.raises(tvm.error.TVMError) as cm: - ck.analyzer.rewrite_simplify(tvm.tir.Div(tvm.tir.Ramp(1,1,2), tvm.tir.Broadcast(0, 2))) + ck.analyzer.rewrite_simplify(tvm.tir.Div(tvm.tir.Ramp(1, 1, 2), tvm.tir.Broadcast(0, 2))) assert "division by zero" in str(cm.execption) From 564bf4a0de08628a3ff14bcb20a423a2595f1590 Mon Sep 17 00:00:00 2001 From: syang-ng Date: Thu, 9 Sep 2021 00:13:37 +0000 Subject: [PATCH 3/3] remove useless code and comment --- src/arith/rewrite_simplify.cc | 1 - tests/python/unittest/test_arith_rewrite_simplify.py | 1 - 2 files changed, 2 deletions(-) diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 1cdd6f0b4cb5..0087866ea4f8 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -474,7 +474,6 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) { if ((div(ramp(b1, c1, lanes), broadcast(c2, lanes))).Match(ret)) { int64_t c1val = c1.Eval()->value; int64_t c2val = c2.Eval()->value; - // If divisor is equal to zero ICHECK(c2val != 0) << "division by zero"; if (c1val % c2val == 0) { return ramp(div(b1, c2), div(c1, c2), lanes).Eval(); diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 95d24d68ee7f..9ff9ff18e5b5 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import py import pytest import tvm from tvm import te