Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 72 additions & 36 deletions python/tvm/contrib/micro/meta_schedule/rpc_runner_micro.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
"""RPC Runner Micro"""

from contextlib import contextmanager
from typing import Callable, List, Optional
from typing import Callable, List, Optional, Union
from collections import namedtuple
import signal
import random

from tvm import micro
from tvm import nd
Expand All @@ -44,10 +45,11 @@ def __init__(
self,
platform: str = "crt",
project_options: Optional[dict] = None,
rpc_config: Optional[RPCConfig] = None,
rpc_configs: Optional[List[RPCConfig]] = None,
evaluator_config: Optional[EvaluatorConfig] = None,
max_workers: Optional[int] = None,
initializer: Optional[Callable[[], None]] = None,
session_timeout_sec: int = 300,
) -> None:
"""Constructor

Expand All @@ -65,21 +67,25 @@ def __init__(
The maximum number of connections. Defaults to number of logical CPU cores.
initializer: Optional[Callable[[], None]]
The initializer function.
session_timeout_sec: int
The session timeout, including the pending time. if the number of candidates sent to runner is larger
than the runner workers, increase the timeout.
"""
super().__init__()
self.platform = platform
if project_options is None:
project_options = {}
self.project_options = project_options
self.rpc_config = RPCConfig._normalized(rpc_config)
self.rpc_configs = rpc_configs
self.evaluator_config = EvaluatorConfig._normalized(evaluator_config)
self.session_timeout_sec = session_timeout_sec

if max_workers is None:
max_workers = cpu_count(logical=True)
logger.info("RPCRunner: max_workers = %d", max_workers)
self.pool = PopenPoolExecutor(
max_workers=max_workers,
timeout=rpc_config.session_timeout_sec,
timeout=session_timeout_sec,
initializer=initializer,
)

