-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[MetaSchedule][M3a] TuneContext #9053
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
fcb9cd4
72738ab
ae58a95
c261133
66eb668
caa7b87
7754e05
9bc31d3
1d32969
f60a6fc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,101 @@ | ||
| # Licensed to the Apache Software Foundation (ASF) under one | ||
| # or more contributor license agreements. See the NOTICE file | ||
| # distributed with this work for additional information | ||
| # regarding copyright ownership. The ASF licenses this file | ||
| # to you under the Apache License, Version 2.0 (the | ||
| # "License"); you may not use this file except in compliance | ||
| # with the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, | ||
| # software distributed under the License is distributed on an | ||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
| """Meta Schedule tuning context.""" | ||
|
|
||
| from typing import Optional | ||
|
|
||
| from tvm import IRModule | ||
| from tvm.runtime import Object | ||
| from tvm.target import Target | ||
| from tvm.meta_schedule.utils import cpu_count | ||
| from tvm._ffi import register_object | ||
|
|
||
| from . import _ffi_api | ||
|
|
||
|
|
||
| @register_object("meta_schedule.TuneContext") | ||
| class TuneContext(Object): | ||
| """ | ||
| The tune context class is designed to contain all resources for a tuning task. | ||
|
|
||
| Different tuning tasks are separated in different TuneContext classes, but different classes in | ||
| the same task can interact with each other through tune context. Most classes have a function | ||
| to initialize with a tune context. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| mod : Optional[IRModule] = None | ||
| The workload to be optimized. | ||
| target : Optional[Target] = None | ||
| The target to be optimized for. | ||
| task_name : Optional[str] = None | ||
| The name of the tuning task. | ||
| rand_state : int = -1 | ||
| The random state. | ||
| Need to be in integer in [1, 2^31-1], -1 means using random number. | ||
| num_threads : int = None | ||
| The number of threads to be used, None means using the logical cpu count. | ||
|
|
||
| Note | ||
| ---- | ||
| In most cases, mod and target should be available in the tuning context. They are "Optional" | ||
| because we allow the user to customize the tuning context, along with other classes, sometimes | ||
| without mod and target. E.g., we can have a stand alone search strategy that generates measure | ||
| candidates without initializing with the tune context. | ||
| """ | ||
|
|
||
| mod: Optional[IRModule] | ||
| target: Optional[Target] | ||
| task_name: Optional[str] | ||
| rand_state: int | ||
| num_threads: int | ||
|
|
||
| def __init__( | ||
| self, | ||
| mod: Optional[IRModule] = None, | ||
| target: Optional[Target] = None, | ||
| task_name: Optional[str] = None, | ||
| rand_state: int = -1, | ||
| num_threads: Optional[int] = None, | ||
| ): | ||
| """Constructor. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| mod : Optional[IRModule] = None | ||
| The workload to be optimized. | ||
| target : Optional[Target] = None | ||
| The target to be optimized for. | ||
| task_name : Optional[str] = None | ||
| The name of the tuning task. | ||
| rand_state : int = -1 | ||
| The random state. | ||
| Need to be in integer in [1, 2^31-1], -1 means using random number. | ||
| num_threads : Optional[int] = None | ||
| The number of threads to be used, None means using the logical cpu count. | ||
| """ | ||
| if num_threads is None: | ||
| num_threads = cpu_count() | ||
|
|
||
| self.__init_handle_by_constructor__( | ||
| _ffi_api.TuneContext, # type: ignore # pylint: disable=no-member | ||
| mod, | ||
| target, | ||
| task_name, | ||
| rand_state, | ||
| num_threads, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,64 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one | ||
| * or more contributor license agreements. See the NOTICE file | ||
| * distributed with this work for additional information | ||
| * regarding copyright ownership. The ASF licenses this file | ||
| * to you under the Apache License, Version 2.0 (the | ||
| * "License"); you may not use this file except in compliance | ||
| * with the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, | ||
| * software distributed under the License is distributed on an | ||
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| * KIND, either express or implied. See the License for the | ||
| * specific language governing permissions and limitations | ||
| * under the License. | ||
| */ | ||
| #include "./tune_context.h" | ||
|
|
||
| #include <random> | ||
| #include <utility> | ||
|
|
||
| namespace tvm { | ||
| namespace meta_schedule { | ||
|
|
||
| /*! | ||
| * \brief Constructor function of TuneContext class. | ||
| * \param mod The mod to be optimized. | ||
| * \param target The target to be optimized for. | ||
| * \param task_name The name of the tuning task. | ||
| * \param rand_state The random state. | ||
| * \param num_threads The number of threads to be used. | ||
| * \param verbose The verbosity level. | ||
| */ | ||
| TuneContext::TuneContext(Optional<IRModule> mod, // | ||
| Optional<Target> target, // | ||
| Optional<String> task_name, // | ||
| support::LinearCongruentialEngine::TRandState rand_state, // | ||
| int num_threads) { | ||
| ObjectPtr<TuneContextNode> n = make_object<TuneContextNode>(); | ||
| n->mod = mod; | ||
| n->target = target; | ||
| n->task_name = task_name; | ||
| if (rand_state == -1) { | ||
| rand_state = std::random_device()(); | ||
| } | ||
| support::LinearCongruentialEngine(&n->rand_state).Seed(rand_state); | ||
| n->num_threads = num_threads; | ||
| data_ = std::move(n); | ||
| } | ||
|
|
||
| TVM_REGISTER_NODE_TYPE(TuneContextNode); | ||
|
|
||
| TVM_REGISTER_GLOBAL("meta_schedule.TuneContext") | ||
| .set_body_typed([](Optional<IRModule> mod, // | ||
| Optional<Target> target, // | ||
| Optional<String> task_name, // | ||
| support::LinearCongruentialEngine::TRandState rand_state, // | ||
| int num_threads) -> TuneContext { | ||
| return TuneContext(mod, target, task_name, rand_state, num_threads); | ||
| }); | ||
| } // namespace meta_schedule | ||
| } // namespace tvm |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,80 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one | ||
| * or more contributor license agreements. See the NOTICE file | ||
| * distributed with this work for additional information | ||
| * regarding copyright ownership. The ASF licenses this file | ||
| * to you under the Apache License, Version 2.0 (the | ||
| * "License"); you may not use this file except in compliance | ||
| * with the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, | ||
| * software distributed under the License is distributed on an | ||
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| * KIND, either express or implied. See the License for the | ||
| * specific language governing permissions and limitations | ||
| * under the License. | ||
| */ | ||
| #ifndef TVM_META_SCHEDULE_TUNE_CONTEXT_H_ | ||
| #define TVM_META_SCHEDULE_TUNE_CONTEXT_H_ | ||
|
|
||
| #include <tvm/ir/module.h> | ||
| #include <tvm/support/random_engine.h> | ||
| #include <tvm/target/target.h> | ||
|
|
||
| namespace tvm { | ||
| namespace meta_schedule { | ||
|
|
||
| /*! \brief The auto tuning context. */ | ||
| class TuneContextNode : public runtime::Object { | ||
| public: | ||
| /*! \brief The workload to be tuned. */ | ||
| Optional<IRModule> mod; | ||
| /*! \brief The target to be tuned for. */ | ||
| Optional<Target> target; | ||
| /*! \brief The name of the tuning task. */ | ||
| Optional<String> task_name; | ||
| /*! \brief The random state. */ | ||
| support::LinearCongruentialEngine::TRandState rand_state; | ||
| /*! \brief The number of threads to be used. */ | ||
| int num_threads; | ||
|
|
||
| void VisitAttrs(tvm::AttrVisitor* v) { | ||
| v->Visit("mod", &mod); | ||
| v->Visit("target", &target); | ||
| v->Visit("task_name", &task_name); | ||
| v->Visit("rand_state", &rand_state); | ||
| v->Visit("num_threads", &num_threads); | ||
| } | ||
|
|
||
| static constexpr const char* _type_key = "meta_schedule.TuneContext"; | ||
| TVM_DECLARE_FINAL_OBJECT_INFO(TuneContextNode, Object); | ||
| }; | ||
|
|
||
| /*! | ||
| * \brief Managed reference to TuneContextNode. | ||
| * \sa TuneContextNode | ||
| */ | ||
| class TuneContext : public runtime::ObjectRef { | ||
| public: | ||
| /*! | ||
| * \brief Constructor. | ||
| * \param mod The workload to be tuned. | ||
| * \param target The target to be tuned for. | ||
| * \param task_name The name of the tuning task. | ||
| * \param rand_state The random state. | ||
| * \param num_threads The number of threads to be used. | ||
| */ | ||
| TVM_DLL explicit TuneContext(Optional<IRModule> mod, // | ||
| Optional<Target> target, // | ||
| Optional<String> task_name, // | ||
| support::LinearCongruentialEngine::TRandState rand_state, // | ||
| int num_threads); | ||
| TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TuneContext, ObjectRef, TuneContextNode); | ||
| }; | ||
|
|
||
| } // namespace meta_schedule | ||
| } // namespace tvm | ||
|
|
||
| #endif // TVM_META_SCHEDULE_TUNE_CONTEXT_H_ |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| # Licensed to the Apache Software Foundation (ASF) under one | ||
| # or more contributor license agreements. See the NOTICE file | ||
| # distributed with this work for additional information | ||
| # regarding copyright ownership. The ASF licenses this file | ||
| # to you under the Apache License, Version 2.0 (the | ||
| # "License"); you may not use this file except in compliance | ||
| # with the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, | ||
| # software distributed under the License is distributed on an | ||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
| """Test the tune context of meta schedule.""" | ||
|
|
||
| import sys | ||
| import pytest | ||
|
|
||
| import tvm | ||
| from tvm import tir | ||
| from tvm.script import ty | ||
| from tvm.target import Target | ||
| from tvm.meta_schedule import TuneContext | ||
|
|
||
| # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring | ||
|
|
||
|
|
||
| @tvm.script.tir | ||
| class Matmul: | ||
| def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # pylint: disable=no-self-argument | ||
| tir.func_attr({"global_symbol": "main", "tir.noalias": True}) | ||
| A = tir.match_buffer(a, (1024, 1024), "float32") | ||
| B = tir.match_buffer(b, (1024, 1024), "float32") | ||
| C = tir.match_buffer(c, (1024, 1024), "float32") | ||
| with tir.block([1024, 1024, tir.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: | ||
| with tir.init(): | ||
| C[vi, vj] = 0.0 | ||
| C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] | ||
|
|
||
|
|
||
| # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring | ||
|
|
||
|
|
||
| def test_tune_context_create(): | ||
| mod = Matmul() | ||
| context = TuneContext(mod, Target("llvm"), "Test Task") | ||
| assert context.num_threads > 0 | ||
| assert context.rand_state != -1 | ||
| assert context.task_name == "Test Task" | ||
| assert context.mod == mod or tvm.ir.structural_equal(context.mod, mod) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| sys.exit(pytest.main([__file__] + sys.argv[1:])) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,5 +29,6 @@ mypy --check-untyped-defs python/tvm/tir/analysis/ | |
| echo "Checking MyPy Type defs in the transform package." | ||
| mypy --check-untyped-defs python/tvm/tir/transform/ | ||
|
|
||
| echo "Checking MyPy Type defs in the tvm.relay.backend.contrib.ethosu package." | ||
| mypy --check-untyped-defs python/tvm/relay/backend/contrib/ethosu/ | ||
| #TODO(@mikepapadim): This is failing atm | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please create an issue if things like this happens and possibly ping codeowners please, we were not aware that any of the checked in code was not tested until recently.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure. I thought we did, but it turned out haven't...Please make sure to report in time :-)
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh actually we reported in this thread: #9050. definitely should submit it as a separate PR though
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea but that PR was closed saying not needed. Thus, I was under the impression that was never merged |
||
| # echo "Checking MyPy Type defs in the tvm.relay.backend.contrib.ethosu package." | ||
| # mypy --check-untyped-defs python/tvm/relay/backend/contrib/ethosu/ | ||
Uh oh!
There was an error while loading. Please reload this page.