From 7480fefb29e84e4db4db373a29f7b433bd473bbf Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Tue, 12 Jan 2021 16:52:04 +0800 Subject: [PATCH 01/13] Bug fix for costmodel --- python/tvm/auto_scheduler/search_task.py | 5 ++++- python/tvm/auto_scheduler/task_scheduler.py | 13 ++++++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index bfa596a1dc61..d985ed1341f5 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -228,6 +228,9 @@ def __init__( if isinstance(target_host, str): target_host = Target(target_host) + if layout_rewrite_option is None: + layout_rewrite_option = LayoutRewriteOption.get_target_default(target) + self.__init_handle_by_constructor__( _ffi_api.SearchTask, compute_dag, @@ -235,7 +238,7 @@ def __init__( target, target_host, hardware_params, - layout_rewrite_option or LayoutRewriteOption.get_target_default(target), + layout_rewrite_option, ) def tune(self, tuning_options, search_policy=None): diff --git a/python/tvm/auto_scheduler/task_scheduler.py b/python/tvm/auto_scheduler/task_scheduler.py index 975306f7be54..6d8daa2d6250 100644 --- a/python/tvm/auto_scheduler/task_scheduler.py +++ b/python/tvm/auto_scheduler/task_scheduler.py @@ -275,7 +275,13 @@ def __init__( self.group_task_ids.append([]) self.group_task_ids[self.tag_to_group_id[tag]].append(i) - def tune(self, tune_option, search_policy="default", search_policy_params=None): + def tune( + self, + tune_option, + search_policy="default", + search_policy_params=None, + adapative_training=False + ): """Tune a batch of tasks together. Parameters @@ -290,6 +296,10 @@ def tune(self, tune_option, search_policy="default", search_policy_params=None): "sketch.random" for SketchPolicy + RandomModel. search_policy_params : Optional[Dict[str, Any]] The parameters of the search policy + adapative_training : bool = False + Option used for XGBModel, which will reduce the model training frequency when there're too + many logs. + """ # init members self.tune_option = tune_option @@ -324,6 +334,7 @@ def tune(self, tune_option, search_policy="default", search_policy_params=None): tune_option.verbose, self.load_model_file, self.load_log_file, + adapative_training ) # do a round robin first to warm up From fd9fe5028fb521a914a864a659e2668f3548836a Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Tue, 12 Jan 2021 20:03:21 +0800 Subject: [PATCH 02/13] Add custom sketch rules --- python/tvm/auto_scheduler/__init__.py | 5 +- python/tvm/auto_scheduler/search_policy.py | 76 ++++++++++++++++++- .../search_policy/sketch_policy.cc | 26 +++++++ .../search_policy/sketch_policy.h | 28 +++++++ .../search_policy/sketch_policy_rules.cc | 28 +++++++ .../search_policy/sketch_policy_rules.h | 22 ++++++ .../test_auto_scheduler_sketch_generation.py | 72 +++++++++++++++++- 7 files changed, 251 insertions(+), 6 deletions(-) diff --git a/python/tvm/auto_scheduler/__init__.py b/python/tvm/auto_scheduler/__init__.py index a03e156cc10f..0e4f2f95739c 100644 --- a/python/tvm/auto_scheduler/__init__.py +++ b/python/tvm/auto_scheduler/__init__.py @@ -50,6 +50,9 @@ is_auto_scheduler_enabled, ) from .search_task import SearchTask, TuningOptions, HardwareParams, create_task, auto_schedule -from .search_policy import EmptyPolicy, SketchPolicy, PreloadMeasuredStates +from .search_policy import ( + EmptyPolicy, SketchPolicy, PreloadMeasuredStates, PreloadCustomSketchRule, + register_custom_sketch_func, get_custom_sketch_callbacks +) from .task_scheduler import TaskScheduler from .workload_registry import register_workload, make_workload_key diff --git a/python/tvm/auto_scheduler/search_policy.py b/python/tvm/auto_scheduler/search_policy.py index 5b15a48943d2..f64834b755f9 100644 --- a/python/tvm/auto_scheduler/search_policy.py +++ b/python/tvm/auto_scheduler/search_policy.py @@ -25,7 +25,7 @@ The above process is repeated until the auto-scheduler runs out of time budget. Reference: -L. Zheng, C. Jia, M. Sun, Z. Wu, C. Yu, et al. "Ansor : Generating High-Performance Tensor +L. Zheng, C. Jia, M. Sun, Z. Wu, C. Yu, et al. "auto_scheduler : Generating High-Performance Tensor Programs for Deep Learning." (OSDI 2020). """ @@ -34,6 +34,7 @@ import tvm._ffi from tvm.runtime import Object from .cost_model import RandomModel +from .loop_state import State from . import _ffi_api @@ -61,6 +62,72 @@ def __init__(self, filename): self.__init_handle_by_constructor__(_ffi_api.PreloadMeasuredStates, filename) +@tvm._ffi.register_object("auto_scheduler.PreloadCustomSketchRule") +class PreloadCustomSketchRule(SearchCallback): + """ + A SearchCallback for SketchSearchPolicy that allowing users to add + custom sketch rule. + + Notes + ----- + This is an advanced feature. Make sure you're clear how it + works and this should only be used in SketchSearchPolicy. + + Parameters + ---------- + meet_condition_func: Function + A function with `(policy, state, stage_id) -> int` + apply_func: Function + A function with `(policy, state, stage_id) -> [[State, int], ...]` + """ + + CONDITION_NUM = { + "pass": 0, + "apply": 1, + "apply_and_skip_rest": 2 + } + + def __init__(self, meet_condition_func, apply_func, rule_name="CustomSketchRule"): + self.__init_handle_by_constructor__( + _ffi_api.PreloadCustomSketchRule, meet_condition_func, apply_func, rule_name) + + +CUSTOM_SKETCH_REGISTRY = {} + + +def register_custom_sketch_func(compute_name, func=None): + """ + """ + global CUSTOM_SKETCH_REGISTRY + + if callable(compute_name): + func = compute_name + compute_name = func.__name__ + + if not isinstance(compute_name, str): + raise ValueError("expect string function name") + + def register(myf): + if compute_name in CUSTOM_SKETCH_REGISTRY: + raise RuntimeError('Custom Sketch for %s has been registered for compute already' % compute_name) + def meet_condition_func(policy, state, stage_id): + state = State(state, policy.search_task.compute_dag) + if state.stages[stage_id].op.name == compute_name: + return PreloadCustomSketchRule.CONDITION_NUM["apply_and_skip_rest"] + else: + return PreloadCustomSketchRule.CONDITION_NUM["pass"] + CUSTOM_SKETCH_REGISTRY[compute_name] = PreloadCustomSketchRule(meet_condition_func, myf) + return myf + + if func: + return register(func) + return register + + +def get_custom_sketch_callbacks(): + return list(CUSTOM_SKETCH_REGISTRY.values()) + + @tvm._ffi.register_object("auto_scheduler.SearchPolicy") class SearchPolicy(Object): """ The base class of search policies. """ @@ -141,8 +208,6 @@ class SketchPolicy(SearchPolicy): - auto_scheduler.PreloadMeasuredStates - auto_scheduler.PreloadCustomSketchRule - - TODO(jcf94): Add these search callback implementations. """ DEFAULT_PARAMS = { @@ -178,6 +243,11 @@ def __init__( if key not in params: params[key] = value + # global CUSTOM_SKETCH_REGISTRY + # if not init_search_callbacks: + # init_search_callbacks = [] + # init_search_callbacks.extend(list(CUSTOM_SKETCH_REGISTRY.values())) + self.__init_handle_by_constructor__( _ffi_api.SketchPolicy, task, diff --git a/src/auto_scheduler/search_policy/sketch_policy.cc b/src/auto_scheduler/search_policy/sketch_policy.cc index 1e20b0fff6ea..fae0d14578c4 100644 --- a/src/auto_scheduler/search_policy/sketch_policy.cc +++ b/src/auto_scheduler/search_policy/sketch_policy.cc @@ -671,6 +671,27 @@ Array SketchPolicyNode::PickStatesWithEpsGreedy(const Array return inputs; } +/********** PreloadCustomSketchRule **********/ +TVM_REGISTER_OBJECT_TYPE(PreloadCustomSketchRuleNode); + +PreloadCustomSketchRule::PreloadCustomSketchRule(PackedFunc meet_condition_func, + PackedFunc apply_func, + String rule_name) { + auto node = make_object(); + node->meet_condition_func = std::move(meet_condition_func); + node->apply_func = std::move(apply_func); + node->rule_name = std::move(rule_name); + data_ = std::move(node); +} + +void PreloadCustomSketchRuleNode::Callback(SearchPolicyNode* policy) { + CHECK(policy->IsInstance()); + auto sketch_policy = dynamic_cast(policy); + sketch_policy->sketch_rules.push_back( + new RuleCustomSketch(meet_condition_func, apply_func, rule_name)); + StdCout(policy->verbose) << "Custom sketch rule added." << std::endl; +} + TVM_REGISTER_GLOBAL("auto_scheduler.SketchPolicy") .set_body_typed([](SearchTask task, CostModel program_cost_model, Map params, int seed, int verbose, @@ -699,5 +720,10 @@ TVM_REGISTER_GLOBAL("auto_scheduler.PrintTitle").set_body_typed([](std::string t PrintTitle(title, 1); }); +TVM_REGISTER_GLOBAL("auto_scheduler.PreloadCustomSketchRule") +.set_body_typed([](PackedFunc meet_condition_func, PackedFunc apply_func, String rule_name) { + return PreloadCustomSketchRule(meet_condition_func, apply_func, rule_name); +}); + } // namespace auto_scheduler } // namespace tvm diff --git a/src/auto_scheduler/search_policy/sketch_policy.h b/src/auto_scheduler/search_policy/sketch_policy.h index 488634902a87..a760cd1251c0 100644 --- a/src/auto_scheduler/search_policy/sketch_policy.h +++ b/src/auto_scheduler/search_policy/sketch_policy.h @@ -197,6 +197,34 @@ class SketchPolicy : public SearchPolicy { TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SketchPolicy, SearchPolicy, SketchPolicyNode); }; +class RuleCustomSketch; + +/*! \brief Pre-search callback function to load custom rules for sketch generation */ +class PreloadCustomSketchRuleNode : public SearchCallbackNode { + public: + // TODO(jcf94): Use tvm::runtime::TypedPackedFunc? + PackedFunc meet_condition_func; + PackedFunc apply_func; + String rule_name; + + void Callback(SearchPolicyNode* policy) final; + + static constexpr const char *_type_key = "auto_scheduler.PreloadCustomSketchRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(PreloadCustomSketchRuleNode, SearchCallbackNode); +}; + +/*! + * \brief Managed reference to PreloadCustomSketchRuleNode. + * \sa PreloadCustomSketchRuleNode + */ +class PreloadCustomSketchRule : public SearchCallback { + public: + PreloadCustomSketchRule(PackedFunc meet_condition_func, PackedFunc apply_func, String rule_name); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PreloadCustomSketchRule, SearchCallback, + PreloadCustomSketchRuleNode); +}; + } // namespace auto_scheduler } // namespace tvm diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.cc b/src/auto_scheduler/search_policy/sketch_policy_rules.cc index f704fe9e82d5..50a5a1ce959c 100644 --- a/src/auto_scheduler/search_policy/sketch_policy_rules.cc +++ b/src/auto_scheduler/search_policy/sketch_policy_rules.cc @@ -461,6 +461,34 @@ std::vector> RuleSpecialComputeLocationGPU::Apply( return {std::make_pair(std::move(tmp_s), stage_id - 1)}; } +/********** RuleCustomSketch **********/ + +SketchGenerationRule::ConditionKind RuleCustomSketch::MeetCondition( + const SketchPolicyNode& policy, const State& state, int stage_id) const { + auto ret = meet_condition_func_( + tvm::runtime::GetRef(&policy), state, stage_id); + if (ret.type_code() == 0) { + return ConditionKind(static_cast(ret)); + } else { + LOG(WARNING) << "Wrong value returned from custom sketch, try apply and skip the rest"; + return ConditionKind::kApplyAndSkipRest; + } +} + +std::vector > RuleCustomSketch::Apply(const SketchPolicyNode& policy, + const State& state, + int stage_id) const { + Array> apply_ret = apply_func_( + tvm::runtime::GetRef(&policy), state, stage_id); + std::vector > ret; + for (const auto& item : apply_ret) { + CHECK_EQ(item.size(), 2); + auto next = item[1].as(); + ret.emplace_back(Downcast(item[0]), next->value); + } + return ret; +} + /********** Init Population **********/ PopulationGenerationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* policy, State* state, diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.h b/src/auto_scheduler/search_policy/sketch_policy_rules.h index 046f036d59d9..a8d06d5d0874 100644 --- a/src/auto_scheduler/search_policy/sketch_policy_rules.h +++ b/src/auto_scheduler/search_policy/sketch_policy_rules.h @@ -131,6 +131,28 @@ DEFINE_SKETCH_GENERATION_RULE(RuleCrossThreadReduction); * location of the producers of compute ops that perform "fake reduction" with const tensors. */ DEFINE_SKETCH_GENERATION_RULE(RuleSpecialComputeLocationGPU); +/*! \brief The rule that allows users to generate custom sketches. */ +class RuleCustomSketch : public SketchGenerationRule { + public: + RuleCustomSketch(PackedFunc meet_condition_func, PackedFunc apply_func, + String rule_name = "CustomSketchRule"): + meet_condition_func_(std::move(meet_condition_func)), + apply_func_(std::move(apply_func)), rule_name_(std::move(rule_name)) {} + + ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state, + int stage_id) const final; + + std::vector > Apply(const SketchPolicyNode& policy, + const State& state, int stage_id) const final; + + std::string GetRuleName() const final { return rule_name_; } + + private: + PackedFunc meet_condition_func_; + PackedFunc apply_func_; + String rule_name_; +}; + /********** Init Population **********/ /*! \brief The base class for rules used to annotate the sketches to get the initial population. */ diff --git a/tests/python/unittest/test_auto_scheduler_sketch_generation.py b/tests/python/unittest/test_auto_scheduler_sketch_generation.py index ddff6dd1a8d6..1507597ac8fd 100644 --- a/tests/python/unittest/test_auto_scheduler_sketch_generation.py +++ b/tests/python/unittest/test_auto_scheduler_sketch_generation.py @@ -36,9 +36,11 @@ ) -def generate_sketches(workload_func, args, target, print_for_debug=False): +def generate_sketches(workload_func, args, target, print_for_debug=False, + init_search_callbacks=None): task = auto_scheduler.SearchTask(func=workload_func, args=args, target=target) - policy = auto_scheduler.SketchPolicy(task, verbose=0) + policy = auto_scheduler.SketchPolicy( + task, verbose=0, init_search_callbacks=init_search_callbacks) return policy.generate_sketches(print_for_debug) @@ -259,6 +261,70 @@ def test_cpu_zero_rank_sketch(): assert len(sketches) == 3 +def test_cpu_custom_sketch(): + def meet_condition_func(search_policy, state, stage_id): + return auto_scheduler.PreloadCustomSketchRule.CONDITION_NUM["apply_and_skip_rest"] + + def apply_func(search_policy, state, stage_id): + ret = [] + state = auto_scheduler.loop_state.State(state, search_policy.search_task.compute_dag) + C = state.stage_ops[2] + + ret.append([state.state_object, -1]) + + s1 = state.copy() + i, _, _ = s1[C].iters + s1.split(C, i, [8, 2]) + ret.append([s1.state_object, -1]) + return ret + + sketches = generate_sketches( + matmul_auto_scheduler_test, (512, 512, 512), "llvm", + init_search_callbacks=[ + auto_scheduler.PreloadCustomSketchRule(meet_condition_func, apply_func) + ] + ) + assert len(sketches) == 2 + assert sketches[0].stages[2].iters[0].range.extent == 512 + assert sketches[0].stages[2].iters[1].range.extent == 512 + assert sketches[0].stages[2].iters[2].range.extent == 512 + assert sketches[1].stages[2].iters[0].range.extent == 32 + assert sketches[1].stages[2].iters[1].range.extent == 8 + assert sketches[1].stages[2].iters[2].range.extent == 2 + assert sketches[1].stages[2].iters[3].range.extent == 512 + assert sketches[1].stages[2].iters[4].range.extent == 512 + + +def test_cpu_custom_sketch_registry(): + @auto_scheduler.register_custom_sketch_func + def C(search_policy, state, stage_id): + ret = [] + state = auto_scheduler.loop_state.State(state, search_policy.search_task.compute_dag) + C = state.stage_ops[2] + + ret.append([state.state_object, -1]) + + s1 = state.copy() + i, _, _ = s1[C].iters + s1.split(C, i, [8, 2]) + ret.append([s1.state_object, -1]) + return ret + + sketches = generate_sketches( + matmul_auto_scheduler_test, (512, 512, 512), "llvm", + init_search_callbacks=auto_scheduler.get_custom_sketch_callbacks() + ) + assert len(sketches) == 2 + assert sketches[0].stages[2].iters[0].range.extent == 512 + assert sketches[0].stages[2].iters[1].range.extent == 512 + assert sketches[0].stages[2].iters[2].range.extent == 512 + assert sketches[1].stages[2].iters[0].range.extent == 32 + assert sketches[1].stages[2].iters[1].range.extent == 8 + assert sketches[1].stages[2].iters[2].range.extent == 2 + assert sketches[1].stages[2].iters[3].range.extent == 512 + assert sketches[1].stages[2].iters[4].range.extent == 512 + + @tvm.testing.requires_cuda def test_cuda_matmul_sketch(): sketches = generate_sketches(matmul_auto_scheduler_test, (512, 512, 512), "cuda") @@ -407,6 +473,8 @@ def test_cuda_zero_rank_sketch(): test_cpu_softmax_sketch() test_cpu_conv2d_winograd_sketch() test_cpu_zero_rank_sketch() + test_cpu_custom_sketch() + test_cpu_custom_sketch_registry() test_cuda_matmul_sketch() test_cuda_conv2d_bn_relu_sketch() test_cuda_max_pool2d_sketch() From 35307da2771def44cd1a0ca52918831405392656 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Tue, 12 Jan 2021 20:11:57 +0800 Subject: [PATCH 03/13] Lint fix --- python/tvm/auto_scheduler/__init__.py | 8 ++++++-- python/tvm/auto_scheduler/search_policy.py | 20 +++++++++---------- python/tvm/auto_scheduler/task_scheduler.py | 9 ++++----- .../search_policy/sketch_policy.cc | 9 ++++----- .../search_policy/sketch_policy.h | 2 +- .../search_policy/sketch_policy_rules.cc | 19 +++++++++--------- .../search_policy/sketch_policy_rules.h | 11 +++++----- 7 files changed, 40 insertions(+), 38 deletions(-) diff --git a/python/tvm/auto_scheduler/__init__.py b/python/tvm/auto_scheduler/__init__.py index 0e4f2f95739c..59148a39278d 100644 --- a/python/tvm/auto_scheduler/__init__.py +++ b/python/tvm/auto_scheduler/__init__.py @@ -51,8 +51,12 @@ ) from .search_task import SearchTask, TuningOptions, HardwareParams, create_task, auto_schedule from .search_policy import ( - EmptyPolicy, SketchPolicy, PreloadMeasuredStates, PreloadCustomSketchRule, - register_custom_sketch_func, get_custom_sketch_callbacks + EmptyPolicy, + SketchPolicy, + PreloadMeasuredStates, + PreloadCustomSketchRule, + register_custom_sketch_func, + get_custom_sketch_callbacks, ) from .task_scheduler import TaskScheduler from .workload_registry import register_workload, make_workload_key diff --git a/python/tvm/auto_scheduler/search_policy.py b/python/tvm/auto_scheduler/search_policy.py index f64834b755f9..1177bc7c9cd6 100644 --- a/python/tvm/auto_scheduler/search_policy.py +++ b/python/tvm/auto_scheduler/search_policy.py @@ -81,22 +81,19 @@ class PreloadCustomSketchRule(SearchCallback): A function with `(policy, state, stage_id) -> [[State, int], ...]` """ - CONDITION_NUM = { - "pass": 0, - "apply": 1, - "apply_and_skip_rest": 2 - } + CONDITION_NUM = {"pass": 0, "apply": 1, "apply_and_skip_rest": 2} def __init__(self, meet_condition_func, apply_func, rule_name="CustomSketchRule"): self.__init_handle_by_constructor__( - _ffi_api.PreloadCustomSketchRule, meet_condition_func, apply_func, rule_name) + _ffi_api.PreloadCustomSketchRule, meet_condition_func, apply_func, rule_name + ) CUSTOM_SKETCH_REGISTRY = {} def register_custom_sketch_func(compute_name, func=None): - """ + """ Helper decorator to register custom sketch functions easily. """ global CUSTOM_SKETCH_REGISTRY @@ -109,13 +106,16 @@ def register_custom_sketch_func(compute_name, func=None): def register(myf): if compute_name in CUSTOM_SKETCH_REGISTRY: - raise RuntimeError('Custom Sketch for %s has been registered for compute already' % compute_name) + raise RuntimeError( + "Custom Sketch for %s has been registered for compute already" % compute_name + ) + def meet_condition_func(policy, state, stage_id): state = State(state, policy.search_task.compute_dag) if state.stages[stage_id].op.name == compute_name: return PreloadCustomSketchRule.CONDITION_NUM["apply_and_skip_rest"] - else: - return PreloadCustomSketchRule.CONDITION_NUM["pass"] + return PreloadCustomSketchRule.CONDITION_NUM["pass"] + CUSTOM_SKETCH_REGISTRY[compute_name] = PreloadCustomSketchRule(meet_condition_func, myf) return myf diff --git a/python/tvm/auto_scheduler/task_scheduler.py b/python/tvm/auto_scheduler/task_scheduler.py index 6d8daa2d6250..8aca025c2bb1 100644 --- a/python/tvm/auto_scheduler/task_scheduler.py +++ b/python/tvm/auto_scheduler/task_scheduler.py @@ -280,7 +280,7 @@ def tune( tune_option, search_policy="default", search_policy_params=None, - adapative_training=False + adapative_training=False, ): """Tune a batch of tasks together. @@ -297,9 +297,8 @@ def tune( search_policy_params : Optional[Dict[str, Any]] The parameters of the search policy adapative_training : bool = False - Option used for XGBModel, which will reduce the model training frequency when there're too - many logs. - + Option used for XGBModel, which will reduce the model training frequency when there're + too many logs. """ # init members self.tune_option = tune_option @@ -334,7 +333,7 @@ def tune( tune_option.verbose, self.load_model_file, self.load_log_file, - adapative_training + adapative_training, ) # do a round robin first to warm up diff --git a/src/auto_scheduler/search_policy/sketch_policy.cc b/src/auto_scheduler/search_policy/sketch_policy.cc index fae0d14578c4..fb27a1281750 100644 --- a/src/auto_scheduler/search_policy/sketch_policy.cc +++ b/src/auto_scheduler/search_policy/sketch_policy.cc @@ -675,8 +675,7 @@ Array SketchPolicyNode::PickStatesWithEpsGreedy(const Array TVM_REGISTER_OBJECT_TYPE(PreloadCustomSketchRuleNode); PreloadCustomSketchRule::PreloadCustomSketchRule(PackedFunc meet_condition_func, - PackedFunc apply_func, - String rule_name) { + PackedFunc apply_func, String rule_name) { auto node = make_object(); node->meet_condition_func = std::move(meet_condition_func); node->apply_func = std::move(apply_func); @@ -721,9 +720,9 @@ TVM_REGISTER_GLOBAL("auto_scheduler.PrintTitle").set_body_typed([](std::string t }); TVM_REGISTER_GLOBAL("auto_scheduler.PreloadCustomSketchRule") -.set_body_typed([](PackedFunc meet_condition_func, PackedFunc apply_func, String rule_name) { - return PreloadCustomSketchRule(meet_condition_func, apply_func, rule_name); -}); + .set_body_typed([](PackedFunc meet_condition_func, PackedFunc apply_func, String rule_name) { + return PreloadCustomSketchRule(meet_condition_func, apply_func, rule_name); + }); } // namespace auto_scheduler } // namespace tvm diff --git a/src/auto_scheduler/search_policy/sketch_policy.h b/src/auto_scheduler/search_policy/sketch_policy.h index a760cd1251c0..792defc7ae4b 100644 --- a/src/auto_scheduler/search_policy/sketch_policy.h +++ b/src/auto_scheduler/search_policy/sketch_policy.h @@ -209,7 +209,7 @@ class PreloadCustomSketchRuleNode : public SearchCallbackNode { void Callback(SearchPolicyNode* policy) final; - static constexpr const char *_type_key = "auto_scheduler.PreloadCustomSketchRule"; + static constexpr const char* _type_key = "auto_scheduler.PreloadCustomSketchRule"; TVM_DECLARE_FINAL_OBJECT_INFO(PreloadCustomSketchRuleNode, SearchCallbackNode); }; diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.cc b/src/auto_scheduler/search_policy/sketch_policy_rules.cc index 50a5a1ce959c..598c8860d663 100644 --- a/src/auto_scheduler/search_policy/sketch_policy_rules.cc +++ b/src/auto_scheduler/search_policy/sketch_policy_rules.cc @@ -463,10 +463,10 @@ std::vector> RuleSpecialComputeLocationGPU::Apply( /********** RuleCustomSketch **********/ -SketchGenerationRule::ConditionKind RuleCustomSketch::MeetCondition( - const SketchPolicyNode& policy, const State& state, int stage_id) const { - auto ret = meet_condition_func_( - tvm::runtime::GetRef(&policy), state, stage_id); +SketchGenerationRule::ConditionKind RuleCustomSketch::MeetCondition(const SketchPolicyNode& policy, + const State& state, + int stage_id) const { + auto ret = meet_condition_func_(tvm::runtime::GetRef(&policy), state, stage_id); if (ret.type_code() == 0) { return ConditionKind(static_cast(ret)); } else { @@ -475,12 +475,11 @@ SketchGenerationRule::ConditionKind RuleCustomSketch::MeetCondition( } } -std::vector > RuleCustomSketch::Apply(const SketchPolicyNode& policy, - const State& state, - int stage_id) const { - Array> apply_ret = apply_func_( - tvm::runtime::GetRef(&policy), state, stage_id); - std::vector > ret; +std::vector> RuleCustomSketch::Apply(const SketchPolicyNode& policy, + const State& state, int stage_id) const { + Array> apply_ret = + apply_func_(tvm::runtime::GetRef(&policy), state, stage_id); + std::vector> ret; for (const auto& item : apply_ret) { CHECK_EQ(item.size(), 2); auto next = item[1].as(); diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.h b/src/auto_scheduler/search_policy/sketch_policy_rules.h index a8d06d5d0874..fc1916b8c67d 100644 --- a/src/auto_scheduler/search_policy/sketch_policy_rules.h +++ b/src/auto_scheduler/search_policy/sketch_policy_rules.h @@ -135,15 +135,16 @@ DEFINE_SKETCH_GENERATION_RULE(RuleSpecialComputeLocationGPU); class RuleCustomSketch : public SketchGenerationRule { public: RuleCustomSketch(PackedFunc meet_condition_func, PackedFunc apply_func, - String rule_name = "CustomSketchRule"): - meet_condition_func_(std::move(meet_condition_func)), - apply_func_(std::move(apply_func)), rule_name_(std::move(rule_name)) {} + String rule_name = "CustomSketchRule") + : meet_condition_func_(std::move(meet_condition_func)), + apply_func_(std::move(apply_func)), + rule_name_(std::move(rule_name)) {} ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state, int stage_id) const final; - std::vector > Apply(const SketchPolicyNode& policy, - const State& state, int stage_id) const final; + std::vector> Apply(const SketchPolicyNode& policy, const State& state, + int stage_id) const final; std::string GetRuleName() const final { return rule_name_; } From 50e0bfd7a62cec0de5dd2d43a154c60f3288bf97 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Tue, 12 Jan 2021 20:13:05 +0800 Subject: [PATCH 04/13] Update --- python/tvm/auto_scheduler/search_policy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/auto_scheduler/search_policy.py b/python/tvm/auto_scheduler/search_policy.py index 1177bc7c9cd6..750edeb43930 100644 --- a/python/tvm/auto_scheduler/search_policy.py +++ b/python/tvm/auto_scheduler/search_policy.py @@ -25,7 +25,7 @@ The above process is repeated until the auto-scheduler runs out of time budget. Reference: -L. Zheng, C. Jia, M. Sun, Z. Wu, C. Yu, et al. "auto_scheduler : Generating High-Performance Tensor +L. Zheng, C. Jia, M. Sun, Z. Wu, C. Yu, et al. "Ansor : Generating High-Performance Tensor Programs for Deep Learning." (OSDI 2020). """ From f0fa7a115979580ddf87aa85cd57407240b83f30 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Tue, 12 Jan 2021 20:14:35 +0800 Subject: [PATCH 05/13] Update --- python/tvm/auto_scheduler/search_policy.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/python/tvm/auto_scheduler/search_policy.py b/python/tvm/auto_scheduler/search_policy.py index 750edeb43930..bede89a64b4f 100644 --- a/python/tvm/auto_scheduler/search_policy.py +++ b/python/tvm/auto_scheduler/search_policy.py @@ -243,11 +243,6 @@ def __init__( if key not in params: params[key] = value - # global CUSTOM_SKETCH_REGISTRY - # if not init_search_callbacks: - # init_search_callbacks = [] - # init_search_callbacks.extend(list(CUSTOM_SKETCH_REGISTRY.values())) - self.__init_handle_by_constructor__( _ffi_api.SketchPolicy, task, From a8e95892456777eb40806c88e3e5dbfa95b513ed Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Tue, 12 Jan 2021 20:45:03 +0800 Subject: [PATCH 06/13] Lint fix --- python/tvm/auto_scheduler/search_policy.py | 3 +-- .../test_auto_scheduler_sketch_generation.py | 20 ++++++++++++------- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/python/tvm/auto_scheduler/search_policy.py b/python/tvm/auto_scheduler/search_policy.py index bede89a64b4f..91602c73ff60 100644 --- a/python/tvm/auto_scheduler/search_policy.py +++ b/python/tvm/auto_scheduler/search_policy.py @@ -93,8 +93,7 @@ def __init__(self, meet_condition_func, apply_func, rule_name="CustomSketchRule" def register_custom_sketch_func(compute_name, func=None): - """ Helper decorator to register custom sketch functions easily. - """ + """Helper decorator to register custom sketch functions easily.""" global CUSTOM_SKETCH_REGISTRY if callable(compute_name): diff --git a/tests/python/unittest/test_auto_scheduler_sketch_generation.py b/tests/python/unittest/test_auto_scheduler_sketch_generation.py index 1507597ac8fd..c1616c5e2e27 100644 --- a/tests/python/unittest/test_auto_scheduler_sketch_generation.py +++ b/tests/python/unittest/test_auto_scheduler_sketch_generation.py @@ -36,11 +36,13 @@ ) -def generate_sketches(workload_func, args, target, print_for_debug=False, - init_search_callbacks=None): +def generate_sketches( + workload_func, args, target, print_for_debug=False, init_search_callbacks=None +): task = auto_scheduler.SearchTask(func=workload_func, args=args, target=target) policy = auto_scheduler.SketchPolicy( - task, verbose=0, init_search_callbacks=init_search_callbacks) + task, verbose=0, init_search_callbacks=init_search_callbacks + ) return policy.generate_sketches(print_for_debug) @@ -279,10 +281,12 @@ def apply_func(search_policy, state, stage_id): return ret sketches = generate_sketches( - matmul_auto_scheduler_test, (512, 512, 512), "llvm", + matmul_auto_scheduler_test, + (512, 512, 512), + "llvm", init_search_callbacks=[ auto_scheduler.PreloadCustomSketchRule(meet_condition_func, apply_func) - ] + ], ) assert len(sketches) == 2 assert sketches[0].stages[2].iters[0].range.extent == 512 @@ -311,8 +315,10 @@ def C(search_policy, state, stage_id): return ret sketches = generate_sketches( - matmul_auto_scheduler_test, (512, 512, 512), "llvm", - init_search_callbacks=auto_scheduler.get_custom_sketch_callbacks() + matmul_auto_scheduler_test, + (512, 512, 512), + "llvm", + init_search_callbacks=auto_scheduler.get_custom_sketch_callbacks(), ) assert len(sketches) == 2 assert sketches[0].stages[2].iters[0].range.extent == 512 From 4967b5cdbc8684db3c3c41b0fb8bf6e60da651db Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Sat, 16 Jan 2021 17:43:32 +0800 Subject: [PATCH 07/13] Update --- python/tvm/auto_scheduler/search_policy.py | 49 ++++++++++++++++--- python/tvm/auto_scheduler/task_scheduler.py | 2 +- .../search_policy/sketch_policy.cc | 2 +- .../search_policy/sketch_policy_rules.cc | 2 +- 4 files changed, 46 insertions(+), 9 deletions(-) diff --git a/python/tvm/auto_scheduler/search_policy.py b/python/tvm/auto_scheduler/search_policy.py index 91602c73ff60..14d5ee3e1582 100644 --- a/python/tvm/auto_scheduler/search_policy.py +++ b/python/tvm/auto_scheduler/search_policy.py @@ -65,7 +65,7 @@ def __init__(self, filename): @tvm._ffi.register_object("auto_scheduler.PreloadCustomSketchRule") class PreloadCustomSketchRule(SearchCallback): """ - A SearchCallback for SketchSearchPolicy that allowing users to add + A SearchCallback for SketchSearchPolicy that allows users to add custom sketch rule. Notes @@ -75,10 +75,12 @@ class PreloadCustomSketchRule(SearchCallback): Parameters ---------- - meet_condition_func: Function - A function with `(policy, state, stage_id) -> int` - apply_func: Function - A function with `(policy, state, stage_id) -> [[State, int], ...]` + meet_condition_func: Callable + A function with `(policy, state, stage_id) -> int`. + apply_func: Callable + A function with `(policy, state, stage_id) -> [[State, int], ...]`. + rule_name: str = "CustomSketchRule" + The name of this custom sketch rule. """ CONDITION_NUM = {"pass": 0, "apply": 1, "apply_and_skip_rest": 2} @@ -93,7 +95,42 @@ def __init__(self, meet_condition_func, apply_func, rule_name="CustomSketchRule" def register_custom_sketch_func(compute_name, func=None): - """Helper decorator to register custom sketch functions easily.""" + """ Helper decorator to register custom sketch functions easily. + The registered function will be used as the apply function of this custom sketch rule. + The meet condition of this custom sketch rule is to match the name of a specific stage. + + Example usage: + For a compute stage: + C = te.compute( + (N, M), + lambda ..., + name="C", + ) + We can register the custom rule by: + @auto_scheduler.register_custom_sketch_func + def C(search_policy, state, stage_id): + ret = [] + state = auto_scheduler.loop_state.State( + state, search_policy.search_task.compute_dag + ) + + ... Do any process with the state ... + + ret.append([state.state_object, -1]) + return ret + Or by: + @auto_scheduler.register_custom_sketch_func("C") + def func(search_policy, state, stage_id): + ret = [] + state = auto_scheduler.loop_state.State( + state, search_policy.search_task.compute_dag + ) + + ... Do any process with the state ... + + ret.append([state.state_object, -1]) + return ret + """ global CUSTOM_SKETCH_REGISTRY if callable(compute_name): diff --git a/python/tvm/auto_scheduler/task_scheduler.py b/python/tvm/auto_scheduler/task_scheduler.py index 8aca025c2bb1..63503a81e60a 100644 --- a/python/tvm/auto_scheduler/task_scheduler.py +++ b/python/tvm/auto_scheduler/task_scheduler.py @@ -72,7 +72,7 @@ def make_search_policies( Load measurement records from this file. If it is not None, the status of the task scheduler, search policies and cost models will be restored according to this file. adapative_training: bool = False - Option used for XGBModel, which will reduce the model training frequency when there're too + Option used by XGBModel to reduce the model training frequency when there're too many logs. Returns diff --git a/src/auto_scheduler/search_policy/sketch_policy.cc b/src/auto_scheduler/search_policy/sketch_policy.cc index fb27a1281750..91721afdba74 100644 --- a/src/auto_scheduler/search_policy/sketch_policy.cc +++ b/src/auto_scheduler/search_policy/sketch_policy.cc @@ -688,7 +688,7 @@ void PreloadCustomSketchRuleNode::Callback(SearchPolicyNode* policy) { auto sketch_policy = dynamic_cast(policy); sketch_policy->sketch_rules.push_back( new RuleCustomSketch(meet_condition_func, apply_func, rule_name)); - StdCout(policy->verbose) << "Custom sketch rule added." << std::endl; + StdCout(policy->verbose) << "Custom sketch rule \"" << rule_name << "\" added." << std::endl; } TVM_REGISTER_GLOBAL("auto_scheduler.SketchPolicy") diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.cc b/src/auto_scheduler/search_policy/sketch_policy_rules.cc index 598c8860d663..110be6bd6f68 100644 --- a/src/auto_scheduler/search_policy/sketch_policy_rules.cc +++ b/src/auto_scheduler/search_policy/sketch_policy_rules.cc @@ -470,7 +470,7 @@ SketchGenerationRule::ConditionKind RuleCustomSketch::MeetCondition(const Sketch if (ret.type_code() == 0) { return ConditionKind(static_cast(ret)); } else { - LOG(WARNING) << "Wrong value returned from custom sketch, try apply and skip the rest"; + LOG(WARNING) << "Wrong rule condition value. Apply the rule and skip the rest"; return ConditionKind::kApplyAndSkipRest; } } From e5d781e3f2ea5e5662ad425715d995b5e72eb0bc Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Sat, 16 Jan 2021 18:11:39 +0800 Subject: [PATCH 08/13] Update --- python/tvm/auto_scheduler/search_policy.py | 76 ++++++++++--------- .../test_auto_scheduler_search_policy.py | 26 +++++++ .../test_auto_scheduler_sketch_generation.py | 2 +- 3 files changed, 66 insertions(+), 38 deletions(-) diff --git a/python/tvm/auto_scheduler/search_policy.py b/python/tvm/auto_scheduler/search_policy.py index 14d5ee3e1582..d0f4362d89fc 100644 --- a/python/tvm/auto_scheduler/search_policy.py +++ b/python/tvm/auto_scheduler/search_policy.py @@ -83,7 +83,9 @@ class PreloadCustomSketchRule(SearchCallback): The name of this custom sketch rule. """ - CONDITION_NUM = {"pass": 0, "apply": 1, "apply_and_skip_rest": 2} + PASS = 0 + APPLY = 1 + APPLY_AND_SKIP_REST = 2 def __init__(self, meet_condition_func, apply_func, rule_name="CustomSketchRule"): self.__init_handle_by_constructor__( @@ -95,41 +97,41 @@ def __init__(self, meet_condition_func, apply_func, rule_name="CustomSketchRule" def register_custom_sketch_func(compute_name, func=None): - """ Helper decorator to register custom sketch functions easily. - The registered function will be used as the apply function of this custom sketch rule. - The meet condition of this custom sketch rule is to match the name of a specific stage. - - Example usage: - For a compute stage: - C = te.compute( - (N, M), - lambda ..., - name="C", + """Helper decorator to register custom sketch functions easily. + The registered function will be used as the apply function of this custom sketch rule. + The meet condition of this custom sketch rule is to match the name of a specific stage. + + Example usage: + For a compute stage: + C = te.compute( + (N, M), + lambda ..., + name="C", + ) + We can register the custom rule by: + @auto_scheduler.register_custom_sketch_func + def C(search_policy, state, stage_id): + ret = [] + state = auto_scheduler.loop_state.State( + state, search_policy.search_task.compute_dag + ) + + ... Do any process with the state ... + + ret.append([state.state_object, -1]) + return ret + Or by: + @auto_scheduler.register_custom_sketch_func("C") + def func(search_policy, state, stage_id): + ret = [] + state = auto_scheduler.loop_state.State( + state, search_policy.search_task.compute_dag ) - We can register the custom rule by: - @auto_scheduler.register_custom_sketch_func - def C(search_policy, state, stage_id): - ret = [] - state = auto_scheduler.loop_state.State( - state, search_policy.search_task.compute_dag - ) - - ... Do any process with the state ... - - ret.append([state.state_object, -1]) - return ret - Or by: - @auto_scheduler.register_custom_sketch_func("C") - def func(search_policy, state, stage_id): - ret = [] - state = auto_scheduler.loop_state.State( - state, search_policy.search_task.compute_dag - ) - - ... Do any process with the state ... - - ret.append([state.state_object, -1]) - return ret + + ... Do any process with the state ... + + ret.append([state.state_object, -1]) + return ret """ global CUSTOM_SKETCH_REGISTRY @@ -149,8 +151,8 @@ def register(myf): def meet_condition_func(policy, state, stage_id): state = State(state, policy.search_task.compute_dag) if state.stages[stage_id].op.name == compute_name: - return PreloadCustomSketchRule.CONDITION_NUM["apply_and_skip_rest"] - return PreloadCustomSketchRule.CONDITION_NUM["pass"] + return PreloadCustomSketchRule.APPLY_AND_SKIP_REST + return PreloadCustomSketchRule.PASS CUSTOM_SKETCH_REGISTRY[compute_name] = PreloadCustomSketchRule(meet_condition_func, myf) return myf diff --git a/tests/python/unittest/test_auto_scheduler_search_policy.py b/tests/python/unittest/test_auto_scheduler_search_policy.py index c96dc63fec29..48c784ffd899 100644 --- a/tests/python/unittest/test_auto_scheduler_search_policy.py +++ b/tests/python/unittest/test_auto_scheduler_search_policy.py @@ -183,6 +183,31 @@ def test_sketch_search_policy_zero_rank(): search_common(task, runner=measure_ctx.runner) +@tvm.testing.requires_llvm +def test_sketch_search_policy_custom_sketch(): + def meet_condition_func(search_policy, state, stage_id): + return auto_scheduler.PreloadCustomSketchRule.APPLY_AND_SKIP_REST + + def apply_func(search_policy, state, stage_id): + ret = [] + state = auto_scheduler.loop_state.State(state, search_policy.search_task.compute_dag) + C = state.stage_ops[2] + + ret.append([state.state_object, -1]) + + s1 = state.copy() + i, _, _ = s1[C].iters + s1.split(C, i, [8, 2]) + ret.append([s1.state_object, -1]) + return ret + + search_common( + cost_model=auto_scheduler.XGBModel(), + init_search_callbacks=[ + auto_scheduler.PreloadCustomSketchRule(meet_condition_func, apply_func) + ], + ) + if __name__ == "__main__": test_workload_registry_empty_policy() test_sketch_search_policy_basic() @@ -191,3 +216,4 @@ def test_sketch_search_policy_zero_rank(): test_sketch_search_policy_cuda_rpc_runner() test_sketch_search_policy_cuda_xgbmodel_rpc_runner() test_sketch_search_policy_zero_rank() + test_sketch_search_policy_custom_sketch() diff --git a/tests/python/unittest/test_auto_scheduler_sketch_generation.py b/tests/python/unittest/test_auto_scheduler_sketch_generation.py index c1616c5e2e27..ed6a7e2224c9 100644 --- a/tests/python/unittest/test_auto_scheduler_sketch_generation.py +++ b/tests/python/unittest/test_auto_scheduler_sketch_generation.py @@ -265,7 +265,7 @@ def test_cpu_zero_rank_sketch(): def test_cpu_custom_sketch(): def meet_condition_func(search_policy, state, stage_id): - return auto_scheduler.PreloadCustomSketchRule.CONDITION_NUM["apply_and_skip_rest"] + return auto_scheduler.PreloadCustomSketchRule.APPLY_AND_SKIP_REST def apply_func(search_policy, state, stage_id): ret = [] From 5dead1c77c36ca6e5fae96aaa5437d26981d146a Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Sun, 17 Jan 2021 14:57:41 +0800 Subject: [PATCH 09/13] Update --- python/tvm/auto_scheduler/search_policy.py | 4 ++-- tests/python/unittest/test_auto_scheduler_search_policy.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/python/tvm/auto_scheduler/search_policy.py b/python/tvm/auto_scheduler/search_policy.py index d0f4362d89fc..2c3e104d282b 100644 --- a/python/tvm/auto_scheduler/search_policy.py +++ b/python/tvm/auto_scheduler/search_policy.py @@ -115,7 +115,7 @@ def C(search_policy, state, stage_id): state = auto_scheduler.loop_state.State( state, search_policy.search_task.compute_dag ) - + ... Do any process with the state ... ret.append([state.state_object, -1]) @@ -127,7 +127,7 @@ def func(search_policy, state, stage_id): state = auto_scheduler.loop_state.State( state, search_policy.search_task.compute_dag ) - + ... Do any process with the state ... ret.append([state.state_object, -1]) diff --git a/tests/python/unittest/test_auto_scheduler_search_policy.py b/tests/python/unittest/test_auto_scheduler_search_policy.py index 48c784ffd899..30aafbd22390 100644 --- a/tests/python/unittest/test_auto_scheduler_search_policy.py +++ b/tests/python/unittest/test_auto_scheduler_search_policy.py @@ -197,7 +197,7 @@ def apply_func(search_policy, state, stage_id): s1 = state.copy() i, _, _ = s1[C].iters - s1.split(C, i, [8, 2]) + s1.split(C, i, [8]) ret.append([s1.state_object, -1]) return ret @@ -208,6 +208,7 @@ def apply_func(search_policy, state, stage_id): ], ) + if __name__ == "__main__": test_workload_registry_empty_policy() test_sketch_search_policy_basic() From 17b9efc694b394ffd390a42553e2b6b46647c9dc Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Sun, 17 Jan 2021 17:02:54 +0800 Subject: [PATCH 10/13] Update --- python/tvm/auto_scheduler/task_scheduler.py | 2 +- src/auto_scheduler/search_policy/sketch_policy.h | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/python/tvm/auto_scheduler/task_scheduler.py b/python/tvm/auto_scheduler/task_scheduler.py index 63503a81e60a..420b5f765a97 100644 --- a/python/tvm/auto_scheduler/task_scheduler.py +++ b/python/tvm/auto_scheduler/task_scheduler.py @@ -297,7 +297,7 @@ def tune( search_policy_params : Optional[Dict[str, Any]] The parameters of the search policy adapative_training : bool = False - Option used for XGBModel, which will reduce the model training frequency when there're + Option used by XGBModel to reduce the model training frequency when there're too many logs. """ # init members diff --git a/src/auto_scheduler/search_policy/sketch_policy.h b/src/auto_scheduler/search_policy/sketch_policy.h index 792defc7ae4b..56986bf88d2d 100644 --- a/src/auto_scheduler/search_policy/sketch_policy.h +++ b/src/auto_scheduler/search_policy/sketch_policy.h @@ -202,7 +202,6 @@ class RuleCustomSketch; /*! \brief Pre-search callback function to load custom rules for sketch generation */ class PreloadCustomSketchRuleNode : public SearchCallbackNode { public: - // TODO(jcf94): Use tvm::runtime::TypedPackedFunc? PackedFunc meet_condition_func; PackedFunc apply_func; String rule_name; From 6adf2e2c5a1d952e893cc49e0b93f5e30e245674 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Mon, 18 Jan 2021 16:41:28 +0800 Subject: [PATCH 11/13] Remove decorator --- python/tvm/auto_scheduler/__init__.py | 2 - python/tvm/auto_scheduler/search_policy.py | 73 ------------------- .../search_policy/sketch_policy.h | 2 - .../test_auto_scheduler_search_policy.py | 14 ++-- .../test_auto_scheduler_sketch_generation.py | 32 -------- 5 files changed, 7 insertions(+), 116 deletions(-) diff --git a/python/tvm/auto_scheduler/__init__.py b/python/tvm/auto_scheduler/__init__.py index 59148a39278d..57e58309525c 100644 --- a/python/tvm/auto_scheduler/__init__.py +++ b/python/tvm/auto_scheduler/__init__.py @@ -55,8 +55,6 @@ SketchPolicy, PreloadMeasuredStates, PreloadCustomSketchRule, - register_custom_sketch_func, - get_custom_sketch_callbacks, ) from .task_scheduler import TaskScheduler from .workload_registry import register_workload, make_workload_key diff --git a/python/tvm/auto_scheduler/search_policy.py b/python/tvm/auto_scheduler/search_policy.py index 2c3e104d282b..becc7e691c3f 100644 --- a/python/tvm/auto_scheduler/search_policy.py +++ b/python/tvm/auto_scheduler/search_policy.py @@ -93,79 +93,6 @@ def __init__(self, meet_condition_func, apply_func, rule_name="CustomSketchRule" ) -CUSTOM_SKETCH_REGISTRY = {} - - -def register_custom_sketch_func(compute_name, func=None): - """Helper decorator to register custom sketch functions easily. - The registered function will be used as the apply function of this custom sketch rule. - The meet condition of this custom sketch rule is to match the name of a specific stage. - - Example usage: - For a compute stage: - C = te.compute( - (N, M), - lambda ..., - name="C", - ) - We can register the custom rule by: - @auto_scheduler.register_custom_sketch_func - def C(search_policy, state, stage_id): - ret = [] - state = auto_scheduler.loop_state.State( - state, search_policy.search_task.compute_dag - ) - - ... Do any process with the state ... - - ret.append([state.state_object, -1]) - return ret - Or by: - @auto_scheduler.register_custom_sketch_func("C") - def func(search_policy, state, stage_id): - ret = [] - state = auto_scheduler.loop_state.State( - state, search_policy.search_task.compute_dag - ) - - ... Do any process with the state ... - - ret.append([state.state_object, -1]) - return ret - """ - global CUSTOM_SKETCH_REGISTRY - - if callable(compute_name): - func = compute_name - compute_name = func.__name__ - - if not isinstance(compute_name, str): - raise ValueError("expect string function name") - - def register(myf): - if compute_name in CUSTOM_SKETCH_REGISTRY: - raise RuntimeError( - "Custom Sketch for %s has been registered for compute already" % compute_name - ) - - def meet_condition_func(policy, state, stage_id): - state = State(state, policy.search_task.compute_dag) - if state.stages[stage_id].op.name == compute_name: - return PreloadCustomSketchRule.APPLY_AND_SKIP_REST - return PreloadCustomSketchRule.PASS - - CUSTOM_SKETCH_REGISTRY[compute_name] = PreloadCustomSketchRule(meet_condition_func, myf) - return myf - - if func: - return register(func) - return register - - -def get_custom_sketch_callbacks(): - return list(CUSTOM_SKETCH_REGISTRY.values()) - - @tvm._ffi.register_object("auto_scheduler.SearchPolicy") class SearchPolicy(Object): """ The base class of search policies. """ diff --git a/src/auto_scheduler/search_policy/sketch_policy.h b/src/auto_scheduler/search_policy/sketch_policy.h index 56986bf88d2d..e409e83dd471 100644 --- a/src/auto_scheduler/search_policy/sketch_policy.h +++ b/src/auto_scheduler/search_policy/sketch_policy.h @@ -197,8 +197,6 @@ class SketchPolicy : public SearchPolicy { TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SketchPolicy, SearchPolicy, SketchPolicyNode); }; -class RuleCustomSketch; - /*! \brief Pre-search callback function to load custom rules for sketch generation */ class PreloadCustomSketchRuleNode : public SearchCallbackNode { public: diff --git a/tests/python/unittest/test_auto_scheduler_search_policy.py b/tests/python/unittest/test_auto_scheduler_search_policy.py index 30aafbd22390..4ae1492709ad 100644 --- a/tests/python/unittest/test_auto_scheduler_search_policy.py +++ b/tests/python/unittest/test_auto_scheduler_search_policy.py @@ -210,11 +210,11 @@ def apply_func(search_policy, state, stage_id): if __name__ == "__main__": - test_workload_registry_empty_policy() - test_sketch_search_policy_basic() - test_sketch_search_policy_basic_spawn() - test_sketch_search_policy_xgbmodel() - test_sketch_search_policy_cuda_rpc_runner() - test_sketch_search_policy_cuda_xgbmodel_rpc_runner() - test_sketch_search_policy_zero_rank() + # test_workload_registry_empty_policy() + # test_sketch_search_policy_basic() + # test_sketch_search_policy_basic_spawn() + # test_sketch_search_policy_xgbmodel() + # test_sketch_search_policy_cuda_rpc_runner() + # test_sketch_search_policy_cuda_xgbmodel_rpc_runner() + # test_sketch_search_policy_zero_rank() test_sketch_search_policy_custom_sketch() diff --git a/tests/python/unittest/test_auto_scheduler_sketch_generation.py b/tests/python/unittest/test_auto_scheduler_sketch_generation.py index ed6a7e2224c9..5d9e064d551d 100644 --- a/tests/python/unittest/test_auto_scheduler_sketch_generation.py +++ b/tests/python/unittest/test_auto_scheduler_sketch_generation.py @@ -299,38 +299,6 @@ def apply_func(search_policy, state, stage_id): assert sketches[1].stages[2].iters[4].range.extent == 512 -def test_cpu_custom_sketch_registry(): - @auto_scheduler.register_custom_sketch_func - def C(search_policy, state, stage_id): - ret = [] - state = auto_scheduler.loop_state.State(state, search_policy.search_task.compute_dag) - C = state.stage_ops[2] - - ret.append([state.state_object, -1]) - - s1 = state.copy() - i, _, _ = s1[C].iters - s1.split(C, i, [8, 2]) - ret.append([s1.state_object, -1]) - return ret - - sketches = generate_sketches( - matmul_auto_scheduler_test, - (512, 512, 512), - "llvm", - init_search_callbacks=auto_scheduler.get_custom_sketch_callbacks(), - ) - assert len(sketches) == 2 - assert sketches[0].stages[2].iters[0].range.extent == 512 - assert sketches[0].stages[2].iters[1].range.extent == 512 - assert sketches[0].stages[2].iters[2].range.extent == 512 - assert sketches[1].stages[2].iters[0].range.extent == 32 - assert sketches[1].stages[2].iters[1].range.extent == 8 - assert sketches[1].stages[2].iters[2].range.extent == 2 - assert sketches[1].stages[2].iters[3].range.extent == 512 - assert sketches[1].stages[2].iters[4].range.extent == 512 - - @tvm.testing.requires_cuda def test_cuda_matmul_sketch(): sketches = generate_sketches(matmul_auto_scheduler_test, (512, 512, 512), "cuda") From 883e3b0f0518126011fede38ee656ab3f30f015b Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Mon, 18 Jan 2021 16:50:43 +0800 Subject: [PATCH 12/13] Update --- python/tvm/auto_scheduler/search_policy.py | 14 ++++++++------ src/auto_scheduler/search_policy/sketch_policy.h | 9 +++++++++ .../unittest/test_auto_scheduler_search_policy.py | 14 +++++++------- .../test_auto_scheduler_sketch_generation.py | 1 - 4 files changed, 24 insertions(+), 14 deletions(-) diff --git a/python/tvm/auto_scheduler/search_policy.py b/python/tvm/auto_scheduler/search_policy.py index becc7e691c3f..a23455a6b8ce 100644 --- a/python/tvm/auto_scheduler/search_policy.py +++ b/python/tvm/auto_scheduler/search_policy.py @@ -70,22 +70,24 @@ class PreloadCustomSketchRule(SearchCallback): Notes ----- - This is an advanced feature. Make sure you're clear how it - works and this should only be used in SketchSearchPolicy. + This is an advanced feature. Make sure you're clear how it works and this should only be used + in SketchSearchPolicy. Parameters ---------- meet_condition_func: Callable - A function with `(policy, state, stage_id) -> int`. + A function with `(policy, state, stage_id) -> int`. Should return one of the result + enumeration. apply_func: Callable A function with `(policy, state, stage_id) -> [[State, int], ...]`. rule_name: str = "CustomSketchRule" The name of this custom sketch rule. """ - PASS = 0 - APPLY = 1 - APPLY_AND_SKIP_REST = 2 + # Result enumeration of the condition function. + PASS = 0 # Skip this rule and continue to try the next rules. + APPLY = 1 # Apply this rule and continue to try the next rules. + APPLY_AND_SKIP_REST = 2 # Apply this rule and skip the rest rules. def __init__(self, meet_condition_func, apply_func, rule_name="CustomSketchRule"): self.__init_handle_by_constructor__( diff --git a/src/auto_scheduler/search_policy/sketch_policy.h b/src/auto_scheduler/search_policy/sketch_policy.h index e409e83dd471..faf058b45b19 100644 --- a/src/auto_scheduler/search_policy/sketch_policy.h +++ b/src/auto_scheduler/search_policy/sketch_policy.h @@ -200,8 +200,11 @@ class SketchPolicy : public SearchPolicy { /*! \brief Pre-search callback function to load custom rules for sketch generation */ class PreloadCustomSketchRuleNode : public SearchCallbackNode { public: + /*! \brief The condition check function of this rule. */ PackedFunc meet_condition_func; + /*! \brief The apply function of this rule. */ PackedFunc apply_func; + /*! \brief The name of this rule. */ String rule_name; void Callback(SearchPolicyNode* policy) final; @@ -216,6 +219,12 @@ class PreloadCustomSketchRuleNode : public SearchCallbackNode { */ class PreloadCustomSketchRule : public SearchCallback { public: + /*! + * \brief The constructor. + * \param meet_condition_func The condition check function of this rule. + * \param apply_func The apply function of this rule. + * \param rule_name The name of this rule. + */ PreloadCustomSketchRule(PackedFunc meet_condition_func, PackedFunc apply_func, String rule_name); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PreloadCustomSketchRule, SearchCallback, diff --git a/tests/python/unittest/test_auto_scheduler_search_policy.py b/tests/python/unittest/test_auto_scheduler_search_policy.py index 4ae1492709ad..30aafbd22390 100644 --- a/tests/python/unittest/test_auto_scheduler_search_policy.py +++ b/tests/python/unittest/test_auto_scheduler_search_policy.py @@ -210,11 +210,11 @@ def apply_func(search_policy, state, stage_id): if __name__ == "__main__": - # test_workload_registry_empty_policy() - # test_sketch_search_policy_basic() - # test_sketch_search_policy_basic_spawn() - # test_sketch_search_policy_xgbmodel() - # test_sketch_search_policy_cuda_rpc_runner() - # test_sketch_search_policy_cuda_xgbmodel_rpc_runner() - # test_sketch_search_policy_zero_rank() + test_workload_registry_empty_policy() + test_sketch_search_policy_basic() + test_sketch_search_policy_basic_spawn() + test_sketch_search_policy_xgbmodel() + test_sketch_search_policy_cuda_rpc_runner() + test_sketch_search_policy_cuda_xgbmodel_rpc_runner() + test_sketch_search_policy_zero_rank() test_sketch_search_policy_custom_sketch() diff --git a/tests/python/unittest/test_auto_scheduler_sketch_generation.py b/tests/python/unittest/test_auto_scheduler_sketch_generation.py index 5d9e064d551d..f3be6c0bc518 100644 --- a/tests/python/unittest/test_auto_scheduler_sketch_generation.py +++ b/tests/python/unittest/test_auto_scheduler_sketch_generation.py @@ -448,7 +448,6 @@ def test_cuda_zero_rank_sketch(): test_cpu_conv2d_winograd_sketch() test_cpu_zero_rank_sketch() test_cpu_custom_sketch() - test_cpu_custom_sketch_registry() test_cuda_matmul_sketch() test_cuda_conv2d_bn_relu_sketch() test_cuda_max_pool2d_sketch() From 6b172f683016282d621e7c4b782eba427431bfc4 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Mon, 18 Jan 2021 16:59:45 +0800 Subject: [PATCH 13/13] Lint fix --- python/tvm/auto_scheduler/search_policy.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/auto_scheduler/search_policy.py b/python/tvm/auto_scheduler/search_policy.py index a23455a6b8ce..f0388a886c5f 100644 --- a/python/tvm/auto_scheduler/search_policy.py +++ b/python/tvm/auto_scheduler/search_policy.py @@ -34,7 +34,6 @@ import tvm._ffi from tvm.runtime import Object from .cost_model import RandomModel -from .loop_state import State from . import _ffi_api