Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion python/tvm/auto_scheduler/cost_model/xgb_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,14 @@ class XGBModel(PythonBasedModel):
their predictions.
"""

def __init__(self, verbose_eval=25, num_warmup_sample=100, seed=None):
def __init__(
self,
verbose_eval=25,
num_warmup_sample=100,
seed=None,
model_file=None,
adapative_training=False,
):
global xgb
try:
if xgb is None:
Expand Down Expand Up @@ -116,12 +123,15 @@ def __init__(self, verbose_eval=25, num_warmup_sample=100, seed=None):
self.plan_size = 32
self.num_warmup_sample = num_warmup_sample
self.verbose_eval = verbose_eval
self.model_file = model_file
self.adapative_training = adapative_training

super().__init__()

# cache measurement input/result pairs and extracted features
self.inputs = []
self.results = []
self.last_train_length = 0
self.inputs_feature_cache = []

def update(self, inputs, results):
Expand All @@ -141,6 +151,15 @@ def update(self, inputs, results):
self.inputs.extend(inputs)
self.results.extend(results)

if (
self.adapative_training
and len(self.inputs) - self.last_train_length < self.last_train_length / 5
):
# Set a training threshold related to `last_train_length` to reduce the training
# overhead when there're too many logs
return
self.last_train_length = len(self.inputs)

# extract feature
n_cached = len(self.inputs_feature_cache)
features, normalized_throughputs, task_ids = get_per_store_features_from_measure_pairs(
Expand Down Expand Up @@ -176,6 +195,10 @@ def update(self, inputs, results):
],
)

# Update the model file if it has been set
if self.model_file:
self.save(self.model_file)

def predict(self, task, states):
"""Predict the scores of states
Parameters
Expand Down
13 changes: 11 additions & 2 deletions python/tvm/auto_scheduler/task_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def make_search_policies(
verbose,
load_model_file=None,
load_log_file=None,
adapative_training=False,
):
"""Make a list of search policies for a list of search tasks.
It creates one policy per task.
Expand All @@ -70,6 +71,9 @@ def make_search_policies(
load_log_file: Optional[str]
Load measurement records from this file. If it is not None, the status of the
task scheduler, search policies and cost models will be restored according to this file.
adapative_training: bool = False
Option used for XGBModel, which will reduce the model training frequency when there're too
many logs.

Returns
-------
Expand All @@ -82,11 +86,16 @@ def make_search_policies(
if isinstance(search_policy, str):
policy_type, model_type = search_policy.split(".")
if model_type == "xgb":
cost_model = XGBModel(num_warmup_sample=len(tasks) * num_measures_per_round)
if load_model_file:
cost_model = XGBModel(
num_warmup_sample=len(tasks) * num_measures_per_round,
model_file=load_model_file,
adapative_training=adapative_training,
)
if load_model_file and os.path.isfile(load_model_file):
logger.info("TaskScheduler: Load pretrained model...")
cost_model.load(load_model_file)
elif load_log_file:
logger.info("TaskScheduler: Reload measured states and train the model...")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load_model_file and load_log_file are mutually exclusive, because update_from_file will retrain a model and overwrite the loaded model.
I think the old code is better.

I don't know why the old code cannot satisfy your need.

Copy link
Contributor Author

@jcf94 jcf94 Jan 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old one is fine. I was just going to add a self.model_file for cost model saving after training, this was modified by the way.

cost_model.update_from_file(load_log_file)
elif model_type == "random":
cost_model = RandomModel()
Expand Down
18 changes: 12 additions & 6 deletions src/auto_scheduler/feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1462,12 +1462,18 @@ void GetPerStoreFeaturesFromMeasurePairs(const Array<MeasureInput>& inputs,
if (find_res == task_cache.end()) {
if (inputs[i]->task->compute_dag.defined()) { // the measure input is complete
task = inputs[i]->task;
} else { // the measure input is incomplete
// rebuild task for incomplete measure pairs read from file
Array<te::Tensor> tensors = (*workload_key_to_tensors)(workload_key);
task = SearchTask(ComputeDAG(tensors), workload_key, inputs[i]->task->target,
inputs[i]->task->target_host, inputs[i]->task->hardware_params,
inputs[i]->task->layout_rewrite_option);
} else {
// The measure input is incomplete, rebuild task for incomplete measure pairs read from file
try {
Array<te::Tensor> tensors = (*workload_key_to_tensors)(workload_key);
task = SearchTask(ComputeDAG(tensors), workload_key, inputs[i]->task->target,
inputs[i]->task->target_host, inputs[i]->task->hardware_params,
inputs[i]->task->layout_rewrite_option);
} catch (std::exception& e) {
// Cannot build ComputeDAG from workload key, the task may have not been registered in
// this search round
continue;
}
}
task_id = task_cache.size();

Expand Down