From 52ad51e4a405c91aeb53a307e47f3eebbf910ba6 Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Fri, 4 Oct 2019 19:55:44 -0700 Subject: [PATCH] [Relay][VM] Add autotvm context when compile --- python/tvm/relay/backend/vm.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index 251d3cc3463e..a6cb91c2dfde 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -23,6 +23,7 @@ import numpy as np import tvm +from tvm import autotvm from tvm._ffi.runtime_ctypes import TVMByteArray from . import _vm from . import vmobj as _obj @@ -186,7 +187,16 @@ def compile(self, mod, target=None, target_host=None): if not target_host: target_host = "llvm" if tvm.module.enabled("llvm") else "stackvm" target_host = tvm.target.create(target_host) - self._compile(mod, target, target_host) + + # If current dispatch context is fallback context (the default root context), + # then load pre-tuned parameters from TopHub + if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext): + tophub_context = autotvm.tophub.context(list(target.values())) + else: + tophub_context = autotvm.util.EmptyContext() + + with tophub_context: + self._compile(mod, target, target_host) return VirtualMachine(self._get_vm())