diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index 251d3cc3463e..a6cb91c2dfde 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -23,6 +23,7 @@ import numpy as np import tvm +from tvm import autotvm from tvm._ffi.runtime_ctypes import TVMByteArray from . import _vm from . import vmobj as _obj @@ -186,7 +187,16 @@ def compile(self, mod, target=None, target_host=None): if not target_host: target_host = "llvm" if tvm.module.enabled("llvm") else "stackvm" target_host = tvm.target.create(target_host) - self._compile(mod, target, target_host) + + # If current dispatch context is fallback context (the default root context), + # then load pre-tuned parameters from TopHub + if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext): + tophub_context = autotvm.tophub.context(list(target.values())) + else: + tophub_context = autotvm.util.EmptyContext() + + with tophub_context: + self._compile(mod, target, target_host) return VirtualMachine(self._get_vm())