Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 195 additions & 0 deletions include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#ifndef TVM_META_SCHEDULE_SCHEDULE_RULE_H_
#define TVM_META_SCHEDULE_SCHEDULE_RULE_H_

#include <tvm/tir/schedule/schedule.h>

namespace tvm {
namespace meta_schedule {

class TuneContext;

/*! \brief Rules to modify a block in a schedule. */
class ScheduleRuleNode : public runtime::Object {
public:
/*! \brief Virtual destructor. */
virtual ~ScheduleRuleNode() = default;

void VisitAttrs(tvm::AttrVisitor* v) {}

/*!
* \brief Initialize the design space generator with tuning context.
* \param context The tuning context for initialization.
* \note This method is supposed to be called only once before every other method.
*/
virtual void InitializeWithTuneContext(const TuneContext& context) = 0;

/*!
* \brief Apply a schedule rule to the specific block in the given schedule.
* \param sch The schedule to be modified.
* \param block The specific block to apply the schedule rule.
* \return The list of schedules generated by applying the schedule rule.
*/
virtual runtime::Array<tir::Schedule> Apply(const tir::Schedule& sch,
const tir::BlockRV& block) = 0;

static constexpr const char* _type_key = "meta_schedule.ScheduleRule";
TVM_DECLARE_BASE_OBJECT_INFO(ScheduleRuleNode, Object);
};

/*! \brief The schedule rule with customized methods on the python-side. */
class PyScheduleRuleNode : public ScheduleRuleNode {
public:
/*!
* \brief The function type of `InitializeWithTuneContext` method.
* \param context The tuning context for initialization.
*/
using FInitializeWithTuneContext = runtime::TypedPackedFunc<void(const TuneContext&)>;
/*!
* \brief The function type of `Apply` method.
* \param sch The schedule to be modified.
* \param block The specific block to apply the schedule rule.
* \return The list of schedules generated by applying the schedule rule.
*/
using FApply =
runtime::TypedPackedFunc<Array<tir::Schedule>(const tir::Schedule&, const tir::BlockRV&)>;
/*!
* \brief Get the schedule rule as string with name.
* \return The string of the schedule rule.
*/
using FAsString = runtime::TypedPackedFunc<String()>;

/*! \brief The packed function to the `InitializeWithTuneContext` function. */
FInitializeWithTuneContext f_initialize_with_tune_context;
/*! \brief The packed function to the `Apply` function. */
FApply f_apply;
/*! \brief The packed function to the `AsString` function. */
FAsString f_as_string;

void VisitAttrs(tvm::AttrVisitor* v) {
// `f_initialize_with_tune_context` is not visited
// `f_apply` is not visited
// `f_as_string` is not visited
}

void InitializeWithTuneContext(const TuneContext& context) final {
ICHECK(f_initialize_with_tune_context != nullptr)
<< "PyScheduleRule's InitializeWithTuneContext method not implemented!";
this->f_initialize_with_tune_context(context);
}

Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block) final {
ICHECK(f_apply != nullptr) << "PyScheduleRule's Apply method not implemented!";
return this->f_apply(sch, block);
}

static constexpr const char* _type_key = "meta_schedule.PyScheduleRule";
TVM_DECLARE_FINAL_OBJECT_INFO(PyScheduleRuleNode, ScheduleRuleNode);
};

/*!
* \brief Managed reference to ScheduleRuleNode
* \sa ScheduleRuleNode
*/
class ScheduleRule : public runtime::ObjectRef {
public:
/*!
* \brief Create an auto-inline rule that inlines spatial blocks if it satisfies some conditions
* \param into_producer If allows to inline a block into its producer
* \param into_consumer If allows to inline a block into its consumer
* \param into_cache_only If it only allows to inline into a block generated by cache_read/write
* \param inline_const_tensor Always inline constant tensors
* \param disallow_if_then_else Always disallow if-then-else-like constructs
* \param require_ordered Always require the read-to-write mapping to be ordered
* \param require_injective Always require the read-to-write mapping to be injective
* \param disallow_op The operators that are disallowed in auto inline
* \return The schedule rule created
*/
TVM_DLL static ScheduleRule AutoInline(bool into_producer, //
bool into_consumer, //
bool into_cache_only, //
bool inline_const_tensor, //
bool disallow_if_then_else, //
bool require_injective, //
bool require_ordered, //
Optional<Array<String>> disallow_op);
/*!
* \brief Create a mega rule: multi-level tiling with data reuse
* \param structure The tiling structure. Recommended:
* - 'SSRSRS' on CPU
* - 'SSSRRSRS' on GPU
* \param tile_binds For each level of tiles, which thread axis it is bound to. Recommended:
* - NullOpt on CPU
* - [blockIdx.x, vthread.x, threadIdx.x] on GPU
* \param use_tensor_core Whether to apply tensor core wmma intrinsic for the computation
* \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit
* \param vector_load_max_len The length of vector lane in vectorized cooperative fetching.
* NullOpt means disable vectorization
* \param reuse_read Data reuse configuration for reading. NullOpt means no reuse.
* \param reuse_write Data reuse configuration for writing. NullOpt means no reuse.
* \return The schedule rule created
*/
TVM_DLL static ScheduleRule MultiLevelTiling(String structure, //
Optional<Array<String>> tile_binds, //
bool use_tensor_core, //
Optional<Integer> max_innermost_factor, //
Optional<Integer> vector_load_max_len, //
Optional<Map<String, ObjectRef>> reuse_read, //
Optional<Map<String, ObjectRef>> reuse_write);
/*!
* \brief A rule that randomly select a compute-at location for a free block
* \return The rule created
*/
TVM_DLL static ScheduleRule RandomComputeLocation();
/*!
* \brief Mark parallelize, vectorize and unroll to each block correspondingly
* \param max_jobs_per_core The maximum number of jobs to be launched per CPU core. It sets the
* uplimit of CPU parallelism, i.e. `num_cores * max_jobs_per_core`. Use -1 to disable
* parallelism.
* \param max_vectorize_extent The maximum extent to be vectorized.
* It sets the uplimit of the CPU vectorization. Use -1 to disable vectorization.
* \param unroll_max_steps The maximum number of unroll steps to be done.
* Use an empty array to disable unroll.
* \param unroll_explicit Whether to explicitly unroll the loop, or just add a unroll pragma.
* \return The schedule rule created
*/
TVM_DLL static ScheduleRule ParallelizeVectorizeUnroll(int max_jobs_per_core, //
int max_vectorize_extent, //
Array<Integer> unroll_max_steps, //
bool unroll_explicit);
/*!
* \brief Create a schedule rule with customized methods on the python-side.
* \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
* \param f_apply The packed function of `Apply`.
* \param f_as_string The packed function of `AsString`.
* \return The schedule rule created.
*/
TVM_DLL static ScheduleRule PyScheduleRule(
PyScheduleRuleNode::FInitializeWithTuneContext f_initialize_with_tune_context, //
PyScheduleRuleNode::FApply f_apply, //
PyScheduleRuleNode::FAsString f_as_string);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ScheduleRule, ObjectRef, ScheduleRuleNode);
};

} // namespace meta_schedule
} // namespace tvm

