diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index 745ef63d3afe..251d3cc3463e 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -178,6 +178,11 @@ def compile(self, mod, target=None, target_host=None): """ target = _update_target(target) target_host = None if target_host == "" else target_host + if not target_host: + for device_type, tgt in target.items(): + if device_type.value == tvm.nd.cpu(0).device_type: + target_host = tgt + break if not target_host: target_host = "llvm" if tvm.module.enabled("llvm") else "stackvm" target_host = tvm.target.create(target_host)