Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions include/tvm/relax/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,24 @@ TVM_DLL Pass Normalize();
*/
TVM_DLL Pass CanonicalizeBindings();

/*!
* \brief Applies multiple simplification passes one after another until reaching
* fixpoint (i.e., until no function in the module changes anymore).
*
* Includes the following passes:
* - DeadCodeElimination
* - CanonicalizeBindings
* - EliminateCommonSubexpressions
* - FoldDataflowBlockOutput
*
* \param iteration_limit Upper bound on number of iterations in case the loop does not converge
* \param entry_functions Entry points to the module, for dead code elimination
* \param call_only Whether to apply common subexpression elimination only to calls
* \return The Pass.
*/
TVM_DLL Pass FixpointSimplification(int iteration_limit, Array<runtime::String> entry_functions,
bool call_only);

/*!
* Eliminate common subexpressions within functions.
* \return The pass that eliminates common subexpressions.
Expand Down
34 changes: 34 additions & 0 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,40 @@ def EliminateCommonSubexpr(call_only=False) -> FunctionPass:
return _ffi_api.EliminateCommonSubexpr(call_only) # type: ignore


def FixpointSimplification(
iteration_limit=10, entry_functions: Optional[List[str]] = None, call_only: bool = False
) -> tvm.ir.transform.Pass:
"""
Applies multiple simplification passes one after another until reaching
fixpoint (i.e., until no function in the module changes anymore).

Includes the following passes:
- DeadCodeElimination
- CanonicalizeBindings
- EliminateCommonSubexpressions
- FoldDataflowBlockOutput

Parameters
----------
iteration_limit: int
Upper bound on number of iterations in case the loop does not converge.

entry_functions: List[str]
Entry points to the module, for dead code elimination

call_only: bool
Whether to apply common subexpression elimination only to calls

Returns
-------
ret: Pass
The pass
"""
if entry_functions is None:
entry_functions = ["main"]
return _ffi_api.FixpointSimplification(iteration_limit, entry_functions, call_only)


def RewriteDataflowReshape() -> tvm.ir.transform.Pass:
"""Convert all reshape-like call_tir to VM reshape operator call.
The VM reshape operator calls will be further lowered to a CreateView
Expand Down
86 changes: 86 additions & 0 deletions src/relax/transform/simplify_fixpoint.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
*
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relax/transform/simplify_fixpoint.cc
* \brief Pass that applies other simplification passes until fixpoint.
* Presently, this subsumes the following passes:
* * FoldDataflowBlockOutput
* * CanonicalizeBIndings
* * EliminateCommonSubexpr
* * DeadCodeElimination
*/
#include <tvm/relax/analysis.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/transform.h>
#include <tvm/relax/utils.h>

#include "utils.h"

namespace tvm {
namespace relax {

uint64_t Hash(const IRModule& mod) { return SHashHandlerDefault().Hash(mod, true); }

IRModule FixpointSimplification(const IRModule& mod, int iteration_limit,
Array<runtime::String> entry_funcs, bool call_only) {
// apply passes until it stops changing
IRModule current_mod = mod;
transform::Pass cse = transform::EliminateCommonSubexpr(call_only);
transform::Pass canonicalize_bindings = transform::CanonicalizeBindings();
transform::Pass dce = transform::DeadCodeElimination(entry_funcs);
transform::Pass fold_df_output = transform::FoldDataflowBlockOutput();

uint64_t last_hash = Hash(current_mod);
int i = 0;
while (i < iteration_limit) {
current_mod = std::move(fold_df_output(cse(canonicalize_bindings(dce(current_mod)))));
uint64_t current_hash = Hash(current_mod);
if (current_hash == last_hash) {
break;
}
last_hash = current_hash;
i++;
}
if (i == iteration_limit) {
LOG(WARNING) << "Iteration limit reached, suggesting FixpointSimplification likely did not "
"actually reach a fixpoint.";
}

return current_mod;
}

namespace transform {

Pass FixpointSimplification(int iteration_limit, Array<runtime::String> entry_functions,
bool call_only) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule m,
PassContext pc) {
return relax::FixpointSimplification(m, iteration_limit, entry_functions, call_only);
};
return CreateModulePass(pass_func, 1, "FixpointSimplification", {});
}

TVM_REGISTER_GLOBAL("relax.transform.FixpointSimplification")
.set_body_typed(FixpointSimplification);

} // namespace transform
} // namespace relax
} // namespace tvm
2 changes: 1 addition & 1 deletion tests/python/relax/test_transform_canonicalize_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def main(x: R.Tensor):
assert_structural_equal(new_mod, Expected)


def test_assign_to_output_indataflow_block():
def test_assign_to_output_in_dataflow_block():
@tvm.script.ir_module
class TestDataflowAssignments:
@R.function
Expand Down
86 changes: 86 additions & 0 deletions tests/python/relax/test_transform_fixpoint_simplification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test fixpoint simplification pass"""
import tvm
import tvm.testing
from tvm.relax.transform import CanonicalizeBindings, DeadCodeElimination, FixpointSimplification
from tvm.script.parser import ir as I, relax as R, tir as T


def verify(input: tvm.IRModule, expected: tvm.IRModule) -> None:
actual = FixpointSimplification()(input)
tvm.ir.assert_structural_equal(actual, expected, map_free_vars=True)


def test_chain_assignment():
# test case from binding canonicalization, except it will simplify all the way
@I.ir_module
class TestChainAssignments:
@R.function
def main(x: R.Tensor):
# need the dataflow block for DCE to work
with R.dataflow():
y = x
z = y
q = z
p = q
o = p
R.output(o)
return o

@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor):
return x

verify(TestChainAssignments, Expected)


def test_eliminate_trivial_check():
# another case from canonicalize bindings that can be further simplified
@I.ir_module
class TestSameShape:
@R.function
def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"):
# need the dataflow block for DCE to work
with R.dataflow():
m, n = T.int64(), T.int64()
y = x
# trivial check, eliminated by canonicalize bindings
z = R.match_cast(x, R.Tensor((m, n), "float32"))
w = z
q = R.add(w, y)
r = R.add(q, w)
R.output(r)
return r

@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"):
with R.dataflow():
q = R.add(x, x)
r = R.add(q, x)
R.output(r)
return r

verify(TestSameShape, Expected)


if __name__ == "__main__":
tvm.testing.main()