From 0b60f777e8e6d7f0e9bbc8ce64ebd380b3553922 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Tue, 21 Dec 2021 22:27:30 -0800 Subject: [PATCH 1/7] Modify TuneContext, TaskScheduler & SearchStrategy functions. Co-authored-by: Junru Shao Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin Co-authored-by: Siyuan Feng --- include/tvm/meta_schedule/tune_context.h | 2 +- python/tvm/auto_scheduler/search_task.py | 3 ++- python/tvm/auto_scheduler/workload_registry.py | 5 ++++- src/meta_schedule/search_strategy/replay_trace.cc | 1 + 4 files changed, 8 insertions(+), 3 deletions(-) diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h index 428a2e80f4dd..5a2cb99644a6 100644 --- a/include/tvm/meta_schedule/tune_context.h +++ b/include/tvm/meta_schedule/tune_context.h @@ -53,7 +53,7 @@ class TuneContextNode : public runtime::Object { /*! \brief The probability of using certain mutator. */ Map mutator_probs; /*! \brief The name of the tuning task. */ - Optional task_name; + String task_name; /*! \brief The random state. */ support::LinearCongruentialEngine::TRandState rand_state; /*! \brief The number of threads to be used. */ diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index f1156998bdac..0e9c4abebbe1 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -543,7 +543,8 @@ def print_best(self, log_file, print_mode="schedule"): code: str The best schedule code in python API or CUDA source code """ - inp, _ = load_best_record(log_file, self.workload_key) + inp, res = load_best_record(log_file, self.workload_key) + print("Best codes (ms):", [float(c) * 1000.0 for c in res.costs]) if inp is None: raise RuntimeError( "Cannot find any valid schedule for %s in file %s" % (self.workload_key, log_file) diff --git a/python/tvm/auto_scheduler/workload_registry.py b/python/tvm/auto_scheduler/workload_registry.py index 885eb0d1d0f8..75702b0a21af 100644 --- a/python/tvm/auto_scheduler/workload_registry.py +++ b/python/tvm/auto_scheduler/workload_registry.py @@ -194,7 +194,10 @@ def workload_key_to_tensors(workload_key): assert callable(value) args = deserialize_args(workload[1:]) - return value(*args) + result = value(*args) + if isinstance(result, tuple): + result = list(result) + return result def serialize_workload_registry_entry(workload_key): diff --git a/src/meta_schedule/search_strategy/replay_trace.cc b/src/meta_schedule/search_strategy/replay_trace.cc index 1eac10d1ad82..8c9e2d8949e9 100644 --- a/src/meta_schedule/search_strategy/replay_trace.cc +++ b/src/meta_schedule/search_strategy/replay_trace.cc @@ -17,6 +17,7 @@ * under the License. */ #include "../utils.h" +#include "tvm/tir/schedule/schedule.h" namespace tvm { namespace meta_schedule { From 0da5a9100b3ea37364e076c2cd99cdb4fddcab76 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Wed, 22 Dec 2021 01:33:19 -0800 Subject: [PATCH 2/7] Retrigger CI. From e88ed56a98c3b6e5c6ea767497e51102b1f654cb Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Wed, 22 Dec 2021 20:31:00 -0800 Subject: [PATCH 3/7] Add ReplayFunc and EvolutionarySearch strategy. Co-authored-by: Junru Shao Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin Co-authored-by: Siyuan Feng --- .../search_strategy/evolutionary_search.py | 117 +++ .../search_strategy/evolutionary_search.cc | 673 ++++++++++++++++++ src/tir/schedule/primitive.h | 17 + src/tir/schedule/primitive/sampling.cc | 22 + .../test_meta_schedule_search_strategy.py | 160 ++++- 5 files changed, 988 insertions(+), 1 deletion(-) create mode 100644 python/tvm/meta_schedule/search_strategy/evolutionary_search.py create mode 100644 src/meta_schedule/search_strategy/evolutionary_search.cc diff --git a/python/tvm/meta_schedule/search_strategy/evolutionary_search.py b/python/tvm/meta_schedule/search_strategy/evolutionary_search.py new file mode 100644 index 000000000000..a679c1970951 --- /dev/null +++ b/python/tvm/meta_schedule/search_strategy/evolutionary_search.py @@ -0,0 +1,117 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Evolutionary Search Strategy""" + +from typing import NamedTuple + +from tvm._ffi import register_object + +from .. import _ffi_api +from .search_strategy import SearchStrategy + + +@register_object("meta_schedule.EvolutionarySearch") +class EvolutionarySearch(SearchStrategy): + """ + Replay Trace Search Strategy is a search strategy that always replays the trace by removing its + decisions so that the decisions would be randomly re-generated. + + Parameters + ---------- + num_trials_per_iter : int + Number of trials per iteration. + num_trials_total : int + Total number of trials. + population_size : int + The initial population of traces from measured samples and randomly generated samples. + init_measured_ratio : int + The ratio of measured samples in the initial population. + init_max_fail_count : int + The maximum number to fail trace replaying. + genetic_num_iters : int + The number of iterations for genetic algorithm. + genetic_mutate_prob : float + The probability of mutation. + genetic_max_fail_count : int + The maximum number to retry mutation. + eps_greedy : float + The ratio of greedy selected samples in the final picks. + """ + + num_trials_per_iter: int + num_trials_total: int + population_size: int + init_measured_ratio: int + init_max_fail_count: int + genetic_num_iters: int + genetic_mutate_prob: float + genetic_max_fail_count: int + eps_greedy: float + + def __init__( + self, + *, + num_trials_per_iter: int, + num_trials_total: int, + population_size: int, + init_measured_ratio: float, + init_max_fail_count: int, + genetic_num_iters: int, + genetic_mutate_prob: float, + genetic_max_fail_count: int, + eps_greedy: float, + ) -> None: + """Constructor""" + self.__init_handle_by_constructor__( + _ffi_api.SearchStrategyEvolutionarySearch, # type: ignore # pylint: disable=no-member + num_trials_per_iter, + num_trials_total, + population_size, + init_measured_ratio, + init_max_fail_count, + genetic_num_iters, + genetic_mutate_prob, + genetic_max_fail_count, + eps_greedy, + ) + + +class EvolutionarySearchConfig(NamedTuple): + """Configuration for EvolutionarySearch""" + + num_trials_per_iter: int + num_trials_total: int + population_size: int = 2048 + init_measured_ratio: float = 0.2 + init_max_fail_count: int = 64 + genetic_num_iters: int = 4 + genetic_mutate_prob: float = 0.85 + genetic_max_fail_count: int = 10 + eps_greedy: float = 0.05 + + def create_strategy(self) -> EvolutionarySearch: + return EvolutionarySearch( + num_trials_per_iter=self.num_trials_per_iter, + num_trials_total=self.num_trials_total, + population_size=self.population_size, + init_measured_ratio=self.init_measured_ratio, + init_max_fail_count=self.init_max_fail_count, + genetic_num_iters=self.genetic_num_iters, + genetic_mutate_prob=self.genetic_mutate_prob, + genetic_max_fail_count=self.genetic_max_fail_count, + eps_greedy=self.eps_greedy, + ) diff --git a/src/meta_schedule/search_strategy/evolutionary_search.cc b/src/meta_schedule/search_strategy/evolutionary_search.cc new file mode 100644 index 000000000000..cb35406c1d8f --- /dev/null +++ b/src/meta_schedule/search_strategy/evolutionary_search.cc @@ -0,0 +1,673 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "../utils.h" + +#define TVM_META_SCHEDULE_CHECK_PROB_RANGE(p, name) \ + CHECK(0.0 <= (p) && (p) <= 1.0) << "ValueError: name should be within [0, 1], " \ + << "but get `" << #p << " = " << (p) << '\''; + +namespace tvm { +namespace meta_schedule { + +using tir::Schedule; + +/**************** Data Structure ****************/ + +/*! + * \brief A heap with a size up-limit. If overflow happens, it evicted the worst items. + * \note It maintains a min heap in terms of `Item::score`. Therefore, when + * overflow happens, the element evicted is the one with the min `Item::score`. + * As time goes, the elements in the heap are going to be larger. + */ +class SizedHeap { + public: + struct Item { + Schedule sch; + IRModule mod; + size_t shash; + double score; + bool operator<(const Item& other) const { return score > other.score; } + }; + + struct ItemHash { + size_t operator()(const Item& hash) const { return hash.shash; } + }; + + struct ItemEqual { + bool operator()(const Item& lhs, const Item& rhs) const { + return lhs.shash == rhs.shash && StructuralEqual()(lhs.mod, rhs.mod); + } + }; + /*! + * \brief Constructor + * \param size_limit The up-limit of the heap size + */ + explicit SizedHeap(int size_limit) : size_limit(size_limit) { heap.reserve(size_limit); } + + /*! + * \brief Push the specific item to the heap if its key did not appears in the heap + * \param item The item to be pushed + */ + void Push(Schedule sch, IRModule mod, double score) { + Item item{sch, mod, StructuralHash()(mod), score}; + if (!in_heap.insert(item).second) { + return; + } + int size = heap.size(); + if (size < size_limit) { + // Heap is not full, just push + heap.emplace_back(item); + std::push_heap(heap.begin(), heap.end()); + } else if (item.score > heap.front().score) { + // if the item is better than the worst one in the heap, we can safely kick it out + std::pop_heap(heap.begin(), heap.end()); + heap.back() = item; + std::push_heap(heap.begin(), heap.end()); + } + // Otherwise, the item is worse than any other element in the heap + } + + /*! \brief Up-limit of the heap size */ + int size_limit; + /*! \brief The heap, the worse the topper */ + std::vector heap; + /*! \brief The traces that are in the heap */ + std::unordered_set in_heap; +}; + +struct PerThreadData { + IRModule mod{nullptr}; + TRandState rand_state{-1}; + std::function trace_sampler = nullptr; + std::function()> mutator_sampler = nullptr; + + /*! + * \brief Set the value for the trace and mutator samplers per thread. + * \param scores The predicted score for the given samples. + * \param genetic_mutate_prob The probability of mutation. + * \param mutator_probs The probability of each mutator as a dict. + */ + void Set(const std::vector& scores, double genetic_mutate_prob, + const Map& mutator_probs) { + trace_sampler = tir::MakeMultinomialSampler(&rand_state, scores); + mutator_sampler = MakeMutatorSampler(genetic_mutate_prob, mutator_probs, &rand_state); + } + + private: + /*! + * \brief Create a sampler function that picks mutators according to the mass function + * \param rand_state The random state for sampling + * \return The sampler created + */ + static std::function()> MakeMutatorSampler( + double genetic_mutate_prob, // + const Map& mutator_probs, // + TRandState* rand_state) { + std::vector> mutators; + std::vector masses; + mutators.push_back(NullOpt); + masses.push_back(1.0 - genetic_mutate_prob); + double total_mass_mutator = 0.0; + if (genetic_mutate_prob > 0) { + for (const auto& kv : mutator_probs) { + Mutator mutator = kv.first; + double mass = kv.second->value; + total_mass_mutator += mass; + mutators.push_back(mutator); + masses.push_back(mass * genetic_mutate_prob); + } + } + // Normalize the sum to 1.0 + if (total_mass_mutator == 0.0) { + masses[0] = 1.0; + for (int i = 1, n = masses.size(); i < n; ++i) { + masses[i] = 0.0; + } + } else if (total_mass_mutator != 1.0) { + for (int i = 1, n = masses.size(); i < n; ++i) { + masses[i] /= total_mass_mutator; + } + } + return [idx_sampler = tir::MakeMultinomialSampler(rand_state, masses), + mutators = std::move(mutators)]() -> Optional { + int i = idx_sampler(); + return mutators[i]; + }; + } +}; + +struct ConcurrentBitmask { + /*! The bit width. */ + static constexpr const int kBitWidth = 64; + /*! \brief The size of the concurrent bitmask. */ + int size; + /*! \brief The bitmasks. */ + std::vector bitmask; + /*! \brief The mutexes, one per kBitWidth(64 here) bitmasks. */ + std::vector mutexes; + + /*! + * \brief Constructor + * \param n The total slots managed by the concurrent bitmask. + */ + explicit ConcurrentBitmask(int n) + : size((n + kBitWidth - 1) / kBitWidth), bitmask(size, 0), mutexes(size) {} + /*! + * \brief Query and mark the given index if not visited before. + * \param x The index to concurrently check if used. If not, mark as used. + * \return Whether the index has been used before. + */ + bool QueryAndMark(int x) { + constexpr uint64_t one = 1; + std::unique_lock lock(mutexes[x / kBitWidth]); + if (bitmask[x / kBitWidth] & (one << (x % kBitWidth))) { + return false; + } else { + bitmask[x / kBitWidth] |= one << (x % kBitWidth); + return true; + } + } +}; + +/**************** Util Functions ****************/ + +/*! + * \brief Assemble measure candidates from the given candidate traces. + * \param traces The picked candidate traces. + * \return The assembled measure candidates. + */ +Array AssembleCandidates(const std::vector& picks, + const Array& args_info) { + Array measure_inputs; + measure_inputs.reserve(picks.size()); + for (const Schedule& sch : picks) { + measure_inputs.push_back(MeasureCandidate(sch, args_info)); + } + return measure_inputs; +} + +/*! + * \brief Predict the normalized score of each candidate. + * \param candidates The candidates for prediction + * \param task The search task + * \param space The search space + * \return The normalized score in the prediction + */ +std::vector PredictNormalizedScore(const std::vector& candidates, + const TuneContext& context, const CostModel& cost_model, + const Array& args_info) { + ICHECK(!candidates.empty()) << "Candidates given for score prediction can not be empty list!"; + std::vector scores = + cost_model->Predict(context, AssembleCandidates(candidates, args_info)); + for (double& score : scores) { + score = std::max(0.0, score); + } + return scores; +} + +/**************** Evolutionary Search ****************/ + +/*!\brief A search strategy that generates measure candidates using evolutionary search. */ +class EvolutionarySearchNode : public SearchStrategyNode { + public: + /*! \brief The state of the search strategy. */ + struct State { + /*! \brief The search strategy itself */ + EvolutionarySearchNode* self; + /*! \brief The design spaces. Decisions are not used so traces only. */ + Array design_spaces; + /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ + int st; + /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ + int ed; + + explicit State(EvolutionarySearchNode* self, Array design_spaces) + : self(self), design_spaces(design_spaces), st(0), ed(self->num_trials_per_iter) {} + + /*! + * \brief Pick up best candidates from database. + * \param num The number of traces to produce. + * \return The picked best candidates. + */ + inline std::vector PickBestFromDatabase(int num); + /*! + * \brief Sample the initial population from previous measured results and randomly generated + * traces via trace replaying. + * \param num The number of traces to produce. + * \return The initial population of traces sampled. + */ + inline std::vector SampleInitPopulation(int num); + /*! + * \brief Evolve the initial population using mutators and samplers. + * \param population The initial population of traces sampled. + * \param num The number of traces to produce. + * \return The evolved traces from initial population. + */ + inline std::vector EvolveWithCostModel(std::vector population, int num); + /*! + * \brief Pick final candidates from the given initial population and bests of evolved ones. + * \param inits The initial population of traces sampled. + * \param bests The best candidates predicted from evolved traces. + * \param num The number of traces to produce. + * \return The final picked candidates with a ratio of both. + */ + inline std::vector PickWithEpsGreedy(const std::vector& inits, + const std::vector& bests, int num); + /*! \brief An interface method to be called by it's counterpart in EvolutionarySearchNode */ + inline Optional> GenerateMeasureCandidates(); + /*! \brief An interface method to be called by it's counterpart in EvolutionarySearchNode */ + inline void NotifyRunnerResults(const TuneContext& context, + const Array& measure_candidates, + const Array& results); + }; + + /*! \brief The tuning context of the evolutionary search strategy. */ + const TuneContextNode* context_{nullptr}; + /*! \brief The target for the workload. */ + Target target_{nullptr}; + /*! \brief The metadata of the function arguments. */ + Array args_info_{nullptr}; + /*! \brief A Database for selecting useful candidates. */ + Database database_{nullptr}; + /*! \brief A cost model helping to explore the search space */ + CostModel cost_model_{nullptr}; + /*! \brief The postprocessors. */ + Array postprocs_{nullptr}; + /*! \brief Mutators and their probability mass */ + Map mutator_probs_{nullptr}; + /*! \brief The number of threads to use. To be initialized with TuneContext. */ + int num_threads_; + /*! \brief The random state. To be initialized with TuneContext. */ + TRandState rand_state_; + /*! \brief Pre thread data including module to be tuned and random state. */ + std::vector per_thread_data_; + /*! \brief The state of the search strategy. */ + std::unique_ptr state_ = nullptr; + /*! \brief The token registered for the given workload in database. */ + Workload token_{nullptr}; + + /*** Configuration: global ***/ + /*! \brief The number of trials per iteration. */ + int num_trials_per_iter; + /*! \brief The number of total trials. */ + int num_trials_total; + /*! \brief The population size in the evolutionary search. */ + int population_size; + /*** Configuration: the initial population ***/ + /*! \brief The ratio of measured states used in the initial population */ + double init_measured_ratio; + /*! \brief The maximum number to fail trace replaying. */ + int init_max_fail_count; + /*** Configuration: evolution ***/ + /*! \brief The number of iterations performed by generic algorithm. */ + int genetic_num_iters; + /*! \brief The probability to perform mutation */ + double genetic_mutate_prob; + /*! \brief The maximum number to try evolving the given trace. */ + int genetic_max_fail_count; + /*** Configuration: pick states for measurement ***/ + /*! \brief The ratio of measurements to use randomly sampled states. */ + double eps_greedy; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `context_` is not visited + // `target_` is not visited + // `args_info_` is not visited + // `database` is not visited + // `cost_model` is not visited + // `postprocs` is not visited + // `mutator_probs_` is not visited + // `num_threads` is not visited + // `rand_state_` is not visited + // `per_thread_data_` is not visited + // `state_` is not visited + + /*** Configuration: global ***/ + v->Visit("num_trials_total", &num_trials_total); + v->Visit("num_trials_per_iter", &num_trials_per_iter); + v->Visit("population_size", &population_size); + /*** Configuration: the initial population ***/ + v->Visit("init_measured_ratio", &init_measured_ratio); + v->Visit("init_max_fail_count", &init_max_fail_count); + /*** Configuration: evolution ***/ + v->Visit("genetic_num_iters", &genetic_num_iters); + v->Visit("genetic_mutate_prob", &genetic_mutate_prob); + v->Visit("genetic_max_fail_count", &genetic_max_fail_count); + /*** Configuration: pick states for measurement ***/ + v->Visit("eps_greedy", &eps_greedy); + } + + static constexpr const char* _type_key = "meta_schedule.EvolutionarySearch"; + TVM_DECLARE_FINAL_OBJECT_INFO(EvolutionarySearchNode, SearchStrategyNode); + + void InitializeWithTuneContext(const TuneContext& context) final { + CHECK(context.defined()) << "TuneContext must be defined!"; + CHECK(context->num_threads > 0) << "Number of threads has to be larger than 0."; + CHECK(context->target.defined()) << "Target must be defined!"; + this->context_ = context.get(); + this->target_ = context->target.value(); + this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(context->mod.value())); + this->mutator_probs_ = context->mutator_probs; + this->postprocs_ = context->postprocs; + this->num_threads_ = context->num_threads; + this->rand_state_ = ForkSeed(&context->rand_state); + this->cost_model_ = context->task_scheduler->cost_model.value(); + this->database_ = context->task_scheduler->database; + this->token_ = this->database_->CommitWorkload(context->mod.value()); + this->per_thread_data_.resize(this->num_threads_); + for (const auto& kv : this->mutator_probs_) { + double mass = kv.second->value; + TVM_META_SCHEDULE_CHECK_PROB_RANGE(mass, "mutator_probs"); + } + for (PerThreadData& data : this->per_thread_data_) { + data.mod = DeepCopyIRModule(context->mod.value()); + data.rand_state = ForkSeed(&this->rand_state_); + } + this->state_.reset(); + } + + void PreTuning(const Array& design_spaces) final { + ICHECK(!design_spaces.empty()); + ICHECK(this->state_ == nullptr); + // Change to traces + Array design_space_traces; + design_space_traces.reserve(design_spaces.size()); + for (const Schedule& space : design_spaces) { + design_space_traces.push_back(space->trace().value()->Simplified(true)); + } + this->state_ = std::make_unique(this, design_space_traces); + } + + void PostTuning() final { + ICHECK(this->state_ != nullptr); + this->state_.reset(); + } + + Optional> GenerateMeasureCandidates() final { + ICHECK(this->state_ != nullptr); + return this->state_->GenerateMeasureCandidates(); + } + + void NotifyRunnerResults(const TuneContext& context, + const Array& measure_candidates, + const Array& results) final { + ICHECK(this->state_ != nullptr); + this->state_->NotifyRunnerResults(context, measure_candidates, results); + } +}; + +std::vector EvolutionarySearchNode::State::PickBestFromDatabase(int num) { + std::vector measured_traces; + measured_traces.reserve(num); + Array top_records = self->database_->GetTopK(self->token_, num); + for (TuningRecord record : top_records) { + measured_traces.push_back(record->trace); + } + int actual_num = measured_traces.size(); + ThreadedTraceApply pp(self->postprocs_); + std::vector results(actual_num, Schedule{nullptr}); + auto f_proc_measured = [this, &measured_traces, &results, &pp](int thread_id, + int trace_id) -> void { + PerThreadData& data = self->per_thread_data_.at(thread_id); + TRandState* rand_state = &data.rand_state; + const IRModule& mod = data.mod; + tir::Trace trace = measured_traces.at(trace_id); + Schedule& result = results.at(trace_id); + ICHECK(!result.defined()); + if (Optional sch = pp.Apply(mod, trace, rand_state)) { + result = sch.value(); + } else { + LOG(FATAL) << "ValueError: Cannot postprocess the trace:\n" << trace; + throw; + } + }; + support::parallel_for_dynamic(0, actual_num, self->num_threads_, f_proc_measured); + return results; +} + +std::vector EvolutionarySearchNode::State::SampleInitPopulation(int num) { + ThreadedTraceApply pp(self->postprocs_); + std::vector results(num, Schedule{nullptr}); + auto f_proc_unmeasured = [this, &results, &pp](int thread_id, int trace_id) -> void { + PerThreadData& data = self->per_thread_data_.at(thread_id); + TRandState* rand_state = &data.rand_state; + const IRModule& mod = data.mod; + Schedule& result = results.at(trace_id); + ICHECK(!result.defined()); + for (int fail_count = 0; fail_count <= self->init_max_fail_count; ++fail_count) { + int design_space_index = tir::SampleInt(rand_state, 0, design_spaces.size()); + tir::Trace trace(design_spaces[design_space_index]->insts, {}); + if (Optional sch = pp.Apply(mod, trace, rand_state)) { + result = sch.value(); + break; + } + } + if (!result.defined()) { + LOG(FATAL) << "Sample-Init-Population failed over the maximum limit! Summary:\n" + << pp.SummarizeFailures(); + } + }; + support::parallel_for_dynamic(0, num, self->num_threads_, f_proc_unmeasured); + LOG(INFO) << "Sample-Init-Population summary:\n" << pp.SummarizeFailures(); + return results; +} + +std::vector EvolutionarySearchNode::State::EvolveWithCostModel( + std::vector population, int num) { + ICHECK_GT(num, 0); + // The heap to record best schedule, we do not consider schedules that are already measured + // Also we use `in_heap` to make sure items in the heap are de-duplicated + SizedHeap heap(num); + for (int iter = 0;; ++iter) { + // Predict normalized score with the cost model, + std::vector scores = PredictNormalizedScore(population, // + GetRef(self->context_), // + self->cost_model_, // + self->args_info_); + ICHECK_EQ(scores.size(), population.size()); + for (int i = 0, n = population.size(); i < n; ++i) { + Schedule sch = population.at(i); + IRModule mod = sch->mod(); + double score = scores.at(i); + if (!self->database_->HasWorkload(mod)) { + heap.Push(sch, mod, score); + } + } + // Discontinue once it reaches end of search + if (iter == self->genetic_num_iters) { + break; + } + // Set threaded samplers, with probability from predicated normalized throughputs + for (PerThreadData& data : self->per_thread_data_) { + data.Set(scores, self->genetic_mutate_prob, self->mutator_probs_); + } + ThreadedTraceApply pp(self->postprocs_); + ConcurrentBitmask cbmask(self->population_size); + std::vector next_population(self->population_size, Schedule{nullptr}); + // The worker function + auto f_find_candidate = [&cbmask, &population, &next_population, &pp, this](int thread_id, + int trace_id) { + // Prepare samplers + PerThreadData& data = self->per_thread_data_.at(thread_id); + TRandState* rand_state = &data.rand_state; + const IRModule& mod = data.mod; + std::function& trace_sampler = data.trace_sampler; + std::function()>& mutator_sampler = data.mutator_sampler; + Schedule& result = next_population.at(trace_id); + int sampled_trace_id = -1; + // Loop until success + for (int fail_count = 0; fail_count <= self->genetic_max_fail_count; ++fail_count) { + sampled_trace_id = trace_sampler(); + tir::Trace trace = population.at(sampled_trace_id)->trace().value(); + if (Optional opt_mutator = mutator_sampler()) { + // Decision: mutate + Mutator mutator = opt_mutator.value(); + if (Optional new_trace = mutator->Apply(trace, rand_state)) { + if (Optional sch = pp.Apply(mod, new_trace.value(), rand_state)) { + // note that sch's trace is different from new_trace + // because it contains post-processing information + result = sch.value(); + break; + } + } + } else if (cbmask.QueryAndMark(sampled_trace_id)) { + // Decision: do not mutate + break; + } + } + // if retry count exceeds the limit, reuse an old sample + if (!result.defined()) { + result = population.at(sampled_trace_id); + } + }; + support::parallel_for_dynamic(0, self->population_size, self->num_threads_, f_find_candidate); + population.swap(next_population); + LOG(INFO) << "Evolve iter #" << iter << " done. Summary:\n" << pp.SummarizeFailures(); + } + // Return the best states from the heap, sorting from higher score to lower ones + std::sort(heap.heap.begin(), heap.heap.end()); + std::vector results; + results.reserve(num); + for (const SizedHeap::Item& item : heap.heap) { + results.push_back(item.sch); + } + + constexpr int kNumScoresPerLine = 16; + std::ostringstream os; + int n = heap.heap.size(); + for (int st = 0; st < n; st += kNumScoresPerLine) { + os << std::endl; + int ed = std::min(st + kNumScoresPerLine, n); + os << "[" << (st + 1) << " : " << ed << "]:\t"; + for (int i = st; i < ed; ++i) { + if (i != st) { + os << " "; + } + os << std::fixed << std::setprecision(4) << heap.heap.at(i).score; + } + } + LOG(INFO) << "Scores of the best " << n << " candidates:" << os.str(); + return results; +} + +std::vector EvolutionarySearchNode::State::PickWithEpsGreedy( + const std::vector& unmeasured, const std::vector& bests, int num) { + int num_rands = num * self->eps_greedy; + int num_bests = num - num_rands; + std::vector rands = + tir::SampleWithoutReplacement(&self->rand_state_, unmeasured.size(), unmeasured.size()); + std::vector results; + results.reserve(num); + for (int i = 0, i_bests = 0, i_rands = 0; i < num; ++i) { + bool has_best = i_bests < static_cast(bests.size()); + bool has_rand = i_rands < static_cast(rands.size()); + // Pick a schedule + Schedule sch{nullptr}; + // If needs `bests`, then prefer `bests` + if (i < num_bests) { + if (has_best) { + sch = bests[i_bests++]; + } else if (has_rand) { + sch = unmeasured[rands[i_rands++]]; + } else { + break; + } + } else { + // Else prefer `rands` + if (has_rand) { + sch = unmeasured[rands[i_rands++]]; + } else if (has_best) { + sch = bests[i_bests++]; + } else { + break; + } + } + results.push_back(sch); + } + return results; +} + +Optional> EvolutionarySearchNode::State::GenerateMeasureCandidates() { + if (st >= self->num_trials_total) { + return NullOpt; + } + int sample_num = self->num_trials_per_iter; + if (ed > self->num_trials_total) { + sample_num = self->num_trials_total - st; + ed = self->num_trials_total; + } + ICHECK_LT(st, ed); + int pop = self->population_size; + std::vector inits; + inits.reserve(pop); + + LOG(INFO) << "Generating candidates......"; + std::vector measured = PickBestFromDatabase(pop * self->init_measured_ratio); + LOG(INFO) << "Picked top " << measured.size() << " candidate(s) from database"; + std::vector unmeasured = SampleInitPopulation(pop - measured.size()); + LOG(INFO) << "Sampled " << unmeasured.size() << " candidate(s)"; + inits.insert(inits.end(), measured.begin(), measured.end()); + inits.insert(inits.end(), unmeasured.begin(), unmeasured.end()); + ICHECK_EQ(inits.size(), self->population_size); + std::vector bests = EvolveWithCostModel(inits, sample_num); + LOG(INFO) << "Got " << bests.size() << " candidate(s) with evolutionary search"; + std::vector picks = PickWithEpsGreedy(unmeasured, bests, sample_num); + LOG(INFO) << "Sending " << picks.size() << " candidates(s) for measurement"; + return AssembleCandidates(picks, self->args_info_); +} + +void EvolutionarySearchNode::State::NotifyRunnerResults( + const TuneContext& context, const Array& measure_candidates, + const Array& results) { + st += results.size(); + ed += results.size(); +} + +SearchStrategy SearchStrategy::EvolutionarySearch(int num_trials_per_iter, // + int num_trials_total, // + int population_size, // + double init_measured_ratio, // + int init_max_fail_count, // + int genetic_num_iters, // + double genetic_mutate_prob, // + int genetic_max_fail_count, // + double eps_greedy) { + TVM_META_SCHEDULE_CHECK_PROB_RANGE(init_measured_ratio, "Initial measured ratio"); + TVM_META_SCHEDULE_CHECK_PROB_RANGE(genetic_mutate_prob, "Mutation probability"); + TVM_META_SCHEDULE_CHECK_PROB_RANGE(eps_greedy, "Greedy pick probability"); + ObjectPtr n = make_object(); + n->num_trials_per_iter = num_trials_per_iter; + n->num_trials_total = num_trials_total; + n->population_size = population_size; + n->init_measured_ratio = init_measured_ratio; + n->init_max_fail_count = init_max_fail_count; + n->genetic_num_iters = genetic_num_iters; + n->genetic_max_fail_count = genetic_max_fail_count; + n->genetic_mutate_prob = genetic_mutate_prob; + n->eps_greedy = eps_greedy; + return SearchStrategy(n); +} + +TVM_REGISTER_NODE_TYPE(EvolutionarySearchNode); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyEvolutionarySearch") + .set_body_typed(SearchStrategy::EvolutionarySearch); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 212e53aa500f..45efd9f76cef 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -36,6 +36,15 @@ namespace tir { */ TVM_DLL int32_t SampleInt(support::LinearCongruentialEngine::TRandState* rand_state, int32_t min_inclusive, int32_t max_exclusive); +/*! + * \brief Sample k random integers from given range without replacement, i.e, no duplication. + * \param rand_state The pointer to schedule's random state + * \param n The range is defined as 0 to n-1. + * \param k The total number of samples. + * \return The randomly selected samples from the n candidates. + */ +std::vector SampleWithoutReplacement( + support::LinearCongruentialEngine::TRandState* rand_state, int32_t n, int32_t k); /*! * \brief Sample once category from candidates according to the probability weights. * \param rand_state The pointer to schedule's random state @@ -47,6 +56,14 @@ TVM_DLL int32_t SampleInt(support::LinearCongruentialEngine::TRandState* rand_st TVM_DLL int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state, const Array& candidates, const Array& probs, Optional* decision); +/*! + * \brief Create a sampling function that does multinomial sampling. + * \param rand_state The random state. + * \param weights The weights for multinomial sampling. + * \return The multinomial sampling function. + */ +TVM_DLL std::function MakeMultinomialSampler( + support::LinearCongruentialEngine::TRandState* rand_state, const std::vector& weights); /*! * \brief Sample the factors to perfect tile a specific loop * \param rand_state The random state diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 171838572dbb..83ef1e20be60 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -187,6 +187,28 @@ int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_st return candidates[i]; } +std::function MakeMultinomialSampler( + support::LinearCongruentialEngine::TRandState* rand_state, const std::vector& weights) { + ICHECK(!weights.empty()); + std::vector sums; + sums.reserve(weights.size()); + double sum = 0.0; + for (double w : weights) { + sums.push_back(sum += w); + } + return [rng = support::LinearCongruentialEngine(rand_state).ForkSeed(), + dist = std::uniform_real_distribution(0.0, sum), + sums = std::move(sums)]() mutable -> int32_t { + support::LinearCongruentialEngine rand_(&rng); + double p = dist(rand_); + int32_t idx = std::lower_bound(sums.begin(), sums.end(), p) - sums.begin(); + int32_t n = sums.size(); + CHECK_LE(0, idx); + CHECK_LE(idx, n); + return (idx == n) ? (n - 1) : idx; + }; +} + std::vector SamplePerfectTile(support::LinearCongruentialEngine::TRandState* rand_state, int32_t extent, int32_t n_splits) { CHECK_GE(extent, 1) << "ValueError: Cannot tile a loop with 0 or negative extent"; diff --git a/tests/python/unittest/test_meta_schedule_search_strategy.py b/tests/python/unittest/test_meta_schedule_search_strategy.py index a4d32175eb0b..53b27ff375da 100644 --- a/tests/python/unittest/test_meta_schedule_search_strategy.py +++ b/tests/python/unittest/test_meta_schedule_search_strategy.py @@ -19,14 +19,20 @@ import sys import pytest import tvm +from tvm.ir import IRModule from tvm.meta_schedule import TuneContext -from tvm.meta_schedule.runner import RunnerResult +from tvm.meta_schedule.builder import LocalBuilder +from tvm.meta_schedule.cost_model import PyCostModel +from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload +from tvm.meta_schedule.mutator.mutator import PyMutator +from tvm.meta_schedule.runner import LocalRunner, RunnerResult from tvm.meta_schedule.search_strategy import ( ReplayFunc, ReplayTrace, SearchStrategy, ) from tvm.meta_schedule.space_generator import ScheduleFn +from tvm.meta_schedule.task_scheduler import RoundRobin from tvm.script import tir as T from tvm.tir.schedule import Schedule, Trace @@ -105,5 +111,157 @@ def test_meta_schedule_replay_func(TestClass: SearchStrategy): # pylint: disabl assert num_trials_each_iter == [7, 7, 6] +def test_meta_schedule_evolutionary_search(): # pylint: disable = invalid-name + class DummyMutator(PyMutator): + """Dummy Mutator for testing""" + + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, trace: Trace) -> Optional[Trace]: + return Trace(trace.insts, {}) + + class DummyDatabase(PyDatabase): + """Dummy Database for testing""" + + def __init__(self): + super().__init__() + self.records = [] + self.workload_reg = [] + + def has_workload(self, mod: IRModule) -> bool: + for workload in self.workload_reg: + if tvm.ir.structural_equal(workload.mod, mod): + return True + return False + + def commit_tuning_record(self, record: TuningRecord) -> None: + self.records.append(record) + + def commit_workload(self, mod: IRModule) -> Workload: + for workload in self.workload_reg: + if tvm.ir.structural_equal(workload.mod, mod): + return workload + workload = Workload(mod) + self.workload_reg.append(workload) + return workload + + def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]: + return list( + filter( + lambda x: x.workload == workload, + sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)), + ) + )[: int(top_k)] + + def __len__(self) -> int: + return len(self.records) + + def print_results(self) -> None: + print("\n".join([str(r) for r in self.records])) + + class RandomModel(PyCostModel): + """Random cost model for testing""" + + random_state: Union[Tuple[str, np.ndarray, int, int, float], dict] + path: Optional[str] + + def __init__( + self, + *, + seed: Optional[int] = None, + path: Optional[str] = None, + max_range: Optional[int] = 100, + ): + super().__init__() + if path is not None: + self.load(path) + else: + np.random.seed(seed) + self.random_state = np.random.get_state() + self.max_range = max_range + + def load(self, path: str) -> None: + self.random_state = tuple(np.load(path, allow_pickle=True)) + + def save(self, path: str) -> None: + np.save(path, np.array(self.random_state, dtype=object), allow_pickle=True) + + def update( + self, + tune_context: TuneContext, + candidates: List[MeasureCandidate], + results: List[RunnerResult], + ) -> None: + pass + + def predict( + self, tune_context: TuneContext, candidates: List[MeasureCandidate] + ) -> np.ndarray: + np.random.set_state(self.random_state) + result = np.random.rand(len(candidates)) * self.max_range + self.random_state = np.random.get_state() + return result + + num_trials_per_iter = 10 + num_trials_total = 100 + + strategy = EvolutionarySearch( + num_trials_per_iter=num_trials_per_iter, + num_trials_total=num_trials_total, + population_size=5, + init_measured_ratio=0.1, + init_max_fail_count=10, + genetic_num_iters=3, + genetic_mutate_prob=0.5, + genetic_max_fail_count=10, + eps_greedy=0.9, + ) + tune_context = TuneContext( + mod=Matmul, + space_generator=ScheduleFn(sch_fn=_schedule_matmul), + mutator_probs={ + DummyMutator(): 1.0, + }, + target=tvm.target.Target("llvm"), + num_threads=1, # because we are using a mutator from the python side + ) + _scheduler = RoundRobin( + tasks=[tune_context], + builder=LocalBuilder(), + runner=LocalRunner(), + database=DummyDatabase(), + cost_model=RandomModel(), + measure_callbacks=[], + ) + tune_context.space_generator.initialize_with_tune_context(tune_context) + spaces = tune_context.space_generator.generate_design_space(tune_context.mod) + + strategy.initialize_with_tune_context(tune_context) + strategy.pre_tuning(spaces) + (correct_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul) + num_trials_each_iter: List[int] = [] + candidates = strategy.generate_measure_candidates() + while candidates is not None: + num_trials_each_iter.append(len(candidates)) + runner_results: List[RunnerResult] = [] + for candidate in candidates: + _is_trace_equal( + candidate.sch, + correct_sch, + remove_decisions=(isinstance(strategy, ReplayTrace)), + ) + runner_results.append(RunnerResult(run_secs=[0.11, 0.41, 0.54], error_msg=None)) + strategy.notify_runner_results(tune_context, candidates, runner_results) + candidates = strategy.generate_measure_candidates() + strategy.post_tuning() + print(num_trials_each_iter) + correct_count = 10 # For each iteration except the last one + assert num_trials_each_iter == [correct_count] * (num_trials_total // correct_count) + ( + [num_trials_total % correct_count] if num_trials_total % correct_count != 0 else [] + ) + del _scheduler + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) From 9d19d9b94833c20438dfe71074c74fbf9d977eac Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Tue, 4 Jan 2022 14:37:33 -0800 Subject: [PATCH 4/7] Fix optional task name. Co-authored-by: Junru Shao Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin Co-authored-by: Siyuan Feng --- include/tvm/meta_schedule/tune_context.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h index 5a2cb99644a6..428a2e80f4dd 100644 --- a/include/tvm/meta_schedule/tune_context.h +++ b/include/tvm/meta_schedule/tune_context.h @@ -53,7 +53,7 @@ class TuneContextNode : public runtime::Object { /*! \brief The probability of using certain mutator. */ Map mutator_probs; /*! \brief The name of the tuning task. */ - String task_name; + Optional task_name; /*! \brief The random state. */ support::LinearCongruentialEngine::TRandState rand_state; /*! \brief The number of threads to be used. */ From e1e96362831ad9694e8d3c1affc518b9c2c5f821 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Tue, 4 Jan 2022 14:50:18 -0800 Subject: [PATCH 5/7] Remove extra files. --- python/tvm/auto_scheduler/search_task.py | 3 +- .../tvm/auto_scheduler/workload_registry.py | 5 +- .../search_strategy/evolutionary_search.py | 117 --- .../search_strategy/evolutionary_search.cc | 673 ------------------ .../test_meta_schedule_search_strategy.py | 160 +---- 5 files changed, 3 insertions(+), 955 deletions(-) delete mode 100644 python/tvm/meta_schedule/search_strategy/evolutionary_search.py delete mode 100644 src/meta_schedule/search_strategy/evolutionary_search.cc diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index 0e9c4abebbe1..f1156998bdac 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -543,8 +543,7 @@ def print_best(self, log_file, print_mode="schedule"): code: str The best schedule code in python API or CUDA source code """ - inp, res = load_best_record(log_file, self.workload_key) - print("Best codes (ms):", [float(c) * 1000.0 for c in res.costs]) + inp, _ = load_best_record(log_file, self.workload_key) if inp is None: raise RuntimeError( "Cannot find any valid schedule for %s in file %s" % (self.workload_key, log_file) diff --git a/python/tvm/auto_scheduler/workload_registry.py b/python/tvm/auto_scheduler/workload_registry.py index 75702b0a21af..885eb0d1d0f8 100644 --- a/python/tvm/auto_scheduler/workload_registry.py +++ b/python/tvm/auto_scheduler/workload_registry.py @@ -194,10 +194,7 @@ def workload_key_to_tensors(workload_key): assert callable(value) args = deserialize_args(workload[1:]) - result = value(*args) - if isinstance(result, tuple): - result = list(result) - return result + return value(*args) def serialize_workload_registry_entry(workload_key): diff --git a/python/tvm/meta_schedule/search_strategy/evolutionary_search.py b/python/tvm/meta_schedule/search_strategy/evolutionary_search.py deleted file mode 100644 index a679c1970951..000000000000 --- a/python/tvm/meta_schedule/search_strategy/evolutionary_search.py +++ /dev/null @@ -1,117 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Evolutionary Search Strategy""" - -from typing import NamedTuple - -from tvm._ffi import register_object - -from .. import _ffi_api -from .search_strategy import SearchStrategy - - -@register_object("meta_schedule.EvolutionarySearch") -class EvolutionarySearch(SearchStrategy): - """ - Replay Trace Search Strategy is a search strategy that always replays the trace by removing its - decisions so that the decisions would be randomly re-generated. - - Parameters - ---------- - num_trials_per_iter : int - Number of trials per iteration. - num_trials_total : int - Total number of trials. - population_size : int - The initial population of traces from measured samples and randomly generated samples. - init_measured_ratio : int - The ratio of measured samples in the initial population. - init_max_fail_count : int - The maximum number to fail trace replaying. - genetic_num_iters : int - The number of iterations for genetic algorithm. - genetic_mutate_prob : float - The probability of mutation. - genetic_max_fail_count : int - The maximum number to retry mutation. - eps_greedy : float - The ratio of greedy selected samples in the final picks. - """ - - num_trials_per_iter: int - num_trials_total: int - population_size: int - init_measured_ratio: int - init_max_fail_count: int - genetic_num_iters: int - genetic_mutate_prob: float - genetic_max_fail_count: int - eps_greedy: float - - def __init__( - self, - *, - num_trials_per_iter: int, - num_trials_total: int, - population_size: int, - init_measured_ratio: float, - init_max_fail_count: int, - genetic_num_iters: int, - genetic_mutate_prob: float, - genetic_max_fail_count: int, - eps_greedy: float, - ) -> None: - """Constructor""" - self.__init_handle_by_constructor__( - _ffi_api.SearchStrategyEvolutionarySearch, # type: ignore # pylint: disable=no-member - num_trials_per_iter, - num_trials_total, - population_size, - init_measured_ratio, - init_max_fail_count, - genetic_num_iters, - genetic_mutate_prob, - genetic_max_fail_count, - eps_greedy, - ) - - -class EvolutionarySearchConfig(NamedTuple): - """Configuration for EvolutionarySearch""" - - num_trials_per_iter: int - num_trials_total: int - population_size: int = 2048 - init_measured_ratio: float = 0.2 - init_max_fail_count: int = 64 - genetic_num_iters: int = 4 - genetic_mutate_prob: float = 0.85 - genetic_max_fail_count: int = 10 - eps_greedy: float = 0.05 - - def create_strategy(self) -> EvolutionarySearch: - return EvolutionarySearch( - num_trials_per_iter=self.num_trials_per_iter, - num_trials_total=self.num_trials_total, - population_size=self.population_size, - init_measured_ratio=self.init_measured_ratio, - init_max_fail_count=self.init_max_fail_count, - genetic_num_iters=self.genetic_num_iters, - genetic_mutate_prob=self.genetic_mutate_prob, - genetic_max_fail_count=self.genetic_max_fail_count, - eps_greedy=self.eps_greedy, - ) diff --git a/src/meta_schedule/search_strategy/evolutionary_search.cc b/src/meta_schedule/search_strategy/evolutionary_search.cc deleted file mode 100644 index cb35406c1d8f..000000000000 --- a/src/meta_schedule/search_strategy/evolutionary_search.cc +++ /dev/null @@ -1,673 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include "../utils.h" - -#define TVM_META_SCHEDULE_CHECK_PROB_RANGE(p, name) \ - CHECK(0.0 <= (p) && (p) <= 1.0) << "ValueError: name should be within [0, 1], " \ - << "but get `" << #p << " = " << (p) << '\''; - -namespace tvm { -namespace meta_schedule { - -using tir::Schedule; - -/**************** Data Structure ****************/ - -/*! - * \brief A heap with a size up-limit. If overflow happens, it evicted the worst items. - * \note It maintains a min heap in terms of `Item::score`. Therefore, when - * overflow happens, the element evicted is the one with the min `Item::score`. - * As time goes, the elements in the heap are going to be larger. - */ -class SizedHeap { - public: - struct Item { - Schedule sch; - IRModule mod; - size_t shash; - double score; - bool operator<(const Item& other) const { return score > other.score; } - }; - - struct ItemHash { - size_t operator()(const Item& hash) const { return hash.shash; } - }; - - struct ItemEqual { - bool operator()(const Item& lhs, const Item& rhs) const { - return lhs.shash == rhs.shash && StructuralEqual()(lhs.mod, rhs.mod); - } - }; - /*! - * \brief Constructor - * \param size_limit The up-limit of the heap size - */ - explicit SizedHeap(int size_limit) : size_limit(size_limit) { heap.reserve(size_limit); } - - /*! - * \brief Push the specific item to the heap if its key did not appears in the heap - * \param item The item to be pushed - */ - void Push(Schedule sch, IRModule mod, double score) { - Item item{sch, mod, StructuralHash()(mod), score}; - if (!in_heap.insert(item).second) { - return; - } - int size = heap.size(); - if (size < size_limit) { - // Heap is not full, just push - heap.emplace_back(item); - std::push_heap(heap.begin(), heap.end()); - } else if (item.score > heap.front().score) { - // if the item is better than the worst one in the heap, we can safely kick it out - std::pop_heap(heap.begin(), heap.end()); - heap.back() = item; - std::push_heap(heap.begin(), heap.end()); - } - // Otherwise, the item is worse than any other element in the heap - } - - /*! \brief Up-limit of the heap size */ - int size_limit; - /*! \brief The heap, the worse the topper */ - std::vector heap; - /*! \brief The traces that are in the heap */ - std::unordered_set in_heap; -}; - -struct PerThreadData { - IRModule mod{nullptr}; - TRandState rand_state{-1}; - std::function trace_sampler = nullptr; - std::function()> mutator_sampler = nullptr; - - /*! - * \brief Set the value for the trace and mutator samplers per thread. - * \param scores The predicted score for the given samples. - * \param genetic_mutate_prob The probability of mutation. - * \param mutator_probs The probability of each mutator as a dict. - */ - void Set(const std::vector& scores, double genetic_mutate_prob, - const Map& mutator_probs) { - trace_sampler = tir::MakeMultinomialSampler(&rand_state, scores); - mutator_sampler = MakeMutatorSampler(genetic_mutate_prob, mutator_probs, &rand_state); - } - - private: - /*! - * \brief Create a sampler function that picks mutators according to the mass function - * \param rand_state The random state for sampling - * \return The sampler created - */ - static std::function()> MakeMutatorSampler( - double genetic_mutate_prob, // - const Map& mutator_probs, // - TRandState* rand_state) { - std::vector> mutators; - std::vector masses; - mutators.push_back(NullOpt); - masses.push_back(1.0 - genetic_mutate_prob); - double total_mass_mutator = 0.0; - if (genetic_mutate_prob > 0) { - for (const auto& kv : mutator_probs) { - Mutator mutator = kv.first; - double mass = kv.second->value; - total_mass_mutator += mass; - mutators.push_back(mutator); - masses.push_back(mass * genetic_mutate_prob); - } - } - // Normalize the sum to 1.0 - if (total_mass_mutator == 0.0) { - masses[0] = 1.0; - for (int i = 1, n = masses.size(); i < n; ++i) { - masses[i] = 0.0; - } - } else if (total_mass_mutator != 1.0) { - for (int i = 1, n = masses.size(); i < n; ++i) { - masses[i] /= total_mass_mutator; - } - } - return [idx_sampler = tir::MakeMultinomialSampler(rand_state, masses), - mutators = std::move(mutators)]() -> Optional { - int i = idx_sampler(); - return mutators[i]; - }; - } -}; - -struct ConcurrentBitmask { - /*! The bit width. */ - static constexpr const int kBitWidth = 64; - /*! \brief The size of the concurrent bitmask. */ - int size; - /*! \brief The bitmasks. */ - std::vector bitmask; - /*! \brief The mutexes, one per kBitWidth(64 here) bitmasks. */ - std::vector mutexes; - - /*! - * \brief Constructor - * \param n The total slots managed by the concurrent bitmask. - */ - explicit ConcurrentBitmask(int n) - : size((n + kBitWidth - 1) / kBitWidth), bitmask(size, 0), mutexes(size) {} - /*! - * \brief Query and mark the given index if not visited before. - * \param x The index to concurrently check if used. If not, mark as used. - * \return Whether the index has been used before. - */ - bool QueryAndMark(int x) { - constexpr uint64_t one = 1; - std::unique_lock lock(mutexes[x / kBitWidth]); - if (bitmask[x / kBitWidth] & (one << (x % kBitWidth))) { - return false; - } else { - bitmask[x / kBitWidth] |= one << (x % kBitWidth); - return true; - } - } -}; - -/**************** Util Functions ****************/ - -/*! - * \brief Assemble measure candidates from the given candidate traces. - * \param traces The picked candidate traces. - * \return The assembled measure candidates. - */ -Array AssembleCandidates(const std::vector& picks, - const Array& args_info) { - Array measure_inputs; - measure_inputs.reserve(picks.size()); - for (const Schedule& sch : picks) { - measure_inputs.push_back(MeasureCandidate(sch, args_info)); - } - return measure_inputs; -} - -/*! - * \brief Predict the normalized score of each candidate. - * \param candidates The candidates for prediction - * \param task The search task - * \param space The search space - * \return The normalized score in the prediction - */ -std::vector PredictNormalizedScore(const std::vector& candidates, - const TuneContext& context, const CostModel& cost_model, - const Array& args_info) { - ICHECK(!candidates.empty()) << "Candidates given for score prediction can not be empty list!"; - std::vector scores = - cost_model->Predict(context, AssembleCandidates(candidates, args_info)); - for (double& score : scores) { - score = std::max(0.0, score); - } - return scores; -} - -/**************** Evolutionary Search ****************/ - -/*!\brief A search strategy that generates measure candidates using evolutionary search. */ -class EvolutionarySearchNode : public SearchStrategyNode { - public: - /*! \brief The state of the search strategy. */ - struct State { - /*! \brief The search strategy itself */ - EvolutionarySearchNode* self; - /*! \brief The design spaces. Decisions are not used so traces only. */ - Array design_spaces; - /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ - int st; - /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ - int ed; - - explicit State(EvolutionarySearchNode* self, Array design_spaces) - : self(self), design_spaces(design_spaces), st(0), ed(self->num_trials_per_iter) {} - - /*! - * \brief Pick up best candidates from database. - * \param num The number of traces to produce. - * \return The picked best candidates. - */ - inline std::vector PickBestFromDatabase(int num); - /*! - * \brief Sample the initial population from previous measured results and randomly generated - * traces via trace replaying. - * \param num The number of traces to produce. - * \return The initial population of traces sampled. - */ - inline std::vector SampleInitPopulation(int num); - /*! - * \brief Evolve the initial population using mutators and samplers. - * \param population The initial population of traces sampled. - * \param num The number of traces to produce. - * \return The evolved traces from initial population. - */ - inline std::vector EvolveWithCostModel(std::vector population, int num); - /*! - * \brief Pick final candidates from the given initial population and bests of evolved ones. - * \param inits The initial population of traces sampled. - * \param bests The best candidates predicted from evolved traces. - * \param num The number of traces to produce. - * \return The final picked candidates with a ratio of both. - */ - inline std::vector PickWithEpsGreedy(const std::vector& inits, - const std::vector& bests, int num); - /*! \brief An interface method to be called by it's counterpart in EvolutionarySearchNode */ - inline Optional> GenerateMeasureCandidates(); - /*! \brief An interface method to be called by it's counterpart in EvolutionarySearchNode */ - inline void NotifyRunnerResults(const TuneContext& context, - const Array& measure_candidates, - const Array& results); - }; - - /*! \brief The tuning context of the evolutionary search strategy. */ - const TuneContextNode* context_{nullptr}; - /*! \brief The target for the workload. */ - Target target_{nullptr}; - /*! \brief The metadata of the function arguments. */ - Array args_info_{nullptr}; - /*! \brief A Database for selecting useful candidates. */ - Database database_{nullptr}; - /*! \brief A cost model helping to explore the search space */ - CostModel cost_model_{nullptr}; - /*! \brief The postprocessors. */ - Array postprocs_{nullptr}; - /*! \brief Mutators and their probability mass */ - Map mutator_probs_{nullptr}; - /*! \brief The number of threads to use. To be initialized with TuneContext. */ - int num_threads_; - /*! \brief The random state. To be initialized with TuneContext. */ - TRandState rand_state_; - /*! \brief Pre thread data including module to be tuned and random state. */ - std::vector per_thread_data_; - /*! \brief The state of the search strategy. */ - std::unique_ptr state_ = nullptr; - /*! \brief The token registered for the given workload in database. */ - Workload token_{nullptr}; - - /*** Configuration: global ***/ - /*! \brief The number of trials per iteration. */ - int num_trials_per_iter; - /*! \brief The number of total trials. */ - int num_trials_total; - /*! \brief The population size in the evolutionary search. */ - int population_size; - /*** Configuration: the initial population ***/ - /*! \brief The ratio of measured states used in the initial population */ - double init_measured_ratio; - /*! \brief The maximum number to fail trace replaying. */ - int init_max_fail_count; - /*** Configuration: evolution ***/ - /*! \brief The number of iterations performed by generic algorithm. */ - int genetic_num_iters; - /*! \brief The probability to perform mutation */ - double genetic_mutate_prob; - /*! \brief The maximum number to try evolving the given trace. */ - int genetic_max_fail_count; - /*** Configuration: pick states for measurement ***/ - /*! \brief The ratio of measurements to use randomly sampled states. */ - double eps_greedy; - - void VisitAttrs(tvm::AttrVisitor* v) { - // `context_` is not visited - // `target_` is not visited - // `args_info_` is not visited - // `database` is not visited - // `cost_model` is not visited - // `postprocs` is not visited - // `mutator_probs_` is not visited - // `num_threads` is not visited - // `rand_state_` is not visited - // `per_thread_data_` is not visited - // `state_` is not visited - - /*** Configuration: global ***/ - v->Visit("num_trials_total", &num_trials_total); - v->Visit("num_trials_per_iter", &num_trials_per_iter); - v->Visit("population_size", &population_size); - /*** Configuration: the initial population ***/ - v->Visit("init_measured_ratio", &init_measured_ratio); - v->Visit("init_max_fail_count", &init_max_fail_count); - /*** Configuration: evolution ***/ - v->Visit("genetic_num_iters", &genetic_num_iters); - v->Visit("genetic_mutate_prob", &genetic_mutate_prob); - v->Visit("genetic_max_fail_count", &genetic_max_fail_count); - /*** Configuration: pick states for measurement ***/ - v->Visit("eps_greedy", &eps_greedy); - } - - static constexpr const char* _type_key = "meta_schedule.EvolutionarySearch"; - TVM_DECLARE_FINAL_OBJECT_INFO(EvolutionarySearchNode, SearchStrategyNode); - - void InitializeWithTuneContext(const TuneContext& context) final { - CHECK(context.defined()) << "TuneContext must be defined!"; - CHECK(context->num_threads > 0) << "Number of threads has to be larger than 0."; - CHECK(context->target.defined()) << "Target must be defined!"; - this->context_ = context.get(); - this->target_ = context->target.value(); - this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(context->mod.value())); - this->mutator_probs_ = context->mutator_probs; - this->postprocs_ = context->postprocs; - this->num_threads_ = context->num_threads; - this->rand_state_ = ForkSeed(&context->rand_state); - this->cost_model_ = context->task_scheduler->cost_model.value(); - this->database_ = context->task_scheduler->database; - this->token_ = this->database_->CommitWorkload(context->mod.value()); - this->per_thread_data_.resize(this->num_threads_); - for (const auto& kv : this->mutator_probs_) { - double mass = kv.second->value; - TVM_META_SCHEDULE_CHECK_PROB_RANGE(mass, "mutator_probs"); - } - for (PerThreadData& data : this->per_thread_data_) { - data.mod = DeepCopyIRModule(context->mod.value()); - data.rand_state = ForkSeed(&this->rand_state_); - } - this->state_.reset(); - } - - void PreTuning(const Array& design_spaces) final { - ICHECK(!design_spaces.empty()); - ICHECK(this->state_ == nullptr); - // Change to traces - Array design_space_traces; - design_space_traces.reserve(design_spaces.size()); - for (const Schedule& space : design_spaces) { - design_space_traces.push_back(space->trace().value()->Simplified(true)); - } - this->state_ = std::make_unique(this, design_space_traces); - } - - void PostTuning() final { - ICHECK(this->state_ != nullptr); - this->state_.reset(); - } - - Optional> GenerateMeasureCandidates() final { - ICHECK(this->state_ != nullptr); - return this->state_->GenerateMeasureCandidates(); - } - - void NotifyRunnerResults(const TuneContext& context, - const Array& measure_candidates, - const Array& results) final { - ICHECK(this->state_ != nullptr); - this->state_->NotifyRunnerResults(context, measure_candidates, results); - } -}; - -std::vector EvolutionarySearchNode::State::PickBestFromDatabase(int num) { - std::vector measured_traces; - measured_traces.reserve(num); - Array top_records = self->database_->GetTopK(self->token_, num); - for (TuningRecord record : top_records) { - measured_traces.push_back(record->trace); - } - int actual_num = measured_traces.size(); - ThreadedTraceApply pp(self->postprocs_); - std::vector results(actual_num, Schedule{nullptr}); - auto f_proc_measured = [this, &measured_traces, &results, &pp](int thread_id, - int trace_id) -> void { - PerThreadData& data = self->per_thread_data_.at(thread_id); - TRandState* rand_state = &data.rand_state; - const IRModule& mod = data.mod; - tir::Trace trace = measured_traces.at(trace_id); - Schedule& result = results.at(trace_id); - ICHECK(!result.defined()); - if (Optional sch = pp.Apply(mod, trace, rand_state)) { - result = sch.value(); - } else { - LOG(FATAL) << "ValueError: Cannot postprocess the trace:\n" << trace; - throw; - } - }; - support::parallel_for_dynamic(0, actual_num, self->num_threads_, f_proc_measured); - return results; -} - -std::vector EvolutionarySearchNode::State::SampleInitPopulation(int num) { - ThreadedTraceApply pp(self->postprocs_); - std::vector results(num, Schedule{nullptr}); - auto f_proc_unmeasured = [this, &results, &pp](int thread_id, int trace_id) -> void { - PerThreadData& data = self->per_thread_data_.at(thread_id); - TRandState* rand_state = &data.rand_state; - const IRModule& mod = data.mod; - Schedule& result = results.at(trace_id); - ICHECK(!result.defined()); - for (int fail_count = 0; fail_count <= self->init_max_fail_count; ++fail_count) { - int design_space_index = tir::SampleInt(rand_state, 0, design_spaces.size()); - tir::Trace trace(design_spaces[design_space_index]->insts, {}); - if (Optional sch = pp.Apply(mod, trace, rand_state)) { - result = sch.value(); - break; - } - } - if (!result.defined()) { - LOG(FATAL) << "Sample-Init-Population failed over the maximum limit! Summary:\n" - << pp.SummarizeFailures(); - } - }; - support::parallel_for_dynamic(0, num, self->num_threads_, f_proc_unmeasured); - LOG(INFO) << "Sample-Init-Population summary:\n" << pp.SummarizeFailures(); - return results; -} - -std::vector EvolutionarySearchNode::State::EvolveWithCostModel( - std::vector population, int num) { - ICHECK_GT(num, 0); - // The heap to record best schedule, we do not consider schedules that are already measured - // Also we use `in_heap` to make sure items in the heap are de-duplicated - SizedHeap heap(num); - for (int iter = 0;; ++iter) { - // Predict normalized score with the cost model, - std::vector scores = PredictNormalizedScore(population, // - GetRef(self->context_), // - self->cost_model_, // - self->args_info_); - ICHECK_EQ(scores.size(), population.size()); - for (int i = 0, n = population.size(); i < n; ++i) { - Schedule sch = population.at(i); - IRModule mod = sch->mod(); - double score = scores.at(i); - if (!self->database_->HasWorkload(mod)) { - heap.Push(sch, mod, score); - } - } - // Discontinue once it reaches end of search - if (iter == self->genetic_num_iters) { - break; - } - // Set threaded samplers, with probability from predicated normalized throughputs - for (PerThreadData& data : self->per_thread_data_) { - data.Set(scores, self->genetic_mutate_prob, self->mutator_probs_); - } - ThreadedTraceApply pp(self->postprocs_); - ConcurrentBitmask cbmask(self->population_size); - std::vector next_population(self->population_size, Schedule{nullptr}); - // The worker function - auto f_find_candidate = [&cbmask, &population, &next_population, &pp, this](int thread_id, - int trace_id) { - // Prepare samplers - PerThreadData& data = self->per_thread_data_.at(thread_id); - TRandState* rand_state = &data.rand_state; - const IRModule& mod = data.mod; - std::function& trace_sampler = data.trace_sampler; - std::function()>& mutator_sampler = data.mutator_sampler; - Schedule& result = next_population.at(trace_id); - int sampled_trace_id = -1; - // Loop until success - for (int fail_count = 0; fail_count <= self->genetic_max_fail_count; ++fail_count) { - sampled_trace_id = trace_sampler(); - tir::Trace trace = population.at(sampled_trace_id)->trace().value(); - if (Optional opt_mutator = mutator_sampler()) { - // Decision: mutate - Mutator mutator = opt_mutator.value(); - if (Optional new_trace = mutator->Apply(trace, rand_state)) { - if (Optional sch = pp.Apply(mod, new_trace.value(), rand_state)) { - // note that sch's trace is different from new_trace - // because it contains post-processing information - result = sch.value(); - break; - } - } - } else if (cbmask.QueryAndMark(sampled_trace_id)) { - // Decision: do not mutate - break; - } - } - // if retry count exceeds the limit, reuse an old sample - if (!result.defined()) { - result = population.at(sampled_trace_id); - } - }; - support::parallel_for_dynamic(0, self->population_size, self->num_threads_, f_find_candidate); - population.swap(next_population); - LOG(INFO) << "Evolve iter #" << iter << " done. Summary:\n" << pp.SummarizeFailures(); - } - // Return the best states from the heap, sorting from higher score to lower ones - std::sort(heap.heap.begin(), heap.heap.end()); - std::vector results; - results.reserve(num); - for (const SizedHeap::Item& item : heap.heap) { - results.push_back(item.sch); - } - - constexpr int kNumScoresPerLine = 16; - std::ostringstream os; - int n = heap.heap.size(); - for (int st = 0; st < n; st += kNumScoresPerLine) { - os << std::endl; - int ed = std::min(st + kNumScoresPerLine, n); - os << "[" << (st + 1) << " : " << ed << "]:\t"; - for (int i = st; i < ed; ++i) { - if (i != st) { - os << " "; - } - os << std::fixed << std::setprecision(4) << heap.heap.at(i).score; - } - } - LOG(INFO) << "Scores of the best " << n << " candidates:" << os.str(); - return results; -} - -std::vector EvolutionarySearchNode::State::PickWithEpsGreedy( - const std::vector& unmeasured, const std::vector& bests, int num) { - int num_rands = num * self->eps_greedy; - int num_bests = num - num_rands; - std::vector rands = - tir::SampleWithoutReplacement(&self->rand_state_, unmeasured.size(), unmeasured.size()); - std::vector results; - results.reserve(num); - for (int i = 0, i_bests = 0, i_rands = 0; i < num; ++i) { - bool has_best = i_bests < static_cast(bests.size()); - bool has_rand = i_rands < static_cast(rands.size()); - // Pick a schedule - Schedule sch{nullptr}; - // If needs `bests`, then prefer `bests` - if (i < num_bests) { - if (has_best) { - sch = bests[i_bests++]; - } else if (has_rand) { - sch = unmeasured[rands[i_rands++]]; - } else { - break; - } - } else { - // Else prefer `rands` - if (has_rand) { - sch = unmeasured[rands[i_rands++]]; - } else if (has_best) { - sch = bests[i_bests++]; - } else { - break; - } - } - results.push_back(sch); - } - return results; -} - -Optional> EvolutionarySearchNode::State::GenerateMeasureCandidates() { - if (st >= self->num_trials_total) { - return NullOpt; - } - int sample_num = self->num_trials_per_iter; - if (ed > self->num_trials_total) { - sample_num = self->num_trials_total - st; - ed = self->num_trials_total; - } - ICHECK_LT(st, ed); - int pop = self->population_size; - std::vector inits; - inits.reserve(pop); - - LOG(INFO) << "Generating candidates......"; - std::vector measured = PickBestFromDatabase(pop * self->init_measured_ratio); - LOG(INFO) << "Picked top " << measured.size() << " candidate(s) from database"; - std::vector unmeasured = SampleInitPopulation(pop - measured.size()); - LOG(INFO) << "Sampled " << unmeasured.size() << " candidate(s)"; - inits.insert(inits.end(), measured.begin(), measured.end()); - inits.insert(inits.end(), unmeasured.begin(), unmeasured.end()); - ICHECK_EQ(inits.size(), self->population_size); - std::vector bests = EvolveWithCostModel(inits, sample_num); - LOG(INFO) << "Got " << bests.size() << " candidate(s) with evolutionary search"; - std::vector picks = PickWithEpsGreedy(unmeasured, bests, sample_num); - LOG(INFO) << "Sending " << picks.size() << " candidates(s) for measurement"; - return AssembleCandidates(picks, self->args_info_); -} - -void EvolutionarySearchNode::State::NotifyRunnerResults( - const TuneContext& context, const Array& measure_candidates, - const Array& results) { - st += results.size(); - ed += results.size(); -} - -SearchStrategy SearchStrategy::EvolutionarySearch(int num_trials_per_iter, // - int num_trials_total, // - int population_size, // - double init_measured_ratio, // - int init_max_fail_count, // - int genetic_num_iters, // - double genetic_mutate_prob, // - int genetic_max_fail_count, // - double eps_greedy) { - TVM_META_SCHEDULE_CHECK_PROB_RANGE(init_measured_ratio, "Initial measured ratio"); - TVM_META_SCHEDULE_CHECK_PROB_RANGE(genetic_mutate_prob, "Mutation probability"); - TVM_META_SCHEDULE_CHECK_PROB_RANGE(eps_greedy, "Greedy pick probability"); - ObjectPtr n = make_object(); - n->num_trials_per_iter = num_trials_per_iter; - n->num_trials_total = num_trials_total; - n->population_size = population_size; - n->init_measured_ratio = init_measured_ratio; - n->init_max_fail_count = init_max_fail_count; - n->genetic_num_iters = genetic_num_iters; - n->genetic_max_fail_count = genetic_max_fail_count; - n->genetic_mutate_prob = genetic_mutate_prob; - n->eps_greedy = eps_greedy; - return SearchStrategy(n); -} - -TVM_REGISTER_NODE_TYPE(EvolutionarySearchNode); -TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyEvolutionarySearch") - .set_body_typed(SearchStrategy::EvolutionarySearch); - -} // namespace meta_schedule -} // namespace tvm diff --git a/tests/python/unittest/test_meta_schedule_search_strategy.py b/tests/python/unittest/test_meta_schedule_search_strategy.py index 53b27ff375da..a4d32175eb0b 100644 --- a/tests/python/unittest/test_meta_schedule_search_strategy.py +++ b/tests/python/unittest/test_meta_schedule_search_strategy.py @@ -19,20 +19,14 @@ import sys import pytest import tvm -from tvm.ir import IRModule from tvm.meta_schedule import TuneContext -from tvm.meta_schedule.builder import LocalBuilder -from tvm.meta_schedule.cost_model import PyCostModel -from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload -from tvm.meta_schedule.mutator.mutator import PyMutator -from tvm.meta_schedule.runner import LocalRunner, RunnerResult +from tvm.meta_schedule.runner import RunnerResult from tvm.meta_schedule.search_strategy import ( ReplayFunc, ReplayTrace, SearchStrategy, ) from tvm.meta_schedule.space_generator import ScheduleFn -from tvm.meta_schedule.task_scheduler import RoundRobin from tvm.script import tir as T from tvm.tir.schedule import Schedule, Trace @@ -111,157 +105,5 @@ def test_meta_schedule_replay_func(TestClass: SearchStrategy): # pylint: disabl assert num_trials_each_iter == [7, 7, 6] -def test_meta_schedule_evolutionary_search(): # pylint: disable = invalid-name - class DummyMutator(PyMutator): - """Dummy Mutator for testing""" - - def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: - pass - - def apply(self, trace: Trace) -> Optional[Trace]: - return Trace(trace.insts, {}) - - class DummyDatabase(PyDatabase): - """Dummy Database for testing""" - - def __init__(self): - super().__init__() - self.records = [] - self.workload_reg = [] - - def has_workload(self, mod: IRModule) -> bool: - for workload in self.workload_reg: - if tvm.ir.structural_equal(workload.mod, mod): - return True - return False - - def commit_tuning_record(self, record: TuningRecord) -> None: - self.records.append(record) - - def commit_workload(self, mod: IRModule) -> Workload: - for workload in self.workload_reg: - if tvm.ir.structural_equal(workload.mod, mod): - return workload - workload = Workload(mod) - self.workload_reg.append(workload) - return workload - - def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]: - return list( - filter( - lambda x: x.workload == workload, - sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)), - ) - )[: int(top_k)] - - def __len__(self) -> int: - return len(self.records) - - def print_results(self) -> None: - print("\n".join([str(r) for r in self.records])) - - class RandomModel(PyCostModel): - """Random cost model for testing""" - - random_state: Union[Tuple[str, np.ndarray, int, int, float], dict] - path: Optional[str] - - def __init__( - self, - *, - seed: Optional[int] = None, - path: Optional[str] = None, - max_range: Optional[int] = 100, - ): - super().__init__() - if path is not None: - self.load(path) - else: - np.random.seed(seed) - self.random_state = np.random.get_state() - self.max_range = max_range - - def load(self, path: str) -> None: - self.random_state = tuple(np.load(path, allow_pickle=True)) - - def save(self, path: str) -> None: - np.save(path, np.array(self.random_state, dtype=object), allow_pickle=True) - - def update( - self, - tune_context: TuneContext, - candidates: List[MeasureCandidate], - results: List[RunnerResult], - ) -> None: - pass - - def predict( - self, tune_context: TuneContext, candidates: List[MeasureCandidate] - ) -> np.ndarray: - np.random.set_state(self.random_state) - result = np.random.rand(len(candidates)) * self.max_range - self.random_state = np.random.get_state() - return result - - num_trials_per_iter = 10 - num_trials_total = 100 - - strategy = EvolutionarySearch( - num_trials_per_iter=num_trials_per_iter, - num_trials_total=num_trials_total, - population_size=5, - init_measured_ratio=0.1, - init_max_fail_count=10, - genetic_num_iters=3, - genetic_mutate_prob=0.5, - genetic_max_fail_count=10, - eps_greedy=0.9, - ) - tune_context = TuneContext( - mod=Matmul, - space_generator=ScheduleFn(sch_fn=_schedule_matmul), - mutator_probs={ - DummyMutator(): 1.0, - }, - target=tvm.target.Target("llvm"), - num_threads=1, # because we are using a mutator from the python side - ) - _scheduler = RoundRobin( - tasks=[tune_context], - builder=LocalBuilder(), - runner=LocalRunner(), - database=DummyDatabase(), - cost_model=RandomModel(), - measure_callbacks=[], - ) - tune_context.space_generator.initialize_with_tune_context(tune_context) - spaces = tune_context.space_generator.generate_design_space(tune_context.mod) - - strategy.initialize_with_tune_context(tune_context) - strategy.pre_tuning(spaces) - (correct_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul) - num_trials_each_iter: List[int] = [] - candidates = strategy.generate_measure_candidates() - while candidates is not None: - num_trials_each_iter.append(len(candidates)) - runner_results: List[RunnerResult] = [] - for candidate in candidates: - _is_trace_equal( - candidate.sch, - correct_sch, - remove_decisions=(isinstance(strategy, ReplayTrace)), - ) - runner_results.append(RunnerResult(run_secs=[0.11, 0.41, 0.54], error_msg=None)) - strategy.notify_runner_results(tune_context, candidates, runner_results) - candidates = strategy.generate_measure_candidates() - strategy.post_tuning() - print(num_trials_each_iter) - correct_count = 10 # For each iteration except the last one - assert num_trials_each_iter == [correct_count] * (num_trials_total // correct_count) + ( - [num_trials_total % correct_count] if num_trials_total % correct_count != 0 else [] - ) - del _scheduler - - if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) From 9742e459a499d1da2216b61f1d494c8157bb9c04 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Tue, 4 Jan 2022 14:57:01 -0800 Subject: [PATCH 6/7] Fix things. --- .../search_strategy/replay_trace.cc | 1 - src/tir/schedule/primitive.h | 17 -------------- src/tir/schedule/primitive/sampling.cc | 22 ------------------- 3 files changed, 40 deletions(-) diff --git a/src/meta_schedule/search_strategy/replay_trace.cc b/src/meta_schedule/search_strategy/replay_trace.cc index 8c9e2d8949e9..1eac10d1ad82 100644 --- a/src/meta_schedule/search_strategy/replay_trace.cc +++ b/src/meta_schedule/search_strategy/replay_trace.cc @@ -17,7 +17,6 @@ * under the License. */ #include "../utils.h" -#include "tvm/tir/schedule/schedule.h" namespace tvm { namespace meta_schedule { diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 45efd9f76cef..212e53aa500f 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -36,15 +36,6 @@ namespace tir { */ TVM_DLL int32_t SampleInt(support::LinearCongruentialEngine::TRandState* rand_state, int32_t min_inclusive, int32_t max_exclusive); -/*! - * \brief Sample k random integers from given range without replacement, i.e, no duplication. - * \param rand_state The pointer to schedule's random state - * \param n The range is defined as 0 to n-1. - * \param k The total number of samples. - * \return The randomly selected samples from the n candidates. - */ -std::vector SampleWithoutReplacement( - support::LinearCongruentialEngine::TRandState* rand_state, int32_t n, int32_t k); /*! * \brief Sample once category from candidates according to the probability weights. * \param rand_state The pointer to schedule's random state @@ -56,14 +47,6 @@ std::vector SampleWithoutReplacement( TVM_DLL int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state, const Array& candidates, const Array& probs, Optional* decision); -/*! - * \brief Create a sampling function that does multinomial sampling. - * \param rand_state The random state. - * \param weights The weights for multinomial sampling. - * \return The multinomial sampling function. - */ -TVM_DLL std::function MakeMultinomialSampler( - support::LinearCongruentialEngine::TRandState* rand_state, const std::vector& weights); /*! * \brief Sample the factors to perfect tile a specific loop * \param rand_state The random state diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 83ef1e20be60..171838572dbb 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -187,28 +187,6 @@ int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_st return candidates[i]; } -std::function MakeMultinomialSampler( - support::LinearCongruentialEngine::TRandState* rand_state, const std::vector& weights) { - ICHECK(!weights.empty()); - std::vector sums; - sums.reserve(weights.size()); - double sum = 0.0; - for (double w : weights) { - sums.push_back(sum += w); - } - return [rng = support::LinearCongruentialEngine(rand_state).ForkSeed(), - dist = std::uniform_real_distribution(0.0, sum), - sums = std::move(sums)]() mutable -> int32_t { - support::LinearCongruentialEngine rand_(&rng); - double p = dist(rand_); - int32_t idx = std::lower_bound(sums.begin(), sums.end(), p) - sums.begin(); - int32_t n = sums.size(); - CHECK_LE(0, idx); - CHECK_LE(idx, n); - return (idx == n) ? (n - 1) : idx; - }; -} - std::vector SamplePerfectTile(support::LinearCongruentialEngine::TRandState* rand_state, int32_t extent, int32_t n_splits) { CHECK_GE(extent, 1) << "ValueError: Cannot tile a loop with 0 or negative extent"; From d3fa6ac384c45d7e21353239ec27813e1b955c87 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Tue, 4 Jan 2022 20:44:18 -0800 Subject: [PATCH 7/7] Add evolutionary search. Co-authored-by: Junru Shao Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin Co-authored-by: Siyuan Feng --- .../meta_schedule/search_strategy/__init__.py | 1 + .../search_strategy/evolutionary_search.py | 117 +++ .../search_strategy/evolutionary_search.cc | 673 ++++++++++++++++++ src/tir/schedule/primitive.h | 17 + src/tir/schedule/primitive/sampling.cc | 22 + .../test_meta_schedule_search_strategy.py | 173 ++++- 6 files changed, 997 insertions(+), 6 deletions(-) create mode 100644 python/tvm/meta_schedule/search_strategy/evolutionary_search.py create mode 100644 src/meta_schedule/search_strategy/evolutionary_search.cc diff --git a/python/tvm/meta_schedule/search_strategy/__init__.py b/python/tvm/meta_schedule/search_strategy/__init__.py index f385b72db46d..174672235b42 100644 --- a/python/tvm/meta_schedule/search_strategy/__init__.py +++ b/python/tvm/meta_schedule/search_strategy/__init__.py @@ -23,3 +23,4 @@ from .search_strategy import SearchStrategy, PySearchStrategy, MeasureCandidate from .replay_trace import ReplayTrace, ReplayTraceConfig from .replay_func import ReplayFunc, ReplayFuncConfig +from .evolutionary_search import EvolutionarySearch, EvolutionarySearchConfig diff --git a/python/tvm/meta_schedule/search_strategy/evolutionary_search.py b/python/tvm/meta_schedule/search_strategy/evolutionary_search.py new file mode 100644 index 000000000000..a679c1970951 --- /dev/null +++ b/python/tvm/meta_schedule/search_strategy/evolutionary_search.py @@ -0,0 +1,117 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Evolutionary Search Strategy""" + +from typing import NamedTuple + +from tvm._ffi import register_object + +from .. import _ffi_api +from .search_strategy import SearchStrategy + + +@register_object("meta_schedule.EvolutionarySearch") +class EvolutionarySearch(SearchStrategy): + """ + Replay Trace Search Strategy is a search strategy that always replays the trace by removing its + decisions so that the decisions would be randomly re-generated. + + Parameters + ---------- + num_trials_per_iter : int + Number of trials per iteration. + num_trials_total : int + Total number of trials. + population_size : int + The initial population of traces from measured samples and randomly generated samples. + init_measured_ratio : int + The ratio of measured samples in the initial population. + init_max_fail_count : int + The maximum number to fail trace replaying. + genetic_num_iters : int + The number of iterations for genetic algorithm. + genetic_mutate_prob : float + The probability of mutation. + genetic_max_fail_count : int + The maximum number to retry mutation. + eps_greedy : float + The ratio of greedy selected samples in the final picks. + """ + + num_trials_per_iter: int + num_trials_total: int + population_size: int + init_measured_ratio: int + init_max_fail_count: int + genetic_num_iters: int + genetic_mutate_prob: float + genetic_max_fail_count: int + eps_greedy: float + + def __init__( + self, + *, + num_trials_per_iter: int, + num_trials_total: int, + population_size: int, + init_measured_ratio: float, + init_max_fail_count: int, + genetic_num_iters: int, + genetic_mutate_prob: float, + genetic_max_fail_count: int, + eps_greedy: float, + ) -> None: + """Constructor""" + self.__init_handle_by_constructor__( + _ffi_api.SearchStrategyEvolutionarySearch, # type: ignore # pylint: disable=no-member + num_trials_per_iter, + num_trials_total, + population_size, + init_measured_ratio, + init_max_fail_count, + genetic_num_iters, + genetic_mutate_prob, + genetic_max_fail_count, + eps_greedy, + ) + + +class EvolutionarySearchConfig(NamedTuple): + """Configuration for EvolutionarySearch""" + + num_trials_per_iter: int + num_trials_total: int + population_size: int = 2048 + init_measured_ratio: float = 0.2 + init_max_fail_count: int = 64 + genetic_num_iters: int = 4 + genetic_mutate_prob: float = 0.85 + genetic_max_fail_count: int = 10 + eps_greedy: float = 0.05 + + def create_strategy(self) -> EvolutionarySearch: + return EvolutionarySearch( + num_trials_per_iter=self.num_trials_per_iter, + num_trials_total=self.num_trials_total, + population_size=self.population_size, + init_measured_ratio=self.init_measured_ratio, + init_max_fail_count=self.init_max_fail_count, + genetic_num_iters=self.genetic_num_iters, + genetic_mutate_prob=self.genetic_mutate_prob, + genetic_max_fail_count=self.genetic_max_fail_count, + eps_greedy=self.eps_greedy, + ) diff --git a/src/meta_schedule/search_strategy/evolutionary_search.cc b/src/meta_schedule/search_strategy/evolutionary_search.cc new file mode 100644 index 000000000000..cb35406c1d8f --- /dev/null +++ b/src/meta_schedule/search_strategy/evolutionary_search.cc @@ -0,0 +1,673 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "../utils.h" + +#define TVM_META_SCHEDULE_CHECK_PROB_RANGE(p, name) \ + CHECK(0.0 <= (p) && (p) <= 1.0) << "ValueError: name should be within [0, 1], " \ + << "but get `" << #p << " = " << (p) << '\''; + +namespace tvm { +namespace meta_schedule { + +using tir::Schedule; + +/**************** Data Structure ****************/ + +/*! + * \brief A heap with a size up-limit. If overflow happens, it evicted the worst items. + * \note It maintains a min heap in terms of `Item::score`. Therefore, when + * overflow happens, the element evicted is the one with the min `Item::score`. + * As time goes, the elements in the heap are going to be larger. + */ +class SizedHeap { + public: + struct Item { + Schedule sch; + IRModule mod; + size_t shash; + double score; + bool operator<(const Item& other) const { return score > other.score; } + }; + + struct ItemHash { + size_t operator()(const Item& hash) const { return hash.shash; } + }; + + struct ItemEqual { + bool operator()(const Item& lhs, const Item& rhs) const { + return lhs.shash == rhs.shash && StructuralEqual()(lhs.mod, rhs.mod); + } + }; + /*! + * \brief Constructor + * \param size_limit The up-limit of the heap size + */ + explicit SizedHeap(int size_limit) : size_limit(size_limit) { heap.reserve(size_limit); } + + /*! + * \brief Push the specific item to the heap if its key did not appears in the heap + * \param item The item to be pushed + */ + void Push(Schedule sch, IRModule mod, double score) { + Item item{sch, mod, StructuralHash()(mod), score}; + if (!in_heap.insert(item).second) { + return; + } + int size = heap.size(); + if (size < size_limit) { + // Heap is not full, just push + heap.emplace_back(item); + std::push_heap(heap.begin(), heap.end()); + } else if (item.score > heap.front().score) { + // if the item is better than the worst one in the heap, we can safely kick it out + std::pop_heap(heap.begin(), heap.end()); + heap.back() = item; + std::push_heap(heap.begin(), heap.end()); + } + // Otherwise, the item is worse than any other element in the heap + } + + /*! \brief Up-limit of the heap size */ + int size_limit; + /*! \brief The heap, the worse the topper */ + std::vector heap; + /*! \brief The traces that are in the heap */ + std::unordered_set in_heap; +}; + +struct PerThreadData { + IRModule mod{nullptr}; + TRandState rand_state{-1}; + std::function trace_sampler = nullptr; + std::function()> mutator_sampler = nullptr; + + /*! + * \brief Set the value for the trace and mutator samplers per thread. + * \param scores The predicted score for the given samples. + * \param genetic_mutate_prob The probability of mutation. + * \param mutator_probs The probability of each mutator as a dict. + */ + void Set(const std::vector& scores, double genetic_mutate_prob, + const Map& mutator_probs) { + trace_sampler = tir::MakeMultinomialSampler(&rand_state, scores); + mutator_sampler = MakeMutatorSampler(genetic_mutate_prob, mutator_probs, &rand_state); + } + + private: + /*! + * \brief Create a sampler function that picks mutators according to the mass function + * \param rand_state The random state for sampling + * \return The sampler created + */ + static std::function()> MakeMutatorSampler( + double genetic_mutate_prob, // + const Map& mutator_probs, // + TRandState* rand_state) { + std::vector> mutators; + std::vector masses; + mutators.push_back(NullOpt); + masses.push_back(1.0 - genetic_mutate_prob); + double total_mass_mutator = 0.0; + if (genetic_mutate_prob > 0) { + for (const auto& kv : mutator_probs) { + Mutator mutator = kv.first; + double mass = kv.second->value; + total_mass_mutator += mass; + mutators.push_back(mutator); + masses.push_back(mass * genetic_mutate_prob); + } + } + // Normalize the sum to 1.0 + if (total_mass_mutator == 0.0) { + masses[0] = 1.0; + for (int i = 1, n = masses.size(); i < n; ++i) { + masses[i] = 0.0; + } + } else if (total_mass_mutator != 1.0) { + for (int i = 1, n = masses.size(); i < n; ++i) { + masses[i] /= total_mass_mutator; + } + } + return [idx_sampler = tir::MakeMultinomialSampler(rand_state, masses), + mutators = std::move(mutators)]() -> Optional { + int i = idx_sampler(); + return mutators[i]; + }; + } +}; + +struct ConcurrentBitmask { + /*! The bit width. */ + static constexpr const int kBitWidth = 64; + /*! \brief The size of the concurrent bitmask. */ + int size; + /*! \brief The bitmasks. */ + std::vector bitmask; + /*! \brief The mutexes, one per kBitWidth(64 here) bitmasks. */ + std::vector mutexes; + + /*! + * \brief Constructor + * \param n The total slots managed by the concurrent bitmask. + */ + explicit ConcurrentBitmask(int n) + : size((n + kBitWidth - 1) / kBitWidth), bitmask(size, 0), mutexes(size) {} + /*! + * \brief Query and mark the given index if not visited before. + * \param x The index to concurrently check if used. If not, mark as used. + * \return Whether the index has been used before. + */ + bool QueryAndMark(int x) { + constexpr uint64_t one = 1; + std::unique_lock lock(mutexes[x / kBitWidth]); + if (bitmask[x / kBitWidth] & (one << (x % kBitWidth))) { + return false; + } else { + bitmask[x / kBitWidth] |= one << (x % kBitWidth); + return true; + } + } +}; + +/**************** Util Functions ****************/ + +/*! + * \brief Assemble measure candidates from the given candidate traces. + * \param traces The picked candidate traces. + * \return The assembled measure candidates. + */ +Array AssembleCandidates(const std::vector& picks, + const Array& args_info) { + Array measure_inputs; + measure_inputs.reserve(picks.size()); + for (const Schedule& sch : picks) { + measure_inputs.push_back(MeasureCandidate(sch, args_info)); + } + return measure_inputs; +} + +/*! + * \brief Predict the normalized score of each candidate. + * \param candidates The candidates for prediction + * \param task The search task + * \param space The search space + * \return The normalized score in the prediction + */ +std::vector PredictNormalizedScore(const std::vector& candidates, + const TuneContext& context, const CostModel& cost_model, + const Array& args_info) { + ICHECK(!candidates.empty()) << "Candidates given for score prediction can not be empty list!"; + std::vector scores = + cost_model->Predict(context, AssembleCandidates(candidates, args_info)); + for (double& score : scores) { + score = std::max(0.0, score); + } + return scores; +} + +/**************** Evolutionary Search ****************/ + +/*!\brief A search strategy that generates measure candidates using evolutionary search. */ +class EvolutionarySearchNode : public SearchStrategyNode { + public: + /*! \brief The state of the search strategy. */ + struct State { + /*! \brief The search strategy itself */ + EvolutionarySearchNode* self; + /*! \brief The design spaces. Decisions are not used so traces only. */ + Array design_spaces; + /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ + int st; + /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ + int ed; + + explicit State(EvolutionarySearchNode* self, Array design_spaces) + : self(self), design_spaces(design_spaces), st(0), ed(self->num_trials_per_iter) {} + + /*! + * \brief Pick up best candidates from database. + * \param num The number of traces to produce. + * \return The picked best candidates. + */ + inline std::vector PickBestFromDatabase(int num); + /*! + * \brief Sample the initial population from previous measured results and randomly generated + * traces via trace replaying. + * \param num The number of traces to produce. + * \return The initial population of traces sampled. + */ + inline std::vector SampleInitPopulation(int num); + /*! + * \brief Evolve the initial population using mutators and samplers. + * \param population The initial population of traces sampled. + * \param num The number of traces to produce. + * \return The evolved traces from initial population. + */ + inline std::vector EvolveWithCostModel(std::vector population, int num); + /*! + * \brief Pick final candidates from the given initial population and bests of evolved ones. + * \param inits The initial population of traces sampled. + * \param bests The best candidates predicted from evolved traces. + * \param num The number of traces to produce. + * \return The final picked candidates with a ratio of both. + */ + inline std::vector PickWithEpsGreedy(const std::vector& inits, + const std::vector& bests, int num); + /*! \brief An interface method to be called by it's counterpart in EvolutionarySearchNode */ + inline Optional> GenerateMeasureCandidates(); + /*! \brief An interface method to be called by it's counterpart in EvolutionarySearchNode */ + inline void NotifyRunnerResults(const TuneContext& context, + const Array& measure_candidates, + const Array& results); + }; + + /*! \brief The tuning context of the evolutionary search strategy. */ + const TuneContextNode* context_{nullptr}; + /*! \brief The target for the workload. */ + Target target_{nullptr}; + /*! \brief The metadata of the function arguments. */ + Array args_info_{nullptr}; + /*! \brief A Database for selecting useful candidates. */ + Database database_{nullptr}; + /*! \brief A cost model helping to explore the search space */ + CostModel cost_model_{nullptr}; + /*! \brief The postprocessors. */ + Array postprocs_{nullptr}; + /*! \brief Mutators and their probability mass */ + Map mutator_probs_{nullptr}; + /*! \brief The number of threads to use. To be initialized with TuneContext. */ + int num_threads_; + /*! \brief The random state. To be initialized with TuneContext. */ + TRandState rand_state_; + /*! \brief Pre thread data including module to be tuned and random state. */ + std::vector per_thread_data_; + /*! \brief The state of the search strategy. */ + std::unique_ptr state_ = nullptr; + /*! \brief The token registered for the given workload in database. */ + Workload token_{nullptr}; + + /*** Configuration: global ***/ + /*! \brief The number of trials per iteration. */ + int num_trials_per_iter; + /*! \brief The number of total trials. */ + int num_trials_total; + /*! \brief The population size in the evolutionary search. */ + int population_size; + /*** Configuration: the initial population ***/ + /*! \brief The ratio of measured states used in the initial population */ + double init_measured_ratio; + /*! \brief The maximum number to fail trace replaying. */ + int init_max_fail_count; + /*** Configuration: evolution ***/ + /*! \brief The number of iterations performed by generic algorithm. */ + int genetic_num_iters; + /*! \brief The probability to perform mutation */ + double genetic_mutate_prob; + /*! \brief The maximum number to try evolving the given trace. */ + int genetic_max_fail_count; + /*** Configuration: pick states for measurement ***/ + /*! \brief The ratio of measurements to use randomly sampled states. */ + double eps_greedy; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `context_` is not visited + // `target_` is not visited + // `args_info_` is not visited + // `database` is not visited + // `cost_model` is not visited + // `postprocs` is not visited + // `mutator_probs_` is not visited + // `num_threads` is not visited + // `rand_state_` is not visited + // `per_thread_data_` is not visited + // `state_` is not visited + + /*** Configuration: global ***/ + v->Visit("num_trials_total", &num_trials_total); + v->Visit("num_trials_per_iter", &num_trials_per_iter); + v->Visit("population_size", &population_size); + /*** Configuration: the initial population ***/ + v->Visit("init_measured_ratio", &init_measured_ratio); + v->Visit("init_max_fail_count", &init_max_fail_count); + /*** Configuration: evolution ***/ + v->Visit("genetic_num_iters", &genetic_num_iters); + v->Visit("genetic_mutate_prob", &genetic_mutate_prob); + v->Visit("genetic_max_fail_count", &genetic_max_fail_count); + /*** Configuration: pick states for measurement ***/ + v->Visit("eps_greedy", &eps_greedy); + } + + static constexpr const char* _type_key = "meta_schedule.EvolutionarySearch"; + TVM_DECLARE_FINAL_OBJECT_INFO(EvolutionarySearchNode, SearchStrategyNode); + + void InitializeWithTuneContext(const TuneContext& context) final { + CHECK(context.defined()) << "TuneContext must be defined!"; + CHECK(context->num_threads > 0) << "Number of threads has to be larger than 0."; + CHECK(context->target.defined()) << "Target must be defined!"; + this->context_ = context.get(); + this->target_ = context->target.value(); + this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(context->mod.value())); + this->mutator_probs_ = context->mutator_probs; + this->postprocs_ = context->postprocs; + this->num_threads_ = context->num_threads; + this->rand_state_ = ForkSeed(&context->rand_state); + this->cost_model_ = context->task_scheduler->cost_model.value(); + this->database_ = context->task_scheduler->database; + this->token_ = this->database_->CommitWorkload(context->mod.value()); + this->per_thread_data_.resize(this->num_threads_); + for (const auto& kv : this->mutator_probs_) { + double mass = kv.second->value; + TVM_META_SCHEDULE_CHECK_PROB_RANGE(mass, "mutator_probs"); + } + for (PerThreadData& data : this->per_thread_data_) { + data.mod = DeepCopyIRModule(context->mod.value()); + data.rand_state = ForkSeed(&this->rand_state_); + } + this->state_.reset(); + } + + void PreTuning(const Array& design_spaces) final { + ICHECK(!design_spaces.empty()); + ICHECK(this->state_ == nullptr); + // Change to traces + Array design_space_traces; + design_space_traces.reserve(design_spaces.size()); + for (const Schedule& space : design_spaces) { + design_space_traces.push_back(space->trace().value()->Simplified(true)); + } + this->state_ = std::make_unique(this, design_space_traces); + } + + void PostTuning() final { + ICHECK(this->state_ != nullptr); + this->state_.reset(); + } + + Optional> GenerateMeasureCandidates() final { + ICHECK(this->state_ != nullptr); + return this->state_->GenerateMeasureCandidates(); + } + + void NotifyRunnerResults(const TuneContext& context, + const Array& measure_candidates, + const Array& results) final { + ICHECK(this->state_ != nullptr); + this->state_->NotifyRunnerResults(context, measure_candidates, results); + } +}; + +std::vector EvolutionarySearchNode::State::PickBestFromDatabase(int num) { + std::vector measured_traces; + measured_traces.reserve(num); + Array top_records = self->database_->GetTopK(self->token_, num); + for (TuningRecord record : top_records) { + measured_traces.push_back(record->trace); + } + int actual_num = measured_traces.size(); + ThreadedTraceApply pp(self->postprocs_); + std::vector results(actual_num, Schedule{nullptr}); + auto f_proc_measured = [this, &measured_traces, &results, &pp](int thread_id, + int trace_id) -> void { + PerThreadData& data = self->per_thread_data_.at(thread_id); + TRandState* rand_state = &data.rand_state; + const IRModule& mod = data.mod; + tir::Trace trace = measured_traces.at(trace_id); + Schedule& result = results.at(trace_id); + ICHECK(!result.defined()); + if (Optional sch = pp.Apply(mod, trace, rand_state)) { + result = sch.value(); + } else { + LOG(FATAL) << "ValueError: Cannot postprocess the trace:\n" << trace; + throw; + } + }; + support::parallel_for_dynamic(0, actual_num, self->num_threads_, f_proc_measured); + return results; +} + +std::vector EvolutionarySearchNode::State::SampleInitPopulation(int num) { + ThreadedTraceApply pp(self->postprocs_); + std::vector results(num, Schedule{nullptr}); + auto f_proc_unmeasured = [this, &results, &pp](int thread_id, int trace_id) -> void { + PerThreadData& data = self->per_thread_data_.at(thread_id); + TRandState* rand_state = &data.rand_state; + const IRModule& mod = data.mod; + Schedule& result = results.at(trace_id); + ICHECK(!result.defined()); + for (int fail_count = 0; fail_count <= self->init_max_fail_count; ++fail_count) { + int design_space_index = tir::SampleInt(rand_state, 0, design_spaces.size()); + tir::Trace trace(design_spaces[design_space_index]->insts, {}); + if (Optional sch = pp.Apply(mod, trace, rand_state)) { + result = sch.value(); + break; + } + } + if (!result.defined()) { + LOG(FATAL) << "Sample-Init-Population failed over the maximum limit! Summary:\n" + << pp.SummarizeFailures(); + } + }; + support::parallel_for_dynamic(0, num, self->num_threads_, f_proc_unmeasured); + LOG(INFO) << "Sample-Init-Population summary:\n" << pp.SummarizeFailures(); + return results; +} + +std::vector EvolutionarySearchNode::State::EvolveWithCostModel( + std::vector population, int num) { + ICHECK_GT(num, 0); + // The heap to record best schedule, we do not consider schedules that are already measured + // Also we use `in_heap` to make sure items in the heap are de-duplicated + SizedHeap heap(num); + for (int iter = 0;; ++iter) { + // Predict normalized score with the cost model, + std::vector scores = PredictNormalizedScore(population, // + GetRef(self->context_), // + self->cost_model_, // + self->args_info_); + ICHECK_EQ(scores.size(), population.size()); + for (int i = 0, n = population.size(); i < n; ++i) { + Schedule sch = population.at(i); + IRModule mod = sch->mod(); + double score = scores.at(i); + if (!self->database_->HasWorkload(mod)) { + heap.Push(sch, mod, score); + } + } + // Discontinue once it reaches end of search + if (iter == self->genetic_num_iters) { + break; + } + // Set threaded samplers, with probability from predicated normalized throughputs + for (PerThreadData& data : self->per_thread_data_) { + data.Set(scores, self->genetic_mutate_prob, self->mutator_probs_); + } + ThreadedTraceApply pp(self->postprocs_); + ConcurrentBitmask cbmask(self->population_size); + std::vector next_population(self->population_size, Schedule{nullptr}); + // The worker function + auto f_find_candidate = [&cbmask, &population, &next_population, &pp, this](int thread_id, + int trace_id) { + // Prepare samplers + PerThreadData& data = self->per_thread_data_.at(thread_id); + TRandState* rand_state = &data.rand_state; + const IRModule& mod = data.mod; + std::function& trace_sampler = data.trace_sampler; + std::function()>& mutator_sampler = data.mutator_sampler; + Schedule& result = next_population.at(trace_id); + int sampled_trace_id = -1; + // Loop until success + for (int fail_count = 0; fail_count <= self->genetic_max_fail_count; ++fail_count) { + sampled_trace_id = trace_sampler(); + tir::Trace trace = population.at(sampled_trace_id)->trace().value(); + if (Optional opt_mutator = mutator_sampler()) { + // Decision: mutate + Mutator mutator = opt_mutator.value(); + if (Optional new_trace = mutator->Apply(trace, rand_state)) { + if (Optional sch = pp.Apply(mod, new_trace.value(), rand_state)) { + // note that sch's trace is different from new_trace + // because it contains post-processing information + result = sch.value(); + break; + } + } + } else if (cbmask.QueryAndMark(sampled_trace_id)) { + // Decision: do not mutate + break; + } + } + // if retry count exceeds the limit, reuse an old sample + if (!result.defined()) { + result = population.at(sampled_trace_id); + } + }; + support::parallel_for_dynamic(0, self->population_size, self->num_threads_, f_find_candidate); + population.swap(next_population); + LOG(INFO) << "Evolve iter #" << iter << " done. Summary:\n" << pp.SummarizeFailures(); + } + // Return the best states from the heap, sorting from higher score to lower ones + std::sort(heap.heap.begin(), heap.heap.end()); + std::vector results; + results.reserve(num); + for (const SizedHeap::Item& item : heap.heap) { + results.push_back(item.sch); + } + + constexpr int kNumScoresPerLine = 16; + std::ostringstream os; + int n = heap.heap.size(); + for (int st = 0; st < n; st += kNumScoresPerLine) { + os << std::endl; + int ed = std::min(st + kNumScoresPerLine, n); + os << "[" << (st + 1) << " : " << ed << "]:\t"; + for (int i = st; i < ed; ++i) { + if (i != st) { + os << " "; + } + os << std::fixed << std::setprecision(4) << heap.heap.at(i).score; + } + } + LOG(INFO) << "Scores of the best " << n << " candidates:" << os.str(); + return results; +} + +std::vector EvolutionarySearchNode::State::PickWithEpsGreedy( + const std::vector& unmeasured, const std::vector& bests, int num) { + int num_rands = num * self->eps_greedy; + int num_bests = num - num_rands; + std::vector rands = + tir::SampleWithoutReplacement(&self->rand_state_, unmeasured.size(), unmeasured.size()); + std::vector results; + results.reserve(num); + for (int i = 0, i_bests = 0, i_rands = 0; i < num; ++i) { + bool has_best = i_bests < static_cast(bests.size()); + bool has_rand = i_rands < static_cast(rands.size()); + // Pick a schedule + Schedule sch{nullptr}; + // If needs `bests`, then prefer `bests` + if (i < num_bests) { + if (has_best) { + sch = bests[i_bests++]; + } else if (has_rand) { + sch = unmeasured[rands[i_rands++]]; + } else { + break; + } + } else { + // Else prefer `rands` + if (has_rand) { + sch = unmeasured[rands[i_rands++]]; + } else if (has_best) { + sch = bests[i_bests++]; + } else { + break; + } + } + results.push_back(sch); + } + return results; +} + +Optional> EvolutionarySearchNode::State::GenerateMeasureCandidates() { + if (st >= self->num_trials_total) { + return NullOpt; + } + int sample_num = self->num_trials_per_iter; + if (ed > self->num_trials_total) { + sample_num = self->num_trials_total - st; + ed = self->num_trials_total; + } + ICHECK_LT(st, ed); + int pop = self->population_size; + std::vector inits; + inits.reserve(pop); + + LOG(INFO) << "Generating candidates......"; + std::vector measured = PickBestFromDatabase(pop * self->init_measured_ratio); + LOG(INFO) << "Picked top " << measured.size() << " candidate(s) from database"; + std::vector unmeasured = SampleInitPopulation(pop - measured.size()); + LOG(INFO) << "Sampled " << unmeasured.size() << " candidate(s)"; + inits.insert(inits.end(), measured.begin(), measured.end()); + inits.insert(inits.end(), unmeasured.begin(), unmeasured.end()); + ICHECK_EQ(inits.size(), self->population_size); + std::vector bests = EvolveWithCostModel(inits, sample_num); + LOG(INFO) << "Got " << bests.size() << " candidate(s) with evolutionary search"; + std::vector picks = PickWithEpsGreedy(unmeasured, bests, sample_num); + LOG(INFO) << "Sending " << picks.size() << " candidates(s) for measurement"; + return AssembleCandidates(picks, self->args_info_); +} + +void EvolutionarySearchNode::State::NotifyRunnerResults( + const TuneContext& context, const Array& measure_candidates, + const Array& results) { + st += results.size(); + ed += results.size(); +} + +SearchStrategy SearchStrategy::EvolutionarySearch(int num_trials_per_iter, // + int num_trials_total, // + int population_size, // + double init_measured_ratio, // + int init_max_fail_count, // + int genetic_num_iters, // + double genetic_mutate_prob, // + int genetic_max_fail_count, // + double eps_greedy) { + TVM_META_SCHEDULE_CHECK_PROB_RANGE(init_measured_ratio, "Initial measured ratio"); + TVM_META_SCHEDULE_CHECK_PROB_RANGE(genetic_mutate_prob, "Mutation probability"); + TVM_META_SCHEDULE_CHECK_PROB_RANGE(eps_greedy, "Greedy pick probability"); + ObjectPtr n = make_object(); + n->num_trials_per_iter = num_trials_per_iter; + n->num_trials_total = num_trials_total; + n->population_size = population_size; + n->init_measured_ratio = init_measured_ratio; + n->init_max_fail_count = init_max_fail_count; + n->genetic_num_iters = genetic_num_iters; + n->genetic_max_fail_count = genetic_max_fail_count; + n->genetic_mutate_prob = genetic_mutate_prob; + n->eps_greedy = eps_greedy; + return SearchStrategy(n); +} + +TVM_REGISTER_NODE_TYPE(EvolutionarySearchNode); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyEvolutionarySearch") + .set_body_typed(SearchStrategy::EvolutionarySearch); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 212e53aa500f..45efd9f76cef 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -36,6 +36,15 @@ namespace tir { */ TVM_DLL int32_t SampleInt(support::LinearCongruentialEngine::TRandState* rand_state, int32_t min_inclusive, int32_t max_exclusive); +/*! + * \brief Sample k random integers from given range without replacement, i.e, no duplication. + * \param rand_state The pointer to schedule's random state + * \param n The range is defined as 0 to n-1. + * \param k The total number of samples. + * \return The randomly selected samples from the n candidates. + */ +std::vector SampleWithoutReplacement( + support::LinearCongruentialEngine::TRandState* rand_state, int32_t n, int32_t k); /*! * \brief Sample once category from candidates according to the probability weights. * \param rand_state The pointer to schedule's random state @@ -47,6 +56,14 @@ TVM_DLL int32_t SampleInt(support::LinearCongruentialEngine::TRandState* rand_st TVM_DLL int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state, const Array& candidates, const Array& probs, Optional* decision); +/*! + * \brief Create a sampling function that does multinomial sampling. + * \param rand_state The random state. + * \param weights The weights for multinomial sampling. + * \return The multinomial sampling function. + */ +TVM_DLL std::function MakeMultinomialSampler( + support::LinearCongruentialEngine::TRandState* rand_state, const std::vector& weights); /*! * \brief Sample the factors to perfect tile a specific loop * \param rand_state The random state diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 171838572dbb..83ef1e20be60 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -187,6 +187,28 @@ int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_st return candidates[i]; } +std::function MakeMultinomialSampler( + support::LinearCongruentialEngine::TRandState* rand_state, const std::vector& weights) { + ICHECK(!weights.empty()); + std::vector sums; + sums.reserve(weights.size()); + double sum = 0.0; + for (double w : weights) { + sums.push_back(sum += w); + } + return [rng = support::LinearCongruentialEngine(rand_state).ForkSeed(), + dist = std::uniform_real_distribution(0.0, sum), + sums = std::move(sums)]() mutable -> int32_t { + support::LinearCongruentialEngine rand_(&rng); + double p = dist(rand_); + int32_t idx = std::lower_bound(sums.begin(), sums.end(), p) - sums.begin(); + int32_t n = sums.size(); + CHECK_LE(0, idx); + CHECK_LE(idx, n); + return (idx == n) ? (n - 1) : idx; + }; +} + std::vector SamplePerfectTile(support::LinearCongruentialEngine::TRandState* rand_state, int32_t extent, int32_t n_splits) { CHECK_GE(extent, 1) << "ValueError: Cannot tile a loop with 0 or negative extent"; diff --git a/tests/python/unittest/test_meta_schedule_search_strategy.py b/tests/python/unittest/test_meta_schedule_search_strategy.py index a4d32175eb0b..b16eab712375 100644 --- a/tests/python/unittest/test_meta_schedule_search_strategy.py +++ b/tests/python/unittest/test_meta_schedule_search_strategy.py @@ -17,16 +17,27 @@ """ Test Meta Schedule SearchStrategy """ # pylint: disable=missing-function-docstring import sys +from typing import List, Optional, Tuple, Union + +import numpy as np import pytest import tvm +from tvm.ir import IRModule from tvm.meta_schedule import TuneContext -from tvm.meta_schedule.runner import RunnerResult +from tvm.meta_schedule.builder import LocalBuilder +from tvm.meta_schedule.cost_model import PyCostModel +from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload +from tvm.meta_schedule.mutator.mutator import PyMutator +from tvm.meta_schedule.runner import LocalRunner, RunnerResult from tvm.meta_schedule.search_strategy import ( + EvolutionarySearch, + MeasureCandidate, ReplayFunc, ReplayTrace, SearchStrategy, ) from tvm.meta_schedule.space_generator import ScheduleFn +from tvm.meta_schedule.task_scheduler import RoundRobin from tvm.script import tir as T from tvm.tir.schedule import Schedule, Trace @@ -80,11 +91,11 @@ def test_meta_schedule_replay_func(TestClass: SearchStrategy): # pylint: disabl num_trials_total = 20 strategy = TestClass(num_trials_per_iter=num_trials_per_iter, num_trials_total=num_trials_total) - tune_context = TuneContext(mod=Matmul, space_generator=ScheduleFn(sch_fn=_schedule_matmul)) - tune_context.space_generator.initialize_with_tune_context(tune_context) - spaces = tune_context.space_generator.generate_design_space(tune_context.mod) + context = TuneContext(mod=Matmul, space_generator=ScheduleFn(sch_fn=_schedule_matmul)) + context.space_generator.initialize_with_tune_context(context) + spaces = context.space_generator.generate_design_space(context.mod) - strategy.initialize_with_tune_context(tune_context) + strategy.initialize_with_tune_context(context) strategy.pre_tuning(spaces) (correct_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul) num_trials_each_iter: List[int] = [] @@ -99,11 +110,161 @@ def test_meta_schedule_replay_func(TestClass: SearchStrategy): # pylint: disabl remove_decisions=(isinstance(strategy, ReplayTrace)), ) runner_results.append(RunnerResult(run_secs=[0.11, 0.41, 0.54], error_msg=None)) - strategy.notify_runner_results(tune_context, candidates, runner_results) + strategy.notify_runner_results(context, candidates, runner_results) candidates = strategy.generate_measure_candidates() strategy.post_tuning() assert num_trials_each_iter == [7, 7, 6] +def test_meta_schedule_evolutionary_search(): # pylint: disable = invalid-name + class DummyMutator(PyMutator): + """Dummy Mutator for testing""" + + def initialize_with_tune_context(self, context: "TuneContext") -> None: + pass + + def apply(self, trace: Trace) -> Optional[Trace]: + return Trace(trace.insts, {}) + + class DummyDatabase(PyDatabase): + """Dummy Database for testing""" + + def __init__(self): + super().__init__() + self.records = [] + self.workload_reg = [] + + def has_workload(self, mod: IRModule) -> bool: + for workload in self.workload_reg: + if tvm.ir.structural_equal(workload.mod, mod): + return True + return False + + def commit_tuning_record(self, record: TuningRecord) -> None: + self.records.append(record) + + def commit_workload(self, mod: IRModule) -> Workload: + for workload in self.workload_reg: + if tvm.ir.structural_equal(workload.mod, mod): + return workload + workload = Workload(mod) + self.workload_reg.append(workload) + return workload + + def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]: + return list( + filter( + lambda x: x.workload == workload, + sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)), + ) + )[: int(top_k)] + + def __len__(self) -> int: + return len(self.records) + + def print_results(self) -> None: + print("\n".join([str(r) for r in self.records])) + + class RandomModel(PyCostModel): + """Random cost model for testing""" + + random_state: Union[Tuple[str, np.ndarray, int, int, float], dict] + path: Optional[str] + + def __init__( + self, + *, + seed: Optional[int] = None, + path: Optional[str] = None, + max_range: Optional[int] = 100, + ): + super().__init__() + if path is not None: + self.load(path) + else: + np.random.seed(seed) + self.random_state = np.random.get_state() + self.max_range = max_range + + def load(self, path: str) -> None: + self.random_state = tuple(np.load(path, allow_pickle=True)) + + def save(self, path: str) -> None: + np.save(path, np.array(self.random_state, dtype=object), allow_pickle=True) + + def update( + self, + context: TuneContext, + candidates: List[MeasureCandidate], + results: List[RunnerResult], + ) -> None: + pass + + def predict(self, context: TuneContext, candidates: List[MeasureCandidate]) -> np.ndarray: + np.random.set_state(self.random_state) + result = np.random.rand(len(candidates)) * self.max_range + self.random_state = np.random.get_state() + return result + + num_trials_per_iter = 10 + num_trials_total = 100 + + strategy = EvolutionarySearch( + num_trials_per_iter=num_trials_per_iter, + num_trials_total=num_trials_total, + population_size=5, + init_measured_ratio=0.1, + init_max_fail_count=10, + genetic_num_iters=3, + genetic_mutate_prob=0.5, + genetic_max_fail_count=10, + eps_greedy=0.9, + ) + context = TuneContext( + mod=Matmul, + space_generator=ScheduleFn(sch_fn=_schedule_matmul), + mutator_probs={ + DummyMutator(): 1.0, + }, + target=tvm.target.Target("llvm"), + num_threads=1, # because we are using a mutator from the python side + ) + _scheduler = RoundRobin( + tasks=[context], + builder=LocalBuilder(), + runner=LocalRunner(), + database=DummyDatabase(), + cost_model=RandomModel(), + measure_callbacks=[], + ) + context.space_generator.initialize_with_tune_context(context) + spaces = context.space_generator.generate_design_space(context.mod) + + strategy.initialize_with_tune_context(context) + strategy.pre_tuning(spaces) + (correct_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul) + num_trials_each_iter: List[int] = [] + candidates = strategy.generate_measure_candidates() + while candidates is not None: + num_trials_each_iter.append(len(candidates)) + runner_results: List[RunnerResult] = [] + for candidate in candidates: + _is_trace_equal( + candidate.sch, + correct_sch, + remove_decisions=(isinstance(strategy, ReplayTrace)), + ) + runner_results.append(RunnerResult(run_secs=[0.11, 0.41, 0.54], error_msg=None)) + strategy.notify_runner_results(context, candidates, runner_results) + candidates = strategy.generate_measure_candidates() + strategy.post_tuning() + print(num_trials_each_iter) + correct_count = 10 # For each iteration except the last one + assert num_trials_each_iter == [correct_count] * (num_trials_total // correct_count) + ( + [num_trials_total % correct_count] if num_trials_total % correct_count != 0 else [] + ) + del _scheduler + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))