diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index b043765a6990..8c6417d19d14 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -275,6 +275,12 @@ TVM_DLL Pass LiftTransformParams(); */ TVM_DLL Pass UpdateVDevice(VDevice new_vdevice, int64_t index); +/*! \brief Expand tuple arguments to internal functions + * + * \return The Pass + */ +TVM_DLL Pass ExpandTupleArguments(); + /*! \brief Remove unused outputs from internal functions * * \return The Pass diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index 0ce0ebba1105..b6887160c8b5 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -33,6 +33,7 @@ DecomposeOpsForInference, DecomposeOpsForTraining, EliminateCommonSubexpr, + ExpandTupleArguments, FewShotTuning, FoldConstant, FunctionPass, diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 428f8c24efd7..1beb535f0bb7 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -558,6 +558,16 @@ def FoldConstant() -> tvm.ir.transform.Pass: return _ffi_api.FoldConstant() # type: ignore +def ExpandTupleArguments() -> tvm.ir.transform.Pass: + """Expand tuple arguments to internal functions + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.ExpandTupleArguments() # type: ignore + + def RemoveUnusedOutputs() -> tvm.ir.transform.Pass: """Remove unused outputs from internal functions diff --git a/src/relax/transform/expand_tuple_arguments.cc b/src/relax/transform/expand_tuple_arguments.cc new file mode 100644 index 000000000000..c61832bbab52 --- /dev/null +++ b/src/relax/transform/expand_tuple_arguments.cc @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include + +#include +#include + +namespace tvm { +namespace relax { + +namespace { + +template +using PMap = std::unordered_map; + +Optional ExpandParams(Function func) { + bool is_exposed = func->attrs.GetAttr(tvm::attr::kGlobalSymbol).defined(); + if (is_exposed) return NullOpt; + + bool has_tuple_param = std::any_of( + func->params.begin(), func->params.end(), + [](const Var& param) -> bool { return param->struct_info_.as(); }); + + if (!has_tuple_param) return NullOpt; + + Array params; + Array bindings; + + std::function expand_param = [&](const Var& param) { + if (auto sinfo = param->struct_info_.as()) { + Array internal_tuple; + for (size_t i = 0; i < sinfo->fields.size(); i++) { + auto name = static_cast(std::stringstream() + << param->name_hint() << "_" << i) + .str(); + Var new_param(name, sinfo->fields[i]); + internal_tuple.push_back(new_param); + expand_param(new_param); + } + bindings.push_back(VarBinding(param, Tuple(internal_tuple))); + } else { + params.push_back(param); + } + }; + + for (const auto& param : func->params) { + expand_param(param); + } + + FuncStructInfo new_sinfo(params.Map([](const auto& var) { return GetStructInfo(var); }), + func->ret_struct_info, + Downcast(func->struct_info_)->purity); + + auto write_ptr = func.CopyOnWrite(); + write_ptr->params = params; + write_ptr->body = SeqExpr({BindingBlock(bindings)}, func->body); + write_ptr->struct_info_ = new_sinfo; + + return func; +} + +class TupleExpander : public ExprMutator { + public: + explicit TupleExpander(PMap callees) : replacements_(callees) {} + + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const CallNode* op) override { + auto node = Downcast(ExprMutator::VisitExpr_(op)); + + if (auto gvar = node->op.as()) { + if (auto it = replacements_.find(gvar.value()); it != replacements_.end()) { + Array new_args; + + std::function expand_arg = [&](const Expr& arg) { + if (auto sinfo = arg->struct_info_.as()) { + for (size_t i = 0; i < sinfo->fields.size(); i++) { + expand_arg(TupleGetItem(arg, i)); + } + } else { + new_args.push_back(arg); + } + }; + + for (const auto& arg : node->args) { + expand_arg(arg); + } + + auto write_ptr = node.CopyOnWrite(); + write_ptr->op = it->second; + write_ptr->args = new_args; + } + } + + return node; + } + + PMap replacements_; +}; + +} // namespace + +namespace transform { + +Pass ExpandTupleArguments() { + runtime::TypedPackedFunc pass_func = + [=](IRModule mod, PassContext pc) -> IRModule { + PMap gvar_replacements; + + { + PMap new_callees; + + for (const auto& [gvar, base_func] : mod->functions) { + if (auto func = base_func.as()) { + if (auto opt = ExpandParams(func.value())) { + auto new_func = opt.value(); + GlobalVar new_gvar(gvar->name_hint, new_func->checked_type_); + new_gvar->struct_info_ = new_func->struct_info_; + gvar_replacements[gvar] = new_gvar; + new_callees[new_gvar] = new_func; + } + } + } + + if (gvar_replacements.empty()) { + return mod; + } + auto write_ptr = mod.CopyOnWrite(); + for (auto [old_gvar, new_gvar] : gvar_replacements) { + write_ptr->Remove(old_gvar); + write_ptr->Add(new_gvar, new_callees.at(new_gvar)); + } + } + + TupleExpander mutator(std::move(gvar_replacements)); + + IRModule caller_updates; + + for (const auto& [gvar, base_func] : mod->functions) { + if (auto func = base_func.as()) { + auto mutated = Downcast(mutator.VisitExpr(func.value())); + if (!mutated.same_as(base_func)) { + caller_updates->Add(gvar, mutated); + } + } + } + + if (caller_updates->functions.size()) { + mod.CopyOnWrite()->Update(caller_updates); + } + return mod; + }; + auto inner_pass = CreateModulePass(pass_func, 0, "ExpandTupleArgumentsInner", {}); + + return tvm::transform::Sequential( + { + inner_pass, + CanonicalizeBindings(), + DeadCodeElimination({}), + }, + "ExpandTupleArguments"); +} + +TVM_REGISTER_GLOBAL("relax.transform.ExpandTupleArguments").set_body_typed(ExpandTupleArguments); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_transform_expand_tuple_args.py b/tests/python/relax/test_transform_expand_tuple_args.py new file mode 100644 index 000000000000..a90db1d84d47 --- /dev/null +++ b/tests/python/relax/test_transform_expand_tuple_args.py @@ -0,0 +1,79 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm.script import ir as I, relax as R, tir as T + + +class BaseCompare(tvm.testing.CompareBeforeAfter): + transform = tvm.relax.transform.ExpandTupleArguments() + + +class TestSimple(BaseCompare): + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor, B: R.Tensor): + return Before.func((A, B)) + + @R.function(private=True) + def func(args: R.Tuple([R.Tensor, R.Tensor])) -> R.Tensor: + return args[0] + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor, B: R.Tensor): + return Expected.func(A, B) + + @R.function(private=True) + def func(A: R.Tensor, B: R.Tensor) -> R.Tensor: + return A + + +class TestNested(BaseCompare): + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor, B: R.Tensor, C: R.Tensor, D: R.Tensor) -> R.Tensor: + return Before.func(((A, B), (C, D))) + + @R.function(private=True) + def func( + args: R.Tuple( + [ + R.Tuple([R.Tensor, R.Tensor]), + R.Tuple([R.Tensor, R.Tensor]), + ] + ) + ) -> R.Tensor: + return args[0][1] + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor, B: R.Tensor, C: R.Tensor, D: R.Tensor) -> R.Tensor: + return Expected.func(A, B, C, D) + + @R.function(private=True) + def func(A: R.Tensor, B: R.Tensor, C: R.Tensor, D: R.Tensor) -> R.Tensor: + return B + + +if __name__ == "__main__": + tvm.testing.main()