From d3b70d72003adcf924e767bec6f92b5dd345684d Mon Sep 17 00:00:00 2001 From: Eldritch Cheese Date: Thu, 5 Jan 2023 14:18:12 -0600 Subject: [PATCH 1/2] [Arith] Simplify to positive numerators in floordiv/floormod Negative numerators to modulo/remainder operations are not supported by the Vulkan API. While the SPIR-V instructions [`OpSRem`](https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpSRem) and [`OpSMod`](https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpSMod) have identical semantics to `tir::Mod` and `tir::FloorMod`, respectively, use of either instruction within Vulkan results in undefined behavior. From the [Vulkan spec](https://registry.khronos.org/vulkan/specs/1.3/html/chap37.html#spirvenv-op-prec): > For the OpSRem and OpSMod instructions, if either operand is > negative the result is undefined. > > Note: While the OpSRem and OpSMod instructions are supported by the > Vulkan environment, they require non-negative values and thus do not > enable additional functionality beyond what OpUMod provides. This issue was first noticed in https://github.com/apache/tvm/pull/13530, where use of integer arithmetic resulted in negative numerators. This hadn't caused issues previously, because most use of div/mod use a denominator that is a power of two. In these cases, `tir.LowerIntrin` implements floordiv and floormod using only bitwise operations. When the denominator isn't a power of two, both `tir::FloorDiv` and `tir::FloorMod` are implemented in terms of `tir::Mod`, which triggers the undefined behavior for negative numerators. This commit implements additional simplification rules that preferentially removes negative values from the numerators. For example, simplifying `floormod(i - 2, 8)` to `floormod(i + 6, 8)`, and simplifying `floordiv(i - 2, 8)` to `floordiv(i + 6, 8) - 1`. These handle the most common case, where some index variable is being offset by a negative constant. --- src/arith/rewrite_simplify.cc | 38 ++++++++++++++--- .../unittest/test_arith_rewrite_simplify.py | 15 +++++-- .../unittest/test_target_codegen_vulkan.py | 42 +++++++++++++++++++ 3 files changed, 85 insertions(+), 10 deletions(-) diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index f1838f5a9099..9d4d7c0e868d 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -221,12 +221,21 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { TVM_TRY_REWRITE_IF(min(x, y + z * c1) + z * c2, min(x + z * c2, y), c1.Eval()->value == -c2.Eval()->value); - TVM_TRY_REWRITE_IF(max(x, y + z * c1) + z * c2, max(x + z * c2, y), - c1.Eval()->value == -c2.Eval()->value); TVM_TRY_REWRITE_IF(min(y + z * c1, x) + z * c2, min(x + z * c2, y), c1.Eval()->value == -c2.Eval()->value); + TVM_TRY_REWRITE_IF(min(x, z * c1 + y) + z * c2, min(x + z * c2, y), + c1.Eval()->value == -c2.Eval()->value); + TVM_TRY_REWRITE_IF(min(z * c1 + y, x) + z * c2, min(x + z * c2, y), + c1.Eval()->value == -c2.Eval()->value); + + TVM_TRY_REWRITE_IF(max(x, y + z * c1) + z * c2, max(x + z * c2, y), + c1.Eval()->value == -c2.Eval()->value); TVM_TRY_REWRITE_IF(max(y + z * c1, x) + z * c2, max(x + z * c2, y), c1.Eval()->value == -c2.Eval()->value); + TVM_TRY_REWRITE_IF(max(x, z * c1 + y) + z * c2, max(x + z * c2, y), + c1.Eval()->value == -c2.Eval()->value); + TVM_TRY_REWRITE_IF(max(z * c1 + y, x) + z * c2, max(x + z * c2, y), + c1.Eval()->value == -c2.Eval()->value); TVM_TRY_REWRITE(max(x, y) + min(x, y), x + y); TVM_TRY_REWRITE(min(x, y) + max(x, y), x + y); @@ -241,6 +250,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { // constant folding // NOTE: canonicalization might better at this. TVM_TRY_REWRITE((x + c1) + c2, x + (c1 + c2)); + TVM_TRY_REWRITE((c1 - y) + c2, (c1 + c2) - y); + TVM_TRY_REWRITE((y - c1) + c2, y + (c2 - c1)); // mul co-efficient folding TVM_TRY_REWRITE(x + x, x * 2); @@ -267,10 +278,14 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { // canonicalization rule // will try rewrite again after canonicalization. + TVM_TRY_RECURSIVE_REWRITE(c1 + x, x + c1); TVM_TRY_RECURSIVE_REWRITE(x + (c1 - y), (x - y) + c1); TVM_TRY_RECURSIVE_REWRITE((c1 - y) + x, (x - y) + c1); - TVM_TRY_RECURSIVE_REWRITE(x + c1 + y, (x + y) + c1); - TVM_TRY_RECURSIVE_REWRITE(x + (c1 + y), (x + y) + c1); + TVM_TRY_RECURSIVE_REWRITE(x + (y - c1), (x + y) - c1); + TVM_TRY_RECURSIVE_REWRITE((y - c1) + x, (x + y) - c1); + TVM_TRY_RECURSIVE_REWRITE((x + c1) + y, (x + y) + c1); + TVM_TRY_RECURSIVE_REWRITE(x + (y + c1), (x + y) + c1); + TVM_TRY_RECURSIVE_REWRITE(x + max(y, z), max(y, z) + x); TVM_TRY_RECURSIVE_REWRITE(x + min(y, z), min(y, z) + x); @@ -493,6 +508,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) { // canonicalization rule // will try rewrite again after canonicalization. TVM_TRY_REWRITE(x - c1, x + (0 - c1)); + TVM_TRY_RECURSIVE_REWRITE(x - (y + c1), (x - y) - c1); TVM_TRY_RECURSIVE_REWRITE((x + c1) - y, (x - y) + c1); TVM_TRY_RECURSIVE_REWRITE(x - (y - z), (x + z) - y); TVM_TRY_RECURSIVE_REWRITE(x - y * c1, x + y * (0 - c1)); @@ -903,6 +919,13 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { TVM_TRY_REWRITE_IF(floordiv(max(y, x * c1), c2), max(floordiv(y, c2), x * floordiv(c1, c2)), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + TVM_TRY_REWRITE_IF(floordiv(x + c1, c2), floordiv(x + floormod(c1, c2), c2) + floordiv(c1, c2), + c2.Eval()->value > 0 && + (c1.Eval()->value < 0 /* || c1.Eval()->value >= c2.Eval()->value*/)); + TVM_TRY_REWRITE_IF(floordiv(x - c1, c2), + floordiv(x + floormod(-1 * c1, c2), c2) + floordiv(-1 * c1, c2), + c2.Eval()->value > 0); + // Rules involving 3-operands. TVM_TRY_REWRITE_IF(floordiv(x * c1 + y + z, c2), x * floordiv(c1, c2) + floordiv(y + z, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); @@ -1013,8 +1036,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(x * floormod(c1, c2) + y, c2), c2.Eval()->value > 0); - TVM_TRY_REWRITE_IF(floormod(x + c1, c2), floormod(x, c2), - c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + TVM_TRY_REWRITE_IF( + floormod(x + c1, c2), floormod(x + floormod(c1, c2), c2), + c2.Eval()->value > 0 && (c1.Eval()->value < 0 || c1.Eval()->value >= c2.Eval()->value)); + TVM_TRY_REWRITE_IF(floormod(x - c1, c2), floormod(x + floormod(-1 * c1, c2), c2), + c2.Eval()->value > 0); TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x + y * floormod(c1, c2), c2), c2.Eval()->value > 0); diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index d6c2cfe8bbdd..a8e7268097e0 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -381,8 +381,8 @@ def test_sub_index_simplify(): ck.verify(fld(x + 5, 3) - fld(x + 2, 3), fld(flm(x + 2, 3), 3) + 1) ck.verify(fld(y, 3) * 3 - y, 0 - flm(y, 3)) - ck.verify(y - fld(y - 6, 5) * 5, flm(y + (-6), 5) + 6) - ck.verify(fld(y - 6, 5) * 5 - y, (-6) - flm(y + (-6), 5)) + ck.verify(y - fld(y - 6, 5) * 5, flm(y + 4, 5) + 6) + ck.verify(fld(y - 6, 5) * 5 - y, (-6) - flm(y + 4, 5)) ck.verify(y - fld(y + z, 5) * 5, flm(y + z, 5) - z) ck.verify(fld(y + z, 5) * 5 - y, z - flm(y + z, 5)) ck.verify(y - fld(y - z, 5) * 5, flm(y - z, 5) + z) @@ -471,14 +471,14 @@ def test_floordiv_index_simplify(): ck.verify(fld(x * 4, 2), x * 2) ck.verify(fld(x * 8 + 7, 16), fld(x, 2)) ck.verify(fld(x * 8 + 39, 16), fld(x, 2) + 2) - ck.verify(fld(x * 8 - 1, 16), fld(x * 8 + -1, 16)) + ck.verify(fld(x * 8 - 1, 16), fld(x * 8 + 15, 16) + -1) ck.verify(fld(x * 8 - 9, 16), fld(x, 2) + -1) ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1), override=True) ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 7), override=True) ck.verify(fld(x * 360 + y, 16), x * 22) ck.verify(fld(x * 360 + y, 25), x * 14) - ck.verify(fld(x * 360 - 8, 25), fld(x * 360 + -8, 25)) + ck.verify(fld(x * 360 - 8, 25), fld(x * 360 + 17, 25) + -1) ck.verify(fld(x * 4 + y, 2), x * 2 + fld(y, 2)) ck.verify(fld(tvm.te.min(x * 6, y), 2), tvm.te.min(x * 3, fld(y, 2))) @@ -488,6 +488,9 @@ def test_floordiv_index_simplify(): ck.verify(fld(tvm.te.min(y, x * 6), 2), tvm.te.min(fld(y, 2), x * 3)) ck.verify(fld(tvm.te.max(y, x * 6), 2), tvm.te.max(fld(y, 2), x * 3)) + # removal of negative offsets + ck.verify(fld(x - 17, 5), fld(x + 3, 5) + -4) + # 3-operands ck.verify(fld(x * 6 + y + z, 2), x * 3 + fld(y + z, 2)) ck.verify(fld(x * 6 - y + (y + 3), 2), x * 3 + 1) @@ -574,6 +577,10 @@ def test_floormod_index_simplify(): ck.verify(flm(x + (-10), 2), flm(x, 2)) ck.verify(flm(x + y * (-10), 2), flm(x, 2)) + # removal of negative offsets + ck.verify(flm(x - 17, 5), flm(x + 3, 5)) + ck.verify(flm(x + 17, 5), flm(x + 2, 5)) + ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 31), override=True) ck.verify(flm(x * 32 + y, 64), flm(x, 2) * 32 + y) ck.verify(flm(x * 32 - y, 64), flm(x * 32 - y, 64)) diff --git a/tests/python/unittest/test_target_codegen_vulkan.py b/tests/python/unittest/test_target_codegen_vulkan.py index 76cad250e053..7b71f4d4ab17 100644 --- a/tests/python/unittest/test_target_codegen_vulkan.py +++ b/tests/python/unittest/test_target_codegen_vulkan.py @@ -28,6 +28,7 @@ import tvm.testing from tvm import relay, te from tvm.topi.math import cast +from tvm.script import tir as T dtype = tvm.testing.parameter("float32", "int32", "float16", "int8") @@ -558,5 +559,46 @@ def do_compute(ins, outs): tvm.build(s, [Out], target) +def test_negative_operand_divmod(target, dev): + """Test handling of negative offsets to floormod/floordiv + + Even though the SPIR-V spec states that OpSRem and OpSMod can give + the signed modulo, the Vulkan spec states that any use of negative + operands is undefined behavior. This test starts with negative + operands to floordiv, validating that they are simplified into the + corresponding positive operands, such that the final TIR can be + expressed using only positive operands. + + SPIR-V: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpSRem + Vulkan: https://registry.khronos.org/vulkan/specs/1.3/html/chap37.html#spirvenv-op-prec + """ + + N = 32 + offset = 16 + divisor = 5 + + @T.prim_func + def func(A: T.Buffer[(N, 2), "int32"]): + for i in T.serial(N): + with T.block("A"): + v_i = T.axis.spatial(N, i) + A[v_i, 0] = T.floordiv(v_i - offset, divisor) + A[v_i, 1] = T.floormod(v_i - offset, divisor) + + if "gpu" in tvm.target.Target(target).keys: + sch = tvm.tir.Schedule(func) + sch.bind(sch.get_loops("A")[0], "threadIdx.x") + func = sch.mod["main"] + + built = tvm.build(func, target=target) + + a_dev = tvm.nd.empty([N, 2], "int32", dev) + built(a_dev) + a = a_dev.numpy() + + np.testing.assert_array_equal(a[:, 0], (np.arange(N) - offset) // divisor) + np.testing.assert_array_equal(a[:, 1], (np.arange(N) - offset) % divisor) + + if __name__ == "__main__": tvm.testing.main() From e59358948ec6da26a4170fe0ac52620fe997009b Mon Sep 17 00:00:00 2001 From: Eldritch Cheese Date: Fri, 6 Jan 2023 13:17:34 -0600 Subject: [PATCH 2/2] Additional changes to resolve breakage, may need either here or in #13708 --- src/arith/rewrite_simplify.cc | 117 +++++++++++++----- .../unittest/test_arith_rewrite_simplify.py | 20 +-- 2 files changed, 99 insertions(+), 38 deletions(-) diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 9d4d7c0e868d..482e50be7bf4 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -253,12 +253,19 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { TVM_TRY_REWRITE((c1 - y) + c2, (c1 + c2) - y); TVM_TRY_REWRITE((y - c1) + c2, y + (c2 - c1)); + TVM_TRY_REWRITE(c1 + (x + c2), x + (c1 + c2)); + TVM_TRY_REWRITE((x + c1) + c2, x + (c1 + c2)); + TVM_TRY_REWRITE(c1 + (x - c2), x + (c1 - c2)); + TVM_TRY_REWRITE((x - c1) + c2, x + (c2 - c1)); + TVM_TRY_REWRITE(c1 + (c2 - x), (c1 + c2) - x); + TVM_TRY_REWRITE((c1 - x) + c2, (c1 + c2) - x); + // mul co-efficient folding TVM_TRY_REWRITE(x + x, x * 2); TVM_TRY_REWRITE(x * y + x, x * (y + 1)); - TVM_TRY_REWRITE(y * x + x, x * (y + 1)); - TVM_TRY_REWRITE(x + y * x, x * (1 + y)); - TVM_TRY_REWRITE(x + x * y, x * (1 + y)); + TVM_TRY_REWRITE(y * x + x, (y + 1) * x); + TVM_TRY_REWRITE(x + y * x, (y + 1) * x); + TVM_TRY_REWRITE(x + x * y, x * (y + 1)); TVM_TRY_REWRITE(x * y + x * z, x * (y + z)); TVM_TRY_REWRITE(y * x + x * z, x * (y + z)); TVM_TRY_REWRITE(x * y + z * x, x * (y + z)); @@ -276,6 +283,10 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { TVM_TRY_REWRITE_IF(floordiv(floormod(x, c2) + c1, c2) + floordiv(x, c2), floordiv(x + c1, c2), c2.Eval()->value > 0); + auto one = PConst(make_const(op->dtype, 1)); + TVM_TRY_REWRITE(floormod(x, 2) + floormod(x + 1, 2), one); + TVM_TRY_REWRITE(floormod(x + 1, 2) + floormod(x, 2), one); + // canonicalization rule // will try rewrite again after canonicalization. TVM_TRY_RECURSIVE_REWRITE(c1 + x, x + c1); @@ -374,6 +385,15 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) { TVM_TRY_REWRITE(x - min(x, y), max(0, x - y)); TVM_TRY_REWRITE(y - min(x, y), max(y - x, 0)); + // constant cancelation + TVM_TRY_REWRITE(c1 - (x + c2), (c1 - c2) - x); + TVM_TRY_REWRITE(c1 - (x - c2), (c1 + c2) - x); + TVM_TRY_REWRITE(c1 - (c2 - x), x + (c1 - c2)); + TVM_TRY_REWRITE((x + c1) - c2, x + (c1 - c2)); + TVM_TRY_REWRITE((x - c1) - c2, x - (c1 + c2)); + TVM_TRY_REWRITE((c1 - x) - c2, (c1 - c2) - x); + TVM_TRY_RECURSIVE_REWRITE((c1 - x) - (c2 - y), (y - x) + (c1 - c2)); + // mul co-efficient folding TVM_TRY_REWRITE(x - x, ZeroWithTypeLike(x)); TVM_TRY_REWRITE(x * y - x, x * (y - 1)); @@ -385,10 +405,6 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) { TVM_TRY_REWRITE(x * y - z * x, x * (y - z)); TVM_TRY_REWRITE(y * x - z * x, x * (y - z)); - // constant cancelation - TVM_TRY_REWRITE((x + c1) - c2, x + (c1 - c2)); - TVM_TRY_REWRITE((c1 - x) - (c2 - y), (y - x) + (c1 - c2)); - // cancelization rule involving 4 operands TVM_TRY_REWRITE((x + y) - (x + z), y - z); TVM_TRY_REWRITE((x + y) - (z + x), y - z); @@ -552,7 +568,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MulNode* op) { if (IsIndexType(op->dtype)) { // constant simplification rule - TVM_TRY_REWRITE((x + c1) * c2, x * c2 + c1 * c2); + TVM_TRY_REWRITE_IF((x + c1) * c2, x * c2 + c1 * c2, c1.Eval()->value != 1); TVM_TRY_REWRITE((x * c1) * c2, x * (c1 * c2)); TVM_TRY_REWRITE(min(x, y) * max(x, y), x * y); TVM_TRY_REWRITE(max(x, y) * min(x, y), x * y); @@ -886,7 +902,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { PrimExpr y_div = CanProveEqual(floordiv(yval, c2val), 0) ? 0 : floordiv(yval, c2val); auto bound = analyzer_->const_int_bound(residue); if (bound.defined() && bound->max_value == bound->min_value) { - return x.Eval() * floordiv(c1val, c2.Eval()) + (y_div + Integer(bound->max_value)); + PrimExpr out = x.Eval() * floordiv(c1val, c2.Eval()) + (y_div + Integer(bound->max_value)); + return RecursiveRewrite(out); } // try simplify divisor @@ -898,7 +915,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { // ==> x' + d + (b * c1 + e) // c2 // ==> x' + d since 0 <= b * c1 <= (a-1) * c1, 0 <= e < c1 // ==> x // (c2 // c1) + (y // c2) - return floordiv(x.Eval(), floordiv(c2val, c1val)) + y_div; + PrimExpr out = floordiv(x.Eval(), floordiv(c2val, c1val)) + y_div; + return RecursiveRewrite(out); } } @@ -919,12 +937,12 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { TVM_TRY_REWRITE_IF(floordiv(max(y, x * c1), c2), max(floordiv(y, c2), x * floordiv(c1, c2)), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); - TVM_TRY_REWRITE_IF(floordiv(x + c1, c2), floordiv(x + floormod(c1, c2), c2) + floordiv(c1, c2), - c2.Eval()->value > 0 && - (c1.Eval()->value < 0 /* || c1.Eval()->value >= c2.Eval()->value*/)); - TVM_TRY_REWRITE_IF(floordiv(x - c1, c2), - floordiv(x + floormod(-1 * c1, c2), c2) + floordiv(-1 * c1, c2), - c2.Eval()->value > 0); + TVM_TRY_RECURSIVE_REWRITE_IF( + floordiv(x + c1, c2), floordiv(x + floormod(c1, c2), c2) + floordiv(c1, c2), + c2.Eval()->value > 0 && (c1.Eval()->value != floormod(c1.Eval()->value, c2.Eval()->value))); + TVM_TRY_RECURSIVE_REWRITE_IF(floordiv(x - c1, c2), + floordiv(x + floormod(-1 * c1, c2), c2) + floordiv(-1 * c1, c2), + c2.Eval()->value > 0); // Rules involving 3-operands. TVM_TRY_REWRITE_IF(floordiv(x * c1 + y + z, c2), x * floordiv(c1, c2) + floordiv(y + z, c2), @@ -974,6 +992,28 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { CanProveGreaterEqual(z.Eval(), 0)); TVM_TRY_REWRITE_IF(floordiv(x - floormod(x, c1), c1), floordiv(x, c1), c1.Eval()->value != 0); + + if (floordiv(x * c1, c2).Match(ret)) { + auto c1val = c1.Eval()->value; + auto c2val = c2.Eval()->value; + IntImm gcd(x.Eval()->dtype, ZeroAwareGCD(c1val, c2val)); + + if (gcd->value > 1) { + // floormod(i*gcd, j*gcd) == floormod(i,j) * gcd + return floordiv(x.Eval() * floordiv(c1val, gcd), floordiv(c2val, gcd)); + } + } + + if (floordiv(x, c1).Match(ret)) { + ModularSet mod = analyzer_->modular_set(x.Eval()); + int64_t c1val = c1.Eval()->value; + IntImm gcd(x.Eval()->dtype, ZeroAwareGCD(c1val, mod->coeff)); + if (gcd->value > 1) { + // floordiv(i*gcd, j*gcd) == floordiv(i,j) + PrimExpr out = floordiv(floordiv(x.Eval(), gcd), floordiv(c1.Eval(), gcd)); + return RecursiveRewrite(out); + } + } } return ret; } @@ -1025,27 +1065,38 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { if (IsIndexType(op->dtype)) { // Be-aware of the division rules: we use floordiv/floormod here - TVM_TRY_REWRITE_IF(floormod(x * c1, c2), floormod(x * floormod(c1, c2), c2), - c2.Eval()->value != 0); + TVM_TRY_REWRITE_IF( + floormod(x * c1, c2), floormod(x * floormod(c1, c2), c2), + c2.Eval()->value != 0 && c1.Eval()->value != floormod(c1.Eval()->value, c2.Eval()->value)); TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(x, floordiv(c2, c1)) * c1 + y, c1.Eval()->value > 0 && c2.Eval()->value > 0 && c2.Eval()->value % c1.Eval()->value == 0 && CanProveEqual(floordiv(y.Eval(), c1.Eval()), 0)); - TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(x * floormod(c1, c2) + y, c2), - c2.Eval()->value > 0); + TVM_TRY_RECURSIVE_REWRITE_IF( + floormod(x * c1 + y, c2), floormod(x * floormod(c1, c2) + y, c2), + c2.Eval()->value > 0 && (c1.Eval()->value != floormod(c1.Eval()->value, c2.Eval()->value))); - TVM_TRY_REWRITE_IF( + TVM_TRY_RECURSIVE_REWRITE_IF( floormod(x + c1, c2), floormod(x + floormod(c1, c2), c2), - c2.Eval()->value > 0 && (c1.Eval()->value < 0 || c1.Eval()->value >= c2.Eval()->value)); - TVM_TRY_REWRITE_IF(floormod(x - c1, c2), floormod(x + floormod(-1 * c1, c2), c2), - c2.Eval()->value > 0); + c2.Eval()->value > 0 && (c1.Eval()->value != floormod(c1.Eval()->value, c2.Eval()->value))); + TVM_TRY_RECURSIVE_REWRITE_IF(floormod(x - c1, c2), floormod(x + floormod(-1 * c1, c2), c2), + c2.Eval()->value > 0); - TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x + y * floormod(c1, c2), c2), - c2.Eval()->value > 0); + TVM_TRY_RECURSIVE_REWRITE_IF( + floormod(x + y * c1, c2), floormod(x + y * floormod(c1, c2), c2), + c2.Eval()->value > 0 && (c1.Eval()->value != floormod(c1.Eval()->value, c2.Eval()->value))); + + TVM_TRY_RECURSIVE_REWRITE_IF( + floormod(x * c1, x * c2), x * floormod(c1, c2), + c2.Eval()->value != 0 && + (c1.Eval()->value != floormod(c1.Eval()->value, c2.Eval()->value))); - TVM_TRY_REWRITE_IF(floormod(x * c1, x * c2), x * floormod(c1, c2), c2.Eval()->value != 0); + TVM_TRY_RECURSIVE_REWRITE_IF(floormod(floormod(x, c1) + y, c2), floormod(x + y, c2), + floormod(c1.Eval()->value, c2.Eval()->value) == 0); + TVM_TRY_RECURSIVE_REWRITE_IF(floormod(x + floormod(y, c1), c2), floormod(x + y, c2), + floormod(c1.Eval()->value, c2.Eval()->value) == 0); TVM_TRY_REWRITE(floormod(x * y, y), ZeroWithTypeLike(x)); TVM_TRY_REWRITE(floormod(y * x, y), ZeroWithTypeLike(y)); @@ -1057,6 +1108,13 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { if (mod->coeff % c1val == 0 && c1val > 0) { return floormod(mod->base, c1).Eval(); } + if (c1val > 0 && mod->coeff > 1) { + IntImm gcd(x.Eval()->dtype, ZeroAwareGCD(c1val, mod->coeff)); + if (gcd->value > 1) { + PrimExpr out = floormod(floordiv(x.Eval(), gcd), floordiv(c1.Eval(), gcd)) * gcd; + return RecursiveRewrite(out); + } + } } } return ret; @@ -1428,7 +1486,7 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(EQ ret) { // Pattern var to match any expression PVar x, y; // Pattern var match IntImm - PVar c1; + PVar c1, c2, c3; PVar lanes; // vector rule @@ -1450,6 +1508,9 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(EQ ret) { TVM_TRY_REWRITE(c1 - x == 0, x == c1); TVM_TRY_REWRITE(x + c1 == 0, x == 0 - c1); TVM_TRY_RECURSIVE_REWRITE(x * y == 0, x == 0 || y == 0); + + TVM_TRY_RECURSIVE_REWRITE(floormod(x + c1, c2) == c3, + c3 == floormod(c3, c2) && floormod(x, c2) == floormod(c3 - c1, c2)); } return std::move(ret); } diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index a8e7268097e0..d6ab1126e86d 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -82,7 +82,7 @@ def test_vector_simplify(): ck.verify(fld(tvm.tir.Ramp(x, 8, 5), tvm.tir.Broadcast(4, 5)), tvm.tir.Ramp(fld(x, 4), 2, 5)) ck.verify( fld(tvm.tir.Ramp(flm(x * 4, 256), 1, 4), tvm.tir.Broadcast(8, 4)), - tvm.tir.Broadcast(fld(flm(x * 4, 256), 8), 4), + tvm.tir.Broadcast(fld(flm(x, 64), 2), 4), ) ck.verify( fld(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)), @@ -136,10 +136,10 @@ def test_vector_simplify(): flm(tvm.tir.Ramp(3, 1, 4), tvm.tir.Broadcast(4, 4)), ) ck.verify( - flm(tvm.tir.Ramp(x * 4, 1, 4), tvm.tir.Broadcast(64, 4)), tvm.tir.Ramp(flm(x * 4, 64), 1, 4) + flm(tvm.tir.Ramp(x * 4, 1, 4), tvm.tir.Broadcast(64, 4)), tvm.tir.Ramp(flm(x, 16) * 4, 1, 4) ) ck.verify( - flm(tvm.tir.Ramp(x * 8, 2, 4), tvm.tir.Broadcast(64, 4)), tvm.tir.Ramp(flm(x * 8, 64), 2, 4) + flm(tvm.tir.Ramp(x * 8, 2, 4), tvm.tir.Broadcast(64, 4)), tvm.tir.Ramp(flm(x, 8) * 8, 2, 4) ) ck.verify( flm(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)), @@ -255,7 +255,7 @@ def test_add_index_simplify(): ck.verify(x * y + 10 * x, x * (y + 10)) ck.verify((2 * z) + tvm.te.min(x, y - (2 * z)), tvm.te.min(x + (z * 2), y)) - ck.verify(y * x + x, x * (y + 1)) + ck.verify(y * x + x, (y + 1) * x) ck.verify(x * y + x, x * (y + 1)) ck.verify((x + 10) + 13, x + 23) ck.verify((x + 10) + (13 + z), x + z + 23) @@ -281,7 +281,7 @@ def test_add_index_simplify(): flm = tvm.te.floormod ck.verify(y * flm(x, 8) + 10 * flm(x, 8), flm(x, 8) * (y + 10)) ck.verify(fld(x, 8) * 8 + flm(x, 8), x) - ck.verify(fld(flm(x, 2) + 7, 2) + fld(x, 2), fld(x + 7, 2)) + ck.verify(fld(flm(x, 2) + 7, 2) + fld(x, 2), fld(x + 1, 2) + 3) def test_sub_index_simplify(): @@ -377,8 +377,8 @@ def test_sub_index_simplify(): ck.analyzer.update(x, tvm.arith.ConstIntBound(-1000, 1000), override=True) ck.analyzer.update(y, tvm.arith.ConstIntBound(-1000, 1000), override=True) ck.verify(x - fld(x, 3) * 3, flm(x, 3)) - ck.verify(fld(x + 5, 3) - fld(x, 3), fld(flm(x, 3) + 5, 3)) - ck.verify(fld(x + 5, 3) - fld(x + 2, 3), fld(flm(x + 2, 3), 3) + 1) + ck.verify(fld(x + 5, 3) - fld(x, 3), fld(flm(x, 3) + 2, 3) + 1) + ck.verify(fld(x + 5, 3) - fld(x + 2, 3), 1) ck.verify(fld(y, 3) * 3 - y, 0 - flm(y, 3)) ck.verify(y - fld(y - 6, 5) * 5, flm(y + 4, 5) + 6) @@ -471,14 +471,14 @@ def test_floordiv_index_simplify(): ck.verify(fld(x * 4, 2), x * 2) ck.verify(fld(x * 8 + 7, 16), fld(x, 2)) ck.verify(fld(x * 8 + 39, 16), fld(x, 2) + 2) - ck.verify(fld(x * 8 - 1, 16), fld(x * 8 + 15, 16) + -1) + ck.verify(fld(x * 8 - 1, 16), fld(x + 1, 2) + -1) ck.verify(fld(x * 8 - 9, 16), fld(x, 2) + -1) ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1), override=True) ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 7), override=True) ck.verify(fld(x * 360 + y, 16), x * 22) ck.verify(fld(x * 360 + y, 25), x * 14) - ck.verify(fld(x * 360 - 8, 25), fld(x * 360 + 17, 25) + -1) + ck.verify(fld(x * 360 - 8, 25), fld(x * 72 + 3, 5) + -1) ck.verify(fld(x * 4 + y, 2), x * 2 + fld(y, 2)) ck.verify(fld(tvm.te.min(x * 6, y), 2), tvm.te.min(x * 3, fld(y, 2))) @@ -565,7 +565,7 @@ def test_floormod_index_simplify(): x, y, nx, ny, z = te.var("x"), te.var("y"), te.var("nx"), te.var("ny"), te.var("z") ck.verify(flm(x * 10, 2), 0) - ck.verify(flm(x * 9600, 6400), flm(x * 3200, 6400)) + ck.verify(flm(x * 9600, 6400), flm(x, 2) * 3200) ck.verify(flm(x * 10 + y, 2), flm(y, 2)) ck.verify(flm(x * 360 + y, 16), flm(x * 8 + y, 16)) ck.verify(flm(x + 10, 2), flm(x, 2))