From 48816559aa26e47dfa50fe7688a60a111aa59c6f Mon Sep 17 00:00:00 2001 From: KJlaccHoeUM9l Date: Wed, 6 Oct 2021 18:10:15 +0300 Subject: [PATCH 1/3] Fixed connecting the server to the tracker through a proxy --- python/tvm/rpc/proxy.py | 2 +- python/tvm/rpc/tracker.py | 5 +-- tests/python/unittest/test_runtime_rpc.py | 40 +++++++++++++++++++++++ 3 files changed, 44 insertions(+), 3 deletions(-) diff --git a/python/tvm/rpc/proxy.py b/python/tvm/rpc/proxy.py index e5ec73db51b9..c3b0056eb591 100644 --- a/python/tvm/rpc/proxy.py +++ b/python/tvm/rpc/proxy.py @@ -379,7 +379,7 @@ def _update_tracker(self, period_update=False): if need_update_info: keylist = "[" + ",".join(self._key_set) + "]" - cinfo = {"key": "server:proxy" + keylist} + cinfo = {"key": "server:proxy" + keylist, "addr": [None, self._listen_port]} base.sendjson(self._tracker_conn, [TrackerCode.UPDATE_INFO, cinfo]) assert base.recvjson(self._tracker_conn) == TrackerCode.SUCCESS self._tracker_pending_puts = [] diff --git a/python/tvm/rpc/tracker.py b/python/tvm/rpc/tracker.py index 74c1f7ac07aa..2b3482962e4d 100644 --- a/python/tvm/rpc/tracker.py +++ b/python/tvm/rpc/tracker.py @@ -337,9 +337,10 @@ def request(self, key, user, priority, callback): def close(self, conn): self._connections.remove(conn) if "key" in conn._info: - key = conn._info["key"].split(":")[1] # 'server:rasp3b' -> 'rasp3b' for value in conn.put_values: - self._scheduler_map[key].remove(value) + _, host, port, key = value + rpc_key = key.split(":")[0] + self._scheduler_map[rpc_key].remove(value) def stop(self): """Safely stop tracker.""" diff --git a/tests/python/unittest/test_runtime_rpc.py b/tests/python/unittest/test_runtime_rpc.py index 22aea8d1fcea..fa6bd75ea6e7 100644 --- a/tests/python/unittest/test_runtime_rpc.py +++ b/tests/python/unittest/test_runtime_rpc.py @@ -29,6 +29,7 @@ from tvm import rpc from tvm.contrib import utils, cc from tvm.rpc.tracker import Tracker +from tvm.rpc.proxy import Proxy if __name__ == "__main__": @@ -538,3 +539,42 @@ def test_rpc_tracker_request(): proc2.join() server.terminate() tracker.terminate() + + +@tvm.testing.requires_rpc +def test_rpc_tracker_via_proxy(): + device_key = "test_device" + + tracker_server = Tracker(port=9000, port_end=9100) + proxy_server = Proxy( + host=tracker_server.host, + port=8888, + port_end=8988, + tracker_addr=(tracker_server.host, tracker_server.port), + ) + + server1 = rpc.Server( + host=proxy_server.host, + port=proxy_server.port, + key=device_key, + tracker_addr=(tracker_server.host, tracker_server.port), + is_proxy=True, + ) + time.sleep(0.1) + server2 = rpc.Server( + host=proxy_server.host, + port=proxy_server.port, + key=device_key, + tracker_addr=(tracker_server.host, tracker_server.port), + is_proxy=True, + ) + time.sleep(0.1) + + client = rpc.connect_tracker(tracker_server.host, tracker_server.port) + summary = client.summary() + assert summary["queue_info"][device_key]["free"] == 2 + + server2.terminate() + server1.terminate() + proxy_server.terminate() + tracker_server.terminate() From a2b15185ca62460522d364ddbf5581d9aa8500a1 Mon Sep 17 00:00:00 2001 From: KJlaccHoeUM9l Date: Fri, 8 Oct 2021 17:41:37 +0300 Subject: [PATCH 2/3] fix lint --- python/tvm/rpc/tracker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/rpc/tracker.py b/python/tvm/rpc/tracker.py index 2b3482962e4d..5a576a705e8a 100644 --- a/python/tvm/rpc/tracker.py +++ b/python/tvm/rpc/tracker.py @@ -338,7 +338,7 @@ def close(self, conn): self._connections.remove(conn) if "key" in conn._info: for value in conn.put_values: - _, host, port, key = value + _, _, _, key = value rpc_key = key.split(":")[0] self._scheduler_map[rpc_key].remove(value) From 36c4092ec6d7e260bd2e0c5607e89401fd6c47b4 Mon Sep 17 00:00:00 2001 From: KJlaccHoeUM9l Date: Mon, 11 Oct 2021 15:27:56 +0300 Subject: [PATCH 3/3] remove time.sleep from test --- tests/python/unittest/test_runtime_rpc.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/python/unittest/test_runtime_rpc.py b/tests/python/unittest/test_runtime_rpc.py index fa6bd75ea6e7..6e1fc815d66d 100644 --- a/tests/python/unittest/test_runtime_rpc.py +++ b/tests/python/unittest/test_runtime_rpc.py @@ -543,6 +543,12 @@ def test_rpc_tracker_request(): @tvm.testing.requires_rpc def test_rpc_tracker_via_proxy(): + """ + tracker + / \ + Host -- Proxy -- RPC server + """ + device_key = "test_device" tracker_server = Tracker(port=9000, port_end=9100) @@ -560,7 +566,6 @@ def test_rpc_tracker_via_proxy(): tracker_addr=(tracker_server.host, tracker_server.port), is_proxy=True, ) - time.sleep(0.1) server2 = rpc.Server( host=proxy_server.host, port=proxy_server.port, @@ -568,11 +573,10 @@ def test_rpc_tracker_via_proxy(): tracker_addr=(tracker_server.host, tracker_server.port), is_proxy=True, ) - time.sleep(0.1) client = rpc.connect_tracker(tracker_server.host, tracker_server.port) - summary = client.summary() - assert summary["queue_info"][device_key]["free"] == 2 + remote1 = client.request(device_key, session_timeout=30) # pylint: disable=unused-variable + remote2 = client.request(device_key, session_timeout=30) # pylint: disable=unused-variable server2.terminate() server1.terminate()