diff --git a/include/tvm/meta_schedule/measure_callback.h b/include/tvm/meta_schedule/measure_callback.h index 151582d4c9ce..30d1c2cd3ee0 100644 --- a/include/tvm/meta_schedule/measure_callback.h +++ b/include/tvm/meta_schedule/measure_callback.h @@ -122,11 +122,6 @@ class MeasureCallback : public runtime::ObjectRef { * \return The measure callback created. */ TVM_DLL static MeasureCallback RemoveBuildArtifact(); - /*! - * \brief Create a measure callback that echos the statistics of the tuning process to the console - * \return The measure callback created. - */ - TVM_DLL static MeasureCallback EchoStatistics(); /*! * \brief Create a measure callback that updates the cost model with measurement result. * \return The measure callback created. @@ -140,6 +135,8 @@ class MeasureCallback : public runtime::ObjectRef { */ TVM_DLL static MeasureCallback PyMeasureCallback(PyMeasureCallbackNode::FApply f_apply, PyMeasureCallbackNode::FAsString f_as_string); + /*! \brief The default list of measure callbacks. */ + TVM_DLL static Array Default(); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MeasureCallback, ObjectRef, MeasureCallbackNode); }; diff --git a/include/tvm/meta_schedule/mutator.h b/include/tvm/meta_schedule/mutator.h index 2b580e75e019..08a8248dfdbc 100644 --- a/include/tvm/meta_schedule/mutator.h +++ b/include/tvm/meta_schedule/mutator.h @@ -127,10 +127,17 @@ class Mutator : public runtime::ObjectRef { * \param f_as_string The packed function of `AsString`. * \return The mutator created. */ - TVM_DLL static Mutator PyMutator(FInitializeWithTuneContext f_initialize_with_tune_context, // - FApply f_apply, // - FClone f_clone, // - FAsString f_as_string); + TVM_DLL static Mutator PyMutator(FInitializeWithTuneContext f_initialize_with_tune_context, + FApply f_apply, FClone f_clone, FAsString f_as_string); + /*! \brief Create default mutators for LLVM */ + TVM_DLL static Map DefaultLLVM(); + /*! \brief Create default mutators for CUDA */ + TVM_DLL static Map DefaultCUDA(); + /*! \brief Create default mutators for CUDA with TensorCore */ + TVM_DLL static Map DefaultCUDATensorCore(); + /*! \brief Create default mutators for Hexagon */ + TVM_DLL static Map DefaultHexagon(); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Mutator, ObjectRef, MutatorNode); }; diff --git a/include/tvm/meta_schedule/postproc.h b/include/tvm/meta_schedule/postproc.h index 4fafb9557631..a680a647956c 100644 --- a/include/tvm/meta_schedule/postproc.h +++ b/include/tvm/meta_schedule/postproc.h @@ -150,6 +150,15 @@ class Postproc : public runtime::ObjectRef { * \return The postprocessor created */ TVM_DLL static Postproc RewriteLayout(); + /*! \brief Create default postprocessors for LLVM */ + TVM_DLL static Array DefaultLLVM(); + /*! \brief Create default postprocessors for CUDA */ + TVM_DLL static Array DefaultCUDA(); + /*! \brief Create default postprocessors for CUDA with TensorCore */ + TVM_DLL static Array DefaultCUDATensorCore(); + /*! \brief Create default postprocessors for Hexagon */ + TVM_DLL static Array DefaultHexagon(); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Postproc, ObjectRef, PostprocNode); }; diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index 2c9da1df9dae..3bc30e09c74a 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -140,8 +140,8 @@ class ScheduleRule : public runtime::ObjectRef { Optional> reuse_write); /*! - * \brief Extension of MultiLevelTiling for auto-tensorizing with a single intrinsic. - * \param intrin_name The name of a tensor intrinsic, must be registerd via + * \brief Extension of MultiLevelTiling for auto-tensorization with a single intrinsic. + * \param intrin_name The name of a tensor intrinsic, must be registered via * TensorIntrin.register(...) beforehand * \param structure The tiling structure. Recommended: * - 'SSRSRS' on CPU @@ -162,12 +162,12 @@ class ScheduleRule : public runtime::ObjectRef { Optional> reuse_read, Optional> reuse_write); /*! - * \brief Extension of MultiLevelTiling for auto-tensorizing with multiple groups of candidate + * \brief Extension of MultiLevelTiling for auto-tensorization with multiple groups of candidate * tensor core intrinsics * \param intrin_groups A list of groups of tensor core intrinsics. The map should contains key * "init", "load_a", "load_b", "compute", "store", which represent the tensor intrin for * initialization, loading operand A, loading operand B, tensor core computation, storing the - * result. The value of the map should be names of tensor intrinsics, must be registerd via + * result. The value of the map should be names of tensor intrinsics, must be registered via * TensorIntrin.register(...) beforehand * \param structure The tiling structure. Recommended: * - 'SSSRRSRS' on GPU @@ -261,6 +261,16 @@ class ScheduleRule : public runtime::ObjectRef { FApply f_apply, // FClone f_clone, // FAsString f_as_string); + + /*! \brief Create default schedule rules for LLVM */ + TVM_DLL static Array DefaultLLVM(); + /*! \brief Create default schedule rules for CUDA */ + TVM_DLL static Array DefaultCUDA(); + /*! \brief Create default postprocessors for CUDA with TensorCore */ + TVM_DLL static Array DefaultCUDATensorCore(); + /*! \brief Create default schedule rules for Hexagon */ + TVM_DLL static Array DefaultHexagon(); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ScheduleRule, ObjectRef, ScheduleRuleNode); }; diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h index efd3dc24524a..c2399eef0824 100644 --- a/include/tvm/meta_schedule/search_strategy.h +++ b/include/tvm/meta_schedule/search_strategy.h @@ -88,6 +88,8 @@ class SearchStrategyNode : public runtime::Object { /*! * \brief Pre-tuning for the search strategy. + * \param max_trials The maximum number of trials. + * \param num_trials_per_iter The number of trials per iteration. * \param design_spaces The design spaces used during tuning process. * \param database The database used during tuning process. * \param cost_model The cost model used during tuning process. @@ -95,7 +97,8 @@ class SearchStrategyNode : public runtime::Object { * initialization. Because the search strategy is stateful, we can always call pretuning * and reset the search strategy. */ - virtual void PreTuning(const Array& design_spaces, + virtual void PreTuning(int max_trials, int num_trials_per_iter, + const Array& design_spaces, const Optional& database, const Optional& cost_model) = 0; @@ -143,10 +146,10 @@ class SearchStrategy : public runtime::ObjectRef { using FInitializeWithTuneContext = runtime::TypedPackedFunc; /*! * \brief The function type of `PreTuning` method. - * \param design_spaces The design spaces for pre-tuning. */ using FPreTuning = runtime::TypedPackedFunc&, const Optional&, const Optional&)>; + int max_trials, int num_trials_per_iter, const Array&, + const Optional&, const Optional&)>; /*! \brief The function type of `PostTuning` method. */ using FPostTuning = runtime::TypedPackedFunc; /*! @@ -185,24 +188,15 @@ class SearchStrategy : public runtime::ObjectRef { /*! * \brief Constructor of replay trace search strategy. - * \param num_trials_per_iter The number of trials per iteration, i.e., the batch size. - * \param max_trials_per_task The total number of trials for trace replaying. * \param max_fail_count The max number of failures during trace replaying. */ - TVM_DLL static SearchStrategy ReplayTrace(int num_trials_per_iter, int max_trials_per_task, - int max_fail_count); + TVM_DLL static SearchStrategy ReplayTrace(int max_fail_count); - /*! - * \brief Constructor of replay func search strategy. - * \param num_trials_per_iter The number of trials per iteration, i.e., the batch size. - * \param max_trials_per_task The total number of trials for func replaying. - */ - TVM_DLL static SearchStrategy ReplayFunc(int num_trials_per_iter, int max_trials_per_task); + /*! \brief Constructor of replay func search strategy. */ + TVM_DLL static SearchStrategy ReplayFunc(); /*! * \brief Constructor of evolutionary search strategy. - * \param num_trials_per_iter The number of trials per iteration, i.e., the batch size. - * \param max_trials_per_task The total number of trials for evolutionary search. * \param population_size The initial sample population. * \param init_measured_ratio The ratio of measures samples in initial population. * \param init_min_unmeasured The minimal size of unmeasured population in the initial sampling. @@ -211,9 +205,7 @@ class SearchStrategy : public runtime::ObjectRef { * \param genetic_max_fail_count The maximum number to try evolving the given trace. * \param eps_greedy The ratio to select samples in a greedy fashion via their predicted score. */ - TVM_DLL static SearchStrategy EvolutionarySearch(int num_trials_per_iter, // - int max_trials_per_task, // - int population_size, // + TVM_DLL static SearchStrategy EvolutionarySearch(int population_size, // double init_measured_ratio, // int init_min_unmeasured, // int genetic_num_iters, // @@ -257,8 +249,8 @@ class PySearchStrategyNode : public SearchStrategyNode { } void InitializeWithTuneContext(const TuneContext& context) final; - void PreTuning(const Array& design_spaces, const Optional& database, - const Optional& cost_model) final; + void PreTuning(int max_trials, int num_trials_per_iter, const Array& design_spaces, + const Optional& database, const Optional& cost_model) final; void PostTuning() final; Optional> GenerateMeasureCandidates() final; void NotifyRunnerResults(const Array& measure_candidates, diff --git a/include/tvm/meta_schedule/space_generator.h b/include/tvm/meta_schedule/space_generator.h index 1e29e757a15c..f746eb809194 100644 --- a/include/tvm/meta_schedule/space_generator.h +++ b/include/tvm/meta_schedule/space_generator.h @@ -20,10 +20,14 @@ #define TVM_META_SCHEDULE_SPACE_GENERATOR_H_ #include +#include +#include +#include #include #include #include #include +#include #include namespace tvm { @@ -71,6 +75,19 @@ class SpaceGenerator; */ class SpaceGeneratorNode : public runtime::Object { public: + /*! \brief The schedule rules. */ + Optional> sch_rules; + /*! \brief The postprocessors. */ + Optional> postprocs; + /*! \brief The probability of using certain mutator. */ + Optional> mutator_probs; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("sch_rules", &sch_rules); + v->Visit("postprocs", &postprocs); + v->Visit("mutator_probs", &mutator_probs); + } + /*! \brief Default destructor */ virtual ~SpaceGeneratorNode() = default; @@ -79,7 +96,7 @@ class SpaceGeneratorNode : public runtime::Object { * \param context The tuning context for initialization. * \note This method is supposed to be called only once before every other method. */ - virtual void InitializeWithTuneContext(const TuneContext& context) = 0; + virtual void InitializeWithTuneContext(const TuneContext& context); /*! * \brief Generate design spaces given a module. @@ -127,12 +144,17 @@ class SpaceGenerator : public runtime::ObjectRef { public: /*! * \brief Create a design space generator with customized methods on the python-side. + * \param sch_rules The schedule rules. + * \param postprocs The postprocessors. + * \param mutator_probs The probability of using certain mutator. * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. * \param f_generate_design_space The packed function of `GenerateDesignSpace`. * \param f_clone The packed function of `Clone`. * \return The design space generator created. */ TVM_DLL static SpaceGenerator PySpaceGenerator( + Optional> sch_rules, Optional> postprocs, + Optional> mutator_probs, FInitializeWithTuneContext f_initialize_with_tune_context, FGenerateDesignSpace f_generate_design_space, FClone f_clone); /*! @@ -141,19 +163,39 @@ class SpaceGenerator : public runtime::ObjectRef { * 1) void(Schedule) * 2) Schedule(Schedule) * 3) Array(Schedule) + * \param sch_rules The schedule rules. + * \param postprocs The postprocessors. + * \param mutator_probs The probability of using certain mutator. */ - TVM_DLL static SpaceGenerator ScheduleFn(PackedFunc schedule_fn); + TVM_DLL static SpaceGenerator ScheduleFn(PackedFunc schedule_fn, + Optional> sch_rules, + Optional> postprocs, + Optional> mutator_probs); /*! * \brief Create a design space generator that is union of multiple design space generators. * \param space_generators An array of design space generators to be unioned. + * \param sch_rules The schedule rules. + * \param postprocs The postprocessors. + * \param mutator_probs The probability of using certain mutator. * \return The design space generator created. */ - TVM_DLL static SpaceGenerator SpaceGeneratorUnion(Array space_generators); + TVM_DLL static SpaceGenerator SpaceGeneratorUnion(Array space_generators, + Optional> sch_rules, + Optional> postprocs, + Optional> mutator_probs); /*! * \brief Create a design space generator that generates design spaces by applying schedule - * rules to blocks in post-DFS order. \return The design space generator created. + * rules to blocks in post-DFS order. + * \param f_block_filter The filter function to filter blocks to be applied with schedule rules. + * \param sch_rules The schedule rules. + * \param postprocs The postprocessors. + * \param mutator_probs The probability of using certain mutator. + * \return The design space generator created. */ - TVM_DLL static SpaceGenerator PostOrderApply(runtime::PackedFunc f_block_filter = nullptr); + TVM_DLL static SpaceGenerator PostOrderApply(runtime::PackedFunc f_block_filter, + Optional> sch_rules, + Optional> postprocs, + Optional> mutator_probs); TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(SpaceGenerator, ObjectRef, SpaceGeneratorNode); }; @@ -171,6 +213,7 @@ class PySpaceGeneratorNode : public SpaceGeneratorNode { FClone f_clone; void VisitAttrs(tvm::AttrVisitor* v) { + SpaceGeneratorNode::VisitAttrs(v); // `f_initialize_with_tune_context` is not visited // `f_generate_design_space` is not visited // `f_clone` is not visited diff --git a/include/tvm/meta_schedule/task_scheduler.h b/include/tvm/meta_schedule/task_scheduler.h index 385816e790e2..17d82558fb82 100644 --- a/include/tvm/meta_schedule/task_scheduler.h +++ b/include/tvm/meta_schedule/task_scheduler.h @@ -21,7 +21,6 @@ #include #include -#include #include #include #include @@ -32,9 +31,64 @@ #include #include +#include +#include + namespace tvm { namespace meta_schedule { +class TaskRecordNode : public runtime::Object { + public: + /*! \brief The tune context of the task. */ + TuneContext ctx{nullptr}; + /*! \brief The weight of the task */ + double task_weight{1.0}; + /*! \brief The FLOP count of the task */ + double flop{1.0}; + /*! \brief Whether the tuning task has been stopped or finished. */ + bool is_terminated = false; + /*! \brief Builder errors happens in the task */ + int build_error_count = 0; + /*! \brief Runner errors happens in the task */ + int run_error_count = 0; + /*! \brief The latency of each run, in milliseconds. */ + std::vector latency_ms = {}; + /*! \brief The measure candidates. */ + Optional> measure_candidates = NullOpt; + /*! \brief The building results. */ + Optional> builder_results = NullOpt; + /*! \brief Packed functions to fetch the runner results asynchronously. */ + Optional> runner_futures = NullOpt; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("ctx", &ctx); + v->Visit("task_weight", &task_weight); + v->Visit("flop", &flop); + v->Visit("is_terminated", &is_terminated); + v->Visit("build_error_count", &build_error_count); + v->Visit("run_error_count", &run_error_count); + // `latency_ms` is not visited + v->Visit("measure_candidates", &measure_candidates); + v->Visit("builder_results", &builder_results); + v->Visit("runner_futures", &runner_futures); + } + + static constexpr const char* _type_key = "meta_schedule.TaskRecord"; + TVM_DECLARE_FINAL_OBJECT_INFO(TaskRecordNode, Object); +}; + +/*! + * \brief Managed reference to TaskRecordNode. + * \sa TaskRecordNode + */ +class TaskRecord : public runtime::ObjectRef { + public: + /*! \brief Constructor */ + explicit TaskRecord(TuneContext task, double task_weight); + + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TaskRecord, ObjectRef, TaskRecordNode); +}; + /*! * \brief The abstract interface of task schedulers. * \note The relationship between SpaceGenerator and other classes are as follows: @@ -73,66 +127,77 @@ namespace meta_schedule { */ class TaskSchedulerNode : public runtime::Object { public: - /*! \brief The tasks to be tuned */ - Array tasks; - /*! \brief The builder of the scheduler. */ - Builder builder{nullptr}; - /*! \brief The runner of the scheduler. */ - Runner runner{nullptr}; - /*! \brief The database of the scheduler. */ - Optional database; - /*! \brief The cost model of the scheduler. */ - Optional cost_model; + /*! \brief The tuning task's logging function. */ + PackedFunc logger; + /*! \brief Records for each task */ + Array tasks_; /*! \brief The list of measure callbacks of the scheduler. */ - Array measure_callbacks; - /*! \brief The maximum number of trials allowed. */ - int max_trials; - /*! \brief The number of trials already conducted. */ - int num_trials_already; - /*! \brief The tuning task's logging function. t*/ - PackedFunc logging_func; + Array measure_callbacks_; + /*! \brief The database used in tuning */ + Optional database_; + /*! \brief The cost model used in tuning */ + Optional cost_model_; + /*! \brief The number of remaining tasks to be tuned. */ + int remaining_tasks_; /*! \brief The default destructor. */ virtual ~TaskSchedulerNode() = default; void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("tasks", &tasks); - v->Visit("builder", &builder); - v->Visit("runner", &runner); - v->Visit("database", &database); - v->Visit("cost_model", &cost_model); - v->Visit("measure_callbacks", &measure_callbacks); - v->Visit("max_trials", &max_trials); - v->Visit("num_trials_already", &num_trials_already); - // `logging_func` is not visited + // `logger` is not visited + v->Visit("tasks_", &tasks_); + v->Visit("measure_callbacks_", &measure_callbacks_); + v->Visit("database_", &database_); + v->Visit("cost_model_", &cost_model_); + v->Visit("remaining_tasks_", &remaining_tasks_); } - /*! \brief Auto-tuning. */ - virtual void Tune(); - - /*! - * \brief Initialize modules of the given task. - * \param task_id The task id to be initialized. - */ - virtual void InitializeTask(int task_id); - /*! - * \brief Touch the task and update its status - * \param task_id The task id to be checked. + * \brief Fetch the next task id. + * \return The next task id. */ - virtual void TouchTask(int task_id); - + virtual int NextTaskId() = 0; /*! * \brief Wait until the task is finished. * \param task_id The task id to be joined. + * \return The results from the runner. */ virtual Array JoinRunningTask(int task_id); - /*! - * \brief Fetch the next task id. - * \return The next task id. + * \brief Jointly tune a given list of tasks. + * \param tasks The tasks to be tuned + * \param task_weights The weight of each task + * \param max_trials_global The maximum number of trials to be performed globally + * \param max_trials_per_task The maximum number of trials to be performed for each task + * \param num_trials_per_iter The number of trials to be performed in each iteration + * \param builder The MetaSchedule builder + * \param runner The MetaSchedule runner + * \param measure_callbacks The callbacks to be called after each measurement + * \param database The database used in tuning + * \param cost_model The cost model used in tuning */ - virtual int NextTaskId() = 0; + virtual void Tune(Array tasks, // + Array task_weights, // + int max_trials_global, // + int max_trials_per_task, // + int num_trials_per_iter, // + Builder builder, // + Runner runner, // + Array measure_callbacks, // + Optional database, // + Optional cost_model); + /*! + * \brief Terminate a task + * \param task_id The id of the task to be terminated + */ + void TerminateTask(int task_id); + /*! + * \brief Touch the task and update its status + * \param task_id The task id to be checked. + */ + void TouchTask(int task_id); + /*! \brief Returns a human-readable string of the tuning statistics. */ + std::string TuningStatistics() const; static constexpr const char* _type_key = "meta_schedule.TaskScheduler"; TVM_DECLARE_BASE_OBJECT_INFO(TaskSchedulerNode, Object); @@ -143,55 +208,48 @@ class TaskScheduler; /*! \brief The task scheduler with customized methods on the python-side. */ class PyTaskSchedulerNode : public TaskSchedulerNode { public: - /*! \brief The function type of `Tune` method. */ - using FTune = runtime::TypedPackedFunc; - - /*! \brief The function type of `InitializeTask` method. */ - using FInitializeTask = runtime::TypedPackedFunc; - /*! - * \brief The function type of `TouchTask` method. - * \param task_id The task id to be checked. - * \return Whether the task is running. + * \brief The function type of `NextTaskId` method. + * \return The next task id. */ - using FTouchTask = runtime::TypedPackedFunc; - + using FNextTaskId = runtime::TypedPackedFunc; /*! * \brief The function type of `JoinRunningTask` method. * \param task_id The task id to be joined. */ using FJoinRunningTask = runtime::TypedPackedFunc(int)>; + /*! \brief The function type of `Tune` method. */ + using FTune = runtime::TypedPackedFunc tasks, // + Array task_weights, // + int max_trials_global, // + int max_trials_per_task, // + int num_trials_per_iter, // + Builder builder, // + Runner runner, // + Array measure_callbacks, // + Optional database, // + Optional cost_model)>; - /*! - * \brief The function type of `NextTaskId` method. - * \return The next task id. - */ - using FNextTaskId = runtime::TypedPackedFunc; - - /*! \brief The packed function to the `Tune` function. */ - FTune f_tune; - /*! \brief The packed function to the `InitializeTask` function. */ - FInitializeTask f_initialize_task; - /*! \brief The packed function to the `TouchTask` function. */ - FTouchTask f_touch_task; - /*! \brief The packed function to the `JoinRunningTask` function. */ - FJoinRunningTask f_join_running_task; /*! \brief The packed function to the `NextTaskId` function. */ FNextTaskId f_next_task_id; + /*! \brief The packed function to the `JoinRunningTask` function. */ + FJoinRunningTask f_join_running_task; + /*! \brief The packed function to the `Tune` function. */ + FTune f_tune; void VisitAttrs(tvm::AttrVisitor* v) { - // `f_tune` is not visited - // `f_initialize_task` is not visited - // `f_touch_task` is not visited - // `f_join_running_task` is not visited + TaskSchedulerNode::VisitAttrs(v); // `f_next_task_id` is not visited + // `f_join_running_task` is not visited + // `f_tune` is not visited } - void Tune() final; - void InitializeTask(int task_id) final; - void TouchTask(int task_id) final; - Array JoinRunningTask(int task_id) final; int NextTaskId() final; + Array JoinRunningTask(int task_id) final; + void Tune(Array tasks, Array task_weights, int max_trials_global, + int max_trials_per_task, int num_trials_per_iter, Builder builder, Runner runner, + Array measure_callbacks, Optional database, + Optional cost_model) final; static constexpr const char* _type_key = "meta_schedule.PyTaskScheduler"; TVM_DECLARE_FINAL_OBJECT_INFO(PyTaskSchedulerNode, TaskSchedulerNode); @@ -205,83 +263,31 @@ class TaskScheduler : public runtime::ObjectRef { public: /*! * \brief Create a task scheduler that fetches tasks in a round-robin fashion. - * \param tasks The tasks to be tuned. - * \param builder The builder of the scheduler. - * \param runner The runner of the scheduler. - * \param database The database of the scheduler. - * \param max_trials The maximum number of trials. - * \param cost_model The cost model of the scheduler. - * \param measure_callbacks The measure callbacks of the scheduler. - * \param logging_func The tuning task's logging function. + * \param logger The tuning task's logging function. * \return The task scheduler created. */ - TVM_DLL static TaskScheduler RoundRobin(Array tasks, // - Builder builder, // - Runner runner, // - Optional database, // - Optional cost_model, // - Optional> measure_callbacks, // - int max_trials, // - PackedFunc logging_func); + TVM_DLL static TaskScheduler RoundRobin(PackedFunc logger); /*! * \brief Create a task scheduler that fetches tasks in a gradient based fashion. - * \param tasks The tasks to be tuned. - * \param task_weights The weights of each task. - * \param builder The builder of the scheduler. - * \param runner The runner of the scheduler. - * \param database The database of the scheduler. - * \param max_trials The maximum number of trials. - * \param cost_model The cost model of the scheduler. - * \param measure_callbacks The measure callbacks of the scheduler. - * \param logging_func The tuning task's logging function. + * \param logger The tuning task's logging function. * \param alpha The parameter alpha to control gradient computation. * \param window_size The parameter to control backward window size. * \param seed The random seed. * \return The task scheduler created. */ - TVM_DLL static TaskScheduler GradientBased(Array tasks, - Array task_weights, // - Builder builder, // - Runner runner, // - Optional database, // - Optional cost_model, // - Optional> measure_callbacks, // - int max_trials, // - PackedFunc logging_func, // - double alpha, // - int window_size, // + TVM_DLL static TaskScheduler GradientBased(PackedFunc logger, double alpha, int window_size, support::LinearCongruentialEngine::TRandState seed); /*! * \brief Create a task scheduler with customized methods on the python-side. - * \param tasks The tasks to be tuned. - * \param builder The builder of the scheduler. - * \param runner The runner of the scheduler. - * \param database The database of the scheduler. - * \param max_trials The maximum number of trials. - * \param cost_model The cost model of the scheduler. - * \param measure_callbacks The measure callbacks of the scheduler. - * \param logging_func The tuning task's logging function. - * \param f_tune The packed function of `Tune`. - * \param f_initialize_task The packed function of `InitializeTask`. - * \param f_touch_task The packed function of `TouchTask`. - * \param f_join_running_task The packed function of `JoinRunningTask`. + * \param logger The tuning task's logging function. * \param f_next_task_id The packed function of `NextTaskId`. + * \param f_join_running_task The packed function of `JoinRunningTask`. + * \param f_tune The packed function of `Tune`. * \return The task scheduler created. */ TVM_DLL static TaskScheduler PyTaskScheduler( - Array tasks, // - Builder builder, // - Runner runner, // - Optional database, // - Optional cost_model, // - Optional> measure_callbacks, // - int max_trials, // - PackedFunc logging_func, // - PyTaskSchedulerNode::FTune f_tune, // - PyTaskSchedulerNode::FInitializeTask f_initialize_task, // - PyTaskSchedulerNode::FTouchTask f_touch_task, // - PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, // - PyTaskSchedulerNode::FNextTaskId f_next_task_id); + PackedFunc logger, PyTaskSchedulerNode::FNextTaskId f_next_task_id, + PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, PyTaskSchedulerNode::FTune f_tune); TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TaskScheduler, ObjectRef, TaskSchedulerNode); }; diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h index 4e2f00fb5a0c..15f3cba30b95 100644 --- a/include/tvm/meta_schedule/tune_context.h +++ b/include/tvm/meta_schedule/tune_context.h @@ -22,10 +22,7 @@ #include #include #include -#include -#include #include -#include #include #include #include @@ -48,6 +45,8 @@ class TuneContext; /*! \brief The auto tuning context. */ class TuneContextNode : public runtime::Object { public: + using TRandState = support::LinearCongruentialEngine::TRandState; + /*! \brief The workload to be tuned. */ Optional mod; /*! \brief The target to be tuned for. */ @@ -56,74 +55,35 @@ class TuneContextNode : public runtime::Object { Optional space_generator; /*! \brief The search strategy. */ Optional search_strategy; - /*! \brief The schedule rules. */ - Array sch_rules; - /*! \brief The postprocessors. */ - Array postprocs; - /*! \brief The probability of using certain mutator. */ - Map mutator_probs; /*! \brief The name of the tuning task. */ Optional task_name; - /*! \brief The tuning task's logging function. t*/ - PackedFunc logging_func; - /*! \brief The random state. */ - support::LinearCongruentialEngine::TRandState rand_state; /*! \brief The number of threads to be used. */ int num_threads; - - /*! \brief Whether the tuning task has been stopped or finished. */ - bool is_terminated; // TODO(@junrushao1994): move to TaskScheduler - /*! \brief The measure candidates. */ - Optional> measure_candidates; - /*! \brief The building results. */ - Optional> builder_results; - /*! \brief Packed functions to fetch the runner results asynchronously. */ - Optional> runner_futures; + /*! \brief The random state. */ + TRandState rand_state; + /*! \brief The tuning task's logging function. t*/ + PackedFunc logger; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("mod", &mod); v->Visit("target", &target); v->Visit("space_generator", &space_generator); v->Visit("search_strategy", &search_strategy); - v->Visit("sch_rules", &sch_rules); - v->Visit("postprocs", &postprocs); - v->Visit("mutator_probs", &mutator_probs); v->Visit("task_name", &task_name); - // `logging_func` is not visited - v->Visit("rand_state", &rand_state); v->Visit("num_threads", &num_threads); - v->Visit("is_terminated", &is_terminated); - v->Visit("measure_candidates", &measure_candidates); - v->Visit("builder_results", &builder_results); - v->Visit("runner_futures", &runner_futures); + v->Visit("rand_state", &rand_state); + // `logger` is not visited } - - /*! \brief Initialize members that needs initialization with tune context. */ + /*! + * \brief Initialize members that needs initialization with tune context. + */ void Initialize(); /*! * \brief Clone the tune context. * \return The cloned tune context. */ TuneContext Clone() const; - /*! \brief Set the measure candidates from the SearchStrategy */ - void _SetMeasureCandidates(const Array& candidates); - /*! - * \brief Send the measure candidates to builder. - * \param builder The builder to send the candidates to. - */ - void _SendToBuilder(const Builder& builder); - /*! - * \brief Send the built measure candidates to runner. - * \param runner The runner to send the candidates to. - */ - void _SendToRunner(const Runner& runner); - /*! - * \brief Join the running tasks. - * \returns The results from the runner - */ - Array _Join(); - /*! \brief Set `measure_candidates`, `builder_results` and `runner_futures` to null. */ - void _ClearMeasureState(); + static constexpr const char* _type_key = "meta_schedule.TuneContext"; TVM_DECLARE_FINAL_OBJECT_INFO(TuneContextNode, Object); }; @@ -134,31 +94,22 @@ class TuneContextNode : public runtime::Object { */ class TuneContext : public runtime::ObjectRef { public: + using TRandState = support::LinearCongruentialEngine::TRandState; /*! * \brief Constructor. * \param mod The workload to be tuned. * \param target The target to be tuned for. * \param space_generator The design space generator. * \param search_strategy The search strategy. - * \param sch_rules The schedule rules. - * \param postprocs The postprocessors. - * \param mutator_probs The probability of using certain mutator. * \param task_name The name of the tuning task. - * \param logging_func The tuning task's logging function. - * \param rand_state The random state. * \param num_threads The number of threads to be used. + * \param rand_state The random state. + * \param logger The tuning task's logging function. */ - TVM_DLL explicit TuneContext(Optional mod, // - Optional target, // - Optional space_generator, // - Optional search_strategy, // - Optional> sch_rules, // - Optional> postprocs, // - Optional> mutator_probs, // - Optional task_name, // - PackedFunc logging_func, // - support::LinearCongruentialEngine::TRandState rand_state, // - int num_threads); + TVM_DLL explicit TuneContext(Optional mod, Optional target, + Optional space_generator, + Optional search_strategy, Optional task_name, + int num_threads, TRandState rand_state, PackedFunc logger); TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TuneContext, ObjectRef, TuneContextNode); }; diff --git a/include/tvm/support/random_engine.h b/include/tvm/support/random_engine.h index d9a8a583ce9c..109a98b3d14a 100644 --- a/include/tvm/support/random_engine.h +++ b/include/tvm/support/random_engine.h @@ -16,19 +16,16 @@ * specific language governing permissions and limitations * under the License. */ - /*! * \file random_engine.h * \brief Random number generator. It provides a generic interface consistent with * `std::uniform_random_bit_generator` */ - #ifndef TVM_SUPPORT_RANDOM_ENGINE_H_ #define TVM_SUPPORT_RANDOM_ENGINE_H_ - #include -#include // for uint64_t +#include #include namespace tvm { @@ -46,32 +43,18 @@ namespace support { class LinearCongruentialEngine { public: - /*! - * \brief The result type is defined as uint64_t here to avoid overflow. - * \note The type name is not in Google style because it is used in STL's distribution inferface. - */ - using result_type = uint64_t; using TRandState = int64_t; - + /*! \brief The result type. */ + using result_type = uint64_t; /*! \brief The multiplier */ static constexpr TRandState multiplier = 48271; - /*! \brief The increment */ static constexpr TRandState increment = 0; - /*! \brief The modulus */ static constexpr TRandState modulus = 2147483647; - - /*! - * \brief The minimum possible value of random state here. - * \note The function name is uncapilized because it is used in STL's distribution inferface. - */ + /*! \brief The minimum possible value of random state here. */ static constexpr result_type min() { return 0; } - - /*! - * \brief The maximum possible value of random state here. - * \note The function name is uncapilized because it is used in STL's distribution inferface. - */ + /*! \brief The maximum possible value of random state here. */ static constexpr result_type max() { return modulus - 1; } /*! @@ -94,20 +77,32 @@ class LinearCongruentialEngine { (*rand_state_ptr_) = ((*rand_state_ptr_) * multiplier + increment) % modulus; return *rand_state_ptr_; } - /*! - * \brief Change the start random state of RNG with the seed of a new random state value. - * \param rand_state The random state given in result_type. + * \brief Normalize the random seed to the range of [1, modulus - 1]. + * \param rand_state The random seed. + * \return The normalized random seed. */ - void Seed(TRandState rand_state) { + static TRandState NormalizeSeed(TRandState rand_state) { if (rand_state == -1) { rand_state = DeviceRandom(); - } else if (rand_state == 0) { + } else { + rand_state %= modulus; + } + if (rand_state == 0) { rand_state = 1; } - ICHECK(rand_state >= 0) << "The random state should be nonnegative"; + if (rand_state < 0) { + LOG(FATAL) << "ValueError: Random seed must be non-negative"; + } + return rand_state; + } + /*! + * \brief Change the start random state of RNG with the seed of a new random state value. + * \param rand_state The random state given in result_type. + */ + void Seed(TRandState rand_state) { ICHECK(rand_state_ptr_ != nullptr); - *rand_state_ptr_ = rand_state % modulus; + *rand_state_ptr_ = NormalizeSeed(rand_state); } /*! diff --git a/python/tvm/contrib/torch/as_torch.py b/python/tvm/contrib/torch/as_torch.py index a8cd895a6c5e..2412519ea9c5 100644 --- a/python/tvm/contrib/torch/as_torch.py +++ b/python/tvm/contrib/torch/as_torch.py @@ -24,15 +24,19 @@ as_torch: a decorator, which is used to wrap the TVMScript code to `torch.nn.module`. """ import tempfile -from typing import Callable, List, Union +from typing import Callable, List, Optional, Union + +# isort: off +from typing_extensions import Literal + +# isort: on import torch import torch.utils.dlpack - import tvm -from tvm.meta_schedule.tune import TuneConfig, tune_tir +from tvm import meta_schedule as ms from tvm.target.target import Target -from tvm.tir.schedule.schedule import Schedule +from tvm.tir import PrimFunc # python wrapper for OperatorModule @@ -48,7 +52,24 @@ def __init__( self.rt_module = None # runtime module self.ir_module = module # IR modules - def tune(self, config: TuneConfig = None, target: Union[str, Target] = None): + def tune( + self, + target: Union[str, Target] = "cpu", + max_trials_global: int = 32, + *, + num_trials_per_iter: int = 32, + builder: ms.Builder.BuilderType = "local", + runner: ms.Runner.RunnerType = "local", + database: ms.Database.DatabaseType = "json", + cost_model: ms.CostModel.CostModelType = "xgb", + measure_callbacks: ms.MeasureCallback.CallbackListType = "default", + task_scheduler: ms.TaskScheduler.TaskSchedulerType = "round-robin", + space: ms.SpaceGenerator.SpaceGeneratorType = "post-order-apply", + strategy: ms.SearchStrategy.SearchStrategyType = "replay_trace", + task_name: str = "main", + num_threads: Union[Literal["physical", "logical"], int] = "physical", + seed: Optional[int] = None, + ) -> None: """ Tune the TVMScript code. @@ -60,23 +81,29 @@ def tune(self, config: TuneConfig = None, target: Union[str, Target] = None): target : Optional[str, Target] The target to tune for. """ - if config is None: - config = TuneConfig( - # Default setting - strategy="replay_trace", - num_trials_per_iter=32, - max_trials_per_task=32, - max_trials_global=32, - ) - if target is None: - target = Target("llvm --num-cores=16") + if target == "cpu": + target = f"llvm --num-cores {ms.utils.cpu_count(logical=False)}" + with tempfile.TemporaryDirectory() as work_dir: - sch: Schedule = tune_tir( + database = ms.tir_integration.tune_tir( mod=self.ir_module, target=target, - config=config, work_dir=work_dir, + max_trials_global=max_trials_global, + num_trials_per_iter=num_trials_per_iter, + builder=builder, + runner=runner, + database=database, + cost_model=cost_model, + measure_callbacks=measure_callbacks, + task_scheduler=task_scheduler, + space=space, + strategy=strategy, + task_name=task_name, + num_threads=num_threads, + seed=seed, ) + sch = ms.tir_integration.compile_tir(database, self.ir_module, target) self.ir_module = sch.mod self.build(target) @@ -117,11 +144,11 @@ def as_torch(func: Union[tvm.ir.module.IRModule, tvm.tir.function.PrimFunc, Call which is the subclass of the original nn.Module. """ - if isinstance(func, (tvm.ir.module.IRModule, tvm.tir.function.PrimFunc)): + if isinstance(func, (tvm.ir.module.IRModule, PrimFunc)): return OperatorModuleWrapper(func) - if isinstance(func, Callable): + if callable(func): - def func_get_param(*args, **kargs): - return OperatorModuleWrapper(func(*args, **kargs)) + def func_get_param(*args, **kwargs): + return OperatorModuleWrapper(func(*args, **kwargs)) return func_get_param diff --git a/python/tvm/contrib/torch/optimize_torch.py b/python/tvm/contrib/torch/optimize_torch.py index 821a3b1f71d5..347ea89f92ee 100644 --- a/python/tvm/contrib/torch/optimize_torch.py +++ b/python/tvm/contrib/torch/optimize_torch.py @@ -28,28 +28,17 @@ import base64 import contextlib import tempfile -from typing import Dict, Optional, Tuple, Union -import warnings +from typing import Optional, Tuple, Union import torch import torch.utils.dlpack - import tvm +from tvm import meta_schedule as ms from tvm import relay from tvm._ffi import get_global_func, register_func -from tvm.ir.module import IRModule -from tvm.ir.transform import PassContext -from tvm.meta_schedule import TuneConfig, default_config -from tvm.meta_schedule.relay_integration import extract_task_from_relay -from tvm.meta_schedule.tune import tune_extracted_tasks -from tvm.meta_schedule.utils import autotvm_silencer -from tvm.runtime import vm -from tvm.runtime.module import Module -from tvm.runtime.ndarray import NDArray -from tvm.target.target import Target - - -# The python wrapper for GraphExecutorFactory +from tvm.target import Target + + class GraphExecutorFactoryWrapper(torch.nn.Module): def __init__(self, module: tvm.runtime.Module): super().__init__() @@ -62,75 +51,32 @@ def forward(self, *torch_inputs: Tuple[torch.Tensor]): return ret -def llvm_target(): - return "llvm -num-cores" - - @register_func("script_torch.save_to_base64") def save_to_base64(obj) -> bytes: with tempfile.NamedTemporaryFile(suffix=".so") as tmpfile: obj.export_library(tmpfile.name) - with open(tmpfile.name, "rb") as tfile: - return base64.b64encode(tfile.read()) - - -def tune_relay_auto( - mod: IRModule, - target: Union[str, Target], - config: TuneConfig, - work_dir: str, - backend: str = "graph", - params: Optional[Dict[str, NDArray]] = None, -) -> Union[Module, vm.Executable]: - """A wrapper of `tune_relay` but provide a default setting for the config. - - Parameters - ---------- - mod : IRModule - The module to tune. - target : Union[str, Target] - The target to tune for. - config : TuneConfig - The search strategy config. - params : Optional[Dict[str, tvm.runtime.NDArray]] - The associated parameters of the program - work_dir : Optional[str] - The working directory to save intermediate results. - backend : str = "graph" - The backend to use for relay compilation(graph / vm). - - Returns - ------- - lib : Union[Module, tvm.runtime.vm.Executable] - The built runtime module or vm Executable for the given relay workload. - """ - target = default_config.target(target) - extracted_tasks = extract_task_from_relay(mod, target, params) - if config is None: - config = TuneConfig( - num_trials_per_iter=16, - max_trials_global=16 * len(extracted_tasks), - ) - database = tune_extracted_tasks(extracted_tasks, config, work_dir) - relay_build = {"graph": relay.build, "vm": relay.vm.compile}[backend] - with target, autotvm_silencer(), database: - with PassContext( - opt_level=3, - config={ - "relay.backend.use_meta_schedule": True, - "relay.backend.use_meta_schedule_dispatch": target.kind.name != "cuda", - "relay.backend.tir_converter": "default", - }, - ): - return relay_build(mod, target=target, params=params) + with open(tmpfile.name, "rb") as temp_file: + return base64.b64encode(temp_file.read()) def optimize_torch( func, example_inputs, - tuning_config=None, - target=None, + *, + max_trials_global: int, work_dir=None, + target: Union[str, Target] = "cpu", + max_trials_per_task: Optional[int] = None, + num_trials_per_iter: int = 64, + builder: ms.Builder.BuilderType = "local", + runner: ms.Runner.RunnerType = "local", + database: ms.Database.DatabaseType = "json", + cost_model: ms.CostModel.CostModelType = "xgb", + measure_callbacks: ms.MeasureCallback.CallbackListType = "default", + task_scheduler: ms.TaskScheduler.TaskSchedulerType = "gradient", + space: ms.SpaceGenerator.SpaceGeneratorType = "post-order-apply", + strategy: ms.SearchStrategy.SearchStrategyType = "evolutionary", + seed: Optional[int] = None, ): """Load PyTorch model that could be traced by TorchScript, then optimize it via MetaSchedule. @@ -139,22 +85,37 @@ def optimize_torch( func : callable or torch.nn.Module A Python function or nn.Module that could run by TorchScript's trace. (ie: torch.jit.trace(model, input)) - example_inputs : tuple or torch.Tensor Inputs to `torch.jit.trace`. - - tuning_config : tvm.meta_schedule.TuneConfig - The configuration for tuning by MetaSchedule. - If user doesn't set the config, the tuning will run with a default setting. - Here, the total number of trials is proportional - to the number of tunable tasks in the input module. - + max_trials_global : int + The maximum number of trials to run globally. + work_dir : Optional[str] + The working directory to save intermediate results. target : Optional[Union[str, Target]] The target of the compilation. If user doesn't set the target, the module will be built for the CPU target. - - work_dir : Optional[str] - The working directory to save intermediate results. + max_trials_per_task : Optional[int] + The maximum number of trials to run per task. + num_trials_per_iter : int + The number of trials to run per iteration + builder : Builder.BuilderType + The builder. + runner : Runner.RunnerType + The runner. + database : Database.DatabaseType + The database. + cost_model : CostModel.CostModelType + The cost model. + measure_callbacks : MeasureCallback.CallbackListType + The measure callbacks. + task_scheduler : TaskScheduler.TaskSchedulerType + The task scheduler. + space : SpaceGenerator.SpaceGeneratorType + The space generator to use. + strategy : SearchStrategy.SearchStrategyType + The search strategy to use. + seed : Optional[int] + The random seed to use. Returns ------- @@ -163,33 +124,47 @@ def optimize_torch( which is the subclass of the original nn.Module. """ - if target is None: - target = llvm_target() - - if tuning_config is None: - warning_msg = ( - "Using the default tuning parameters.", - "The default number of trials is set to a small value to let tuning finish quickly.", - "For optimal performance, it is recommended to provide", - "the `tuning_config` argument with a bigger number of trials.", - ) - warnings.warn(" ".join(warning_msg), stacklevel=2) + if target == "cpu": + target = f"llvm --num-cores {ms.utils.cpu_count(logical=False)}" + if not isinstance(target, Target): + target = Target(target) # If `func` is already a traced module this statement makes no effect jit_mod = torch.jit.trace(func, example_inputs) - if isinstance(example_inputs, torch.Tensor): example_inputs = [example_inputs] - shape_list = [(f"inp_{idx}", i.shape) for idx, i in enumerate(example_inputs)] mod, params = relay.frontend.from_pytorch(jit_mod, shape_list) # IRmodule + if work_dir: context_manager = contextlib.nullcontext(work_dir) else: context_manager = tempfile.TemporaryDirectory() - with context_manager as work_dir_path: - executor_factory = tune_relay_auto( - mod=mod, params=params, config=tuning_config, target=target, work_dir=work_dir_path + with context_manager as work_dir: # pylint: disable=redefined-argument-from-local + database = ms.relay_integration.tune_relay( + mod=mod, + params=params, + target=target, + work_dir=work_dir, + max_trials_global=max_trials_global, + max_trials_per_task=max_trials_per_task, + num_trials_per_iter=num_trials_per_iter, + builder=builder, + runner=runner, + database=database, + cost_model=cost_model, + measure_callbacks=measure_callbacks, + task_scheduler=task_scheduler, + space=space, + strategy=strategy, + seed=seed, + ) + executor_factory = ms.relay_integration.compile_relay( + database=database, + mod=mod, + target=target, + params=params, + backend="graph", ) save_runtime_mod = get_global_func("tvmtorch.save_runtime_mod") diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py index cf348d49f4e2..c92ed47d8a2a 100644 --- a/python/tvm/meta_schedule/__init__.py +++ b/python/tvm/meta_schedule/__init__.py @@ -20,24 +20,35 @@ builder, cost_model, database, - default_config, feature_extractor, measure_callback, mutator, postproc, + relay_integration, runner, schedule_rule, search_strategy, space_generator, + tir_integration, ) +from .builder import Builder +from .cost_model import CostModel +from .database import Database from .extracted_task import ExtractedTask +from .feature_extractor import FeatureExtractor +from .measure_callback import MeasureCallback +from .mutator import Mutator +from .postproc import Postproc from .profiler import Profiler from .relay_integration import ( - extract_task_from_relay, is_meta_schedule_dispatch_enabled, is_meta_schedule_enabled, ) -from .search_strategy import MeasureCandidate -from .tune import TuneConfig, tune_extracted_tasks, tune_relay, tune_te, tune_tir +from .runner import Runner +from .schedule_rule import ScheduleRule +from .search_strategy import MeasureCandidate, SearchStrategy +from .space_generator import SpaceGenerator +from .tir_integration import tune_tir +from .tune import tune_tasks from .tune_context import TuneContext from .utils import derived_object diff --git a/python/tvm/meta_schedule/builder/builder.py b/python/tvm/meta_schedule/builder/builder.py index a2254f243380..fcab906e6207 100644 --- a/python/tvm/meta_schedule/builder/builder.py +++ b/python/tvm/meta_schedule/builder/builder.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Meta Schedule builders that translate IRModule to runtime.Module, and then export""" -from typing import Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional, Union # isort: off from typing_extensions import Literal @@ -112,6 +112,8 @@ def __init__( class Builder(Object): """The abstract builder interface.""" + BuilderType = Union["Builder", Literal["local"]] + def build(self, build_inputs: List[BuilderInput]) -> List[BuilderResult]: """Build the given inputs. @@ -126,6 +128,33 @@ def build(self, build_inputs: List[BuilderInput]) -> List[BuilderResult]: """ return _ffi_api.BuilderBuild(self, build_inputs) # type: ignore # pylint: disable=no-member + @staticmethod + def create( # pylint: disable=keyword-arg-before-vararg + kind: Literal["local"] = "local", + *args, + **kwargs, + ) -> "Builder": + """Create a Builder. + + Parameters + ---------- + kind : Literal["local"] + The kind of the builder. For now, only "local" is supported. + + Returns + ------- + builder : Builder + The builder created. + """ + from . import LocalBuilder # pylint: disable=import-outside-toplevel + + if kind == "local": + return LocalBuilder(*args, **kwargs) # type: ignore + raise ValueError(f"Unknown Builder: {kind}") + + +create = Builder.create # pylint: disable=invalid-name + @register_object("meta_schedule.PyBuilder") class _PyBuilder(Builder): @@ -168,16 +197,3 @@ def build(self, build_inputs: List[BuilderInput]) -> List[BuilderResult]: The results of building the given inputs. """ raise NotImplementedError - - -def create( # pylint: disable=keyword-arg-before-vararg - kind: Literal["local"] = "local", - *args, - **kwargs, -) -> Builder: - """Create a Builder.""" - from . import LocalBuilder # pylint: disable=import-outside-toplevel - - if kind == "local": - return LocalBuilder(*args, **kwargs) # type: ignore - raise ValueError(f"Unknown Builder: {kind}") diff --git a/python/tvm/meta_schedule/builder/local_builder.py b/python/tvm/meta_schedule/builder/local_builder.py index e81ccfe808ff..6e282d8cb62d 100644 --- a/python/tvm/meta_schedule/builder/local_builder.py +++ b/python/tvm/meta_schedule/builder/local_builder.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. """Local builder that compile on the local host""" -import logging import os import tempfile from typing import Callable, Dict, List, Optional, Union @@ -26,10 +25,11 @@ from tvm.target import Target from ...contrib.popen_pool import MapResult, PopenPoolExecutor, StatusKind +from ..logging import get_logger from ..utils import cpu_count, derived_object, get_global_func_with_default_on_worker from .builder import BuilderInput, BuilderResult, PyBuilder -logger = logging.getLogger(__name__) # pylint: disable=invalid-name +logger = get_logger(__name__) # pylint: disable=invalid-name T_BUILD = Callable[ # pylint: disable=invalid-name @@ -137,7 +137,7 @@ def __init__( super().__init__() if max_workers is None: - max_workers = cpu_count(logical=False) + max_workers = cpu_count(logical=True) logger.info("LocalBuilder: max_workers = %d", max_workers) self.max_workers = max_workers diff --git a/python/tvm/meta_schedule/cost_model/cost_model.py b/python/tvm/meta_schedule/cost_model/cost_model.py index d3b660d837dd..54a4d7a34391 100644 --- a/python/tvm/meta_schedule/cost_model/cost_model.py +++ b/python/tvm/meta_schedule/cost_model/cost_model.py @@ -16,23 +16,30 @@ # under the License. """Meta Schedule CostModel.""" import ctypes -from typing import Callable, List +from typing import Callable, List, Union + +# isort: off +from typing_extensions import Literal + +# isort: on import numpy as np # type: ignore from tvm._ffi import register_object -from tvm.meta_schedule.utils import _get_default_str from tvm.runtime import Object from .. import _ffi_api from ..runner import RunnerResult from ..search_strategy import MeasureCandidate from ..tune_context import TuneContext +from ..utils import _get_default_str @register_object("meta_schedule.CostModel") class CostModel(Object): """Cost model.""" + CostModelType = Union["CostModel", Literal["xgb", "mlp", "random"]] + def load(self, path: str) -> None: """Load the cost model from given file location. @@ -97,6 +104,41 @@ def predict(self, context: TuneContext, candidates: List[MeasureCandidate]) -> n ) return results + @staticmethod + def create( + kind: Literal["xgb", "mlp", "random"], + *args, + **kwargs, + ) -> "CostModel": + """Create a CostModel. + + Parameters + ---------- + kind : Literal["xgb", "mlp", "random"] + The kind of the cost model. Can be "xgb", "mlp", or "random". + + Returns + ------- + cost_model : CostModel + The created cost model. + """ + from . import RandomModel, XGBModel # pylint: disable=import-outside-toplevel + + if kind == "xgb": + return XGBModel(*args, **kwargs) # type: ignore + if kind == "random": + return RandomModel(*args, **kwargs) # type: ignore + if kind == "mlp": + from .mlp_model import ( # type: ignore # pylint: disable=import-outside-toplevel + MLPModel, + ) + + return MLPModel(*args, **kwargs) # type: ignore + raise ValueError(f"Unknown CostModel: {kind}") + + +create = CostModel.create # pylint: disable=invalid-name + @register_object("meta_schedule.PyCostModel") class _PyCostModel(CostModel): diff --git a/python/tvm/meta_schedule/cost_model/mlp_model.py b/python/tvm/meta_schedule/cost_model/mlp_model.py index e7f07f0a4542..8bd050b689bf 100644 --- a/python/tvm/meta_schedule/cost_model/mlp_model.py +++ b/python/tvm/meta_schedule/cost_model/mlp_model.py @@ -19,7 +19,6 @@ Segment Sum MLP cost model """ import glob -import logging import math import os import random @@ -38,14 +37,13 @@ from ..cost_model import PyCostModel from ..database import JSONDatabase from ..feature_extractor import FeatureExtractor, PerStoreFeature +from ..logging import get_logger from ..runner import RunnerResult from ..search_strategy import MeasureCandidate from ..tune_context import TuneContext from ..utils import derived_object, shash2hex -logging.basicConfig() -logger = logging.getLogger("mlp_model") # pylint: disable=invalid-name -logger.setLevel(logging.INFO) +logger = get_logger("mlp_model") # pylint: disable=invalid-name # pylint: disable=no-member,import-outside-toplevel diff --git a/python/tvm/meta_schedule/cost_model/random_model.py b/python/tvm/meta_schedule/cost_model/random_model.py index bc178f76ac90..19516bee0d4f 100644 --- a/python/tvm/meta_schedule/cost_model/random_model.py +++ b/python/tvm/meta_schedule/cost_model/random_model.py @@ -19,12 +19,11 @@ """ from typing import List, Optional, Tuple, Union -from tvm.meta_schedule.utils import derived_object # type: ignore - from ..cost_model import PyCostModel from ..runner import RunnerResult from ..search_strategy import MeasureCandidate from ..tune_context import TuneContext +from ..utils import derived_object # type: ignore @derived_object diff --git a/python/tvm/meta_schedule/cost_model/xgb_model.py b/python/tvm/meta_schedule/cost_model/xgb_model.py index 59774b534e55..0a2786c6abe0 100644 --- a/python/tvm/meta_schedule/cost_model/xgb_model.py +++ b/python/tvm/meta_schedule/cost_model/xgb_model.py @@ -14,15 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -""" -XGBoost-based cost model -""" -import logging +"""XGBoost-based cost model""" import os import tempfile from collections import OrderedDict from itertools import chain as itertools_chain -from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tuple, Callable +from typing import TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Tuple import numpy as np # type: ignore @@ -30,21 +27,20 @@ from ...runtime import NDArray from ..cost_model import PyCostModel from ..feature_extractor import FeatureExtractor +from ..logging import get_logger from ..runner import RunnerResult from ..search_strategy import MeasureCandidate from ..utils import cpu_count, derived_object, shash2hex from .metric import max_curve - if TYPE_CHECKING: - import xgboost as xgb # type: ignore from xgboost.callback import TrainingCallback # type: ignore from ..tune_context import TuneContext -logger = logging.getLogger(__name__) # pylint: disable=invalid-name +logger = get_logger(__name__) # pylint: disable=invalid-name def make_metric_sorter(focused_metric): @@ -302,7 +298,7 @@ class XGBModel(PyCostModel): average_peak_n : int The number to calculate average peak score. adaptive_training : bool - Whether use adpative training to reduce tuning time. + Whether use adaptive training to reduce tuning time. """ # feature extractor @@ -327,7 +323,7 @@ def __init__( self, *, # feature extractor - extractor: FeatureExtractor, + extractor: FeatureExtractor.FeatureExtractorType = "per-store-feature", # xgboost model config config: XGBConfig = XGBConfig(), # random result before enough samples @@ -339,6 +335,8 @@ def __init__( adaptive_training: bool = True, ): super().__init__() + if not isinstance(extractor, FeatureExtractor): + extractor = FeatureExtractor.create(extractor) # feature extractor self.extractor = extractor # model-related @@ -652,7 +650,7 @@ def _get_custom_call_back( """Get a customized callback function for XGBoost. Work around xgboost import.""" def optional_xgboost_callback(cls): - """Decorator for importing TraningCallback from xgboost""" + """Decorator for importing TrainingCallback from xgboost""" # pylint:disable = import-outside-toplevel try: from xgboost.callback import TrainingCallback # type: ignore @@ -696,7 +694,7 @@ def __call__(self, env: "xgb.core.CallbackEnv"): return self.after_iteration(env.model, env.iteration, env.evaluation_result_list) def init(self, model: "xgb.Booster"): - """Internal function for intialization""" + """Internal function for initialization""" booster: "xgb.Booster" = model self.state["best_iteration"] = 0 self.state["best_score"] = float("inf") diff --git a/python/tvm/meta_schedule/database/database.py b/python/tvm/meta_schedule/database/database.py index 75b78b118eea..e21ce29ed699 100644 --- a/python/tvm/meta_schedule/database/database.py +++ b/python/tvm/meta_schedule/database/database.py @@ -164,6 +164,8 @@ def from_json(json_obj: Any, workload: Workload) -> "TuningRecord": class Database(Object): """The abstract database interface.""" + DatabaseType = Union["Database", Literal["json", "memory"]] + def has_workload(self, mod: IRModule) -> bool: """Check if the database has the given workload. Parameters @@ -361,6 +363,56 @@ def current() -> Optional["Database"]: """Get the current database under scope.""" return _ffi_api.DatabaseCurrent() # type: ignore # pylint: disable=no-member + @staticmethod + def create( # pylint: disable=keyword-arg-before-vararg + kind: Union[ + Literal[ + "json", + "memory", + "union", + "ordered_union", + ], + Callable[[Schedule], bool], + ] = "json", + *args, + **kwargs, + ) -> "Database": + """Create a Database. + + Parameters + ---------- + kind : str = "json" | "memory" | "union" | "ordered_union" | Callable[[Schedule], bool] + The kind of the database to be created. The following kinds are supported: + "json", "memory", "union", "ordered_union", and a custom schedule function. + + Returns + ------- + database : Database + The created database. + """ + from . import ( # pylint: disable=import-outside-toplevel + JSONDatabase, + MemoryDatabase, + OrderedUnionDatabase, + ScheduleFnDatabase, + UnionDatabase, + ) + + if callable(kind): + return ScheduleFnDatabase(kind, *args, **kwargs) # type: ignore + if kind == "json": + return JSONDatabase(*args, **kwargs) + if kind == "memory": + return MemoryDatabase(*args, **kwargs) # type: ignore + if kind == "union": + return UnionDatabase(*args, **kwargs) # type: ignore + if kind == "ordered_union": + return OrderedUnionDatabase(*args, **kwargs) # type: ignore + raise ValueError(f"Unknown Database: {kind}") + + +create = Database.create # pylint: disable=invalid-name + @register_object("meta_schedule.PyDatabase") class _PyDatabase(Database): @@ -568,38 +620,3 @@ def __len__(self) -> int: The number of records in the database """ raise NotImplementedError - - -def create( # pylint: disable=keyword-arg-before-vararg - kind: Union[ - Literal[ - "json", - "memory", - "union", - "ordered_union", - ], - Callable[[Schedule], bool], - ] = "json", - *args, - **kwargs, -) -> Database: - """Create a Database.""" - from . import ( # pylint: disable=import-outside-toplevel - JSONDatabase, - MemoryDatabase, - OrderedUnionDatabase, - ScheduleFnDatabase, - UnionDatabase, - ) - - if callable(kind): - return ScheduleFnDatabase(kind, *args, **kwargs) # type: ignore - if kind == "json": - return JSONDatabase(*args, **kwargs) - if kind == "memory": - return MemoryDatabase(*args, **kwargs) # type: ignore - if kind == "union": - return UnionDatabase(*args, **kwargs) # type: ignore - if kind == "ordered_union": - return OrderedUnionDatabase(*args, **kwargs) # type: ignore - raise ValueError(f"Unknown Database: {kind}") diff --git a/python/tvm/meta_schedule/default_config.py b/python/tvm/meta_schedule/default_config.py deleted file mode 100644 index c701fd6568e0..000000000000 --- a/python/tvm/meta_schedule/default_config.py +++ /dev/null @@ -1,454 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=import-outside-toplevel -"""Pre-configured Defaults for MetaSchedule search rules""" -import logging -from os import path as osp -from typing import Callable, Dict, List, Optional, Union - -from tvm.ir import IRModule -from tvm.target import Target -from tvm.tir import PrimFunc - -from .builder import Builder, LocalBuilder -from .cost_model import CostModel, XGBModel -from .database import Database, JSONDatabase -from .feature_extractor import PerStoreFeature -from .measure_callback import MeasureCallback -from .mutator import Mutator -from .postproc import Postproc -from .runner import LocalRunner, Runner -from .schedule_rule import ScheduleRule -from .space_generator import PostOrderApply, SpaceGenerator - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name - -FnSpaceGenerator = Callable[[], SpaceGenerator] -FnScheduleRule = Callable[[], List[ScheduleRule]] -FnPostproc = Callable[[], List[Postproc]] -FnMutatorProb = Callable[[], Dict[Mutator, float]] - - -def mod(mod: Union[PrimFunc, IRModule]) -> IRModule: # pylint: disable=redefined-outer-name - """Normalize the input to an IRModule""" - if isinstance(mod, PrimFunc): - mod = mod.with_attr("global_symbol", "main") - mod = mod.with_attr("tir.noalias", True) - mod = IRModule({"main": mod}) - if not isinstance(mod, IRModule): - raise TypeError(f"Expected `mod` to be PrimFunc or IRModule, but gets: {mod}") - func_names = mod.get_global_vars() - (func_name,) = func_names - if len(func_names) == 1 and func_name.name_hint != "main": - mod = IRModule({"main": mod[func_name]}) - return mod - - -def target(target: Union[str, Target]) -> Target: # pylint: disable=redefined-outer-name - """Normalize the input to tvm.target.Target""" - if isinstance(target, str): - target = Target(target) - if not isinstance(target, Target): - raise TypeError(f"Expected `target` to be str or Target, but gets: {target}") - return target - - -def builder(builder: Optional[Builder]) -> Builder: # pylint: disable=redefined-outer-name - """Normalize the input to tvm.meta_schedule.Builder""" - if builder is None: - builder = LocalBuilder() # type: ignore - if not isinstance(builder, Builder): - raise TypeError(f"Expected `builder` to be Builder, but gets: {builder}") - return builder - - -def runner(runner: Optional[Runner]) -> Runner: # pylint: disable=redefined-outer-name - """Normalize the input to tvm.meta_schedule.Runner""" - if runner is None: - runner = LocalRunner() # type: ignore - if not isinstance(runner, Runner): - raise TypeError(f"Expected `runner` to be Runner, but gets: {runner}") - return runner - - -def database( - database: Union[None, Database], # pylint: disable=redefined-outer-name - path: str, -) -> Database: - """Normalize the input to tvm.meta_schedule.Database""" - if database is None: - path_workload = osp.join(path, "database_workload.json") - path_tuning_record = osp.join(path, "database_tuning_record.json") - logger.info( - "Creating JSONDatabase. Workload at: %s. Tuning records at: %s", - path_workload, - path_tuning_record, - ) - database = JSONDatabase( - path_workload=path_workload, - path_tuning_record=path_tuning_record, - ) - if not isinstance(database, Database): - raise TypeError(f"Expected `database` to be Database, but gets: {database}") - return database - - -def callbacks( # pylint: disable=redefined-outer-name - measure_callbacks: Optional[List[MeasureCallback]], -) -> List[MeasureCallback]: - """Normalize the input to List[tvm.meta_schedule.MeasureCallback]""" - if measure_callbacks is None: - from tvm.meta_schedule import measure_callback as M - - return [ - M.AddToDatabase(), - M.RemoveBuildArtifact(), - M.EchoStatistics(), - M.UpdateCostModel(), - ] - if not isinstance(measure_callbacks, (list, tuple)): - raise TypeError( - f"Expected `measure_callbacks` to be List[MeasureCallback], " - f"but gets: {measure_callbacks}" - ) - measure_callbacks = list(measure_callbacks) - for i, callback in enumerate(measure_callbacks): - if not isinstance(callback, MeasureCallback): - raise TypeError( - f"Expected `measure_callbacks` to be List[MeasureCallback], " - f"but measure_callbacks[{i}] is: {callback}" - ) - return measure_callbacks - - -def cost_model( - cost_model: Optional[CostModel], # pylint: disable=redefined-outer-name - adpative_training: Optional[bool], -) -> CostModel: - """Normalize the input to tvm.meta_schedule.CostModel""" - if cost_model is None: - return XGBModel( # type: ignore - extractor=PerStoreFeature(), - adaptive_training=adpative_training is None or adpative_training, - ) - if not isinstance(cost_model, CostModel): - raise TypeError(f"Expected `cost_model` to be CostModel, but gets: {cost_model}") - return cost_model - - -def space_generator( - space_generator: Optional[FnSpaceGenerator], # pylint: disable=redefined-outer-name -) -> SpaceGenerator: - """Normalize the input to tvm.meta_schedule.SpaceGenerator""" - if space_generator is None: - return PostOrderApply() - if callable(space_generator): - space_generator = space_generator() - if not isinstance(space_generator, SpaceGenerator): - raise TypeError( - f"Expected `space_generator` to return SpaceGenerator, " f"but gets: {space_generator}" - ) - return space_generator - - -def schedule_rules( # pylint: disable=redefined-outer-name - sch_rules: Optional[FnScheduleRule], - target: Target, -) -> List[ScheduleRule]: - """Normalize the input to List[tvm.meta_schedule.ScheduleRule]""" - if callable(sch_rules): - return sch_rules() - if sch_rules is not None: - raise TypeError(f"Expected `sch_rules` to be None or callable, but gets: {sch_rules}") - if target.kind.name == "llvm": - return _DefaultLLVM.schedule_rules() - if target.kind.name in ["cuda", "rocm", "vulkan"]: - return _DefaultCUDA.schedule_rules() - if target.kind.name == "hexagon": - return _DefaultHexagon.schedule_rules() - raise ValueError(f"Unsupported target: {target}") - - -def postproc( # pylint: disable=redefined-outer-name - postproc: Optional[FnPostproc], - target: Target, -) -> List[Postproc]: - """Normalize the input to List[tvm.meta_schedule.Postproc]""" - if callable(postproc): - return postproc() - if postproc is not None: - raise TypeError(f"Expected `postproc` to be None or callable, but gets: {postproc}") - if target.kind.name == "llvm": - return _DefaultLLVM.postprocs() - if target.kind.name in ["cuda", "rocm", "vulkan"]: - return _DefaultCUDA.postprocs() - if target.kind.name == "hexagon": - return _DefaultHexagon.postprocs() - raise ValueError(f"Unsupported target: {target}") - - -def mutator_probs( # pylint: disable=redefined-outer-name - mutator_probs: Optional[FnMutatorProb], - target: Target, -) -> Dict[Mutator, float]: - """Normalize the input to Dict[tvm.meta_schedule.Mutator, float]""" - if callable(mutator_probs): - return mutator_probs() - if mutator_probs is not None: - raise TypeError( - f"Expected `mutator_probs` to be None or callable, but gets: {mutator_probs}" - ) - if target.kind.name in ["llvm", "hexagon"]: - return _DefaultLLVM.mutator_probs() - if target.kind.name in ["cuda", "rocm", "vulkan"]: - return _DefaultCUDA.mutator_probs() - raise ValueError(f"Unsupported target: {target}") - - -class _DefaultLLVM: - """Default tuning configuration for LLVM.""" - - @staticmethod - def schedule_rules() -> List[ScheduleRule]: - from tvm.meta_schedule import schedule_rule as M - - return [ - M.AutoInline( - into_producer=False, - into_consumer=True, - inline_const_tensor=True, - disallow_if_then_else=True, - require_injective=True, - require_ordered=True, - disallow_op=["tir.exp"], - ), - M.AddRFactor(max_jobs_per_core=16, max_innermost_factor=64), - M.MultiLevelTiling( - structure="SSRSRS", - tile_binds=None, - max_innermost_factor=64, - vector_load_lens=None, - reuse_read=None, - reuse_write=M.ReuseType( - req="may", - levels=[1, 2], - scope="global", - ), - ), - M.ParallelizeVectorizeUnroll( - max_jobs_per_core=16, - max_vectorize_extent=64, - unroll_max_steps=[0, 16, 64, 512], - unroll_explicit=True, - ), - M.RandomComputeLocation(), - ] - - @staticmethod - def postprocs() -> List[Postproc]: - from tvm.meta_schedule import postproc as M - - return [ - M.DisallowDynamicLoop(), - M.RewriteParallelVectorizeUnroll(), - M.RewriteReductionBlock(), - M.RewriteLayout(), - ] - - @staticmethod - def mutator_probs() -> Dict[Mutator, float]: - from tvm.meta_schedule import mutator as M - - return { - M.MutateTileSize(): 0.9, - M.MutateComputeLocation(): 0.05, - M.MutateUnroll(): 0.03, - M.MutateParallel(max_jobs_per_core=16): 0.02, - } - - -class _DefaultHexagon: - """Default tuning configuration for Hexagon.""" - - @staticmethod - def schedule_rules() -> List[ScheduleRule]: - from tvm.meta_schedule import schedule_rule as M - - return [ - M.AutoInline( - into_producer=False, - into_consumer=True, - inline_const_tensor=True, - disallow_if_then_else=True, - require_injective=True, - require_ordered=True, - disallow_op=["tir.exp"], - ), - M.MultiLevelTilingWideVector( - structure="SRSRS", - vector_length_in_bits=1024, - max_innermost_factor=128, - reuse_read=None, - reuse_write=M.ReuseType( - req="may", - levels=[1, 2], - scope="global", - ), - ), - M.ParallelizeVectorizeUnroll( - max_jobs_per_core=16, - max_vectorize_extent=128, - unroll_max_steps=[0, 16, 64, 512], - unroll_explicit=True, - ), - ] - - @staticmethod - def postprocs() -> List[Postproc]: - from tvm.meta_schedule import postproc as M - - return [ - M.DisallowDynamicLoop(), - M.RewriteParallelVectorizeUnroll(), - M.RewriteReductionBlock(), - # TODO(masahi): Fix RewriteLayout for link-params=True case - # M.RewriteLayout(), - ] - - -class _DefaultCUDA: - """Default tuning configuration for CUDA.""" - - @staticmethod - def schedule_rules() -> List[ScheduleRule]: - from tvm.meta_schedule import schedule_rule as M - - return [ - M.MultiLevelTiling( - structure="SSSRRSRS", - tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"], - max_innermost_factor=64, - vector_load_lens=[1, 2, 3, 4, 8, 16], - reuse_read=M.ReuseType( - req="must", - levels=[4], - scope="shared", - ), - reuse_write=M.ReuseType( - req="must", - levels=[3], - scope="local", - ), - ), - M.AutoInline( - into_producer=True, - into_consumer=True, - inline_const_tensor=True, - disallow_if_then_else=False, - require_injective=False, - require_ordered=False, - disallow_op=None, - ), - M.CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512]), - M.ParallelizeVectorizeUnroll( - max_jobs_per_core=-1, # disable parallelize - max_vectorize_extent=-1, # disable vectorize - unroll_max_steps=[0, 16, 64, 512, 1024], - unroll_explicit=True, - ), - M.AutoBind( - max_threadblocks=256, - thread_extents=[32, 64, 128, 256, 512, 1024], - ), - ] - - @staticmethod - def postprocs() -> List[Postproc]: - from tvm.meta_schedule import postproc as M - - return [ - M.DisallowDynamicLoop(), - M.RewriteCooperativeFetch(), - M.RewriteUnboundBlock(), - M.RewriteParallelVectorizeUnroll(), - M.RewriteReductionBlock(), - M.VerifyGPUCode(), - ] - - @staticmethod - def mutator_probs() -> Dict[Mutator, float]: - from tvm.meta_schedule import mutator as M - - return { - M.MutateTileSize(): 0.9, - M.MutateUnroll(): 0.08, - M.MutateThreadBinding(): 0.02, - } - - -class _DefaultCUDATensorCore: - """Default tuning configuration for CUDA TensorCore.""" - - @staticmethod - def schedule_rules(): - from tvm.meta_schedule import schedule_rule as M - from tvm.tir.tensor_intrin.cuda import get_wmma_intrin_group - - return [ - M.MultiLevelTilingTensorCore( - intrin_groups=[ - get_wmma_intrin_group( - store_scope="shared", - in_dtype=in_dtype, - out_dtype=out_dtype, - trans_b=trans_b, - ) - for (in_dtype, out_dtype) in [("float16", "float16"), ("int8", "int32")] - for trans_b in [False, True] - ], - structure="SSSRRSRS", - tile_binds=["blockIdx.y", "blockIdx.x", "threadIdx.y"], - max_innermost_factor=4, - vector_load_lens=[1, 2, 3, 4, 8, 16], - reuse_read=M.ReuseType(req="must", levels=[4], scope="shared"), - reuse_write=M.ReuseType( - req="must", - levels=[2], - scope="shared", - ), - use_software_pipeline=False, - ), - *_DefaultCUDA.schedule_rules(), - ] - - @staticmethod - def postprocs() -> List[Postproc]: - from tvm.meta_schedule import postproc as M - - return [ - M.DisallowDynamicLoop(), - M.RewriteCooperativeFetch(), - M.RewriteUnboundBlock(), - M.RewriteParallelVectorizeUnroll(), - M.RewriteReductionBlock(), - M.RewriteTensorize(), - M.VerifyGPUCode(), - ] - - @staticmethod - def mutator_probs() -> Dict[Mutator, float]: - return _DefaultCUDA.mutator_probs() diff --git a/python/tvm/meta_schedule/feature_extractor/feature_extractor.py b/python/tvm/meta_schedule/feature_extractor/feature_extractor.py index 04064b1cce35..c14c97e0f526 100644 --- a/python/tvm/meta_schedule/feature_extractor/feature_extractor.py +++ b/python/tvm/meta_schedule/feature_extractor/feature_extractor.py @@ -15,22 +15,29 @@ # specific language governing permissions and limitations # under the License. """Meta Schedule FeatureExtractor.""" -from typing import Callable, List +from typing import Callable, List, Union + +# isort: off +from typing_extensions import Literal + +# isort: on from tvm._ffi import register_object from tvm.runtime import Object from tvm.runtime.ndarray import NDArray from .. import _ffi_api -from ..utils import _get_default_str -from ..tune_context import TuneContext from ..search_strategy import MeasureCandidate +from ..tune_context import TuneContext +from ..utils import _get_default_str @register_object("meta_schedule.FeatureExtractor") class FeatureExtractor(Object): """Extractor for features from measure candidates for use in cost model.""" + FeatureExtractorType = Union[Literal["per-store-feature"], "FeatureExtractor"] + def extract_from( self, context: TuneContext, candidates: List[MeasureCandidate] ) -> List[NDArray]: @@ -53,6 +60,19 @@ def extract_from( ) return result + @staticmethod + def create( + kind: Literal["per-store-feature"], + *args, + **kwargs, + ) -> "FeatureExtractor": + """Create a CostModel.""" + from . import PerStoreFeature # pylint: disable=import-outside-toplevel + + if kind == "per-store-feature": + return PerStoreFeature(*args, **kwargs) # type: ignore + raise ValueError(f"Unknown CostModel: {kind}") + @register_object("meta_schedule.PyFeatureExtractor") class _PyFeatureExtractor(FeatureExtractor): diff --git a/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py b/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py index d805648bfbfd..18b84c364ad4 100644 --- a/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py +++ b/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py @@ -15,16 +15,18 @@ # specific language governing permissions and limitations # under the License. """Random Feature Extractor.""" -from typing import List, Union, Tuple +from typing import List, Tuple, Union import numpy as np # type: ignore from tvm.runtime.ndarray import NDArray, array -from ..tune_context import TuneContext -from ..search_strategy import MeasureCandidate from ..feature_extractor import PyFeatureExtractor +from ..search_strategy import MeasureCandidate +from ..tune_context import TuneContext +from ..utils import derived_object +@derived_object class RandomFeatureExtractor(PyFeatureExtractor): """Random Feature Extractor diff --git a/python/tvm/meta_schedule/logging.py b/python/tvm/meta_schedule/logging.py new file mode 100644 index 000000000000..9d673266a3f2 --- /dev/null +++ b/python/tvm/meta_schedule/logging.py @@ -0,0 +1,259 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Logging interface in MetaSchedule""" +import logging +import logging.config +import os +import os.path as osp +from logging import Logger +from typing import Any, Callable, Dict, List, Optional + + +def get_logger(name: str) -> Logger: + """Create or get a logger by its name. This is essentially a wrapper of python's native logger. + + Parameters + ---------- + name : str + The name of the logger. + + Returns + ------- + logger : Logger + The logger instance. + """ + return logging.getLogger(name) + + +def get_logging_func(logger: Logger) -> Optional[Callable[[int, str], None]]: + """Get the logging function. + + Parameters + ---------- + logger : Logger + The logger instance. + Returns + ------- + result : Optional[Callable] + The function to do the specified level of logging. + """ + if logger is None: + return None + + level2log = { + logging.DEBUG: logger.debug, + logging.INFO: logger.info, + logging.WARNING: logger.warning, + logging.ERROR: logger.error, + # logging.FATAL not included + } + + def logging_func(level: int, msg: str): + if level < 0: + from IPython.display import ( # type: ignore # pylint: disable=import-outside-toplevel + clear_output, + ) + + clear_output(wait=True) + else: + level2log[level](msg) + + return logging_func + + +def create_loggers( + log_dir: str, + params: List[Dict[str, Any]], + logger_config: Optional[Dict[str, Any]] = None, + disable_existing_loggers: bool = False, +): + """Create loggers from configuration""" + if logger_config is None: + config = {} + else: + config = logger_config + + config.setdefault("loggers", {}) + config.setdefault("handlers", {}) + config.setdefault("formatters", {}) + + global_logger_name = "tvm.meta_schedule" + global_logger = logging.getLogger(global_logger_name) + if global_logger.level is logging.NOTSET: + global_logger.setLevel(logging.INFO) + + config["loggers"].setdefault( + global_logger_name, + { + "level": logging._levelToName[global_logger.level], # pylint: disable=protected-access + "handlers": [handler.get_name() for handler in global_logger.handlers] + + [global_logger_name + ".console", global_logger_name + ".file"], + "propagate": False, + }, + ) + config["loggers"].setdefault( + "{logger_name}", + { + "level": "INFO", + "handlers": [ + "{logger_name}.file", + ], + "propagate": False, + }, + ) + config["handlers"].setdefault( + global_logger_name + ".console", + { + "class": "logging.StreamHandler", + "stream": "ext://sys.stdout", + "formatter": "tvm.meta_schedule.standard_formatter", + }, + ) + config["handlers"].setdefault( + global_logger_name + ".file", + { + "class": "logging.FileHandler", + "filename": "{log_dir}/" + __name__ + ".task_scheduler.log", + "mode": "a", + "level": "INFO", + "formatter": "tvm.meta_schedule.standard_formatter", + }, + ) + config["handlers"].setdefault( + "{logger_name}.file", + { + "class": "logging.FileHandler", + "filename": "{log_dir}/{logger_name}.log", + "mode": "a", + "level": "INFO", + "formatter": "tvm.meta_schedule.standard_formatter", + }, + ) + config["formatters"].setdefault( + "tvm.meta_schedule.standard_formatter", + { + "format": "%(asctime)s.%(msecs)03d %(levelname)s %(message)s", + "datefmt": "%Y-%m-%d %H:%M:%S", + }, + ) + + # set up dictConfig loggers + p_config = {"version": 1, "disable_existing_loggers": disable_existing_loggers} + for k, v in config.items(): + if k in ["formatters", "handlers", "loggers"]: + p_config[k] = _batch_parameterize_config(v, params) # type: ignore + else: + p_config[k] = v + logging.config.dictConfig(p_config) + + # check global logger + if global_logger.level not in [logging.DEBUG, logging.INFO]: + global_logger.warning( + "Logging level set to %s, please set to logging.INFO" + " or logging.DEBUG to view full log.", + logging._levelToName[global_logger.level], # pylint: disable=protected-access + ) + global_logger.info("Logging directory: %s", log_dir) + + +def _batch_parameterize_config( + config: Dict[str, Any], + params: List[Dict[str, str]], +) -> Dict[str, Any]: + """Parameterize the given configuration with multiple parameters sets. + + Parameters + ---------- + config : Dict[str, Any] + The given config dict. + Params : List[Dict[str, str]] + List of the given multiple parameters sets. + + Returns + ------- + result : Dict[str, Any] + The parameterized configuration. + """ + results = {} + for name, cfg in config.items(): + for p in params: + p_name = name.format(**p) + if p_name not in results: + p_cfg = _parameterize_config(cfg, p) + results[p_name] = p_cfg + return results + + +def _parameterize_config( + config: Dict[str, Any], + params: Dict[str, str], +) -> Dict[str, Any]: + """Parameterize the given configuration. + + Parameters + ---------- + config : Dict[str, Any] + The given config dict. + Params : Dict[str, str] + The given parameters. + + Returns + ------- + result : Dict[str, Any] + The parameterized configuration. + """ + result = {} + for k, v in config.items(): + if isinstance(k, str): + k = k.format(**params) + if isinstance(v, str): + v = v.format(**params) + elif isinstance(v, dict): + v = _parameterize_config(v, params) + elif isinstance(v, list): + v = [t.format(**params) for t in v] + result[k] = v + return result + + +def get_loggers_from_work_dir( + work_dir: str, + task_names: List[str], +) -> List[Logger]: + """Create loggers from work directory + + Parameters + ---------- + work_dir : str + The work directory. + task_names : List[str] + The list of task names. + + Returns + ------- + loggers : List[Logger] + The list of loggers. + """ + log_dir = osp.join(work_dir, "logs") + os.makedirs(log_dir, exist_ok=True) + pattern = __name__ + ".task_{i:0" + f"{len(str(len(task_names) - 1))}" + "d}_{name}" + loggers = [pattern.format(i=i, name=name) for i, name in enumerate(task_names)] + create_loggers( + log_dir=log_dir, + params=[{"log_dir": log_dir, "logger_name": logger} for logger in loggers], + ) + return [get_logger(logger) for logger in loggers] diff --git a/python/tvm/meta_schedule/measure_callback/__init__.py b/python/tvm/meta_schedule/measure_callback/__init__.py index f697e7733e7e..f43aee7d875c 100644 --- a/python/tvm/meta_schedule/measure_callback/__init__.py +++ b/python/tvm/meta_schedule/measure_callback/__init__.py @@ -14,11 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -""" -The tvm.meta_schedule.measure_callback package. -""" -from .measure_callback import MeasureCallback, PyMeasureCallback +"""The tvm.meta_schedule.measure_callback package.""" from .add_to_database import AddToDatabase -from .echo_statistics import EchoStatistics +from .measure_callback import MeasureCallback, PyMeasureCallback from .remove_build_artifact import RemoveBuildArtifact from .update_cost_model import UpdateCostModel diff --git a/python/tvm/meta_schedule/measure_callback/echo_statistics.py b/python/tvm/meta_schedule/measure_callback/echo_statistics.py deleted file mode 100644 index 867409f88174..000000000000 --- a/python/tvm/meta_schedule/measure_callback/echo_statistics.py +++ /dev/null @@ -1,30 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""A callback that echos the statistics of the tuning process to the console""" -from tvm._ffi import register_object - -from .. import _ffi_api -from .measure_callback import MeasureCallback - - -@register_object("meta_schedule.EchoStatistics") -class EchoStatistics(MeasureCallback): - def __init__(self) -> None: - """A callback that echos the statistics of the tuning process to the console""" - self.__init_handle_by_constructor__( - _ffi_api.MeasureCallbackEchoStatistics, # type: ignore # pylint: disable=no-member - ) diff --git a/python/tvm/meta_schedule/measure_callback/measure_callback.py b/python/tvm/meta_schedule/measure_callback/measure_callback.py index d9e412ed5605..d4a10c1e4009 100644 --- a/python/tvm/meta_schedule/measure_callback/measure_callback.py +++ b/python/tvm/meta_schedule/measure_callback/measure_callback.py @@ -16,7 +16,12 @@ # under the License. """Meta Schedule MeasureCallback.""" -from typing import Callable, List, TYPE_CHECKING +from typing import TYPE_CHECKING, Callable, List, Union + +# isort: off +from typing_extensions import Literal + +# isort: on from tvm._ffi import register_object from tvm.runtime import Object @@ -35,6 +40,8 @@ class MeasureCallback(Object): """Rules to apply after measure results is available.""" + CallbackListType = Union[List["MeasureCallback"], "MeasureCallback", Literal["default"]] + def apply( self, task_scheduler: "TaskScheduler", @@ -67,6 +74,13 @@ def apply( runner_results, ) + @staticmethod + def create(kind: Literal["default"]) -> List["MeasureCallback"]: + """Create a list of measure callbacks.""" + if kind == "default": + return _ffi_api.MeasureCallbackDefault() # type: ignore # pylint: disable=no-member + raise ValueError(f"Unknown kind of MeasureCallback list: {kind}") + @register_object("meta_schedule.PyMeasureCallback") class _PyMeasureCallback(MeasureCallback): diff --git a/python/tvm/meta_schedule/mutator/mutator.py b/python/tvm/meta_schedule/mutator/mutator.py index c5286aced7d8..188cb30c5b69 100644 --- a/python/tvm/meta_schedule/mutator/mutator.py +++ b/python/tvm/meta_schedule/mutator/mutator.py @@ -15,7 +15,12 @@ # specific language governing permissions and limitations # under the License. """Meta Schedule Mutator.""" -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING, Callable, Dict, Optional + +# isort: off +from typing_extensions import Literal + +# isort: on from tvm._ffi import register_object from tvm.runtime import Object @@ -68,6 +73,43 @@ def clone(self) -> "Mutator": """ return _ffi_api.MutatorClone(self) # type: ignore # pylint: disable=no-member + @staticmethod + def create( + kind: Literal[ + "llvm", + "cuda", + "cuda-tensorcore", + "hexagon", + ] + ) -> Dict["Mutator", float]: + """Create a list of default mutators. + + Parameters + ---------- + kind : Literal["llvm", "cuda", "cuda-tensorcore", "hexagon"] + The kind of mutators. + + Returns + ------- + mutators : List[Mutator] + The list of mutators. + """ + funcs = { + # pylint: disable=no-member + "llvm": _ffi_api.MutatorDefaultLLVM, # type: ignore + "cuda": _ffi_api.MutatorDefaultCUDA, # type: ignore + "cuda-tensorcore": _ffi_api.MutatorDefaultCUDATensorCore, # type: ignore + "hexagon": _ffi_api.MutatorDefaultHexagon, # type: ignore + # pylint: enable=no-member + } + for k, v in funcs.items(): + if k == kind: + return v() + raise ValueError(f"Unsupported kind {kind} for mutator creation.") + + +create = Mutator.create # pylint: disable=invalid-name + @register_object("meta_schedule.PyMutator") class _PyMutator(Mutator): diff --git a/python/tvm/meta_schedule/postproc/postproc.py b/python/tvm/meta_schedule/postproc/postproc.py index 6eec2965ceeb..67a0d27e8261 100644 --- a/python/tvm/meta_schedule/postproc/postproc.py +++ b/python/tvm/meta_schedule/postproc/postproc.py @@ -15,8 +15,12 @@ # specific language governing permissions and limitations # under the License. """Meta Schedule Postproc.""" +from typing import TYPE_CHECKING, Callable, List -from typing import TYPE_CHECKING, Callable +# isort: off +from typing_extensions import Literal + +# isort: on from tvm._ffi import register_object from tvm.runtime import Object @@ -70,6 +74,36 @@ def clone(self) -> "Postproc": """ return _ffi_api.PostprocClone(self) # type: ignore # pylint: disable=no-member + @staticmethod + def create(kind: Literal["llvm", "cuda", "cuda-tensorcore", "hexagon"]) -> List["Postproc"]: + """Create a list of default postprocessors. + + Parameters + ---------- + kind : Literal["llvm", "cuda", "cuda-tensorcore", "hexagon"] + The kind of the postprocessors. + + Returns + ------- + postprocs : List[Mutator] + The list of postprocessors. + """ + funcs = { + # pylint: disable=no-member + "llvm": _ffi_api.PostprocDefaultLLVM, # type: ignore + "cuda": _ffi_api.PostprocDefaultCUDA, # type: ignore + "cuda-tensorcore": _ffi_api.PostprocDefaultCUDATensorCore, # type: ignore + "hexagon": _ffi_api.PostprocDefaultHexagon, # type: ignore + # pylint: enable=no-member + } + for k, v in funcs.items(): + if k == kind: + return v() + raise ValueError(f"Unsupported kind {kind} for postproc creation.") + + +create = Postproc.create # pylint: disable=invalid-name + @register_object("meta_schedule.PyPostproc") class _PyPostproc(Postproc): diff --git a/python/tvm/meta_schedule/profiler.py b/python/tvm/meta_schedule/profiler.py index 206c2429d802..7446578a38d7 100644 --- a/python/tvm/meta_schedule/profiler.py +++ b/python/tvm/meta_schedule/profiler.py @@ -15,8 +15,6 @@ # specific language governing permissions and limitations # under the License. """A context manager that profiles tuning time cost for different parts.""" - -import logging from contextlib import contextmanager from typing import Dict, Optional @@ -25,8 +23,6 @@ from . import _ffi_api -logger = logging.getLogger(__name__) # pylint: disable=invalid-name - @register_object("meta_schedule.Profiler") class Profiler(Object): diff --git a/python/tvm/meta_schedule/relay_integration.py b/python/tvm/meta_schedule/relay_integration.py index 24009ab07fcf..af992dd4bc8b 100644 --- a/python/tvm/meta_schedule/relay_integration.py +++ b/python/tvm/meta_schedule/relay_integration.py @@ -15,8 +15,14 @@ # specific language governing permissions and limitations # under the License. """MetaSchedule-Relay integration""" -from typing import Any, Dict, List, Optional +from contextlib import contextmanager +from types import MappingProxyType +from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union +# isort: off +from typing_extensions import Literal + +# isort: on import numpy as np # type: ignore from tvm import nd from tvm._ffi import get_global_func @@ -24,19 +30,88 @@ from tvm.runtime import NDArray from tvm.target import Target +from .builder import Builder +from .cost_model import CostModel +from .database import Database from .extracted_task import ExtractedTask -from .utils import autotvm_silencer +from .logging import get_loggers_from_work_dir +from .measure_callback import MeasureCallback +from .profiler import Profiler +from .runner import Runner +from .search_strategy import SearchStrategy +from .space_generator import SpaceGenerator +from .task_scheduler import TaskScheduler +from .tune import tune_tasks +from .tune_context import TuneContext +from .utils import fork_seed + +if TYPE_CHECKING: + from tvm import relay + +_extract_task = get_global_func( # pylint: disable=invalid-name + "relay.backend.MetaScheduleExtractTask", + allow_missing=True, +) + + +@contextmanager +def _autotvm_silencer(): + """A context manager that silences autotvm warnings.""" + from tvm import autotvm # pylint: disable=import-outside-toplevel + + silent = autotvm.GLOBAL_SCOPE.silent + autotvm.GLOBAL_SCOPE.silent = True + try: + yield + finally: + autotvm.GLOBAL_SCOPE.silent = silent -def extract_task_from_relay( +def _normalize_params( mod: IRModule, - target: Target, - params: Optional[Dict[str, NDArray]] = None, + target: Union[Target, str], + params: Optional[Dict[str, NDArray]], + pass_config: Mapping[str, Any], + executor: Optional["relay.backend.Executor"], +) -> Tuple[ + IRModule, + Target, + Dict[str, NDArray], + Dict[str, Any], + Optional["relay.backend.Executor"], +]: + from tvm import relay # pylint: disable=import-outside-toplevel + + if isinstance(mod, relay.Function): + mod = IRModule.from_expr(mod) + if not isinstance(target, Target): + target = Target(target) + if params is None: + params = {} + relay_params = {} + for name, param in params.items(): + if isinstance(param, np.ndarray): + param = nd.array(param) + relay_params[name] = param + if executor is not None: + mod = mod.with_attr("executor", executor) + pass_config = dict(pass_config) + return mod, target, relay_params, pass_config, executor + + +def extract_tasks( + mod: IRModule, + target: Union[Target, str], + params: Optional[Dict[str, NDArray]], *, opt_level: int = 3, - pass_config: Optional[Dict[str, Any]] = None, - disabled_pass: Optional[List[str]] = None, - tir_converter: str = "default", + pass_config: Mapping[str, Any] = MappingProxyType( + { + "relay.backend.use_meta_schedule": True, + "relay.backend.tir_converter": "default", + } + ), + executor: Optional["relay.backend.Executor"] = None, ) -> List[ExtractedTask]: """Extract tuning tasks from a relay program. @@ -49,18 +124,11 @@ def extract_task_from_relay( params : Optional[Dict[str, tvm.runtime.NDArray]] The associated parameters of the program opt_level : int - The optimization level of the compiler - pass_config : Optional[Dict[str, Any]] - The pass config of the compiler - disabled_pass : Optional[List[str]] - The list of disabled passes of the compiler - tir_converter : str - The filter function to filter out the extracted tasks. Builtin filters: - - "default" - - "allow_extern" - The converter is a PackedFunc registered as f"relay.backend.tir_converter.{tir_converter}", - with the signature below: - (args: List[te.Tensor], constants: List[NDArray]) -> Optional[tir.PrimFunc] + The optimization level of the compilation + pass_config : Mapping[str, Any] + The pass configuration + executor : Optional[relay.backend.Executor] + The executor to use Returns ------- @@ -69,47 +137,229 @@ def extract_task_from_relay( """ # pylint: disable=import-outside-toplevel from tvm import autotvm - from tvm.relay import Function as RelayFunc # pylint: enable=import-outside-toplevel + mod, target, params, pass_config, _ = _normalize_params( + mod, target, params, pass_config, executor + ) + if target.kind.name != "cuda" and isinstance( + autotvm.DispatchContext.current, autotvm.FallbackContext + ): + tophub_context = autotvm.tophub.context(target) + else: + tophub_context = autotvm.utils.EmptyContext() + with Profiler.timeit("TaskExtraction"): + with target, _autotvm_silencer(), tophub_context: + with transform.PassContext( + opt_level=opt_level, + config=pass_config, + ): + return list(_extract_task(mod, target, params)) + + +def extracted_tasks_to_tune_contexts( + extracted_tasks: List[ExtractedTask], + work_dir: str, + space: SpaceGenerator.SpaceGeneratorType = "post-order-apply", + strategy: SearchStrategy.SearchStrategyType = "evolutionary", + num_threads: Union[Literal["physical", "logical"], int] = "physical", + seed: Optional[int] = None, +) -> Tuple[List[TuneContext], List[float]]: + """Convert ExtractedTask to TuneContext. + + Parameters + ---------- + tasks : List[ExtractedTask] + The tasks to be converted + work_dir : str + The working directory to store logs and databases + space : SpaceGenerator.SpaceGeneratorType + The space generator to use. + strategy : SearchStrategy.SearchStrategyType + The search strategy to use. + num_threads : Union[Literal["physical", "logical"], int] + The number of threads to use in multi-threaded search algorithm. + seed : Optional[int] + The random seed to use. + + Returns + ------- + tasks : List[TuneContext] + The converted tasks + task_weights : List[float] + The weights of the tasks + """ + tasks: List[TuneContext] = [] + task_weights: List[float] = [] + for task, logger, rand_state in zip( + extracted_tasks, + get_loggers_from_work_dir(work_dir, [t.task_name for t in extracted_tasks]), + fork_seed(seed, n=len(extracted_tasks)), + ): + tasks.append( + TuneContext( + mod=task.dispatched[0], + target=task.target, + space_generator=space, + search_strategy=strategy, + task_name=task.task_name, + logger=logger, + rand_state=rand_state, + num_threads=num_threads, + ).clone() + ) + task_weights.append(task.weight) + return tasks, task_weights + + +def tune_relay( + mod: IRModule, + params: Dict[str, NDArray], + target: Union[str, Target], + work_dir: str, + max_trials_global: int, + *, + max_trials_per_task: Optional[int] = None, + num_trials_per_iter: int = 64, + builder: Builder.BuilderType = "local", + runner: Runner.RunnerType = "local", + database: Database.DatabaseType = "json", + cost_model: CostModel.CostModelType = "xgb", + measure_callbacks: MeasureCallback.CallbackListType = "default", + task_scheduler: TaskScheduler.TaskSchedulerType = "gradient", + space: SpaceGenerator.SpaceGeneratorType = "post-order-apply", + strategy: SearchStrategy.SearchStrategyType = "evolutionary", + seed: Optional[int] = None, +) -> Database: + """Tune a Relay program. + + Parameters + ---------- + mod : Union[IRModule, tir.PrimFunc] + The module or function to tune + params : Optional[Dict[str, tvm.runtime.NDArray]] + The associated parameters of the program + target : Union[Target, str] + The compilation target + work_dir : str + The working directory to store the tuning records + max_trials_global : int + The maximum number of trials to run + max_trials_per_task : Optional[int] + The maximum number of trials to run for each task + num_trials_per_iter : int + The number of trials to run per iteration + builder : BuilderType + The builder to use + runner : RunnerType + The runner to use + database : DatabaseType + The database to use + cost_model : CostModelType + The cost model to use + measure_callbacks : CallbackListType + The measure callbacks to use + task_scheduler : TaskSchedulerType + The task scheduler to use + space : SpaceGeneratorType + The space generator to use + strategy : SearchStrategyType + The search strategy to use + seed : Optional[int] + The random seed - extract_task_func = get_global_func( - "relay.backend.MetaScheduleExtractTask", - allow_missing=False, + Returns + ------- + database : Database + The database that contains the tuning records + """ + tasks, task_weights = extracted_tasks_to_tune_contexts( + extracted_tasks=extract_tasks(mod, target, params), + work_dir=work_dir, + space=space, + strategy=strategy, + seed=seed, + ) + return tune_tasks( + tasks=tasks, + task_weights=task_weights, + work_dir=work_dir, + max_trials_global=max_trials_global, + max_trials_per_task=max_trials_per_task, + num_trials_per_iter=num_trials_per_iter, + builder=builder, + runner=runner, + database=database, + cost_model=cost_model, + measure_callbacks=measure_callbacks, + task_scheduler=task_scheduler, ) - if isinstance(mod, RelayFunc): - mod = IRModule.from_expr(mod) - if not isinstance(target, Target): - target = Target(target) - if disabled_pass is None: - disabled_pass = [] - if pass_config is None: - pass_config = { + +def compile_relay( + database: Database, + mod: IRModule, + target: Union[Target, str], + params: Optional[Dict[str, NDArray]], + *, + backend: Literal["graph", "vm"] = "graph", + opt_level: int = 3, + pass_config: Mapping[str, Any] = MappingProxyType( + { "relay.backend.use_meta_schedule": True, - "relay.backend.tir_converter": tir_converter, + "relay.backend.tir_converter": "default", } - if params is None: - params = {} - relay_params = {} - for name, param in params.items(): - if isinstance(param, np.ndarray): - param = nd.array(param) - relay_params[name] = param + ), + executor: Optional["relay.backend.Executor"] = None, +): + """Compile a relay program with a MetaSchedule database. - with target, autotvm_silencer(), transform.PassContext( - opt_level=opt_level, - config=pass_config, - disabled_pass=disabled_pass, - ): - if target.kind.name != "cuda" and isinstance( - autotvm.DispatchContext.current, autotvm.FallbackContext - ): - tophub_context = autotvm.tophub.context(target) - else: - tophub_context = autotvm.utils.EmptyContext() - with tophub_context: - return list(extract_task_func(mod, target, relay_params)) + Parameters + ---------- + database : Database + The database to use + mod : IRModule + The Relay program to be compiled + target : tvm.target.Target + The compilation target + params : Optional[Dict[str, tvm.runtime.NDArray]] + The associated parameters of the program + backend : str + The backend to use. Builtin backends: + - "graph" + - "vm" + opt_level : int + The optimization level of the compilation + pass_config : Mapping[str, Any] + The pass configuration + executor : Optional[relay.backend.Executor] + The executor to use in relay.build. It is not supported by RelayVM. + + Returns + ------- + lib : Union[Module, tvm.runtime.vm.Executable] + The built runtime module or vm Executable for the given relay workload. + """ + # pylint: disable=import-outside-toplevel + from tvm import relay + + # pylint: enable=import-outside-toplevel + mod, target, params, pass_config, executor = _normalize_params( + mod, target, params, pass_config, executor + ) + pass_config.setdefault("relay.backend.use_meta_schedule_dispatch", target.kind.name != "cuda") + with Profiler.timeit("PostTuningCompilation"): + with target, _autotvm_silencer(), database: + with transform.PassContext( + opt_level=opt_level, + config=pass_config, + ): + if backend == "graph": + return relay.build(mod, target=target, params=params, executor=executor) + elif backend == "vm": + return relay.vm.compile(mod, target=target, params=params) + else: + raise ValueError(f"Unknown backend: {backend}") def is_meta_schedule_enabled() -> bool: @@ -134,7 +384,8 @@ def is_meta_schedule_dispatch_enabled() -> bool: enabled: bool Whether the meta schedule is enabled """ - return transform.PassContext.current().config.get( + result = transform.PassContext.current().config.get( "relay.backend.use_meta_schedule_dispatch", - False, + 0, ) + return bool(result & 1) diff --git a/python/tvm/meta_schedule/runner/local_runner.py b/python/tvm/meta_schedule/runner/local_runner.py index 2d3214f53b6b..dfd4764607fb 100644 --- a/python/tvm/meta_schedule/runner/local_runner.py +++ b/python/tvm/meta_schedule/runner/local_runner.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. """Local Runner""" -import logging from contextlib import contextmanager from typing import Callable, List, Optional, Union @@ -23,6 +22,7 @@ from ...contrib.popen_pool import PopenPoolExecutor from ...runtime import Device, Module +from ..logging import get_logger from ..profiler import Profiler from ..utils import derived_object, get_global_func_with_default_on_worker from .config import EvaluatorConfig @@ -34,7 +34,7 @@ run_evaluator_common, ) -logger = logging.getLogger(__name__) # pylint: disable=invalid-name +logger = get_logger(__name__) # pylint: disable=invalid-name T_ALLOC_ARGUMENT = Callable[ # pylint: disable=invalid-name diff --git a/python/tvm/meta_schedule/runner/rpc_runner.py b/python/tvm/meta_schedule/runner/rpc_runner.py index aa6f3daaac60..9bdf715756cc 100644 --- a/python/tvm/meta_schedule/runner/rpc_runner.py +++ b/python/tvm/meta_schedule/runner/rpc_runner.py @@ -16,7 +16,6 @@ # under the License. """RPC Runner""" import concurrent.futures -import logging import os.path as osp from contextlib import contextmanager from typing import Callable, List, Optional, Union @@ -25,6 +24,7 @@ from tvm.rpc import RPCSession from tvm.runtime import Device, Module +from ..logging import get_logger from ..profiler import Profiler from ..utils import ( cpu_count, @@ -41,7 +41,7 @@ run_evaluator_common, ) -logger = logging.getLogger(__name__) # pylint: disable=invalid-name +logger = get_logger(__name__) # pylint: disable=invalid-name T_CREATE_SESSION = Callable[ # pylint: disable=invalid-name diff --git a/python/tvm/meta_schedule/runner/runner.py b/python/tvm/meta_schedule/runner/runner.py index 539e47f15c41..1753d8b4abf9 100644 --- a/python/tvm/meta_schedule/runner/runner.py +++ b/python/tvm/meta_schedule/runner/runner.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Runners""" -from typing import Callable, List, Optional +from typing import Callable, List, Optional, Union # isort: off from typing_extensions import Literal @@ -167,6 +167,8 @@ def result(self) -> RunnerResult: class Runner(Object): """The abstract runner interface""" + RunnerType = Union["Runner", Literal["local", "rpc"]] + def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: """Run the built artifact and get runner futures. @@ -182,6 +184,24 @@ def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: """ return _ffi_api.RunnerRun(self, runner_inputs) # type: ignore # pylint: disable=no-member + @staticmethod + def create( # pylint: disable=keyword-arg-before-vararg + kind: Literal["local", "rpc"] = "local", + *args, + **kwargs, + ) -> "Runner": + """Create a Runner.""" + from . import LocalRunner, RPCRunner # pylint: disable=import-outside-toplevel + + if kind == "local": + return LocalRunner(*args, **kwargs) # type: ignore + elif kind == "rpc": + return RPCRunner(*args, **kwargs) # type: ignore + raise ValueError(f"Unknown Runner: {kind}") + + +create = Runner.create # pylint: disable=invalid-name + @register_object("meta_schedule.PyRunner") class _PyRunner(Runner): @@ -228,18 +248,3 @@ def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: The runner futures. """ raise NotImplementedError - - -def create( # pylint: disable=keyword-arg-before-vararg - kind: Literal["local", "rpc"] = "local", - *args, - **kwargs, -) -> Runner: - """Create a Runner.""" - from . import LocalRunner, RPCRunner # pylint: disable=import-outside-toplevel - - if kind == "local": - return LocalRunner(*args, **kwargs) # type: ignore - elif kind == "rpc": - return RPCRunner(*args, **kwargs) # type: ignore - raise ValueError(f"Unknown Runner: {kind}") diff --git a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py index 2c8e223611aa..19cb1d8a55ec 100644 --- a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py +++ b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py @@ -20,6 +20,11 @@ """ from typing import TYPE_CHECKING, Callable, List +# isort: off +from typing_extensions import Literal + +# isort: on + from tvm._ffi import register_object from tvm.runtime import Object from tvm.tir.schedule import BlockRV, Schedule @@ -76,6 +81,36 @@ def clone(self) -> "ScheduleRule": """ return _ffi_api.ScheduleRuleClone(self) # type: ignore # pylint: disable=no-member + @staticmethod + def create(kind: Literal["llvm", "cuda", "cuda-tensorcore", "hexagon"]) -> List["ScheduleRule"]: + """Create a list of schedule rules for the given kind. + + Parameters + ---------- + kind : Literal["llvm", "cuda", "cuda-tensorcore", "hexagon"] + The kind of the schedule rules. + + Returns + ------- + rules : List[ScheduleRule] + The list of schedule rules. + """ + funcs = { + # pylint: disable=no-member + "llvm": _ffi_api.ScheduleRuleDefaultLLVM, # type: ignore + "cuda": _ffi_api.ScheduleRuleDefaultCUDA, # type: ignore + "cuda-tensorcore": _ffi_api.ScheduleRuleDefaultCUDATensorCore, # type: ignore + "hexagon": _ffi_api.ScheduleRuleDefaultHexagon, # type: ignore + # pylint: enable=no-member + } + for k, v in funcs.items(): + if k == kind: + return v() + raise ValueError(f"Unsupported kind {kind} for schedule rule creation.") + + +create = ScheduleRule.create # pylint: disable=invalid-name + @register_object("meta_schedule.PyScheduleRule") class _PyScheduleRule(ScheduleRule): diff --git a/python/tvm/meta_schedule/search_strategy/evolutionary_search.py b/python/tvm/meta_schedule/search_strategy/evolutionary_search.py index f54fc53935f0..2851ebe7b1d1 100644 --- a/python/tvm/meta_schedule/search_strategy/evolutionary_search.py +++ b/python/tvm/meta_schedule/search_strategy/evolutionary_search.py @@ -29,10 +29,6 @@ class EvolutionarySearch(SearchStrategy): Parameters ---------- - num_trials_per_iter : int - Number of trials per iteration. - max_trials_per_task : int - Total number of trials. population_size : int The initial population of traces from measured samples and randomly generated samples. init_measured_ratio : int @@ -49,8 +45,6 @@ class EvolutionarySearch(SearchStrategy): The ratio of greedy selected samples in the final picks. """ - num_trials_per_iter: int - max_trials_per_task: int population_size: int init_measured_ratio: int init_min_unmeasured: int @@ -62,8 +56,6 @@ class EvolutionarySearch(SearchStrategy): def __init__( self, *, - num_trials_per_iter: int, - max_trials_per_task: int, population_size: int = 2048, init_measured_ratio: float = 0.2, init_min_unmeasured: int = 50, @@ -75,8 +67,6 @@ def __init__( """Constructor""" self.__init_handle_by_constructor__( _ffi_api.SearchStrategyEvolutionarySearch, # type: ignore # pylint: disable=no-member - num_trials_per_iter, - max_trials_per_task, population_size, init_measured_ratio, init_min_unmeasured, diff --git a/python/tvm/meta_schedule/search_strategy/replay_func.py b/python/tvm/meta_schedule/search_strategy/replay_func.py index d89e2b133cde..f4660014241a 100644 --- a/python/tvm/meta_schedule/search_strategy/replay_func.py +++ b/python/tvm/meta_schedule/search_strategy/replay_func.py @@ -35,17 +35,8 @@ class ReplayFunc(SearchStrategy): Total number of trials for one task """ - num_trials_per_iter: int - max_trials_per_task: int - - def __init__( - self, - num_trials_per_iter: int, - max_trials_per_task: int, - ): + def __init__(self): """Constructor""" self.__init_handle_by_constructor__( _ffi_api.SearchStrategyReplayFunc, # type: ignore # pylint: disable=no-member - num_trials_per_iter, - max_trials_per_task, ) diff --git a/python/tvm/meta_schedule/search_strategy/replay_trace.py b/python/tvm/meta_schedule/search_strategy/replay_trace.py index 36dbb8734e57..e24ad5a5219a 100644 --- a/python/tvm/meta_schedule/search_strategy/replay_trace.py +++ b/python/tvm/meta_schedule/search_strategy/replay_trace.py @@ -29,25 +29,15 @@ class ReplayTrace(SearchStrategy): Parameters ---------- - num_trials_per_iter : int - Number of trials per iteration. - max_trials_per_task : int - Total number of trials for one task max_fail_count : int Max number of failures during trace replaying. """ - num_trials_per_iter: int - max_trials_per_task: int max_fail_count: int - def __init__( - self, num_trials_per_iter: int, max_trials_per_task: int, max_fail_count: int = 100 - ): + def __init__(self, max_fail_count: int = 100): """Constructor""" self.__init_handle_by_constructor__( _ffi_api.SearchStrategyReplayTrace, # type: ignore # pylint: disable=no-member - num_trials_per_iter, - max_trials_per_task, max_fail_count, ) diff --git a/python/tvm/meta_schedule/search_strategy/search_strategy.py b/python/tvm/meta_schedule/search_strategy/search_strategy.py index 276e65713325..3b72cc8d1ac6 100644 --- a/python/tvm/meta_schedule/search_strategy/search_strategy.py +++ b/python/tvm/meta_schedule/search_strategy/search_strategy.py @@ -18,7 +18,7 @@ Meta Schedule search strategy that generates the measure candidates for measurement. """ -from typing import TYPE_CHECKING, Callable, List, Optional +from typing import TYPE_CHECKING, Callable, List, Optional, Union # isort: off from typing_extensions import Literal @@ -76,10 +76,16 @@ def __init__( @register_object("meta_schedule.SearchStrategy") class SearchStrategy(Object): - """ - Search strategy is the class that generates the measure candidates. It has to be pre-tuned - before usage and post-tuned after usage. - """ + """Search strategy is the class that generates the measure candidates.""" + + SearchStrategyType = Union[ + "SearchStrategy", + Literal[ + "replay-func", + "replay-trace", + "evolutionary", + ], + ] def _initialize_with_tune_context(self, context: "TuneContext") -> None: """Initialize the search strategy with tuning context. @@ -95,6 +101,8 @@ def _initialize_with_tune_context(self, context: "TuneContext") -> None: def pre_tuning( self, + max_trials: int, + num_trials_per_iter: int, design_spaces: List[Schedule], database: Optional["Database"] = None, cost_model: Optional["CostModel"] = None, @@ -103,6 +111,10 @@ def pre_tuning( Parameters ---------- + max_trials : int + The maximum number of trials. + num_trials_per_iter : int + The number of trials per iteration. design_spaces : List[Schedule] The design spaces used during tuning process. database : Optional[Database] = None @@ -112,6 +124,8 @@ def pre_tuning( """ _ffi_api.SearchStrategyPreTuning( # type: ignore # pylint: disable=no-member self, + max_trials, + num_trials_per_iter, design_spaces, database, cost_model, @@ -161,6 +175,34 @@ def clone(self) -> "SearchStrategy": """ return _ffi_api.SearchStrategyClone(self) # type: ignore # pylint: disable=no-member + @staticmethod + def create( # pylint: disable=keyword-arg-before-vararg + kind: Literal[ + "evolutionary", + "replay-trace", + "replay-func", + ] = "evolutionary", + *args, + **kwargs, + ) -> "SearchStrategy": + """Create a search strategy.""" + from . import ( # pylint: disable=import-outside-toplevel + EvolutionarySearch, + ReplayFunc, + ReplayTrace, + ) + + if kind == "evolutionary": + return EvolutionarySearch(*args, **kwargs) + if kind == "replay-trace": + return ReplayTrace(*args, **kwargs) + if kind == "replay-func": + return ReplayFunc(*args, **kwargs) # type: ignore + raise ValueError(f"Unknown SearchStrategy: {kind}") + + +create = SearchStrategy.create # pylint: disable=invalid-name + @register_object("meta_schedule.PySearchStrategy") class _PySearchStrategy(SearchStrategy): @@ -223,7 +265,14 @@ def _initialize_with_tune_context(self, context: "TuneContext") -> None: """ raise NotImplementedError - def pre_tuning(self, design_spaces: List[Schedule]) -> None: + def pre_tuning( + self, + max_trials: int, + num_trials_per_iter: int, + design_spaces: List[Schedule], + database: Optional["Database"] = None, + cost_model: Optional["CostModel"] = None, + ) -> None: """Pre-tuning for the search strategy. Parameters @@ -272,28 +321,3 @@ def clone(self) -> SearchStrategy: The cloned search strategy. """ raise NotImplementedError - - -def create( # pylint: disable=keyword-arg-before-vararg - kind: Literal[ - "evolutionary", - "replay_trace", - "replay_func", - ] = "evolutionary", - *args, - **kwargs, -) -> SearchStrategy: - """Create a search strategy.""" - from . import ( # pylint: disable=import-outside-toplevel - EvolutionarySearch, - ReplayFunc, - ReplayTrace, - ) - - if kind == "evolutionary": - return EvolutionarySearch(*args, **kwargs) - if kind == "replay_trace": - return ReplayTrace(*args, **kwargs) - if kind == "replay_func": - return ReplayFunc(*args, **kwargs) - raise ValueError(f"Unknown SearchStrategy: {kind}") diff --git a/python/tvm/meta_schedule/space_generator/post_order_apply.py b/python/tvm/meta_schedule/space_generator/post_order_apply.py index 6e2a2c52b1a1..930e8a51dc61 100644 --- a/python/tvm/meta_schedule/space_generator/post_order_apply.py +++ b/python/tvm/meta_schedule/space_generator/post_order_apply.py @@ -15,11 +15,16 @@ # specific language governing permissions and limitations # under the License. """Post Order Apply Space Generator.""" - - from tvm._ffi import register_object -from .space_generator import SpaceGenerator + from .. import _ffi_api +from .space_generator import ( + MutatorProbType, + PostprocType, + ScheduleRuleType, + SpaceGenerator, + _normalize_rules, +) @register_object("meta_schedule.PostOrderApply") @@ -37,8 +42,19 @@ class PostOrderApply(SpaceGenerator): all blocks will have schedules generated. """ - def __init__(self, f_block_filter=None): + def __init__( + self, + f_block_filter=None, + sch_rules: ScheduleRuleType = "from-target", + postprocs: PostprocType = "from-target", + mutator_probs: MutatorProbType = "from-target", + ): """Constructor""" + sch_rules, postprocs, mutator_probs = _normalize_rules(sch_rules, postprocs, mutator_probs) self.__init_handle_by_constructor__( - _ffi_api.SpaceGeneratorPostOrderApply, f_block_filter # type: ignore # pylint: disable=no-member + _ffi_api.SpaceGeneratorPostOrderApply, # type: ignore # pylint: disable=no-member + f_block_filter, + sch_rules, + postprocs, + mutator_probs, ) diff --git a/python/tvm/meta_schedule/space_generator/schedule_fn.py b/python/tvm/meta_schedule/space_generator/schedule_fn.py index d6b063dcb263..65956e843679 100644 --- a/python/tvm/meta_schedule/space_generator/schedule_fn.py +++ b/python/tvm/meta_schedule/space_generator/schedule_fn.py @@ -18,7 +18,13 @@ from tvm._ffi import register_object from .. import _ffi_api -from .space_generator import SpaceGenerator +from .space_generator import ( + MutatorProbType, + PostprocType, + ScheduleRuleType, + SpaceGenerator, + _normalize_rules, +) @register_object("meta_schedule.ScheduleFn") @@ -30,7 +36,13 @@ class ScheduleFn(SpaceGenerator): - 3) [Schedule] -> List[Schedule] """ - def __init__(self, sch_fn: SpaceGenerator.ScheduleFnType): + def __init__( + self, + sch_fn: SpaceGenerator.ScheduleFnType, + sch_rules: ScheduleRuleType = "from-target", + postprocs: PostprocType = "from-target", + mutator_probs: MutatorProbType = "from-target", + ): """Constructor. Parameters @@ -41,7 +53,11 @@ def __init__(self, sch_fn: SpaceGenerator.ScheduleFnType): - 2) [Schedule] -> Schedule - 3) [Schedule] -> List[Schedule] """ + sch_rules, postprocs, mutator_probs = _normalize_rules(sch_rules, postprocs, mutator_probs) self.__init_handle_by_constructor__( _ffi_api.SpaceGeneratorScheduleFn, # type: ignore # pylint: disable=no-member sch_fn, + sch_rules, + postprocs, + mutator_probs, ) diff --git a/python/tvm/meta_schedule/space_generator/space_generator.py b/python/tvm/meta_schedule/space_generator/space_generator.py index 23c0361645b5..f6212a360a87 100644 --- a/python/tvm/meta_schedule/space_generator/space_generator.py +++ b/python/tvm/meta_schedule/space_generator/space_generator.py @@ -18,7 +18,7 @@ Meta Schedule design space generators that generates design space for generation of measure candidates. """ -from typing import TYPE_CHECKING, Callable, List, Optional, Union +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union # isort: off from typing_extensions import Literal @@ -32,6 +32,9 @@ from .. import _ffi_api if TYPE_CHECKING: + from ..mutator import Mutator + from ..postproc import Postproc + from ..schedule_rule import ScheduleRule from ..tune_context import TuneContext @@ -45,6 +48,16 @@ class SpaceGenerator(Object): Callable[[Schedule], List[Schedule]], # Multiple outputs ] + SpaceGeneratorType = Union[ + "SpaceGenerator", + ScheduleFnType, + Literal["post-order-apply", "union"], + ] + + sch_rules: Optional[List["ScheduleRule"]] + postprocs: Optional[List["Postproc"]] + mutator_probs: Optional[Dict["Mutator", float]] + def _initialize_with_tune_context(self, context: "TuneContext") -> None: """Initialize the design space generator with tuning context. @@ -82,8 +95,91 @@ def clone(self) -> "SpaceGenerator": """ return _ffi_api.SpaceGeneratorClone(self) # type: ignore # pylint: disable=no-member + @staticmethod + def create( # pylint: disable=keyword-arg-before-vararg + kind: Union[ + Literal["post-order-apply", "union"], + ScheduleFnType, + ] = "post-order-apply", + *args, + **kwargs, + ) -> "SpaceGenerator": + """Create a design space generator.""" + from . import ( # pylint: disable=import-outside-toplevel + PostOrderApply, + ScheduleFn, + SpaceGeneratorUnion, + ) + + if callable(kind): + + def create_schedule_fn( + func, + sch_rules=[], + postprocs=[], + mutator_probs={}, + ): # pylint: disable=dangerous-default-value + return ScheduleFn(func, sch_rules, postprocs, mutator_probs) + + return create_schedule_fn(kind, *args, **kwargs) # type: ignore + if kind == "post-order-apply": + return PostOrderApply(*args, **kwargs) + if kind == "union": + return SpaceGeneratorUnion(*args, **kwargs) + raise ValueError(f"Unknown SpaceGenerator: {kind}") + ScheduleFnType = SpaceGenerator.ScheduleFnType +ScheduleRuleType = Union[ + List["ScheduleRule"], + Literal["llvm", "cuda", "cuda-tensorcore", "hexagon", "from-target"], +] +PostprocType = Union[ + List["Postproc"], + Literal["llvm", "cuda", "cuda-tensorcore", "hexagon", "from-target"], +] +MutatorProbType = Union[ + Dict["Mutator", float], + Literal["llvm", "cuda", "cuda-tensorcore", "hexagon", "from-target"], +] +create = SpaceGenerator.create # pylint: disable=invalid-name + + +def _normalize_rules( + sch_rules: ScheduleRuleType, + postprocs: PostprocType, + mutator_probs: MutatorProbType, +) -> Tuple[ + Optional[List["ScheduleRule"]], + Optional[List["Postproc"]], + Optional[Dict["Mutator", float]], +]: + # pylint: disable=import-outside-toplevel + from ..mutator import Mutator + from ..postproc import Postproc + from ..schedule_rule import ScheduleRule + + # pylint: enable=import-outside-toplevel + assert sch_rules is not None + assert postprocs is not None + assert mutator_probs is not None + + if isinstance(sch_rules, str): + if sch_rules == "from-target": + sch_rules = None + else: + sch_rules = ScheduleRule.create(sch_rules) + if isinstance(postprocs, str): + if postprocs == "from-target": + postprocs = None + else: + postprocs = Postproc.create(postprocs) + if isinstance(mutator_probs, str): + if mutator_probs == "from-target": + mutator_probs = None + else: + mutator_probs = Mutator.create(mutator_probs) + return sch_rules, postprocs, mutator_probs # type: ignore @register_object("meta_schedule.PySpaceGenerator") @@ -97,14 +193,21 @@ class _PySpaceGenerator(SpaceGenerator): def __init__( self, + sch_rules: ScheduleRuleType = "from-target", + postprocs: PostprocType = "from-target", + mutator_probs: MutatorProbType = "from-target", f_initialize_with_tune_context: Optional[Callable] = None, f_generate_design_space: Optional[Callable] = None, f_clone: Optional[Callable] = None, ): """Constructor.""" + sch_rules, postprocs, mutator_probs = _normalize_rules(sch_rules, postprocs, mutator_probs) self.__init_handle_by_constructor__( _ffi_api.SpaceGeneratorPySpaceGenerator, # type: ignore # pylint: disable=no-member + sch_rules, + postprocs, + mutator_probs, f_initialize_with_tune_context, f_generate_design_space, f_clone, @@ -121,6 +224,7 @@ class PySpaceGenerator: _tvm_metadata = { "cls": _PySpaceGenerator, + "fields": ["sch_rules", "postprocs", "mutator_probs"], "methods": ["_initialize_with_tune_context", "generate_design_space", "clone"], } @@ -158,27 +262,3 @@ def clone(self) -> SpaceGenerator: The cloned design space generator. """ raise NotImplementedError - - -def create( # pylint: disable=keyword-arg-before-vararg - kind: Union[ - Literal["post_order_apply", "union"], - ScheduleFnType, - ] = "post_order_apply", - *args, - **kwargs, -) -> SpaceGenerator: - """Create a design space generator.""" - from . import ( # pylint: disable=import-outside-toplevel - PostOrderApply, - ScheduleFn, - SpaceGeneratorUnion, - ) - - if callable(kind): - return ScheduleFn(kind, *args, **kwargs) # type: ignore - if kind == "post_order_apply": - return PostOrderApply(*args, **kwargs) - if kind == "union": - return SpaceGeneratorUnion(*args, **kwargs) - raise ValueError(f"Unknown SpaceGenerator: {kind}") diff --git a/python/tvm/meta_schedule/space_generator/space_generator_union.py b/python/tvm/meta_schedule/space_generator/space_generator_union.py index 5541ab0b5026..e3d8f441d1ef 100644 --- a/python/tvm/meta_schedule/space_generator/space_generator_union.py +++ b/python/tvm/meta_schedule/space_generator/space_generator_union.py @@ -20,14 +20,26 @@ from tvm._ffi import register_object from .. import _ffi_api -from .space_generator import SpaceGenerator +from .space_generator import ( + MutatorProbType, + PostprocType, + ScheduleRuleType, + SpaceGenerator, + _normalize_rules, +) @register_object("meta_schedule.SpaceGeneratorUnion") class SpaceGeneratorUnion(SpaceGenerator): """Union of design space generators.""" - def __init__(self, space_generators: List[SpaceGenerator]): + def __init__( + self, + space_generators: List[SpaceGenerator], + sch_rules: ScheduleRuleType = "from-target", + postprocs: PostprocType = "from-target", + mutator_probs: MutatorProbType = "from-target", + ): """Constructor. Parameters @@ -35,7 +47,11 @@ def __init__(self, space_generators: List[SpaceGenerator]): space_generators : List[SpaceGenerator] The list of design space generators to be unioned. """ + sch_rules, postprocs, mutator_probs = _normalize_rules(sch_rules, postprocs, mutator_probs) self.__init_handle_by_constructor__( _ffi_api.SpaceGeneratorSpaceGeneratorUnion, # type: ignore # pylint: disable=no-member space_generators, + sch_rules, + postprocs, + mutator_probs, ) diff --git a/python/tvm/meta_schedule/task_scheduler/gradient_based.py b/python/tvm/meta_schedule/task_scheduler/gradient_based.py index 20d32dd1c59f..963de8711e10 100644 --- a/python/tvm/meta_schedule/task_scheduler/gradient_based.py +++ b/python/tvm/meta_schedule/task_scheduler/gradient_based.py @@ -15,24 +15,13 @@ # specific language governing permissions and limitations # under the License. """Gradient Based Task Scheduler""" -import logging -from typing import TYPE_CHECKING, List, Optional - from tvm._ffi import register_object from .. import _ffi_api -from ..builder import Builder -from ..cost_model import CostModel -from ..database import Database -from ..measure_callback import MeasureCallback -from ..runner import Runner -from ..utils import make_logging_func +from ..logging import get_logger, get_logging_func from .task_scheduler import TaskScheduler -if TYPE_CHECKING: - from ..tune_context import TuneContext - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name +logger = get_logger(__name__) # pylint: disable=invalid-name @register_object("meta_schedule.GradientBased") @@ -41,15 +30,7 @@ class GradientBased(TaskScheduler): def __init__( self, - tasks: List["TuneContext"], - task_weights: List[float], - builder: Builder, - runner: Runner, *, - database: Database, - cost_model: Optional[CostModel] = None, - measure_callbacks: Optional[List[MeasureCallback]] = None, - max_trials: int, alpha: float = 0.2, window_size: int = 3, seed: int = -1, @@ -58,22 +39,6 @@ def __init__( Parameters ---------- - tasks : List[TuneContext] - List of tasks to schedule. - task_weights : List[float] - The weights of each task. - builder : Builder - The builder. - runner : Runner - The runner. - database : Database - The database. - cost_model : CostModel, default None. - The cost model of the scheduler. - measure_callbacks : Optional[List[MeasureCallback]] = None - The list of measure callbacks of the scheduler. - max_trials : int - The maximum number of trials to run. alpha : float = 0.2 The parameter alpha in gradient computation. window_size : int = 3 @@ -83,15 +48,7 @@ def __init__( """ self.__init_handle_by_constructor__( _ffi_api.TaskSchedulerGradientBased, # type: ignore # pylint: disable=no-member - tasks, - task_weights, - builder, - runner, - database, - cost_model, - measure_callbacks, - max_trials, - make_logging_func(logger), + get_logging_func(logger), alpha, window_size, seed, diff --git a/python/tvm/meta_schedule/task_scheduler/round_robin.py b/python/tvm/meta_schedule/task_scheduler/round_robin.py index ed395643bbaa..e5c7f14af424 100644 --- a/python/tvm/meta_schedule/task_scheduler/round_robin.py +++ b/python/tvm/meta_schedule/task_scheduler/round_robin.py @@ -15,87 +15,22 @@ # specific language governing permissions and limitations # under the License. """Round Robin Task Scheduler""" - -import logging -from typing import TYPE_CHECKING, List, Optional - from tvm._ffi import register_object -from tvm.meta_schedule.measure_callback.measure_callback import MeasureCallback from .. import _ffi_api -from ..builder import Builder -from ..cost_model import CostModel -from ..database import Database -from ..runner import Runner -from ..utils import make_logging_func +from ..logging import get_logger, get_logging_func from .task_scheduler import TaskScheduler -if TYPE_CHECKING: - from ..tune_context import TuneContext - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name +logger = get_logger(__name__) # pylint: disable=invalid-name @register_object("meta_schedule.RoundRobin") class RoundRobin(TaskScheduler): - """Round Robin Task Scheduler - - Parameters - ---------- - tasks: List[TuneContext] - The list of tune context to process. - builder: Builder - The builder of the scheduler. - runner: Runner - The runner of the scheduler. - database: Database - The database of the scheduler. - measure_callbacks: Optional[List[MeasureCallback]] = None - The list of measure callbacks of the scheduler. - """ - - def __init__( - self, - tasks: List["TuneContext"], - task_weights: List[float], - builder: Builder, - runner: Runner, - *, - database: Database, - cost_model: Optional[CostModel] = None, - measure_callbacks: Optional[List[MeasureCallback]] = None, - max_trials: int, - ) -> None: - """Constructor. + """Round Robin Task Scheduler""" - Parameters - ---------- - tasks : List[TuneContext] - List of tasks to schedule. - task_weights : List[float] - List of weights for each task. Not used in round robin. - builder : Builder - The builder. - runner : Runner - The runner. - database : Database - The database. - cost_model : Optional[CostModel] - The cost model. - measure_callbacks: Optional[List[MeasureCallback]] - The list of measure callbacks of the scheduler. - max_trials : int - The maximum number of trials. - """ - del task_weights + def __init__(self) -> None: + """Constructor.""" self.__init_handle_by_constructor__( _ffi_api.TaskSchedulerRoundRobin, # type: ignore # pylint: disable=no-member - tasks, - builder, - runner, - database, - cost_model, - measure_callbacks, - max_trials, - make_logging_func(logger), + get_logging_func(logger), ) diff --git a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py index 29a5f18dfb8a..f06f4d911fa8 100644 --- a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py +++ b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py @@ -15,9 +15,7 @@ # specific language governing permissions and limitations # under the License. """Auto-tuning Task Scheduler""" - -import logging -from typing import Callable, List, Optional +from typing import Callable, List, Optional, Union # isort: off from typing_extensions import Literal @@ -28,53 +26,44 @@ from tvm.runtime import Object from .. import _ffi_api -from ..builder import Builder +from ..builder import Builder, BuilderResult from ..cost_model import CostModel from ..database import Database +from ..logging import get_logger, get_logging_func from ..measure_callback import MeasureCallback from ..runner import Runner, RunnerResult +from ..search_strategy import MeasureCandidate from ..tune_context import TuneContext -from ..utils import make_logging_func -logger = logging.getLogger(__name__) # pylint: disable=invalid-name +logger = get_logger(__name__) # pylint: disable=invalid-name + + +@register_object("meta_schedule.TaskRecord") +class TaskRecord(Object): + """The running record of a task.""" + + ctx: TuneContext + task_weight: float + flop: float + is_terminated: bool + build_error_count: int + run_error_count: int + measure_candidates: List[MeasureCandidate] + builder_results: List[BuilderResult] + runner_results: List[RunnerResult] @register_object("meta_schedule.TaskScheduler") class TaskScheduler(Object): - """The abstract task scheduler interface. - - Parameters - ---------- - tasks: List[TuneContext] - The list of tune context to process. - builder: Builder - The builder of the scheduler. - runner: Runner - The runner of the scheduler. - database: Database - The database of the scheduler. - max_trials : int - The maximum number of trials allowed. - cost_model : Optional[CostModel] - The cost model used for search. - measure_callbacks: List[MeasureCallback] = None - The list of measure callbacks of the scheduler. - num_trials_already : int - The number of trials already conducted. - """ + """The abstract task scheduler interface.""" - tasks: List[TuneContext] - builder: Builder - runner: Runner - database: Database - max_trials: int - cost_model: Optional[CostModel] - measure_callbacks: List[MeasureCallback] - num_trials_already: int + tasks_: List[TaskRecord] + measure_callbacks_: List[MeasureCallback] + database_: Optional[Database] + cost_model_: Optional[CostModel] + remaining_tasks_: int - def tune(self) -> None: - """Auto-tuning.""" - _ffi_api.TaskSchedulerTune(self) # type: ignore # pylint: disable=no-member + TaskSchedulerType = Union["TaskScheduler", Literal["gradient", "round-robin"]] def next_task_id(self) -> int: """Fetch the next task id. @@ -101,15 +90,68 @@ def join_running_task(self, task_id: int) -> List[RunnerResult]: """ return _ffi_api.TaskSchedulerJoinRunningTask(self, task_id) # type: ignore # pylint: disable=no-member - def initialize_task(self, task_id: int) -> None: - """Initialize modules of the given task. + def tune( + self, + tasks: List[TuneContext], + task_weights: List[float], + max_trials_global: int, + max_trials_per_task: int, + num_trials_per_iter: int, + builder: Builder, + runner: Runner, + measure_callbacks: List[MeasureCallback], + database: Optional[Database], + cost_model: Optional[CostModel], + ) -> None: + """Auto-tuning. + + Parameters + ---------- + tasks : List[TuneContext] + The list of tuning contexts as tasks. + task_weights : List[float] + The list of task weights. + max_trials_global : int + The maximum number of trials globally. + max_trials_per_task : int + The maximum number of trials per task. + num_trials_per_iter : int + The number of trials per iteration. + builder : Builder + The builder. + runner : Runner + The runner. + measure_callbacks : List[MeasureCallback] + The list of measure callbacks. + database : Optional[Database] + The database. + cost_model : Optional[CostModel] + The cost model. + """ + task_weights = [float(w) for w in task_weights] + _ffi_api.TaskSchedulerTune( # type: ignore # pylint: disable=no-member + self, + tasks, + task_weights, + max_trials_global, + max_trials_per_task, + num_trials_per_iter, + builder, + runner, + measure_callbacks, + database, + cost_model, + ) + + def terminate_task(self, task_id: int) -> None: + """Terminate the task Parameters ---------- task_id : int - The task id to be initialized. + The task id to be terminated. """ - _ffi_api.TaskSchedulerInitializeTask(self, task_id) # type: ignore # pylint: disable=no-member + _ffi_api.TaskSchedulerTerminateTask(self, task_id) # type: ignore # pylint: disable=no-member def touch_task(self, task_id: int) -> None: """Touch the task and update its status @@ -121,6 +163,37 @@ def touch_task(self, task_id: int) -> None: """ _ffi_api.TaskSchedulerTouchTask(self, task_id) # type: ignore # pylint: disable=no-member + def tuning_statistics(self) -> str: + """Returns a human-readable string of the tuning statistics. + + Returns + ------- + tuning_statistics : str + The tuning statistics. + """ + return _ffi_api.TaskSchedulerTuningStatistics(self) # type: ignore # pylint: disable=no-member + + @staticmethod + def create( # pylint: disable=keyword-arg-before-vararg + kind: Literal["round-robin", "gradient"] = "gradient", + *args, + **kwargs, + ) -> "TaskScheduler": + """Create a task scheduler.""" + from . import ( # pylint: disable=import-outside-toplevel + GradientBased, + RoundRobin, + ) + + if kind == "round-robin": + return RoundRobin(*args, **kwargs) # type: ignore + if kind == "gradient": + return GradientBased(*args, **kwargs) + raise ValueError(f"Unknown TaskScheduler name: {kind}") + + +create = TaskScheduler.create # pylint: disable=invalid-name + @register_object("meta_schedule.PyTaskScheduler") class _PyTaskScheduler(TaskScheduler): @@ -133,36 +206,18 @@ class _PyTaskScheduler(TaskScheduler): def __init__( self, - tasks: List[TuneContext], - builder: Builder, - runner: Runner, - database: Database, - max_trials: int, - cost_model: Optional[CostModel] = None, - measure_callbacks: Optional[List[MeasureCallback]] = None, - f_tune: Callable = None, - f_initialize_task: Callable = None, - f_touch_task: Callable = None, - f_join_running_task: Callable = None, - f_next_task_id: Callable = None, + f_next_task_id: Callable, + f_join_running_task: Callable, + f_tune: Callable, ): """Constructor.""" self.__init_handle_by_constructor__( _ffi_api.TaskSchedulerPyTaskScheduler, # type: ignore # pylint: disable=no-member - tasks, - builder, - runner, - database, - max_trials, - cost_model, - measure_callbacks, - make_logging_func(logger), - f_tune, - f_initialize_task, - f_touch_task, - f_join_running_task, + get_logging_func(logger), f_next_task_id, + f_join_running_task, + f_tune, ) @@ -176,47 +231,39 @@ class PyTaskScheduler: _tvm_metadata = { "cls": _PyTaskScheduler, - "fields": [ - "tasks", - "builder", - "runner", - "database", - "cost_model", - "measure_callbacks", - "max_trials", - ], - "methods": [ - "tune", - "initialize_task", - "touch_task", - "join_running_task", - "next_task_id", - ], + "fields": [], + "methods": ["next_task_id", "join_running_task", "tune"], } - def __init__( + def __init__(self): + ... + + def tune( self, tasks: List[TuneContext], + task_weights: List[float], + max_trials_global: int, + max_trials_per_task: int, builder: Builder, runner: Runner, - *, - database: Optional[Database] = None, - cost_model: Optional[CostModel] = None, - measure_callbacks: Optional[List[MeasureCallback]] = None, - max_trials: int, - ): - self.tasks = tasks - self.builder = builder - self.runner = runner - self.database = database - self.cost_model = cost_model - self.measure_callbacks = measure_callbacks - self.max_trials = max_trials - - def tune(self) -> None: + measure_callbacks: List[MeasureCallback], + database: Optional[Database], + cost_model: Optional[CostModel], + ) -> None: """Auto-tuning.""" # Using self._outer to replace the self pointer - _ffi_api.TaskSchedulerTune(self._outer()) # type: ignore # pylint: disable=no-member + _ffi_api.TaskSchedulerTune( # type: ignore # pylint: disable=no-member + self._outer(), # type: ignore # pylint: disable=no-member + tasks, + task_weights, + max_trials_global, + max_trials_per_task, + builder, + runner, + measure_callbacks, + database, + cost_model, + ) def next_task_id(self) -> int: """Fetch the next task id. @@ -238,40 +285,3 @@ def join_running_task(self, task_id: int) -> List[RunnerResult]: """ # Using self._outer to replace the self pointer return _ffi_api.TaskSchedulerJoinRunningTask(self._outer(), task_id) # type: ignore # pylint: disable=no-member - - def initialize_task(self, task_id: int) -> None: - """Initialize modules of the given task. - - Parameters - ---------- - task_id : int - The task id to be initialized. - """ - # Using self._outer to replace the self pointer - _ffi_api.TaskSchedulerInitializeTask(self._outer(), task_id) # type: ignore # pylint: disable=no-member - - def touch_task(self, task_id: int) -> None: - """Touch the task and update its status - - Parameters - ---------- - task_id : int - The task id to be checked. - """ - # Using self._outer to replace the self pointer - _ffi_api.TaskSchedulerTouchTask(self._outer(), task_id) # type: ignore # pylint: disable=no-member - - -def create( # pylint: disable=keyword-arg-before-vararg - kind: Literal["round-robin", "gradient"] = "gradient", - *args, - **kwargs, -) -> "TaskScheduler": - """Create a task scheduler.""" - from . import GradientBased, RoundRobin # pylint: disable=import-outside-toplevel - - if kind == "round-robin": - return RoundRobin(*args, **kwargs) - if kind == "gradient": - return GradientBased(*args, **kwargs) - raise ValueError(f"Unknown TaskScheduler name: {kind}") diff --git a/python/tvm/meta_schedule/testing/dataset_extract_tasks.py b/python/tvm/meta_schedule/testing/dataset_extract_tasks.py index 1795996a3717..5d71d088a379 100644 --- a/python/tvm/meta_schedule/testing/dataset_extract_tasks.py +++ b/python/tvm/meta_schedule/testing/dataset_extract_tasks.py @@ -21,8 +21,8 @@ import json import os -from tqdm import tqdm # type: ignore import tvm +from tqdm import tqdm # type: ignore from tvm import meta_schedule as ms from tvm.ir import save_json from tvm.meta_schedule.testing.relay_workload import _load_cache @@ -60,7 +60,7 @@ def extract_and_save_tasks(cache_file): mod, params_bytearray, _ = _load_cache(args.model_cache_dir, cache_file) params = load_param_dict(params_bytearray) try: - extracted_tasks = ms.extract_task_from_relay(mod, target=args.target, params=params) + extracted_tasks = ms.relay_integration.extract_tasks(mod, target=args.target, params=params) except tvm.error.TVMError as error: print(str(error)) return diff --git a/python/tvm/meta_schedule/testing/dataset_sample_candidates.py b/python/tvm/meta_schedule/testing/dataset_sample_candidates.py index 35b872e7351e..39a12b494108 100644 --- a/python/tvm/meta_schedule/testing/dataset_sample_candidates.py +++ b/python/tvm/meta_schedule/testing/dataset_sample_candidates.py @@ -22,8 +22,8 @@ import os from typing import List -from tqdm import tqdm # type: ignore import tvm +from tqdm import tqdm # type: ignore from tvm import meta_schedule as ms from tvm.ir import load_json from tvm.target import Target @@ -117,25 +117,20 @@ def sample_candidates(task, task_name, model_name): evolve_with_cost_model = tvm.get_global_func( "meta_schedule.SearchStrategyEvolutionarySearchEvolveWithCostModel" ) - strategy = ms.search_strategy.EvolutionarySearch( - num_trials_per_iter=args.num_trials_per_iter, - max_trials_per_task=args.max_trials_per_task, - init_measured_ratio=0.0, - ) + strategy = ms.search_strategy.EvolutionarySearch(init_measured_ratio=0.0) target = Target(args.target) context = ms.TuneContext( mod=task, target=target, - space_generator=ms.space_generator.PostOrderApply(), + space_generator="post-order-apply", search_strategy=strategy, - sch_rules=ms.default_config.schedule_rules(None, target), - postprocs=ms.default_config.postproc(None, target), - mutator_probs=ms.default_config.mutator_probs(None, target), task_name=task_name, ) context.initialize() context.pre_tuning( - context.generate_design_space(), + max_trials=args.max_trials_per_task, + num_trials_per_iter=args.num_trials_per_iter, + design_spaces=context.generate_design_space(), database=database, cost_model=ms.cost_model.RandomModel(), # type: ignore ) diff --git a/python/tvm/meta_schedule/testing/relay_workload.py b/python/tvm/meta_schedule/testing/relay_workload.py index 9dcff2ace583..6d1cd7f1604c 100644 --- a/python/tvm/meta_schedule/testing/relay_workload.py +++ b/python/tvm/meta_schedule/testing/relay_workload.py @@ -24,9 +24,9 @@ import tvm import tvm.relay.testing +from tvm import meta_schedule as ms from tvm import relay from tvm.ir import IRModule -from tvm.meta_schedule import ExtractedTask, extract_task_from_relay from tvm.runtime import NDArray, load_param_dict, save_param_dict from tvm.target import Target @@ -34,15 +34,17 @@ def _get_network( - args: Tuple[str, List[int], str] + args: Tuple[str, List[int], Optional[str]] ) -> Tuple[IRModule, bytearray, Tuple[str, List[int], str]]: name: str input_shape: List[int] - layout: str + layout: Optional[str] name, input_shape, layout = args - mod: IRModule + if layout == "None": + layout = None + mod: IRModule if name in [ "resnet_18", "resnet_50", @@ -60,24 +62,30 @@ def _get_network( assert layout is None or layout in ["NCHW", "NHWC"] + params: Dict[str, Any] = {} if name in ["resnet_18", "resnet_50"]: - model = getattr(models, name.replace("_", ""))(weights=None) + model = getattr(models, name.replace("_", "")) elif name == "wide_resnet_50": - model = getattr(models, "wide_resnet50_2")(weights=None) + model = getattr(models, "wide_resnet50_2") elif name == "resnext_50": - model = getattr(models, "resnext50_32x4d")(weights=None) + model = getattr(models, "resnext50_32x4d") elif name == "mobilenet_v2": - model = getattr(models, name)(weights=None) + model = getattr(models, name) elif name == "mobilenet_v3": - model = getattr(models, name + "_large")(weights=None) + model = getattr(models, name + "_large") elif name == "inception_v3": - model = getattr(models, name)(weights=None, aux_logits=False) + model = getattr(models, name) + params["aux_logits"] = False elif name == "densenet_121": - model = getattr(models, name.replace("_", ""))(weights=None) + model = getattr(models, name.replace("_", "")) elif name == "resnet3d_18": - model = models.video.r3d_18(weights=None) + model = models.video.r3d_18 elif name == "vgg_16": - model = getattr(models, name.replace("_", ""))(weights=None) + model = getattr(models, name.replace("_", "")) + try: + model = model(**params, weights=None) + except TypeError: + model = model(**params, pretrained=False) dtype = "float32" input_data = torch.randn(input_shape).type( # pylint: disable=no-member @@ -90,7 +98,7 @@ def _get_network( shape_list = [(input_name, input_shape)] mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) passes = [relay.transform.RemoveUnusedFunctions()] - if layout == "NHWC": + if layout is None or layout == "NHWC": # PyTorch is imported as NCHW by default passes.append( relay.transform.ConvertLayout( @@ -251,10 +259,7 @@ def extract_from_relay( input_shape: List[int], *, cache_dir: Optional[str] = None, - opt_level: int = 3, - pass_config: Optional[Dict[str, Any]] = None, - disabled_pass: Optional[List[str]] = None, -) -> List[ExtractedTask]: +) -> List[ms.ExtractedTask]: """Extract the tasks from a network. Parameters @@ -272,12 +277,6 @@ def extract_from_relay( cache_dir : Optional[str] The directory to cache the generated network. If not specified, the cache will be disabled. - opt_level : int - The optimization level of the compiler. - pass_config : Optional[Dict[str, Any]] - The pass config of the compiler. - disabled_pass : Optional[List[str]] - The disabled pass of the compiler. Returns ------- @@ -287,13 +286,10 @@ def extract_from_relay( filename = f'tasks-{target.kind.name}-{name}-{",".join(str(i) for i in input_shape)}.json' extracted_tasks = _load_cache(cache_dir, filename) if extracted_tasks is None: - extracted_tasks = extract_task_from_relay( + extracted_tasks = ms.relay_integration.extract_tasks( mod=mod, target=target, params=params, - opt_level=opt_level, - pass_config=pass_config, - disabled_pass=disabled_pass, ) extracted_tasks = list(extracted_tasks) _save_cache(cache_dir, filename, extracted_tasks) diff --git a/python/tvm/meta_schedule/testing/schedule_rule.py b/python/tvm/meta_schedule/testing/schedule_rule.py deleted file mode 100644 index f14e90b6f0b2..000000000000 --- a/python/tvm/meta_schedule/testing/schedule_rule.py +++ /dev/null @@ -1,36 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Default schedule rules""" -from typing import List, Tuple, Union - -from tvm.meta_schedule import default_config -from tvm.meta_schedule.schedule_rule import ScheduleRule - - -def get_rules(kind: str, types: Union[type, Tuple[type, ...]]) -> List[ScheduleRule]: - """Get default schedule rules""" - # pylint: disable=protected-access - if kind == "llvm": - rules = default_config._DefaultLLVM.schedule_rules() - elif kind == "cuda": - rules = default_config._DefaultCUDA.schedule_rules() - elif kind == "tensor_core": - rules = default_config._DefaultCUDATensorCore.schedule_rules() - else: - raise NotImplementedError(f"{kind} is not supported") - # pylint: enable=protected-access - return [rule for rule in rules if isinstance(rule, types)] diff --git a/python/tvm/meta_schedule/testing/space_generation.py b/python/tvm/meta_schedule/testing/space_generation.py index f85faca13f7a..5ac20f8fdf2f 100644 --- a/python/tvm/meta_schedule/testing/space_generation.py +++ b/python/tvm/meta_schedule/testing/space_generation.py @@ -15,24 +15,51 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union +# isort: off +from typing_extensions import Literal + +# isort: on + +from tvm import meta_schedule as ms from tvm.ir import IRModule, structural_equal +from tvm.target import Target from tvm.tir import Schedule from tvm.tir.schedule import Trace from tvm.tir.schedule.testing import verify_trace_roundtrip -def check_trace(spaces: List[Schedule], expected: List[List[str]]): - expected_traces = {"\n".join(t) for t in expected} - actual_traces = set() - for space in spaces: - trace = Trace(space.trace.insts, {}) - trace = trace.simplified(remove_postproc=True) - str_trace = "\n".join(t[2:] for t in str(trace).strip().splitlines()[2:] if t != " pass") - actual_traces.add(str_trace) - assert str_trace in expected_traces, "\n" + str_trace - assert len(expected_traces) == len(actual_traces) +def get_rules( + kind: Literal["llvm", "cuda", "cuda-tensorcore", "hexagon"], + types: Union[type, Tuple[type, ...]], +) -> List[ms.ScheduleRule]: + """Get default schedule rules""" + rules = ms.ScheduleRule.create(kind) + return [rule for rule in rules if isinstance(rule, types)] + + +def generate_design_space( + kind: Literal["llvm", "cuda", "cuda-tensorcore", "hexagon"], + mod: IRModule, + target: Target, + types: Union[type, Tuple[type, ...]], + sch_rules: Optional[List[ms.ScheduleRule]] = None, +) -> List[Schedule]: + if sch_rules is None: + sch_rules = get_rules(kind, types) + else: + assert types is None + return ms.TuneContext( + mod=mod, + target=target, + space_generator=ms.space_generator.PostOrderApply( + sch_rules=sch_rules, + postprocs=[], + mutator_probs={}, + ), + task_name="test", + ).generate_design_space() def _find_match_sketch_id( diff --git a/python/tvm/meta_schedule/testing/tlcbench.py b/python/tvm/meta_schedule/testing/tlcbench.py index 108d83ba9de9..2e9f9f52b1fc 100644 --- a/python/tvm/meta_schedule/testing/tlcbench.py +++ b/python/tvm/meta_schedule/testing/tlcbench.py @@ -17,14 +17,14 @@ # pylint: disable=invalid-name,import-outside-toplevel # type: ignore """Model loader for TLCBench.""" +import logging import multiprocessing import os -import logging + import tvm from tvm import relay from tvm.contrib.download import download_testdata - log = logging.getLogger(__name__) @@ -64,7 +64,6 @@ def deserialize_relay(json_path, params_path): with open(params_path, "rb") as fi: params = relay.load_param_dict(fi.read()) - return mod, params diff --git a/python/tvm/meta_schedule/testing/torchbench/run.py b/python/tvm/meta_schedule/testing/torchbench/run.py index f6984d1c9d10..fe939b2c9ba9 100644 --- a/python/tvm/meta_schedule/testing/torchbench/run.py +++ b/python/tvm/meta_schedule/testing/torchbench/run.py @@ -54,7 +54,7 @@ --mode tune \ --model resnet50 \ --target "nvidia/geforce-rtx-3070" \ - --work-dir ../workdir \ + --work-dir /path/to/work/dir/ \ --num-trials 20000 \ --rpc-host \ --rpc-port \ @@ -73,7 +73,7 @@ --mode eval \ --model resnet50 \ --target "nvidia/geforce-rtx-3070" \ - --work-dir ../workdir \ + --work-dir /path/to/work/dir/ \ --num-trials 0 ``` @@ -84,13 +84,11 @@ --mode all \ --model resnet50 \ --target "llvm -num-cores 6" \ - --work-dir ../workdir \ + --work-dir /path/to/work/dir/ \ --num-trials 0 ``` """ - # pylint: disable=logging-format-interpolation - import argparse import functools import logging @@ -100,10 +98,9 @@ import numpy as np # type: ignore import torch # type: ignore -from scipy.stats import ttest_ind # type: ignore - import tvm import tvm.relay +from scipy.stats import ttest_ind # type: ignore from tvm import meta_schedule as ms from tvm.contrib.graph_executor import GraphModule from tvm.meta_schedule.testing.torchbench.utils import ( @@ -147,10 +144,10 @@ def should_eval(self): class ResultComparisonMetric(Enum): """ - This changes how it compares the resultl with the expected value during + This changes how it compares the results with the expected value during accuracy check. - cosine: Use the cosine similarity. It should be greater than 0.99. - - allclose-1e-4: Use the max element-wise absolute difference. It should be less than 1e-4. + - allclose-1e-4: Use the max elementwise absolute difference. It should be less than 1e-4. """ COSINE = "cosine" @@ -220,15 +217,6 @@ def parse_args(): The working directory to save intermediate results and store databases for compilation. """, ) - args.add_argument( - "--cache-dir", - type=str, - default=None, - help=""" - The directory to cache the generated network. - If not specified, the cache will be disabled. - """, - ) args.add_argument( "--num-trials", type=int, @@ -279,7 +267,7 @@ def parse_args(): args.add_argument( "--adaptive-training", action="store_true", - help="Whether to use adpative training for cost model.", + help="Whether to use adaptive training for cost model.", ) args.add_argument( "--cpu-flush", @@ -309,7 +297,8 @@ def parse_args(): logging.basicConfig( - format="%(asctime)s.%(msecs)03d %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S" + format="%(asctime)s.%(msecs)03d %(levelname)s %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", ) logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) ARGS = parse_args() @@ -320,11 +309,12 @@ def parse_args(): runner = load_torchdynamo_benchmark_runner( # pylint: disable=invalid-name - IS_CUDA, cosine_similarity=ARGS.result_metric == ResultComparisonMetric.COSINE + IS_CUDA, + cosine_similarity=ARGS.result_metric == ResultComparisonMetric.COSINE, ) -def get_metaschedule_runner() -> ms.runner.PyRunner: +def get_meta_schedule_runner() -> ms.runner.PyRunner: """ Get the Runner for MetaSchedule. @@ -349,33 +339,10 @@ def get_metaschedule_runner() -> ms.runner.PyRunner: alloc_repeat=1, ) else: - warnings.warn("Falling back to Metaschedule LocalRunner because --rpc-host isn't provided.") + warnings.warn("Falling back to MetaSchedule LocalRunner because --rpc-host isn't provided.") return ms.runner.LocalRunner() -def get_tune_config() -> ms.TuneConfig: - """ - Get the TuneConfig. - """ - if ARGS.mode.should_tune: - max_trials_per_task = ARGS.max_trials_per_task - max_trials_global = ARGS.num_trials - else: - max_trials_per_task = 0 - max_trials_global = 0 - - if max_trials_per_task is None: - max_trials_per_task = max_trials_global - - return ms.TuneConfig( - strategy="evolutionary", - num_trials_per_iter=64, - max_trials_per_task=max_trials_per_task, - max_trials_global=max_trials_global, - adaptive_training=ARGS.adaptive_training, - ) - - def get_graph_executor_forward(mod: GraphModule, device: tvm.runtime.Device) -> Callable: """ Get the forward function for graph executor, in order to integrate with TorchDynamo. @@ -419,7 +386,7 @@ def forward(*args): def create_tvm_task_collection_backend(tasks: List[ms.ExtractedTask]) -> Callable: """ - This torchdynamo backend only collects the extracted tasks from Metaschedule. + This torchdynamo backend only collects the extracted tasks from MetaSchedule. It doesn't tune the model. """ @@ -428,7 +395,11 @@ def backend(graph_module, example_inputs): shape_list = [(f"inp_{idx}", i.shape) for idx, i in enumerate(example_inputs)] ir_mod, params = tvm.relay.frontend.from_pytorch(jit_mod, shape_list) - extracted_tasks = ms.extract_task_from_relay(ir_mod, ARGS.target, params) + extracted_tasks = ms.relay_integration.extract_tasks( + mod=ir_mod, + target=ARGS.target, + params=params, + ) logger.info("Extracted %d tasks", len(extracted_tasks)) tasks.extend(extracted_tasks) @@ -440,31 +411,21 @@ def backend(graph_module, example_inputs): def create_tvm_compilation_backend(database: ms.database.Database) -> Callable: """ This torchdynamo backend compiles the model using history best record from the - Metaschedule database. + MetaSchedule database. """ def backend(graph_module, example_inputs): - # pylint: disable=import-outside-toplevel - from tvm.ir.transform import PassContext - - # pylint: enable=import-outside-toplevel - jit_mod = torch.jit.trace(graph_module, example_inputs) shape_list = [(f"inp_{idx}", i.shape) for idx, i in enumerate(example_inputs)] ir_mod, params = tvm.relay.frontend.from_pytorch(jit_mod, shape_list) - relay_build = {"graph": tvm.relay.build, "vm": tvm.relay.vm.compile}[ARGS.backend] - with ARGS.target, ms.utils.autotvm_silencer(), database: - with PassContext( - opt_level=3, - config={ - "relay.backend.use_meta_schedule": True, - "relay.backend.use_meta_schedule_dispatch": not IS_CUDA, - "relay.backend.tir_converter": "default", - }, - ): - lib = relay_build(ir_mod, target=ARGS.target, params=params) - + lib = ms.relay_integration.compile_relay( + database=database, + mod=ir_mod, + target=ARGS.target, + params=params, + backend=ARGS.backend, + ) device = tvm.cuda(0) if IS_CUDA else tvm.cpu(0) if ARGS.backend == "graph": @@ -503,7 +464,9 @@ def is_output_correct(output: torch.Tensor, expected: torch.Tensor) -> bool: def performance_experiment( - model_iter_fn: Callable, model: torch.nn.Module, example_inputs: Tuple[torch.Tensor] + model_iter_fn: Callable, + model: torch.nn.Module, + example_inputs: Tuple[torch.Tensor], ) -> str: """ Performs the actual benchmarking @@ -560,11 +523,11 @@ def main(): """ describe() + database = ms.database.JSONDatabase(work_dir=ARGS.work_dir) if not ARGS.mode.should_tune: - ms_database = ms.default_config.database(None, ARGS.work_dir) - if len(ms_database) == 0: + if len(database) == 0: raise RuntimeError( - "Script is runnig in eval mode while the tuning database is empty. " + "Script is running in eval mode while the tuning database is empty. " "Please tune the model first." ) @@ -573,6 +536,7 @@ def main(): "Benchmark is running on CUDA, while --cpu-flush is turned on. " "This flag will have no effect on CUDA." ) + ARGS.cpu_flush = False try: _, name, model, example_inputs, batch_size = runner.load_model( @@ -587,16 +551,27 @@ def main(): logging.exception(f"{ARGS.model} failed to load") return - tuning_tasks: List[ms.ExtractedTask] = [] - task_collect_ctx = torchdynamo.optimize(create_tvm_task_collection_backend(tuning_tasks)) - task_collect_ctx(runner.model_iter_fn)(model, example_inputs) - - database = ms.tune_extracted_tasks( - extracted_tasks=tuning_tasks, - config=get_tune_config(), - work_dir=ARGS.work_dir, - runner=get_metaschedule_runner(), # type: ignore - ) + if ARGS.mode.should_tune: + extracted_tasks: List[ms.ExtractedTask] = [] + task_collect_ctx = torchdynamo.optimize(create_tvm_task_collection_backend(extracted_tasks)) + task_collect_ctx(runner.model_iter_fn)(model, example_inputs) + tasks, task_weights = ms.relay_integration.extracted_tasks_to_tune_contexts( + extracted_tasks=extracted_tasks, + work_dir=ARGS.work_dir, + ) + database = ms.tune.tune_tasks( + tasks=tasks, + task_weights=task_weights, + work_dir=ARGS.work_dir, + max_trials_global=ARGS.num_trials, + max_trials_per_task=ARGS.num_trials_per_task, + runner=get_meta_schedule_runner(), # type: ignore + database=database, + cost_model=ms.cost_model.XGBModel( # type: ignore + extractor=ms.feature_extractor.PerStoreFeature(), + adaptive_training=ARGS.adaptive_training, + ), + ) if ARGS.mode.should_eval: torchdynamo.reset() diff --git a/python/tvm/meta_schedule/testing/tune_onnx.py b/python/tvm/meta_schedule/testing/tune_onnx.py index 6d473ed3237c..a7c177afdca4 100644 --- a/python/tvm/meta_schedule/testing/tune_onnx.py +++ b/python/tvm/meta_schedule/testing/tune_onnx.py @@ -15,18 +15,19 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-docstring -from distutils.util import strtobool import argparse import json import logging -import onnx # type: ignore +from distutils.util import strtobool +import onnx # type: ignore import tvm from tvm import meta_schedule as ms from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc from tvm.relay.frontend import from_onnx from tvm.support import describe -from .tune_utils import generate_input_data, create_timer + +from .tune_utils import create_timer, generate_input_data def _parse_args(): @@ -126,7 +127,7 @@ def _parse_args(): logging.basicConfig( format="%(asctime)s.%(msecs)03d %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S" ) -logging.getLogger("tvm.meta_schedule").setLevel(logging.INFO) +logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) ARGS = _parse_args() @@ -146,33 +147,38 @@ def main(): item["name"]: generate_input_data(item["shape"], item["dtype"]) for item in ARGS.input_shape } - runner = ms.runner.RPCRunner( - rpc_config=ARGS.rpc_config, - evaluator_config=ms.runner.EvaluatorConfig( - number=ARGS.number, - repeat=ARGS.repeat, - min_repeat_ms=ARGS.min_repeat_ms, - enable_cpu_cache_flush=ARGS.cpu_flush, - ), - alloc_repeat=1, - ) - with ms.Profiler() as profiler: - lib = ms.tune_relay( + database = ms.relay_integration.tune_relay( mod=mod, target=ARGS.target, - config=ms.TuneConfig( - strategy="evolutionary", - num_trials_per_iter=64, - max_trials_per_task=ARGS.num_trials, - max_trials_global=ARGS.num_trials, + params=params, + work_dir=ARGS.work_dir, + max_trials_global=ARGS.num_trials, + num_trials_per_iter=64, + runner=ms.runner.RPCRunner( # type: ignore + rpc_config=ARGS.rpc_config, + evaluator_config=ms.runner.EvaluatorConfig( + number=ARGS.number, + repeat=ARGS.repeat, + min_repeat_ms=ARGS.min_repeat_ms, + enable_cpu_cache_flush=ARGS.cpu_flush, + ), + alloc_repeat=1, + ), + cost_model=ms.cost_model.XGBModel( # type: ignore + extractor=ms.feature_extractor.PerStoreFeature(), adaptive_training=ARGS.adaptive_training, ), - runner=runner, # type: ignore - work_dir=ARGS.work_dir, + strategy=ms.search_strategy.EvolutionarySearch(), + ) + lib = ms.relay_integration.compile_relay( + database=database, + mod=mod, + target=ARGS.target, params=params, backend=ARGS.backend, ) + print("Tuning Time:") print(profiler.table()) diff --git a/python/tvm/meta_schedule/testing/tune_relay.py b/python/tvm/meta_schedule/testing/tune_relay.py index 7c5977495db5..de1668c1dd16 100644 --- a/python/tvm/meta_schedule/testing/tune_relay.py +++ b/python/tvm/meta_schedule/testing/tune_relay.py @@ -131,7 +131,7 @@ def _parse_args(): logging.basicConfig( format="%(asctime)s.%(msecs)03d %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S" ) -logging.getLogger("tvm.meta_schedule").setLevel(logging.INFO) +logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) ARGS = _parse_args() @@ -164,30 +164,34 @@ def main(): print(f" input_shape: {item['shape']}") print(f" input_dtype: {item['dtype']}") - runner = ms.runner.RPCRunner( - rpc_config=ARGS.rpc_config, - evaluator_config=ms.runner.EvaluatorConfig( - number=ARGS.number, - repeat=ARGS.repeat, - min_repeat_ms=ARGS.min_repeat_ms, - enable_cpu_cache_flush=ARGS.cpu_flush, - ), - alloc_repeat=1, - ) - with ms.Profiler() as profiler: - lib = ms.tune_relay( + database = ms.relay_integration.tune_relay( mod=mod, target=ARGS.target, - config=ms.TuneConfig( - strategy="evolutionary", - num_trials_per_iter=64, - max_trials_per_task=ARGS.num_trials, - max_trials_global=ARGS.num_trials, + work_dir=ARGS.work_dir, + max_trials_global=ARGS.num_trials, + num_trials_per_iter=64, + params=params, + runner=ms.runner.RPCRunner( # type: ignore + rpc_config=ARGS.rpc_config, + evaluator_config=ms.runner.EvaluatorConfig( + number=ARGS.number, + repeat=ARGS.repeat, + min_repeat_ms=ARGS.min_repeat_ms, + enable_cpu_cache_flush=ARGS.cpu_flush, + ), + alloc_repeat=1, + ), + cost_model=ms.cost_model.XGBModel( # type: ignore + extractor=ms.feature_extractor.PerStoreFeature(), adaptive_training=ARGS.adaptive_training, ), - runner=runner, # type: ignore - work_dir=ARGS.work_dir, + strategy=ms.search_strategy.EvolutionarySearch(), + ) + lib = ms.relay_integration.compile_relay( + database=database, + mod=mod, + target=ARGS.target, params=params, backend=ARGS.backend, ) diff --git a/python/tvm/meta_schedule/testing/tune_te.py b/python/tvm/meta_schedule/testing/tune_te.py index d54d92048ee6..16f9be674f39 100644 --- a/python/tvm/meta_schedule/testing/tune_te.py +++ b/python/tvm/meta_schedule/testing/tune_te.py @@ -15,14 +15,14 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-docstring -from distutils.util import strtobool import argparse import logging +from distutils.util import strtobool from typing import Optional import tvm -from tvm import tir from tvm import meta_schedule as ms +from tvm import tir from tvm.meta_schedule.testing.te_workload import create_te_workload from tvm.support import describe @@ -106,37 +106,36 @@ def _parse_args(): logging.basicConfig( format="%(asctime)s.%(msecs)03d %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S" ) -logging.getLogger("tvm.meta_schedule").setLevel(logging.INFO) +logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) ARGS = _parse_args() def main(): describe() print(f"Workload: {ARGS.workload}") - runner = ms.runner.RPCRunner( - rpc_config=ARGS.rpc_config, - evaluator_config=ms.runner.EvaluatorConfig( - number=ARGS.number, - repeat=ARGS.repeat, - min_repeat_ms=ARGS.min_repeat_ms, - enable_cpu_cache_flush=ARGS.cpu_flush, - ), - alloc_repeat=1, - ) with ms.Profiler() as profiler: - sch: Optional[tir.Schedule] = ms.tune_tir( + sch: Optional[tir.Schedule] = ms.tir_integration.tune_tir( mod=create_te_workload(ARGS.workload, 0), target=ARGS.target, - config=ms.TuneConfig( - strategy="evolutionary", - num_trials_per_iter=64, - max_trials_per_task=ARGS.num_trials, - max_trials_global=ARGS.num_trials, + work_dir=ARGS.work_dir, + max_trials_global=ARGS.num_trials, + num_trials_per_iter=64, + runner=ms.runner.RPCRunner( # type: ignore + rpc_config=ARGS.rpc_config, + evaluator_config=ms.runner.EvaluatorConfig( + number=ARGS.number, + repeat=ARGS.repeat, + min_repeat_ms=ARGS.min_repeat_ms, + enable_cpu_cache_flush=ARGS.cpu_flush, + ), + alloc_repeat=1, + ), + cost_model=ms.cost_model.XGBModel( # type: ignore + extractor=ms.feature_extractor.PerStoreFeature(), adaptive_training=ARGS.adaptive_training, ), - runner=runner, # type: ignore + strategy=ms.search_strategy.EvolutionarySearch(), task_name=ARGS.workload, - work_dir=ARGS.work_dir, ) print("Tuning Time:") diff --git a/python/tvm/meta_schedule/tir_integration.py b/python/tvm/meta_schedule/tir_integration.py new file mode 100644 index 000000000000..975987ebcb67 --- /dev/null +++ b/python/tvm/meta_schedule/tir_integration.py @@ -0,0 +1,155 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""MetaSchedule-TIR integration""" +from typing import Optional, Union + +# isort: off +from typing_extensions import Literal + +# isort: on +from tvm import ir, tir +from tvm.target import Target + +from .builder import Builder +from .cost_model import CostModel +from .database import Database +from .logging import get_loggers_from_work_dir +from .measure_callback import MeasureCallback +from .runner import Runner +from .search_strategy import SearchStrategy +from .space_generator import SpaceGenerator +from .task_scheduler import TaskScheduler +from .tune import tune_tasks +from .tune_context import TuneContext, _normalize_mod +from .utils import fork_seed + + +def tune_tir( + mod: Union[ir.IRModule, tir.PrimFunc], + target: Union[str, Target], + work_dir: str, + max_trials_global: int, + *, + num_trials_per_iter: int = 64, + builder: Builder.BuilderType = "local", + runner: Runner.RunnerType = "local", + database: Database.DatabaseType = "json", + cost_model: CostModel.CostModelType = "xgb", + measure_callbacks: MeasureCallback.CallbackListType = "default", + task_scheduler: TaskScheduler.TaskSchedulerType = "round-robin", + space: SpaceGenerator.SpaceGeneratorType = "post-order-apply", + strategy: SearchStrategy.SearchStrategyType = "evolutionary", + task_name: str = "main", + num_threads: Union[Literal["physical", "logical"], int] = "physical", + seed: Optional[int] = None, +) -> Database: + """Tune a TIR function. + + Parameters + ---------- + mod : Union[ir.IRModule, tir.PrimFunc] + The TIR function to tune. + target : Union[str, Target] + The target to tune for. + work_dir : str + The working directory. + max_trials_global : int + The maximum number of trials to run globally. + num_trials_per_iter : int + The number of trials to run per iteration + builder : Builder.BuilderType + The builder. + runner : Runner.RunnerType + The runner. + database : Database.DatabaseType + The database. + cost_model : CostModel.CostModelType + The cost model. + measure_callbacks : MeasureCallback.CallbackListType + The measure callbacks. + task_scheduler : TaskScheduler.TaskSchedulerType + The task scheduler. + space : SpaceGenerator.SpaceGeneratorType + The space generator. + strategy : SearchStrategy.SearchStrategyType + The search strategy. + task_name : str + The name of the task. + num_threads : Union[Literal["physical", "logical"], int] + The number of threads to use. + seed : Optional[int] + The seed for the random number generator. + + Returns + ------- + database : Database + The database with all tuning records + """ + (logger,) = get_loggers_from_work_dir(work_dir, [task_name]) + (seed,) = fork_seed(seed, n=1) + return tune_tasks( + tasks=[ + TuneContext( + mod=mod, + target=target, + space_generator=space, + search_strategy=strategy, + task_name=task_name, + logger=logger, + rand_state=seed, + num_threads=num_threads, + ).clone() + ], + task_weights=[1.0], + work_dir=work_dir, + max_trials_global=max_trials_global, + max_trials_per_task=max_trials_global, + num_trials_per_iter=num_trials_per_iter, + builder=builder, + runner=runner, + database=database, + cost_model=cost_model, + measure_callbacks=measure_callbacks, + task_scheduler=task_scheduler, + ) + + +def compile_tir( + database: Database, + mod: Union[ir.IRModule, tir.PrimFunc], + target: Union[Target, str], +) -> tir.Schedule: + """Compile a TIR to tir.Schedule, according to the records in the database. + + Parameters + ---------- + database : Database + The database of tuning records. + mod : Union[ir.IRModule, tir.PrimFunc] + The TIR function to tune. + target : Union[str, Target] + The target to tune for. + + Returns + ------- + sch : tir.Schedule + The best schedule found in the database. + """ + mod = _normalize_mod(mod) + if not isinstance(target, Target): + target = Target(target) + return database.query_schedule(mod, target, workload_name="main") diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py index 96b554d4e659..f7a2d4dc376f 100644 --- a/python/tvm/meta_schedule/tune.py +++ b/python/tvm/meta_schedule/tune.py @@ -14,637 +14,99 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""User-facing Tuning API""" -# pylint: disable=import-outside-toplevel -import logging -import logging.config -import os -from os import path as osp -from typing import Any, Callable, Dict, List, NamedTuple, Optional, Union +"""The core tuning API""" +from typing import List, Optional -from tvm.ir import IRModule -from tvm.ir.transform import PassContext -from tvm.runtime import Module, NDArray, vm -from tvm.target import Target -from tvm.te import Tensor, create_prim_func -from tvm.tir import PrimFunc, Schedule - -from . import default_config from .builder import Builder from .cost_model import CostModel -from .database import Database, TuningRecord -from .extracted_task import ExtractedTask +from .database import Database from .measure_callback import MeasureCallback -from .mutator import Mutator -from .postproc import Postproc -from .profiler import Profiler from .runner import Runner -from .schedule_rule import ScheduleRule -from .search_strategy import EvolutionarySearch, ReplayFunc, ReplayTrace -from .space_generator import PostOrderApply, SpaceGenerator -from .task_scheduler import GradientBased, RoundRobin +from .task_scheduler import TaskScheduler from .tune_context import TuneContext -from .utils import autotvm_silencer, batch_parameterize_config - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name - -FnSpaceGenerator = Callable[[], SpaceGenerator] -FnScheduleRule = Callable[[], List[ScheduleRule]] -FnPostproc = Callable[[], List[Postproc]] -FnMutatorProb = Callable[[], Dict[Mutator, float]] - - -class TuneConfig(NamedTuple): - """Configuration for tuning - - Parameters - ---------- - max_trials_global: int - Maximum number of trials to run. - num_trials_per_iter: int - Number of trials to run per iteration. - max_trials_per_task: Optional[int] - Maximum number of trials to run per task. If None, use `max_trials_global`. - task_scheduler: str = "gradient" - Task scheduler to use. - Valid options are: round_robin, gradient. - strategy: str = "evolutionary" - Search strategy to use. - Valid options are: evolutionary, replay_func, replay_trace. - task_scheduler_config: Optional[Dict[str, Any]] = None - Configuration for task scheduler. - search_strategy_config: Optional[Dict[str, Any]] = None - Configuration for search strategy. - logger_config: Optional[Dict[str, Any]] = None - Configuration for logger. - adaptive_training: Optional[bool] = None - Whether adpative training is enabled for cost model. - """ - - max_trials_global: int - num_trials_per_iter: int - max_trials_per_task: Optional[int] = None - task_scheduler: str = "gradient" - strategy: str = "evolutionary" - task_scheduler_config: Optional[Dict[str, Any]] = None - search_strategy_config: Optional[Dict[str, Any]] = None - logger_config: Optional[Dict[str, Any]] = None - adaptive_training: Optional[bool] = None - - def create_strategy(self): - """Create search strategy from configuration""" - cls_tbl = { - "evolutionary": EvolutionarySearch, - "replay_func": ReplayFunc, - "replay_trace": ReplayTrace, - } - if self.strategy not in cls_tbl: - raise ValueError( - f"Invalid search strategy: {self.strategy}. " - "Valid options are: {}".format(", ".join(cls_tbl.keys())) - ) - # `max_trials_per_task` defaults to `max_trials_global` - max_trials_per_task = self.max_trials_per_task - if max_trials_per_task is None: - max_trials_per_task = self.max_trials_global - # `search_strategy_config` defaults to empty dict - config = self.search_strategy_config - if config is None: - config = {} - return cls_tbl[self.strategy]( - num_trials_per_iter=self.num_trials_per_iter, - max_trials_per_task=max_trials_per_task, - **config, - ) - - def create_task_scheduler(self, **kwargs): - """Create task scheduler from configuration""" - cls_tbl = { - "round_robin": RoundRobin, - "gradient": GradientBased, - } - if self.task_scheduler not in cls_tbl: - raise ValueError( - f"Invalid task scheduler: {self.task_scheduler}. " - "Valid options are: {}".format(", ".join(cls_tbl.keys())) - ) - # `task_scheduler_config` defaults to empty dict - config = self.task_scheduler_config - if config is None: - config = {} - return cls_tbl[self.task_scheduler]( - max_trials=self.max_trials_global, - **kwargs, - **config, - ) - - def create_loggers( - self, - log_dir: str, - params: List[Dict[str, Any]], - disable_existing_loggers: bool = False, - ): - """Create loggers from configuration""" - if self.logger_config is None: - config = {} - else: - config = self.logger_config - - config.setdefault("loggers", {}) - config.setdefault("handlers", {}) - config.setdefault("formatters", {}) - - global_logger_name = "tvm.meta_schedule" - global_logger = logging.getLogger(global_logger_name) - if global_logger.level is logging.NOTSET: - global_logger.setLevel(logging.INFO) - - config["loggers"].setdefault( - global_logger_name, - { - "level": logging._levelToName[ # pylint: disable=protected-access - global_logger.level - ], - "handlers": [handler.get_name() for handler in global_logger.handlers] - + [global_logger_name + ".console", global_logger_name + ".file"], - "propagate": False, - }, - ) - config["loggers"].setdefault( - "{logger_name}", - { - "level": "INFO", - "handlers": [ - "{logger_name}.file", - ], - "propagate": False, - }, - ) - config["handlers"].setdefault( - global_logger_name + ".console", - { - "class": "logging.StreamHandler", - "stream": "ext://sys.stdout", - "formatter": "tvm.meta_schedule.standard_formatter", - }, - ) - config["handlers"].setdefault( - global_logger_name + ".file", - { - "class": "logging.FileHandler", - "filename": "{log_dir}/" + __name__ + ".task_scheduler.log", - "mode": "a", - "level": "INFO", - "formatter": "tvm.meta_schedule.standard_formatter", - }, - ) - config["handlers"].setdefault( - "{logger_name}.file", - { - "class": "logging.FileHandler", - "filename": "{log_dir}/{logger_name}.log", - "mode": "a", - "level": "INFO", - "formatter": "tvm.meta_schedule.standard_formatter", - }, - ) - config["formatters"].setdefault( - "tvm.meta_schedule.standard_formatter", - { - "format": "%(asctime)s.%(msecs)03d %(levelname)s %(message)s", - "datefmt": "%Y-%m-%d %H:%M:%S", - }, - ) - - # set up dictConfig loggers - p_config = {"version": 1, "disable_existing_loggers": disable_existing_loggers} - for k, v in config.items(): - if k in ["formatters", "handlers", "loggers"]: - p_config[k] = batch_parameterize_config(v, params) # type: ignore - else: - p_config[k] = v - logging.config.dictConfig(p_config) - - # check global logger - if global_logger.level not in [logging.DEBUG, logging.INFO]: - global_logger.warning( - "Logging level set to %s, please set to logging.INFO" - " or logging.DEBUG to view full log.", - logging._levelToName[global_logger.level], # pylint: disable=protected-access - ) - global_logger.info("Logging directory: %s", log_dir) -def tune_extracted_tasks( - extracted_tasks: List[ExtractedTask], - config: TuneConfig, - work_dir: str, +def tune_tasks( *, - builder: Optional[Builder] = None, - runner: Optional[Runner] = None, - database: Optional[Database] = None, - cost_model: Optional[CostModel] = None, - measure_callbacks: Optional[List[MeasureCallback]] = None, - space: Optional[FnSpaceGenerator] = None, - sch_rules: Optional[FnScheduleRule] = None, - postprocs: Optional[FnPostproc] = None, - mutator_probs: Optional[FnMutatorProb] = None, - num_threads: Optional[int] = None, + tasks: List[TuneContext], + task_weights: List[float], + work_dir: str, + max_trials_global: int, + max_trials_per_task: Optional[int] = None, + num_trials_per_iter: int = 64, + builder: Builder.BuilderType = "local", + runner: Runner.RunnerType = "local", + database: Database.DatabaseType = "json", + cost_model: CostModel.CostModelType = "xgb", + measure_callbacks: MeasureCallback.CallbackListType = "default", + task_scheduler: TaskScheduler.TaskSchedulerType = "gradient", ) -> Database: - """Tune extracted tasks with a given target. + """Tune a list of tasks. Using a task scheduler. Parameters ---------- - extracted_tasks : List[ExtractedTask] - The list of extracted tasks. - config : TuneConfig - The search strategy config. + tasks : List[TuneContext] + The list of tasks to tune. + task_weights : List[float] + The weight of each task. work_dir : str - The working directory to save intermediate results. - builder : Optional[Builder] - The builder to use. - runner : Optional[Runner] - The runner to use. - database : Optional[Database] - The database to use. - cost_model : Optional[CostModel] - The cost model to use. - measure_callbacks : Optional[List[MeasureCallback]] - The callbacks used during tuning. - task_scheduler : Optional[TaskScheduler] - The task scheduler to use. - space : Optional[FnSpaceGenerator] - The space generator to use. - sch_rules : Optional[FnScheduleRule] - The search rules to use. - postprocs : Optional[FnPostproc] - The postprocessors to use. - mutator_probs : Optional[FnMutatorProb] - The probability distribution to use different mutators. - num_threads : Optional[int] - The number of threads to use. + The working directory. + max_trials_global : int + The maximum number of trials to run globally. + max_trials_per_task : Optional[int] + The maximum number of trials to run per task. + num_trials_per_iter : int + The number of trials to run per iteration + builder : Builder.BuilderType + The builder. + runner : Runner.RunnerType + The runner. + database : Database.DatabaseType + The database. + cost_model : CostModel.CostModelType + The cost model. + measure_callbacks : MeasureCallback.CallbackListType + The measure callbacks. + task_scheduler : TaskScheduler.TaskSchedulerType + The task scheduler. Returns ------- database : Database - The database containing all the tuning results. - + The database with all tuning records """ - # pylint: disable=protected-access - # logging directory is set to `work_dir/logs` by default - log_dir = osp.join(work_dir, "logs") - os.makedirs(log_dir, exist_ok=True) - max_width = len(str(len(extracted_tasks) - 1)) - logger_name_pattern = __name__ + ".task_{task_id:0" + f"{max_width}" + "d}_{task_name}" - - config.create_loggers( - log_dir=log_dir, - params=[ - { - "log_dir": log_dir, - "logger_name": logger_name_pattern.format(task_id=i, task_name=task.task_name), - } - for i, task in enumerate(extracted_tasks) - ], - ) - - logger.info("Working directory: %s", work_dir) - database = default_config.database(database, work_dir) - builder = default_config.builder(builder) - runner = default_config.runner(runner) - cost_model = default_config.cost_model(cost_model, config.adaptive_training) - measure_callbacks = default_config.callbacks(measure_callbacks) - # parse the tuning contexts - tune_contexts = [] - for i, task in enumerate(extracted_tasks): - assert len(task.dispatched) == 1, "Only size 1 dispatched task list is supported for now" - tune_contexts.append( - TuneContext( - mod=default_config.mod(task.dispatched[0]), - target=task.target, - space_generator=default_config.space_generator(space), - search_strategy=config.create_strategy(), - sch_rules=default_config.schedule_rules(sch_rules, task.target), - postprocs=default_config.postproc(postprocs, task.target), - mutator_probs=default_config.mutator_probs(mutator_probs, task.target), - task_name=task.task_name, - logger=logging.getLogger( - logger_name_pattern.format(task_id=i, task_name=task.task_name) - ), - num_threads=num_threads, - ) + if len(tasks) != len(task_weights): + raise ValueError( + f"Length of tasks ({len(tasks)}) and task_weights ({len(task_weights)}) do not match." ) - # parse the task scheduler - # pylint: enable=protected-access - task_scheduler = config.create_task_scheduler( - tasks=tune_contexts, - task_weights=[float(t.weight) for t in extracted_tasks], + if max_trials_per_task is None: + max_trials_per_task = max_trials_global + if not isinstance(builder, Builder): + builder = Builder.create(builder) + if not isinstance(runner, Runner): + runner = Runner.create(runner) + if database == "json": + database = Database.create(database, work_dir=work_dir) + elif not isinstance(database, Database): + database = Database.create(database) + if not isinstance(cost_model, CostModel): + cost_model = CostModel.create(cost_model) + if isinstance(measure_callbacks, MeasureCallback): + measure_callbacks = [measure_callbacks] + elif measure_callbacks == "default": + measure_callbacks = MeasureCallback.create(measure_callbacks) + if not isinstance(task_scheduler, TaskScheduler): + task_scheduler = TaskScheduler.create(task_scheduler) + task_scheduler.tune( + tasks=tasks, + task_weights=task_weights, + max_trials_global=max_trials_global, + max_trials_per_task=max_trials_per_task, + num_trials_per_iter=num_trials_per_iter, builder=builder, runner=runner, - database=database, - cost_model=cost_model, measure_callbacks=measure_callbacks, - ) - if config.max_trials_global > 0: - task_scheduler.tune() - cost_model.save(osp.join(work_dir, "cost_model.xgb")) - return database - - -def tune_tir( - mod: Union[IRModule, PrimFunc], - target: Union[str, Target], - config: TuneConfig, - work_dir: str, - *, - builder: Optional[Builder] = None, - runner: Optional[Runner] = None, - database: Optional[Database] = None, - cost_model: Optional[CostModel] = None, - measure_callbacks: Optional[List[MeasureCallback]] = None, - space: Optional[FnSpaceGenerator] = None, - blocks: Optional[List[str]] = None, - sch_rules: Optional[FnScheduleRule] = None, - postprocs: Optional[FnPostproc] = None, - mutator_probs: Optional[FnMutatorProb] = None, - task_name: str = "main", - num_threads: Optional[int] = None, -) -> Optional[Schedule]: - """Tune a TIR IRModule with a given target. - - Parameters - ---------- - mod : Union[IRModule, PrimFunc] - The module to tune. - target : Union[str, Target] - The target to tune for. - config : TuneConfig - The search strategy config. - work_dir : str - The working directory to save intermediate results. - builder : Optional[Builder] - The builder to use. - runner : Optional[Runner] - The runner to use. - database : Optional[Database] - The database to use. - cost_model : Optional[CostModel] - The cost model to use. - measure_callbacks : Optional[List[MeasureCallback]] - The callbacks used during tuning. - space : Optional[FnSpaceGenerator] - The space generator to use. - blocks : Optional[List[str]] - A list of block names specifying blocks to be tuned. Note that if - the list is not None, blocks outside this list will not be tuned. - Only one of this argument and space may be provided. - sch_rules : Optional[FnScheduleRule] - The search rules to use. - postprocs : Optional[FnPostproc] - The postprocessors to use. - mutator_probs : Optional[FnMutatorProb] - The probability distribution to use different mutators. - task_name : str - The name of the function to extract schedules from. - num_threads : Optional[int] - The number of threads to use - - Returns - ------- - sch : Optional[Schedule] - The tuned schedule. - """ - # logging directory is set to `work_dir/logs` by default - log_dir = osp.join(work_dir, "logs") - os.makedirs(log_dir, exist_ok=True) - - config.create_loggers( - log_dir=log_dir, - params=[{"log_dir": log_dir, "logger_name": __name__ + f".task_{task_name}"}], - ) - - if blocks is not None: - assert space is None, "Can not specify blocks to tune when a search space is given." - # Create a filter function to identify named blocks. - def _f_block_filter(block, target_names) -> bool: - return block.name_hint in target_names - - # Create a space generator that targets specific blocks. - space = PostOrderApply(f_block_filter=lambda block: _f_block_filter(block, blocks)) - - # pylint: disable=protected-access - mod = default_config.mod(mod) - target = default_config.target(target) - # pylint: enable=protected-access - database = tune_extracted_tasks( - extracted_tasks=[ - ExtractedTask( - task_name=task_name, - mod=mod, - dispatched=[mod], - target=target, - weight=1, - ), - ], - config=config, - work_dir=work_dir, - builder=builder, - runner=runner, - database=database, - cost_model=cost_model, - measure_callbacks=measure_callbacks, - space=space, - sch_rules=sch_rules, - postprocs=postprocs, - mutator_probs=mutator_probs, - num_threads=num_threads, - ) - with Profiler.timeit("PostTuningCompilation"): - bests: List[TuningRecord] = database.get_top_k(database.commit_workload(mod), top_k=1) - if not bests: - return None - assert len(bests) == 1 - sch = Schedule(mod) - bests[0].trace.apply_to_schedule(sch, remove_postproc=False) - return sch - - -def tune_te( - tensors: List[Tensor], - target: Union[str, Target], - config: TuneConfig, - work_dir: str, - *, - task_name: str = "main", - builder: Optional[Builder] = None, - runner: Optional[Runner] = None, - database: Optional[Database] = None, - cost_model: Optional[CostModel] = None, - measure_callbacks: Optional[List[MeasureCallback]] = None, - space: Optional[FnSpaceGenerator] = None, - sch_rules: Optional[FnScheduleRule] = None, - postprocs: Optional[FnPostproc] = None, - mutator_probs: Optional[FnMutatorProb] = None, - num_threads: Optional[int] = None, -) -> Optional[Schedule]: - """Tune a TE compute DAG with a given target. - - Parameters - ---------- - tensor : List[Tensor] - The list of input/output tensors of the TE compute DAG. - target : Union[str, Target] - The target to tune for. - config : TuneConfig - The search strategy config. - task_name : str - The name of the task. - work_dir : str - The working directory to save intermediate results. - builder : Optional[Builder] - The builder to use. - runner : Optional[Runner] - The runner to use. - database : Optional[Database] - The database to use. - measure_callbacks : Optional[List[MeasureCallback]] - The callbacks used during tuning. - - Returns - ------- - sch : Optional[Schedule] - The tuned schedule. - """ - with Profiler.timeit("CreatePrimFunc"): - func = create_prim_func(tensors) - return tune_tir( - mod=func, - target=target, - config=config, - work_dir=work_dir, - task_name=task_name, - builder=builder, - runner=runner, - database=database, - cost_model=cost_model, - measure_callbacks=measure_callbacks, - space=space, - sch_rules=sch_rules, - postprocs=postprocs, - mutator_probs=mutator_probs, - num_threads=num_threads, - ) - - -def tune_relay( - mod: IRModule, - target: Union[str, Target], - config: TuneConfig, - work_dir: str, - *, - backend: str = "graph", - params: Optional[Dict[str, NDArray]] = None, - builder: Optional[Builder] = None, - runner: Optional[Runner] = None, - database: Optional[Database] = None, - cost_model: Optional[CostModel] = None, - measure_callbacks: Optional[List[MeasureCallback]] = None, - space: Optional[FnSpaceGenerator] = None, - sch_rules: Optional[FnScheduleRule] = None, - postprocs: Optional[FnPostproc] = None, - mutator_probs: Optional[FnMutatorProb] = None, - num_threads: Optional[int] = None, - executor=None, -) -> Union[Module, vm.Executable]: - """Tune a Relay IRModule with a given target. - - Parameters - ---------- - mod : IRModule - The module to tune. - target : Union[str, Target] - The target to tune for. - config : TuneConfig - The search strategy config. - params : Optional[Dict[str, tvm.runtime.NDArray]] - The associated parameters of the program - task_name : str - The name of the task. - work_dir : str - The working directory to save intermediate results. - builder : Optional[Builder] - The builder to use. - runner : Optional[Runner] - The runner to use. - database : Optional[Database] - The database to use. - measure_callbacks : Optional[List[MeasureCallback]] - The callbacks used during tuning. - backend : str = "graph" - The backend to use for relay compilation(graph / vm). - executor : relay.backend.Executor - The executor to be passed to relay.build(...). In particular, its link-params - attribute affects task extration and workload database look up. - - Returns - ------- - lib : Union[Module, tvm.runtime.vm.Executable] - The built runtime module or vm Executable for the given relay workload. - """ - # pylint: disable=import-outside-toplevel - from tvm import relay - - from .relay_integration import extract_task_from_relay - - # pylint: disable=protected-access, enable=import-outside-toplevel - target = default_config.target(target) - # pylint: enable=protected-access, - # parse the tuning contexts - - if executor is None: - executor = relay.backend.Executor("graph") - - if "link-params" in executor.attrs: - link_params = executor.attrs["link-params"] - else: - link_params = False - - with Profiler.timeit("TaskExtraction"): - pass_config = { - "relay.FuseOps.link_params": link_params, - "relay.backend.use_meta_schedule": True, - "relay.backend.tir_converter": "default", - } - extracted_tasks = extract_task_from_relay(mod, target, params, pass_config=pass_config) - - database = tune_extracted_tasks( - extracted_tasks, - config, - work_dir, - builder=builder, - runner=runner, database=database, cost_model=cost_model, - measure_callbacks=measure_callbacks, - space=space, - sch_rules=sch_rules, - postprocs=postprocs, - mutator_probs=mutator_probs, - num_threads=num_threads, ) - - with Profiler.timeit("PostTuningCompilation"): - with target, autotvm_silencer(), database: - with PassContext( - opt_level=3, - config={ - "relay.backend.use_meta_schedule": True, - "relay.backend.use_meta_schedule_dispatch": target.kind.name != "cuda", - "relay.backend.tir_converter": "default", - }, - ): - if backend == "graph": - return relay.build(mod, target=target, params=params, executor=executor) - - # Executor is not supported by VM - return relay.vm.compile(mod, target=target, params=params) + return database diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index 29cd94110c0c..38a46ebe757e 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -16,38 +16,49 @@ # under the License. """Meta Schedule tuning context.""" -import logging -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Union + +# isort: off +from typing_extensions import Literal + +# isort: on from tvm import IRModule from tvm._ffi import register_object -from tvm.meta_schedule.utils import cpu_count, make_logging_func from tvm.runtime import Object from tvm.target import Target from tvm.tir import PrimFunc, Schedule from . import _ffi_api +from .logging import Logger, get_logger, get_logging_func +from .utils import cpu_count if TYPE_CHECKING: from .cost_model import CostModel from .database import Database - from .mutator import Mutator - from .postproc import Postproc from .runner import RunnerResult - from .schedule_rule import ScheduleRule from .search_strategy import MeasureCandidate, SearchStrategy - from .space_generator import ScheduleFn, ScheduleFnType, SpaceGenerator - from .tune import TuneConfig + from .space_generator import SpaceGenerator + + +def _normalize_mod(mod: Union[PrimFunc, IRModule]) -> IRModule: + """Normalize the input to an IRModule""" + if isinstance(mod, PrimFunc): + mod = mod.with_attr("global_symbol", "main") + mod = mod.with_attr("tir.noalias", True) + mod = IRModule({"main": mod}) + if not isinstance(mod, IRModule): + raise TypeError(f"Expected `mod` to be PrimFunc or IRModule, but gets: {mod}") + func_names = mod.get_global_vars() + (func_name,) = func_names + if len(func_names) == 1 and func_name.name_hint != "main": + mod = IRModule({"main": mod[func_name]}) + return mod @register_object("meta_schedule.TuneContext") class TuneContext(Object): - """ - The tune context class is designed to contain all resources for a tuning task. - - Different tuning tasks are separated in different TuneContext classes, but different classes in - the same task can interact with each other through tune context. Most classes have a function - to initialize with a tune context. + """The tune context class is designed to contain all resources for a tuning task. Parameters ---------- @@ -57,22 +68,9 @@ class TuneContext(Object): The target to be optimized for. space_generator : Union[None, ScheduleFnType, SpaceGenerator] = None The design space generator. - search_strategy : Union[None, TuneConfig, SearchStrategy] = None + search_strategy : Union[None, SearchStrategy] = None The search strategy. if None, the strategy is left blank. - If TuneConfig, the strategy is initialized with the TuneConfig.create_strategy(). - sch_rules: Union[None, str, List[ScheduleRule]] = None, - The schedule rules. - If None, use an empty list of rules. - if "default", use target-default rules. - postprocs: Union[None, str, List[Postproc"]] = None, - The postprocessors. - If None, use an empty list of rules. - if "default", use target-default rules. - mutator_probs: Union[None, str, Dict[Mutator, float]] - Mutators and their probability mass. - If None, use an empty list of rules. - if "default", use target-default rules. task_name : Optional[str] = None The name of the tuning task. logger : logging.Logger @@ -82,24 +80,14 @@ class TuneContext(Object): Need to be in integer in [1, 2^31-1], -1 means using random number. num_threads : int = None The number of threads to be used, None means using the logical cpu count. - - Note - ---- - In most cases, mod and target should be available in the tuning context. They are "Optional" - because we allow the user to customize the tuning context, along with other classes, sometimes - without mod and target. E.g., we can have a stand alone search strategy that generates measure - candidates without initializing with the tune context. """ mod: Optional[IRModule] target: Optional[Target] space_generator: Optional["SpaceGenerator"] search_strategy: Optional["SearchStrategy"] - sch_rules: List["ScheduleRule"] - postprocs: List["Postproc"] - mutator_probs: Optional[Dict["Mutator", float]] task_name: str - logger: Optional[logging.Logger] + logger: Optional[Logger] rand_state: int num_threads: int @@ -107,114 +95,57 @@ def __init__( self, mod: Optional[IRModule] = None, *, - target: Optional[Target] = None, - space_generator: Union[None, "ScheduleFnType", "ScheduleFn", "SpaceGenerator"] = None, - search_strategy: Union[None, "SearchStrategy", "TuneConfig"] = None, - sch_rules: Union[None, str, List["ScheduleRule"]] = None, - postprocs: Union[None, str, List["Postproc"]] = None, - mutator_probs: Union[None, str, Dict["Mutator", float]] = None, + target: Union[Target, str, None] = None, + space_generator: Union["SpaceGenerator.SpaceGeneratorType", None] = None, + search_strategy: Union["SearchStrategy.SearchStrategyType", None] = None, task_name: str = "main", - logger: Optional[logging.Logger] = None, rand_state: int = -1, - num_threads: Optional[int] = None, + num_threads: Union[int, Literal["physical", "logical"]] = "physical", + logger: Optional[Logger] = None, ): # pylint: disable=import-outside-toplevel - from . import default_config - from .space_generator import ScheduleFn - from .tune import TuneConfig + import tvm.tir.tensor_intrin # pylint: disable=unused-import + + from .search_strategy import SearchStrategy + from .space_generator import SpaceGenerator # pylint: enable=import-outside-toplevel if isinstance(mod, PrimFunc): - mod = IRModule.from_expr(mod) - if callable(space_generator): - space_generator = ScheduleFn(space_generator) - if isinstance(search_strategy, TuneConfig): - search_strategy = search_strategy.create_strategy() - if isinstance(sch_rules, str): - if sch_rules == "default": - if target is None: - raise ValueError("target is required when sch_rules is 'default'") - sch_rules = default_config.schedule_rules(None, target) - else: - raise ValueError("sch_rules should be a list of ScheduleRule or 'default'") - if isinstance(postprocs, str): - if postprocs == "default": - if target is None: - raise ValueError("target is required when postprocs is 'default'") - postprocs = default_config.postproc(None, target) - else: - raise ValueError("postprocs should be a list of Postproc or 'default'") - if isinstance(mutator_probs, str): - if mutator_probs == "default": - if target is None: - raise ValueError("target is required when mutator_probs is 'default'") - mutator_probs = default_config.mutator_probs(None, target) + mod = _normalize_mod(mod) + if target is not None: + if not isinstance(target, Target): + target = Target(target) + if space_generator is not None: + if not isinstance(space_generator, SpaceGenerator): + space_generator = SpaceGenerator.create(space_generator) + if search_strategy is not None: + if not isinstance(search_strategy, SearchStrategy): + search_strategy = SearchStrategy.create(search_strategy) if logger is None: - self.logger = logging.getLogger(__name__) - else: - self.logger = None - if num_threads is None: - num_threads = cpu_count(logical=False) + logger = get_logger(__name__) + if not isinstance(num_threads, int): + if num_threads == "physical": + num_threads = cpu_count(logical=False) + elif num_threads == "logical": + num_threads = cpu_count(logical=True) + else: + raise ValueError( + f"Invalid num_threads: {num_threads}, " + "should be either an integer, 'physical', or 'logical'" + ) self.__init_handle_by_constructor__( _ffi_api.TuneContext, # type: ignore # pylint: disable=no-member mod, target, space_generator, search_strategy, - sch_rules, - postprocs, - mutator_probs, task_name, - make_logging_func(logger), - rand_state, num_threads, + rand_state, + get_logging_func(logger), ) _ffi_api.TuneContextInitialize(self) # type: ignore # pylint: disable=no-member - def _set_measure_candidates(self, candidates): - """Set candidates in a tuning context. - - Parameters - ---------- - candidates : List[MeasureCandidate] - A list of measure candidates for the tuning context. - """ - _ffi_api.TuneContextSetMeasureCandidates(self, candidates) # type: ignore # pylint: disable=no-member - - def _send_to_builder(self, builder): - """Send candidates to builder. - - Parameters - ---------- - builder : Builder - The builder for building the candidates. - """ - _ffi_api.TuneContextSendToBuilder(self, builder) # type: ignore # pylint: disable=no-member - - def _send_to_runner(self, runner): - """Send candidates to runner. - - Parameters - ---------- - runner : Runner - The runner for running the candidates. - """ - _ffi_api.TuneContextSendToRunner(self, runner) # type: ignore # pylint: disable=no-member - - def _join(self): - """Join the runner processes. - - Returns - ------- - result : List[RunnerResult] - The runner results. - """ - return _ffi_api.TuneContextJoin(self) # type: ignore # pylint: disable=no-member - - def _clear_measure_state(self): - """Clear the measure states.""" - _ffi_api.TuneContextClearMeasureState(self) # type: ignore # pylint: disable=no-member - def generate_design_space(self) -> List[Schedule]: """Generate design spaces given a module. @@ -236,6 +167,8 @@ def generate_design_space(self) -> List[Schedule]: def pre_tuning( self, + max_trials: int, + num_trials_per_iter: int = 64, design_spaces: Optional[List[Schedule]] = None, database: Optional["Database"] = None, cost_model: Optional["CostModel"] = None, @@ -246,6 +179,10 @@ def pre_tuning( Parameters ---------- + max_trials : int + The maximum number of trials to be executed. + num_trials_per_iter : int = 64 + The number of trials to be executed per iteration. design_spaces : Optional[List[Schedule]] The design spaces used during tuning process. If None, use the outcome of `self.generate_design_space()`. @@ -278,7 +215,13 @@ def pre_tuning( if cost_model is None: if isinstance(self.search_strategy, EvolutionarySearch): cost_model = RandomModel() # type: ignore - return self.search_strategy.pre_tuning(design_spaces, database, cost_model) + return self.search_strategy.pre_tuning( + max_trials, + num_trials_per_iter, + design_spaces, + database, + cost_model, + ) def post_tuning(self) -> None: """A method to be called for SearchStrategy to do necessary cleanup after tuning. diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py index 7b7c4a68653d..eb3c6437603c 100644 --- a/python/tvm/meta_schedule/utils.py +++ b/python/tvm/meta_schedule/utils.py @@ -16,12 +16,11 @@ # under the License. """Utilities for meta schedule""" import ctypes -import logging import os import shutil -from contextlib import contextmanager -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, List, Optional, Union +import numpy as np # type: ignore import psutil # type: ignore from tvm._ffi import get_global_func, register_func from tvm.error import TVMError @@ -86,7 +85,7 @@ def method(*args, **kwargs): assert isinstance(cls.__base__, type) assert hasattr( cls, "_tvm_metadata" - ), "Please use the user-facing method overiding class, i.e., PyRunner." + ), "Please use the user-facing method overriding class, i.e., PyRunner." base = cls.__base__ metadata = getattr(base, "_tvm_metadata") @@ -114,7 +113,10 @@ def __init__(self, *args, **kwargs): def __getattr__(self, name: str): """Bridge the attribute function.""" - return self._inst.__getattribute__(name) + try: + return self._inst.__getattribute__(name) + except AttributeError: + return super(TVMDerivedObject, self).__getattr__(name) def __setattr__(self, name, value): if name not in ["_inst", "key", "handle"]: @@ -157,14 +159,6 @@ def _cpu_count_impl(logical: bool = True) -> int: return psutil.cpu_count(logical=logical) or 1 -@register_func("meta_schedule._process_error_message") -def _process_error_message(error_msg: str) -> str: - error_msg_lines = str(error_msg).splitlines() - if len(error_msg_lines) >= 50: - return "\n".join(error_msg_lines[:25] + ["..."] + error_msg_lines[-25:]) - return error_msg - - def cpu_count(logical: bool = True) -> int: """Return the number of logical or physical CPUs in the system @@ -193,6 +187,14 @@ def cpu_count(logical: bool = True) -> int: return _cpu_count_impl(logical) +@register_func("meta_schedule.using_ipython") +def _using_ipython(): + try: + return get_ipython().__class__.__name__ == "ZMQInteractiveShell" # type: ignore + except NameError: + return False + + def get_global_func_with_default_on_worker( name: Union[None, str, Callable], default: Callable, @@ -335,114 +337,7 @@ def _to_hex_address(handle: ctypes.c_void_p) -> str: return hex(ctypes.cast(handle, ctypes.c_void_p).value) -@contextmanager -def autotvm_silencer(): - """A context manager that silences autotvm warnings.""" - from tvm import autotvm # pylint: disable=import-outside-toplevel - - silent = autotvm.GLOBAL_SCOPE.silent - autotvm.GLOBAL_SCOPE.silent = True - try: - yield - finally: - autotvm.GLOBAL_SCOPE.silent = silent - - -def make_logging_func(logger: logging.Logger) -> Optional[Callable]: - """Get the logging function. - Parameters - ---------- - logger : logging.Logger - The logger instance. - Returns - ------- - result : Optional[Callable] - The function to do the specified level of logging. - """ - if logger is None: - return None - - level2log = { - logging.DEBUG: logger.debug, - logging.INFO: logger.info, - logging.WARNING: logger.warning, - logging.ERROR: logger.error, - # logging.FATAL not included - } - - def logging_func(level: int, msg: str): - def clear_notebook_output(): - from IPython.display import clear_output # type: ignore # pylint: disable=import-outside-toplevel - - clear_output(wait=True) - - if level < 0: - clear_notebook_output() - else: - level2log[level](msg) - - return logging_func - - -@register_func("meta_schedule.using_ipython") -def _check_ipython_env(): - try: - return get_ipython().__class__.__name__ == "ZMQInteractiveShell" # type: ignore - except NameError: - return False - - -def parameterize_config(config: Dict[str, Any], params: Dict[str, str]) -> Dict[str, Any]: - """Parameterize the given configuration. - - Parameters - ---------- - config : Dict[str, Any] - The given config dict. - Params : Dict[str, str] - The given parameters. - - Returns - ------- - result : Dict[str, Any] - The parameterized configuration. - """ - result = {} - for k, v in config.items(): - if isinstance(k, str): - k = k.format(**params) - if isinstance(v, str): - v = v.format(**params) - elif isinstance(v, dict): - v = parameterize_config(v, params) - elif isinstance(v, list): - v = [t.format(**params) for t in v] - result[k] = v - return result - - -def batch_parameterize_config( - config: Dict[str, Any], params: List[Dict[str, str]] -) -> Dict[str, Any]: - """Parameterize the given configuration with multiple parameters sets. - - Parameters - ---------- - config : Dict[str, Any] - The given config dict. - Params : List[Dict[str, str]] - List of the given multiple parameters sets. - - Returns - ------- - result : Dict[str, Any] - The parameterized configuration. - """ - results = {} - for name, cfg in config.items(): - for p in params: - p_name = name.format(**p) - if p_name not in results: - p_cfg = parameterize_config(cfg, p) - results[p_name] = p_cfg - return results +def fork_seed(seed: Optional[int], n: int) -> List[int]: + # fmt: off + return np.random.RandomState(seed=seed).randint(1, 2 ** 30, size=n).tolist() + # fmt: on diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py index 3374d18dff80..86dd2eee5cd7 100644 --- a/python/tvm/tir/tensor_intrin/cuda.py +++ b/python/tvm/tir/tensor_intrin/cuda.py @@ -16,13 +16,14 @@ # under the License. # pylint: disable=invalid-name,missing-function-docstring """Intrinsics for tensorization on NVIDIA GPU.""" -from typing import Tuple, Dict +from typing import Dict, Tuple + from tvm.script import tir as T from tvm.tir.function import PrimFunc -from .. import IntImm, Cast + from ..._ffi import register_func from ...runtime import convert -from .. import TensorIntrin +from .. import Cast, IntImm, TensorIntrin def shared_16x16_to_ldmatrix_32x8_layout(i, j): diff --git a/src/meta_schedule/measure_callback/add_to_database.cc b/src/meta_schedule/measure_callback/add_to_database.cc index 26399276c933..68a4b93ea96f 100644 --- a/src/meta_schedule/measure_callback/add_to_database.cc +++ b/src/meta_schedule/measure_callback/add_to_database.cc @@ -27,12 +27,12 @@ class AddToDatabaseNode : public MeasureCallbackNode { const Array& measure_candidates, const Array& builder_results, const Array& runner_results) final { - if (!task_scheduler->database.defined()) { + if (!task_scheduler->database_.defined()) { return; } auto _ = Profiler::TimedScope("MeasureCallback/AddToDatabase"); - TuneContext task = task_scheduler->tasks[task_id]; - Database database = task_scheduler->database.value(); + TuneContext task = task_scheduler->tasks_[task_id]->ctx; + Database database = task_scheduler->database_.value(); Workload workload = database->CommitWorkload(task->mod.value()); Target target = task->target.value(); ICHECK_EQ(runner_results.size(), measure_candidates.size()); diff --git a/src/meta_schedule/measure_callback/echo_statistics.cc b/src/meta_schedule/measure_callback/echo_statistics.cc deleted file mode 100644 index fb1064266566..000000000000 --- a/src/meta_schedule/measure_callback/echo_statistics.cc +++ /dev/null @@ -1,134 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include - -#include "../utils.h" - -namespace tvm { -namespace meta_schedule { - -constexpr const double kMaxTime = 1e10; - -std::string GetTaskName(const TuneContext& task, int task_id) { - std::ostringstream os; - os << "Task #" << task_id << ": " << task->task_name; - return os.str(); -} - -struct TaskInfo { - std::string name; - double flop = 0.0; - int trials = -1; - int best_round = -1; - double best_ms = kMaxTime; - double best_gflops = 0.0; - int error_count = 0; - PackedFunc logging_func; - - explicit TaskInfo(const String& name, PackedFunc logging_func) - : name(name), logging_func(logging_func) {} - - void Update(double run_ms) { - ++trials; - if (run_ms < best_ms) { - best_ms = run_ms; - best_round = trials; - best_gflops = flop / run_ms / 1e6; - } - TVM_PY_LOG(INFO, logging_func) << "[" << name << "] Trial #" << trials // - << std::fixed << std::setprecision(4) // - << ": GFLOPs: " << (flop / run_ms / 1e6) // - << ". Time: " << run_ms << " ms" // - << ". Best GFLOPs: " << best_gflops; - } - - void UpdateError(std::string err, const MeasureCandidate& candidate) { - static const auto* f_proc = runtime::Registry::Get("meta_schedule._process_error_message"); - ICHECK(f_proc != nullptr); - err = (*f_proc)(err).operator std::string(); - ++error_count; - ++trials; - TVM_PY_LOG(INFO, logging_func) - << "[" << name << "] Trial #" << trials // - << std::fixed << std::setprecision(4) // - << ": Error in building: " << err << "\n" - << tir::AsTVMScript(candidate->sch->mod()) << "\n" - << Concat(candidate->sch->trace().value()->AsPython(false), "\n"); - } -}; - -class EchoStatisticsNode : public MeasureCallbackNode { - public: - void Apply(const TaskScheduler& task_scheduler, int task_id, - const Array& measure_candidates, - const Array& builder_results, - const Array& runner_results) final { - auto _ = Profiler::TimedScope("MeasureCallback/EchoStatistics"); - if (this->task_info.empty()) { - SetupTaskInfo(task_scheduler->tasks); - } - ICHECK_EQ(measure_candidates.size(), builder_results.size()); - ICHECK_EQ(measure_candidates.size(), runner_results.size()); - int n = measure_candidates.size(); - TuneContext task = task_scheduler->tasks[task_id]; - TaskInfo& info = this->task_info[task_id]; - std::string task_name = GetTaskName(task, task_id); - for (int i = 0; i < n; ++i) { - MeasureCandidate candidate = measure_candidates[i]; - BuilderResult builder_result = builder_results[i]; - RunnerResult runner_result = runner_results[i]; - if (Optional err = builder_result->error_msg) { - info.UpdateError(err.value(), candidate); - } else if (Optional err = runner_result->error_msg) { - info.UpdateError(err.value(), candidate); - } else { - ICHECK(runner_result->run_secs.defined()); - info.Update(GetRunMsMedian(runner_result)); - } - } - } - - void SetupTaskInfo(const Array& tasks) { - task_info.reserve(tasks.size()); - int task_id = 0; - for (const TuneContext& task : tasks) { - task_info.push_back(TaskInfo(GetTaskName(task, task_id), task->logging_func)); - TaskInfo& info = task_info.back(); - info.flop = tir::EstimateTIRFlops(task->mod.value()); - ++task_id; - } - } - - std::vector task_info; - - static constexpr const char* _type_key = "meta_schedule.EchoStatistics"; - TVM_DECLARE_FINAL_OBJECT_INFO(EchoStatisticsNode, MeasureCallbackNode); -}; - -MeasureCallback MeasureCallback::EchoStatistics() { - ObjectPtr n = make_object(); - return MeasureCallback(n); -} - -TVM_REGISTER_NODE_TYPE(EchoStatisticsNode); -TVM_REGISTER_GLOBAL("meta_schedule.MeasureCallbackEchoStatistics") - .set_body_typed(MeasureCallback::EchoStatistics); - -} // namespace meta_schedule -} // namespace tvm diff --git a/src/meta_schedule/measure_callback/measure_callback.cc b/src/meta_schedule/measure_callback/measure_callback.cc index ebe63e7b76f1..f16fb73c520c 100644 --- a/src/meta_schedule/measure_callback/measure_callback.cc +++ b/src/meta_schedule/measure_callback/measure_callback.cc @@ -39,6 +39,14 @@ MeasureCallback MeasureCallback::PyMeasureCallback(PyMeasureCallbackNode::FApply return MeasureCallback(n); } +Array MeasureCallback::Default() { + return { + MeasureCallback::AddToDatabase(), + MeasureCallback::RemoveBuildArtifact(), + MeasureCallback::UpdateCostModel(), + }; +} + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { const auto* self = n.as(); @@ -55,6 +63,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.MeasureCallbackApply") .set_body_method(&MeasureCallbackNode::Apply); TVM_REGISTER_GLOBAL("meta_schedule.MeasureCallbackPyMeasureCallback") .set_body_typed(MeasureCallback::PyMeasureCallback); +TVM_REGISTER_GLOBAL("meta_schedule.MeasureCallbackDefault") + .set_body_typed(MeasureCallback::Default); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/measure_callback/update_cost_model.cc b/src/meta_schedule/measure_callback/update_cost_model.cc index 8851345c43b0..0563699ba6b9 100644 --- a/src/meta_schedule/measure_callback/update_cost_model.cc +++ b/src/meta_schedule/measure_callback/update_cost_model.cc @@ -28,11 +28,12 @@ class UpdateCostModelNode : public MeasureCallbackNode { const Array& builder_results, const Array& runner_results) final { auto _ = Profiler::TimedScope("MeasureCallback/UpdateCostModel"); - TuneContext task = task_scheduler->tasks[task_id]; - ICHECK(task_scheduler->cost_model.defined()) - << "Cost model must be defined for the task scheduler!"; + const TaskRecord& task = task_scheduler->tasks_[task_id]; + if (!task_scheduler->cost_model_.defined()) { + return; + } + CostModel cost_model = task_scheduler->cost_model_.value(); ICHECK(task->measure_candidates.defined()) << "Task's measure candidates must be present!"; - CostModel cost_model = task_scheduler->cost_model.value(); ICHECK_EQ(measure_candidates.size(), builder_results.size()); ICHECK_EQ(runner_results.size(), builder_results.size()); int n = builder_results.size(); @@ -46,7 +47,7 @@ class UpdateCostModelNode : public MeasureCallbackNode { pruned_runner_result.push_back(runner_results[i]); } } - cost_model->Update(task, pruned_candidate, pruned_runner_result); + cost_model->Update(task->ctx, pruned_candidate, pruned_runner_result); } static constexpr const char* _type_key = "meta_schedule.UpdateCostModel"; diff --git a/src/meta_schedule/mutator/mutator.cc b/src/meta_schedule/mutator/mutator.cc index 25312ab61f99..8e9bfc8bde4b 100644 --- a/src/meta_schedule/mutator/mutator.cc +++ b/src/meta_schedule/mutator/mutator.cc @@ -51,6 +51,31 @@ Mutator Mutator::PyMutator( return Mutator(n); } +Map Mutator::DefaultLLVM() { + return Map{ + {Mutator::MutateTileSize(), FloatImm(DataType::Float(64), 0.9)}, + {Mutator::MutateComputeLocation(), FloatImm(DataType::Float(64), 0.05)}, + {Mutator::MutateUnroll(), FloatImm(DataType::Float(64), 0.03)}, + {Mutator::MutateParallel(/*max_jobs_per_core=*/16), FloatImm(DataType::Float(64), 0.02)}}; +} + +Map Mutator::DefaultCUDA() { + return Map{ + {Mutator::MutateTileSize(), FloatImm(DataType::Float(64), 0.9)}, + {Mutator::MutateUnroll(), FloatImm(DataType::Float(64), 0.08)}, + {Mutator::MutateThreadBinding(), FloatImm(DataType::Float(64), 0.02)}}; +} + +Map Mutator::DefaultCUDATensorCore() { return Mutator::DefaultCUDA(); } + +Map Mutator::DefaultHexagon() { + return Map{ + {Mutator::MutateTileSize(), FloatImm(DataType::Float(64), 0.9)}, + {Mutator::MutateComputeLocation(), FloatImm(DataType::Float(64), 0.05)}, + {Mutator::MutateUnroll(), FloatImm(DataType::Float(64), 0.03)}, + {Mutator::MutateParallel(/*max_jobs_per_core=*/16), FloatImm(DataType::Float(64), 0.02)}}; +} + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { const auto* self = n.as(); @@ -72,6 +97,11 @@ TVM_REGISTER_GLOBAL("meta_schedule.MutatorApply") }); TVM_REGISTER_GLOBAL("meta_schedule.MutatorClone").set_body_method(&MutatorNode::Clone); TVM_REGISTER_GLOBAL("meta_schedule.MutatorPyMutator").set_body_typed(Mutator::PyMutator); +TVM_REGISTER_GLOBAL("meta_schedule.MutatorDefaultLLVM").set_body_typed(Mutator::DefaultLLVM); +TVM_REGISTER_GLOBAL("meta_schedule.MutatorDefaultCUDA").set_body_typed(Mutator::DefaultCUDA); +TVM_REGISTER_GLOBAL("meta_schedule.MutatorDefaultCUDATensorCore") + .set_body_typed(Mutator::DefaultCUDATensorCore); +TVM_REGISTER_GLOBAL("meta_schedule.MutatorDefaultHexagon").set_body_typed(Mutator::DefaultHexagon); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/postproc/postproc.cc b/src/meta_schedule/postproc/postproc.cc index 957d6e7364e4..acc157e36e94 100644 --- a/src/meta_schedule/postproc/postproc.cc +++ b/src/meta_schedule/postproc/postproc.cc @@ -50,6 +50,48 @@ Postproc Postproc::PyPostproc( return Postproc(n); } +Array Postproc::DefaultLLVM() { + return Array{ + Postproc::DisallowDynamicLoop(), + Postproc::RewriteParallelVectorizeUnroll(), + Postproc::RewriteReductionBlock(), + Postproc::RewriteLayout(), + }; +} + +Array Postproc::DefaultCUDA() { + return Array{ + Postproc::DisallowDynamicLoop(), + Postproc::RewriteCooperativeFetch(), + Postproc::RewriteUnboundBlock(/*max_threadblocks=*/256), + Postproc::RewriteParallelVectorizeUnroll(), + Postproc::RewriteReductionBlock(), + Postproc::VerifyGPUCode(), + }; +} + +Array Postproc::DefaultCUDATensorCore() { + return Array{ + Postproc::DisallowDynamicLoop(), + Postproc::RewriteCooperativeFetch(), + Postproc::RewriteUnboundBlock(/*max_threadblocks=*/256), + Postproc::RewriteParallelVectorizeUnroll(), + Postproc::RewriteReductionBlock(), + Postproc::RewriteTensorize(/*vectorize_init_loop=*/false), + Postproc::VerifyGPUCode(), + }; +} + +Array Postproc::DefaultHexagon() { + return Array{ + Postproc::DisallowDynamicLoop(), + Postproc::RewriteParallelVectorizeUnroll(), // + Postproc::RewriteReductionBlock(), + // TODO(masahi): Fix RewriteLayout for link-params=True case + // Postproc::RewriteLayout(), + }; +} + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { const auto* self = n.as(); @@ -67,6 +109,12 @@ TVM_REGISTER_GLOBAL("meta_schedule.PostprocInitializeWithTuneContext") TVM_REGISTER_GLOBAL("meta_schedule.PostprocApply").set_body_method(&PostprocNode::Apply); TVM_REGISTER_GLOBAL("meta_schedule.PostprocClone").set_body_method(&PostprocNode::Clone); TVM_REGISTER_GLOBAL("meta_schedule.PostprocPyPostproc").set_body_typed(Postproc::PyPostproc); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocDefaultLLVM").set_body_typed(Postproc::DefaultLLVM); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocDefaultCUDA").set_body_typed(Postproc::DefaultCUDA); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocDefaultCUDATensorCore") + .set_body_typed(Postproc::DefaultCUDATensorCore); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocDefaultHexagon") + .set_body_typed(Postproc::DefaultHexagon); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc index ac9f45ca8ef4..427653b06c2a 100644 --- a/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc +++ b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc @@ -97,7 +97,7 @@ class RewriteCooperativeFetchNode : public PostprocNode { if (Optional v = context->target.value()->GetAttr("thread_warp_size")) { this->thread_warp_size_ = v.value()->value; } else { - TVM_PY_LOG(INFO, context->logging_func) << "'thread_warp_size' is not defined in the target"; + TVM_PY_LOG(INFO, context->logger) << "'thread_warp_size' is not defined in the target"; } } diff --git a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc index f2fc67f74cc7..e8d821636fd3 100644 --- a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc +++ b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc @@ -32,12 +32,12 @@ class CrossThreadReductionNode : public ScheduleRuleNode { Optional opt_warp_size = target->GetAttr("thread_warp_size"); if (!opt_max_threads_per_block.defined()) { - TVM_PY_LOG(WARNING, context->logging_func) + TVM_PY_LOG(WARNING, context->logger) << "Target does not have attribute \"max_threads_per_block\", therefore the " "rule CrossThreadReduction will not be applied"; } if (!opt_warp_size.defined()) { - TVM_PY_LOG(WARNING, context->logging_func) + TVM_PY_LOG(WARNING, context->logger) << "Target does not have attribute \"thread_warp_size\", therefore the rule " "CrossThreadReduction will not be applied"; } diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index 2ae6714f55d8..d9c46015eac3 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -84,10 +84,10 @@ void MultiLevelTilingNode::InitializeWithTuneContext(const TuneContext& context) if (Optional v = context->target.value()->GetAttr("thread_warp_size")) { this->thread_warp_size_ = v.value()->value; } else { - TVM_PY_LOG(INFO, context->logging_func) << "'thread_warp_size' is not defined in the target"; + TVM_PY_LOG(INFO, context->logger) << "'thread_warp_size' is not defined in the target"; } } - logging_func = context->logging_func; + logger = context->logger; } // Entry of the mega rule; Inherited from ScheduleRuleNode diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.h b/src/meta_schedule/schedule_rule/multi_level_tiling.h index 8f55e8e7e4e4..98b4634af106 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.h +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.h @@ -193,7 +193,7 @@ class MultiLevelTilingNode : public ScheduleRuleNode { /*! \brief The maximum number of threads to be used size of a thread warp */ int max_threads_per_block_; /*! \brief The logging function */ - PackedFunc logging_func; + PackedFunc logger; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("structure", &structure); diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index 2ec78c1918e9..e8a03c722656 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -209,12 +209,12 @@ Array MultiLevelTilingTensorCoreNode::Apply(const Schedule& sch, } Array results; for (auto&& state : ApplySubRules(initial_states)) { - TVM_PY_LOG(INFO, logging_func) << "Sketch " << results.size() << ": tensorizing with " - << state.as()->intrin_group.compute_intrin; + TVM_PY_LOG(INFO, logger) << "Sketch " << results.size() << ": tensorizing with " + << state.as()->intrin_group.compute_intrin; results.push_back(std::move(state->sch)); } if (results.empty()) { - TVM_PY_LOG(INFO, logging_func) << "The workload cannot be tensorized."; + TVM_PY_LOG(INFO, logger) << "The workload cannot be tensorized."; return {original_sch}; } return results; @@ -293,8 +293,8 @@ std::vector MultiLevelTilingTensorCoreNode::AddReadReuseTensorCore( } else if (dtype.is_int() && dtype.bits() == 8) { sch->StorageAlign(cache_read, 0, -2, 32, 16); } else { - TVM_PY_LOG(WARNING, logging_func) << "StorageAlign is not applied for data type " << dtype - << ", shared memory accesses might be inefficient."; + TVM_PY_LOG(WARNING, logger) << "StorageAlign is not applied for data type " << dtype + << ", shared memory accesses might be inefficient."; } } return {state}; diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc index 8485e697eb24..428a1206a4ca 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc @@ -49,17 +49,17 @@ class MultiLevelTilingWithIntrinNode : public MultiLevelTilingNode { Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { auto desc_func = tir::TensorIntrin::Get(intrin_name).value()->desc; if (!CheckAutoTensorizeApplicable(sch, block_rv, desc_func)) { - TVM_PY_LOG(INFO, logging_func) << "The workload cannot be tensorized."; + TVM_PY_LOG(INFO, logger) << "The workload cannot be tensorized."; return {sch}; } auto res = MultiLevelTilingNode::Apply(sch->Copy(), block_rv); if (res.empty()) { - TVM_PY_LOG(INFO, logging_func) << "The workload cannot be tensorized."; + TVM_PY_LOG(INFO, logger) << "The workload cannot be tensorized."; return {sch}; } - TVM_PY_LOG(INFO, logging_func) << "Tensorizing with " << intrin_name; + TVM_PY_LOG(INFO, logger) << "Tensorizing with " << intrin_name; return res; } diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc index 416b43f46d56..8333833bfafa 100644 --- a/src/meta_schedule/schedule_rule/schedule_rule.cc +++ b/src/meta_schedule/schedule_rule/schedule_rule.cc @@ -51,6 +51,152 @@ ScheduleRule ScheduleRule::PyScheduleRule( return ScheduleRule(n); } +Array ScheduleRule::DefaultLLVM() { + return { + ScheduleRule::AutoInline( + /*into_producer=*/false, + /*into_consumer=*/true, + /*inline_const_tensor=*/true, + /*disallow_if_then_else=*/true, + /*require_injective=*/true, + /*require_ordered=*/true, + /*disallow_op=*/Array{"tir.exp"}), + ScheduleRule::AddRFactor( + /*max_jobs_per_core=*/16, + /*max_innermost_factor=*/Integer(64)), + ScheduleRule::MultiLevelTiling( + /*structure=*/"SSRSRS", + /*tile_binds=*/NullOpt, + /*max_innermost_factor=*/Integer(64), + /*vector_load_lens=*/NullOpt, + /*reuse_read=*/NullOpt, + /*reuse_write=*/ + Map{{"req", String("may")}, + {"levels", Array{1, 2}}, + {"scope", String("global")}}), + ScheduleRule::ParallelizeVectorizeUnroll( + /*max_jobs_per_core=*/16, + /*max_vectorize_extent=*/64, + /*unroll_max_steps=*/Array{0, 16, 64, 512}, + /*unroll_explicit=*/true), + ScheduleRule::RandomComputeLocation(), + }; +} + +Array ScheduleRule::DefaultCUDA() { + return { + ScheduleRule::MultiLevelTiling( + /*structure=*/"SSSRRSRS", + /*tile_binds=*/Array{"blockIdx.x", "vthread.x", "threadIdx.x"}, + /*max_innermost_factor=*/Integer(64), + /*vector_load_lens=*/Array{1, 2, 3, 4, 8, 16}, + /*reuse_read=*/ + Map{{"req", String("must")}, + {"levels", Array{4}}, // + {"scope", String("shared")}}, + /*reuse_write=*/ + Map{{"req", String("must")}, + {"levels", Array{3}}, // + {"scope", String("local")}}), + ScheduleRule::AutoInline( + /*into_producer=*/true, + /*into_consumer=*/true, + /*inline_const_tensor=*/true, + /*disallow_if_then_else=*/false, + /*require_injective=*/false, + /*require_ordered=*/false, + /*disallow_op=*/Array{}), + ScheduleRule::CrossThreadReduction( + /*thread_extents=*/Array{4, 8, 16, 32, 64, 128, 256, 512}), + ScheduleRule::ParallelizeVectorizeUnroll( + /*max_jobs_per_core=*/-1, + /*max_vectorize_extent=*/-1, + /*unroll_max_steps=*/Array{0, 16, 64, 512, 1024}, + /*unroll_explicit=*/true), + ScheduleRule::AutoBind( + /*max_threadblocks=*/256, + /*thread_extents*/ Array{32, 64, 128, 256, 512, 1024}), + }; +} + +Array ScheduleRule::DefaultCUDATensorCore() { + Array> intrin_groups = { + { + {"init", "wmma_fill_16x16x16_f16"}, + {"load_a", "wmma_load_16x16x16_f16_a"}, + {"load_b", "wmma_load_16x16x16_f16_b"}, + {"compute", "wmma_sync_16x16x16_f16f16f16"}, + {"store", "wmma_store_16x16x16_f16_shared"}, + }, + { + {"init", "wmma_fill_16x16x16_f16"}, + {"load_a", "wmma_load_16x16x16_f16_a"}, + {"load_b", "wmma_load_16x16x16_f16_b_trans"}, + {"compute", "wmma_sync_16x16x16_f16f16f16_trans"}, + {"store", "wmma_store_16x16x16_f16_shared"}, + }, + { + {"init", "wmma_fill_16x16x16_s32"}, + {"load_a", "wmma_load_16x16x16_s8_a"}, + {"load_b", "wmma_load_16x16x16_s8_b"}, + {"compute", "wmma_sync_16x16x16_s8s8s32"}, + {"store", "wmma_store_16x16x16_s32_shared"}, + }, + { + {"init", "wmma_fill_16x16x16_s32"}, + {"load_a", "wmma_load_16x16x16_s8_a"}, + {"load_b", "wmma_load_16x16x16_s8_b_trans"}, + {"compute", "wmma_sync_16x16x16_s8s8s32_trans"}, + {"store", "wmma_store_16x16x16_s32_shared"}, + }, + }; + Array results{ScheduleRule::MultiLevelTilingTensorCore( + /*intrin_groups=*/intrin_groups, + /*structure=*/"SSSRRSRS", + /*tile_binds=*/Array{"blockIdx.x", "vthread.x", "threadIdx.x"}, + /*max_innermost_factor=*/Integer(4), + /*vector_load_lens=*/Array{1, 2, 3, 4, 8, 16}, + /*reuse_read=*/ + Map{{"req", String("must")}, + {"levels", Array{4}}, // + {"scope", String("shared")}}, + /*reuse_write=*/ + Map{{"req", String("must")}, + {"levels", Array{2}}, // + {"scope", String("shared")}}, + /*use_software_pipeline=*/false)}; + Array append = ScheduleRule::DefaultCUDA(); + results.insert(results.end(), append.begin(), append.end()); + return results; +} + +Array ScheduleRule::DefaultHexagon() { + return { + ScheduleRule::AutoInline( + /*into_producer=*/false, + /*into_consumer=*/true, + /*inline_const_tensor=*/true, + /*disallow_if_then_else=*/true, + /*require_injective=*/true, + /*require_ordered=*/true, + /*disallow_op=*/Array{"tir.exp"}), + ScheduleRule::MultiLevelTilingWideVector( + /*structure=*/"SRSRS", + /*vector_length_in_bits=*/1024, + /*max_innermost_factor=*/Integer(128), + /*reuse_read=*/NullOpt, + /*reuse_write=*/ + Map{{"req", String("may")}, + {"levels", Array{1, 2}}, + {"scope", String("global")}}), + ScheduleRule::ParallelizeVectorizeUnroll( + /*max_jobs_per_core=*/16, + /*max_vectorize_extent=*/128, + /*unroll_max_steps=*/Array{0, 16, 64, 512}, + /*unroll_explicit=*/true), + }; +} + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { const auto* self = n.as(); @@ -71,6 +217,14 @@ TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleClone") .set_body_method(&ScheduleRuleNode::Clone); TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRulePyScheduleRule") .set_body_typed(ScheduleRule::PyScheduleRule); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultLLVM") + .set_body_typed(ScheduleRule::DefaultLLVM); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultCUDA") + .set_body_typed(ScheduleRule::DefaultCUDA); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultCUDATensorCore") + .set_body_typed(ScheduleRule::DefaultCUDATensorCore); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultHexagon") + .set_body_typed(ScheduleRule::DefaultHexagon); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/search_strategy/evolutionary_search.cc b/src/meta_schedule/search_strategy/evolutionary_search.cc index 5930704eb0d1..df67d371929b 100644 --- a/src/meta_schedule/search_strategy/evolutionary_search.cc +++ b/src/meta_schedule/search_strategy/evolutionary_search.cc @@ -238,14 +238,18 @@ class EvolutionarySearchNode : public SearchStrategyNode { struct State { /*! \brief The search strategy itself */ EvolutionarySearchNode* self; - /*! \brief The design spaces. Decisions are not used so traces only. */ - Array design_spaces; + /*! \brief The number of total trials. */ + int max_trials; + /*! \brief The number of trials per iteration. */ + int num_trials_per_iter; /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ int st; /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ int ed; /*! \brief The counter of returning empty results. */ int num_empty_iters; + /*! \brief The design spaces. Decisions are not used so traces only. */ + Array design_spaces; /*! \brief Pre thread data including module to be tuned and random state. */ std::vector per_thread_data_; /*! @@ -260,14 +264,19 @@ class EvolutionarySearchNode : public SearchStrategyNode { /*! \brief The token registered for the given workload in database. */ Workload token_{nullptr}; - explicit State(EvolutionarySearchNode* self, Array design_spaces, Database database, - CostModel cost_model) + explicit State(EvolutionarySearchNode* self, int max_trials, int num_trials_per_iter, + Array design_space_schedules, Database database, CostModel cost_model) : self(self), - design_spaces(design_spaces), + max_trials(max_trials), + num_trials_per_iter(num_trials_per_iter), st(0), - ed(self->num_trials_per_iter), + ed(num_trials_per_iter), num_empty_iters(0) { - const TuneContextNode* ctx = self->context_; + design_spaces.reserve(design_spaces.size()); + for (const Schedule& space : design_space_schedules) { + design_spaces.push_back(space->trace().value()->Simplified(true)); + } + const TuneContextNode* ctx = self->ctx_; IRModule mod = ctx->mod.value(); this->per_thread_data_.resize(ctx->num_threads); for (PerThreadData& data : this->per_thread_data_) { @@ -316,17 +325,17 @@ class EvolutionarySearchNode : public SearchStrategyNode { }; /*! \brief The tuning context of the evolutionary search strategy. */ - const TuneContextNode* context_{nullptr}; + const TuneContextNode* ctx_{nullptr}; + /*! \brief The postprocessors */ + Array postprocs_; + /*! \brief The mutators and their probability. */ + Map mutator_probs_; /*! \brief The random state. To be initialized with TuneContext. */ TRandState rand_state_; /*! \brief The state of the search strategy. */ std::unique_ptr state_ = nullptr; /*** Configuration: global ***/ - /*! \brief The number of trials per iteration. */ - int num_trials_per_iter; - /*! \brief The number of total trials. */ - int max_trials_per_task; /*! \brief The population size in the evolutionary search. */ int population_size; /*! @@ -356,8 +365,6 @@ class EvolutionarySearchNode : public SearchStrategyNode { // `state_` is not visited /*** Configuration: global ***/ - v->Visit("max_trials_per_task", &max_trials_per_task); - v->Visit("num_trials_per_iter", &num_trials_per_iter); v->Visit("population_size", &population_size); v->Visit("num_empty_iters_before_early_stop", &num_empty_iters_before_early_stop); /*** Configuration: the initial population ***/ @@ -374,23 +381,25 @@ class EvolutionarySearchNode : public SearchStrategyNode { static constexpr const char* _type_key = "meta_schedule.EvolutionarySearch"; TVM_DECLARE_FINAL_OBJECT_INFO(EvolutionarySearchNode, SearchStrategyNode); - void InitializeWithTuneContext(const TuneContext& context) final { - CHECK(context.defined()) << "TuneContext must be defined!"; - CHECK(context->num_threads > 0) << "Number of threads has to be larger than 0."; - CHECK(context->target.defined()) << "Target must be defined!"; - this->context_ = context.get(); - this->rand_state_ = ForkSeed(&context->rand_state); - for (const auto& kv : context->mutator_probs) { - double mass = kv.second->value; - TVM_META_SCHEDULE_CHECK_PROB_RANGE(mass, "mutator_probs"); - } + void InitializeWithTuneContext(const TuneContext& ctx) final { + CHECK(ctx->num_threads > 0) << "ValueError: `TuneContext.num_threads` must be > 0"; + CHECK(ctx->space_generator.defined()) + << "ValueError: `TuneContext.space_generator` must be defined"; + CHECK(ctx->space_generator.value()->postprocs.defined()) + << "ValueError: `TuneContext.space_generator.postprocs` must be defined"; + CHECK(ctx->space_generator.value()->mutator_probs.defined()) + << "ValueError: `TuneContext.space_generator.mutator_probs` must be defined"; + this->ctx_ = ctx.get(); + this->postprocs_ = ctx->space_generator.value()->postprocs.value(); + this->mutator_probs_ = ctx->space_generator.value()->mutator_probs.value(); + this->rand_state_ = ForkSeed(&ctx->rand_state); this->state_.reset(); } - void PreTuning(const Array& design_spaces, const Optional& database, - const Optional& cost_model) final { + void PreTuning(int max_trials, int num_trials_per_iter, const Array& design_spaces, + const Optional& database, const Optional& cost_model) final { ICHECK(!design_spaces.empty()); - CHECK(this->context_ != nullptr) << "ValueError: Did you forget to initialize the TuneContext?"; + CHECK(this->ctx_ != nullptr) << "ValueError: Did you forget to initialize the TuneContext?"; CHECK(database.defined()) << "ValueError: Database is not supplied in PreTuning. Evolutionary" "search algorithm requires a database to be present, so that it " @@ -401,23 +410,15 @@ class EvolutionarySearchNode : public SearchStrategyNode { "algorithm expects a cost model to filter out potentially less efficient kernels. If " "you do not expect a cost model to help, please use " "`tvm.meta_schedule.cost_model.RandomModel`"; - if (this->state_ != nullptr) { - TVM_PY_LOG(WARNING, this->context_->logging_func) - << "EvolutionarySearch is already initialized."; - this->state_.reset(); - } - ICHECK(this->state_ == nullptr); - Array design_space_traces; - design_space_traces.reserve(design_spaces.size()); - for (const Schedule& space : design_spaces) { - design_space_traces.push_back(space->trace().value()->Simplified(true)); - } - this->state_ = - std::make_unique(this, design_space_traces, database.value(), cost_model.value()); + CHECK(this->state_ == nullptr) + << "ValueError: `PreTuning` is already invoked without corresponding `PostTuning`."; + this->state_ = std::make_unique(this, max_trials, num_trials_per_iter, design_spaces, + database.value(), cost_model.value()); } void PostTuning() final { - ICHECK(this->state_ != nullptr); + CHECK(this->state_ != nullptr) << "ValueError: `PostTuning` is invoked without corresponding " + "`PreTuning`, or `PostTuning` is already invoked."; this->state_.reset(); } @@ -434,8 +435,6 @@ class EvolutionarySearchNode : public SearchStrategyNode { SearchStrategy Clone() const final { ObjectPtr n = make_object(); - n->max_trials_per_task = this->max_trials_per_task; - n->num_trials_per_iter = this->num_trials_per_iter; n->population_size = this->population_size; n->num_empty_iters_before_early_stop = this->num_empty_iters_before_early_stop; n->init_measured_ratio = this->init_measured_ratio; @@ -444,7 +443,7 @@ class EvolutionarySearchNode : public SearchStrategyNode { n->genetic_mutate_prob = this->genetic_mutate_prob; n->genetic_max_fail_count = this->genetic_max_fail_count; n->eps_greedy = this->eps_greedy; - n->context_ = this->context_; + n->ctx_ = this->ctx_; n->rand_state_ = this->rand_state_; n->state_ = nullptr; // cleared the state return SearchStrategy(n); @@ -460,7 +459,7 @@ std::vector EvolutionarySearchNode::State::PickBestFromDatabase(int nu measured_traces.push_back(record->trace); } int actual_num = measured_traces.size(); - ThreadedTraceApply pp(self->context_->postprocs); + ThreadedTraceApply pp(self->postprocs_); std::vector results(actual_num, Schedule{nullptr}); auto f_proc_measured = [this, &measured_traces, &results, &pp](int thread_id, int trace_id) -> void { @@ -477,13 +476,13 @@ std::vector EvolutionarySearchNode::State::PickBestFromDatabase(int nu throw; } }; - support::parallel_for_dynamic(0, actual_num, self->context_->num_threads, f_proc_measured); + support::parallel_for_dynamic(0, actual_num, self->ctx_->num_threads, f_proc_measured); return results; } std::vector EvolutionarySearchNode::State::SampleInitPopulation(int num) { auto _ = Profiler::TimedScope("EvoSearch/SampleInitPopulation"); - ThreadedTraceApply pp(self->context_->postprocs); + ThreadedTraceApply pp(self->postprocs_); std::vector out_schs; while (static_cast(out_schs.size()) < self->init_min_unmeasured) { std::vector results(num, Schedule{nullptr}); @@ -499,14 +498,14 @@ std::vector EvolutionarySearchNode::State::SampleInitPopulation(int nu result = sch.value(); } }; - support::parallel_for_dynamic(0, num, self->context_->num_threads, f_proc_unmeasured); + support::parallel_for_dynamic(0, num, self->ctx_->num_threads, f_proc_unmeasured); for (int i = 0; i < num; i++) { if (results[i].defined()) { out_schs.push_back(results[i]); } } - TVM_PY_LOG(INFO, self->context_->logging_func) << "Sample-Init-Population summary:\n" - << pp.SummarizeFailures(); + TVM_PY_LOG(INFO, self->ctx_->logger) << "Sample-Init-Population summary:\n" + << pp.SummarizeFailures(); } return out_schs; } @@ -524,7 +523,7 @@ std::vector EvolutionarySearchNode::State::EvolveWithCostModel( for (int iter = 0;; ++iter) { // Predict normalized score with the cost model, std::vector scores = - PredictNormalizedScore(population, GetRef(self->context_), this->cost_model_); + PredictNormalizedScore(population, GetRef(self->ctx_), this->cost_model_); { auto _ = Profiler::TimedScope("EvoSearch/Evolve/Misc"); @@ -545,12 +544,12 @@ std::vector EvolutionarySearchNode::State::EvolveWithCostModel( } // Set threaded samplers, with probability from predicated normalized throughput for (PerThreadData& data : this->per_thread_data_) { - data.Set(scores, self->genetic_mutate_prob, self->context_->mutator_probs); + data.Set(scores, self->genetic_mutate_prob, self->mutator_probs_); } } { auto _ = Profiler::TimedScope("EvoSearch/Evolve/Mutation"); - ThreadedTraceApply pp(self->context_->postprocs); + ThreadedTraceApply pp(self->postprocs_); ConcurrentBitmask cbmask(self->population_size); std::vector next_population(self->population_size, Schedule{nullptr}); // The worker function @@ -589,13 +588,12 @@ std::vector EvolutionarySearchNode::State::EvolveWithCostModel( result = population.at(sampled_trace_id); } }; - support::parallel_for_dynamic(0, self->population_size, self->context_->num_threads, + support::parallel_for_dynamic(0, self->population_size, self->ctx_->num_threads, f_find_candidate); population.swap(next_population); - TVM_PY_LOG(INFO, self->context_->logging_func) - << "Evolve iter #" << iter << " done. Summary:\n" - << pp.SummarizeFailures(); + TVM_PY_LOG(INFO, self->ctx_->logger) << "Evolve iter #" << iter << " done. Summary:\n" + << pp.SummarizeFailures(); } } // Return the best states from the heap, sorting from higher score to lower ones @@ -622,7 +620,7 @@ std::vector EvolutionarySearchNode::State::EvolveWithCostModel( os << std::fixed << std::setprecision(4) << heap.heap.at(i).score; } } - TVM_PY_LOG(INFO, self->context_->logging_func) + TVM_PY_LOG(INFO, self->ctx_->logger) << "Scores of the best " << n << " candidates:" << os.str(); return results; } @@ -673,33 +671,32 @@ std::vector EvolutionarySearchNode::State::PickWithEpsGreedy( } Optional> EvolutionarySearchNode::State::GenerateMeasureCandidates() { - if (st >= self->max_trials_per_task) { + if (st >= max_trials) { return NullOpt; } - int sample_num = self->num_trials_per_iter; - if (ed > self->max_trials_per_task) { - sample_num = self->max_trials_per_task - st; - ed = self->max_trials_per_task; + int sample_num = num_trials_per_iter; + if (ed > max_trials) { + sample_num = max_trials - st; + ed = max_trials; } ICHECK_LT(st, ed); int pop = self->population_size; std::vector inits; inits.reserve(pop); - TVM_PY_LOG(INFO, self->context_->logging_func) << "Generating candidates......"; + TVM_PY_LOG(INFO, self->ctx_->logger) << "Generating candidates......"; std::vector measured = PickBestFromDatabase(pop * self->init_measured_ratio); - TVM_PY_LOG(INFO, self->context_->logging_func) + TVM_PY_LOG(INFO, self->ctx_->logger) << "Picked top " << measured.size() << " candidate(s) from database"; std::vector unmeasured = SampleInitPopulation(pop - measured.size()); - TVM_PY_LOG(INFO, self->context_->logging_func) - << "Sampled " << unmeasured.size() << " candidate(s)"; + TVM_PY_LOG(INFO, self->ctx_->logger) << "Sampled " << unmeasured.size() << " candidate(s)"; inits.insert(inits.end(), measured.begin(), measured.end()); inits.insert(inits.end(), unmeasured.begin(), unmeasured.end()); std::vector bests = EvolveWithCostModel(inits, sample_num); - TVM_PY_LOG(INFO, self->context_->logging_func) + TVM_PY_LOG(INFO, self->ctx_->logger) << "Got " << bests.size() << " candidate(s) with evolutionary search"; std::vector picks = PickWithEpsGreedy(unmeasured, bests, sample_num); - TVM_PY_LOG(INFO, self->context_->logging_func) + TVM_PY_LOG(INFO, self->ctx_->logger) << "Sending " << picks.size() << " candidates(s) for measurement"; if (picks.empty()) { ++this->num_empty_iters; @@ -716,9 +713,7 @@ void EvolutionarySearchNode::State::NotifyRunnerResults( ed += results.size(); } -SearchStrategy SearchStrategy::EvolutionarySearch(int num_trials_per_iter, // - int max_trials_per_task, // - int population_size, // +SearchStrategy SearchStrategy::EvolutionarySearch(int population_size, // double init_measured_ratio, // int init_min_unmeasured, // int genetic_num_iters, // @@ -729,8 +724,6 @@ SearchStrategy SearchStrategy::EvolutionarySearch(int num_trials_per_iter, / TVM_META_SCHEDULE_CHECK_PROB_RANGE(genetic_mutate_prob, "Mutation probability"); TVM_META_SCHEDULE_CHECK_PROB_RANGE(eps_greedy, "Greedy pick probability"); ObjectPtr n = make_object(); - n->num_trials_per_iter = num_trials_per_iter; - n->max_trials_per_task = max_trials_per_task; n->population_size = population_size; n->num_empty_iters_before_early_stop = 5; n->init_measured_ratio = init_measured_ratio; diff --git a/src/meta_schedule/search_strategy/replay_func.cc b/src/meta_schedule/search_strategy/replay_func.cc index 6914ab2f0f0a..7bb4a02ab299 100644 --- a/src/meta_schedule/search_strategy/replay_func.cc +++ b/src/meta_schedule/search_strategy/replay_func.cc @@ -28,65 +28,69 @@ class ReplayFuncNode : public SearchStrategyNode { struct State { /*! \brief The search strategy itself */ ReplayFuncNode* self; + /*! \brief The number of total trials. */ + int max_trials; + /*! \brief The number of trials per iteration. */ + int num_trials_per_iter; /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ int st; /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ int ed; - explicit State(ReplayFuncNode* self) : self(self), st(0), ed(self->num_trials_per_iter) { - const TuneContextNode* ctx = self->context_; - ICHECK(ctx); + explicit State(ReplayFuncNode* self, int max_trials, int num_trials_per_iter) + : self(self), + max_trials(max_trials), + num_trials_per_iter(num_trials_per_iter), + st(0), + ed(num_trials_per_iter) { + CHECK(self->mod_.defined() && self->space_generator_.defined()) + << "ValueError: The search strategy has not been initialized."; } inline Optional> GenerateMeasureCandidates(); inline void NotifyRunnerResults(const Array& results); }; - /*! \brief The number of trials per iteration. */ - int num_trials_per_iter; - /*! \brief The number of total trials. */ - int max_trials_per_task; - - /*! \brief The tuning context of the search strategy. */ - const TuneContextNode* context_{nullptr}; /*! \brief The random state. -1 means using random number. */ TRandState rand_state_ = -1; + /*! \brief The IRModule to be scheduled from TuneContext. */ + Optional mod_ = NullOpt; + /*! \brief The space generator from TuneContext. */ + Optional space_generator_ = NullOpt; /*! \brief The state of the search strategy. */ std::unique_ptr state_ = nullptr; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("num_trials_per_iter", &num_trials_per_iter); - v->Visit("max_trials_per_task", &max_trials_per_task); - // `context_` is not visited. - // `rand_state_` is not visited - // `state_` is not visited - } + void VisitAttrs(tvm::AttrVisitor* v) {} static constexpr const char* _type_key = "meta_schedule.ReplayFunc"; TVM_DECLARE_FINAL_OBJECT_INFO(ReplayFuncNode, SearchStrategyNode); - void InitializeWithTuneContext(const TuneContext& context) final { - CHECK(context->space_generator.defined()) + void InitializeWithTuneContext(const TuneContext& ctx) final { + CHECK(ctx->mod.defined()) << "ValueError: TuneContext.mod is not defined"; + CHECK(ctx->space_generator.defined()) << "ValueError: TuneContext.space_generator is not defined"; - CHECK(context->mod.defined()) << "ValueError: TuneContext.mod is not defined"; - this->context_ = context.get(); - this->rand_state_ = ForkSeed(&context->rand_state); + if (!ctx->space_generator.value()->postprocs.defined()) { + TVM_PY_LOG(WARNING, ctx->logger) + << "`postprocs` is not defined in " << ctx->space_generator.value() + << ". Please explicitly set `postprocs` to an empty list if you don't want to " + "apply any post-processing."; + } + this->rand_state_ = ForkSeed(&ctx->rand_state); + this->mod_ = ctx->mod; + this->space_generator_ = ctx->space_generator; this->state_.reset(); } - void PreTuning(const Array& design_spaces, const Optional& database, - const Optional& cost_model) final { - CHECK(this->context_ != nullptr) << "ValueError: Did you forget to initialize the TuneContext?"; - if (this->state_ != nullptr) { - TVM_PY_LOG(WARNING, this->context_->logging_func) << "ReplayFunc is already initialized."; - this->state_.reset(); - } - ICHECK(this->state_ == nullptr); - this->state_ = std::make_unique(this); + void PreTuning(int max_trials, int num_trials_per_iter, const Array& design_spaces, + const Optional& database, const Optional& cost_model) final { + CHECK(this->state_ == nullptr) + << "ValueError: `PreTuning` is already invoked without corresponding `PostTuning`."; + this->state_ = std::make_unique(this, max_trials, num_trials_per_iter); } void PostTuning() final { - ICHECK(this->state_ != nullptr); + CHECK(this->state_ != nullptr) << "ValueError: `PostTuning` is invoked without corresponding " + "`PreTuning`, or `PostTuning` is already invoked."; this->state_.reset(); } @@ -103,32 +107,30 @@ class ReplayFuncNode : public SearchStrategyNode { SearchStrategy Clone() const final { ObjectPtr n = make_object(); - n->num_trials_per_iter = this->num_trials_per_iter; - n->max_trials_per_task = this->max_trials_per_task; - n->context_ = this->context_; - n->rand_state_ = this->rand_state_; - n->state_ = nullptr; // cleared the state + n->rand_state_ = -1; + n->mod_ = NullOpt; + n->space_generator_ = NullOpt; + n->state_ = nullptr; return SearchStrategy(n); } }; inline Optional> ReplayFuncNode::State::GenerateMeasureCandidates() { - if (st >= self->max_trials_per_task) { + if (st >= max_trials) { return NullOpt; } - ed = std::min(ed, self->max_trials_per_task); + ed = std::min(ed, max_trials); Array result; - const TuneContextNode* ctx = self->context_; - ICHECK(ctx); - IRModule mod = ctx->mod.value(); + IRModule mod = self->mod_.value(); + Array postprocs = self->space_generator_.value()->postprocs.value_or({}); for (int i = st; i < ed; i++) { for (;;) { - Array schs = ctx->space_generator.value()->GenerateDesignSpace(mod); + Array schs = self->space_generator_.value()->GenerateDesignSpace(mod); int design_space_index = tir::SampleInt(&self->rand_state_, 0, schs.size()); tir::Schedule sch = schs[design_space_index]; sch->EnterPostproc(); bool failed = false; - for (const Postproc& proc : ctx->postprocs) { + for (const Postproc& proc : postprocs) { if (!proc->Apply(sch)) { failed = true; break; @@ -145,14 +147,12 @@ inline Optional> ReplayFuncNode::State::GenerateMeasureC } inline void ReplayFuncNode::State::NotifyRunnerResults(const Array& results) { - st += self->num_trials_per_iter; - ed += self->num_trials_per_iter; + st += num_trials_per_iter; + ed += num_trials_per_iter; } -SearchStrategy SearchStrategy::ReplayFunc(int num_trials_per_iter, int max_trials_per_task) { +SearchStrategy SearchStrategy::ReplayFunc() { ObjectPtr n = make_object(); - n->num_trials_per_iter = num_trials_per_iter; - n->max_trials_per_task = max_trials_per_task; return SearchStrategy(n); } diff --git a/src/meta_schedule/search_strategy/replay_trace.cc b/src/meta_schedule/search_strategy/replay_trace.cc index bd553bf037d1..d76ee220a858 100644 --- a/src/meta_schedule/search_strategy/replay_trace.cc +++ b/src/meta_schedule/search_strategy/replay_trace.cc @@ -30,6 +30,10 @@ class ReplayTraceNode : public SearchStrategyNode { ReplayTraceNode* self; /*! \brief The design spaces. */ Array design_spaces; + /*! \brief The number of total trials. */ + int max_trials; + /*! \brief The number of trials per iteration. */ + int num_trials_per_iter; /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ int st; /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ @@ -38,13 +42,17 @@ class ReplayTraceNode : public SearchStrategyNode { /*! \brief The module to be tuned. */ Array per_thread_mod_{nullptr}; - explicit State(ReplayTraceNode* self, Array design_spaces) - : self(self), design_spaces(design_spaces), st(0), ed(self->num_trials_per_iter) { - const TuneContextNode* ctx = self->context_; - ICHECK(ctx); - IRModule mod = ctx->mod.value(); - this->per_thread_mod_.reserve(ctx->num_threads); - for (int i = 0; i < ctx->num_threads; i++) { + explicit State(ReplayTraceNode* self, Array design_spaces, int max_trials, + int num_trials_per_iter) + : self(self), + design_spaces(design_spaces), + max_trials(max_trials), + num_trials_per_iter(num_trials_per_iter), + st(0), + ed(num_trials_per_iter) { + IRModule mod = self->mod_.value(); + this->per_thread_mod_.reserve(self->num_threads_); + for (int i = 0; i < self->num_threads_; i++) { this->per_thread_mod_.push_back(DeepCopyIRModule(mod)); } } @@ -53,54 +61,61 @@ class ReplayTraceNode : public SearchStrategyNode { inline void NotifyRunnerResults(const Array& results); }; - /*! \brief The number of trials per iteration. */ - int num_trials_per_iter; - /*! \brief The number of total trials. */ - int max_trials_per_task; /*! \brief The max number of failures during trace replaying. */ int max_fail_count; - /*! \brief The tuning context of the search strategy. */ - const TuneContextNode* context_{nullptr}; /*! \brief The random state. -1 means using random number. */ TRandState rand_state_ = -1; + /*! \brief The IRModule to be scheduled from TuneContext. */ + Optional mod_ = NullOpt; + /*! \brief The number of threads to be used. */ + int num_threads_ = -1; + /*! \brief The postprocessors. */ + Array postprocs_ = {}; /*! \brief The state of the search strategy. */ std::unique_ptr state_ = nullptr; void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("num_trials_per_iter", &num_trials_per_iter); - v->Visit("max_trials_per_task", &max_trials_per_task); v->Visit("max_fail_count", &max_fail_count); - // `context_` is not visited. // `rand_state_` is not visited + // `mod_` is not visited + // `num_threads_` is not visited + // `postprocs_` is not visited // `state_` is not visited } static constexpr const char* _type_key = "meta_schedule.ReplayTrace"; TVM_DECLARE_FINAL_OBJECT_INFO(ReplayTraceNode, SearchStrategyNode); - void InitializeWithTuneContext(const TuneContext& context) final { - CHECK(context->mod.defined()) << "ValueError: TuneContext.mod is not defined"; - this->context_ = context.get(); - this->rand_state_ = ForkSeed(&context->rand_state); + void InitializeWithTuneContext(const TuneContext& ctx) final { + CHECK(ctx->mod.defined()) << "ValueError: TuneContext.mod is not defined"; + CHECK(ctx->space_generator.defined()) + << "ValueError: TuneContext.space_generator is not defined"; + if (!ctx->space_generator.value()->postprocs.defined()) { + TVM_PY_LOG(WARNING, ctx->logger) + << "`postprocs` is not defined in " << ctx->space_generator.value() + << ". Please explicitly set `postprocs` to an empty list if you don't want to " + "apply any post-processing."; + } + this->rand_state_ = ForkSeed(&ctx->rand_state); + this->mod_ = ctx->mod; + this->num_threads_ = ctx->num_threads; + this->postprocs_ = ctx->space_generator.value()->postprocs.value_or({}); this->state_.reset(); } - void PreTuning(const Array& design_spaces, const Optional& database, - const Optional& cost_model) final { + void PreTuning(int max_trials, int num_trials_per_iter, const Array& design_spaces, + const Optional& database, const Optional& cost_model) final { ICHECK(!design_spaces.empty()); - CHECK(this->context_ != nullptr) << "ValueError: Did you forget to initialize the TuneContext?"; - if (this->state_ != nullptr) { - TVM_PY_LOG(WARNING, this->context_->logging_func) << "RelayTrace is already initialized."; - this->state_.reset(); - } - ICHECK(this->state_ == nullptr); + CHECK(this->state_ == nullptr) + << "ValueError: `PreTuning` is already invoked without corresponding `PostTuning`."; Array design_space_traces; design_space_traces.reserve(design_spaces.size()); for (const tir::Schedule& space : design_spaces) { design_space_traces.push_back(space->trace().value()->Simplified(true)); } - this->state_ = std::make_unique(this, design_space_traces); + this->state_ = + std::make_unique(this, design_space_traces, max_trials, num_trials_per_iter); } void PostTuning() final { @@ -121,10 +136,7 @@ class ReplayTraceNode : public SearchStrategyNode { SearchStrategy Clone() const final { ObjectPtr n = make_object(); - n->num_trials_per_iter = this->num_trials_per_iter; - n->max_trials_per_task = this->max_trials_per_task; n->max_fail_count = this->max_fail_count; - n->context_ = this->context_; n->rand_state_ = this->rand_state_; n->state_ = nullptr; // cleared the state return SearchStrategy(n); @@ -132,16 +144,14 @@ class ReplayTraceNode : public SearchStrategyNode { }; inline Optional> ReplayTraceNode::State::GenerateMeasureCandidates() { - if (st >= self->max_trials_per_task) { + if (st >= max_trials) { return NullOpt; } - ed = std::min(ed, self->max_trials_per_task); + ed = std::min(ed, max_trials); ICHECK_LT(st, ed); - const TuneContextNode* ctx = self->context_; - ICHECK(ctx); - std::vector per_thread_rand_state = ForkSeed(&self->rand_state_, ctx->num_threads); + std::vector per_thread_rand_state = ForkSeed(&self->rand_state_, self->num_threads_); Array per_task_result(ed - st, MeasureCandidate{nullptr}); - ThreadedTraceApply pp(ctx->postprocs); + ThreadedTraceApply pp(self->postprocs_); auto f_worker = [this, &per_thread_rand_state, &per_task_result, &pp](int thread_id, int task_id) -> void { TRandState& rand_state = per_thread_rand_state[thread_id]; @@ -159,7 +169,7 @@ inline Optional> ReplayTraceNode::State::GenerateMeasure } } }; - support::parallel_for_dynamic(0, ed - st, ctx->num_threads, f_worker); + support::parallel_for_dynamic(0, ed - st, self->num_threads_, f_worker); Array filtered; filtered.reserve(ed - st); for (MeasureCandidate result : per_task_result) @@ -170,15 +180,12 @@ inline Optional> ReplayTraceNode::State::GenerateMeasure } inline void ReplayTraceNode::State::NotifyRunnerResults(const Array& results) { - st += self->num_trials_per_iter; - ed += self->num_trials_per_iter; + st += num_trials_per_iter; + ed += num_trials_per_iter; } -SearchStrategy SearchStrategy::ReplayTrace(int num_trials_per_iter, int max_trials_per_task, - int max_fail_count) { +SearchStrategy SearchStrategy::ReplayTrace(int max_fail_count) { ObjectPtr n = make_object(); - n->num_trials_per_iter = num_trials_per_iter; - n->max_trials_per_task = max_trials_per_task; n->max_fail_count = max_fail_count; return SearchStrategy(n); } diff --git a/src/meta_schedule/search_strategy/search_strategy.cc b/src/meta_schedule/search_strategy/search_strategy.cc index 81c7fda315b4..641457226d11 100644 --- a/src/meta_schedule/search_strategy/search_strategy.cc +++ b/src/meta_schedule/search_strategy/search_strategy.cc @@ -34,11 +34,12 @@ void PySearchStrategyNode::InitializeWithTuneContext(const TuneContext& context) f_initialize_with_tune_context(context); } -void PySearchStrategyNode::PreTuning(const Array& design_spaces, +void PySearchStrategyNode::PreTuning(int max_trials, int num_trials_per_iter, + const Array& design_spaces, const Optional& database, const Optional& cost_model) { ICHECK(f_pre_tuning != nullptr) << "PySearchStrategy's PreTuning method not implemented!"; - f_pre_tuning(design_spaces, database, cost_model); + f_pre_tuning(max_trials, num_trials_per_iter, design_spaces, database, cost_model); } void PySearchStrategyNode::PostTuning() { diff --git a/src/meta_schedule/space_generator/post_order_apply.cc b/src/meta_schedule/space_generator/post_order_apply.cc index 991e4fa08047..8eb2760dc791 100644 --- a/src/meta_schedule/space_generator/post_order_apply.cc +++ b/src/meta_schedule/space_generator/post_order_apply.cc @@ -89,31 +89,27 @@ class BlockCollector : public tir::StmtVisitor { * */ class PostOrderApplyNode : public SpaceGeneratorNode { public: - /*! \brief The random state. -1 means using random number. */ - TRandState rand_state_ = -1; - /*! \brief The schedule rules to be applied in order. */ - Array sch_rules_{nullptr}; - /*! \brief The logging function to use. */ - PackedFunc logging_func; - /*! \brief Optional block names to target. If not specified all blocks will have spaces generated. + /*! + * \brief Optional block names to target. If not specified all blocks will have spaces generated. */ runtime::PackedFunc f_block_filter_ = nullptr; + /*! \brief The random state. -1 means using random number. */ + TRandState rand_state_ = -1; void VisitAttrs(tvm::AttrVisitor* v) { + SpaceGeneratorNode::VisitAttrs(v); // `rand_state_` is not visited // `sch_rules_` is not visited } void InitializeWithTuneContext(const TuneContext& context) final { + SpaceGeneratorNode::InitializeWithTuneContext(context); this->rand_state_ = ForkSeed(&context->rand_state); - CHECK(context->sch_rules.defined()) - << "ValueError: Schedules rules not given in PostOrderApply!"; - this->sch_rules_ = context->sch_rules; - this->logging_func = context->logging_func; } Array GenerateDesignSpace(const IRModule& mod) final { using ScheduleAndUnvisitedBlocks = std::pair>; + CHECK(sch_rules.defined()) << "ValueError: `sch_rules` is not set in PostOrderApply"; tir::Schedule sch = tir::Schedule::Traced( /*mod=*/mod, /*rand_state=*/ForkSeed(&this->rand_state_), @@ -126,7 +122,7 @@ class PostOrderApplyNode : public SpaceGeneratorNode { // always concat multiple schedule rules as one Array all_blocks = BlockCollector::Collect(sch, f_block_filter_); Array> rules{NullOpt}; - rules.insert(rules.end(), sch_rules_.begin(), sch_rules_.end()); + rules.insert(rules.end(), sch_rules.value().begin(), sch_rules.value().end()); for (Optional sch_rule : rules) { if (sch_rule.defined()) { for (const tir::Schedule& sch : result) { @@ -191,19 +187,22 @@ class PostOrderApplyNode : public SpaceGeneratorNode { SpaceGenerator Clone() const final { ObjectPtr n = make_object(*this); - n->sch_rules_ = Array(); - for (const ScheduleRule& sch_rule : this->sch_rules_) { - n->sch_rules_.push_back(sch_rule->Clone()); - } + CloneRules(this, n.get()); return SpaceGenerator(n); } static constexpr const char* _type_key = "meta_schedule.PostOrderApply"; TVM_DECLARE_FINAL_OBJECT_INFO(PostOrderApplyNode, SpaceGeneratorNode); }; -SpaceGenerator SpaceGenerator::PostOrderApply(runtime::PackedFunc f_block_filter) { +SpaceGenerator SpaceGenerator::PostOrderApply(runtime::PackedFunc f_block_filter, + Optional> sch_rules, + Optional> postprocs, + Optional> mutator_probs) { ObjectPtr n = make_object(); - n->f_block_filter_ = f_block_filter; + n->sch_rules = std::move(sch_rules); + n->postprocs = std::move(postprocs); + n->mutator_probs = std::move(mutator_probs); + n->f_block_filter_ = std::move(f_block_filter); return SpaceGenerator(n); } diff --git a/src/meta_schedule/space_generator/schedule_fn.cc b/src/meta_schedule/space_generator/schedule_fn.cc index adea139b1cd4..48fbc82aba02 100644 --- a/src/meta_schedule/space_generator/schedule_fn.cc +++ b/src/meta_schedule/space_generator/schedule_fn.cc @@ -30,10 +30,12 @@ class ScheduleFnNode : public SpaceGeneratorNode { runtime::PackedFunc schedule_fn_; void VisitAttrs(tvm::AttrVisitor* v) { + SpaceGeneratorNode::VisitAttrs(v); // `schedule_fn_` is not visited. } void InitializeWithTuneContext(const TuneContext& context) final { + SpaceGeneratorNode::InitializeWithTuneContext(context); this->rand_state_ = ForkSeed(&context->rand_state); } @@ -74,6 +76,7 @@ class ScheduleFnNode : public SpaceGeneratorNode { SpaceGenerator Clone() const final { ObjectPtr n = make_object(*this); + CloneRules(this, n.get()); return SpaceGenerator(n); } @@ -81,8 +84,14 @@ class ScheduleFnNode : public SpaceGeneratorNode { TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleFnNode, SpaceGeneratorNode); }; -SpaceGenerator SpaceGenerator::ScheduleFn(PackedFunc schedule_fn) { +SpaceGenerator SpaceGenerator::ScheduleFn(PackedFunc schedule_fn, + Optional> sch_rules, + Optional> postprocs, + Optional> mutator_probs) { ObjectPtr n = make_object(); + n->sch_rules = std::move(sch_rules); + n->postprocs = std::move(postprocs); + n->mutator_probs = std::move(mutator_probs); n->schedule_fn_ = std::move(schedule_fn); return SpaceGenerator(n); } diff --git a/src/meta_schedule/space_generator/space_generator.cc b/src/meta_schedule/space_generator/space_generator.cc index 6fc31ed896f2..53107bafb2c0 100644 --- a/src/meta_schedule/space_generator/space_generator.cc +++ b/src/meta_schedule/space_generator/space_generator.cc @@ -21,6 +21,97 @@ namespace tvm { namespace meta_schedule { +String GetRuleKindFromTarget(const Target& target) { + if (target->kind->name == "llvm") { + return "llvm"; + } + if (target->kind->name == "hexagon") { + return "hexagon"; + } + if (target->kind->name == "cuda") { + if (Optional opt_sm = target->GetAttr("arch")) { + std::string sm = opt_sm.value(); + if (support::StartsWith(sm, "sm_")) { + sm = sm.substr(3); + try { + if (std::stoi(sm) >= 75) { + return "cuda_tensorcore"; + } + } catch (const std::invalid_argument& e) { + LOG(WARNING) << "ValueError: Unable to parse `target.arch`: " << sm + << ". Details: " << e.what(); + } + } + } + return "cuda"; + } + if (target->kind->name == "rocm") { + return "cuda"; + } + if (target->kind->name == "vulkan") { + return "cuda"; + } + LOG(FATAL) << "Unsupported target: " << target; + throw; +} + +void SpaceGeneratorNode::InitializeWithTuneContext(const TuneContext& context) { + if (context->target.defined() && // + !(sch_rules.defined() && // + postprocs.defined() && // + mutator_probs.defined())) { + String kind = GetRuleKindFromTarget(context->target.value()); + Array default_sch_rules; + Array default_postprocs; + Map default_mutator_probs; + if (kind == "llvm") { + default_sch_rules = ScheduleRule::DefaultLLVM(); + default_postprocs = Postproc::DefaultLLVM(); + default_mutator_probs = Mutator::DefaultLLVM(); + } else if (kind == "cuda") { + default_sch_rules = ScheduleRule::DefaultCUDA(); + default_postprocs = Postproc::DefaultCUDA(); + default_mutator_probs = Mutator::DefaultCUDA(); + } else if (kind == "cuda_tensorcore") { + default_sch_rules = ScheduleRule::DefaultCUDATensorCore(); + default_postprocs = Postproc::DefaultCUDATensorCore(); + default_mutator_probs = Mutator::DefaultCUDATensorCore(); + } else if (kind == "hexagon") { + default_sch_rules = ScheduleRule::DefaultHexagon(); + default_postprocs = Postproc::DefaultHexagon(); + default_mutator_probs = Mutator::DefaultHexagon(); + } else { + LOG(FATAL) << "Unsupported kind: " << kind; + throw; + } + if (!sch_rules.defined()) { + sch_rules = default_sch_rules; + } + if (!postprocs.defined()) { + postprocs = default_postprocs; + } + if (!mutator_probs.defined()) { + mutator_probs = default_mutator_probs; + } + } + if (sch_rules.defined()) { + for (ScheduleRule i : sch_rules.value()) { + i->InitializeWithTuneContext(context); + } + } + if (postprocs.defined()) { + for (Postproc i : postprocs.value()) { + i->InitializeWithTuneContext(context); + } + } + if (mutator_probs.defined()) { + for (const auto& kv : mutator_probs.value()) { + Mutator mutator = kv.first; + mutator->InitializeWithTuneContext(context); + } + } +} + void PySpaceGeneratorNode::InitializeWithTuneContext(const TuneContext& context) { ICHECK(f_initialize_with_tune_context != nullptr) << "PySpaceGenerator's InitializeWithTuneContext method not implemented!"; @@ -39,9 +130,14 @@ SpaceGenerator PySpaceGeneratorNode::Clone() const { } SpaceGenerator SpaceGenerator::PySpaceGenerator( + Optional> sch_rules, Optional> postprocs, + Optional> mutator_probs, FInitializeWithTuneContext f_initialize_with_tune_context, FGenerateDesignSpace f_generate_design_space, FClone f_clone) { ObjectPtr n = make_object(); + n->sch_rules = sch_rules; + n->postprocs = postprocs; + n->mutator_probs = mutator_probs; n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context); n->f_generate_design_space = std::move(f_generate_design_space); n->f_clone = std::move(f_clone); diff --git a/src/meta_schedule/space_generator/space_generator_union.cc b/src/meta_schedule/space_generator/space_generator_union.cc index 771d0c187f97..819a4ee5f795 100644 --- a/src/meta_schedule/space_generator/space_generator_union.cc +++ b/src/meta_schedule/space_generator/space_generator_union.cc @@ -27,10 +27,13 @@ class SpaceGeneratorUnionNode : public SpaceGeneratorNode { /*! \brief The array of design space generators unioned, could be recursive. */ Array space_generators; - void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("space_generators", &space_generators); } + void VisitAttrs(tvm::AttrVisitor* v) { + SpaceGeneratorNode::VisitAttrs(v); + v->Visit("space_generators", &space_generators); + } void InitializeWithTuneContext(const TuneContext& context) final { - // Initialize each space generator. + SpaceGeneratorNode::InitializeWithTuneContext(context); for (const SpaceGenerator& space_generator : space_generators) { space_generator->InitializeWithTuneContext(context); } @@ -53,6 +56,7 @@ class SpaceGeneratorUnionNode : public SpaceGeneratorNode { for (const SpaceGenerator& space_generator : this->space_generators) { n->space_generators.push_back(space_generator->Clone()); } + CloneRules(this, n.get()); return SpaceGenerator(n); } @@ -65,8 +69,14 @@ class SpaceGeneratorUnionNode : public SpaceGeneratorNode { * \param space_generators Array of the design space generators to be unioned. * \return The design space generator created. */ -SpaceGenerator SpaceGenerator::SpaceGeneratorUnion(Array space_generators) { +SpaceGenerator SpaceGenerator::SpaceGeneratorUnion(Array space_generators, + Optional> sch_rules, + Optional> postprocs, + Optional> mutator_probs) { ObjectPtr n = make_object(); + n->sch_rules = std::move(sch_rules); + n->postprocs = std::move(postprocs); + n->mutator_probs = std::move(mutator_probs); n->space_generators = std::move(space_generators); return SpaceGenerator(n); } diff --git a/src/meta_schedule/task_scheduler/gradient_based.cc b/src/meta_schedule/task_scheduler/gradient_based.cc index 506bb620e1d8..bae52573a0f9 100644 --- a/src/meta_schedule/task_scheduler/gradient_based.cc +++ b/src/meta_schedule/task_scheduler/gradient_based.cc @@ -21,236 +21,122 @@ namespace tvm { namespace meta_schedule { -struct TaskRecord { - TuneContext task; - double weight; - double flop; - std::vector best_time_cost_history; // in ms - int trials; -}; - /*! \brief The gradient based task scheduler. */ class GradientBasedNode final : public TaskSchedulerNode { public: - // Parameters used in gradient computation double alpha; int window_size; + support::LinearCongruentialEngine::TRandState rand_state; - std::vector task_records_; - std::vector best_time_cost_per_task_; // in ms - int num_rounds_already_; - support::LinearCongruentialEngine::TRandState rand_state_; + int round_robin_rounds_; + std::vector> best_latency_history_; void VisitAttrs(tvm::AttrVisitor* v) { TaskSchedulerNode::VisitAttrs(v); v->Visit("alpha", &alpha); v->Visit("window_size", &window_size); - // `task_records_` is not visited. - // `best_time_cost_per_task_` is not visited. + // `rand_state` is not visited. // `num_rounds_already_` is not visited. - // `rand_state_` is not visited. + // `best_latency_history_` is not visited. } static constexpr const char* _type_key = "meta_schedule.GradientBased"; TVM_DECLARE_FINAL_OBJECT_INFO(GradientBasedNode, TaskSchedulerNode); public: - std::string TuningStatistics() const { - std::ostringstream os; - int n_tasks = task_records_.size(); - int total_trials = 0; - double total_latency = 0.0; - support::TablePrinter p; - - if (using_ipython()) { - p.Row() << "ID" - << "Name" - << "FLOP" - << "Weight" - << "GFLOPS" - << "Latency (us)" - << "Wtd. Latency" - << "Trials" - << "Terminated"; - } else { - p.Row() << "ID" - << "Name" - << "FLOP" - << "Weight" - << "Speed (GFLOPS)" - << "Latency (us)" - << "Weighted Latency (us)" - << "Trials" - << "Terminated"; - } - - p.Separator(); - - for (int i = 0; i < n_tasks; ++i) { - const TaskRecord& record = task_records_[i]; - auto row = p.Row(); - int trials = record.trials; - String task_name = record.task->task_name.value(); - if (using_ipython() && task_name.length() > 23) { - std::string temp = task_name.c_str(); - temp = temp.substr(0, 20) + "..."; - task_name = String(temp); - } - row << /*id=*/i // - << /*name=*/task_name // - << /*flops=*/static_cast(record.flop) // - << /*weight=*/static_cast(record.weight); - double latency = 1e9; - if (trials > 0) { - latency = record.best_time_cost_history.back(); - } - if (latency >= 1e9) { - row << /*speed=*/"N/A" << /*latency=*/"N/A" << /*weighted_latency=*/"N/A"; - } else { - latency *= 1000.0; - double speed = record.flop / latency / 1000.0; - double weighted_latency = latency * record.weight; - row << /*speed=*/speed << /*latency=*/latency << /*weighted_latency=*/weighted_latency; - total_latency += weighted_latency; - total_trials += trials; - } - row << trials; - if (tasks[i]->is_terminated) { - row << "Y"; - } else { - row << ""; - } - } - p.Separator(); - os << p.AsStr() // - << "\nProgress: " << total_trials / (max_trials * 0.01) << "%" // - << "\nTotal Trials: " << total_trials << " / " << max_trials // - << "\nTotal latency (us): " << total_latency // - << "\n"; - return os.str(); + void Tune(Array tasks, Array task_weights, int max_trials_global, + int max_trials_per_task, int num_trials_per_iter, Builder builder, Runner runner, + Array measure_callbacks, Optional database, + Optional cost_model) final { + int n_tasks = tasks.size(); + round_robin_rounds_ = 0; + best_latency_history_.resize(n_tasks, std::vector()); + TaskSchedulerNode::Tune(tasks, task_weights, max_trials_global, max_trials_per_task, + num_trials_per_iter, builder, runner, measure_callbacks, database, + cost_model); } int NextTaskId() final { - int n_tasks = task_records_.size(); - // Round robin - if (num_rounds_already_ == 0) { - TVM_PY_LOG_CLEAR_SCREEN(this->logging_func); - TVM_PY_LOG(INFO, this->logging_func) << "\n" << this->TuningStatistics(); + int n_tasks = this->tasks_.size(); + // Step 1. Check if it's in round robin mode. + if (round_robin_rounds_ == 0) { + TVM_PY_LOG(INFO, this->logger) << "\n" << this->TuningStatistics(); } - if (num_rounds_already_ < n_tasks) { - return num_rounds_already_++; + if (round_robin_rounds_ < n_tasks) { + return round_robin_rounds_++; } - if (num_rounds_already_ == n_tasks) { + if (round_robin_rounds_ == n_tasks) { for (int i = 0; i < n_tasks; ++i) { this->JoinRunningTask(i); } + ++round_robin_rounds_; } - ++num_rounds_already_; - // Check running tasks + // Step 2. Collect the tasks that are not terminated yet std::vector tasks_alive; - tasks_alive.reserve(n_tasks); - for (int i = 0; i < n_tasks; ++i) { - this->TouchTask(i); - if (!tasks[i]->is_terminated) { - tasks_alive.push_back(i); + { + tasks_alive.reserve(n_tasks); + for (int i = 0; i < n_tasks; ++i) { + this->TouchTask(i); + if (!this->tasks_[i]->is_terminated) { + tasks_alive.push_back(i); + } + } + if (tasks_alive.empty()) { + return -1; } } - if (tasks_alive.empty()) { - return -1; - } + // Step 3. Calculate the gradient of each task alive std::vector grad; grad.reserve(n_tasks); for (int task_id : tasks_alive) { - const TaskRecord& record = task_records_[task_id]; - const int w = this->window_size; - int n = record.best_time_cost_history.size(); + const std::vector& best_latency = this->best_latency_history_.at(task_id); + int n = best_latency.size(); ICHECK_GE(n, 1); - double best = record.best_time_cost_history[n - 1]; + double task_weight = this->tasks_[task_id]->task_weight; + int w = this->window_size; + double best = best_latency[n - 1]; if (best < 1e9) { - double g1 = (n >= 1 + w) ? (record.best_time_cost_history[n - 1 - w] - best) / w : 0.0; + double g1 = (n >= 1 + w) ? (best_latency[n - 1 - w] - best) / w : 0.0; double g2 = best / n; double g = alpha * g1 + (1 - alpha) * g2; - grad.push_back(g * record.weight); + grad.push_back(g * task_weight); } else { // If the best time cost is unavailable, it means some task is not valid. Skip it. grad.push_back(-1e9); } } + // Step 4. Select the task with the largest gradient auto max_grad = std::max_element(grad.begin(), grad.end()); auto min_grad = std::min_element(grad.begin(), grad.end()); int task_id = -1; if (*max_grad == *min_grad) { - task_id = tasks_alive[tir::SampleInt(&rand_state_, 0, tasks_alive.size())]; + task_id = tasks_alive[tir::SampleInt(&this->rand_state, 0, tasks_alive.size())]; } else { task_id = tasks_alive[std::distance(grad.begin(), max_grad)]; } - if (tasks[task_id]->runner_futures.defined()) { + if (this->tasks_[task_id]->runner_futures.defined()) { JoinRunningTask(task_id); } return task_id; } Array JoinRunningTask(int task_id) final { - TaskRecord& record = task_records_[task_id]; Array results = TaskSchedulerNode::JoinRunningTask(task_id); - double& best_time_cost = this->best_time_cost_per_task_[task_id]; - for (const RunnerResult& result : results) { - if (!result->error_msg.defined()) { - best_time_cost = std::min(best_time_cost, GetRunMsMedian(result)); - } - } - record.best_time_cost_history.push_back(best_time_cost); - record.trials += results.size(); - TVM_PY_LOG_CLEAR_SCREEN(this->logging_func); - TVM_PY_LOG(INFO, this->logging_func) - << "[Updated] Task #" << task_id << ": " << record.task->task_name << "\n" - << this->TuningStatistics(); + TaskRecordNode* task = this->tasks_[task_id].get(); + this->best_latency_history_.at(task_id).push_back( + *std::min_element(task->latency_ms.begin(), // + task->latency_ms.end())); return results; } }; -TaskScheduler TaskScheduler::GradientBased(Array tasks, // - Array task_weights, // - Builder builder, // - Runner runner, // - Optional database, // - Optional cost_model, // - Optional> measure_callbacks, // - int max_trials, // - PackedFunc logging_func, // - double alpha, // - int window_size, // +TaskScheduler TaskScheduler::GradientBased(PackedFunc logger, double alpha, int window_size, support::LinearCongruentialEngine::TRandState seed) { - CHECK_EQ(tasks.size(), task_weights.size()) - << "The size of `tasks` should have the same as `task_weights`."; - int n_tasks = tasks.size(); - std::vector task_records; - task_records.reserve(n_tasks); - for (int i = 0; i < n_tasks; ++i) { - task_records.push_back(TaskRecord{ - /*task=*/tasks[i], - /*weights=*/task_weights[i]->value, - /*flop=*/std::max(1.0, tir::EstimateTIRFlops(tasks[i]->mod.value())), - /*best_time_cost_history=*/{}, - /*trials=*/0, - }); - } ObjectPtr n = make_object(); - n->tasks = tasks; - n->builder = builder; - n->runner = runner; - n->database = database; - n->max_trials = max_trials; - n->cost_model = cost_model; - n->measure_callbacks = measure_callbacks.value_or({}); - n->logging_func = logging_func; - n->num_trials_already = 0; + n->logger = logger; n->alpha = alpha; n->window_size = window_size; - n->task_records_ = std::move(task_records); - n->best_time_cost_per_task_ = std::vector(n_tasks, 1e100); - n->num_rounds_already_ = 0; - support::LinearCongruentialEngine(&n->rand_state_).Seed(seed); + n->rand_state = support::LinearCongruentialEngine::NormalizeSeed(seed); return TaskScheduler(n); } diff --git a/src/meta_schedule/task_scheduler/round_robin.cc b/src/meta_schedule/task_scheduler/round_robin.cc index ea22878840af..d09f2c2ba791 100644 --- a/src/meta_schedule/task_scheduler/round_robin.cc +++ b/src/meta_schedule/task_scheduler/round_robin.cc @@ -37,13 +37,13 @@ class RoundRobinNode final : public TaskSchedulerNode { protected: int NextTaskId() final { - int n_tasks = this->tasks.size(); + int n_tasks = this->tasks_.size(); for (int i = 0; i < n_tasks; ++i) { this->TouchTask(i); } for (int i = 0; i < n_tasks; ++i) { task_id = (task_id + 1) % n_tasks; - TuneContext task = tasks[task_id]; + TaskRecordNode* task = this->tasks_[task_id].get(); if (!task->is_terminated) { if (task->runner_futures.defined()) { JoinRunningTask(task_id); @@ -55,24 +55,9 @@ class RoundRobinNode final : public TaskSchedulerNode { } }; -TaskScheduler TaskScheduler::RoundRobin(Array tasks, // - Builder builder, // - Runner runner, // - Optional database, // - Optional cost_model, // - Optional> measure_callbacks, // - int max_trials, // - PackedFunc logging_func) { +TaskScheduler TaskScheduler::RoundRobin(PackedFunc logger) { ObjectPtr n = make_object(); - n->tasks = tasks; - n->builder = builder; - n->runner = runner; - n->database = database; - n->max_trials = max_trials; - n->cost_model = cost_model; - n->measure_callbacks = measure_callbacks.value_or({}); - n->logging_func = logging_func; - n->num_trials_already = 0; + n->logger = logger; n->task_id = -1; return TaskScheduler(n); } diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc index ea233648f4f5..21efde26d993 100644 --- a/src/meta_schedule/task_scheduler/task_scheduler.cc +++ b/src/meta_schedule/task_scheduler/task_scheduler.cc @@ -21,83 +21,225 @@ namespace tvm { namespace meta_schedule { -void TaskSchedulerNode::InitializeTask(int task_id) { +TaskRecord::TaskRecord(TuneContext ctx, double task_weight) { + ObjectPtr n = runtime::make_object(); + n->ctx = ctx; + n->task_weight = task_weight; + n->flop = 1.0; auto _ = Profiler::TimedScope("InitializeTask"); - TuneContext task = this->tasks[task_id]; - TVM_PY_LOG(INFO, this->logging_func) - << "Initializing Task #" << task_id << ": " << task->task_name; - TVM_PY_LOG(INFO, task->logging_func) - << "Initializing Task #" << task_id << ": " << task->task_name; - CHECK(task->mod.defined()) << "ValueError: Require `context.mod`, but it is not defined"; - CHECK(task->space_generator.defined()) + CHECK(ctx->mod.defined()) << "ValueError: Require `context.mod`, but it is not defined"; + CHECK(ctx->space_generator.defined()) << "ValueError: Require `context.space_generator`, but it is not defined"; - CHECK(task->search_strategy.defined()) + CHECK(ctx->search_strategy.defined()) << "ValueError: Require `context.search_strategy`, but it is not defined"; - TVM_PY_LOG(INFO, task->logging_func) << "\n" << tir::AsTVMScript(task->mod); - task->Initialize(); - Array design_spaces = - task->space_generator.value()->GenerateDesignSpace(task->mod.value()); - TVM_PY_LOG(INFO, task->logging_func) - << "Total " << design_spaces.size() << " design space(s) generated"; - for (int i = 0, n = design_spaces.size(); i < n; ++i) { - tir::Schedule sch = design_spaces[i]; - tir::Trace trace = sch->trace().value(); - trace = trace->Simplified(true); - TVM_PY_LOG(INFO, task->logging_func) << "Design space #" << i << ":\n" - << tir::AsTVMScript(sch->mod()) << "\n" - << Concat(trace->AsPython(false), "\n"); + TVM_PY_LOG(INFO, ctx->logger) << "\n" << tir::AsTVMScript(ctx->mod); + ctx->Initialize(); + n->flop = std::max(1.0, tir::EstimateTIRFlops(ctx->mod.value())); + this->data_ = std::move(n); +} + +void SendToBuilder(TaskRecordNode* self, const Builder& builder) { + auto _ = Profiler::TimedScope("SendToBuilder"); + Array candidates = self->measure_candidates.value(); + Target target = self->ctx->target.value(); + Array inputs; + inputs.reserve(candidates.size()); + for (const MeasureCandidate& candidate : candidates) { + inputs.push_back(BuilderInput(candidate->sch->mod(), target)); } - task->search_strategy.value()->PreTuning(design_spaces, database, cost_model); + self->builder_results = builder->Build(inputs); } -void TaskSchedulerNode::Tune() { - int n_tasks = this->tasks.size(); - for (int task_id = 0; task_id < n_tasks; ++task_id) { - InitializeTask(task_id); +void SendToRunner(TaskRecordNode* self, const Runner& runner) { + auto _ = Profiler::TimedScope("SendToRunner"); + Array candidates = self->measure_candidates.value(); + Array builder_results = self->builder_results.value(); + Target target = self->ctx->target.value(); + ICHECK_EQ(candidates.size(), builder_results.size()); + int n = candidates.size(); + int n_build_errors = 0; + Array inputs; + inputs.reserve(n); + for (int i = 0; i < n; ++i) { + const MeasureCandidate& candidate = candidates[i]; + const BuilderResult& builder_result = builder_results[i]; + if (builder_result->error_msg.defined()) { + ++n_build_errors; + continue; + } + inputs.push_back(RunnerInput(/*artifact_path=*/builder_result->artifact_path.value(), + /*device_type=*/target->kind->name, + /*args_info=*/candidate->args_info)); + } + Array futures = runner->Run(inputs); + if (n_build_errors == 0) { + self->runner_futures = futures; + return; + } + Array results; + results.reserve(n); + for (int i = 0, j = 0; i < n; ++i) { + const BuilderResult& builder_result = builder_results[i]; + if (builder_result->error_msg.defined()) { + results.push_back(RunnerFuture( + /*f_done=*/[]() -> bool { return true; }, + /*f_result=*/ + [msg = builder_result->error_msg]() -> RunnerResult { + return RunnerResult(NullOpt, msg); + })); + } else { + results.push_back(futures[j++]); + } + } + self->runner_futures = results; +} + +void TaskCleanUp(TaskRecordNode* self, int task_id, const Array& results) { + ICHECK_EQ(self->builder_results.value().size(), results.size()); + ICHECK_EQ(self->runner_futures.value().size(), results.size()); + int n = results.size(); + std::string name = self->ctx->task_name.value(); + const PackedFunc& logger = self->ctx->logger; + for (int i = 0; i < n; ++i) { + const BuilderResult& builder_result = self->builder_results.value()[i]; + const MeasureCandidate& candidate = self->measure_candidates.value()[i]; + const RunnerResult& runner_result = results[i]; + Optional error_msg = NullOpt; + int trials = self->latency_ms.size() + 1; + double run_ms = 1e9; + if ((error_msg = builder_result->error_msg)) { + ++self->build_error_count; + } else if ((error_msg = runner_result->error_msg)) { + ++self->run_error_count; + } else { + run_ms = GetRunMsMedian(runner_result); + } + self->latency_ms.push_back(run_ms); + if (error_msg) { + const tir::Schedule& sch = candidate->sch; + std::string err = error_msg.value(); + TVM_PY_LOG(INFO, logger) << std::fixed << std::setprecision(4) // + << "[Task #" << task_id << ": " << name << "] Trial #" << trials + << ": Error in building:\n" + << err << "\n" + << tir::AsTVMScript(sch->mod()) << "\n" + << Concat(sch->trace().value()->AsPython(false), "\n"); + } else { + double best_ms = *std::min_element(self->latency_ms.begin(), self->latency_ms.end()); + TVM_PY_LOG(INFO, logger) << std::fixed << std::setprecision(4) // + << "[Task #" << task_id << ": " << name << "] Trial #" << trials + << ": GFLOPs: " << (self->flop / run_ms / 1e6) + << ". Time: " << (run_ms * 1e3) << " us" + << ". Best GFLOPs: " << (self->flop / best_ms / 1e6); + } } - int running_tasks = tasks.size(); - for (int task_id; num_trials_already < max_trials && (task_id = NextTaskId()) != -1;) { - TVM_PY_LOG(INFO, this->logging_func) - << "Scheduler picks Task #" << task_id << ": " << tasks[task_id]->task_name; - TuneContext task = tasks[task_id]; + self->measure_candidates = NullOpt; + self->builder_results = NullOpt; + self->runner_futures = NullOpt; +} + +void TaskSchedulerNode::Tune(Array ctxs, Array task_weights, + int max_trials_global, int max_trials_per_task, + int num_trials_per_iter, Builder builder, Runner runner, + Array measure_callbacks, Optional database, + Optional cost_model) { + CHECK_EQ(ctxs.size(), task_weights.size()) << "ValueError: `task_weights` must have the same " + "length as `ctxs`"; + int n_tasks = this->remaining_tasks_ = ctxs.size(); + this->measure_callbacks_ = measure_callbacks; + this->database_ = database; + this->cost_model_ = cost_model; + this->tasks_.clear(); + this->tasks_.reserve(n_tasks); + for (int i = 0; i < n_tasks; ++i) { + const TuneContext& ctx = ctxs[i]; + double weight = task_weights[i]->value; + TVM_PY_LOG(INFO, this->logger) << "Initializing Task #" << i << ": " << ctx->task_name; + TVM_PY_LOG(INFO, ctx->logger) << "Initializing Task #" << i << ": " << ctx->task_name; + this->tasks_.push_back(TaskRecord(ctx, weight)); + Array design_spaces = + ctx->space_generator.value()->GenerateDesignSpace(ctx->mod.value()); + TVM_PY_LOG(INFO, ctx->logger) << "Total " << design_spaces.size() + << " design space(s) generated"; + for (int i = 0, n = design_spaces.size(); i < n; ++i) { + tir::Schedule sch = design_spaces[i]; + tir::Trace trace = sch->trace().value(); + trace = trace->Simplified(true); + TVM_PY_LOG(INFO, ctx->logger) << "Design space #" << i << ":\n" + << tir::AsTVMScript(sch->mod()) << "\n" + << Concat(trace->AsPython(false), "\n"); + } + ctx->search_strategy.value()->PreTuning(max_trials_per_task, num_trials_per_iter, design_spaces, + database, cost_model); + } + + int num_trials_already = 0; + for (int task_id; num_trials_already < max_trials_global && (task_id = NextTaskId()) != -1;) { + TVM_PY_LOG(INFO, this->logger) + << "TaskScheduler picks Task #" << task_id << ": " << tasks_[task_id]->ctx->task_name; + TaskRecordNode* task = tasks_[task_id].get(); ICHECK(!task->is_terminated); ICHECK(!task->runner_futures.defined()); - if (Optional> candidates = - task->search_strategy.value()->GenerateMeasureCandidates()) { + if (static_cast(task->latency_ms.size()) >= max_trials_per_task) { + TerminateTask(task_id); + continue; + } + if (Optional> candidates = task->measure_candidates = + task->ctx->search_strategy.value()->GenerateMeasureCandidates()) { int num_candidates = candidates.value().size(); - task->_SetMeasureCandidates(candidates.value()); num_trials_already += num_candidates; - TVM_PY_LOG(INFO, this->logging_func) - << "Sending " << num_candidates << " sample(s) to builder"; - task->_SendToBuilder(this->builder); - TVM_PY_LOG(INFO, this->logging_func) - << "Sending " << num_candidates << " sample(s) to runner"; - task->_SendToRunner(this->runner); + TVM_PY_LOG(INFO, this->logger) << "Sending " << num_candidates << " sample(s) to builder"; + SendToBuilder(task, builder); + TVM_PY_LOG(INFO, this->logger) << "Sending " << num_candidates << " sample(s) to runner"; + SendToRunner(task, runner); } else { - ICHECK(!task->is_terminated); - task->is_terminated = true; - --running_tasks; - TVM_PY_LOG(INFO, this->logging_func) - << "Task #" << task_id << " has finished. Remaining task(s): " << running_tasks; + TerminateTask(task_id); } } for (int task_id = 0; task_id < n_tasks; ++task_id) { - TuneContext task = tasks[task_id]; + TaskRecordNode* task = this->tasks_[task_id].get(); if (!task->is_terminated) { if (task->runner_futures.defined()) { JoinRunningTask(task_id); } - task->is_terminated = true; - --running_tasks; - TVM_PY_LOG(INFO, this->logging_func) - << "Task #" << task_id << " has finished. Remaining task(s): " << running_tasks; + TerminateTask(task_id); } - task->search_strategy.value()->PostTuning(); + task->ctx->search_strategy.value()->PostTuning(); } } +Array TaskSchedulerNode::JoinRunningTask(int task_id) { + TaskRecordNode* task = this->tasks_[task_id].get(); + ICHECK(task->runner_futures.defined()); + Array results; + { + auto _ = Profiler::TimedScope("JoinRunnerFutures"); + Array futures = task->runner_futures.value(); + results.reserve(futures.size()); + for (RunnerFuture future : futures) { + results.push_back(future->Result()); + } + } + ICHECK(task->measure_candidates.defined()); + task->ctx->search_strategy.value()->NotifyRunnerResults(task->measure_candidates.value(), + results); + ICHECK(task->builder_results.defined()); + ICHECK_EQ(results.size(), task->measure_candidates.value().size()); + ICHECK_EQ(results.size(), task->builder_results.value().size()); + for (const MeasureCallback& callback : this->measure_callbacks_) { + callback->Apply(GetRef(this), task_id, task->measure_candidates.value(), + task->builder_results.value(), results); + } + TaskCleanUp(task, task_id, results); + TVM_PY_LOG_CLEAR_SCREEN(this->logger); + TVM_PY_LOG(INFO, this->logger) << "[Updated] Task #" << task_id << ": " << task->ctx->task_name + << "\n" + << this->TuningStatistics(); + return results; +} + void TaskSchedulerNode::TouchTask(int task_id) { - TuneContext task = tasks[task_id]; + TaskRecordNode* task = this->tasks_[task_id].get(); if (!task->is_terminated && task->runner_futures.defined()) { for (const RunnerFuture future : task->runner_futures.value()) { if (!future->Done()) { @@ -108,39 +250,85 @@ void TaskSchedulerNode::TouchTask(int task_id) { } } -Array TaskSchedulerNode::JoinRunningTask(int task_id) { - TuneContext task = tasks[task_id]; - Array results = task->_Join(); - for (const MeasureCallback& callback : this->measure_callbacks) { - callback->Apply(GetRef(this), task_id, task->measure_candidates.value(), - task->builder_results.value(), results); - } - task->_ClearMeasureState(); - return results; +void TaskSchedulerNode::TerminateTask(int task_id) { + TaskRecordNode* task = this->tasks_[task_id].get(); + ICHECK(!task->is_terminated); + task->is_terminated = true; + --this->remaining_tasks_; + TVM_PY_LOG_CLEAR_SCREEN(this->logger); + TVM_PY_LOG(INFO, this->logger) << "Task #" << task_id + << " has finished. Remaining task(s): " << this->remaining_tasks_ + << "\n" + << this->TuningStatistics(); } -void PyTaskSchedulerNode::Tune() { - if (f_tune == nullptr) { - TaskSchedulerNode::Tune(); - } else { - f_tune(); +std::string TaskSchedulerNode::TuningStatistics() const { + std::ostringstream os; + int n_tasks = this->tasks_.size(); + int total_trials = 0; + double total_latency = 0.0; + support::TablePrinter p; + p.Row() << "ID" + << "Name" + << "FLOP" + << "Weight" + << "Speed (GFLOPS)" + << "Latency (us)" + << "Weighted Latency (us)" + << "Trials" + << "Done"; + p.Separator(); + for (int i = 0; i < n_tasks; ++i) { + const TaskRecordNode* task = this->tasks_[i].get(); + auto row = p.Row(); + int trials = task->latency_ms.size(); + row << /*id=*/i << /*name=*/task->ctx->task_name.value() // + << /*flops=*/static_cast(task->flop) + << /*weight=*/static_cast(task->task_weight); + double latency_ms = 1e9; + if (!task->latency_ms.empty()) { + latency_ms = *std::min_element(task->latency_ms.begin(), task->latency_ms.end()); + } + if (latency_ms >= 1e9) { + row << /*speed=*/"N/A" << /*latency=*/"N/A" << /*weighted_latency=*/"N/A"; + } else { + latency_ms *= 1000.0; + double speed = task->flop / latency_ms / 1000.0; + double weighted_latency = latency_ms * task->task_weight; + row << /*speed=*/speed << /*latency=*/latency_ms << /*weighted_latency=*/weighted_latency; + total_latency += weighted_latency; + total_trials += trials; + } + row << trials; + if (task->is_terminated) { + row << "Y"; + } else { + row << ""; + } } + p.Separator(); + os << p.AsStr() // + << "\nTotal trials: " << total_trials // + << "\nTotal latency (us): " << total_latency // + << "\n"; + return os.str(); } -void PyTaskSchedulerNode::InitializeTask(int task_id) { - if (f_initialize_task == nullptr) { - TaskSchedulerNode::InitializeTask(task_id); - } else { - f_initialize_task(task_id); - } +TaskScheduler TaskScheduler::PyTaskScheduler( + PackedFunc logger, PyTaskSchedulerNode::FNextTaskId f_next_task_id, + PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, PyTaskSchedulerNode::FTune f_tune) { + CHECK(f_next_task_id != nullptr) << "ValueError: next_task_id is not defined"; + ObjectPtr n = make_object(); + n->logger = logger; + n->f_next_task_id = f_next_task_id; + n->f_join_running_task = f_join_running_task; + n->f_tune = f_tune; + return TaskScheduler(n); } -void PyTaskSchedulerNode::TouchTask(int task_id) { - if (f_touch_task == nullptr) { - return TaskSchedulerNode::TouchTask(task_id); - } else { - return f_touch_task(task_id); - } +int PyTaskSchedulerNode::NextTaskId() { + CHECK(f_next_task_id != nullptr) << "PyTaskScheduler's NextTaskId method not implemented!"; + return f_next_task_id(); } Array PyTaskSchedulerNode::JoinRunningTask(int task_id) { @@ -151,61 +339,38 @@ Array PyTaskSchedulerNode::JoinRunningTask(int task_id) { } } -int PyTaskSchedulerNode::NextTaskId() { - ICHECK(f_next_task_id != nullptr) << "PyTaskScheduler's NextTaskId method not implemented!"; - return f_next_task_id(); -} - -TaskScheduler TaskScheduler::PyTaskScheduler( - Array tasks, // - Builder builder, // - Runner runner, // - Optional database, // - Optional cost_model, // - Optional> measure_callbacks, // - int max_trials, // - PackedFunc logging_func, // - PyTaskSchedulerNode::FTune f_tune, // - PyTaskSchedulerNode::FInitializeTask f_initialize_task, // - PyTaskSchedulerNode::FTouchTask f_touch_task, // - PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, // - PyTaskSchedulerNode::FNextTaskId f_next_task_id) { - ObjectPtr n = make_object(); - n->tasks = tasks; - n->builder = builder; - n->runner = runner; - n->database = database; - n->max_trials = max_trials; - n->cost_model = cost_model; - if (measure_callbacks.defined()) { - n->measure_callbacks = measure_callbacks.value(); +void PyTaskSchedulerNode::Tune(Array tasks, Array task_weights, + int max_trials_global, int max_trials_per_task, + int num_trials_per_iter, Builder builder, Runner runner, + Array measure_callbacks, + Optional database, Optional cost_model) { + if (f_tune == nullptr) { + TaskSchedulerNode::Tune(tasks, task_weights, max_trials_global, max_trials_per_task, + num_trials_per_iter, builder, runner, measure_callbacks, database, + cost_model); } else { - n->measure_callbacks = {}; + f_tune(tasks, task_weights, max_trials_global, max_trials_per_task, num_trials_per_iter, + builder, runner, measure_callbacks, database, cost_model); } - n->logging_func = logging_func; - n->num_trials_already = 0; - n->f_tune = f_tune; - n->f_initialize_task = f_initialize_task; - n->f_touch_task = f_touch_task; - n->f_join_running_task = f_join_running_task; - n->f_next_task_id = f_next_task_id; - return TaskScheduler(n); } +TVM_REGISTER_NODE_TYPE(TaskRecordNode); TVM_REGISTER_OBJECT_TYPE(TaskSchedulerNode); TVM_REGISTER_NODE_TYPE(PyTaskSchedulerNode); TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerPyTaskScheduler") .set_body_typed(TaskScheduler::PyTaskScheduler); TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTune") .set_body_method(&TaskSchedulerNode::Tune); -TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerInitializeTask") - .set_body_method(&TaskSchedulerNode::InitializeTask); -TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTouchTask") - .set_body_method(&TaskSchedulerNode::TouchTask); TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerJoinRunningTask") .set_body_method(&TaskSchedulerNode::JoinRunningTask); TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerNextTaskId") .set_body_method(&TaskSchedulerNode::NextTaskId); +TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTerminateTask") + .set_body_method(&TaskSchedulerNode::TerminateTask); +TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTouchTask") + .set_body_method(&TaskSchedulerNode::TouchTask); +TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTuningStatistics") + .set_body_method(&TaskSchedulerNode::TuningStatistics); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/tune_context.cc b/src/meta_schedule/tune_context.cc index ee24624fe9e4..768c95857184 100644 --- a/src/meta_schedule/tune_context.cc +++ b/src/meta_schedule/tune_context.cc @@ -23,58 +23,32 @@ namespace tvm { namespace meta_schedule { -TuneContext::TuneContext(Optional mod, // - Optional target, // - Optional space_generator, // - Optional search_strategy, // - Optional> sch_rules, // - Optional> postprocs, // - Optional> mutator_probs, // - Optional task_name, // - PackedFunc logging_func, // - support::LinearCongruentialEngine::TRandState rand_state, // - int num_threads) { +TuneContext::TuneContext(Optional mod, Optional target, + Optional space_generator, + Optional search_strategy, Optional task_name, + int num_threads, TRandState rand_state, PackedFunc logger) { + CHECK(rand_state == -1 || rand_state >= 0) << "ValueError: Invalid random state: " << rand_state; ObjectPtr n = make_object(); n->mod = mod; n->target = target; n->space_generator = space_generator; n->search_strategy = search_strategy; - n->sch_rules = sch_rules.value_or({}); - n->postprocs = postprocs.value_or({}); - n->mutator_probs = mutator_probs.value_or({}); n->task_name = task_name; - n->logging_func = logging_func; - support::LinearCongruentialEngine(&n->rand_state).Seed(rand_state); n->num_threads = num_threads; - n->is_terminated = false; - n->runner_futures = NullOpt; - n->measure_candidates = NullOpt; + n->rand_state = support::LinearCongruentialEngine::NormalizeSeed(rand_state); + n->logger = logger; data_ = std::move(n); } TuneContext TuneContextNode::Clone() const { ObjectPtr n = make_object(*this); - if (this->sch_rules.defined()) { - n->sch_rules = Array(); - for (const ScheduleRule& sch_rule : this->sch_rules) { - n->sch_rules.push_back(sch_rule->Clone()); - } - } - if (this->postprocs.defined()) { - n->postprocs = Array(); - for (const Postproc& postproc : this->postprocs) { - n->postprocs.push_back(postproc->Clone()); - } + if (this->space_generator.defined()) { + n->space_generator = this->space_generator.value()->Clone(); } - if (this->mutator_probs.defined()) { - n->mutator_probs = Map(); - for (const auto& kv : this->mutator_probs) { - n->mutator_probs.Set(kv.first->Clone(), kv.second); - } + if (this->search_strategy.defined()) { + n->search_strategy = this->search_strategy.value()->Clone(); } - if (this->space_generator.defined()) n->space_generator = this->space_generator.value()->Clone(); - if (this->search_strategy.defined()) n->search_strategy = this->search_strategy.value()->Clone(); - n->rand_state = support::LinearCongruentialEngine(&n->rand_state).ForkSeed(); + n->rand_state = ForkSeed(&n->rand_state); n->Initialize(); return TuneContext(n); } @@ -86,136 +60,22 @@ void TuneContextNode::Initialize() { if (this->search_strategy.defined()) { this->search_strategy.value()->InitializeWithTuneContext(GetRef(this)); } - for (const ScheduleRule& sch_rule : sch_rules) { - sch_rule->InitializeWithTuneContext(GetRef(this)); - } - for (const Postproc& postproc : postprocs) { - postproc->InitializeWithTuneContext(GetRef(this)); - } - for (const auto& kv : mutator_probs) { - kv.first->InitializeWithTuneContext(GetRef(this)); - } -} - -void TuneContextNode::_SetMeasureCandidates(const Array& candidates) { - this->measure_candidates = candidates; -} - -void TuneContextNode::_SendToBuilder(const Builder& builder) { - auto _ = Profiler::TimedScope("SendToBuilder"); - Array candidates = this->measure_candidates.value(); - Target target = this->target.value(); - Array inputs; - inputs.reserve(candidates.size()); - for (const MeasureCandidate& candidate : candidates) { - inputs.push_back(BuilderInput(candidate->sch->mod(), target)); - } - this->builder_results = builder->Build(inputs); -} - -void TuneContextNode::_SendToRunner(const Runner& runner) { - auto _ = Profiler::TimedScope("SendToRunner"); - Array candidates = this->measure_candidates.value(); - Array builder_results = this->builder_results.value(); - Target target = this->target.value(); - ICHECK_EQ(candidates.size(), builder_results.size()); - int n = candidates.size(); - int n_build_errors = 0; - Array inputs; - inputs.reserve(n); - for (int i = 0; i < n; ++i) { - const MeasureCandidate& candidate = candidates[i]; - const BuilderResult& builder_result = builder_results[i]; - if (builder_result->error_msg.defined()) { - ++n_build_errors; - continue; - } - inputs.push_back(RunnerInput(/*artifact_path=*/builder_result->artifact_path.value(), - /*device_type=*/target->kind->name, - /*args_info=*/candidate->args_info)); - } - Array futures = runner->Run(inputs); - if (n_build_errors == 0) { - this->runner_futures = futures; - return; - } - Array results; - results.reserve(n); - for (int i = 0, j = 0; i < n; ++i) { - const BuilderResult& builder_result = builder_results[i]; - if (builder_result->error_msg.defined()) { - results.push_back(RunnerFuture( - /*f_done=*/[]() -> bool { return true; }, - /*f_result=*/ - [msg = builder_result->error_msg]() -> RunnerResult { - return RunnerResult(NullOpt, msg); - })); - } else { - results.push_back(futures[j++]); - } - } - this->runner_futures = results; -} - -Array TuneContextNode::_Join() { - ICHECK(this->runner_futures.defined()); - Array futures = this->runner_futures.value(); - int n = futures.size(); - Array results; - { - auto _ = Profiler::TimedScope("JoinRunnerFutures"); - results.reserve(n); - for (RunnerFuture future : futures) { - results.push_back(future->Result()); - } - } - if (this->search_strategy.defined()) { - this->search_strategy.value()->NotifyRunnerResults(this->measure_candidates.value(), results); - } - ICHECK(this->measure_candidates.defined()); - ICHECK(this->builder_results.defined()); - ICHECK_EQ(results.size(), this->measure_candidates.value().size()); - ICHECK_EQ(results.size(), this->builder_results.value().size()); - return results; -} - -void TuneContextNode::_ClearMeasureState() { - this->measure_candidates = NullOpt; - this->builder_results = NullOpt; - this->runner_futures = NullOpt; } TVM_REGISTER_NODE_TYPE(TuneContextNode); - TVM_REGISTER_GLOBAL("meta_schedule.TuneContext") - .set_body_typed([](Optional mod, // - Optional target, // - Optional space_generator, // - Optional search_strategy, // - Optional> sch_rules, // - Optional> postprocs, // - Optional> mutator_probs, // - Optional task_name, // - PackedFunc logging_func, // - support::LinearCongruentialEngine::TRandState rand_state, // - int num_threads) -> TuneContext { - return TuneContext(mod, target, space_generator, search_strategy, sch_rules, postprocs, - mutator_probs, task_name, logging_func, rand_state, num_threads); + .set_body_typed([](Optional mod, Optional target, + Optional space_generator, + Optional search_strategy, Optional task_name, + int num_threads, TRandState rand_state, PackedFunc logger) -> TuneContext { + return TuneContext(mod, target, space_generator, search_strategy, task_name, num_threads, + rand_state, logger); }); - TVM_REGISTER_GLOBAL("meta_schedule._SHash2Hex").set_body_typed(SHash2Hex); TVM_REGISTER_GLOBAL("meta_schedule.TuneContextInitialize") .set_body_method(&TuneContextNode::Initialize); -TVM_REGISTER_GLOBAL("meta_schedule.TuneContextSetMeasureCandidates") - .set_body_method(&TuneContextNode::_SetMeasureCandidates); -TVM_REGISTER_GLOBAL("meta_schedule.TuneContextSendToBuilder") - .set_body_method(&TuneContextNode::_SendToBuilder); -TVM_REGISTER_GLOBAL("meta_schedule.TuneContextSendToRunner") - .set_body_method(&TuneContextNode::_SendToRunner); -TVM_REGISTER_GLOBAL("meta_schedule.TuneContextJoin") - .set_body_method(&TuneContextNode::_Join); -TVM_REGISTER_GLOBAL("meta_schedule.TuneContextClearMeasureState") - .set_body_method(&TuneContextNode::_ClearMeasureState); +TVM_REGISTER_GLOBAL("meta_schedule.TuneContextClone") + .set_body_method(&TuneContextNode::Clone); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index f0b736081670..41d8ffde558c 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -44,6 +44,7 @@ #include #include +#include #include #include "../printer/text_printer.h" @@ -55,8 +56,8 @@ #include "../tir/schedule/primitive.h" #include "../tir/schedule/utils.h" -#define TVM_PY_LOG(logging_level, logging_func) \ - ::tvm::meta_schedule::PyLogMessage(__FILE__, __LINE__, logging_func, \ +#define TVM_PY_LOG(logging_level, logger) \ + ::tvm::meta_schedule::PyLogMessage(__FILE__, __LINE__, logger, \ PyLogMessage::Level::logging_level) \ .stream() #define TVM_PY_LOG_CLEAR_SCREEN(logging_func) clear_logging(__FILE__, __LINE__, logging_func) @@ -81,14 +82,18 @@ class PyLogMessage { // FATAL not included }; - explicit PyLogMessage(const char* file, int lineno, PackedFunc logging_func, Level logging_level) - : file_(file), lineno_(lineno), logging_func_(logging_func), logging_level_(logging_level) {} + explicit PyLogMessage(const char* file, int lineno, PackedFunc logger, Level logging_level) + : file_(file), lineno_(lineno), logger_(logger), logging_level_(logging_level) { + if (this->logger_ != nullptr) { + stream_ << "" << file_ << ":" << lineno_ << " "; + } + } TVM_NO_INLINE ~PyLogMessage() { ICHECK(logging_level_ != Level::CLEAR) << "Cannot use CLEAR as logging level in TVM_PY_LOG, please use TVM_PY_LOG_CLEAR_SCREEN."; - if (this->logging_func_.defined()) { - logging_func_(static_cast(logging_level_), stream_.str()); + if (this->logger_ != nullptr) { + logger_(static_cast(logging_level_), stream_.str()); } else { if (logging_level_ == Level::INFO) { runtime::detail::LogMessage(file_, lineno_).stream() << stream_.str(); @@ -109,7 +114,7 @@ class PyLogMessage { const char* file_; int lineno_; std::ostringstream stream_; - PackedFunc logging_func_; + PackedFunc logger_; Level logging_level_; }; @@ -120,7 +125,9 @@ class PyLogMessage { inline bool using_ipython() { bool flag = false; const auto* f_using_ipython = runtime::Registry::Get("meta_schedule.using_ipython"); - if (f_using_ipython->defined()) flag = (*f_using_ipython)(); + if (f_using_ipython) { + flag = (*f_using_ipython)(); + } return flag; } @@ -459,6 +466,40 @@ struct SortTuningRecordByMeanRunSecs { } }; +/*! + * \brief The helper function to clone schedule rules, postprocessors, and mutators. + * \param src The source space generator. + * \param dst The destination space generator. + */ +inline void CloneRules(const SpaceGeneratorNode* src, SpaceGeneratorNode* dst) { + if (src->sch_rules.defined()) { + Array original = src->sch_rules.value(); + Array sch_rules; + sch_rules.reserve(original.size()); + for (const ScheduleRule& sch_rule : original) { + sch_rules.push_back(sch_rule->Clone()); + } + dst->sch_rules = std::move(sch_rules); + } + if (src->postprocs.defined()) { + Array original = src->postprocs.value(); + Array postprocs; + postprocs.reserve(original.size()); + for (const Postproc& postproc : original) { + postprocs.push_back(postproc->Clone()); + } + dst->postprocs = std::move(postprocs); + } + if (src->mutator_probs.defined()) { + Map original = src->mutator_probs.value(); + Map mutator_probs; + for (const auto& kv : original) { + mutator_probs.Set(kv.first->Clone(), kv.second); + } + dst->mutator_probs = std::move(mutator_probs); + } +} + } // namespace meta_schedule } // namespace tvm diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 8fa8610c0fca..b4373c6f5f1e 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -547,7 +547,7 @@ TECompiler& TECompiler::Global() { } TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_auto_scheduler", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_meta_schedule", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_meta_schedule_dispatch", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_meta_schedule_dispatch", Integer); TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.tir_converter", String); TVM_REGISTER_GLOBAL("relay.backend._TECompilerGlobal").set_body_typed([]() { diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 6f55402baded..27738615c7eb 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -48,6 +48,7 @@ #include #include +#include "../../printer/text_printer.h" #include "../../te/operation/create_primfunc.h" #include "../op/memory/memory.h" #include "../transforms/meta_schedule_layout_rewrite.h" @@ -387,7 +388,18 @@ class ScheduleBuilder : public ExprVisitor { mod = tir::transform::RemoveWeightLayoutRewriteBlock()(std::move(mod)); prim_func = Downcast(mod->Lookup("main")); } else { - LOG(WARNING) << "Cannot find workload: " << prim_fn_var->name_hint; + int dispatch = backend::UseMetaScheduleDispatch(); + // (dispatch & 2): controls whether to print TVMScript for missing TIR + // (dispatch & 4): controls whether to raise fatal errors for missing TIR + if (dispatch & 2) { + LOG(WARNING) << "Cannot find workload: " << prim_fn_var->name_hint << "\n" + << tir::AsTVMScript(f.value()); + } else { + LOG(WARNING) << "Cannot find workload: " << prim_fn_var->name_hint; + } + if (dispatch & 4) { + LOG(FATAL); + } } } } diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 6c65a081f156..00c75921f2f2 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -632,6 +632,13 @@ inline bool IsMetaScheduleEnabled() { .value(); } +/*! \brief Consider MetaSchedule's dispatch option. */ +inline int UseMetaScheduleDispatch() { + return transform::PassContext::Current() + ->GetConfig("relay.backend.use_meta_schedule_dispatch", Integer(0)) + .value() + ->value; +} /*! * \brief Method in TECompiler to convert TE compute to scheduleable TIR * \param args The arguments of the TE compute diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 8cfbadf65012..1d9272cf2dd5 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -225,7 +225,7 @@ Schedule ConcreteScheduleNode::Copy() { /******** Schedule: Schedule: Sampling ********/ void ConcreteScheduleNode::Seed(support::LinearCongruentialEngine::TRandState seed) { - support::LinearCongruentialEngine(&rand_state_).Seed(seed); + this->rand_state_ = support::LinearCongruentialEngine::NormalizeSeed(seed); } support::LinearCongruentialEngine::TRandState ConcreteScheduleNode::ForkSeed() { diff --git a/tests/python/contrib/test_hexagon/test_meta_schedule.py b/tests/python/contrib/test_hexagon/test_meta_schedule.py index 8b07122c2a17..e8caa9f04e87 100644 --- a/tests/python/contrib/test_hexagon/test_meta_schedule.py +++ b/tests/python/contrib/test_hexagon/test_meta_schedule.py @@ -16,23 +16,25 @@ # under the License. """ Test rpc based launcher for hexagon """ -import pytest -import numpy as np import tempfile +import numpy as np +import pytest import tvm.testing import tvm.topi.testing -from tvm import te, relay from tvm import meta_schedule as ms +from tvm import relay, te +from tvm.contrib.hexagon.meta_schedule import ( + get_hexagon_local_builder, + get_hexagon_rpc_runner, +) +from tvm.meta_schedule import postproc, schedule_rule from tvm.meta_schedule.arg_info import TensorInfo from tvm.meta_schedule.builder import BuilderInput -from tvm.meta_schedule import postproc, schedule_rule +from tvm.meta_schedule.runner import RunnerInput from tvm.script import tir as T from tvm.tir import FloatImm from tvm.tir.tensor_intrin.hexagon import VRMPY_u8u8i32_INTRIN -from tvm.meta_schedule.runner import RunnerInput -from tvm.contrib.hexagon.meta_schedule import get_hexagon_local_builder, get_hexagon_rpc_runner -from tvm.relay.backend import Executor from .infrastructure import get_hexagon_target @@ -43,7 +45,9 @@ @tvm.script.ir_module class MatmulModule: @T.prim_func - def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-self-argument + def main( # type: ignore # pylint: disable=no-self-argument + a: T.handle, b: T.handle, c: T.handle + ) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) A = T.match_buffer(a, (16, 16), "float32") B = T.match_buffer(b, (16, 16), "float32") @@ -52,7 +56,7 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-s with T.block("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): - C[vi, vj] = 0.0 + C[vi, vj] = 0.0 # type: ignore C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] @@ -186,26 +190,28 @@ def test_vrmpy_dense(hexagon_launcher): schedule_dense(sch, block, M, do_tune) else: with tempfile.TemporaryDirectory() as work_dir: - config = ms.TuneConfig( - strategy="replay_trace", - num_trials_per_iter=8, - max_trials_per_task=8, - max_trials_global=8, - ) def schedule_dense_for_tune(sch): block = sch.get_block("compute") return schedule_dense(sch, block, None, True) - sch = ms.tune_tir( + target = get_hexagon_target("v69") + database = ms.tir_integration.tune_tir( mod=workload, target=target, - config=config, work_dir=work_dir, - space=ms.space_generator.ScheduleFn(schedule_dense_for_tune), + max_trials_global=8, + space=ms.space_generator.ScheduleFn( + schedule_dense_for_tune, + sch_rules=[], + postprocs=[], + mutator_probs=[], + ), + strategy="replay-trace", builder=get_hexagon_local_builder(), runner=get_hexagon_rpc_runner(hexagon_launcher, number=10), ) + sch = ms.tir_integration.compile_tir(database, workload, target) with hexagon_launcher.start_session() as session: verify_dense(sch, get_hexagon_target("v68"), M, N, K, session) @@ -216,10 +222,10 @@ def schedule_dense_for_tune(sch): @tvm.script.ir_module class Module_vrmpy_auto_tensorize: @T.prim_func - def main( - X: T.Buffer[(128, 768), "uint8"], - packedW: T.Buffer[(24, 192, 32, 4), "uint8"], - compute: T.Buffer[(128, 768), "int32"], + def main( # type: ignore + X: T.Buffer[(128, 768), "uint8"], # type: ignore + packedW: T.Buffer[(24, 192, 32, 4), "uint8"], # type: ignore + compute: T.Buffer[(128, 768), "int32"], # type: ignore ) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) for i0_0_i1_0_0_fused in T.parallel( @@ -230,37 +236,37 @@ def main( i = T.axis.spatial(128, i0_0_i1_0_0_fused // 8 * 2 + i0_1_init + i0_2_init) j_o = T.axis.spatial(24, i1_0_2_init + i0_0_i1_0_0_fused % 8 * 3 + i1_0_1_init) T.reads() - T.writes(compute[i, j_o * 32 : j_o * 32 + 32]) + T.writes(compute[i, j_o * 32 : j_o * 32 + 32]) # type: ignore for i1_1 in T.vectorized(32): with T.block("compute_init"): j_i_init = T.axis.spatial(32, i1_1) T.reads() T.writes(compute[i, j_o * 32 + j_i_init]) - compute[i, j_o * 32 + j_i_init] = 0 + compute[i, j_o * 32 + j_i_init] = 0 # type: ignore for i2_0_0, i0_1, i1_0_1, i2_0_1, i0_2, i1_0_2 in T.grid(32, 2, 3, 6, 1, 1): with T.block("compute_o_update"): i = T.axis.spatial(128, i0_0_i1_0_0_fused // 8 * 2 + i0_1 + i0_2) j_o = T.axis.spatial(24, i1_0_2 + i0_0_i1_0_0_fused % 8 * 3 + i1_0_1) k_o = T.axis.reduce(192, i2_0_0 * 6 + i2_0_1) T.reads( - compute[i, j_o * 32 : j_o * 32 + 32], - X[i, k_o * 4 : k_o * 4 + 4], - packedW[j_o, k_o, 0:32, 0:4], + compute[i, j_o * 32 : j_o * 32 + 32], # type: ignore + X[i, k_o * 4 : k_o * 4 + 4], # type: ignore + packedW[j_o, k_o, 0:32, 0:4], # type: ignore ) - T.writes(compute[i, j_o * 32 : j_o * 32 + 32]) + T.writes(compute[i, j_o * 32 : j_o * 32 + 32]) # type: ignore A = T.match_buffer( - X[i, k_o * 4 : k_o * 4 + 4], [4], dtype="uint8", offset_factor=1 + X[i, k_o * 4 : k_o * 4 + 4], [4], dtype="uint8", offset_factor=1 # type: ignore ) B = T.match_buffer( packedW[j_o, k_o, 0:32, 0:4], [32, 4], dtype="uint8", offset_factor=1 ) C = T.match_buffer( - compute[i, j_o * 32 : j_o * 32 + 32], [32], dtype="int32", offset_factor=1 + compute[i, j_o * 32 : j_o * 32 + 32], [32], dtype="int32", offset_factor=1 # type: ignore ) - A_u8x4: T.uint8x4 = A[0:4] - A_i32: T.int32 = T.reinterpret(A_u8x4, dtype="int32") - B_i32x32: T.int32x32 = T.reinterpret(B[0, 0:128], dtype="int32x32") - C[0:32] = T.call_llvm_pure_intrin( + A_u8x4: T.uint8x4 = A[0:4] # type: ignore + A_i32: T.int32 = T.reinterpret(A_u8x4, dtype="int32") # type: ignore + B_i32x32: T.int32x32 = T.reinterpret(B[0, 0:128], dtype="int32x32") # type: ignore + C[0:32] = T.call_llvm_pure_intrin( # type: ignore 4390, T.uint32(3), C[0:32], B_i32x32, A_i32, dtype="int32x32" ) @@ -303,23 +309,20 @@ def test_vrmpy_dense_auto_tensorize(hexagon_launcher): if True: with tempfile.TemporaryDirectory() as work_dir: - config = ms.TuneConfig( - strategy="replay_trace", + target = get_hexagon_target("v68") + database = ms.tir_integration.tune_tir( + mod=workload, + target=target, + max_trials_global=8, num_trials_per_iter=8, max_trials_per_task=8, - max_trials_global=8, - ) - - sch = ms.tune_tir( - mod=workload, - target=get_hexagon_target("v68"), - config=config, work_dir=work_dir, sch_rules=lambda: sch_rules, postprocs=lambda: postprocs, builder=get_hexagon_local_builder(), runner=get_hexagon_rpc_runner(hexagon_launcher, number=10), ) + sch = ms.tir_integration.compile_tir(database, workload, target) else: sch = tvm.tir.Schedule(Module_vrmpy_auto_tensorize, debug_mask="all") @@ -358,6 +361,7 @@ def test_conv2d_relay_auto_schedule(hexagon_launcher): kernel_layout="HWIO", ) mod = tvm.IRModule.from_expr(conv2d + bias) + mod = mod.with_attr("executor", relay.backend.Executor("graph", {"link-params": True})) data_np = np.random.randn(*d_shape).astype("float16") weight_np = np.random.randn(*w_shape).astype("float16") @@ -379,24 +383,25 @@ def test_conv2d_relay_auto_schedule(hexagon_launcher): ref = rt_mod_ref.get_output(0).numpy() - config = ms.TuneConfig( - strategy="replay_trace", - num_trials_per_iter=8, - max_trials_per_task=8, - max_trials_global=8, - ) - with tempfile.TemporaryDirectory() as work_dir: - executor = Executor("graph", {"link-params": True}) - lib = ms.tune_relay( + target = get_hexagon_target("v69") + database = ms.relay_integration.tune_relay( mod=mod, params=params, - target=get_hexagon_target("v69"), - config=config, + target=target, + max_trials_global=8, + max_trials_per_task=8, + num_trials_per_iter=8, + strategy=ms.search_strategy.ReplayTrace(), work_dir=work_dir, builder=get_hexagon_local_builder(), runner=get_hexagon_rpc_runner(hexagon_launcher, number=20), - executor=executor, + ) + lib = ms.relay_integration.compile_relay( + database=database, + mod=mod, + params=params, + target=target, ) with hexagon_launcher.start_session() as session: diff --git a/tests/python/integration/test_meta_schedule_auto_tensorize.py b/tests/python/integration/test_auto_tensorize.py similarity index 73% rename from tests/python/integration/test_meta_schedule_auto_tensorize.py rename to tests/python/integration/test_auto_tensorize.py index fd28f7928301..3fdf027a490d 100644 --- a/tests/python/integration/test_meta_schedule_auto_tensorize.py +++ b/tests/python/integration/test_auto_tensorize.py @@ -24,23 +24,14 @@ import tvm.topi.testing from tvm import meta_schedule as ms from tvm import relay -from tvm.meta_schedule import postproc, schedule_rule -from tvm.meta_schedule.relay_integration import extract_task_from_relay +from tvm.meta_schedule.testing import relay_workload from tvm.meta_schedule.testing.tlcbench import load_quantized_bert_base -from tvm.meta_schedule.tune import tune_extracted_tasks from tvm.tir.tensor_intrin.arm_cpu import DP4A_INTRIN from tvm.tir.tensor_intrin.rocm import AMDGPU_SDOT4_INTRIN from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN -CONFIG = ms.TuneConfig( - strategy="evolutionary", - num_trials_per_iter=32, - max_trials_per_task=32, - max_trials_global=20000, -) - SCH_RULES_FOR_VNNI = [ - schedule_rule.AutoInline( + ms.schedule_rule.AutoInline( into_producer=False, into_consumer=True, inline_const_tensor=True, @@ -49,62 +40,62 @@ require_ordered=True, disallow_op=["tir.exp"], ), - schedule_rule.AddRFactor(max_jobs_per_core=16, max_innermost_factor=64), - schedule_rule.MultiLevelTilingWithIntrin( + ms.schedule_rule.AddRFactor(max_jobs_per_core=16, max_innermost_factor=64), + ms.schedule_rule.MultiLevelTilingWithIntrin( VNNI_INTRIN, structure="SSRSRS", tile_binds=None, max_innermost_factor=64, vector_load_lens=None, reuse_read=None, - reuse_write=schedule_rule.ReuseType( + reuse_write=ms.schedule_rule.ReuseType( req="may", levels=[1, 2], scope="global", ), ), - schedule_rule.MultiLevelTiling( + ms.schedule_rule.MultiLevelTiling( structure="SSRSRS", tile_binds=None, max_innermost_factor=64, vector_load_lens=None, reuse_read=None, - reuse_write=schedule_rule.ReuseType( + reuse_write=ms.schedule_rule.ReuseType( req="may", levels=[1, 2], scope="global", ), ), - schedule_rule.ParallelizeVectorizeUnroll( + ms.schedule_rule.ParallelizeVectorizeUnroll( max_jobs_per_core=16, max_vectorize_extent=64, unroll_max_steps=[0, 16, 64, 512], unroll_explicit=True, ), - schedule_rule.RandomComputeLocation(), + ms.schedule_rule.RandomComputeLocation(), ] -def get_sch_rules_for_dp4a(intrin): +def _get_sch_rules_for_dp4a(intrin): return [ - schedule_rule.MultiLevelTilingWithIntrin( + ms.schedule_rule.MultiLevelTilingWithIntrin( intrin, structure="SSSRRSRS", tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"], max_innermost_factor=64, vector_load_lens=[1, 2, 3, 4], - reuse_read=schedule_rule.ReuseType( + reuse_read=ms.schedule_rule.ReuseType( req="must", levels=[4], scope="shared", ), - reuse_write=schedule_rule.ReuseType( + reuse_write=ms.schedule_rule.ReuseType( req="must", levels=[3], scope="local", ), ), - schedule_rule.AutoInline( + ms.schedule_rule.AutoInline( into_producer=True, into_consumer=True, inline_const_tensor=True, @@ -113,8 +104,8 @@ def get_sch_rules_for_dp4a(intrin): require_ordered=False, disallow_op=None, ), - schedule_rule.CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512]), - schedule_rule.ParallelizeVectorizeUnroll( + ms.schedule_rule.CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512]), + ms.schedule_rule.ParallelizeVectorizeUnroll( max_jobs_per_core=-1, # disable parallelize max_vectorize_extent=-1, # disable vectorize unroll_max_steps=[0, 16, 64, 512, 1024], @@ -123,24 +114,24 @@ def get_sch_rules_for_dp4a(intrin): ] -SCH_RULES_FOR_DP4A = get_sch_rules_for_dp4a(DP4A_INTRIN) -SCH_RULES_FOR_SDOT4 = get_sch_rules_for_dp4a(AMDGPU_SDOT4_INTRIN) +SCH_RULES_FOR_DP4A = _get_sch_rules_for_dp4a(DP4A_INTRIN) +SCH_RULES_FOR_SDOT4 = _get_sch_rules_for_dp4a(AMDGPU_SDOT4_INTRIN) POSTPROCS_FOR_VNNI = [ - postproc.DisallowDynamicLoop(), - postproc.RewriteParallelVectorizeUnroll(), - postproc.RewriteReductionBlock(), - postproc.RewriteTensorize(vectorize_init_loop=True), + ms.postproc.DisallowDynamicLoop(), + ms.postproc.RewriteParallelVectorizeUnroll(), + ms.postproc.RewriteReductionBlock(), + ms.postproc.RewriteTensorize(vectorize_init_loop=True), ] POSTPROCS_FOR_DP4A = [ - postproc.DisallowDynamicLoop(), - postproc.RewriteCooperativeFetch(), - postproc.RewriteUnboundBlock(), - postproc.RewriteParallelVectorizeUnroll(), - postproc.RewriteReductionBlock(), - postproc.RewriteTensorize(), - postproc.VerifyGPUCode(), + ms.postproc.DisallowDynamicLoop(), + ms.postproc.RewriteCooperativeFetch(), + ms.postproc.RewriteUnboundBlock(), + ms.postproc.RewriteParallelVectorizeUnroll(), + ms.postproc.RewriteReductionBlock(), + ms.postproc.RewriteTensorize(), + ms.postproc.VerifyGPUCode(), ] @@ -148,33 +139,33 @@ def tune_and_test(relay_mod, data_np, weight_np, op_name, target, sch_rules, pos """Test tuning.""" tgt = "cuda" if "nvidia" in target else target dev = tvm.device(tgt, 0) - ref = ( relay.create_executor("vm", mod=relay_mod, device=dev, target=tgt) .evaluate()(*[data_np, weight_np]) .numpy() ) - params = {"weight": weight_np} - - extracted_tasks = extract_task_from_relay(relay_mod, target, params) - tune_tasks = list( filter( lambda task: op_name in task.task_name, - extracted_tasks, + ms.relay_integration.extracted_task_from_relay(relay_mod, target, params), ) ) - with tempfile.TemporaryDirectory() as work_dir: - database = tune_extracted_tasks( - tune_tasks, - CONFIG, + tasks, task_weights = ms.relay_integration.extracted_tasks_to_tune_contexts( + extracted_tasks=tune_tasks, work_dir=work_dir, - sch_rules=lambda: sch_rules, - postprocs=lambda: postprocs, + space=ms.space_generator.PostOrderApply( + sch_rules=sch_rules, + postprocs=postprocs, + ), + ) + database = ms.tune.tune_tasks( + tasks=tasks, + task_weights=task_weights, + work_dir=work_dir, + max_trials_global=20000, ) - with database, tvm.transform.PassContext( opt_level=3, config={"relay.backend.use_meta_schedule": True}, @@ -186,12 +177,9 @@ def tune_and_test(relay_mod, data_np, weight_np, op_name, target, sch_rules, pos assert "vpdpbusd" in asm runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) - runtime.set_input("data", data_np) runtime.run() - out = runtime.get_output(0).numpy() - np.testing.assert_equal(out, ref) @@ -243,28 +231,28 @@ def _test_conv2d(data_dtype, sch_rules, postprocs, target): tune_and_test(relay_mod, data_np, weight_np, "conv2d", target, sch_rules, postprocs) -def _test_bert_int8(target, sch_rules, postprocs): - relay_mod, params, input_info = load_quantized_bert_base() - +def _test_bert_int8(relay_mod, params, input_info, target, sch_rules, postprocs): relay_mod = relay.transform.FastMath()(relay_mod) - - extracted_tasks = extract_task_from_relay(relay_mod, target, params) - tune_tasks = [ task - for task in extracted_tasks + for task in ms.relay_integration.extract_tasks(relay_mod, target, params) if "dense" in task.task_name or "batch_matmul" in task.task_name ] - with tempfile.TemporaryDirectory() as work_dir: - database = tune_extracted_tasks( - tune_tasks, - CONFIG, + tasks, task_weights = ms.relay_integration.extracted_tasks_to_tune_contexts( + extracted_tasks=tune_tasks, work_dir=work_dir, - sch_rules=lambda: sch_rules, - postprocs=lambda: postprocs, + space=ms.space_generator.PostOrderApply( + sch_rules=sch_rules, + postprocs=postprocs, + ), + ) + database = ms.tune.tune_tasks( + tasks=tasks, + task_weights=task_weights, + work_dir=work_dir, + max_trials_global=20000, ) - with database, tvm.transform.PassContext( opt_level=3, config={"relay.backend.use_meta_schedule": True}, @@ -273,14 +261,11 @@ def _test_bert_int8(target, sch_rules, postprocs): dev = tvm.device("cuda" if "nvidia" in target else target, 0) runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) - inputs = [] - for name, shape in input_info: arr = np.random.uniform(1, 10, size=shape).astype("int64") runtime.set_input(name, arr) inputs.append(arr) - print(runtime.benchmark(dev, number=1, repeat=50).mean) @@ -295,7 +280,6 @@ def test_vnni_dense(): @tvm.testing.requires_gpu def test_dp4a_dense(): _test_dense("int8", SCH_RULES_FOR_DP4A, POSTPROCS_FOR_DP4A, "nvidia/geforce-rtx-3070") - # Uncomment to test on vulkan or rocm target # _test_dense( # "int8", sch_rules_for_dp4a, postprocs_for_dp4a, "vulkan -from_device=0" @@ -316,7 +300,6 @@ def test_vnni_conv2d(): @tvm.testing.requires_gpu def test_dp4a_conv2d(): _test_conv2d("int8", SCH_RULES_FOR_DP4A, POSTPROCS_FOR_DP4A, "nvidia/geforce-rtx-3070") - # Uncomment to test on vulkan or rocm target # _test_conv2d( # "int8", sch_rules_for_dp4a, postprocs_for_dp4a, "vulkan -from_device=0" @@ -329,17 +312,46 @@ def test_dp4a_conv2d(): @tvm.testing.requires_cascadelake @pytest.mark.skip_if(tvm.testing.IS_IN_CI, reason="Slow on CI") def test_vnni_bert_int8(): - _test_bert_int8("llvm -mcpu=cascadelake -num-cores 4", SCH_RULES_FOR_VNNI, POSTPROCS_FOR_VNNI) + relay_mod, params, input_info = load_quantized_bert_base() + _test_bert_int8( + relay_mod, + params, + input_info, + "llvm -mcpu=cascadelake -num-cores 4", + SCH_RULES_FOR_VNNI, + POSTPROCS_FOR_VNNI, + ) @tvm.testing.requires_gpu @pytest.mark.skip("Slow on CI") def test_dp4a_bert_int8(): - _test_bert_int8("nvidia/geforce-rtx-3070", SCH_RULES_FOR_DP4A, POSTPROCS_FOR_DP4A) - + relay_mod, params, input_info = load_quantized_bert_base() + _test_bert_int8( + relay_mod, + params, + input_info, + "nvidia/geforce-rtx-3070", + SCH_RULES_FOR_DP4A, + POSTPROCS_FOR_DP4A, + ) # Uncomment to test on vulkan or rocm target - # _test_bert_int8("vulkan -from_device=0", sch_rules_for_dp4a, postprocs_for_dp4a) - # _test_bert_int8("rocm", sch_rules_for_sdot4, postprocs_for_dp4a) + # _test_bert_int8( + # relay_mod, + # params, + # input_info, + # "vulkan -from_device=0", + # sch_rules_for_dp4a, + # postprocs_for_dp4a, + # ) + # _test_bert_int8( + # relay_mod, + # params, + # input_info, + # "rocm", + # sch_rules_for_sdot4, + # postprocs_for_dp4a, + # ) @tvm.testing.requires_gpu @@ -356,14 +368,12 @@ def test_cuda_tensor_core(model_name, input_shape): data = tvm.nd.array(np.random.randint(0, 30521, size=input_shape), dev) # embedding size else: data = tvm.nd.array(np.random.randn(*input_shape).astype("float32"), dev) - mod, params, (input_name, _, _) = relay_workload.get_network(model_name, input_shape) seq = tvm.transform.Sequential( [ relay.transform.ToMixedPrecision(), ] ) - with tvm.transform.PassContext(opt_level=3): mod = seq(mod) @@ -377,18 +387,19 @@ def convert_layout(mod): with tempfile.TemporaryDirectory() as work_dir: with ms.Profiler() as profiler: - rt_mod1: tvm.runtime.Module = ms.tune_relay( - mod=convert_layout(mod), - params=params, + converted_mod = convert_layout(mod) + database = ms.relay_integration.tune_relay( + mod=converted_mod, target=target, - config=ms.TuneConfig( - num_trials_per_iter=32, - max_trials_per_task=200, - max_trials_global=3000, - ), - sch_rules=ms.default_config._DefaultCUDATensorCore.schedule_rules, - postprocs=ms.default_config._DefaultCUDATensorCore.postprocs, work_dir=work_dir, + max_trials_global=3000, + params=params, + ) + rt_mod1 = ms.relay_integration.compile_relay( + database=database, + mod=converted_mod, + target=target, + params=params, ) print(profiler.table()) diff --git a/tests/python/integration/test_legacy_tuning.py b/tests/python/integration/test_legacy_tuning.py new file mode 100644 index 000000000000..04c5f85ce5d4 --- /dev/null +++ b/tests/python/integration/test_legacy_tuning.py @@ -0,0 +1,380 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Test the tuner +""" +import logging +import multiprocessing as mp +import textwrap + +import tvm +import tvm.relay +import tvm.testing +from tvm import autotvm, te +from tvm.autotvm.measure import measure_methods +from tvm.autotvm.tuner import RandomTuner +from tvm.contrib import tar +from tvm.ir.instrument import pass_instrument +from tvm.ir.transform import PassContext +from tvm.target import Target +from tvm.tir.analysis import _ffi_api as _analysis_ffi_api + + +def setup_module(): + """Setup the module used for testing.""" + + @autotvm.template("testing/conv2d_no_batching") + def conv2d_no_batching( # pylint: disable=unused-variable + batch_size, input_h, input_w, channels_in, channels_out, kernel_h, kernel_w + ): + """An example template for testing""" + assert batch_size == 1, "Only consider batch_size = 1 in this template" + + data = te.placeholder((batch_size, channels_in, input_h, input_w), name="data") + kernel = te.placeholder((channels_out, channels_in, kernel_h, kernel_w), name="kernel") + + axis_rc = te.reduce_axis((0, channels_in), name="rc") + axis_ry = te.reduce_axis((0, kernel_h), name="ry") + axis_rx = te.reduce_axis((0, kernel_w), name="rx") + + conv = te.compute( + (batch_size, channels_out, input_h - kernel_h + 1, input_w - kernel_w + 1), + lambda nn, ff, yy, xx: te.sum( + data[nn, axis_rc, yy + axis_ry, xx + axis_rx] + * kernel[ff, axis_rc, axis_ry, axis_rx], + axis=[axis_rc, axis_ry, axis_rx], + ), + tag="conv2d_nchw", + ) + + schedule = te.create_schedule([conv.op]) + + output = conv + cache_write_ol = schedule.cache_write(conv, "local") + + # create cache stage + cache_read_aa = schedule.cache_read(data, "shared", [cache_write_ol]) + cache_read_ww = schedule.cache_read(kernel, "shared", [cache_write_ol]) + cache_read_al = schedule.cache_read(cache_read_aa, "local", [cache_write_ol]) + cache_read_wl = schedule.cache_read(cache_read_ww, "local", [cache_write_ol]) + + # tile and bind spatial axes + axis_n, axis_f, axis_y, axis_x = schedule[output].op.axis + cfg = autotvm.get_config() + cfg.define_split("tile_f", cfg.axis(axis_f), num_outputs=4) + cfg.define_split("tile_y", cfg.axis(axis_y), num_outputs=4) + cfg.define_split("tile_x", cfg.axis(axis_x), num_outputs=4) + axis_bf, axis_vf, axis_tf, axis_fi = cfg["tile_f"].apply(schedule, output, axis_f) + axis_by, axis_vy, axis_ty, axis_yi = cfg["tile_y"].apply(schedule, output, axis_y) + axis_bx, axis_vx, axis_tx, axis_xi = cfg["tile_x"].apply(schedule, output, axis_x) + kernel_scope = axis_n # this is the scope to attach global config inside this kernel + + schedule[output].bind(axis_bf, te.thread_axis("blockIdx.z")) + schedule[output].bind(axis_by, te.thread_axis("blockIdx.y")) + schedule[output].bind(axis_bx, te.thread_axis("blockIdx.x")) + schedule[output].bind(axis_vf, te.thread_axis("vthread")) + schedule[output].bind(axis_vy, te.thread_axis("vthread")) + schedule[output].bind(axis_vx, te.thread_axis("vthread")) + schedule[output].bind(axis_tf, te.thread_axis("threadIdx.z")) + schedule[output].bind(axis_ty, te.thread_axis("threadIdx.y")) + schedule[output].bind(axis_tx, te.thread_axis("threadIdx.x")) + schedule[output].reorder( + axis_n, + axis_bf, + axis_by, + axis_bx, + axis_vf, + axis_vy, + axis_vx, + axis_tf, + axis_ty, + axis_tx, + axis_fi, + axis_yi, + axis_xi, + ) + schedule[cache_write_ol].compute_at(schedule[output], axis_tx) + + # tile and bind reduction axes + axis_n, axis_f, axis_y, axis_x = schedule[cache_write_ol].op.axis + axis_rc, axis_ry, axis_rx = schedule[cache_write_ol].op.reduce_axis + cfg.define_split("tile_rc", cfg.axis(axis_rc), num_outputs=3) + cfg.define_split("tile_ry", cfg.axis(axis_ry), num_outputs=3) + cfg.define_split("tile_rx", cfg.axis(axis_rx), num_outputs=3) + axis_rco, axis_rcm, axis_rci = cfg["tile_rc"].apply(schedule, cache_write_ol, axis_rc) + axis_ryo, axis_rym, axis_ryi = cfg["tile_rx"].apply(schedule, cache_write_ol, axis_ry) + axis_rxo, axis_rxm, axis_rxi = cfg["tile_ry"].apply(schedule, cache_write_ol, axis_rx) + schedule[cache_write_ol].reorder( + axis_rco, + axis_ryo, + axis_rxo, + axis_rcm, + axis_rym, + axis_rxm, + axis_rci, + axis_ryi, + axis_rxi, + axis_n, + axis_f, + axis_y, + axis_x, + ) + + schedule[cache_read_aa].compute_at(schedule[cache_write_ol], axis_rxo) + schedule[cache_read_ww].compute_at(schedule[cache_write_ol], axis_rxo) + schedule[cache_read_al].compute_at(schedule[cache_write_ol], axis_rxm) + schedule[cache_read_wl].compute_at(schedule[cache_write_ol], axis_rxm) + + # cooperative fetching + for load in [cache_read_aa, cache_read_ww]: + axis_n, axis_f, axis_y, axis_x = schedule[load].op.axis + fused = schedule[load].fuse(axis_n, axis_f, axis_y, axis_x) + axis_tz, fused = schedule[load].split(fused, nparts=cfg["tile_f"].size[2]) + axis_ty, fused = schedule[load].split(fused, nparts=cfg["tile_y"].size[2]) + axis_tx, fused = schedule[load].split(fused, nparts=cfg["tile_x"].size[2]) + schedule[load].bind(axis_tz, te.thread_axis("threadIdx.z")) + schedule[load].bind(axis_ty, te.thread_axis("threadIdx.y")) + schedule[load].bind(axis_tx, te.thread_axis("threadIdx.x")) + + # tune unroll + cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) + cfg.define_knob("unroll_explicit", [0, 1]) + schedule[output].pragma( + kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val + ) + schedule[output].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val) + + return schedule, [data, kernel, conv] + + +def teardown_module(): + """Remove the module from the autotvm task tables.""" + # TODO(areusch): Tasks should not be registered into a global. + del autotvm.task.task.TASK_TABLE["testing/conv2d_no_batching"] + + +def get_sample_task(target=tvm.target.cuda(), target_host=None): + """return a sample task for testing""" + target, target_host = Target.canon_target_and_host(target, target_host) + task = autotvm.task.create( + "testing/conv2d_no_batching", args=(1, 7, 7, 512, 512, 3, 3), target=target + ) + return task, target + + +def run_test_with_all_multiprocessing(func, *args, **kwargs): + """Check all multiprocessing methods work for the tuning test. + + In the past fork() had the most support at detriment to spawn() and forkserver(). + As fork() is unavailable or unsafe on some platforms it is good to check all + available methods. + """ + for multiprocessing_method in mp.get_all_start_methods(): + old_start_method = mp.get_start_method() + try: + mp.set_start_method(multiprocessing_method, force=True) + func(*args, **kwargs) + finally: + mp.set_start_method(old_start_method, force=True) + + +@tvm.testing.parametrize_targets("cuda", "opencl") +def test_tuning_gpu(target): + """Test gpu tuning.""" + + def runner(target): + # init task + task, target = get_sample_task(target, None) + logging.info("task config space: %s", task.config_space) + + measure_option = autotvm.measure_option(autotvm.LocalBuilder(), autotvm.LocalRunner()) + + results = [] + + tuner = RandomTuner(task) + tuner.tune( + n_trial=20, + measure_option=measure_option, + callbacks=(lambda _tuner, _inputs, rs: results.extend(rs),), + ) + + assert len(results) == 20 + + successful_results = [ + r + for r in results + if r.error_no == autotvm.MeasureErrorNo.NO_ERROR + # We filter records before building if we know they won't work ahead of time. + # We can't guarantee we get one good record so we count these as success too + or r.error_no == autotvm.MeasureErrorNo.INSTANTIATION_ERROR + ] + assert len(successful_results) > 0, f"No successful tuning runs: {results!r}" + + run_test_with_all_multiprocessing(runner, target) + + +@tvm.testing.parametrize_targets("cuda", "opencl") +def test_tuning_gpu_inherits_pass_context(target): + """Autotvm tuner inherits PassContexts but also adds a gpu verification pass by default. + + Test that using PassContext inherits passes properly but also runs gpu verification pass. + """ + + @pass_instrument + class PassInstrumentChecker: + """Pass Instrument that simply sees if it's been run.""" + + def __init__(self): + self.has_been_run = False + + def run_after_pass(self, *_): + self.has_been_run = True + + class GPUVerifyPassMocked: + """Context manager that mocks tir.analysis.verify_gpu_code meant + to verify the pass has been run. This is done by patching the ffi func handles.""" + + FFI_FUNC_HANDLE = "tir.analysis.verify_gpu_code" + FUNC_NAME = "verify_gpu_code" + + def __init__(self) -> None: + self.old_impl = tvm._ffi.get_global_func(self.FFI_FUNC_HANDLE) + self.has_been_run = False + + def gpu_verify_pass_mocked(self): + """Get the replacement for the gpu verification pass.""" + + def _gpu_verify_pass_mocked(*args, **kwargs): + self.has_been_run = True + return self.old_impl(*args, **kwargs) + + return _gpu_verify_pass_mocked + + def __enter__(self): + tvm._ffi.register_func( + self.FFI_FUNC_HANDLE, self.gpu_verify_pass_mocked(), override=True + ) + + # Also overwrite the python bindings + setattr( + _analysis_ffi_api, self.FUNC_NAME, tvm._ffi.get_global_func(self.FFI_FUNC_HANDLE) + ) + + def __exit__(self, *args, **kwargs): + # Restore FFI status back to normal + tvm._ffi.register_func(self.FFI_FUNC_HANDLE, self.old_impl, override=True) + setattr(_analysis_ffi_api, self.FUNC_NAME, self.old_impl) + + class OverwrittenBuildFunc(measure_methods._WrappedBuildFunc): + """BuildFunc that mocks and patches as necessary to test proper passes are run.""" + + def __call__(self, measure_input, tmp_dir, **kwargs): + instrument = PassInstrumentChecker() + mocked_pass_checker = GPUVerifyPassMocked() + with mocked_pass_checker: + with PassContext(instruments=[instrument]): + regular_result = super().__call__(measure_input, tmp_dir, **kwargs) + + # Check instrument has been run, meaning context was inherited by builder + assert instrument.has_been_run + + # But also check the gpu verification pass has been run + # (which was not in the inherited ctx) + assert mocked_pass_checker.has_been_run + + return regular_result + + class MockedLocalBuilder(measure_methods.LocalBuilder): + """As measure_methods.LocalBuilder but overwrites the PassContext for testing.""" + + def __init__( + self, + timeout=10, + n_parallel=None, + build_kwargs=None, + build_func="default", + do_fork=False, + runtime=None, + ): + # pylint: disable=too-many-function-args + super().__init__(timeout, n_parallel, build_kwargs, build_func, do_fork, runtime) + + self.build_func = OverwrittenBuildFunc(tar.tar, runtime) + + def runner(target): + task, target = get_sample_task(target, None) + logging.info("task config space: %s", task.config_space) + + # Note: we use the MockedLocalBuilder here instead of autotvm.LocalBuilder() + measure_option = autotvm.measure_option(MockedLocalBuilder(), autotvm.LocalRunner()) + + results = [] + + tuner = RandomTuner(task) + tuner.tune( + n_trial=1, + measure_option=measure_option, + callbacks=(lambda _tuner, _inputs, rs: results.extend(rs),), + ) + + assert len(results) == 1 + + run_test_with_all_multiprocessing(runner, target) + + +def test_tuning_cpu(): + """Test tuning on cpu.""" + + def runner(): + ir_mod = tvm.parser.fromtext( + textwrap.dedent( + """ + #[version = "0.0.5"] + def @main(%a : Tensor[(1, 3, 32, 32), float32], %b : Tensor[(3, 3, 5, 5), float32]) { + nn.conv2d(%a, %b, data_layout="NCHW", kernel_layout="OIHW") + } + """ + ) + ) + tasks = autotvm.task.relay_integration.extract_from_program( + ir_mod, {}, tvm.target.create("llvm") + ) + assert len(tasks) == 1, f"Extracted != 1 task from program: {tasks!r}" + + task = tasks[0] + + measure_option = autotvm.measure_option(autotvm.LocalBuilder(), autotvm.LocalRunner()) + + results = [] + + tuner = RandomTuner(task) + tuner.tune( + n_trial=20, + measure_option=measure_option, + callbacks=(lambda _tuner, _inputs, rs: results.extend(rs),), + ) + + assert len(results) == 20 + + successful_results = [r for r in results if r.error_no == autotvm.MeasureErrorNo.NO_ERROR] + assert len(successful_results) > 0, f"No successful tuning runs: {results!r}" + + run_test_with_all_multiprocessing(runner) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/integration/test_tuning.py b/tests/python/integration/test_tuning.py index 04c5f85ce5d4..af5143908108 100644 --- a/tests/python/integration/test_tuning.py +++ b/tests/python/integration/test_tuning.py @@ -14,367 +14,86 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -""" -Test the tuner -""" +# pylint: disable=missing-docstring import logging -import multiprocessing as mp -import textwrap +import tempfile +from typing import List, Optional +import numpy as np # type: ignore +import pytest import tvm -import tvm.relay -import tvm.testing -from tvm import autotvm, te -from tvm.autotvm.measure import measure_methods -from tvm.autotvm.tuner import RandomTuner -from tvm.contrib import tar -from tvm.ir.instrument import pass_instrument -from tvm.ir.transform import PassContext -from tvm.target import Target -from tvm.tir.analysis import _ffi_api as _analysis_ffi_api - - -def setup_module(): - """Setup the module used for testing.""" - - @autotvm.template("testing/conv2d_no_batching") - def conv2d_no_batching( # pylint: disable=unused-variable - batch_size, input_h, input_w, channels_in, channels_out, kernel_h, kernel_w - ): - """An example template for testing""" - assert batch_size == 1, "Only consider batch_size = 1 in this template" - - data = te.placeholder((batch_size, channels_in, input_h, input_w), name="data") - kernel = te.placeholder((channels_out, channels_in, kernel_h, kernel_w), name="kernel") - - axis_rc = te.reduce_axis((0, channels_in), name="rc") - axis_ry = te.reduce_axis((0, kernel_h), name="ry") - axis_rx = te.reduce_axis((0, kernel_w), name="rx") - - conv = te.compute( - (batch_size, channels_out, input_h - kernel_h + 1, input_w - kernel_w + 1), - lambda nn, ff, yy, xx: te.sum( - data[nn, axis_rc, yy + axis_ry, xx + axis_rx] - * kernel[ff, axis_rc, axis_ry, axis_rx], - axis=[axis_rc, axis_ry, axis_rx], - ), - tag="conv2d_nchw", - ) - - schedule = te.create_schedule([conv.op]) - - output = conv - cache_write_ol = schedule.cache_write(conv, "local") - - # create cache stage - cache_read_aa = schedule.cache_read(data, "shared", [cache_write_ol]) - cache_read_ww = schedule.cache_read(kernel, "shared", [cache_write_ol]) - cache_read_al = schedule.cache_read(cache_read_aa, "local", [cache_write_ol]) - cache_read_wl = schedule.cache_read(cache_read_ww, "local", [cache_write_ol]) - - # tile and bind spatial axes - axis_n, axis_f, axis_y, axis_x = schedule[output].op.axis - cfg = autotvm.get_config() - cfg.define_split("tile_f", cfg.axis(axis_f), num_outputs=4) - cfg.define_split("tile_y", cfg.axis(axis_y), num_outputs=4) - cfg.define_split("tile_x", cfg.axis(axis_x), num_outputs=4) - axis_bf, axis_vf, axis_tf, axis_fi = cfg["tile_f"].apply(schedule, output, axis_f) - axis_by, axis_vy, axis_ty, axis_yi = cfg["tile_y"].apply(schedule, output, axis_y) - axis_bx, axis_vx, axis_tx, axis_xi = cfg["tile_x"].apply(schedule, output, axis_x) - kernel_scope = axis_n # this is the scope to attach global config inside this kernel - - schedule[output].bind(axis_bf, te.thread_axis("blockIdx.z")) - schedule[output].bind(axis_by, te.thread_axis("blockIdx.y")) - schedule[output].bind(axis_bx, te.thread_axis("blockIdx.x")) - schedule[output].bind(axis_vf, te.thread_axis("vthread")) - schedule[output].bind(axis_vy, te.thread_axis("vthread")) - schedule[output].bind(axis_vx, te.thread_axis("vthread")) - schedule[output].bind(axis_tf, te.thread_axis("threadIdx.z")) - schedule[output].bind(axis_ty, te.thread_axis("threadIdx.y")) - schedule[output].bind(axis_tx, te.thread_axis("threadIdx.x")) - schedule[output].reorder( - axis_n, - axis_bf, - axis_by, - axis_bx, - axis_vf, - axis_vy, - axis_vx, - axis_tf, - axis_ty, - axis_tx, - axis_fi, - axis_yi, - axis_xi, - ) - schedule[cache_write_ol].compute_at(schedule[output], axis_tx) - - # tile and bind reduction axes - axis_n, axis_f, axis_y, axis_x = schedule[cache_write_ol].op.axis - axis_rc, axis_ry, axis_rx = schedule[cache_write_ol].op.reduce_axis - cfg.define_split("tile_rc", cfg.axis(axis_rc), num_outputs=3) - cfg.define_split("tile_ry", cfg.axis(axis_ry), num_outputs=3) - cfg.define_split("tile_rx", cfg.axis(axis_rx), num_outputs=3) - axis_rco, axis_rcm, axis_rci = cfg["tile_rc"].apply(schedule, cache_write_ol, axis_rc) - axis_ryo, axis_rym, axis_ryi = cfg["tile_rx"].apply(schedule, cache_write_ol, axis_ry) - axis_rxo, axis_rxm, axis_rxi = cfg["tile_ry"].apply(schedule, cache_write_ol, axis_rx) - schedule[cache_write_ol].reorder( - axis_rco, - axis_ryo, - axis_rxo, - axis_rcm, - axis_rym, - axis_rxm, - axis_rci, - axis_ryi, - axis_rxi, - axis_n, - axis_f, - axis_y, - axis_x, - ) - - schedule[cache_read_aa].compute_at(schedule[cache_write_ol], axis_rxo) - schedule[cache_read_ww].compute_at(schedule[cache_write_ol], axis_rxo) - schedule[cache_read_al].compute_at(schedule[cache_write_ol], axis_rxm) - schedule[cache_read_wl].compute_at(schedule[cache_write_ol], axis_rxm) - - # cooperative fetching - for load in [cache_read_aa, cache_read_ww]: - axis_n, axis_f, axis_y, axis_x = schedule[load].op.axis - fused = schedule[load].fuse(axis_n, axis_f, axis_y, axis_x) - axis_tz, fused = schedule[load].split(fused, nparts=cfg["tile_f"].size[2]) - axis_ty, fused = schedule[load].split(fused, nparts=cfg["tile_y"].size[2]) - axis_tx, fused = schedule[load].split(fused, nparts=cfg["tile_x"].size[2]) - schedule[load].bind(axis_tz, te.thread_axis("threadIdx.z")) - schedule[load].bind(axis_ty, te.thread_axis("threadIdx.y")) - schedule[load].bind(axis_tx, te.thread_axis("threadIdx.x")) - - # tune unroll - cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) - cfg.define_knob("unroll_explicit", [0, 1]) - schedule[output].pragma( - kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val - ) - schedule[output].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val) - - return schedule, [data, kernel, conv] - - -def teardown_module(): - """Remove the module from the autotvm task tables.""" - # TODO(areusch): Tasks should not be registered into a global. - del autotvm.task.task.TASK_TABLE["testing/conv2d_no_batching"] - - -def get_sample_task(target=tvm.target.cuda(), target_host=None): - """return a sample task for testing""" - target, target_host = Target.canon_target_and_host(target, target_host) - task = autotvm.task.create( - "testing/conv2d_no_batching", args=(1, 7, 7, 512, 512, 3, 3), target=target +from tvm import meta_schedule as ms +from tvm import relay +from tvm.contrib import graph_executor +from tvm.meta_schedule.testing.relay_workload import get_network +from tvm.target.target import Target + +logging.basicConfig( + format="%(asctime)s.%(msecs)03d %(levelname)s %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) +logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) + + +@pytest.mark.skip("Integration test") +@pytest.mark.parametrize( + "model_name, input_shape, target, layout", + [ + ("resnet_18", [1, 3, 224, 224], "llvm --num-cores=16", "NHWC"), + ("resnet_18", [1, 3, 224, 224], "nvidia/geforce-rtx-3090-ti", "NHWC"), + ], +) +def test_meta_schedule_tune_relay( + model_name: str, + input_shape: List[int], + target: str, + layout: Optional[str], +): + dev = tvm.cpu() if str(target).startswith("llvm") else tvm.cuda() + if model_name.startswith("bert"): + data = tvm.nd.array(np.random.randint(0, 30521, size=input_shape), dev) # embedding size + else: + data = tvm.nd.array(np.random.randn(*input_shape).astype("float32"), dev) + + mod, params, (input_name, _, _) = get_network( + name=model_name, + input_shape=input_shape, + layout=layout, ) - return task, target - - -def run_test_with_all_multiprocessing(func, *args, **kwargs): - """Check all multiprocessing methods work for the tuning test. - - In the past fork() had the most support at detriment to spawn() and forkserver(). - As fork() is unavailable or unsafe on some platforms it is good to check all - available methods. - """ - for multiprocessing_method in mp.get_all_start_methods(): - old_start_method = mp.get_start_method() - try: - mp.set_start_method(multiprocessing_method, force=True) - func(*args, **kwargs) - finally: - mp.set_start_method(old_start_method, force=True) - - -@tvm.testing.parametrize_targets("cuda", "opencl") -def test_tuning_gpu(target): - """Test gpu tuning.""" - - def runner(target): - # init task - task, target = get_sample_task(target, None) - logging.info("task config space: %s", task.config_space) - - measure_option = autotvm.measure_option(autotvm.LocalBuilder(), autotvm.LocalRunner()) - - results = [] - - tuner = RandomTuner(task) - tuner.tune( - n_trial=20, - measure_option=measure_option, - callbacks=(lambda _tuner, _inputs, rs: results.extend(rs),), - ) - - assert len(results) == 20 - - successful_results = [ - r - for r in results - if r.error_no == autotvm.MeasureErrorNo.NO_ERROR - # We filter records before building if we know they won't work ahead of time. - # We can't guarantee we get one good record so we count these as success too - or r.error_no == autotvm.MeasureErrorNo.INSTANTIATION_ERROR - ] - assert len(successful_results) > 0, f"No successful tuning runs: {results!r}" - - run_test_with_all_multiprocessing(runner, target) - - -@tvm.testing.parametrize_targets("cuda", "opencl") -def test_tuning_gpu_inherits_pass_context(target): - """Autotvm tuner inherits PassContexts but also adds a gpu verification pass by default. - - Test that using PassContext inherits passes properly but also runs gpu verification pass. - """ - - @pass_instrument - class PassInstrumentChecker: - """Pass Instrument that simply sees if it's been run.""" - - def __init__(self): - self.has_been_run = False - def run_after_pass(self, *_): - self.has_been_run = True - - class GPUVerifyPassMocked: - """Context manager that mocks tir.analysis.verify_gpu_code meant - to verify the pass has been run. This is done by patching the ffi func handles.""" - - FFI_FUNC_HANDLE = "tir.analysis.verify_gpu_code" - FUNC_NAME = "verify_gpu_code" - - def __init__(self) -> None: - self.old_impl = tvm._ffi.get_global_func(self.FFI_FUNC_HANDLE) - self.has_been_run = False - - def gpu_verify_pass_mocked(self): - """Get the replacement for the gpu verification pass.""" - - def _gpu_verify_pass_mocked(*args, **kwargs): - self.has_been_run = True - return self.old_impl(*args, **kwargs) - - return _gpu_verify_pass_mocked - - def __enter__(self): - tvm._ffi.register_func( - self.FFI_FUNC_HANDLE, self.gpu_verify_pass_mocked(), override=True - ) - - # Also overwrite the python bindings - setattr( - _analysis_ffi_api, self.FUNC_NAME, tvm._ffi.get_global_func(self.FFI_FUNC_HANDLE) + target = Target(target) + with tempfile.TemporaryDirectory() as work_dir: + with ms.Profiler() as profiler: + database = ms.relay_integration.tune_relay( + mod=mod, + target=target, + params=params, + work_dir=work_dir, + max_trials_global=2048, ) - - def __exit__(self, *args, **kwargs): - # Restore FFI status back to normal - tvm._ffi.register_func(self.FFI_FUNC_HANDLE, self.old_impl, override=True) - setattr(_analysis_ffi_api, self.FUNC_NAME, self.old_impl) - - class OverwrittenBuildFunc(measure_methods._WrappedBuildFunc): - """BuildFunc that mocks and patches as necessary to test proper passes are run.""" - - def __call__(self, measure_input, tmp_dir, **kwargs): - instrument = PassInstrumentChecker() - mocked_pass_checker = GPUVerifyPassMocked() - with mocked_pass_checker: - with PassContext(instruments=[instrument]): - regular_result = super().__call__(measure_input, tmp_dir, **kwargs) - - # Check instrument has been run, meaning context was inherited by builder - assert instrument.has_been_run - - # But also check the gpu verification pass has been run - # (which was not in the inherited ctx) - assert mocked_pass_checker.has_been_run - - return regular_result - - class MockedLocalBuilder(measure_methods.LocalBuilder): - """As measure_methods.LocalBuilder but overwrites the PassContext for testing.""" - - def __init__( - self, - timeout=10, - n_parallel=None, - build_kwargs=None, - build_func="default", - do_fork=False, - runtime=None, - ): - # pylint: disable=too-many-function-args - super().__init__(timeout, n_parallel, build_kwargs, build_func, do_fork, runtime) - - self.build_func = OverwrittenBuildFunc(tar.tar, runtime) - - def runner(target): - task, target = get_sample_task(target, None) - logging.info("task config space: %s", task.config_space) - - # Note: we use the MockedLocalBuilder here instead of autotvm.LocalBuilder() - measure_option = autotvm.measure_option(MockedLocalBuilder(), autotvm.LocalRunner()) - - results = [] - - tuner = RandomTuner(task) - tuner.tune( - n_trial=1, - measure_option=measure_option, - callbacks=(lambda _tuner, _inputs, rs: results.extend(rs),), - ) - - assert len(results) == 1 - - run_test_with_all_multiprocessing(runner, target) - - -def test_tuning_cpu(): - """Test tuning on cpu.""" - - def runner(): - ir_mod = tvm.parser.fromtext( - textwrap.dedent( - """ - #[version = "0.0.5"] - def @main(%a : Tensor[(1, 3, 32, 32), float32], %b : Tensor[(3, 3, 5, 5), float32]) { - nn.conv2d(%a, %b, data_layout="NCHW", kernel_layout="OIHW") - } - """ + rt_mod1 = ms.relay_integration.compile_relay( + database=database, + mod=mod, + target=target, + params=params, ) - ) - tasks = autotvm.task.relay_integration.extract_from_program( - ir_mod, {}, tvm.target.create("llvm") - ) - assert len(tasks) == 1, f"Extracted != 1 task from program: {tasks!r}" - - task = tasks[0] - - measure_option = autotvm.measure_option(autotvm.LocalBuilder(), autotvm.LocalRunner()) - - results = [] - - tuner = RandomTuner(task) - tuner.tune( - n_trial=20, - measure_option=measure_option, - callbacks=(lambda _tuner, _inputs, rs: results.extend(rs),), - ) - - assert len(results) == 20 - - successful_results = [r for r in results if r.error_no == autotvm.MeasureErrorNo.NO_ERROR] - assert len(successful_results) > 0, f"No successful tuning runs: {results!r}" - - run_test_with_all_multiprocessing(runner) - - -if __name__ == "__main__": - tvm.testing.main() + print(profiler.table()) + # Compile without meta-schedule for correctness check + with tvm.transform.PassContext(opt_level=0): + rt_mod2 = relay.build(mod, target=target, params=params) + + def get_output(data, lib): + module = graph_executor.GraphModule(lib["default"](dev)) + module.set_input(input_name, data) + module.run() + return module.get_output(0).numpy() + + # Check correctness + actual_output = get_output(data, rt_mod1) + expected_output = get_output(data, rt_mod2) + assert np.allclose(actual_output, expected_output, rtol=1e-4, atol=2e-4) + + +if __name__ == """__main__""": + test_meta_schedule_tune_relay("resnet_18", [1, 3, 224, 224], "llvm --num-cores=16", "NHWC") + test_meta_schedule_tune_relay("resnet_18", [1, 3, 224, 224], "nvidia/geforce-rtx-3090-ti", None) diff --git a/tests/python/unittest/test_meta_schedule_cost_model.py b/tests/python/unittest/test_meta_schedule_cost_model.py index c47897eabb3e..ed5229a20af5 100644 --- a/tests/python/unittest/test_meta_schedule_cost_model.py +++ b/tests/python/unittest/test_meta_schedule_cost_model.py @@ -15,27 +15,27 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-docstring -from typing import List - import os import re import shutil import tempfile -from functools import partial import unittest -import numpy as np +from functools import partial +from typing import List +import numpy as np import tvm import tvm.testing -from tvm.script import tir as T -from tvm.tir.schedule.schedule import Schedule from tvm.meta_schedule.cost_model import PyCostModel, RandomModel, XGBModel -from tvm.meta_schedule.cost_model.xgb_model import _get_custom_call_back, PackSum +from tvm.meta_schedule.cost_model.xgb_model import PackSum, _get_custom_call_back from tvm.meta_schedule.feature_extractor import RandomFeatureExtractor from tvm.meta_schedule.runner import RunnerResult from tvm.meta_schedule.search_strategy import MeasureCandidate from tvm.meta_schedule.tune_context import TuneContext from tvm.meta_schedule.utils import derived_object +from tvm.script import tir as T +from tvm.tir.schedule.schedule import Schedule + # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring @tvm.script.ir_module @@ -244,9 +244,10 @@ def xgb_version_check(): @unittest.skipIf(xgb_version_check(), "test not supported for xgboost version after 1.6.0") def test_meta_schedule_xgb_model_callback_as_function(): # pylint: disable=import-outside-toplevel - import xgboost as xgb from itertools import chain as itertools_chain + import xgboost as xgb + # pylint: enable=import-outside-toplevel extractor = RandomFeatureExtractor() diff --git a/tests/python/unittest/test_meta_schedule_custom_rule_winograd_cpu.py b/tests/python/unittest/test_meta_schedule_custom_rule_winograd_cpu.py index 69408a2e901a..ac18bab81006 100644 --- a/tests/python/unittest/test_meta_schedule_custom_rule_winograd_cpu.py +++ b/tests/python/unittest/test_meta_schedule_custom_rule_winograd_cpu.py @@ -168,10 +168,6 @@ def test_conv2d_winograd_cpu(): target=target, task_name="Custom Search Space Task", space_generator=ms.space_generator.PostOrderApply(), - sch_rules=ms.default_config.schedule_rules( - None, - target, - ), ) post_order_apply = context.space_generator (sch,) = post_order_apply.generate_design_space(mod) diff --git a/tests/python/unittest/test_meta_schedule_custom_rule_winograd_cuda.py b/tests/python/unittest/test_meta_schedule_custom_rule_winograd_cuda.py index 958baabedb6d..89a04a9464ce 100644 --- a/tests/python/unittest/test_meta_schedule_custom_rule_winograd_cuda.py +++ b/tests/python/unittest/test_meta_schedule_custom_rule_winograd_cuda.py @@ -286,9 +286,6 @@ def test_conv2d_winograd_cuda(): target=Target("nvidia/geforce-rtx-3090", host="llvm"), task_name="Custom Search Space Task", space_generator=ms.space_generator.PostOrderApply(), - sch_rules=ms.default_config.schedule_rules( # pylint: disable=protected-access - None, Target("cuda") - ), ) post_order_apply = context.space_generator (sch,) = post_order_apply.generate_design_space(mod) diff --git a/tests/python/unittest/test_meta_schedule_measure_callback.py b/tests/python/unittest/test_meta_schedule_measure_callback.py index fba8c883e501..20596e8e8c4d 100644 --- a/tests/python/unittest/test_meta_schedule_measure_callback.py +++ b/tests/python/unittest/test_meta_schedule_measure_callback.py @@ -73,14 +73,7 @@ def apply( measure_callback = FancyMeasureCallback() measure_callback.apply( - ms.task_scheduler.RoundRobin( - tasks=[], - task_weights=[], - builder=DummyBuilder(), - runner=DummyRunner(), - database=ms.database.MemoryDatabase(), - max_trials=1, - ), + ms.task_scheduler.RoundRobin(), 0, [ms.MeasureCandidate(Schedule(Matmul), None)], [ms.builder.BuilderResult("test_build", None)], @@ -104,14 +97,7 @@ def apply( measure_callback = FailingMeasureCallback() with pytest.raises(ValueError, match="test"): measure_callback.apply( - ms.task_scheduler.RoundRobin( - tasks=[], - task_weights=[], - builder=DummyBuilder(), - runner=DummyRunner(), - database=ms.database.MemoryDatabase(), - max_trials=1, - ), + ms.task_scheduler.RoundRobin(), 0, [ms.MeasureCandidate(Schedule(Matmul), None)], [ms.builder.BuilderResult("test_build", None)], diff --git a/tests/python/unittest/test_meta_schedule_mutator_mutate_compute_location.py b/tests/python/unittest/test_meta_schedule_mutator_mutate_compute_location.py index 3d4a9966cb90..4147a9fbab86 100644 --- a/tests/python/unittest/test_meta_schedule_mutator_mutate_compute_location.py +++ b/tests/python/unittest/test_meta_schedule_mutator_mutate_compute_location.py @@ -15,8 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring -from tvm.meta_schedule import TuneContext -from tvm.meta_schedule.mutator import MutateComputeLocation, Mutator +from tvm import meta_schedule as ms from tvm.script import tir as T from tvm.target import Target from tvm.tir import Schedule @@ -61,15 +60,17 @@ def _sch(decision: int) -> Schedule: return sch -def _make_mutator(target: Target) -> Mutator: - ctx = TuneContext( +def _make_mutator(target: Target) -> ms.Mutator: + ctx = ms.TuneContext( mod=add, target=target, - mutator_probs={ - MutateComputeLocation(): 1.0, - }, + space_generator=ms.space_generator.PostOrderApply( + sch_rules=[], + postprocs=[], + mutator_probs={ms.mutator.MutateComputeLocation(): 1.0}, + ), ) - return list(ctx.mutator_probs.keys())[0] + return list(ctx.space_generator.mutator_probs.keys())[0] def test_mutate_compute_location_add(): diff --git a/tests/python/unittest/test_meta_schedule_mutator_mutate_parallel.py b/tests/python/unittest/test_meta_schedule_mutator_mutate_parallel.py index b517c3ed490a..728f522335bf 100644 --- a/tests/python/unittest/test_meta_schedule_mutator_mutate_parallel.py +++ b/tests/python/unittest/test_meta_schedule_mutator_mutate_parallel.py @@ -17,8 +17,7 @@ # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring from typing import List -from tvm.meta_schedule import TuneContext -from tvm.meta_schedule.mutator import MutateParallel, Mutator +from tvm import meta_schedule as ms from tvm.script import tir as T from tvm.target import Target from tvm.tir import Schedule @@ -79,15 +78,17 @@ def _sch(decisions: List[List[int]], ann_val: int) -> Schedule: return sch -def _make_mutator(target: Target, max_jobs_per_core: int) -> Mutator: - ctx = TuneContext( +def _make_mutator(target: Target, max_jobs_per_core: int) -> ms.Mutator: + ctx = ms.TuneContext( mod=matmul, target=target, - mutator_probs={ - MutateParallel(max_jobs_per_core): 1.0, - }, + space_generator=ms.space_generator.PostOrderApply( + sch_rules=[], + postprocs=[], + mutator_probs={ms.mutator.MutateParallel(max_jobs_per_core): 1.0}, + ), ) - return list(ctx.mutator_probs.keys())[0] + return list(ctx.space_generator.mutator_probs.keys())[0] def test_mutate_parallel_matmul(): diff --git a/tests/python/unittest/test_meta_schedule_mutator_mutate_thread_binding.py b/tests/python/unittest/test_meta_schedule_mutator_mutate_thread_binding.py index 1dc7588edd7d..d3a431af0687 100644 --- a/tests/python/unittest/test_meta_schedule_mutator_mutate_thread_binding.py +++ b/tests/python/unittest/test_meta_schedule_mutator_mutate_thread_binding.py @@ -15,8 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring -from tvm.meta_schedule import TuneContext -from tvm.meta_schedule.mutator import MutateThreadBinding, Mutator +from tvm import meta_schedule as ms from tvm.script import tir as T from tvm.target import Target from tvm.tir import Schedule @@ -62,15 +61,17 @@ def _sch() -> Schedule: return sch -def _make_mutator(target: Target) -> Mutator: - ctx = TuneContext( +def _make_mutator(target: Target) -> ms.Mutator: + ctx = ms.TuneContext( mod=element_wise, target=target, - mutator_probs={ - MutateThreadBinding(): 1.0, - }, + space_generator=ms.space_generator.PostOrderApply( + sch_rules=[], + postprocs=[], + mutator_probs={ms.mutator.MutateThreadBinding(): 1.0}, + ), ) - return list(ctx.mutator_probs.keys())[0] + return list(ctx.space_generator.mutator_probs.keys())[0] def test_mutate_thread_binding(): diff --git a/tests/python/unittest/test_meta_schedule_mutator_mutate_tile_size.py b/tests/python/unittest/test_meta_schedule_mutator_mutate_tile_size.py index 00b190a75de7..0600c0b79194 100644 --- a/tests/python/unittest/test_meta_schedule_mutator_mutate_tile_size.py +++ b/tests/python/unittest/test_meta_schedule_mutator_mutate_tile_size.py @@ -19,8 +19,7 @@ from functools import reduce from typing import List -from tvm.meta_schedule import TuneContext -from tvm.meta_schedule.mutator import MutateTileSize, Mutator +from tvm import meta_schedule as ms from tvm.script import tir as T from tvm.target import Target from tvm.tir import Schedule @@ -67,13 +66,17 @@ def _sch(decisions: List[List[int]]) -> Schedule: return sch -def _make_mutator(target: Target) -> Mutator: - ctx = TuneContext( +def _make_mutator(target: Target) -> ms.Mutator: + ctx = ms.TuneContext( mod=matmul, target=target, - mutator_probs={MutateTileSize(): 1.0}, + space_generator=ms.space_generator.PostOrderApply( + sch_rules=[], + postprocs=[], + mutator_probs={ms.mutator.MutateTileSize(): 1.0}, + ), ) - return list(ctx.mutator_probs.keys())[0] + return list(ctx.space_generator.mutator_probs.keys())[0] def test_mutate_tile_size_matmul(): diff --git a/tests/python/unittest/test_meta_schedule_mutator_mutate_unroll.py b/tests/python/unittest/test_meta_schedule_mutator_mutate_unroll.py index 7bed83f52232..a59a7e655b09 100644 --- a/tests/python/unittest/test_meta_schedule_mutator_mutate_unroll.py +++ b/tests/python/unittest/test_meta_schedule_mutator_mutate_unroll.py @@ -17,8 +17,7 @@ # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring from typing import List -from tvm.meta_schedule import TuneContext -from tvm.meta_schedule.mutator import MutateUnroll, Mutator +from tvm import meta_schedule as ms from tvm.script import tir as T from tvm.target import Target from tvm.tir import Schedule @@ -84,15 +83,17 @@ def _sch(decisions: List[List[int]]) -> Schedule: return sch -def _make_mutator(target: Target) -> Mutator: - ctx = TuneContext( +def _make_mutator(target: Target) -> ms.Mutator: + ctx = ms.TuneContext( mod=matmul, target=target, - mutator_probs={ - MutateUnroll(): 1.0, - }, + space_generator=ms.space_generator.PostOrderApply( + sch_rules=[], + postprocs=[], + mutator_probs={ms.mutator.MutateUnroll(): 1.0}, + ), ) - return list(ctx.mutator_probs.keys())[0] + return list(ctx.space_generator.mutator_probs.keys())[0] def test_mutate_unroll_matmul(): diff --git a/tests/python/unittest/test_meta_schedule_post_order_apply.py b/tests/python/unittest/test_meta_schedule_post_order_apply.py index b40ba2869d1c..9026feb9e08e 100644 --- a/tests/python/unittest/test_meta_schedule_post_order_apply.py +++ b/tests/python/unittest/test_meta_schedule_post_order_apply.py @@ -243,8 +243,11 @@ def test_meta_schedule_post_order_apply(): mod=mod, target=Target("llvm"), task_name="Test Task", - space_generator=PostOrderApply(), - sch_rules=[WowSoFancyScheduleRule()], + space_generator=PostOrderApply( + sch_rules=[WowSoFancyScheduleRule()], + postprocs=[], + mutator_probs={}, + ), ) post_order_apply = context.space_generator schs = post_order_apply.generate_design_space(mod) @@ -259,8 +262,11 @@ def test_meta_schedule_post_order_apply_double(): mod=mod, target=Target("llvm"), task_name="Double Rules Task", - space_generator=PostOrderApply(), - sch_rules=[DoubleScheduleRule()], + space_generator=PostOrderApply( + sch_rules=[DoubleScheduleRule()], + postprocs=[], + mutator_probs={}, + ), ) post_order_apply = context.space_generator schs = post_order_apply.generate_design_space(mod) @@ -276,8 +282,11 @@ def test_meta_schedule_post_order_apply_multiple(): mod=mod, target=Target("llvm"), task_name="Double Rules Task", - space_generator=PostOrderApply(), - sch_rules=[DoubleScheduleRule(), ReorderScheduleRule()], + space_generator=PostOrderApply( + sch_rules=[DoubleScheduleRule(), ReorderScheduleRule()], + postprocs=[], + mutator_probs={}, + ), ) post_order_apply = context.space_generator schs = post_order_apply.generate_design_space(mod) @@ -293,8 +302,11 @@ def test_meta_schedule_post_order_apply_duplicate_matmul(): mod=mod, target=Target("llvm"), task_name="Duplicate Matmul Task", - space_generator=PostOrderApply(), - sch_rules=[WowSoFancyScheduleRule()], + space_generator=PostOrderApply( + sch_rules=[WowSoFancyScheduleRule()], + postprocs=[], + mutator_probs={}, + ), ) post_order_apply = context.space_generator with pytest.raises( @@ -346,8 +358,11 @@ def correct_trace(a, b, c, d): mod=mod, target=Target("llvm"), task_name="Remove Block Task", - space_generator=PostOrderApply(), - sch_rules=[RemoveBlock(), TrinityDoubleRule()], + space_generator=PostOrderApply( + sch_rules=[RemoveBlock(), TrinityDoubleRule()], + postprocs=[], + mutator_probs={}, + ), ) post_order_apply = context.space_generator schs = post_order_apply.generate_design_space(mod) @@ -373,8 +388,11 @@ def test_meta_schedule_custom_search_space(): mod=mod, target=Target("llvm"), task_name="Custom Search Space Task", - space_generator=PostOrderApply(), - sch_rules=[], + space_generator=PostOrderApply( + sch_rules=[], + postprocs=[], + mutator_probs={}, + ), ) post_order_apply = context.space_generator post_order_apply.generate_design_space(mod) @@ -401,8 +419,12 @@ def _get_sch(filter_fn): mod=mod, target=Target("llvm"), task_name="Custom Search Space Task", - space_generator=PostOrderApply(f_block_filter=filter_fn), - sch_rules=[TrinityDoubleRule()], + space_generator=PostOrderApply( + f_block_filter=filter_fn, + sch_rules=[TrinityDoubleRule()], + postprocs=[], + mutator_probs={}, + ), ) post_order_apply = context.space_generator schs = post_order_apply.generate_design_space(mod) diff --git a/tests/python/unittest/test_meta_schedule_postproc_disallow_dynamic_loop.py b/tests/python/unittest/test_meta_schedule_postproc_disallow_dynamic_loop.py index 92c669ca1feb..5dc2500d1b2d 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_disallow_dynamic_loop.py +++ b/tests/python/unittest/test_meta_schedule_postproc_disallow_dynamic_loop.py @@ -17,9 +17,8 @@ # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring import tvm +from tvm import meta_schedule as ms from tvm import tir -from tvm.meta_schedule import TuneContext -from tvm.meta_schedule.postproc import DisallowDynamicLoop from tvm.script import tir as T from tvm.target import Target @@ -28,13 +27,17 @@ def _target() -> Target: return Target("cuda", host="llvm") -def _create_context(mod, target) -> TuneContext: - ctx = TuneContext( +def _create_context(mod, target) -> ms.TuneContext: + ctx = ms.TuneContext( mod=mod, target=target, - postprocs=[ - DisallowDynamicLoop(), - ], + space_generator=ms.space_generator.PostOrderApply( + sch_rules=[], + postprocs=[ + ms.postproc.DisallowDynamicLoop(), + ], + mutator_probs={}, + ), task_name="test", ) return ctx @@ -83,14 +86,14 @@ def test_postproc_disallow_dynamic_loops(): mod = Matmul ctx = _create_context(mod, target=_target()) sch = tir.Schedule(mod, debug_mask="all") - assert ctx.postprocs[0].apply(sch) + assert ctx.space_generator.postprocs[0].apply(sch) def test_postproc_disallow_dynamic_loops_fail(): mod = DynamicLoop ctx = _create_context(mod, target=_target()) sch = tir.Schedule(mod, debug_mask="all") - assert not ctx.postprocs[0].apply(sch) + assert not ctx.space_generator.postprocs[0].apply(sch) if __name__ == "__main__": diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py index e55f693e72d3..c82bc697c993 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py @@ -18,9 +18,8 @@ import tvm import tvm.testing +from tvm import meta_schedule as ms from tvm import tir -from tvm.meta_schedule import TuneContext -from tvm.meta_schedule.postproc import RewriteCooperativeFetch from tvm.meta_schedule.testing import te_workload from tvm.script import tir as T from tvm.target import Target @@ -31,13 +30,17 @@ def _target() -> Target: return Target("cuda", host="llvm") -def _create_context(mod, target) -> TuneContext: - ctx = TuneContext( +def _create_context(mod, target) -> ms.TuneContext: + ctx = ms.TuneContext( mod=mod, target=target, - postprocs=[ - RewriteCooperativeFetch(), - ], + space_generator=ms.space_generator.PostOrderApply( + sch_rules=[], + postprocs=[ + ms.postproc.RewriteCooperativeFetch(), + ], + mutator_probs={}, + ), task_name="test", ) return ctx @@ -246,7 +249,7 @@ def test_rewrite_cooperative_fetch(): # pylint: enable=line-too-long,invalid-name # fmt: on sch.enter_postproc() - assert ctx.postprocs[0].apply(sch) + assert ctx.space_generator.postprocs[0].apply(sch) tvm.ir.assert_structural_equal(sch.mod, AfterRewrite0) @@ -291,8 +294,7 @@ def test_rewrite_warp_execution(): # pylint: enable=line-too-long,invalid-name # fmt: on sch.enter_postproc() - assert ctx.postprocs[0].apply(sch) - print(sch.mod["main"].script()) + assert ctx.space_generator.postprocs[0].apply(sch) tvm.ir.assert_structural_equal(sch.mod, WarpExecutionAfterRewrite) diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py index e0ed68b69ce0..91a51c8e9033 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py @@ -15,10 +15,9 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring - import tvm -from tvm.meta_schedule import TuneContext -from tvm.meta_schedule.postproc import RewriteLayout +import tvm.testing +from tvm import meta_schedule as ms from tvm.script import tir as T from tvm.target import Target @@ -27,32 +26,41 @@ def _target() -> Target: return Target("cuda", host="llvm") -def _create_context(mod, target) -> TuneContext: - return TuneContext( +def _create_context(mod, target) -> ms.TuneContext: + ctx = ms.TuneContext( mod=mod, target=target, - postprocs=[ - RewriteLayout(), - ], + space_generator=ms.space_generator.PostOrderApply( + sch_rules=[], + postprocs=[ + ms.postproc.RewriteLayout(), + ], + mutator_probs={}, + ), task_name="test", ) + return ctx class BaseBeforeAfter(tvm.testing.CompareBeforeAfter): def transform(self): def inner(mod): target = Target("cuda", host="llvm") - ctx = TuneContext( + ctx = ms.TuneContext( mod=mod, target=target, - postprocs=[ - RewriteLayout(), - ], + space_generator=ms.space_generator.PostOrderApply( + sch_rules=[], + postprocs=[ + ms.postproc.RewriteLayout(), + ], + mutator_probs={}, + ), task_name="test", ) sch = tvm.tir.Schedule(mod, debug_mask="all") sch.enter_postproc() - assert ctx.postprocs[0].apply(sch) + assert ctx.space_generator.postprocs[0].apply(sch) return sch.mod return inner @@ -147,5 +155,54 @@ def expected(A: T.Buffer[(16, 1), "float32"]): T.evaluate(A_global[vi]) +@T.prim_func +def tir_matmul( + A: T.Buffer[(16, 16), "float32"], + B: T.Buffer[(16, 16), "float32"], + C: T.Buffer[(16, 16), "float32"], +) -> None: + T.func_attr({"layout_free_buffers": [1]}) + for i0, j, k0, i1, k1 in T.grid(4, 16, 4, 4, 4): + with T.block("matmul"): + vi = T.axis.S(16, i0 * 4 + i1) + vj = T.axis.S(16, j) + vk = T.axis.R(16, k0 * 4 + k1) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +@T.prim_func +def rewritten_tir_matmul( + A: T.Buffer[(16, 16), "float32"], + B: T.Buffer[(16, 16), "float32"], + C: T.Buffer[(16, 16), "float32"], +) -> None: + T.func_attr({"layout_free_buffers": [1]}) + B_reindex = T.alloc_buffer([16, 4, 4], dtype="float32") + for ax0, ax1 in T.grid(16, 16): + with T.block("layout_rewrite"): + i0, i1 = T.axis.remap("SS", [ax0, ax1]) + T.block_attr({"meta_schedule.layout_rewrite_preproc": True}) + B_reindex[i1, i0 // 4, i0 % 4] = B[i0, i1] + for i0, j, k0, i1, k1 in T.grid(4, 16, 4, 4, 4): + with T.block("matmul"): + vi = T.axis.spatial(16, i0 * 4 + i1) + vj = T.axis.spatial(16, j) + vk = T.axis.reduce(16, k0 * 4 + k1) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B_reindex[vj, vk // 4, vk % 4] + + +def test_layout_rewrite(): + target = _target() + ctx = _create_context(tir_matmul, target) + sch = tvm.tir.Schedule(tir_matmul, debug_mask="all") + sch.enter_postproc() + assert ctx.space_generator.postprocs[0].apply(sch) + tvm.ir.assert_structural_equal(sch.mod["main"], rewritten_tir_matmul) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_reduction_block.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_reduction_block.py index 24d1229b3ac6..7e499424058d 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_reduction_block.py +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_reduction_block.py @@ -17,9 +17,8 @@ # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring import tvm +from tvm import meta_schedule as ms from tvm import tir -from tvm.meta_schedule import TuneContext -from tvm.meta_schedule.postproc import RewriteReductionBlock from tvm.script import tir as T from tvm.target import Target @@ -28,13 +27,17 @@ def _target() -> Target: return Target("cuda", host="llvm") -def _create_context(mod, target) -> TuneContext: - ctx = TuneContext( +def _create_context(mod, target) -> ms.TuneContext: + ctx = ms.TuneContext( mod=mod, target=target, - postprocs=[ - RewriteReductionBlock(), - ], + space_generator=ms.space_generator.PostOrderApply( + sch_rules=[], + postprocs=[ + ms.postproc.RewriteReductionBlock(), + ], + mutator_probs={}, + ), task_name="test", ) return ctx @@ -200,7 +203,7 @@ def test_rewrite_tiled_matmul(): ctx = _create_context(mod, target) sch = tir.Schedule(mod, debug_mask="all") sch.enter_postproc() - assert ctx.postprocs[0].apply(sch) + assert ctx.space_generator.postprocs[0].apply(sch) tvm.ir.assert_structural_equal(sch.mod, Matmul_after_rewrite) @@ -210,7 +213,7 @@ def test_rewrite_softmax(): ctx = _create_context(mod, target) sch = tir.Schedule(mod, debug_mask="all") sch.enter_postproc() - assert ctx.postprocs[0].apply(sch) + assert ctx.space_generator.postprocs[0].apply(sch) # The module should not be rewritten tvm.ir.assert_structural_equal(sch.mod, Softmax_cross_thread_reduction) diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_tensorize.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_tensorize.py index fc624cd5a68f..8f9d287621e2 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_tensorize.py +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_tensorize.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring import tvm -from tvm.meta_schedule import TuneContext, postproc +from tvm import meta_schedule as ms from tvm.script import tir as T from tvm.tir.tensor_intrin import arm_cpu, cuda, rocm, x86 @@ -450,11 +450,15 @@ def main( compute[v0, v1] = compute_local[v0, v1] -def _create_context(mod, target, postprocs): - ctx = TuneContext( +def _create_context(mod, target, postprocs) -> ms.TuneContext: + ctx = ms.TuneContext( mod=mod, target=target, - postprocs=postprocs, + space_generator=ms.space_generator.PostOrderApply( + sch_rules=[], + postprocs=postprocs, + mutator_probs={}, + ), task_name="test", ) return ctx @@ -467,14 +471,14 @@ def test_rewrite_tensorize_conv2d_nchwc_vnni(): mod, target, [ - postproc.RewriteReductionBlock(), - postproc.RewriteTensorize(True), + ms.postproc.RewriteReductionBlock(), + ms.postproc.RewriteTensorize(True), ], ) sch = tvm.tir.Schedule(mod, debug_mask="all") sch.enter_postproc() - for proc in ctx.postprocs: + for proc in ctx.space_generator.postprocs: proc.apply(sch) tvm.ir.assert_structural_equal(sch.mod, Conv2dNCHWcVNNIModuleTensorized) @@ -487,15 +491,15 @@ def test_rewrite_tensorize_dense_dp4a(): mod, target, [ - postproc.RewriteCooperativeFetch(), - postproc.RewriteReductionBlock(), - postproc.RewriteTensorize(), + ms.postproc.RewriteCooperativeFetch(), + ms.postproc.RewriteReductionBlock(), + ms.postproc.RewriteTensorize(), ], ) sch = tvm.tir.Schedule(mod, debug_mask="all") sch.enter_postproc() - for proc in ctx.postprocs: + for proc in ctx.space_generator.postprocs: proc.apply(sch) tvm.ir.assert_structural_equal(sch.mod, DenseDP4ATensorized) diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py index ebc435a02e8b..b01447ad4a9e 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py @@ -17,25 +17,25 @@ # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring import tvm +from tvm import meta_schedule as ms from tvm import tir -from tvm.meta_schedule import TuneContext -from tvm.meta_schedule.postproc import RewriteUnboundBlock from tvm.script import tir as T from tvm.target import Target -from tvm.tir.schedule.schedule import Schedule def _target() -> Target: return Target("cuda --max_threads_per_block=1024", host="llvm") -def _create_context(mod, target) -> TuneContext: - ctx = TuneContext( +def _create_context(mod, target) -> ms.TuneContext: + ctx = ms.TuneContext( mod=mod, target=target, - postprocs=[ - RewriteUnboundBlock(), - ], + space_generator=ms.space_generator.PostOrderApply( + sch_rules=[], + postprocs=[ms.postproc.RewriteUnboundBlock()], + mutator_probs={}, + ), task_name="test", ) return ctx @@ -363,7 +363,7 @@ def test_rewrite_cooperative_fetch(): ctx = _create_context(mod, target) sch = tir.Schedule(mod, debug_mask="all") sch.enter_postproc() - assert ctx.postprocs[0].apply(sch) + assert ctx.space_generator.postprocs[0].apply(sch) tvm.ir.assert_structural_equal(sch.mod, After_cooperative_fetch) @@ -373,7 +373,7 @@ def test_rewrite_norm_bmn(): ctx = _create_context(mod, target) sch = tir.Schedule(mod, debug_mask="all") sch.enter_postproc() - assert ctx.postprocs[0].apply(sch) + assert ctx.space_generator.postprocs[0].apply(sch) tvm.ir.assert_structural_equal(sch.mod, After_norm_bmn) @@ -383,7 +383,7 @@ def test_rewrite_cuda_loop_split_no_reduction(): ctx = _create_context(mod, target) sch = tir.Schedule(mod, debug_mask="all") sch.enter_postproc() - assert ctx.postprocs[0].apply(sch) + assert ctx.space_generator.postprocs[0].apply(sch) tvm.ir.assert_structural_equal(sch.mod, Bert_fused_reshape_transpose_reshape_after_rub) @@ -393,7 +393,7 @@ def test_rewrite_cuda_loop_split_no_reduction_large(): ctx = _create_context(mod, target) sch = tir.Schedule(mod, debug_mask="all") sch.enter_postproc() - assert ctx.postprocs[0].apply(sch) + assert ctx.space_generator.postprocs[0].apply(sch) tvm.ir.assert_structural_equal(sch.mod, Bert_fused_reshape_transpose_reshape_after_rub_large) @@ -403,7 +403,7 @@ def test_rewrite_cuda_loop_split_for_kind(): ctx = _create_context(mod, target) sch = tir.Schedule(mod, debug_mask="all") sch.enter_postproc() - assert ctx.postprocs[0].apply(sch) + assert ctx.space_generator.postprocs[0].apply(sch) tvm.ir.assert_structural_equal(sch.mod["main"], after_unrolled_loop) diff --git a/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py b/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py index e7632561c05c..86a88af40309 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py +++ b/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py @@ -15,15 +15,11 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring - -import sys - import pytest import tvm import tvm.testing +from tvm import meta_schedule as ms from tvm import tir -from tvm.meta_schedule import TuneContext -from tvm.meta_schedule.postproc import VerifyGPUCode from tvm.script import tir as T from tvm.target import Target @@ -32,16 +28,17 @@ def _target() -> Target: return Target("nvidia/geforce-rtx-3080") -def _create_context(mod, target) -> TuneContext: - ctx = TuneContext( +def _create_context(mod, target) -> ms.TuneContext: + return ms.TuneContext( mod=mod, target=target, - postprocs=[ - VerifyGPUCode(), - ], + space_generator=ms.space_generator.PostOrderApply( + sch_rules=[], + postprocs=[ms.postproc.VerifyGPUCode()], + mutator_probs={}, + ), task_name="test", ) - return ctx # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable,misplaced-comparison-constant @@ -786,7 +783,7 @@ def GMMCUDATensorCore( def test_postproc_check_pass(mod): ctx = _create_context(mod, target=_target()) sch = tir.Schedule(mod, debug_mask="all") - assert ctx.postprocs[0].apply(sch) + assert ctx.space_generator.postprocs[0].apply(sch) @pytest.mark.parametrize( @@ -801,7 +798,7 @@ def test_postproc_check_pass(mod): def test_postproc_check_fail(mod): ctx = _create_context(mod, target=_target()) sch = tir.Schedule(mod, debug_mask="all") - assert not ctx.postprocs[0].apply(sch) + assert not ctx.space_generator.postprocs[0].apply(sch) if __name__ == "__main__": diff --git a/tests/python/unittest/test_meta_schedule_integration.py b/tests/python/unittest/test_meta_schedule_relay_integration.py similarity index 55% rename from tests/python/unittest/test_meta_schedule_integration.py rename to tests/python/unittest/test_meta_schedule_relay_integration.py index 366a2e4887ed..cf61df0c6ba8 100644 --- a/tests/python/unittest/test_meta_schedule_integration.py +++ b/tests/python/unittest/test_meta_schedule_relay_integration.py @@ -23,9 +23,13 @@ from tvm import meta_schedule as ms from tvm import relay, te, tir from tvm._ffi import register_func +from tvm.contrib import graph_executor +from tvm.ir.transform import PassContext from tvm.meta_schedule.testing.relay_workload import get_network from tvm.meta_schedule.testing.tlcbench import load_quantized_bert_base +from tvm.meta_schedule.tune_context import _normalize_mod from tvm.script import tir as T +from tvm.target import Target # pylint: disable=no-member,line-too-long,too-many-nested-blocks,unbalanced-tuple-unpacking,no-self-argument,missing-docstring,invalid-name @@ -60,15 +64,14 @@ def test_meta_schedule_dynamic_loop_extent(): a = relay.var("a", shape=(1, 8, 8, 512), dtype="float32") b = relay.nn.adaptive_avg_pool2d(a, (7, 7), "NHWC") mod = IRModule({"main": relay.Function([a], b)}) - extracted_tasks = ms.extract_task_from_relay(mod, target="llvm", params={}) + extracted_tasks = ms.relay_integration.extract_tasks(mod, target="llvm", params={}) assert not extracted_tasks -@pytest.mark.xfail(strict=True, reason="See https://github.com/apache/tvm/issues/12732") @requires_torch def test_meta_schedule_integration_extract_from_resnet(): mod, params, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224]) - extracted_tasks = ms.extract_task_from_relay(mod, target="llvm", params=params) + extracted_tasks = ms.relay_integration.extract_tasks(mod, target="llvm", params=params) expected_task_names = [ "fused_" + s for s in [ @@ -186,7 +189,7 @@ def test_meta_schedule_integration_extract_from_bert_base(): ), } mod, params, _ = get_network(name="bert_base", input_shape=[1, 64]) - extracted_tasks = ms.extract_task_from_relay(mod, target="llvm", params=params) + extracted_tasks = ms.relay_integration.extract_tasks(mod, target="llvm", params=params) assert len(extracted_tasks) == len(expected) for t in extracted_tasks: prim_func = None @@ -199,7 +202,6 @@ def test_meta_schedule_integration_extract_from_bert_base(): assert expected_shape == shape, t.task_name -@pytest.mark.xfail(strict=True, reason="See https://github.com/apache/tvm/issues/12732") @requires_torch def test_meta_schedule_integration_extract_from_resnet_with_filter_func(): @register_func("relay.backend.tir_converter.remove_purely_spatial", override=True) @@ -229,11 +231,14 @@ def traverse(t): return create_prim_func(args) mod, params, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224]) - extracted_tasks = ms.extract_task_from_relay( + extracted_tasks = ms.relay_integration.extract_tasks( mod, target="llvm", params=params, - tir_converter="remove_purely_spatial", + pass_config={ + "relay.backend.use_meta_schedule": True, + "relay.backend.tir_converter": "remove_purely_spatial", + }, ) expected_task_names = [ "fused_" + s @@ -266,32 +271,34 @@ def traverse(t): @pytest.mark.skip("Too slow on CI") def extract_task_qbert(): - mod, params, _ = load_quantized_bert_base(batch_size=1, seq_len=128) - target = "llvm -mcpu=cascadelake" - extracted_tasks = ms.extract_task_from_relay(mod, target, params) - tune_tasks = list( - filter( - lambda task: "dense" in task.task_name or "batch_matmul" in task.task_name, - extracted_tasks, + def _test(mod, params, target): + extracted_tasks = ms.relay_integration.extract_tasks(mod, target, params) + tune_tasks = list( + filter( + lambda task: "dense" in task.task_name or "batch_matmul" in task.task_name, + extracted_tasks, + ) ) - ) - # three int8 dense, two int8 bmm, and one fp32 dense - assert len(tune_tasks) == 6 + # three int8 dense, two int8 bmm, and one fp32 dense + assert len(tune_tasks) == 6 + + for task in tune_tasks: + relay_func = list(task.mod.functions.values())[0] + out_type = relay_func.body.checked_type - for task in tune_tasks: - relay_func = list(task.mod.functions.values())[0] - out_type = relay_func.body.checked_type + if out_type.dtype == "float32": + continue - if out_type.dtype == "float32": - continue + sch = tvm.tir.Schedule(_normalize_mod(task.dispatched[0])) + block = sch.get_block("compute") + annotations = sch.get(block).annotations - mod = ms.default_config.mod(task.dispatched[0]) - sch = tvm.tir.Schedule(mod) - block = sch.get_block("compute") - annotations = sch.get(block).annotations + assert "schedule_rule" in annotations + assert "vnni" in annotations["schedule_rule"] + ... - assert "schedule_rule" in annotations - assert "vnni" in annotations["schedule_rule"] + mod, params, _ = load_quantized_bert_base(batch_size=1, seq_len=128) + _test(mod, params, target="llvm -mcpu=cascadelake") @tvm.testing.skip_if_32bit(reason="Apparently the LLVM version on i386 image is too old") @@ -322,7 +329,7 @@ def test_extract_task_arm_conv2d_nchwc(): params = {"weight": weight_np, "bias": bias_np} target = "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon" - extracted_tasks = ms.extract_task_from_relay(relay_mod, target, params) + extracted_tasks = ms.relay_integration.extract_tasks(relay_mod, target, params) tune_tasks = list( filter( lambda task: "conv2d" in task.task_name, @@ -339,5 +346,148 @@ def test_extract_task_arm_conv2d_nchwc(): assert list(out_type.shape) == [1, 8, 130, 130, 4] +def test_meta_schedule_te2primfunc_argument_order(): + # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument + # fmt: off + @tvm.script.ir_module + class _fused_layout_transform: + @T.prim_func + def main( # type: ignore + placeholder: T.Buffer[(1, 3, 16, 16), "float32"], # type: ignore + T_layout_trans: T.Buffer[(1, 1, 16, 16, 3), "float32"], # type: ignore + ) -> None: # type: ignore + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + for i0, i1, i2, i3, i4 in T.grid(1, 1, 16, 16, 3): + with T.block("T_layout_trans"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(placeholder[ax0, ax1 * 3 + ax4, ax2, ax3]) + T.writes(T_layout_trans[ax0, ax1, ax2, ax3, ax4]) + T_layout_trans[ax0, ax1, ax2, ax3, ax4] = T.if_then_else( + ax0 < 1 and ax1 * 3 + ax4 < 3 and ax2 < 16 and ax3 < 16, # type: ignore + placeholder[ax0, ax1 * 3 + ax4, ax2, ax3], + T.float32(0), + dtype="float32", + ) + + @tvm.script.ir_module + class _fused_layout_transform_1: + @T.prim_func + def main(placeholder: T.Buffer[(1, 2, 16, 16, 4), "float32"], T_layout_trans: T.Buffer[(1, 8, 16, 16), "float32"]) -> None: # type: ignore + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + for i0, i1, i2, i3 in T.grid(1, 8, 16, 16): + with T.block("T_layout_trans"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(placeholder[ax0, ax1 // 4, ax2, ax3, ax1 % 4]) # type: ignore + T.writes(T_layout_trans[ax0, ax1, ax2, ax3]) + T_layout_trans[ax0, ax1, ax2, ax3] = T.if_then_else(ax0 < 1 and ax1 < 8 and ax2 < 16 and ax3 < 16, placeholder[ax0, ax1 // 4, ax2, ax3, ax1 % 4], T.float32(0), dtype="float32") # type: ignore + + @tvm.script.ir_module + class _fused_nn_contrib_conv2d_NCHWc: + @T.prim_func + def main(placeholder: T.Buffer[(1, 1, 16, 16, 3), "float32"], placeholder_1: T.Buffer[(2, 1, 5, 5, 3, 4), "float32"], conv2d_NCHWc: T.Buffer[(1, 2, 16, 16, 4), "float32"]) -> None: # type: ignore + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + data_pad = T.alloc_buffer([1, 1, 20, 20, 3], dtype="float32") + for i0, i1, i2, i3, i4 in T.grid(1, 1, 20, 20, 3): + with T.block("data_pad"): + i0_1, i1_1, i2_1, i3_1, i4_1 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(placeholder[i0_1, i1_1, i2_1 - 2, i3_1 - 2, i4_1]) + T.writes(data_pad[i0_1, i1_1, i2_1, i3_1, i4_1]) + data_pad[i0_1, i1_1, i2_1, i3_1, i4_1] = T.if_then_else(2 <= i2_1 and i2_1 < 18 and 2 <= i3_1 and i3_1 < 18, placeholder[i0_1, i1_1, i2_1 - 2, i3_1 - 2, i4_1], T.float32(0), dtype="float32") # type: ignore # pylint: disable=R1716 + for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(1, 2, 16, 16, 4, 3, 5, 5): + with T.block("conv2d_NCHWc"): + n, oc_chunk, oh, ow, oc_block, ic, kh, kw = T.axis.remap("SSSSSRRR", [i0, i1, i2, i3, i4, i5, i6, i7]) + T.reads(data_pad[n, ic // 3, oh + kh, ow + kw, ic % 3], placeholder_1[oc_chunk, ic // 3, kh, kw, ic % 3, oc_block]) # type: ignore + T.writes(conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block]) + T.block_attr({"workload":["conv2d_NCHWc.x86", ["TENSOR", [1, 1, 16, 16, 3], "float32"], ["TENSOR", [2, 1, 5, 5, 3, 4], "float32"], [1, 1], [2, 2, 2, 2], [1, 1], "NCHW3c", "NCHW4c", "float32"]}) + with T.init(): + conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = T.float32(0) + conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] + data_pad[n, ic // 3, oh + kh, ow + kw, ic % 3] * placeholder_1[oc_chunk, ic // 3, kh, kw, ic % 3, oc_block] # type: ignore + + # fmt: on + # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument + + def _create_database(): + database = ms.database.create("memory") + + def _commit(mod): + workload = database.commit_workload(mod) + database.commit_tuning_record( + ms.database.TuningRecord( + tir.schedule.Trace([], {}), + workload=workload, + run_secs=[0.1], + ) + ) + + _commit(_fused_layout_transform) + _commit(_fused_layout_transform_1) + _commit(_fused_nn_contrib_conv2d_NCHWc) + return database + + data_shape = (1, 3, 16, 16) + weight_shape = (8, 3, 5, 5) + + def _create_relay_mod(): + data = relay.var("data", relay.TensorType(data_shape, "float32")) + weight = relay.var("weight", relay.TensorType(weight_shape, "float32")) + y = relay.nn.conv2d( + data, + weight, + padding=(2, 2), + kernel_size=(5, 5), + kernel_layout="OIHW", + out_dtype="float32", + ) + f = relay.Function([data, weight], y) + mod = tvm.IRModule.from_expr(f) + mod = relay.transform.InferType()(mod) + return mod + + mod = _create_relay_mod() + dev = tvm.cpu() + target = Target("llvm --num-cores=16") + params = { + "weight": np.random.rand(*weight_shape).astype("float32"), + } + data = tvm.nd.array( + np.random.rand(*data_shape).astype("float32"), + dev, + ) + + with target, _create_database(), PassContext( + opt_level=3, + config={ + "relay.backend.use_meta_schedule": True, + "relay.backend.use_meta_schedule_dispatch": 7, + "relay.backend.tir_converter": "default", + }, + ): + rt_mod1 = relay.build(mod, target=target, params=params) + + # Compile without meta-schedule for correctness check + with tvm.transform.PassContext(opt_level=0): + rt_mod2 = relay.build(mod, target=target, params=params) + + def get_output(data, lib): + module = graph_executor.GraphModule(lib["default"](dev)) + module.set_input("data", data) + module.run() + return module.get_output(0).numpy() + + # Check correctness + actual_output = get_output(data, rt_mod1) + expected_output = get_output(data, rt_mod2) + assert np.allclose(actual_output, expected_output, rtol=1e-4, atol=2e-4) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py b/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py index 70b49944ba0f..7f56683588ba 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py @@ -17,7 +17,10 @@ # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring from tvm import meta_schedule as ms from tvm.meta_schedule.testing import te_workload -from tvm.meta_schedule.testing.space_generation import check_sketches +from tvm.meta_schedule.testing.space_generation import ( + check_sketches, + generate_design_space, +) from tvm.script import tir as T from tvm.target import Target from tvm.te import create_prim_func @@ -104,13 +107,12 @@ def cpu_matmul_2( ("SamplePerfectTile", [4, 128]), ] mod = create_prim_func(te_workload.matmul(n=4, m=4, k=512)) - actual = ms.TuneContext( + actual = generate_design_space( + kind="llvm", mod=mod, target=Target("llvm --num-cores=32"), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules=[ms.schedule_rule.AddRFactor()], - task_name="test", - ).generate_design_space() + types=ms.schedule_rule.AddRFactor, + ) check_sketches( mod, sketches=actual, @@ -269,13 +271,12 @@ def argmax_2( ("SamplePerfectTile", [8, 16]), ] mod = argmax - actual = ms.TuneContext( + actual = generate_design_space( + kind="llvm", mod=mod, target=Target("llvm --num-cores=32"), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules=[ms.schedule_rule.AddRFactor()], - task_name="test", - ).generate_design_space() + types=ms.schedule_rule.AddRFactor, + ) check_sketches( mod, sketches=actual, diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py index a50292df7ae3..f0eee4138daa 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py @@ -16,8 +16,10 @@ # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring from tvm import meta_schedule as ms -from tvm.meta_schedule.testing.schedule_rule import get_rules -from tvm.meta_schedule.testing.space_generation import check_sketches +from tvm.meta_schedule.testing.space_generation import ( + check_sketches, + generate_design_space, +) from tvm.script import tir as T from tvm.target import Target @@ -80,13 +82,12 @@ def elementwise_0( ("SampleCategorical", 5), ] mod = element_wise - actual = ms.TuneContext( + actual = generate_design_space( + kind="cuda", mod=mod, target=Target("nvidia/geforce-rtx-3080", host="llvm"), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules=get_rules("cuda", ms.schedule_rule.AutoBind), - task_name="test", - ).generate_design_space() + types=ms.schedule_rule.AutoBind, + ) check_sketches( mod, sketches=actual, @@ -114,13 +115,12 @@ def reduction_loop_only_0( C[()] = T.min(C[()], A[k0] / B[k0]) mod = reduction_loop_only - actual = ms.TuneContext( + actual = generate_design_space( + kind="cuda", mod=mod, target=Target("nvidia/geforce-rtx-3080", host="llvm"), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules=get_rules("cuda", ms.schedule_rule.AutoBind), - task_name="test", - ).generate_design_space() + types=ms.schedule_rule.AutoBind, + ) check_sketches( mod, sketches=actual, @@ -145,13 +145,12 @@ def zero_dim_add_0( C[()] = A[()] + B[()] mod = zero_dim_add - actual = ms.TuneContext( + actual = generate_design_space( + kind="cuda", mod=mod, target=Target("nvidia/geforce-rtx-3080", host="llvm"), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules=get_rules("cuda", ms.schedule_rule.AutoBind), - task_name="test", - ).generate_design_space() + types=ms.schedule_rule.AutoBind, + ) check_sketches( mod, sketches=actual, diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py index c0801c9d7b5e..c17209e2cb77 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py @@ -17,7 +17,7 @@ # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring import tvm from tvm import meta_schedule as ms -from tvm.meta_schedule.testing.schedule_rule import get_rules +from tvm.meta_schedule.testing.space_generation import generate_design_space from tvm.script import tir as T from tvm.target import Target @@ -338,74 +338,63 @@ def main(T_full: T.Buffer[(1, 12, 4096), "int64"]) -> None: # fmt: on -def _create_context(mod, target, rule): - ctx = ms.TuneContext( - mod=mod, - target=target, - space_generator=ms.space_generator.PostOrderApply(), - sch_rules=[rule], - task_name="test", - ) - return ctx - - def test_inline_consumer_chain(): mod = Conv2DBiasBnReLU target = Target("llvm") - ctx = _create_context( + (space,) = generate_design_space( + kind="llvm", mod=mod, target=target, - rule=get_rules("llvm", ms.schedule_rule.AutoInline)[0], + types=ms.schedule_rule.AutoInline, ) - (space,) = ctx.space_generator.generate_design_space(mod=mod) tvm.ir.assert_structural_equal(lhs=space.mod, rhs=Conv2DBiasBnReLUInlined) def test_inline_into_cache(): mod = MultiLevelTiledConv2D target = Target("cuda", host="llvm") - ctx = _create_context( + (space,) = generate_design_space( + kind="cuda", mod=mod, target=target, - rule=get_rules("cuda", ms.schedule_rule.AutoInline)[0], + types=ms.schedule_rule.AutoInline, ) - (space,) = ctx.space_generator.generate_design_space(mod=mod) tvm.ir.assert_structural_equal(lhs=space.mod, rhs=MultiLevelTiledConv2DAfterInline) def test_inline_into_multiple_consumers(): mod = SoftmaxBeforeInline target = Target("cuda", host="llvm") - ctx = _create_context( + (space,) = generate_design_space( + kind="cuda", mod=mod, target=target, - rule=get_rules("cuda", ms.schedule_rule.AutoInline)[0], + types=ms.schedule_rule.AutoInline, ) - (space,) = ctx.space_generator.generate_design_space(mod=mod) tvm.ir.assert_structural_equal(lhs=space.mod, rhs=SoftmaxAfterInline) def test_inline_pure_spatial(): mod = BeforePureSpatial target = Target("llvm") - ctx = _create_context( + (space,) = generate_design_space( + kind="llvm", mod=mod, target=target, - rule=get_rules("llvm", ms.schedule_rule.AutoInline)[0], + types=ms.schedule_rule.AutoInline, ) - (space,) = ctx.space_generator.generate_design_space(mod=mod) tvm.ir.assert_structural_equal(lhs=space.mod, rhs=AfterPureSpatial) def test_inline_constant_tensor(): mod = ConstConsumer target = Target("cuda", host="llvm") - ctx = _create_context( + (space,) = generate_design_space( + kind="cuda", mod=mod, target=target, - rule=get_rules("cuda", ms.schedule_rule.AutoInline)[0], + types=ms.schedule_rule.AutoInline, ) - (space,) = ctx.space_generator.generate_design_space(mod=mod) tvm.ir.assert_structural_equal(lhs=space.mod, rhs=ConstConsumer) diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py b/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py index 718b264bddd2..c851c9bec3b5 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py @@ -19,8 +19,10 @@ import tvm from tvm import meta_schedule as ms from tvm.meta_schedule.testing import te_workload -from tvm.meta_schedule.testing.schedule_rule import get_rules -from tvm.meta_schedule.testing.space_generation import check_sketches +from tvm.meta_schedule.testing.space_generation import ( + check_sketches, + generate_design_space, +) from tvm.script import tir as T from tvm.target import Target from tvm.te import create_prim_func @@ -280,13 +282,12 @@ def softmax_mn_3( ("SampleCategorical", 7), ] mod = create_prim_func(te_workload.softmax_mn(n=256, m=256)) - actual = ms.TuneContext( + actual = generate_design_space( + kind="cuda", mod=mod, target=Target("nvidia/geforce-rtx-3090", host="llvm"), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules=get_rules("cuda", ms.schedule_rule.CrossThreadReduction), - task_name="test", - ).generate_design_space() + types=ms.schedule_rule.CrossThreadReduction, + ) check_sketches( mod, sketches=actual, @@ -476,13 +477,12 @@ def softmax_mn_after_inline_3( ] mod = Softmax_mn_after_inline - actual = ms.TuneContext( + actual = generate_design_space( + kind="cuda", mod=mod, target=Target("nvidia/geforce-rtx-3090", host="llvm"), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules=get_rules("cuda", ms.schedule_rule.CrossThreadReduction), - task_name="test", - ).generate_design_space() + types=ms.schedule_rule.CrossThreadReduction, + ) check_sketches( mod, sketches=actual, @@ -552,13 +552,12 @@ def batch_norm_bmn_1(A: T.Buffer[(1, 512, 512), "float32"], D: T.Buffer[1, "floa ] mod = create_prim_func(te_workload.norm_bmn(B=1, M=512, N=512)) - actual = ms.TuneContext( + actual = generate_design_space( + kind="cuda", mod=mod, target=Target("nvidia/geforce-rtx-3090", host="llvm"), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules=get_rules("cuda", ms.schedule_rule.CrossThreadReduction), - task_name="test", - ).generate_design_space() + types=ms.schedule_rule.CrossThreadReduction, + ) check_sketches( mod, sketches=actual, @@ -670,13 +669,12 @@ def argmax_1( ] mod = argmax - actual = ms.TuneContext( + actual = generate_design_space( + kind="cuda", mod=mod, target=Target("nvidia/geforce-rtx-3090", host="llvm"), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules=get_rules("cuda", ms.schedule_rule.CrossThreadReduction), - task_name="test", - ).generate_design_space() + types=ms.schedule_rule.CrossThreadReduction, + ) check_sketches( mod, sketches=actual, @@ -745,13 +743,12 @@ def argmax_1( ] mod = argmax_32 - actual = ms.TuneContext( + actual = generate_design_space( + kind="cuda", mod=mod, target=Target("nvidia/geforce-rtx-3090", host="llvm"), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules=get_rules("cuda", ms.schedule_rule.CrossThreadReduction), - task_name="test", - ).generate_design_space() + types=ms.schedule_rule.CrossThreadReduction, + ) check_sketches( mod, sketches=actual, diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py index d9d078106333..28e6f295e78f 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py @@ -16,10 +16,12 @@ # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring from tvm import meta_schedule as ms -from tvm import te, target +from tvm import target, te from tvm.meta_schedule.testing import te_workload -from tvm.meta_schedule.testing.schedule_rule import get_rules -from tvm.meta_schedule.testing.space_generation import check_sketches +from tvm.meta_schedule.testing.space_generation import ( + check_sketches, + generate_design_space, +) from tvm.script import tir as T from tvm.target import Target @@ -128,13 +130,12 @@ def cpu_matmul_2( ] mod = te.create_prim_func(te_workload.matmul(512, 512, 512)) - actual = ms.TuneContext( + actual = generate_design_space( + kind="llvm", mod=mod, target=Target("llvm"), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules=get_rules("llvm", ms.schedule_rule.MultiLevelTiling), - task_name="test", - ).generate_design_space() + types=ms.schedule_rule.MultiLevelTiling, + ) check_sketches( mod, sketches=actual, @@ -253,13 +254,12 @@ def cpu_matmul_relu_2( ("SamplePerfectTile", [64, 8]), ] mod = te.create_prim_func(te_workload.matmul_relu(512, 512, 512)) - actual = ms.TuneContext( + actual = generate_design_space( + kind="llvm", mod=mod, target=Target("llvm"), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules=get_rules("llvm", ms.schedule_rule.MultiLevelTiling), - task_name="test", - ).generate_design_space() + types=ms.schedule_rule.MultiLevelTiling, + ) check_sketches( mod, sketches=actual, @@ -360,13 +360,12 @@ def cuda_matmul_0( ("SampleCategorical", 0), ] mod = te.create_prim_func(te_workload.matmul(512, 512, 512)) - actual = ms.TuneContext( + actual = generate_design_space( + kind="cuda", mod=mod, target=Target("nvidia/geforce-rtx-3080"), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules=get_rules("cuda", ms.schedule_rule.MultiLevelTiling), - task_name="test", - ).generate_design_space() + types=ms.schedule_rule.MultiLevelTiling, + ) check_sketches( mod, sketches=actual, @@ -479,13 +478,12 @@ def cuda_matmul_relu_0( ("SampleCategorical", 3), ] mod = te.create_prim_func(te_workload.matmul_relu(512, 512, 512)) - actual = ms.TuneContext( + actual = generate_design_space( + kind="cuda", mod=mod, target=Target("nvidia/geforce-rtx-3080"), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules=get_rules("cuda", ms.schedule_rule.MultiLevelTiling), - task_name="test", - ).generate_design_space() + types=ms.schedule_rule.MultiLevelTiling, + ) check_sketches( mod, sketches=actual, @@ -511,13 +509,12 @@ def sum_with_trivial_block_iter( # Expect nothing to happen - the rule is not supposed to be applied in this case mod = sum_with_trivial_block_iter - (sch,) = ms.TuneContext( + (sch,) = generate_design_space( + kind="cuda", mod=mod, target=Target("nvidia/geforce-rtx-3080"), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules=get_rules("cuda", ms.schedule_rule.MultiLevelTiling), - task_name="test", - ).generate_design_space() + types=ms.schedule_rule.MultiLevelTiling, + ) assert not sch.trace.simplified(remove_postproc=True).insts @@ -593,10 +590,11 @@ def cpu_conv2d_nhwc( te_workload.conv2d_nhwc(1, H, W, I, O, 3, 1, 1, 1, in_dtype="float16", out_dtype="float16") ) - actual = ms.TuneContext( + actual = generate_design_space( + kind="cuda", mod=mod, target=Target(target_hexagon, host=target_hexagon), - space_generator=ms.space_generator.PostOrderApply(), + types=None, sch_rules=[ ms.schedule_rule.MultiLevelTilingWideVector( structure="SRSRS", @@ -606,8 +604,7 @@ def cpu_conv2d_nhwc( reuse_write=None, ) ], - task_name="test", - ).generate_design_space() + ) decision_0 = [ ("SamplePerfectTile", [1, 1, 1]), diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_intrin.py b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_intrin.py index 38ddb137e108..e70f7cb2c618 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_intrin.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_intrin.py @@ -18,7 +18,10 @@ from tvm import meta_schedule as ms from tvm import te from tvm.ir import assert_structural_equal -from tvm.meta_schedule.testing.space_generation import check_sketches +from tvm.meta_schedule.testing.space_generation import ( + check_sketches, + generate_design_space, +) from tvm.script import tir as T from tvm.target import Target from tvm.tir.tensor_intrin.arm_cpu import DP4A_INTRIN @@ -226,10 +229,11 @@ def vnni_conv2d_nchwc_2(placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], plac mod = conv2d_nchwc target = Target("llvm -mcpu=cascadelake -num-cores=4") - actual = ms.TuneContext( + actual = generate_design_space( + kind="llvm", mod=mod, target=Target(target), - space_generator=ms.space_generator.PostOrderApply(), + types=None, sch_rules=[ ms.schedule_rule.MultiLevelTilingWithIntrin( VNNI_INTRIN, @@ -241,7 +245,7 @@ def vnni_conv2d_nchwc_2(placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], plac reuse_write=ms.schedule_rule.ReuseType(req="may", levels=[1, 2], scope="global"), ), ], - ).generate_design_space() + ) check_sketches( mod, sketches=actual, @@ -266,10 +270,11 @@ def _dense(m, n, k, in_dtype, out_dtype): return te.create_prim_func([X, W, matmul]) mod = _dense(m, n, k, in_dtype, out_dtype) - actual = ms.TuneContext( + actual = generate_design_space( + kind="cuda", mod=mod, target=Target("cuda"), - space_generator=ms.space_generator.PostOrderApply(), + types=None, sch_rules=[ ms.schedule_rule.MultiLevelTilingWithIntrin( DP4A_INTRIN, @@ -281,7 +286,7 @@ def _dense(m, n, k, in_dtype, out_dtype): reuse_write=ms.schedule_rule.ReuseType(req="must", levels=[3], scope="local"), ) ], - ).generate_design_space() + ) if expected_mods is None: assert expected_decisions is None assert len(actual) == 1 diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py index a53c1062b98d..0e4bd6bf302a 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py @@ -20,8 +20,11 @@ from tvm import meta_schedule as ms from tvm import te from tvm.meta_schedule.testing import te_workload -from tvm.meta_schedule.testing.schedule_rule import get_rules -from tvm.meta_schedule.testing.space_generation import check_sketches +from tvm.meta_schedule.testing.space_generation import ( + check_sketches, + generate_design_space, + get_rules, +) from tvm.script import tir as T from tvm.tir.tensor_intrin.cuda import get_wmma_intrin_group @@ -186,13 +189,16 @@ def matmul_relu_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128, 128), "f out_dtype="float32", ) ) - actual = ms.TuneContext( + actual = generate_design_space( + kind="cuda", mod=mod, target=tvm.target.Target("cuda"), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules=[multi_level_tiling_tensor_core()] - + get_rules("cuda", ms.schedule_rule.AutoInline), - ).generate_design_space() + types=None, + sch_rules=[ + multi_level_tiling_tensor_core(), + ] + + get_rules(kind="cuda", types=ms.schedule_rule.AutoInline), + ) check_sketches( mod, sketches=actual, @@ -324,10 +330,11 @@ def matmul_relu_fallback_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128, out_dtype="float32", ) ) - actual = ms.TuneContext( + actual = generate_design_space( + kind="cuda", mod=mod, target=tvm.target.Target("cuda"), - space_generator=ms.space_generator.PostOrderApply(), + types=None, sch_rules=[ multi_level_tiling_tensor_core(), ] @@ -338,7 +345,7 @@ def matmul_relu_fallback_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128, ms.schedule_rule.AutoInline, ), ), - ).generate_design_space() + ) check_sketches( mod, sketches=actual, @@ -475,12 +482,15 @@ def conv2d_0(inputs: T.Buffer[(1, 16, 16, 32), "float16"], weight: T.Buffer[(3, out_dtype="float32", ) ) - actual = ms.TuneContext( + actual = generate_design_space( + kind="cuda", mod=mod, target=tvm.target.Target("cuda"), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules=[multi_level_tiling_tensor_core()], - ).generate_design_space() + types=None, + sch_rules=[ + multi_level_tiling_tensor_core(), + ], + ) check_sketches( mod, sketches=actual, @@ -490,17 +500,18 @@ def conv2d_0(inputs: T.Buffer[(1, 16, 16, 32), "float16"], weight: T.Buffer[(3, # Test adding inapplicable tensor intrinsics doesn't change the search space # This test case uses the same workload, decision and the expected sketch as above - actual = ms.TuneContext( + actual = generate_design_space( + kind="cuda", mod=mod, target=tvm.target.Target("cuda"), - space_generator=ms.space_generator.PostOrderApply(), + types=None, sch_rules=[ multi_level_tiling_tensor_core( in_dtype="float16", out_dtype=["float16", "float32"], ), ], - ).generate_design_space() + ) check_sketches( mod, sketches=actual, @@ -638,16 +649,17 @@ def matmul_relu_pipeline_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128, out_dtype="float32", ) ) - actual = ms.TuneContext( + actual = generate_design_space( + kind="cuda", mod=mod, target=tvm.target.Target("cuda"), - space_generator=ms.space_generator.PostOrderApply(), + types=None, sch_rules=[ multi_level_tiling_tensor_core( use_software_pipeline=True, ), ], - ).generate_design_space() + ) check_sketches( mod, sketches=actual, @@ -775,13 +787,14 @@ def matmul_relu_global_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128, 1 out_dtype="float32", ) ) - actual = ms.TuneContext( + actual = generate_design_space( + kind="cuda", mod=mod, target=tvm.target.Target("cuda"), - space_generator=ms.space_generator.PostOrderApply(), + types=None, sch_rules=[multi_level_tiling_tensor_core(write_reuse_scope="global")] + get_rules("cuda", ms.schedule_rule.AutoInline), - ).generate_design_space() + ) check_sketches( mod, sketches=actual, @@ -799,13 +812,14 @@ def test_matmul_relu_non_tensorizable(): k=128, ) ) - (sch,) = ms.TuneContext( + (sch,) = generate_design_space( + kind="cuda", mod=mod, target=tvm.target.Target("cuda"), - space_generator=ms.space_generator.PostOrderApply(), + types=None, sch_rules=[multi_level_tiling_tensor_core(write_reuse_scope="global")] + get_rules("cuda", ms.schedule_rule.AutoInline), - ).generate_design_space() + ) tvm.ir.assert_structural_equal(mod, sch.mod["main"]) @@ -934,13 +948,14 @@ def padded_matmul_relu_0(A: T.Buffer[(127, 127), "float16"], B: T.Buffer[(127, 1 out_dtype="float32", ) ) - actual = ms.TuneContext( + actual = generate_design_space( + kind="cuda", mod=mod, target=tvm.target.Target("cuda"), - space_generator=ms.space_generator.PostOrderApply(), + types=None, sch_rules=[multi_level_tiling_tensor_core(write_reuse_scope="shared")] + get_rules("cuda", ms.schedule_rule.AutoInline), - ).generate_design_space() + ) check_sketches( mod, sketches=actual, @@ -1083,13 +1098,14 @@ def conv2d_1x1_0(inputs: T.Buffer[(1, 16, 16, 64), "float16"], weight: T.Buffer[ out_dtype="float32", ) ) - actual = ms.TuneContext( + actual = generate_design_space( + kind="cuda", mod=mod, target=tvm.target.Target("cuda"), - space_generator=ms.space_generator.PostOrderApply(), + types=None, sch_rules=[multi_level_tiling_tensor_core(write_reuse_scope="shared")] + get_rules("cuda", ms.schedule_rule.AutoInline), - ).generate_design_space() + ) check_sketches( mod, sketches=actual, diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py b/tests/python/unittest/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py index 8076fcaa8bd4..520dfbfb1cc5 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py @@ -17,7 +17,10 @@ # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring import tvm from tvm import meta_schedule as ms -from tvm.meta_schedule.testing.space_generation import check_sketches +from tvm.meta_schedule.testing.space_generation import ( + check_sketches, + generate_design_space, +) from tvm.script import tir as T from tvm.target import Target @@ -252,10 +255,11 @@ def Matmul_0( ] mod = Matmul - actual = ms.TuneContext( + actual = generate_design_space( + kind="llvm", mod=mod, target=Target("llvm --num-cores=32"), - space_generator=ms.space_generator.PostOrderApply(), + types=None, sch_rules=[ ms.schedule_rule.ParallelizeVectorizeUnroll( max_jobs_per_core=16, @@ -264,8 +268,7 @@ def Matmul_0( unroll_explicit=True, ), ], - task_name="test", - ).generate_design_space() + ) check_sketches( mod, sketches=actual, @@ -276,10 +279,11 @@ def Matmul_0( def test_parallel_vectorize_unroll_spatial(): mod = PureSpatial - actual = ms.TuneContext( + actual = generate_design_space( + kind="llvm", mod=mod, target=Target("llvm --num-cores=32"), - space_generator=ms.space_generator.PostOrderApply(), + types=None, sch_rules=[ ms.schedule_rule.ParallelizeVectorizeUnroll( max_jobs_per_core=-1, @@ -288,8 +292,7 @@ def test_parallel_vectorize_unroll_spatial(): unroll_explicit=True, ), ], - task_name="test", - ).generate_design_space() + ) assert len(actual) == 1 trace = actual[0].trace.simplified(remove_postproc=True) assert not trace.insts diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py b/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py index fc52aa199cc1..7c9433cedf50 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py @@ -17,7 +17,10 @@ # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring import tvm from tvm import meta_schedule as ms -from tvm.meta_schedule.testing.space_generation import check_sketches +from tvm.meta_schedule.testing.space_generation import ( + check_sketches, + generate_design_space, +) from tvm.script import tir as T from tvm.target import Target @@ -87,13 +90,13 @@ def add_0( ] mod = Add - actual = ms.TuneContext( + actual = generate_design_space( + kind="llvm", mod=mod, target=Target("llvm"), - space_generator=ms.space_generator.PostOrderApply(), + types=None, sch_rules=[ms.schedule_rule.RandomComputeLocation()], - task_name="test", - ).generate_design_space() + ) check_sketches( mod, sketches=actual, diff --git a/tests/python/unittest/test_meta_schedule_search_strategy.py b/tests/python/unittest/test_meta_schedule_search_strategy.py index 7433f001c0eb..e34554420600 100644 --- a/tests/python/unittest/test_meta_schedule_search_strategy.py +++ b/tests/python/unittest/test_meta_schedule_search_strategy.py @@ -84,14 +84,16 @@ def test_meta_schedule_replay_func( context = ms.TuneContext( mod=Matmul, - space_generator=ms.space_generator.ScheduleFn(sch_fn=_schedule_matmul), - search_strategy=TestClass( - num_trials_per_iter=num_trials_per_iter, max_trials_per_task=max_trials_per_task - ), + space_generator=ms.space_generator.ScheduleFn(sch_fn=_schedule_matmul, postprocs=[]), + search_strategy=TestClass(), ) strategy = context.search_strategy spaces = context.space_generator.generate_design_space(context.mod) - strategy.pre_tuning(spaces) + strategy.pre_tuning( + max_trials=max_trials_per_task, + num_trials_per_iter=num_trials_per_iter, + design_spaces=spaces, + ) (correct_sch,) = ms.space_generator.ScheduleFn(sch_fn=_schedule_matmul).generate_design_space( Matmul ) @@ -135,10 +137,13 @@ def _schedule_matmul_small(sch: Schedule): mod=Matmul, space_generator=ms.space_generator.ScheduleFn( sch_fn=_schedule_matmul_small, + sch_rules=[], + postprocs=[], + mutator_probs={ + DummyMutator(): 1.0, + }, ), search_strategy=ms.search_strategy.EvolutionarySearch( - num_trials_per_iter=num_trials_per_iter, - max_trials_per_task=max_trials_per_task, population_size=5, init_measured_ratio=0.1, init_min_unmeasured=50, @@ -147,15 +152,14 @@ def _schedule_matmul_small(sch: Schedule): genetic_max_fail_count=10, eps_greedy=0.9, ), - mutator_probs={ - DummyMutator(): 1.0, - }, target=tvm.target.Target("llvm"), num_threads=1, # because we are using a mutator from the python side ) strategy = context.search_strategy strategy.pre_tuning( - context.space_generator.generate_design_space(context.mod), + max_trials=max_trials_per_task, + num_trials_per_iter=num_trials_per_iter, + design_spaces=context.space_generator.generate_design_space(context.mod), database=ms.database.MemoryDatabase(), cost_model=ms.cost_model.RandomModel(), ) @@ -197,8 +201,6 @@ def _schedule_matmul_empty(sch: Schedule): context = ms.TuneContext( mod=Matmul, search_strategy=ms.search_strategy.EvolutionarySearch( - num_trials_per_iter=num_trials_per_iter, - max_trials_per_task=max_trials_per_task, population_size=5, init_measured_ratio=0.1, init_min_unmeasured=50, @@ -209,16 +211,20 @@ def _schedule_matmul_empty(sch: Schedule): ), space_generator=ms.space_generator.ScheduleFn( sch_fn=_schedule_matmul_empty, + sch_rules=[], + postprocs=[], + mutator_probs={ + DummyMutator(): 1.0, + }, ), - mutator_probs={ - DummyMutator(): 1.0, - }, target=tvm.target.Target("llvm"), num_threads=1, ) strategy = context.search_strategy strategy.pre_tuning( - context.space_generator.generate_design_space(context.mod), + max_trials=max_trials_per_task, + num_trials_per_iter=num_trials_per_iter, + design_spaces=context.space_generator.generate_design_space(context.mod), database=ms.database.MemoryDatabase(), cost_model=ms.cost_model.RandomModel(), ) @@ -246,4 +252,7 @@ def _schedule_matmul_empty(sch: Schedule): if __name__ == "__main__": - tvm.testing.main() + test_meta_schedule_replay_func(ms.search_strategy.ReplayFunc) + test_meta_schedule_replay_func(ms.search_strategy.ReplayTrace) + test_meta_schedule_evolutionary_search() + test_meta_schedule_evolutionary_search_early_stop() diff --git a/tests/python/unittest/test_meta_schedule_space_cpu.py b/tests/python/unittest/test_meta_schedule_space_cpu.py index 25dc14fd5cb7..47f3e6d4cc51 100644 --- a/tests/python/unittest/test_meta_schedule_space_cpu.py +++ b/tests/python/unittest/test_meta_schedule_space_cpu.py @@ -16,7 +16,11 @@ # under the License. """Tests for MetaSchedule search space on CPU""" from tvm import meta_schedule as ms -from tvm.meta_schedule.testing.space_generation import check_sketches, print_sketches +from tvm.meta_schedule.testing.space_generation import ( + check_sketches, + print_sketches, + generate_design_space, +) from tvm.meta_schedule.testing.te_workload import create_te_workload from tvm.script import tir as T from tvm.target import Target @@ -26,6 +30,15 @@ def _target(): return Target("aws/cpu/c5.9xlarge") +def _design_space(mod): + return generate_design_space( + kind="llvm", + mod=mod, + target=_target(), + types=ms.ScheduleRule, + ) + + def test_cpu_c1d(): # fmt: off @T.prim_func @@ -161,12 +174,7 @@ def c1d_2(inputs: T.Buffer[(1, 256, 64), "float32"], weight: T.Buffer[(3, 64, 12 ] mod = create_te_workload("C1D", 0) - actual = ms.TuneContext( - mod=mod, - target=_target(), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules="default", - ).generate_design_space() + actual = _design_space(mod) check_sketches( mod, sketches=actual, @@ -337,12 +345,7 @@ def c2d_2(inputs: T.Buffer[(1, 224, 224, 3), "float32"], weight: T.Buffer[(7, 7, ] mod = create_te_workload("C2D", 0) - actual = ms.TuneContext( - mod=mod, - target=_target(), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules="default", - ).generate_design_space() + actual = _design_space(mod) check_sketches( mod, sketches=actual, @@ -534,12 +537,7 @@ def c3d_2(inputs: T.Buffer[(1, 16, 224, 224, 3), "float32"], weight: T.Buffer[(7 ] mod = create_te_workload("C3D", 0) - actual = ms.TuneContext( - mod=mod, - target=_target(), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules="default", - ).generate_design_space() + actual = _design_space(mod) check_sketches( mod, sketches=actual, @@ -727,12 +725,7 @@ def cap_2(inputs: T.Buffer[(1, 16, 16, 4, 4, 32), "float32"], weight: T.Buffer[( ("SampleComputeLocation", -1), ] mod = create_te_workload("CAP", 0) - actual = ms.TuneContext( - mod=mod, - target=_target(), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules="default", - ).generate_design_space() + actual = _design_space(mod) check_sketches( mod, sketches=actual, @@ -887,12 +880,7 @@ def dep_2(placeholder: T.Buffer[(1, 112, 112, 32), "float32"], placeholder_1: T. ("SampleComputeLocation", 5), ] mod = create_te_workload("DEP", 0) - actual = ms.TuneContext( - mod=mod, - target=_target(), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules="default", - ).generate_design_space() + actual = _design_space(mod) check_sketches( mod, sketches=actual, @@ -1065,12 +1053,7 @@ def dil_2(inputs: T.Buffer[(1, 224, 224, 3), "float32"], weight: T.Buffer[(7, 7, ("SampleComputeLocation", 1), ] mod = create_te_workload("DIL", 0) - actual = ms.TuneContext( - mod=mod, - target=_target(), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules="default", - ).generate_design_space() + actual = _design_space(mod) check_sketches( mod, sketches=actual, @@ -1187,12 +1170,7 @@ def gmm_2(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "flo ("SampleCategorical", 1), ] mod = create_te_workload("GMM", 0) - actual = ms.TuneContext( - mod=mod, - target=_target(), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules="default", - ).generate_design_space() + actual = _design_space(mod) check_sketches( mod, sketches=actual, @@ -1361,12 +1339,7 @@ def grp_2(inputs: T.Buffer[(1, 56, 56, 64), "float32"], weight: T.Buffer[(3, 3, ("SampleComputeLocation", 9), ] mod = create_te_workload("GRP", 0) - actual = ms.TuneContext( - mod=mod, - target=_target(), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules="default", - ).generate_design_space() + actual = _design_space(mod) check_sketches( mod, sketches=actual, @@ -1521,12 +1494,7 @@ def t2d_2(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 5 ("SampleComputeLocation", -2), ] mod = create_te_workload("T2D", 0) - actual = ms.TuneContext( - mod=mod, - target=_target(), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules="default", - ).generate_design_space() + actual = _design_space(mod) check_sketches( mod, sketches=actual, @@ -1646,12 +1614,7 @@ def nrm_2(A: T.Buffer[(1, 256, 256), "float32"], D: T.Buffer[1, "float32"]) -> N ("SampleComputeLocation", -1), ] mod = create_te_workload("NRM", 0) - actual = ms.TuneContext( - mod=mod, - target=_target(), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules="default", - ).generate_design_space() + actual = _design_space(mod) check_sketches( mod, sketches=actual, @@ -2220,12 +2183,7 @@ def sfm_8(A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256 ("SampleComputeLocation", 0), ] mod = create_te_workload("SFM", 0) - actual = ms.TuneContext( - mod=mod, - target=_target(), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules="default", - ).generate_design_space() + actual = _design_space(mod) check_sketches( mod, sketches=actual, @@ -2404,12 +2362,7 @@ def cbr_2(data: T.Buffer[(1, 224, 224, 3), "float32"], kernel: T.Buffer[(7, 7, 3 ("SampleComputeLocation", 1), ] mod = create_te_workload("CBR", 0) - actual = ms.TuneContext( - mod=mod, - target=_target(), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules="default", - ).generate_design_space() + actual = _design_space(mod) check_sketches( mod, sketches=actual, @@ -2588,12 +2541,7 @@ def tbg_2(query: T.Buffer[(1, 128, 12, 64), "float32"], value: T.Buffer[(1, 128, ("SampleComputeLocation", -2), ] mod = create_te_workload("TBG", 0) - actual = ms.TuneContext( - mod=mod, - target=_target(), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules="default", - ).generate_design_space() + actual = _design_space(mod) check_sketches( mod, sketches=actual, diff --git a/tests/python/unittest/test_meta_schedule_space_cuda.py b/tests/python/unittest/test_meta_schedule_space_cuda.py index ffa2b57ba8ec..f0f6e91ea655 100644 --- a/tests/python/unittest/test_meta_schedule_space_cuda.py +++ b/tests/python/unittest/test_meta_schedule_space_cuda.py @@ -15,9 +15,14 @@ # specific language governing permissions and limitations # under the License. """Tests for MetaSchedule search space on CUDA""" -from tvm import te, topi, autotvm +from tvm import autotvm from tvm import meta_schedule as ms -from tvm.meta_schedule.testing.space_generation import check_sketches, print_sketches +from tvm import te, topi +from tvm.meta_schedule.testing.space_generation import ( + check_sketches, + generate_design_space, + print_sketches, +) from tvm.meta_schedule.testing.te_workload import create_te_workload from tvm.script import tir as T from tvm.target import Target @@ -27,6 +32,15 @@ def _target(): return Target("nvidia/geforce-rtx-3070") +def _design_space(mod): + return generate_design_space( + kind="cuda", + mod=mod, + target=_target(), + types=ms.ScheduleRule, + ) + + def _conv2d_winograd_nchw(): data = te.placeholder((1, 64, 224, 224), name="data", dtype="float32") kernel = te.placeholder((6, 6, 64, 64), name="kernel", dtype="float32") @@ -119,12 +133,7 @@ def c1d_0(inputs: T.Buffer[(1, 256, 64), "float32"], weight: T.Buffer[(3, 64, 12 ] mod = create_te_workload("C1D", 0) - actual = ms.TuneContext( - mod=mod, - target=_target(), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules="default", - ).generate_design_space() + actual = _design_space(mod) check_sketches( mod, sketches=actual, @@ -208,12 +217,7 @@ def c2d_0(inputs: T.Buffer[(1, 224, 224, 3), "float32"], weight: T.Buffer[(7, 7, ] mod = create_te_workload("C2D", 0) - actual = ms.TuneContext( - mod=mod, - target=_target(), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules="default", - ).generate_design_space() + actual = _design_space(mod) check_sketches( mod, sketches=actual, @@ -303,12 +307,7 @@ def c3d_0(inputs: T.Buffer[(1, 16, 224, 224, 3), "float32"], weight: T.Buffer[(7 ("SampleCategorical", 1), ] mod = create_te_workload("C3D", 0) - actual = ms.TuneContext( - mod=mod, - target=_target(), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules="default", - ).generate_design_space() + actual = _design_space(mod) check_sketches( mod, sketches=actual, @@ -404,12 +403,7 @@ def cap_0(inputs: T.Buffer[(1, 16, 16, 4, 4, 32), "float32"], weight: T.Buffer[( ("SampleCategorical", 2), ] mod = create_te_workload("CAP", 0) - actual = ms.TuneContext( - mod=mod, - target=_target(), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules="default", - ).generate_design_space() + actual = _design_space(mod) check_sketches( mod, sketches=actual, @@ -492,12 +486,7 @@ def dep_0(placeholder: T.Buffer[(1, 112, 112, 32), "float32"], placeholder_1: T. ("SampleCategorical", 1), ] mod = create_te_workload("DEP", 0) - actual = ms.TuneContext( - mod=mod, - target=_target(), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules="default", - ).generate_design_space() + actual = _design_space(mod) check_sketches( mod, sketches=actual, @@ -580,12 +569,7 @@ def dil_0(inputs: T.Buffer[(1, 224, 224, 3), "float32"], weight: T.Buffer[(7, 7, ("SampleCategorical", 3), ] mod = create_te_workload("DIL", 0) - actual = ms.TuneContext( - mod=mod, - target=_target(), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules="default", - ).generate_design_space() + actual = _design_space(mod) check_sketches( mod, sketches=actual, @@ -661,12 +645,7 @@ def gmm_0(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "flo ("SampleCategorical", 4), ] mod = create_te_workload("GMM", 0) - actual = ms.TuneContext( - mod=mod, - target=_target(), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules="default", - ).generate_design_space() + actual = _design_space(mod) check_sketches( mod, sketches=actual, @@ -750,12 +729,7 @@ def grp_0(inputs: T.Buffer[(1, 56, 56, 64), "float32"], weight: T.Buffer[(3, 3, ("SampleCategorical", 1), ] mod = create_te_workload("GRP", 0) - actual = ms.TuneContext( - mod=mod, - target=_target(), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules="default", - ).generate_design_space() + actual = _design_space(mod) check_sketches( mod, sketches=actual, @@ -840,12 +814,7 @@ def t2d_0(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 5 ("SampleCategorical", 2), ] mod = create_te_workload("T2D", 0) - actual = ms.TuneContext( - mod=mod, - target=_target(), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules="default", - ).generate_design_space() + actual = _design_space(mod) check_sketches( mod, sketches=actual, @@ -923,12 +892,7 @@ def nrm_1(A: T.Buffer[(1, 256, 256), "float32"], D: T.Buffer[1, "float32"]) -> N ("SampleCategorical", 4), ] mod = create_te_workload("NRM", 0) - actual = ms.TuneContext( - mod=mod, - target=_target(), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules="default", - ).generate_design_space() + actual = _design_space(mod) check_sketches( mod, sketches=actual, @@ -1135,12 +1099,7 @@ def sfm_3(A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256 ("SampleCategorical", 0), ] mod = create_te_workload("SFM", 0) - actual = ms.TuneContext( - mod=mod, - target=_target(), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules="default", - ).generate_design_space() + actual = _design_space(mod) check_sketches( mod, sketches=actual, @@ -1225,12 +1184,7 @@ def cbr_0(data: T.Buffer[(1, 224, 224, 3), "float32"], kernel: T.Buffer[(7, 7, 3 ("SampleCategorical", 3), ] mod = create_te_workload("CBR", 0) - actual = ms.TuneContext( - mod=mod, - target=_target(), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules="default", - ).generate_design_space() + actual = _design_space(mod) check_sketches( mod, sketches=actual, @@ -1309,12 +1263,7 @@ def tbg_0(query: T.Buffer[(1, 128, 12, 64), "float32"], value: T.Buffer[(1, 128, ("SampleCategorical", 4), ] mod = create_te_workload("TBG", 0) - actual = ms.TuneContext( - mod=mod, - target=_target(), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules="default", - ).generate_design_space() + actual = _design_space(mod) check_sketches( mod, sketches=actual, @@ -1459,12 +1408,7 @@ def winograd_nchw_conv2d(data: T.Buffer[(1, 64, 224, 224), "float32"], kernel: T ("SampleCategorical", 4), ] mod = _conv2d_winograd_nchw() - actual = ms.TuneContext( - mod=mod, - target=_target(), - space_generator=ms.space_generator.PostOrderApply(), - sch_rules="default", - ).generate_design_space() + actual = _design_space(mod) check_sketches( mod, sketches=actual, diff --git a/tests/python/unittest/test_meta_schedule_space_generator.py b/tests/python/unittest/test_meta_schedule_space_generator.py index 9201fe16e849..ef2be381c694 100644 --- a/tests/python/unittest/test_meta_schedule_space_generator.py +++ b/tests/python/unittest/test_meta_schedule_space_generator.py @@ -18,7 +18,6 @@ # pylint: disable=missing-function-docstring import math -import sys import pytest import tvm @@ -94,7 +93,11 @@ def test_meta_schedule_design_space_generator_union(): def test_meta_schedule_design_space_generator_NIE(): @derived_object class TestPySpaceGenerator(PySpaceGenerator): - pass + def __init__(self): + super().__init__() + self.sch_rules = [] + self.postprocs = [] + self.mutator_probs = {} with pytest.raises( TVMError, match="PySpaceGenerator's InitializeWithTuneContext method not implemented!" diff --git a/tests/python/unittest/test_meta_schedule_task_scheduler.py b/tests/python/unittest/test_meta_schedule_task_scheduler.py index 3edd81ee9a11..33a019e3c555 100644 --- a/tests/python/unittest/test_meta_schedule_task_scheduler.py +++ b/tests/python/unittest/test_meta_schedule_task_scheduler.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. """ Test Meta Schedule Task Scheduler """ - import random import weakref from typing import Set @@ -23,19 +22,11 @@ import pytest import tvm import tvm.testing -from tvm.support import libinfo from tvm import meta_schedule as ms -from tvm._ffi.base import TVMError from tvm.meta_schedule.testing.dummy_object import DummyBuilder, DummyRunner from tvm.script import tir as T from tvm.tir import Schedule -# from tvm.meta_schedule import TuneContext, measure_callback -# from tvm.meta_schedule.search_strategy import ReplayTrace -# from tvm.meta_schedule.space_generator import ScheduleFn -# from tvm.meta_schedule.task_scheduler import GradientBased, PyTaskScheduler, RoundRobin -# from tvm.meta_schedule.utils import derived_object - # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring @@ -131,9 +122,10 @@ class MyTaskScheduler(ms.task_scheduler.PyTaskScheduler): done: Set = set() def next_task_id(self) -> int: - while len(self.done) != len(self.tasks): - x = random.randint(0, len(self.tasks) - 1) - task = self.tasks[x] + tasks = self._outer().tasks_ + while len(self.done) != len(tasks): + x = random.randint(0, len(tasks) - 1) + task = tasks[x] if not task.is_terminated: """Calling base func via following route: Python side: @@ -157,28 +149,28 @@ def test_meta_schedule_task_scheduler_single(): num_trials_per_iter = 3 max_trials_per_task = 10 database = ms.database.MemoryDatabase() - round_robin = ms.task_scheduler.RoundRobin( + round_robin = ms.task_scheduler.RoundRobin() + round_robin.tune( [ ms.TuneContext( MatmulModule, target=tvm.target.Target("llvm"), - space_generator=ms.space_generator.ScheduleFn(sch_fn=_schedule_matmul), - search_strategy=ms.search_strategy.ReplayTrace( - num_trials_per_iter, - max_trials_per_task, - ), + space_generator=_schedule_matmul, + search_strategy=ms.search_strategy.ReplayTrace(), task_name="Test", rand_state=42, ) ], [1.0], + max_trials_global=num_trials_per_iter, + max_trials_per_task=max_trials_per_task, + num_trials_per_iter=64, builder=DummyBuilder(), runner=DummyRunner(), database=database, measure_callbacks=[ms.measure_callback.AddToDatabase()], - max_trials=max_trials_per_task, + cost_model=None, ) - round_robin.tune() assert len(database) == max_trials_per_task @@ -189,48 +181,42 @@ def test_meta_schedule_task_scheduler_multiple(): ms.TuneContext( MatmulModule, target=tvm.target.Target("llvm"), - space_generator=ms.space_generator.ScheduleFn(sch_fn=_schedule_matmul), - search_strategy=ms.search_strategy.ReplayTrace( - num_trials_per_iter, - max_trials_per_task, - ), + space_generator=_schedule_matmul, + search_strategy=ms.search_strategy.ReplayTrace(), task_name="Matmul", rand_state=42, ), ms.TuneContext( MatmulReluModule, target=tvm.target.Target("llvm"), - space_generator=ms.space_generator.ScheduleFn(sch_fn=_schedule_matmul), - search_strategy=ms.search_strategy.ReplayTrace( - num_trials_per_iter, - max_trials_per_task, - ), + space_generator=_schedule_matmul, + search_strategy=ms.search_strategy.ReplayTrace(), task_name="MatmulRelu", rand_state=0xDEADBEEF, ), ms.TuneContext( BatchMatmulModule, target=tvm.target.Target("llvm"), - space_generator=ms.space_generator.ScheduleFn(sch_fn=_schedule_batch_matmul), - search_strategy=ms.search_strategy.ReplayTrace( - num_trials_per_iter, - max_trials_per_task, - ), + space_generator=_schedule_batch_matmul, + search_strategy=ms.search_strategy.ReplayTrace(), task_name="BatchMatmul", rand_state=0x114514, ), ] database = ms.database.MemoryDatabase() - round_robin = ms.task_scheduler.RoundRobin( + round_robin = ms.task_scheduler.RoundRobin() + round_robin.tune( tasks, [1.0, 1.0, 1.0], builder=DummyBuilder(), runner=DummyRunner(), database=database, measure_callbacks=[ms.measure_callback.AddToDatabase()], - max_trials=max_trials_per_task * len(tasks), + max_trials_global=max_trials_per_task * len(tasks), + max_trials_per_task=max_trials_per_task, + num_trials_per_iter=num_trials_per_iter, + cost_model=None, ) - round_robin.tune() assert len(database) == max_trials_per_task * len(tasks) for task in tasks: assert ( @@ -249,82 +235,60 @@ def test_meta_schedule_task_scheduler_NIE(): # pylint: disable=invalid-name class NIETaskScheduler(ms.task_scheduler.PyTaskScheduler): pass - with pytest.raises(TVMError, match="PyTaskScheduler's NextTaskId method not implemented!"): - scheduler = NIETaskScheduler( - tasks=[], - builder=DummyBuilder(), - runner=DummyRunner(), - database=ms.database.MemoryDatabase(), - max_trials=1, - ) + with pytest.raises(ValueError, match="next_task_id is not defined"): + scheduler = NIETaskScheduler() scheduler.next_task_id() def test_meta_schedule_task_scheduler_avoid_cyclic(): # pylint: disable=invalid-name - database = ms.database.MemoryDatabase() - scheduler = MyTaskScheduler( - [], - builder=DummyBuilder(), - runner=DummyRunner(), - database=database, - measure_callbacks=[ - ms.measure_callback.AddToDatabase(), - ], - max_trials=10, - ) + scheduler = MyTaskScheduler() test = weakref.ref(scheduler) # test if it can be destructed successfully del scheduler assert test() is None def test_meta_schedule_task_scheduler_override_next_task_id_only(): # pylint: disable=invalid-name - num_trials_per_iter = 6 max_trials_per_task = 101 tasks = [ ms.TuneContext( MatmulModule, target=tvm.target.Target("llvm"), - space_generator=ms.space_generator.ScheduleFn(sch_fn=_schedule_matmul), - search_strategy=ms.search_strategy.ReplayTrace( - num_trials_per_iter, - max_trials_per_task, - ), + space_generator=_schedule_matmul, + search_strategy=ms.search_strategy.ReplayTrace(), task_name="Matmul", rand_state=42, ), ms.TuneContext( MatmulReluModule, target=tvm.target.Target("llvm"), - space_generator=ms.space_generator.ScheduleFn(sch_fn=_schedule_matmul), - search_strategy=ms.search_strategy.ReplayTrace( - num_trials_per_iter, - max_trials_per_task, - ), + space_generator=_schedule_matmul, + search_strategy=ms.search_strategy.ReplayTrace(), task_name="MatmulRelu", rand_state=0xDEADBEEF, ), ms.TuneContext( BatchMatmulModule, target=tvm.target.Target("llvm"), - space_generator=ms.space_generator.ScheduleFn(sch_fn=_schedule_batch_matmul), - search_strategy=ms.search_strategy.ReplayTrace( - num_trials_per_iter, - max_trials_per_task, - ), + space_generator=_schedule_batch_matmul, + search_strategy=ms.search_strategy.ReplayTrace(), task_name="BatchMatmul", rand_state=0x114514, ), ] database = ms.database.MemoryDatabase() - scheduler = MyTaskScheduler( + scheduler = MyTaskScheduler() + scheduler.tune( tasks, + task_weights=[1.0] * len(tasks), builder=DummyBuilder(), runner=DummyRunner(), database=database, measure_callbacks=[ms.measure_callback.AddToDatabase()], - max_trials=max_trials_per_task * len(tasks), + max_trials_global=max_trials_per_task * len(tasks), + max_trials_per_task=max_trials_per_task, + num_trials_per_iter=6, + cost_model=None, ) - scheduler.tune() assert len(database) == max_trials_per_task * len(tasks) for task in tasks: assert ( @@ -339,55 +303,47 @@ def test_meta_schedule_task_scheduler_override_next_task_id_only(): # pylint: d def test_meta_schedule_task_scheduler_multiple_gradient_based(): - num_trials_per_iter = 6 max_trials_per_task = 101 tasks = [ ms.TuneContext( MatmulModule, target=tvm.target.Target("llvm"), - space_generator=ms.space_generator.ScheduleFn(sch_fn=_schedule_matmul), - search_strategy=ms.search_strategy.ReplayTrace( - num_trials_per_iter, - max_trials_per_task, - ), + space_generator=_schedule_matmul, + search_strategy=ms.search_strategy.ReplayTrace(), task_name="Matmul", rand_state=42, ), ms.TuneContext( MatmulReluModule, target=tvm.target.Target("llvm"), - space_generator=ms.space_generator.ScheduleFn(sch_fn=_schedule_matmul), - search_strategy=ms.search_strategy.ReplayTrace( - num_trials_per_iter, - max_trials_per_task, - ), + space_generator=_schedule_matmul, + search_strategy=ms.search_strategy.ReplayTrace(), task_name="MatmulRelu", rand_state=0xDEADBEEF, ), ms.TuneContext( BatchMatmulModule, target=tvm.target.Target("llvm"), - space_generator=ms.space_generator.ScheduleFn(sch_fn=_schedule_batch_matmul), - search_strategy=ms.search_strategy.ReplayTrace( - num_trials_per_iter, - max_trials_per_task, - ), + space_generator=_schedule_batch_matmul, + search_strategy=ms.search_strategy.ReplayTrace(), task_name="BatchMatmul", rand_state=0x114514, ), ] database = ms.database.MemoryDatabase() - gradient_based = ms.task_scheduler.GradientBased( + gradient_based = ms.task_scheduler.GradientBased() + gradient_based.tune( tasks, task_weights=[1.0, 1.0, 1.0], builder=DummyBuilder(), runner=DummyRunner(), database=database, measure_callbacks=[ms.measure_callback.AddToDatabase()], - seed=0x20220214, - max_trials=max_trials_per_task * len(tasks), + max_trials_global=max_trials_per_task * len(tasks), + max_trials_per_task=max_trials_per_task, + num_trials_per_iter=6, + cost_model=None, ) - gradient_based.tune() assert len(database) == max_trials_per_task * len(tasks) for task in tasks: assert ( @@ -397,4 +353,9 @@ def test_meta_schedule_task_scheduler_multiple_gradient_based(): if __name__ == "__main__": - tvm.testing.main() + test_meta_schedule_task_scheduler_single() + test_meta_schedule_task_scheduler_multiple() + test_meta_schedule_task_scheduler_NIE() + test_meta_schedule_task_scheduler_avoid_cyclic() + test_meta_schedule_task_scheduler_override_next_task_id_only() + test_meta_schedule_task_scheduler_multiple_gradient_based() diff --git a/tests/python/unittest/test_meta_schedule_tune_relay.py b/tests/python/unittest/test_meta_schedule_tune_relay.py deleted file mode 100644 index 91101dd6b6c0..000000000000 --- a/tests/python/unittest/test_meta_schedule_tune_relay.py +++ /dev/null @@ -1,554 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=missing-docstring -import logging -import tempfile -from os import path as osp -from typing import List, Optional - -import numpy as np # type: ignore -import pytest -import tvm -import tvm.testing -from tvm import meta_schedule as ms -from tvm import relay -from tvm._ffi import register_func -from tvm.contrib import graph_executor -from tvm.ir import IRModule -from tvm.meta_schedule.testing.relay_workload import get_network -from tvm.script import tir as T -from tvm.target.target import Target -from tvm.tir.schedule import BlockRV, Schedule -from tvm.tir.schedule.trace import Trace -from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN - -logging.basicConfig( - format="%(asctime)s.%(msecs)03d %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S" -) -logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) - -# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument -# fmt: off - -@tvm.script.ir_module -class tvmgen_default_fused_layout_transform: - @T.prim_func - def main( # type: ignore - placeholder: T.Buffer[(1, 3, 16, 16), "float32"], # type: ignore - T_layout_trans: T.Buffer[(1, 1, 16, 16, 3), "float32"], # type: ignore - ) -> None: # type: ignore - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body - # with T.block("root") - for i0, i1, i2, i3, i4 in T.grid(1, 1, 16, 16, 3): - with T.block("T_layout_trans"): - ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) - T.reads(placeholder[ax0, ax1 * 3 + ax4, ax2, ax3]) - T.writes(T_layout_trans[ax0, ax1, ax2, ax3, ax4]) - T_layout_trans[ax0, ax1, ax2, ax3, ax4] = T.if_then_else( - ax0 < 1 and ax1 * 3 + ax4 < 3 and ax2 < 16 and ax3 < 16, # type: ignore - placeholder[ax0, ax1 * 3 + ax4, ax2, ax3], - T.float32(0), - dtype="float32", - ) - - -@tvm.script.ir_module -class tvmgen_default_fused_nn_contrib_conv2d_NCHWc: - @T.prim_func - def main(placeholder: T.Buffer[(1, 1, 16, 16, 3), "float32"], placeholder_1: T.Buffer[(2, 1, 5, 5, 3, 4), "float32"], conv2d_NCHWc: T.Buffer[(1, 2, 16, 16, 4), "float32"]) -> None: # type: ignore - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body - # with T.block("root") - data_pad = T.alloc_buffer([1, 1, 20, 20, 3], dtype="float32") - for i0, i1, i2, i3, i4 in T.grid(1, 1, 20, 20, 3): - with T.block("data_pad"): - i0_1, i1_1, i2_1, i3_1, i4_1 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) - T.reads(placeholder[i0_1, i1_1, i2_1 - 2, i3_1 - 2, i4_1]) - T.writes(data_pad[i0_1, i1_1, i2_1, i3_1, i4_1]) - data_pad[i0_1, i1_1, i2_1, i3_1, i4_1] = T.if_then_else(2 <= i2_1 and i2_1 < 18 and 2 <= i3_1 and i3_1 < 18, placeholder[i0_1, i1_1, i2_1 - 2, i3_1 - 2, i4_1], T.float32(0), dtype="float32") # type: ignore # pylint: disable=R1716 - for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(1, 2, 16, 16, 4, 3, 5, 5): - with T.block("conv2d_NCHWc"): - n, oc_chunk, oh, ow, oc_block, ic, kh, kw = T.axis.remap("SSSSSRRR", [i0, i1, i2, i3, i4, i5, i6, i7]) - T.reads(data_pad[n, ic // 3, oh + kh, ow + kw, ic % 3], placeholder_1[oc_chunk, ic // 3, kh, kw, ic % 3, oc_block]) # type: ignore - T.writes(conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block]) - T.block_attr({"workload":["conv2d_NCHWc.x86", ["TENSOR", [1, 1, 16, 16, 3], "float32"], ["TENSOR", [2, 1, 5, 5, 3, 4], "float32"], [1, 1], [2, 2, 2, 2], [1, 1], "NCHW3c", "NCHW4c", "float32"]}) - with T.init(): - conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = T.float32(0) - conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] + data_pad[n, ic // 3, oh + kh, ow + kw, ic % 3] * placeholder_1[oc_chunk, ic // 3, kh, kw, ic % 3, oc_block] # type: ignore - -@tvm.script.ir_module -class tvmgen_default_fused_layout_transform_1: - @T.prim_func - def main(placeholder: T.Buffer[(1, 2, 16, 16, 4), "float32"], T_layout_trans: T.Buffer[(1, 8, 16, 16), "float32"]) -> None: # type: ignore - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body - # with T.block("root") - for i0, i1, i2, i3 in T.grid(1, 8, 16, 16): - with T.block("T_layout_trans"): - ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(placeholder[ax0, ax1 // 4, ax2, ax3, ax1 % 4]) # type: ignore - T.writes(T_layout_trans[ax0, ax1, ax2, ax3]) - T_layout_trans[ax0, ax1, ax2, ax3] = T.if_then_else(ax0 < 1 and ax1 < 8 and ax2 < 16 and ax3 < 16, placeholder[ax0, ax1 // 4, ax2, ax3, ax1 % 4], T.float32(0), dtype="float32") # type: ignore - -# fmt: on -# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument - - -@pytest.mark.skip("Integration test") -@pytest.mark.parametrize( - "model_name, input_shape, target, layout", - [ - ("resnet_18", [1, 3, 224, 224], "llvm --num-cores=12", "NHWC"), - ("resnet_18", [1, 3, 224, 224], "nvidia/geforce-rtx-3070", "NHWC"), - ("mobilenet_v2", [1, 3, 224, 224], "llvm --num-cores=12", "NHWC"), - ("mobilenet_v2", [1, 3, 224, 224], "nvidia/geforce-rtx-3070", "NHWC"), - ("bert_base", [1, 64], "llvm --num-cores=12", None), - ("bert_base", [1, 64], "nvidia/geforce-rtx-3070", None), - ], -) -def test_meta_schedule_tune_relay( - model_name: str, - input_shape: List[int], - target: str, - layout: Optional[str], -): - dev = tvm.cpu() if str(target).startswith("llvm") else tvm.cuda() - if model_name.startswith("bert"): - data = tvm.nd.array(np.random.randint(0, 30521, size=input_shape), dev) # embedding size - else: - data = tvm.nd.array(np.random.randn(*input_shape).astype("float32"), dev) - - mod, params, (input_name, _, _) = get_network( - name=model_name, - input_shape=input_shape, - layout=layout, - ) - - target = Target(target) - with tempfile.TemporaryDirectory() as work_dir: - with ms.Profiler() as profiler: - rt_mod1: tvm.runtime.Module = ms.tune_relay( - mod=mod, - params=params, - target=target, - config=ms.TuneConfig( - strategy="evolutionary", - num_trials_per_iter=32, - max_trials_per_task=20000, - max_trials_global=20000, - ), - work_dir=work_dir, - ) - print(profiler.table()) - # Compile without meta-schedule for correctness check - with tvm.transform.PassContext(opt_level=0): - rt_mod2 = relay.build(mod, target=target, params=params) - - def get_output(data, lib): - module = graph_executor.GraphModule(lib["default"](dev)) - module.set_input(input_name, data) - module.run() - return module.get_output(0).numpy() - - # Check correctness - actual_output = get_output(data, rt_mod1) - expected_output = get_output(data, rt_mod2) - assert np.allclose(actual_output, expected_output, rtol=1e-4, atol=2e-4) - - -def test_meta_schedule_te2primfunc_argument_order(): - @ms.derived_object - class TestDummyDatabase(ms.database.PyDatabase): - def __init__(self): - super().__init__() - self.records = [] - self.workload_reg = [] - - def has_workload(self, mod: IRModule) -> ms.database.Workload: - for workload in self.workload_reg: - if tvm.ir.structural_equal(workload.mod, mod): - return True - # The database has already put in all correct workloads - raise ValueError( - "The workload searched for is not in given database!" - + " Incorrect TIR was generated from TE subgraph." - ) - - def commit_tuning_record(self, record: ms.database.TuningRecord) -> None: - self.records.append(record) - - def commit_workload(self, mod: IRModule) -> ms.database.Workload: - for workload in self.workload_reg: - if tvm.ir.structural_equal(workload.mod, mod): - return workload - workload = ms.database.Workload(mod) - self.workload_reg.append(workload) - return workload - - def get_top_k( - self, - workload: ms.database.Workload, - top_k: int, - ) -> List[ms.database.TuningRecord]: - return list( - filter( - lambda x: x.workload == workload, - sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)), - ) - )[: int(top_k)] - - def __len__(self) -> int: - return len(self.records) - - def print_results(self) -> None: - print("\n".join([str(r) for r in self.records])) - - data_shape = (1, 3, 16, 16) - weight_shape = (8, 3, 5, 5) - data = relay.var("data", relay.TensorType(data_shape, "float32")) - weight = relay.var("weight", relay.TensorType(weight_shape, "float32")) - y = relay.nn.conv2d( - data, - weight, - padding=(2, 2), - kernel_size=(5, 5), - kernel_layout="OIHW", - out_dtype="float32", - ) - f = relay.Function([data, weight], y) - mod = tvm.IRModule.from_expr(f) - mod = relay.transform.InferType()(mod) - - data_sample = np.random.rand(*data_shape).astype("float32") - weight_sample = np.random.rand(*weight_shape).astype("float32") - params = {mod["main"].params[1].name_hint: weight_sample} - - input_name = "data" - dev = tvm.cpu() - target = Target("llvm --num-cores=12") - data = tvm.nd.array(data_sample, dev) - - database = TestDummyDatabase() - database.commit_workload(tvmgen_default_fused_layout_transform) - database.commit_workload(tvmgen_default_fused_layout_transform_1) - database.commit_workload(tvmgen_default_fused_nn_contrib_conv2d_NCHWc) - - with database, tvm.transform.PassContext( # pylint: disable=not-context-manager - opt_level=3, - config={"relay.backend.use_meta_schedule": True}, - ): - rt_mod1 = relay.build(mod, target=target, params=params) - - # Compile without meta-schedule for correctness check - with tvm.transform.PassContext(opt_level=0): - rt_mod2 = relay.build(mod, target=target, params=params) - - def get_output(data, lib): - module = graph_executor.GraphModule(lib["default"](dev)) - module.set_input(input_name, data) - module.run() - return module.get_output(0).numpy() - - # Check correctness - actual_output = get_output(data, rt_mod1) - expected_output = get_output(data, rt_mod2) - assert np.allclose(actual_output, expected_output, rtol=1e-4, atol=2e-4) - - -def test_meta_schedule_relay_lowering(): - data_shape = (1, 3, 16, 16) - weight_shape = (8, 3, 5, 5) - data = relay.var("data", relay.TensorType(data_shape, "float32")) - weight = relay.var("weight", relay.TensorType(weight_shape, "float32")) - y = relay.nn.conv2d( - data, - weight, - padding=(2, 2), - kernel_size=(5, 5), - kernel_layout="OIHW", - out_dtype="float32", - ) - f = relay.Function([data, weight], y) - mod = tvm.IRModule.from_expr(f) - mod = relay.transform.InferType()(mod) - - data_sample = np.random.rand(*data_shape).astype("float32") - weight_sample = np.random.rand(*weight_shape).astype("float32") - params = {mod["main"].params[1].name_hint: weight_sample} - - input_name = "data" - dev = tvm.cpu() - target = Target("llvm --num-cores=12") - data = tvm.nd.array(data_sample, dev) - - with tempfile.TemporaryDirectory() as work_dir: - database = ms.database.JSONDatabase( - osp.join(work_dir, "workload.json"), osp.join(work_dir, "records.json") - ) - database.commit_tuning_record( - ms.database.TuningRecord( - Trace([], {}), - database.commit_workload(tvmgen_default_fused_nn_contrib_conv2d_NCHWc), - [0.0], - target=target, - args_info=[], - ) - ) - with database, tvm.transform.PassContext( - opt_level=3, - config={"relay.backend.use_meta_schedule": True}, - ): - rt_mod1 = relay.build(mod, target=target, params=params) - - # Compile without meta-schedule for correctness check - with tvm.transform.PassContext(opt_level=0): - rt_mod2 = relay.build(mod, target=target, params=params) - - def get_output(data, lib): - module = graph_executor.GraphModule(lib["default"](dev)) - module.set_input(input_name, data) - module.run() - return module.get_output(0).numpy() - - # Check correctness - actual_output = get_output(data, rt_mod1) - expected_output = get_output(data, rt_mod2) - assert np.allclose(actual_output, expected_output, rtol=1e-4, atol=2e-4) - - -def schedule_dense(dense_block, M, do_tune, sch): # pylint: disable=invalid-name - """ - Manually schedule a dense block, created from TE compute op via CreatePrimFunc, - using VNNI instruction. - """ - post_blocks = sch.get_consumers(dense_block) - - if len(post_blocks) > 0: - # Fuse all intermediate post ops into the last op. - # This is equivalent to the traverse_inline function used in TE schedules. - while True: - next_post_blocks = [] - for post_block in post_blocks: - next_consumers = sch.get_consumers(post_block) - - if len(next_consumers) > 0: - sch.compute_inline(post_block) - - next_post_blocks += next_consumers - - if len(next_post_blocks) == 0: - assert len(post_blocks) == 1 - outer_block = post_blocks[0] - a_y, a_x = sch.get_loops(outer_block)[-2:] - break - - post_blocks = next_post_blocks - else: - a_y, a_x, _ = sch.get_loops(dense_block)[-3:] - outer_block = dense_block - - if do_tune: - y_factors = sch.sample_perfect_tile(a_y, n=2, max_innermost_factor=128) - a_yo, a_yi = sch.split(a_y, factors=y_factors) - else: - a_yo, a_yi = sch.split(a_y, factors=[None, min(M, 64)]) - - a_xo, a_xi = sch.split(a_x, factors=[None, 16]) - sch.reorder(a_yo, a_xo, a_yi, a_xi) - fused = sch.fuse(a_yo, a_xo) - - if outer_block != dense_block: - # Handle the case when dense is fused with post ops. - sch.vectorize(a_xi) - sch.compute_at(dense_block, a_yi) - - a_xi, a_k = sch.get_loops(dense_block)[-2:] - a_ko, a_ki = sch.split(a_k, factors=[None, 4]) - sch.reorder(a_ko, a_xi, a_ki) - - # We need to parallelize before decompose_reduction, otherwise the so-called "Compact dataflow" - # condition is violated. - sch.parallel(fused) - dec = sch.decompose_reduction(dense_block, a_ko) - - init_loop = sch.get_loops(dec)[-1] - sch.vectorize(init_loop) - - sch.tensorize(a_xi, VNNI_INTRIN) - - -def manual_tir_common(do_tune=False): - M, N, K = 1024, 1024, 1024 # pylint: disable=invalid-name - data_shape = (M, K) - weight_shape = (N, K) - - data_dtype = "uint8" - data = relay.var("data", shape=data_shape, dtype=data_dtype) - weight = relay.var("weight", shape=weight_shape, dtype="int8") - bias = relay.var("bias", shape=(weight_shape[0],), dtype="int32") - - # dense is tuned by the TIR schedule above, bmm is scheduled by TE (topi/x86/batch_matmul.py) - dense = relay.nn.dense(data, weight, out_dtype="int32") - bias_add = relay.nn.bias_add(dense, bias) + relay.const(1, dtype="int32") - out = relay.nn.batch_matmul( - relay.cast(relay.expand_dims(bias_add, 0), "uint8"), - relay.cast(relay.expand_dims(bias_add, 0), "int8"), - out_dtype="int32", - ) - - relay_mod = tvm.IRModule.from_expr(out) - - target = "llvm -mcpu=cascadelake -num-cores 4" - dev = tvm.device(target, 0) - - data = np.random.uniform(1, 10, size=(M, K)).astype("uint8") - weight_np = np.random.uniform(1, 10, size=weight_shape).astype("int8") - bias_np = np.random.uniform(1, 10, size=(weight_shape[0],)).astype("int32") - - ref = ( - relay.create_executor("vm", mod=relay_mod, device=dev, target=target) - .evaluate()(*[data, weight_np, bias_np]) - .numpy() - ) - - params = {"weight": weight_np, "bias": bias_np} - - if do_tune: - extracted_tasks = ms.extract_task_from_relay(relay_mod, target, params) - # Filter out tasks that we don't intend to schedule / tune with TIR. - tune_tasks = list( - filter( - lambda task: "dense" in task.task_name, - extracted_tasks, - ) - ) - config = ms.TuneConfig( - strategy="replay_trace", - num_trials_per_iter=8, - max_trials_per_task=8, - max_trials_global=8, - ) - - with tempfile.TemporaryDirectory() as work_dir: - # postprocs=lambda: [] is important to prevent default post processors from - # tampering with the manual schedule. - database = ms.tune_extracted_tasks( - tune_tasks, - config, - work_dir=work_dir, - postprocs=lambda: [], - ) - else: - - def schedule_fn(sch) -> bool: - if "dense" not in sch.mod.attrs["task_name"]: - return False - - block = sch.get_block("compute") - - # Looks up schedule_rule annotation. - # See the comment in test_tune_relay_manual_tir_vnni(). - schedule_rule = sch.get(block).annotations["schedule_rule"] - - assert "dense_vnni" in schedule_rule - - schedule_dense(block, M, False, sch) - - return True - - database = ms.database.ScheduleFnDatabase(schedule_fn) - - with database, tvm.transform.PassContext( - opt_level=3, - config={"relay.backend.use_meta_schedule": True}, - ): - # pylint: disable=W0105 - """ - The log should say - Warning: Cannot find workload: tvmgen_default_fused_expand_dims - Warning: Cannot find workload: tvmgen_default_fused_cast - Warning: Cannot find workload: tvmgen_default_fused_cast_1 - Warning: Cannot find workload: tvmgen_default_fused_nn_batch_matmul - - This means batch matmul and others are scheduled by TE, and dense (the one not warned) - is found in the meta schedule tuning database during compilation - """ - # pylint: enable=W0105 - lib = relay.build(relay_mod, target=target, params=params) - - runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) - - runtime.set_input("data", data) - runtime.run() - - out = runtime.get_output(0).numpy() - - np.testing.assert_equal(out, ref) - - -@tvm.testing.requires_cascadelake -def test_tune_relay_manual_tir_vnni(): - manual_tir_common(do_tune=False) - - # pylint: disable=W0105 - """ - We can inject and apply a custom TIR scheduling to a TE compute of interest, using - the "schedule_rule" annotation. For example, in topi/x86/dense.py we have the following - declaration for int8 dense targeting the VNNI instruction. - - C = te.compute( - ... - attrs={"schedule_rule": "meta_schedule.dense_vnni"}, - ) - - When the MetaSchedule encounters a TensorIR block with the "schedule_rule" annotation, - it looks up the packed func registry for a function that is associated with the given schedule - rule key ("meta_schedule.dense_vnni" in this example). The signature of such custom schedule - functions must be - - (tir.schedule.Schedule, tir.schedule.BlockRV) -> [tir.schedule.Schedule]. - - The BlockRV argument corresponds to the TE compute annotated with "schedule_rule". - - The relevant code is in meta_schedule/space_generator/post_order_apply.cc. - - """ - # pylint: enable=W0105 - - def schedule_rule_dense_vnni(sch: Schedule, dense_block: BlockRV): - schedule_dense(dense_block, None, True, sch) - return [sch] - - register_func("meta_schedule.dense_vnni", schedule_rule_dense_vnni) - - manual_tir_common(do_tune=True) - - -if __name__ == """__main__""": - test_meta_schedule_tune_relay("resnet_18", [1, 3, 224, 224], "llvm --num-cores=12", None) - test_meta_schedule_tune_relay("resnet_18", [1, 3, 224, 224], "nvidia/geforce-rtx-3070", "NCHW") - test_meta_schedule_tune_relay("mobilenet_v2", [1, 3, 224, 224], "llvm --num-cores=12", None) - test_meta_schedule_tune_relay("mobilenet_v2", [1, 3, 224, 224], "nvidia/geforce-rtx-3070", None) - test_meta_schedule_tune_relay("bert_base", [1, 64], "llvm --num-cores=12", None) - test_meta_schedule_tune_relay("bert_base", [1, 64], "nvidia/geforce-rtx-3070", None) - test_meta_schedule_te2primfunc_argument_order() - test_meta_schedule_relay_lowering() - test_tune_relay_manual_tir_vnni() diff --git a/tests/python/unittest/test_meta_schedule_tune_te.py b/tests/python/unittest/test_meta_schedule_tune_te.py deleted file mode 100644 index d294b2ddd6e8..000000000000 --- a/tests/python/unittest/test_meta_schedule_tune_te.py +++ /dev/null @@ -1,52 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=missing-docstring -import logging -import tempfile - -import pytest -from tvm.meta_schedule import TuneConfig, tune_te -from tvm.meta_schedule.testing import te_workload -from tvm.target.target import Target -from tvm.tir import Schedule - -logging.basicConfig() -logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) - - -def test_tune_matmul(): - with tempfile.TemporaryDirectory() as work_dir: - sch: Schedule = tune_te( - tensors=te_workload.batch_matmul_nkkm(B=1, N=128, M=128, K=128), - target=Target("llvm --num-cores=16"), - config=TuneConfig( - strategy="replay_trace", - num_trials_per_iter=1, - max_trials_per_task=1, - max_trials_global=1, - ), - work_dir=work_dir, - ) - if sch is None: - print("No valid schedule found!") - else: - print(sch.mod.script()) - print(sch.trace) - - -if __name__ == """__main__""": - test_tune_matmul() diff --git a/tests/python/unittest/test_meta_schedule_tune_tir.py b/tests/python/unittest/test_meta_schedule_tune_tir.py index 6ab5f9b8c5c4..aa45120c2316 100644 --- a/tests/python/unittest/test_meta_schedule_tune_tir.py +++ b/tests/python/unittest/test_meta_schedule_tune_tir.py @@ -17,17 +17,14 @@ # pylint: disable=missing-docstring,no-member,invalid-name,unused-variable import logging import tempfile -import numpy as np +import numpy as np import pytest import tvm - +import tvm.testing from tvm import meta_schedule as ms -from tvm.meta_schedule import TuneContext, TuneConfig, tune_tir from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc from tvm.meta_schedule.testing.local_rpc import LocalRPC -from tvm.meta_schedule.schedule_rule import PyScheduleRule -from tvm.meta_schedule.utils import derived_object from tvm.script import tir as T from tvm.target import Target from tvm.tir.schedule import BlockRV, Schedule @@ -64,77 +61,42 @@ def two_step(a: T.handle, c: T.handle) -> None: C[vi, vj] = B[vi, vj] + 3.0 -@pytest.mark.skip("Integration test") +@tvm.testing.requires_llvm def test_tune_matmul_cpu(): with tempfile.TemporaryDirectory() as work_dir: - sch: Schedule = tune_tir( + target = Target("llvm --num-cores=16") + database = ms.tir_integration.tune_tir( mod=matmul, - target=Target("llvm --num-cores=16"), - config=TuneConfig( - strategy="replay_trace", - num_trials_per_iter=32, - max_trials_per_task=32, - max_trials_global=32, - ), + target=target, work_dir=work_dir, + max_trials_global=32, + num_trials_per_iter=16, ) + sch = ms.tir_integration.compile_tir(database, matmul, target) if sch is None: print("No valid schedule found!") else: - print(sch.mod.script()) - print(sch.trace) - - -@pytest.mark.skip("Integration test") -def test_tune_block_cpu(): - @derived_object - class RemoveBlock(PyScheduleRule): - def _initialize_with_tune_context(self, context: TuneContext) -> None: - pass - - def apply(self, sch: Schedule, block: BlockRV): - if sch.get(block).name_hint == "root": - return [sch] - sch = sch.copy() - sch.compute_inline(block) - return [sch] - - with tempfile.TemporaryDirectory() as work_dir: - sch: Schedule = tune_tir( - mod=two_step, - target=Target("llvm --num-cores=16"), - config=TuneConfig( - strategy="replay_trace", - num_trials_per_iter=32, - max_trials_per_task=32, - max_trials_global=32, - ), - work_dir=work_dir, - blocks=["A"], - sch_rules=lambda *args: [RemoveBlock()], - ) - assert sch is not None + sch.mod.show() + sch.trace.show() -@pytest.mark.skip("Integration test") +@tvm.testing.requires_cuda def test_tune_matmul_cuda(): with tempfile.TemporaryDirectory() as work_dir: - sch: Schedule = tune_tir( + target = Target("nvidia/geforce-rtx-3070") + database = ms.tir_integration.tune_tir( mod=matmul, - target=Target("nvidia/geforce-rtx-3070"), - config=TuneConfig( - strategy="replay_trace", - num_trials_per_iter=32, - max_trials_per_task=32, - max_trials_global=32, - ), + target=target, work_dir=work_dir, + max_trials_global=32, + num_trials_per_iter=16, ) + sch = ms.tir_integration.compile_tir(database, matmul, target) if sch is None: print("No valid schedule found!") else: - print(sch.mod.script()) - print(sch.trace) + sch.mod.show() + sch.trace.show() def test_tune_run_module_via_rpc(): @@ -179,6 +141,43 @@ def f_timer(rt_mod, dev, input_data): tvm.testing.assert_allclose(result.numpy(), c_np, rtol=1e-3) +def test_tune_block_cpu(): + @ms.derived_object + class RemoveBlock(ms.schedule_rule.PyScheduleRule): + def _initialize_with_tune_context(self, context: ms.TuneContext) -> None: + pass + + def apply(self, sch: Schedule, block: BlockRV): + if sch.get(block).name_hint == "root": + return [sch] + sch = sch.copy() + sch.compute_inline(block) + return [sch] + + def clone(self) -> "RemoveBlock": + return RemoveBlock() + + with tempfile.TemporaryDirectory() as work_dir: + target = Target("llvm --num-cores=16") + database = ms.tir_integration.tune_tir( + mod=two_step, + target=target, + work_dir=work_dir, + max_trials_global=32, + num_trials_per_iter=16, + space=ms.space_generator.PostOrderApply( + f_block_filter=lambda block: block.name_hint == "A", + sch_rules=[RemoveBlock()], + postprocs=[], + mutator_probs={}, + ), + ) + sch = ms.tir_integration.compile_tir(database, two_step, target) + assert sch is not None + sch.mod.show() + sch.trace.show() + + if __name__ == """__main__""": test_tune_matmul_cpu() test_tune_matmul_cuda() diff --git a/tests/python/unittest/test_meta_schedule_vnni_integration.py b/tests/python/unittest/test_meta_schedule_vnni_integration.py new file mode 100644 index 000000000000..2cd609863056 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_vnni_integration.py @@ -0,0 +1,249 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import logging +import tempfile +from typing import Optional + +import numpy as np # type: ignore +import pytest +import tvm +from tvm import meta_schedule as ms +from tvm import relay +from tvm._ffi import register_func +from tvm.tir.schedule import BlockRV, Schedule +from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN + +logging.basicConfig( + format="%(asctime)s.%(msecs)03d %(levelname)s %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) +logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) + + +def _schedule_dense(m: Optional[int], do_tune: bool): + """Manually schedule a dense block, created from TE compute op via CreatePrimFunc, + using VNNI instruction. + """ + + def schedule_fn(sch, dense_block: Optional[BlockRV] = None) -> bool: + if "dense" not in sch.mod.attrs["task_name"]: + return False + if dense_block is None: + dense_block = sch.get_block("compute") + assert "dense_vnni" in sch.get(dense_block).annotations["schedule_rule"] + + post_blocks = sch.get_consumers(dense_block) + if len(post_blocks) > 0: + # Fuse all intermediate post ops into the last op. + # This is equivalent to the traverse_inline function used in TE schedules. + while True: + next_post_blocks = [] + for post_block in post_blocks: + next_consumers = sch.get_consumers(post_block) + if len(next_consumers) > 0: + sch.compute_inline(post_block) + next_post_blocks += next_consumers + if len(next_post_blocks) == 0: + assert len(post_blocks) == 1 + outer_block = post_blocks[0] + a_y, a_x = sch.get_loops(outer_block)[-2:] + break + post_blocks = next_post_blocks + else: + a_y, a_x, _ = sch.get_loops(dense_block)[-3:] + outer_block = dense_block + if do_tune: + y_factors = sch.sample_perfect_tile(a_y, n=2, max_innermost_factor=128) + a_yo, a_yi = sch.split(a_y, factors=y_factors) + else: + a_yo, a_yi = sch.split(a_y, factors=[None, min(m, 64)]) + a_xo, a_xi = sch.split(a_x, factors=[None, 16]) + sch.reorder(a_yo, a_xo, a_yi, a_xi) + fused = sch.fuse(a_yo, a_xo) + if outer_block != dense_block: + # Handle the case when dense is fused with post ops. + sch.vectorize(a_xi) + sch.compute_at(dense_block, a_yi) + a_xi, a_k = sch.get_loops(dense_block)[-2:] + a_ko, a_ki = sch.split(a_k, factors=[None, 4]) + sch.reorder(a_ko, a_xi, a_ki) + # We need to parallelize before decompose_reduction, otherwise the so-called "Compact dataflow" + # condition is violated. + sch.parallel(fused) + dec = sch.decompose_reduction(dense_block, a_ko) + init_loop = sch.get_loops(dec)[-1] + sch.vectorize(init_loop) + sch.tensorize(a_xi, VNNI_INTRIN) + return True + + return schedule_fn + + +def _relay_dense(m, n, k): + data = relay.var("data", shape=(m, k), dtype="uint8") + weight = relay.var("weight", shape=(n, k), dtype="int8") + bias = relay.var("bias", shape=(n,), dtype="int32") + # dense is tuned by the TIR schedule above, bmm is scheduled by TE (topi/x86/batch_matmul.py) + dense = relay.nn.dense(data, weight, out_dtype="int32") + bias_add = relay.nn.bias_add(dense, bias) + relay.const(1, dtype="int32") + out = relay.nn.batch_matmul( + relay.cast(relay.expand_dims(bias_add, 0), "uint8"), + relay.cast(relay.expand_dims(bias_add, 0), "int8"), + out_dtype="int32", + ) + relay_mod = tvm.IRModule.from_expr(out) + data = np.random.uniform(1, 10, size=(m, k)).astype("uint8") + params = { + "weight": np.random.uniform(1, 10, size=(n, k)).astype("int8"), + "bias": np.random.uniform(1, 10, size=(n,)).astype("int32"), + } + + def f_check(lib, dev): + ref = ( + relay.create_executor( + "vm", + mod=relay_mod, + device=dev, + target="llvm", + ) + .evaluate()(data, params["weight"], params["bias"]) + .numpy() + ) + runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) + runtime.set_input("data", data) + runtime.run() + out = runtime.get_output(0).numpy() + np.testing.assert_equal(out, ref) + + return relay_mod, params, f_check + + +@pytest.mark.skip("Requires cascadelake") +def test_vnni_schedule_fn_database(): + m, n, k = 1024, 1024, 1024 + target = tvm.target.Target("llvm -mcpu=cascadelake -num-cores 4") + dev = tvm.cpu(0) + relay_mod, params, f_check = _relay_dense(m, n, k) + + with ms.database.ScheduleFnDatabase( + _schedule_dense( + m=m, + do_tune=False, + ) + ), tvm.transform.PassContext( + opt_level=3, + config={"relay.backend.use_meta_schedule": True}, + ): + # pylint: disable=W0105 + """The log should say + Warning: Cannot find workload: tvmgen_default_fused_expand_dims + Warning: Cannot find workload: tvmgen_default_fused_cast + Warning: Cannot find workload: tvmgen_default_fused_cast_1 + Warning: Cannot find workload: tvmgen_default_fused_nn_batch_matmul + + This means batch matmul and others are scheduled by TE, and dense (the one not warned) + is found in the meta schedule tuning database during compilation + """ + # pylint: enable=W0105 + lib = relay.build(relay_mod, target=target, params=params) + f_check(lib, dev) + + +@pytest.mark.skip("Requires cascadelake") +def test_vnni_schedule_fn_tune(): + # pylint: disable=W0105 + """ + We can inject and apply a custom TIR scheduling to a TE compute of interest, using + the "schedule_rule" annotation. For example, in topi/x86/dense.py we have the following + declaration for int8 dense targeting the VNNI instruction. + + C = te.compute( + ... + attrs={"schedule_rule": "meta_schedule.dense_vnni"}, + ) + + When the MetaSchedule encounters a TensorIR block with the "schedule_rule" annotation, + it looks up the packed func registry for a function that is associated with the given schedule + rule key ("meta_schedule.dense_vnni" in this example). The signature of such custom schedule + functions must be + + (tir.schedule.Schedule, tir.schedule.BlockRV) -> [tir.schedule.Schedule]. + + The BlockRV argument corresponds to the TE compute annotated with "schedule_rule". + + The relevant code is in meta_schedule/space_generator/post_order_apply.cc. + """ + + def schedule_rule_dense_vnni(sch: Schedule, dense_block: BlockRV): + _schedule_dense(m=None, do_tune=True)(sch, dense_block) + return [sch] + + register_func("meta_schedule.dense_vnni", schedule_rule_dense_vnni) + + m, n, k = 1024, 1024, 1024 + target = tvm.target.Target("llvm -mcpu=cascadelake -num-cores 4") + dev = tvm.cpu(0) + relay_mod, params, f_check = _relay_dense(m, n, k) + + extracted_tasks = ms.relay_integration.extract_tasks(relay_mod, target, params) + with tempfile.TemporaryDirectory() as work_dir: + # postprocs=lambda: [] is important to prevent default post processors from + # tampering with the manual schedule. + tasks = ms.relay_integration.extracted_tasks_to_tune_contexts( + list( + filter( + lambda task: "dense" in task.task_name, + extracted_tasks, + ) + ), + work_dir=work_dir, + space=ms.space_generator.PostOrderApply( + f_block_filter=None, + sch_rules=None, + postprocs=[], + mutator_probs=None, + ), + ) + database = ms.relay_integration.tune_tasks( + tasks=tasks, + task_weights=[1.0] * len(tasks), + work_dir=work_dir, + max_trials_global=20000, + ) + with database, tvm.transform.PassContext( + opt_level=3, + config={"relay.backend.use_meta_schedule": True}, + ): + # pylint: disable=W0105 + """The log should say + Warning: Cannot find workload: tvmgen_default_fused_expand_dims + Warning: Cannot find workload: tvmgen_default_fused_cast + Warning: Cannot find workload: tvmgen_default_fused_cast_1 + Warning: Cannot find workload: tvmgen_default_fused_nn_batch_matmul + + This means batch matmul and others are scheduled by TE, and dense (the one not warned) + is found in the meta schedule tuning database during compilation + """ + # pylint: enable=W0105 + lib = relay.build(relay_mod, target=target, params=params) + f_check(lib, dev) + + +if __name__ == """__main__""": + test_vnni_schedule_fn_database() + test_vnni_schedule_fn_tune() diff --git a/tests/python/unittest/test_tir_schedule_trace.py b/tests/python/unittest/test_tir_schedule_trace.py index 8a5155bcba43..916db184e09b 100644 --- a/tests/python/unittest/test_tir_schedule_trace.py +++ b/tests/python/unittest/test_tir_schedule_trace.py @@ -282,7 +282,6 @@ def test_trace_simplified_2(): ) ) trace = trace.simplified(remove_postproc=False) - print(trace.show()) assert str(trace) == "\n".join( ( "# from tvm import tir",