diff --git a/src/relax/transform/dead_code_elimination.cc b/src/relax/transform/dead_code_elimination.cc index 73f66d2ef362..28c7d74ef8d0 100644 --- a/src/relax/transform/dead_code_elimination.cc +++ b/src/relax/transform/dead_code_elimination.cc @@ -50,12 +50,22 @@ class CallTracer : public ExprVisitor { explicit CallTracer(IRModule mod) : mod_{mod}, called_funcs_{}, visiting_{} {} void VisitExpr_(const GlobalVarNode* op) final { - called_funcs_.insert(GetRef(op)); - auto func = mod_->Lookup(op->name_hint); - if (const auto* function_node = func.as()) { - VisitExpr(GetRef(function_node)); + auto gvar = GetRef(op); + called_funcs_.insert(gvar); + if (auto func = mod_->functions.Get(gvar)) { + if (const auto* function_node = func.as()) { + VisitExpr(GetRef(function_node)); + } + // else: Don't visit PrimFuncs -- we don't need to collect any tir.Calls therein. + } else { + // The GlobalVar is not contained in the IRModule. While the + // input IRModule is ill-formed, this specific case is allowed + // for use with `relax.transform.ApplyPassToFunction`. If this + // occurs, DCE should not remove any internal functions from the + // IRModule, as their removal is only valid if we have a + // complete call graph. + all_callees_found_ = false; } - // else: Don't visit PrimFuncs -- we don't need to collect any tir.Calls therein. } void VisitExpr_(const CallNode* call_node) final { ExprVisitor::VisitExpr_(call_node); } @@ -77,11 +87,24 @@ class CallTracer : public ExprVisitor { VisitExpr(main_func); } - bool check_if_called(GlobalVar gv) { return called_funcs_.count(gv) > 0; } + /* \brief Check if a function is unreachable + * + * \param gvar The function to be checked + * + * \return True if the function can be proven to be unreachable, + * either directly or indirectly, from an external caller. + * Otherwise, false. + */ + bool CheckIfProvablyUnreachable(const GlobalVar& gvar) const { + return all_callees_found_ && !called_funcs_.count(gvar); + } private: IRModule mod_; + /* \brief Whether all callees could be located within the IRModule */ + bool all_callees_found_{true}; + // Record the names of all encountered functions. std::unordered_set called_funcs_; @@ -101,7 +124,7 @@ IRModule RemoveUnusedFunctions( // The tracer contains all user-provided entry functions, all // externally-callable functions, and anything that is directly or // indirectly accessible from an entry function. - if (!tracer.check_if_called(kv.first)) { + if (tracer.CheckIfProvablyUnreachable(kv.first)) { to_remove.push_back(kv.first); } } diff --git a/tests/python/relax/conftest.py b/tests/python/relax/conftest.py index 1e12a95e524b..bb5a04ef7679 100644 --- a/tests/python/relax/conftest.py +++ b/tests/python/relax/conftest.py @@ -37,7 +37,14 @@ def pytest_configure(config): "markers", ( "skip_well_formed_check_before_transform: " - "Only check for well-formed IRModule after a transform" + "Suppress the default well-formed check before a IRModule transform" + ), + ) + config.addinivalue_line( + "markers", + ( + "skip_well_formed_check_after_transform: " + "Suppress the default well-formed check after a IRModule transform" ), ) @@ -54,15 +61,20 @@ def pytest_configure(config): # `@pytest.mark.skip_well_formed_check_before_transform` @pytest.fixture(autouse=True) def apply_instrument_well_formed(unit_test_marks): - validate_before_transform = "skip_well_formed_check_before_transform" not in unit_test_marks + validate_after_transform = "skip_well_formed_check_after_transform" not in unit_test_marks - instrument = WellFormedInstrument(validate_before_transform=validate_before_transform) current = tvm.transform.PassContext.current() + instruments = list(current.instruments) + + if validate_before_transform or validate_after_transform: + instruments.append( + WellFormedInstrument(validate_before_transform=validate_before_transform) + ) override = tvm.transform.PassContext( - # Append the new instrument - instruments=[*current.instruments, instrument], + # With the new WellFormedInstrument appended + instruments=instruments, # Forward all other parameters opt_level=current.opt_level, required_pass=current.required_pass, diff --git a/tests/python/relax/test_transform_dead_code_elimination.py b/tests/python/relax/test_transform_dead_code_elimination.py index c0a2d47b19f1..2dae252cadd1 100644 --- a/tests/python/relax/test_transform_dead_code_elimination.py +++ b/tests/python/relax/test_transform_dead_code_elimination.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +import pytest + import tvm import tvm.testing from tvm.relax.transform import DeadCodeElimination @@ -507,5 +509,158 @@ def test_extern_func(): verify(before, before) +@pytest.mark.skip_well_formed_check_before_transform +@pytest.mark.skip_well_formed_check_after_transform +def test_compatibility_with_apply_pass_to_function(): + """DeadCodeElimination can be used with ApplyPassToFunction + + The `ApplyPassToFunction` utility calls another transform, where + only the specified functions are exposed to the internal + transform. This intermediate does not contain `cls.subroutine`, + and so the intermediate is ill-formed. + + In general, IRModule transformations may assume that their inputs + are well-formed. In specific cases, IRModule transformations may + accept IRModules that are ill-formed. The `DeadCodeElimination` + transform allows IRModule arguments that are ill-formed due to + a dangling GlobalVar. + + After `DeadCodeElimination` completes, the resulting function is + inserted in the original IRModule, providing a well-formed output + from `ApplyPassToFunction`. + + """ + + @I.ir_module + class Before: + @R.function + def to_be_transformed(A: R.Tensor): + cls = Before + + B = R.add(A, A) + C = cls.subroutine(B) + D = R.multiply(C, C) + return C + + @R.function + def to_be_ignored(A: R.Tensor): + cls = Before + + B = R.add(A, A) + C = cls.subroutine(B) + D = R.multiply(C, C) + return C + + @R.function(private=True) + def subroutine(arg: R.Tensor) -> R.Tensor: + return R.add(arg, arg) + + @I.ir_module + class Expected: + @R.function + def to_be_transformed(A: R.Tensor): + cls = Expected + + B = R.add(A, A) + C = cls.subroutine(B) + return C + + @R.function + def to_be_ignored(A: R.Tensor): + cls = Expected + + B = R.add(A, A) + C = cls.subroutine(B) + D = R.multiply(C, C) + return C + + @R.function(private=True) + def subroutine(arg: R.Tensor) -> R.Tensor: + return R.add(arg, arg) + + # The well-formed check in conftest.py must be disabled, to avoid + # triggering on the ill-formed intermediate, so this unit test + # checks it explicitly. + assert tvm.relax.analysis.well_formed(Before) + After = tvm.ir.transform.ApplyPassToFunction( + tvm.relax.transform.DeadCodeElimination(), + "to_be_transformed", + )(Before) + assert tvm.relax.analysis.well_formed(After) + tvm.ir.assert_structural_equal(Expected, After) + + +@pytest.mark.skip_well_formed_check_before_transform +@pytest.mark.skip_well_formed_check_after_transform +def test_well_formed_output_with_restricted_scope(): + """DeadCodeElimination can be used with ApplyPassToFunction + + If the call graph cannot be completely traced, private functions + should not be removed. + + See `test_compatibility_with_apply_pass_to_function` for full + description of `DeadCodeElimination` and `ApplyPassToFunction`. + + """ + + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor): + cls = Before + + B = R.add(A, A) + C = cls.subroutine(B) + D = R.multiply(C, C) + return C + + @R.function(private=True) + def subroutine(A: R.Tensor) -> R.Tensor: + cls = Before + + B = R.add(A, A) + C = cls.subsubroutine(B) + D = R.multiply(C, C) + return C + + @R.function(private=True) + def subsubroutine(A: R.Tensor) -> R.Tensor: + B = R.add(A, A) + C = R.multiply(B, B) + return B + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor): + cls = Expected + + B = R.add(A, A) + C = cls.subroutine(B) + return C + + @R.function(private=True) + def subroutine(A: R.Tensor) -> R.Tensor: + cls = Expected + + B = R.add(A, A) + C = cls.subsubroutine(B) + D = R.multiply(C, C) + return C + + @R.function(private=True) + def subsubroutine(A: R.Tensor) -> R.Tensor: + B = R.add(A, A) + return B + + assert tvm.relax.analysis.well_formed(Before) + After = tvm.ir.transform.ApplyPassToFunction( + tvm.relax.transform.DeadCodeElimination(), + "main|subsubroutine", + )(Before) + assert tvm.relax.analysis.well_formed(After) + tvm.ir.assert_structural_equal(Expected, After) + + if __name__ == "__main__": tvm.testing.main()