From fca1570a70063380c02225b652d28fdeac322c5c Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 3 Aug 2023 15:05:28 -0400 Subject: [PATCH 1/6] Implement fixpoint simplification pass --- include/tvm/relax/transform.h | 16 ++++ python/tvm/relax/transform/transform.py | 31 +++++++ src/relax/transform/simplify_fixpoint.cc | 78 +++++++++++++++++ .../test_transform_fixpoint_simplification.py | 86 +++++++++++++++++++ 4 files changed, 211 insertions(+) create mode 100644 src/relax/transform/simplify_fixpoint.cc create mode 100644 tests/python/relax/test_transform_fixpoint_simplification.py diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 8d01262aab5c..1b98b5ea6747 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -163,6 +163,22 @@ TVM_DLL Pass Normalize(); */ TVM_DLL Pass CanonicalizeBindings(); +/*! + * \brief Applies multiple simplification passes one after another until reaching + * fixpoint (i.e., until no function in the module changes anymore). + * + * Includes the following passes: + * * DeadCodeElimination + * * CanonicalizeBindings + * * EliminateCommonSubexpressions + * * FoldDataflowBlockOutput + * + * \param entry_functions Entry points to the module, for dead code elimination + * \param call_only Whether to apply common subexpression elimination only to calls + * \return The Pass. + */ +TVM_DLL Pass FixpointSimplification(Array entry_functions, bool call_only); + /*! * Eliminate common subexpressions within functions. * \return The pass that eliminates common subexpressions. diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index f512e42bf69b..6fa3592d6154 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -312,6 +312,37 @@ def EliminateCommonSubexpr(call_only=False) -> FunctionPass: return _ffi_api.EliminateCommonSubexpr(call_only) # type: ignore +def FixpointSimplification( + entry_functions: Optional[List[str]] = None, call_only: bool = False +) -> tvm.ir.transform.Pass: + """ + Applies multiple simplification passes one after another until reaching + fixpoint (i.e., until no function in the module changes anymore). + + Includes the following passes: + * DeadCodeElimination + * CanonicalizeBindings + * EliminateCommonSubexpressions + * FoldDataflowBlockOutput + + Parameters + ---------- + entry_functions: List[str] + Entry points to the module, for dead code elimination + + call_only: bool + Whether to apply common subexpression elimination only to calls + + Returns + ------- + ret: Pass + The pass + """ + if entry_functions is None: + entry_functions = ["main"] + return _ffi_api.FixpointSimplification(entry_functions, call_only) + + def RewriteDataflowReshape() -> tvm.ir.transform.Pass: """Convert all reshape-like call_tir to VM reshape operator call. The VM reshape operator calls will be further lowered to a CreateView diff --git a/src/relax/transform/simplify_fixpoint.cc b/src/relax/transform/simplify_fixpoint.cc new file mode 100644 index 000000000000..006d14756cc9 --- /dev/null +++ b/src/relax/transform/simplify_fixpoint.cc @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/transform/simplify_fixpoint.cc + * \brief Pass that applies other simplification passes until fixpoint. + * Presently, this subsumes the following passes: + * * FoldDataflowBlockOutput + * * CanonicalizeBIndings + * * EliminateCommonSubexpr + * * DeadCodeElimination + */ +#include +#include +#include +#include + +#include "utils.h" + +namespace tvm { +namespace relax { + +uint64_t Hash(const IRModule& mod) { return SHashHandlerDefault().Hash(mod, true); } + +IRModule FixpointSimplification(const IRModule& mod, Array entry_funcs, + bool call_only) { + // apply passes until it stops changing + IRModule current_mod = mod; + transform::Pass cse = transform::EliminateCommonSubexpr(call_only); + transform::Pass canonicalize_bindings = transform::CanonicalizeBindings(); + transform::Pass dce = transform::DeadCodeElimination(entry_funcs); + transform::Pass fold_df_output = transform::FoldDataflowBlockOutput(); + + while (true) { + uint64_t last_hash = Hash(current_mod); + current_mod = std::move(fold_df_output(cse(canonicalize_bindings(dce(current_mod))))); + uint64_t current_hash = Hash(current_mod); + if (current_hash == last_hash) { + break; + } + } + + return current_mod; +} + +namespace transform { + +Pass FixpointSimplification(Array entry_functions, bool call_only) { + runtime::TypedPackedFunc pass_func = [=](IRModule m, + PassContext pc) { + return relax::FixpointSimplification(m, entry_functions, call_only); + }; + return CreateModulePass(pass_func, 1, "FixpointSimplification", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.FixpointSimplification") + .set_body_typed(FixpointSimplification); + +} // namespace transform +} // namespace relax +} // namespace tvm \ No newline at end of file diff --git a/tests/python/relax/test_transform_fixpoint_simplification.py b/tests/python/relax/test_transform_fixpoint_simplification.py new file mode 100644 index 000000000000..dbdaabcdeac3 --- /dev/null +++ b/tests/python/relax/test_transform_fixpoint_simplification.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test fixpoint simplification pass""" +import tvm +import tvm.testing +from tvm.relax.transform import CanonicalizeBindings, DeadCodeElimination, FixpointSimplification +from tvm.script.parser import ir as I, relax as R, tir as T + + +def verify(input: tvm.IRModule, expected: tvm.IRModule) -> None: + actual = FixpointSimplification()(input) + tvm.ir.assert_structural_equal(actual, expected, map_free_vars=True) + + +def test_chain_assignment(): + # test case from binding canonicalization, except it will simplify all the way + @I.ir_module + class TestChainAssignments: + @R.function + def main(x: R.Tensor): + # need the dataflow block for DCE to work + with R.dataflow(): + y = x + z = y + q = z + p = q + o = p + R.output(o) + return o + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor): + return x + + verify(TestChainAssignments, Expected) + + +def test_eliminate_trivial_check(): + # another case from canonicalize bindings that can be further simplified + @I.ir_module + class TestSameShape: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + # need the dataflow block for DCE to work + with R.dataflow(): + m, n = T.int64(), T.int64() + y = x + # trivial check, eliminated by canonicalize bindings + z = R.match_cast(x, R.Tensor((m, n), "float32")) + w = z + q = R.add(w, y) + r = R.add(q, w) + R.output(r) + return r + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + with R.dataflow(): + q = R.add(x, x) + r = R.add(q, x) + R.output(r) + return r + + verify(TestSameShape, Expected) + + +if __name__ == "__main__": + tvm.testing.main() From 1d7381839d3de875fc23ebfdde8b69b454a74ece Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 3 Aug 2023 16:46:24 -0400 Subject: [PATCH 2/6] Do not repeatedly compute the same hash --- src/relax/transform/simplify_fixpoint.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/relax/transform/simplify_fixpoint.cc b/src/relax/transform/simplify_fixpoint.cc index 006d14756cc9..ad64cbee8e48 100644 --- a/src/relax/transform/simplify_fixpoint.cc +++ b/src/relax/transform/simplify_fixpoint.cc @@ -48,13 +48,14 @@ IRModule FixpointSimplification(const IRModule& mod, Array entr transform::Pass dce = transform::DeadCodeElimination(entry_funcs); transform::Pass fold_df_output = transform::FoldDataflowBlockOutput(); + uint64_t last_hash = Hash(current_mod); while (true) { - uint64_t last_hash = Hash(current_mod); current_mod = std::move(fold_df_output(cse(canonicalize_bindings(dce(current_mod))))); uint64_t current_hash = Hash(current_mod); if (current_hash == last_hash) { break; } + last_hash = current_hash; } return current_mod; From 267c83f6ec38082feadcb33c9e90c928bf6dbc0f Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 3 Aug 2023 19:06:40 -0400 Subject: [PATCH 3/6] Fix test function name --- tests/python/relax/test_transform_canonicalize_bindings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relax/test_transform_canonicalize_bindings.py b/tests/python/relax/test_transform_canonicalize_bindings.py index 5e1d1b881e2c..6197c819ed00 100644 --- a/tests/python/relax/test_transform_canonicalize_bindings.py +++ b/tests/python/relax/test_transform_canonicalize_bindings.py @@ -89,7 +89,7 @@ def main(x: R.Tensor): assert_structural_equal(new_mod, Expected) -def test_assign_to_output_indataflow_block(): +def test_assign_to_output_in_dataflow_block(): @tvm.script.ir_module class TestDataflowAssignments: @R.function From c7e87bea429a9fa50b9df6dd49db4f3beab6add7 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Fri, 4 Aug 2023 13:16:43 -0400 Subject: [PATCH 4/6] Even more whitespace --- include/tvm/relax/transform.h | 2 +- src/relax/transform/simplify_fixpoint.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 1b98b5ea6747..73fbb2c4beb9 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -166,7 +166,7 @@ TVM_DLL Pass CanonicalizeBindings(); /*! * \brief Applies multiple simplification passes one after another until reaching * fixpoint (i.e., until no function in the module changes anymore). - * + * * Includes the following passes: * * DeadCodeElimination * * CanonicalizeBindings diff --git a/src/relax/transform/simplify_fixpoint.cc b/src/relax/transform/simplify_fixpoint.cc index ad64cbee8e48..ab8675a8b843 100644 --- a/src/relax/transform/simplify_fixpoint.cc +++ b/src/relax/transform/simplify_fixpoint.cc @@ -76,4 +76,4 @@ TVM_REGISTER_GLOBAL("relax.transform.FixpointSimplification") } // namespace transform } // namespace relax -} // namespace tvm \ No newline at end of file +} // namespace tvm From 56e0938cb7bc7a1f3205c1fd938572ed6b477656 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Mon, 7 Aug 2023 15:29:37 -0400 Subject: [PATCH 5/6] Include limit on the number of iterations for safety --- include/tvm/relax/transform.h | 4 +++- python/tvm/relax/transform/transform.py | 7 +++++-- src/relax/transform/simplify_fixpoint.cc | 17 ++++++++++++----- 3 files changed, 20 insertions(+), 8 deletions(-) diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 73fbb2c4beb9..069bccd9789c 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -173,11 +173,13 @@ TVM_DLL Pass CanonicalizeBindings(); * * EliminateCommonSubexpressions * * FoldDataflowBlockOutput * + * \param iteration_limit Upper bound on number of iterations in case the loop does not converge * \param entry_functions Entry points to the module, for dead code elimination * \param call_only Whether to apply common subexpression elimination only to calls * \return The Pass. */ -TVM_DLL Pass FixpointSimplification(Array entry_functions, bool call_only); +TVM_DLL Pass FixpointSimplification(int iteration_limit, Array entry_functions, + bool call_only); /*! * Eliminate common subexpressions within functions. diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 6fa3592d6154..c3c45d906436 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -313,7 +313,7 @@ def EliminateCommonSubexpr(call_only=False) -> FunctionPass: def FixpointSimplification( - entry_functions: Optional[List[str]] = None, call_only: bool = False + iteration_limit=10, entry_functions: Optional[List[str]] = None, call_only: bool = False ) -> tvm.ir.transform.Pass: """ Applies multiple simplification passes one after another until reaching @@ -327,6 +327,9 @@ def FixpointSimplification( Parameters ---------- + iteration_limit: int + Upper bound on number of iterations in case the loop does not converge. + entry_functions: List[str] Entry points to the module, for dead code elimination @@ -340,7 +343,7 @@ def FixpointSimplification( """ if entry_functions is None: entry_functions = ["main"] - return _ffi_api.FixpointSimplification(entry_functions, call_only) + return _ffi_api.FixpointSimplification(iteration_limit, entry_functions, call_only) def RewriteDataflowReshape() -> tvm.ir.transform.Pass: diff --git a/src/relax/transform/simplify_fixpoint.cc b/src/relax/transform/simplify_fixpoint.cc index ab8675a8b843..8dc4f8d6de49 100644 --- a/src/relax/transform/simplify_fixpoint.cc +++ b/src/relax/transform/simplify_fixpoint.cc @@ -39,8 +39,8 @@ namespace relax { uint64_t Hash(const IRModule& mod) { return SHashHandlerDefault().Hash(mod, true); } -IRModule FixpointSimplification(const IRModule& mod, Array entry_funcs, - bool call_only) { +IRModule FixpointSimplification(const IRModule& mod, int iteration_limit, + Array entry_funcs, bool call_only) { // apply passes until it stops changing IRModule current_mod = mod; transform::Pass cse = transform::EliminateCommonSubexpr(call_only); @@ -49,13 +49,19 @@ IRModule FixpointSimplification(const IRModule& mod, Array entr transform::Pass fold_df_output = transform::FoldDataflowBlockOutput(); uint64_t last_hash = Hash(current_mod); - while (true) { + int i = 0; + while (i < iteration_limit) { current_mod = std::move(fold_df_output(cse(canonicalize_bindings(dce(current_mod))))); uint64_t current_hash = Hash(current_mod); if (current_hash == last_hash) { break; } last_hash = current_hash; + i++; + } + if (i == iteration_limit) { + LOG(WARNING) << "Iteration limit reached, suggesting FixpointSimplification likely did not " + "actually reach a fixpoint."; } return current_mod; @@ -63,10 +69,11 @@ IRModule FixpointSimplification(const IRModule& mod, Array entr namespace transform { -Pass FixpointSimplification(Array entry_functions, bool call_only) { +Pass FixpointSimplification(int iteration_limit, Array entry_functions, + bool call_only) { runtime::TypedPackedFunc pass_func = [=](IRModule m, PassContext pc) { - return relax::FixpointSimplification(m, entry_functions, call_only); + return relax::FixpointSimplification(m, iteration_limit, entry_functions, call_only); }; return CreateModulePass(pass_func, 1, "FixpointSimplification", {}); } From 2b0b4ca73557087e936ea90ff72c010c9bdcefc2 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 10 Aug 2023 13:53:05 -0400 Subject: [PATCH 6/6] Change doc comments (to retrigger CI) --- include/tvm/relax/transform.h | 8 ++++---- python/tvm/relax/transform/transform.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 069bccd9789c..7ffd0cbe8151 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -168,10 +168,10 @@ TVM_DLL Pass CanonicalizeBindings(); * fixpoint (i.e., until no function in the module changes anymore). * * Includes the following passes: - * * DeadCodeElimination - * * CanonicalizeBindings - * * EliminateCommonSubexpressions - * * FoldDataflowBlockOutput + * - DeadCodeElimination + * - CanonicalizeBindings + * - EliminateCommonSubexpressions + * - FoldDataflowBlockOutput * * \param iteration_limit Upper bound on number of iterations in case the loop does not converge * \param entry_functions Entry points to the module, for dead code elimination diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index c3c45d906436..52975d99a43d 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -320,10 +320,10 @@ def FixpointSimplification( fixpoint (i.e., until no function in the module changes anymore). Includes the following passes: - * DeadCodeElimination - * CanonicalizeBindings - * EliminateCommonSubexpressions - * FoldDataflowBlockOutput + - DeadCodeElimination + - CanonicalizeBindings + - EliminateCommonSubexpressions + - FoldDataflowBlockOutput Parameters ----------