Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 72 additions & 20 deletions python/tvm/contrib/hexagon/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import os
import pathlib
import signal
import socket
import stat
import subprocess
from typing import Union
Expand Down Expand Up @@ -304,6 +305,7 @@ def __init__(
self._serial_number = serial_number
adb_socket = rpc_info["adb_server_socket"] if rpc_info["adb_server_socket"] else "tcp:5037"
self._adb_device_sub_cmd = ["adb", "-L", adb_socket, "-s", self._serial_number]
self.forwarded_ports_ = []

super(HexagonLauncherAndroid, self).__init__(rpc_info, workspace)

Expand Down Expand Up @@ -356,26 +358,46 @@ def _copy_binaries(self):
for item in self.ANDROID_HEXAGON_RPC_FILES:
self._copy_to_remote(lib_dir / item, self._workspace / item)

def _process_forwarded_ports(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Rename to _existing_forwarded_ports or _get_forwarded_ports. The current name reads as if this is performing some processing on each forwarded port.

forwarded_ports = subprocess.check_output(self._adb_device_sub_cmd + ["forward", "--list"])
existing_forwards = []
for forward in str(forwarded_ports).split("\\n"):
entry = forward.split()
if len(entry) == 3:
_, local, _ = entry
existing_forwards.append(int(local.strip("tcp:")))
return existing_forwards

def _forward_ports(self, rpc_server_port, existing_forwards):
# Enable port forward for RPC server. We forward the first ten open ports
# starting from the rpc_server_port
port = rpc_server_port
while len(self.forwarded_ports_) < 10:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the total number of ports isn't important, instead trying to match the port range that may be attempted a server, when it searches for an available port to listen on. How often is the port already forwarded or in use, and would be be better to throw an error in those cases?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure throwing an error in these cases makes sense given that the user only cares about the functionality of their tests on hardware and doesn't want to see test failures based on bad port configurations.

If we change the RPC server to fail rather than search for a new port as it does now, then we can revisit the launcher code and have it try until successfully binding to a single port and remove the port range nonsense.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, and that would make a lot of sense to me. I like the semantics of "bind here or error" much more than "bind here or maybe somewhere else or error". That makes sense for it to be a later change, and not something needed at the moment.

if port not in existing_forwards and not _is_port_in_use(port):
subprocess.check_call(
self._adb_device_sub_cmd + ["forward", f"tcp:{port}", f"tcp:{port}"]
)
self.forwarded_ports_.append(port)
port += 1

def _reverse_ports(self, rpc_tracker_port):
subprocess.check_call(
self._adb_device_sub_cmd
+ ["reverse", f"tcp:{rpc_tracker_port}", f"tcp:{rpc_tracker_port}"]
)

def _run_server_script(self):
"""Setup the ADB connection and execute the server script."""

# Removed pre-defined forward/reverse rules
subprocess.check_call(self._adb_device_sub_cmd + ["forward", "--remove-all"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to clean up port forwards that have been opened for more than N days? The .stop_server() method may not be called if a segfault occurs, and that may leave ports open unnecessarily.

subprocess.check_call(self._adb_device_sub_cmd + ["reverse", "--remove-all"])

# Collect any existing adb port forwarding to avoid duplication
# with another running process
existing_forwards = self._process_forwarded_ports()
# Enable port reverse for RPC tracker
rpc_tracker_port = self._rpc_info["rpc_tracker_port"]
rpc_server_port = self._rpc_info["rpc_server_port"]
subprocess.check_call(
self._adb_device_sub_cmd
+ ["reverse", f"tcp:{rpc_tracker_port}", f"tcp:{rpc_tracker_port}"]
)
# Enable port forward for RPC server. We forward 9 ports after the rpc_server_port.
for i in range(0, 10):
subprocess.check_call(
self._adb_device_sub_cmd
+ ["forward", f"tcp:{rpc_server_port+i}", f"tcp:{rpc_server_port+i}"]
)

self._reverse_ports(rpc_tracker_port)
self._forward_ports(rpc_server_port, existing_forwards)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Since existing_forwards isn't used elsewhere, may be cleaner to generated it inside _forward_ports.


# Run server and connect to tracker
subprocess.Popen(
Expand All @@ -385,13 +407,27 @@ def _run_server_script(self):
stderr=subprocess.PIPE,
)

def start_server(self):
"""Abstract method implementation. See description in HexagonLauncherRPC."""
self._copy_binaries()
self._run_server_script()
def _cleanup_port_forwarding(self):
# Removed pre-defined forward/reverse rules
rpc_tracker_port = self._rpc_info["rpc_tracker_port"]
subprocess.check_call(
self._adb_device_sub_cmd + ["reverse", "--remove", f"tcp:{rpc_tracker_port}"]
)
for port in self.forwarded_ports_:
subprocess.check_call(self._adb_device_sub_cmd + ["forward", "--remove", f"tcp:{port}"])

def stop_server(self):
"""Abstract method implementation. See description in HexagonLauncherRPC."""
def _terminate_remote(self):
# Send interupt to main and child processes
subprocess.Popen(
self._adb_device_sub_cmd
+ ["shell", f"pkill -l sigint -P `cat {self._workspace}/rpc_pid.txt`"]
)
subprocess.Popen(
self._adb_device_sub_cmd
+ ["shell", f"kill -s sigint `cat {self._workspace}/rpc_pid.txt`"]
)
# Wait for processes to destruct cleanly after receiving the intrupt
subprocess.Popen(self._adb_device_sub_cmd + ["shell", "sleep", "0.1s"])
# Kill process children
subprocess.Popen(
self._adb_device_sub_cmd + ["shell", f"pkill -P `cat {self._workspace}/rpc_pid.txt`"]
Expand All @@ -401,6 +437,16 @@ def stop_server(self):
self._adb_device_sub_cmd + ["shell", f"kill `cat {self._workspace}/rpc_pid.txt`"]
)

def start_server(self):
"""Abstract method implementation. See description in HexagonLauncherRPC."""
self._copy_binaries()
self._run_server_script()

def stop_server(self):
"""Abstract method implementation. See description in HexagonLauncherRPC."""
self._cleanup_port_forwarding()
self._terminate_remote()


class HexagonLauncherSimulator(HexagonLauncherRPC):
"""Hexagon Launcher for Hexagon simulator."""
Expand Down Expand Up @@ -501,6 +547,12 @@ def stop_server(self):
self._server_process.terminate()


# https://stackoverflow.com/a/52872579/2689797
def _is_port_in_use(port: int) -> bool:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(("localhost", port)) == 0


# pylint: disable=invalid-name
def HexagonLauncher(
serial_number: str,
Expand Down
6 changes: 1 addition & 5 deletions tests/python/contrib/test_hexagon/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,18 +85,14 @@ def android_serial_number() -> Optional[str]:


def get_free_port():
# https://stackoverflow.com/a/52872579/2689797
def is_port_in_use(port: int) -> bool:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(("localhost", port)) == 0

global previous_port
if previous_port is None:
port = random.randint(listen_port_min, listen_port_max)
else:
port = previous_port + 1

while is_port_in_use(port):
while tvm.contrib.hexagon.build._is_port_in_use(port):
port = port + 1 if port < listen_port_max else listen_port_min

previous_port = port
Expand Down