From 68c255c7a45a71bfe7ff505a94f6db459b7b5847 Mon Sep 17 00:00:00 2001 From: Chenfan Date: Thu, 26 Aug 2021 11:15:07 +0800 Subject: [PATCH 1/4] Update rewrite_simplify.cc --- src/arith/rewrite_simplify.cc | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index ff6536ab066b..1d3475b13dad 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -858,14 +858,18 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { ModularSet bmod = analyzer_->modular_set(b1.Eval()); int64_t ramp_min = floordiv(bmod->base, c2val); int64_t ramp_max = floordiv(bmod->base + (lanes.Eval() - 1) * c1val, c2val); - if (bmod->coeff % c2val == 0) { - if (ramp_min == ramp_max) { + if (ramp_min == ramp_max) { + // If b1 can devide c2 + if (bmod->coeff % c2val == 0) { return ramp(floormod(bmod->base, c2), c1, lanes).Eval(); - } else { - return floormod(ramp(floormod(bmod->base, c2), c1, lanes), broadcast(c2, lanes)).Eval(); } - } else if (c2val % bmod->coeff == 0 && ramp_min == ramp_max) { - return ramp(floormod(b1, c2), c1, lanes).Eval(); + // If all indices can be guaranteed to settle inside a coeff range + if (c2val % bmod->coeff == 0 && bmod->base + (lanes.Eval() - 1) * c1val < bmod->coeff) { + return ramp(floormod(b1, c2), c1, lanes).Eval(); + } + } + if (bmod->coeff % c2val == 0) { + return floormod(ramp(floormod(bmod->base, c2), c1, lanes), broadcast(c2, lanes)).Eval(); } } } From ed83401d45957b14806d39917c0494bc25755ff1 Mon Sep 17 00:00:00 2001 From: Chenfan Date: Thu, 26 Aug 2021 11:17:51 +0800 Subject: [PATCH 2/4] Update test_arith_rewrite_simplify.py --- .../unittest/test_arith_rewrite_simplify.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 231c376c50ca..fb3462b9c41e 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -28,6 +28,12 @@ def verify(self, data, expected): data, res, expected ) + def verify_not(self, data, expected): + res = self.analyzer.rewrite_simplify(data) + assert not tvm.ir.structural_equal(res, expected), "data={}, res={}, expected={}".format( + data, res, expected + ) + def test_vector_simplify(): ck = RewriteChecker() @@ -135,13 +141,17 @@ def test_vector_simplify(): ck.verify( flm(tvm.tir.Ramp(x * 8, 2, 4), tvm.tir.Broadcast(64, 4)), tvm.tir.Ramp(flm(x * 8, 64), 2, 4) ) - ck.verify( + ck.verify_not( flm(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)), tvm.tir.Ramp(flm(x * 4, 64), 1, 5) - ) - ck.verify( + ) # Example negative case: x = 15; [60, 61, 62, 63, 64] % 64 = [60, 61, 62, 63, 0] + ck.verify_not( flm(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)), tvm.tir.Ramp(flm(x * 4 + 3, 64), 1, 4), - ) + ) # Example negative case: x = 15; [63, 64, 65, 66] % 64 = [63, 0, 1, 2] + ck.verify_not( + flm(tvm.tir.Ramp(x * 2, 1, 8), tvm.tir.Broadcast(20, 8)), + tvm.tir.Ramp(flm(x * 2, 20), 1, 8), + ) # Example negative case: x = 9; [18, 19, 20, ..., 25] % 20 = [18, 19, 0, 1, ..., 5] ck.verify( flm(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)), flm(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)), From d056704dc70cb37f515da961b26fc02e8f119736 Mon Sep 17 00:00:00 2001 From: Chenfan Date: Thu, 26 Aug 2021 11:26:37 +0800 Subject: [PATCH 3/4] Update test_arith_rewrite_simplify.py --- .../unittest/test_arith_rewrite_simplify.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index fb3462b9c41e..cfede5a6d46f 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -28,12 +28,6 @@ def verify(self, data, expected): data, res, expected ) - def verify_not(self, data, expected): - res = self.analyzer.rewrite_simplify(data) - assert not tvm.ir.structural_equal(res, expected), "data={}, res={}, expected={}".format( - data, res, expected - ) - def test_vector_simplify(): ck = RewriteChecker() @@ -141,16 +135,17 @@ def test_vector_simplify(): ck.verify( flm(tvm.tir.Ramp(x * 8, 2, 4), tvm.tir.Broadcast(64, 4)), tvm.tir.Ramp(flm(x * 8, 64), 2, 4) ) - ck.verify_not( - flm(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)), tvm.tir.Ramp(flm(x * 4, 64), 1, 5) + ck.verify( + flm(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)), + flm(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)), ) # Example negative case: x = 15; [60, 61, 62, 63, 64] % 64 = [60, 61, 62, 63, 0] - ck.verify_not( + ck.verify( + flm(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)), flm(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)), - tvm.tir.Ramp(flm(x * 4 + 3, 64), 1, 4), ) # Example negative case: x = 15; [63, 64, 65, 66] % 64 = [63, 0, 1, 2] - ck.verify_not( + ck.verify( + flm(tvm.tir.Ramp(x * 2, 1, 8), tvm.tir.Broadcast(20, 8)), flm(tvm.tir.Ramp(x * 2, 1, 8), tvm.tir.Broadcast(20, 8)), - tvm.tir.Ramp(flm(x * 2, 20), 1, 8), ) # Example negative case: x = 9; [18, 19, 20, ..., 25] % 20 = [18, 19, 0, 1, ..., 5] ck.verify( flm(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)), From a324b32c6784f74ee571bfba9cb5ae2e06c2b552 Mon Sep 17 00:00:00 2001 From: Chenfan Date: Thu, 26 Aug 2021 11:32:43 +0800 Subject: [PATCH 4/4] Update test_arith_rewrite_simplify.py --- tests/python/unittest/test_arith_rewrite_simplify.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index cfede5a6d46f..641eed51d5cf 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -101,15 +101,16 @@ def test_vector_simplify(): ck.verify( fld(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)), fld(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)), - ) + ) # Example negative case: x = 15; [60, 61, 62, 63, 64] / 64 = [0, 0, 0, 0, 1] ck.verify( fld(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)), fld(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)), - ) + ) # Example negative case: x = 15; [63, 64, 65, 66] % 64 = [0, 1, 1, 1] ck.verify( fld(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)), fld(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)), - ) + ) # Example negative case: x = 9; [63, 70, 77, 84] % 64 = [0, 1, 1, 1] + # floor mod ck.verify(flm(y.astype("int32x2"), x.astype("int32x2")), flm(y, x).astype("int32x2")) ck.verify(flm(tvm.tir.Ramp(x, 4, 4), 2), tvm.tir.Broadcast(flm(x, 2), 4)) @@ -150,7 +151,7 @@ def test_vector_simplify(): ck.verify( flm(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)), flm(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)), - ) + ) # Example negative case: x = 9; [63, 70, 77, 84] % 64 = [63, 6, 13, 20] # Min/Max rules vx = te.var("vx", dtype="int32x2")