From 3c4e9c179f1354ace88fc5d976828e8bdd68da8c Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Tue, 21 Dec 2021 22:27:30 -0800 Subject: [PATCH 1/4] Modify TuneContext, TaskScheduler & SearchStrategy functions. Co-authored-by: Junru Shao Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin Co-authored-by: Siyuan Feng --- include/tvm/meta_schedule/cost_model.h | 20 +-- include/tvm/meta_schedule/feature_extractor.h | 12 +- include/tvm/meta_schedule/mutator.h | 146 +++++++++++++++ include/tvm/meta_schedule/postproc.h | 167 ++++++++++++++++++ include/tvm/meta_schedule/search_strategy.h | 46 ++++- include/tvm/meta_schedule/task_scheduler.h | 16 +- include/tvm/meta_schedule/tune_context.h | 31 +++- include/tvm/support/random_engine.h | 8 + python/tvm/auto_scheduler/search_task.py | 3 +- .../tvm/auto_scheduler/workload_registry.py | 5 +- .../meta_schedule/cost_model/cost_model.py | 22 ++- .../meta_schedule/cost_model/random_model.py | 8 +- .../feature_extractor/feature_extractor.py | 10 +- .../random_feature_extractor.py | 2 +- python/tvm/meta_schedule/mutator/__init__.py | 22 +++ python/tvm/meta_schedule/mutator/mutator.py | 88 +++++++++ python/tvm/meta_schedule/postproc/__init__.py | 18 ++ python/tvm/meta_schedule/postproc/postproc.py | 90 ++++++++++ .../schedule_rule/schedule_rule.py | 10 +- .../search_strategy/replay_trace.py | 13 +- .../search_strategy/search_strategy.py | 41 +++-- .../space_generator/schedule_fn.py | 4 +- .../space_generator/space_generator.py | 10 +- .../task_scheduler/round_robin.py | 26 ++- .../task_scheduler/task_scheduler.py | 20 ++- python/tvm/meta_schedule/tune_context.py | 27 ++- src/meta_schedule/cost_model/cost_model.cc | 4 +- .../search_strategy/replay_trace.cc | 64 ++++--- .../task_scheduler/round_robin.cc | 15 +- .../task_scheduler/task_scheduler.cc | 126 ++++++------- src/meta_schedule/tune_context.cc | 45 +++-- src/meta_schedule/utils.h | 78 +++++++- src/tir/schedule/concrete_schedule.cc | 2 +- src/tir/schedule/traced_schedule.cc | 2 +- .../unittest/test_meta_schedule_cost_model.py | 12 +- .../test_meta_schedule_feature_extractor.py | 4 +- .../test_meta_schedule_post_order_apply.py | 10 +- .../test_meta_schedule_search_strategy.py | 73 ++++---- .../test_meta_schedule_task_scheduler.py | 77 +++++--- 39 files changed, 1118 insertions(+), 259 deletions(-) create mode 100644 include/tvm/meta_schedule/mutator.h create mode 100644 include/tvm/meta_schedule/postproc.h create mode 100644 python/tvm/meta_schedule/mutator/__init__.py create mode 100644 python/tvm/meta_schedule/mutator/mutator.py create mode 100644 python/tvm/meta_schedule/postproc/__init__.py create mode 100644 python/tvm/meta_schedule/postproc/postproc.py diff --git a/include/tvm/meta_schedule/cost_model.h b/include/tvm/meta_schedule/cost_model.h index b05dc3c11802..6fadc2fb9c13 100644 --- a/include/tvm/meta_schedule/cost_model.h +++ b/include/tvm/meta_schedule/cost_model.h @@ -51,20 +51,20 @@ class CostModelNode : public runtime::Object { /*! * \brief Update the cost model given running results. - * \param tune_context The tuning context. + * \param context The tuning context. * \param candidates The measure candidates. * \param results The running results of the measure candidates. */ - virtual void Update(const TuneContext& tune_context, const Array& candidates, + virtual void Update(const TuneContext& context, const Array& candidates, const Array& results) = 0; /*! * \brief Predict the normalized score (the larger the better) of given measure candidates. - * \param tune_context The tuning context. + * \param context The tuning context. * \param candidates The measure candidates. * \return The predicted normalized score. */ - virtual std::vector Predict(const TuneContext& tune_context, + virtual std::vector Predict(const TuneContext& context, const Array& candidates) = 0; static constexpr const char* _type_key = "meta_schedule.CostModel"; @@ -86,7 +86,7 @@ class PyCostModelNode : public CostModelNode { using FSave = runtime::TypedPackedFunc; /*! * \brief Update the cost model given running results. - * \param tune_context The tuning context. + * \param context The tuning context. * \param candidates The measure candidates. * \param results The running results of the measure candidates. * \return Whether cost model was updated successfully. @@ -95,7 +95,7 @@ class PyCostModelNode : public CostModelNode { const Array&)>; /*! * \brief Predict the running results of given measure candidates. - * \param tune_context The tuning context. + * \param context The tuning context. * \param candidates The measure candidates. * \param p_addr The address to save the the estimated running results. */ @@ -135,17 +135,17 @@ class PyCostModelNode : public CostModelNode { ICHECK(f_save != nullptr) << "PyCostModel's Save method not implemented!"; f_save(path); } - void Update(const TuneContext& tune_context, const Array& candidates, + void Update(const TuneContext& context, const Array& candidates, const Array& results) { ICHECK(f_update != nullptr) << "PyCostModel's Update method not implemented!"; - f_update(tune_context, candidates, results); + f_update(context, candidates, results); } - std::vector Predict(const TuneContext& tune_context, + std::vector Predict(const TuneContext& context, const Array& candidates) { ICHECK(f_predict != nullptr) << "PyCostModel's Predict method not implemented!"; std::vector result(candidates.size(), 0.0); - f_predict(tune_context, candidates, result.data()); + f_predict(context, candidates, result.data()); return result; } diff --git a/include/tvm/meta_schedule/feature_extractor.h b/include/tvm/meta_schedule/feature_extractor.h index ee5d94c13c98..c2ca2beb9b68 100644 --- a/include/tvm/meta_schedule/feature_extractor.h +++ b/include/tvm/meta_schedule/feature_extractor.h @@ -37,11 +37,11 @@ class FeatureExtractorNode : public runtime::Object { /*! * \brief Extract features from the given measure candidate. - * \param tune_context The tuning context for feature extraction. + * \param context The tuning context for feature extraction. * \param candidates The measure candidates to extract features from. * \return The feature ndarray extracted. */ - virtual Array ExtractFrom(const TuneContext& tune_context, + virtual Array ExtractFrom(const TuneContext& context, const Array& candidates) = 0; static constexpr const char* _type_key = "meta_schedule.FeatureExtractor"; @@ -53,12 +53,12 @@ class PyFeatureExtractorNode : public FeatureExtractorNode { public: /*! * \brief Extract features from the given measure candidate. - * \param tune_context The tuning context for feature extraction. + * \param context The tuning context for feature extraction. * \param candidates The measure candidates to extract features from. * \return The feature ndarray extracted. */ using FExtractFrom = runtime::TypedPackedFunc( - const TuneContext& tune_context, const Array& candidates)>; + const TuneContext& context, const Array& candidates)>; /*! * \brief Get the feature extractor as string with name. * \return The string of the feature extractor. @@ -75,10 +75,10 @@ class PyFeatureExtractorNode : public FeatureExtractorNode { // `f_as_string` is not visited } - Array ExtractFrom(const TuneContext& tune_context, + Array ExtractFrom(const TuneContext& context, const Array& candidates) { ICHECK(f_extract_from != nullptr) << "PyFeatureExtractor's ExtractFrom method not implemented!"; - return f_extract_from(tune_context, candidates); + return f_extract_from(context, candidates); } static constexpr const char* _type_key = "meta_schedule.PyFeatureExtractor"; diff --git a/include/tvm/meta_schedule/mutator.h b/include/tvm/meta_schedule/mutator.h new file mode 100644 index 000000000000..e3fa847c3748 --- /dev/null +++ b/include/tvm/meta_schedule/mutator.h @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_META_SCHEDULE_MUTATOR_H_ +#define TVM_META_SCHEDULE_MUTATOR_H_ + +#include + +namespace tvm { +namespace meta_schedule { + +class TuneContext; + +/*! \brief Mutator is designed to mutate the trace to explore the design space. */ +class MutatorNode : public runtime::Object { + public: + /*! \brief Virtual destructor. */ + virtual ~MutatorNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + /*! + * \brief Initialize the design space generator with tuning context. + * \param context The tuning context for initialization. + * \note This method is supposed to be called only once before every other method. + */ + virtual void InitializeWithTuneContext(const TuneContext& context) = 0; + + /*! + * \brief Apply the mutator function to the given trace. + * \param trace The given trace for mutation. + * \param rand_state The random state for mutation. + * \return None if mutator failed, otherwise return the mutated trace. + */ + virtual Optional Apply(const tir::Trace& trace, + support::LinearCongruentialEngine::TRandState* rand_state) = 0; + + static constexpr const char* _type_key = "meta_schedule.Mutator"; + TVM_DECLARE_BASE_OBJECT_INFO(MutatorNode, Object); +}; + +/*! \brief The mutator with customized methods on the python-side. */ +class PyMutatorNode : public MutatorNode { + public: + /*! + * \brief The function type of `InitializeWithTuneContext` method. + * \param context The tuning context for initialization. + */ + using FInitializeWithTuneContext = runtime::TypedPackedFunc; + /*! + * \brief Apply the mutator function to the given trace. + * \param trace The given trace for mutation. + * \return None if mutator failed, otherwise return the mutated trace. + */ + using FApply = runtime::TypedPackedFunc( + const tir::Trace&, support::LinearCongruentialEngine::TRandState rand_state)>; + /*! + * \brief Get the mutator as string with name. + * \return The string of the mutator. + */ + using FAsString = runtime::TypedPackedFunc; + + /*! \brief The packed function to the `InitializeWithTuneContext` function. */ + FInitializeWithTuneContext f_initialize_with_tune_context; + /*! \brief The packed function to the `Apply` function. */ + FApply f_apply; + /*! \brief The packed function to the `AsString` function. */ + FAsString f_as_string; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_initialize_with_tune_context` is not visited + // `f_apply` is not visited + // `f_as_string` is not visited + } + + void InitializeWithTuneContext(const TuneContext& context) final { + ICHECK(f_initialize_with_tune_context != nullptr) + << "PyMutator's InitializeWithTuneContext method not implemented!"; + this->f_initialize_with_tune_context(context); + } + + Optional Apply(const tir::Trace& trace, + support::LinearCongruentialEngine::TRandState* rand_state) final { + ICHECK(f_apply != nullptr) << "PyMutator's Apply method not implemented!"; + return this->f_apply(trace, *rand_state); + } + + static constexpr const char* _type_key = "meta_schedule.PyMutator"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyMutatorNode, MutatorNode); +}; + +/*! + * \brief Managed reference to MutatorNode + * \sa MutatorNode + */ +class Mutator : public runtime::ObjectRef { + public: + /*! \brief Create a Mutator that mutates the tile size. */ + TVM_DLL static Mutator MutateTileSize(); + /*! + * \brief Create a Mutator that mutates the parallel extent + * \param max_jobs_per_core The maximum number of parallel jobs per core. + * \return The created mutator. + */ + TVM_DLL static Mutator MutateParallel(int64_t max_jobs_per_core); + /*! \brief Create a Mutator that mutates auto unroll step */ + TVM_DLL static Mutator MutateUnroll(); + /*! + * \brief Create a Mutator that mutates the outcome of SampleComputeLocation + * \return The mutator created + */ + TVM_DLL static Mutator MutateComputeLocation(); + /*! + * \brief Create a mutator with customized methods on the python-side. + * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. + * \param f_apply The packed function of `Apply`. + * \param f_as_string The packed function of `AsString`. + * \return The mutator created. + */ + TVM_DLL static Mutator PyMutator( + PyMutatorNode::FInitializeWithTuneContext f_initialize_with_tune_context, // + PyMutatorNode::FApply f_apply, // + PyMutatorNode::FAsString f_as_string); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Mutator, ObjectRef, MutatorNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_MUTATOR_H_ diff --git a/include/tvm/meta_schedule/postproc.h b/include/tvm/meta_schedule/postproc.h new file mode 100644 index 000000000000..93e8be0cd129 --- /dev/null +++ b/include/tvm/meta_schedule/postproc.h @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_META_SCHEDULE_POSTPROC_H_ +#define TVM_META_SCHEDULE_POSTPROC_H_ + +#include + +namespace tvm { +namespace meta_schedule { + +class TuneContext; + +/*! + * \brief Rules to apply a postprocessor to a schedule. + */ +class PostprocNode : public runtime::Object { + public: + /*! \brief Virtual destructor. */ + virtual ~PostprocNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + /*! + * \brief Initialize the design space generator with tuning context. + * \param context The tuning context for initialization. + * \note This method is supposed to be called only once before every other method. + */ + virtual void InitializeWithTuneContext(const TuneContext& context) = 0; + + /*! + * \brief Apply a postprocessor to the given schedule. + * \param sch The schedule to be post processed. + * \return Whether the postprocessor was successfully applied. + */ + virtual bool Apply(const tir::Schedule& sch) = 0; + + static constexpr const char* _type_key = "meta_schedule.Postproc"; + TVM_DECLARE_BASE_OBJECT_INFO(PostprocNode, Object); +}; + +/*! \brief The postprocessor with customized methods on the python-side. */ +class PyPostprocNode : public PostprocNode { + public: + /*! + * \brief The function type of `InitializeWithTuneContext` method. + * \param context The tuning context for initialization. + */ + using FInitializeWithTuneContext = runtime::TypedPackedFunc; + /*! + * \brief Apply a postprocessor to the given schedule. + * \param sch The schedule to be post processed. + * \return Whether the postprocessor was successfully applied. + */ + using FApply = runtime::TypedPackedFunc; + /*! + * \brief Get the postprocessor function as string with name. + * \return The string of the postprocessor function. + */ + using FAsString = runtime::TypedPackedFunc; + + /*! \brief The packed function to the `InitializeWithTuneContext` function. */ + FInitializeWithTuneContext f_initialize_with_tune_context; + /*! \brief The packed function to the `Apply` function. */ + FApply f_apply; + /*! \brief The packed function to the `AsString` function. */ + FAsString f_as_string; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_initialize_with_tune_context` is not visited + // `f_apply` is not visited + // `f_as_string` is not visited + } + + void InitializeWithTuneContext(const TuneContext& context) final { + ICHECK(f_initialize_with_tune_context != nullptr) + << "PyPostproc's InitializeWithTuneContext method not implemented!"; + this->f_initialize_with_tune_context(context); + } + + bool Apply(const tir::Schedule& sch) final { + ICHECK(f_apply != nullptr) << "PyPostproc's Apply method not implemented!"; + return this->f_apply(sch); + } + + static constexpr const char* _type_key = "meta_schedule.PyPostproc"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyPostprocNode, PostprocNode); +}; + +/*! + * \brief Managed reference to PostprocNode + * \sa PostprocNode + */ +class Postproc : public runtime::ObjectRef { + public: + /*! + * \brief Create a postprocessor with customized methods on the python-side. + * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. + * \param f_apply The packed function of `Apply`. + * \param f_as_string The packed function of `AsString`. + * \return The postprocessor created. + */ + TVM_DLL static Postproc PyPostproc( + PyPostprocNode::FInitializeWithTuneContext f_initialize_with_tune_context, // + PyPostprocNode::FApply f_apply, // + PyPostprocNode::FAsString f_as_string); + /*! + * \brief Create a postprocessor that checks if all loops are static + * \return The postprocessor created + */ + TVM_DLL static Postproc DisallowDynamicLoop(); + /*! + * \brief Create a postprocessor that rewrites the cooperative fetch annotation to + * actual vectorized cooperative fetching in loop bindings. + * \return The postprocessor created. + */ + TVM_DLL static Postproc RewriteCooperativeFetch(); + /*! + * \brief Creates a postprocessor that applies parallelization, vectorization and auto unrolling + * according to the annotation of each block + * \return The postprocessor created + */ + TVM_DLL static Postproc RewriteParallelVectorizeUnroll(); + /*! + * \brief Create a postprocessor that rewrites reduction block by moving the init block out. + * \return The postprocessor created. + */ + TVM_DLL static Postproc RewriteReductionBlock(); + /*! + * \brief Create a postprocessor that adds thread binding to unbound blocks + * \return The postprocessor created. + */ + TVM_DLL static Postproc RewriteUnboundBlock(); + /*! + * \brief Create a postprocessor that tensorize Tensor Core related components + * \return The postprocessor created. + */ + TVM_DLL static Postproc RewriteTensorCore(); + + /*! + * \brief Creates a postprocessor that verifies if the GPU code is correct + * \return The postprocessor created + */ + TVM_DLL static Postproc VerifyGPUCode(); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Postproc, ObjectRef, PostprocNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_POSTPROC_H_ diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h index e1c68c8a1a11..e8d2cbd70382 100644 --- a/include/tvm/meta_schedule/search_strategy.h +++ b/include/tvm/meta_schedule/search_strategy.h @@ -28,6 +28,8 @@ namespace meta_schedule { // Forward declaration class TuneContext; +class CostModel; +class Database; /*! \brief The schedule (with input shapes) to be measured. */ class MeasureCandidateNode : public runtime::Object { @@ -133,9 +135,13 @@ class SearchStrategyNode : public runtime::Object { /*! * \brief Update the search strategy with measurement results. + * \param context The tuning context. + * \param measure_candidates The candidates to be measured. * \param results The measurement results from the runner. */ - virtual void NotifyRunnerResults(const Array& results) = 0; + virtual void NotifyRunnerResults(const TuneContext& context, + const Array& measure_candidates, + const Array& results) = 0; static constexpr const char* _type_key = "meta_schedule.SearchStrategy"; TVM_DECLARE_BASE_OBJECT_INFO(SearchStrategyNode, Object); @@ -165,7 +171,8 @@ class PySearchStrategyNode : public SearchStrategyNode { * \brief The function type of `NotifyRunnerResults` method. * \param results The measurement results from the runner. */ - using FNotifyRunnerResults = runtime::TypedPackedFunc&)>; + using FNotifyRunnerResults = runtime::TypedPackedFunc&, const Array&)>; /*! \brief The packed function to the `InitializeWithTuneContext` method. */ FInitializeWithTuneContext f_initialize_with_tune_context; @@ -208,10 +215,12 @@ class PySearchStrategyNode : public SearchStrategyNode { return this->f_generate_measure_candidates(); } - void NotifyRunnerResults(const Array& results) final { + void NotifyRunnerResults(const TuneContext& context, + const Array& measure_candidates, + const Array& results) final { ICHECK(f_notify_runner_results != nullptr) << "PySearchStrategy's NotifyRunnerResults method not implemented!"; - this->f_notify_runner_results(results); + this->f_notify_runner_results(context, measure_candidates, results); } static constexpr const char* _type_key = "meta_schedule.PySearchStrategy"; @@ -247,6 +256,35 @@ class SearchStrategy : public runtime::ObjectRef { */ TVM_DLL static SearchStrategy ReplayTrace(int num_trials_per_iter, int num_trials_total); + /*! + * \brief Constructor of replay func search strategy. + * \param num_trials_per_iter The number of trials per iteration, i.e., the batch size. + * \param num_trials_total The total number of trials for func replaying. + */ + TVM_DLL static SearchStrategy ReplayFunc(int num_trials_per_iter, int num_trials_total); + + /*! + * \brief Constructor of evolutionary search strategy. + * \param num_trials_per_iter The number of trials per iteration, i.e., the batch size. + * \param num_trials_total The total number of trials for evolutionary search. + * \param population_size The initial sample population. + * \param init_measured_ratio The ratio of measures samples in initial population. + * \param init_max_fail_count The maximum number to fail trace replaying. + * \param genetic_num_iters The iterations to run the genetic algorithm. + * \param genetic_mutate_prob The probability of mutation. + * \param genetic_max_fail_count The maximum number to try evolving the given trace. + * \param eps_greedy The ratio to select samples in a greedy fashion via their predicted score. + */ + TVM_DLL static SearchStrategy EvolutionarySearch(int num_trials_per_iter, // + int num_trials_total, // + int population_size, // + double init_measured_ratio, // + int init_max_fail_count, // + int genetic_num_iters, // + double genetic_mutate_prob, // + int genetic_max_fail_count, // + double eps_greedy); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchStrategy, ObjectRef, SearchStrategyNode); }; diff --git a/include/tvm/meta_schedule/task_scheduler.h b/include/tvm/meta_schedule/task_scheduler.h index f28c33dc4fe4..0284a55e0d03 100644 --- a/include/tvm/meta_schedule/task_scheduler.h +++ b/include/tvm/meta_schedule/task_scheduler.h @@ -20,7 +20,9 @@ #define TVM_META_SCHEDULE_TASK_SCHEDULER_H_ #include +#include #include +#include #include #include @@ -78,7 +80,7 @@ class TaskSchedulerNode : public runtime::Object { /*! \brief The list of measure callbacks of the scheduler. */ Array measure_callbacks; - /*! \brief The default desctructor. */ + /*! \brief The default destructor. */ virtual ~TaskSchedulerNode() = default; void VisitAttrs(tvm::AttrVisitor* v) { @@ -248,15 +250,19 @@ class TaskScheduler : public runtime::ObjectRef { * \param runner The runner of the scheduler. * \param database The database of the scheduler. */ - TVM_DLL static TaskScheduler RoundRobin(Array tasks, // - Builder builder, // - Runner runner, // - Database database); // + TVM_DLL static TaskScheduler RoundRobin(Array tasks, // + Builder builder, // + Runner runner, // + Database database, // + Optional cost_model, // + Optional> measure_callbacks); TVM_DLL static TaskScheduler PyTaskScheduler( Array tasks, // Builder builder, // Runner runner, // Database database, // + Optional cost_model, // + Optional> measure_callbacks, // PyTaskSchedulerNode::FTune f_tune, // PyTaskSchedulerNode::FInitializeTask f_initialize_task, // PyTaskSchedulerNode::FSetTaskStopped f_set_task_stopped, // diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h index 6eacd4d4f12a..ff3a14c076e4 100644 --- a/include/tvm/meta_schedule/tune_context.h +++ b/include/tvm/meta_schedule/tune_context.h @@ -20,7 +20,12 @@ #define TVM_META_SCHEDULE_TUNE_CONTEXT_H_ #include +#include +#include +#include +#include #include +#include #include #include #include @@ -28,6 +33,8 @@ namespace tvm { namespace meta_schedule { +class TaskSchedulerNode; + /*! \brief The auto tuning context. */ class TuneContextNode : public runtime::Object { public: @@ -41,19 +48,27 @@ class TuneContextNode : public runtime::Object { Optional search_strategy; /*! \brief The schedule rules. */ Array sch_rules; + /*! \brief The postprocessors. */ + Array postprocs; + /*! \brief The probability of using certain mutator. */ + Map mutator_probs; /*! \brief The name of the tuning task. */ - Optional task_name; + String task_name; /*! \brief The random state. */ support::LinearCongruentialEngine::TRandState rand_state; /*! \brief The number of threads to be used. */ int num_threads; + /*! \brief The task scheduler that owns the tune context */ + const TaskSchedulerNode* task_scheduler; /*! \brief Whether the tuning task has been stopped or finished. */ bool is_stopped; - /*! \brief Packed functions to fetch the runner results asynchronously. */ - Optional> runner_futures; /*! \brief The measure candidates. */ Optional> measure_candidates; + /*! \brief The building results. */ + Optional> builder_results; + /*! \brief Packed functions to fetch the runner results asynchronously. */ + Optional> runner_futures; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("mod", &mod); @@ -61,14 +76,20 @@ class TuneContextNode : public runtime::Object { v->Visit("space_generator", &space_generator); v->Visit("search_strategy", &search_strategy); v->Visit("sch_rules", &sch_rules); + v->Visit("postprocs", &postprocs); + v->Visit("mutator_probs", &mutator_probs); v->Visit("task_name", &task_name); v->Visit("rand_state", &rand_state); v->Visit("num_threads", &num_threads); v->Visit("is_stopped", &is_stopped); + v->Visit("builder_results", &builder_results); v->Visit("runner_futures", &runner_futures); v->Visit("measure_candidates", &measure_candidates); } + /*! \brief Initialize members that needs initialization with tune context. */ + void Initialize(); + static constexpr const char* _type_key = "meta_schedule.TuneContext"; TVM_DECLARE_FINAL_OBJECT_INFO(TuneContextNode, Object); }; @@ -86,6 +107,8 @@ class TuneContext : public runtime::ObjectRef { * \param space_generator The design space generator. * \param search_strategy The search strategy. * \param sch_rules The schedule rules. + * \param postprocs The postprocessors. + * \param mutator_probs The probability of using certain mutator. * \param task_name The name of the tuning task. * \param rand_state The random state. * \param num_threads The number of threads to be used. @@ -95,6 +118,8 @@ class TuneContext : public runtime::ObjectRef { Optional space_generator, // Optional search_strategy, // Optional> sch_rules, // + Optional> postprocs, // + Optional> mutator_probs, // Optional task_name, // support::LinearCongruentialEngine::TRandState rand_state, // int num_threads); diff --git a/include/tvm/support/random_engine.h b/include/tvm/support/random_engine.h index fcd2326050ed..89b1e9117ff4 100644 --- a/include/tvm/support/random_engine.h +++ b/include/tvm/support/random_engine.h @@ -29,6 +29,7 @@ #include #include // for uint64_t +#include namespace tvm { namespace support { @@ -73,6 +74,12 @@ class LinearCongruentialEngine { */ static constexpr result_type max() { return modulus - 1; } + /*! + * \brief Get a device random state + * \return The random state + */ + static TRandState DeviceRandom() { return (std::random_device()()) % modulus; } + /*! * \brief Operator to move the random state to the next and return the new random state. According * to definition of linear congruential engine, the new random state value is computed as @@ -93,6 +100,7 @@ class LinearCongruentialEngine { * \param rand_state The random state given in result_type. */ void Seed(TRandState rand_state = 1) { + ICHECK(rand_state != -1) << "The seed can't be -1 which should be changed to random seed!"; rand_state %= modulus; // Make sure the seed is within the range of modulus. if (rand_state == 0) rand_state = 1; // Avoid getting all 0 given the current parameter set. diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index f1156998bdac..0e9c4abebbe1 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -543,7 +543,8 @@ def print_best(self, log_file, print_mode="schedule"): code: str The best schedule code in python API or CUDA source code """ - inp, _ = load_best_record(log_file, self.workload_key) + inp, res = load_best_record(log_file, self.workload_key) + print("Best codes (ms):", [float(c) * 1000.0 for c in res.costs]) if inp is None: raise RuntimeError( "Cannot find any valid schedule for %s in file %s" % (self.workload_key, log_file) diff --git a/python/tvm/auto_scheduler/workload_registry.py b/python/tvm/auto_scheduler/workload_registry.py index 885eb0d1d0f8..75702b0a21af 100644 --- a/python/tvm/auto_scheduler/workload_registry.py +++ b/python/tvm/auto_scheduler/workload_registry.py @@ -194,7 +194,10 @@ def workload_key_to_tensors(workload_key): assert callable(value) args = deserialize_args(workload[1:]) - return value(*args) + result = value(*args) + if isinstance(result, tuple): + result = list(result) + return result def serialize_workload_registry_entry(workload_key): diff --git a/python/tvm/meta_schedule/cost_model/cost_model.py b/python/tvm/meta_schedule/cost_model/cost_model.py index f5bd60162ec5..f794b11471d9 100644 --- a/python/tvm/meta_schedule/cost_model/cost_model.py +++ b/python/tvm/meta_schedule/cost_model/cost_model.py @@ -55,7 +55,7 @@ def save(self, path: str) -> None: def update( self, - tune_context: TuneContext, + context: TuneContext, candidates: List[MeasureCandidate], results: List[RunnerResult], ) -> None: @@ -63,21 +63,21 @@ def update( Parameters ---------- - tune_context : TuneContext, + context : TuneContext, The tuning context. candidates : List[MeasureCandidate] The measure candidates. results : List[RunnerResult] The running results of the measure candidates. """ - _ffi_api.CostModelUpdate(self, tune_context, candidates, results) # type: ignore # pylint: disable=no-member + _ffi_api.CostModelUpdate(self, context, candidates, results) # type: ignore # pylint: disable=no-member - def predict(self, tune_context: TuneContext, candidates: List[MeasureCandidate]) -> np.ndarray: + def predict(self, context: TuneContext, candidates: List[MeasureCandidate]) -> np.ndarray: """Update the cost model given running results. Parameters ---------- - tune_context : TuneContext, + context : TuneContext, The tuning context. candidates : List[MeasureCandidate] The measure candidates. @@ -91,7 +91,7 @@ def predict(self, tune_context: TuneContext, candidates: List[MeasureCandidate]) results = np.zeros(shape=(n,), dtype="float64") _ffi_api.CostModelPredict( # type: ignore # pylint: disable=no-member self, - tune_context, + context, candidates, results.ctypes.data_as(ctypes.c_void_p), ) @@ -115,20 +115,18 @@ def f_save(path: str) -> None: @check_override(self.__class__, CostModel) def f_update( - tune_context: TuneContext, + context: TuneContext, candidates: List[MeasureCandidate], results: List[RunnerResult], ) -> None: - self.update(tune_context, candidates, results) + self.update(context, candidates, results) @check_override(self.__class__, CostModel) - def f_predict( - tune_context: TuneContext, candidates: List[MeasureCandidate], return_ptr - ) -> None: + def f_predict(context: TuneContext, candidates: List[MeasureCandidate], return_ptr) -> None: n = len(candidates) return_ptr = ctypes.cast(return_ptr, ctypes.POINTER(ctypes.c_double)) array_wrapper = np.ctypeslib.as_array(return_ptr, shape=(n,)) - array_wrapper[:] = self.predict(tune_context, candidates) + array_wrapper[:] = self.predict(context, candidates) assert ( array_wrapper.dtype == "float64" ), "ValueError: Invalid data type returned from CostModel Predict!" diff --git a/python/tvm/meta_schedule/cost_model/random_model.py b/python/tvm/meta_schedule/cost_model/random_model.py index 23238d25797c..8808476aba15 100644 --- a/python/tvm/meta_schedule/cost_model/random_model.py +++ b/python/tvm/meta_schedule/cost_model/random_model.py @@ -84,7 +84,7 @@ def save(self, path: str) -> None: def update( self, - tune_context: TuneContext, + context: TuneContext, candidates: List[MeasureCandidate], results: List[RunnerResult], ) -> None: @@ -92,7 +92,7 @@ def update( Parameters ---------- - tune_context : TuneContext, + context : TuneContext, The tuning context. candidates : List[MeasureCandidate] The measure candidates. @@ -100,12 +100,12 @@ def update( The running results of the measure candidates. """ - def predict(self, tune_context: TuneContext, candidates: List[MeasureCandidate]) -> np.ndarray: + def predict(self, context: TuneContext, candidates: List[MeasureCandidate]) -> np.ndarray: """Update the cost model given running results. Parameters ---------- - tune_context : TuneContext, + context : TuneContext, The tuning context. candidates : List[MeasureCandidate] The measure candidates. diff --git a/python/tvm/meta_schedule/feature_extractor/feature_extractor.py b/python/tvm/meta_schedule/feature_extractor/feature_extractor.py index bd7656e5bef1..5043d4beca4f 100644 --- a/python/tvm/meta_schedule/feature_extractor/feature_extractor.py +++ b/python/tvm/meta_schedule/feature_extractor/feature_extractor.py @@ -32,13 +32,13 @@ class FeatureExtractor(Object): """Extractor for features from measure candidates for use in cost model.""" def extract_from( - self, tune_context: TuneContext, candidates: List[MeasureCandidate] + self, context: TuneContext, candidates: List[MeasureCandidate] ) -> List[NDArray]: """Extract features from the given measure candidate. Parameters ---------- - tune_context : TuneContext + context : TuneContext The tuning context for feature extraction. candidates : List[MeasureCandidate] The measure candidates to extract features from. @@ -49,7 +49,7 @@ def extract_from( The feature numpy ndarray extracted. """ result = _ffi_api.FeatureExtractorExtractFrom( # type: ignore # pylint: disable=no-member - self, tune_context, candidates + self, context, candidates ) return result @@ -63,9 +63,9 @@ def __init__(self): @check_override(self.__class__, FeatureExtractor) def f_extract_from( - tune_context: TuneContext, candidates: List[MeasureCandidate] + context: TuneContext, candidates: List[MeasureCandidate] ) -> List[NDArray]: - features = self.extract_from(tune_context, candidates) + features = self.extract_from(context, candidates) return features def f_as_string() -> str: diff --git a/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py b/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py index 7c72a25b2378..d805648bfbfd 100644 --- a/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py +++ b/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py @@ -51,7 +51,7 @@ def __init__(self, *, feature_size: int = 30, max_block_num: int = 5, seed=0): self.random_state = np.random.get_state() def extract_from( - self, tune_context: TuneContext, candidates: List[MeasureCandidate] + self, context: TuneContext, candidates: List[MeasureCandidate] ) -> List[NDArray]: np.random.set_state(self.random_state) result = [ diff --git a/python/tvm/meta_schedule/mutator/__init__.py b/python/tvm/meta_schedule/mutator/__init__.py new file mode 100644 index 000000000000..f88043b4b4fd --- /dev/null +++ b/python/tvm/meta_schedule/mutator/__init__.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +The tvm.meta_schedule.mutator package. +Meta Schedule mutator that mutates the trace to explore the +design space. +""" +from .mutator import Mutator, PyMutator diff --git a/python/tvm/meta_schedule/mutator/mutator.py b/python/tvm/meta_schedule/mutator/mutator.py new file mode 100644 index 000000000000..80e0f6620816 --- /dev/null +++ b/python/tvm/meta_schedule/mutator/mutator.py @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Meta Schedule Mutator.""" +from typing import Optional, TYPE_CHECKING + +from tvm._ffi import register_object +from tvm.runtime import Object +from tvm.tir.schedule import Trace + +from .. import _ffi_api +from ..utils import _get_hex_address, check_override + +if TYPE_CHECKING: + from ..tune_context import TuneContext + + +class Mutator(Object): + """Mutator is designed to mutate the trace to explore the design space.""" + + def initialize_with_tune_context(self, context: "TuneContext") -> None: + """Initialize the mutator with a tune context. + + Parameters + ---------- + context : TuneContext + The tuning context for initializing the mutator. + """ + _ffi_api.MutatorInitializeWithTuneContext( # type: ignore # pylint: disable=no-member + self, context + ) + + def apply(self, trace: Trace) -> Optional[Trace]: + """Apply the mutator function to the given trace. + + Parameters + ---------- + trace : Trace + The given trace for mutation. + + Returns + ------- + trace : Optional[Trace] + None if mutator failed, otherwise return the mutated trace. + """ + return _ffi_api.MutatorApply(self, trace, -1) # type: ignore # pylint: disable=no-member + + +@register_object("meta_schedule.PyMutator") +class PyMutator(Mutator): + """An abstract mutator with customized methods on the python-side.""" + + def __init__(self): + """Constructor.""" + + @check_override(self.__class__, Mutator) + def f_initialize_with_tune_context(context: "TuneContext") -> None: + self.initialize_with_tune_context(context) + + @check_override(self.__class__, Mutator) + def f_apply(trace: Trace, _) -> Optional[Trace]: + return self.apply(trace) + + def f_as_string() -> str: + return str(self) + + self.__init_handle_by_constructor__( + _ffi_api.MutatorPyMutator, # type: ignore # pylint: disable=no-member + f_initialize_with_tune_context, + f_apply, + f_as_string, + ) + + def __str__(self) -> str: + return f"{self.__class__.__name__}({_get_hex_address(self.handle)})" diff --git a/python/tvm/meta_schedule/postproc/__init__.py b/python/tvm/meta_schedule/postproc/__init__.py new file mode 100644 index 000000000000..6ee052ecba68 --- /dev/null +++ b/python/tvm/meta_schedule/postproc/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The tvm.meta_schedule.postproc package.""" +from .postproc import Postproc, PyPostproc diff --git a/python/tvm/meta_schedule/postproc/postproc.py b/python/tvm/meta_schedule/postproc/postproc.py new file mode 100644 index 000000000000..5f1180c27c5a --- /dev/null +++ b/python/tvm/meta_schedule/postproc/postproc.py @@ -0,0 +1,90 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Meta Schedule Postproc.""" + +from typing import TYPE_CHECKING + +from tvm._ffi import register_object +from tvm.runtime import Object +from tvm.tir.schedule import Schedule + +from .. import _ffi_api +from ..utils import _get_hex_address, check_override + +if TYPE_CHECKING: + from ..tune_context import TuneContext + + +@register_object("meta_schedule.Postproc") +class Postproc(Object): + """Rules to apply a postprocessor to a schedule.""" + + def initialize_with_tune_context(self, context: "TuneContext") -> None: + """Initialize the postprocessor with a tune context. + + Parameters + ---------- + context : TuneContext + The tuning context for initializing the postprocessor. + """ + _ffi_api.PostprocInitializeWithTuneContext( # type: ignore # pylint: disable=no-member + self, context + ) + + def apply(self, sch: Schedule) -> bool: + """Apply a postprocessor to the given schedule. + + Parameters + ---------- + sch : Schedule + The schedule to be post processed. + + Returns + ------- + result : bool + Whether the postprocessor was successfully applied. + """ + return _ffi_api.PostprocApply(self, sch) # type: ignore # pylint: disable=no-member + + +@register_object("meta_schedule.PyPostproc") +class PyPostproc(Postproc): + """An abstract Postproc with customized methods on the python-side.""" + + def __init__(self): + """Constructor.""" + + @check_override(self.__class__, Postproc) + def f_initialize_with_tune_context(context: "TuneContext") -> None: + self.initialize_with_tune_context(context) + + @check_override(self.__class__, Postproc) + def f_apply(sch: Schedule) -> bool: + return self.apply(sch) + + def f_as_string() -> str: + return str(self) + + self.__init_handle_by_constructor__( + _ffi_api.PostprocPyPostproc, # type: ignore # pylint: disable=no-member + f_initialize_with_tune_context, + f_apply, + f_as_string, + ) + + def __str__(self) -> str: + return f"{self.__class__.__name__}({_get_hex_address(self.handle)})" diff --git a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py index b995e5acb6fc..ab142c03cf85 100644 --- a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py +++ b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py @@ -35,16 +35,16 @@ class ScheduleRule(Object): """Rules to modify a block in a schedule.""" - def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + def initialize_with_tune_context(self, context: "TuneContext") -> None: """Initialize the schedule rule with a tune context. Parameters ---------- - tune_context : TuneContext + context : TuneContext The tuning context for initializing the schedule rule. """ _ffi_api.ScheduleRuleInitializeWithTuneContext( # type: ignore # pylint: disable=no-member - self, tune_context + self, context ) def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: @@ -75,8 +75,8 @@ def __init__(self): """Constructor.""" @check_override(self.__class__, ScheduleRule) - def f_initialize_with_tune_context(tune_context: "TuneContext") -> None: - self.initialize_with_tune_context(tune_context) + def f_initialize_with_tune_context(context: "TuneContext") -> None: + self.initialize_with_tune_context(context) @check_override(self.__class__, ScheduleRule) def f_apply(sch: Schedule, block: BlockRV) -> List[Schedule]: diff --git a/python/tvm/meta_schedule/search_strategy/replay_trace.py b/python/tvm/meta_schedule/search_strategy/replay_trace.py index 15f8295f2524..f55013546021 100644 --- a/python/tvm/meta_schedule/search_strategy/replay_trace.py +++ b/python/tvm/meta_schedule/search_strategy/replay_trace.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Replay Trace Search Strategy""" +from typing import NamedTuple from tvm._ffi import register_object from .search_strategy import SearchStrategy @@ -41,7 +42,17 @@ class ReplayTrace(SearchStrategy): def __init__(self, num_trials_per_iter: int, num_trials_total: int): """Constructor""" self.__init_handle_by_constructor__( - _ffi_api.ReplayTrace, # type: ignore # pylint: disable=no-member + _ffi_api.SearchStrategyReplayTrace, # pylint: disable=no-member num_trials_per_iter, num_trials_total, ) + + +class ReplayTraceConfig(NamedTuple): + """Configuration for ReplayTrace""" + + num_trials_per_iter: int + num_trials_total: int + + def create_strategy(self) -> ReplayTrace: + return ReplayTrace(self.num_trials_per_iter, self.num_trials_total) diff --git a/python/tvm/meta_schedule/search_strategy/search_strategy.py b/python/tvm/meta_schedule/search_strategy/search_strategy.py index 6cee09edd4fc..411fecb2b698 100644 --- a/python/tvm/meta_schedule/search_strategy/search_strategy.py +++ b/python/tvm/meta_schedule/search_strategy/search_strategy.py @@ -48,7 +48,11 @@ class MeasureCandidate(Object): sch: Schedule args_info: List[ArgInfo] - def __init__(self, sch: Schedule, args_info: List[ArgInfo]) -> None: + def __init__( + self, + sch: Schedule, + args_info: List[ArgInfo], + ) -> None: """Constructor. Parameters @@ -72,19 +76,16 @@ class SearchStrategy(Object): before usage and post-tuned after usage. """ - def initialize_with_tune_context( - self, - tune_context: "TuneContext", - ) -> None: + def initialize_with_tune_context(self, context: "TuneContext") -> None: """Initialize the search strategy with tuning context. Parameters ---------- - tune_context : TuneContext + context : TuneContext The tuning context for initialization. """ _ffi_api.SearchStrategyInitializeWithTuneContext( # type: ignore # pylint: disable=no-member - self, tune_context + self, context ) def pre_tuning(self, design_spaces: List[Schedule]) -> None: @@ -111,15 +112,29 @@ def generate_measure_candidates(self) -> Optional[List[MeasureCandidate]]: """ return _ffi_api.SearchStrategyGenerateMeasureCandidates(self) # type: ignore # pylint: disable=no-member - def notify_runner_results(self, results: List[RunnerResult]) -> None: + def notify_runner_results( + self, + context: "TuneContext", + measure_candidates: List[MeasureCandidate], + results: List[RunnerResult], + ) -> None: """Update the search strategy with profiling results. Parameters ---------- + context : TuneContext + The tuning context for update. + measure_candidates : List[MeasureCandidate] + The measure candidates for update. results : List[RunnerResult] The profiling results from the runner. """ - _ffi_api.SearchStrategyNotifyRunnerResults(self, results) # type: ignore # pylint: disable=no-member + _ffi_api.SearchStrategyNotifyRunnerResults( # type: ignore # pylint: disable=no-member + self, + context, + measure_candidates, + results, + ) @register_object("meta_schedule.PySearchStrategy") @@ -146,8 +161,12 @@ def f_generate_measure_candidates() -> List[MeasureCandidate]: return self.generate_measure_candidates() @check_override(self.__class__, SearchStrategy) - def f_notify_runner_results(results: List["RunnerResult"]) -> None: - self.notify_runner_results(results) + def f_notify_runner_results( + context: "TuneContext", + measure_candidates: List[MeasureCandidate], + results: List["RunnerResult"], + ) -> None: + self.notify_runner_results(context, measure_candidates, results) self.__init_handle_by_constructor__( _ffi_api.SearchStrategyPySearchStrategy, # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/space_generator/schedule_fn.py b/python/tvm/meta_schedule/space_generator/schedule_fn.py index 64edd9e0bf8c..8a57a8417247 100644 --- a/python/tvm/meta_schedule/space_generator/schedule_fn.py +++ b/python/tvm/meta_schedule/space_generator/schedule_fn.py @@ -51,12 +51,12 @@ def __init__(self, sch_fn: SCH_FN_TYPE): super().__init__() self.sch_fn = sch_fn - def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + def initialize_with_tune_context(self, context: "TuneContext") -> None: """Initialize the design space generator with tuning context. Parameters ---------- - tune_context : TuneContext + context : TuneContext The tuning context for initializing the design space generator. """ diff --git a/python/tvm/meta_schedule/space_generator/space_generator.py b/python/tvm/meta_schedule/space_generator/space_generator.py index 2172613ce1e6..e0b0ab2fc16e 100644 --- a/python/tvm/meta_schedule/space_generator/space_generator.py +++ b/python/tvm/meta_schedule/space_generator/space_generator.py @@ -36,16 +36,16 @@ class SpaceGenerator(Object): """The abstract design space generator interface.""" - def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + def initialize_with_tune_context(self, context: "TuneContext") -> None: """Initialize the design space generator with tuning context. Parameters ---------- - tune_context : TuneContext + context : TuneContext The tuning context for initializing the design space generator. """ _ffi_api.SpaceGeneratorInitializeWithTuneContext( # type: ignore # pylint: disable=no-member - self, tune_context + self, context ) def generate_design_space(self, mod: IRModule) -> List[Schedule]: @@ -72,8 +72,8 @@ def __init__(self): """Constructor.""" @check_override(self.__class__, SpaceGenerator) - def f_initialize_with_tune_context(tune_context: "TuneContext") -> None: - self.initialize_with_tune_context(tune_context) + def f_initialize_with_tune_context(context: "TuneContext") -> None: + self.initialize_with_tune_context(context) @check_override(self.__class__, SpaceGenerator) def f_generate_design_space(mod: IRModule) -> List[Schedule]: diff --git a/python/tvm/meta_schedule/task_scheduler/round_robin.py b/python/tvm/meta_schedule/task_scheduler/round_robin.py index 391011b4f53f..a63d9a3f2183 100644 --- a/python/tvm/meta_schedule/task_scheduler/round_robin.py +++ b/python/tvm/meta_schedule/task_scheduler/round_robin.py @@ -16,13 +16,15 @@ # under the License. """Round Robin Task Scheduler""" -from typing import List, TYPE_CHECKING +from typing import List, Optional, TYPE_CHECKING from tvm._ffi import register_object +from tvm.meta_schedule.measure_callback.measure_callback import MeasureCallback from ..builder import Builder from ..runner import Runner from ..database import Database +from ..cost_model import CostModel from .task_scheduler import TaskScheduler from .. import _ffi_api @@ -33,7 +35,21 @@ @register_object("meta_schedule.RoundRobin") class RoundRobin(TaskScheduler): - """Round Robin Task Scheduler""" + """Round Robin Task Scheduler + + Parameters + ---------- + tasks: List[TuneContext] + The list of tune context to process. + builder: Builder + The builder of the scheduler. + runner: Runner + The runner of the scheduler. + database: Database + The database of the scheduler. + measure_callbacks: Optional[List[MeasureCallback]] = None + The list of measure callbacks of the scheduler. + """ def __init__( self, @@ -41,6 +57,8 @@ def __init__( builder: Builder, runner: Runner, database: Database, + cost_model: Optional[CostModel] = None, + measure_callbacks: Optional[List[MeasureCallback]] = None, ) -> None: """Constructor. @@ -54,6 +72,8 @@ def __init__( The runner. database : Database The database. + measure_callbacks: Optional[List[MeasureCallback]] + The list of measure callbacks of the scheduler. """ self.__init_handle_by_constructor__( _ffi_api.TaskSchedulerRoundRobin, # type: ignore # pylint: disable=no-member @@ -61,4 +81,6 @@ def __init__( builder, runner, database, + cost_model, + measure_callbacks, ) diff --git a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py index aeea154cfe02..dd8e3fe89b63 100644 --- a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py +++ b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py @@ -16,14 +16,16 @@ # under the License. """Auto-tuning Task Scheduler""" -from typing import List +from typing import List, Optional from tvm._ffi import register_object +from tvm.meta_schedule.measure_callback.measure_callback import MeasureCallback from tvm.runtime import Object from ..runner import Runner from ..builder import Builder from ..database import Database +from ..cost_model import CostModel from ..tune_context import TuneContext from .. import _ffi_api from ..utils import check_override @@ -43,12 +45,16 @@ class TaskScheduler(Object): The runner of the scheduler. database: Database The database of the scheduler. + measure_callbacks: List[MeasureCallback] = None + The list of measure callbacks of the scheduler. """ tasks: List[TuneContext] builder: Builder runner: Runner database: Database + cost_model: Optional[CostModel] + measure_callbacks: List[MeasureCallback] def tune(self) -> None: """Auto-tuning.""" @@ -59,7 +65,7 @@ def next_task_id(self) -> int: Returns ------- - int + next_task_id : int The next task id. """ return _ffi_api.TaskSchedulerNextTaskId(self) # type: ignore # pylint: disable=no-member @@ -94,7 +100,7 @@ def _is_task_running(self, task_id: int) -> bool: Returns ------- - bool + running : bool Whether the task is running. """ return _ffi_api.TaskSchedulerIsTaskRunning(self, task_id) # type: ignore # pylint: disable=no-member @@ -120,6 +126,8 @@ def __init__( builder: Builder, runner: Runner, database: Database, + cost_model: Optional[CostModel] = None, + measure_callbacks: Optional[List[MeasureCallback]] = None, ): """Constructor. @@ -133,6 +141,10 @@ def __init__( The runner of the scheduler. database: Database The database of the scheduler. + cost_model: Optional[CostModel] + The cost model of the scheduler. + measure_callbacks: List[MeasureCallback] + The list of measure callbacks of the scheduler. """ @check_override(self.__class__, TaskScheduler, required=False) @@ -173,6 +185,8 @@ def f_join_running_task(task_id: int) -> None: builder, runner, database, + cost_model, + measure_callbacks, f_tune, f_initialize_task, f_set_task_stopped, diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index 99b8c7e869cd..196b1c16b6f2 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -16,13 +16,14 @@ # under the License. """Meta Schedule tuning context.""" -from typing import Optional, List, TYPE_CHECKING +from typing import Optional, List, Dict, TYPE_CHECKING from tvm import IRModule from tvm._ffi import register_object from tvm.meta_schedule.utils import cpu_count from tvm.runtime import Object from tvm.target import Target +from tvm.tir import PrimFunc from . import _ffi_api @@ -30,6 +31,8 @@ from .space_generator import SpaceGenerator from .search_strategy import SearchStrategy from .schedule_rule import ScheduleRule + from .postproc import Postproc + from .mutator import Mutator @register_object("meta_schedule.TuneContext") @@ -53,6 +56,10 @@ class TuneContext(Object): The search strategy. sch_rules: Optional[List[ScheduleRule]] = None, The schedule rules. + postprocs: Optional[List[Postproc"]] = None, + The postprocessors. + mutator_probs: Optional[Dict[Mutator, float]] + Mutators and their probability mass. task_name : Optional[str] = None The name of the tuning task. rand_state : int = -1 @@ -71,23 +78,31 @@ class TuneContext(Object): mod: Optional[IRModule] target: Optional[Target] - space_generator: "SpaceGenerator" - search_strategy: "SearchStrategy" - task_name: Optional[str] + space_generator: Optional["SpaceGenerator"] + search_strategy: Optional["SearchStrategy"] + sch_rules: List["ScheduleRule"] + postprocs: List["Postproc"] + mutator_probs: Optional[Dict["Mutator", float]] + task_name: str rand_state: int num_threads: int def __init__( self, mod: Optional[IRModule] = None, + *, target: Optional[Target] = None, space_generator: Optional["SpaceGenerator"] = None, search_strategy: Optional["SearchStrategy"] = None, sch_rules: Optional[List["ScheduleRule"]] = None, - task_name: Optional[str] = None, + postprocs: Optional[List["Postproc"]] = None, + mutator_probs: Optional[Dict["Mutator", float]] = None, + task_name: str = "main", rand_state: int = -1, num_threads: Optional[int] = None, ): + if isinstance(mod, PrimFunc): + mod = IRModule.from_expr(mod) if num_threads is None: num_threads = cpu_count() @@ -98,6 +113,8 @@ def __init__( space_generator, search_strategy, sch_rules, + postprocs, + mutator_probs, task_name, rand_state, num_threads, diff --git a/src/meta_schedule/cost_model/cost_model.cc b/src/meta_schedule/cost_model/cost_model.cc index 5cd32b097caa..c6efb5430336 100644 --- a/src/meta_schedule/cost_model/cost_model.cc +++ b/src/meta_schedule/cost_model/cost_model.cc @@ -53,10 +53,10 @@ TVM_REGISTER_GLOBAL("meta_schedule.CostModelUpdate") .set_body_method(&CostModelNode::Update); TVM_REGISTER_GLOBAL("meta_schedule.CostModelPredict") .set_body_typed([](CostModel model, // - const TuneContext& tune_context, // + const TuneContext& context, // Array candidates, // void* p_addr) -> void { - std::vector result = model->Predict(tune_context, candidates); + std::vector result = model->Predict(context, candidates); std::copy(result.begin(), result.end(), static_cast(p_addr)); }); TVM_REGISTER_GLOBAL("meta_schedule.CostModelPyCostModel").set_body_typed(CostModel::PyCostModel); diff --git a/src/meta_schedule/search_strategy/replay_trace.cc b/src/meta_schedule/search_strategy/replay_trace.cc index 200eca34133d..8c9e2d8949e9 100644 --- a/src/meta_schedule/search_strategy/replay_trace.cc +++ b/src/meta_schedule/search_strategy/replay_trace.cc @@ -17,6 +17,7 @@ * under the License. */ #include "../utils.h" +#include "tvm/tir/schedule/schedule.h" namespace tvm { namespace meta_schedule { @@ -24,20 +25,18 @@ namespace meta_schedule { /*! \brief A search strategy that generates measure candidates using trace and random decisions. */ class ReplayTraceNode : public SearchStrategyNode { public: - using TRandState = support::LinearCongruentialEngine::TRandState; - /*! \brief The state of the search strategy. */ struct State { /*! \brief The search strategy itself */ ReplayTraceNode* self; /*! \brief The design spaces. */ - Array design_spaces; + Array design_spaces; /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ int st; /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ int ed; - explicit State(ReplayTraceNode* self, Array design_spaces) + explicit State(ReplayTraceNode* self, Array design_spaces) : self(self), design_spaces(design_spaces), st(0), ed(self->num_trials_per_iter) {} inline Optional> GenerateMeasureCandidates(); @@ -50,9 +49,11 @@ class ReplayTraceNode : public SearchStrategyNode { int num_trials_total; /*! \brief The module to be tuned. */ - IRModule mod_{nullptr}; + Array per_thread_mod_{nullptr}; /*! \brief The metadata of the function arguments. */ Array args_info_{nullptr}; + /*! \brief The post processors */ + Array postprocs_{nullptr}; /*! \brief The number of threads to use. -1 means using logical cpu number. */ int num_threads_ = -1; /*! \brief The random state. -1 means using random number. */ @@ -63,8 +64,9 @@ class ReplayTraceNode : public SearchStrategyNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("num_trials_per_iter", &num_trials_per_iter); v->Visit("num_trials_total", &num_trials_total); - // `mod_` is not visited + // `per_thread_mod_` is not visited // `args_info_` is not visited + // `postprocs_` is not visited // `num_threads_` is not visited // `rand_state_` is not visited // `state_` is not visited @@ -74,9 +76,16 @@ class ReplayTraceNode : public SearchStrategyNode { TVM_DECLARE_FINAL_OBJECT_INFO(ReplayTraceNode, SearchStrategyNode); void InitializeWithTuneContext(const TuneContext& context) final { - this->mod_ = context->mod.value(); - this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(this->mod_)); + CHECK(context->num_threads > 0) << "Number of threads has to be larger than 0."; this->num_threads_ = context->num_threads; + + this->per_thread_mod_.reserve(this->num_threads_); + for (int i = 0; i < this->num_threads_; i++) { + this->per_thread_mod_.push_back(DeepCopyIRModule(context->mod.value())); + } + + this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(context->mod.value())); + this->postprocs_ = context->postprocs; this->rand_state_ = ForkSeed(&context->rand_state); this->state_.reset(); } @@ -84,7 +93,12 @@ class ReplayTraceNode : public SearchStrategyNode { void PreTuning(const Array& design_spaces) final { ICHECK(!design_spaces.empty()); ICHECK(this->state_ == nullptr); - this->state_ = std::make_unique(this, design_spaces); + Array design_space_traces; + design_space_traces.reserve(design_spaces.size()); + for (const tir::Schedule& space : design_spaces) { + design_space_traces.push_back(space->trace().value()->Simplified(true)); + } + this->state_ = std::make_unique(this, design_space_traces); } void PostTuning() final { @@ -97,7 +111,9 @@ class ReplayTraceNode : public SearchStrategyNode { return this->state_->GenerateMeasureCandidates(); } - void NotifyRunnerResults(const Array& results) final { + void NotifyRunnerResults(const TuneContext& context, + const Array& measure_candidates, + const Array& results) final { ICHECK(this->state_ != nullptr); this->state_->NotifyRunnerResults(results); } @@ -111,19 +127,20 @@ inline Optional> ReplayTraceNode::State::GenerateMeasure ICHECK_LT(st, ed); std::vector per_thread_rand_state = ForkSeed(&self->rand_state_, self->num_threads_); Array per_task_result(ed - st, MeasureCandidate{nullptr}); - auto f_worker = [this, &per_thread_rand_state, &per_task_result](int thread_id, - int task_id) -> void { + ThreadedTraceApply pp(self->postprocs_); + auto f_worker = [this, &per_thread_rand_state, &per_task_result, &pp](int thread_id, + int task_id) -> void { TRandState& rand_state = per_thread_rand_state[thread_id]; - int design_space_index = tir::SampleInt(&rand_state, 0, design_spaces.size()); - tir::Trace trace = design_spaces[design_space_index]->trace().value(); - tir::Trace new_trace = tir::Trace(trace->insts, {}); - tir::Schedule sch = tir::Schedule::Traced( // - self->mod_, // - /*rand_state=*/ForkSeed(&rand_state), // - /*debug_mode=*/0, // - /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone); - new_trace->ApplyToSchedule(sch, /*remove_postproc=*/true); - per_task_result.Set(task_id, MeasureCandidate(sch, self->args_info_)); + IRModule mod = self->per_thread_mod_[thread_id]; + for (;;) { + int design_space_index = tir::SampleInt(&rand_state, 0, design_spaces.size()); + tir::Trace trace = design_spaces[design_space_index]; + tir::Trace new_trace = tir::Trace(trace->insts, {}); + if (Optional sch = pp.Apply(mod, new_trace, &rand_state)) { + per_task_result.Set(task_id, MeasureCandidate(sch.value(), self->args_info_)); + break; + } + } }; support::parallel_for_dynamic(0, ed - st, self->num_threads_, f_worker); return per_task_result; @@ -142,7 +159,8 @@ SearchStrategy SearchStrategy::ReplayTrace(int num_trials_per_iter, int num_tria } TVM_REGISTER_NODE_TYPE(ReplayTraceNode); -TVM_REGISTER_GLOBAL("meta_schedule.ReplayTrace").set_body_typed(SearchStrategy::ReplayTrace); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyReplayTrace") + .set_body_typed(SearchStrategy::ReplayTrace); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/task_scheduler/round_robin.cc b/src/meta_schedule/task_scheduler/round_robin.cc index 3ef5026cae98..72989a20bcd5 100644 --- a/src/meta_schedule/task_scheduler/round_robin.cc +++ b/src/meta_schedule/task_scheduler/round_robin.cc @@ -52,16 +52,23 @@ class RoundRobinNode final : public TaskSchedulerNode { } }; -TaskScheduler TaskScheduler::RoundRobin(Array tasks, // - Builder builder, // - Runner runner, // - Database database) { +TaskScheduler TaskScheduler::RoundRobin(Array tasks, // + Builder builder, // + Runner runner, // + Database database, // + Optional cost_model, // + Optional> measure_callbacks) { ObjectPtr n = make_object(); n->tasks = tasks; n->builder = builder; n->runner = runner; n->database = database; + n->cost_model = cost_model; + n->measure_callbacks = measure_callbacks.value_or({}); n->task_id = -1; + for (const TuneContext& task : tasks) { + task->task_scheduler = n.get(); + } return TaskScheduler(n); } diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc index 08f2b4f451bd..1f3943dc14dc 100644 --- a/src/meta_schedule/task_scheduler/task_scheduler.cc +++ b/src/meta_schedule/task_scheduler/task_scheduler.cc @@ -16,7 +16,6 @@ * specific language governing permissions and limitations * under the License. */ - #include "../utils.h" namespace tvm { @@ -29,9 +28,9 @@ namespace meta_schedule { * \param candidates The measure candidates. * \return An array of the builder results. */ -Array SendToBuilder(const Builder& builder, // - const TuneContext& context, +Array SendToBuilder(const Builder& builder, const TuneContext& context, const Array& candidates) { + LOG(INFO) << "Sending " << candidates.size() << " sample(s) to builder"; Target target = context->target.value(); Array inputs; inputs.reserve(candidates.size()); @@ -45,14 +44,14 @@ Array SendToBuilder(const Builder& builder, // * \brief Send the built measure candidates to runner. * \param runner The runner to send the candidates to. * \param context The tuning context. - * \param candidates The mesure candidates. + * \param candidates The measure candidates. * \param builder_results The builder results. * \return An array of the runner results. */ -Array SendToRunner(const Runner& runner, // - const TuneContext& context, +Array SendToRunner(const Runner& runner, const TuneContext& context, const Array& candidates, const Array& builder_results) { + LOG(INFO) << "Sending " << candidates.size() << " sample(s) to runner"; Target target = context->target.value(); ICHECK_EQ(candidates.size(), builder_results.size()); int n = candidates.size(); @@ -94,54 +93,60 @@ Array SendToRunner(const Runner& runner, // void TaskSchedulerNode::InitializeTask(int task_id) { TuneContext task = this->tasks[task_id]; - // Derive the values. - IRModule mod = task->mod.value(); - SpaceGenerator space = task->space_generator.value(); - SearchStrategy strategy = task->search_strategy.value(); - // Initialize Modules. - space->InitializeWithTuneContext(task); - strategy->InitializeWithTuneContext(task); + LOG(INFO) << "Initializing task " << task_id << ": " << task->task_name << ", mod =\n" + << tir::AsTVMScript(task->mod); + this->tasks[task_id]->Initialize(); } void TaskSchedulerNode::Tune() { for (int i = 0; i < static_cast(this->tasks.size()); i++) { + TuneContext task = tasks[i]; // Check Optional value validity. - CHECK(tasks[i]->mod.defined()) << "ValueError: Require `context.mod`, but it is not defined"; - CHECK(tasks[i]->space_generator.defined()) + CHECK(task->mod.defined()) << "ValueError: Require `context.mod`, but it is not defined"; + CHECK(task->space_generator.defined()) << "ValueError: Require `context.space_generator`, but it is not defined"; - CHECK(tasks[i]->search_strategy.defined()) + CHECK(task->search_strategy.defined()) << "ValueError: Require `context.search_strategy`, but it is not defined"; - InitializeTask(i); - - tasks[i]->search_strategy.value()->PreTuning( - tasks[i]->space_generator.value()->GenerateDesignSpace(tasks[i]->mod.value())); + Array design_spaces = + task->space_generator.value()->GenerateDesignSpace(task->mod.value()); + LOG(INFO) << "Total " << design_spaces.size() << " design space(s) generated"; + for (int i = 0, n = design_spaces.size(); i < n; ++i) { + tir::Schedule sch = design_spaces[i]; + tir::Trace trace = sch->trace().value(); + trace = trace->Simplified(true); + LOG(INFO) << "Design space #" << i << ":\n" + << tir::AsTVMScript(sch->mod()) << "\n" + << Concat(trace->AsPython(false), "\n"); + } + task->search_strategy.value()->PreTuning(design_spaces); } int running_tasks = tasks.size(); - while (running_tasks > 0) { - for (int task_id; (task_id = NextTaskId()) != -1;) { - TuneContext task = tasks[task_id]; - ICHECK(!task->is_stopped); - ICHECK(!task->runner_futures.defined()); - SearchStrategy strategy = task->search_strategy.value(); - if ((task->measure_candidates = strategy->GenerateMeasureCandidates()).defined()) { - Array builder_results = - SendToBuilder(this->builder, task, task->measure_candidates.value()); - task->runner_futures = - SendToRunner(this->runner, task, task->measure_candidates.value(), builder_results); - } else { - SetTaskStopped(task_id); - --running_tasks; - } + for (int task_id; (task_id = NextTaskId()) != -1;) { + LOG(INFO) << "Scheduler picks Task #" << task_id << ": " << tasks[task_id]->task_name; + TuneContext task = tasks[task_id]; + ICHECK(!task->is_stopped); + ICHECK(!task->runner_futures.defined()); + SearchStrategy strategy = task->search_strategy.value(); + if ((task->measure_candidates = strategy->GenerateMeasureCandidates()).defined()) { + Array builder_results = + SendToBuilder(this->builder, task, task->measure_candidates.value()); + task->builder_results = builder_results; + task->runner_futures = + SendToRunner(this->runner, task, task->measure_candidates.value(), builder_results); + } else { + SetTaskStopped(task_id); + --running_tasks; + LOG(INFO) << "Task #" << task_id << " has finished. Remaining task(s): " << running_tasks; } - int n_tasks = this->tasks.size(); - for (int task_id = 0; task_id < n_tasks; ++task_id) - if (IsTaskRunning(task_id)) { - TuneContext task = tasks[task_id]; - this->JoinRunningTask(task_id); - task->search_strategy.value()->PostTuning(); - } + } + ICHECK_EQ(running_tasks, 0) << "Not all tasks are finished"; + int n_tasks = this->tasks.size(); + for (int task_id = 0; task_id < n_tasks; ++task_id) { + ICHECK(!IsTaskRunning(task_id)) << "Task #" << task_id << " is still running"; + TuneContext task = tasks[task_id]; + task->search_strategy.value()->PostTuning(); } } @@ -175,25 +180,20 @@ void TaskSchedulerNode::JoinRunningTask(int task_id) { for (const RunnerFuture future : task->runner_futures.value()) { results.push_back(future->Result()); } - task->search_strategy.value()->NotifyRunnerResults(results); - task->runner_futures = NullOpt; - // Add to database + task->search_strategy.value()->NotifyRunnerResults(task, task->measure_candidates.value(), + results); + // Invoke the callbacks ICHECK(task->measure_candidates.defined()); - ICHECK(results.size() == task->measure_candidates.value().size()); - int index = 0; - for (const RunnerResult& result : results) { - if (!result->error_msg.defined() && result->run_secs.defined()) { - Optional trace = task->measure_candidates.value()[index]->sch->trace(); - ICHECK(trace.defined()); - this->database->CommitTuningRecord(TuningRecord( - /*trace=*/trace.value(), - /*run_secs=*/result->run_secs.value(), - /*workload=*/this->database->CommitWorkload(task->mod.value()), - /*target=*/task->target.value(), - /*args_info=*/task->measure_candidates.value()[index]->args_info)); - } - index++; + ICHECK(task->builder_results.defined()); + ICHECK_EQ(results.size(), task->measure_candidates.value().size()); + ICHECK_EQ(results.size(), task->builder_results.value().size()); + for (const MeasureCallback& callback : this->measure_callbacks) { + callback->Apply(GetRef(this), task_id, task->measure_candidates.value(), + task->builder_results.value(), results); } + task->measure_candidates = NullOpt; + task->builder_results = NullOpt; + task->runner_futures = NullOpt; } TaskScheduler TaskScheduler::PyTaskScheduler( @@ -201,6 +201,8 @@ TaskScheduler TaskScheduler::PyTaskScheduler( Builder builder, // Runner runner, // Database database, // + Optional cost_model, // + Optional> measure_callbacks, // PyTaskSchedulerNode::FTune f_tune, // PyTaskSchedulerNode::FInitializeTask f_initialize_task, // PyTaskSchedulerNode::FSetTaskStopped f_set_task_stopped, // @@ -212,6 +214,12 @@ TaskScheduler TaskScheduler::PyTaskScheduler( n->builder = builder; n->runner = runner; n->database = database; + n->cost_model = cost_model; + if (measure_callbacks.defined()) { + n->measure_callbacks = measure_callbacks.value(); + } else { + n->measure_callbacks = {}; + } n->f_tune = f_tune; n->f_initialize_task = f_initialize_task; n->f_set_task_stopped = f_set_task_stopped; diff --git a/src/meta_schedule/tune_context.cc b/src/meta_schedule/tune_context.cc index ac85d43e7987..c06cb9adc8ff 100644 --- a/src/meta_schedule/tune_context.cc +++ b/src/meta_schedule/tune_context.cc @@ -16,7 +16,6 @@ * specific language governing permissions and limitations * under the License. */ -#include #include #include "./utils.h" @@ -24,21 +23,13 @@ namespace tvm { namespace meta_schedule { -/*! - * \brief Constructor function of TuneContext class. - * \param mod The mod to be optimized. - * \param target The target to be optimized for. - * \param space_generator The design space generator. - * \param task_name The name of the tuning task. - * \param rand_state The random state. - * \param num_threads The number of threads to be used. - * \param verbose The verbosity level. - */ TuneContext::TuneContext(Optional mod, // Optional target, // Optional space_generator, // Optional search_strategy, // Optional> sch_rules, // + Optional> postprocs, // + Optional> mutator_probs, // Optional task_name, // support::LinearCongruentialEngine::TRandState rand_state, // int num_threads) { @@ -48,9 +39,11 @@ TuneContext::TuneContext(Optional mod, n->space_generator = space_generator; n->search_strategy = search_strategy; n->sch_rules = sch_rules.value_or({}); - n->task_name = task_name; + n->postprocs = postprocs.value_or({}); + n->mutator_probs = mutator_probs.value_or({}); + n->task_name = task_name.value_or("main"); if (rand_state == -1) { - rand_state = std::random_device()(); + rand_state = support::LinearCongruentialEngine::DeviceRandom(); } support::LinearCongruentialEngine(&n->rand_state).Seed(rand_state); n->num_threads = num_threads; @@ -60,6 +53,26 @@ TuneContext::TuneContext(Optional mod, data_ = std::move(n); } +void TuneContextNode::Initialize() { + if (this->space_generator.defined()) { + this->space_generator.value()->InitializeWithTuneContext(GetRef(this)); + } + if (this->search_strategy.defined()) { + this->search_strategy.value()->InitializeWithTuneContext(GetRef(this)); + } + for (const ScheduleRule& sch_rule : sch_rules) { + sch_rule->InitializeWithTuneContext(GetRef(this)); + } + for (const Postproc& postproc : postprocs) { + postproc->InitializeWithTuneContext(GetRef(this)); + } + if (mutator_probs.defined()) { + for (const auto& kv : mutator_probs) { + kv.first->InitializeWithTuneContext(GetRef(this)); + } + } +} + TVM_REGISTER_NODE_TYPE(TuneContextNode); TVM_REGISTER_GLOBAL("meta_schedule.TuneContext") @@ -68,11 +81,13 @@ TVM_REGISTER_GLOBAL("meta_schedule.TuneContext") Optional space_generator, // Optional search_strategy, // Optional> sch_rules, // + Optional> postprocs, // + Optional> mutator_probs, // Optional task_name, // support::LinearCongruentialEngine::TRandState rand_state, // int num_threads) -> TuneContext { - return TuneContext(mod, target, space_generator, search_strategy, sch_rules, task_name, - rand_state, num_threads); + return TuneContext(mod, target, space_generator, search_strategy, sch_rules, postprocs, + mutator_probs, task_name, rand_state, num_threads); }); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 0a9ce4a1aed9..3b4eb32105ea 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -200,7 +200,7 @@ inline support::LinearCongruentialEngine::TRandState ForkSeed( /*! * \brief Fork a random state into another ones, i.e. PRNG splitting. - * The given random state is also mutated. + * The given random state is also mutated. * \param rand_state The random state to be forked * \param n The number of forks * \return The forked random states @@ -215,6 +215,15 @@ inline std::vector ForkSeed( return results; } +/*! + * \brief Get deep copy of an IRModule. + * \param mod The IRModule to make a deep copy. + * \return The deep copy of the IRModule. + */ +inline IRModule DeepCopyIRModule(IRModule mod) { + return Downcast(LoadJSON(SaveJSON(mod))); +} + /*! * \brief Concatenate strings * \param strs The strings to concatenate @@ -233,6 +242,73 @@ inline std::string Concat(const Array& strs, const std::string& delim) { return os.str(); } +/*! + * \brief A helper data structure that replays a trace and collects failure counts + * for each postprocessor + */ +struct ThreadedTraceApply { + /*! \brief Constructor */ + explicit ThreadedTraceApply(const Array& postprocs) + : n_(postprocs.size()), items_(new Item[n_]) { + for (int i = 0; i < n_; ++i) { + items_[i].postproc = postprocs[i]; + items_[i].fail_counter = 0; + } + } + + /*! \brief Destructor */ + ~ThreadedTraceApply() { delete[] items_; } + + /*! + * \brief Apply the trace and postprocessors to an IRModule + * \param mod The IRModule to be applied + * \param trace The trace to apply to the IRModule + * \param rand_state The random seed + * \return The schedule created, or NullOpt if any postprocessor fails + */ + Optional Apply(const IRModule& mod, const tir::Trace& trace, + TRandState* rand_state) { + tir::Schedule sch = + tir::Schedule::Traced(mod, + /*rand_state=*/ForkSeed(rand_state), + /*debug_mode=*/0, + /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone); + trace->ApplyToSchedule(sch, /*remove_postproc=*/true); + sch->EnterPostproc(); + for (int i = 0; i < n_; ++i) { + Item& item = items_[i]; + if (!item.postproc->Apply(sch)) { + ++item.fail_counter; + return NullOpt; + } + } + return sch; + } + + /*! \brief Returns a string summarizing the failures on each postprocessor */ + std::string SummarizeFailures() const { + std::ostringstream os; + for (int i = 0; i < n_; ++i) { + const Item& item = items_[i]; + os << "Postproc #" << i << " [" << item.postproc // + << "]: " << item.fail_counter.load() << " failure(s)"; + if (i != n_ - 1) { + os << "\n"; + } + } + return os.str(); + } + + private: + struct Item { + Postproc postproc{nullptr}; + std::atomic fail_counter{0}; + }; + + int n_; + Item* items_; +}; + } // namespace meta_schedule } // namespace tvm diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 81d91b70a0b3..d0e272064c7a 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -30,7 +30,7 @@ Schedule Schedule::Concrete(IRModule mod, support::LinearCongruentialEngine::TRa n->error_render_level_ = error_render_level; n->symbol_table_ = {}; n->analyzer_ = std::make_unique(); - support::LinearCongruentialEngine(&n->rand_state_).Seed(seed); + n->Seed(seed); return Schedule(std::move(n)); } diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 61283668f85d..b4d1ba01e93e 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -29,7 +29,7 @@ Schedule Schedule::Traced(IRModule mod, support::LinearCongruentialEngine::TRand n->symbol_table_ = {}; n->analyzer_ = std::make_unique(); n->trace_ = Trace(); - support::LinearCongruentialEngine(&n->rand_state_).Seed(seed); + n->Seed(seed); return Schedule(std::move(n)); } diff --git a/tests/python/unittest/test_meta_schedule_cost_model.py b/tests/python/unittest/test_meta_schedule_cost_model.py index 3f98d711ea61..5b409be026ee 100644 --- a/tests/python/unittest/test_meta_schedule_cost_model.py +++ b/tests/python/unittest/test_meta_schedule_cost_model.py @@ -62,15 +62,13 @@ def save(self, path: str) -> None: def update( self, - tune_context: TuneContext, + context: TuneContext, candidates: List[MeasureCandidate], results: List[RunnerResult], ) -> None: pass - def predict( - self, tune_context: TuneContext, candidates: List[MeasureCandidate] - ) -> np.ndarray: + def predict(self, context: TuneContext, candidates: List[MeasureCandidate]) -> np.ndarray: return np.random.rand(10) model = FancyCostModel() @@ -91,15 +89,13 @@ def save(self, path: str) -> None: def update( self, - tune_context: TuneContext, + context: TuneContext, candidates: List[MeasureCandidate], results: List[RunnerResult], ) -> None: pass - def predict( - self, tune_context: TuneContext, candidates: List[MeasureCandidate] - ) -> np.ndarray: + def predict(self, context: TuneContext, candidates: List[MeasureCandidate]) -> np.ndarray: return np.random.rand(10) cost_model = NotSoFancyCostModel() diff --git a/tests/python/unittest/test_meta_schedule_feature_extractor.py b/tests/python/unittest/test_meta_schedule_feature_extractor.py index 143d446f48fd..d95397b42c77 100644 --- a/tests/python/unittest/test_meta_schedule_feature_extractor.py +++ b/tests/python/unittest/test_meta_schedule_feature_extractor.py @@ -28,7 +28,7 @@ def test_meta_schedule_feature_extractor(): class FancyFeatureExtractor(PyFeatureExtractor): def extract_from( self, - tune_context: TuneContext, # pylint: disable = unused-argument + context: TuneContext, # pylint: disable = unused-argument candidates: List[MeasureCandidate], # pylint: disable = unused-argument ) -> List[np.ndarray]: return [np.random.rand(4, 5)] @@ -43,7 +43,7 @@ def test_meta_schedule_feature_extractor_as_string(): class NotSoFancyFeatureExtractor(PyFeatureExtractor): def extract_from( self, - tune_context: TuneContext, # pylint: disable = unused-argument + context: TuneContext, # pylint: disable = unused-argument candidates: List[MeasureCandidate], # pylint: disable = unused-argument ) -> List[np.ndarray]: return [] diff --git a/tests/python/unittest/test_meta_schedule_post_order_apply.py b/tests/python/unittest/test_meta_schedule_post_order_apply.py index b78e67817ebf..78477e6acdd6 100644 --- a/tests/python/unittest/test_meta_schedule_post_order_apply.py +++ b/tests/python/unittest/test_meta_schedule_post_order_apply.py @@ -135,7 +135,7 @@ def _check_correct(schedule: Schedule): class WowSoFancyScheduleRule(PyScheduleRule): - def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + def initialize_with_tune_context(self, context: "TuneContext") -> None: pass def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: @@ -151,7 +151,7 @@ def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: class DoubleScheduleRule(PyScheduleRule): - def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + def initialize_with_tune_context(self, context: "TuneContext") -> None: pass def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: @@ -175,7 +175,7 @@ def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: class ReorderScheduleRule(PyScheduleRule): - def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + def initialize_with_tune_context(self, context: "TuneContext") -> None: pass def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: @@ -262,7 +262,7 @@ def test_meta_schedule_post_order_apply_duplicate_matmul(): def test_meta_schedule_post_order_apply_remove_block(): class TrinityDouble(PyScheduleRule): - def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + def initialize_with_tune_context(self, context: "TuneContext") -> None: pass def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: @@ -283,7 +283,7 @@ def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: return result class RemoveBlock(PyScheduleRule): - def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + def initialize_with_tune_context(self, context: "TuneContext") -> None: pass def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: diff --git a/tests/python/unittest/test_meta_schedule_search_strategy.py b/tests/python/unittest/test_meta_schedule_search_strategy.py index 9b3ddfd7c789..6ef3771fb783 100644 --- a/tests/python/unittest/test_meta_schedule_search_strategy.py +++ b/tests/python/unittest/test_meta_schedule_search_strategy.py @@ -16,25 +16,25 @@ # under the License. """ Test Meta Schedule SearchStrategy """ # pylint: disable=missing-function-docstring -from typing import List - import sys - import pytest +from typing import List import tvm from tvm.meta_schedule import TuneContext from tvm.meta_schedule.runner import RunnerResult +from tvm.meta_schedule.search_strategy import ( + ReplayTrace, + SearchStrategy, +) from tvm.meta_schedule.space_generator import ScheduleFn -from tvm.meta_schedule.search_strategy import ReplayTrace - from tvm.script import tir as T from tvm.tir.schedule import Schedule, Trace MATMUL_M = 32 -# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, unbalanced-tuple-unpacking +# pylint: disable=missing-class-docstring,invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, unbalanced-tuple-unpacking # fmt: off @tvm.script.ir_module @@ -53,48 +53,57 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] # fmt: on -# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument +# pylint: enable=missing-class-docstring,invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument -def _is_trace_equal(sch_1: Schedule, sch_2: Schedule) -> bool: - trace_1 = Trace(sch_1.trace.insts, {}) - trace_2 = Trace(sch_2.trace.insts, {}) +def _is_trace_equal(sch_1: Schedule, sch_2: Schedule, remove_decisions=True) -> bool: + if remove_decisions: + trace_1 = Trace(sch_1.trace.insts, {}) + trace_2 = Trace(sch_2.trace.insts, {}) + else: + trace_1 = sch_1.trace + trace_2 = sch_2.trace return str(trace_1) == str(trace_2) def _schedule_matmul(sch: Schedule): block = sch.get_block("matmul") i, j, k = sch.get_loops(block=block) - # TODO(@zxybazh): Change to `sample_perfect_tile` after upstreaming - i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 4, 64, 2]) - j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[4, 64, 2, 2]) - k_0, k_1 = sch.split(loop=k, factors=[32, 32]) + i_0, i_1, i_2, i_3 = sch.split(i, sch.sample_perfect_tile(i, n=4)) + j_0, j_1, j_2, j_3 = sch.split(j, sch.sample_perfect_tile(j, n=4)) + k_0, k_1 = sch.split(k, sch.sample_perfect_tile(k, n=2)) sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) -def test_meta_schedule_replay_trace(): +@pytest.mark.parametrize("TestClass", [ReplayTrace]) +def test_meta_schedule_replay_func(TestClass: SearchStrategy): # pylint: disable = invalid-name num_trials_per_iter = 7 num_trials_total = 20 - (example_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul) - replay = ReplayTrace(num_trials_per_iter=num_trials_per_iter, num_trials_total=num_trials_total) - tune_context = TuneContext(mod=Matmul) - replay.initialize_with_tune_context(tune_context) - - num_trials_each_round: List[int] = [] - replay.pre_tuning([example_sch]) - while True: - candidates = replay.generate_measure_candidates() - if candidates is None: - break - num_trials_each_round.append(len(candidates)) + strategy = TestClass(num_trials_per_iter=num_trials_per_iter, num_trials_total=num_trials_total) + tune_context = TuneContext(mod=Matmul, space_generator=ScheduleFn(sch_fn=_schedule_matmul)) + tune_context.space_generator.initialize_with_tune_context(tune_context) + spaces = tune_context.space_generator.generate_design_space(tune_context.mod) + + strategy.initialize_with_tune_context(tune_context) + strategy.pre_tuning(spaces) + (correct_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul) + num_trials_each_iter: List[int] = [] + candidates = strategy.generate_measure_candidates() + while candidates is not None: + num_trials_each_iter.append(len(candidates)) runner_results: List[RunnerResult] = [] for candidate in candidates: - assert _is_trace_equal(candidate.sch, example_sch) - runner_results.append(RunnerResult(run_secs=[0.5, 0.4, 0.3], error_msg=None)) - replay.notify_runner_results(runner_results) - replay.post_tuning() - assert num_trials_each_round == [7, 7, 6] + _is_trace_equal( + candidate.sch, + correct_sch, + remove_decisions=(isinstance(strategy, ReplayTrace)), + ) + runner_results.append(RunnerResult(run_secs=[0.11, 0.41, 0.54], error_msg=None)) + strategy.notify_runner_results(tune_context, candidates, runner_results) + candidates = strategy.generate_measure_candidates() + strategy.post_tuning() + assert num_trials_each_iter == [7, 7, 6] if __name__ == "__main__": diff --git a/tests/python/unittest/test_meta_schedule_task_scheduler.py b/tests/python/unittest/test_meta_schedule_task_scheduler.py index 7eb61ad2c7cf..d3c4dbca826f 100644 --- a/tests/python/unittest/test_meta_schedule_task_scheduler.py +++ b/tests/python/unittest/test_meta_schedule_task_scheduler.py @@ -16,24 +16,22 @@ # under the License. """ Test Meta Schedule Task Scheduler """ -from typing import List - -import sys import random +import sys +from typing import List import pytest - import tvm -from tvm.script import tir as T from tvm.ir import IRModule -from tvm.tir import Schedule -from tvm.meta_schedule import TuneContext -from tvm.meta_schedule.space_generator import ScheduleFn -from tvm.meta_schedule.search_strategy import ReplayTrace -from tvm.meta_schedule.builder import PyBuilder, BuilderInput, BuilderResult -from tvm.meta_schedule.runner import PyRunner, RunnerInput, RunnerFuture, RunnerResult +from tvm.meta_schedule import TuneContext, measure_callback +from tvm.meta_schedule.builder import BuilderInput, BuilderResult, PyBuilder from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload -from tvm.meta_schedule.task_scheduler import RoundRobin, PyTaskScheduler +from tvm.meta_schedule.runner import PyRunner, RunnerFuture, RunnerInput, RunnerResult +from tvm.meta_schedule.search_strategy import ReplayTrace +from tvm.meta_schedule.space_generator import ScheduleFn +from tvm.meta_schedule.task_scheduler import PyTaskScheduler, RoundRobin +from tvm.script import tir as T +from tvm.tir import Schedule # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring @@ -140,7 +138,10 @@ def __init__(self): self.records = [] self.workload_reg = [] - def has_workload(self, mod: IRModule) -> bool: + def has_workload(self, mod: IRModule) -> Workload: + for workload in self.workload_reg: + if tvm.ir.structural_equal(workload.mod, mod): + return True return False def commit_tuning_record(self, record: TuningRecord) -> None: @@ -183,7 +184,13 @@ def test_meta_schedule_task_scheduler_single(): rand_state=42, ) database = DummyDatabase() - round_robin = RoundRobin([task], DummyBuilder(), DummyRunner(), database) + round_robin = RoundRobin( + [task], + DummyBuilder(), + DummyRunner(), + database, + measure_callbacks=[measure_callback.AddToDatabase()], + ) round_robin.tune() assert len(database) == num_trials_total @@ -218,15 +225,29 @@ def test_meta_schedule_task_scheduler_multiple(): ), ] database = DummyDatabase() - round_robin = RoundRobin(tasks, DummyBuilder(), DummyRunner(), database) + round_robin = RoundRobin( + tasks, + DummyBuilder(), + DummyRunner(), + database, + measure_callbacks=[measure_callback.AddToDatabase()], + ) round_robin.tune() assert len(database) == num_trials_total * len(tasks) print(database.workload_reg) for task in tasks: - assert len(database.get_top_k(database.commit_workload(task.mod), 1e9)) == num_trials_total + assert ( + len( + database.get_top_k( + database.commit_workload(task.mod), + 100000, + ) + ) + == num_trials_total + ) -def test_meta_schedule_task_scheduler_NIE(): +def test_meta_schedule_task_scheduler_not_implemented_error(): # pylint: disable=invalid-name class MyTaskScheduler(PyTaskScheduler): pass @@ -234,7 +255,7 @@ class MyTaskScheduler(PyTaskScheduler): MyTaskScheduler([], DummyBuilder(), DummyRunner(), DummyDatabase()) -def test_meta_schedule_task_scheduler_override_next_task_id_only(): +def test_meta_schedule_task_scheduler_override_next_task_id_only(): # pylint: disable=invalid-name class MyTaskScheduler(PyTaskScheduler): done = set() @@ -291,11 +312,27 @@ def next_task_id(self) -> int: ), ] database = DummyDatabase() - scheduler = MyTaskScheduler(tasks, DummyBuilder(), DummyRunner(), database) + scheduler = MyTaskScheduler( + tasks, + DummyBuilder(), + DummyRunner(), + database, + measure_callbacks=[ + measure_callback.AddToDatabase(), + ], + ) scheduler.tune() assert len(database) == num_trials_total * len(tasks) for task in tasks: - assert len(database.get_top_k(database.commit_workload(task.mod), 1e9)) == num_trials_total + assert ( + len( + database.get_top_k( + database.commit_workload(task.mod), + 100000, + ) + ) + == num_trials_total + ) if __name__ == "__main__": From 959959ef2e01f3d6e7de677033886990fb002f83 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Tue, 21 Dec 2021 22:52:02 -0800 Subject: [PATCH 2/4] Minor fix. Fix mypy. Fix mypy. --- include/tvm/meta_schedule/space_generator.h | 8 ++++---- include/tvm/meta_schedule/task_scheduler.h | 19 +++++++++++++++++++ .../search_strategy/replay_trace.py | 2 +- .../test_meta_schedule_search_strategy.py | 10 +++++----- 4 files changed, 29 insertions(+), 10 deletions(-) diff --git a/include/tvm/meta_schedule/space_generator.h b/include/tvm/meta_schedule/space_generator.h index 7aff6839dc55..3611870c7c9b 100644 --- a/include/tvm/meta_schedule/space_generator.h +++ b/include/tvm/meta_schedule/space_generator.h @@ -139,13 +139,13 @@ class SpaceGenerator : public ObjectRef { public: /*! * \brief Create a design space generator with customized methods on the python-side. - * \param initialize_with_tune_context_func The packed function of `InitializeWithTuneContext`. - * \param generate_design_space_func The packed function of `GenerateDesignSpace`. + * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. + * \param f_generate_design_space The packed function of `GenerateDesignSpace`. * \return The design space generator created. */ TVM_DLL static SpaceGenerator PySpaceGenerator( - PySpaceGeneratorNode::FInitializeWithTuneContext initialize_with_tune_context_func, - PySpaceGeneratorNode::FGenerateDesignSpace generate_design_space_func); + PySpaceGeneratorNode::FInitializeWithTuneContext f_initialize_with_tune_context, + PySpaceGeneratorNode::FGenerateDesignSpace f_generate_design_space); /*! * \brief Create a design space generator that is union of multiple design space generators. diff --git a/include/tvm/meta_schedule/task_scheduler.h b/include/tvm/meta_schedule/task_scheduler.h index 0284a55e0d03..ddd6f4c4815f 100644 --- a/include/tvm/meta_schedule/task_scheduler.h +++ b/include/tvm/meta_schedule/task_scheduler.h @@ -249,6 +249,9 @@ class TaskScheduler : public runtime::ObjectRef { * \param builder The builder of the scheduler. * \param runner The runner of the scheduler. * \param database The database of the scheduler. + * \param cost_model The cost model of the scheduler. + * \param measure_callbacks The measure callbacks of the scheduler. + * \return The task scheduler created. */ TVM_DLL static TaskScheduler RoundRobin(Array tasks, // Builder builder, // @@ -256,6 +259,22 @@ class TaskScheduler : public runtime::ObjectRef { Database database, // Optional cost_model, // Optional> measure_callbacks); + /*! + * \brief Create a task scheduler with customized methods on the python-side. + * \param tasks The tasks to be tuned. + * \param builder The builder of the scheduler. + * \param runner The runner of the scheduler. + * \param database The database of the scheduler. + * \param cost_model The cost model of the scheduler. + * \param measure_callbacks The measure callbacks of the scheduler. + * \param f_tune The packed function of `Tune`. + * \param f_initialize_task The packed function of `InitializeTask`. + * \param f_set_task_stopped The packed function of `SetTaskStopped`. + * \param f_is_task_running The packed function of `IsTaskRunning`. + * \param f_join_running_task The packed function of `JoinRunningTask`. + * \param f_next_task_id The packed function of `NextTaskId`. + * \return The task scheduler created. + */ TVM_DLL static TaskScheduler PyTaskScheduler( Array tasks, // Builder builder, // diff --git a/python/tvm/meta_schedule/search_strategy/replay_trace.py b/python/tvm/meta_schedule/search_strategy/replay_trace.py index f55013546021..5655038d2ead 100644 --- a/python/tvm/meta_schedule/search_strategy/replay_trace.py +++ b/python/tvm/meta_schedule/search_strategy/replay_trace.py @@ -42,7 +42,7 @@ class ReplayTrace(SearchStrategy): def __init__(self, num_trials_per_iter: int, num_trials_total: int): """Constructor""" self.__init_handle_by_constructor__( - _ffi_api.SearchStrategyReplayTrace, # pylint: disable=no-member + _ffi_api.SearchStrategyReplayTrace, # type: ignore # pylint: disable=no-member num_trials_per_iter, num_trials_total, ) diff --git a/tests/python/unittest/test_meta_schedule_search_strategy.py b/tests/python/unittest/test_meta_schedule_search_strategy.py index 6ef3771fb783..668fca9ecbbf 100644 --- a/tests/python/unittest/test_meta_schedule_search_strategy.py +++ b/tests/python/unittest/test_meta_schedule_search_strategy.py @@ -81,11 +81,11 @@ def test_meta_schedule_replay_func(TestClass: SearchStrategy): # pylint: disabl num_trials_total = 20 strategy = TestClass(num_trials_per_iter=num_trials_per_iter, num_trials_total=num_trials_total) - tune_context = TuneContext(mod=Matmul, space_generator=ScheduleFn(sch_fn=_schedule_matmul)) - tune_context.space_generator.initialize_with_tune_context(tune_context) - spaces = tune_context.space_generator.generate_design_space(tune_context.mod) + context = TuneContext(mod=Matmul, space_generator=ScheduleFn(sch_fn=_schedule_matmul)) + context.space_generator.initialize_with_tune_context(context) + spaces = context.space_generator.generate_design_space(context.mod) - strategy.initialize_with_tune_context(tune_context) + strategy.initialize_with_tune_context(context) strategy.pre_tuning(spaces) (correct_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul) num_trials_each_iter: List[int] = [] @@ -100,7 +100,7 @@ def test_meta_schedule_replay_func(TestClass: SearchStrategy): # pylint: disabl remove_decisions=(isinstance(strategy, ReplayTrace)), ) runner_results.append(RunnerResult(run_secs=[0.11, 0.41, 0.54], error_msg=None)) - strategy.notify_runner_results(tune_context, candidates, runner_results) + strategy.notify_runner_results(context, candidates, runner_results) candidates = strategy.generate_measure_candidates() strategy.post_tuning() assert num_trials_each_iter == [7, 7, 6] From 00474be16c35b361d757384a20b1666c61a4f8cf Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Wed, 22 Dec 2021 01:33:19 -0800 Subject: [PATCH 3/4] Retrigger CI. From 96aca71fd537f473a086a867c9445642cdec19e4 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Fri, 31 Dec 2021 22:36:55 -0800 Subject: [PATCH 4/4] Minor fixes. --- include/tvm/meta_schedule/tune_context.h | 4 +--- python/tvm/auto_scheduler/search_task.py | 3 +-- python/tvm/auto_scheduler/workload_registry.py | 5 +---- src/meta_schedule/search_strategy/replay_trace.cc | 1 - src/meta_schedule/tune_context.cc | 2 +- src/meta_schedule/utils.h | 5 +++++ 6 files changed, 9 insertions(+), 11 deletions(-) diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h index ff3a14c076e4..428a2e80f4dd 100644 --- a/include/tvm/meta_schedule/tune_context.h +++ b/include/tvm/meta_schedule/tune_context.h @@ -53,7 +53,7 @@ class TuneContextNode : public runtime::Object { /*! \brief The probability of using certain mutator. */ Map mutator_probs; /*! \brief The name of the tuning task. */ - String task_name; + Optional task_name; /*! \brief The random state. */ support::LinearCongruentialEngine::TRandState rand_state; /*! \brief The number of threads to be used. */ @@ -82,8 +82,6 @@ class TuneContextNode : public runtime::Object { v->Visit("rand_state", &rand_state); v->Visit("num_threads", &num_threads); v->Visit("is_stopped", &is_stopped); - v->Visit("builder_results", &builder_results); - v->Visit("runner_futures", &runner_futures); v->Visit("measure_candidates", &measure_candidates); } diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index 0e9c4abebbe1..f1156998bdac 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -543,8 +543,7 @@ def print_best(self, log_file, print_mode="schedule"): code: str The best schedule code in python API or CUDA source code """ - inp, res = load_best_record(log_file, self.workload_key) - print("Best codes (ms):", [float(c) * 1000.0 for c in res.costs]) + inp, _ = load_best_record(log_file, self.workload_key) if inp is None: raise RuntimeError( "Cannot find any valid schedule for %s in file %s" % (self.workload_key, log_file) diff --git a/python/tvm/auto_scheduler/workload_registry.py b/python/tvm/auto_scheduler/workload_registry.py index 75702b0a21af..885eb0d1d0f8 100644 --- a/python/tvm/auto_scheduler/workload_registry.py +++ b/python/tvm/auto_scheduler/workload_registry.py @@ -194,10 +194,7 @@ def workload_key_to_tensors(workload_key): assert callable(value) args = deserialize_args(workload[1:]) - result = value(*args) - if isinstance(result, tuple): - result = list(result) - return result + return value(*args) def serialize_workload_registry_entry(workload_key): diff --git a/src/meta_schedule/search_strategy/replay_trace.cc b/src/meta_schedule/search_strategy/replay_trace.cc index 8c9e2d8949e9..1eac10d1ad82 100644 --- a/src/meta_schedule/search_strategy/replay_trace.cc +++ b/src/meta_schedule/search_strategy/replay_trace.cc @@ -17,7 +17,6 @@ * under the License. */ #include "../utils.h" -#include "tvm/tir/schedule/schedule.h" namespace tvm { namespace meta_schedule { diff --git a/src/meta_schedule/tune_context.cc b/src/meta_schedule/tune_context.cc index c06cb9adc8ff..f4595d3b524c 100644 --- a/src/meta_schedule/tune_context.cc +++ b/src/meta_schedule/tune_context.cc @@ -41,7 +41,7 @@ TuneContext::TuneContext(Optional mod, n->sch_rules = sch_rules.value_or({}); n->postprocs = postprocs.value_or({}); n->mutator_probs = mutator_probs.value_or({}); - n->task_name = task_name.value_or("main"); + n->task_name = task_name; if (rand_state == -1) { rand_state = support::LinearCongruentialEngine::DeviceRandom(); } diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 3b4eb32105ea..3e989e4bb46c 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -300,12 +300,17 @@ struct ThreadedTraceApply { } private: + /*! \brief A helper data structure that stores the fail count for each postprocessor. */ struct Item { + /*! \brief The postprocessor. */ Postproc postproc{nullptr}; + /*! \brief The thread-safe postprocessor failure counter. */ std::atomic fail_counter{0}; }; + /*! \brief The number of total postprocessors. */ int n_; + /*! \brief The pointer to the list of postprocessor items. */ Item* items_; };