From be46555e83c60e9b7e89468655cf290fcd86bef8 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Thu, 17 Dec 2020 20:40:23 +0800 Subject: [PATCH 01/14] Add layout rewrite options for measure --- include/tvm/auto_scheduler/search_task.h | 5 ++++- python/tvm/auto_scheduler/measure.py | 3 ++- python/tvm/auto_scheduler/search_task.py | 7 ++++++- src/auto_scheduler/compute_dag.cc | 5 ++++- src/auto_scheduler/feature.cc | 6 ++++-- src/auto_scheduler/measure_record.cc | 4 ++++ src/auto_scheduler/search_task.cc | 10 +++++++--- 7 files changed, 31 insertions(+), 9 deletions(-) diff --git a/include/tvm/auto_scheduler/search_task.h b/include/tvm/auto_scheduler/search_task.h index 60e721bd4389..131402d80910 100755 --- a/include/tvm/auto_scheduler/search_task.h +++ b/include/tvm/auto_scheduler/search_task.h @@ -118,6 +118,8 @@ class SearchTaskNode : public Object { Target target_host; /*! \brief Hardware parameters used in this search task. */ HardwareParams hardware_params; + /*! \brief */ + int layout_rewrite_option; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("compute_dag", &compute_dag); @@ -125,6 +127,7 @@ class SearchTaskNode : public Object { v->Visit("target", &target); v->Visit("target_host", &target_host); v->Visit("hardware_params", &hardware_params); + v->Visit("layout_rewrite_option", &layout_rewrite_option); } static constexpr const char* _type_key = "auto_scheduler.SearchTask"; @@ -146,7 +149,7 @@ class SearchTask : public ObjectRef { * \param hardware_params Hardware parameters used in this search task. */ SearchTask(ComputeDAG compute_dag, String workload_key, Target target, Target target_host, - Optional hardware_params); + Optional hardware_params, int layout_rewrite_option); TVM_DEFINE_OBJECT_REF_METHODS(SearchTask, ObjectRef, SearchTaskNode); }; diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 7e4f14933819..87ac259e79a1 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -186,6 +186,7 @@ def recover_measure_input(inp, rebuild_state=False): target=task.target, target_host=task.target_host, hardware_params=task.hardware_params, + layout_rewrite_option=task.layout_rewrite_option ) if rebuild_state: @@ -551,7 +552,7 @@ def _timed_func(inp_serialized, build_func, verbose): try: sch, args = task.compute_dag.apply_steps_from_state( - inp.state, layout_rewrite=LayoutRewriteOption.REWRITE_FOR_PRE_TRANSFORMED + inp.state, layout_rewrite=task.layout_rewrite_option ) # pylint: disable=broad-except except Exception: diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index be83e06bb89d..3ab91adb754f 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -178,6 +178,7 @@ class SearchTask(Object): The target host device of this search task. hardware_params : Optional[HardwareParams] Hardware parameters used in this search task. + layout_rewrite_option : LayoutRewriteOption = LayoutRewriteOption.NO_REWRITE Examples -------- @@ -204,6 +205,7 @@ def __init__( target=None, target_host=None, hardware_params=None, + layout_rewrite_option=LayoutRewriteOption.NO_REWRITE ): assert ( func is not None or workload_key is not None @@ -221,7 +223,8 @@ def __init__( target_host = Target(target_host) self.__init_handle_by_constructor__( - _ffi_api.SearchTask, compute_dag, workload_key, target, target_host, hardware_params + _ffi_api.SearchTask, compute_dag, workload_key, target, target_host, hardware_params, + layout_rewrite_option ) def tune(self, tuning_options, search_policy=None): @@ -305,6 +308,7 @@ def __getstate__(self): "target": self.target, "target_host": self.target_host, "hardware_params": self.hardware_params, + "layout_rewrite_option": self.layout_rewrite_option } def __setstate__(self, state): @@ -327,6 +331,7 @@ def __setstate__(self, state): state["target"], state["target_host"], state["hardware_params"], + state["layout_rewrite_option"] ) diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index af45f2df8b04..6b7a15ec050a 100755 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -1023,7 +1023,10 @@ ComputeDAG ComputeDAG::RewriteLayout(Array* transform_steps, } original_compute_op = op; CHECK(!new_compute_op.defined()); - new_compute_op = te::ComputeOp(pop->name, pop->tag, pop->attrs, pop->axis, new_body); + auto new_attrs = pop->attrs; + new_attrs.Set("ori_placeholder_layout", tvm::String(origin_layout)); + new_attrs.Set("new_placeholder_layout", tvm::String(new_layout)); + new_compute_op = te::ComputeOp(pop->name, pop->tag, new_attrs, pop->axis, new_body); } } } diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc index 53287a0eddeb..47b9fb60aab4 100755 --- a/src/auto_scheduler/feature.cc +++ b/src/auto_scheduler/feature.cc @@ -1398,7 +1398,8 @@ void GetPerStoreFeaturesFromFile(const std::string& filename, int max_lines, int // rebuild task Array tensors = (*workload_key_to_tensors)(workload_key); task = SearchTask(ComputeDAG(tensors), workload_key, cur_inp->task->target, - cur_inp->task->target_host, cur_inp->task->hardware_params); + cur_inp->task->target_host, cur_inp->task->hardware_params, + cur_inp->task->layout_rewrite_option); task_id = task_cache.size(); // compute min cost for each task @@ -1465,7 +1466,8 @@ void GetPerStoreFeaturesFromMeasurePairs(const Array& inputs, // rebuild task for incomplete measure pairs read from file Array tensors = (*workload_key_to_tensors)(workload_key); task = SearchTask(ComputeDAG(tensors), workload_key, inputs[i]->task->target, - inputs[i]->task->target_host, inputs[i]->task->hardware_params); + inputs[i]->task->target_host, inputs[i]->task->hardware_params, + inputs[i]->task->layout_rewrite_option); } task_id = task_cache.size(); diff --git a/src/auto_scheduler/measure_record.cc b/src/auto_scheduler/measure_record.cc index faf3fca4cfc4..ccb348ea1b3e 100644 --- a/src/auto_scheduler/measure_record.cc +++ b/src/auto_scheduler/measure_record.cc @@ -162,6 +162,7 @@ struct Handler<::tvm::auto_scheduler::SearchTaskNode> { writer->BeginArray(false); writer->WriteArrayItem(std::string(data.workload_key)); writer->WriteArrayItem(data.target->str()); + writer->WriteArrayItem(data.layout_rewrite_option); writer->WriteArrayItem(*data.hardware_params.get()); if (data.target_host.defined()) { writer->WriteArrayItem(data.target_host->str()); @@ -182,6 +183,9 @@ struct Handler<::tvm::auto_scheduler::SearchTaskNode> { reader->Read(&str_value); data->target = ::tvm::Target(str_value); s = reader->NextArrayItem(); + ICHECK(s); + reader->Read(&(data->layout_rewrite_option)); + s = reader->NextArrayItem(); if (s) { reader->Read(hardware_params_node.get()); s = reader->NextArrayItem(); diff --git a/src/auto_scheduler/search_task.cc b/src/auto_scheduler/search_task.cc index 93f34609cbbc..63a78d6ec3d1 100755 --- a/src/auto_scheduler/search_task.cc +++ b/src/auto_scheduler/search_task.cc @@ -113,7 +113,8 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target } SearchTask::SearchTask(ComputeDAG compute_dag, String workload_key, Target target, - Target target_host, Optional hardware_params) { + Target target_host, Optional hardware_params, + int layout_rewrite_option) { auto node = make_object(); node->compute_dag = std::move(compute_dag); node->workload_key = std::move(workload_key); @@ -125,6 +126,7 @@ SearchTask::SearchTask(ComputeDAG compute_dag, String workload_key, Target targe node->hardware_params = HardwareParamsNode::GetDefaultHardwareParams(node->target, node->target_host); } + node->layout_rewrite_option = layout_rewrite_option; data_ = std::move(node); } @@ -139,8 +141,10 @@ TVM_REGISTER_GLOBAL("auto_scheduler.HardwareParams") TVM_REGISTER_GLOBAL("auto_scheduler.SearchTask") .set_body_typed([](ComputeDAG compute_dag, String workload_key, Target target, - Target target_host, Optional hardware_params) { - return SearchTask(compute_dag, workload_key, target, target_host, hardware_params); + Target target_host, Optional hardware_params, + int layout_rewrite_option) { + return SearchTask(compute_dag, workload_key, target, target_host, hardware_params, + layout_rewrite_option); }); } // namespace auto_scheduler From feff2e3fc52cc843046a6c186afd131956a7fc1e Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Thu, 17 Dec 2020 20:40:51 +0800 Subject: [PATCH 02/14] Update schedule for inserted transform stage --- src/auto_scheduler/compute_dag.cc | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index 6b7a15ec050a..851aae99a169 100755 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -997,11 +997,20 @@ ComputeDAG ComputeDAG::RewriteLayout(Array* transform_steps, transform_steps->Set(i, std::move(step)); } } + + // Add schedule for the new added transform stage Array to_fuse; - for (size_t i = 0; i < new_shape.size() - 1; i++) { - to_fuse.push_back(i); + + if (new_shape.size() >= 5) { + to_fuse.push_back(0); + to_fuse.push_back(1); + to_fuse.push_back(2); + transform_steps->push_back(FuseStep(stage_id, to_fuse)); + } else if (new_shape.size() >= 3) { + to_fuse.push_back(0); + to_fuse.push_back(1); + transform_steps->push_back(FuseStep(stage_id, to_fuse)); } - transform_steps->push_back(FuseStep(stage_id, to_fuse)); transform_steps->push_back(AnnotationStep(stage_id, 0, IteratorAnnotation::kParallel)); } From e091099ed030fdf926df5fb860c1bcee0c63271f Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Fri, 18 Dec 2020 10:50:33 +0800 Subject: [PATCH 03/14] Set layout rewrite when tuning for network --- python/tvm/auto_scheduler/relay_integration.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index 2b26fc4931bd..1791b6736c47 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -33,7 +33,7 @@ from tvm.te.tensor import ComputeOp, PlaceholderOp, Tensor from tvm.tir import expr as _expr from . import _ffi_api -from .compute_dag import ComputeDAG +from .compute_dag import ComputeDAG, LayoutRewriteOption from .dispatcher import DispatchContext from .search_task import SearchTask from .workload_registry import register_workload_tensors @@ -126,6 +126,8 @@ def extract_tasks( target=target, target_host=target_host, hardware_params=hardware_params, + # In default, try to apply layout rewrite to improve the performance + layout_rewrite_option=LayoutRewriteOption.REWRITE_FOR_PRE_TRANSFORMED ) ) weights.append(use_count_dict[ccache_key] + 1) From b36ed24ccfbaaf00717c6771db899da9fbc7f995 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 23 Dec 2020 13:46:08 +0800 Subject: [PATCH 04/14] Update --- include/tvm/auto_scheduler/search_task.h | 6 +++--- python/tvm/auto_scheduler/compute_dag.py | 7 ++++++- python/tvm/auto_scheduler/relay_integration.py | 3 ++- python/tvm/auto_scheduler/search_task.py | 16 ++++++++++------ src/auto_scheduler/measure_record.cc | 6 ++++-- src/auto_scheduler/search_task.cc | 4 ++-- 6 files changed, 27 insertions(+), 15 deletions(-) diff --git a/include/tvm/auto_scheduler/search_task.h b/include/tvm/auto_scheduler/search_task.h index 131402d80910..2c52f1e9d138 100755 --- a/include/tvm/auto_scheduler/search_task.h +++ b/include/tvm/auto_scheduler/search_task.h @@ -118,8 +118,8 @@ class SearchTaskNode : public Object { Target target_host; /*! \brief Hardware parameters used in this search task. */ HardwareParams hardware_params; - /*! \brief */ - int layout_rewrite_option; + /*! \brief Layout rewrite option used during program measuring. */ + LayoutRewriteOption layout_rewrite_option; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("compute_dag", &compute_dag); @@ -149,7 +149,7 @@ class SearchTask : public ObjectRef { * \param hardware_params Hardware parameters used in this search task. */ SearchTask(ComputeDAG compute_dag, String workload_key, Target target, Target target_host, - Optional hardware_params, int layout_rewrite_option); + Optional hardware_params, LayoutRewriteOption layout_rewrite_option); TVM_DEFINE_OBJECT_REF_METHODS(SearchTask, ObjectRef, SearchTaskNode); }; diff --git a/python/tvm/auto_scheduler/compute_dag.py b/python/tvm/auto_scheduler/compute_dag.py index 94cb640f3516..b1464aefe881 100755 --- a/python/tvm/auto_scheduler/compute_dag.py +++ b/python/tvm/auto_scheduler/compute_dag.py @@ -32,7 +32,12 @@ class LayoutRewriteOption: - """Options for applying layout rewrite.""" + """ + Options for applying layout rewrite. + + The NO_REWRITE and INSERT_TRANSFORM_STAGE is expected to be used when tuning a dependent op, + and the REWRITE_FOR_PRE_TRANSFORMED is expected to be used when tuning ops inside a network. + """ # Do not perform layout rewrite NO_REWRITE = 0 diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index 1791b6736c47..56a196e9ba95 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -126,7 +126,8 @@ def extract_tasks( target=target, target_host=target_host, hardware_params=hardware_params, - # In default, try to apply layout rewrite to improve the performance + # When auto scheduler is used in end to end network, try to apply layout rewrite + # to improve the overall performance layout_rewrite_option=LayoutRewriteOption.REWRITE_FOR_PRE_TRANSFORMED ) ) diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index 3ab91adb754f..78b1db64c08a 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -179,6 +179,11 @@ class SearchTask(Object): hardware_params : Optional[HardwareParams] Hardware parameters used in this search task. layout_rewrite_option : LayoutRewriteOption = LayoutRewriteOption.NO_REWRITE + The default layout rewrite option used during program measuring. + Cost model will adjust the auto scheduler to find a better schedule for the specified + layout rewrite option. + It's excepted to use NO_REWRITE or INSERT_TRANSFORM_STAGE when tuning a dependent op, and + to use REWRITE_FOR_PRE_TRANSFORMED when tuning ops inside a network for better performance. Examples -------- @@ -243,15 +248,18 @@ def tune(self, tuning_options, search_policy=None): _ffi_api.AutoSchedule(search_policy, tuning_options) - def apply_best(self, log_file, layout_rewrite_option=None): + def apply_best(self, log_file, layout_rewrite_option=LayoutRewriteOption.NO_REWRITE): """Apply the history best from a log file and return the schedule. Parameters ---------- log_file : str The name of the log file. - layout_rewrite_option : Optional[LayoutRewriteOption] + layout_rewrite_option : LayoutRewriteOption = LayoutRewriteOption.NO_REWRITE The layout rewrite option. + For a dependent op, NO_REWRITE or INSERT_TRANSFORM_STAGE may result on different + performance. In experience, op with large shape may get benefit from the option + INSERT_TRANSFORM_STAGE. Returns ------- @@ -263,10 +271,6 @@ def apply_best(self, log_file, layout_rewrite_option=None): "Cannot find any valid schedule for %s in file %s" % (self.workload_key, log_file) ) - if layout_rewrite_option is None: - layout_rewrite_option = LayoutRewriteOption.NO_REWRITE - if self.target.kind.name == "llvm": - layout_rewrite_option = LayoutRewriteOption.INSERT_TRANSFORM_STAGE sch, args = self.compute_dag.apply_steps_from_state(inp.state, layout_rewrite_option) return sch, args diff --git a/src/auto_scheduler/measure_record.cc b/src/auto_scheduler/measure_record.cc index ccb348ea1b3e..ec9d5e24cd2a 100644 --- a/src/auto_scheduler/measure_record.cc +++ b/src/auto_scheduler/measure_record.cc @@ -162,7 +162,7 @@ struct Handler<::tvm::auto_scheduler::SearchTaskNode> { writer->BeginArray(false); writer->WriteArrayItem(std::string(data.workload_key)); writer->WriteArrayItem(data.target->str()); - writer->WriteArrayItem(data.layout_rewrite_option); + writer->WriteArrayItem(static_cast(data.layout_rewrite_option)); writer->WriteArrayItem(*data.hardware_params.get()); if (data.target_host.defined()) { writer->WriteArrayItem(data.target_host->str()); @@ -172,6 +172,7 @@ struct Handler<::tvm::auto_scheduler::SearchTaskNode> { inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_scheduler::SearchTaskNode* data) { bool s; std::string str_value; + int int_value; auto hardware_params_node = ::tvm::make_object<::tvm::auto_scheduler::HardwareParamsNode>(); reader->BeginArray(); s = reader->NextArrayItem(); @@ -184,7 +185,8 @@ struct Handler<::tvm::auto_scheduler::SearchTaskNode> { data->target = ::tvm::Target(str_value); s = reader->NextArrayItem(); ICHECK(s); - reader->Read(&(data->layout_rewrite_option)); + reader->Read(&int_value); + data->layout_rewrite_option = ::tvm::auto_scheduler::LayoutRewriteOption(int_value); s = reader->NextArrayItem(); if (s) { reader->Read(hardware_params_node.get()); diff --git a/src/auto_scheduler/search_task.cc b/src/auto_scheduler/search_task.cc index 63a78d6ec3d1..0abee16fceab 100755 --- a/src/auto_scheduler/search_task.cc +++ b/src/auto_scheduler/search_task.cc @@ -114,7 +114,7 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target SearchTask::SearchTask(ComputeDAG compute_dag, String workload_key, Target target, Target target_host, Optional hardware_params, - int layout_rewrite_option) { + LayoutRewriteOption layout_rewrite_option) { auto node = make_object(); node->compute_dag = std::move(compute_dag); node->workload_key = std::move(workload_key); @@ -144,7 +144,7 @@ TVM_REGISTER_GLOBAL("auto_scheduler.SearchTask") Target target_host, Optional hardware_params, int layout_rewrite_option) { return SearchTask(compute_dag, workload_key, target, target_host, hardware_params, - layout_rewrite_option); + LayoutRewriteOption(layout_rewrite_option)); }); } // namespace auto_scheduler From cee5dd0ad8c64f6fb8aea9b3db76494bade9cc4a Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Thu, 24 Dec 2020 20:01:37 +0800 Subject: [PATCH 05/14] Update the log version --- include/tvm/auto_scheduler/measure_record.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/auto_scheduler/measure_record.h b/include/tvm/auto_scheduler/measure_record.h index 4d7952f74b40..ec40611d49b4 100755 --- a/include/tvm/auto_scheduler/measure_record.h +++ b/include/tvm/auto_scheduler/measure_record.h @@ -34,7 +34,7 @@ namespace tvm { namespace auto_scheduler { -const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.4"; // NOLINT(*) +const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.5"; // NOLINT(*) /*! \brief Callback for logging the input and results of measurements to file */ class RecordToFileNode : public MeasureCallbackNode { From 95f86262798565255f2a02d28897ffc9638f9ceb Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Thu, 24 Dec 2020 20:20:10 +0800 Subject: [PATCH 06/14] Update --- python/tvm/auto_scheduler/compute_dag.py | 2 +- python/tvm/auto_scheduler/search_task.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/tvm/auto_scheduler/compute_dag.py b/python/tvm/auto_scheduler/compute_dag.py index b1464aefe881..ad04934eace6 100755 --- a/python/tvm/auto_scheduler/compute_dag.py +++ b/python/tvm/auto_scheduler/compute_dag.py @@ -35,7 +35,7 @@ class LayoutRewriteOption: """ Options for applying layout rewrite. - The NO_REWRITE and INSERT_TRANSFORM_STAGE is expected to be used when tuning a dependent op, + The NO_REWRITE and INSERT_TRANSFORM_STAGE are expected to be used when tuning a standalone op, and the REWRITE_FOR_PRE_TRANSFORMED is expected to be used when tuning ops inside a network. """ diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index 78b1db64c08a..a030254c9052 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -180,10 +180,10 @@ class SearchTask(Object): Hardware parameters used in this search task. layout_rewrite_option : LayoutRewriteOption = LayoutRewriteOption.NO_REWRITE The default layout rewrite option used during program measuring. - Cost model will adjust the auto scheduler to find a better schedule for the specified - layout rewrite option. - It's excepted to use NO_REWRITE or INSERT_TRANSFORM_STAGE when tuning a dependent op, and - to use REWRITE_FOR_PRE_TRANSFORMED when tuning ops inside a network for better performance. + Auto_scheduler will find a better schedule for the specified layout rewrite option. + The NO_REWRITE and INSERT_TRANSFORM_STAGE are expected to be used when tuning a standalone + op, and the REWRITE_FOR_PRE_TRANSFORMED is expected to be used when tuning ops inside a + network. Examples -------- From 0060036d9ac274081f5e1b6ebe76a264c74bcabe Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Sat, 26 Dec 2020 15:03:57 +0800 Subject: [PATCH 07/14] Update --- include/tvm/auto_scheduler/measure_record.h | 2 +- .../tvm/auto_scheduler/relay_integration.py | 32 ++++++++++++++----- python/tvm/auto_scheduler/search_task.py | 22 ++++++++----- src/auto_scheduler/measure_record.cc | 10 +++--- 4 files changed, 45 insertions(+), 21 deletions(-) diff --git a/include/tvm/auto_scheduler/measure_record.h b/include/tvm/auto_scheduler/measure_record.h index ec40611d49b4..4d7952f74b40 100755 --- a/include/tvm/auto_scheduler/measure_record.h +++ b/include/tvm/auto_scheduler/measure_record.h @@ -34,7 +34,7 @@ namespace tvm { namespace auto_scheduler { -const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.5"; // NOLINT(*) +const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.4"; // NOLINT(*) /*! \brief Callback for logging the input and results of measurements to file */ class RecordToFileNode : public MeasureCallbackNode { diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index 56a196e9ba95..49e27b6d6c74 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -63,6 +63,27 @@ def call_all_topi_funcs(mod, params, target): autotvm.GLOBAL_SCOPE.silent = old_autotvm_silent +def enable_layout_rewrite(target): + """Check if this target should enable_layout_rewrite. + + Parameters + ---------- + target: tvm.target.Target + The compilation target. + + Returns + ------- + enable_layout_rewrite: bool + """ + # only enable layout rewrite for cpu / mali backend + enable_layout_rewrite_targets = ["cpu", "mali"] + enable_layout_rewrite = any( + enable_layout_rewrite_target in target.keys + for enable_layout_rewrite_target in enable_layout_rewrite_targets + ) + return enable_layout_rewrite + + def extract_tasks( mod, params, target, target_host=None, hardware_params=None, include_simple_tasks=False ): @@ -128,7 +149,8 @@ def extract_tasks( hardware_params=hardware_params, # When auto scheduler is used in end to end network, try to apply layout rewrite # to improve the overall performance - layout_rewrite_option=LayoutRewriteOption.REWRITE_FOR_PRE_TRANSFORMED + layout_rewrite_option=LayoutRewriteOption.REWRITE_FOR_PRE_TRANSFORMED \ + if enable_layout_rewrite(target) else LayoutRewriteOption.NO_REWRITE ) ) weights.append(use_count_dict[ccache_key] + 1) @@ -262,13 +284,7 @@ def auto_schedule_topi(outs, has_complex_op): key = register_workload_tensors(dag.hash_key(), io_tensors) - # only enable layout rewrite for cpu / mali backend target = tvm.target.Target.current() - enable_layout_rewrite_targets = ["cpu", "mali"] - enable_layout_rewrite = any( - enable_layout_rewrite_target in target.keys - for enable_layout_rewrite_target in enable_layout_rewrite_targets - ) env = TracingEnvironment.current if env is None: @@ -287,7 +303,7 @@ def auto_schedule_topi(outs, has_complex_op): schedule = te.create_schedule([x.op for x in outs]) elif env.tracing_mode == TracingMode.PREPARE_LAYOUT_REWRITE: # in prepare_layout_rewrite mode - if enable_layout_rewrite and has_layout_free: + if enable_layout_rewrite(target) and has_layout_free: dispatch_ctx = DispatchContext.current state = dispatch_ctx.query(target, key, has_complex_op, dag) if state is None: diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index a030254c9052..2d4646c40ee8 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -178,7 +178,7 @@ class SearchTask(Object): The target host device of this search task. hardware_params : Optional[HardwareParams] Hardware parameters used in this search task. - layout_rewrite_option : LayoutRewriteOption = LayoutRewriteOption.NO_REWRITE + layout_rewrite_option : Optional[LayoutRewriteOption] The default layout rewrite option used during program measuring. Auto_scheduler will find a better schedule for the specified layout rewrite option. The NO_REWRITE and INSERT_TRANSFORM_STAGE are expected to be used when tuning a standalone @@ -210,7 +210,7 @@ def __init__( target=None, target_host=None, hardware_params=None, - layout_rewrite_option=LayoutRewriteOption.NO_REWRITE + layout_rewrite_option=None ): assert ( func is not None or workload_key is not None @@ -227,6 +227,12 @@ def __init__( if isinstance(target_host, str): target_host = Target(target_host) + if layout_rewrite_option is None: + layout_rewrite_option = LayoutRewriteOption.NO_REWRITE + if target.kind.name == "llvm" or \ + ("device" in target.attrs.keys and target.attrs["device"] == "mali"): + layout_rewrite_option = LayoutRewriteOption.INSERT_TRANSFORM_STAGE + self.__init_handle_by_constructor__( _ffi_api.SearchTask, compute_dag, workload_key, target, target_host, hardware_params, layout_rewrite_option @@ -248,18 +254,16 @@ def tune(self, tuning_options, search_policy=None): _ffi_api.AutoSchedule(search_policy, tuning_options) - def apply_best(self, log_file, layout_rewrite_option=LayoutRewriteOption.NO_REWRITE): + def apply_best(self, log_file, layout_rewrite_option=None): """Apply the history best from a log file and return the schedule. Parameters ---------- log_file : str The name of the log file. - layout_rewrite_option : LayoutRewriteOption = LayoutRewriteOption.NO_REWRITE + layout_rewrite_option : Optional[LayoutRewriteOption] The layout rewrite option. - For a dependent op, NO_REWRITE or INSERT_TRANSFORM_STAGE may result on different - performance. In experience, op with large shape may get benefit from the option - INSERT_TRANSFORM_STAGE. + Returns ------- @@ -271,7 +275,9 @@ def apply_best(self, log_file, layout_rewrite_option=LayoutRewriteOption.NO_REWR "Cannot find any valid schedule for %s in file %s" % (self.workload_key, log_file) ) - sch, args = self.compute_dag.apply_steps_from_state(inp.state, layout_rewrite_option) + sch, args = self.compute_dag.apply_steps_from_state( + inp.state, layout_rewrite_option or self.layout_rewrite_option + ) return sch, args def print_best(self, log_file, print_mode="schedule"): diff --git a/src/auto_scheduler/measure_record.cc b/src/auto_scheduler/measure_record.cc index ec9d5e24cd2a..5f009aa56fd0 100644 --- a/src/auto_scheduler/measure_record.cc +++ b/src/auto_scheduler/measure_record.cc @@ -166,6 +166,8 @@ struct Handler<::tvm::auto_scheduler::SearchTaskNode> { writer->WriteArrayItem(*data.hardware_params.get()); if (data.target_host.defined()) { writer->WriteArrayItem(data.target_host->str()); + } else { + writer->WriteArrayItem(""); } writer->EndArray(); } @@ -184,10 +186,6 @@ struct Handler<::tvm::auto_scheduler::SearchTaskNode> { reader->Read(&str_value); data->target = ::tvm::Target(str_value); s = reader->NextArrayItem(); - ICHECK(s); - reader->Read(&int_value); - data->layout_rewrite_option = ::tvm::auto_scheduler::LayoutRewriteOption(int_value); - s = reader->NextArrayItem(); if (s) { reader->Read(hardware_params_node.get()); s = reader->NextArrayItem(); @@ -196,6 +194,10 @@ struct Handler<::tvm::auto_scheduler::SearchTaskNode> { reader->Read(&str_value); data->target_host = ::tvm::Target(str_value); s = reader->NextArrayItem(); + ICHECK(s); + reader->Read(&int_value); + data->layout_rewrite_option = ::tvm::auto_scheduler::LayoutRewriteOption(int_value); + s = reader->NextArrayItem(); ICHECK(!s); } } From eadbbf08dde3726c129b2f38c0244a8cdd030385 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Sat, 26 Dec 2020 15:10:55 +0800 Subject: [PATCH 08/14] Update --- src/auto_scheduler/measure_record.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/auto_scheduler/measure_record.cc b/src/auto_scheduler/measure_record.cc index 5f009aa56fd0..abb8db190d3e 100644 --- a/src/auto_scheduler/measure_record.cc +++ b/src/auto_scheduler/measure_record.cc @@ -167,7 +167,7 @@ struct Handler<::tvm::auto_scheduler::SearchTaskNode> { if (data.target_host.defined()) { writer->WriteArrayItem(data.target_host->str()); } else { - writer->WriteArrayItem(""); + writer->WriteArrayItem(std::string("")); } writer->EndArray(); } From 5e82774a363a330b3332940715bc54a5aea7034e Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Sat, 26 Dec 2020 15:27:19 +0800 Subject: [PATCH 09/14] Bug fix for CI --- python/tvm/auto_scheduler/measure.py | 2 +- python/tvm/auto_scheduler/search_task.py | 2 +- src/auto_scheduler/measure_record.cc | 6 ++++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 87ac259e79a1..f0e2d014767f 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -186,7 +186,7 @@ def recover_measure_input(inp, rebuild_state=False): target=task.target, target_host=task.target_host, hardware_params=task.hardware_params, - layout_rewrite_option=task.layout_rewrite_option + layout_rewrite_option=task.layout_rewrite_option, ) if rebuild_state: diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index 2d4646c40ee8..075237323670 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -230,7 +230,7 @@ def __init__( if layout_rewrite_option is None: layout_rewrite_option = LayoutRewriteOption.NO_REWRITE if target.kind.name == "llvm" or \ - ("device" in target.attrs.keys and target.attrs["device"] == "mali"): + ("device" in target.attrs and target.attrs["device"] == "mali"): layout_rewrite_option = LayoutRewriteOption.INSERT_TRANSFORM_STAGE self.__init_handle_by_constructor__( diff --git a/src/auto_scheduler/measure_record.cc b/src/auto_scheduler/measure_record.cc index abb8db190d3e..1120f437b176 100644 --- a/src/auto_scheduler/measure_record.cc +++ b/src/auto_scheduler/measure_record.cc @@ -162,13 +162,13 @@ struct Handler<::tvm::auto_scheduler::SearchTaskNode> { writer->BeginArray(false); writer->WriteArrayItem(std::string(data.workload_key)); writer->WriteArrayItem(data.target->str()); - writer->WriteArrayItem(static_cast(data.layout_rewrite_option)); writer->WriteArrayItem(*data.hardware_params.get()); if (data.target_host.defined()) { writer->WriteArrayItem(data.target_host->str()); } else { writer->WriteArrayItem(std::string("")); } + writer->WriteArrayItem(static_cast(data.layout_rewrite_option)); writer->EndArray(); } inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_scheduler::SearchTaskNode* data) { @@ -192,7 +192,9 @@ struct Handler<::tvm::auto_scheduler::SearchTaskNode> { data->hardware_params = ::tvm::auto_scheduler::HardwareParams(hardware_params_node); if (s) { reader->Read(&str_value); - data->target_host = ::tvm::Target(str_value); + if (!str_value.empty()) { + data->target_host = ::tvm::Target(str_value); + } s = reader->NextArrayItem(); ICHECK(s); reader->Read(&int_value); From a4bc457eca621573e367f80f7b454219809e8d34 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Sat, 26 Dec 2020 15:32:04 +0800 Subject: [PATCH 10/14] Pylint fix --- .../tvm/auto_scheduler/relay_integration.py | 5 +++-- python/tvm/auto_scheduler/search_task.py | 22 ++++++++++++------- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index 49e27b6d6c74..9a45710efbfd 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -149,8 +149,9 @@ def extract_tasks( hardware_params=hardware_params, # When auto scheduler is used in end to end network, try to apply layout rewrite # to improve the overall performance - layout_rewrite_option=LayoutRewriteOption.REWRITE_FOR_PRE_TRANSFORMED \ - if enable_layout_rewrite(target) else LayoutRewriteOption.NO_REWRITE + layout_rewrite_option=LayoutRewriteOption.REWRITE_FOR_PRE_TRANSFORMED + if enable_layout_rewrite(target) + else LayoutRewriteOption.NO_REWRITE, ) ) weights.append(use_count_dict[ccache_key] + 1) diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index 075237323670..76c0e9580ea1 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -210,7 +210,7 @@ def __init__( target=None, target_host=None, hardware_params=None, - layout_rewrite_option=None + layout_rewrite_option=None, ): assert ( func is not None or workload_key is not None @@ -229,13 +229,19 @@ def __init__( if layout_rewrite_option is None: layout_rewrite_option = LayoutRewriteOption.NO_REWRITE - if target.kind.name == "llvm" or \ - ("device" in target.attrs and target.attrs["device"] == "mali"): + if target.kind.name == "llvm" or ( + "device" in target.attrs and target.attrs["device"] == "mali" + ): layout_rewrite_option = LayoutRewriteOption.INSERT_TRANSFORM_STAGE self.__init_handle_by_constructor__( - _ffi_api.SearchTask, compute_dag, workload_key, target, target_host, hardware_params, - layout_rewrite_option + _ffi_api.SearchTask, + compute_dag, + workload_key, + target, + target_host, + hardware_params, + layout_rewrite_option, ) def tune(self, tuning_options, search_policy=None): @@ -263,7 +269,7 @@ def apply_best(self, log_file, layout_rewrite_option=None): The name of the log file. layout_rewrite_option : Optional[LayoutRewriteOption] The layout rewrite option. - + Returns ------- @@ -318,7 +324,7 @@ def __getstate__(self): "target": self.target, "target_host": self.target_host, "hardware_params": self.hardware_params, - "layout_rewrite_option": self.layout_rewrite_option + "layout_rewrite_option": self.layout_rewrite_option, } def __setstate__(self, state): @@ -341,7 +347,7 @@ def __setstate__(self, state): state["target"], state["target_host"], state["hardware_params"], - state["layout_rewrite_option"] + state["layout_rewrite_option"], ) From 03db2586f67cf6eee130fa7f9fe500f5de002e8f Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Sat, 26 Dec 2020 15:41:27 +0800 Subject: [PATCH 11/14] Pylint fix --- python/tvm/auto_scheduler/measure.py | 1 - python/tvm/auto_scheduler/relay_integration.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index f0e2d014767f..cfe31d3bdbc1 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -53,7 +53,6 @@ make_traceback_info, request_remote, ) -from .compute_dag import LayoutRewriteOption from .workload_registry import ( serialize_workload_registry_entry, deserialize_workload_registry_entry, diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index 9a45710efbfd..c67707e66c5c 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -77,11 +77,10 @@ def enable_layout_rewrite(target): """ # only enable layout rewrite for cpu / mali backend enable_layout_rewrite_targets = ["cpu", "mali"] - enable_layout_rewrite = any( + return any( enable_layout_rewrite_target in target.keys for enable_layout_rewrite_target in enable_layout_rewrite_targets ) - return enable_layout_rewrite def extract_tasks( From 6d6c57fb763f57e82cef118b0cf05988838f9dbf Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Sat, 26 Dec 2020 15:49:16 +0800 Subject: [PATCH 12/14] Update --- include/tvm/auto_scheduler/search_task.h | 1 + python/tvm/auto_scheduler/search_task.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/include/tvm/auto_scheduler/search_task.h b/include/tvm/auto_scheduler/search_task.h index 2c52f1e9d138..8460cc795302 100755 --- a/include/tvm/auto_scheduler/search_task.h +++ b/include/tvm/auto_scheduler/search_task.h @@ -147,6 +147,7 @@ class SearchTask : public ObjectRef { * \param target The target device of this search task. * \param target_host The target host device of this search task. * \param hardware_params Hardware parameters used in this search task. + * \param layout_rewrite_option The default layout rewrite option used during program measuring. */ SearchTask(ComputeDAG compute_dag, String workload_key, Target target, Target target_host, Optional hardware_params, LayoutRewriteOption layout_rewrite_option); diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index 76c0e9580ea1..abcc1bdff8ef 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -179,7 +179,8 @@ class SearchTask(Object): hardware_params : Optional[HardwareParams] Hardware parameters used in this search task. layout_rewrite_option : Optional[LayoutRewriteOption] - The default layout rewrite option used during program measuring. + The layout rewrite option used during program measuring. If None, the + INSERT_TRANSFORM_STAGE will be used for cpu and mali gpu, else NO_REWRITE will be used. Auto_scheduler will find a better schedule for the specified layout rewrite option. The NO_REWRITE and INSERT_TRANSFORM_STAGE are expected to be used when tuning a standalone op, and the REWRITE_FOR_PRE_TRANSFORMED is expected to be used when tuning ops inside a From 37190a45b98f3c33efb95ff0c077714bf8c355f9 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Sat, 26 Dec 2020 17:29:54 +0800 Subject: [PATCH 13/14] Update --- include/tvm/auto_scheduler/measure_record.h | 2 +- include/tvm/auto_scheduler/search_task.h | 4 +-- python/tvm/auto_scheduler/compute_dag.py | 25 +++++++++++++++++ .../tvm/auto_scheduler/relay_integration.py | 27 +++---------------- python/tvm/auto_scheduler/search_task.py | 13 +++------ 5 files changed, 34 insertions(+), 37 deletions(-) diff --git a/include/tvm/auto_scheduler/measure_record.h b/include/tvm/auto_scheduler/measure_record.h index 4d7952f74b40..ec40611d49b4 100755 --- a/include/tvm/auto_scheduler/measure_record.h +++ b/include/tvm/auto_scheduler/measure_record.h @@ -34,7 +34,7 @@ namespace tvm { namespace auto_scheduler { -const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.4"; // NOLINT(*) +const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.5"; // NOLINT(*) /*! \brief Callback for logging the input and results of measurements to file */ class RecordToFileNode : public MeasureCallbackNode { diff --git a/include/tvm/auto_scheduler/search_task.h b/include/tvm/auto_scheduler/search_task.h index 8460cc795302..9e7d3aa2cd32 100755 --- a/include/tvm/auto_scheduler/search_task.h +++ b/include/tvm/auto_scheduler/search_task.h @@ -118,7 +118,7 @@ class SearchTaskNode : public Object { Target target_host; /*! \brief Hardware parameters used in this search task. */ HardwareParams hardware_params; - /*! \brief Layout rewrite option used during program measuring. */ + /*! \brief The layout rewrite option used for measuring programs. */ LayoutRewriteOption layout_rewrite_option; void VisitAttrs(tvm::AttrVisitor* v) { @@ -147,7 +147,7 @@ class SearchTask : public ObjectRef { * \param target The target device of this search task. * \param target_host The target host device of this search task. * \param hardware_params Hardware parameters used in this search task. - * \param layout_rewrite_option The default layout rewrite option used during program measuring. + * \param layout_rewrite_option The layout rewrite option used for measuring programs. */ SearchTask(ComputeDAG compute_dag, String workload_key, Target target, Target target_host, Optional hardware_params, LayoutRewriteOption layout_rewrite_option); diff --git a/python/tvm/auto_scheduler/compute_dag.py b/python/tvm/auto_scheduler/compute_dag.py index ad04934eace6..38ddedbf7a0a 100755 --- a/python/tvm/auto_scheduler/compute_dag.py +++ b/python/tvm/auto_scheduler/compute_dag.py @@ -49,6 +49,31 @@ class LayoutRewriteOption: # so this option must be used along with `AutoSchedulerLayoutRewrite` pass in Relay. REWRITE_FOR_PRE_TRANSFORMED = 2 + @staticmethod + def get_target_default(target, in_relay_integration=False): + """ Get the default layout rewrite option for the specified target. + Currently we only enable layout rewrite for cpu / mali backend for now + + Parameters + ---------- + target: tvm.target.Target + The compilation target. + in_relay_integration: bool + If this check is ask for relay integration. + + Returns + ------- + layout_rewrite_option: LayoutRewriteOption + The default layout rewrite option for the specified target. + """ + layout_rewrite_option = LayoutRewriteOption.NO_REWRITE + if target.kind.name == "llvm" or ( + "device" in target.attrs and target.attrs["device"] == "mali" + ): + layout_rewrite_option = LayoutRewriteOption.REWRITE_FOR_PRE_TRANSFORMED \ + if in_relay_integration else LayoutRewriteOption.INSERT_TRANSFORM_STAGE + + return layout_rewrite_option @tvm._ffi.register_object("auto_scheduler.ComputeDAG") class ComputeDAG(Object): diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index c67707e66c5c..0ac65c8f3bc6 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -63,26 +63,6 @@ def call_all_topi_funcs(mod, params, target): autotvm.GLOBAL_SCOPE.silent = old_autotvm_silent -def enable_layout_rewrite(target): - """Check if this target should enable_layout_rewrite. - - Parameters - ---------- - target: tvm.target.Target - The compilation target. - - Returns - ------- - enable_layout_rewrite: bool - """ - # only enable layout rewrite for cpu / mali backend - enable_layout_rewrite_targets = ["cpu", "mali"] - return any( - enable_layout_rewrite_target in target.keys - for enable_layout_rewrite_target in enable_layout_rewrite_targets - ) - - def extract_tasks( mod, params, target, target_host=None, hardware_params=None, include_simple_tasks=False ): @@ -148,9 +128,7 @@ def extract_tasks( hardware_params=hardware_params, # When auto scheduler is used in end to end network, try to apply layout rewrite # to improve the overall performance - layout_rewrite_option=LayoutRewriteOption.REWRITE_FOR_PRE_TRANSFORMED - if enable_layout_rewrite(target) - else LayoutRewriteOption.NO_REWRITE, + layout_rewrite_option=LayoutRewriteOption.get_target_default(target, True), ) ) weights.append(use_count_dict[ccache_key] + 1) @@ -303,7 +281,8 @@ def auto_schedule_topi(outs, has_complex_op): schedule = te.create_schedule([x.op for x in outs]) elif env.tracing_mode == TracingMode.PREPARE_LAYOUT_REWRITE: # in prepare_layout_rewrite mode - if enable_layout_rewrite(target) and has_layout_free: + if LayoutRewriteOption.get_target_default(target, True) != LayoutRewriteOption.NO_REWRITE \ + and has_layout_free: dispatch_ctx = DispatchContext.current state = dispatch_ctx.query(target, key, has_complex_op, dag) if state is None: diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index abcc1bdff8ef..bfa596a1dc61 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -179,8 +179,8 @@ class SearchTask(Object): hardware_params : Optional[HardwareParams] Hardware parameters used in this search task. layout_rewrite_option : Optional[LayoutRewriteOption] - The layout rewrite option used during program measuring. If None, the - INSERT_TRANSFORM_STAGE will be used for cpu and mali gpu, else NO_REWRITE will be used. + The layout rewrite option used for measuring programs. If None, the default value will be + set depending on the specified target. Auto_scheduler will find a better schedule for the specified layout rewrite option. The NO_REWRITE and INSERT_TRANSFORM_STAGE are expected to be used when tuning a standalone op, and the REWRITE_FOR_PRE_TRANSFORMED is expected to be used when tuning ops inside a @@ -228,13 +228,6 @@ def __init__( if isinstance(target_host, str): target_host = Target(target_host) - if layout_rewrite_option is None: - layout_rewrite_option = LayoutRewriteOption.NO_REWRITE - if target.kind.name == "llvm" or ( - "device" in target.attrs and target.attrs["device"] == "mali" - ): - layout_rewrite_option = LayoutRewriteOption.INSERT_TRANSFORM_STAGE - self.__init_handle_by_constructor__( _ffi_api.SearchTask, compute_dag, @@ -242,7 +235,7 @@ def __init__( target, target_host, hardware_params, - layout_rewrite_option, + layout_rewrite_option or LayoutRewriteOption.get_target_default(target), ) def tune(self, tuning_options, search_policy=None): From ea2b1705ef942d2e055b984585130f48323f3f85 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Sat, 26 Dec 2020 17:33:06 +0800 Subject: [PATCH 14/14] Lint fix --- python/tvm/auto_scheduler/compute_dag.py | 10 +++++++--- python/tvm/auto_scheduler/relay_integration.py | 6 ++++-- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/python/tvm/auto_scheduler/compute_dag.py b/python/tvm/auto_scheduler/compute_dag.py index 38ddedbf7a0a..d84eb1f7ad39 100755 --- a/python/tvm/auto_scheduler/compute_dag.py +++ b/python/tvm/auto_scheduler/compute_dag.py @@ -51,7 +51,7 @@ class LayoutRewriteOption: @staticmethod def get_target_default(target, in_relay_integration=False): - """ Get the default layout rewrite option for the specified target. + """Get the default layout rewrite option for the specified target. Currently we only enable layout rewrite for cpu / mali backend for now Parameters @@ -70,11 +70,15 @@ def get_target_default(target, in_relay_integration=False): if target.kind.name == "llvm" or ( "device" in target.attrs and target.attrs["device"] == "mali" ): - layout_rewrite_option = LayoutRewriteOption.REWRITE_FOR_PRE_TRANSFORMED \ - if in_relay_integration else LayoutRewriteOption.INSERT_TRANSFORM_STAGE + layout_rewrite_option = ( + LayoutRewriteOption.REWRITE_FOR_PRE_TRANSFORMED + if in_relay_integration + else LayoutRewriteOption.INSERT_TRANSFORM_STAGE + ) return layout_rewrite_option + @tvm._ffi.register_object("auto_scheduler.ComputeDAG") class ComputeDAG(Object): """ diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index 0ac65c8f3bc6..3287f3d4a1e5 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -281,8 +281,10 @@ def auto_schedule_topi(outs, has_complex_op): schedule = te.create_schedule([x.op for x in outs]) elif env.tracing_mode == TracingMode.PREPARE_LAYOUT_REWRITE: # in prepare_layout_rewrite mode - if LayoutRewriteOption.get_target_default(target, True) != LayoutRewriteOption.NO_REWRITE \ - and has_layout_free: + if ( + LayoutRewriteOption.get_target_default(target, True) != LayoutRewriteOption.NO_REWRITE + and has_layout_free + ): dispatch_ctx = DispatchContext.current state = dispatch_ctx.query(target, key, has_complex_op, dag) if state is None: