From 36fd0e8d73fa9dbf91f5319681b5132d9db201f4 Mon Sep 17 00:00:00 2001 From: Mohamad Date: Thu, 19 Jan 2023 11:59:03 -0800 Subject: [PATCH 1/4] Allow multiple runners in tuning micro models with meta-schedule --- .../micro/meta_schedule/rpc_runner_micro.py | 108 ++++++++++++------ python/tvm/micro/build.py | 3 + python/tvm/micro/testing/pytest_plugin.py | 7 +- tests/micro/zephyr/test_ms_tuning.py | 13 ++- 4 files changed, 93 insertions(+), 38 deletions(-) diff --git a/python/tvm/contrib/micro/meta_schedule/rpc_runner_micro.py b/python/tvm/contrib/micro/meta_schedule/rpc_runner_micro.py index e4c08351841d..307855438e71 100644 --- a/python/tvm/contrib/micro/meta_schedule/rpc_runner_micro.py +++ b/python/tvm/contrib/micro/meta_schedule/rpc_runner_micro.py @@ -17,9 +17,10 @@ """RPC Runner Micro""" from contextlib import contextmanager -from typing import Callable, List, Optional +from typing import Callable, List, Optional, Union from collections import namedtuple import signal +import random from tvm import micro from tvm import nd @@ -44,10 +45,11 @@ def __init__( self, platform: str = "crt", project_options: Optional[dict] = None, - rpc_config: Optional[RPCConfig] = None, + rpc_configs: Optional[List[RPCConfig]] = None, evaluator_config: Optional[EvaluatorConfig] = None, max_workers: Optional[int] = None, initializer: Optional[Callable[[], None]] = None, + session_timeout_sec: int = 300, ) -> None: """Constructor @@ -65,21 +67,25 @@ def __init__( The maximum number of connections. Defaults to number of logical CPU cores. initializer: Optional[Callable[[], None]] The initializer function. + session_timeout_sec: int + The session timeout, including the pending time. if the number of candidates sent to runner is larger + than the runner workers, increase the timeout. """ super().__init__() self.platform = platform if project_options is None: project_options = {} self.project_options = project_options - self.rpc_config = RPCConfig._normalized(rpc_config) + self.rpc_configs = rpc_configs self.evaluator_config = EvaluatorConfig._normalized(evaluator_config) + self.session_timeout_sec = session_timeout_sec if max_workers is None: max_workers = cpu_count(logical=True) logger.info("RPCRunner: max_workers = %d", max_workers) self.pool = PopenPoolExecutor( max_workers=max_workers, - timeout=rpc_config.session_timeout_sec, + timeout=session_timeout_sec, initializer=initializer, ) @@ -92,13 +98,13 @@ def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: _worker_func, self.platform, self.project_options or {}, - self.rpc_config, + self.rpc_configs, self.evaluator_config, str(runner_input.artifact_path), str(runner_input.device_type), tuple(arg_info.as_json() for arg_info in runner_input.args_info), ), - timeout_sec=self.rpc_config.session_timeout_sec, + timeout_sec=self.session_timeout_sec, ) results.append(future) # type: ignore return results @@ -107,7 +113,7 @@ def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: def _worker_func( platform: str, project_options: dict, - rpc_config: RPCConfig, + rpc_configs: List[RPCConfig], evaluator_config: EvaluatorConfig, artifact_path: str, device_type: str, @@ -119,6 +125,7 @@ def _worker_func( project_options=project_options, ) + rpc_config = random.choice(rpc_configs) remote_kw = { "device_key": rpc_config.tracker_key, "host": rpc_config.tracker_host, @@ -126,6 +133,7 @@ def _worker_func( "priority": 0, "timeout": 100, } + build_result = namedtuple("BuildResult", ["filename"])(artifact_path) with module_loader(remote_kw, build_result) as (remote, mod): @@ -156,36 +164,35 @@ def _worker_func( def get_rpc_runner_micro( platform, options, - rpc_config: RPCConfig = None, evaluator_config: EvaluatorConfig = None, - session_timeout_sec=300, + tracker_host: Optional[str] = None, + tracker_port: Union[None, int, str] = None, + session_timeout_sec: int = 300, + rpc_timeout_sec: int = 10, + serial_numbers: List[str] = None, ): """Parameters ---------- platform: str The platform used for project generation. - project_options: dict + options: dict The options for the generated micro project. - rpc_config: RPCConfig - The rpc configuration. evaluator_config: EvaluatorConfig The evaluator configuration. + tracker_host: Optional[str] + The host url of the rpc server. + tracker_port: Union[None, int, str] + The TCP port to bind to session_timeout_sec: int The session timeout. if the number of candidates sent to runner is larger than the runner workers, increase the timeout. + rpc_timeout_sec: + The rpc session timeout. + serial_numbers: + List of board serial numbers to be used during tuning. + For "CRT" and "QEMU" platforms the serial numners are not used, + but the length of the list determines the number of runner instances. """ - if rpc_config is None: - tracker_host = "127.0.0.1" - tracker_port = 9000 - tracker_key = "$local$device$%d" % tracker_port - rpc_config = RPCConfig( - tracker_host=tracker_host, - tracker_port=tracker_port, - tracker_key=tracker_key, - session_priority=0, - session_timeout_sec=session_timeout_sec, - ) - tracker_port_end = rpc_config.tracker_port + 1000 if evaluator_config is None: evaluator_config = EvaluatorConfig( @@ -195,26 +202,54 @@ def get_rpc_runner_micro( enable_cpu_cache_flush=False, ) + if tracker_host is None: + tracker_host = "127.0.0.1" + + if tracker_port is None: + tracker_port = 9000 + else: + tracker_port = int(tracker_port) + tracker_port_end = tracker_port + 1000 + + if not (serial_numbers): + serial_numbers = ["$local$device"] + tracker = Tracker( - port=rpc_config.tracker_port, + port=tracker_port, port_end=tracker_port_end, silent=True, reuse_addr=True, timeout=60, ) - server = Server( - port=rpc_config.tracker_port, - port_end=tracker_port_end, - key=rpc_config.tracker_key, - silent=True, - tracker_addr=(rpc_config.tracker_host, rpc_config.tracker_port), - reuse_addr=True, - timeout=60, - ) + + servers = [] + rpc_configs = [] + for serial_number in serial_numbers: + key = serial_number + rpc_config = RPCConfig( + tracker_host=tracker_host, + tracker_port=tracker_port, + tracker_key=key, + session_priority=0, + session_timeout_sec=rpc_timeout_sec, + ) + rpc_configs.append(rpc_config) + + server = Server( + port=tracker_port, + port_end=tracker_port_end, + key=key, + silent=True, + tracker_addr=(tracker_host, tracker_port), + reuse_addr=True, + timeout=60, + ) + servers.append(server) def terminate(): tracker.terminate() - server.terminate() + for server in servers: + server.terminate() def handle_SIGINT(signal, frame): terminate() @@ -226,8 +261,9 @@ def handle_SIGINT(signal, frame): yield RPCRunnerMicro( platform=platform, project_options=options, - rpc_config=rpc_config, + rpc_configs=rpc_configs, evaluator_config=evaluator_config, + session_timeout_sec=session_timeout_sec, ) finally: terminate() diff --git a/python/tvm/micro/build.py b/python/tvm/micro/build.py index e45054be98ca..f0c90db01cc3 100644 --- a/python/tvm/micro/build.py +++ b/python/tvm/micro/build.py @@ -150,6 +150,9 @@ def __call__(self, remote_kw, build_result): with open(build_result.filename, "rb") as build_file: build_result_bin = build_file.read() + if "board" in self._project_options and "$local$device" not in remote_kw["device_key"]: + self._project_options["serial_number"] = remote_kw["device_key"] + tracker = _rpc.connect_tracker(remote_kw["host"], remote_kw["port"]) remote = tracker.request( remote_kw["device_key"], diff --git a/python/tvm/micro/testing/pytest_plugin.py b/python/tvm/micro/testing/pytest_plugin.py index c32377fb7e7d..3a828ea3a01e 100644 --- a/python/tvm/micro/testing/pytest_plugin.py +++ b/python/tvm/micro/testing/pytest_plugin.py @@ -142,4 +142,9 @@ def pytest_configure(config): @pytest.fixture def serial_number(request): - return request.config.getoption("--serial-number") + serial_number = request.config.getoption("--serial-number") + if serial_number: + serial_number_splitted = serial_number.split(",") + if len(serial_number_splitted) > 1: + return serial_number_splitted + return serial_number diff --git a/tests/micro/zephyr/test_ms_tuning.py b/tests/micro/zephyr/test_ms_tuning.py index 3ce6ff68bc32..c9168a027d97 100644 --- a/tests/micro/zephyr/test_ms_tuning.py +++ b/tests/micro/zephyr/test_ms_tuning.py @@ -80,6 +80,14 @@ def test_ms_tuning_conv2d(workspace_dir, board, microtvm_debug, use_fvp, serial_ "serial_number": serial_number, "config_main_stack_size": 4096, } + if isinstance(serial_number, list): + project_options["serial_number"] = serial_number[0] # project_api expects an string. + serial_numbers = serial_number + else: + if serial_number is not None: # use a single device in tuning + serial_numbers = [serial_number] + else: # use two dummy serial numbers (for QEMU) + serial_numbers = [str(i) for i in range(2)] boards_file = pathlib.Path(tvm.micro.get_microtvm_template_projects("zephyr")) / "boards.json" with open(boards_file) as f: @@ -95,7 +103,10 @@ def test_ms_tuning_conv2d(workspace_dir, board, microtvm_debug, use_fvp, serial_ builder = get_local_builder_micro() with ms.Profiler() as profiler: with get_rpc_runner_micro( - platform=platform, options=project_options, session_timeout_sec=120 + platform=platform, + options=project_options, + session_timeout_sec=120, + serial_numbers=serial_numbers, ) as runner: db: ms.Database = ms.relay_integration.tune_relay( From 2111cc1b43895e4922cfa12e6690c26a1c01f4c2 Mon Sep 17 00:00:00 2001 From: Mohamad Date: Mon, 23 Jan 2023 15:01:47 -0800 Subject: [PATCH 2/4] adding clarifying comments --- python/tvm/micro/build.py | 4 ++++ tests/micro/zephyr/test_ms_tuning.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/python/tvm/micro/build.py b/python/tvm/micro/build.py index f0c90db01cc3..5f9e21bd5cfd 100644 --- a/python/tvm/micro/build.py +++ b/python/tvm/micro/build.py @@ -150,6 +150,10 @@ def __call__(self, remote_kw, build_result): with open(build_result.filename, "rb") as build_file: build_result_bin = build_file.read() + # In case we are tuning on multiple physical boards (with Meta-schedule), the tracker + # device_key is the serial_number of the board that wil be used in generating micro session. + # For CRT projects, and in cases that the serial number is not provied + # (including tuning with AutoTVM), the serial number field doesn't change. if "board" in self._project_options and "$local$device" not in remote_kw["device_key"]: self._project_options["serial_number"] = remote_kw["device_key"] diff --git a/tests/micro/zephyr/test_ms_tuning.py b/tests/micro/zephyr/test_ms_tuning.py index c9168a027d97..dd8a7bb9c90b 100644 --- a/tests/micro/zephyr/test_ms_tuning.py +++ b/tests/micro/zephyr/test_ms_tuning.py @@ -86,7 +86,7 @@ def test_ms_tuning_conv2d(workspace_dir, board, microtvm_debug, use_fvp, serial_ else: if serial_number is not None: # use a single device in tuning serial_numbers = [serial_number] - else: # use two dummy serial numbers (for QEMU) + else: # use two dummy serial numbers (for testing with QEMU) serial_numbers = [str(i) for i in range(2)] boards_file = pathlib.Path(tvm.micro.get_microtvm_template_projects("zephyr")) / "boards.json" From a18655119ba32317e61b5c4de2bab8be54372274 Mon Sep 17 00:00:00 2001 From: Mohamad Date: Mon, 23 Jan 2023 15:04:13 -0800 Subject: [PATCH 3/4] nit --- python/tvm/micro/build.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/micro/build.py b/python/tvm/micro/build.py index 5f9e21bd5cfd..e3b2a318a6b4 100644 --- a/python/tvm/micro/build.py +++ b/python/tvm/micro/build.py @@ -152,7 +152,7 @@ def __call__(self, remote_kw, build_result): # In case we are tuning on multiple physical boards (with Meta-schedule), the tracker # device_key is the serial_number of the board that wil be used in generating micro session. - # For CRT projects, and in cases that the serial number is not provied + # For CRT projects, and in cases that the serial number is not provided # (including tuning with AutoTVM), the serial number field doesn't change. if "board" in self._project_options and "$local$device" not in remote_kw["device_key"]: self._project_options["serial_number"] = remote_kw["device_key"] From 036965e1b5802cdbee1121fb4527374dc20c5abb Mon Sep 17 00:00:00 2001 From: Mohamad Date: Wed, 25 Jan 2023 10:52:55 -0800 Subject: [PATCH 4/4] disable parallel tuning test on fvp --- tests/micro/zephyr/test_ms_tuning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/micro/zephyr/test_ms_tuning.py b/tests/micro/zephyr/test_ms_tuning.py index dd8a7bb9c90b..16d48ca4cdd6 100644 --- a/tests/micro/zephyr/test_ms_tuning.py +++ b/tests/micro/zephyr/test_ms_tuning.py @@ -61,7 +61,7 @@ def create_relay_module(): @tvm.testing.requires_micro -@pytest.mark.xfail_on_fvp() +@pytest.mark.skip_boards(["mps2_an521", "mps3_an547"]) def test_ms_tuning_conv2d(workspace_dir, board, microtvm_debug, use_fvp, serial_number): """Test meta-schedule tuning for microTVM Zephyr"""