-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[RELAY] Add structural hashing for Relay #1977
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Edit: Discussed with @jroesch, it seems that being able to check hash map membership by structural equality is exactly the point here. It might be helpful to have comments indicating this |
src/relay/ir/hash.cc
Outdated
| } | ||
|
|
||
| size_t VisitExpr_(const VarNode* var) final { | ||
| return std::hash<int>()(var_map_[GetRef<Var>(var)]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use .at()
there is the other case (hashing term with free var), it should be error with at() or just hash by ptr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a matter of fact, when VisitExpr is called, this means var is not in the hash_map_, it is a free var, we could simply hash by type_annotation and name hint to be safe.
src/relay/ir/hash.cc
Outdated
|
|
||
| using AttrsHashHandler::VisitAttr_; | ||
| size_t VisitAttr_(const Variable* lhs) final { | ||
| return 0; // return LeafNodeEqual(GetRef<NodeRef>(lhs), other); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tqchen I don't fully understand this case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are two possible cases in here:
- Variable is a var defined in say TypeParams, in this case, then we will need to put the hash value at the declaration site(likely you can put everything under the same hash_map_ and reuse it for both Attr and Expr
- Variable is a free variable, in that case, depending on the equality semantics(graph equal vs alpha equal)
- In the case of alpha_equal, free variables do not map to each other, and hash by the pointer is a good choice
- In the case of graph_equal, free variables might match each other, maybe we could hash by name_hint
- By considering both cases, hash by name_hint might not be a bad choice
src/relay/ir/hash.cc
Outdated
| } | ||
|
|
||
| size_t VisitExpr_(const VarNode* var) final { | ||
| return std::hash<int>()(var_map_[GetRef<Var>(var)]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a matter of fact, when VisitExpr is called, this means var is not in the hash_map_, it is a free var, we could simply hash by type_annotation and name hint to be safe.
src/relay/ir/hash.cc
Outdated
|
|
||
| private: | ||
| // whether to map open terms. | ||
| bool map_free_var_{false}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
map_free_var has less of a use in here, in alpha equality case, it means if we would like to also call BindVar when we meet free variables
src/relay/ir/hash.cc
Outdated
| bool map_free_var_{false}; | ||
| // renaming of NodeRef to indicate two nodes equals to each other | ||
| std::unordered_map<NodeRef, size_t, NodeHash, NodeEqual> hash_map_; | ||
| std::unordered_map<NodeRef, int, NodeHash, NodeEqual> var_map_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
likely we can collapse var_map into hash_map
src/relay/ir/hash.cc
Outdated
| size_t VisitType_(const FuncTypeNode* func_type) final { | ||
| size_t hash = std::hash<std::string>()(func_type->_type_key); | ||
| for (auto type_param : func_type->type_params) { | ||
| hash = Combine(hash, TypeHash(type_param)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a declaration side, need to record the hash value of type_param into the hash_map_
src/relay/ir/hash.cc
Outdated
| } | ||
|
|
||
| size_t VisitExpr_(const GlobalVarNode* global) final { | ||
| return GetRef<GlobalVar>(global).hash(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consider hash by global->name_hint instead (so two environment might still be able to match)
python/tvm/relay/ir_pass.py
Outdated
| """ | ||
| return bool(_make._graph_equal(lhs, rhs)) | ||
|
|
||
| def expr_hash(expr): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we need a good name for this, expr_hash is a bit generic. Maybe structural_hash/? Most users get used to same function name being overloaded for types, so we could use it for both type and Expr
include/tvm/relay/pass.h
Outdated
| * | ||
| * \return the hash value. | ||
| */ | ||
| size_t HashType(const Type& type); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can have an overloaded name for both Type and Expr, StructuralHash?
src/relay/ir/hash.cc
Outdated
| } | ||
|
|
||
| size_t VisitType_(const IncompleteTypeNode* incomplete) final { | ||
| return GetRef<IncompleteType>(incomplete); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
conversion of pointer to size_t? The simplest approach to make this work for both graph and renaming is to hash Kind(ignore pointer for now)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Likely, we can also directly use BindVar(Incomplete) and combine it with kind and key
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is just a bug, was intending to call hash, will do that.
src/relay/ir/hash.cc
Outdated
| } | ||
|
|
||
| size_t VisitType_(const TypeVarNode* tyvar) final { | ||
| int index = BindVar(GetRef<TypeVar>(tyvar)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since both TypeVar/Var/Variable have two possible ways of occurrence:
- The point of declaration
- The point of visit
There are certain implications in this function that is not necessarily apparent from the code. Because Visit already checks the hashmap, this means the TypeVar itself is unbound. And we hash it by the free variable index, which implies graph equality.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider leave a comment here about its implications
src/relay/ir/hash.cc
Outdated
|
|
||
| size_t BindVar(const NodeRef& var) { | ||
| size_t hash = std::hash<int>()(var_counter++); | ||
| CHECK(hash_map_.find(var) == hash_map_.end()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hash_map_.count(var) == 0
src/relay/ir/hash.cc
Outdated
| return it->second; | ||
| } | ||
|
|
||
| return std::hash<std::string>()(var->name_hint); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively, we can use BindVar here as well, and combine it with name_hint and key
src/relay/ir/hash.cc
Outdated
| } | ||
|
|
||
| size_t VisitType_(const TypeVarNode* tyvar) final { | ||
| int index = BindVar(GetRef<TypeVar>(tyvar)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shold be size_t
src/relay/ir/hash.cc
Outdated
|
|
||
| size_t VisitType_(const TypeVarNode* tyvar) final { | ||
| int index = BindVar(GetRef<TypeVar>(tyvar)); | ||
| return std::hash<int>()(index); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to hash again, as BindVar already hashes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this was small hold over from previous version. Fixed in commit.
Add an implementation of structural hashing for Relay.
This PR also extends the alpha equality tests to ensure the hash matches when equal, and does not match when not equal.
cc @MarisaKirisame @tqchen @slyubomirsky
Should be good to go.