From e6e8c1df72c16a67df781e571ff2f45dd9c05ee6 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Tue, 14 Mar 2023 13:48:45 +0800 Subject: [PATCH] [IR] Enhance IRModule SEqual/SHash to support cross function calls As GlobalVars are defined under IRModule, we need to define it during the IRModule SEqual/SHash step via (`DefEqual` and `DefHash`). --- src/ir/module.cc | 68 ++++++++++++++++++++++++++++-------------------- 1 file changed, 40 insertions(+), 28 deletions(-) diff --git a/src/ir/module.cc b/src/ir/module.cc index 42ced9612045..7a973da29dfa 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -63,67 +63,79 @@ IRModule::IRModule(tvm::Map functions, } bool IRModuleNode::SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const { - if (functions.size() != other->functions.size()) return false; if (!equal(this->attrs, other->attrs)) return false; - if (equal.IsPathTracingEnabled()) { - const ObjectPathPair& obj_path_pair = equal.GetCurrentObjectPaths(); - for (const auto& kv : this->functions) { - if (!other->ContainGlobalVar(kv.first->name_hint)) return false; + + if (functions.size() != other->functions.size()) return false; + // Update GlobalVar remap + for (const auto& gv : this->GetGlobalVars()) { + if (!other->ContainGlobalVar(gv->name_hint)) return false; + if (!equal.DefEqual(gv, other->GetGlobalVar(gv->name_hint))) return false; + } + // Checking functions + for (const auto& kv : this->functions) { + if (equal.IsPathTracingEnabled()) { + const ObjectPathPair& obj_path_pair = equal.GetCurrentObjectPaths(); ObjectPathPair func_paths = {obj_path_pair->lhs_path->Attr("functions")->MapValue(kv.first), obj_path_pair->rhs_path->Attr("functions") ->MapValue(other->GetGlobalVar(kv.first->name_hint))}; if (!equal(kv.second, other->Lookup(kv.first->name_hint), func_paths)) return false; + } else { + if (!equal(kv.second, other->Lookup(kv.first->name_hint))) return false; } - if (type_definitions.size() != other->type_definitions.size()) return false; - for (const auto& kv : this->type_definitions) { - if (!other->ContainGlobalTypeVar(kv.first->name_hint)) return false; - ObjectPathPair type_def_paths = { + } + + if (type_definitions.size() != other->type_definitions.size()) return false; + // Update GlobalTypeVar remap + for (const auto& gtv : this->GetGlobalTypeVars()) { + if (!other->ContainGlobalTypeVar(gtv->name_hint)) return false; + if (!equal.DefEqual(gtv, other->GetGlobalTypeVar(gtv->name_hint))) return false; + } + // Checking type_definitions + for (const auto& kv : this->type_definitions) { + if (equal.IsPathTracingEnabled()) { + const ObjectPathPair& obj_path_pair = equal.GetCurrentObjectPaths(); + ObjectPathPair type_paths = { obj_path_pair->lhs_path->Attr("type_definitions")->MapValue(kv.first), obj_path_pair->rhs_path->Attr("type_definitions") ->MapValue(other->GetGlobalTypeVar(kv.first->name_hint))}; - if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint), type_def_paths)) - return false; + if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint), type_paths)) return false; + } else { + if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint))) return false; } - return true; - } - for (const auto& kv : this->functions) { - if (!other->ContainGlobalVar(kv.first->name_hint)) return false; - if (!equal(kv.second, other->Lookup(kv.first->name_hint))) return false; - } - if (type_definitions.size() != other->type_definitions.size()) return false; - for (const auto& kv : this->type_definitions) { - if (!other->ContainGlobalTypeVar(kv.first->name_hint)) return false; - if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint))) return false; } return true; } void IRModuleNode::SHashReduce(SHashReducer hash_reduce) const { - using KV = std::pair; + using KV = std::tuple; // hash the functions. std::vector temp; auto reduce_temp = [&]() { // sort by the hash key of the keys. std::sort(temp.begin(), temp.end(), - [](const KV& lhs, const KV& rhs) { return lhs.first < rhs.first; }); + [](const KV& lhs, const KV& rhs) { return std::get<0>(lhs) < std::get<0>(rhs); }); hash_reduce(static_cast(temp.size())); - // hash the content + // Defhash the GlobalVar/GlobalTypeVar + for (size_t i = 0; i < temp.size(); ++i) { + hash_reduce.DefHash(std::get<1>(temp[i])); + } + // hash the name and content for (size_t i = 0; i < temp.size(); ++i) { - hash_reduce(temp[i].first); - hash_reduce(temp[i].second); + hash_reduce(std::get<0>(temp[i])); + hash_reduce(std::get<2>(temp[i])); } }; for (const auto& kv : this->functions) { - temp.emplace_back(kv.first->name_hint, kv.second); + temp.emplace_back(kv.first->name_hint, kv.first, kv.second); } reduce_temp(); temp.clear(); for (const auto& kv : this->type_definitions) { - temp.emplace_back(kv.first->name_hint, kv.second); + temp.emplace_back(kv.first->name_hint, kv.first, kv.second); } reduce_temp(); hash_reduce(this->attrs);