Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions include/tvm/meta_schedule/measure_callback.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,6 @@ class MeasureCallback : public runtime::ObjectRef {
* \return The measure callback created.
*/
TVM_DLL static MeasureCallback RemoveBuildArtifact();
/*!
* \brief Create a measure callback that echos the statistics of the tuning process to the console
* \return The measure callback created.
*/
TVM_DLL static MeasureCallback EchoStatistics();
/*!
* \brief Create a measure callback that updates the cost model with measurement result.
* \return The measure callback created.
Expand All @@ -140,6 +135,8 @@ class MeasureCallback : public runtime::ObjectRef {
*/
TVM_DLL static MeasureCallback PyMeasureCallback(PyMeasureCallbackNode::FApply f_apply,
PyMeasureCallbackNode::FAsString f_as_string);
/*! \brief The default list of measure callbacks. */
TVM_DLL static Array<MeasureCallback, void> Default();
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MeasureCallback, ObjectRef, MeasureCallbackNode);
};

Expand Down
15 changes: 11 additions & 4 deletions include/tvm/meta_schedule/mutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,17 @@ class Mutator : public runtime::ObjectRef {
* \param f_as_string The packed function of `AsString`.
* \return The mutator created.
*/
TVM_DLL static Mutator PyMutator(FInitializeWithTuneContext f_initialize_with_tune_context, //
FApply f_apply, //
FClone f_clone, //
FAsString f_as_string);
TVM_DLL static Mutator PyMutator(FInitializeWithTuneContext f_initialize_with_tune_context,
FApply f_apply, FClone f_clone, FAsString f_as_string);
/*! \brief Create default mutators for LLVM */
TVM_DLL static Map<Mutator, FloatImm, void> DefaultLLVM();
/*! \brief Create default mutators for CUDA */
TVM_DLL static Map<Mutator, FloatImm, void> DefaultCUDA();
/*! \brief Create default mutators for CUDA with TensorCore */
TVM_DLL static Map<Mutator, FloatImm, void> DefaultCUDATensorCore();
/*! \brief Create default mutators for Hexagon */
TVM_DLL static Map<Mutator, FloatImm, void> DefaultHexagon();

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Mutator, ObjectRef, MutatorNode);
};

Expand Down
9 changes: 9 additions & 0 deletions include/tvm/meta_schedule/postproc.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,15 @@ class Postproc : public runtime::ObjectRef {
* \return The postprocessor created
*/
TVM_DLL static Postproc RewriteLayout();
/*! \brief Create default postprocessors for LLVM */
TVM_DLL static Array<Postproc, void> DefaultLLVM();
/*! \brief Create default postprocessors for CUDA */
TVM_DLL static Array<Postproc, void> DefaultCUDA();
/*! \brief Create default postprocessors for CUDA with TensorCore */
TVM_DLL static Array<Postproc, void> DefaultCUDATensorCore();
/*! \brief Create default postprocessors for Hexagon */
TVM_DLL static Array<Postproc, void> DefaultHexagon();

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Postproc, ObjectRef, PostprocNode);
};

Expand Down
18 changes: 14 additions & 4 deletions include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ class ScheduleRule : public runtime::ObjectRef {
Optional<Map<String, ObjectRef>> reuse_write);

/*!
* \brief Extension of MultiLevelTiling for auto-tensorizing with a single intrinsic.
* \param intrin_name The name of a tensor intrinsic, must be registerd via
* \brief Extension of MultiLevelTiling for auto-tensorization with a single intrinsic.
* \param intrin_name The name of a tensor intrinsic, must be registered via
* TensorIntrin.register(...) beforehand
* \param structure The tiling structure. Recommended:
* - 'SSRSRS' on CPU
Expand All @@ -162,12 +162,12 @@ class ScheduleRule : public runtime::ObjectRef {
Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write);

/*!
* \brief Extension of MultiLevelTiling for auto-tensorizing with multiple groups of candidate
* \brief Extension of MultiLevelTiling for auto-tensorization with multiple groups of candidate
* tensor core intrinsics
* \param intrin_groups A list of groups of tensor core intrinsics. The map should contains key
* "init", "load_a", "load_b", "compute", "store", which represent the tensor intrin for
* initialization, loading operand A, loading operand B, tensor core computation, storing the
* result. The value of the map should be names of tensor intrinsics, must be registerd via
* result. The value of the map should be names of tensor intrinsics, must be registered via
* TensorIntrin.register(...) beforehand
* \param structure The tiling structure. Recommended:
* - 'SSSRRSRS' on GPU
Expand Down Expand Up @@ -261,6 +261,16 @@ class ScheduleRule : public runtime::ObjectRef {
FApply f_apply, //
FClone f_clone, //
FAsString f_as_string);

/*! \brief Create default schedule rules for LLVM */
TVM_DLL static Array<ScheduleRule, void> DefaultLLVM();
/*! \brief Create default schedule rules for CUDA */
TVM_DLL static Array<ScheduleRule, void> DefaultCUDA();
/*! \brief Create default postprocessors for CUDA with TensorCore */
TVM_DLL static Array<ScheduleRule, void> DefaultCUDATensorCore();
/*! \brief Create default schedule rules for Hexagon */
TVM_DLL static Array<ScheduleRule, void> DefaultHexagon();

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ScheduleRule, ObjectRef, ScheduleRuleNode);
};

