diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 8d01262aab5c..7ffd0cbe8151 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -163,6 +163,24 @@ TVM_DLL Pass Normalize(); */ TVM_DLL Pass CanonicalizeBindings(); +/*! + * \brief Applies multiple simplification passes one after another until reaching + * fixpoint (i.e., until no function in the module changes anymore). + * + * Includes the following passes: + * - DeadCodeElimination + * - CanonicalizeBindings + * - EliminateCommonSubexpressions + * - FoldDataflowBlockOutput + * + * \param iteration_limit Upper bound on number of iterations in case the loop does not converge + * \param entry_functions Entry points to the module, for dead code elimination + * \param call_only Whether to apply common subexpression elimination only to calls + * \return The Pass. + */ +TVM_DLL Pass FixpointSimplification(int iteration_limit, Array entry_functions, + bool call_only); + /*! * Eliminate common subexpressions within functions. * \return The pass that eliminates common subexpressions. diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index f512e42bf69b..52975d99a43d 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -312,6 +312,40 @@ def EliminateCommonSubexpr(call_only=False) -> FunctionPass: return _ffi_api.EliminateCommonSubexpr(call_only) # type: ignore +def FixpointSimplification( + iteration_limit=10, entry_functions: Optional[List[str]] = None, call_only: bool = False +) -> tvm.ir.transform.Pass: + """ + Applies multiple simplification passes one after another until reaching + fixpoint (i.e., until no function in the module changes anymore). + + Includes the following passes: + - DeadCodeElimination + - CanonicalizeBindings + - EliminateCommonSubexpressions + - FoldDataflowBlockOutput + + Parameters + ---------- + iteration_limit: int + Upper bound on number of iterations in case the loop does not converge. + + entry_functions: List[str] + Entry points to the module, for dead code elimination + + call_only: bool + Whether to apply common subexpression elimination only to calls + + Returns + ------- + ret: Pass + The pass + """ + if entry_functions is None: + entry_functions = ["main"] + return _ffi_api.FixpointSimplification(iteration_limit, entry_functions, call_only) + + def RewriteDataflowReshape() -> tvm.ir.transform.Pass: """Convert all reshape-like call_tir to VM reshape operator call. The VM reshape operator calls will be further lowered to a CreateView diff --git a/src/relax/transform/simplify_fixpoint.cc b/src/relax/transform/simplify_fixpoint.cc new file mode 100644 index 000000000000..8dc4f8d6de49 --- /dev/null +++ b/src/relax/transform/simplify_fixpoint.cc @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/transform/simplify_fixpoint.cc + * \brief Pass that applies other simplification passes until fixpoint. + * Presently, this subsumes the following passes: + * * FoldDataflowBlockOutput + * * CanonicalizeBIndings + * * EliminateCommonSubexpr + * * DeadCodeElimination + */ +#include +#include +#include +#include + +#include "utils.h" + +namespace tvm { +namespace relax { + +uint64_t Hash(const IRModule& mod) { return SHashHandlerDefault().Hash(mod, true); } + +IRModule FixpointSimplification(const IRModule& mod, int iteration_limit, + Array entry_funcs, bool call_only) { + // apply passes until it stops changing + IRModule current_mod = mod; + transform::Pass cse = transform::EliminateCommonSubexpr(call_only); + transform::Pass canonicalize_bindings = transform::CanonicalizeBindings(); + transform::Pass dce = transform::DeadCodeElimination(entry_funcs); + transform::Pass fold_df_output = transform::FoldDataflowBlockOutput(); + + uint64_t last_hash = Hash(current_mod); + int i = 0; + while (i < iteration_limit) { + current_mod = std::move(fold_df_output(cse(canonicalize_bindings(dce(current_mod))))); + uint64_t current_hash = Hash(current_mod); + if (current_hash == last_hash) { + break; + } + last_hash = current_hash; + i++; + } + if (i == iteration_limit) { + LOG(WARNING) << "Iteration limit reached, suggesting FixpointSimplification likely did not " + "actually reach a fixpoint."; + } + + return current_mod; +} + +namespace transform { + +Pass FixpointSimplification(int iteration_limit, Array entry_functions, + bool call_only) { + runtime::TypedPackedFunc pass_func = [=](IRModule m, + PassContext pc) { + return relax::FixpointSimplification(m, iteration_limit, entry_functions, call_only); + }; + return CreateModulePass(pass_func, 1, "FixpointSimplification", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.FixpointSimplification") + .set_body_typed(FixpointSimplification); + +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_transform_canonicalize_bindings.py b/tests/python/relax/test_transform_canonicalize_bindings.py index 5e1d1b881e2c..6197c819ed00 100644 --- a/tests/python/relax/test_transform_canonicalize_bindings.py +++ b/tests/python/relax/test_transform_canonicalize_bindings.py @@ -89,7 +89,7 @@ def main(x: R.Tensor): assert_structural_equal(new_mod, Expected) -def test_assign_to_output_indataflow_block(): +def test_assign_to_output_in_dataflow_block(): @tvm.script.ir_module class TestDataflowAssignments: @R.function diff --git a/tests/python/relax/test_transform_fixpoint_simplification.py b/tests/python/relax/test_transform_fixpoint_simplification.py new file mode 100644 index 000000000000..dbdaabcdeac3 --- /dev/null +++ b/tests/python/relax/test_transform_fixpoint_simplification.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test fixpoint simplification pass""" +import tvm +import tvm.testing +from tvm.relax.transform import CanonicalizeBindings, DeadCodeElimination, FixpointSimplification +from tvm.script.parser import ir as I, relax as R, tir as T + + +def verify(input: tvm.IRModule, expected: tvm.IRModule) -> None: + actual = FixpointSimplification()(input) + tvm.ir.assert_structural_equal(actual, expected, map_free_vars=True) + + +def test_chain_assignment(): + # test case from binding canonicalization, except it will simplify all the way + @I.ir_module + class TestChainAssignments: + @R.function + def main(x: R.Tensor): + # need the dataflow block for DCE to work + with R.dataflow(): + y = x + z = y + q = z + p = q + o = p + R.output(o) + return o + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor): + return x + + verify(TestChainAssignments, Expected) + + +def test_eliminate_trivial_check(): + # another case from canonicalize bindings that can be further simplified + @I.ir_module + class TestSameShape: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + # need the dataflow block for DCE to work + with R.dataflow(): + m, n = T.int64(), T.int64() + y = x + # trivial check, eliminated by canonicalize bindings + z = R.match_cast(x, R.Tensor((m, n), "float32")) + w = z + q = R.add(w, y) + r = R.add(q, w) + R.output(r) + return r + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): + with R.dataflow(): + q = R.add(x, x) + r = R.add(q, x) + R.output(r) + return r + + verify(TestSameShape, Expected) + + +if __name__ == "__main__": + tvm.testing.main()