Expand All @@ -92,13 +98,13 @@ def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]:
_worker_func,
self.platform,
self.project_options or {},
self.rpc_config,
self.rpc_configs,
self.evaluator_config,
str(runner_input.artifact_path),
str(runner_input.device_type),
tuple(arg_info.as_json() for arg_info in runner_input.args_info),
),
timeout_sec=self.rpc_config.session_timeout_sec,
timeout_sec=self.session_timeout_sec,
)
results.append(future) # type: ignore
return results
Expand All @@ -107,7 +113,7 @@ def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]:
def _worker_func(
platform: str,
project_options: dict,
rpc_config: RPCConfig,
rpc_configs: List[RPCConfig],
evaluator_config: EvaluatorConfig,
artifact_path: str,
device_type: str,
Expand All @@ -119,13 +125,15 @@ def _worker_func(
project_options=project_options,
)

rpc_config = random.choice(rpc_configs)
remote_kw = {
"device_key": rpc_config.tracker_key,
"host": rpc_config.tracker_host,
"port": rpc_config.tracker_port,
"priority": 0,
"timeout": 100,
}

build_result = namedtuple("BuildResult", ["filename"])(artifact_path)

with module_loader(remote_kw, build_result) as (remote, mod):
Expand Down Expand Up @@ -156,36 +164,35 @@ def _worker_func(
def get_rpc_runner_micro(
platform,
options,
rpc_config: RPCConfig = None,
evaluator_config: EvaluatorConfig = None,
session_timeout_sec=300,
tracker_host: Optional[str] = None,
tracker_port: Union[None, int, str] = None,
session_timeout_sec: int = 300,
rpc_timeout_sec: int = 10,
serial_numbers: List[str] = None,
):
"""Parameters
----------
platform: str
The platform used for project generation.
project_options: dict
options: dict
The options for the generated micro project.
rpc_config: RPCConfig
The rpc configuration.
evaluator_config: EvaluatorConfig
The evaluator configuration.
tracker_host: Optional[str]
The host url of the rpc server.
tracker_port: Union[None, int, str]
The TCP port to bind to
session_timeout_sec: int
The session timeout. if the number of candidates sent to runner is larger
than the runner workers, increase the timeout.
rpc_timeout_sec:
The rpc session timeout.
serial_numbers:
List of board serial numbers to be used during tuning.
For "CRT" and "QEMU" platforms the serial numners are not used,
but the length of the list determines the number of runner instances.
"""
if rpc_config is None:
tracker_host = "127.0.0.1"
tracker_port = 9000
tracker_key = "$local$device$%d" % tracker_port
rpc_config = RPCConfig(
tracker_host=tracker_host,
tracker_port=tracker_port,
tracker_key=tracker_key,
session_priority=0,
session_timeout_sec=session_timeout_sec,
)
tracker_port_end = rpc_config.tracker_port + 1000

if evaluator_config is None:
evaluator_config = EvaluatorConfig(
Expand All @@ -195,26 +202,54 @@ def get_rpc_runner_micro(
enable_cpu_cache_flush=False,
)

if tracker_host is None:
tracker_host = "127.0.0.1"

if tracker_port is None:
tracker_port = 9000
else:
tracker_port = int(tracker_port)
tracker_port_end = tracker_port + 1000

if not (serial_numbers):
serial_numbers = ["$local$device"]

tracker = Tracker(
port=rpc_config.tracker_port,
port=tracker_port,
port_end=tracker_port_end,
silent=True,
reuse_addr=True,
timeout=60,
)
server = Server(
port=rpc_config.tracker_port,
port_end=tracker_port_end,
key=rpc_config.tracker_key,
silent=True,
tracker_addr=(rpc_config.tracker_host, rpc_config.tracker_port),
reuse_addr=True,
timeout=60,
)

servers = []
rpc_configs = []
for serial_number in serial_numbers:
key = serial_number
rpc_config = RPCConfig(
tracker_host=tracker_host,
tracker_port=tracker_port,
tracker_key=key,
session_priority=0,
session_timeout_sec=rpc_timeout_sec,
)
rpc_configs.append(rpc_config)

server = Server(
port=tracker_port,
port_end=tracker_port_end,
key=key,
silent=True,
tracker_addr=(tracker_host, tracker_port),
reuse_addr=True,
timeout=60,
)
servers.append(server)

def terminate():
tracker.terminate()
server.terminate()
for server in servers:
server.terminate()

def handle_SIGINT(signal, frame):
terminate()
Expand All @@ -226,8 +261,9 @@ def handle_SIGINT(signal, frame):
yield RPCRunnerMicro(
platform=platform,
project_options=options,
rpc_config=rpc_config,
rpc_configs=rpc_configs,
evaluator_config=evaluator_config,
session_timeout_sec=session_timeout_sec,
)
finally:
terminate()
7 changes: 7 additions & 0 deletions python/tvm/micro/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,13 @@ def __call__(self, remote_kw, build_result):
with open(build_result.filename, "rb") as build_file:
build_result_bin = build_file.read()

# In case we are tuning on multiple physical boards (with Meta-schedule), the tracker
# device_key is the serial_number of the board that wil be used in generating micro session.
# For CRT projects, and in cases that the serial number is not provided
# (including tuning with AutoTVM), the serial number field doesn't change.
if "board" in self._project_options and "$local$device" not in remote_kw["device_key"]:
self._project_options["serial_number"] = remote_kw["device_key"]

tracker = _rpc.connect_tracker(remote_kw["host"], remote_kw["port"])
remote = tracker.request(
remote_kw["device_key"],
Expand Down
7 changes: 6 additions & 1 deletion python/tvm/micro/testing/pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,9 @@ def pytest_configure(config):

@pytest.fixture
def serial_number(request):
return request.config.getoption("--serial-number")
serial_number = request.config.getoption("--serial-number")
if serial_number:
serial_number_splitted = serial_number.split(",")
if len(serial_number_splitted) > 1:
return serial_number_splitted
return serial_number
15 changes: 13 additions & 2 deletions tests/micro/zephyr/test_ms_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def create_relay_module():


@tvm.testing.requires_micro
@pytest.mark.xfail_on_fvp()
@pytest.mark.skip_boards(["mps2_an521", "mps3_an547"])
def test_ms_tuning_conv2d(workspace_dir, board, microtvm_debug, use_fvp, serial_number):
"""Test meta-schedule tuning for microTVM Zephyr"""

Expand All @@ -80,6 +80,14 @@ def test_ms_tuning_conv2d(workspace_dir, board, microtvm_debug, use_fvp, serial_
"serial_number": serial_number,
"config_main_stack_size": 4096,
}
if isinstance(serial_number, list):
project_options["serial_number"] = serial_number[0] # project_api expects an string.
serial_numbers = serial_number
else:
if serial_number is not None: # use a single device in tuning
serial_numbers = [serial_number]
else: # use two dummy serial numbers (for testing with QEMU)
serial_numbers = [str(i) for i in range(2)]

boards_file = pathlib.Path(tvm.micro.get_microtvm_template_projects("zephyr")) / "boards.json"
with open(boards_file) as f:
Expand All @@ -95,7 +103,10 @@ def test_ms_tuning_conv2d(workspace_dir, board, microtvm_debug, use_fvp, serial_
builder = get_local_builder_micro()
with ms.Profiler() as profiler:
with get_rpc_runner_micro(
platform=platform, options=project_options, session_timeout_sec=120
platform=platform,
options=project_options,
session_timeout_sec=120,
serial_numbers=serial_numbers,
) as runner:

db: ms.Database = ms.relay_integration.tune_relay(
Expand Down