diff --git a/python/tvm/auto_scheduler/__init__.py b/python/tvm/auto_scheduler/__init__.py index a03e156cc10f..57e58309525c 100644 --- a/python/tvm/auto_scheduler/__init__.py +++ b/python/tvm/auto_scheduler/__init__.py @@ -50,6 +50,11 @@ is_auto_scheduler_enabled, ) from .search_task import SearchTask, TuningOptions, HardwareParams, create_task, auto_schedule -from .search_policy import EmptyPolicy, SketchPolicy, PreloadMeasuredStates +from .search_policy import ( + EmptyPolicy, + SketchPolicy, + PreloadMeasuredStates, + PreloadCustomSketchRule, +) from .task_scheduler import TaskScheduler from .workload_registry import register_workload, make_workload_key diff --git a/python/tvm/auto_scheduler/search_policy.py b/python/tvm/auto_scheduler/search_policy.py index 5b15a48943d2..f0388a886c5f 100644 --- a/python/tvm/auto_scheduler/search_policy.py +++ b/python/tvm/auto_scheduler/search_policy.py @@ -61,6 +61,39 @@ def __init__(self, filename): self.__init_handle_by_constructor__(_ffi_api.PreloadMeasuredStates, filename) +@tvm._ffi.register_object("auto_scheduler.PreloadCustomSketchRule") +class PreloadCustomSketchRule(SearchCallback): + """ + A SearchCallback for SketchSearchPolicy that allows users to add + custom sketch rule. + + Notes + ----- + This is an advanced feature. Make sure you're clear how it works and this should only be used + in SketchSearchPolicy. + + Parameters + ---------- + meet_condition_func: Callable + A function with `(policy, state, stage_id) -> int`. Should return one of the result + enumeration. + apply_func: Callable + A function with `(policy, state, stage_id) -> [[State, int], ...]`. + rule_name: str = "CustomSketchRule" + The name of this custom sketch rule. + """ + + # Result enumeration of the condition function. + PASS = 0 # Skip this rule and continue to try the next rules. + APPLY = 1 # Apply this rule and continue to try the next rules. + APPLY_AND_SKIP_REST = 2 # Apply this rule and skip the rest rules. + + def __init__(self, meet_condition_func, apply_func, rule_name="CustomSketchRule"): + self.__init_handle_by_constructor__( + _ffi_api.PreloadCustomSketchRule, meet_condition_func, apply_func, rule_name + ) + + @tvm._ffi.register_object("auto_scheduler.SearchPolicy") class SearchPolicy(Object): """ The base class of search policies. """ @@ -141,8 +174,6 @@ class SketchPolicy(SearchPolicy): - auto_scheduler.PreloadMeasuredStates - auto_scheduler.PreloadCustomSketchRule - - TODO(jcf94): Add these search callback implementations. """ DEFAULT_PARAMS = { diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index bfa596a1dc61..d985ed1341f5 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -228,6 +228,9 @@ def __init__( if isinstance(target_host, str): target_host = Target(target_host) + if layout_rewrite_option is None: + layout_rewrite_option = LayoutRewriteOption.get_target_default(target) + self.__init_handle_by_constructor__( _ffi_api.SearchTask, compute_dag, @@ -235,7 +238,7 @@ def __init__( target, target_host, hardware_params, - layout_rewrite_option or LayoutRewriteOption.get_target_default(target), + layout_rewrite_option, ) def tune(self, tuning_options, search_policy=None): diff --git a/python/tvm/auto_scheduler/task_scheduler.py b/python/tvm/auto_scheduler/task_scheduler.py index 975306f7be54..420b5f765a97 100644 --- a/python/tvm/auto_scheduler/task_scheduler.py +++ b/python/tvm/auto_scheduler/task_scheduler.py @@ -72,7 +72,7 @@ def make_search_policies( Load measurement records from this file. If it is not None, the status of the task scheduler, search policies and cost models will be restored according to this file. adapative_training: bool = False - Option used for XGBModel, which will reduce the model training frequency when there're too + Option used by XGBModel to reduce the model training frequency when there're too many logs. Returns @@ -275,7 +275,13 @@ def __init__( self.group_task_ids.append([]) self.group_task_ids[self.tag_to_group_id[tag]].append(i) - def tune(self, tune_option, search_policy="default", search_policy_params=None): + def tune( + self, + tune_option, + search_policy="default", + search_policy_params=None, + adapative_training=False, + ): """Tune a batch of tasks together. Parameters @@ -290,6 +296,9 @@ def tune(self, tune_option, search_policy="default", search_policy_params=None): "sketch.random" for SketchPolicy + RandomModel. search_policy_params : Optional[Dict[str, Any]] The parameters of the search policy + adapative_training : bool = False + Option used by XGBModel to reduce the model training frequency when there're + too many logs. """ # init members self.tune_option = tune_option @@ -324,6 +333,7 @@ def tune(self, tune_option, search_policy="default", search_policy_params=None): tune_option.verbose, self.load_model_file, self.load_log_file, + adapative_training, ) # do a round robin first to warm up diff --git a/src/auto_scheduler/search_policy/sketch_policy.cc b/src/auto_scheduler/search_policy/sketch_policy.cc index 1e20b0fff6ea..91721afdba74 100644 --- a/src/auto_scheduler/search_policy/sketch_policy.cc +++ b/src/auto_scheduler/search_policy/sketch_policy.cc @@ -671,6 +671,26 @@ Array SketchPolicyNode::PickStatesWithEpsGreedy(const Array return inputs; } +/********** PreloadCustomSketchRule **********/ +TVM_REGISTER_OBJECT_TYPE(PreloadCustomSketchRuleNode); + +PreloadCustomSketchRule::PreloadCustomSketchRule(PackedFunc meet_condition_func, + PackedFunc apply_func, String rule_name) { + auto node = make_object(); + node->meet_condition_func = std::move(meet_condition_func); + node->apply_func = std::move(apply_func); + node->rule_name = std::move(rule_name); + data_ = std::move(node); +} + +void PreloadCustomSketchRuleNode::Callback(SearchPolicyNode* policy) { + CHECK(policy->IsInstance()); + auto sketch_policy = dynamic_cast(policy); + sketch_policy->sketch_rules.push_back( + new RuleCustomSketch(meet_condition_func, apply_func, rule_name)); + StdCout(policy->verbose) << "Custom sketch rule \"" << rule_name << "\" added." << std::endl; +} + TVM_REGISTER_GLOBAL("auto_scheduler.SketchPolicy") .set_body_typed([](SearchTask task, CostModel program_cost_model, Map params, int seed, int verbose, @@ -699,5 +719,10 @@ TVM_REGISTER_GLOBAL("auto_scheduler.PrintTitle").set_body_typed([](std::string t PrintTitle(title, 1); }); +TVM_REGISTER_GLOBAL("auto_scheduler.PreloadCustomSketchRule") + .set_body_typed([](PackedFunc meet_condition_func, PackedFunc apply_func, String rule_name) { + return PreloadCustomSketchRule(meet_condition_func, apply_func, rule_name); + }); + } // namespace auto_scheduler } // namespace tvm diff --git a/src/auto_scheduler/search_policy/sketch_policy.h b/src/auto_scheduler/search_policy/sketch_policy.h index 488634902a87..faf058b45b19 100644 --- a/src/auto_scheduler/search_policy/sketch_policy.h +++ b/src/auto_scheduler/search_policy/sketch_policy.h @@ -197,6 +197,40 @@ class SketchPolicy : public SearchPolicy { TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SketchPolicy, SearchPolicy, SketchPolicyNode); }; +/*! \brief Pre-search callback function to load custom rules for sketch generation */ +class PreloadCustomSketchRuleNode : public SearchCallbackNode { + public: + /*! \brief The condition check function of this rule. */ + PackedFunc meet_condition_func; + /*! \brief The apply function of this rule. */ + PackedFunc apply_func; + /*! \brief The name of this rule. */ + String rule_name; + + void Callback(SearchPolicyNode* policy) final; + + static constexpr const char* _type_key = "auto_scheduler.PreloadCustomSketchRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(PreloadCustomSketchRuleNode, SearchCallbackNode); +}; + +/*! + * \brief Managed reference to PreloadCustomSketchRuleNode. + * \sa PreloadCustomSketchRuleNode + */ +class PreloadCustomSketchRule : public SearchCallback { + public: + /*! + * \brief The constructor. + * \param meet_condition_func The condition check function of this rule. + * \param apply_func The apply function of this rule. + * \param rule_name The name of this rule. + */ + PreloadCustomSketchRule(PackedFunc meet_condition_func, PackedFunc apply_func, String rule_name); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PreloadCustomSketchRule, SearchCallback, + PreloadCustomSketchRuleNode); +}; + } // namespace auto_scheduler } // namespace tvm diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.cc b/src/auto_scheduler/search_policy/sketch_policy_rules.cc index f704fe9e82d5..110be6bd6f68 100644 --- a/src/auto_scheduler/search_policy/sketch_policy_rules.cc +++ b/src/auto_scheduler/search_policy/sketch_policy_rules.cc @@ -461,6 +461,33 @@ std::vector> RuleSpecialComputeLocationGPU::Apply( return {std::make_pair(std::move(tmp_s), stage_id - 1)}; } +/********** RuleCustomSketch **********/ + +SketchGenerationRule::ConditionKind RuleCustomSketch::MeetCondition(const SketchPolicyNode& policy, + const State& state, + int stage_id) const { + auto ret = meet_condition_func_(tvm::runtime::GetRef(&policy), state, stage_id); + if (ret.type_code() == 0) { + return ConditionKind(static_cast(ret)); + } else { + LOG(WARNING) << "Wrong rule condition value. Apply the rule and skip the rest"; + return ConditionKind::kApplyAndSkipRest; + } +} + +std::vector> RuleCustomSketch::Apply(const SketchPolicyNode& policy, + const State& state, int stage_id) const { + Array> apply_ret = + apply_func_(tvm::runtime::GetRef(&policy), state, stage_id); + std::vector> ret; + for (const auto& item : apply_ret) { + CHECK_EQ(item.size(), 2); + auto next = item[1].as(); + ret.emplace_back(Downcast(item[0]), next->value); + } + return ret; +} + /********** Init Population **********/ PopulationGenerationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* policy, State* state, diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.h b/src/auto_scheduler/search_policy/sketch_policy_rules.h index 046f036d59d9..fc1916b8c67d 100644 --- a/src/auto_scheduler/search_policy/sketch_policy_rules.h +++ b/src/auto_scheduler/search_policy/sketch_policy_rules.h @@ -131,6 +131,29 @@ DEFINE_SKETCH_GENERATION_RULE(RuleCrossThreadReduction); * location of the producers of compute ops that perform "fake reduction" with const tensors. */ DEFINE_SKETCH_GENERATION_RULE(RuleSpecialComputeLocationGPU); +/*! \brief The rule that allows users to generate custom sketches. */ +class RuleCustomSketch : public SketchGenerationRule { + public: + RuleCustomSketch(PackedFunc meet_condition_func, PackedFunc apply_func, + String rule_name = "CustomSketchRule") + : meet_condition_func_(std::move(meet_condition_func)), + apply_func_(std::move(apply_func)), + rule_name_(std::move(rule_name)) {} + + ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state, + int stage_id) const final; + + std::vector> Apply(const SketchPolicyNode& policy, const State& state, + int stage_id) const final; + + std::string GetRuleName() const final { return rule_name_; } + + private: + PackedFunc meet_condition_func_; + PackedFunc apply_func_; + String rule_name_; +}; + /********** Init Population **********/ /*! \brief The base class for rules used to annotate the sketches to get the initial population. */ diff --git a/tests/python/unittest/test_auto_scheduler_search_policy.py b/tests/python/unittest/test_auto_scheduler_search_policy.py index c96dc63fec29..30aafbd22390 100644 --- a/tests/python/unittest/test_auto_scheduler_search_policy.py +++ b/tests/python/unittest/test_auto_scheduler_search_policy.py @@ -183,6 +183,32 @@ def test_sketch_search_policy_zero_rank(): search_common(task, runner=measure_ctx.runner) +@tvm.testing.requires_llvm +def test_sketch_search_policy_custom_sketch(): + def meet_condition_func(search_policy, state, stage_id): + return auto_scheduler.PreloadCustomSketchRule.APPLY_AND_SKIP_REST + + def apply_func(search_policy, state, stage_id): + ret = [] + state = auto_scheduler.loop_state.State(state, search_policy.search_task.compute_dag) + C = state.stage_ops[2] + + ret.append([state.state_object, -1]) + + s1 = state.copy() + i, _, _ = s1[C].iters + s1.split(C, i, [8]) + ret.append([s1.state_object, -1]) + return ret + + search_common( + cost_model=auto_scheduler.XGBModel(), + init_search_callbacks=[ + auto_scheduler.PreloadCustomSketchRule(meet_condition_func, apply_func) + ], + ) + + if __name__ == "__main__": test_workload_registry_empty_policy() test_sketch_search_policy_basic() @@ -191,3 +217,4 @@ def test_sketch_search_policy_zero_rank(): test_sketch_search_policy_cuda_rpc_runner() test_sketch_search_policy_cuda_xgbmodel_rpc_runner() test_sketch_search_policy_zero_rank() + test_sketch_search_policy_custom_sketch() diff --git a/tests/python/unittest/test_auto_scheduler_sketch_generation.py b/tests/python/unittest/test_auto_scheduler_sketch_generation.py index ddff6dd1a8d6..f3be6c0bc518 100644 --- a/tests/python/unittest/test_auto_scheduler_sketch_generation.py +++ b/tests/python/unittest/test_auto_scheduler_sketch_generation.py @@ -36,9 +36,13 @@ ) -def generate_sketches(workload_func, args, target, print_for_debug=False): +def generate_sketches( + workload_func, args, target, print_for_debug=False, init_search_callbacks=None +): task = auto_scheduler.SearchTask(func=workload_func, args=args, target=target) - policy = auto_scheduler.SketchPolicy(task, verbose=0) + policy = auto_scheduler.SketchPolicy( + task, verbose=0, init_search_callbacks=init_search_callbacks + ) return policy.generate_sketches(print_for_debug) @@ -259,6 +263,42 @@ def test_cpu_zero_rank_sketch(): assert len(sketches) == 3 +def test_cpu_custom_sketch(): + def meet_condition_func(search_policy, state, stage_id): + return auto_scheduler.PreloadCustomSketchRule.APPLY_AND_SKIP_REST + + def apply_func(search_policy, state, stage_id): + ret = [] + state = auto_scheduler.loop_state.State(state, search_policy.search_task.compute_dag) + C = state.stage_ops[2] + + ret.append([state.state_object, -1]) + + s1 = state.copy() + i, _, _ = s1[C].iters + s1.split(C, i, [8, 2]) + ret.append([s1.state_object, -1]) + return ret + + sketches = generate_sketches( + matmul_auto_scheduler_test, + (512, 512, 512), + "llvm", + init_search_callbacks=[ + auto_scheduler.PreloadCustomSketchRule(meet_condition_func, apply_func) + ], + ) + assert len(sketches) == 2 + assert sketches[0].stages[2].iters[0].range.extent == 512 + assert sketches[0].stages[2].iters[1].range.extent == 512 + assert sketches[0].stages[2].iters[2].range.extent == 512 + assert sketches[1].stages[2].iters[0].range.extent == 32 + assert sketches[1].stages[2].iters[1].range.extent == 8 + assert sketches[1].stages[2].iters[2].range.extent == 2 + assert sketches[1].stages[2].iters[3].range.extent == 512 + assert sketches[1].stages[2].iters[4].range.extent == 512 + + @tvm.testing.requires_cuda def test_cuda_matmul_sketch(): sketches = generate_sketches(matmul_auto_scheduler_test, (512, 512, 512), "cuda") @@ -407,6 +447,7 @@ def test_cuda_zero_rank_sketch(): test_cpu_softmax_sketch() test_cpu_conv2d_winograd_sketch() test_cpu_zero_rank_sketch() + test_cpu_custom_sketch() test_cuda_matmul_sketch() test_cuda_conv2d_bn_relu_sketch() test_cuda_max_pool2d_sketch()