Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 40 additions & 28 deletions src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,67 +63,79 @@ IRModule::IRModule(tvm::Map<GlobalVar, BaseFunc> functions,
}

bool IRModuleNode::SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const {
if (functions.size() != other->functions.size()) return false;
if (!equal(this->attrs, other->attrs)) return false;
if (equal.IsPathTracingEnabled()) {
const ObjectPathPair& obj_path_pair = equal.GetCurrentObjectPaths();
for (const auto& kv : this->functions) {
if (!other->ContainGlobalVar(kv.first->name_hint)) return false;

if (functions.size() != other->functions.size()) return false;
// Update GlobalVar remap
for (const auto& gv : this->GetGlobalVars()) {
if (!other->ContainGlobalVar(gv->name_hint)) return false;
if (!equal.DefEqual(gv, other->GetGlobalVar(gv->name_hint))) return false;
}
// Checking functions
for (const auto& kv : this->functions) {
if (equal.IsPathTracingEnabled()) {
const ObjectPathPair& obj_path_pair = equal.GetCurrentObjectPaths();
ObjectPathPair func_paths = {obj_path_pair->lhs_path->Attr("functions")->MapValue(kv.first),
obj_path_pair->rhs_path->Attr("functions")
->MapValue(other->GetGlobalVar(kv.first->name_hint))};
if (!equal(kv.second, other->Lookup(kv.first->name_hint), func_paths)) return false;
} else {
if (!equal(kv.second, other->Lookup(kv.first->name_hint))) return false;
}
if (type_definitions.size() != other->type_definitions.size()) return false;
for (const auto& kv : this->type_definitions) {
if (!other->ContainGlobalTypeVar(kv.first->name_hint)) return false;
ObjectPathPair type_def_paths = {
}

if (type_definitions.size() != other->type_definitions.size()) return false;
// Update GlobalTypeVar remap
for (const auto& gtv : this->GetGlobalTypeVars()) {
if (!other->ContainGlobalTypeVar(gtv->name_hint)) return false;
if (!equal.DefEqual(gtv, other->GetGlobalTypeVar(gtv->name_hint))) return false;
}
// Checking type_definitions
for (const auto& kv : this->type_definitions) {
if (equal.IsPathTracingEnabled()) {
const ObjectPathPair& obj_path_pair = equal.GetCurrentObjectPaths();
ObjectPathPair type_paths = {
obj_path_pair->lhs_path->Attr("type_definitions")->MapValue(kv.first),
obj_path_pair->rhs_path->Attr("type_definitions")
->MapValue(other->GetGlobalTypeVar(kv.first->name_hint))};
if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint), type_def_paths))
return false;
if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint), type_paths)) return false;
} else {
if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint))) return false;
}
return true;
}
for (const auto& kv : this->functions) {
if (!other->ContainGlobalVar(kv.first->name_hint)) return false;
if (!equal(kv.second, other->Lookup(kv.first->name_hint))) return false;
}
if (type_definitions.size() != other->type_definitions.size()) return false;
for (const auto& kv : this->type_definitions) {
if (!other->ContainGlobalTypeVar(kv.first->name_hint)) return false;
if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint))) return false;
}
return true;
}

void IRModuleNode::SHashReduce(SHashReducer hash_reduce) const {
using KV = std::pair<std::string, ObjectRef>;
using KV = std::tuple<std::string, ObjectRef, ObjectRef>;
// hash the functions.
std::vector<KV> temp;

auto reduce_temp = [&]() {
// sort by the hash key of the keys.
std::sort(temp.begin(), temp.end(),
[](const KV& lhs, const KV& rhs) { return lhs.first < rhs.first; });
[](const KV& lhs, const KV& rhs) { return std::get<0>(lhs) < std::get<0>(rhs); });

hash_reduce(static_cast<uint64_t>(temp.size()));
// hash the content
// Defhash the GlobalVar/GlobalTypeVar
for (size_t i = 0; i < temp.size(); ++i) {
hash_reduce.DefHash(std::get<1>(temp[i]));
}
// hash the name and content
for (size_t i = 0; i < temp.size(); ++i) {
hash_reduce(temp[i].first);
hash_reduce(temp[i].second);
hash_reduce(std::get<0>(temp[i]));
hash_reduce(std::get<2>(temp[i]));
}
};

for (const auto& kv : this->functions) {
temp.emplace_back(kv.first->name_hint, kv.second);
temp.emplace_back(kv.first->name_hint, kv.first, kv.second);
}
reduce_temp();

temp.clear();
for (const auto& kv : this->type_definitions) {
temp.emplace_back(kv.first->name_hint, kv.second);
temp.emplace_back(kv.first->name_hint, kv.first, kv.second);
}
reduce_temp();
hash_reduce(this->attrs);
Expand Down