diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 8c6417d19d14..f743bb53d089 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -281,6 +281,12 @@ TVM_DLL Pass UpdateVDevice(VDevice new_vdevice, int64_t index); */ TVM_DLL Pass ExpandTupleArguments(); +/*! \brief Remove unused parameters to internal functions + * + * \return The Pass + */ +TVM_DLL Pass RemoveUnusedParameters(); + /*! \brief Remove unused outputs from internal functions * * \return The Pass diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index b6887160c8b5..c3f037da5f64 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -56,6 +56,7 @@ PatternCheckContext, RealizeVDevice, RemovePurityChecking, + RemoveUnusedParameters, RemoveUnusedOutputs, RewriteCUDAGraph, RewriteDataflowReshape, diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 1beb535f0bb7..0af89b7d9e8d 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -568,6 +568,16 @@ def ExpandTupleArguments() -> tvm.ir.transform.Pass: return _ffi_api.ExpandTupleArguments() # type: ignore +def RemoveUnusedParameters() -> tvm.ir.transform.Pass: + """Remove unused arguments to internal functions + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.RemoveUnusedParameters() # type: ignore + + def RemoveUnusedOutputs() -> tvm.ir.transform.Pass: """Remove unused outputs from internal functions diff --git a/src/relax/transform/remove_unused_parameters.cc b/src/relax/transform/remove_unused_parameters.cc new file mode 100644 index 000000000000..d053d56f3205 --- /dev/null +++ b/src/relax/transform/remove_unused_parameters.cc @@ -0,0 +1,260 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include + +#include +#include +#include + +#include "utils.h" + +namespace tvm { +namespace relax { + +namespace { + +template +using PSet = std::unordered_set; + +template +using PMap = std::unordered_map; + +/* \brief Describes the modifications to be made for a function */ +struct CalleeAnalysis { + /* \brief The updated private function */ + Function func; + + /* \brief A function that updates the callsite arguments + * + * \param The arguments used to call the original function + * + * \return The arguments to be used for the modified function + */ + std::function(Array)> arg_updater; +}; + +std::optional AnalyzeCallee(Function func) { + bool is_exposed = func->attrs.GetAttr(tvm::attr::kGlobalSymbol).defined(); + if (is_exposed) return std::nullopt; + + auto free_relax_vars = [&]() -> PSet { + auto array_free_vars = FreeVars(func->body); + return {array_free_vars.begin(), array_free_vars.end()}; + }(); + + std::vector parameter_mask; + parameter_mask.reserve(func->params.size()); + + Array params; + for (const auto& param : func->params) { + bool is_used = free_relax_vars.count(param); + parameter_mask.push_back(is_used); + if (is_used) { + params.push_back(param); + } + } + + if (func->params.size() == params.size()) { + // Early bail-out for the common case where the function uses all + // of its parameters. + return std::nullopt; + } + + // Even if a parameter is unused, it may provide definitions for + // symbolic variables. We still want to remove the relax variable + // to reduce computational steps in the parent, but we need to + // provide the symbolic variables the other steps. + auto defined_tir_params = [&]() -> PSet { + auto param_sinfo = + TupleStructInfo(params.Map([](const auto& var) { return GetStructInfo(var); })); + auto arr = DefinableTIRVarsInStructInfo(param_sinfo); + return {arr.begin(), arr.end()}; + }(); + + // Use an array to define the order of the symbolic variables + Array free_tir_vars; + for (const auto& tir_var : FreeSymbolicVars(func->body)) { + if (!defined_tir_params.count(tir_var)) { + free_tir_vars.push_back(tir_var); + } + } + + for (const auto& tir_var : free_tir_vars) { + Var relax_var("param_" + tir_var->name_hint, PrimStructInfo(tir_var)); + params.push_back(relax_var); + } + + FuncStructInfo new_sinfo(params.Map([](const auto& var) { return GetStructInfo(var); }), + func->ret_struct_info, + Downcast(func->struct_info_)->purity); + + auto arg_updater = [parameter_mask, old_relax_params = func->params, + free_tir_vars](Array old_args) -> Array { + ICHECK_EQ(old_args.size(), parameter_mask.size()) + << "Call provides " << old_args.size() << ", but the callee accepts " + << parameter_mask.size() << " parameters"; + + Array new_args; + for (size_t i = 0; i < old_args.size(); i++) { + if (parameter_mask.at(i)) { + new_args.push_back(old_args[i]); + } + } + + if (free_tir_vars.size()) { + Map old_binding; + for (size_t i = 0; i < old_relax_params.size(); i++) { + old_binding.Set(old_relax_params[i], old_args[i]); + } + arith::Analyzer analyzer; + auto tir_binding = InferSymbolicVarMap(old_binding, &analyzer); + + for (const auto& tir_var : free_tir_vars) { + new_args.push_back(PrimValue(tir_binding.at(tir_var))); + } + } + + return new_args; + }; + + auto write_ptr = func.CopyOnWrite(); + write_ptr->params = params; + write_ptr->struct_info_ = new_sinfo; + + return CalleeAnalysis{func, arg_updater}; +} + +class CallSiteMutator : public ExprMutator { + public: + explicit CallSiteMutator(PMap> callsite_updaters) + : callsite_updaters_(callsite_updaters) {} + + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const FunctionNode* op) override { + auto node = ExprMutator::VisitExpr_(op); + + // If a function was modified, that means it called into a private + // function that now takes a reduced number of arguments. Some + // bindings in the calling scope, previously used to define those + // unused arguments, may be able to be removed as a result. + if (node.get() != op) { + node = RemoveAllUnused(node); + } + return node; + } + + Expr VisitExpr_(const CallNode* op) override { + auto node = Downcast(ExprMutator::VisitExpr_(op)); + + if (auto gvar = node->op.as()) { + if (auto it = callsite_updaters_.find(gvar.value()); it != callsite_updaters_.end()) { + node = it->second(std::move(node)); + } + } + + return node; + } + + PMap> callsite_updaters_; +}; + +} // namespace + +namespace transform { + +Pass RemoveUnusedParameters() { + runtime::TypedPackedFunc pass_func = + [=](IRModule mod, PassContext pc) -> IRModule { + PMap> callsite_updaters; + + { + IRModule new_callees; + + for (const auto& [gvar, base_func] : mod->functions) { + if (auto func = base_func.as()) { + if (auto callee_res = AnalyzeCallee(func.value())) { + auto new_func = callee_res->func; + GlobalVar new_gvar(gvar->name_hint, new_func->checked_type_); + new_gvar->struct_info_ = new_func->struct_info_; + new_callees->Add(new_gvar, new_func); + + callsite_updaters[gvar] = [old_gvar = gvar, new_gvar, + arg_updater = callee_res->arg_updater](Call call) -> Call { + ICHECK(call->op.same_as(old_gvar)) << "InternalError: " + << "Updater should be applied to " << old_gvar + << ", but was applied to " << call->op; + auto write_ptr = call.CopyOnWrite(); + write_ptr->op = new_gvar; + write_ptr->args = arg_updater(call->args); + return call; + }; + } + } + } + + if (callsite_updaters.empty()) { + return mod; + } + auto write_ptr = mod.CopyOnWrite(); + + // Remove any private subroutines that have unused parameters, + // then add the updated versions. The new private functions + // have the same name, but require a new GlobalVar to hold the + // updated StructInfo. As a result, calling `Update()` without + // first calling `Remove()` introduce a duplicate name and + // produce an error. + for (const auto& it : callsite_updaters) { + write_ptr->Remove(it.first); + } + write_ptr->Update(new_callees); + } + + CallSiteMutator mutator(std::move(callsite_updaters)); + + IRModule caller_updates; + + for (const auto& [gvar, base_func] : mod->functions) { + if (auto func = base_func.as()) { + auto mutated = Downcast(mutator.VisitExpr(func.value())); + if (!mutated.same_as(base_func)) { + caller_updates->Add(gvar, mutated); + } + } + } + + if (caller_updates->functions.size()) { + mod.CopyOnWrite()->Update(caller_updates); + } + return mod; + }; + return CreateModulePass(pass_func, 0, "RemoveUnusedParameters", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.RemoveUnusedParameters") + .set_body_typed(RemoveUnusedParameters); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_transform_remove_unused_parameters.py b/tests/python/relax/test_transform_remove_unused_parameters.py new file mode 100644 index 000000000000..82c8d0bd1d29 --- /dev/null +++ b/tests/python/relax/test_transform_remove_unused_parameters.py @@ -0,0 +1,101 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm.script import ir as I, relax as R, tir as T + + +class BaseCompare(tvm.testing.CompareBeforeAfter): + transform = tvm.relax.transform.RemoveUnusedParameters() + + +class TestSimple(BaseCompare): + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor, B: R.Tensor): + return Before.func(A, B) + + @R.function(private=True) + def func(A: R.Tensor, B: R.Tensor) -> R.Tensor: + return A + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor, B: R.Tensor): + return Expected.func(A) + + @R.function(private=True) + def func(A: R.Tensor) -> R.Tensor: + return A + + +class TestSymbolicVariables(BaseCompare): + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], "float32"): + return Before.func(A) + + @R.function(private=True) + def func(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], "float32"): + m = T.int64() + n = T.int64() + return R.zeros(R.shape([m, n]), dtype="float32") + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], "float32"): + m = T.int64() + n = T.int64() + out: R.Tensor([m, n], "float32") = Expected.func(R.prim_value(n), R.prim_value(m)) + return out + + @R.function(private=True) + def func( + param_n: R.Prim(value="n"), param_m: R.Prim(value="m") + ) -> R.Tensor(["m", "n"], "float32"): + m = T.int64() + n = T.int64() + return R.zeros(R.shape([m, n]), dtype="float32") + + +class TestNoExtraSymbolicVariables(BaseCompare): + """Don't add symbolic variables if they can be inferred.""" + + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], "float32"): + return Before.func(A) + + @R.function(private=True) + def func(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], "float32"): + m = T.int64() + n = T.int64() + zeros = R.zeros(R.shape([m, n]), dtype="float32") + out = R.add(A, zeros) + return out + + Expected = Before + + +if __name__ == "__main__": + tvm.testing.main()