diff --git a/python/tvm/contrib/hexagon/build.py b/python/tvm/contrib/hexagon/build.py index 16d3a30fd643..776faa9e9fd1 100644 --- a/python/tvm/contrib/hexagon/build.py +++ b/python/tvm/contrib/hexagon/build.py @@ -23,6 +23,7 @@ import os import pathlib import signal +import socket import stat import subprocess from typing import Union @@ -304,6 +305,7 @@ def __init__( self._serial_number = serial_number adb_socket = rpc_info["adb_server_socket"] if rpc_info["adb_server_socket"] else "tcp:5037" self._adb_device_sub_cmd = ["adb", "-L", adb_socket, "-s", self._serial_number] + self.forwarded_ports_ = [] super(HexagonLauncherAndroid, self).__init__(rpc_info, workspace) @@ -356,26 +358,46 @@ def _copy_binaries(self): for item in self.ANDROID_HEXAGON_RPC_FILES: self._copy_to_remote(lib_dir / item, self._workspace / item) + def _process_forwarded_ports(self): + forwarded_ports = subprocess.check_output(self._adb_device_sub_cmd + ["forward", "--list"]) + existing_forwards = [] + for forward in str(forwarded_ports).split("\\n"): + entry = forward.split() + if len(entry) == 3: + _, local, _ = entry + existing_forwards.append(int(local.strip("tcp:"))) + return existing_forwards + + def _forward_ports(self, rpc_server_port, existing_forwards): + # Enable port forward for RPC server. We forward the first ten open ports + # starting from the rpc_server_port + port = rpc_server_port + while len(self.forwarded_ports_) < 10: + if port not in existing_forwards and not _is_port_in_use(port): + subprocess.check_call( + self._adb_device_sub_cmd + ["forward", f"tcp:{port}", f"tcp:{port}"] + ) + self.forwarded_ports_.append(port) + port += 1 + + def _reverse_ports(self, rpc_tracker_port): + subprocess.check_call( + self._adb_device_sub_cmd + + ["reverse", f"tcp:{rpc_tracker_port}", f"tcp:{rpc_tracker_port}"] + ) + def _run_server_script(self): """Setup the ADB connection and execute the server script.""" - # Removed pre-defined forward/reverse rules - subprocess.check_call(self._adb_device_sub_cmd + ["forward", "--remove-all"]) - subprocess.check_call(self._adb_device_sub_cmd + ["reverse", "--remove-all"]) - + # Collect any existing adb port forwarding to avoid duplication + # with another running process + existing_forwards = self._process_forwarded_ports() # Enable port reverse for RPC tracker rpc_tracker_port = self._rpc_info["rpc_tracker_port"] rpc_server_port = self._rpc_info["rpc_server_port"] - subprocess.check_call( - self._adb_device_sub_cmd - + ["reverse", f"tcp:{rpc_tracker_port}", f"tcp:{rpc_tracker_port}"] - ) - # Enable port forward for RPC server. We forward 9 ports after the rpc_server_port. - for i in range(0, 10): - subprocess.check_call( - self._adb_device_sub_cmd - + ["forward", f"tcp:{rpc_server_port+i}", f"tcp:{rpc_server_port+i}"] - ) + + self._reverse_ports(rpc_tracker_port) + self._forward_ports(rpc_server_port, existing_forwards) # Run server and connect to tracker subprocess.Popen( @@ -385,13 +407,27 @@ def _run_server_script(self): stderr=subprocess.PIPE, ) - def start_server(self): - """Abstract method implementation. See description in HexagonLauncherRPC.""" - self._copy_binaries() - self._run_server_script() + def _cleanup_port_forwarding(self): + # Removed pre-defined forward/reverse rules + rpc_tracker_port = self._rpc_info["rpc_tracker_port"] + subprocess.check_call( + self._adb_device_sub_cmd + ["reverse", "--remove", f"tcp:{rpc_tracker_port}"] + ) + for port in self.forwarded_ports_: + subprocess.check_call(self._adb_device_sub_cmd + ["forward", "--remove", f"tcp:{port}"]) - def stop_server(self): - """Abstract method implementation. See description in HexagonLauncherRPC.""" + def _terminate_remote(self): + # Send interupt to main and child processes + subprocess.Popen( + self._adb_device_sub_cmd + + ["shell", f"pkill -l sigint -P `cat {self._workspace}/rpc_pid.txt`"] + ) + subprocess.Popen( + self._adb_device_sub_cmd + + ["shell", f"kill -s sigint `cat {self._workspace}/rpc_pid.txt`"] + ) + # Wait for processes to destruct cleanly after receiving the intrupt + subprocess.Popen(self._adb_device_sub_cmd + ["shell", "sleep", "0.1s"]) # Kill process children subprocess.Popen( self._adb_device_sub_cmd + ["shell", f"pkill -P `cat {self._workspace}/rpc_pid.txt`"] @@ -401,6 +437,16 @@ def stop_server(self): self._adb_device_sub_cmd + ["shell", f"kill `cat {self._workspace}/rpc_pid.txt`"] ) + def start_server(self): + """Abstract method implementation. See description in HexagonLauncherRPC.""" + self._copy_binaries() + self._run_server_script() + + def stop_server(self): + """Abstract method implementation. See description in HexagonLauncherRPC.""" + self._cleanup_port_forwarding() + self._terminate_remote() + class HexagonLauncherSimulator(HexagonLauncherRPC): """Hexagon Launcher for Hexagon simulator.""" @@ -501,6 +547,12 @@ def stop_server(self): self._server_process.terminate() +# https://stackoverflow.com/a/52872579/2689797 +def _is_port_in_use(port: int) -> bool: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + return s.connect_ex(("localhost", port)) == 0 + + # pylint: disable=invalid-name def HexagonLauncher( serial_number: str, diff --git a/tests/python/contrib/test_hexagon/conftest.py b/tests/python/contrib/test_hexagon/conftest.py index 87bb69a34961..009150b1081c 100644 --- a/tests/python/contrib/test_hexagon/conftest.py +++ b/tests/python/contrib/test_hexagon/conftest.py @@ -85,10 +85,6 @@ def android_serial_number() -> Optional[str]: def get_free_port(): - # https://stackoverflow.com/a/52872579/2689797 - def is_port_in_use(port: int) -> bool: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - return s.connect_ex(("localhost", port)) == 0 global previous_port if previous_port is None: @@ -96,7 +92,7 @@ def is_port_in_use(port: int) -> bool: else: port = previous_port + 1 - while is_port_in_use(port): + while tvm.contrib.hexagon.build._is_port_in_use(port): port = port + 1 if port < listen_port_max else listen_port_min previous_port = port