From f2285f2627e8825f5880f4a356d0ca05dbdc6c6a Mon Sep 17 00:00:00 2001 From: Jeremiah Morrill Date: Sat, 9 Nov 2019 21:50:26 -0800 Subject: [PATCH 01/33] Initial checkin for Windows support --- python/tvm/autotvm/measure/local_executor.py | 95 +++++++++++++++++-- python/tvm/autotvm/measure/measure_methods.py | 6 +- python/tvm/autotvm/task/task.py | 15 ++- .../tvm/autotvm/tuner/xgboost_cost_model.py | 78 +++++++++++++-- python/tvm/rpc/base.py | 15 ++- python/tvm/rpc/server.py | 13 ++- 6 files changed, 197 insertions(+), 25 deletions(-) diff --git a/python/tvm/autotvm/measure/local_executor.py b/python/tvm/autotvm/measure/local_executor.py index cf81e2b50e50..a8ed342cf965 100644 --- a/python/tvm/autotvm/measure/local_executor.py +++ b/python/tvm/autotvm/measure/local_executor.py @@ -17,8 +17,28 @@ """Local based implementation of the executor using multiprocessing""" import signal - -from multiprocessing import Process, Queue +import os + +if os.name == 'nt': + import queue as thread_queue + import threading + # Pathos uses dill, which can pickle things like functions + from pathos.helpers import ProcessPool + # On Windows, there is no fork(), a 'multiprocessing.Process' + # or ProcessPool has to 'build up' the script from scratch, we + # set these environment variables so each python.exe process + # does not allocate unneeded threads + os.environ['OMP_NUM_THREADS'] = "1" + os.environ['TVM_NUM_THREADS'] = "1" + # numpy seems to honor this + os.environ['MKL_NUM_THREADS'] = "1" + + # Since there is no fork() on Windows, to mitigate performance impact + # we will use a process pool for executers, vs the *nix based systems + # that will fork() a new process for each executor + executor_pool = None + +from multiprocessing import Process, Queue, cpu_count try: from queue import Empty except ImportError: @@ -68,6 +88,27 @@ def call_with_timeout(queue, timeout, func, args, kwargs): p.terminate() p.join() +if os.name == 'nt': + def call_from_pool(func, args, kwargs, timeout, env): + """A wrapper to support timeout of a function call for a pool process""" + + # Restore environment variables from parent + for key, val in env.items(): + os.environ[key] = val + + queue = thread_queue.Queue(2) + + # We use a thread here for Windows, because starting up a new Process can be heavy + # This isn't as clean as the *nix implementation, which can kill a process that + # has timed out + thread = threading.Thread(target=_execute_func, args=(func, queue, args, kwargs)) + thread.start() + thread.join(timeout=timeout) + + queue.put(executor.TimeoutError()) + + res = queue.get() + return res class LocalFuture(executor.Future): """Local wrapper for the future @@ -119,6 +160,31 @@ def done(self): def get(self, timeout=None): return self._result +if os.name == 'nt': + class LocalFuturePool(executor.Future): + """Local wrapper for the future using a Process pool + + Parameters + ---------- + thread: threading.Thread + Thread for running this task + pool_results: result from Pool.apply_async + queue for receiving the result of this task + """ + def __init__(self, pool_results): + self._done = False + self._pool_results = pool_results + + def done(self): + return self._done + + def get(self, timeout=None): + try: + res = self._pool_results.get(timeout=timeout) + except Empty: + raise executor.TimeoutError() + self._done = True + return res class LocalExecutor(executor.Executor): """Local executor that runs workers on the same machine with multiprocessing. @@ -145,8 +211,23 @@ def submit(self, func, *args, **kwargs): if not self.do_fork: return LocalFutureNoFork(func(*args, **kwargs)) - queue = Queue(2) - process = Process(target=call_with_timeout, - args=(queue, self.timeout, func, args, kwargs)) - process.start() - return LocalFuture(process, queue) + if os.name != 'nt': + queue = Queue(2) + process = Process(target=call_with_timeout, + args=(queue, self.timeout, func, args, kwargs)) + process.start() + return LocalFuture(process, queue) + else: + global executor_pool + + if executor_pool is None: + # We use a static pool for executor processes because Process.start(entry) + # is so slow on Windows, we lose a lot of parallelism. + # Right now cpu_count() is used, which isn't optimal from a user configuration + # perspective, but is reasonable at this time. + executor_pool = ProcessPool(cpu_count()) + + # Windows seemed to be missing some valuable environ variables + # on the pool's process side. We might be able to get away with + # just sending the PATH variable, but for now, we just clone our env + return LocalFuturePool(executor_pool.apply_async(call_from_pool, (func, args, kwargs, self.timeout, os.environ.copy()))) diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py index 36efc881958e..fa63d9e058f1 100644 --- a/python/tvm/autotvm/measure/measure_methods.py +++ b/python/tvm/autotvm/measure/measure_methods.py @@ -27,7 +27,9 @@ import os import threading import time -from random import getrandbits +# Import random directly because it appears dill will pull in 'getrandbits' +# and it will always get the same random number. Using random.getrandbits fixes +import random from collections import namedtuple import tempfile @@ -403,7 +405,7 @@ def _wrapped(measure_input, tmp_dir, **kwargs): tic = time.time() try: filename = os.path.join(tmp_dir, "tmp_func_%0x.%s" % ( - getrandbits(64), output_format)) + random.getrandbits(64), output_format)) # TODO(tvm-team) consider linline _build_func_common func, arg_info = _build_func_common(measure_input, **kwargs) func.export_library(filename, build_func) diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py index 4f3cc90b474e..bfa53df6df51 100644 --- a/python/tvm/autotvm/task/task.py +++ b/python/tvm/autotvm/task/task.py @@ -23,7 +23,7 @@ """ import numpy as np - +import os from ... import tensor, expr, container, target as _target from ..util import get_const_int, get_const_tuple, get_func_name @@ -97,7 +97,9 @@ def __getstate__(self): "workload": self.workload, "flop": self.flop, "target": self.target, - "target_host": self.target_host + "target_host": self.target_host, + # On Windows we will use, dill, which can pickle functions + "func": self.func if os.name == 'nt' else None } def __setstate__(self, state): @@ -105,7 +107,8 @@ def __setstate__(self, state): self.args = state["args"] self.kwargs = state["kwargs"] self.config_space = state["config_space"] - self.func = TASK_TABLE.get(state["name"], _raise_error) + # Use pickled function on Windows + self.func = state["func"] if os.name == 'nt' else TASK_TABLE.get(state["name"], _raise_error) self.workload = state["workload"] self.flop = state["flop"] self.target = state["target"] @@ -189,7 +192,11 @@ def create(func_name, args, target, target_host=None, template_key=None): with ctx: with target: sch, _ = func(*args) - ret.config_space.code_hash = getattr(sch, 'code_hash', None) + try: + # getattr will throw here on Windows, as of an Oct 2019 commit + ret.config_space.code_hash = getattr(sch, 'code_hash', None) + except: + ret.config_space.code_hash = None ret.workload = ctx.workload ret.flop = ret.config_space.flop or compute_flop(sch) diff --git a/python/tvm/autotvm/tuner/xgboost_cost_model.py b/python/tvm/autotvm/tuner/xgboost_cost_model.py index 265365144639..bc1f7cee75e5 100644 --- a/python/tvm/autotvm/tuner/xgboost_cost_model.py +++ b/python/tvm/autotvm/tuner/xgboost_cost_model.py @@ -20,6 +20,15 @@ import multiprocessing import logging import time +import os + +if os.name == 'nt': + # Pathos' Pool does pickling via dill, which can pickle + # functions, which is required because Windows doesn't + # support fork() + from pathos.helpers import mp as pathos_multiprocess + from pathos.helpers import ProcessPool + import pathos.multiprocessing import numpy as np try: @@ -153,12 +162,59 @@ def _reset_pool(self, space, target, task): self._close_pool() - # use global variable to pass common arguments - global _extract_space, _extract_target, _extract_task - _extract_space = space - _extract_target = target - _extract_task = task - self.pool = multiprocessing.Pool(self.num_threads) + if os.name == 'nt': + # For Windows, we need space, target, task to be pickled and set on the + # Pool's process side, where the *nix impl simply sets globals + # then forks. + # To ensure each process in the pool is properly set, we have to do + # some synchronization by sending an async call and waiting for + # the queue to have an item set + + # There seems to be diminishing returns on large pool sizes given + # the small job sizes mapped later in the code (largest seems to be 128) + # so the pool size is capped + pool_size = min(16, int(self.num_threads)) + + self.pool = ProcessPool(pool_size) + manager = pathos_multiprocess.Manager() + + pipe_syncs = [] + + # A simple pathos.map would be cleaner, but it seems that in some cases, + # some of the pools processes will be missed, with some processes running + # the method twice. It seems that just passing a Queue in this manner, + # hits all the processes in the pool. Some assertion should be built to verify + for i in range(pool_size): + queue = manager.Queue(1) + results = { + "queue": queue, + "apipe": self.pool.apply_async(_set_pool_process_state, (space, target, task, queue)) + } + pipe_syncs.append(results) + + # wait loop until all async calls have completed + while True: + all_ready = True + for pipe_sync in pipe_syncs: + if pipe_sync["apipe"].ready() == False: + all_ready = False + break + if all_ready: + break; + else: + time.sleep(0.1) + # complete the async requests on the pool + for pipe_sync in pipe_syncs: + pipe_sync["apipe"].get() + # This may not be needed + pipe_sync["queue"].get(block=True) + else: + # use global variable to pass common arguments + global _extract_space, _extract_target, _extract_task + _extract_space = space + _extract_target = target + _extract_task = task + self.pool = multiprocessing.Pool(self.num_threads) def _close_pool(self): if self.pool: @@ -332,6 +388,16 @@ def __del__(self): _extract_target = None _extract_task = None +if os.name == 'nt': + def _set_pool_process_state(space, target, task, sync_queue): + """sets process state for when fork() is not available """ + global _extract_space, _extract_target, _extract_task + _extract_space = space + _extract_target = target + _extract_task = task + # Notify caller that we are done. We may be able to remove this + sync_queue.put(None) + def _extract_itervar_feature_index(index): """extract iteration var feature for an index in extract_space""" try: diff --git a/python/tvm/rpc/base.py b/python/tvm/rpc/base.py index a1e837cd0c1f..683b8fe21700 100644 --- a/python/tvm/rpc/base.py +++ b/python/tvm/rpc/base.py @@ -26,6 +26,7 @@ import struct import random import logging +import os from .._ffi.function import _init_api from .._ffi.base import py_str @@ -59,8 +60,18 @@ class TrackerCode(object): def get_addr_family(addr): - res = socket.getaddrinfo(addr[0], addr[1], 0, 0, socket.IPPROTO_TCP) - return res[0][0] + if os.name == 'nt': + # WINDOWS CANNOT USE THE *NIX IMPL OF THIS! FUNCTION SUCCEEDS AND WORKS + # BUT IT CAUSES MAJOR PROBLEMS. IT LEAVES MYSTERIOUS REFERENCES THAT ARE + # HELD AND THE RPCSESSION WOULD NOT BE IMMEDIATE RELEASED, CAUSING + # TIMEOUTS WITH THE RPCSERVER BECAUSE THE SOCKET IN THE C++ DIDN'T LOSE ALL + # OF ITS REFERENCES. + # This isn't a 1:1 of the *nix implementation, should probably + # take a closer look as it probably doesn't work with IPV6 addresses + return socket.AF_INET + else: + res = socket.getaddrinfo(addr[0], addr[1], 0, 0, socket.IPPROTO_TCP) + return res[0][0] def recvall(sock, nbytes): diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py index 9e03097e89a7..00654446e5b5 100644 --- a/python/tvm/rpc/server.py +++ b/python/tvm/rpc/server.py @@ -211,10 +211,15 @@ def _accept_conn(listen_sock, tracker_conn, ping_period=2): if server_proc.is_alive(): logger.info("Timeout in RPC session, kill..") import psutil - parent = psutil.Process(server_proc.pid) - # terminate worker childs - for child in parent.children(recursive=True): - child.terminate() + try: + parent = psutil.Process(server_proc.pid) + # terminate worker childs + # this can throw on Windows + for child in parent.children(recursive=True): + child.terminate() + except: # pylint: disable=broad-except + pass + # terminate the worker server_proc.terminate() work_path.remove() From 74038efe00cc216dfef9611616271aac9a286186 Mon Sep 17 00:00:00 2001 From: Jeremiah Morrill Date: Mon, 11 Nov 2019 02:06:06 -0800 Subject: [PATCH 02/33] Work to support autotvm.LocalRunner in Windows --- python/tvm/autotvm/measure/local_executor.py | 466 +++---- python/tvm/autotvm/measure/measure_methods.py | 1225 +++++++++-------- python/tvm/rpc/server.py | 872 ++++++------ python/tvm/rpc/tracker.py | 50 +- 4 files changed, 1338 insertions(+), 1275 deletions(-) diff --git a/python/tvm/autotvm/measure/local_executor.py b/python/tvm/autotvm/measure/local_executor.py index a8ed342cf965..d35c42b5dc35 100644 --- a/python/tvm/autotvm/measure/local_executor.py +++ b/python/tvm/autotvm/measure/local_executor.py @@ -1,233 +1,233 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Local based implementation of the executor using multiprocessing""" - -import signal -import os - -if os.name == 'nt': - import queue as thread_queue - import threading - # Pathos uses dill, which can pickle things like functions - from pathos.helpers import ProcessPool - # On Windows, there is no fork(), a 'multiprocessing.Process' - # or ProcessPool has to 'build up' the script from scratch, we - # set these environment variables so each python.exe process - # does not allocate unneeded threads - os.environ['OMP_NUM_THREADS'] = "1" - os.environ['TVM_NUM_THREADS'] = "1" - # numpy seems to honor this - os.environ['MKL_NUM_THREADS'] = "1" - - # Since there is no fork() on Windows, to mitigate performance impact - # we will use a process pool for executers, vs the *nix based systems - # that will fork() a new process for each executor - executor_pool = None - -from multiprocessing import Process, Queue, cpu_count -try: - from queue import Empty -except ImportError: - from Queue import Empty - -try: - import psutil -except ImportError: - psutil = None - -from . import executor - - -def kill_child_processes(parent_pid, sig=signal.SIGTERM): - """kill all child processes recursively""" - try: - parent = psutil.Process(parent_pid) - except psutil.NoSuchProcess: - return - children = parent.children(recursive=True) - for process in children: - try: - process.send_signal(sig) - except psutil.NoSuchProcess: - return - -def _execute_func(func, queue, args, kwargs): - """execute function and return the result or exception to a queue""" - try: - res = func(*args, **kwargs) - except Exception as exc: # pylint: disable=broad-except - res = exc - queue.put(res) - - -def call_with_timeout(queue, timeout, func, args, kwargs): - """A wrapper to support timeout of a function call""" - - # start a new process for timeout (cannot use thread because we have c function) - p = Process(target=_execute_func, args=(func, queue, args, kwargs)) - p.start() - p.join(timeout=timeout) - - queue.put(executor.TimeoutError()) - - kill_child_processes(p.pid) - p.terminate() - p.join() - -if os.name == 'nt': - def call_from_pool(func, args, kwargs, timeout, env): - """A wrapper to support timeout of a function call for a pool process""" - - # Restore environment variables from parent - for key, val in env.items(): - os.environ[key] = val - - queue = thread_queue.Queue(2) - - # We use a thread here for Windows, because starting up a new Process can be heavy - # This isn't as clean as the *nix implementation, which can kill a process that - # has timed out - thread = threading.Thread(target=_execute_func, args=(func, queue, args, kwargs)) - thread.start() - thread.join(timeout=timeout) - - queue.put(executor.TimeoutError()) - - res = queue.get() - return res - -class LocalFuture(executor.Future): - """Local wrapper for the future - - Parameters - ---------- - process: multiprocessing.Process - process for running this task - queue: multiprocessing.Queue - queue for receiving the result of this task - """ - def __init__(self, process, queue): - self._done = False - self._process = process - self._queue = queue - - def done(self): - self._done = self._done or not self._queue.empty() - return self._done - - def get(self, timeout=None): - try: - res = self._queue.get(block=True, timeout=timeout) - except Empty: - raise executor.TimeoutError() - if self._process.is_alive(): - kill_child_processes(self._process.pid) - self._process.terminate() - self._process.join() - self._queue.close() - self._queue.join_thread() - self._done = True - del self._queue - del self._process - return res - - -class LocalFutureNoFork(executor.Future): - """Local wrapper for the future. - This is a none-fork version of LocalFuture. - Use this for the runtime that does not support fork (like cudnn) - """ - def __init__(self, result): - self._result = result - - def done(self): - return True - - def get(self, timeout=None): - return self._result - -if os.name == 'nt': - class LocalFuturePool(executor.Future): - """Local wrapper for the future using a Process pool - - Parameters - ---------- - thread: threading.Thread - Thread for running this task - pool_results: result from Pool.apply_async - queue for receiving the result of this task - """ - def __init__(self, pool_results): - self._done = False - self._pool_results = pool_results - - def done(self): - return self._done - - def get(self, timeout=None): - try: - res = self._pool_results.get(timeout=timeout) - except Empty: - raise executor.TimeoutError() - self._done = True - return res - -class LocalExecutor(executor.Executor): - """Local executor that runs workers on the same machine with multiprocessing. - - Parameters - ---------- - timeout: float, optional - timeout of a job. If time is out. A TimeoutError will be returned (not raised) - do_fork: bool, optional - For some runtime systems that do not support fork after initialization - (e.g. cuda runtime, cudnn). Set this to False if you have used these runtime - before submitting jobs. - """ - def __init__(self, timeout=None, do_fork=True): - self.timeout = timeout or executor.Executor.DEFAULT_TIMEOUT - self.do_fork = do_fork - - if self.do_fork: - if not psutil: - raise RuntimeError("Python package psutil is missing. " - "please try `pip install psutil`") - - def submit(self, func, *args, **kwargs): - if not self.do_fork: - return LocalFutureNoFork(func(*args, **kwargs)) - - if os.name != 'nt': - queue = Queue(2) - process = Process(target=call_with_timeout, - args=(queue, self.timeout, func, args, kwargs)) - process.start() - return LocalFuture(process, queue) - else: - global executor_pool - - if executor_pool is None: - # We use a static pool for executor processes because Process.start(entry) - # is so slow on Windows, we lose a lot of parallelism. - # Right now cpu_count() is used, which isn't optimal from a user configuration - # perspective, but is reasonable at this time. - executor_pool = ProcessPool(cpu_count()) - - # Windows seemed to be missing some valuable environ variables - # on the pool's process side. We might be able to get away with - # just sending the PATH variable, but for now, we just clone our env - return LocalFuturePool(executor_pool.apply_async(call_from_pool, (func, args, kwargs, self.timeout, os.environ.copy()))) +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Local based implementation of the executor using multiprocessing""" + +import signal +import os + +if os.name == 'nt': + import queue as thread_queue + import threading + # Pathos uses dill, which can pickle things like functions + from pathos.helpers import ProcessPool + # On Windows, there is no fork(), a 'multiprocessing.Process' + # or ProcessPool has to 'build up' the script from scratch, we + # set these environment variables so each python.exe process + # does not allocate unneeded threads + os.environ['OMP_NUM_THREADS'] = "1" + os.environ['TVM_NUM_THREADS'] = "1" + # numpy seems to honor this + os.environ['MKL_NUM_THREADS'] = "1" + + # Since there is no fork() on Windows, to mitigate performance impact + # we will use a process pool for executers, vs the *nix based systems + # that will fork() a new process for each executor + executor_pool = None + +from multiprocessing import Process, Queue, cpu_count +try: + from queue import Empty +except ImportError: + from Queue import Empty + +try: + import psutil +except ImportError: + psutil = None + +from . import executor + + +def kill_child_processes(parent_pid, sig=signal.SIGTERM): + """kill all child processes recursively""" + try: + parent = psutil.Process(parent_pid) + except psutil.NoSuchProcess: + return + children = parent.children(recursive=True) + for process in children: + try: + process.send_signal(sig) + except psutil.NoSuchProcess: + return + +def _execute_func(func, queue, args, kwargs): + """execute function and return the result or exception to a queue""" + try: + res = func(*args, **kwargs) + except Exception as exc: # pylint: disable=broad-except + res = exc + queue.put(res) + + +def call_with_timeout(queue, timeout, func, args, kwargs): + """A wrapper to support timeout of a function call""" + + # start a new process for timeout (cannot use thread because we have c function) + p = Process(target=_execute_func, args=(func, queue, args, kwargs)) + p.start() + p.join(timeout=timeout) + + queue.put(executor.TimeoutError()) + + kill_child_processes(p.pid) + p.terminate() + p.join() + +if os.name == 'nt': + def call_from_pool(func, args, kwargs, timeout, env): + """A wrapper to support timeout of a function call for a pool process""" + + # Restore environment variables from parent + for key, val in env.items(): + os.environ[key] = val + + queue = thread_queue.Queue(2) + + # We use a thread here for Windows, because starting up a new Process can be heavy + # This isn't as clean as the *nix implementation, which can kill a process that + # has timed out + thread = threading.Thread(target=_execute_func, args=(func, queue, args, kwargs)) + thread.start() + thread.join(timeout=timeout) + + queue.put(executor.TimeoutError()) + + res = queue.get() + return res + +class LocalFuture(executor.Future): + """Local wrapper for the future + + Parameters + ---------- + process: multiprocessing.Process + process for running this task + queue: multiprocessing.Queue + queue for receiving the result of this task + """ + def __init__(self, process, queue): + self._done = False + self._process = process + self._queue = queue + + def done(self): + self._done = self._done or not self._queue.empty() + return self._done + + def get(self, timeout=None): + try: + res = self._queue.get(block=True, timeout=timeout) + except Empty: + raise executor.TimeoutError() + if self._process.is_alive(): + kill_child_processes(self._process.pid) + self._process.terminate() + self._process.join() + self._queue.close() + self._queue.join_thread() + self._done = True + del self._queue + del self._process + return res + + +class LocalFutureNoFork(executor.Future): + """Local wrapper for the future. + This is a none-fork version of LocalFuture. + Use this for the runtime that does not support fork (like cudnn) + """ + def __init__(self, result): + self._result = result + + def done(self): + return True + + def get(self, timeout=None): + return self._result + +if os.name == 'nt': + class LocalFuturePool(executor.Future): + """Local wrapper for the future using a Process pool + + Parameters + ---------- + thread: threading.Thread + Thread for running this task + pool_results: result from Pool.apply_async + queue for receiving the result of this task + """ + def __init__(self, pool_results): + self._done = False + self._pool_results = pool_results + + def done(self): + return self._done + + def get(self, timeout=None): + try: + res = self._pool_results.get(timeout=timeout) + except Empty: + raise executor.TimeoutError() + self._done = True + return res + +class LocalExecutor(executor.Executor): + """Local executor that runs workers on the same machine with multiprocessing. + + Parameters + ---------- + timeout: float, optional + timeout of a job. If time is out. A TimeoutError will be returned (not raised) + do_fork: bool, optional + For some runtime systems that do not support fork after initialization + (e.g. cuda runtime, cudnn). Set this to False if you have used these runtime + before submitting jobs. + """ + def __init__(self, timeout=None, do_fork=True): + self.timeout = timeout or executor.Executor.DEFAULT_TIMEOUT + self.do_fork = do_fork + + if self.do_fork: + if not psutil: + raise RuntimeError("Python package psutil is missing. " + "please try `pip install psutil`") + + def submit(self, func, *args, **kwargs): + if not self.do_fork: + return LocalFutureNoFork(func(*args, **kwargs)) + + if os.name != 'nt': + queue = Queue(2) + process = Process(target=call_with_timeout, + args=(queue, self.timeout, func, args, kwargs)) + process.start() + return LocalFuture(process, queue) + else: + global executor_pool + + if executor_pool is None: + # We use a static pool for executor processes because Process.start(entry) + # is so slow on Windows, we lose a lot of parallelism. + # Right now cpu_count() is used, which isn't optimal from a user configuration + # perspective, but is reasonable at this time. + executor_pool = ProcessPool(cpu_count() * 2) + + # Windows seemed to be missing some valuable environ variables + # on the pool's process side. We might be able to get away with + # just sending the PATH variable, but for now, we just clone our env + return LocalFuturePool(executor_pool.apply_async(call_from_pool, (func, args, kwargs, self.timeout, os.environ.copy()))) diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py index fa63d9e058f1..6db8b8f21083 100644 --- a/python/tvm/autotvm/measure/measure_methods.py +++ b/python/tvm/autotvm/measure/measure_methods.py @@ -1,611 +1,614 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name,too-many-function-args,too-many-nested-blocks -""" -Functions that run on executor for measurement. - -These functions are responsible for building the tvm module, uploading it to -remote devices, recording the running time costs, and checking the correctness of the output. -""" - -import logging -import shutil -import os -import threading -import time -# Import random directly because it appears dill will pull in 'getrandbits' -# and it will always get the same random number. Using random.getrandbits fixes -import random -from collections import namedtuple -import tempfile - -import numpy as np - -from ... import ir_pass, build, build_config, nd, TVMError, register_func, \ - rpc as _rpc, target as _target -from ...contrib import nvcc, ndk, tar - -from ..util import get_const_tuple -from ..env import AutotvmGlobalScope -from ..task.space import InstantiationError - -from .measure import MeasureResult, MeasureErrorNo, Builder, Runner -from .local_executor import LocalExecutor - -logger = logging.getLogger('autotvm') - -class BuildResult(namedtuple("BuildResult", ('filename', 'arg_info', 'error', 'time_cost'))): - """ - Stores all the necessary inputs for a measurement. - - Parameters - ---------- - filename : str - The filename of generated library - arg_info : Tuple - The shape and dtype information of tvm tensor arguments - error : Exception - The error happens during compilation. - time_cost : float - The time cost of building - """ - -class LocalBuilder(Builder): - """Run compilation on local machine - - Parameters - ---------- - timeout: float - The timeout of a compilation - n_parallel: int - The number of tasks run in parallel. "None" will use all cpu cores - build_func: callable or str - If is 'default', use default build function - If is 'ndk', use function for android ndk - If is callable, use it as custom build function, expect lib_format field. - """ - def __init__(self, timeout=10, n_parallel=None, build_func='default'): - super(LocalBuilder, self).__init__(timeout, n_parallel) - - if isinstance(build_func, str): - if build_func == 'default': - build_func = tar.tar - elif build_func == 'ndk': - build_func = ndk.create_shared - else: - raise ValueError("Invalid build_func" + build_func) - self.build_func = _wrap_build_func(build_func) - self.executor = LocalExecutor(timeout=timeout) - self.tmp_dir = tempfile.mkdtemp() - - def build(self, measure_inputs): - results = [] - - shutil.rmtree(self.tmp_dir) - self.tmp_dir = tempfile.mkdtemp() - - for i in range(0, len(measure_inputs), self.n_parallel): - futures = [] - for inp in measure_inputs[i:i + self.n_parallel]: - ret = self.executor.submit(self.build_func, - inp, - self.tmp_dir, - **self.build_kwargs) - futures.append(ret) - - for future in futures: - res = future.get() - - if isinstance(res, Exception): - # timeout or fleet error, return MeasureResult directly - results.append(MeasureResult((res,), MeasureErrorNo.BUILD_TIMEOUT, - self.timeout, time.time())) - elif res.error is not None: - # instantiation error - if isinstance(res.error, InstantiationError): - results.append(MeasureResult((res.error,), - MeasureErrorNo.INSTANTIATION_ERROR, - res.time_cost, time.time())) - else: - if "InstantiationError" in str(res.error): - msg = str(res.error) - try: - msg = msg.split('\n')[-2].split(": ")[1] - except Exception: # pylint: disable=broad-except - pass - results.append(MeasureResult((InstantiationError(msg),), - MeasureErrorNo.INSTANTIATION_ERROR, - res.time_cost, time.time())) - else: # tvm error - results.append(MeasureResult((res.error,), - MeasureErrorNo.COMPILE_HOST, - res.time_cost, time.time())) - else: - # return BuildResult - results.append(res) - - return results - - -class RPCRunner(Runner): - """Run generated code on remove devices. - This function will ask a RPC Tracker to get device for measurement. - - Parameters - ---------- - timeout: float - The timeout of a compilation - n_parallel: int - The number of tasks run in parallel. "None" will use all cpu cores - key: str - The key of the device registered in the tracker - host: str - The host address of RPC Tracker - port: int - The port of RPC Tracker - number: int - The number of times to run the generated code for taking average. - We call these runs as one `repeat` of measurement. - repeat : int, optional - The number of times to repeat the measurement. - In total, the generated code will be run (1 + number x repeat) times, - where the first "1" is warm up and will be discarded. - The returned result contains `repeat` costs, - each of which is an average of `number` costs. - min_repeat_ms: int, optional - The minimum duration of one `repeat` in milliseconds. - By default, one `repeat` contains `number` runs. If this parameter is set, - the parameters `number` will be dynamically adjusted to meet the - minimum duration requirement of one `repeat`. - i.e., When the run time of one `repeat` falls below this time, the `number` parameter - will be automatically increased. - cooldown_interval: float, optional - The cool down interval between two measurements. - check_correctness: bool, optional - Whether check correctness after measurement. This will use llvm cpu target to - call your template and get the reference output. - This can work for TOPI templates, but may not work for your custom template. - """ - def __init__(self, - key, host, port, priority=1, - timeout=10, n_parallel=None, - number=4, repeat=3, min_repeat_ms=0, cooldown_interval=0.1, - check_correctness=False): - super(RPCRunner, self).__init__(timeout, n_parallel) - - self.key = key - self.host = host - self.port = port - self.priority = priority - self.timeout = timeout - - self.number = number - self.repeat = repeat - self.min_repeat_ms = min_repeat_ms - - self.ref_input = None - self.ref_output = None - self.check_correctness = check_correctness - self.cooldown_interval = cooldown_interval - - self.executor = LocalExecutor() - - def set_task(self, task): - self.task = task - - if check_remote(task.target, self.key, self.host, self.port): - logger.info("Get devices for measurement successfully!") - else: - raise RuntimeError("Cannot get remote devices from the tracker. " - "Please check the status of tracker by " - "'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' " - "and make sure you have free devices on the queue status.") - - if self.check_correctness: - # use llvm cpu to generate a reference input/output - # this option works for tuning topi, but might not work for you custom op - with _target.create("llvm"): - s, arg_bufs = task.instantiate(task.config_space.get(0)) - self.ref_input = [np.random.uniform(size=get_const_tuple(x.shape)).astype(x.dtype) - for x in arg_bufs] - func = build(s, arg_bufs, "llvm") - tvm_buf = [nd.array(x) for x in self.ref_input] - func(*tvm_buf) - self.ref_output = [x.asnumpy() for x in tvm_buf] - - def get_build_kwargs(self): - kwargs = {} - if 'cuda' in self.task.target.keys or 'opencl' in self.task.target.keys: - remote = request_remote(self.key, self.host, self.port) - ctx = remote.context(str(self.task.target), 0) - max_dims = ctx.max_thread_dimensions - kwargs['check_gpu'] = { - 'max_shared_memory_per_block': ctx.max_shared_memory_per_block, - 'max_threads_per_block': ctx.max_threads_per_block, - 'max_thread_x': max_dims[0], - 'max_thread_y': max_dims[1], - 'max_thread_z': max_dims[2], - } - - if 'cuda' in self.task.target.keys: - kwargs["cuda_arch"] = "sm_" + "".join(ctx.compute_version.split('.')) - - return kwargs - - def run(self, measure_inputs, build_results): - results = [] - remote_args = (self.key, self.host, self.port, self.priority, self.timeout) - - for i in range(0, len(measure_inputs), self.n_parallel): - futures = [] - for measure_inp, build_res in zip(measure_inputs[i:i+self.n_parallel], - build_results[i:i+self.n_parallel]): - ret = self.executor.submit(run_through_rpc, - measure_inp, - build_res, - self.number, - self.repeat, - self.min_repeat_ms, - self.cooldown_interval, - remote_args, - self.ref_input, - self.ref_output) - futures.append(ret) - - for future in futures: - res = future.get() - if isinstance(res, Exception): # executor error or timeout - results.append(MeasureResult((str(res),), MeasureErrorNo.RUN_TIMEOUT, - self.timeout, time.time())) - else: - results.append(res) - - return results - -class LocalRunner(RPCRunner): - """Run generated code on local devices. - - Parameters - ---------- - timeout: float - The timeout of a compilation - number: int - The number of times to run the generated code for taking average. - We call these runs as one `repeat` of measurement. - repeat : int, optional - The number of times to repeat the measurement. - In total, the generated code will be run (1 + number x repeat) times, - where the first one is warm up and will be discarded. - The returned result contains `repeat` costs, - each of which is an average of `number` costs. - min_repeat_ms: int, optional - The minimum duration of one `repeat` in milliseconds. - By default, one `repeat` contains `number` runs. If this parameter is set, - the parameters `number` will be dynamically adjusted to meet the - minimum duration requirement of one `repeat`. - i.e., When the run time of one `repeat` falls below this time, the `number` parameter - will be automatically increased. - cooldown_interval: float, optional - The cool down interval between two measurements. - check_correctness: bool, optional - Whether check correctness after measurement. This will use llvm cpu target to - call your template and get the reference output. - This can work for TOPI templates, but may not work for your custom template. - - Note - ---- - This is a "fake" local mode. We start a silent rpc tracker and rpc server - for the user. In this way we reuse timeout/isolation mechanism in RPC infrastructure. - """ - def __init__(self, - timeout=10, - number=4, repeat=3, min_repeat_ms=0, cooldown_interval=0.1, - check_correctness=False): - super(LocalRunner, self).__init__('', None, None, 0, - timeout=timeout, n_parallel=1, - number=number, repeat=repeat, - min_repeat_ms=min_repeat_ms, - cooldown_interval=cooldown_interval, - check_correctness=check_correctness) - self.tracker = None - self.server = None - - def set_task(self, task): - self.task = task - - from ...rpc.tracker import Tracker - from ...rpc.server import Server - - tracker = Tracker('0.0.0.0', port=9000, port_end=10000, silent=True) - device_key = '$local$device$%d' % tracker.port - server = Server('0.0.0.0', port=9000, port_end=10000, - key=device_key, - use_popen=True, silent=True, - tracker_addr=(tracker.host, tracker.port)) - self.key = device_key - self.host = tracker.host - self.port = tracker.port - - super(LocalRunner, self).set_task(task) - return server, tracker - - -def _build_func_common(measure_input, check_gpu=None, cuda_arch=None, build_option=None): - """Common part for building a configuration""" - target, task, config = measure_input - - with target: - s, args = task.instantiate(config) - - # check invalidity of template and code hash consistency - if not config.valid(): - raise InstantiationError(config.errors) - - opts = build_option or {} - if check_gpu: # Add verify pass to filter out invalid configs in advance. - opts["add_lower_pass"] = [(2, gpu_verify_pass(**check_gpu))] - if cuda_arch: - set_cuda_target_arch(cuda_arch) - - # if target is vta, we need to use vta build - if hasattr(measure_input.target, 'device_name') and \ - measure_input.target.device_name == 'vta': - import vta - func = vta.build(s, args, target_host=task.target_host) - else: - with build_config(**opts): - func = build(s, args, target_host=task.target_host) - return func, tuple((get_const_tuple(x.shape), x.dtype) for x in args) - - -def _wrap_build_func(build_func): - """ - Wrap build_func to a function that can be used in measure. - - Parameters - ---------- - build_func : The compilation function - We expect fcompile to contain an attr "output_format" - - Returns - ------- - wrapped_build_func : function - The wrapped build function - """ - if not hasattr(build_func, "output_format"): - raise AttributeError("Expect build_func to have the attribute output_format.") - output_format = build_func.output_format - - def _wrapped(measure_input, tmp_dir, **kwargs): - """ - Wrapped build func. - - Parameters - ---------- - measure_input: MeasureInput - The input of measurement - - tmp_dir: str - The path of temporary directory to export generated library - """ - tic = time.time() - try: - filename = os.path.join(tmp_dir, "tmp_func_%0x.%s" % ( - random.getrandbits(64), output_format)) - # TODO(tvm-team) consider linline _build_func_common - func, arg_info = _build_func_common(measure_input, **kwargs) - func.export_library(filename, build_func) - except Exception as e: # pylint: disable=broad-except - return BuildResult(None, None, e, time.time() - tic) - return BuildResult(filename, arg_info, None, time.time() - tic) - return _wrapped - - -def run_through_rpc(measure_input, build_result, - number, repeat, min_repeat_ms, cooldown_interval, - remote_args, ref_input=None, ref_output=None): - """Run a generated library through rpc - - Parameters - ---------- - measure_input: MeasureInput - The raw measure input - build_result: BuildResult - The result returned from Builder. This contains the path to the generated library. - number: int - The number of times to run the generated code for taking average. - We call these runs as one `repeat` of measurement. - repeat : int, optional - The number of times to repeat the measurement. - In total, the generated code will be run (1 + number x repeat) times, - where the first one is warm up and will be discarded. - The returned result contains `repeat` costs, - each of which is an average of `number` costs. - min_repeat_ms: int, optional - The minimum duration of one `repeat` in milliseconds. - By default, one `repeat` contains `number` runs. If this parameter is set, - the parameters `number` will be dynamically adjusted to meet the - minimum duration requirement of one `repeat`. - i.e., When the run time of one `repeat` falls below this time, the `number` parameter - will be automatically increased. - cooldown_interval: float - The cool down interval between two measurements - remote_args: Tuple - The argument for request_remote - ref_input: List of np.ndarray - The reference input used for checking correctness - ref_output: List of np.ndarray - The reference output used for checking correctness - """ - if isinstance(build_result, MeasureResult): - return build_result - - tic = time.time() - errno = MeasureErrorNo.NO_ERROR - try: - # upload built module - remote = request_remote(*remote_args) - # Program the FPGA every single time when targeting VTA - if hasattr(measure_input.target, 'device_name') and \ - measure_input.target.device_name == 'vta': - from vta import program_fpga, reconfig_runtime - program_fpga(remote, None) - reconfig_runtime(remote) - remote.upload(build_result.filename) - func = remote.load_module(os.path.split(build_result.filename)[1]) - ctx = remote.context(str(measure_input.target), 0) - time_f = func.time_evaluator( - func.entry_name, ctx, number=number, repeat=repeat, min_repeat_ms=min_repeat_ms) - - # set input - if ref_input: - args = [nd.array(x, ctx=ctx) for x in ref_input] - else: - # create empty arrays on the remote device and copy them once. - # This can avoid some memory issues that make the measurement results unreliable. - args = [nd.empty(x[0], dtype=x[1], ctx=ctx) for x in build_result.arg_info] - args = [nd.array(x, ctx=ctx) for x in args] - ctx.sync() - - costs = time_f(*args).results - - # clean up remote files - remote.remove(build_result.filename) - remote.remove(os.path.splitext(build_result.filename)[0] + '.so') - remote.remove('') - - if len(costs) > 2: # remove largest and smallest value to reduce variance - costs = list(costs) - costs.sort() - costs = tuple(costs[1:-1]) - - # check correctness of output - if ref_output: - for expected, real in zip(ref_output, args): - if not np.allclose(expected, real.asnumpy(), rtol=1e-4): - logger.warning("Wrong Answer!") - errno = MeasureErrorNo.WRONG_ANSWER - except TVMError as exc: - msg = str(exc) - if "Stack trace returned" in msg: - msg = msg[:msg.index("Stack trace returned")] - if "CUDA Source" in msg: - msg = msg[:msg.index("CUDA Source")] - costs = (RuntimeError(msg[:1024]),) - errno = MeasureErrorNo.RUNTIME_DEVICE - tstamp = time.time() - time.sleep(cooldown_interval) - return MeasureResult(costs, errno, tstamp - tic + build_result.time_cost, tstamp) - - -def request_remote(device_key, host=None, port=None, priority=1, timeout=60): - """Request a remote session - - Parameters - ---------- - device_key: string - The device key of registered device in tracker - host: host, optional - The host address of rpc tracker. - If is none, will use environment variable "TVM_TRACKER_HOST" - port: int, optional - The port of rpc tracker. - If is none, will use environment variable "TVM_TRACKER_PORT" - priority: int, optional - The priority of this request, larger is more prior - timeout: float, optional - The timeout of this session (units: second) - - Returns - ------ - session: RPCSession - """ - # connect to the tracker - host = host or os.environ['TVM_TRACKER_HOST'] - port = port or int(os.environ['TVM_TRACKER_PORT']) - - tracker = _rpc.connect_tracker(host, port) - remote = tracker.request(device_key, priority=priority, - session_timeout=timeout) - return remote - - -def check_remote(target, device_key, host=None, port=None, priority=100, timeout=10): - """ - Check the availability of a remote device - - Parameters - ---------- - target: Target - The wanted compilation target - device_key: string - device key of registered device in tracker - host: host, optional - The host address of rpc tracker. - If is none, will use environment variable "TVM_TRACKER_HOST" - port: int, optional - The port address of rpc tracker. - If is none, will use environment variable "TVM_TRACKER_PORT" - priority: int, optional - The priority of this request, larger is more prior - timeout: float, optional - The timeout of this check (units: seconds). - - Returns - ------- - available: bool - True if can find available device - """ - def _check(): - remote = request_remote(device_key, host, port, priority) - ctx = remote.context(str(target)) - while not ctx.exist: # wait until we get an available device - pass - t = threading.Thread(target=_check,) - t.start() - t.join(timeout) - return not t.is_alive() - - -@register_func -def tvm_callback_cuda_compile(code): - """use nvcc to generate ptx code for better optimization""" - ptx = nvcc.compile_cuda(code, target="ptx", arch=AutotvmGlobalScope.current.cuda_target_arch) - return ptx - - -def set_cuda_target_arch(arch): - """set target architecture of nvcc compiler - - Parameters - ---------- - arch: str - The argument of nvcc -arch. (e.g. "sm_51", "sm_62") - """ - AutotvmGlobalScope.current.cuda_target_arch = arch - - -def gpu_verify_pass(**kwargs): - """Verify the validity of a gpu kernel. - This pass will check memory usage and number of threads per block. - """ - def verify_pass(stmt): - valid = ir_pass.VerifyGPUCode(stmt, kwargs) - if not valid: - raise InstantiationError("Skipped because of invalid gpu kernel") - return stmt - return verify_pass +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,too-many-function-args,too-many-nested-blocks +""" +Functions that run on executor for measurement. + +These functions are responsible for building the tvm module, uploading it to +remote devices, recording the running time costs, and checking the correctness of the output. +""" + +import logging +import shutil +import os +import threading +import time +# Import random directly because it appears dill will pull in 'getrandbits' +# and it will always get the same random number. Using random.getrandbits fixes +import random +from collections import namedtuple +import tempfile + +import numpy as np + +from ... import ir_pass, build, build_config, nd, TVMError, register_func, \ + rpc as _rpc, target as _target +from ...contrib import nvcc, ndk, tar + +from ..util import get_const_tuple +from ..env import AutotvmGlobalScope +from ..task.space import InstantiationError + +from .measure import MeasureResult, MeasureErrorNo, Builder, Runner +from .local_executor import LocalExecutor + +logger = logging.getLogger('autotvm') + +class BuildResult(namedtuple("BuildResult", ('filename', 'arg_info', 'error', 'time_cost'))): + """ + Stores all the necessary inputs for a measurement. + + Parameters + ---------- + filename : str + The filename of generated library + arg_info : Tuple + The shape and dtype information of tvm tensor arguments + error : Exception + The error happens during compilation. + time_cost : float + The time cost of building + """ + +class LocalBuilder(Builder): + """Run compilation on local machine + + Parameters + ---------- + timeout: float + The timeout of a compilation + n_parallel: int + The number of tasks run in parallel. "None" will use all cpu cores + build_func: callable or str + If is 'default', use default build function + If is 'ndk', use function for android ndk + If is callable, use it as custom build function, expect lib_format field. + """ + def __init__(self, timeout=10, n_parallel=None, build_func='default'): + super(LocalBuilder, self).__init__(timeout, n_parallel) + + if isinstance(build_func, str): + if build_func == 'default': + build_func = tar.tar + elif build_func == 'ndk': + build_func = ndk.create_shared + else: + raise ValueError("Invalid build_func" + build_func) + self.build_func = _wrap_build_func(build_func) + self.executor = LocalExecutor(timeout=timeout) + self.tmp_dir = tempfile.mkdtemp() + + def build(self, measure_inputs): + results = [] + + shutil.rmtree(self.tmp_dir) + self.tmp_dir = tempfile.mkdtemp() + + for i in range(0, len(measure_inputs), self.n_parallel): + futures = [] + for inp in measure_inputs[i:i + self.n_parallel]: + ret = self.executor.submit(self.build_func, + inp, + self.tmp_dir, + **self.build_kwargs) + futures.append(ret) + + for future in futures: + res = future.get() + + if isinstance(res, Exception): + # timeout or fleet error, return MeasureResult directly + results.append(MeasureResult((res,), MeasureErrorNo.BUILD_TIMEOUT, + self.timeout, time.time())) + elif res.error is not None: + # instantiation error + if isinstance(res.error, InstantiationError): + results.append(MeasureResult((res.error,), + MeasureErrorNo.INSTANTIATION_ERROR, + res.time_cost, time.time())) + else: + if "InstantiationError" in str(res.error): + msg = str(res.error) + try: + msg = msg.split('\n')[-2].split(": ")[1] + except Exception: # pylint: disable=broad-except + pass + results.append(MeasureResult((InstantiationError(msg),), + MeasureErrorNo.INSTANTIATION_ERROR, + res.time_cost, time.time())) + else: # tvm error + results.append(MeasureResult((res.error,), + MeasureErrorNo.COMPILE_HOST, + res.time_cost, time.time())) + else: + # return BuildResult + results.append(res) + + return results + + +class RPCRunner(Runner): + """Run generated code on remove devices. + This function will ask a RPC Tracker to get device for measurement. + + Parameters + ---------- + timeout: float + The timeout of a compilation + n_parallel: int + The number of tasks run in parallel. "None" will use all cpu cores + key: str + The key of the device registered in the tracker + host: str + The host address of RPC Tracker + port: int + The port of RPC Tracker + number: int + The number of times to run the generated code for taking average. + We call these runs as one `repeat` of measurement. + repeat : int, optional + The number of times to repeat the measurement. + In total, the generated code will be run (1 + number x repeat) times, + where the first "1" is warm up and will be discarded. + The returned result contains `repeat` costs, + each of which is an average of `number` costs. + min_repeat_ms: int, optional + The minimum duration of one `repeat` in milliseconds. + By default, one `repeat` contains `number` runs. If this parameter is set, + the parameters `number` will be dynamically adjusted to meet the + minimum duration requirement of one `repeat`. + i.e., When the run time of one `repeat` falls below this time, the `number` parameter + will be automatically increased. + cooldown_interval: float, optional + The cool down interval between two measurements. + check_correctness: bool, optional + Whether check correctness after measurement. This will use llvm cpu target to + call your template and get the reference output. + This can work for TOPI templates, but may not work for your custom template. + """ + def __init__(self, + key, host, port, priority=1, + timeout=10, n_parallel=None, + number=4, repeat=3, min_repeat_ms=0, cooldown_interval=0.1, + check_correctness=False): + super(RPCRunner, self).__init__(timeout, n_parallel) + + self.key = key + self.host = host + self.port = port + self.priority = priority + self.timeout = timeout + + self.number = number + self.repeat = repeat + self.min_repeat_ms = min_repeat_ms + + self.ref_input = None + self.ref_output = None + self.check_correctness = check_correctness + self.cooldown_interval = cooldown_interval + + self.executor = LocalExecutor() + + def set_task(self, task): + self.task = task + + if check_remote(task.target, self.key, self.host, self.port): + logger.info("Get devices for measurement successfully!") + else: + raise RuntimeError("Cannot get remote devices from the tracker. " + "Please check the status of tracker by " + "'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' " + "and make sure you have free devices on the queue status.") + + if self.check_correctness: + # use llvm cpu to generate a reference input/output + # this option works for tuning topi, but might not work for you custom op + with _target.create("llvm"): + s, arg_bufs = task.instantiate(task.config_space.get(0)) + self.ref_input = [np.random.uniform(size=get_const_tuple(x.shape)).astype(x.dtype) + for x in arg_bufs] + func = build(s, arg_bufs, "llvm") + tvm_buf = [nd.array(x) for x in self.ref_input] + func(*tvm_buf) + self.ref_output = [x.asnumpy() for x in tvm_buf] + + def get_build_kwargs(self): + kwargs = {} + if 'cuda' in self.task.target.keys or 'opencl' in self.task.target.keys: + remote = request_remote(self.key, self.host, self.port) + ctx = remote.context(str(self.task.target), 0) + max_dims = ctx.max_thread_dimensions + kwargs['check_gpu'] = { + 'max_shared_memory_per_block': ctx.max_shared_memory_per_block, + 'max_threads_per_block': ctx.max_threads_per_block, + 'max_thread_x': max_dims[0], + 'max_thread_y': max_dims[1], + 'max_thread_z': max_dims[2], + } + + if 'cuda' in self.task.target.keys: + kwargs["cuda_arch"] = "sm_" + "".join(ctx.compute_version.split('.')) + + return kwargs + + def run(self, measure_inputs, build_results): + results = [] + remote_args = (self.key, self.host, self.port, self.priority, self.timeout) + + for i in range(0, len(measure_inputs), self.n_parallel): + futures = [] + for measure_inp, build_res in zip(measure_inputs[i:i+self.n_parallel], + build_results[i:i+self.n_parallel]): + ret = self.executor.submit(run_through_rpc, + measure_inp, + build_res, + self.number, + self.repeat, + self.min_repeat_ms, + self.cooldown_interval, + remote_args, + self.ref_input, + self.ref_output) + futures.append(ret) + + for future in futures: + res = future.get() + if isinstance(res, Exception): # executor error or timeout + results.append(MeasureResult((str(res),), MeasureErrorNo.RUN_TIMEOUT, + self.timeout, time.time())) + else: + results.append(res) + + return results + +class LocalRunner(RPCRunner): + """Run generated code on local devices. + + Parameters + ---------- + timeout: float + The timeout of a compilation + number: int + The number of times to run the generated code for taking average. + We call these runs as one `repeat` of measurement. + repeat : int, optional + The number of times to repeat the measurement. + In total, the generated code will be run (1 + number x repeat) times, + where the first one is warm up and will be discarded. + The returned result contains `repeat` costs, + each of which is an average of `number` costs. + min_repeat_ms: int, optional + The minimum duration of one `repeat` in milliseconds. + By default, one `repeat` contains `number` runs. If this parameter is set, + the parameters `number` will be dynamically adjusted to meet the + minimum duration requirement of one `repeat`. + i.e., When the run time of one `repeat` falls below this time, the `number` parameter + will be automatically increased. + cooldown_interval: float, optional + The cool down interval between two measurements. + check_correctness: bool, optional + Whether check correctness after measurement. This will use llvm cpu target to + call your template and get the reference output. + This can work for TOPI templates, but may not work for your custom template. + + Note + ---- + This is a "fake" local mode. We start a silent rpc tracker and rpc server + for the user. In this way we reuse timeout/isolation mechanism in RPC infrastructure. + """ + def __init__(self, + timeout=10, + number=4, repeat=3, min_repeat_ms=0, cooldown_interval=0.1, + check_correctness=False): + super(LocalRunner, self).__init__('', None, None, 0, + timeout=timeout, n_parallel=1, + number=number, repeat=repeat, + min_repeat_ms=min_repeat_ms, + cooldown_interval=cooldown_interval, + check_correctness=check_correctness) + self.tracker = None + self.server = None + + def set_task(self, task): + self.task = task + + from ...rpc.tracker import Tracker + from ...rpc.server import Server + + # Windows will not let you connect to 0.0.0.0 + local_address = '0.0.0.0' if os.name != 'nt' else '127.0.0.1' + + tracker = Tracker(local_address, port=9000, port_end=10000, silent=True) + device_key = '$local$device$%d' % tracker.port + server = Server(local_address, port=9000, port_end=10000, + key=device_key, + use_popen=True, silent=True, + tracker_addr=(tracker.host, tracker.port)) + self.key = device_key + self.host = tracker.host + self.port = tracker.port + + super(LocalRunner, self).set_task(task) + return server, tracker + + +def _build_func_common(measure_input, check_gpu=None, cuda_arch=None, build_option=None): + """Common part for building a configuration""" + target, task, config = measure_input + + with target: + s, args = task.instantiate(config) + + # check invalidity of template and code hash consistency + if not config.valid(): + raise InstantiationError(config.errors) + + opts = build_option or {} + if check_gpu: # Add verify pass to filter out invalid configs in advance. + opts["add_lower_pass"] = [(2, gpu_verify_pass(**check_gpu))] + if cuda_arch: + set_cuda_target_arch(cuda_arch) + + # if target is vta, we need to use vta build + if hasattr(measure_input.target, 'device_name') and \ + measure_input.target.device_name == 'vta': + import vta + func = vta.build(s, args, target_host=task.target_host) + else: + with build_config(**opts): + func = build(s, args, target_host=task.target_host) + return func, tuple((get_const_tuple(x.shape), x.dtype) for x in args) + + +def _wrap_build_func(build_func): + """ + Wrap build_func to a function that can be used in measure. + + Parameters + ---------- + build_func : The compilation function + We expect fcompile to contain an attr "output_format" + + Returns + ------- + wrapped_build_func : function + The wrapped build function + """ + if not hasattr(build_func, "output_format"): + raise AttributeError("Expect build_func to have the attribute output_format.") + output_format = build_func.output_format + + def _wrapped(measure_input, tmp_dir, **kwargs): + """ + Wrapped build func. + + Parameters + ---------- + measure_input: MeasureInput + The input of measurement + + tmp_dir: str + The path of temporary directory to export generated library + """ + tic = time.time() + try: + filename = os.path.join(tmp_dir, "tmp_func_%0x.%s" % ( + random.getrandbits(64), output_format)) + # TODO(tvm-team) consider linline _build_func_common + func, arg_info = _build_func_common(measure_input, **kwargs) + func.export_library(filename, build_func) + except Exception as e: # pylint: disable=broad-except + return BuildResult(None, None, e, time.time() - tic) + return BuildResult(filename, arg_info, None, time.time() - tic) + return _wrapped + + +def run_through_rpc(measure_input, build_result, + number, repeat, min_repeat_ms, cooldown_interval, + remote_args, ref_input=None, ref_output=None): + """Run a generated library through rpc + + Parameters + ---------- + measure_input: MeasureInput + The raw measure input + build_result: BuildResult + The result returned from Builder. This contains the path to the generated library. + number: int + The number of times to run the generated code for taking average. + We call these runs as one `repeat` of measurement. + repeat : int, optional + The number of times to repeat the measurement. + In total, the generated code will be run (1 + number x repeat) times, + where the first one is warm up and will be discarded. + The returned result contains `repeat` costs, + each of which is an average of `number` costs. + min_repeat_ms: int, optional + The minimum duration of one `repeat` in milliseconds. + By default, one `repeat` contains `number` runs. If this parameter is set, + the parameters `number` will be dynamically adjusted to meet the + minimum duration requirement of one `repeat`. + i.e., When the run time of one `repeat` falls below this time, the `number` parameter + will be automatically increased. + cooldown_interval: float + The cool down interval between two measurements + remote_args: Tuple + The argument for request_remote + ref_input: List of np.ndarray + The reference input used for checking correctness + ref_output: List of np.ndarray + The reference output used for checking correctness + """ + if isinstance(build_result, MeasureResult): + return build_result + + tic = time.time() + errno = MeasureErrorNo.NO_ERROR + try: + # upload built module + remote = request_remote(*remote_args) + # Program the FPGA every single time when targeting VTA + if hasattr(measure_input.target, 'device_name') and \ + measure_input.target.device_name == 'vta': + from vta import program_fpga, reconfig_runtime + program_fpga(remote, None) + reconfig_runtime(remote) + remote.upload(build_result.filename) + func = remote.load_module(os.path.split(build_result.filename)[1]) + ctx = remote.context(str(measure_input.target), 0) + time_f = func.time_evaluator( + func.entry_name, ctx, number=number, repeat=repeat, min_repeat_ms=min_repeat_ms) + + # set input + if ref_input: + args = [nd.array(x, ctx=ctx) for x in ref_input] + else: + # create empty arrays on the remote device and copy them once. + # This can avoid some memory issues that make the measurement results unreliable. + args = [nd.empty(x[0], dtype=x[1], ctx=ctx) for x in build_result.arg_info] + args = [nd.array(x, ctx=ctx) for x in args] + ctx.sync() + + costs = time_f(*args).results + + # clean up remote files + remote.remove(build_result.filename) + remote.remove(os.path.splitext(build_result.filename)[0] + '.so') + remote.remove('') + + if len(costs) > 2: # remove largest and smallest value to reduce variance + costs = list(costs) + costs.sort() + costs = tuple(costs[1:-1]) + + # check correctness of output + if ref_output: + for expected, real in zip(ref_output, args): + if not np.allclose(expected, real.asnumpy(), rtol=1e-4): + logger.warning("Wrong Answer!") + errno = MeasureErrorNo.WRONG_ANSWER + except TVMError as exc: + msg = str(exc) + if "Stack trace returned" in msg: + msg = msg[:msg.index("Stack trace returned")] + if "CUDA Source" in msg: + msg = msg[:msg.index("CUDA Source")] + costs = (RuntimeError(msg[:1024]),) + errno = MeasureErrorNo.RUNTIME_DEVICE + tstamp = time.time() + time.sleep(cooldown_interval) + return MeasureResult(costs, errno, tstamp - tic + build_result.time_cost, tstamp) + + +def request_remote(device_key, host=None, port=None, priority=1, timeout=60): + """Request a remote session + + Parameters + ---------- + device_key: string + The device key of registered device in tracker + host: host, optional + The host address of rpc tracker. + If is none, will use environment variable "TVM_TRACKER_HOST" + port: int, optional + The port of rpc tracker. + If is none, will use environment variable "TVM_TRACKER_PORT" + priority: int, optional + The priority of this request, larger is more prior + timeout: float, optional + The timeout of this session (units: second) + + Returns + ------ + session: RPCSession + """ + # connect to the tracker + host = host or os.environ['TVM_TRACKER_HOST'] + port = port or int(os.environ['TVM_TRACKER_PORT']) + + tracker = _rpc.connect_tracker(host, port) + remote = tracker.request(device_key, priority=priority, + session_timeout=timeout) + return remote + + +def check_remote(target, device_key, host=None, port=None, priority=100, timeout=10): + """ + Check the availability of a remote device + + Parameters + ---------- + target: Target + The wanted compilation target + device_key: string + device key of registered device in tracker + host: host, optional + The host address of rpc tracker. + If is none, will use environment variable "TVM_TRACKER_HOST" + port: int, optional + The port address of rpc tracker. + If is none, will use environment variable "TVM_TRACKER_PORT" + priority: int, optional + The priority of this request, larger is more prior + timeout: float, optional + The timeout of this check (units: seconds). + + Returns + ------- + available: bool + True if can find available device + """ + def _check(): + remote = request_remote(device_key, host, port, priority) + ctx = remote.context(str(target)) + while not ctx.exist: # wait until we get an available device + pass + t = threading.Thread(target=_check,) + t.start() + t.join(timeout) + return not t.is_alive() + + +@register_func +def tvm_callback_cuda_compile(code): + """use nvcc to generate ptx code for better optimization""" + ptx = nvcc.compile_cuda(code, target="ptx", arch=AutotvmGlobalScope.current.cuda_target_arch) + return ptx + + +def set_cuda_target_arch(arch): + """set target architecture of nvcc compiler + + Parameters + ---------- + arch: str + The argument of nvcc -arch. (e.g. "sm_51", "sm_62") + """ + AutotvmGlobalScope.current.cuda_target_arch = arch + + +def gpu_verify_pass(**kwargs): + """Verify the validity of a gpu kernel. + This pass will check memory usage and number of threads per block. + """ + def verify_pass(stmt): + valid = ir_pass.VerifyGPUCode(stmt, kwargs) + if not valid: + raise InstantiationError("Skipped because of invalid gpu kernel") + return stmt + return verify_pass diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py index 00654446e5b5..7faaeb08b429 100644 --- a/python/tvm/rpc/server.py +++ b/python/tvm/rpc/server.py @@ -1,418 +1,454 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""RPC server implementation. - -Note ----- -Server is TCP based with the following protocol: -- Initial handshake to the peer - - [RPC_MAGIC, keysize(int32), key-bytes] -- The key is in format - - {server|client}:device-type[:random-key] [-timeout=timeout] -""" -# pylint: disable=invalid-name - -from __future__ import absolute_import - -import os -import ctypes -import socket -import select -import struct -import logging -import multiprocessing -import subprocess -import time -import sys -import signal - -from .._ffi.function import register_func -from .._ffi.base import py_str -from .._ffi.libinfo import find_lib_path -from ..module import load as _load_module -from ..contrib import util -from . import base -from . base import TrackerCode - -logger = logging.getLogger('RPCServer') - -def _server_env(load_library, work_path=None): - """Server environment function return temp dir""" - if work_path: - temp = work_path - else: - temp = util.tempdir() - - # pylint: disable=unused-variable - @register_func("tvm.rpc.server.workpath") - def get_workpath(path): - return temp.relpath(path) - - @register_func("tvm.rpc.server.load_module", override=True) - def load_module(file_name): - """Load module from remote side.""" - path = temp.relpath(file_name) - m = _load_module(path) - logger.info("load_module %s", path) - return m - - libs = [] - load_library = load_library.split(":") if load_library else [] - for file_name in load_library: - file_name = find_lib_path(file_name)[0] - libs.append(ctypes.CDLL(file_name, ctypes.RTLD_GLOBAL)) - logger.info("Load additional library %s", file_name) - temp.libs = libs - return temp - -def _serve_loop(sock, addr, load_library, work_path=None): - """Server loop""" - sockfd = sock.fileno() - temp = _server_env(load_library, work_path) - base._ServerLoop(sockfd) - if not work_path: - temp.remove() - logger.info("Finish serving %s", addr) - -def _parse_server_opt(opts): - # parse client options - ret = {} - for kv in opts: - if kv.startswith("-timeout="): - ret["timeout"] = float(kv[9:]) - return ret - -def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr): - """Listening loop of the server master.""" - def _accept_conn(listen_sock, tracker_conn, ping_period=2): - """Accept connection from the other places. - - Parameters - ---------- - listen_sock: Socket - The socket used by listening process. - - tracker_conn : connnection to tracker - Tracker connection - - ping_period : float, optional - ping tracker every k seconds if no connection is accepted. - """ - old_keyset = set() - # Report resource to tracker - if tracker_conn: - matchkey = base.random_key(rpc_key + ":") - base.sendjson(tracker_conn, - [TrackerCode.PUT, rpc_key, (port, matchkey), custom_addr]) - assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS - else: - matchkey = rpc_key - - unmatch_period_count = 0 - unmatch_timeout = 4 - # Wait until we get a valid connection - while True: - if tracker_conn: - trigger = select.select([listen_sock], [], [], ping_period) - if not listen_sock in trigger[0]: - base.sendjson(tracker_conn, [TrackerCode.GET_PENDING_MATCHKEYS]) - pending_keys = base.recvjson(tracker_conn) - old_keyset.add(matchkey) - # if match key not in pending key set - # it means the key is acquired by a client but not used. - if matchkey not in pending_keys: - unmatch_period_count += 1 - else: - unmatch_period_count = 0 - # regenerate match key if key is acquired but not used for a while - if unmatch_period_count * ping_period > unmatch_timeout + ping_period: - logger.info("no incoming connections, regenerate key ...") - matchkey = base.random_key(rpc_key + ":", old_keyset) - base.sendjson(tracker_conn, - [TrackerCode.PUT, rpc_key, (port, matchkey), - custom_addr]) - assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS - unmatch_period_count = 0 - continue - conn, addr = listen_sock.accept() - magic = struct.unpack(" max_retry: - raise RuntimeError("Maximum retry error: last error: %s" % str(err)) - time.sleep(retry_period) - -def _popen(cmd): - proc = subprocess.Popen(cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - env=os.environ) - (out, _) = proc.communicate() - if proc.returncode != 0: - msg = "Server invoke error:\n" - msg += out - raise RuntimeError(msg) - - -class Server(object): - """Start RPC server on a separate process. - - This is a simple python implementation based on multi-processing. - It is also possible to implement a similar C based server with - TVM runtime which does not depend on the python. - - Parameters - ---------- - host : str - The host url of the server. - - port : int - The port to be bind to - - port_end : int, optional - The end port to search - - is_proxy : bool, optional - Whether the address specified is a proxy. - If this is true, the host and port actually corresponds to the - address of the proxy server. - - use_popen : bool, optional - Whether to use Popen to start a fresh new process instead of fork. - This is recommended to switch on if we want to do local RPC demonstration - for GPU devices to avoid fork safety issues. - - tracker_addr: Tuple (str, int) , optional - The address of RPC Tracker in tuple(host, ip) format. - If is not None, the server will register itself to the tracker. - - key : str, optional - The key used to identify the device type in tracker. - - load_library : str, optional - List of additional libraries to be loaded during execution. - - custom_addr: str, optional - Custom IP Address to Report to RPC Tracker - - silent: bool, optional - Whether run this server in silent mode. - """ - def __init__(self, - host, - port=9091, - port_end=9199, - is_proxy=False, - use_popen=False, - tracker_addr=None, - key="", - load_library=None, - custom_addr=None, - silent=False): - try: - if base._ServerLoop is None: - raise RuntimeError("Please compile with USE_RPC=1") - except NameError: - raise RuntimeError("Please compile with USE_RPC=1") - self.host = host - self.port = port - self.libs = [] - self.custom_addr = custom_addr - self.use_popen = use_popen - - if silent: - logger.setLevel(logging.ERROR) - - if use_popen: - cmd = [sys.executable, - "-m", "tvm.exec.rpc_server", - "--host=%s" % host, - "--port=%s" % port] - if tracker_addr: - assert key - cmd += ["--tracker=%s:%d" % tracker_addr, - "--key=%s" % key] - if load_library: - cmd += ["--load-library", load_library] - if custom_addr: - cmd += ["--custom-addr", custom_addr] - if silent: - cmd += ["--silent"] - - # prexec_fn is not thread safe and may result in deadlock. - # python 3.2 introduced the start_new_session parameter as - # an alternative to the common use case of - # prexec_fn=os.setsid. Once the minimum version of python - # supported by TVM reaches python 3.2 this code can be - # rewritten in favour of start_new_session. In the - # interim, stop the pylint diagnostic. - # - # pylint: disable=subprocess-popen-preexec-fn - self.proc = subprocess.Popen(cmd, preexec_fn=os.setsid) - time.sleep(0.5) - elif not is_proxy: - sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM) - self.port = None - for my_port in range(port, port_end): - try: - sock.bind((host, my_port)) - self.port = my_port - break - except socket.error as sock_err: - if sock_err.errno in [98, 48]: - continue - else: - raise sock_err - if not self.port: - raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end)) - logger.info("bind to %s:%d", host, self.port) - sock.listen(1) - self.sock = sock - self.proc = multiprocessing.Process( - target=_listen_loop, args=( - self.sock, self.port, key, tracker_addr, load_library, - self.custom_addr)) - self.proc.deamon = True - self.proc.start() - else: - self.proc = multiprocessing.Process( - target=_connect_proxy_loop, args=((host, port), key, load_library)) - self.proc.deamon = True - self.proc.start() - - def terminate(self): - """Terminate the server process""" - if self.use_popen: - if self.proc: - os.killpg(self.proc.pid, signal.SIGTERM) - self.proc = None - else: - if self.proc: - self.proc.terminate() - self.proc = None - - def __del__(self): - self.terminate() +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""RPC server implementation. + +Note +---- +Server is TCP based with the following protocol: +- Initial handshake to the peer + - [RPC_MAGIC, keysize(int32), key-bytes] +- The key is in format + - {server|client}:device-type[:random-key] [-timeout=timeout] +""" +# pylint: disable=invalid-name + +from __future__ import absolute_import + +import os +import ctypes +import socket +import select +import struct +import logging +import multiprocessing +import subprocess +import time +import sys +import signal + +if os.name == 'nt': + from pathos.helpers import ProcessPool + import threading + +from .._ffi.function import register_func +from .._ffi.base import py_str +from .._ffi.libinfo import find_lib_path +from ..module import load as _load_module +from ..contrib import util +from . import base +from . base import TrackerCode + +logger = logging.getLogger('RPCServer') + +def _server_env(load_library, work_path=None): + """Server environment function return temp dir""" + if work_path: + temp = work_path + else: + temp = util.tempdir() + + # pylint: disable=unused-variable + @register_func("tvm.rpc.server.workpath") + def get_workpath(path): + return temp.relpath(path) + + @register_func("tvm.rpc.server.load_module", override=True) + def load_module(file_name): + """Load module from remote side.""" + path = temp.relpath(file_name) + m = _load_module(path) + logger.info("load_module %s", path) + return m + + libs = [] + load_library = load_library.split(":") if load_library else [] + for file_name in load_library: + file_name = find_lib_path(file_name)[0] + libs.append(ctypes.CDLL(file_name, ctypes.RTLD_GLOBAL)) + logger.info("Load additional library %s", file_name) + temp.libs = libs + return temp + +def _serve_loop(sock, addr, load_library, work_path=None): + """Server loop""" + sockfd = sock.fileno() + temp = _server_env(load_library, work_path) + base._ServerLoop(sockfd) + if not work_path: + temp.remove() + logger.info("Finish serving %s", addr) + +def _parse_server_opt(opts): + # parse client options + ret = {} + for kv in opts: + if kv.startswith("-timeout="): + ret["timeout"] = float(kv[9:]) + return ret + +def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr): + """Listening loop of the server master.""" + def _accept_conn(listen_sock, tracker_conn, ping_period=2): + """Accept connection from the other places. + + Parameters + ---------- + listen_sock: Socket + The socket used by listening process. + + tracker_conn : connnection to tracker + Tracker connection + + ping_period : float, optional + ping tracker every k seconds if no connection is accepted. + """ + old_keyset = set() + # Report resource to tracker + if tracker_conn: + matchkey = base.random_key(rpc_key + ":") + base.sendjson(tracker_conn, + [TrackerCode.PUT, rpc_key, (port, matchkey), custom_addr]) + assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS + else: + matchkey = rpc_key + + unmatch_period_count = 0 + unmatch_timeout = 4 + # Wait until we get a valid connection + while True: + if tracker_conn: + trigger = select.select([listen_sock], [], [], ping_period) + if not listen_sock in trigger[0]: + base.sendjson(tracker_conn, [TrackerCode.GET_PENDING_MATCHKEYS]) + pending_keys = base.recvjson(tracker_conn) + old_keyset.add(matchkey) + # if match key not in pending key set + # it means the key is acquired by a client but not used. + if matchkey not in pending_keys: + unmatch_period_count += 1 + else: + unmatch_period_count = 0 + # regenerate match key if key is acquired but not used for a while + if unmatch_period_count * ping_period > unmatch_timeout + ping_period: + logger.info("no incoming connections, regenerate key ...") + matchkey = base.random_key(rpc_key + ":", old_keyset) + base.sendjson(tracker_conn, + [TrackerCode.PUT, rpc_key, (port, matchkey), + custom_addr]) + assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS + unmatch_period_count = 0 + continue + conn, addr = listen_sock.accept() + magic = struct.unpack(" max_retry: + raise RuntimeError("Maximum retry error: last error: %s" % str(err)) + time.sleep(retry_period) + +def _popen(cmd): + proc = subprocess.Popen(cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=os.environ) + (out, _) = proc.communicate() + if proc.returncode != 0: + msg = "Server invoke error:\n" + msg += out + raise RuntimeError(msg) + +if os.name == 'nt': + def start_server_from_pool(host, port, port_end, is_proxy, use_popen, + tracker_addr, key, load_library, custom_addr, silent): + def run(): + server = Server(host, + port, + port_end, + key=key, + tracker_addr=tracker_addr, + load_library=load_library, + custom_addr=custom_addr, + silent=silent) + t = threading.Thread(target=run) + t.daemon = True + t.start() + +class Server(object): + """Start RPC server on a separate process. + + This is a simple python implementation based on multi-processing. + It is also possible to implement a similar C based server with + TVM runtime which does not depend on the python. + + Parameters + ---------- + host : str + The host url of the server. + + port : int + The port to be bind to + + port_end : int, optional + The end port to search + + is_proxy : bool, optional + Whether the address specified is a proxy. + If this is true, the host and port actually corresponds to the + address of the proxy server. + + use_popen : bool, optional + Whether to use Popen to start a fresh new process instead of fork. + This is recommended to switch on if we want to do local RPC demonstration + for GPU devices to avoid fork safety issues. + + tracker_addr: Tuple (str, int) , optional + The address of RPC Tracker in tuple(host, ip) format. + If is not None, the server will register itself to the tracker. + + key : str, optional + The key used to identify the device type in tracker. + + load_library : str, optional + List of additional libraries to be loaded during execution. + + custom_addr: str, optional + Custom IP Address to Report to RPC Tracker + + silent: bool, optional + Whether run this server in silent mode. + """ + def __init__(self, + host, + port=9091, + port_end=9199, + is_proxy=False, + use_popen=False, + tracker_addr=None, + key="", + load_library=None, + custom_addr=None, + silent=False): + try: + if base._ServerLoop is None: + raise RuntimeError("Please compile with USE_RPC=1") + except NameError: + raise RuntimeError("Please compile with USE_RPC=1") + self.host = host + self.port = port + self.libs = [] + self.custom_addr = custom_addr + self.use_popen = use_popen + self.proc = None + + if silent: + logger.setLevel(logging.ERROR) + + if use_popen: + cmd = [sys.executable, + "-m", "tvm.exec.rpc_server", + "--host=%s" % host, + "--port=%s" % port] + if tracker_addr: + assert key + cmd += ["--tracker=%s:%d" % tracker_addr, + "--key=%s" % key] + if load_library: + cmd += ["--load-library", load_library] + if custom_addr: + cmd += ["--custom-addr", custom_addr] + if silent: + cmd += ["--silent"] + + if os.name == 'nt': + self.proc = ProcessPool(1) + self.proc.apply(start_server_from_pool, args=(host, port, port_end, is_proxy, + use_popen, tracker_addr, key, load_library, custom_addr, silent)) + else: + # prexec_fn is not thread safe and may result in deadlock. + # python 3.2 introduced the start_new_session parameter as + # an alternative to the common use case of + # prexec_fn=os.setsid. Once the minimum version of python + # supported by TVM reaches python 3.2 this code can be + # rewritten in favour of start_new_session. In the + # interim, stop the pylint diagnostic. + # + # pylint: disable=subprocess-popen-preexec-fn + self.proc = subprocess.Popen(cmd, preexec_fn=os.setsid) + time.sleep(0.5) + elif not is_proxy: + sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM) + self.port = None + for my_port in range(port, port_end): + try: + sock.bind((host, my_port)) + self.port = my_port + break + except socket.error as sock_err: + sock_errno = sock_err.errno + if os.name == 'nt': + # Win32 socket codes are offset 10000 + sock_errno -= 10000 + if sock_errno in [98, 48]: + continue + else: + raise sock_err + if not self.port: + raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end)) + logger.info("bind to %s:%d", host, self.port) + sock.listen(1) + self.sock = sock + + if os.name == 'nt': + _listen_loop(self.sock, self.port, key, tracker_addr, load_library, self.custom_addr) + else: + self.proc = multiprocessing.Process( + target=_listen_loop, args=( + self.sock, self.port, key, tracker_addr, load_library, + self.custom_addr)) + self.proc.deamon = True + self.proc.start() + else: + self.proc = multiprocessing.Process( + target=_connect_proxy_loop, args=((host, port), key, load_library)) + self.proc.deamon = True + self.proc.start() + + def terminate(self): + """Terminate the server process""" + if self.use_popen: + if self.proc: + if os.name == 'nt': + self.proc.terminate() + else: + os.killpg(self.proc.pid, signal.SIGTERM) + self.proc = None + else: + if self.proc: + self.proc.terminate() + self.proc = None + + def __del__(self): + self.terminate() diff --git a/python/tvm/rpc/tracker.py b/python/tvm/rpc/tracker.py index b9b29a7fe4a1..609edd55a51b 100644 --- a/python/tvm/rpc/tracker.py +++ b/python/tvm/rpc/tracker.py @@ -49,7 +49,10 @@ import errno import struct import json - +import os +if os.name == 'nt': + import threading + from pathos.helpers import ProcessPool try: from tornado import ioloop from . import tornado_util @@ -352,8 +355,16 @@ def run(self): def _tracker_server(listen_sock, stop_key): handler = TrackerServerHandler(listen_sock, stop_key) - handler.run() + if os.name != 'nt': + handler.run() + else: + def run(): + handler.run() + + t = threading.Thread(target=run) + t.daemon = True + t.start() class Tracker(object): """Start RPC tracker on a seperate process. @@ -381,7 +392,7 @@ def __init__(self, silent=False): if silent: logger.setLevel(logging.WARN) - + self.proc = None sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM) self.port = None self.stop_key = base.random_key("tracker") @@ -391,7 +402,11 @@ def __init__(self, self.port = my_port break except socket.error as sock_err: - if sock_err.errno in [98, 48]: + sock_errno = sock_err.errno + if os.name == 'nt': + # Win32 socket codes are offset 10000 + sock_errno -= 10000 + if sock_errno in [98, 48]: continue else: raise sock_err @@ -399,9 +414,14 @@ def __init__(self, raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end)) logger.info("bind to %s:%d", host, self.port) sock.listen(1) - self.proc = multiprocessing.Process( - target=_tracker_server, args=(sock, self.stop_key)) - self.proc.start() + + if os.name == 'nt': + self.proc = ProcessPool(1) + self.proc.apply_async(_tracker_server, args=(sock, self.stop_key)).get() + else: + self.proc = multiprocessing.Process( + target=_tracker_server, args=(sock, self.stop_key)) + self.proc.start() self.host = host # close the socket on this process sock.close() @@ -419,12 +439,16 @@ def _stop_tracker(self): def terminate(self): """Terminate the server process""" if self.proc: - if self.proc.is_alive(): - self._stop_tracker() - self.proc.join(1) - if self.proc.is_alive(): - logger.info("Terminating Tracker Server...") - self.proc.terminate() + if os.name =='nt': + self.proc.close() + self.proc.join() + else: + if self.proc.is_alive(): + self._stop_tracker() + self.proc.join(1) + if self.proc.is_alive(): + logger.info("Terminating Tracker Server...") + self.proc.terminate() self.proc = None def __del__(self): From 7af065e430cba74ff2b80a336e9bea9711e6812f Mon Sep 17 00:00:00 2001 From: Jeremiah Morrill Date: Mon, 11 Nov 2019 02:36:17 -0800 Subject: [PATCH 03/33] Fix line endings --- python/tvm/autotvm/measure/local_executor.py | 466 +++++++++---------- 1 file changed, 233 insertions(+), 233 deletions(-) diff --git a/python/tvm/autotvm/measure/local_executor.py b/python/tvm/autotvm/measure/local_executor.py index d35c42b5dc35..a2f795576e92 100644 --- a/python/tvm/autotvm/measure/local_executor.py +++ b/python/tvm/autotvm/measure/local_executor.py @@ -1,233 +1,233 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Local based implementation of the executor using multiprocessing""" - -import signal -import os - -if os.name == 'nt': - import queue as thread_queue - import threading - # Pathos uses dill, which can pickle things like functions - from pathos.helpers import ProcessPool - # On Windows, there is no fork(), a 'multiprocessing.Process' - # or ProcessPool has to 'build up' the script from scratch, we - # set these environment variables so each python.exe process - # does not allocate unneeded threads - os.environ['OMP_NUM_THREADS'] = "1" - os.environ['TVM_NUM_THREADS'] = "1" - # numpy seems to honor this - os.environ['MKL_NUM_THREADS'] = "1" - - # Since there is no fork() on Windows, to mitigate performance impact - # we will use a process pool for executers, vs the *nix based systems - # that will fork() a new process for each executor - executor_pool = None - -from multiprocessing import Process, Queue, cpu_count -try: - from queue import Empty -except ImportError: - from Queue import Empty - -try: - import psutil -except ImportError: - psutil = None - -from . import executor - - -def kill_child_processes(parent_pid, sig=signal.SIGTERM): - """kill all child processes recursively""" - try: - parent = psutil.Process(parent_pid) - except psutil.NoSuchProcess: - return - children = parent.children(recursive=True) - for process in children: - try: - process.send_signal(sig) - except psutil.NoSuchProcess: - return - -def _execute_func(func, queue, args, kwargs): - """execute function and return the result or exception to a queue""" - try: - res = func(*args, **kwargs) - except Exception as exc: # pylint: disable=broad-except - res = exc - queue.put(res) - - -def call_with_timeout(queue, timeout, func, args, kwargs): - """A wrapper to support timeout of a function call""" - - # start a new process for timeout (cannot use thread because we have c function) - p = Process(target=_execute_func, args=(func, queue, args, kwargs)) - p.start() - p.join(timeout=timeout) - - queue.put(executor.TimeoutError()) - - kill_child_processes(p.pid) - p.terminate() - p.join() - -if os.name == 'nt': - def call_from_pool(func, args, kwargs, timeout, env): - """A wrapper to support timeout of a function call for a pool process""" - - # Restore environment variables from parent - for key, val in env.items(): - os.environ[key] = val - - queue = thread_queue.Queue(2) - - # We use a thread here for Windows, because starting up a new Process can be heavy - # This isn't as clean as the *nix implementation, which can kill a process that - # has timed out - thread = threading.Thread(target=_execute_func, args=(func, queue, args, kwargs)) - thread.start() - thread.join(timeout=timeout) - - queue.put(executor.TimeoutError()) - - res = queue.get() - return res - -class LocalFuture(executor.Future): - """Local wrapper for the future - - Parameters - ---------- - process: multiprocessing.Process - process for running this task - queue: multiprocessing.Queue - queue for receiving the result of this task - """ - def __init__(self, process, queue): - self._done = False - self._process = process - self._queue = queue - - def done(self): - self._done = self._done or not self._queue.empty() - return self._done - - def get(self, timeout=None): - try: - res = self._queue.get(block=True, timeout=timeout) - except Empty: - raise executor.TimeoutError() - if self._process.is_alive(): - kill_child_processes(self._process.pid) - self._process.terminate() - self._process.join() - self._queue.close() - self._queue.join_thread() - self._done = True - del self._queue - del self._process - return res - - -class LocalFutureNoFork(executor.Future): - """Local wrapper for the future. - This is a none-fork version of LocalFuture. - Use this for the runtime that does not support fork (like cudnn) - """ - def __init__(self, result): - self._result = result - - def done(self): - return True - - def get(self, timeout=None): - return self._result - -if os.name == 'nt': - class LocalFuturePool(executor.Future): - """Local wrapper for the future using a Process pool - - Parameters - ---------- - thread: threading.Thread - Thread for running this task - pool_results: result from Pool.apply_async - queue for receiving the result of this task - """ - def __init__(self, pool_results): - self._done = False - self._pool_results = pool_results - - def done(self): - return self._done - - def get(self, timeout=None): - try: - res = self._pool_results.get(timeout=timeout) - except Empty: - raise executor.TimeoutError() - self._done = True - return res - -class LocalExecutor(executor.Executor): - """Local executor that runs workers on the same machine with multiprocessing. - - Parameters - ---------- - timeout: float, optional - timeout of a job. If time is out. A TimeoutError will be returned (not raised) - do_fork: bool, optional - For some runtime systems that do not support fork after initialization - (e.g. cuda runtime, cudnn). Set this to False if you have used these runtime - before submitting jobs. - """ - def __init__(self, timeout=None, do_fork=True): - self.timeout = timeout or executor.Executor.DEFAULT_TIMEOUT - self.do_fork = do_fork - - if self.do_fork: - if not psutil: - raise RuntimeError("Python package psutil is missing. " - "please try `pip install psutil`") - - def submit(self, func, *args, **kwargs): - if not self.do_fork: - return LocalFutureNoFork(func(*args, **kwargs)) - - if os.name != 'nt': - queue = Queue(2) - process = Process(target=call_with_timeout, - args=(queue, self.timeout, func, args, kwargs)) - process.start() - return LocalFuture(process, queue) - else: - global executor_pool - - if executor_pool is None: - # We use a static pool for executor processes because Process.start(entry) - # is so slow on Windows, we lose a lot of parallelism. - # Right now cpu_count() is used, which isn't optimal from a user configuration - # perspective, but is reasonable at this time. - executor_pool = ProcessPool(cpu_count() * 2) - - # Windows seemed to be missing some valuable environ variables - # on the pool's process side. We might be able to get away with - # just sending the PATH variable, but for now, we just clone our env - return LocalFuturePool(executor_pool.apply_async(call_from_pool, (func, args, kwargs, self.timeout, os.environ.copy()))) +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Local based implementation of the executor using multiprocessing""" + +import signal +import os + +if os.name == 'nt': + import queue as thread_queue + import threading + # Pathos uses dill, which can pickle things like functions + from pathos.helpers import ProcessPool + # On Windows, there is no fork(), a 'multiprocessing.Process' + # or ProcessPool has to 'build up' the script from scratch, we + # set these environment variables so each python.exe process + # does not allocate unneeded threads + os.environ['OMP_NUM_THREADS'] = "1" + os.environ['TVM_NUM_THREADS'] = "1" + # numpy seems to honor this + os.environ['MKL_NUM_THREADS'] = "1" + + # Since there is no fork() on Windows, to mitigate performance impact + # we will use a process pool for executers, vs the *nix based systems + # that will fork() a new process for each executor + executor_pool = None + +from multiprocessing import Process, Queue, cpu_count +try: + from queue import Empty +except ImportError: + from Queue import Empty + +try: + import psutil +except ImportError: + psutil = None + +from . import executor + + +def kill_child_processes(parent_pid, sig=signal.SIGTERM): + """kill all child processes recursively""" + try: + parent = psutil.Process(parent_pid) + except psutil.NoSuchProcess: + return + children = parent.children(recursive=True) + for process in children: + try: + process.send_signal(sig) + except psutil.NoSuchProcess: + return + +def _execute_func(func, queue, args, kwargs): + """execute function and return the result or exception to a queue""" + try: + res = func(*args, **kwargs) + except Exception as exc: # pylint: disable=broad-except + res = exc + queue.put(res) + + +def call_with_timeout(queue, timeout, func, args, kwargs): + """A wrapper to support timeout of a function call""" + + # start a new process for timeout (cannot use thread because we have c function) + p = Process(target=_execute_func, args=(func, queue, args, kwargs)) + p.start() + p.join(timeout=timeout) + + queue.put(executor.TimeoutError()) + + kill_child_processes(p.pid) + p.terminate() + p.join() + +if os.name == 'nt': + def call_from_pool(func, args, kwargs, timeout, env): + """A wrapper to support timeout of a function call for a pool process""" + + # Restore environment variables from parent + for key, val in env.items(): + os.environ[key] = val + + queue = thread_queue.Queue(2) + + # We use a thread here for Windows, because starting up a new Process can be heavy + # This isn't as clean as the *nix implementation, which can kill a process that + # has timed out + thread = threading.Thread(target=_execute_func, args=(func, queue, args, kwargs)) + thread.start() + thread.join(timeout=timeout) + + queue.put(executor.TimeoutError()) + + res = queue.get() + return res + +class LocalFuture(executor.Future): + """Local wrapper for the future + + Parameters + ---------- + process: multiprocessing.Process + process for running this task + queue: multiprocessing.Queue + queue for receiving the result of this task + """ + def __init__(self, process, queue): + self._done = False + self._process = process + self._queue = queue + + def done(self): + self._done = self._done or not self._queue.empty() + return self._done + + def get(self, timeout=None): + try: + res = self._queue.get(block=True, timeout=timeout) + except Empty: + raise executor.TimeoutError() + if self._process.is_alive(): + kill_child_processes(self._process.pid) + self._process.terminate() + self._process.join() + self._queue.close() + self._queue.join_thread() + self._done = True + del self._queue + del self._process + return res + + +class LocalFutureNoFork(executor.Future): + """Local wrapper for the future. + This is a none-fork version of LocalFuture. + Use this for the runtime that does not support fork (like cudnn) + """ + def __init__(self, result): + self._result = result + + def done(self): + return True + + def get(self, timeout=None): + return self._result + +if os.name == 'nt': + class LocalFuturePool(executor.Future): + """Local wrapper for the future using a Process pool + + Parameters + ---------- + thread: threading.Thread + Thread for running this task + pool_results: result from Pool.apply_async + queue for receiving the result of this task + """ + def __init__(self, pool_results): + self._done = False + self._pool_results = pool_results + + def done(self): + return self._done + + def get(self, timeout=None): + try: + res = self._pool_results.get(timeout=timeout) + except Empty: + raise executor.TimeoutError() + self._done = True + return res + +class LocalExecutor(executor.Executor): + """Local executor that runs workers on the same machine with multiprocessing. + + Parameters + ---------- + timeout: float, optional + timeout of a job. If time is out. A TimeoutError will be returned (not raised) + do_fork: bool, optional + For some runtime systems that do not support fork after initialization + (e.g. cuda runtime, cudnn). Set this to False if you have used these runtime + before submitting jobs. + """ + def __init__(self, timeout=None, do_fork=True): + self.timeout = timeout or executor.Executor.DEFAULT_TIMEOUT + self.do_fork = do_fork + + if self.do_fork: + if not psutil: + raise RuntimeError("Python package psutil is missing. " + "please try `pip install psutil`") + + def submit(self, func, *args, **kwargs): + if not self.do_fork: + return LocalFutureNoFork(func(*args, **kwargs)) + + if os.name != 'nt': + queue = Queue(2) + process = Process(target=call_with_timeout, + args=(queue, self.timeout, func, args, kwargs)) + process.start() + return LocalFuture(process, queue) + else: + global executor_pool + + if executor_pool is None: + # We use a static pool for executor processes because Process.start(entry) + # is so slow on Windows, we lose a lot of parallelism. + # Right now cpu_count() is used, which isn't optimal from a user configuration + # perspective, but is reasonable at this time. + executor_pool = ProcessPool(cpu_count() * 2) + + # Windows seemed to be missing some valuable environ variables + # on the pool's process side. We might be able to get away with + # just sending the PATH variable, but for now, we just clone our env + return LocalFuturePool(executor_pool.apply_async(call_from_pool, (func, args, kwargs, self.timeout, os.environ.copy()))) From 3b0c75a4dda8361cf41a9d0f8cfcf0ff95f328d9 Mon Sep 17 00:00:00 2001 From: Jeremiah Morrill Date: Mon, 11 Nov 2019 20:52:32 -0800 Subject: [PATCH 04/33] Fixed socket.h build error --- src/common/socket.h | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/common/socket.h b/src/common/socket.h index 616991d021d1..07fcfa83c99e 100644 --- a/src/common/socket.h +++ b/src/common/socket.h @@ -58,9 +58,6 @@ static inline int poll(struct pollfd *pfd, int nfds, int timeout) { return WSAPoll(pfd, nfds, timeout); } -static inline int inet_pton(int family, const char* addr_str, void* addr_buf) { - return InetPton(family, addr_str, addr_buf); -} #else #include #endif // defined(_WIN32) From 73a3600742933e1a5d20f0edc7a13bcb80647c93 Mon Sep 17 00:00:00 2001 From: Jeremiah Morrill Date: Mon, 11 Nov 2019 23:23:15 -0800 Subject: [PATCH 05/33] fixed rpc_tracker exec on Windows. Added code comments to tracker.py --- python/tvm/exec/rpc_tracker.py | 9 ++++++--- python/tvm/rpc/tracker.py | 7 ++++++- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/python/tvm/exec/rpc_tracker.py b/python/tvm/exec/rpc_tracker.py index 8f5bd1dc73a0..f57326834668 100644 --- a/python/tvm/exec/rpc_tracker.py +++ b/python/tvm/exec/rpc_tracker.py @@ -17,7 +17,7 @@ # pylint: disable=redefined-outer-name, invalid-name """Tool to start RPC tracker""" from __future__ import absolute_import - +import os import logging import argparse import multiprocessing @@ -28,8 +28,11 @@ def main(args): """Main funciton""" tracker = Tracker(args.host, port=args.port, port_end=args.port_end, silent=args.silent) - tracker.proc.join() - + if os.name =='nt': + while True: + input() + else: + tracker.proc.join() if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/python/tvm/rpc/tracker.py b/python/tvm/rpc/tracker.py index 609edd55a51b..d70bb6114c45 100644 --- a/python/tvm/rpc/tracker.py +++ b/python/tvm/rpc/tracker.py @@ -359,6 +359,9 @@ def _tracker_server(listen_sock, stop_key): if os.name != 'nt': handler.run() else: + # on Windows, this function is the target of a + # ProcessPool, so we need the function to exit. + # therefor we use a thread to start up the handler def run(): handler.run() @@ -416,8 +419,10 @@ def __init__(self, sock.listen(1) if os.name == 'nt': + # We use the process pool because there it is more likely that + # the pool process will die with this parent self.proc = ProcessPool(1) - self.proc.apply_async(_tracker_server, args=(sock, self.stop_key)).get() + self.proc.apply(_tracker_server, args=(sock, self.stop_key)) else: self.proc = multiprocessing.Process( target=_tracker_server, args=(sock, self.stop_key)) From 7202717262cd58e2474d57c45c3eb7a76ec48616 Mon Sep 17 00:00:00 2001 From: Jeremiah Morrill Date: Mon, 2 Dec 2019 17:20:05 -0800 Subject: [PATCH 06/33] Optimize process pool usage in xgboost --- .../tvm/autotvm/tuner/xgboost_cost_model.py | 22 +-- python/tvm/rpc/server.py | 154 +++++++++++++----- 2 files changed, 129 insertions(+), 47 deletions(-) diff --git a/python/tvm/autotvm/tuner/xgboost_cost_model.py b/python/tvm/autotvm/tuner/xgboost_cost_model.py index bc1f7cee75e5..b46fd9a7230f 100644 --- a/python/tvm/autotvm/tuner/xgboost_cost_model.py +++ b/python/tvm/autotvm/tuner/xgboost_cost_model.py @@ -170,12 +170,11 @@ def _reset_pool(self, space, target, task): # some synchronization by sending an async call and waiting for # the queue to have an item set - # There seems to be diminishing returns on large pool sizes given - # the small job sizes mapped later in the code (largest seems to be 128) - # so the pool size is capped - pool_size = min(16, int(self.num_threads)) + #pool_size = min(32, int(self.num_threads)) + pool_size = self.num_threads + if self.pool == None: + self.pool = ProcessPool(pool_size) - self.pool = ProcessPool(pool_size) manager = pathos_multiprocess.Manager() pipe_syncs = [] @@ -188,7 +187,7 @@ def _reset_pool(self, space, target, task): queue = manager.Queue(1) results = { "queue": queue, - "apipe": self.pool.apply_async(_set_pool_process_state, (space, target, task, queue)) + "apipe": self.pool.apply_async(_set_pool_process_state, args=(space, target, task, queue)) } pipe_syncs.append(results) @@ -216,7 +215,10 @@ def _reset_pool(self, space, target, task): _extract_task = task self.pool = multiprocessing.Pool(self.num_threads) - def _close_pool(self): + def _close_pool(self, force_close=False): + if os.name == 'nt' and not force_close: + return + if self.pool: self.pool.terminate() self.pool.join() @@ -251,7 +253,7 @@ def fit(self, xs, ys, plan_size): self.base_model = None else: dtrain.set_base_margin(discount * self.base_model.predict(xs, output_margin=True)) - + self.bst = xgb.train(self.xgb_params, dtrain, num_boost_round=8000, callbacks=[custom_callback( @@ -381,7 +383,7 @@ def _get_feature(self, indexes): return ret def __del__(self): - self._close_pool() + self._close_pool(force_close=True) _extract_space = None @@ -538,7 +540,7 @@ def callback(env): res = [x.split(':') for x in bst_eval.split()] for kv in res[1:]: res_dict[kv[0]] = [float(kv[1])] - + eval_res = [] keys = list(res_dict.keys()) keys.sort(key=lambda x: x if metric_shortname not in x else "a" + x) diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py index 7faaeb08b429..1986a8cbae59 100644 --- a/python/tvm/rpc/server.py +++ b/python/tvm/rpc/server.py @@ -43,6 +43,7 @@ if os.name == 'nt': from pathos.helpers import ProcessPool import threading + import multiprocessing.pool from .._ffi.function import register_func from .._ffi.base import py_str @@ -54,26 +55,45 @@ logger = logging.getLogger('RPCServer') +_temp = None + +class NoDaemonProcess(multiprocessing.Process): + # make 'daemon' attribute always return False + def _get_daemon(self): + return False + def _set_daemon(self, value): + pass + daemon = property(_get_daemon, _set_daemon) + +# We sub-class multiprocessing.pool.Pool instead of multiprocessing.Pool +# because the latter is only a wrapper function, not a proper class. +class MyPool(multiprocessing.pool.Pool): + Process = NoDaemonProcess + +# pylint: disable=unused-variable +@register_func("tvm.rpc.server.workpath", override=True) +def get_workpath(path): + global _temp + return _temp.relpath(path) + +@register_func("tvm.rpc.server.load_module", override=True) +def load_module(file_name): + """Load module from remote side.""" + global _temp + path = _temp.relpath(file_name) + m = _load_module(path) + logger.info("load_module %s", path) + return m + def _server_env(load_library, work_path=None): """Server environment function return temp dir""" + global _temp if work_path: temp = work_path else: temp = util.tempdir() - # pylint: disable=unused-variable - @register_func("tvm.rpc.server.workpath") - def get_workpath(path): - return temp.relpath(path) - - @register_func("tvm.rpc.server.load_module", override=True) - def load_module(file_name): - """Load module from remote side.""" - path = temp.relpath(file_name) - m = _load_module(path) - logger.info("load_module %s", path) - return m - + _temp = temp libs = [] load_library = load_library.split(":") if load_library else [] for file_name in load_library: @@ -85,13 +105,37 @@ def load_module(file_name): def _serve_loop(sock, addr, load_library, work_path=None): """Server loop""" - sockfd = sock.fileno() - temp = _server_env(load_library, work_path) - base._ServerLoop(sockfd) - if not work_path: - temp.remove() + try: + sockfd = sock.fileno() + temp = _server_env(load_library, work_path) + base._ServerLoop(sockfd) + if not work_path: + temp.remove() + except Exception as ex: + print(ex) + pass + logger.info("Finish serving %s", addr) +def _serve_loop_pool(args): + """Server loop""" + sock = args["sock"] + addr = args["addr"] + load_library = args["load_library"] + work_path = args["work_path"] + + try: + sockfd = sock.fileno() + temp = _server_env(load_library, work_path) + base._ServerLoop(sockfd) + if not work_path: + temp.remove() + except Exception as ex: + print(ex) + pass + + logger.info("Finish serving %s", addr) + def _parse_server_opt(opts): # parse client options ret = {} @@ -100,7 +144,12 @@ def _parse_server_opt(opts): ret["timeout"] = float(kv[9:]) return ret +trial_counter = 0 + def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr): + global trial_counter + executerPool = MyPool(processes=1) + """Listening loop of the server master.""" def _accept_conn(listen_sock, tracker_conn, ping_period=2): """Accept connection from the other places. @@ -204,28 +253,59 @@ def _accept_conn(listen_sock, tracker_conn, ping_period=2): # step 3: serving work_path = util.tempdir() logger.info("connection from %s", addr) - server_proc = multiprocessing.Process(target=_serve_loop, - args=(conn, addr, load_library, work_path)) - server_proc.deamon = True - server_proc.start() - # close from our side. - conn.close() - # wait until server process finish or timeout - server_proc.join(opts.get("timeout", None)) - if server_proc.is_alive(): - logger.info("Timeout in RPC session, kill..") - import psutil + + def handle_posix(): + server_proc = multiprocessing.Process(target=_serve_loop, + args=(conn, addr, load_library, work_path)) + server_proc.deamon = True + server_proc.start() + # close from our side. + conn.close() + # wait until server process finish or timeout + server_proc.join(opts.get("timeout", None)) + if server_proc.is_alive(): + logger.info("Timeout in RPC session, kill..") + import psutil + try: + parent = psutil.Process(server_proc.pid) + # terminate worker childs + # this can throw on Windows + for child in parent.children(recursive=True): + child.terminate() + except: # pylint: disable=broad-except + pass + + # terminate the worker + server_proc.terminate() + def handle_win32(): + global trial_counter + nonlocal executerPool + + trial_counter += 1 + + args = { + "sock" : conn, + "addr" : addr, + "load_library" : load_library, + "work_path" : work_path + } + + executerPool.map(_serve_loop_pool, [args]) + if trial_counter % 1 == 0: + executerPool.terminate() + executerPool = MyPool(processes=1) + try: - parent = psutil.Process(server_proc.pid) - # terminate worker childs - # this can throw on Windows - for child in parent.children(recursive=True): - child.terminate() - except: # pylint: disable=broad-except + conn.close() + conn.shutdown(1) + except: pass - # terminate the worker - server_proc.terminate() + if os.name != 'nt': + handle_posix() + else: + handle_win32() + work_path.remove() From 255d46aa8a83128bf1ade35ef0cd0b9fdf87a53c Mon Sep 17 00:00:00 2001 From: Jeremiah Morrill Date: Mon, 2 Dec 2019 17:21:09 -0800 Subject: [PATCH 07/33] Removed timeouts from local executor on Windows --- python/tvm/autotvm/measure/local_executor.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/tvm/autotvm/measure/local_executor.py b/python/tvm/autotvm/measure/local_executor.py index a2f795576e92..6784a7b1e083 100644 --- a/python/tvm/autotvm/measure/local_executor.py +++ b/python/tvm/autotvm/measure/local_executor.py @@ -103,8 +103,7 @@ def call_from_pool(func, args, kwargs, timeout, env): # has timed out thread = threading.Thread(target=_execute_func, args=(func, queue, args, kwargs)) thread.start() - thread.join(timeout=timeout) - + thread.join() queue.put(executor.TimeoutError()) res = queue.get() @@ -180,7 +179,7 @@ def done(self): def get(self, timeout=None): try: - res = self._pool_results.get(timeout=timeout) + res = self._pool_results.get() except Empty: raise executor.TimeoutError() self._done = True From 49f5e87abeb7b2cf1daa386c0f24877543696b08 Mon Sep 17 00:00:00 2001 From: Jeremiah Morrill Date: Mon, 2 Dec 2019 17:29:03 -0800 Subject: [PATCH 08/33] Added Windows support to C++ RPC Server --- CMakeLists.txt | 10 +- apps/cpp_rpc/CMakeLists.txt | 23 +++ apps/cpp_rpc/main.cc | 92 ++++++--- apps/cpp_rpc/rpc_env.cc | 316 +++++++++++++++-------------- apps/cpp_rpc/rpc_env.h | 6 +- apps/cpp_rpc/rpc_server.cc | 258 +++++++++++------------ apps/cpp_rpc/rpc_server.h | 23 ++- apps/cpp_rpc/win32_process.cc | 279 +++++++++++++++++++++++++ apps/cpp_rpc/win32_process.h | 43 ++++ src/common/ring_buffer.h | 2 +- src/runtime/rpc/rpc_socket_impl.cc | 12 +- 11 files changed, 746 insertions(+), 318 deletions(-) create mode 100644 apps/cpp_rpc/CMakeLists.txt create mode 100644 apps/cpp_rpc/win32_process.cc create mode 100644 apps/cpp_rpc/win32_process.h diff --git a/CMakeLists.txt b/CMakeLists.txt index bf18ffc9e856..ed905547cbae 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.2) +cmake_minimum_required(VERSION 3.9) project(tvm C CXX) # Utility functions @@ -63,6 +63,7 @@ tvm_option(USE_NNPACK "Build with nnpack support" OFF) tvm_option(USE_RANDOM "Build with random support" OFF) tvm_option(USE_MICRO_STANDALONE_RUNTIME "Build with micro.standalone_runtime support" OFF) tvm_option(USE_ANTLR "Build with ANTLR for Relay parsing" OFF) +tvm_option(USE_CXX_RPC "Build CXX RPC" OFF) # include directories include_directories(${CMAKE_INCLUDE_PATH}) @@ -275,6 +276,9 @@ add_library(tvm SHARED ${COMPILER_SRCS} ${RUNTIME_SRCS}) add_library(tvm_topi SHARED ${TOPI_SRCS}) add_library(tvm_runtime SHARED ${RUNTIME_SRCS}) +if(USE_CXX_RPC STREQUAL "ON") + add_subdirectory("apps/cpp_rpc") +endif() if(USE_RELAY_DEBUG) message(STATUS "Building Relay in debug mode...") @@ -405,6 +409,10 @@ endif(INSTALL_DEV) # More target definitions if(MSVC) + set_property(TARGET tvm PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE) + set_property(TARGET tvm_topi PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE) + set_property(TARGET tvm_runtime PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE) + set_property(TARGET nnvm_compiler PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE) target_compile_definitions(tvm PRIVATE -DTVM_EXPORTS) target_compile_definitions(tvm_runtime PRIVATE -DTVM_EXPORTS) target_compile_definitions(nnvm_compiler PRIVATE -DNNVM_EXPORTS) diff --git a/apps/cpp_rpc/CMakeLists.txt b/apps/cpp_rpc/CMakeLists.txt new file mode 100644 index 000000000000..61c40c1affe6 --- /dev/null +++ b/apps/cpp_rpc/CMakeLists.txt @@ -0,0 +1,23 @@ +set(TVM_RPC_SOURCES + main.cc + rpc_env.cc + rpc_server.cc +) + +if(WIN32) + list(APPEND TVM_RPC_SOURCES win32_process.cc) +endif() + +# Set output to same directory as the other TVM libs +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}) +add_executable(tvm_rpc ${TVM_RPC_SOURCES}) +set_property(TARGET tvm_rpc PROPERTY INTERPROCEDURAL_OPTIMIZATION_RELEASE TRUE) +target_compile_definitions(tvm_rpc PUBLIC -DNOMINMAX) +target_include_directories( + tvm_rpc + PUBLIC "../../include" + PUBLIC DLPACK_PATH + PUBLIC DMLC_PATH +) + +target_link_libraries(tvm_rpc tvm_runtime) \ No newline at end of file diff --git a/apps/cpp_rpc/main.cc b/apps/cpp_rpc/main.cc index 3cf2ed6a5d59..f37cf56d39d0 100644 --- a/apps/cpp_rpc/main.cc +++ b/apps/cpp_rpc/main.cc @@ -21,10 +21,12 @@ * \file rpc_server.cc * \brief RPC Server for TVM. */ -#include -#include -#include +#include +#include +#include +#if defined(__linux__) || defined(__ANDROID__) #include +#endif #include #include #include @@ -35,11 +37,15 @@ #include "../../src/common/socket.h" #include "rpc_server.h" +#if defined(_WIN32) +#include "win32_process.h" +#endif + using namespace std; using namespace tvm::runtime; using namespace tvm::common; -static const string kUSAGE = \ +static const string kUsage = \ "Command line usage\n" \ " server - Start the server\n" \ "--host - The hostname of the server, Default=0.0.0.0\n" \ @@ -73,13 +79,16 @@ struct RpcServerArgs { string key; string custom_addr; bool silent = false; +#if defined(WIN32) + std::string mmap_path; +#endif }; /*! * \brief PrintArgs print the contents of RpcServerArgs * \param args RpcServerArgs structure */ -void PrintArgs(struct RpcServerArgs args) { +void PrintArgs(const RpcServerArgs& args) { LOG(INFO) << "host = " << args.host; LOG(INFO) << "port = " << args.port; LOG(INFO) << "port_end = " << args.port_end; @@ -89,6 +98,7 @@ void PrintArgs(struct RpcServerArgs args) { LOG(INFO) << "silent = " << ((args.silent) ? ("True"): ("False")); } +#if defined(__linux__) || defined(__ANDROID__) /*! * \brief CtrlCHandler, exits if Ctrl+C is pressed * \param s signal @@ -109,7 +119,7 @@ void HandleCtrlC() { sigIntHandler.sa_flags = 0; sigaction(SIGINT, &sigIntHandler, nullptr); } - +#endif /*! * \brief GetCmdOption Parse and find the command option. * \param argc arg counter @@ -129,7 +139,7 @@ string GetCmdOption(int argc, char* argv[], string option, bool key = false) { } // We assume "=" is the end of option. CHECK_EQ(*option.rbegin(), '='); - cmd = arg.substr(arg.find("=") + 1); + cmd = arg.substr(arg.find('=') + 1); return cmd; } } @@ -156,41 +166,41 @@ bool ValidateTracker(string &tracker) { * \brief ParseCmdArgs parses the command line arguments. * \param argc arg counter * \param argv arg values - * \param args, the output structure which holds the parsed values + * \param args the output structure which holds the parsed values */ void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) { - string silent = GetCmdOption(argc, argv, "--silent", true); + const string silent = GetCmdOption(argc, argv, "--silent", true); if (!silent.empty()) { args.silent = true; // Only errors and fatal is logged dmlc::InitLogging("--minloglevel=2"); } - string host = GetCmdOption(argc, argv, "--host="); + const string host = GetCmdOption(argc, argv, "--host="); if (!host.empty()) { if (!ValidateIP(host)) { LOG(WARNING) << "Wrong host address format."; - LOG(INFO) << kUSAGE; + LOG(INFO) << kUsage; exit(1); } args.host = host; } - string port = GetCmdOption(argc, argv, "--port="); + const string port = GetCmdOption(argc, argv, "--port="); if (!port.empty()) { if (!IsNumber(port) || stoi(port) > 65535) { LOG(WARNING) << "Wrong port number."; - LOG(INFO) << kUSAGE; + LOG(INFO) << kUsage; exit(1); } args.port = stoi(port); } - string port_end = GetCmdOption(argc, argv, "--port_end="); + const string port_end = GetCmdOption(argc, argv, "--port_end="); if (!port_end.empty()) { if (!IsNumber(port_end) || stoi(port_end) > 65535) { LOG(WARNING) << "Wrong port_end number."; - LOG(INFO) << kUSAGE; + LOG(INFO) << kUsage; exit(1); } args.port_end = stoi(port_end); @@ -200,26 +210,34 @@ void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) { if (!tracker.empty()) { if (!ValidateTracker(tracker)) { LOG(WARNING) << "Wrong tracker address format."; - LOG(INFO) << kUSAGE; + LOG(INFO) << kUsage; exit(1); } args.tracker = tracker; } - string key = GetCmdOption(argc, argv, "--key="); + const string key = GetCmdOption(argc, argv, "--key="); if (!key.empty()) { args.key = key; } - string custom_addr = GetCmdOption(argc, argv, "--custom_addr="); + const string custom_addr = GetCmdOption(argc, argv, "--custom_addr="); if (!custom_addr.empty()) { if (!ValidateIP(custom_addr)) { LOG(WARNING) << "Wrong custom address format."; - LOG(INFO) << kUSAGE; + LOG(INFO) << kUsage; exit(1); } args.custom_addr = custom_addr; } +#if defined(WIN32) + const string mmap_path = GetCmdOption(argc, argv, "--child_proc="); + if(!mmap_path.empty()) { + args.mmap_path = mmap_path; + dmlc::InitLogging("--minloglevel=0"); + } +#endif + } /*! @@ -229,17 +247,34 @@ void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) { * \return result of operation. */ int RpcServer(int argc, char * argv[]) { - struct RpcServerArgs args; + RpcServerArgs args; /* parse the command line args */ ParseCmdArgs(argc, argv, args); PrintArgs(args); - // Ctrl+C handler LOG(INFO) << "Starting CPP Server, Press Ctrl+C to stop."; +#if defined(__linux__) || defined(__ANDROID__) + // Ctrl+C handler HandleCtrlC(); - tvm::runtime::RPCServerCreate(args.host, args.port, args.port_end, args.tracker, - args.key, args.custom_addr, args.silent); +#endif + +#if defined(WIN32) + if(!args.mmap_path.empty()) { + int ret = 0; + + try { + ChildProcSocketHandler(args.mmap_path); + } catch (const std::exception&) { + ret = -1; + } + + return ret; + } +#endif + + RPCServerCreate(args.host, args.port, args.port_end, args.tracker, + args.key, args.custom_addr, args.silent); return 0; } @@ -251,15 +286,18 @@ int RpcServer(int argc, char * argv[]) { */ int main(int argc, char * argv[]) { if (argc <= 1) { - LOG(INFO) << kUSAGE; + LOG(INFO) << kUsage; return 0; } + // Runs WSAStartup on Win32, no-op on POSIX + Socket::Startup(); + if (0 == strcmp(argv[1], "server")) { - RpcServer(argc, argv); - } else { - LOG(INFO) << kUSAGE; + return RpcServer(argc, argv); } + LOG(INFO) << kUsage; + return 0; } diff --git a/apps/cpp_rpc/rpc_env.cc b/apps/cpp_rpc/rpc_env.cc index 44f848dc749e..c4f77d3ffdfe 100644 --- a/apps/cpp_rpc/rpc_env.cc +++ b/apps/cpp_rpc/rpc_env.cc @@ -20,77 +20,74 @@ * \file rpc_env.cc * \brief Server environment of the RPC. */ +#include #include -#include -#ifndef _MSC_VER -#include +#ifndef _WIN32 #include +#include #include #else #include +#include +namespace { + int mkdir(const char* path, int /* ignored */) { return _mkdir(path); } +} #endif +#include #include -#include #include #include -#include +#include +#include -#include "rpc_env.h" #include "../../src/common/util.h" #include "../../src/runtime/file_util.h" +#include "rpc_env.h" + +namespace { +#if defined(__linux__) || defined(__ANDROID__) + const std::string untar_cmd = "tar -C "; +#elif defined(_WIN32) + const std::string untar_cmd = "wsl tar -C "; +#endif +}// Anonymous namespace namespace tvm { namespace runtime { - RPCEnv::RPCEnv() { - #if defined(__linux__) || defined(__ANDROID__) - base_ = "./rpc"; - mkdir(&base_[0], 0777); + base_ = "./rpc"; + mkdir(base_.c_str(), 0777); + TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath").set_body([](TVMArgs args, TVMRetValue* rv) { + static RPCEnv env; + *rv = env.GetPath(args[0]); + }); - TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath") - .set_body([](TVMArgs args, TVMRetValue* rv) { - static RPCEnv env; - *rv = env.GetPath(args[0]); - }); - - TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module") - .set_body([](TVMArgs args, TVMRetValue *rv) { - static RPCEnv env; - std::string file_name = env.GetPath(args[0]); - *rv = Load(&file_name, ""); - LOG(INFO) << "Load module from " << file_name << " ..."; - }); - #else - LOG(FATAL) << "Only support RPC in linux environment"; - #endif + TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module").set_body([](TVMArgs args, TVMRetValue* rv) { + static RPCEnv env; + std::string file_name = env.GetPath(args[0]); + *rv = Load(&file_name, ""); + LOG(INFO) << "Load module from " << file_name << " ..."; + }); } /*! - * \brief GetPath To get the workpath from packed function - * \param name The file name + * \brief GetPath To get the work path from packed function + * \param file_name The file name * \return The full path of file. */ -std::string RPCEnv::GetPath(std::string file_name) { +std::string RPCEnv::GetPath(const std::string& file_name) const { // we assume file_name has "/" means file_name is the exact path // and does not create /.rpc/ - if (file_name.find("/") != std::string::npos) { - return file_name; - } else { - return base_ + "/" + file_name; - } + return file_name.find('/') != std::string::npos ? file_name : base_ + "/" + file_name; } /*! * \brief Remove The RPC Environment cleanup function */ -void RPCEnv::CleanUp() { - #if defined(__linux__) || defined(__ANDROID__) - CleanDir(&base_[0]); - int ret = rmdir(&base_[0]); - if (ret != 0) { - LOG(WARNING) << "Remove directory " << base_ << " failed"; - } - #else - LOG(FATAL) << "Only support RPC in linux environment"; - #endif +void RPCEnv::CleanUp() const { + CleanDir(base_); + const int ret = rmdir(base_.c_str()); + if (ret != 0) { + LOG(WARNING) << "Remove directory " << base_ << " failed"; + } } /*! @@ -98,53 +95,54 @@ void RPCEnv::CleanUp() { * \param dirname The root directory name * \return vector Files in directory. */ -std::vector ListDir(const std::string &dirname) { +std::vector ListDir(const std::string& dirname) { std::vector vec; - #ifndef _MSC_VER - DIR *dp = opendir(dirname.c_str()); - if (dp == nullptr) { - int errsv = errno; - LOG(FATAL) << "ListDir " << dirname <<" error: " << strerror(errsv); - } - dirent *d; - while ((d = readdir(dp)) != nullptr) { - std::string filename = d->d_name; - if (filename != "." && filename != "..") { - std::string f = dirname; - if (f[f.length() - 1] != '/') { - f += '/'; - } - f += d->d_name; - vec.push_back(f); +#ifndef _WIN32 + DIR* dp = opendir(dirname.c_str()); + if (dp == nullptr) { + int errsv = errno; + LOG(FATAL) << "ListDir " << dirname << " error: " << strerror(errsv); + } + dirent* d; + while ((d = readdir(dp)) != nullptr) { + std::string filename = d->d_name; + if (filename != "." && filename != "..") { + std::string f = dirname; + if (f[f.length() - 1] != '/') { + f += '/'; } + f += d->d_name; + vec.push_back(f); } - closedir(dp); - #else - WIN32_FIND_DATA fd; - std::string pattern = dirname + "/*"; - HANDLE handle = FindFirstFile(pattern.c_str(), &fd); - if (handle == INVALID_HANDLE_VALUE) { - int errsv = GetLastError(); - LOG(FATAL) << "ListDir " << dirname << " error: " << strerror(errsv); - } - do { - if (fd.cFileName != "." && fd.cFileName != "..") { - std::string f = dirname; - char clast = f[f.length() - 1]; - if (f == ".") { - f = fd.cFileName; - } else if (clast != '/' && clast != '\\') { - f += '/'; - f += fd.cFileName; - } - vec.push_back(f); + } + closedir(dp); +#elif defined(_WIN32) + WIN32_FIND_DATAA fd; + const std::string pattern = dirname + "/*"; + HANDLE handle = FindFirstFileA(pattern.c_str(), &fd); + if (handle == INVALID_HANDLE_VALUE) { + const int errsv = GetLastError(); + LOG(FATAL) << "ListDir " << dirname << " error: " << strerror(errsv); + } + do { + std::string filename = fd.cFileName; + if (filename != "." && filename != "..") { + std::string f = dirname; + if (f[f.length() - 1] != '/') { + f += '/'; } - } while (FindNextFile(handle, &fd)); - FindClose(handle); - #endif + f += filename; + vec.push_back(f); + } + } while (FindNextFileA(handle, &fd)); + FindClose(handle); +#else + LOG(FATAL) << "Operating system not supported"; +#endif return vec; } +#if defined(__linux__) || defined(__ANDROID__) /*! * \brief LinuxShared Creates a linux shared library * \param output The output file name @@ -152,35 +150,66 @@ std::vector ListDir(const std::string &dirname) { * \param options The compiler options * \param cc The compiler */ -void LinuxShared(const std::string output, +void LinuxShared(const std::string output, const std::vector &files, - std::string options = "", + std::string options = "", std::string cc = "g++") { - std::string cmd = cc; - cmd += " -shared -fPIC "; - cmd += " -o " + output; - for (auto f = files.begin(); f != files.end(); ++f) { - cmd += " " + *f; - } - cmd += " " + options; - std::string err_msg; - auto executed_status = common::Execute(cmd, &err_msg); - if (executed_status) { - LOG(FATAL) << err_msg; - } + std::string cmd = cc; + cmd += " -shared -fPIC "; + cmd += " -o " + output; + for (const auto& file : files) { + cmd += " " + file; + } + cmd += " " + options; + std::string err_msg; + const auto executed_status = common::Execute(cmd, &err_msg); + if (executed_status) { + LOG(FATAL) << err_msg; + } } +#endif + +#ifdef _WIN32 +/*! + * \brief WindowsShared Creates a Windows shared library + * \param output The output file name + * \param files The files for building + * \param options The compiler options + * \param cc The compiler + */ +void WindowsShared(const std::string& output, + const std::vector& files, + const std::string& options = "", + const std::string& cc = "clang") { + std::string cmd = cc; + cmd += " -fuse-ld=lld-link -Wl,/EXPORT:__tvm_main__ -shared "; + cmd += " -o " + output; + for (const auto& file : files) { + cmd += " " + file; + } + cmd += " " + options; + std::string err_msg; + const auto executed_status = common::Execute(cmd, &err_msg); + if (executed_status) { + printf("compile error: %s\n", err_msg.c_str()); + LOG(FATAL) << err_msg; + } +} +#endif /*! * \brief CreateShared Creates a shared library * \param output The output file name * \param files The files for building */ -void CreateShared(const std::string output, const std::vector &files) { - #if defined(__linux__) || defined(__ANDROID__) - LinuxShared(output, files); - #else - LOG(FATAL) << "Do not support creating shared library"; - #endif +void CreateShared(const std::string& output, const std::vector& files) { +#if defined(__linux__) || defined(__ANDROID__) + LinuxShared(output, files); +#elif defined(_WIN32) + WindowsShared(output, files); +#else + LOG(FATAL) << "Operating system not supported"; +#endif } /*! @@ -193,61 +222,52 @@ void CreateShared(const std::string output, const std::vector &file * \param fmt The format of file * \return Module The loaded module */ -Module Load(std::string *fileIn, const std::string fmt) { - std::string file = *fileIn; +Module Load(std::string *fileIn, const std::string& fmt) { + const std::string& file = *fileIn; if (common::EndsWith(file, ".so")) { - return Module::LoadFromFile(file, fmt); + return Module::LoadFromFile(file, fmt); } - #if defined(__linux__) || defined(__ANDROID__) - std::string file_name = file + ".so"; - if (common::EndsWith(file, ".o")) { - std::vector files; - files.push_back(file); - CreateShared(file_name, files); - } else if (common::EndsWith(file, ".tar")) { - std::string tmp_dir = "./rpc/tmp/"; - mkdir(&tmp_dir[0], 0777); - std::string cmd = "tar -C " + tmp_dir + " -zxf " + file; - std::string err_msg; - int executed_status = common::Execute(cmd, &err_msg); - if (executed_status) { - LOG(FATAL) << err_msg; - } - CreateShared(file_name, ListDir(tmp_dir)); - CleanDir(tmp_dir); - rmdir(&tmp_dir[0]); - } else { - file_name = file; + std::string file_name = file + ".so"; + if (common::EndsWith(file, ".o")) { + std::vector files; + files.push_back(file); + CreateShared(file_name, files); + } else if (common::EndsWith(file, ".tar")) { + const std::string tmp_dir = "./rpc/tmp/"; + mkdir(tmp_dir.c_str(), 0777); + + const std::string cmd = untar_cmd + tmp_dir + " -zxf " + file; + + std::string err_msg; + const int executed_status = common::Execute(cmd, &err_msg); + if (executed_status) { + LOG(FATAL) << err_msg; } - *fileIn = file_name; - return Module::LoadFromFile(file_name, fmt); - #else - LOG(FATAL) << "Do not support creating shared library"; - #endif + CreateShared(file_name, ListDir(tmp_dir)); + CleanDir(tmp_dir); + (void)rmdir(tmp_dir.c_str()); + } else { + file_name = file; + } + *fileIn = file_name; + return Module::LoadFromFile(file_name, fmt); } /*! * \brief CleanDir Removes the files from the directory * \param dirname The name of the directory */ -void CleanDir(const std::string &dirname) { - #if defined(__linux__) || defined(__ANDROID__) - DIR *dp = opendir(dirname.c_str()); - dirent *d; - while ((d = readdir(dp)) != nullptr) { - std::string filename = d->d_name; - if (filename != "." && filename != "..") { - filename = dirname + "/" + d->d_name; - int ret = std::remove(&filename[0]); - if (ret != 0) { - LOG(WARNING) << "Remove file " << filename << " failed"; - } - } +void CleanDir(const std::string& dirname) { + auto files = ListDir(dirname); + for (const auto& filename : files) { + std::string file_path = dirname + "/"; + file_path += filename; + const int ret = std::remove(filename.c_str()); + if (ret != 0) { + LOG(WARNING) << "Remove file " << filename << " failed"; } - #else - LOG(FATAL) << "Only support RPC in linux environment"; - #endif + } } } // namespace runtime diff --git a/apps/cpp_rpc/rpc_env.h b/apps/cpp_rpc/rpc_env.h index 82409bae81a1..d046f6ecb480 100644 --- a/apps/cpp_rpc/rpc_env.h +++ b/apps/cpp_rpc/rpc_env.h @@ -40,7 +40,7 @@ namespace runtime { * \param file The format of file * \return Module The loaded module */ -Module Load(std::string *path, const std::string fmt = ""); +Module Load(std::string *path, const std::string& fmt = ""); /*! * \brief CleanDir Removes the files from the directory @@ -62,11 +62,11 @@ struct RPCEnv { * \param name The file name * \return The full path of file. */ - std::string GetPath(std::string file_name); + std::string GetPath(const std::string& file_name) const; /*! * \brief The RPC Environment cleanup function */ - void CleanUp(); + void CleanUp() const; private: /*! diff --git a/apps/cpp_rpc/rpc_server.cc b/apps/cpp_rpc/rpc_server.cc index b35a63bd67dc..f586b8f1faf6 100644 --- a/apps/cpp_rpc/rpc_server.cc +++ b/apps/cpp_rpc/rpc_server.cc @@ -22,24 +22,27 @@ * \brief RPC Server implementation. */ #include - #if defined(__linux__) || defined(__ANDROID__) #include #include #endif -#include -#include -#include -#include #include +#include +#include +#include #include -#include "rpc_server.h" -#include "rpc_env.h" -#include "rpc_tracker_client.h" +#include "../../src/common/socket.h" #include "../../src/runtime/rpc/rpc_session.h" #include "../../src/runtime/rpc/rpc_socket_impl.h" -#include "../../src/common/socket.h" +#include "rpc_env.h" +#include "rpc_server.h" +#include "rpc_tracker_client.h" +#if defined(_WIN32) +#include "win32_process.h" +#endif + +using namespace std::chrono; namespace tvm { namespace runtime { @@ -49,7 +52,7 @@ namespace runtime { * \param status status value */ #if defined(__linux__) || defined(__ANDROID__) -static pid_t waitPidEintr(int *status) { +static pid_t waitPidEintr(int* status) { pid_t pid = 0; while ((pid = waitpid(-1, status, 0)) == -1) { if (errno == EINTR) { @@ -76,34 +79,32 @@ class RPCServer { public: /*! * \brief Constructor. - */ - RPCServer(const std::string &host, - int port, - int port_end, - const std::string &tracker_addr, - const std::string &key, - const std::string &custom_addr) { - // Init the values - host_ = host; - port_ = port; - port_end_ = port_end; - tracker_addr_ = tracker_addr; - key_ = key; - custom_addr_ = custom_addr; + */ + RPCServer(std::string host, int port, int port_end, std::string tracker_addr, + std::string key, std::string custom_addr) : + host_(std::move(host)), port_(port), my_port_(0), port_end_(port_end), + tracker_addr_(std::move(tracker_addr)), key_(std::move(key)), + custom_addr_(std::move(custom_addr)) + { + } /*! * \brief Destructor. - */ + */ ~RPCServer() { - // Free the resources - tracker_sock_.Close(); - listen_sock_.Close(); + try { + // Free the resources + tracker_sock_.Close(); + listen_sock_.Close(); + } catch(...) { + + } } /*! * \brief Start Creates the RPC listen process and execution. - */ + */ void Start() { listen_sock_.Create(); my_port_ = listen_sock_.TryBindHost(host_, port_, port_end_); @@ -130,102 +131,95 @@ class RPCServer { tracker.TryConnect(); // step 2: wait for in-coming connections AcceptConnection(&tracker, &conn, &addr, &opts); - } - catch (const char* msg) { + } catch (const char* msg) { LOG(WARNING) << "Socket exception: " << msg; // close tracker resource tracker.Close(); continue; - } - catch (std::exception& e) { - // Other errors + } catch (const std::exception& e) { + // close tracker resource + tracker.Close(); LOG(WARNING) << "Exception standard: " << e.what(); continue; } int timeout = GetTimeOutFromOpts(opts); - #if defined(__linux__) || defined(__ANDROID__) - // step 3: serving - if (timeout != 0) { - const pid_t timer_pid = fork(); - if (timer_pid == 0) { - // Timer process - sleep(timeout); - exit(0); - } +#if defined(__linux__) || defined(__ANDROID__) + // step 3: serving + if (timeout != 0) { + const pid_t timer_pid = fork(); + if (timer_pid == 0) { + // Timer process + sleep(timeout); + exit(0); + } - const pid_t worker_pid = fork(); - if (worker_pid == 0) { - // Worker process - ServerLoopProc(conn, addr); - exit(0); - } + const pid_t worker_pid = fork(); + if (worker_pid == 0) { + // Worker process + ServerLoopProc(conn, addr); + exit(0); + } - int status = 0; - const pid_t finished_first = waitPidEintr(&status); - if (finished_first == timer_pid) { - kill(worker_pid, SIGKILL); - } else if (finished_first == worker_pid) { - kill(timer_pid, SIGKILL); - } else { - LOG(INFO) << "Child pid=" << finished_first << " unexpected, but still continue."; - } + int status = 0; + const pid_t finished_first = waitPidEintr(&status); + if (finished_first == timer_pid) { + kill(worker_pid, SIGKILL); + } else if (finished_first == worker_pid) { + kill(timer_pid, SIGKILL); + } else { + LOG(INFO) << "Child pid=" << finished_first << " unexpected, but still continue."; + } - int status_second = 0; - waitPidEintr(&status_second); + int status_second = 0; + waitPidEintr(&status_second); - // Logging. - if (finished_first == timer_pid) { - LOG(INFO) << "Child pid=" << worker_pid << " killed (timeout = " << timeout - << "), Process status = " << status_second; - } else if (finished_first == worker_pid) { - LOG(INFO) << "Child pid=" << timer_pid << " killed, Process status = " << status_second; - } - } else { - auto pid = fork(); - if (pid == 0) { - ServerLoopProc(conn, addr); - exit(0); - } - // Wait for the result - int status = 0; - wait(&status); - LOG(INFO) << "Child pid=" << pid << " exited, Process status =" << status; + // Logging. + if (finished_first == timer_pid) { + LOG(INFO) << "Child pid=" << worker_pid << " killed (timeout = " << timeout + << "), Process status = " << status_second; + } else if (finished_first == worker_pid) { + LOG(INFO) << "Child pid=" << timer_pid << " killed, Process status = " << status_second; } - #else - // step 3: serving - std::future proc(std::async(std::launch::async, - &RPCServer::ServerLoopProc, this, conn, addr)); - // wait until server process finish or timeout - if (timeout != 0) { - // Autoterminate after timeout - proc.wait_for(std::chrono::seconds(timeout)); - } else { - // Wait for the result - proc.get(); + } else { + auto pid = fork(); + if (pid == 0) { + ServerLoopProc(conn, addr); + exit(0); } - #endif + // Wait for the result + int status = 0; + wait(&status); + LOG(INFO) << "Child pid=" << pid << " exited, Process status =" << status; + } +#elif defined(WIN32) + auto start_time = high_resolution_clock::now(); + try { + SpawnRPCChild(conn.sockfd, seconds(timeout)); + } catch (const std::exception&) { + + } + auto dur = high_resolution_clock::now() - start_time; + + LOG(INFO) << "Serve Time " << duration_cast(dur).count() << "ms"; +#endif // close from our side. LOG(INFO) << "Socket Connection Closed"; conn.Close(); } } - /*! * \brief AcceptConnection Accepts the RPC Server connection. * \param tracker Tracker details. - * \param conn New connection information. + * \param conn_sock New connection information. * \param addr New connection address information. * \param opts Parsed options for socket * \param ping_period Timeout for select call waiting */ - void AcceptConnection(TrackerClient* tracker, - common::TCPSocket* conn_sock, - common::SockAddr* addr, - std::string* opts, - int ping_period = 2) { - std::set old_keyset; + void AcceptConnection(TrackerClient* tracker, common::TCPSocket* conn_sock, + common::SockAddr* addr, std::string* opts, int ping_period = 2) { + std::set old_keyset; std::string matchkey; // Report resource to tracker and get key @@ -236,7 +230,7 @@ class RPCServer { common::TCPSocket conn = listen_sock_.Accept(addr); int code = kRPCMagic; - CHECK_EQ(conn.RecvAll(&code, sizeof(code)), sizeof(code)); + CHECK_EQ(conn.RecvAll(&code, sizeof(code)), sizeof(code)); if (code != kRPCMagic) { conn.Close(); LOG(FATAL) << "Client connected is not TVM RPC server"; @@ -265,15 +259,15 @@ class RPCServer { std::string arg0; ssin >> arg0; if (arg0 != expect_header) { - code = kRPCMismatch; - CHECK_EQ(conn.SendAll(&code, sizeof(code)), sizeof(code)); - conn.Close(); - LOG(WARNING) << "Mismatch key from" << addr->AsString(); - continue; + code = kRPCMismatch; + CHECK_EQ(conn.SendAll(&code, sizeof(code)), sizeof(code)); + conn.Close(); + LOG(WARNING) << "Mismatch key from" << addr->AsString(); + continue; } else { code = kRPCSuccess; CHECK_EQ(conn.SendAll(&code, sizeof(code)), sizeof(code)); - keylen = server_key.length(); + keylen = int(server_key.length()); CHECK_EQ(conn.SendAll(&keylen, sizeof(keylen)), sizeof(keylen)); CHECK_EQ(conn.SendAll(server_key.c_str(), keylen), keylen); LOG(INFO) << "Connection success " << addr->AsString(); @@ -289,25 +283,23 @@ class RPCServer { * \param sock The socket information * \param addr The socket address information */ - void ServerLoopProc(common::TCPSocket sock, common::SockAddr addr) { - // Server loop - auto env = RPCEnv(); - RPCServerLoop(sock.sockfd); - LOG(INFO) << "Finish serving " << addr.AsString(); - env.CleanUp(); + static void ServerLoopProc(common::TCPSocket sock, common::SockAddr addr) { + // Server loop + const auto env = RPCEnv(); + RPCServerLoop(int(sock.sockfd)); + LOG(INFO) << "Finish serving " << addr.AsString(); + env.CleanUp(); } /*! * \brief GetTimeOutFromOpts Parse and get the timeout option. * \param opts The option string - * \param timeout value after parsing. */ - int GetTimeOutFromOpts(std::string opts) { - std::string cmd; - std::string option = "-timeout="; + int GetTimeOutFromOpts(const std::string& opts) const { + const std::string option = "-timeout="; if (opts.find(option) == 0) { - cmd = opts.substr(opts.find_last_of(option) + 1); + const std::string cmd = opts.substr(opts.find_last_of(option) + 1); CHECK(common::IsNumber(cmd)) << "Timeout is not valid"; return std::stoi(cmd); } @@ -325,35 +317,45 @@ class RPCServer { common::TCPSocket tracker_sock_; }; +#if defined(WIN32) +/*! +* \brief ServerLoopFromChild The Server loop process. +* \param socket The socket information +*/ +void ServerLoopFromChild(SOCKET socket) { + // Server loop + tvm::common::TCPSocket sock(socket); + const auto env = RPCEnv(); + RPCServerLoop(int(sock.sockfd)); + + sock.Close(); + env.CleanUp(); +} +#endif + /*! * \brief RPCServerCreate Creates the RPC Server. * \param host The hostname of the server, Default=0.0.0.0 * \param port The port of the RPC, Default=9090 * \param port_end The end search port of the RPC, Default=9199 - * \param tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default="" + * \param tracker_addr The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default="" * \param key The key used to identify the device type in tracker. Default="" * \param custom_addr Custom IP Address to Report to RPC Tracker. Default="" * \param silent Whether run in silent mode. Default=True */ -void RPCServerCreate(std::string host, - int port, - int port_end, - std::string tracker_addr, - std::string key, - std::string custom_addr, - bool silent) { +void RPCServerCreate(std::string host, int port, int port_end, std::string tracker_addr, + std::string key, std::string custom_addr, bool silent) { if (silent) { // Only errors and fatal is logged dmlc::InitLogging("--minloglevel=2"); } // Start the rpc server - RPCServer rpc(host, port, port_end, tracker_addr, key, custom_addr); + RPCServer rpc(std::move(host), port, port_end, std::move(tracker_addr), std::move(key), std::move(custom_addr)); rpc.Start(); } -TVM_REGISTER_GLOBAL("rpc._ServerCreate") -.set_body([](TVMArgs args, TVMRetValue* rv) { - RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6]); - }); +TVM_REGISTER_GLOBAL("rpc._ServerCreate").set_body([](TVMArgs args, TVMRetValue* rv) { + RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6]); +}); } // namespace runtime } // namespace tvm diff --git a/apps/cpp_rpc/rpc_server.h b/apps/cpp_rpc/rpc_server.h index 205182e4449a..db7c89d823dd 100644 --- a/apps/cpp_rpc/rpc_server.h +++ b/apps/cpp_rpc/rpc_server.h @@ -30,6 +30,15 @@ namespace tvm { namespace runtime { +#if defined(WIN32) +/*! + * \brief ServerLoopFromChild The Server loop process. + * \param sock The socket information + * \param addr The socket address information + */ +void ServerLoopFromChild(SOCKET socket); +#endif + /*! * \brief RPCServerCreate Creates the RPC Server. * \param host The hostname of the server, Default=0.0.0.0 @@ -40,13 +49,13 @@ namespace runtime { * \param custom_addr Custom IP Address to Report to RPC Tracker. Default="" * \param silent Whether run in silent mode. Default=True */ -TVM_DLL void RPCServerCreate(std::string host = "", - int port = 9090, - int port_end = 9099, - std::string tracker_addr = "", - std::string key = "", - std::string custom_addr = "", - bool silent = true); +void RPCServerCreate(std::string host = "", + int port = 9090, + int port_end = 9099, + std::string tracker_addr = "", + std::string key = "", + std::string custom_addr = "", + bool silent = true); } // namespace runtime } // namespace tvm #endif // TVM_APPS_CPP_RPC_SERVER_H_ diff --git a/apps/cpp_rpc/win32_process.cc b/apps/cpp_rpc/win32_process.cc new file mode 100644 index 000000000000..4af222a4a8dd --- /dev/null +++ b/apps/cpp_rpc/win32_process.cc @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif +#include +#include +#include +#include +#include +#include +#include +#include +#include "win32_process.h" +#include "rpc_server.h" + +using namespace std::chrono; +using namespace tvm::runtime; + +namespace { +// The prefix path for the memory mapped file used to store IPC information +const std::string kMemoryMapPrefix = "/MAPPED_FILE/TVM_RPC"; +// Used to construct unique names for named resources in the parent process +const std::string kParent = "parent"; +// Used to construct unique names for named resources in the child process +const std::string kChild = "child"; +// The timeout of the WIN32 events, in the parent and the child +const milliseconds kEventTimeout(2000); + +// Used to create unique WIN32 mmap paths and event names +int child_counter_ = 0; + +/*! + * \brief HandleDeleter Deleter for UniqueHandle smart pointer + * \param handle The WIN32 HANDLE to manage + */ +struct HandleDeleter { + void operator()(HANDLE handle) const { + if (handle != INVALID_HANDLE_VALUE && handle != nullptr) { + CloseHandle(handle); + } + } +}; + +/*! + * \brief UniqueHandle Smart pointer to manage a WIN32 HANDLE + */ +using UniqueHandle = std::unique_ptr; + +/*! + * \brief MakeUniqueHandle Helper method to construct a UniqueHandle + * \param handle The WIN32 HANDLE to manage + */ +UniqueHandle MakeUniqueHandle(HANDLE handle) { + if (handle == INVALID_HANDLE_VALUE || handle == nullptr) { + return nullptr; + } + + return UniqueHandle(handle); +} + +/*! + * \brief GetSocket Gets the socket info from the parent process and duplicates the socket + * \param mmap_path The path to the memory mapped info set by the parent + */ +SOCKET GetSocket(const std::string& mmap_path) { + WSAPROTOCOL_INFO protocol_info; + + const std::string parent_event_name = mmap_path + kParent; + const std::string child_event_name = mmap_path + kChild; + + // Open the events + UniqueHandle parent_file_mapping_event; + if ((parent_file_mapping_event = MakeUniqueHandle(OpenEventA(SYNCHRONIZE, false, parent_event_name.c_str()))) == nullptr) { + LOG(FATAL) << "OpenEvent() failed: " << GetLastError(); + } + + UniqueHandle child_file_mapping_event; + if ((child_file_mapping_event = MakeUniqueHandle(OpenEventA(EVENT_MODIFY_STATE, false, child_event_name.c_str()))) == nullptr) { + LOG(FATAL) << "OpenEvent() failed: " << GetLastError(); + } + + // Wait for the parent to set the event, notifying WSAPROTOCOL_INFO is ready to be read + if (WaitForSingleObject(parent_file_mapping_event.get(), uint32_t(kEventTimeout.count())) != WAIT_OBJECT_0) { + LOG(FATAL) << "WaitForSingleObject() failed: " << GetLastError(); + } + + const UniqueHandle file_map = MakeUniqueHandle(OpenFileMappingA(FILE_MAP_READ | FILE_MAP_WRITE, + false, + mmap_path.c_str())); + if (!file_map) { + LOG(INFO) << "CreateFileMapping() failed: " << GetLastError(); + } + + void* map_view = MapViewOfFile(file_map.get(), + FILE_MAP_READ | FILE_MAP_WRITE, + 0, 0, 0); + + SOCKET sock_duplicated = INVALID_SOCKET; + + if (map_view != nullptr) { + memcpy(&protocol_info, map_view, sizeof(WSAPROTOCOL_INFO)); + UnmapViewOfFile(map_view); + + // Creates the duplicate socket, that was created in the parent + sock_duplicated = WSASocket(FROM_PROTOCOL_INFO, + FROM_PROTOCOL_INFO, + FROM_PROTOCOL_INFO, + &protocol_info, + 0, + 0); + + // Let the parent know we are finished dupicating the socket + SetEvent(child_file_mapping_event.get()); + } else { + LOG(FATAL) << "MapViewOfFile() failed: " << GetLastError(); + } + + return sock_duplicated; +} +}// Anonymous namespace + +namespace tvm { +namespace runtime { +/*! + * \brief SpawnRPCChild Spawns a child process with a given timeout to run + * \param fd The client socket to duplicate in the child + * \param timeout The time in seconds to wait for the child to complete before termination + */ +void SpawnRPCChild(SOCKET fd, seconds timeout) { + STARTUPINFOA startup_info; + + memset(&startup_info, 0, sizeof(startup_info)); + startup_info.cb = sizeof(startup_info); + + std::string file_map_path = kMemoryMapPrefix + std::to_string(child_counter_++); + + const std::string parent_event_name = file_map_path + kParent; + const std::string child_event_name = file_map_path + kChild; + + // Create an event to let the child know the socket info was set to the mmap file + UniqueHandle parent_file_mapping_event; + if ((parent_file_mapping_event = MakeUniqueHandle(CreateEventA(nullptr, true, false, parent_event_name.c_str()))) == nullptr) { + LOG(FATAL) << "CreateEvent for parent file mapping failed"; + } + + UniqueHandle child_file_mapping_event; + // An event to let the parent know the socket info was read from the mmap file + if ((child_file_mapping_event = MakeUniqueHandle(CreateEventA(nullptr, true, false, child_event_name.c_str()))) == nullptr) { + LOG(FATAL) << "CreateEvent for child file mapping failed"; + } + + char current_executable[MAX_PATH]; + + // Get the full path of the current executable + GetModuleFileNameA(nullptr, current_executable, MAX_PATH); + + std::string child_command_line = current_executable; + child_command_line += " server --child_proc="; + child_command_line += file_map_path; + + // CreateProcessA requires a non const char*, so we copy our std::string + std::unique_ptr command_line_ptr(new char[child_command_line.size() + 1]); + strcpy(command_line_ptr.get(), child_command_line.c_str()); + + PROCESS_INFORMATION child_process_info; + if (CreateProcessA(nullptr, + command_line_ptr.get(), + nullptr, + nullptr, + false, + CREATE_NO_WINDOW, + nullptr, + nullptr, + &startup_info, + &child_process_info)) { + // Child process and thread handles must be closed, so wrapped in RAII + auto child_process_handle = MakeUniqueHandle(child_process_info.hProcess); + auto child_process_thread_handle = MakeUniqueHandle(child_process_info.hThread); + + WSAPROTOCOL_INFO protocol_info; + // Get info needed to duplicate the socket + if (WSADuplicateSocket(fd, + child_process_info.dwProcessId, + &protocol_info) == SOCKET_ERROR) { + LOG(FATAL) << "WSADuplicateSocket(): failed. Error =" << WSAGetLastError(); + } + + // Create a mmap file to store the info needed for duplicating the SOCKET in the child proc + UniqueHandle file_map = MakeUniqueHandle(CreateFileMappingA(INVALID_HANDLE_VALUE, + nullptr, + PAGE_READWRITE, + 0, + sizeof(WSAPROTOCOL_INFO), + file_map_path.c_str())); + if (!file_map) { + LOG(INFO) << "CreateFileMapping() failed: " << GetLastError(); + } + + if (GetLastError() == ERROR_ALREADY_EXISTS) { + LOG(FATAL) << "CreateFileMapping(): mapping file already exists"; + } else { + void* map_view = MapViewOfFile(file_map.get(), FILE_MAP_READ | FILE_MAP_WRITE, 0, 0, 0); + + if (map_view != nullptr) { + memcpy(map_view, &protocol_info, sizeof(WSAPROTOCOL_INFO)); + UnmapViewOfFile(map_view); + + // Let child proc know the mmap file is ready to be read + SetEvent(parent_file_mapping_event.get()); + + // Wait for the child to finish reading mmap file + if (WaitForSingleObject(child_file_mapping_event.get(), uint32_t(kEventTimeout.count())) != WAIT_OBJECT_0) { + TerminateProcess(child_process_handle.get(), 0); + LOG(FATAL) << "WaitForSingleObject for child file mapping timed out. Terminating child process."; + } + } else { + TerminateProcess(child_process_handle.get(), 0); + LOG(FATAL) << "MapViewOfFile() failed: " << GetLastError(); + } + } + + const DWORD process_timeout = timeout.count() + ? uint32_t(duration_cast(timeout).count()) + : INFINITE; + + // Wait for child process to exit, or hit configured timeout + if (WaitForSingleObject(child_process_handle.get(), process_timeout) != WAIT_OBJECT_0) { + LOG(INFO) << "Child process timeout. Terminating."; + TerminateProcess(child_process_handle.get(), 0); + } + } else { + LOG(INFO) << "Create child process failed: " << GetLastError(); + } +} +/*! + * \brief ChildProcSocketHandler Ran from the child process and runs server to handle the client socket + * \param mmap_path The memory mapped file path that will contain the information to duplicate the client socket from the parent + */ +void ChildProcSocketHandler(const std::string& mmap_path) { + SOCKET socket; + const auto last_thread_priority = GetThreadPriority(GetCurrentThread()); + + // Set high thread priority to avoid the thread scheduler from + // interfering with any measurements in the RPC server. + SetThreadPriority(GetCurrentThread(), THREAD_PRIORITY_TIME_CRITICAL); + + try { + if ((socket = GetSocket(mmap_path)) != INVALID_SOCKET) { + tvm::runtime::ServerLoopFromChild(socket); + } + else { + LOG(FATAL) << "GetSocket() failed"; + } + } catch (...) { + // Restore thread priority + SetThreadPriority(GetCurrentThread(), last_thread_priority); + throw; + } +} +} // namespace runtime +} // namespace tvm \ No newline at end of file diff --git a/apps/cpp_rpc/win32_process.h b/apps/cpp_rpc/win32_process.h new file mode 100644 index 000000000000..7d1a27680ed3 --- /dev/null +++ b/apps/cpp_rpc/win32_process.h @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + /*! + * \file win32_process.h + * \brief Win32 process code to mimic a POSIX fork() + */ +#ifndef TVM_APPS_CPP_RPC_WIN32_PROCESS_H_ +#define TVM_APPS_CPP_RPC_WIN32_PROCESS_H_ +#include +#include +namespace tvm { +namespace runtime { +/*! + * \brief SpawnRPCChild Spawns a child process with a given timeout to run + * \param fd The client socket to duplicate in the child + * \param timeout The time in seconds to wait for the child to complete before termination + */ +void SpawnRPCChild(SOCKET fd, std::chrono::seconds timeout); +/*! + * \brief ChildProcSocketHandler Ran from the child process and runs server to handle the client socket + * \param mmap_path The memory mapped file path that will contain the information to duplicate the client socket from the parent + */ +void ChildProcSocketHandler(const std::string& mmap_path); +} // namespace runtime +} // namespace tvm +#endif // TVM_APPS_CPP_RPC_WIN32_PROCESS_H_ \ No newline at end of file diff --git a/src/common/ring_buffer.h b/src/common/ring_buffer.h index f548acf1846b..1ce4a88a83b3 100644 --- a/src/common/ring_buffer.h +++ b/src/common/ring_buffer.h @@ -63,7 +63,7 @@ class RingBuffer { size_t ncopy = head_ptr_ + bytes_available_ - old_size; memcpy(&ring_[0] + old_size, &ring_[0], ncopy); } - } else if (ring_.size() > n * 8 && ring_.size() > kInitCapacity) { + } else if (ring_.size() > n * 8 && ring_.size() > kInitCapacity && bytes_available_ > 0) { // shrink too large temporary buffer to avoid out of memory on some embedded devices size_t old_bytes = bytes_available_; diff --git a/src/runtime/rpc/rpc_socket_impl.cc b/src/runtime/rpc/rpc_socket_impl.cc index cb59e723251b..2b7edbd8d45d 100644 --- a/src/runtime/rpc/rpc_socket_impl.cc +++ b/src/runtime/rpc/rpc_socket_impl.cc @@ -34,8 +34,13 @@ class SockChannel final : public RPCChannel { explicit SockChannel(common::TCPSocket sock) : sock_(sock) {} ~SockChannel() { - if (!sock_.BadSocket()) { - sock_.Close(); + try { + // BadSocket can throw + if (!sock_.BadSocket()) { + sock_.Close(); + } + } catch (...) { + } } size_t Send(const void* data, size_t size) final { @@ -100,7 +105,8 @@ Module RPCClientConnect(std::string url, int port, std::string key) { return CreateRPCModule(RPCConnect(url, port, "client:" + key)); } -void RPCServerLoop(int sockfd) { +// TVM_DLL needed for MSVC +TVM_DLL void RPCServerLoop(int sockfd) { common::TCPSocket sock( static_cast(sockfd)); RPCSession::Create( From 9846d2c0d2480c77a5a2691fe4122757e0f248ff Mon Sep 17 00:00:00 2001 From: Jeremiah Morrill Date: Mon, 2 Dec 2019 17:38:43 -0800 Subject: [PATCH 09/33] Fix upstream compilation error on MSVC --- include/tvm/runtime/container.h | 1 + 1 file changed, 1 insertion(+) diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h index 2714ac237131..d6bbc419fc23 100644 --- a/include/tvm/runtime/container.h +++ b/include/tvm/runtime/container.h @@ -198,6 +198,7 @@ class ADTObj : public Object, public InplaceArrayBase { } friend class ADT; + template friend class InplaceArrayBase; }; From 029b5ce670a17a6242407910ad607881a28547b1 Mon Sep 17 00:00:00 2001 From: Jeremiah Morrill Date: Thu, 5 Dec 2019 16:09:31 -0800 Subject: [PATCH 10/33] XGBoostCostModel crash if num_threads==None --- python/tvm/autotvm/tuner/xgboost_cost_model.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/python/tvm/autotvm/tuner/xgboost_cost_model.py b/python/tvm/autotvm/tuner/xgboost_cost_model.py index b46fd9a7230f..1001b342b98e 100644 --- a/python/tvm/autotvm/tuner/xgboost_cost_model.py +++ b/python/tvm/autotvm/tuner/xgboost_cost_model.py @@ -169,14 +169,10 @@ def _reset_pool(self, space, target, task): # To ensure each process in the pool is properly set, we have to do # some synchronization by sending an async call and waiting for # the queue to have an item set - - #pool_size = min(32, int(self.num_threads)) - pool_size = self.num_threads + pool_size = self.num_threads if self.num_threads != None else multiprocessing.cpu_count() if self.pool == None: self.pool = ProcessPool(pool_size) - manager = pathos_multiprocess.Manager() - pipe_syncs = [] # A simple pathos.map would be cleaner, but it seems that in some cases, From 3d4ed58f1239757a50ae1f6700c92f0dc6f477ec Mon Sep 17 00:00:00 2001 From: Jeremiah Morrill Date: Thu, 5 Dec 2019 16:11:08 -0800 Subject: [PATCH 11/33] CXX RPC Server fix windows only defs --- apps/cpp_rpc/CMakeLists.txt | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/apps/cpp_rpc/CMakeLists.txt b/apps/cpp_rpc/CMakeLists.txt index 61c40c1affe6..98887381c068 100644 --- a/apps/cpp_rpc/CMakeLists.txt +++ b/apps/cpp_rpc/CMakeLists.txt @@ -12,7 +12,11 @@ endif() set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}) add_executable(tvm_rpc ${TVM_RPC_SOURCES}) set_property(TARGET tvm_rpc PROPERTY INTERPROCEDURAL_OPTIMIZATION_RELEASE TRUE) -target_compile_definitions(tvm_rpc PUBLIC -DNOMINMAX) + +if(WIN32) + target_compile_definitions(tvm_rpc PUBLIC -DNOMINMAX) +endif() + target_include_directories( tvm_rpc PUBLIC "../../include" From 7ed37458c50a86212caccc8affddede865ccfed9 Mon Sep 17 00:00:00 2001 From: Jeremiah Morrill Date: Thu, 5 Dec 2019 16:12:04 -0800 Subject: [PATCH 12/33] Removed unneeded SetThreadPriority for Win32 --- apps/cpp_rpc/win32_process.cc | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/apps/cpp_rpc/win32_process.cc b/apps/cpp_rpc/win32_process.cc index 4af222a4a8dd..c6c72d79ab81 100644 --- a/apps/cpp_rpc/win32_process.cc +++ b/apps/cpp_rpc/win32_process.cc @@ -256,24 +256,18 @@ void SpawnRPCChild(SOCKET fd, seconds timeout) { */ void ChildProcSocketHandler(const std::string& mmap_path) { SOCKET socket; - const auto last_thread_priority = GetThreadPriority(GetCurrentThread()); // Set high thread priority to avoid the thread scheduler from // interfering with any measurements in the RPC server. SetThreadPriority(GetCurrentThread(), THREAD_PRIORITY_TIME_CRITICAL); - - try { - if ((socket = GetSocket(mmap_path)) != INVALID_SOCKET) { - tvm::runtime::ServerLoopFromChild(socket); - } - else { - LOG(FATAL) << "GetSocket() failed"; - } - } catch (...) { - // Restore thread priority - SetThreadPriority(GetCurrentThread(), last_thread_priority); - throw; + + if ((socket = GetSocket(mmap_path)) != INVALID_SOCKET) { + tvm::runtime::ServerLoopFromChild(socket); + } + else { + LOG(FATAL) << "GetSocket() failed"; } + } } // namespace runtime } // namespace tvm \ No newline at end of file From ce4111cecd5fe70082ce100b1d906f0211df2500 Mon Sep 17 00:00:00 2001 From: Jeremiah Morrill Date: Thu, 5 Dec 2019 16:13:20 -0800 Subject: [PATCH 13/33] Changed windows clang compile command. Removed unneeded printf call --- apps/cpp_rpc/rpc_env.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/apps/cpp_rpc/rpc_env.cc b/apps/cpp_rpc/rpc_env.cc index c4f77d3ffdfe..edd166f340ac 100644 --- a/apps/cpp_rpc/rpc_env.cc +++ b/apps/cpp_rpc/rpc_env.cc @@ -182,7 +182,7 @@ void WindowsShared(const std::string& output, const std::string& options = "", const std::string& cc = "clang") { std::string cmd = cc; - cmd += " -fuse-ld=lld-link -Wl,/EXPORT:__tvm_main__ -shared "; + cmd += " -O2 -flto=full -fuse-ld=lld-link -Wl,/EXPORT:__tvm_main__ -shared "; cmd += " -o " + output; for (const auto& file : files) { cmd += " " + file; @@ -191,7 +191,6 @@ void WindowsShared(const std::string& output, std::string err_msg; const auto executed_status = common::Execute(cmd, &err_msg); if (executed_status) { - printf("compile error: %s\n", err_msg.c_str()); LOG(FATAL) << err_msg; } } From 557270822a8d5967a03483620812b4efcd35e5ff Mon Sep 17 00:00:00 2001 From: Jeremiah Morrill Date: Tue, 17 Dec 2019 13:48:34 -0800 Subject: [PATCH 14/33] Turned off CUDA git cache in c++ rpc server. Load modules that end in .dll in c++ rpc server --- apps/cpp_rpc/main.cc | 3 +++ apps/cpp_rpc/rpc_env.cc | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/apps/cpp_rpc/main.cc b/apps/cpp_rpc/main.cc index f37cf56d39d0..dc4717f9f8d9 100644 --- a/apps/cpp_rpc/main.cc +++ b/apps/cpp_rpc/main.cc @@ -292,6 +292,9 @@ int main(int argc, char * argv[]) { // Runs WSAStartup on Win32, no-op on POSIX Socket::Startup(); +#if defined(_WIN32) + SetEnvironmentVariableA("CUDA_CACHE_DISABLE", "1"); +#endif if (0 == strcmp(argv[1], "server")) { return RpcServer(argc, argv); diff --git a/apps/cpp_rpc/rpc_env.cc b/apps/cpp_rpc/rpc_env.cc index edd166f340ac..fb761a1911e7 100644 --- a/apps/cpp_rpc/rpc_env.cc +++ b/apps/cpp_rpc/rpc_env.cc @@ -223,7 +223,7 @@ void CreateShared(const std::string& output, const std::vector& fil */ Module Load(std::string *fileIn, const std::string& fmt) { const std::string& file = *fileIn; - if (common::EndsWith(file, ".so")) { + if (common::EndsWith(file, ".so") || common::EndsWith(file, ".dll")) { return Module::LoadFromFile(file, fmt); } From 9b7278981ba4ae713c00130be8da45bdbc100996 Mon Sep 17 00:00:00 2001 From: Jeremiah Morrill Date: Tue, 17 Dec 2019 14:10:10 -0800 Subject: [PATCH 15/33] Fix superfluous formatting --- apps/cpp_rpc/rpc_server.cc | 3 +- .../tvm/autotvm/tuner/xgboost_cost_model.py | 2 - python/tvm/rpc/server.py | 1068 ++++++++--------- 3 files changed, 536 insertions(+), 537 deletions(-) diff --git a/apps/cpp_rpc/rpc_server.cc b/apps/cpp_rpc/rpc_server.cc index f586b8f1faf6..7f05e42ae721 100644 --- a/apps/cpp_rpc/rpc_server.cc +++ b/apps/cpp_rpc/rpc_server.cc @@ -354,7 +354,8 @@ void RPCServerCreate(std::string host, int port, int port_end, std::string track rpc.Start(); } -TVM_REGISTER_GLOBAL("rpc._ServerCreate").set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("rpc._ServerCreate") +.set_body([](TVMArgs args, TVMRetValue* rv) { RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6]); }); } // namespace runtime diff --git a/python/tvm/autotvm/tuner/xgboost_cost_model.py b/python/tvm/autotvm/tuner/xgboost_cost_model.py index 1001b342b98e..4f825306d970 100644 --- a/python/tvm/autotvm/tuner/xgboost_cost_model.py +++ b/python/tvm/autotvm/tuner/xgboost_cost_model.py @@ -249,7 +249,6 @@ def fit(self, xs, ys, plan_size): self.base_model = None else: dtrain.set_base_margin(discount * self.base_model.predict(xs, output_margin=True)) - self.bst = xgb.train(self.xgb_params, dtrain, num_boost_round=8000, callbacks=[custom_callback( @@ -536,7 +535,6 @@ def callback(env): res = [x.split(':') for x in bst_eval.split()] for kv in res[1:]: res_dict[kv[0]] = [float(kv[1])] - eval_res = [] keys = list(res_dict.keys()) keys.sort(key=lambda x: x if metric_shortname not in x else "a" + x) diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py index 1986a8cbae59..52fc4c520290 100644 --- a/python/tvm/rpc/server.py +++ b/python/tvm/rpc/server.py @@ -1,534 +1,534 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""RPC server implementation. - -Note ----- -Server is TCP based with the following protocol: -- Initial handshake to the peer - - [RPC_MAGIC, keysize(int32), key-bytes] -- The key is in format - - {server|client}:device-type[:random-key] [-timeout=timeout] -""" -# pylint: disable=invalid-name - -from __future__ import absolute_import - -import os -import ctypes -import socket -import select -import struct -import logging -import multiprocessing -import subprocess -import time -import sys -import signal - -if os.name == 'nt': - from pathos.helpers import ProcessPool - import threading - import multiprocessing.pool - -from .._ffi.function import register_func -from .._ffi.base import py_str -from .._ffi.libinfo import find_lib_path -from ..module import load as _load_module -from ..contrib import util -from . import base -from . base import TrackerCode - -logger = logging.getLogger('RPCServer') - -_temp = None - -class NoDaemonProcess(multiprocessing.Process): - # make 'daemon' attribute always return False - def _get_daemon(self): - return False - def _set_daemon(self, value): - pass - daemon = property(_get_daemon, _set_daemon) - -# We sub-class multiprocessing.pool.Pool instead of multiprocessing.Pool -# because the latter is only a wrapper function, not a proper class. -class MyPool(multiprocessing.pool.Pool): - Process = NoDaemonProcess - -# pylint: disable=unused-variable -@register_func("tvm.rpc.server.workpath", override=True) -def get_workpath(path): - global _temp - return _temp.relpath(path) - -@register_func("tvm.rpc.server.load_module", override=True) -def load_module(file_name): - """Load module from remote side.""" - global _temp - path = _temp.relpath(file_name) - m = _load_module(path) - logger.info("load_module %s", path) - return m - -def _server_env(load_library, work_path=None): - """Server environment function return temp dir""" - global _temp - if work_path: - temp = work_path - else: - temp = util.tempdir() - - _temp = temp - libs = [] - load_library = load_library.split(":") if load_library else [] - for file_name in load_library: - file_name = find_lib_path(file_name)[0] - libs.append(ctypes.CDLL(file_name, ctypes.RTLD_GLOBAL)) - logger.info("Load additional library %s", file_name) - temp.libs = libs - return temp - -def _serve_loop(sock, addr, load_library, work_path=None): - """Server loop""" - try: - sockfd = sock.fileno() - temp = _server_env(load_library, work_path) - base._ServerLoop(sockfd) - if not work_path: - temp.remove() - except Exception as ex: - print(ex) - pass - - logger.info("Finish serving %s", addr) - -def _serve_loop_pool(args): - """Server loop""" - sock = args["sock"] - addr = args["addr"] - load_library = args["load_library"] - work_path = args["work_path"] - - try: - sockfd = sock.fileno() - temp = _server_env(load_library, work_path) - base._ServerLoop(sockfd) - if not work_path: - temp.remove() - except Exception as ex: - print(ex) - pass - - logger.info("Finish serving %s", addr) - -def _parse_server_opt(opts): - # parse client options - ret = {} - for kv in opts: - if kv.startswith("-timeout="): - ret["timeout"] = float(kv[9:]) - return ret - -trial_counter = 0 - -def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr): - global trial_counter - executerPool = MyPool(processes=1) - - """Listening loop of the server master.""" - def _accept_conn(listen_sock, tracker_conn, ping_period=2): - """Accept connection from the other places. - - Parameters - ---------- - listen_sock: Socket - The socket used by listening process. - - tracker_conn : connnection to tracker - Tracker connection - - ping_period : float, optional - ping tracker every k seconds if no connection is accepted. - """ - old_keyset = set() - # Report resource to tracker - if tracker_conn: - matchkey = base.random_key(rpc_key + ":") - base.sendjson(tracker_conn, - [TrackerCode.PUT, rpc_key, (port, matchkey), custom_addr]) - assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS - else: - matchkey = rpc_key - - unmatch_period_count = 0 - unmatch_timeout = 4 - # Wait until we get a valid connection - while True: - if tracker_conn: - trigger = select.select([listen_sock], [], [], ping_period) - if not listen_sock in trigger[0]: - base.sendjson(tracker_conn, [TrackerCode.GET_PENDING_MATCHKEYS]) - pending_keys = base.recvjson(tracker_conn) - old_keyset.add(matchkey) - # if match key not in pending key set - # it means the key is acquired by a client but not used. - if matchkey not in pending_keys: - unmatch_period_count += 1 - else: - unmatch_period_count = 0 - # regenerate match key if key is acquired but not used for a while - if unmatch_period_count * ping_period > unmatch_timeout + ping_period: - logger.info("no incoming connections, regenerate key ...") - matchkey = base.random_key(rpc_key + ":", old_keyset) - base.sendjson(tracker_conn, - [TrackerCode.PUT, rpc_key, (port, matchkey), - custom_addr]) - assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS - unmatch_period_count = 0 - continue - conn, addr = listen_sock.accept() - magic = struct.unpack(" max_retry: - raise RuntimeError("Maximum retry error: last error: %s" % str(err)) - time.sleep(retry_period) - -def _popen(cmd): - proc = subprocess.Popen(cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - env=os.environ) - (out, _) = proc.communicate() - if proc.returncode != 0: - msg = "Server invoke error:\n" - msg += out - raise RuntimeError(msg) - -if os.name == 'nt': - def start_server_from_pool(host, port, port_end, is_proxy, use_popen, - tracker_addr, key, load_library, custom_addr, silent): - def run(): - server = Server(host, - port, - port_end, - key=key, - tracker_addr=tracker_addr, - load_library=load_library, - custom_addr=custom_addr, - silent=silent) - t = threading.Thread(target=run) - t.daemon = True - t.start() - -class Server(object): - """Start RPC server on a separate process. - - This is a simple python implementation based on multi-processing. - It is also possible to implement a similar C based server with - TVM runtime which does not depend on the python. - - Parameters - ---------- - host : str - The host url of the server. - - port : int - The port to be bind to - - port_end : int, optional - The end port to search - - is_proxy : bool, optional - Whether the address specified is a proxy. - If this is true, the host and port actually corresponds to the - address of the proxy server. - - use_popen : bool, optional - Whether to use Popen to start a fresh new process instead of fork. - This is recommended to switch on if we want to do local RPC demonstration - for GPU devices to avoid fork safety issues. - - tracker_addr: Tuple (str, int) , optional - The address of RPC Tracker in tuple(host, ip) format. - If is not None, the server will register itself to the tracker. - - key : str, optional - The key used to identify the device type in tracker. - - load_library : str, optional - List of additional libraries to be loaded during execution. - - custom_addr: str, optional - Custom IP Address to Report to RPC Tracker - - silent: bool, optional - Whether run this server in silent mode. - """ - def __init__(self, - host, - port=9091, - port_end=9199, - is_proxy=False, - use_popen=False, - tracker_addr=None, - key="", - load_library=None, - custom_addr=None, - silent=False): - try: - if base._ServerLoop is None: - raise RuntimeError("Please compile with USE_RPC=1") - except NameError: - raise RuntimeError("Please compile with USE_RPC=1") - self.host = host - self.port = port - self.libs = [] - self.custom_addr = custom_addr - self.use_popen = use_popen - self.proc = None - - if silent: - logger.setLevel(logging.ERROR) - - if use_popen: - cmd = [sys.executable, - "-m", "tvm.exec.rpc_server", - "--host=%s" % host, - "--port=%s" % port] - if tracker_addr: - assert key - cmd += ["--tracker=%s:%d" % tracker_addr, - "--key=%s" % key] - if load_library: - cmd += ["--load-library", load_library] - if custom_addr: - cmd += ["--custom-addr", custom_addr] - if silent: - cmd += ["--silent"] - - if os.name == 'nt': - self.proc = ProcessPool(1) - self.proc.apply(start_server_from_pool, args=(host, port, port_end, is_proxy, - use_popen, tracker_addr, key, load_library, custom_addr, silent)) - else: - # prexec_fn is not thread safe and may result in deadlock. - # python 3.2 introduced the start_new_session parameter as - # an alternative to the common use case of - # prexec_fn=os.setsid. Once the minimum version of python - # supported by TVM reaches python 3.2 this code can be - # rewritten in favour of start_new_session. In the - # interim, stop the pylint diagnostic. - # - # pylint: disable=subprocess-popen-preexec-fn - self.proc = subprocess.Popen(cmd, preexec_fn=os.setsid) - time.sleep(0.5) - elif not is_proxy: - sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM) - self.port = None - for my_port in range(port, port_end): - try: - sock.bind((host, my_port)) - self.port = my_port - break - except socket.error as sock_err: - sock_errno = sock_err.errno - if os.name == 'nt': - # Win32 socket codes are offset 10000 - sock_errno -= 10000 - if sock_errno in [98, 48]: - continue - else: - raise sock_err - if not self.port: - raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end)) - logger.info("bind to %s:%d", host, self.port) - sock.listen(1) - self.sock = sock - - if os.name == 'nt': - _listen_loop(self.sock, self.port, key, tracker_addr, load_library, self.custom_addr) - else: - self.proc = multiprocessing.Process( - target=_listen_loop, args=( - self.sock, self.port, key, tracker_addr, load_library, - self.custom_addr)) - self.proc.deamon = True - self.proc.start() - else: - self.proc = multiprocessing.Process( - target=_connect_proxy_loop, args=((host, port), key, load_library)) - self.proc.deamon = True - self.proc.start() - - def terminate(self): - """Terminate the server process""" - if self.use_popen: - if self.proc: - if os.name == 'nt': - self.proc.terminate() - else: - os.killpg(self.proc.pid, signal.SIGTERM) - self.proc = None - else: - if self.proc: - self.proc.terminate() - self.proc = None - - def __del__(self): - self.terminate() +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""RPC server implementation. + +Note +---- +Server is TCP based with the following protocol: +- Initial handshake to the peer + - [RPC_MAGIC, keysize(int32), key-bytes] +- The key is in format + - {server|client}:device-type[:random-key] [-timeout=timeout] +""" +# pylint: disable=invalid-name + +from __future__ import absolute_import + +import os +import ctypes +import socket +import select +import struct +import logging +import multiprocessing +import subprocess +import time +import sys +import signal + +if os.name == 'nt': + from pathos.helpers import ProcessPool + import threading + import multiprocessing.pool + +from .._ffi.function import register_func +from .._ffi.base import py_str +from .._ffi.libinfo import find_lib_path +from ..module import load as _load_module +from ..contrib import util +from . import base +from . base import TrackerCode + +logger = logging.getLogger('RPCServer') + +_temp = None + +class NoDaemonProcess(multiprocessing.Process): + # make 'daemon' attribute always return False + def _get_daemon(self): + return False + def _set_daemon(self, value): + pass + daemon = property(_get_daemon, _set_daemon) + +# We sub-class multiprocessing.pool.Pool instead of multiprocessing.Pool +# because the latter is only a wrapper function, not a proper class. +class MyPool(multiprocessing.pool.Pool): + Process = NoDaemonProcess + +# pylint: disable=unused-variable +@register_func("tvm.rpc.server.workpath", override=True) +def get_workpath(path): + global _temp + return _temp.relpath(path) + +@register_func("tvm.rpc.server.load_module", override=True) +def load_module(file_name): + """Load module from remote side.""" + global _temp + path = _temp.relpath(file_name) + m = _load_module(path) + logger.info("load_module %s", path) + return m + +def _server_env(load_library, work_path=None): + """Server environment function return temp dir""" + global _temp + if work_path: + temp = work_path + else: + temp = util.tempdir() + + _temp = temp + libs = [] + load_library = load_library.split(":") if load_library else [] + for file_name in load_library: + file_name = find_lib_path(file_name)[0] + libs.append(ctypes.CDLL(file_name, ctypes.RTLD_GLOBAL)) + logger.info("Load additional library %s", file_name) + temp.libs = libs + return temp + +def _serve_loop(sock, addr, load_library, work_path=None): + """Server loop""" + try: + sockfd = sock.fileno() + temp = _server_env(load_library, work_path) + base._ServerLoop(sockfd) + if not work_path: + temp.remove() + except Exception as ex: + print(ex) + pass + + logger.info("Finish serving %s", addr) + +def _serve_loop_pool(args): + """Server loop""" + sock = args["sock"] + addr = args["addr"] + load_library = args["load_library"] + work_path = args["work_path"] + + try: + sockfd = sock.fileno() + temp = _server_env(load_library, work_path) + base._ServerLoop(sockfd) + if not work_path: + temp.remove() + except Exception as ex: + print(ex) + pass + + logger.info("Finish serving %s", addr) + +def _parse_server_opt(opts): + # parse client options + ret = {} + for kv in opts: + if kv.startswith("-timeout="): + ret["timeout"] = float(kv[9:]) + return ret + +trial_counter = 0 + +def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr): + global trial_counter + executerPool = MyPool(processes=1) + + """Listening loop of the server master.""" + def _accept_conn(listen_sock, tracker_conn, ping_period=2): + """Accept connection from the other places. + + Parameters + ---------- + listen_sock: Socket + The socket used by listening process. + + tracker_conn : connnection to tracker + Tracker connection + + ping_period : float, optional + ping tracker every k seconds if no connection is accepted. + """ + old_keyset = set() + # Report resource to tracker + if tracker_conn: + matchkey = base.random_key(rpc_key + ":") + base.sendjson(tracker_conn, + [TrackerCode.PUT, rpc_key, (port, matchkey), custom_addr]) + assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS + else: + matchkey = rpc_key + + unmatch_period_count = 0 + unmatch_timeout = 4 + # Wait until we get a valid connection + while True: + if tracker_conn: + trigger = select.select([listen_sock], [], [], ping_period) + if not listen_sock in trigger[0]: + base.sendjson(tracker_conn, [TrackerCode.GET_PENDING_MATCHKEYS]) + pending_keys = base.recvjson(tracker_conn) + old_keyset.add(matchkey) + # if match key not in pending key set + # it means the key is acquired by a client but not used. + if matchkey not in pending_keys: + unmatch_period_count += 1 + else: + unmatch_period_count = 0 + # regenerate match key if key is acquired but not used for a while + if unmatch_period_count * ping_period > unmatch_timeout + ping_period: + logger.info("no incoming connections, regenerate key ...") + matchkey = base.random_key(rpc_key + ":", old_keyset) + base.sendjson(tracker_conn, + [TrackerCode.PUT, rpc_key, (port, matchkey), + custom_addr]) + assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS + unmatch_period_count = 0 + continue + conn, addr = listen_sock.accept() + magic = struct.unpack(" max_retry: + raise RuntimeError("Maximum retry error: last error: %s" % str(err)) + time.sleep(retry_period) + +def _popen(cmd): + proc = subprocess.Popen(cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=os.environ) + (out, _) = proc.communicate() + if proc.returncode != 0: + msg = "Server invoke error:\n" + msg += out + raise RuntimeError(msg) + +if os.name == 'nt': + def start_server_from_pool(host, port, port_end, is_proxy, use_popen, + tracker_addr, key, load_library, custom_addr, silent): + def run(): + server = Server(host, + port, + port_end, + key=key, + tracker_addr=tracker_addr, + load_library=load_library, + custom_addr=custom_addr, + silent=silent) + t = threading.Thread(target=run) + t.daemon = True + t.start() + +class Server(object): + """Start RPC server on a separate process. + + This is a simple python implementation based on multi-processing. + It is also possible to implement a similar C based server with + TVM runtime which does not depend on the python. + + Parameters + ---------- + host : str + The host url of the server. + + port : int + The port to be bind to + + port_end : int, optional + The end port to search + + is_proxy : bool, optional + Whether the address specified is a proxy. + If this is true, the host and port actually corresponds to the + address of the proxy server. + + use_popen : bool, optional + Whether to use Popen to start a fresh new process instead of fork. + This is recommended to switch on if we want to do local RPC demonstration + for GPU devices to avoid fork safety issues. + + tracker_addr: Tuple (str, int) , optional + The address of RPC Tracker in tuple(host, ip) format. + If is not None, the server will register itself to the tracker. + + key : str, optional + The key used to identify the device type in tracker. + + load_library : str, optional + List of additional libraries to be loaded during execution. + + custom_addr: str, optional + Custom IP Address to Report to RPC Tracker + + silent: bool, optional + Whether run this server in silent mode. + """ + def __init__(self, + host, + port=9091, + port_end=9199, + is_proxy=False, + use_popen=False, + tracker_addr=None, + key="", + load_library=None, + custom_addr=None, + silent=False): + try: + if base._ServerLoop is None: + raise RuntimeError("Please compile with USE_RPC=1") + except NameError: + raise RuntimeError("Please compile with USE_RPC=1") + self.host = host + self.port = port + self.libs = [] + self.custom_addr = custom_addr + self.use_popen = use_popen + self.proc = None + + if silent: + logger.setLevel(logging.ERROR) + + if use_popen: + cmd = [sys.executable, + "-m", "tvm.exec.rpc_server", + "--host=%s" % host, + "--port=%s" % port] + if tracker_addr: + assert key + cmd += ["--tracker=%s:%d" % tracker_addr, + "--key=%s" % key] + if load_library: + cmd += ["--load-library", load_library] + if custom_addr: + cmd += ["--custom-addr", custom_addr] + if silent: + cmd += ["--silent"] + + if os.name == 'nt': + self.proc = ProcessPool(1) + self.proc.apply(start_server_from_pool, args=(host, port, port_end, is_proxy, + use_popen, tracker_addr, key, load_library, custom_addr, silent)) + else: + # prexec_fn is not thread safe and may result in deadlock. + # python 3.2 introduced the start_new_session parameter as + # an alternative to the common use case of + # prexec_fn=os.setsid. Once the minimum version of python + # supported by TVM reaches python 3.2 this code can be + # rewritten in favour of start_new_session. In the + # interim, stop the pylint diagnostic. + # + # pylint: disable=subprocess-popen-preexec-fn + self.proc = subprocess.Popen(cmd, preexec_fn=os.setsid) + time.sleep(0.5) + elif not is_proxy: + sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM) + self.port = None + for my_port in range(port, port_end): + try: + sock.bind((host, my_port)) + self.port = my_port + break + except socket.error as sock_err: + sock_errno = sock_err.errno + if os.name == 'nt': + # Win32 socket codes are offset 10000 + sock_errno -= 10000 + if sock_errno in [98, 48]: + continue + else: + raise sock_err + if not self.port: + raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end)) + logger.info("bind to %s:%d", host, self.port) + sock.listen(1) + self.sock = sock + + if os.name == 'nt': + _listen_loop(self.sock, self.port, key, tracker_addr, load_library, self.custom_addr) + else: + self.proc = multiprocessing.Process( + target=_listen_loop, args=( + self.sock, self.port, key, tracker_addr, load_library, + self.custom_addr)) + self.proc.deamon = True + self.proc.start() + else: + self.proc = multiprocessing.Process( + target=_connect_proxy_loop, args=((host, port), key, load_library)) + self.proc.deamon = True + self.proc.start() + + def terminate(self): + """Terminate the server process""" + if self.use_popen: + if self.proc: + if os.name == 'nt': + self.proc.terminate() + else: + os.killpg(self.proc.pid, signal.SIGTERM) + self.proc = None + else: + if self.proc: + self.proc.terminate() + self.proc = None + + def __del__(self): + self.terminate() From c31baf210fda5f17ca18c7f6ce56f79a6ff361c2 Mon Sep 17 00:00:00 2001 From: Jeremiah Morrill Date: Tue, 17 Dec 2019 15:33:04 -0800 Subject: [PATCH 16/33] Cleanup server.py --- python/tvm/rpc/server.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py index 52fc4c520290..ef7db3bad654 100644 --- a/python/tvm/rpc/server.py +++ b/python/tvm/rpc/server.py @@ -144,12 +144,12 @@ def _parse_server_opt(opts): ret["timeout"] = float(kv[9:]) return ret -trial_counter = 0 +# For Windows +_trial_counter = 0 +# For Windows +_executor_pool = None def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr): - global trial_counter - executerPool = MyPool(processes=1) - """Listening loop of the server master.""" def _accept_conn(listen_sock, tracker_conn, ping_period=2): """Accept connection from the other places. @@ -278,10 +278,15 @@ def handle_posix(): # terminate the worker server_proc.terminate() def handle_win32(): - global trial_counter - nonlocal executerPool + global _trial_counter + global _executor_pool + + if _trial_counter % 5 == 0 or _executor_pool == None: + if _executor_pool != None: + _executor_pool.terminate() + _executor_pool = MyPool(processes=1) - trial_counter += 1 + _trial_counter += 1 args = { "sock" : conn, @@ -290,10 +295,7 @@ def handle_win32(): "work_path" : work_path } - executerPool.map(_serve_loop_pool, [args]) - if trial_counter % 1 == 0: - executerPool.terminate() - executerPool = MyPool(processes=1) + _executor_pool.map(_serve_loop_pool, [args]) try: conn.close() From 7cb0e33dc0d7a6453fc3d7d59fac26246f43c836 Mon Sep 17 00:00:00 2001 From: Jeremiah Morrill Date: Tue, 17 Dec 2019 15:33:37 -0800 Subject: [PATCH 17/33] Undo some formatting --- apps/cpp_rpc/rpc_server.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/apps/cpp_rpc/rpc_server.cc b/apps/cpp_rpc/rpc_server.cc index 7f05e42ae721..f018680901ef 100644 --- a/apps/cpp_rpc/rpc_server.cc +++ b/apps/cpp_rpc/rpc_server.cc @@ -356,7 +356,7 @@ void RPCServerCreate(std::string host, int port, int port_end, std::string track TVM_REGISTER_GLOBAL("rpc._ServerCreate") .set_body([](TVMArgs args, TVMRetValue* rv) { - RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6]); -}); + RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6]); + }); } // namespace runtime } // namespace tvm From e2d1e8fed992c9520e3f0666a16f96c26bd92986 Mon Sep 17 00:00:00 2001 From: Jeremiah Morrill Date: Tue, 17 Dec 2019 15:34:27 -0800 Subject: [PATCH 18/33] Update comments and sleep time in xgboost_cost_model.py --- python/tvm/autotvm/tuner/xgboost_cost_model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tvm/autotvm/tuner/xgboost_cost_model.py b/python/tvm/autotvm/tuner/xgboost_cost_model.py index 4f825306d970..b9ae1764382e 100644 --- a/python/tvm/autotvm/tuner/xgboost_cost_model.py +++ b/python/tvm/autotvm/tuner/xgboost_cost_model.py @@ -167,8 +167,8 @@ def _reset_pool(self, space, target, task): # Pool's process side, where the *nix impl simply sets globals # then forks. # To ensure each process in the pool is properly set, we have to do - # some synchronization by sending an async call and waiting for - # the queue to have an item set + # some synchronization by sending an async call, setting space, target and task, then + # waiting for the queue to have an item set pool_size = self.num_threads if self.num_threads != None else multiprocessing.cpu_count() if self.pool == None: self.pool = ProcessPool(pool_size) @@ -178,7 +178,7 @@ def _reset_pool(self, space, target, task): # A simple pathos.map would be cleaner, but it seems that in some cases, # some of the pools processes will be missed, with some processes running # the method twice. It seems that just passing a Queue in this manner, - # hits all the processes in the pool. Some assertion should be built to verify + # hits all the processes in the pool. Some assertion could be built to verify for i in range(pool_size): queue = manager.Queue(1) results = { @@ -197,7 +197,7 @@ def _reset_pool(self, space, target, task): if all_ready: break; else: - time.sleep(0.1) + time.sleep(0.05) # complete the async requests on the pool for pipe_sync in pipe_syncs: pipe_sync["apipe"].get() From 514beac36ffd4f479ae99c4e8cc1f1b1412cfa98 Mon Sep 17 00:00:00 2001 From: Jeremiah Morrill Date: Tue, 17 Dec 2019 21:39:50 -0800 Subject: [PATCH 19/33] Removed unneeded print(...) and exception handling from server.py --- python/tvm/rpc/server.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py index ef7db3bad654..d6c12483f9e9 100644 --- a/python/tvm/rpc/server.py +++ b/python/tvm/rpc/server.py @@ -105,15 +105,11 @@ def _server_env(load_library, work_path=None): def _serve_loop(sock, addr, load_library, work_path=None): """Server loop""" - try: - sockfd = sock.fileno() - temp = _server_env(load_library, work_path) - base._ServerLoop(sockfd) - if not work_path: - temp.remove() - except Exception as ex: - print(ex) - pass + sockfd = sock.fileno() + temp = _server_env(load_library, work_path) + base._ServerLoop(sockfd) + if not work_path: + temp.remove() logger.info("Finish serving %s", addr) @@ -130,8 +126,7 @@ def _serve_loop_pool(args): base._ServerLoop(sockfd) if not work_path: temp.remove() - except Exception as ex: - print(ex) + except Exception: pass logger.info("Finish serving %s", addr) From 98a56f078a28d984f051497871def4a2fed7e38a Mon Sep 17 00:00:00 2001 From: Jeremiah Morrill Date: Fri, 20 Dec 2019 13:24:34 -0800 Subject: [PATCH 20/33] Removed redundant blank line at the start of a code block --- src/runtime/rpc/rpc_socket_impl.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/runtime/rpc/rpc_socket_impl.cc b/src/runtime/rpc/rpc_socket_impl.cc index 2b7edbd8d45d..74e61fe5af20 100644 --- a/src/runtime/rpc/rpc_socket_impl.cc +++ b/src/runtime/rpc/rpc_socket_impl.cc @@ -40,7 +40,6 @@ class SockChannel final : public RPCChannel { sock_.Close(); } } catch (...) { - } } size_t Send(const void* data, size_t size) final { From 598b27afe102cfa938ef01ce459e1ce2067dde73 Mon Sep 17 00:00:00 2001 From: Jeremiah Morrill Date: Fri, 20 Dec 2019 14:48:02 -0800 Subject: [PATCH 21/33] Fix some python linter issues --- python/tvm/autotvm/measure/local_executor.py | 15 ++++++----- python/tvm/autotvm/task/task.py | 4 +-- .../tvm/autotvm/tuner/xgboost_cost_model.py | 17 ++++++------ python/tvm/exec/rpc_tracker.py | 2 +- python/tvm/rpc/base.py | 6 ++--- python/tvm/rpc/server.py | 27 +++++++++++-------- python/tvm/rpc/tracker.py | 2 +- 7 files changed, 40 insertions(+), 33 deletions(-) diff --git a/python/tvm/autotvm/measure/local_executor.py b/python/tvm/autotvm/measure/local_executor.py index 6784a7b1e083..88453406e265 100644 --- a/python/tvm/autotvm/measure/local_executor.py +++ b/python/tvm/autotvm/measure/local_executor.py @@ -36,7 +36,7 @@ # Since there is no fork() on Windows, to mitigate performance impact # we will use a process pool for executers, vs the *nix based systems # that will fork() a new process for each executor - executor_pool = None + _executor_pool = None from multiprocessing import Process, Queue, cpu_count try: @@ -89,7 +89,7 @@ def call_with_timeout(queue, timeout, func, args, kwargs): p.join() if os.name == 'nt': - def call_from_pool(func, args, kwargs, timeout, env): + def call_from_pool(func, args, kwargs, timeout, env): # pylint: disable=unused-argument """A wrapper to support timeout of a function call for a pool process""" # Restore environment variables from parent @@ -213,20 +213,21 @@ def submit(self, func, *args, **kwargs): if os.name != 'nt': queue = Queue(2) process = Process(target=call_with_timeout, - args=(queue, self.timeout, func, args, kwargs)) + args=(queue, self.timeout, func, args, kwargs)) process.start() return LocalFuture(process, queue) else: - global executor_pool + global _executor_pool - if executor_pool is None: + if _executor_pool is None: # We use a static pool for executor processes because Process.start(entry) # is so slow on Windows, we lose a lot of parallelism. # Right now cpu_count() is used, which isn't optimal from a user configuration # perspective, but is reasonable at this time. - executor_pool = ProcessPool(cpu_count() * 2) + _executor_pool = ProcessPool(cpu_count() * 2) # Windows seemed to be missing some valuable environ variables # on the pool's process side. We might be able to get away with # just sending the PATH variable, but for now, we just clone our env - return LocalFuturePool(executor_pool.apply_async(call_from_pool, (func, args, kwargs, self.timeout, os.environ.copy()))) + return LocalFuturePool(_executor_pool.apply_async(call_from_pool, + (func, args, kwargs, self.timeout, os.environ.copy()))) diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py index 6ab371b5d653..0a73590362f7 100644 --- a/python/tvm/autotvm/task/task.py +++ b/python/tvm/autotvm/task/task.py @@ -22,8 +22,8 @@ registers the standard task. """ -import numpy as np import os +import numpy as np from ... import tensor, expr, container, target as _target from ..util import get_const_int, get_const_tuple, get_func_name @@ -195,7 +195,7 @@ def create(func_name, args, target, target_host=None, template_key=None): try: # getattr will throw here on Windows, as of an Oct 2019 commit ret.config_space.code_hash = getattr(sch, 'code_hash', None) - except: + except: # pylint: disable=bare-except ret.config_space.code_hash = None ret.workload = ctx.workload diff --git a/python/tvm/autotvm/tuner/xgboost_cost_model.py b/python/tvm/autotvm/tuner/xgboost_cost_model.py index b9ae1764382e..a8824020aa11 100644 --- a/python/tvm/autotvm/tuner/xgboost_cost_model.py +++ b/python/tvm/autotvm/tuner/xgboost_cost_model.py @@ -28,7 +28,6 @@ # support fork() from pathos.helpers import mp as pathos_multiprocess from pathos.helpers import ProcessPool - import pathos.multiprocessing import numpy as np try: @@ -169,8 +168,9 @@ def _reset_pool(self, space, target, task): # To ensure each process in the pool is properly set, we have to do # some synchronization by sending an async call, setting space, target and task, then # waiting for the queue to have an item set - pool_size = self.num_threads if self.num_threads != None else multiprocessing.cpu_count() - if self.pool == None: + num_threads = self.num_threads + pool_size = num_threads if num_threads != None else multiprocessing.cpu_count() + if self.pool is None: self.pool = ProcessPool(pool_size) manager = pathos_multiprocess.Manager() pipe_syncs = [] @@ -179,11 +179,12 @@ def _reset_pool(self, space, target, task): # some of the pools processes will be missed, with some processes running # the method twice. It seems that just passing a Queue in this manner, # hits all the processes in the pool. Some assertion could be built to verify - for i in range(pool_size): + for _ in range(pool_size): queue = manager.Queue(1) results = { - "queue": queue, - "apipe": self.pool.apply_async(_set_pool_process_state, args=(space, target, task, queue)) + "queue": queue, + "apipe": self.pool.apply_async(_set_pool_process_state, + args=(space, target, task, queue)) } pipe_syncs.append(results) @@ -195,7 +196,7 @@ def _reset_pool(self, space, target, task): all_ready = False break if all_ready: - break; + break else: time.sleep(0.05) # complete the async requests on the pool @@ -214,7 +215,7 @@ def _reset_pool(self, space, target, task): def _close_pool(self, force_close=False): if os.name == 'nt' and not force_close: return - + if self.pool: self.pool.terminate() self.pool.join() diff --git a/python/tvm/exec/rpc_tracker.py b/python/tvm/exec/rpc_tracker.py index f57326834668..6a433b7c9d6b 100644 --- a/python/tvm/exec/rpc_tracker.py +++ b/python/tvm/exec/rpc_tracker.py @@ -28,7 +28,7 @@ def main(args): """Main funciton""" tracker = Tracker(args.host, port=args.port, port_end=args.port_end, silent=args.silent) - if os.name =='nt': + if os.name == 'nt': while True: input() else: diff --git a/python/tvm/rpc/base.py b/python/tvm/rpc/base.py index 683b8fe21700..98861995a122 100644 --- a/python/tvm/rpc/base.py +++ b/python/tvm/rpc/base.py @@ -60,6 +60,7 @@ class TrackerCode(object): def get_addr_family(addr): + """Gets the address family""" if os.name == 'nt': # WINDOWS CANNOT USE THE *NIX IMPL OF THIS! FUNCTION SUCCEEDS AND WORKS # BUT IT CAUSES MAJOR PROBLEMS. IT LEAVES MYSTERIOUS REFERENCES THAT ARE @@ -69,9 +70,8 @@ def get_addr_family(addr): # This isn't a 1:1 of the *nix implementation, should probably # take a closer look as it probably doesn't work with IPV6 addresses return socket.AF_INET - else: - res = socket.getaddrinfo(addr[0], addr[1], 0, 0, socket.IPPROTO_TCP) - return res[0][0] + res = socket.getaddrinfo(addr[0], addr[1], 0, 0, socket.IPPROTO_TCP) + return res[0][0] def recvall(sock, nbytes): diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py index d6c12483f9e9..2bfe6fe1ba05 100644 --- a/python/tvm/rpc/server.py +++ b/python/tvm/rpc/server.py @@ -117,7 +117,7 @@ def _serve_loop_pool(args): """Server loop""" sock = args["sock"] addr = args["addr"] - load_library = args["load_library"] + load_library = args["load_library"] work_path = args["work_path"] try: @@ -126,7 +126,7 @@ def _serve_loop_pool(args): base._ServerLoop(sockfd) if not work_path: temp.remove() - except Exception: + except Exception: # pylint: disable=broad-except pass logger.info("Finish serving %s", addr) @@ -242,7 +242,7 @@ def _accept_conn(listen_sock, tracker_conn, ping_period=2): tracker_conn.close() tracker_conn = None continue - except RuntimeError as exc: + except RuntimeError as exc: # pylint: disable=broad-except raise exc # step 3: serving @@ -250,8 +250,9 @@ def _accept_conn(listen_sock, tracker_conn, ping_period=2): logger.info("connection from %s", addr) def handle_posix(): + """Handles serving on non-Windows OS""" server_proc = multiprocessing.Process(target=_serve_loop, - args=(conn, addr, load_library, work_path)) + args=(conn, addr, load_library, work_path)) server_proc.deamon = True server_proc.start() # close from our side. @@ -267,15 +268,16 @@ def handle_posix(): # this can throw on Windows for child in parent.children(recursive=True): child.terminate() - except: # pylint: disable=broad-except + except Exception: # pylint: disable=broad-except pass # terminate the worker server_proc.terminate() def handle_win32(): + """Handles serving on Windows OS""" global _trial_counter global _executor_pool - + if _trial_counter % 5 == 0 or _executor_pool == None: if _executor_pool != None: _executor_pool.terminate() @@ -295,7 +297,7 @@ def handle_win32(): try: conn.close() conn.shutdown(1) - except: + except Exception: # pylint: disable=broad-except pass if os.name != 'nt': @@ -358,13 +360,16 @@ def _popen(cmd): raise RuntimeError(msg) if os.name == 'nt': - def start_server_from_pool(host, port, port_end, is_proxy, use_popen, - tracker_addr, key, load_library, custom_addr, silent): + def start_server_from_pool(host, port, port_end, is_proxy, use_popen, + tracker_addr, key, load_library, custom_addr, silent): + """Starts the RPC server from within a process pool""" def run(): server = Server(host, port, port_end, key=key, + is_proxy=is_proxy, + use_popen=use_popen, tracker_addr=tracker_addr, load_library=load_library, custom_addr=custom_addr, @@ -439,7 +444,7 @@ def __init__(self, self.custom_addr = custom_addr self.use_popen = use_popen self.proc = None - + if silent: logger.setLevel(logging.ERROR) @@ -462,7 +467,7 @@ def __init__(self, if os.name == 'nt': self.proc = ProcessPool(1) self.proc.apply(start_server_from_pool, args=(host, port, port_end, is_proxy, - use_popen, tracker_addr, key, load_library, custom_addr, silent)) + use_popen, tracker_addr, key, load_library, custom_addr, silent)) else: # prexec_fn is not thread safe and may result in deadlock. # python 3.2 introduced the start_new_session parameter as diff --git a/python/tvm/rpc/tracker.py b/python/tvm/rpc/tracker.py index d70bb6114c45..a2eb47b99039 100644 --- a/python/tvm/rpc/tracker.py +++ b/python/tvm/rpc/tracker.py @@ -444,7 +444,7 @@ def _stop_tracker(self): def terminate(self): """Terminate the server process""" if self.proc: - if os.name =='nt': + if os.name == 'nt': self.proc.close() self.proc.join() else: From da0c6102118e61ca4fb3ca90b3ef12e398c6f545 Mon Sep 17 00:00:00 2001 From: Jeremiah Morrill Date: Fri, 20 Dec 2019 15:42:43 -0800 Subject: [PATCH 22/33] Fixed more pylint warnings --- python/tvm/autotvm/measure/local_executor.py | 36 +++++++++--------- python/tvm/autotvm/task/task.py | 3 +- .../tvm/autotvm/tuner/xgboost_cost_model.py | 14 +++---- python/tvm/rpc/base.py | 2 +- python/tvm/rpc/server.py | 38 ++++++++++--------- python/tvm/rpc/tracker.py | 6 +-- 6 files changed, 52 insertions(+), 47 deletions(-) diff --git a/python/tvm/autotvm/measure/local_executor.py b/python/tvm/autotvm/measure/local_executor.py index 88453406e265..0c675af83162 100644 --- a/python/tvm/autotvm/measure/local_executor.py +++ b/python/tvm/autotvm/measure/local_executor.py @@ -36,8 +36,9 @@ # Since there is no fork() on Windows, to mitigate performance impact # we will use a process pool for executers, vs the *nix based systems # that will fork() a new process for each executor - _executor_pool = None + EXECUTOR_POOL = None +#pylint: disable=wrong-import-position from multiprocessing import Process, Queue, cpu_count try: from queue import Empty @@ -50,7 +51,7 @@ psutil = None from . import executor - +#pylint: enable=wrong-import-position def kill_child_processes(parent_pid, sig=signal.SIGTERM): """kill all child processes recursively""" @@ -216,18 +217,19 @@ def submit(self, func, *args, **kwargs): args=(queue, self.timeout, func, args, kwargs)) process.start() return LocalFuture(process, queue) - else: - global _executor_pool - - if _executor_pool is None: - # We use a static pool for executor processes because Process.start(entry) - # is so slow on Windows, we lose a lot of parallelism. - # Right now cpu_count() is used, which isn't optimal from a user configuration - # perspective, but is reasonable at this time. - _executor_pool = ProcessPool(cpu_count() * 2) - - # Windows seemed to be missing some valuable environ variables - # on the pool's process side. We might be able to get away with - # just sending the PATH variable, but for now, we just clone our env - return LocalFuturePool(_executor_pool.apply_async(call_from_pool, - (func, args, kwargs, self.timeout, os.environ.copy()))) + + global EXECUTOR_POOL + + if EXECUTOR_POOL is None: + # We use a static pool for executor processes because Process.start(entry) + # is so slow on Windows, we lose a lot of parallelism. + # Right now cpu_count() is used, which isn't optimal from a user configuration + # perspective, but is reasonable at this time. + EXECUTOR_POOL = ProcessPool(cpu_count() * 2) + + # Windows seemed to be missing some valuable environ variables + # on the pool's process side. We might be able to get away with + # just sending the PATH variable, but for now, we just clone our env + return LocalFuturePool(EXECUTOR_POOL.apply_async(call_from_pool, + (func, args, kwargs, + self.timeout, os.environ.copy()))) diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py index 0a73590362f7..2c88182d0315 100644 --- a/python/tvm/autotvm/task/task.py +++ b/python/tvm/autotvm/task/task.py @@ -108,7 +108,8 @@ def __setstate__(self, state): self.kwargs = state["kwargs"] self.config_space = state["config_space"] # Use pickled function on Windows - self.func = state["func"] if os.name == 'nt' else TASK_TABLE.get(state["name"], _raise_error) + self.func = state["func"] if os.name == 'nt' else \ + TASK_TABLE.get(state["name"], _raise_error) self.workload = state["workload"] self.flop = state["flop"] self.target = state["target"] diff --git a/python/tvm/autotvm/tuner/xgboost_cost_model.py b/python/tvm/autotvm/tuner/xgboost_cost_model.py index a8824020aa11..7c530dc362e9 100644 --- a/python/tvm/autotvm/tuner/xgboost_cost_model.py +++ b/python/tvm/autotvm/tuner/xgboost_cost_model.py @@ -28,7 +28,7 @@ # support fork() from pathos.helpers import mp as pathos_multiprocess from pathos.helpers import ProcessPool - +#pylint: disable=wrong-import-position import numpy as np try: import xgboost as xgb @@ -39,7 +39,7 @@ from ..util import get_rank from .metric import max_curve, recall_curve, cover_curve from .model_based_tuner import CostModel, FeatureCache - +#pylint: enable=wrong-import-position logger = logging.getLogger('autotvm') class XGBoostCostModel(CostModel): @@ -169,7 +169,7 @@ def _reset_pool(self, space, target, task): # some synchronization by sending an async call, setting space, target and task, then # waiting for the queue to have an item set num_threads = self.num_threads - pool_size = num_threads if num_threads != None else multiprocessing.cpu_count() + pool_size = num_threads if num_threads is not None else multiprocessing.cpu_count() if self.pool is None: self.pool = ProcessPool(pool_size) manager = pathos_multiprocess.Manager() @@ -192,13 +192,12 @@ def _reset_pool(self, space, target, task): while True: all_ready = True for pipe_sync in pipe_syncs: - if pipe_sync["apipe"].ready() == False: + if pipe_sync["apipe"].ready() is False: all_ready = False break if all_ready: break - else: - time.sleep(0.05) + time.sleep(0.05) # complete the async requests on the pool for pipe_sync in pipe_syncs: pipe_sync["apipe"].get() @@ -484,10 +483,11 @@ def _extract_curve_feature_log(arg): def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None, maximize=False, verbose_eval=True): """callback function for xgboost to support multiple custom evaluation functions""" + #pylint: disable=import-outside-toplevel from xgboost.core import EarlyStopException from xgboost.callback import _fmt_metric from xgboost.training import aggcv - + #pylint: enable=import-outside-toplevel state = {} metric_shortname = metric.split("-")[1] diff --git a/python/tvm/rpc/base.py b/python/tvm/rpc/base.py index 98861995a122..8ced8161b1d2 100644 --- a/python/tvm/rpc/base.py +++ b/python/tvm/rpc/base.py @@ -64,7 +64,7 @@ def get_addr_family(addr): if os.name == 'nt': # WINDOWS CANNOT USE THE *NIX IMPL OF THIS! FUNCTION SUCCEEDS AND WORKS # BUT IT CAUSES MAJOR PROBLEMS. IT LEAVES MYSTERIOUS REFERENCES THAT ARE - # HELD AND THE RPCSESSION WOULD NOT BE IMMEDIATE RELEASED, CAUSING + # HELD AND THE RPCSESSION WOULD NOT BE IMMEDIATE RELEASED, CAUSING # TIMEOUTS WITH THE RPCSERVER BECAUSE THE SOCKET IN THE C++ DIDN'T LOSE ALL # OF ITS REFERENCES. # This isn't a 1:1 of the *nix implementation, should probably diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py index 2bfe6fe1ba05..1ced95819914 100644 --- a/python/tvm/rpc/server.py +++ b/python/tvm/rpc/server.py @@ -39,10 +39,12 @@ import time import sys import signal - +import psutil +#pylint: disable=wrong-import-position if os.name == 'nt': from pathos.helpers import ProcessPool import threading + #pylint: disable=ungrouped-imports import multiprocessing.pool from .._ffi.function import register_func @@ -52,7 +54,7 @@ from ..contrib import util from . import base from . base import TrackerCode - +#pylint: enable=wrong-import-position logger = logging.getLogger('RPCServer') _temp = None @@ -67,6 +69,7 @@ def _set_daemon(self, value): # We sub-class multiprocessing.pool.Pool instead of multiprocessing.Pool # because the latter is only a wrapper function, not a proper class. +# pylint: disable=W0223 class MyPool(multiprocessing.pool.Pool): Process = NoDaemonProcess @@ -129,7 +132,7 @@ def _serve_loop_pool(args): except Exception: # pylint: disable=broad-except pass - logger.info("Finish serving %s", addr) + logger.info("Finish serving %s", addr) def _parse_server_opt(opts): # parse client options @@ -211,11 +214,10 @@ def _accept_conn(listen_sock, tracker_conn, ping_period=2): conn.close() logger.warning("mismatch key from %s", addr) continue - else: - conn.sendall(struct.pack(" Date: Fri, 20 Dec 2019 15:51:54 -0800 Subject: [PATCH 23/33] pylint fixes --- python/tvm/autotvm/measure/local_executor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/tvm/autotvm/measure/local_executor.py b/python/tvm/autotvm/measure/local_executor.py index 0c675af83162..72f54c149297 100644 --- a/python/tvm/autotvm/measure/local_executor.py +++ b/python/tvm/autotvm/measure/local_executor.py @@ -39,6 +39,7 @@ EXECUTOR_POOL = None #pylint: disable=wrong-import-position +#pylint: disable=ungrouped-imports from multiprocessing import Process, Queue, cpu_count try: from queue import Empty @@ -51,6 +52,7 @@ psutil = None from . import executor +#pylint: enable=ungrouped-imports #pylint: enable=wrong-import-position def kill_child_processes(parent_pid, sig=signal.SIGTERM): From d8b0d6752549dd5f202808cf95562a31aa84f3b8 Mon Sep 17 00:00:00 2001 From: Jeremiah Morrill Date: Fri, 20 Dec 2019 16:02:53 -0800 Subject: [PATCH 24/33] Fixup CMakeLists.txt so it uses v3.2 on everything but Windows --- CMakeLists.txt | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index e6f58da35af9..6f1bba9517d6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,9 @@ -cmake_minimum_required(VERSION 3.9) +if(WIN32) + cmake_minimum_required(VERSION 3.9) +else() + cmake_minimum_required(VERSION 3.2) +endif() + project(tvm C CXX) # Utility functions From 2154ace31c25e274a9cb3bf8960b63db21fbdcd1 Mon Sep 17 00:00:00 2001 From: Jeremiah Morrill Date: Fri, 20 Dec 2019 16:15:24 -0800 Subject: [PATCH 25/33] Fix build error with linux in python rpc server --- python/tvm/rpc/server.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py index 1ced95819914..9b6f0fd056f7 100644 --- a/python/tvm/rpc/server.py +++ b/python/tvm/rpc/server.py @@ -59,19 +59,21 @@ _temp = None -class NoDaemonProcess(multiprocessing.Process): - # make 'daemon' attribute always return False - def _get_daemon(self): - return False - def _set_daemon(self, value): - pass - daemon = property(_get_daemon, _set_daemon) +if os.name == 'nt': + class NoDaemonProcess(multiprocessing.Process): + # make 'daemon' attribute always return False + def _get_daemon(self): + return False + def _set_daemon(self, value): + pass + daemon = property(_get_daemon, _set_daemon) -# We sub-class multiprocessing.pool.Pool instead of multiprocessing.Pool -# because the latter is only a wrapper function, not a proper class. -# pylint: disable=W0223 -class MyPool(multiprocessing.pool.Pool): - Process = NoDaemonProcess +if os.name == 'nt': + # We sub-class multiprocessing.pool.Pool instead of multiprocessing.Pool + # because the latter is only a wrapper function, not a proper class. + # pylint: disable=W0223 + class MyPool(multiprocessing.pool.Pool): + Process = NoDaemonProcess # pylint: disable=unused-variable @register_func("tvm.rpc.server.workpath", override=True) From 0ab3ee9b8153112ab805a6f9a1c1cafe9acf3379 Mon Sep 17 00:00:00 2001 From: Jeremiah Morrill Date: Thu, 26 Dec 2019 09:35:24 -0800 Subject: [PATCH 26/33] Fixup CMakeLists.txt to remove dead nnvm proj setting --- CMakeLists.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 70445af89971..4efc34b03226 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -397,7 +397,6 @@ if(MSVC) set_property(TARGET tvm PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE) set_property(TARGET tvm_topi PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE) set_property(TARGET tvm_runtime PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE) - set_property(TARGET nnvm_compiler PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE) target_compile_definitions(tvm PRIVATE -DTVM_EXPORTS) target_compile_definitions(tvm_runtime PRIVATE -DTVM_EXPORTS) endif() From 7c062b7b53cf02bfbd5970894a9ce0641dae29b9 Mon Sep 17 00:00:00 2001 From: Jeremiah Morrill Date: Thu, 26 Dec 2019 10:07:31 -0800 Subject: [PATCH 27/33] Removed unneeded exception handling in task.py, now that master fixed issue --- python/tvm/autotvm/task/task.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py index 2c88182d0315..87b96c167eef 100644 --- a/python/tvm/autotvm/task/task.py +++ b/python/tvm/autotvm/task/task.py @@ -193,11 +193,7 @@ def create(func_name, args, target, target_host=None, template_key=None): with ctx: with target: sch, _ = func(*args) - try: - # getattr will throw here on Windows, as of an Oct 2019 commit - ret.config_space.code_hash = getattr(sch, 'code_hash', None) - except: # pylint: disable=bare-except - ret.config_space.code_hash = None + ret.config_space.code_hash = getattr(sch, 'code_hash', None) ret.workload = ctx.workload ret.flop = ret.config_space.flop or compute_flop(sch) From 387ff125b2bb50b8635fc0890cd66b0b8051e9be Mon Sep 17 00:00:00 2001 From: Jeremiah Morrill Date: Sat, 18 Jan 2020 13:09:46 -0800 Subject: [PATCH 28/33] Fixed changes in last merge in CPP server --- apps/cpp_rpc/rpc_server.cc | 38 +++++++------------------------------- 1 file changed, 7 insertions(+), 31 deletions(-) diff --git a/apps/cpp_rpc/rpc_server.cc b/apps/cpp_rpc/rpc_server.cc index c7fcbf340be4..ea4ab00c113b 100644 --- a/apps/cpp_rpc/rpc_server.cc +++ b/apps/cpp_rpc/rpc_server.cc @@ -32,10 +32,9 @@ #include #include -#include "../../src/common/socket.h" +#include "../../src/support/socket.h" #include "../../src/runtime/rpc/rpc_session.h" #include "../../src/runtime/rpc/rpc_socket_impl.h" -<<<<<<< HEAD #include "rpc_env.h" #include "rpc_server.h" #include "rpc_tracker_client.h" @@ -44,9 +43,6 @@ #endif using namespace std::chrono; -======= -#include "../../src/support/socket.h" ->>>>>>> upstream/master namespace tvm { namespace runtime { @@ -221,18 +217,12 @@ class RPCServer { * \param opts Parsed options for socket * \param ping_period Timeout for select call waiting */ -<<<<<<< HEAD - void AcceptConnection(TrackerClient* tracker, common::TCPSocket* conn_sock, - common::SockAddr* addr, std::string* opts, int ping_period = 2) { - std::set old_keyset; -======= - void AcceptConnection(TrackerClient* tracker, + void AcceptConnection(TrackerClient* tracker, support::TCPSocket* conn_sock, - support::SockAddr* addr, - std::string* opts, + support::SockAddr* addr, + std::string* opts, int ping_period = 2) { - std::set old_keyset; ->>>>>>> upstream/master + std::set old_keyset; std::string matchkey; // Report resource to tracker and get key @@ -296,21 +286,12 @@ class RPCServer { * \param sock The socket information * \param addr The socket address information */ -<<<<<<< HEAD - static void ServerLoopProc(common::TCPSocket sock, common::SockAddr addr) { + static void ServerLoopProc(support::TCPSocket sock, support::SockAddr addr) { // Server loop const auto env = RPCEnv(); RPCServerLoop(int(sock.sockfd)); LOG(INFO) << "Finish serving " << addr.AsString(); env.CleanUp(); -======= - void ServerLoopProc(support::TCPSocket sock, support::SockAddr addr) { - // Server loop - auto env = RPCEnv(); - RPCServerLoop(sock.sockfd); - LOG(INFO) << "Finish serving " << addr.AsString(); - env.CleanUp(); ->>>>>>> upstream/master } /*! @@ -321,13 +302,8 @@ class RPCServer { const std::string option = "-timeout="; if (opts.find(option) == 0) { -<<<<<<< HEAD const std::string cmd = opts.substr(opts.find_last_of(option) + 1); - CHECK(common::IsNumber(cmd)) << "Timeout is not valid"; -======= - cmd = opts.substr(opts.find_last_of(option) + 1); CHECK(support::IsNumber(cmd)) << "Timeout is not valid"; ->>>>>>> upstream/master return std::stoi(cmd); } return 0; @@ -351,7 +327,7 @@ class RPCServer { */ void ServerLoopFromChild(SOCKET socket) { // Server loop - tvm::common::TCPSocket sock(socket); + tvm::support::TCPSocket sock(socket); const auto env = RPCEnv(); RPCServerLoop(int(sock.sockfd)); From cbd7638134546906cfa1053aacef885cf6790a99 Mon Sep 17 00:00:00 2001 From: Jeremiah Morrill Date: Sat, 18 Jan 2020 13:10:22 -0800 Subject: [PATCH 29/33] Increased thread stack size that causes stack overflow on Windows with some models --- python/tvm/autotvm/task/relay_integration.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/tvm/autotvm/task/relay_integration.py b/python/tvm/autotvm/task/relay_integration.py index 55763afcf7bc..e62aceea7f50 100644 --- a/python/tvm/autotvm/task/relay_integration.py +++ b/python/tvm/autotvm/task/relay_integration.py @@ -154,8 +154,12 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None, # wrap build call in thread to avoid multiprocessing problems build_thread = threading.Thread(target=_lower, args=(mod, target, param)) + # Stack would overflow on some platforms (Windows) on some models + old_stack_size = threading.stack_size(1024 * 1024 * 3) build_thread.start() build_thread.join() + # Restore stacksize to original + threading.stack_size(old_stack_size) logger.disabled = old_state From 292c21897f21ab30b112d4b560183cfeb3b7af4b Mon Sep 17 00:00:00 2001 From: Jeremiah Morrill Date: Mon, 24 Feb 2020 16:58:14 -0800 Subject: [PATCH 30/33] Fixup task.py from merge --- python/tvm/autotvm/task/task.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py index 48ed8b40af9e..5971400bb2a8 100644 --- a/python/tvm/autotvm/task/task.py +++ b/python/tvm/autotvm/task/task.py @@ -174,14 +174,7 @@ def __setstate__(self, state): self.args = state["args"] self.kwargs = state["kwargs"] self.config_space = state["config_space"] -<<<<<<< HEAD - # Use pickled function on Windows - self.func = state["func"] if os.name == 'nt' else \ - TASK_TABLE.get(state["name"], _raise_error) - self.workload = state["workload"] -======= self.func = TASK_TABLE.get(state["name"], _raise_error) ->>>>>>> upstream/master self.flop = state["flop"] self.target = state["target"] self.target_host = state["target_host"] From 4a329dfe0b14f29dfdcf3c4eb41e2fbc174eaf6f Mon Sep 17 00:00:00 2001 From: Jeremiah Morrill Date: Wed, 26 Feb 2020 16:31:26 -0800 Subject: [PATCH 31/33] Change Windows untar to use python vs WSL --- apps/cpp_rpc/rpc_env.cc | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/apps/cpp_rpc/rpc_env.cc b/apps/cpp_rpc/rpc_env.cc index d12ccf8d2b4b..b5dc51b9e7ef 100644 --- a/apps/cpp_rpc/rpc_env.cc +++ b/apps/cpp_rpc/rpc_env.cc @@ -45,11 +45,23 @@ namespace { #include "rpc_env.h" namespace { + std::string GenerateUntarCommand(const std::string& tar_file, const std::string& output_dir) { + std::string untar_cmd; + untar_cmd.reserve(512); #if defined(__linux__) || defined(__ANDROID__) - const std::string untar_cmd = "tar -C "; + untar_cmd += "tar -C "; + untar_cmd += output_dir; + untar_cmd += " -zxf "; + untar_cmd += tar_file; #elif defined(_WIN32) - const std::string untar_cmd = "wsl tar -C "; -#endif + untar_cmd += "python -m tarfile -e "; + untar_cmd += tar_file; + untar_cmd += " "; + untar_cmd += output_dir; +#endif + return untar_cmd; + } + }// Anonymous namespace namespace tvm { @@ -236,7 +248,7 @@ Module Load(std::string *fileIn, const std::string& fmt) { const std::string tmp_dir = "./rpc/tmp/"; mkdir(tmp_dir.c_str(), 0777); - const std::string cmd = untar_cmd + tmp_dir + " -zxf " + file; + const std::string cmd = GenerateUntarCommand(file, tmp_dir); std::string err_msg; const int executed_status = support::Execute(cmd, &err_msg); From 85796ddacb1bef606aa9e30490d01e05c210cb7f Mon Sep 17 00:00:00 2001 From: Jeremiah Morrill Date: Fri, 6 Mar 2020 12:37:54 -0800 Subject: [PATCH 32/33] Remove export all symbols in main cmake --- CMakeLists.txt | 5 +---- src/runtime/graph/graph_runtime.h | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index deb7b457fcaf..11a1e3630fe1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -279,10 +279,6 @@ endif() add_library(tvm SHARED ${COMPILER_SRCS} ${RUNTIME_SRCS}) add_library(tvm_topi SHARED ${TOPI_SRCS}) -if(WIN32) - set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON) -endif() - add_library(tvm_runtime SHARED ${RUNTIME_SRCS}) if(USE_CXX_RPC STREQUAL "ON") @@ -407,6 +403,7 @@ endif(INSTALL_DEV) if(MSVC) set_property(TARGET tvm PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE) set_property(TARGET tvm_topi PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE) + set_property(TARGET tvm_runtime PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE) target_compile_definitions(tvm PRIVATE -DTVM_EXPORTS) target_compile_definitions(tvm_runtime PRIVATE -DTVM_EXPORTS) endif() diff --git a/src/runtime/graph/graph_runtime.h b/src/runtime/graph/graph_runtime.h index c83d68e08159..b787c0a53726 100644 --- a/src/runtime/graph/graph_runtime.h +++ b/src/runtime/graph/graph_runtime.h @@ -65,7 +65,7 @@ struct TVMOpParam { * This runtime can be acccesibly in various language via * TVM runtime PackedFunc API. */ -class GraphRuntime : public ModuleNode { +class TVM_DLL GraphRuntime : public ModuleNode { struct OpArgs { std::vector args; std::vector arg_values; From 1c786cfef10a15019afaf2550def07951b87787a Mon Sep 17 00:00:00 2001 From: Jeremiah Morrill Date: Tue, 28 Apr 2020 15:58:26 -0700 Subject: [PATCH 33/33] removed exporting for _tvm_main_ to get successful builds on Windows --- apps/cpp_rpc/rpc_env.cc | 2 +- python/tvm/contrib/cc.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/apps/cpp_rpc/rpc_env.cc b/apps/cpp_rpc/rpc_env.cc index b5dc51b9e7ef..64c1a8ac761c 100644 --- a/apps/cpp_rpc/rpc_env.cc +++ b/apps/cpp_rpc/rpc_env.cc @@ -194,7 +194,7 @@ void WindowsShared(const std::string& output, const std::string& options = "", const std::string& cc = "clang") { std::string cmd = cc; - cmd += " -O2 -flto=full -fuse-ld=lld-link -Wl,/EXPORT:__tvm_main__ -shared "; + cmd += " -O2 -flto=full -fuse-ld=lld-link -shared "; cmd += " -o " + output; for (const auto& file : files) { cmd += " " + file; diff --git a/python/tvm/contrib/cc.py b/python/tvm/contrib/cc.py index ae37923a1dcf..4a88061a8805 100644 --- a/python/tvm/contrib/cc.py +++ b/python/tvm/contrib/cc.py @@ -211,6 +211,7 @@ def _windows_shared(output, objects, options): except FileNotFoundError: raise RuntimeError("Can not find cl.exe," "please run this in Vistual Studio Command Prompt.") + print(py_str(out)) if proc.returncode != 0: msg = "Compilation error:\n" msg += py_str(out) @@ -226,7 +227,6 @@ def _windows_shared(output, objects, options): if obj.endswith(".o"): link_cmd += [obj] - link_cmd += ["-EXPORT:__tvm_main__"] link_cmd += [temp_path + "dllmain.obj"] link_cmd += ["-out:" + output]