-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[MetaSchedule] Consolidate module hashing and equality testing #13050
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
| ObjectPtr<WorkloadNode> n = runtime::make_object<WorkloadNode>(); | ||
| n->shash = tvm::StructuralHash()(mod); | ||
| n->mod = mod; | ||
| n->shash = ModuleEquality::Create("structural")->Hash(mod); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this constructor is not used anywhere, I'm hard coding the instance here to avoid passing the mod_eq_name string. I want to remove this constructor if there is no objection.
| /*! \brief The default database implementation, which mimics two database tables with two files. */ | ||
| class JSONDatabaseNode : public DatabaseNode { | ||
| public: | ||
| explicit JSONDatabaseNode(String mod_eq_name = "structural") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default value is required, since we need to be able to default-construct JSONDatabaseNode and other database nodes due to the use of TVM_REGISTER_NODE_TYPE macro.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible if we bypass it via some tricks like unique ptr? (Sorry I don’t think very deeply so feel free to object)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't get what you meant here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just noticed that DatabaseNode::mod_eq_ has been already a unique_ptr - so we don't need an explicit default constructor DatabaseNode::DatabaseNode, but instead we could do inline initialization:
- std::unique_ptr<ModuleEquality> mod_eq_;
+ std::unique_ptr<ModuleEquality> mod_eq_{nullptr};BTW, just to nitpick, shall we change mod_eq_ from private to protected so it could be accessed by derived classes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since JSONDatabaseNode is constructed by Database class,
tvm/src/meta_schedule/database/json_database.cc
Lines 147 to 150 in eecb7fd
| Database Database::JSONDatabase(String path_workload, String path_tuning_record, | |
| bool allow_missing) { | |
| int num_threads = std::thread::hardware_concurrency(); | |
| ObjectPtr<JSONDatabaseNode> n = make_object<JSONDatabaseNode>(); |
and Database is not a subclass of DatabaseNode, setting mod_eq after JSONDatabaseNode is constructed requires making mod_eq a public member or introducing a setter method in DatabaseNode.
Moreover, it would also introduce complication to initializing workloads2idx_ (see #13050 (comment)).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just tried: Removing the construction of workloads2idx_, workloads2idx_(/*bucket_count*/ 0, WorkloadHash(), WorkloadEqual(GetModuleEquality())) leads to compilation failure since workloads2idx_ cannot be default-constructed (due to WorkloadEqual which requires mod_eq to be initialized).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Then I'm fine with the current implementation
| public: | ||
| explicit JSONDatabaseNode(String mod_eq_name = "structural") | ||
| : DatabaseNode(mod_eq_name), | ||
| workloads2idx_(/*bucket_count*/ 0, WorkloadHash(), WorkloadEqual(GetModuleEquality())) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since WorkloadEqual needs to be aware of the module hashing / equality testing being used, we need to construct this unordered_map after the ModuleEquality instance is constructed in DatabaseNode by L72.
3c79ad5 to
b71ca1b
Compare
| shash = tvm::StructuralHash()(mod); | ||
| String recalc_shash = SHash2Str(shash); | ||
| CHECK_EQ(recalc_shash, str_shash) << "ValueError: Structural hash changed. Given: " << str_shash | ||
| << "; Recalculated: " << recalc_shash; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since ModuleEquality information is not written to json (Workload doesn't have ModuleEquality instance), we cannot determine the appropriate hashing method here. So I had to remove this verification.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps we could leave a TODO here in case of future changes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the removal make sense. On the other hand, we are always using structural hash as the shash. I think we may not really need the functionality to specify the shash direcly in constructor.
- What about we construct it from the mod and
mod_eq_namestring so that we can obtain the customized hash result? - Maybe we can store the
mod_eq_nameinWorkloadso that when we parse it we can still check the shash results?
What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To keep things clean, I want to minimize the number of unique ModuleEquality instances scattered throughout MS infra. Currently, only task extraction and Database "own" a ModuleEquality instance.
Given this goal, I think having Database compute a hash of Workload using its ModuleEquality is more natural than passing mod_eq_name to Workload and let it compute hash. Moreover, all Workload would end up having the same mod_eq_name, so it feels redundant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One way to restore hash verification is to do verification inside DatabaseNode after Workload::FromJSON is called, since the database does know what the ModuleEquality instance is supposed to be.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Restored hash verification in JSONDatabase, the only place where Workload::FromJSON is used currently.
b71ca1b to
0c8ee7f
Compare
junrushao
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise LGTM!
zxybazh
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for sending the new feature! It's very exciting to see such improvements!
Generally it's looking good. I have some minor suggestions over the Workload usage and question towards the hash map of workloads.
Other than these, currently I can see the only option of module equality functions defined on the c++ side, do you think we can also allow easier customization of the hash and equal function via providing FFI functions or by overridding the class from python side (i.e., providing a PyModuleEquality class)?
| shash = tvm::StructuralHash()(mod); | ||
| String recalc_shash = SHash2Str(shash); | ||
| CHECK_EQ(recalc_shash, str_shash) << "ValueError: Structural hash changed. Given: " << str_shash | ||
| << "; Recalculated: " << recalc_shash; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the removal make sense. On the other hand, we are always using structural hash as the shash. I think we may not really need the functionality to specify the shash direcly in constructor.
- What about we construct it from the mod and
mod_eq_namestring so that we can obtain the customized hash result? - Maybe we can store the
mod_eq_nameinWorkloadso that when we parse it we can still check the shash results?
What do you think?
@zxybazh This can be considered in future, but I don't expect that people would want to add a custom |
3058a68 to
95c01b7
Compare
zxybazh
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @masahi for addressing my questions and adding the verification, the PR looks good to me! Looking forward to the introduction of new module equality functions!
junrushao
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…ng NDArray raw data (#13091) A follow up to #13050, also builds on #13001. This PR enables the functionality in #12706 without changing the existing `StructuralEqual/Hash`. A question for discussion: Should this be the default ModuleEquality used by MS? It has no effect for the `link-params = False` case, and it simplifies the MS tuning API usage for the `link-params = True` case (Hexagon etc).
…e#13050) Currently, MS uses `StructuralEqual/Hash` in task extraction / evo search / database. Sometimes, we want to use different hashing and equality testing methods, for example (1) to ignore NDArray (apache#12706) or (2) to enable anchor-op only tuning (identify `conv2d` and `conv2d -> add` subgraphs as equal). To enable such flexibility, this PR consolidate raw calls to `StructuralEqual/Hash` into one place, which for now is named `ModuleEquality`. Since hashing is also done for equality testing, I think it is appropriate to call the component responsible for hashing / equality test that way. But other suggestions are welcome. Importantly, task extraction and database are now using the same hashing / equal method based on TIR mod, while previously task extraction was using a cache key-ed on relay mod.
…ng NDArray raw data (apache#13091) A follow up to apache#13050, also builds on apache#13001. This PR enables the functionality in apache#12706 without changing the existing `StructuralEqual/Hash`. A question for discussion: Should this be the default ModuleEquality used by MS? It has no effect for the `link-params = False` case, and it simplifies the MS tuning API usage for the `link-params = True` case (Hexagon etc).
…e#13050) Currently, MS uses `StructuralEqual/Hash` in task extraction / evo search / database. Sometimes, we want to use different hashing and equality testing methods, for example (1) to ignore NDArray (apache#12706) or (2) to enable anchor-op only tuning (identify `conv2d` and `conv2d -> add` subgraphs as equal). To enable such flexibility, this PR consolidate raw calls to `StructuralEqual/Hash` into one place, which for now is named `ModuleEquality`. Since hashing is also done for equality testing, I think it is appropriate to call the component responsible for hashing / equality test that way. But other suggestions are welcome. Importantly, task extraction and database are now using the same hashing / equal method based on TIR mod, while previously task extraction was using a cache key-ed on relay mod.
…ng NDArray raw data (apache#13091) A follow up to apache#13050, also builds on apache#13001. This PR enables the functionality in apache#12706 without changing the existing `StructuralEqual/Hash`. A question for discussion: Should this be the default ModuleEquality used by MS? It has no effect for the `link-params = False` case, and it simplifies the MS tuning API usage for the `link-params = True` case (Hexagon etc).
Currently, MS uses
StructuralEqual/Hashin task extraction / evo search / database. Sometimes, we want to use different hashing and equality testing methods, for example (1) to ignore NDArray (#12706) or (2) to enable anchor-op only tuning (identifyconv2dandconv2d -> addsubgraphs as equal).To enable such flexibility, this PR consolidate raw calls to
StructuralEqual/Hashinto one place, which for now is namedModuleEquality. Since hashing is also done for equality testing, I think it is appropriate to call the component responsible for hashing / equality test that way. But other suggestions are welcome.Importantly, task extraction and database are now using the same hashing / equal method based on TIR mod, while previously task extraction was using a cache key-ed on relay mod.
cc @junrushao @zxybazh @tqchen