#endif // TVM_META_SCHEDULE_SCHEDULE_RULE_H_
6 changes: 3 additions & 3 deletions include/tvm/meta_schedule/search_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,10 @@ class SearchStrategyNode : public runtime::Object {

/*!
* \brief Initialize the search strategy with tuning context.
* \param tune_context The tuning context for initialization.
* \param context The tuning context for initialization.
* \note This method is supposed to be called only once before every other method.
*/
virtual void InitializeWithTuneContext(const TuneContext& tune_context) = 0;
virtual void InitializeWithTuneContext(const TuneContext& context) = 0;

/*!
* \brief Pre-tuning for the search strategy.
Expand Down Expand Up @@ -146,7 +146,7 @@ class PySearchStrategyNode : public SearchStrategyNode {
public:
/*!
* \brief The function type of `InitializeWithTuneContext` method.
* \param tune_context The tuning context for initialization.
* \param context The tuning context for initialization.
*/
using FInitializeWithTuneContext = runtime::TypedPackedFunc<void(const TuneContext&)>;
/*!
Expand Down
16 changes: 11 additions & 5 deletions include/tvm/meta_schedule/space_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ class SpaceGeneratorNode : public Object {

/*!
* \brief Initialize the design space generator with tuning context.
* \param tune_context The tuning context for initialization.
* \param context The tuning context for initialization.
* \note This method is supposed to be called only once before every other method.
*/
virtual void InitializeWithTuneContext(const TuneContext& tune_context) = 0;
virtual void InitializeWithTuneContext(const TuneContext& context) = 0;

/*!
* \brief Generate design spaces given a module.
Expand All @@ -92,7 +92,7 @@ class PySpaceGeneratorNode : public SpaceGeneratorNode {
public:
/*!
* \brief The function type of `InitializeWithTuneContext` method.
* \param tune_context The tuning context for initialization.
* \param context The tuning context for initialization.
*/
using FInitializeWithTuneContext = runtime::TypedPackedFunc<void(const TuneContext&)>;
/*!
Expand All @@ -112,10 +112,10 @@ class PySpaceGeneratorNode : public SpaceGeneratorNode {
// `f_generate_design_space` is not visited
}

void InitializeWithTuneContext(const TuneContext& tune_context) final {
void InitializeWithTuneContext(const TuneContext& context) final {
ICHECK(f_initialize_with_tune_context != nullptr)
<< "PySpaceGenerator's InitializeWithTuneContext !";
f_initialize_with_tune_context(tune_context);
f_initialize_with_tune_context(context);
}

Array<tir::Schedule> GenerateDesignSpace(const IRModule& mod) final {
Expand Down Expand Up @@ -153,6 +153,12 @@ class SpaceGenerator : public ObjectRef {
* \return The design space generator created.
*/
TVM_DLL static SpaceGenerator SpaceGeneratorUnion(Array<SpaceGenerator, void> space_generators);
/*!
* \brief Create a design space generator that generates design spaces by applying schedule rules
* to blocks in post-DFS order.
* \return The design space generator created.
*/
TVM_DLL static SpaceGenerator PostOrderApply();
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(SpaceGenerator, ObjectRef, SpaceGeneratorNode);
};

