From c3bdf42c23f5f093214abe2b9d006b2b24c338db Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Sun, 26 Sep 2021 14:48:42 -0700 Subject: [PATCH 1/8] Add c++ side SearchStrategy. Co-authored-by: Junru Shao Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin Co-authored-by: Siyuan Feng --- include/tvm/meta_schedule/runner.h | 57 ++++++ include/tvm/meta_schedule/search_strategy.h | 187 ++++++++++++++++++ include/tvm/support/random_engine.h | 10 + .../search_strategy/replay_trace.cc | 149 ++++++++++++++ .../search_strategy/search_strategy.cc | 68 +++++++ src/meta_schedule/utils.h | 74 +++++++ src/tir/schedule/concrete_schedule.cc | 4 +- src/tir/schedule/primitive.h | 8 + src/tir/schedule/primitive/sampling.cc | 12 ++ 9 files changed, 566 insertions(+), 3 deletions(-) create mode 100644 include/tvm/meta_schedule/runner.h create mode 100644 include/tvm/meta_schedule/search_strategy.h create mode 100644 src/meta_schedule/search_strategy/replay_trace.cc create mode 100644 src/meta_schedule/search_strategy/search_strategy.cc diff --git a/include/tvm/meta_schedule/runner.h b/include/tvm/meta_schedule/runner.h new file mode 100644 index 000000000000..111bc40e2016 --- /dev/null +++ b/include/tvm/meta_schedule/runner.h @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_META_SCHEDULE_RUNNER_H_ +#define TVM_META_SCHEDULE_RUNNER_H_ + +#include + +namespace tvm { +namespace meta_schedule { + +/*! \brief The runner's result. */ +class RunnerResultNode : public runtime::Object { + public: + /*! \brief The run time in seconds.*/ + Optional> run_secs; + /*! \brief The error message, if any. */ + Optional error_msg; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("run_secs", &run_secs); + v->Visit("error_msg", &error_msg); + } + + static constexpr const char* _type_key = "meta_schedule.RunnerResult"; + TVM_DECLARE_FINAL_OBJECT_INFO(RunnerResultNode, runtime::Object); +}; + +/*! + * \brief Managed reference to RunnerResultNode + * \sa RunnerResultNode + */ +class RunnerResult : public runtime::ObjectRef { + public: + TVM_DLL explicit RunnerResult(Optional> run_secs, Optional error_msg); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RunnerResult, runtime::ObjectRef, RunnerResultNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_RUNNER_H_ diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h new file mode 100644 index 000000000000..7e87181c6d3b --- /dev/null +++ b/include/tvm/meta_schedule/search_strategy.h @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_META_SCHEDULE_SEARCH_STRATEGY_H_ +#define TVM_META_SCHEDULE_SEARCH_STRATEGY_H_ + +#include + +#include "./arg_info.h" +#include "./runner.h" + +namespace tvm { +namespace meta_schedule { + +// Forward declaration +class TuneContext; + +class MeasureCandidateNode : public runtime::Object { + public: + tir::Schedule sch; + Array args_info; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("sch", &sch); + v->Visit("args_info", &args_info); + } + + static constexpr const char* _type_key = "meta_schedule.MeasureCandidate"; + TVM_DECLARE_FINAL_OBJECT_INFO(MeasureCandidateNode, Object); +}; + +class MeasureCandidate : public runtime::ObjectRef { + public: + TVM_DLL MeasureCandidate(tir::Schedule sch, Array args_info); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(MeasureCandidate, ObjectRef, MeasureCandidateNode); +}; + +/*! \brief The search strategy for measure candidates generation. */ +class SearchStrategyNode : public runtime::Object { + public: + /*! \brief Virtual destructor */ + virtual ~SearchStrategyNode() = default; + + /*! + * \brief Initialize the search strategy with tuning context. + * \param tune_context The tuning context for initialization. + */ + virtual void InitializeWithTuneContext(const TuneContext& tune_context) = 0; + + /*! + * \brief Pre-tuning for the search strategy. + * \param design_spaces The design spaces for pre-tuning. + */ + virtual void PreTuning(const Array& design_spaces) = 0; + + /*! \brief Post-tuning for the search strategy. */ + virtual void PostTuning() = 0; + + /*! + * \brief Generate measure candidates from design spaces for measurement. + * \return The measure candidates generated, nullptr if finished. + */ + virtual Optional> GenerateMeasureCandidates() = 0; + + /*! + * \brief Update the search strategy with profiling results. + * \param results The profiling results from the runner. + */ + virtual void NotifyRunnerResults(const Array& results) = 0; + + static constexpr const char* _type_key = "meta_schedule.SearchStrategy"; + TVM_DECLARE_BASE_OBJECT_INFO(SearchStrategyNode, Object); +}; + +/*! \brief The python side customizable class for measure candidate generation */ +class PySearchStrategyNode : public SearchStrategyNode { + public: + /*! + * \brief The function type of `InitializeWithTuneContext` method. + * \param tune_context The tuning context for initialization. + */ + using FInitializeWithTuneContext = runtime::TypedPackedFunc; + /*! + * \brief The function type of `PreTuning` method. + * \param design_spaces The design spaces for pre-tuning. + */ + using FPreTuning = runtime::TypedPackedFunc&)>; + /*! \brief The function type of `PostTuning` method. */ + using FPostTuning = runtime::TypedPackedFunc; + /*! + * \brief The function type of `GenerateMeasureCandidates` method. + * \return The measure candidates generated, nullptr if finished. + */ + using FGenerateMeasureCandidates = runtime::TypedPackedFunc>()>; + /*! + * \brief The function type of `NotifyRunnerResults` method. + * \param results The profiling results from the runner. + */ + using FNotifyRunnerResults = runtime::TypedPackedFunc&)>; + + /*! \brief The packed function to the `InitializeWithTuneContext` method. */ + FInitializeWithTuneContext f_initialize_with_tune_context; + /*! \brief The packed function to the `PreTuning` method. */ + FPreTuning f_pre_tuning; + /*! \brief The packed function to the `PostTuning` method. */ + FPostTuning f_post_tuning; + /*! \brief The packed function to the `GenerateMeasureCandidates` method. */ + FGenerateMeasureCandidates f_generate_measure_candidates; + /*! \brief The packed function to the `NotifyRunnerResults` method. */ + FNotifyRunnerResults f_notify_runner_results; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_initialize_with_tune_context` is not visited + // `f_pre_tuning` is not visited + // `f_post_tuning` is not visited + // `f_generate_measure_candidates` is not visited + // `f_notify_runner_results` is not visited + } + + void InitializeWithTuneContext(const TuneContext& context) final { + this->f_initialize_with_tune_context(context); + } + + void PreTuning(const Array& design_spaces) final { + this->f_pre_tuning(design_spaces); + } + + void PostTuning() final { this->f_post_tuning(); } + + Optional> GenerateMeasureCandidates() final { + return this->f_generate_measure_candidates(); + } + + void NotifyRunnerResults(const Array& results) final { + this->f_notify_runner_results(results); + } + + static constexpr const char* _type_key = "meta_schedule.PySearchStrategy"; + TVM_DECLARE_FINAL_OBJECT_INFO(PySearchStrategyNode, SearchStrategyNode); +}; + +/*! + * \brief Managed reference to SearchStrategyNode. + * \sa SearchStrategyNode + */ +class SearchStrategy : public runtime::ObjectRef { + public: + /*! + * \brief Create a search strategy with customized methods on the python-side. + * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. + * \param f_pre_tuning The packed function of `PreTuning`. + * \param f_post_tuning The packed function of `PostTuning`. + * \param f_generate_measure_candidates The packed function of `GenerateMeasureCandidates`. + * \param f_notify_runner_results The packed function of `NotifyRunnerResults`. + * \return The search strategy created. + */ + TVM_DLL static SearchStrategy PySearchStrategy( + PySearchStrategyNode::FInitializeWithTuneContext f_initialize_with_tune_context, // + PySearchStrategyNode::FPreTuning f_pre_tuning, // + PySearchStrategyNode::FPostTuning f_post_tuning, // + PySearchStrategyNode::FGenerateMeasureCandidates f_generate_measure_candidates, // + PySearchStrategyNode::FNotifyRunnerResults f_notify_runner_results); + + TVM_DLL static SearchStrategy ReplayTrace(int num_trials_per_iter, int num_trials_total); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchStrategy, ObjectRef, SearchStrategyNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_SEARCH_STRATEGY_H_ diff --git a/include/tvm/support/random_engine.h b/include/tvm/support/random_engine.h index 6b733d074f6a..fcd2326050ed 100644 --- a/include/tvm/support/random_engine.h +++ b/include/tvm/support/random_engine.h @@ -102,6 +102,16 @@ class LinearCongruentialEngine { *rand_state_ptr_ = rand_state; // Change pointed random state to given random state value. } + /*! + * \brief Fork a new seed for another RNG from current random state. + * \return The forked seed. + */ + TRandState ForkSeed() { + // In order for reproducibility, we computer the new seed using RNG's random state and a + // different set of parameters. Note that both 32767 and 1999999973 are prime numbers. + return ((*this)() * 32767) % 1999999973; + } + /*! * \brief Construct a random number generator with a random state pointer. * \param rand_state_ptr The random state pointer given in result_type*. diff --git a/src/meta_schedule/search_strategy/replay_trace.cc b/src/meta_schedule/search_strategy/replay_trace.cc new file mode 100644 index 000000000000..e10b6af3ba76 --- /dev/null +++ b/src/meta_schedule/search_strategy/replay_trace.cc @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +/*! \brief A search strategy that replays the trace. */ +class ReplayTraceNode : public SearchStrategyNode { + public: + using TRandState = support::LinearCongruentialEngine::TRandState; + + /*! \brief The state of the search strategy. */ + struct State { + /*! \brief The search strategy itself */ + ReplayTraceNode* self; + /*! \brief The design spaces. */ + Array design_spaces; + /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ + int st; + /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ + int ed; + + explicit State(ReplayTraceNode* self, Array design_spaces) + : self(self), design_spaces(design_spaces), st(0), ed(self->num_trials_per_iter) {} + + inline Optional> GenerateMeasureCandidates(); + inline void NotifyRunnerResults(const Array& results); + }; + + /*! \brief The number of trials per iteration. */ + int num_trials_per_iter; + /*! \brief The number of total trials. */ + int num_trials_total; + + /*! \brief The module to be tuned. */ + IRModule mod_{nullptr}; + /*! \brief The metadata of the function arguments. */ + Array args_info_{nullptr}; + /*! \brief The number of threads to use. */ + int num_threads_ = -1; + /*! \brief The random state */ + TRandState rand_state_ = -1; + /*! \brief The state of the search strategy. */ + std::unique_ptr state_ = nullptr; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("num_trials_per_iter", &num_trials_per_iter); + v->Visit("num_trials_total", &num_trials_total); + // `mod_` is not visited + // `args_info_` is not visited + // `num_threads_` is not visited + // `rand_state_` is not visited + // `state_` is not visited + } + + static constexpr const char* _type_key = "meta_schedule.ReplayTrace"; + TVM_DECLARE_FINAL_OBJECT_INFO(ReplayTraceNode, SearchStrategyNode); + + public: + void InitializeWithTuneContext(const TuneContext& tune_context) final { + this->mod_ = tune_context->mod.value(); + this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(this->mod_)); + this->num_threads_ = tune_context->num_threads; + this->rand_state_ = ForkSeed(&tune_context->rand_state); + this->state_.reset(); + } + + void PreTuning(const Array& design_spaces) final { + ICHECK(!design_spaces.empty()); + ICHECK(this->state_ == nullptr); + this->state_ = std::make_unique(this, design_spaces); + } + + void PostTuning() final { + ICHECK(this->state_ != nullptr); + this->state_.reset(); + } + + Optional> GenerateMeasureCandidates() final { + ICHECK(this->state_ != nullptr); + return this->state_->GenerateMeasureCandidates(); + } + + void NotifyRunnerResults(const Array& results) final { + ICHECK(this->state_ != nullptr); + this->state_->NotifyRunnerResults(results); + } +}; + +inline Optional> ReplayTraceNode::State::GenerateMeasureCandidates() { + if (st >= self->num_trials_total) { + return NullOpt; + } + ed = std::min(ed, self->num_trials_total); + ICHECK_LT(st, ed); + std::vector per_thread_rand_state = ForkSeed(&self->rand_state_, self->num_threads_); + Array per_task_result(ed - st, MeasureCandidate{nullptr}); + auto f_worker = [this, &per_thread_rand_state, &per_task_result](int thread_id, + int task_id) -> void { + TRandState& rand_state = per_thread_rand_state[thread_id]; + int design_space_index = tir::SampleInt(&rand_state, 0, design_spaces.size()); + tir::Trace trace = design_spaces[design_space_index]->trace().value(); + tir::Trace new_trace = tir::Trace(trace->insts, {}); + tir::Schedule sch = tir::Schedule::Traced( // + self->mod_, // + /*rand_state=*/ForkSeed(&rand_state), // + /*debug_mode=*/0, // + /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone); + new_trace->ApplyToSchedule(sch, /*remove_postproc=*/true); + per_task_result.Set(task_id, MeasureCandidate(sch, self->args_info_)); + }; + support::parallel_for_dynamic(0, ed - st, self->num_threads_, f_worker); + return per_task_result; +} + +inline void ReplayTraceNode::State::NotifyRunnerResults(const Array& results) { + st += self->num_trials_per_iter; + ed += self->num_trials_per_iter; +} + +SearchStrategy SearchStrategy::ReplayTrace(int num_trials_per_iter, int num_trials_total) { + ObjectPtr n = make_object(); + n->num_trials_per_iter = num_trials_per_iter; + n->num_trials_total = num_trials_total; + return SearchStrategy(n); +} + +TVM_REGISTER_NODE_TYPE(ReplayTraceNode); +TVM_REGISTER_GLOBAL("meta_schedule.ReplayTrace").set_body_typed(SearchStrategy::ReplayTrace); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/search_strategy/search_strategy.cc b/src/meta_schedule/search_strategy/search_strategy.cc new file mode 100644 index 000000000000..fefe8dfce76e --- /dev/null +++ b/src/meta_schedule/search_strategy/search_strategy.cc @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +MeasureCandidate::MeasureCandidate(tir::Schedule sch, Array args_info) { + ObjectPtr n = make_object(); + n->sch = sch; + n->args_info = args_info; + data_ = std::move(n); +} + +SearchStrategy SearchStrategy::PySearchStrategy( + PySearchStrategyNode::FInitializeWithTuneContext f_initialize_with_tune_context, // + PySearchStrategyNode::FPreTuning f_pre_tuning, // + PySearchStrategyNode::FPostTuning f_post_tuning, // + PySearchStrategyNode::FGenerateMeasureCandidates f_generate_measure_candidates, // + PySearchStrategyNode::FNotifyRunnerResults f_notify_runner_results) { + ObjectPtr n = make_object(); + n->f_initialize_with_tune_context = f_initialize_with_tune_context; + n->f_pre_tuning = f_pre_tuning; + n->f_post_tuning = f_post_tuning; + n->f_generate_measure_candidates = f_generate_measure_candidates; + n->f_notify_runner_results = f_notify_runner_results; + return SearchStrategy(n); +} + +TVM_REGISTER_NODE_TYPE(MeasureCandidateNode); +TVM_REGISTER_OBJECT_TYPE(SearchStrategyNode); +TVM_REGISTER_NODE_TYPE(PySearchStrategyNode); + +TVM_REGISTER_GLOBAL("meta_schedule.MeasureCandidate") + .set_body_typed([](tir::Schedule sch, Array args_info) -> MeasureCandidate { + return MeasureCandidate(sch, args_info); + }); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyPySearchStrategy") + .set_body_typed(SearchStrategy::PySearchStrategy); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyInitializeWithTuneContext") + .set_body_method(&SearchStrategyNode::InitializeWithTuneContext); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyPreTuning") + .set_body_method(&SearchStrategyNode::PreTuning); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyPostTuning") + .set_body_method(&SearchStrategyNode::PostTuning); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyGenerateMeasureCandidates") + .set_body_method(&SearchStrategyNode::GenerateMeasureCandidates); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyNotifyRunnerResults") + .set_body_method(&SearchStrategyNode::NotifyRunnerResults); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 4c9e1e2c10a1..90a27bb77207 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -23,16 +23,21 @@ #include #include #include +#include +#include #include #include #include #include +#include #include #include +#include "../printer/text_printer.h" #include "../support/array.h" #include "../support/base64.h" +#include "../tir/schedule/primitive.h" namespace tvm { namespace meta_schedule { @@ -131,6 +136,75 @@ inline String JSONObj2Str(const ObjectRef& json_obj) { */ inline String SHash2Str(Workload::THashCode hash_code) { return std::to_string(hash_code); } +/*! + * \brief Find the entry function of the given IRModule. + * \param mod The IRModule to find the entry function. + * \return The entry function. + */ +inline tir::PrimFunc FindEntryFunc(const IRModule& mod) { + // Priority 1: PrimFunc marked as `tir::attr::kIsEntryFunc` + int num_prim_func = 0; + const tir::PrimFuncNode* main_func = nullptr; + const tir::PrimFuncNode* last_func = nullptr; + for (const auto& kv : mod->functions) { + GlobalVar gv = kv.first; + BaseFunc base_func = kv.second; + if (const auto* func = base_func.as()) { + last_func = func; + if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { + return GetRef(func); + } + if (gv->name_hint == "main") { + main_func = func; + } + ++num_prim_func; + } + } + // Priority 2: PrimFunc whose name is `main` + if (main_func != nullptr) { + return GetRef(main_func); + } + // Priority 3: The only PrimFunc in the IRModule + if (num_prim_func == 0) { + LOG(FATAL) << "ValueError: Cannot find any PrimFunc in the given IRModule: " + << tir::AsTVMScript(mod); + } + if (num_prim_func > 1) { + LOG(FATAL) << "ValueError: Multiple PrimFuncs exist in the IRModule, but none of them are " + "annotated with `kIsEntryFunc`, i.e. `tir.is_entry_func`" + << tir::AsTVMScript(mod); + } + return GetRef(last_func); +} + +/*! + * \brief Fork a random state into another, i.e. PRNG splitting. + * The given random state is also mutated. + * \param rand_state The random state to be forked + * \return The forked random state + */ +inline support::LinearCongruentialEngine::TRandState ForkSeed( + support::LinearCongruentialEngine::TRandState* rand_state) { + return support::LinearCongruentialEngine(rand_state).ForkSeed(); +} + +/*! + * \brief Fork a random state into another ones, i.e. PRNG splitting. + * The given random state is also mutated. + * \param rand_state The random state to be forked + * \param n The number of forks + * \return The forked random states + */ +inline std::vector ForkSeed( + support::LinearCongruentialEngine::TRandState* rand_state, int n) { + std::vector results; + results.reserve(n); + for (int i = 0; i < n; ++i) { + results.push_back(support::LinearCongruentialEngine(rand_state).ForkSeed()); + } + return results; +} + } // namespace meta_schedule } // namespace tvm diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 07af73ebabb6..93eba520f9d1 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -220,9 +220,7 @@ void ConcreteScheduleNode::Seed(support::LinearCongruentialEngine::TRandState se } support::LinearCongruentialEngine::TRandState ConcreteScheduleNode::ForkSeed() { - // In order for reproducibility, we computer the new seed using RNG's random state and a different - // set of parameters. Note that both 32767 and 1999999973 are prime numbers. - return (support::LinearCongruentialEngine(&rand_state_)() * 32767) % 1999999973; + return support::LinearCongruentialEngine(&rand_state_).ForkSeed(); } ExprRV ConcreteScheduleNode::SampleCategorical(const Array& candidates, diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 8ad6bdf7d37f..8d8acd2693f4 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -26,6 +26,14 @@ namespace tvm { namespace tir { /******** Schedule: Sampling ********/ +/*! + * \brief Sample a random integer from a given range. + * \param min_inclusive The minimum value of the range, inclusive. + * \param max_exclusive The maximum value of the range, exclusive. + * \return The random integer sampled in the given range. + */ +TVM_DLL int SampleInt(support::LinearCongruentialEngine::TRandState* rand_state, int min_inclusive, + int max_exclusive); /*! * \brief Sample once category from candidates according to the probability weights. * \param self The schedule to update diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 8843ac613179..6ac6226118cd 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -24,6 +24,18 @@ namespace tvm { namespace tir { +int SampleInt(support::LinearCongruentialEngine::TRandState* rand_state, int min_inclusive, + int max_exclusive) { + CHECK(min_inclusive < max_exclusive) + << "ValueError: max_exclusive must be greater than min_inclusive."; + if (min_inclusive + 1 == max_exclusive) { + return min_inclusive; + } + support::LinearCongruentialEngine rand_(rand_state); + std::uniform_int_distribution dist(min_inclusive, max_exclusive - 1); + return dist(rand_); +} + int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state, const Array& candidates, const Array& probs, Optional* decision) { From 580545cf046eebb8c2c3204186fa6580b6bc26e1 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Sun, 26 Sep 2021 15:06:46 -0700 Subject: [PATCH 2/8] Add python-side code & test. Co-authored-by: Junru Shao Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin Co-authored-by: Siyuan Feng --- python/tvm/meta_schedule/__init__.py | 2 + python/tvm/meta_schedule/runner/__init__.py | 18 ++ python/tvm/meta_schedule/runner/runner.py | 60 +++++++ .../meta_schedule/search_strategy/__init__.py | 20 +++ .../search_strategy/replay_trace.py | 47 +++++ .../search_strategy/search_strategy.py | 166 ++++++++++++++++++ src/meta_schedule/runner/runner.cc | 41 +++++ .../test_meta_schedule_search_strategy.py | 98 +++++++++++ 8 files changed, 452 insertions(+) create mode 100644 python/tvm/meta_schedule/runner/__init__.py create mode 100644 python/tvm/meta_schedule/runner/runner.py create mode 100644 python/tvm/meta_schedule/search_strategy/__init__.py create mode 100644 python/tvm/meta_schedule/search_strategy/replay_trace.py create mode 100644 python/tvm/meta_schedule/search_strategy/search_strategy.py create mode 100644 src/meta_schedule/runner/runner.cc create mode 100644 tests/python/unittest/test_meta_schedule_search_strategy.py diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py index f8b2b026c83b..c22cc205bf35 100644 --- a/python/tvm/meta_schedule/__init__.py +++ b/python/tvm/meta_schedule/__init__.py @@ -19,5 +19,7 @@ from . import builder from . import database from . import space_generator +from . import search_strategy +from . import runner from .database import TuningRecord from .tune_context import TuneContext diff --git a/python/tvm/meta_schedule/runner/__init__.py b/python/tvm/meta_schedule/runner/__init__.py new file mode 100644 index 000000000000..65d2ef04e04c --- /dev/null +++ b/python/tvm/meta_schedule/runner/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""meta_schedule.runner""" +from .runner import RunnerResult diff --git a/python/tvm/meta_schedule/runner/runner.py b/python/tvm/meta_schedule/runner/runner.py new file mode 100644 index 000000000000..e8170203f543 --- /dev/null +++ b/python/tvm/meta_schedule/runner/runner.py @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Runners""" +from typing import List, Optional + +from tvm._ffi import register_object +from tvm.runtime import Object + +from .. import _ffi_api +from ..arg_info import ArgInfo + + +@register_object("meta_schedule.RunnerResult") +class RunnerResult(Object): + """The runner's result + + Parameters + ---------- + run_secs : Optional[List[float]] + The run time in seconds. + error_msg : Optional[str] + The error message, if any. + """ + + run_secs: Optional[List[float]] + error_msg: Optional[str] + + def __init__( + self, + run_secs: Optional[List[float]], + error_msg: Optional[str], + ) -> None: + """Constructor + + Parameters + ---------- + run_secs : Optional[List[float]] + The run time in seconds. + error_msg : Optional[str] + The error message, if any. + """ + self.__init_handle_by_constructor__( + _ffi_api.RunnerResult, # type: ignore # pylint: disable=no-member + run_secs, + error_msg, + ) diff --git a/python/tvm/meta_schedule/search_strategy/__init__.py b/python/tvm/meta_schedule/search_strategy/__init__.py new file mode 100644 index 000000000000..40f21da0b2d1 --- /dev/null +++ b/python/tvm/meta_schedule/search_strategy/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Search Strategy""" + +from .search_strategy import SearchStrategy, PySearchStrategy +from .replay_trace import ReplayTrace diff --git a/python/tvm/meta_schedule/search_strategy/replay_trace.py b/python/tvm/meta_schedule/search_strategy/replay_trace.py new file mode 100644 index 000000000000..3afdff6de77e --- /dev/null +++ b/python/tvm/meta_schedule/search_strategy/replay_trace.py @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Replay Trace Search Strategy""" + +from tvm._ffi import register_object +from .search_strategy import SearchStrategy +from .. import _ffi_api + + +@register_object("meta_schedule.ReplayTrace") +class ReplayTrace(SearchStrategy): + """ + Replay Trace Search Strategy is a search strategy that always replays the trace by removing its + decisions so that the decisions would be randomly re-generated. + + Parameters + ---------- + num_trials_per_iter : int + Number of trials per iteration. + num_trials_total : int + Total number of trials. + """ + + num_trials_per_iter: int + num_trials_total: int + + def __init__(self, num_trials_per_iter: int, num_trials_total: int): + """Constructor""" + self.__init_handle_by_constructor__( + _ffi_api.ReplayTrace, # pylint: disable=no-member + num_trials_per_iter, + num_trials_total, + ) diff --git a/python/tvm/meta_schedule/search_strategy/search_strategy.py b/python/tvm/meta_schedule/search_strategy/search_strategy.py new file mode 100644 index 000000000000..72713155c41d --- /dev/null +++ b/python/tvm/meta_schedule/search_strategy/search_strategy.py @@ -0,0 +1,166 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Search Strategy""" + +from typing import List, Optional, TYPE_CHECKING + +from tvm._ffi import register_object +from tvm.runtime import Object +from tvm.tir.schedule import Schedule + +from .. import _ffi_api +from ..arg_info import ArgInfo +from ..runner import RunnerResult + +if TYPE_CHECKING: + from ..tune_context import TuneContext + + +@register_object("meta_schedule.MeasureCandidate") +class MeasureCandidate(Object): + """Measure candidate class. + + Parameters + ---------- + sch : Schedule + The schedule to be measured. + args_info : List[ArgInfo] + The argument information. + """ + + sch: Schedule + args_info: List[ArgInfo] + + def __init__(self, sch: Schedule, args_info: List[ArgInfo]) -> None: + """Constructor. + + Parameters + ---------- + sch : Schedule + The schedule to be measured. + args_info : List[ArgInfo] + The argument information. + """ + self.__init_handle_by_constructor__( + _ffi_api.MeasureCandidate, # pylint: disable=no-member + sch, + args_info, + ) + + +@register_object("meta_schedule.SearchStrategy") +class SearchStrategy(Object): + """ + Search strategy is the class that generates the measure candidates. It has to be pre-tuned + before usage and post-tuned after usage. + """ + + def initialize_with_tune_context( + self, + tune_context: "TuneContext", + ) -> None: + """Initialize the search strategy with tuning context. + + Parameters + ---------- + tune_context : TuneContext + The tuning context for initialization. + """ + _ffi_api.SearchStrategyInitializeWithTuneContext( # pylint: disable=no-member + self, tune_context + ) + + def pre_tuning(self, design_spaces: List[Schedule]) -> None: + """Pre-tuning for the search strategy. + + Parameters + ---------- + design_spaces : List[Schedule] + The design spaces for pre-tuning. + """ + _ffi_api.SearchStrategyPreTuning(self, design_spaces) # pylint: disable=no-member + + def post_tuning(self) -> None: + """Post-tuning for the search strategy.""" + _ffi_api.SearchStrategyPostTuning(self) # pylint: disable=no-member + + def generate_measure_candidates(self) -> Optional[List[MeasureCandidate]]: + """Generate measure candidates from design spaces for measurement. + + Returns + ------- + measure_candidates : Optional[List[IRModule]] + The measure candidates generated, None if finished. + """ + return _ffi_api.SearchStrategyGenerateMeasureCandidates(self) # pylint: disable=no-member + + def notify_runner_results(self, results: List[RunnerResult]) -> None: + """Update the search strategy with profiling results. + + Parameters + ---------- + results : List[RunnerResult] + The profiling results from the runner. + """ + _ffi_api.SearchStrategyNotifyRunnerResults(self, results) # pylint: disable=no-member + + +@register_object("meta_schedule.PySearchStrategy") +class PySearchStrategy(SearchStrategy): + """An abstract search strategy with customized methods on the python-side.""" + + def __init__(self): + """Constructor.""" + + def f_initialize_with_tune_context(context: "TuneContext") -> None: + self.initialize_with_tune_context(context) + + def f_pre_tuning(design_spaces: List[Schedule]) -> None: + self.pre_tuning(design_spaces) + + def f_post_tuning() -> None: + self.post_tuning() + + def f_generate_measure_candidates() -> List[MeasureCandidate]: + return self.generate_measure_candidates() + + def f_notify_runner_results(results: List["RunnerResult"]) -> None: + self.notify_runner_results(results) + + self.__init_handle_by_constructor__( + _ffi_api.SearchStrategyPySearchStrategy, # pylint: disable=no-member + f_initialize_with_tune_context, + f_pre_tuning, + f_post_tuning, + f_generate_measure_candidates, + f_notify_runner_results, + ) + + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + raise NotImplementedError + + def pre_tuning(self, design_spaces: List[Schedule]) -> None: + raise NotImplementedError + + def post_tuning(self) -> None: + raise NotImplementedError + + def generate_measure_candidates(self) -> List[MeasureCandidate]: + raise NotImplementedError + + def notify_runner_results(self, results: List["RunnerResult"]) -> None: + raise NotImplementedError diff --git a/src/meta_schedule/runner/runner.cc b/src/meta_schedule/runner/runner.cc new file mode 100644 index 000000000000..8f509bdd7b84 --- /dev/null +++ b/src/meta_schedule/runner/runner.cc @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +RunnerResult::RunnerResult(Optional> run_secs, Optional error_msg) { + ObjectPtr n = make_object(); + n->run_secs = run_secs; + n->error_msg = error_msg; + this->data_ = n; +} + +TVM_REGISTER_NODE_TYPE(RunnerResultNode); + +TVM_REGISTER_GLOBAL("meta_schedule.RunnerResult") + .set_body_typed([](Array run_secs, Optional error_msg) -> RunnerResult { + return RunnerResult(run_secs, error_msg); + }); + +} // namespace meta_schedule +} // namespace tvm diff --git a/tests/python/unittest/test_meta_schedule_search_strategy.py b/tests/python/unittest/test_meta_schedule_search_strategy.py new file mode 100644 index 000000000000..6e90bddb84b4 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_search_strategy.py @@ -0,0 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Test Meta Schedule SearchStrategy """ +# pylint: disable=missing-function-docstring +from typing import List + +import sys + +import pytest + +import tvm +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.runner import RunnerResult +from tvm.meta_schedule.space_generator import ScheduleFn +from tvm.meta_schedule.search_strategy import SearchStrategy, ReplayTrace + +from tvm.script import ty +from tvm.tir.schedule import Schedule, Trace + + +MATMUL_M = 32 + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, unbalanced-tuple-unpacking +# fmt: off + +@tvm.script.tir +class Matmul: + def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + tir.func_attr({"global_symbol": "main"}) + A = tir.match_buffer(a, (32, 32), "float32") + B = tir.match_buffer(b, (32, 32), "float32") + C = tir.match_buffer(c, (32, 32), "float32") + with tir.block([32, 32, tir.reduce_axis(0, 32)], "matmul") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument + + +def _is_trace_equal(sch_1: Schedule, sch_2: Schedule) -> bool: + trace_1 = Trace(sch_1.trace.insts, {}) + trace_2 = Trace(sch_2.trace.insts, {}) + return str(trace_1) == str(trace_2) + + +def _schedule_matmul(sch: Schedule): + block = sch.get_block("matmul") + i, j, k = sch.get_loops(block=block) + # TODO(@zxybazh): Change to `sample_perfect_tile` after upstreaming + i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 4, 64, 2]) + j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[4, 64, 2, 2]) + k_0, k_1 = sch.split(loop=k, factors=[32, 32]) + sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) + + +def test_meta_schedule_replay_trace(): + num_trials_per_iter = 7 + num_trials_total = 20 + + (example_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul()) + replay = ReplayTrace(num_trials_per_iter=num_trials_per_iter, num_trials_total=num_trials_total) + tune_context = TuneContext(mod=Matmul()) + replay.initialize_with_tune_context(tune_context) + + num_trials_each_round: List[int] = [] + replay.pre_tuning([example_sch]) + while True: + candidates = replay.generate_measure_candidates() + if candidates is None: + break + num_trials_each_round.append(len(candidates)) + runner_results: List[RunnerResult] = [] + for candidate in candidates: + assert _is_trace_equal(candidate.sch, example_sch) + runner_results.append(RunnerResult(run_secs=[0.5, 0.4, 0.3], error_msg=None)) + replay.notify_runner_results(runner_results) + replay.post_tuning() + assert num_trials_each_round == [7, 7, 6] + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) From 11570e5d515d01ac8baffcfa241bff3b415e7e6b Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 27 Sep 2021 10:57:44 -0700 Subject: [PATCH 3/8] Add docs. --- include/tvm/meta_schedule/runner.h | 5 +++++ include/tvm/meta_schedule/search_strategy.h | 12 ++++++++++++ src/meta_schedule/utils.h | 2 ++ 3 files changed, 19 insertions(+) diff --git a/include/tvm/meta_schedule/runner.h b/include/tvm/meta_schedule/runner.h index 111bc40e2016..390fa661a0a8 100644 --- a/include/tvm/meta_schedule/runner.h +++ b/include/tvm/meta_schedule/runner.h @@ -47,6 +47,11 @@ class RunnerResultNode : public runtime::Object { */ class RunnerResult : public runtime::ObjectRef { public: + /*! + * \brief Constructor for RunnerResult. + * \param run_secs The run time in seconds. + * \param error_msg The error message, if any. + */ TVM_DLL explicit RunnerResult(Optional> run_secs, Optional error_msg); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RunnerResult, runtime::ObjectRef, RunnerResultNode); }; diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h index 7e87181c6d3b..2ea659d8a16e 100644 --- a/include/tvm/meta_schedule/search_strategy.h +++ b/include/tvm/meta_schedule/search_strategy.h @@ -30,9 +30,12 @@ namespace meta_schedule { // Forward declaration class TuneContext; +/*! \brief The measure candidate class. */ class MeasureCandidateNode : public runtime::Object { public: + /*! \brief The schedule for profiling. */ tir::Schedule sch; + /*! \brief The argument information. */ Array args_info; void VisitAttrs(tvm::AttrVisitor* v) { @@ -44,8 +47,17 @@ class MeasureCandidateNode : public runtime::Object { TVM_DECLARE_FINAL_OBJECT_INFO(MeasureCandidateNode, Object); }; +/*! + * \brief Managed reference to MeasureCandidateNode. + * \sa MeasureCandidateNode + */ class MeasureCandidate : public runtime::ObjectRef { public: + /*! + * \brief Constructor of MeasureCandidate. + * \param sch The schedule for profiling. + * \param args_info The argument information. + */ TVM_DLL MeasureCandidate(tir::Schedule sch, Array args_info); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(MeasureCandidate, ObjectRef, MeasureCandidateNode); }; diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 90a27bb77207..6cc5fd99ca69 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -34,6 +34,8 @@ #include +#include + #include "../printer/text_printer.h" #include "../support/array.h" #include "../support/base64.h" From 9affcfb87625f8e95683828bc506cbcf56019173 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 27 Sep 2021 11:03:41 -0700 Subject: [PATCH 4/8] Minor fix. --- python/tvm/meta_schedule/runner/runner.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/meta_schedule/runner/runner.py b/python/tvm/meta_schedule/runner/runner.py index e8170203f543..b756c6e6b011 100644 --- a/python/tvm/meta_schedule/runner/runner.py +++ b/python/tvm/meta_schedule/runner/runner.py @@ -21,7 +21,6 @@ from tvm.runtime import Object from .. import _ffi_api -from ..arg_info import ArgInfo @register_object("meta_schedule.RunnerResult") From 62b95457a343efc1fd55e04e00b9f277ef735c61 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 27 Sep 2021 22:18:37 -0700 Subject: [PATCH 5/8] Add workflow. --- include/tvm/meta_schedule/search_strategy.h | 39 ++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h index 2ea659d8a16e..19a7074bd1b1 100644 --- a/include/tvm/meta_schedule/search_strategy.h +++ b/include/tvm/meta_schedule/search_strategy.h @@ -62,7 +62,44 @@ class MeasureCandidate : public runtime::ObjectRef { TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(MeasureCandidate, ObjectRef, MeasureCandidateNode); }; -/*! \brief The search strategy for measure candidates generation. */ +/*! + +* \brief The search strategy for measure candidates generation. +* \note The relationship between SearchStrategy and other classes are as follows: + + ┌──────────────────────────────────────────────────────────────┐ + ┌──┴───────────────────────────────────────────────────────────┐ │ +┌──┴────────────────── Tune Context ───────────────────────────┐ │ │ +│ ┌─────────────────────┐ │ │ │ +│ │ │ Generate │ │ │ +│ │ Space Generator ├──────────────┐ │ │ │ +│ │ │ │ │ │ │ +│ └─────────────────────┘ ▼ │ │ │ +│ Design Space │ │ │ +│ ┌─────────────────────┐ │ │ │ │ +│ Generate │ │ Pretuning │ │ │ │ +│ ┌───────────┤ Search Strategy │◄─────────────┘ │ │ │ +│ │ │ │ │ ├──┘ +│ │ └─────────────────────┘ ├──┘ +└────┼─────────────────────────────────────────────────────────┘ + │ + │ +┌────┼──────────────── Managed By Task Scheduler ─────────────────────┐ +│ │ ┌───────────┐ │ +│ │ Send to │ │ Send to │ +│ ▼ ┌─────────────►│ Builder ├──────────┐ │ +│ Measure Candidate │ Builder │ │ Runner │ │ +│ │ │ └───────────┘ │ │ +│ │ ┌────────────┴────────┐ │ │ +│ │ │ │ ┌───────────┐ │ │ +│ └────►│ Task Scheduler │ │ │ │ │ +│ │ │ │ Runner │◄─────────┘ │ +│ └─────────────────────┘ │ │ │ +│ ▲ └─────┬─────┘ │ +│ │ │ │ +│ └─── Runner Future ◄────┘ │ +└─────────────────────────────────────────────────────────────────────┘ +*/ class SearchStrategyNode : public runtime::Object { public: /*! \brief Virtual destructor */ From aadb2182d0abfdc67dd3cfd2df6322a375ba7cc4 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 27 Sep 2021 22:45:58 -0700 Subject: [PATCH 6/8] Add docs. --- include/tvm/meta_schedule/builder.h | 4 +- include/tvm/meta_schedule/runner.h | 6 +-- include/tvm/meta_schedule/search_strategy.h | 23 ++++++----- include/tvm/meta_schedule/space_generator.h | 38 ++++++++++++++++++- .../search_strategy/replay_trace.cc | 7 ++-- src/meta_schedule/utils.h | 4 +- 6 files changed, 58 insertions(+), 24 deletions(-) diff --git a/include/tvm/meta_schedule/builder.h b/include/tvm/meta_schedule/builder.h index 9186c9d039e0..19358552df10 100644 --- a/include/tvm/meta_schedule/builder.h +++ b/include/tvm/meta_schedule/builder.h @@ -25,7 +25,7 @@ namespace tvm { namespace meta_schedule { -/*! \brief The builder's input. */ +/*! \brief The builder's input, containing an IRModule and the target. */ class BuilderInputNode : public runtime::Object { public: /*! \brief The IRModule to be built. */ @@ -57,7 +57,7 @@ class BuilderInput : public runtime::ObjectRef { TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BuilderInput, runtime::ObjectRef, BuilderInputNode); }; -/*! \brief The builder's output. */ +/*! \brief The builder's output, containing the artifact path or error message if any. */ class BuilderResultNode : public runtime::Object { public: /*! \brief The path to the built artifact. */ diff --git a/include/tvm/meta_schedule/runner.h b/include/tvm/meta_schedule/runner.h index 390fa661a0a8..36d07024559d 100644 --- a/include/tvm/meta_schedule/runner.h +++ b/include/tvm/meta_schedule/runner.h @@ -24,12 +24,12 @@ namespace tvm { namespace meta_schedule { -/*! \brief The runner's result. */ +/*! \brief Runner's output containing measurement result of MeasureCandidate or error msg if any. */ class RunnerResultNode : public runtime::Object { public: - /*! \brief The run time in seconds.*/ + /*! \brief The run time in seconds. If not None, error_msg should be None. */ Optional> run_secs; - /*! \brief The error message, if any. */ + /*! \brief The error message, if any. If not None, run_secs should be None. */ Optional error_msg; void VisitAttrs(tvm::AttrVisitor* v) { diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h index 19a7074bd1b1..7e9d2486ead0 100644 --- a/include/tvm/meta_schedule/search_strategy.h +++ b/include/tvm/meta_schedule/search_strategy.h @@ -30,12 +30,12 @@ namespace meta_schedule { // Forward declaration class TuneContext; -/*! \brief The measure candidate class. */ +/*! \brief The schedule (with input shapes) to be measured. */ class MeasureCandidateNode : public runtime::Object { public: - /*! \brief The schedule for profiling. */ + /*! \brief The schedule for measurement. */ tir::Schedule sch; - /*! \brief The argument information. */ + /*! \brief The argument information, e.g., (shape, dtype) for tensors. */ Array args_info; void VisitAttrs(tvm::AttrVisitor* v) { @@ -55,18 +55,16 @@ class MeasureCandidate : public runtime::ObjectRef { public: /*! * \brief Constructor of MeasureCandidate. - * \param sch The schedule for profiling. - * \param args_info The argument information. + * \param sch The schedule for measurement. + * \param args_info The argument information, e.g., (shape, dtype) for tensors. */ TVM_DLL MeasureCandidate(tir::Schedule sch, Array args_info); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(MeasureCandidate, ObjectRef, MeasureCandidateNode); }; /*! - -* \brief The search strategy for measure candidates generation. -* \note The relationship between SearchStrategy and other classes are as follows: - + * \brief The search strategy for measure candidates generation. + * \note The relationship between SearchStrategy and other classes are as follows: ┌──────────────────────────────────────────────────────────────┐ ┌──┴───────────────────────────────────────────────────────────┐ │ ┌──┴────────────────── Tune Context ───────────────────────────┐ │ │ @@ -108,6 +106,7 @@ class SearchStrategyNode : public runtime::Object { /*! * \brief Initialize the search strategy with tuning context. * \param tune_context The tuning context for initialization. + * \note This method is supposed to be called only once before every other method. */ virtual void InitializeWithTuneContext(const TuneContext& tune_context) = 0; @@ -127,8 +126,8 @@ class SearchStrategyNode : public runtime::Object { virtual Optional> GenerateMeasureCandidates() = 0; /*! - * \brief Update the search strategy with profiling results. - * \param results The profiling results from the runner. + * \brief Update the search strategy with measurement results. + * \param results The measurement results from the runner. */ virtual void NotifyRunnerResults(const Array& results) = 0; @@ -158,7 +157,7 @@ class PySearchStrategyNode : public SearchStrategyNode { using FGenerateMeasureCandidates = runtime::TypedPackedFunc>()>; /*! * \brief The function type of `NotifyRunnerResults` method. - * \param results The profiling results from the runner. + * \param results The measurement results from the runner. */ using FNotifyRunnerResults = runtime::TypedPackedFunc&)>; diff --git a/include/tvm/meta_schedule/space_generator.h b/include/tvm/meta_schedule/space_generator.h index 9528be2a85ad..3dc181e05d8a 100644 --- a/include/tvm/meta_schedule/space_generator.h +++ b/include/tvm/meta_schedule/space_generator.h @@ -28,7 +28,42 @@ namespace meta_schedule { // Forward declaration class TuneContext; -/*! \brief The abstract class for design space generation. */ +/*! + * \brief The abstract class for design space generation. + * \note The relationship between SpaceGenerator and other classes are as follows: + ┌──────────────────────────────────────────────────────────────┐ + ┌──┴───────────────────────────────────────────────────────────┐ │ +┌──┴────────────────── Tune Context ───────────────────────────┐ │ │ +│ ┌─────────────────────┐ │ │ │ +│ │ │ Generate │ │ │ +│ │ Space Generator ├──────────────┐ │ │ │ +│ │ │ │ │ │ │ +│ └─────────────────────┘ ▼ │ │ │ +│ Design Space │ │ │ +│ ┌─────────────────────┐ │ │ │ │ +│ Generate │ │ Pretuning │ │ │ │ +│ ┌───────────┤ Search Strategy │◄─────────────┘ │ │ │ +│ │ │ │ │ ├──┘ +│ │ └─────────────────────┘ ├──┘ +└────┼─────────────────────────────────────────────────────────┘ + │ + │ +┌────┼──────────────── Managed By Task Scheduler ─────────────────────┐ +│ │ ┌───────────┐ │ +│ │ Send to │ │ Send to │ +│ ▼ ┌─────────────►│ Builder ├──────────┐ │ +│ Measure Candidate │ Builder │ │ Runner │ │ +│ │ │ └───────────┘ │ │ +│ │ ┌────────────┴────────┐ │ │ +│ │ │ │ ┌───────────┐ │ │ +│ └────►│ Task Scheduler │ │ │ │ │ +│ │ │ │ Runner │◄─────────┘ │ +│ └─────────────────────┘ │ │ │ +│ ▲ └─────┬─────┘ │ +│ │ │ │ +│ └─── Runner Future ◄────┘ │ +└─────────────────────────────────────────────────────────────────────┘ +*/ class SpaceGeneratorNode : public Object { public: /*! \brief Default destructor */ @@ -37,6 +72,7 @@ class SpaceGeneratorNode : public Object { /*! * \brief Initialize the design space generator with tuning context. * \param tune_context The tuning context for initialization. + * \note This method is supposed to be called only once before every other method. */ virtual void InitializeWithTuneContext(const TuneContext& tune_context) = 0; diff --git a/src/meta_schedule/search_strategy/replay_trace.cc b/src/meta_schedule/search_strategy/replay_trace.cc index e10b6af3ba76..1c83aee8c0fd 100644 --- a/src/meta_schedule/search_strategy/replay_trace.cc +++ b/src/meta_schedule/search_strategy/replay_trace.cc @@ -21,7 +21,7 @@ namespace tvm { namespace meta_schedule { -/*! \brief A search strategy that replays the trace. */ +/*! \brief A search strategy that generates measure candidates using trace and random decisions. */ class ReplayTraceNode : public SearchStrategyNode { public: using TRandState = support::LinearCongruentialEngine::TRandState; @@ -53,9 +53,9 @@ class ReplayTraceNode : public SearchStrategyNode { IRModule mod_{nullptr}; /*! \brief The metadata of the function arguments. */ Array args_info_{nullptr}; - /*! \brief The number of threads to use. */ + /*! \brief The number of threads to use. -1 means using logical cpu number. */ int num_threads_ = -1; - /*! \brief The random state */ + /*! \brief The random state. -1 means using random number. */ TRandState rand_state_ = -1; /*! \brief The state of the search strategy. */ std::unique_ptr state_ = nullptr; @@ -73,7 +73,6 @@ class ReplayTraceNode : public SearchStrategyNode { static constexpr const char* _type_key = "meta_schedule.ReplayTrace"; TVM_DECLARE_FINAL_OBJECT_INFO(ReplayTraceNode, SearchStrategyNode); - public: void InitializeWithTuneContext(const TuneContext& tune_context) final { this->mod_ = tune_context->mod.value(); this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(this->mod_)); diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 6cc5fd99ca69..30294b8f91e1 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -33,7 +33,6 @@ #include #include - #include #include "../printer/text_printer.h" @@ -139,7 +138,8 @@ inline String JSONObj2Str(const ObjectRef& json_obj) { inline String SHash2Str(Workload::THashCode hash_code) { return std::to_string(hash_code); } /*! - * \brief Find the entry function of the given IRModule. + * \brief Find the entry function of the given IRModule, i.e, functions marked by + * `tir::attr::kIsEntryFunc`, whose name is `main` or being the only PrimeFunc. * \param mod The IRModule to find the entry function. * \return The entry function. */ From 4d95b27a3c9729daa2e2e4bc558b722b9f4d0d5b Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Tue, 28 Sep 2021 10:41:22 -0700 Subject: [PATCH 7/8] Fix docs. --- include/tvm/meta_schedule/search_strategy.h | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h index 7e9d2486ead0..c337c013a1c4 100644 --- a/include/tvm/meta_schedule/search_strategy.h +++ b/include/tvm/meta_schedule/search_strategy.h @@ -19,11 +19,10 @@ #ifndef TVM_META_SCHEDULE_SEARCH_STRATEGY_H_ #define TVM_META_SCHEDULE_SEARCH_STRATEGY_H_ +#include +#include #include -#include "./arg_info.h" -#include "./runner.h" - namespace tvm { namespace meta_schedule { @@ -224,6 +223,11 @@ class SearchStrategy : public runtime::ObjectRef { PySearchStrategyNode::FGenerateMeasureCandidates f_generate_measure_candidates, // PySearchStrategyNode::FNotifyRunnerResults f_notify_runner_results); + /*! + * \brief Constructor of replay trace search strategy. + * \param num_trials_per_iter The number of trials per iteration, i.e., the batch size. + * \param num_trials_total The total number of trials for trace replaying. + */ TVM_DLL static SearchStrategy ReplayTrace(int num_trials_per_iter, int num_trials_total); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchStrategy, ObjectRef, SearchStrategyNode); From 826399414adbace2fb7fa8b19620a6227498f09e Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Tue, 28 Sep 2021 13:29:52 -0700 Subject: [PATCH 8/8] Add notes. --- include/tvm/meta_schedule/search_strategy.h | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h index c337c013a1c4..941dae4336e1 100644 --- a/include/tvm/meta_schedule/search_strategy.h +++ b/include/tvm/meta_schedule/search_strategy.h @@ -112,10 +112,17 @@ class SearchStrategyNode : public runtime::Object { /*! * \brief Pre-tuning for the search strategy. * \param design_spaces The design spaces for pre-tuning. + * \note Pre-tuning is supposed to be called before the tuning process and after the + * initialization. Because the search strategy is stateful, we can always call pretuning + * and reset the search strategy. */ virtual void PreTuning(const Array& design_spaces) = 0; - /*! \brief Post-tuning for the search strategy. */ + /*! + * \brief Post-tuning for the search strategy. + * \note Post-tuning is supposed to be called after the tuning process and before we reset the + * search strategy with another pre-tuning. Post-tuning can be empty. + */ virtual void PostTuning() = 0; /*!