diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index ae6d4059a633..b043765a6990 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -275,6 +275,12 @@ TVM_DLL Pass LiftTransformParams(); */ TVM_DLL Pass UpdateVDevice(VDevice new_vdevice, int64_t index); +/*! \brief Remove unused outputs from internal functions + * + * \return The Pass + */ +TVM_DLL Pass RemoveUnusedOutputs(); + /*! * \brief Annotate Op Pattern Kind for TIR functions, which is used in FuseOps. * \note It is an auto-detect pass for "unscheduled prim_funcs", the op_pattern will be diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index 4d841a8e7b92..0ce0ebba1105 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -55,6 +55,7 @@ PatternCheckContext, RealizeVDevice, RemovePurityChecking, + RemoveUnusedOutputs, RewriteCUDAGraph, RewriteDataflowReshape, RunCodegen, diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 049bce7428da..428f8c24efd7 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -558,6 +558,16 @@ def FoldConstant() -> tvm.ir.transform.Pass: return _ffi_api.FoldConstant() # type: ignore +def RemoveUnusedOutputs() -> tvm.ir.transform.Pass: + """Remove unused outputs from internal functions + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.RemoveUnusedOutputs() # type: ignore + + def AnnotateTIROpPattern() -> tvm.ir.transform.Pass: """Annotate Op Pattern Kind for TIR functions diff --git a/python/tvm/script/parser/relax/parser.py b/python/tvm/script/parser/relax/parser.py index 863c249975a7..1a0c3cea8e0b 100644 --- a/python/tvm/script/parser/relax/parser.py +++ b/python/tvm/script/parser/relax/parser.py @@ -367,6 +367,9 @@ def visit_if(self: Parser, node: doc.If) -> None: @dispatch.register(token="relax", type_name="enter_token") def enter_token(self: Parser) -> Dict[str, Any]: def relax_call(self, *args) -> Expr: + + args = [convert_to_expr(arg) if isinstance(arg, tuple) else arg for arg in args] + if all(isinstance(x, Expr) for x in args): return relax.Call(self, args) arg_types = [type(x) for x in args] diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index 7c8b0e883c71..29c9463ba5c8 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -1938,7 +1938,7 @@ def __init_subclass__(cls): @classmethod def _normalize_ir_module(cls, func): - if isinstance(func, tvm.tir.PrimFunc): + if isinstance(func, (tvm.tir.PrimFunc, tvm.IRModule)): def inner(self): # pylint: disable=unused-argument @@ -2042,8 +2042,7 @@ def inner(self): @staticmethod def _is_method(func): - sig = inspect.signature(func) - return "self" in sig.parameters + return callable(func) and "self" in inspect.signature(func).parameters def test_compare(self, before, expected, transform): """Unit test to compare the expected TIR PrimFunc to actual""" diff --git a/src/relax/transform/remove_unused_outputs.cc b/src/relax/transform/remove_unused_outputs.cc new file mode 100644 index 000000000000..e3bf12382c67 --- /dev/null +++ b/src/relax/transform/remove_unused_outputs.cc @@ -0,0 +1,326 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include + +#include +#include +#include + +namespace tvm { +namespace relax { + +namespace { + +template +using PSet = std::unordered_set; + +template +using PMap = std::unordered_map; + +class PartialTupleUsageCollector : ExprVisitor { + public: + static PMap> Collect(const IRModule& mod) { + PMap num_outputs; + + for (const auto& [gvar, base_func] : mod->functions) { + bool is_exposed = base_func->attrs.GetAttr(tvm::attr::kGlobalSymbol).defined(); + + if (!is_exposed) { + if (auto relax_func = base_func.as()) { + if (auto out_tuple = relax_func->ret_struct_info.as()) { + num_outputs[gvar] = out_tuple->fields.size(); + } + } + } + } + + if (num_outputs.empty()) { + // Early bail-out if the module has no private functions that + // return tuples. + return {}; + } + + PartialTupleUsageCollector visitor(std::move(num_outputs)); + for (const auto& [gvar, base_func] : mod->functions) { + if (auto func = base_func.as()) { + visitor.VisitExpr(func.value()); + } + } + + PMap> to_update; + for (const auto& [gvar, mask] : visitor.output_usage_mask_) { + bool has_unused_output = + std::any_of(mask.begin(), mask.end(), [](const bool is_used) { return !is_used; }); + if (has_unused_output) { + to_update[gvar] = mask; + } + } + + return to_update; + } + + private: + explicit PartialTupleUsageCollector(PMap num_outputs) { + for (const auto& [gvar, num_output] : num_outputs) { + output_usage_mask_[gvar] = std::vector(num_output, false); + } + } + + void VisitBinding(const Binding& binding) override { + ExprVisitor::VisitBinding(binding); + known_bindings_.Set(binding->var, GetBoundValue(binding)); + } + + void VisitExpr_(const TupleGetItemNode* op) override { + Expr tuple = UnwrapBindings(op->tuple); + + if (auto call = tuple.as()) { + if (auto opt_callee = call->op.as()) { + auto callee = opt_callee.value(); + if (auto it = output_usage_mask_.find(callee); it != output_usage_mask_.end()) { + auto& used_indices = it->second; + + CHECK_GE(op->index, 0) << "IndexError: " + << "Indices for TupleGetItem must be non-negative, " + << "but expression " << GetRef(op) + << " uses a tuple index of " << op->index; + size_t index = op->index; + + CHECK_LT(index, used_indices.size()) + << "IndexError: " + << "Indices for TupleGetItem must be less than the size of the tuple, " + << "but expression " << GetRef(op) << " uses a tuple index of " << op->index + << " for a tuple of size " << used_indices.size(); + used_indices[index] = true; + } + } + } + } + + Expr UnwrapBindings(Expr expr) const { + auto get_bound_value = [&](const Expr& expr) -> Optional { + if (auto var = expr.as()) { + if (auto known_binding = known_bindings_.Get(var.value())) { + return known_binding.value(); + } + } + return NullOpt; + }; + + while (auto unwrapped = get_bound_value(expr)) { + expr = unwrapped.value(); + } + return expr; + } + + Map known_bindings_; + PMap> output_usage_mask_; +}; + +Function UpdateCallee(Function func, const std::vector& usage_mask) { + auto old_func_sinfo = func->struct_info_.as(); + + auto old_ret_sinfo = func->ret_struct_info.as(); + ICHECK(old_ret_sinfo) << "All functions returning non-tuple outputs " + << "should have been pruned already by PartialTupleUsageCollector"; + + Array outputs; + + // This helper variable will be removed by the post-proc of + // CanonicalizeBindings and DeadCodeElimination. + Var previous_outputs("previous_outputs", func->ret_struct_info); + + for (size_t i = 0; i < usage_mask.size(); i++) { + if (usage_mask[i]) { + outputs.push_back(TupleGetItem(previous_outputs, i)); + } + } + + Expr new_output = outputs.size() == 1 ? outputs[0] : Tuple(outputs); + StructInfo new_return_sinfo = + outputs.size() == 1 ? GetStructInfo(outputs[0]) : TupleStructInfo(outputs.Map(GetStructInfo)); + + VarBinding binding(previous_outputs, func->body); + BindingBlock binding_block({binding}); + SeqExpr new_body({binding_block}, new_output); + + auto old_sinfo = Downcast(func->struct_info_); + FuncStructInfo new_sinfo(old_func_sinfo->params.value(), new_return_sinfo, + old_func_sinfo->purity); + + auto write_ptr = func.CopyOnWrite(); + write_ptr->struct_info_ = new_sinfo; + write_ptr->body = new_body; + + return func; +} + +class CallSiteMutator : public ExprMutator { + public: + explicit CallSiteMutator(PMap> callsite_updaters) + : callsite_updaters_(callsite_updaters) {} + + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const CallNode* op) override { + auto node = Downcast(ExprMutator::VisitExpr_(op)); + + if (auto gvar = node->op.as()) { + if (auto it = callsite_updaters_.find(gvar.value()); it != callsite_updaters_.end()) { + return it->second(node); + } + } + + return node; + } + + PMap> callsite_updaters_; +}; + +} // namespace + +namespace transform { + +Pass RemoveUnusedOutputs() { + runtime::TypedPackedFunc pass_func = + [=](IRModule mod, PassContext pc) -> IRModule { + auto usage = PartialTupleUsageCollector::Collect(mod); + + if (usage.empty()) { + // Early bail-out if there are no updates to make. + return mod; + } + + PMap> callsite_updaters; + + { + IRModule new_callees; + + for (const auto& [gvar, base_func] : mod->functions) { + if (auto func = base_func.as()) { + if (auto it = usage.find(gvar); it != usage.end()) { + const auto& usage_mask = it->second; + auto new_func = UpdateCallee(func.value(), usage_mask); + + GlobalVar new_gvar(gvar->name_hint, new_func->checked_type_); + new_gvar->struct_info_ = new_func->struct_info_; + new_callees->Add(new_gvar, new_func); + + callsite_updaters[gvar] = [old_gvar = gvar, new_gvar, usage_mask](Call call) -> Expr { + ICHECK(call->op.same_as(old_gvar)) << "InternalError: " + << "Updater should be applied to " << old_gvar + << ", but was applied to " << call->op; + + auto old_call_sinfo = call->struct_info_.as(); + ICHECK(old_call_sinfo) + << "InternalError: " + << "Updater should be applied to Call producing an output tuple, " + << "but " << call << " has struct info " << call->struct_info_; + CHECK_EQ(usage_mask.size(), old_call_sinfo->fields.size()) + << "Function " << call->op << " produces " << usage_mask.size() << " outputs, " + << "but " << call << " was used in a context expecting " + << old_call_sinfo->fields.size() << " outputs."; + + Call new_call(new_gvar, call->args); + + int num_outputs_used = 0; + for (bool used : usage_mask) { + num_outputs_used += used; + } + + Array new_results; + int new_result_index = 0; + for (size_t i = 0; i < usage_mask.size(); i++) { + if (usage_mask[i]) { + // This element of the old output tuple was used. We replace + // it either with access into the new output tuple, if callee + // still produces multiple outputs, or with the output + // itself, if the callee has been reduced to producing a + // single output. + auto replacement = [&]() -> Expr { + if (num_outputs_used == 1) { + return new_call; + } else { + return TupleGetItem(new_call, new_result_index); + } + }(); + new_results.push_back(replacement); + new_result_index++; + } else { + // This element of the tuple was unused in the old output, + // and is no longer generated from the modified callee. We + // could remember the index mapping and re-index any access + // into the old tuple, but it's simpler to just let + // CanonicalizeBindings and DCE handle it. + new_results.push_back( + relax::PrimValue(FloatImm(DataType::Float(64), std::nan("")))); + } + } + + return Tuple(new_results); + }; + } + } + } + + auto write_ptr = mod.CopyOnWrite(); + for (const auto& [gvar, callee] : new_callees->functions) { + write_ptr->Remove(write_ptr->GetGlobalVar(gvar->name_hint)); + write_ptr->Add(gvar, callee); + } + } + + CallSiteMutator mutator(std::move(callsite_updaters)); + + IRModule caller_updates; + + for (const auto& [gvar, base_func] : mod->functions) { + if (auto func = base_func.as()) { + auto mutated = Downcast(mutator.VisitExpr(func.value())); + if (!mutated.same_as(base_func)) { + caller_updates->Add(gvar, mutated); + } + } + } + + if (caller_updates->functions.size()) { + mod.CopyOnWrite()->Update(caller_updates); + } + return mod; + }; + auto inner_pass = CreateModulePass(pass_func, 0, "RemoveUnusedOutputsInner", {}); + return tvm::transform::Sequential( + { + inner_pass, + CanonicalizeBindings(), + DeadCodeElimination({}), + }, + "RemoveUnusedOutputs"); +} + +TVM_REGISTER_GLOBAL("relax.transform.RemoveUnusedOutputs").set_body_typed(RemoveUnusedOutputs); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_transform_remove_unused_outputs.py b/tests/python/relax/test_transform_remove_unused_outputs.py new file mode 100644 index 000000000000..c0405ca58d00 --- /dev/null +++ b/tests/python/relax/test_transform_remove_unused_outputs.py @@ -0,0 +1,123 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm.script import ir as I, relax as R, tir as T + + +class BaseCompare(tvm.testing.CompareBeforeAfter): + transform = tvm.relax.transform.RemoveUnusedOutputs() + + +class TestSimple(BaseCompare): + @I.ir_module + class Before: + @R.function + def main(): + args = Before.func() + return args[0] + + @R.function(private=True) + def func() -> R.Tuple([R.Tensor, R.Tensor]): + A = R.zeros([16, 16], "int32") + B = R.ones([16, 16], "int32") + return (A, B) + + @I.ir_module + class Expected: + @R.function + def main(): + A = Expected.func() + return A + + @R.function(private=True) + def func() -> R.Tensor([16, 16], "int32"): + A = R.zeros([16, 16], "int32") + return A + + +class TestUseMultipleOutputs(BaseCompare): + @I.ir_module + class Before: + @R.function + def main(): + args = Before.func() + return (args[0], args[2]) + + @R.function(private=True) + def func() -> R.Tuple([R.Tensor, R.Tensor, R.Tensor]): + A = R.zeros([16, 16], "int32") + B = R.ones([16, 16], "int32") + C = R.zeros([32, 32], "int32") + return (A, B, C) + + @I.ir_module + class Expected: + @R.function + def main(): + args = Expected.func() + return (args[0], args[1]) + + @R.function(private=True) + def func() -> R.Tuple([R.Tensor([16, 16], "int32"), R.Tensor([32, 32], "int32")]): + A = R.zeros([16, 16], "int32") + C = R.zeros([32, 32], "int32") + return (A, C) + + +class TestMultipleCallSites(BaseCompare): + @I.ir_module + class Before: + @R.function + def main_a(): + args = Before.func() + return args[0] + + @R.function + def main_b(): + args = Before.func() + return args[2] + + @R.function(private=True) + def func() -> R.Tuple([R.Tensor, R.Tensor, R.Tensor]): + A = R.zeros([16, 16], "int32") + B = R.ones([16, 16], "int32") + C = R.zeros([32, 32], "int32") + return (A, B, C) + + @I.ir_module + class Expected: + @R.function + def main_a(): + args = Expected.func() + return args[0] + + @R.function + def main_b(): + args = Expected.func() + return args[1] + + @R.function(private=True) + def func() -> R.Tuple([R.Tensor([16, 16], "int32"), R.Tensor([32, 32], "int32")]): + A = R.zeros([16, 16], "int32") + C = R.zeros([32, 32], "int32") + return (A, C) + + +if __name__ == "__main__": + tvm.testing.main()