Expand Down
5 changes: 5 additions & 0 deletions include/tvm/meta_schedule/tune_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class TuneContextNode : public runtime::Object {
Optional<SpaceGenerator> space_generator;
/*! \brief The search strategy. */
Optional<SearchStrategy> search_strategy;
/*! \brief The schedule rules. */
Array<ScheduleRule> sch_rules;
/*! \brief The name of the tuning task. */
Optional<String> task_name;
/*! \brief The random state. */
Expand All @@ -57,6 +59,7 @@ class TuneContextNode : public runtime::Object {
v->Visit("target", &target);
v->Visit("space_generator", &space_generator);
v->Visit("search_strategy", &search_strategy);
v->Visit("sch_rules", &sch_rules);
v->Visit("task_name", &task_name);
v->Visit("rand_state", &rand_state);
v->Visit("num_threads", &num_threads);
Expand All @@ -81,6 +84,7 @@ class TuneContext : public runtime::ObjectRef {
* \param target The target to be tuned for.
* \param space_generator The design space generator.
* \param search_strategy The search strategy.
* \param sch_rules The schedule rules.
* \param task_name The name of the tuning task.
* \param rand_state The random state.
* \param num_threads The number of threads to be used.
Expand All @@ -89,6 +93,7 @@ class TuneContext : public runtime::ObjectRef {
Optional<Target> target, //
Optional<SpaceGenerator> space_generator, //
Optional<SearchStrategy> search_strategy, //
Optional<Array<ScheduleRule>> sch_rules, //
Optional<String> task_name, //
support::LinearCongruentialEngine::TRandState rand_state, //
int num_threads);
Expand Down
6 changes: 6 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,12 @@ class ScheduleNode : public runtime::Object {
* \return The corresponding loop sref
*/
virtual StmtSRef GetSRef(const LoopRV& loop_rv) const = 0;
/*!
* \brief Check the existance of a specific BlockRV
* \param block_rv The BlockRV to be looked up
* \return Whether the corresponding block exists
*/
virtual bool HasBlock(const BlockRV& block_rv) const = 0;
/*!
* \brief Get the block/loop sref corresponding to the specific statement
* \param stmt The statement to be looked up
Expand Down
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@
from . import runner
from . import space_generator
from . import search_strategy
from . import schedule_rule
from . import integration
from .tune_context import TuneContext
19 changes: 19 additions & 0 deletions python/tvm/meta_schedule/schedule_rule/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
The tvm.meta_schedule.schedule_rule package.
Meta Schedule schedule rules are used for modification of
blocks in a schedule. See also PostOrderApply.
"""
from .schedule_rule import PyScheduleRule, ScheduleRule
Loading