From 7114a7eb8fbd6914c6afa12d8984e87495f3ba77 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Mon, 15 Apr 2024 16:31:24 +0800 Subject: [PATCH] [Relax] Stabilize relax pass mutation order The current implementation of the relax pass is not stable, to be more specific, the order of the mutation is not stable. This PR aims to stabilize the mutation order of the relax pass, and further stabilize the output of the relax pass. Also fixes a minor doc typo in NN frontend --- include/tvm/ir/module.h | 3 ++- python/tvm/relax/frontend/nn/core.py | 6 +++--- src/ir/module.cc | 4 ++++ src/relax/transform/alter_op_impl.cc | 3 ++- src/relax/transform/dead_code_elimination.cc | 3 ++- src/relax/transform/fuse_ops.cc | 22 +++++++++++--------- src/relax/transform/fuse_tir.cc | 3 ++- src/relax/transform/legalize_ops.cc | 8 ++++--- 8 files changed, 32 insertions(+), 20 deletions(-) diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 2a5412a5671f..8fd87a6304dd 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -249,7 +249,8 @@ class IRModuleNode : public Object { TVM_DLL GlobalVar GetGlobalVar(const String& str) const; /*! - * \brief Collect all global vars defined in this module. + * \brief Collect all global vars defined in this module, ordered by + * the global variable name. * \returns An array of global vars */ TVM_DLL Array GetGlobalVars() const; diff --git a/python/tvm/relax/frontend/nn/core.py b/python/tvm/relax/frontend/nn/core.py index b7b3f411ed41..4953c1c81701 100644 --- a/python/tvm/relax/frontend/nn/core.py +++ b/python/tvm/relax/frontend/nn/core.py @@ -475,10 +475,10 @@ def export_tvm( ------- irmodule : tvm.ir.IRModule The converted tvm IR representation of the model. - params : Dict[str, tvm.nd.array] - A dictionary of parameters corresponding to the weights of - the model. + params : List[Tuple[str, Parameter]] + A list of Parameters corresponding to the weights of the model. ext_mods : List[nn.ExternModule] + A list of ExternModules that are used in the model. """ # pylint: disable=import-outside-toplevel from . import spec as _spec diff --git a/src/ir/module.cc b/src/ir/module.cc index 2e60441e94d3..261fbfe087c6 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -26,6 +26,7 @@ #include #include +#include #include #include #include @@ -183,6 +184,9 @@ tvm::Array IRModuleNode::GetGlobalVars() const { for (const auto& pair : global_var_map_) { global_vars.push_back(pair.second); } + std::sort(global_vars.begin(), global_vars.end(), [](const GlobalVar& lhs, const GlobalVar& rhs) { + return lhs->name_hint < rhs->name_hint; + }); return tvm::Array(global_vars); } diff --git a/src/relax/transform/alter_op_impl.cc b/src/relax/transform/alter_op_impl.cc index 8b5518212cc8..2cb226d56e27 100644 --- a/src/relax/transform/alter_op_impl.cc +++ b/src/relax/transform/alter_op_impl.cc @@ -89,7 +89,8 @@ class AlterOpImplMutator : public ExprMutator { op_buffer_axis_separators__(axis_separators_) {} IRModule Run() { - for (const auto& [gv, func] : mod_->functions) { + for (const auto& gv : mod_->GetGlobalVars()) { + const auto& func = mod_->Lookup(gv); if (func->IsInstance()) { relax::Function update_func = Downcast(VisitExpr(func)); builder_->UpdateFunction(gv, update_func); diff --git a/src/relax/transform/dead_code_elimination.cc b/src/relax/transform/dead_code_elimination.cc index 28c7d74ef8d0..876c714c61e3 100644 --- a/src/relax/transform/dead_code_elimination.cc +++ b/src/relax/transform/dead_code_elimination.cc @@ -148,7 +148,8 @@ IRModule DeadCodeElimination(const IRModule& arg_mod, Array ent for (const auto& name : entry_function_names) { entry_functions.insert(mod->GetGlobalVar(name)); } - for (const auto& [gv, func] : mod->functions) { + for (const auto& gv : mod->GetGlobalVars()) { + const auto& func = mod->Lookup(gv); if (func.as() || func->GetLinkageType() == LinkageType::kExternal) { entry_functions.insert(gv); } diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index a2a3e96dd567..3e762778d849 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -691,7 +691,8 @@ class OperatorFusor : public ExprMutator { * \return The new IRModule after transformation */ IRModule Transform() { - for (const auto& [gv, func] : mod_->functions) { + for (const auto& gv : mod_->GetGlobalVars()) { + const auto& func = mod_->Lookup(gv); // Only visit Relax function without attr kPrimitive. if (func->IsInstance() && !func->HasNonzeroAttr(attr::kPrimitive)) { auto updated_func = Downcast(VisitExpr(func)); @@ -1196,9 +1197,9 @@ class CompositeFunctionAnnotator : public ExprMutator { IRModule Run() { auto mod = builder_->GetContextIRModule(); - auto all_functions = mod->functions; - for (const auto& entry : all_functions) { - if (const auto* func = entry.second.as()) { + for (const auto& gv : mod->GetGlobalVars()) { + const auto& base_func = mod->Lookup(gv); + if (const auto* func = base_func.as()) { if (func->GetAttr(attr::kComposite).defined() || func->GetAttr(attr::kCodegen).defined()) { continue; @@ -1208,7 +1209,7 @@ class CompositeFunctionAnnotator : public ExprMutator { if (!new_body.same_as(func->body)) { auto new_func = Function(func->params, new_body, func->ret_struct_info, func->is_pure, func->attrs, func->span); - builder_->UpdateFunction(entry.first, new_func); + builder_->UpdateFunction(gv, new_func); } } } @@ -1272,11 +1273,12 @@ IRModule FuseOpsByPattern(const tvm::Array& patterns, support::Arena arena; for (const auto& pattern : patterns) { OperatorFusor::GroupMap group_map; - for (const auto& entry : mod->functions) { - if (entry.second->IsInstance()) { + for (const auto& gv : mod->GetGlobalVars()) { + const auto& base_func = mod->Lookup(gv); + if (base_func->IsInstance()) { continue; } - const FunctionNode* function = entry.second.as(); + const FunctionNode* function = base_func.as(); if (function->GetAttr(attr::kPrimitive).defined() || function->GetAttr(attr::kComposite).defined() || function->GetAttr(attr::kCodegen).defined()) { @@ -1285,8 +1287,8 @@ IRModule FuseOpsByPattern(const tvm::Array& patterns, auto map = PatternBasedPartitioner::Run(pattern->name, pattern->pattern, pattern->annotation_patterns, - pattern->check.value_or(nullptr), entry.second, - &arena, pattern->attrs_getter.value_or(nullptr)); + pattern->check.value_or(nullptr), base_func, &arena, + pattern->attrs_getter.value_or(nullptr)); for (const auto& [key, value] : map) { CHECK(!group_map.count(key)) << "ValueError: " diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index 11785ab73ac6..3df17b29ca52 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -964,7 +964,8 @@ class TIRFuseMutator : public ExprMutator { static IRModule Transform(IRModule mod) { // Collect all primitive relax functions Map primitive_relax; - for (const auto& [gvar, base_func] : mod->functions) { + for (const auto& gvar : mod->GetGlobalVars()) { + const auto& base_func = mod->Lookup(gvar); // Only fuse primitive relax functions if (base_func->HasNonzeroAttr(attr::kPrimitive)) { if (auto func = base_func.as()) { diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc index 343c18acd7a9..e2e463ff2b2f 100644 --- a/src/relax/transform/legalize_ops.cc +++ b/src/relax/transform/legalize_ops.cc @@ -67,16 +67,18 @@ class LegalizeMutator : public ExprMutator { } IRModule Transform() { - for (const auto& [gv, func] : mod_->functions) { + for (const auto& gv : mod_->GetGlobalVars()) { + const auto& func = mod_->Lookup(gv); if (func->IsInstance()) { auto updated_func = Downcast(this->VisitExpr(func)); builder_->UpdateFunction(gv, Downcast(updated_func)); } } // Fill the "kTarget" attribute of PrimFunc - for (const auto& [gv, func] : builder_->GetContextIRModule()->functions) { + const auto& mod = builder_->GetContextIRModule(); + for (const auto& gv : mod->GetGlobalVars()) { const tir::PrimFuncNode* prim_func; - if (tmap_.count(gv) && (prim_func = func.as())) { + if (tmap_.count(gv) && (prim_func = mod->Lookup(gv).as())) { auto f = WithAttr(GetRef(prim_func), tvm::attr::kTarget, tmap_[gv]); builder_->UpdateFunction(gv, f); }