Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion python/tvm/auto_scheduler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@
is_auto_scheduler_enabled,
)
from .search_task import SearchTask, TuningOptions, HardwareParams, create_task, auto_schedule
from .search_policy import EmptyPolicy, SketchPolicy, PreloadMeasuredStates
from .search_policy import (
EmptyPolicy,
SketchPolicy,
PreloadMeasuredStates,
PreloadCustomSketchRule,
)
from .task_scheduler import TaskScheduler
from .workload_registry import register_workload, make_workload_key
35 changes: 33 additions & 2 deletions python/tvm/auto_scheduler/search_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,39 @@ def __init__(self, filename):
self.__init_handle_by_constructor__(_ffi_api.PreloadMeasuredStates, filename)


@tvm._ffi.register_object("auto_scheduler.PreloadCustomSketchRule")
class PreloadCustomSketchRule(SearchCallback):
"""
A SearchCallback for SketchSearchPolicy that allows users to add
custom sketch rule.

Notes
-----
This is an advanced feature. Make sure you're clear how it works and this should only be used
in SketchSearchPolicy.

Parameters
----------
meet_condition_func: Callable
A function with `(policy, state, stage_id) -> int`. Should return one of the result
enumeration.
apply_func: Callable
A function with `(policy, state, stage_id) -> [[State, int], ...]`.
rule_name: str = "CustomSketchRule"
The name of this custom sketch rule.
"""

# Result enumeration of the condition function.
PASS = 0 # Skip this rule and continue to try the next rules.
APPLY = 1 # Apply this rule and continue to try the next rules.
APPLY_AND_SKIP_REST = 2 # Apply this rule and skip the rest rules.

