diff --git a/python/tvm/rpc/proxy.py b/python/tvm/rpc/proxy.py index e5ec73db51b9..c3b0056eb591 100644 --- a/python/tvm/rpc/proxy.py +++ b/python/tvm/rpc/proxy.py @@ -379,7 +379,7 @@ def _update_tracker(self, period_update=False): if need_update_info: keylist = "[" + ",".join(self._key_set) + "]" - cinfo = {"key": "server:proxy" + keylist} + cinfo = {"key": "server:proxy" + keylist, "addr": [None, self._listen_port]} base.sendjson(self._tracker_conn, [TrackerCode.UPDATE_INFO, cinfo]) assert base.recvjson(self._tracker_conn) == TrackerCode.SUCCESS self._tracker_pending_puts = [] diff --git a/python/tvm/rpc/tracker.py b/python/tvm/rpc/tracker.py index 74c1f7ac07aa..5a576a705e8a 100644 --- a/python/tvm/rpc/tracker.py +++ b/python/tvm/rpc/tracker.py @@ -337,9 +337,10 @@ def request(self, key, user, priority, callback): def close(self, conn): self._connections.remove(conn) if "key" in conn._info: - key = conn._info["key"].split(":")[1] # 'server:rasp3b' -> 'rasp3b' for value in conn.put_values: - self._scheduler_map[key].remove(value) + _, _, _, key = value + rpc_key = key.split(":")[0] + self._scheduler_map[rpc_key].remove(value) def stop(self): """Safely stop tracker.""" diff --git a/tests/python/unittest/test_runtime_rpc.py b/tests/python/unittest/test_runtime_rpc.py index 22aea8d1fcea..6e1fc815d66d 100644 --- a/tests/python/unittest/test_runtime_rpc.py +++ b/tests/python/unittest/test_runtime_rpc.py @@ -29,6 +29,7 @@ from tvm import rpc from tvm.contrib import utils, cc from tvm.rpc.tracker import Tracker +from tvm.rpc.proxy import Proxy if __name__ == "__main__": @@ -538,3 +539,46 @@ def test_rpc_tracker_request(): proc2.join() server.terminate() tracker.terminate() + + +@tvm.testing.requires_rpc +def test_rpc_tracker_via_proxy(): + """ + tracker + / \ + Host -- Proxy -- RPC server + """ + + device_key = "test_device" + + tracker_server = Tracker(port=9000, port_end=9100) + proxy_server = Proxy( + host=tracker_server.host, + port=8888, + port_end=8988, + tracker_addr=(tracker_server.host, tracker_server.port), + ) + + server1 = rpc.Server( + host=proxy_server.host, + port=proxy_server.port, + key=device_key, + tracker_addr=(tracker_server.host, tracker_server.port), + is_proxy=True, + ) + server2 = rpc.Server( + host=proxy_server.host, + port=proxy_server.port, + key=device_key, + tracker_addr=(tracker_server.host, tracker_server.port), + is_proxy=True, + ) + + client = rpc.connect_tracker(tracker_server.host, tracker_server.port) + remote1 = client.request(device_key, session_timeout=30) # pylint: disable=unused-variable + remote2 = client.request(device_key, session_timeout=30) # pylint: disable=unused-variable + + server2.terminate() + server1.terminate() + proxy_server.terminate() + tracker_server.terminate()