From d891a52311776d0b5c83b9ebc0ac86ffc43de0d1 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Mon, 4 Jan 2021 16:56:24 +0800 Subject: [PATCH 01/12] Enhancement for autoscheduler cost model --- .../auto_scheduler/cost_model/xgb_model.py | 20 ++++++++++++++++++- python/tvm/auto_scheduler/task_scheduler.py | 11 +++++----- src/auto_scheduler/feature.cc | 18 +++++++++++------ 3 files changed, 37 insertions(+), 12 deletions(-) diff --git a/python/tvm/auto_scheduler/cost_model/xgb_model.py b/python/tvm/auto_scheduler/cost_model/xgb_model.py index eb14dff0815c..f1e4c342d117 100644 --- a/python/tvm/auto_scheduler/cost_model/xgb_model.py +++ b/python/tvm/auto_scheduler/cost_model/xgb_model.py @@ -88,7 +88,7 @@ class XGBModel(PythonBasedModel): their predictions. """ - def __init__(self, verbose_eval=25, num_warmup_sample=100, seed=None): + def __init__(self, verbose_eval=25, num_warmup_sample=100, seed=None, model_file=None): global xgb try: if xgb is None: @@ -116,12 +116,17 @@ def __init__(self, verbose_eval=25, num_warmup_sample=100, seed=None): self.plan_size = 32 self.num_warmup_sample = num_warmup_sample self.verbose_eval = verbose_eval + self.model_file = model_file + if model_file: + logger.info("XGBModel: Load pretrained model from %s..." % model_file) + self.load(model_file) super().__init__() # cache measurement input/result pairs and extracted features self.inputs = [] self.results = [] + self.last_train_length = 0 self.inputs_feature_cache = [] def update(self, inputs, results): @@ -141,6 +146,15 @@ def update(self, inputs, results): self.inputs.extend(inputs) self.results.extend(results) + print("self.inputs: ", len(self.inputs)) + print("self.last_train_length", self.last_train_length) + + if len(self.inputs) - self.last_train_length < self.last_train_length / 5: + # Skip if the added + return + else: + self.last_train_length = len(self.inputs) + # extract feature n_cached = len(self.inputs_feature_cache) features, normalized_throughputs, task_ids = get_per_store_features_from_measure_pairs( @@ -176,6 +190,9 @@ def update(self, inputs, results): ], ) + if self.model_file: + self.save(self.model_file) + def predict(self, task, states): """Predict the scores of states Parameters @@ -298,6 +315,7 @@ def load(self, file_name: str): file_name: str The filename """ + print(file_name) if self.bst is None: self.bst = xgb.Booster(self.xgb_params) self.bst.load_model(file_name) diff --git a/python/tvm/auto_scheduler/task_scheduler.py b/python/tvm/auto_scheduler/task_scheduler.py index ab83ff40c461..a13011d1e0ed 100644 --- a/python/tvm/auto_scheduler/task_scheduler.py +++ b/python/tvm/auto_scheduler/task_scheduler.py @@ -82,11 +82,12 @@ def make_search_policies( if isinstance(search_policy, str): policy_type, model_type = search_policy.split(".") if model_type == "xgb": - cost_model = XGBModel(num_warmup_sample=len(tasks) * num_measures_per_round) - if load_model_file: - logger.info("TaskScheduler: Load pretrained model...") - cost_model.load(load_model_file) - elif load_log_file: + cost_model = XGBModel( + num_warmup_sample=len(tasks) * num_measures_per_round, + model_file=load_model_file, + ) + if load_log_file: + logger.info("TaskScheduler: Reload measured states and pretrain model...") cost_model.update_from_file(load_log_file) elif model_type == "random": cost_model = RandomModel() diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc index 47b9fb60aab4..a5d4958af769 100755 --- a/src/auto_scheduler/feature.cc +++ b/src/auto_scheduler/feature.cc @@ -1462,12 +1462,18 @@ void GetPerStoreFeaturesFromMeasurePairs(const Array& inputs, if (find_res == task_cache.end()) { if (inputs[i]->task->compute_dag.defined()) { // the measure input is complete task = inputs[i]->task; - } else { // the measure input is incomplete - // rebuild task for incomplete measure pairs read from file - Array tensors = (*workload_key_to_tensors)(workload_key); - task = SearchTask(ComputeDAG(tensors), workload_key, inputs[i]->task->target, - inputs[i]->task->target_host, inputs[i]->task->hardware_params, - inputs[i]->task->layout_rewrite_option); + } else { + // The measure input is incomplete, rebuild task for incomplete measure pairs read from file + try { + Array tensors = (*workload_key_to_tensors)(workload_key); + task = SearchTask(ComputeDAG(tensors), workload_key, inputs[i]->task->target, + inputs[i]->task->target_host, inputs[i]->task->hardware_params, + inputs[i]->task->layout_rewrite_option); + } catch (std::exception& e) { + // Cannot build ComputeDAG from workload key, the task may have not been registered in + // this search round + continue; + } } task_id = task_cache.size(); From 27bbe860bda2ffac005f2eef4a78138246baa5ca Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Mon, 4 Jan 2021 16:56:48 +0800 Subject: [PATCH 02/12] Bug fix for graph_runtime_debug --- src/runtime/graph/debug/graph_runtime_debug.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/runtime/graph/debug/graph_runtime_debug.cc b/src/runtime/graph/debug/graph_runtime_debug.cc index 3353c117318b..5561fdc54879 100644 --- a/src/runtime/graph/debug/graph_runtime_debug.cc +++ b/src/runtime/graph/debug/graph_runtime_debug.cc @@ -154,7 +154,7 @@ class GraphRuntimeDebug : public GraphRuntime { TVMSynchronize(ctx.device_type, ctx.device_id, nullptr); auto op_tend = std::chrono::high_resolution_clock::now(); double op_duration = - std::chrono::duration_cast >(op_tend - op_tbegin).count(); + std::chrono::duration_cast(op_tend - op_tbegin).count(); return op_duration; } From 1022c581e0e34a0839fba38ea2b24634c9ef0cb9 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Mon, 4 Jan 2021 17:07:09 +0800 Subject: [PATCH 03/12] Update --- python/tvm/auto_scheduler/cost_model/xgb_model.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/python/tvm/auto_scheduler/cost_model/xgb_model.py b/python/tvm/auto_scheduler/cost_model/xgb_model.py index f1e4c342d117..e92142c3d311 100644 --- a/python/tvm/auto_scheduler/cost_model/xgb_model.py +++ b/python/tvm/auto_scheduler/cost_model/xgb_model.py @@ -146,11 +146,9 @@ def update(self, inputs, results): self.inputs.extend(inputs) self.results.extend(results) - print("self.inputs: ", len(self.inputs)) - print("self.last_train_length", self.last_train_length) - if len(self.inputs) - self.last_train_length < self.last_train_length / 5: - # Skip if the added + # Set a training threshold related to `last_train_length` to reduce the training + # overhead when there're too many logs return else: self.last_train_length = len(self.inputs) @@ -190,6 +188,7 @@ def update(self, inputs, results): ], ) + # Update the model file if it has been set if self.model_file: self.save(self.model_file) @@ -315,7 +314,6 @@ def load(self, file_name: str): file_name: str The filename """ - print(file_name) if self.bst is None: self.bst = xgb.Booster(self.xgb_params) self.bst.load_model(file_name) From 76610651c732b752ba14e68e1ed2092207742bb4 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Mon, 4 Jan 2021 17:25:18 +0800 Subject: [PATCH 04/12] Lint fix --- python/tvm/auto_scheduler/cost_model/xgb_model.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/tvm/auto_scheduler/cost_model/xgb_model.py b/python/tvm/auto_scheduler/cost_model/xgb_model.py index e92142c3d311..7798139c8272 100644 --- a/python/tvm/auto_scheduler/cost_model/xgb_model.py +++ b/python/tvm/auto_scheduler/cost_model/xgb_model.py @@ -118,7 +118,7 @@ def __init__(self, verbose_eval=25, num_warmup_sample=100, seed=None, model_file self.verbose_eval = verbose_eval self.model_file = model_file if model_file: - logger.info("XGBModel: Load pretrained model from %s..." % model_file) + logger.info("XGBModel: Load pretrained model from %s...", model_file) self.load(model_file) super().__init__() @@ -150,8 +150,7 @@ def update(self, inputs, results): # Set a training threshold related to `last_train_length` to reduce the training # overhead when there're too many logs return - else: - self.last_train_length = len(self.inputs) + self.last_train_length = len(self.inputs) # extract feature n_cached = len(self.inputs_feature_cache) From 25f6cf5215cc867d384590d4794d9da0238a750a Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Mon, 4 Jan 2021 17:54:14 +0800 Subject: [PATCH 05/12] Update --- python/tvm/auto_scheduler/task_scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/auto_scheduler/task_scheduler.py b/python/tvm/auto_scheduler/task_scheduler.py index a13011d1e0ed..73fccf2abb5a 100644 --- a/python/tvm/auto_scheduler/task_scheduler.py +++ b/python/tvm/auto_scheduler/task_scheduler.py @@ -87,7 +87,7 @@ def make_search_policies( model_file=load_model_file, ) if load_log_file: - logger.info("TaskScheduler: Reload measured states and pretrain model...") + logger.info("TaskScheduler: Reload measured states and train the model...") cost_model.update_from_file(load_log_file) elif model_type == "random": cost_model = RandomModel() From 7659b7afbd4dec1609c3e7630289c160ae91179f Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Mon, 4 Jan 2021 19:14:36 +0800 Subject: [PATCH 06/12] Update --- src/runtime/graph/debug/graph_runtime_debug.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/runtime/graph/debug/graph_runtime_debug.cc b/src/runtime/graph/debug/graph_runtime_debug.cc index 5561fdc54879..288c5e2248b0 100644 --- a/src/runtime/graph/debug/graph_runtime_debug.cc +++ b/src/runtime/graph/debug/graph_runtime_debug.cc @@ -153,9 +153,10 @@ class GraphRuntimeDebug : public GraphRuntime { const TVMContext& ctx = data_entry_[entry_id(index, 0)]->ctx; TVMSynchronize(ctx.device_type, ctx.device_id, nullptr); auto op_tend = std::chrono::high_resolution_clock::now(); - double op_duration = - std::chrono::duration_cast(op_tend - op_tbegin).count(); - return op_duration; + double op_duration_us = + std::chrono::duration_cast >(op_tend - op_tbegin).count() * + 1e6; + return op_duration_us; } /*! From 0230ddf3436f82a57719a8761fcccf39626ee2be Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Mon, 4 Jan 2021 20:45:49 +0800 Subject: [PATCH 07/12] Add file exist check for cost model load --- python/tvm/auto_scheduler/cost_model/xgb_model.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/tvm/auto_scheduler/cost_model/xgb_model.py b/python/tvm/auto_scheduler/cost_model/xgb_model.py index 7798139c8272..33ac36332c22 100644 --- a/python/tvm/auto_scheduler/cost_model/xgb_model.py +++ b/python/tvm/auto_scheduler/cost_model/xgb_model.py @@ -19,6 +19,7 @@ """Cost model based on xgboost""" import multiprocessing import logging +import os from collections import defaultdict import numpy as np @@ -313,6 +314,9 @@ def load(self, file_name: str): file_name: str The filename """ + if not os.path.isfile(file_name): + return + if self.bst is None: self.bst = xgb.Booster(self.xgb_params) self.bst.load_model(file_name) From a600630aed6072f4cc8c1d38f16076c391906e7b Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Tue, 5 Jan 2021 17:46:01 +0800 Subject: [PATCH 08/12] Update --- python/tvm/auto_scheduler/cost_model/xgb_model.py | 4 ---- python/tvm/auto_scheduler/task_scheduler.py | 5 ++++- src/runtime/graph/debug/graph_runtime_debug.cc | 7 +++---- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/python/tvm/auto_scheduler/cost_model/xgb_model.py b/python/tvm/auto_scheduler/cost_model/xgb_model.py index 33ac36332c22..7798139c8272 100644 --- a/python/tvm/auto_scheduler/cost_model/xgb_model.py +++ b/python/tvm/auto_scheduler/cost_model/xgb_model.py @@ -19,7 +19,6 @@ """Cost model based on xgboost""" import multiprocessing import logging -import os from collections import defaultdict import numpy as np @@ -314,9 +313,6 @@ def load(self, file_name: str): file_name: str The filename """ - if not os.path.isfile(file_name): - return - if self.bst is None: self.bst = xgb.Booster(self.xgb_params) self.bst.load_model(file_name) diff --git a/python/tvm/auto_scheduler/task_scheduler.py b/python/tvm/auto_scheduler/task_scheduler.py index 73fccf2abb5a..97466e98f230 100644 --- a/python/tvm/auto_scheduler/task_scheduler.py +++ b/python/tvm/auto_scheduler/task_scheduler.py @@ -86,7 +86,10 @@ def make_search_policies( num_warmup_sample=len(tasks) * num_measures_per_round, model_file=load_model_file, ) - if load_log_file: + if load_model_file: + logger.info("TaskScheduler: Load pretrained model...") + cost_model.load(load_model_file) + elif load_log_file: logger.info("TaskScheduler: Reload measured states and train the model...") cost_model.update_from_file(load_log_file) elif model_type == "random": diff --git a/src/runtime/graph/debug/graph_runtime_debug.cc b/src/runtime/graph/debug/graph_runtime_debug.cc index 288c5e2248b0..3353c117318b 100644 --- a/src/runtime/graph/debug/graph_runtime_debug.cc +++ b/src/runtime/graph/debug/graph_runtime_debug.cc @@ -153,10 +153,9 @@ class GraphRuntimeDebug : public GraphRuntime { const TVMContext& ctx = data_entry_[entry_id(index, 0)]->ctx; TVMSynchronize(ctx.device_type, ctx.device_id, nullptr); auto op_tend = std::chrono::high_resolution_clock::now(); - double op_duration_us = - std::chrono::duration_cast >(op_tend - op_tbegin).count() * - 1e6; - return op_duration_us; + double op_duration = + std::chrono::duration_cast >(op_tend - op_tbegin).count(); + return op_duration; } /*! From 85cd50a555f005991791a8a02fc6aef6f5cab6cb Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Tue, 5 Jan 2021 18:00:44 +0800 Subject: [PATCH 09/12] Update --- python/tvm/auto_scheduler/cost_model/xgb_model.py | 9 ++++----- python/tvm/auto_scheduler/task_scheduler.py | 5 +++++ 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/python/tvm/auto_scheduler/cost_model/xgb_model.py b/python/tvm/auto_scheduler/cost_model/xgb_model.py index 7798139c8272..cd1d82438a51 100644 --- a/python/tvm/auto_scheduler/cost_model/xgb_model.py +++ b/python/tvm/auto_scheduler/cost_model/xgb_model.py @@ -88,7 +88,8 @@ class XGBModel(PythonBasedModel): their predictions. """ - def __init__(self, verbose_eval=25, num_warmup_sample=100, seed=None, model_file=None): + def __init__(self, verbose_eval=25, num_warmup_sample=100, seed=None, model_file=None, + adapative_training=False): global xgb try: if xgb is None: @@ -117,9 +118,6 @@ def __init__(self, verbose_eval=25, num_warmup_sample=100, seed=None, model_file self.num_warmup_sample = num_warmup_sample self.verbose_eval = verbose_eval self.model_file = model_file - if model_file: - logger.info("XGBModel: Load pretrained model from %s...", model_file) - self.load(model_file) super().__init__() @@ -146,7 +144,8 @@ def update(self, inputs, results): self.inputs.extend(inputs) self.results.extend(results) - if len(self.inputs) - self.last_train_length < self.last_train_length / 5: + if self.adapative_training and \ + len(self.inputs) - self.last_train_length < self.last_train_length / 5: # Set a training threshold related to `last_train_length` to reduce the training # overhead when there're too many logs return diff --git a/python/tvm/auto_scheduler/task_scheduler.py b/python/tvm/auto_scheduler/task_scheduler.py index 97466e98f230..68913ade59f8 100644 --- a/python/tvm/auto_scheduler/task_scheduler.py +++ b/python/tvm/auto_scheduler/task_scheduler.py @@ -47,6 +47,7 @@ def make_search_policies( verbose, load_model_file=None, load_log_file=None, + adapative_training=False, ): """Make a list of search policies for a list of search tasks. It creates one policy per task. @@ -70,6 +71,9 @@ def make_search_policies( load_log_file: Optional[str] Load measurement records from this file. If it is not None, the status of the task scheduler, search policies and cost models will be restored according to this file. + adapative_training: bool = False + Option used for XGBModel, which will reduce the model training frequency when there're too + many logs Returns ------- @@ -85,6 +89,7 @@ def make_search_policies( cost_model = XGBModel( num_warmup_sample=len(tasks) * num_measures_per_round, model_file=load_model_file, + adapative_training=adapative_training, ) if load_model_file: logger.info("TaskScheduler: Load pretrained model...") From 076fab7608276039478d0bfb0687f553fe300f1b Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Tue, 5 Jan 2021 19:11:53 +0800 Subject: [PATCH 10/12] Lint fix --- .../tvm/auto_scheduler/cost_model/xgb_model.py | 16 ++++++++++++---- python/tvm/auto_scheduler/task_scheduler.py | 2 +- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/python/tvm/auto_scheduler/cost_model/xgb_model.py b/python/tvm/auto_scheduler/cost_model/xgb_model.py index cd1d82438a51..f14d9f57f2e1 100644 --- a/python/tvm/auto_scheduler/cost_model/xgb_model.py +++ b/python/tvm/auto_scheduler/cost_model/xgb_model.py @@ -88,8 +88,14 @@ class XGBModel(PythonBasedModel): their predictions. """ - def __init__(self, verbose_eval=25, num_warmup_sample=100, seed=None, model_file=None, - adapative_training=False): + def __init__( + self, + verbose_eval=25, + num_warmup_sample=100, + seed=None, + model_file=None, + adapative_training=False, + ): global xgb try: if xgb is None: @@ -144,8 +150,10 @@ def update(self, inputs, results): self.inputs.extend(inputs) self.results.extend(results) - if self.adapative_training and \ - len(self.inputs) - self.last_train_length < self.last_train_length / 5: + if ( + self.adapative_training + and len(self.inputs) - self.last_train_length < self.last_train_length / 5 + ): # Set a training threshold related to `last_train_length` to reduce the training # overhead when there're too many logs return diff --git a/python/tvm/auto_scheduler/task_scheduler.py b/python/tvm/auto_scheduler/task_scheduler.py index 68913ade59f8..5d8ca8af7438 100644 --- a/python/tvm/auto_scheduler/task_scheduler.py +++ b/python/tvm/auto_scheduler/task_scheduler.py @@ -73,7 +73,7 @@ def make_search_policies( task scheduler, search policies and cost models will be restored according to this file. adapative_training: bool = False Option used for XGBModel, which will reduce the model training frequency when there're too - many logs + many logs. Returns ------- From adc5c42d579af1225358a08a4892c985f01902be Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Tue, 5 Jan 2021 19:30:29 +0800 Subject: [PATCH 11/12] Update --- python/tvm/auto_scheduler/task_scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/auto_scheduler/task_scheduler.py b/python/tvm/auto_scheduler/task_scheduler.py index 5d8ca8af7438..975306f7be54 100644 --- a/python/tvm/auto_scheduler/task_scheduler.py +++ b/python/tvm/auto_scheduler/task_scheduler.py @@ -91,7 +91,7 @@ def make_search_policies( model_file=load_model_file, adapative_training=adapative_training, ) - if load_model_file: + if load_model_file and os.path.isfile(load_model_file): logger.info("TaskScheduler: Load pretrained model...") cost_model.load(load_model_file) elif load_log_file: From 3a3f35bfc40287f25b2fa73f0653016161d76f25 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Tue, 5 Jan 2021 20:52:23 +0800 Subject: [PATCH 12/12] Bug fix --- python/tvm/auto_scheduler/cost_model/xgb_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/auto_scheduler/cost_model/xgb_model.py b/python/tvm/auto_scheduler/cost_model/xgb_model.py index f14d9f57f2e1..f42648288bfa 100644 --- a/python/tvm/auto_scheduler/cost_model/xgb_model.py +++ b/python/tvm/auto_scheduler/cost_model/xgb_model.py @@ -124,6 +124,7 @@ def __init__( self.num_warmup_sample = num_warmup_sample self.verbose_eval = verbose_eval self.model_file = model_file + self.adapative_training = adapative_training super().__init__()