From 3ba4640103ee40dd1d661eb45e67df96d2f3ca21 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 27 Oct 2023 11:39:59 -0500 Subject: [PATCH 1/2] [Unity] Implement RemoveUnusedParameters transform Currently, the `FuseOps` and `FuseTIR` passes have a large amount of added complexity to identify and handle partial use of tuple arguments. The handling partial use of tuples could be significantly simpler if performed in multiple steps. 1. Perform `FuseOps`. Any tuple variables that are used by the fused function are passed as-is. 2. Expand any parameters that are passed as a tuple. Any unused tensors that were included in a partially-used tuple will be converted to unused parameters. 3. Remove any unused parameters. Any unused tensors that were included in a partially-used tuple will be removed in this step. 4. Perform `FuseTIR`. No checking for tuple arguments, either partial or full, is required at this step. This PR implements `relax.transform.RemoveUnusedParameters`, which is step (3) in this process. --- include/tvm/relax/transform.h | 6 + python/tvm/relax/transform/__init__.py | 1 + python/tvm/relax/transform/transform.py | 10 + .../transform/remove_unused_parameters.cc | 239 ++++++++++++++++++ ...test_transform_remove_unused_parameters.py | 101 ++++++++ 5 files changed, 357 insertions(+) create mode 100644 src/relax/transform/remove_unused_parameters.cc create mode 100644 tests/python/relax/test_transform_remove_unused_parameters.py diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index b043765a6990..5a878d9160a8 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -275,6 +275,12 @@ TVM_DLL Pass LiftTransformParams(); */ TVM_DLL Pass UpdateVDevice(VDevice new_vdevice, int64_t index); +/*! \brief Remove unused parameters to internal functions + * + * \return The Pass + */ +TVM_DLL Pass RemoveUnusedParameters(); + /*! \brief Remove unused outputs from internal functions * * \return The Pass diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index 0ce0ebba1105..d1171c221249 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -55,6 +55,7 @@ PatternCheckContext, RealizeVDevice, RemovePurityChecking, + RemoveUnusedParameters, RemoveUnusedOutputs, RewriteCUDAGraph, RewriteDataflowReshape, diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 428f8c24efd7..e6636d14f423 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -558,6 +558,16 @@ def FoldConstant() -> tvm.ir.transform.Pass: return _ffi_api.FoldConstant() # type: ignore +def RemoveUnusedParameters() -> tvm.ir.transform.Pass: + """Remove unused arguments to internal functions + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.RemoveUnusedParameters() # type: ignore + + def RemoveUnusedOutputs() -> tvm.ir.transform.Pass: """Remove unused outputs from internal functions diff --git a/src/relax/transform/remove_unused_parameters.cc b/src/relax/transform/remove_unused_parameters.cc new file mode 100644 index 000000000000..33e3ecb8cd0d --- /dev/null +++ b/src/relax/transform/remove_unused_parameters.cc @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include + +#include +#include +#include + +namespace tvm { +namespace relax { + +namespace { + +template +using PSet = std::unordered_set; + +template +using PMap = std::unordered_map; + +struct CalleeAnalysis { + // The updated function + Function func; + + // A mutator that updates calls at the call site. + std::function(Array)> arg_updater; +}; + +std::optional AnalyzeCallee(Function func) { + bool is_exposed = func->attrs.GetAttr(tvm::attr::kGlobalSymbol).defined(); + if (is_exposed) return std::nullopt; + + auto free_relax_vars = [&]() -> PSet { + auto array_free_vars = FreeVars(func->body); + return {array_free_vars.begin(), array_free_vars.end()}; + }(); + + std::vector parameter_mask; + parameter_mask.reserve(func->params.size()); + + Array params; + for (const auto& param : func->params) { + bool is_used = free_relax_vars.count(param); + parameter_mask.push_back(is_used); + if (is_used) { + params.push_back(param); + } + } + + if (func->params.size() == params.size()) { + // Early bail-out for the common case where the function uses all + // of its parameters. + return std::nullopt; + } + + // Even if a parameter is unused, it may provide definitions for + // symbolic variables. We still want to remove the relax variable + // to reduce computational steps in the parent, but we need to + // provide the symbolic variables the other steps. + auto defined_tir_params = [&]() -> PSet { + auto param_sinfo = + TupleStructInfo(params.Map([](const auto& var) { return GetStructInfo(var); })); + auto arr = DefinableTIRVarsInStructInfo(param_sinfo); + return {arr.begin(), arr.end()}; + }(); + + // Use an array to define the order of the symbolic variables + Array free_tir_vars; + for (const auto& tir_var : FreeSymbolicVars(func->body)) { + if (!defined_tir_params.count(tir_var)) { + free_tir_vars.push_back(tir_var); + } + } + + for (const auto& tir_var : free_tir_vars) { + Var relax_var("param_" + tir_var->name_hint, PrimStructInfo(tir_var)); + params.push_back(relax_var); + } + + FuncStructInfo new_sinfo(params.Map([](const auto& var) { return GetStructInfo(var); }), + func->ret_struct_info, + Downcast(func->struct_info_)->purity); + + auto arg_updater = [parameter_mask, old_relax_params = func->params, + free_tir_vars](Array old_args) -> Array { + ICHECK_EQ(old_args.size(), parameter_mask.size()) + << "Call provides " << old_args.size() << ", but the callee accepts " + << parameter_mask.size() << " parameters"; + + Array new_args; + for (size_t i = 0; i < old_args.size(); i++) { + if (parameter_mask.at(i)) { + new_args.push_back(old_args[i]); + } + } + + if (free_tir_vars.size()) { + Map old_binding; + for (size_t i = 0; i < old_relax_params.size(); i++) { + old_binding.Set(old_relax_params[i], old_args[i]); + } + arith::Analyzer analyzer; + auto tir_binding = InferSymbolicVarMap(old_binding, &analyzer); + + for (const auto& tir_var : free_tir_vars) { + new_args.push_back(PrimValue(tir_binding.at(tir_var))); + } + } + + return new_args; + }; + + auto write_ptr = func.CopyOnWrite(); + write_ptr->params = params; + write_ptr->struct_info_ = new_sinfo; + + return CalleeAnalysis{func, arg_updater}; +} + +class CallSiteMutator : public ExprMutator { + public: + explicit CallSiteMutator(PMap> callsite_updaters) + : callsite_updaters_(callsite_updaters) {} + + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const CallNode* op) override { + auto node = Downcast(ExprMutator::VisitExpr_(op)); + + if (auto gvar = node->op.as()) { + if (auto it = callsite_updaters_.find(gvar.value()); it != callsite_updaters_.end()) { + node = it->second(std::move(node)); + } + } + + return node; + } + + PMap> callsite_updaters_; +}; + +} // namespace + +namespace transform { + +Pass RemoveUnusedParameters() { + runtime::TypedPackedFunc pass_func = + [=](IRModule mod, PassContext pc) -> IRModule { + PMap> callsite_updaters; + + { + IRModule new_callees; + + for (const auto& [gvar, base_func] : mod->functions) { + if (auto func = base_func.as()) { + if (auto callee_res = AnalyzeCallee(func.value())) { + auto new_func = callee_res->func; + GlobalVar new_gvar(gvar->name_hint, new_func->checked_type_); + new_gvar->struct_info_ = new_func->struct_info_; + new_callees->Add(new_gvar, new_func); + + callsite_updaters[gvar] = [old_gvar = gvar, new_gvar, + arg_updater = callee_res->arg_updater](Call call) -> Call { + ICHECK(call->op.same_as(old_gvar)) << "InternalError: " + << "Updater should be applied to " << old_gvar + << ", but was applied to " << call->op; + auto write_ptr = call.CopyOnWrite(); + write_ptr->op = new_gvar; + write_ptr->args = arg_updater(call->args); + return call; + }; + } + } + } + + if (callsite_updaters.empty()) { + return mod; + } + auto write_ptr = mod.CopyOnWrite(); + for (const auto& it : callsite_updaters) { + write_ptr->Remove(it.first); + } + write_ptr->Update(new_callees); + } + + CallSiteMutator mutator(std::move(callsite_updaters)); + + IRModule caller_updates; + + for (const auto& [gvar, base_func] : mod->functions) { + if (auto func = base_func.as()) { + auto mutated = Downcast(mutator.VisitExpr(func.value())); + if (!mutated.same_as(base_func)) { + caller_updates->Add(gvar, mutated); + } + } + } + + if (caller_updates->functions.size()) { + mod.CopyOnWrite()->Update(caller_updates); + } + return mod; + }; + auto inner_pass = CreateModulePass(pass_func, 0, "RemoveUnusedParametersInner", {}); + return tvm::transform::Sequential( + { + inner_pass, + CanonicalizeBindings(), + DeadCodeElimination({}), + }, + "RemoveUnusedParameters"); +} + +TVM_REGISTER_GLOBAL("relax.transform.RemoveUnusedParameters") + .set_body_typed(RemoveUnusedParameters); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_transform_remove_unused_parameters.py b/tests/python/relax/test_transform_remove_unused_parameters.py new file mode 100644 index 000000000000..82c8d0bd1d29 --- /dev/null +++ b/tests/python/relax/test_transform_remove_unused_parameters.py @@ -0,0 +1,101 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm.script import ir as I, relax as R, tir as T + + +class BaseCompare(tvm.testing.CompareBeforeAfter): + transform = tvm.relax.transform.RemoveUnusedParameters() + + +class TestSimple(BaseCompare): + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor, B: R.Tensor): + return Before.func(A, B) + + @R.function(private=True) + def func(A: R.Tensor, B: R.Tensor) -> R.Tensor: + return A + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor, B: R.Tensor): + return Expected.func(A) + + @R.function(private=True) + def func(A: R.Tensor) -> R.Tensor: + return A + + +class TestSymbolicVariables(BaseCompare): + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], "float32"): + return Before.func(A) + + @R.function(private=True) + def func(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], "float32"): + m = T.int64() + n = T.int64() + return R.zeros(R.shape([m, n]), dtype="float32") + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], "float32"): + m = T.int64() + n = T.int64() + out: R.Tensor([m, n], "float32") = Expected.func(R.prim_value(n), R.prim_value(m)) + return out + + @R.function(private=True) + def func( + param_n: R.Prim(value="n"), param_m: R.Prim(value="m") + ) -> R.Tensor(["m", "n"], "float32"): + m = T.int64() + n = T.int64() + return R.zeros(R.shape([m, n]), dtype="float32") + + +class TestNoExtraSymbolicVariables(BaseCompare): + """Don't add symbolic variables if they can be inferred.""" + + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], "float32"): + return Before.func(A) + + @R.function(private=True) + def func(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], "float32"): + m = T.int64() + n = T.int64() + zeros = R.zeros(R.shape([m, n]), dtype="float32") + out = R.add(A, zeros) + return out + + Expected = Before + + +if __name__ == "__main__": + tvm.testing.main() From 970db4cde7208d16462e405516db8c11468701ea Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 30 Nov 2023 21:55:24 -0600 Subject: [PATCH 2/2] Update based on review comments --- .../transform/remove_unused_parameters.cc | 41 ++++++++++++++----- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/src/relax/transform/remove_unused_parameters.cc b/src/relax/transform/remove_unused_parameters.cc index 33e3ecb8cd0d..d053d56f3205 100644 --- a/src/relax/transform/remove_unused_parameters.cc +++ b/src/relax/transform/remove_unused_parameters.cc @@ -26,6 +26,8 @@ #include #include +#include "utils.h" + namespace tvm { namespace relax { @@ -37,11 +39,17 @@ using PSet = std::unordered_set; template using PMap = std::unordered_map; +/* \brief Describes the modifications to be made for a function */ struct CalleeAnalysis { - // The updated function + /* \brief The updated private function */ Function func; - // A mutator that updates calls at the call site. + /* \brief A function that updates the callsite arguments + * + * \param The arguments used to call the original function + * + * \return The arguments to be used for the modified function + */ std::function(Array)> arg_updater; }; @@ -143,6 +151,19 @@ class CallSiteMutator : public ExprMutator { using ExprMutator::VisitExpr_; + Expr VisitExpr_(const FunctionNode* op) override { + auto node = ExprMutator::VisitExpr_(op); + + // If a function was modified, that means it called into a private + // function that now takes a reduced number of arguments. Some + // bindings in the calling scope, previously used to define those + // unused arguments, may be able to be removed as a result. + if (node.get() != op) { + node = RemoveAllUnused(node); + } + return node; + } + Expr VisitExpr_(const CallNode* op) override { auto node = Downcast(ExprMutator::VisitExpr_(op)); @@ -196,6 +217,13 @@ Pass RemoveUnusedParameters() { return mod; } auto write_ptr = mod.CopyOnWrite(); + + // Remove any private subroutines that have unused parameters, + // then add the updated versions. The new private functions + // have the same name, but require a new GlobalVar to hold the + // updated StructInfo. As a result, calling `Update()` without + // first calling `Remove()` introduce a duplicate name and + // produce an error. for (const auto& it : callsite_updaters) { write_ptr->Remove(it.first); } @@ -220,14 +248,7 @@ Pass RemoveUnusedParameters() { } return mod; }; - auto inner_pass = CreateModulePass(pass_func, 0, "RemoveUnusedParametersInner", {}); - return tvm::transform::Sequential( - { - inner_pass, - CanonicalizeBindings(), - DeadCodeElimination({}), - }, - "RemoveUnusedParameters"); + return CreateModulePass(pass_func, 0, "RemoveUnusedParameters", {}); } TVM_REGISTER_GLOBAL("relax.transform.RemoveUnusedParameters")