diff --git a/CMakeLists.txt b/CMakeLists.txt index fc7c67c83a48..658ce59d09a3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,9 @@ -cmake_minimum_required(VERSION 3.2) +if(WIN32) + cmake_minimum_required(VERSION 3.9) +else() + cmake_minimum_required(VERSION 3.2) +endif() + project(tvm C CXX) # Utility functions @@ -442,6 +447,9 @@ endif(INSTALL_DEV) # More target definitions if(MSVC) + set_property(TARGET tvm PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE) + set_property(TARGET tvm_topi PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE) + set_property(TARGET tvm_runtime PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE) target_compile_definitions(tvm PRIVATE -DTVM_EXPORTS) target_compile_definitions(tvm_runtime PRIVATE -DTVM_EXPORTS) endif() diff --git a/apps/cpp_rpc/rpc_env.cc b/apps/cpp_rpc/rpc_env.cc index b5dc51b9e7ef..64c1a8ac761c 100644 --- a/apps/cpp_rpc/rpc_env.cc +++ b/apps/cpp_rpc/rpc_env.cc @@ -194,7 +194,7 @@ void WindowsShared(const std::string& output, const std::string& options = "", const std::string& cc = "clang") { std::string cmd = cc; - cmd += " -O2 -flto=full -fuse-ld=lld-link -Wl,/EXPORT:__tvm_main__ -shared "; + cmd += " -O2 -flto=full -fuse-ld=lld-link -shared "; cmd += " -o " + output; for (const auto& file : files) { cmd += " " + file; diff --git a/python/tvm/autotvm/measure/local_executor.py b/python/tvm/autotvm/measure/local_executor.py index cf81e2b50e50..72f54c149297 100644 --- a/python/tvm/autotvm/measure/local_executor.py +++ b/python/tvm/autotvm/measure/local_executor.py @@ -17,8 +17,30 @@ """Local based implementation of the executor using multiprocessing""" import signal - -from multiprocessing import Process, Queue +import os + +if os.name == 'nt': + import queue as thread_queue + import threading + # Pathos uses dill, which can pickle things like functions + from pathos.helpers import ProcessPool + # On Windows, there is no fork(), a 'multiprocessing.Process' + # or ProcessPool has to 'build up' the script from scratch, we + # set these environment variables so each python.exe process + # does not allocate unneeded threads + os.environ['OMP_NUM_THREADS'] = "1" + os.environ['TVM_NUM_THREADS'] = "1" + # numpy seems to honor this + os.environ['MKL_NUM_THREADS'] = "1" + + # Since there is no fork() on Windows, to mitigate performance impact + # we will use a process pool for executers, vs the *nix based systems + # that will fork() a new process for each executor + EXECUTOR_POOL = None + +#pylint: disable=wrong-import-position +#pylint: disable=ungrouped-imports +from multiprocessing import Process, Queue, cpu_count try: from queue import Empty except ImportError: @@ -30,7 +52,8 @@ psutil = None from . import executor - +#pylint: enable=ungrouped-imports +#pylint: enable=wrong-import-position def kill_child_processes(parent_pid, sig=signal.SIGTERM): """kill all child processes recursively""" @@ -68,6 +91,26 @@ def call_with_timeout(queue, timeout, func, args, kwargs): p.terminate() p.join() +if os.name == 'nt': + def call_from_pool(func, args, kwargs, timeout, env): # pylint: disable=unused-argument + """A wrapper to support timeout of a function call for a pool process""" + + # Restore environment variables from parent + for key, val in env.items(): + os.environ[key] = val + + queue = thread_queue.Queue(2) + + # We use a thread here for Windows, because starting up a new Process can be heavy + # This isn't as clean as the *nix implementation, which can kill a process that + # has timed out + thread = threading.Thread(target=_execute_func, args=(func, queue, args, kwargs)) + thread.start() + thread.join() + queue.put(executor.TimeoutError()) + + res = queue.get() + return res class LocalFuture(executor.Future): """Local wrapper for the future @@ -119,6 +162,31 @@ def done(self): def get(self, timeout=None): return self._result +if os.name == 'nt': + class LocalFuturePool(executor.Future): + """Local wrapper for the future using a Process pool + + Parameters + ---------- + thread: threading.Thread + Thread for running this task + pool_results: result from Pool.apply_async + queue for receiving the result of this task + """ + def __init__(self, pool_results): + self._done = False + self._pool_results = pool_results + + def done(self): + return self._done + + def get(self, timeout=None): + try: + res = self._pool_results.get() + except Empty: + raise executor.TimeoutError() + self._done = True + return res class LocalExecutor(executor.Executor): """Local executor that runs workers on the same machine with multiprocessing. @@ -145,8 +213,25 @@ def submit(self, func, *args, **kwargs): if not self.do_fork: return LocalFutureNoFork(func(*args, **kwargs)) - queue = Queue(2) - process = Process(target=call_with_timeout, - args=(queue, self.timeout, func, args, kwargs)) - process.start() - return LocalFuture(process, queue) + if os.name != 'nt': + queue = Queue(2) + process = Process(target=call_with_timeout, + args=(queue, self.timeout, func, args, kwargs)) + process.start() + return LocalFuture(process, queue) + + global EXECUTOR_POOL + + if EXECUTOR_POOL is None: + # We use a static pool for executor processes because Process.start(entry) + # is so slow on Windows, we lose a lot of parallelism. + # Right now cpu_count() is used, which isn't optimal from a user configuration + # perspective, but is reasonable at this time. + EXECUTOR_POOL = ProcessPool(cpu_count() * 2) + + # Windows seemed to be missing some valuable environ variables + # on the pool's process side. We might be able to get away with + # just sending the PATH variable, but for now, we just clone our env + return LocalFuturePool(EXECUTOR_POOL.apply_async(call_from_pool, + (func, args, kwargs, + self.timeout, os.environ.copy()))) diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py index 6533e75eef93..07c4086334d6 100644 --- a/python/tvm/autotvm/measure/measure_methods.py +++ b/python/tvm/autotvm/measure/measure_methods.py @@ -27,7 +27,9 @@ import os import threading import time -from random import getrandbits +# Import random directly because it appears dill will pull in 'getrandbits' +# and it will always get the same random number. Using random.getrandbits fixes +import random from collections import namedtuple import tempfile @@ -331,10 +333,12 @@ def set_task(self, task): from ...rpc.tracker import Tracker from ...rpc.server import Server + # Windows will not let you connect to 0.0.0.0 + local_address = '0.0.0.0' if os.name != 'nt' else '127.0.0.1' self.task = task - tracker = Tracker('0.0.0.0', port=9000, port_end=10000, silent=True) + tracker = Tracker(local_address, port=9000, port_end=10000, silent=True) device_key = '$local$device$%d' % tracker.port - server = Server('0.0.0.0', port=9000, port_end=10000, + server = Server(local_address, port=9000, port_end=10000, key=device_key, use_popen=True, silent=True, tracker_addr=(tracker.host, tracker.port)) @@ -392,6 +396,7 @@ def _wrap_build_func(build_func): raise AttributeError("Expect build_func to have the attribute output_format.") output_format = build_func.output_format + def _wrapped(measure_input, tmp_dir, **kwargs): """ Wrapped build func. @@ -407,7 +412,7 @@ def _wrapped(measure_input, tmp_dir, **kwargs): tic = time.time() try: filename = os.path.join(tmp_dir, "tmp_func_%0x.%s" % ( - getrandbits(64), output_format)) + random.getrandbits(64), output_format)) # TODO(tvm-team) consider linline _build_func_common func, arg_info = _build_func_common(measure_input, **kwargs) func.export_library(filename, build_func) diff --git a/python/tvm/autotvm/task/relay_integration.py b/python/tvm/autotvm/task/relay_integration.py index de183db41e2c..ccf404b6da49 100644 --- a/python/tvm/autotvm/task/relay_integration.py +++ b/python/tvm/autotvm/task/relay_integration.py @@ -135,8 +135,12 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No # wrap build call in thread to avoid multiprocessing problems build_thread = threading.Thread(target=_lower, args=(mod, target, param)) + # Stack would overflow on some platforms (Windows) on some models + old_stack_size = threading.stack_size(1024 * 1024 * 3) build_thread.start() build_thread.join() + # Restore stacksize to original + threading.stack_size(old_stack_size) logger.disabled = old_state diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py index 00b667670c65..0f0b1914f900 100644 --- a/python/tvm/autotvm/task/task.py +++ b/python/tvm/autotvm/task/task.py @@ -21,6 +21,8 @@ func is a state-less function, or a string that registers the standard task. """ + +import os import numpy as np from tvm import target as _target @@ -169,7 +171,9 @@ def __getstate__(self): "config_space": self.config_space, "flop": self.flop, "target": self.target, - "target_host": self.target_host + "target_host": self.target_host, + # On Windows we will use, dill, which can pickle functions + "func": self.func if os.name == 'nt' else None } def __setstate__(self, state): diff --git a/python/tvm/autotvm/tuner/xgboost_cost_model.py b/python/tvm/autotvm/tuner/xgboost_cost_model.py index 305244808a33..3e08e3ac59f2 100644 --- a/python/tvm/autotvm/tuner/xgboost_cost_model.py +++ b/python/tvm/autotvm/tuner/xgboost_cost_model.py @@ -20,7 +20,15 @@ import multiprocessing import logging import time - +import os + +if os.name == 'nt': + # Pathos' Pool does pickling via dill, which can pickle + # functions, which is required because Windows doesn't + # support fork() + from pathos.helpers import mp as pathos_multiprocess + from pathos.helpers import ProcessPool +#pylint: disable=wrong-import-position import numpy as np try: import xgboost as xgb @@ -31,7 +39,7 @@ from ..util import get_rank from .metric import max_curve, recall_curve, cover_curve from .model_based_tuner import CostModel, FeatureCache - +#pylint: enable=wrong-import-position logger = logging.getLogger('autotvm') class XGBoostCostModel(CostModel): @@ -153,14 +161,60 @@ def _reset_pool(self, space, target, task): self._close_pool() - # use global variable to pass common arguments - global _extract_space, _extract_target, _extract_task - _extract_space = space - _extract_target = target - _extract_task = task - self.pool = multiprocessing.Pool(self.num_threads) + if os.name == 'nt': + # For Windows, we need space, target, task to be pickled and set on the + # Pool's process side, where the *nix impl simply sets globals + # then forks. + # To ensure each process in the pool is properly set, we have to do + # some synchronization by sending an async call, setting space, target and task, then + # waiting for the queue to have an item set + num_threads = self.num_threads + pool_size = num_threads if num_threads is not None else multiprocessing.cpu_count() + if self.pool is None: + self.pool = ProcessPool(pool_size) + manager = pathos_multiprocess.Manager() + pipe_syncs = [] + + # A simple pathos.map would be cleaner, but it seems that in some cases, + # some of the pools processes will be missed, with some processes running + # the method twice. It seems that just passing a Queue in this manner, + # hits all the processes in the pool. Some assertion could be built to verify + for _ in range(pool_size): + queue = manager.Queue(1) + results = { + "queue": queue, + "apipe": self.pool.apply_async(_set_pool_process_state, + args=(space, target, task, queue)) + } + pipe_syncs.append(results) + + # wait loop until all async calls have completed + while True: + all_ready = True + for pipe_sync in pipe_syncs: + if pipe_sync["apipe"].ready() is False: + all_ready = False + break + if all_ready: + break + time.sleep(0.05) + # complete the async requests on the pool + for pipe_sync in pipe_syncs: + pipe_sync["apipe"].get() + # This may not be needed + pipe_sync["queue"].get(block=True) + else: + # use global variable to pass common arguments + global _extract_space, _extract_target, _extract_task + _extract_space = space + _extract_target = target + _extract_task = task + self.pool = multiprocessing.Pool(self.num_threads) + + def _close_pool(self, force_close=False): + if os.name == 'nt' and not force_close: + return - def _close_pool(self): if self.pool: self.pool.terminate() self.pool.join() @@ -195,7 +249,6 @@ def fit(self, xs, ys, plan_size): self.base_model = None else: dtrain.set_base_margin(discount * self.base_model.predict(xs, output_margin=True)) - self.bst = xgb.train(self.xgb_params, dtrain, num_boost_round=8000, callbacks=[custom_callback( @@ -324,13 +377,23 @@ def _get_feature(self, indexes): return ret def __del__(self): - self._close_pool() + self._close_pool(force_close=True) _extract_space = None _extract_target = None _extract_task = None +if os.name == 'nt': + def _set_pool_process_state(space, target, task, sync_queue): + """sets process state for when fork() is not available """ + global _extract_space, _extract_target, _extract_task + _extract_space = space + _extract_target = target + _extract_task = task + # Notify caller that we are done. We may be able to remove this + sync_queue.put(None) + def _extract_itervar_feature_index(index): """extract iteration var feature for an index in extract_space""" try: @@ -423,7 +486,7 @@ def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None, from xgboost.core import EarlyStopException from xgboost.callback import _fmt_metric from xgboost.training import aggcv - + #pylint: enable=import-outside-toplevel state = {} metric_shortname = metric.split("-")[1] @@ -472,7 +535,6 @@ def callback(env): res = [x.split(':') for x in bst_eval.split()] for kv in res[1:]: res_dict[kv[0]] = [float(kv[1])] - eval_res = [] keys = list(res_dict.keys()) keys.sort(key=lambda x: x if metric_shortname not in x else "a" + x) diff --git a/python/tvm/contrib/cc.py b/python/tvm/contrib/cc.py index ae37923a1dcf..4a88061a8805 100644 --- a/python/tvm/contrib/cc.py +++ b/python/tvm/contrib/cc.py @@ -211,6 +211,7 @@ def _windows_shared(output, objects, options): except FileNotFoundError: raise RuntimeError("Can not find cl.exe," "please run this in Vistual Studio Command Prompt.") + print(py_str(out)) if proc.returncode != 0: msg = "Compilation error:\n" msg += py_str(out) @@ -226,7 +227,6 @@ def _windows_shared(output, objects, options): if obj.endswith(".o"): link_cmd += [obj] - link_cmd += ["-EXPORT:__tvm_main__"] link_cmd += [temp_path + "dllmain.obj"] link_cmd += ["-out:" + output] diff --git a/python/tvm/exec/rpc_tracker.py b/python/tvm/exec/rpc_tracker.py index 8f5bd1dc73a0..6a433b7c9d6b 100644 --- a/python/tvm/exec/rpc_tracker.py +++ b/python/tvm/exec/rpc_tracker.py @@ -17,7 +17,7 @@ # pylint: disable=redefined-outer-name, invalid-name """Tool to start RPC tracker""" from __future__ import absolute_import - +import os import logging import argparse import multiprocessing @@ -28,8 +28,11 @@ def main(args): """Main funciton""" tracker = Tracker(args.host, port=args.port, port_end=args.port_end, silent=args.silent) - tracker.proc.join() - + if os.name == 'nt': + while True: + input() + else: + tracker.proc.join() if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/python/tvm/rpc/base.py b/python/tvm/rpc/base.py index bc81534a12d9..3df366397d56 100644 --- a/python/tvm/rpc/base.py +++ b/python/tvm/rpc/base.py @@ -26,6 +26,7 @@ import struct import random import logging +import os import tvm._ffi from .._ffi.base import py_str @@ -59,6 +60,16 @@ class TrackerCode(object): def get_addr_family(addr): + """Gets the address family""" + if os.name == 'nt': + # WINDOWS CANNOT USE THE *NIX IMPL OF THIS! FUNCTION SUCCEEDS AND WORKS + # BUT IT CAUSES MAJOR PROBLEMS. IT LEAVES MYSTERIOUS REFERENCES THAT ARE + # HELD AND THE RPCSESSION WOULD NOT BE IMMEDIATE RELEASED, CAUSING + # TIMEOUTS WITH THE RPCSERVER BECAUSE THE SOCKET IN THE C++ DIDN'T LOSE ALL + # OF ITS REFERENCES. + # This isn't a 1:1 of the *nix implementation, should probably + # take a closer look as it probably doesn't work with IPV6 addresses + return socket.AF_INET res = socket.getaddrinfo(addr[0], addr[1], 0, 0, socket.IPPROTO_TCP) return res[0][0] diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py index 627d67a0a835..7cb6b5c31c67 100644 --- a/python/tvm/rpc/server.py +++ b/python/tvm/rpc/server.py @@ -38,6 +38,13 @@ import signal import platform import tvm._ffi +import psutil +#pylint: disable=wrong-import-position +if os.name == 'nt': + from pathos.helpers import ProcessPool + import threading + #pylint: disable=ungrouped-imports + import multiprocessing.pool from tvm._ffi.base import py_str from tvm._ffi.libinfo import find_lib_path @@ -45,29 +52,51 @@ from tvm.contrib import util from . import base from . base import TrackerCode - +#pylint: enable=wrong-import-position logger = logging.getLogger('RPCServer') +_temp = None + +if os.name == 'nt': + class NoDaemonProcess(multiprocessing.Process): + # make 'daemon' attribute always return False + def _get_daemon(self): + return False + def _set_daemon(self, value): + pass + daemon = property(_get_daemon, _set_daemon) + +if os.name == 'nt': + # We sub-class multiprocessing.pool.Pool instead of multiprocessing.Pool + # because the latter is only a wrapper function, not a proper class. + # pylint: disable=W0223 + class MyPool(multiprocessing.pool.Pool): + Process = NoDaemonProcess + +# pylint: disable=unused-variable +@tvm._ffi.register_func("tvm.rpc.server.workpath", override=True) +def get_workpath(path): + global _temp + return _temp.relpath(path) + +@tvm._ffi.register_func("tvm.rpc.server.load_module", override=True) +def load_module(file_name): + """Load module from remote side.""" + global _temp + path = _temp.relpath(file_name) + m = _load_module(path) + logger.info("load_module %s", path) + return m + def _server_env(load_library, work_path=None): """Server environment function return temp dir""" + global _temp if work_path: temp = work_path else: temp = util.tempdir() - # pylint: disable=unused-variable - @tvm._ffi.register_func("tvm.rpc.server.workpath") - def get_workpath(path): - return temp.relpath(path) - - @tvm._ffi.register_func("tvm.rpc.server.load_module", override=True) - def load_module(file_name): - """Load module from remote side.""" - path = temp.relpath(file_name) - m = _load_module(path) - logger.info("load_module %s", path) - return m - + _temp = temp libs = [] load_library = load_library.split(":") if load_library else [] for file_name in load_library: @@ -84,6 +113,25 @@ def _serve_loop(sock, addr, load_library, work_path=None): base._ServerLoop(sockfd) if not work_path: temp.remove() + + logger.info("Finish serving %s", addr) + +def _serve_loop_pool(args): + """Server loop""" + sock = args["sock"] + addr = args["addr"] + load_library = args["load_library"] + work_path = args["work_path"] + + try: + sockfd = sock.fileno() + temp = _server_env(load_library, work_path) + base._ServerLoop(sockfd) + if not work_path: + temp.remove() + except Exception: # pylint: disable=broad-except + pass + logger.info("Finish serving %s", addr) def _parse_server_opt(opts): @@ -94,6 +142,11 @@ def _parse_server_opt(opts): ret["timeout"] = float(kv[9:]) return ret +# For Windows +_trial_counter = 0 +# For Windows +_executor_pool = None + def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr): """Listening loop of the server master.""" def _accept_conn(listen_sock, tracker_conn, ping_period=2): @@ -191,30 +244,69 @@ def _accept_conn(listen_sock, tracker_conn, ping_period=2): tracker_conn.close() tracker_conn = None continue - except RuntimeError as exc: + except RuntimeError as exc: # pylint: disable=broad-except raise exc # step 3: serving work_path = util.tempdir() logger.info("connection from %s", addr) - server_proc = multiprocessing.Process(target=_serve_loop, - args=(conn, addr, load_library, work_path)) - server_proc.deamon = True - server_proc.start() - # close from our side. - conn.close() - # wait until server process finish or timeout - server_proc.join(opts.get("timeout", None)) - if server_proc.is_alive(): - logger.info("Timeout in RPC session, kill..") - # pylint: disable=import-outside-toplevel - import psutil - parent = psutil.Process(server_proc.pid) - # terminate worker childs - for child in parent.children(recursive=True): - child.terminate() - # terminate the worker - server_proc.terminate() + + def handle_posix(): + """Handles serving on non-Windows OS""" + server_proc = multiprocessing.Process(target=_serve_loop, + args=(conn, addr, load_library, work_path)) + server_proc.deamon = True + server_proc.start() + # close from our side. + conn.close() + # wait until server process finish or timeout + server_proc.join(opts.get("timeout", None)) + if server_proc.is_alive(): + logger.info("Timeout in RPC session, kill..") + + try: + parent = psutil.Process(server_proc.pid) + # terminate worker childs + # this can throw on Windows + for child in parent.children(recursive=True): + child.terminate() + except Exception: # pylint: disable=broad-except + pass + + # terminate the worker + server_proc.terminate() + def handle_win32(): + """Handles serving on Windows OS""" + global _trial_counter + global _executor_pool + + if _trial_counter % 5 == 0 or _executor_pool is None: + if _executor_pool is not None: + _executor_pool.terminate() + _executor_pool = MyPool(processes=1) + + _trial_counter += 1 + + args = { + "sock" : conn, + "addr" : addr, + "load_library" : load_library, + "work_path" : work_path + } + + _executor_pool.map(_serve_loop_pool, [args]) + + try: + conn.close() + conn.shutdown(1) + except Exception: # pylint: disable=broad-except + pass + + if os.name != 'nt': + handle_posix() + else: + handle_win32() + work_path.remove() @@ -233,7 +325,6 @@ def _connect_proxy_loop(addr, key, load_library): magic = struct.unpack("