Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/tvm/autotvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,6 @@
from .tuner import callback
from .task import template, get_config, create, ConfigSpace, ConfigEntity, \
register_topi_compute, register_topi_schedule, \
DispatchContext, FallbackContext, ApplyHistoryBest as apply_history_best
DispatchContext, FallbackContext, ApplyHistoryBest as apply_history_best, \
ApplyGraphBest as apply_graph_best
from .env import GLOBAL_SCOPE
2 changes: 1 addition & 1 deletion python/tvm/autotvm/task/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .space import ConfigSpace, ConfigEntity
from .code_hash import attach_code_hash, attach_code_hash_to_arg
from .dispatcher import dispatcher, DispatchContext, ApplyConfig, ApplyHistoryBest, \
FallbackContext, clear_fallback_cache
FallbackContext, clear_fallback_cache, ApplyGraphBest

from .topi_integration import register_topi_compute, register_topi_schedule
from .nnvm_integration import extract_from_graph, extract_from_multiple_graph
80 changes: 80 additions & 0 deletions python/tvm/autotvm/task/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,3 +345,83 @@ def clear_fallback_cache(target, workload):
while not isinstance(context, FallbackContext):
context = context._old_ctx
context.clear_cache(target, workload)

class ApplyGraphBest(DispatchContext):
"""Load the graph level tuning optimal schedules.

The input records should be in the ascending order of
node index for target operator. Usually this can be obtained
with graph tuner.

This context maintains an internal counter to indicate the current
node index.
"""
def __init__(self, records):
"""
Parameters
----------
records : str or iterator of (MeasureInput, MeasureResult)
Collection of tuning records.
If is str, then it should be the filename of a records log file.
Each row of this file is an encoded record pair.
Otherwise, it is an iterator.
"""
from ..record import load_from_file

super(ApplyGraphBest, self).__init__()
if isinstance(records, str):
records = load_from_file(records)
self._records = list(records)
self._counter = 0
self._global_cfg_dict = {}

def _query_inside(self, target, workload):
"""
Query the context to get config from records.

Parameters
----------
target : Target
The current target
workload : Workload
The current workload.

Returns
-------
cfg : ConfigSpace
The specific configuration.
"""
cfg = self._records[self._counter][0].config
self._counter += 1
return cfg

def query_global_dict(self, key):
"""
Query the context to get config from global
config dictionary.

Parameters
----------
key : str
Key to query the config.

Returns
-------
cfg : ConfigSpace
The specific configuration.
"""
return self._global_cfg_dict[key]

def update_global_dict(self, key, val):
"""
Update the global config dictionary.

Parameters
----------
key : str
Key of config.

val : ConfigSpace
Value of config.
"""
self._global_cfg_dict[key] = val
Loading