From d767e28884c23adc8679edfdc954d7b9a33c997b Mon Sep 17 00:00:00 2001 From: yuanfz98 Date: Wed, 9 Mar 2022 10:35:17 +0100 Subject: [PATCH 1/5] commit --- .../transforms/common_subexpr_elim_tools.cc | 6 +- .../test_tir_transform_common_subexpr_elim.py | 88 +++++++++++++++++++ 2 files changed, 93 insertions(+), 1 deletion(-) diff --git a/src/tir/transforms/common_subexpr_elim_tools.cc b/src/tir/transforms/common_subexpr_elim_tools.cc index 218667c331a5..690dc31075c3 100644 --- a/src/tir/transforms/common_subexpr_elim_tools.cc +++ b/src/tir/transforms/common_subexpr_elim_tools.cc @@ -25,6 +25,7 @@ #include "common_subexpr_elim_tools.h" +#include #include // For the class Pass and the class PassContext #include #include // For the ExprDeepEqual analysis @@ -727,7 +728,10 @@ bool EquivalentTerms(const PrimExpr& a, const PrimExpr& b) { // For now, we just check the syntactic equality, but that could later become a semantic test, // for instance identifying computations modulo commutativity (like x+y and y+x), or modulo // associativity (like (x+y)+z and x+(y+z)), etc. - return EqualTerms(a, b); + arith::Analyzer analyser; + PrimExpr a_simplified = analyser.Simplify(a); + PrimExpr b_simplified = analyser.Simplify(b); + return EqualTerms(a_simplified, b_simplified); } /*! diff --git a/tests/python/unittest/test_tir_transform_common_subexpr_elim.py b/tests/python/unittest/test_tir_transform_common_subexpr_elim.py index 17c0cbdd99c6..848fd2ce016c 100644 --- a/tests/python/unittest/test_tir_transform_common_subexpr_elim.py +++ b/tests/python/unittest/test_tir_transform_common_subexpr_elim.py @@ -16,6 +16,55 @@ # under the License. import tvm from tvm import te +from tvm.script import tir as T + +@T.prim_func +def func_distributivity( + i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32 +) -> None: + B = T.buffer_decl((50,), "int32") + B[i1] = x * (y + z) + B[i2] = x * y + x * z + + +@T.prim_func +def func_distributivity_expected( + i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32 +) -> None: + B = T.buffer_decl((50,), "int32") + cse_var_1 = T.var("int32") + with T.let(cse_var_1, x * (y + z)): + B[i1] = cse_var_1 + B[i2] = cse_var_1 + +@T.prim_func +def func_associativity( + i1: T.int32, i2: T.int32, i3: T.int32, x: T.int32, y: T.int32, z: T.int32 +) -> None: + B = T.buffer_decl((50,), "int32") + B[i1] = (x+y) + z + B[i2] = (y + z) + x + B[i3] = (x+z)+y + +@T.prim_func +def func_associativity_expected( + i1: T.int32, i2: T.int32, i3: T.int32, x: T.int32, y: T.int32, z: T.int32 +) -> None: + B = T.buffer_decl((50,), "int32") + cse_var_1 = T.var("int32") + with T.let(cse_var_1, x+(y+z)): + B[i1] = cse_var_1 + B[i2] = cse_var_1 + B[i3] = cse_var_1 + +def _check(original, transformed): + func = original + mod = tvm.IRModule.from_expr(func) + print(mod) + mod = tvm.tir.transform.CommonSubexprElimTIR()(mod) + print(mod) + tvm.ir.assert_structural_equal(mod["main"], transformed) + # A test program which gives the opportunity for the CSE pass to introduce two new variables, at two different levels def test_cse(): @@ -70,6 +119,8 @@ def test_cse(): # We will check all of that underneath and more, making also sure that nothing else has been changed mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([i1, i2, z3], body)) + print(mod) + body = tvm.tir.transform.CommonSubexprElimTIR()(mod) tvm.transform.PrintIR()(body) @@ -167,6 +218,7 @@ def test_cse_ifNode_1(): ) mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([i1, i2, i3, y, z], body)) + print(mod) body = tvm.tir.transform.CommonSubexprElimTIR()(mod) tvm.transform.PrintIR()(body) @@ -226,6 +278,7 @@ def test_cse_ifNode_2(): ) mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([i1, i2, i3, y, z], body)) + print(mod) body = tvm.tir.transform.CommonSubexprElimTIR()(mod) tvm.transform.PrintIR()(body) @@ -243,6 +296,28 @@ def test_cse_ifNode_2(): # Test commoning in cascade : after having introduced a big exp ((x+y)+z) into a new variable, # it will become possible to do another commoning for (x+y) which appears both in the new variable # and in the rest of the program. + + +@T.prim_func +def func_cse_cascade( + i1: T.int32, i2: T.int32, i3: T.int32, x: T.int32, y: T.int32, z: T.int32 +) -> None: + B = T.buffer_decl((50,), "int32") + B[i1] = (x + y) + z + B[i2] = (x + y) + z + B[i3] = x + y + + +@T.prim_func +def func_cse_cascade_expected( + i1: T.int32, i2: T.int32, i3: T.int32, x: T.int32, y: T.int32, z: T.int32 +) -> None: + B = T.buffer_decl((50,), "int32") + B[i1] = (x + y) + z + B[i2] = (x + y) + z + B[i3] = x + y + + def test_cse_cascade(): i1 = te.var("i1") i2 = te.var("i2") @@ -265,6 +340,7 @@ def test_cse_cascade(): ) mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([i1, i2, i3, x, y, z], body)) + print(mod) body = tvm.tir.transform.CommonSubexprElimTIR()(mod) tvm.transform.PrintIR()(body) @@ -305,8 +381,20 @@ def test_cse_cascade(): assert tvm.ir.structural_equal(store3.value, cse_var_2) + + + +def test_semantic_equiv_distributivity(): + _check(func_distributivity, func_distributivity_expected) + + +def test_semantic_equiv_associativity(): + _check(func_associativity, func_associativity_expected) + + if __name__ == "__main__": test_cse() test_cse_ifNode_1() test_cse_ifNode_2() test_cse_cascade() + test_semantic_equiv() From cd13186fa3a9f42967763c1ec7825779e644d308 Mon Sep 17 00:00:00 2001 From: yuanfz98 Date: Wed, 9 Mar 2022 11:06:23 +0100 Subject: [PATCH 2/5] black format --- .../test_tir_transform_common_subexpr_elim.py | 28 ++++++++----------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_common_subexpr_elim.py b/tests/python/unittest/test_tir_transform_common_subexpr_elim.py index 848fd2ce016c..b9b3e0e7eedf 100644 --- a/tests/python/unittest/test_tir_transform_common_subexpr_elim.py +++ b/tests/python/unittest/test_tir_transform_common_subexpr_elim.py @@ -18,10 +18,9 @@ from tvm import te from tvm.script import tir as T + @T.prim_func -def func_distributivity( - i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32 -) -> None: +def func_distributivity(i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32) -> None: B = T.buffer_decl((50,), "int32") B[i1] = x * (y + z) B[i2] = x * y + x * z @@ -37,25 +36,24 @@ def func_distributivity_expected( B[i1] = cse_var_1 B[i2] = cse_var_1 + @T.prim_func -def func_associativity( - i1: T.int32, i2: T.int32, i3: T.int32, x: T.int32, y: T.int32, z: T.int32 -) -> None: +def func_associativity(i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32) -> None: B = T.buffer_decl((50,), "int32") - B[i1] = (x+y) + z - B[i2] = (y + z) + x - B[i3] = (x+z)+y + B[i1] = (x + y) + z + B[i2] = x + (y + z) + @T.prim_func def func_associativity_expected( - i1: T.int32, i2: T.int32, i3: T.int32, x: T.int32, y: T.int32, z: T.int32 + i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32 ) -> None: B = T.buffer_decl((50,), "int32") cse_var_1 = T.var("int32") - with T.let(cse_var_1, x+(y+z)): + with T.let(cse_var_1, (x + y) + z): B[i1] = cse_var_1 B[i2] = cse_var_1 - B[i3] = cse_var_1 + def _check(original, transformed): func = original @@ -381,9 +379,6 @@ def test_cse_cascade(): assert tvm.ir.structural_equal(store3.value, cse_var_2) - - - def test_semantic_equiv_distributivity(): _check(func_distributivity, func_distributivity_expected) @@ -397,4 +392,5 @@ def test_semantic_equiv_associativity(): test_cse_ifNode_1() test_cse_ifNode_2() test_cse_cascade() - test_semantic_equiv() + test_semantic_equiv_distributivity() + test_semantic_equiv_associativity() From 6a794b644757f35e3ba4b5b217e58790fb17a275 Mon Sep 17 00:00:00 2001 From: yuanfz98 Date: Wed, 9 Mar 2022 11:10:43 +0100 Subject: [PATCH 3/5] rm prints --- .../test_tir_transform_common_subexpr_elim.py | 24 +------------------ 1 file changed, 1 insertion(+), 23 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_common_subexpr_elim.py b/tests/python/unittest/test_tir_transform_common_subexpr_elim.py index b9b3e0e7eedf..147f52a0c5d7 100644 --- a/tests/python/unittest/test_tir_transform_common_subexpr_elim.py +++ b/tests/python/unittest/test_tir_transform_common_subexpr_elim.py @@ -58,9 +58,7 @@ def func_associativity_expected( def _check(original, transformed): func = original mod = tvm.IRModule.from_expr(func) - print(mod) mod = tvm.tir.transform.CommonSubexprElimTIR()(mod) - print(mod) tvm.ir.assert_structural_equal(mod["main"], transformed) @@ -117,7 +115,6 @@ def test_cse(): # We will check all of that underneath and more, making also sure that nothing else has been changed mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([i1, i2, z3], body)) - print(mod) body = tvm.tir.transform.CommonSubexprElimTIR()(mod) @@ -216,7 +213,6 @@ def test_cse_ifNode_1(): ) mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([i1, i2, i3, y, z], body)) - print(mod) body = tvm.tir.transform.CommonSubexprElimTIR()(mod) tvm.transform.PrintIR()(body) @@ -276,7 +272,6 @@ def test_cse_ifNode_2(): ) mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([i1, i2, i3, y, z], body)) - print(mod) body = tvm.tir.transform.CommonSubexprElimTIR()(mod) tvm.transform.PrintIR()(body) @@ -296,24 +291,7 @@ def test_cse_ifNode_2(): # and in the rest of the program. -@T.prim_func -def func_cse_cascade( - i1: T.int32, i2: T.int32, i3: T.int32, x: T.int32, y: T.int32, z: T.int32 -) -> None: - B = T.buffer_decl((50,), "int32") - B[i1] = (x + y) + z - B[i2] = (x + y) + z - B[i3] = x + y - -@T.prim_func -def func_cse_cascade_expected( - i1: T.int32, i2: T.int32, i3: T.int32, x: T.int32, y: T.int32, z: T.int32 -) -> None: - B = T.buffer_decl((50,), "int32") - B[i1] = (x + y) + z - B[i2] = (x + y) + z - B[i3] = x + y def test_cse_cascade(): @@ -338,7 +316,7 @@ def test_cse_cascade(): ) mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([i1, i2, i3, x, y, z], body)) - print(mod) + body = tvm.tir.transform.CommonSubexprElimTIR()(mod) tvm.transform.PrintIR()(body) From e14af519706a4fe6cb898c8bb42adddd4faee3e1 Mon Sep 17 00:00:00 2001 From: yuanfz98 Date: Wed, 9 Mar 2022 11:11:17 +0100 Subject: [PATCH 4/5] format --- .../unittest/test_tir_transform_common_subexpr_elim.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_common_subexpr_elim.py b/tests/python/unittest/test_tir_transform_common_subexpr_elim.py index 147f52a0c5d7..6ca760cdae0c 100644 --- a/tests/python/unittest/test_tir_transform_common_subexpr_elim.py +++ b/tests/python/unittest/test_tir_transform_common_subexpr_elim.py @@ -291,9 +291,6 @@ def test_cse_ifNode_2(): # and in the rest of the program. - - - def test_cse_cascade(): i1 = te.var("i1") i2 = te.var("i2") @@ -316,7 +313,7 @@ def test_cse_cascade(): ) mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([i1, i2, i3, x, y, z], body)) - + body = tvm.tir.transform.CommonSubexprElimTIR()(mod) tvm.transform.PrintIR()(body) From 9a75aca75cbaebe958f3fe7b1e972f202a3c746b Mon Sep 17 00:00:00 2001 From: yuanfz98 Date: Wed, 9 Mar 2022 11:12:57 +0100 Subject: [PATCH 5/5] format --- .../python/unittest/test_tir_transform_common_subexpr_elim.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_common_subexpr_elim.py b/tests/python/unittest/test_tir_transform_common_subexpr_elim.py index 6ca760cdae0c..01c231d9629c 100644 --- a/tests/python/unittest/test_tir_transform_common_subexpr_elim.py +++ b/tests/python/unittest/test_tir_transform_common_subexpr_elim.py @@ -115,7 +115,6 @@ def test_cse(): # We will check all of that underneath and more, making also sure that nothing else has been changed mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([i1, i2, z3], body)) - body = tvm.tir.transform.CommonSubexprElimTIR()(mod) tvm.transform.PrintIR()(body) @@ -289,8 +288,6 @@ def test_cse_ifNode_2(): # Test commoning in cascade : after having introduced a big exp ((x+y)+z) into a new variable, # it will become possible to do another commoning for (x+y) which appears both in the new variable # and in the rest of the program. - - def test_cse_cascade(): i1 = te.var("i1") i2 = te.var("i2") @@ -313,7 +310,6 @@ def test_cse_cascade(): ) mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([i1, i2, i3, x, y, z], body)) - body = tvm.tir.transform.CommonSubexprElimTIR()(mod) tvm.transform.PrintIR()(body)