Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 29 additions & 8 deletions python/tvm/autotvm/tophub.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
TopHub: Tensor Operator Hub
To get the best performance, we typically need auto-tuning for the specific devices.
TVM releases pre-tuned parameters in TopHub for some common networks and hardware targets.
TVM will download these parameters for you when you call nnvm.compiler.build_module .
TVM will download these parameters for you when you call
nnvm.compiler.build_module or relay.build.
"""
# pylint: disable=invalid-name

Expand All @@ -30,6 +31,16 @@
from .. import target as _target
from ..contrib.download import download
from .record import load_from_file
from .util import EmptyContext

# environment variable to read TopHub location
AUTOTVM_TOPHUB_LOC_VAR = "TOPHUB_LOCATION"

# default location of TopHub
AUTOTVM_TOPHUB_DEFAULT_LOC = "https://raw.githubusercontent.com/uwsampl/tvm-distro/master/tophub"

# value of AUTOTVM_TOPHUB_LOC_VAR to specify to not read from TopHub
AUTOTVM_TOPHUB_NONE_LOC = "NONE"

# root path to store TopHub files
AUTOTVM_TOPHUB_ROOT_PATH = os.path.join(os.path.expanduser('~'), ".tvm", "tophub")
Expand Down Expand Up @@ -61,6 +72,9 @@ def _alias(name):
}
return table.get(name, name)

def _get_tophub_location():
location = os.getenv(AUTOTVM_TOPHUB_LOC_VAR, None)
return AUTOTVM_TOPHUB_DEFAULT_LOC if location is None else location

def context(target, extra_files=None):
"""Return the dispatch context with pre-tuned parameters.
Expand All @@ -75,6 +89,10 @@ def context(target, extra_files=None):
extra_files: list of str, optional
Extra log files to load
"""
tophub_location = _get_tophub_location()
if tophub_location == AUTOTVM_TOPHUB_NONE_LOC:
return EmptyContext()

best_context = ApplyHistoryBest([])

targets = target if isinstance(target, (list, tuple)) else [target]
Expand All @@ -94,7 +112,7 @@ def context(target, extra_files=None):
for name in possible_names:
name = _alias(name)
if name in all_packages:
if not check_backend(name):
if not check_backend(tophub_location, name):
continue

filename = "%s_%s.log" % (name, PACKAGE_VERSION[name])
Expand All @@ -108,7 +126,7 @@ def context(target, extra_files=None):
return best_context


def check_backend(backend):
def check_backend(tophub_location, backend):
"""Check whether have pre-tuned parameters of the certain target.
If not, will download it.

Expand All @@ -135,18 +153,21 @@ def check_backend(backend):
else:
import urllib2
try:
download_package(package_name)
download_package(tophub_location, package_name)
return True
except urllib2.URLError as e:
logging.warning("Failed to download tophub package for %s: %s", backend, e)
return False


def download_package(package_name):
def download_package(tophub_location, package_name):
"""Download pre-tuned parameters of operators for a backend

Parameters
----------
tophub_location: str
The location to download TopHub parameters from

package_name: str
The name of package
"""
Expand All @@ -160,9 +181,9 @@ def download_package(package_name):
if not os.path.isdir(path):
os.mkdir(path)

logger.info("Download pre-tuned parameters package %s", package_name)
download("https://raw.githubusercontent.com/uwsampl/tvm-distro/master/tophub/%s"
% package_name, os.path.join(rootpath, package_name), True, verbose=0)
download_url = "{0}/{1}".format(tophub_location, package_name)
logger.info("Download pre-tuned parameters package from %s", download_url)
download(download_url, os.path.join(rootpath, package_name), True, verbose=0)


# global cache for load_reference_log
Expand Down