Expand Down
32 changes: 12 additions & 20 deletions include/tvm/meta_schedule/search_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,17 @@ class SearchStrategyNode : public runtime::Object {

/*!
* \brief Pre-tuning for the search strategy.
* \param max_trials The maximum number of trials.
* \param num_trials_per_iter The number of trials per iteration.
* \param design_spaces The design spaces used during tuning process.
* \param database The database used during tuning process.
* \param cost_model The cost model used during tuning process.
* \note Pre-tuning is supposed to be called before the tuning process and after the
* initialization. Because the search strategy is stateful, we can always call pretuning
* and reset the search strategy.
*/
virtual void PreTuning(const Array<tir::Schedule>& design_spaces,
virtual void PreTuning(int max_trials, int num_trials_per_iter,
const Array<tir::Schedule>& design_spaces,
const Optional<Database>& database,
const Optional<CostModel>& cost_model) = 0;

Expand Down Expand Up @@ -143,10 +146,10 @@ class SearchStrategy : public runtime::ObjectRef {
using FInitializeWithTuneContext = runtime::TypedPackedFunc<void(const TuneContext&)>;
/*!
* \brief The function type of `PreTuning` method.
* \param design_spaces The design spaces for pre-tuning.
*/
using FPreTuning = runtime::TypedPackedFunc<void(
const Array<tir::Schedule>&, const Optional<Database>&, const Optional<CostModel>&)>;
int max_trials, int num_trials_per_iter, const Array<tir::Schedule>&,
const Optional<Database>&, const Optional<CostModel>&)>;
/*! \brief The function type of `PostTuning` method. */
using FPostTuning = runtime::TypedPackedFunc<void()>;
/*!
Expand Down Expand Up @@ -185,24 +188,15 @@ class SearchStrategy : public runtime::ObjectRef {

/*!
* \brief Constructor of replay trace search strategy.
* \param num_trials_per_iter The number of trials per iteration, i.e., the batch size.
* \param max_trials_per_task The total number of trials for trace replaying.
* \param max_fail_count The max number of failures during trace replaying.
*/
TVM_DLL static SearchStrategy ReplayTrace(int num_trials_per_iter, int max_trials_per_task,
int max_fail_count);
TVM_DLL static SearchStrategy ReplayTrace(int max_fail_count);

/*!
* \brief Constructor of replay func search strategy.
* \param num_trials_per_iter The number of trials per iteration, i.e., the batch size.
* \param max_trials_per_task The total number of trials for func replaying.
*/
TVM_DLL static SearchStrategy ReplayFunc(int num_trials_per_iter, int max_trials_per_task);
/*! \brief Constructor of replay func search strategy. */
TVM_DLL static SearchStrategy ReplayFunc();

/*!
* \brief Constructor of evolutionary search strategy.
* \param num_trials_per_iter The number of trials per iteration, i.e., the batch size.
* \param max_trials_per_task The total number of trials for evolutionary search.
* \param population_size The initial sample population.
* \param init_measured_ratio The ratio of measures samples in initial population.
* \param init_min_unmeasured The minimal size of unmeasured population in the initial sampling.
Expand All @@ -211,9 +205,7 @@ class SearchStrategy : public runtime::ObjectRef {
* \param genetic_max_fail_count The maximum number to try evolving the given trace.
* \param eps_greedy The ratio to select samples in a greedy fashion via their predicted score.
*/
TVM_DLL static SearchStrategy EvolutionarySearch(int num_trials_per_iter, //
int max_trials_per_task, //
int population_size, //
TVM_DLL static SearchStrategy EvolutionarySearch(int population_size, //
double init_measured_ratio, //
int init_min_unmeasured, //
int genetic_num_iters, //
Expand Down Expand Up @@ -257,8 +249,8 @@ class PySearchStrategyNode : public SearchStrategyNode {
}

void InitializeWithTuneContext(const TuneContext& context) final;
void PreTuning(const Array<tir::Schedule>& design_spaces, const Optional<Database>& database,
const Optional<CostModel>& cost_model) final;
void PreTuning(int max_trials, int num_trials_per_iter, const Array<tir::Schedule>& design_spaces,
const Optional<Database>& database, const Optional<CostModel>& cost_model) final;
void PostTuning() final;
Optional<Array<MeasureCandidate>> GenerateMeasureCandidates() final;
void NotifyRunnerResults(const Array<MeasureCandidate>& measure_candidates,
Expand Down
53 changes: 48 additions & 5 deletions include/tvm/meta_schedule/space_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,14 @@
#define TVM_META_SCHEDULE_SPACE_GENERATOR_H_

#include <tvm/ir/module.h>
#include <tvm/meta_schedule/mutator.h>
#include <tvm/meta_schedule/postproc.h>
#include <tvm/meta_schedule/schedule_rule.h>
#include <tvm/node/reflection.h>
#include <tvm/runtime/container/array.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/target/target.h>
#include <tvm/tir/schedule/schedule.h>

namespace tvm {
Expand Down Expand Up @@ -71,6 +75,19 @@ class SpaceGenerator;
*/
class SpaceGeneratorNode : public runtime::Object {
public:
/*! \brief The schedule rules. */
Optional<Array<ScheduleRule>> sch_rules;
/*! \brief The postprocessors. */
Optional<Array<Postproc>> postprocs;
/*! \brief The probability of using certain mutator. */
Optional<Map<Mutator, FloatImm>> mutator_probs;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("sch_rules", &sch_rules);
v->Visit("postprocs", &postprocs);
v->Visit("mutator_probs", &mutator_probs);
}

/*! \brief Default destructor */
virtual ~SpaceGeneratorNode() = default;

Expand All @@ -79,7 +96,7 @@ class SpaceGeneratorNode : public runtime::Object {
* \param context The tuning context for initialization.
* \note This method is supposed to be called only once before every other method.
*/
virtual void InitializeWithTuneContext(const TuneContext& context) = 0;
virtual void InitializeWithTuneContext(const TuneContext& context);

/*!
* \brief Generate design spaces given a module.
Expand Down Expand Up @@ -127,12 +144,17 @@ class SpaceGenerator : public runtime::ObjectRef {
public:
/*!
* \brief Create a design space generator with customized methods on the python-side.
* \param sch_rules The schedule rules.
* \param postprocs The postprocessors.
* \param mutator_probs The probability of using certain mutator.
* \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
* \param f_generate_design_space The packed function of `GenerateDesignSpace`.
* \param f_clone The packed function of `Clone`.
* \return The design space generator created.
*/
TVM_DLL static SpaceGenerator PySpaceGenerator(
Optional<Array<ScheduleRule>> sch_rules, Optional<Array<Postproc>> postprocs,
Optional<Map<Mutator, FloatImm>> mutator_probs,
FInitializeWithTuneContext f_initialize_with_tune_context,
FGenerateDesignSpace f_generate_design_space, FClone f_clone);
/*!
Expand All @@ -141,19 +163,39 @@ class SpaceGenerator : public runtime::ObjectRef {
* 1) void(Schedule)
* 2) Schedule(Schedule)
* 3) Array<Schedule>(Schedule)
* \param sch_rules The schedule rules.
* \param postprocs The postprocessors.
* \param mutator_probs The probability of using certain mutator.
*/
TVM_DLL static SpaceGenerator ScheduleFn(PackedFunc schedule_fn);
TVM_DLL static SpaceGenerator ScheduleFn(PackedFunc schedule_fn,
Optional<Array<ScheduleRule>> sch_rules,
Optional<Array<Postproc>> postprocs,
Optional<Map<Mutator, FloatImm>> mutator_probs);
/*!
* \brief Create a design space generator that is union of multiple design space generators.
* \param space_generators An array of design space generators to be unioned.
* \param sch_rules The schedule rules.
* \param postprocs The postprocessors.
* \param mutator_probs The probability of using certain mutator.
* \return The design space generator created.
*/
TVM_DLL static SpaceGenerator SpaceGeneratorUnion(Array<SpaceGenerator, void> space_generators);
TVM_DLL static SpaceGenerator SpaceGeneratorUnion(Array<SpaceGenerator, void> space_generators,
Optional<Array<ScheduleRule>> sch_rules,
Optional<Array<Postproc>> postprocs,
Optional<Map<Mutator, FloatImm>> mutator_probs);
/*!
* \brief Create a design space generator that generates design spaces by applying schedule
* rules to blocks in post-DFS order. \return The design space generator created.
* rules to blocks in post-DFS order.
* \param f_block_filter The filter function to filter blocks to be applied with schedule rules.
* \param sch_rules The schedule rules.
* \param postprocs The postprocessors.
* \param mutator_probs The probability of using certain mutator.
* \return The design space generator created.
*/
TVM_DLL static SpaceGenerator PostOrderApply(runtime::PackedFunc f_block_filter = nullptr);
TVM_DLL static SpaceGenerator PostOrderApply(runtime::PackedFunc f_block_filter,
Optional<Array<ScheduleRule>> sch_rules,
Optional<Array<Postproc>> postprocs,
Optional<Map<Mutator, FloatImm>> mutator_probs);
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(SpaceGenerator, ObjectRef, SpaceGeneratorNode);
};

Expand All @@ -171,6 +213,7 @@ class PySpaceGeneratorNode : public SpaceGeneratorNode {
FClone f_clone;

void VisitAttrs(tvm::AttrVisitor* v) {
SpaceGeneratorNode::VisitAttrs(v);
// `f_initialize_with_tune_context` is not visited
// `f_generate_design_space` is not visited
// `f_clone` is not visited
Expand Down
Loading