def __init__(self, meet_condition_func, apply_func, rule_name="CustomSketchRule"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From your test case I think we should not provide a default rule name. Otherwise it's easy to get the rule name conflict if users call PreloadCustomSketchRule twice.

self.__init_handle_by_constructor__(
_ffi_api.PreloadCustomSketchRule, meet_condition_func, apply_func, rule_name
)


@tvm._ffi.register_object("auto_scheduler.SearchPolicy")
class SearchPolicy(Object):
""" The base class of search policies. """
Expand Down Expand Up @@ -141,8 +174,6 @@ class SketchPolicy(SearchPolicy):

- auto_scheduler.PreloadMeasuredStates
- auto_scheduler.PreloadCustomSketchRule

TODO(jcf94): Add these search callback implementations.
"""

DEFAULT_PARAMS = {
Expand Down
5 changes: 4 additions & 1 deletion python/tvm/auto_scheduler/search_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,14 +228,17 @@ def __init__(
if isinstance(target_host, str):
target_host = Target(target_host)

if layout_rewrite_option is None:
layout_rewrite_option = LayoutRewriteOption.get_target_default(target)

self.__init_handle_by_constructor__(
_ffi_api.SearchTask,
compute_dag,
workload_key,
target,
target_host,
hardware_params,
layout_rewrite_option or LayoutRewriteOption.get_target_default(target),
layout_rewrite_option,
)

def tune(self, tuning_options, search_policy=None):
Expand Down
14 changes: 12 additions & 2 deletions python/tvm/auto_scheduler/task_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def make_search_policies(
Load measurement records from this file. If it is not None, the status of the
task scheduler, search policies and cost models will be restored according to this file.
adapative_training: bool = False
Option used for XGBModel, which will reduce the model training frequency when there're too
Option used by XGBModel to reduce the model training frequency when there're too
many logs.

Returns
Expand Down Expand Up @@ -275,7 +275,13 @@ def __init__(
self.group_task_ids.append([])
self.group_task_ids[self.tag_to_group_id[tag]].append(i)

def tune(self, tune_option, search_policy="default", search_policy_params=None):
def tune(
self,
tune_option,
search_policy="default",
search_policy_params=None,
adapative_training=False,
):
"""Tune a batch of tasks together.

Parameters
Expand All @@ -290,6 +296,9 @@ def tune(self, tune_option, search_policy="default", search_policy_params=None):
"sketch.random" for SketchPolicy + RandomModel.
search_policy_params : Optional[Dict[str, Any]]
The parameters of the search policy
adapative_training : bool = False
Option used by XGBModel to reduce the model training frequency when there're
too many logs.
"""
# init members
self.tune_option = tune_option
Expand Down Expand Up @@ -324,6 +333,7 @@ def tune(self, tune_option, search_policy="default", search_policy_params=None):
tune_option.verbose,
self.load_model_file,
self.load_log_file,
adapative_training,
)

# do a round robin first to warm up
Expand Down
25 changes: 25 additions & 0 deletions src/auto_scheduler/search_policy/sketch_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,26 @@ Array<MeasureInput> SketchPolicyNode::PickStatesWithEpsGreedy(const Array<State>
return inputs;
}

/********** PreloadCustomSketchRule **********/
TVM_REGISTER_OBJECT_TYPE(PreloadCustomSketchRuleNode);

PreloadCustomSketchRule::PreloadCustomSketchRule(PackedFunc meet_condition_func,
PackedFunc apply_func, String rule_name) {
auto node = make_object<PreloadCustomSketchRuleNode>();
node->meet_condition_func = std::move(meet_condition_func);
node->apply_func = std::move(apply_func);
node->rule_name = std::move(rule_name);
data_ = std::move(node);
}

void PreloadCustomSketchRuleNode::Callback(SearchPolicyNode* policy) {
CHECK(policy->IsInstance<SketchPolicyNode>());
auto sketch_policy = dynamic_cast<SketchPolicyNode*>(policy);
sketch_policy->sketch_rules.push_back(
new RuleCustomSketch(meet_condition_func, apply_func, rule_name));
StdCout(policy->verbose) << "Custom sketch rule \"" << rule_name << "\" added." << std::endl;
}

TVM_REGISTER_GLOBAL("auto_scheduler.SketchPolicy")
.set_body_typed([](SearchTask task, CostModel program_cost_model, Map<String, ObjectRef> params,
int seed, int verbose,
Expand Down Expand Up @@ -699,5 +719,10 @@ TVM_REGISTER_GLOBAL("auto_scheduler.PrintTitle").set_body_typed([](std::string t
PrintTitle(title, 1);
});

TVM_REGISTER_GLOBAL("auto_scheduler.PreloadCustomSketchRule")
.set_body_typed([](PackedFunc meet_condition_func, PackedFunc apply_func, String rule_name) {
return PreloadCustomSketchRule(meet_condition_func, apply_func, rule_name);
});

} // namespace auto_scheduler
} // namespace tvm
34 changes: 34 additions & 0 deletions src/auto_scheduler/search_policy/sketch_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,40 @@ class SketchPolicy : public SearchPolicy {
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SketchPolicy, SearchPolicy, SketchPolicyNode);
};

/*! \brief Pre-search callback function to load custom rules for sketch generation */
class PreloadCustomSketchRuleNode : public SearchCallbackNode {
public:
/*! \brief The condition check function of this rule. */
PackedFunc meet_condition_func;
/*! \brief The apply function of this rule. */
PackedFunc apply_func;
/*! \brief The name of this rule. */
String rule_name;

void Callback(SearchPolicyNode* policy) final;

static constexpr const char* _type_key = "auto_scheduler.PreloadCustomSketchRule";
TVM_DECLARE_FINAL_OBJECT_INFO(PreloadCustomSketchRuleNode, SearchCallbackNode);
};

/*!
* \brief Managed reference to PreloadCustomSketchRuleNode.
* \sa PreloadCustomSketchRuleNode
*/
class PreloadCustomSketchRule : public SearchCallback {
public:
/*!
* \brief The constructor.
* \param meet_condition_func The condition check function of this rule.
* \param apply_func The apply function of this rule.
* \param rule_name The name of this rule.
*/
PreloadCustomSketchRule(PackedFunc meet_condition_func, PackedFunc apply_func, String rule_name);

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PreloadCustomSketchRule, SearchCallback,
PreloadCustomSketchRuleNode);
};

} // namespace auto_scheduler
} // namespace tvm

Expand Down
27 changes: 27 additions & 0 deletions src/auto_scheduler/search_policy/sketch_policy_rules.cc
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,33 @@ std::vector<std::pair<State, int>> RuleSpecialComputeLocationGPU::Apply(
return {std::make_pair(std::move(tmp_s), stage_id - 1)};
}

/********** RuleCustomSketch **********/

SketchGenerationRule::ConditionKind RuleCustomSketch::MeetCondition(const SketchPolicyNode& policy,
const State& state,
int stage_id) const {
auto ret = meet_condition_func_(tvm::runtime::GetRef<SketchPolicy>(&policy), state, stage_id);
if (ret.type_code() == 0) {
return ConditionKind(static_cast<int>(ret));
} else {
LOG(WARNING) << "Wrong rule condition value. Apply the rule and skip the rest";
return ConditionKind::kApplyAndSkipRest;
}
}

std::vector<std::pair<State, int>> RuleCustomSketch::Apply(const SketchPolicyNode& policy,
const State& state, int stage_id) const {
Array<Array<ObjectRef>> apply_ret =
apply_func_(tvm::runtime::GetRef<SketchPolicy>(&policy), state, stage_id);
std::vector<std::pair<State, int>> ret;
for (const auto& item : apply_ret) {
CHECK_EQ(item.size(), 2);
auto next = item[1].as<IntImmNode>();
ret.emplace_back(Downcast<State>(item[0]), next->value);
}
return ret;
}

/********** Init Population **********/

PopulationGenerationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* policy, State* state,
Expand Down
23 changes: 23 additions & 0 deletions src/auto_scheduler/search_policy/sketch_policy_rules.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,29 @@ DEFINE_SKETCH_GENERATION_RULE(RuleCrossThreadReduction);
* location of the producers of compute ops that perform "fake reduction" with const tensors. */
DEFINE_SKETCH_GENERATION_RULE(RuleSpecialComputeLocationGPU);

/*! \brief The rule that allows users to generate custom sketches. */
class RuleCustomSketch : public SketchGenerationRule {
public:
RuleCustomSketch(PackedFunc meet_condition_func, PackedFunc apply_func,
String rule_name = "CustomSketchRule")
: meet_condition_func_(std::move(meet_condition_func)),
apply_func_(std::move(apply_func)),
rule_name_(std::move(rule_name)) {}

ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state,
int stage_id) const final;

std::vector<std::pair<State, int>> Apply(const SketchPolicyNode& policy, const State& state,
int stage_id) const final;

std::string GetRuleName() const final { return rule_name_; }

private:
PackedFunc meet_condition_func_;
PackedFunc apply_func_;
String rule_name_;
};

/********** Init Population **********/

/*! \brief The base class for rules used to annotate the sketches to get the initial population. */
Expand Down
27 changes: 27 additions & 0 deletions tests/python/unittest/test_auto_scheduler_search_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,32 @@ def test_sketch_search_policy_zero_rank():
search_common(task, runner=measure_ctx.runner)


@tvm.testing.requires_llvm
def test_sketch_search_policy_custom_sketch():
def meet_condition_func(search_policy, state, stage_id):
return auto_scheduler.PreloadCustomSketchRule.APPLY_AND_SKIP_REST

def apply_func(search_policy, state, stage_id):
ret = []
state = auto_scheduler.loop_state.State(state, search_policy.search_task.compute_dag)
C = state.stage_ops[2]

ret.append([state.state_object, -1])

s1 = state.copy()
i, _, _ = s1[C].iters
s1.split(C, i, [8])
ret.append([s1.state_object, -1])
return ret

search_common(
cost_model=auto_scheduler.XGBModel(),
init_search_callbacks=[
auto_scheduler.PreloadCustomSketchRule(meet_condition_func, apply_func)
],
)


if __name__ == "__main__":
test_workload_registry_empty_policy()
test_sketch_search_policy_basic()
Expand All @@ -191,3 +217,4 @@ def test_sketch_search_policy_zero_rank():
test_sketch_search_policy_cuda_rpc_runner()
test_sketch_search_policy_cuda_xgbmodel_rpc_runner()
test_sketch_search_policy_zero_rank()
test_sketch_search_policy_custom_sketch()
45 changes: 43 additions & 2 deletions tests/python/unittest/test_auto_scheduler_sketch_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,13 @@
)


def generate_sketches(workload_func, args, target, print_for_debug=False):
def generate_sketches(
workload_func, args, target, print_for_debug=False, init_search_callbacks=None
):
task = auto_scheduler.SearchTask(func=workload_func, args=args, target=target)
policy = auto_scheduler.SketchPolicy(task, verbose=0)
policy = auto_scheduler.SketchPolicy(
task, verbose=0, init_search_callbacks=init_search_callbacks
)
return policy.generate_sketches(print_for_debug)


Expand Down Expand Up @@ -259,6 +263,42 @@ def test_cpu_zero_rank_sketch():
assert len(sketches) == 3


def test_cpu_custom_sketch():
def meet_condition_func(search_policy, state, stage_id):
return auto_scheduler.PreloadCustomSketchRule.APPLY_AND_SKIP_REST

def apply_func(search_policy, state, stage_id):
ret = []
state = auto_scheduler.loop_state.State(state, search_policy.search_task.compute_dag)
C = state.stage_ops[2]

ret.append([state.state_object, -1])

s1 = state.copy()
i, _, _ = s1[C].iters
s1.split(C, i, [8, 2])
ret.append([s1.state_object, -1])
return ret

sketches = generate_sketches(
matmul_auto_scheduler_test,
(512, 512, 512),
"llvm",
init_search_callbacks=[
auto_scheduler.PreloadCustomSketchRule(meet_condition_func, apply_func)
],
)
assert len(sketches) == 2
assert sketches[0].stages[2].iters[0].range.extent == 512
assert sketches[0].stages[2].iters[1].range.extent == 512
assert sketches[0].stages[2].iters[2].range.extent == 512
assert sketches[1].stages[2].iters[0].range.extent == 32
assert sketches[1].stages[2].iters[1].range.extent == 8
assert sketches[1].stages[2].iters[2].range.extent == 2
assert sketches[1].stages[2].iters[3].range.extent == 512
assert sketches[1].stages[2].iters[4].range.extent == 512


@tvm.testing.requires_cuda
def test_cuda_matmul_sketch():
sketches = generate_sketches(matmul_auto_scheduler_test, (512, 512, 512), "cuda")
Expand Down Expand Up @@ -407,6 +447,7 @@ def test_cuda_zero_rank_sketch():
test_cpu_softmax_sketch()
test_cpu_conv2d_winograd_sketch()
test_cpu_zero_rank_sketch()
test_cpu_custom_sketch()
test_cuda_matmul_sketch()
test_cuda_conv2d_bn_relu_sketch()
test_cuda_max_pool2d_sketch()
Expand Down