From 32aed40a5ffe61e599c08c711b2a69f189be1f89 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Thu, 28 Jan 2021 00:43:31 +0000 Subject: [PATCH 1/3] [AutoScheduler] Add sampling to dispatcher --- python/tvm/auto_scheduler/__init__.py | 2 +- python/tvm/auto_scheduler/dispatcher.py | 86 ++++++++++++++++++- .../relay/test_auto_scheduler_tuning.py | 17 +++- 3 files changed, 99 insertions(+), 6 deletions(-) diff --git a/python/tvm/auto_scheduler/__init__.py b/python/tvm/auto_scheduler/__init__.py index 57e58309525c..06ca44d997e5 100644 --- a/python/tvm/auto_scheduler/__init__.py +++ b/python/tvm/auto_scheduler/__init__.py @@ -33,7 +33,7 @@ # Shortcut from .compute_dag import ComputeDAG, LayoutRewriteOption, get_shape_from_rewritten_layout from .cost_model import RandomModel, XGBModel -from .dispatcher import DispatchContext, ApplyHistoryBest +from .dispatcher import DispatchContext, ApplyHistoryBest, ApplyHistoryBestOrSample from .measure import ( MeasureInput, MeasureResult, diff --git a/python/tvm/auto_scheduler/dispatcher.py b/python/tvm/auto_scheduler/dispatcher.py index f2d7536bea88..b5a091bcce0d 100644 --- a/python/tvm/auto_scheduler/dispatcher.py +++ b/python/tvm/auto_scheduler/dispatcher.py @@ -28,8 +28,13 @@ import numpy as np +from tvm.contrib.utils import tempdir from tvm.tir.expr import FloatImm -from .measure_record import load_records +from .cost_model import RandomModel, XGBModel +from .measure import LocalRPCMeasureContext +from .measure_record import RecordToFile, load_records +from .search_policy import PreloadMeasuredStates, SketchPolicy +from .search_task import SearchTask, TuningOptions from .utils import calc_workload_dis_factor, decode_workload_key logger = logging.getLogger("auto_scheduler") @@ -301,6 +306,85 @@ def update(self, target, workload_key, state): entry[workload_args] = (state, 1) +class ApplyHistoryBestOrSample(ApplyHistoryBest): + """ + Apply the history best config, or sample a valid schedule if no config is found. + + Parameters + ---------- + records : str or iterator of (auto_scheduler.measure.MeasureInput,\ + auto_scheduler.measure.MeasureResult) + Collection of tuning records. + If is str, then it should be the filename of a records log file. + Each row of this file is an encoded record pair. Otherwise, it is an iterator. + sample_simple_workloads: bool + When False, sampling will not apply to simple workloads (w/o reduction). + cost_model_file: str + The filename of the pre-trained XGBoost cost model. If not present, then random + model will be used. + """ + + def __init__(self, records, sample_simple_workloads=False, cost_model_file=None): + self.sample_simple_workloads = sample_simple_workloads + self.log_dir = tempdir() + if cost_model_file is None: + self.cost_model = RandomModel() + else: + self.cost_model = XGBModel(num_warmup_sample=1, model_file=cost_model_file) + + super(ApplyHistoryBestOrSample, self).__init__( + records, n_lines=None, include_compatible=True + ) + + def query(self, target, workload_key, has_complex_op, dag): + if has_complex_op or self.sample_simple_workloads: + ret = self._query_inside(target, workload_key) + else: + ret = super(ApplyHistoryBestOrSample, self)._query_inside(target, workload_key) + + if ret is None: + ret = self._old_ctx.query(target, workload_key, has_complex_op, dag) + return ret + + def _query_inside(self, target, workload_key): + ret = super(ApplyHistoryBestOrSample, self)._query_inside(target, workload_key) + if ret is not None: + return ret + + # Sampling valid schedules when no existing records can be used. + task = SearchTask(workload_key=workload_key, target=target) + measure_ctx = LocalRPCMeasureContext(min_repeat_ms=300) + + log_file = self.log_dir.relpath("%s.log" % decode_workload_key(workload_key)[0]) + + while ret is None: + tune_option = TuningOptions( + num_measure_trials=2, + runner=measure_ctx.runner, + measure_callbacks=[RecordToFile(log_file)], + verbose=0, + ) + search_policy = SketchPolicy( + task, + self.cost_model, + params={ + "eps_greedy": 0.01, + "sample_init_min_population": 64, + "evolutionary_search_num_iters": 0, + }, + init_search_callbacks=[PreloadMeasuredStates(log_file)], + verbose=0, + ) + task.tune(tune_option, search_policy) + + # Load the sampled records and query again. + self.load(log_file) + ret = super(ApplyHistoryBestOrSample, self)._query_inside(target, workload_key) + + del measure_ctx + return ret + + class FallbackContext(DispatchContext): """ A fallback dispatch context. diff --git a/tests/python/relay/test_auto_scheduler_tuning.py b/tests/python/relay/test_auto_scheduler_tuning.py index 4ae434d72a20..3b61052ae3b0 100644 --- a/tests/python/relay/test_auto_scheduler_tuning.py +++ b/tests/python/relay/test_auto_scheduler_tuning.py @@ -56,9 +56,16 @@ def tune_network(network, target): ): lib = relay.build(mod, target=target, params=params) + # Sample a schedule when missing + with auto_scheduler.ApplyHistoryBestOrSample(None): + with tvm.transform.PassContext( + opt_level=3, config={"relay.backend.use_auto_scheduler": True} + ): + lib2 = relay.build(mod, target=target, params=params) + # Compile without auto-scheduler and any other optimization for correctness check with tvm.transform.PassContext(opt_level=0): - lib2 = relay.build(mod, target=target, params=params) + ref_lib = relay.build(mod, target=target, params=params) # Check the correctness def get_output(data, lib): @@ -76,10 +83,12 @@ def get_output(data, lib): else: raise ValueError("Unknown network: " + network) - actual_output = get_output(data, lib) - expected_output = get_output(data, lib2) + actual_output1 = get_output(data, lib) + actual_output2 = get_output(data, lib2) + expected_output = get_output(data, ref_lib) - tvm.testing.assert_allclose(actual_output, expected_output, rtol=1e-4, atol=1e-4) + tvm.testing.assert_allclose(actual_output1, expected_output, rtol=1e-4, atol=1e-4) + tvm.testing.assert_allclose(actual_output2, expected_output, rtol=1e-4, atol=1e-4) @tvm.testing.requires_cuda From f63c5a697395d63ddd5f423509bfe50686d78cae Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 5 Feb 2021 17:44:05 +0000 Subject: [PATCH 2/3] address comment --- python/tvm/auto_scheduler/dispatcher.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/auto_scheduler/dispatcher.py b/python/tvm/auto_scheduler/dispatcher.py index b5a091bcce0d..bde83ab9112d 100644 --- a/python/tvm/auto_scheduler/dispatcher.py +++ b/python/tvm/auto_scheduler/dispatcher.py @@ -330,7 +330,8 @@ def __init__(self, records, sample_simple_workloads=False, cost_model_file=None) if cost_model_file is None: self.cost_model = RandomModel() else: - self.cost_model = XGBModel(num_warmup_sample=1, model_file=cost_model_file) + self.cost_model = XGBModel() + self.cost_model.load(cost_model_file) super(ApplyHistoryBestOrSample, self).__init__( records, n_lines=None, include_compatible=True From c1d96ef3197cebc4d68dca101e4ead996c811ef2 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Mon, 8 Feb 2021 19:23:10 +0000 Subject: [PATCH 3/3] make measurment configurable --- python/tvm/auto_scheduler/dispatcher.py | 10 ++++++++-- tests/python/relay/test_auto_scheduler_tuning.py | 2 +- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/python/tvm/auto_scheduler/dispatcher.py b/python/tvm/auto_scheduler/dispatcher.py index bde83ab9112d..6a25960fe7b7 100644 --- a/python/tvm/auto_scheduler/dispatcher.py +++ b/python/tvm/auto_scheduler/dispatcher.py @@ -322,10 +322,16 @@ class ApplyHistoryBestOrSample(ApplyHistoryBest): cost_model_file: str The filename of the pre-trained XGBoost cost model. If not present, then random model will be used. + num_measure: int + Meausre the top-N rank of sampled schedules on the device. The default -1 means + no measurement and simply return the top-1 schedule ranked by the cost model. """ - def __init__(self, records, sample_simple_workloads=False, cost_model_file=None): + def __init__( + self, records, sample_simple_workloads=False, cost_model_file=None, num_measure=-1 + ): self.sample_simple_workloads = sample_simple_workloads + self.num_measure = num_measure self.log_dir = tempdir() if cost_model_file is None: self.cost_model = RandomModel() @@ -360,7 +366,7 @@ def _query_inside(self, target, workload_key): while ret is None: tune_option = TuningOptions( - num_measure_trials=2, + num_measure_trials=self.num_measure, runner=measure_ctx.runner, measure_callbacks=[RecordToFile(log_file)], verbose=0, diff --git a/tests/python/relay/test_auto_scheduler_tuning.py b/tests/python/relay/test_auto_scheduler_tuning.py index 3b61052ae3b0..1ec0e305311a 100644 --- a/tests/python/relay/test_auto_scheduler_tuning.py +++ b/tests/python/relay/test_auto_scheduler_tuning.py @@ -57,7 +57,7 @@ def tune_network(network, target): lib = relay.build(mod, target=target, params=params) # Sample a schedule when missing - with auto_scheduler.ApplyHistoryBestOrSample(None): + with auto_scheduler.ApplyHistoryBestOrSample(None, num_measure=2): with tvm.transform.PassContext( opt_level=3, config={"relay.backend.use_auto_scheduler": True} ):