From c47ac7431554b113b6f0a0a19f561f4206ac5147 Mon Sep 17 00:00:00 2001 From: Mercy Date: Wed, 13 Jun 2018 04:09:06 +0800 Subject: [PATCH 1/2] add silent mode to rpc server and rpc tracker --- python/tvm/contrib/rpc/base.py | 10 ++-- python/tvm/contrib/rpc/server.py | 80 +++++++++++++++++++++---------- python/tvm/contrib/rpc/tracker.py | 15 ++++-- python/tvm/exec/rpc_server.py | 10 ++-- python/tvm/exec/rpc_tracker.py | 13 +++-- src/runtime/rpc/rpc_server_env.cc | 1 - 6 files changed, 89 insertions(+), 40 deletions(-) diff --git a/python/tvm/contrib/rpc/base.py b/python/tvm/contrib/rpc/base.py index 67e6d6b43bd1..d0004f0c86f8 100644 --- a/python/tvm/contrib/rpc/base.py +++ b/python/tvm/contrib/rpc/base.py @@ -120,7 +120,7 @@ def random_key(prefix, cmap=None): return prefix + str(random.random()) -def connect_with_retry(addr, timeout=60, retry_period=5): +def connect_with_retry(addr, timeout=60, retry_period=5, silent=False): """Connect to a TPC address with retry This function is only reliable to short period of server restart. @@ -135,6 +135,9 @@ def connect_with_retry(addr, timeout=60, retry_period=5): retry_period : float Number of seconds before we retry again. + + silent: bool + whether run in silent mode """ tstart = time.time() while True: @@ -149,8 +152,9 @@ def connect_with_retry(addr, timeout=60, retry_period=5): if period > timeout: raise RuntimeError( "Failed to connect to server %s" % str(addr)) - logging.info("Cannot connect to tracker%s, retry in %g secs...", - str(addr), retry_period) + if not silent: + logging.info("Cannot connect to tracker%s, retry in %g secs...", + str(addr), retry_period) time.sleep(retry_period) diff --git a/python/tvm/contrib/rpc/server.py b/python/tvm/contrib/rpc/server.py index bef85bc5711c..97b2a42e2359 100644 --- a/python/tvm/contrib/rpc/server.py +++ b/python/tvm/contrib/rpc/server.py @@ -19,6 +19,7 @@ import multiprocessing import subprocess import time +import sys from ..._ffi.function import register_func from ..._ffi.base import py_str @@ -28,7 +29,7 @@ from . import base from . base import TrackerCode -def _server_env(load_library): +def _server_env(load_library, logger): """Server environment function return temp dir""" temp = util.tempdir() # pylint: disable=unused-variable @@ -41,7 +42,7 @@ def load_module(file_name): """Load module from remote side.""" path = temp.relpath(file_name) m = _load_module(path) - logging.info("load_module %s", path) + logger.info("load_module %s", path) return m libs = [] @@ -49,18 +50,21 @@ def load_module(file_name): for file_name in load_library: file_name = find_lib_path(file_name)[0] libs.append(ctypes.CDLL(file_name, ctypes.RTLD_GLOBAL)) - logging.info("Load additional library %s", file_name) + logger.info("Load additional library %s", file_name) temp.libs = libs return temp -def _serve_loop(sock, addr, load_library): +def _serve_loop(sock, addr, load_library, silent): """Server loop""" + logger = logging.getLogger("RPCServer") + if silent: + logger.disabled = True sockfd = sock.fileno() - temp = _server_env(load_library) + temp = _server_env(load_library, logger) base._ServerLoop(sockfd) temp.remove() - logging.info("Finish serving %s", addr) + logger.info("Finish serving %s", addr) def _parse_server_opt(opts): @@ -71,8 +75,12 @@ def _parse_server_opt(opts): ret["timeout"] = float(kv[9:]) return ret -def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr): +def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr, silent): """Lisenting loop of the server master.""" + logger = logging.getLogger("RPCServer") + if silent: + logger.disabled = True + def _accept_conn(listen_sock, tracker_conn, ping_period=2): """Accept connection from the other places. @@ -115,7 +123,7 @@ def _accept_conn(listen_sock, tracker_conn, ping_period=2): unmatch_period_count = 0 # regenerate match key if key is acquired but not used for a while if unmatch_period_count * ping_period > unmatch_timeout + ping_period: - logging.info("RPCServer: no incoming connections, regenerate key ...") + logger.info("no incoming connections, regenerate key ...") matchkey = base.random_key(rpc_key + ":", old_keyset) base.sendjson(tracker_conn, [TrackerCode.PUT, rpc_key, (port, matchkey), @@ -136,7 +144,7 @@ def _accept_conn(listen_sock, tracker_conn, ping_period=2): if arr[0] != expect_header: conn.sendall(struct.pack(" max_retry: raise RuntimeError("Maximum retry error: last error: %s" % str(err)) time.sleep(retry_period) @@ -264,6 +281,9 @@ class Server(object): This is recommended to switch on if we want to do local RPC demonstration for GPU devices to avoid fork safety issues. + silent: bool, optional + Whether run this server in silent mode. + key : str, optional The key used to identify the server in Proxy connection. @@ -276,6 +296,7 @@ def __init__(self, port_end=9199, is_proxy=False, use_popen=False, + silent=False, tracker_addr=None, key="", load_library=None, @@ -290,8 +311,12 @@ def __init__(self, self.libs = [] self.custom_addr = custom_addr + self.logger = logging.getLogger("RPCServer") + if silent: + self.logger.disabled = True + if use_popen: - cmd = ["python", + cmd = [sys.executable, "-m", "tvm.exec.rpc_server", "--host=%s" % host, "--port=%s" % port] @@ -303,11 +328,14 @@ def __init__(self, cmd += ["--load-library", load_library] if custom_addr: cmd += ["--custom-addr", custom_addr] + if silent: + cmd += ["--silent"] + self.proc = multiprocessing.Process( target=subprocess.check_call, args=(cmd,)) self.proc.deamon = True self.proc.start() - time.sleep(1) + time.sleep(0.5) elif not is_proxy: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.port = None @@ -321,19 +349,19 @@ def __init__(self, continue else: raise sock_err - if not self.port: - raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end)) - logging.info("RPCServer: bind to %s:%d", host, self.port) + if not self.port: + raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end)) + self.logger.info("bind to %s:%d", host, self.port) sock.listen(1) self.sock = sock self.proc = multiprocessing.Process( target=_listen_loop, args=( - self.sock, self.port, key, tracker_addr, load_library, self.custom_addr)) + self.sock, self.port, key, tracker_addr, load_library, self.custom_addr, silent)) self.proc.deamon = True self.proc.start() else: self.proc = multiprocessing.Process( - target=_connect_proxy_loop, args=((host, port), key, load_library)) + target=_connect_proxy_loop, args=((host, port), key, load_library, silent)) self.proc.deamon = True self.proc.start() diff --git a/python/tvm/contrib/rpc/tracker.py b/python/tvm/contrib/rpc/tracker.py index 812b3a9770ab..47e81859c934 100644 --- a/python/tvm/contrib/rpc/tracker.py +++ b/python/tvm/contrib/rpc/tracker.py @@ -309,7 +309,6 @@ def run(self): def _tracker_server(listen_sock, stop_key): handler = TrackerServerHandler(listen_sock, stop_key) handler.run() - logging.info("Tracker Stop signal received, terminating...") class Tracker(object): @@ -327,11 +326,19 @@ class Tracker(object): port_end : int, optional The end TCP port to search + + silent: bool, optional + Whether run in silent mode """ def __init__(self, host, port=9190, - port_end=9199): + port_end=9199, + silent=False): + self.logger = logging.getLogger("RPCTracker") + if silent: + self.logger.disabled = True + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.port = None self.stop_key = base.random_key("tracker") @@ -347,7 +354,7 @@ def __init__(self, raise sock_err if not self.port: raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end)) - logging.info("RPCTracker: bind to %s:%d", host, self.port) + self.logger.info("bind to %s:%d", host, self.port) sock.listen(1) self.proc = multiprocessing.Process( target=_tracker_server, args=(sock, self.stop_key)) @@ -373,7 +380,7 @@ def terminate(self): self._stop_tracker() self.proc.join(1) if self.proc.is_alive(): - logging.info("Terminating Tracker Server...") + self.logger.info("Terminating Tracker Server...") self.proc.terminate() self.proc = None diff --git a/python/tvm/exec/rpc_server.py b/python/tvm/exec/rpc_server.py index d874ed63b673..e26dbc3b2025 100644 --- a/python/tvm/exec/rpc_server.py +++ b/python/tvm/exec/rpc_server.py @@ -27,7 +27,8 @@ def main(args): key=args.key, tracker_addr=tracker_addr, load_library=args.load_library, - custom_addr=args.custom_addr) + custom_addr=args.custom_addr, + silent=args.silent) server.proc.join() @@ -51,6 +52,8 @@ def main(args): and ROCM compilers.") parser.add_argument('--custom-addr', type=str, help="Custom IP Address to Report to RPC Tracker") + parser.add_argument('--silent', action='store_true', + help="Whether run in silent mode.") parser.set_defaults(fork=True) args = parser.parse_args() @@ -62,6 +65,7 @@ def main(args): ) multiprocessing.set_start_method('spawn') else: - logging.info("If you are running ROCM/Metal, \ - fork with cause compiler internal error. Try to launch with arg ```--no-fork```") + if not args.silent: + logging.info("If you are running ROCM/Metal, fork will cause " + "compiler internal error. Try to launch with arg ```--no-fork```") main(args) diff --git a/python/tvm/exec/rpc_tracker.py b/python/tvm/exec/rpc_tracker.py index 3b76f57eb689..3e4c63e20f9b 100644 --- a/python/tvm/exec/rpc_tracker.py +++ b/python/tvm/exec/rpc_tracker.py @@ -11,7 +11,8 @@ def main(args): """Main funciton""" - tracker = Tracker(args.host, port=args.port) + tracker = Tracker(args.host, port=args.port, port_end=args.port_end, + silent=args.silent) tracker.proc.join() @@ -21,10 +22,15 @@ def main(args): help='the hostname of the tracker') parser.add_argument('--port', type=int, default=9190, help='The port of the PRC') + parser.add_argument('--port-end', type=int, default=9199, + help='The end search port of the PRC') parser.add_argument('--no-fork', dest='fork', action='store_false', help="Use spawn mode to avoid fork. This option \ is able to avoid potential fork problems with Metal, OpenCL \ and ROCM compilers.") + parser.add_argument('--silent', action='store_true', + help="Whether run in silent mode.") + parser.set_defaults(fork=True) args = parser.parse_args() logging.basicConfig(level=logging.INFO) @@ -35,6 +41,7 @@ def main(args): ) multiprocessing.set_start_method('spawn') else: - logging.info("If you are running ROCM/Metal, \ - fork with cause compiler internal error. Try to launch with arg ```--no-fork```") + if not args.silent: + logging.info("If you are running ROCM/Metal, fork will cause " + "compiler internal error. Try to launch with arg ```--no-fork```") main(args) diff --git a/src/runtime/rpc/rpc_server_env.cc b/src/runtime/rpc/rpc_server_env.cc index 17ee2abd97b1..a995a953bf79 100644 --- a/src/runtime/rpc/rpc_server_env.cc +++ b/src/runtime/rpc/rpc_server_env.cc @@ -20,7 +20,6 @@ TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.upload"). set_body([](TVMArgs args, TVMRetValue *rv) { std::string file_name = RPCGetPath(args[0]); std::string data = args[1]; - LOG(INFO) << "Upload " << file_name << "... nbytes=" << data.length(); SaveBinaryToFile(file_name, data); }); From d6786eefb5c9ef69ae7c5630e5df2c1216493031 Mon Sep 17 00:00:00 2001 From: Mercy Date: Wed, 13 Jun 2018 08:13:59 +0800 Subject: [PATCH 2/2] fix --- python/tvm/contrib/rpc/proxy.py | 2 +- python/tvm/contrib/rpc/server.py | 14 +++++++++----- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/python/tvm/contrib/rpc/proxy.py b/python/tvm/contrib/rpc/proxy.py index 315354edeb92..b1168c1a843f 100644 --- a/python/tvm/contrib/rpc/proxy.py +++ b/python/tvm/contrib/rpc/proxy.py @@ -536,7 +536,7 @@ def _fsend(data): def _connect(key): conn = yield websocket.websocket_connect(url) on_message = create_on_message(conn) - temp = _server_env(None) + temp = _server_env(None, None) # Start connecton conn.write_message(struct.pack('