diff --git a/apps/microtvm/arduino/template_project/microtvm_api_server.py b/apps/microtvm/arduino/template_project/microtvm_api_server.py index 4975d924dac1..4e8ee32c3719 100644 --- a/apps/microtvm/arduino/template_project/microtvm_api_server.py +++ b/apps/microtvm/arduino/template_project/microtvm_api_server.py @@ -56,15 +56,6 @@ raise FileNotFoundError(f"Board file {{{BOARDS}}} does not exist.") -def get_cmsis_path(cmsis_path: pathlib.Path) -> pathlib.Path: - """Returns CMSIS dependency path""" - if cmsis_path: - return pathlib.Path(cmsis_path) - if os.environ.get("CMSIS_PATH"): - return pathlib.Path(os.environ.get("CMSIS_PATH")) - assert False, "'cmsis_path' option not passed!" - - class BoardAutodetectFailed(Exception): """Raised when no attached hardware is found matching the requested board""" @@ -78,7 +69,11 @@ class BoardAutodetectFailed(Exception): ) + [ server.ProjectOption( "arduino_cli_cmd", - required=(["generate_project", "flash", "open_transport"] if not ARDUINO_CLI_CMD else None), + required=( + ["generate_project", "build", "flash", "open_transport"] + if not ARDUINO_CLI_CMD + else None + ), optional=( ["generate_project", "build", "flash", "open_transport"] if ARDUINO_CLI_CMD else None ), @@ -337,7 +332,7 @@ def _copy_cmsis(self, project_path: pathlib.Path, cmsis_path: str): However, the latest release does not include header files that are copied in this function. """ (project_path / "include" / "cmsis").mkdir() - cmsis_path = get_cmsis_path(cmsis_path) + cmsis_path = pathlib.Path(cmsis_path) for item in self.CMSIS_INCLUDE_HEADERS: shutil.copy2( cmsis_path / "CMSIS" / "NN" / "Include" / item, @@ -357,7 +352,7 @@ def _populate_makefile( flags = { "FQBN": self._get_fqbn(board), "VERBOSE_FLAG": "--verbose" if verbose else "", - "ARUINO_CLI_CMD": self._get_arduino_cli_cmd(arduino_cli_cmd), + "ARUINO_CLI_CMD": arduino_cli_cmd, "BOARD": board, "BUILD_EXTRA_FLAGS": build_extra_flags, } @@ -377,9 +372,10 @@ def _populate_makefile( def generate_project(self, model_library_format_path, standalone_crt_dir, project_dir, options): # List all used project options board = options["board"] - verbose = options.get("verbose") project_type = options["project_type"] - arduino_cli_cmd = options.get("arduino_cli_cmd") + arduino_cli_cmd = options["arduino_cli_cmd"] + verbose = options["verbose"] + cmsis_path = options.get("cmsis_path") compile_definitions = options.get("compile_definitions") extra_files_tar = options.get("extra_files_tar") @@ -455,12 +451,6 @@ def generate_project(self, model_library_format_path, standalone_crt_dir, projec build_extra_flags, ) - def _get_arduino_cli_cmd(self, arduino_cli_cmd: str): - if not arduino_cli_cmd: - arduino_cli_cmd = ARDUINO_CLI_CMD - assert arduino_cli_cmd, "'arduino_cli_cmd' command not passed and not found by default!" - return arduino_cli_cmd - def _get_platform_version(self, arduino_cli_path: str) -> float: # sample output of this command: # 'arduino-cli alpha Version: 0.18.3 Commit: d710b642 Date: 2021-05-14T12:36:58Z\n' @@ -494,11 +484,10 @@ def _get_fqbn(self, board: str): def build(self, options): # List all used project options - arduino_cli_cmd = options.get("arduino_cli_cmd") + arduino_cli_cmd = options["arduino_cli_cmd"] warning_as_error = options.get("warning_as_error") - cli_command = self._get_arduino_cli_cmd(arduino_cli_cmd) - self._check_platform_version(cli_command, warning_as_error) + self._check_platform_version(arduino_cli_cmd, warning_as_error) compile_cmd = ["make", "build"] # Specify project to compile subprocess.run(compile_cmd, check=True, cwd=API_SERVER_DIR) @@ -539,7 +528,7 @@ def _parse_connected_boards(self, tabular_str): def _auto_detect_port(self, arduino_cli_cmd: str, board: str) -> str: # It is assumed only one board with this type is connected to this host machine. - list_cmd = [self._get_arduino_cli_cmd(arduino_cli_cmd), "board", "list"] + list_cmd = [arduino_cli_cmd, "board", "list"] list_cmd_output = subprocess.run( list_cmd, check=True, stdout=subprocess.PIPE ).stdout.decode("utf-8") @@ -599,7 +588,7 @@ def _get_board_from_makefile(self, makefile_path: pathlib.Path) -> str: def flash(self, options): # List all used project options - arduino_cli_cmd = options.get("arduino_cli_cmd") + arduino_cli_cmd = options["arduino_cli_cmd"] warning_as_error = options.get("warning_as_error") port = options.get("port") board = options.get("board") @@ -608,9 +597,8 @@ def flash(self, options): if not board: board = self._get_board_from_makefile(API_SERVER_DIR / MAKEFILE_FILENAME) - cli_command = self._get_arduino_cli_cmd(arduino_cli_cmd) - self._check_platform_version(cli_command, warning_as_error) - port = self._get_arduino_port(cli_command, board, port, serial_number) + self._check_platform_version(arduino_cli_cmd, warning_as_error) + port = self._get_arduino_port(arduino_cli_cmd, board, port, serial_number) upload_cmd = ["make", "flash", f"PORT={port}"] for _ in range(self.FLASH_MAX_RETRIES): @@ -639,7 +627,7 @@ def open_transport(self, options): import serial.tools.list_ports # List all used project options - arduino_cli_cmd = options.get("arduino_cli_cmd") + arduino_cli_cmd = options["arduino_cli_cmd"] port = options.get("port") board = options.get("board") serial_number = options.get("serial_number") diff --git a/apps/microtvm/zephyr/template_project/microtvm_api_server.py b/apps/microtvm/zephyr/template_project/microtvm_api_server.py index 9a8015d62571..b0cd21e4adb2 100644 --- a/apps/microtvm/zephyr/template_project/microtvm_api_server.py +++ b/apps/microtvm/zephyr/template_project/microtvm_api_server.py @@ -175,7 +175,7 @@ def _find_board_from_cmake_file(cmake_file: Union[str, pathlib.Path]) -> str: def _find_platform_from_cmake_file(cmake_file: Union[str, pathlib.Path]) -> str: emu_platform = None - with open(API_SERVER_DIR / CMAKELIST_FILENAME) as cmake_f: + with open(cmake_file) as cmake_f: for line in cmake_f: set_platform = re.match("set\(EMU_PLATFORM (.*)\)", line) if set_platform: @@ -184,14 +184,14 @@ def _find_platform_from_cmake_file(cmake_file: Union[str, pathlib.Path]) -> str: return emu_platform -def _get_device_args(options): +def _get_device_args(serial_number: str = None): flash_runner = _get_flash_runner() if flash_runner == "nrfjprog": - return _get_nrf_device_args(options) + return _get_nrf_device_args(serial_number) if flash_runner == "openocd": - return _get_openocd_device_args(options) + return _get_openocd_device_args(serial_number) raise BoardError( f"Don't know how to find serial terminal for board {_find_board_from_cmake_file(API_SERVER_DIR / CMAKELIST_FILENAME)} with flash " @@ -199,14 +199,8 @@ def _get_device_args(options): ) -def _get_board_mem_size_bytes(options): - board_file_path = ( - pathlib.Path(get_zephyr_base(options)) - / "boards" - / "arm" - / options["board"] - / (options["board"] + ".yaml") - ) +def _get_board_mem_size_bytes(zephyr_base: str, board: str): + board_file_path = pathlib.Path(zephyr_base) / "boards" / "arm" / board / (board + ".yaml") try: with open(board_file_path) as f: board_data = yaml.load(f, Loader=yaml.FullLoader) @@ -219,14 +213,14 @@ def _get_board_mem_size_bytes(options): DEFAULT_HEAP_SIZE_BYTES = 216 * 1024 -def _get_recommended_heap_size_bytes(options): - prop = BOARD_PROPERTIES[options["board"]] +def _get_recommended_heap_size_bytes(board: str): + prop = BOARD_PROPERTIES[board] if "recommended_heap_size_bytes" in prop: return prop["recommended_heap_size_bytes"] return DEFAULT_HEAP_SIZE_BYTES -def generic_find_serial_port(serial_number=None): +def generic_find_serial_port(serial_number: str = None): """Find a USB serial port based on its serial number or its VID:PID. This method finds a USB serial port device path based on the port's serial number (if given) or @@ -264,12 +258,11 @@ def generic_find_serial_port(serial_number=None): return serial_ports[0].device -def _get_openocd_device_args(options): - serial_number = options.get("serial_number") +def _get_openocd_device_args(serial_number: str = None): return ["--serial", generic_find_serial_port(serial_number)] -def _get_nrf_device_args(serial_number: str): +def _get_nrf_device_args(serial_number: str = None): nrfjprog_args = ["nrfjprog", "--ids"] nrfjprog_ids = subprocess.check_output(nrfjprog_args, encoding="utf-8") if not nrfjprog_ids.strip("\n"): @@ -369,26 +362,6 @@ def _get_nrf_device_args(serial_number: str): ] -def get_zephyr_base(options: dict) -> str: - """Returns Zephyr base path""" - zephyr_base = options.get("zephyr_base", ZEPHYR_BASE) - assert zephyr_base, "'zephyr_base' option not passed and not found by default!" - return zephyr_base - - -def get_cmsis_path(options: dict) -> pathlib.Path: - """Returns CMSIS dependency path""" - cmsis_path = options.get("cmsis_path", os.environ.get("CMSIS_PATH", None)) - if cmsis_path: - return pathlib.Path(cmsis_path) - return None - - -def get_west_cmd(options: dict) -> str: - """Returns west command""" - return options.get("west_cmd", WEST_CMD) - - class Handler(server.ProjectAPIHandler): def __init__(self): super(Handler, self).__init__() @@ -546,23 +519,22 @@ def _generate_cmake_args( def generate_project(self, model_library_format_path, standalone_crt_dir, project_dir, options): zephyr_board = options["board"] project_type = options["project_type"] + zephyr_base = options["zephyr_base"] + west_cmd = options["west_cmd"] - zephyr_base = get_zephyr_base(options) warning_as_error = options.get("warning_as_error") use_fvp = options.get("use_fvp") - west_cmd = get_west_cmd(options) verbose = options.get("verbose") - recommended_heap_size = _get_recommended_heap_size_bytes(options) + recommended_heap_size = _get_recommended_heap_size_bytes(zephyr_board) heap_size_bytes = options.get("heap_size_bytes") or recommended_heap_size - board_mem_size = _get_board_mem_size_bytes(options) + board_mem_size = _get_board_mem_size_bytes(zephyr_base, zephyr_board) compile_definitions = options.get("compile_definitions") config_main_stack_size = options.get("config_main_stack_size") extra_files_tar = options.get("extra_files_tar") - - cmsis_path = get_cmsis_path(options) + cmsis_path = options.get("cmsis_path") # Check Zephyr version version = self._get_platform_version(zephyr_base) @@ -679,7 +651,7 @@ def generate_project(self, model_library_format_path, standalone_crt_dir, projec tf.extractall(project_dir) def build(self, options): - verbose = options.get("verbose", None) + verbose = options.get("verbose") if BUILD_DIR.exists(): shutil.rmtree(BUILD_DIR) @@ -737,7 +709,7 @@ def _has_fpu(cls, zephyr_board): def flash(self, options): serial_number = options.get("serial_number") - west_cmd_list = get_west_cmd(options).split(" ") + west_cmd_list = options["west_cmd"].split(" ") if _find_platform_from_cmake_file(API_SERVER_DIR / CMAKELIST_FILENAME): return # NOTE: qemu requires no flash step--it is launched from open_transport. @@ -766,11 +738,16 @@ def open_transport(self, options): zephyr_board = _find_board_from_cmake_file(API_SERVER_DIR / CMAKELIST_FILENAME) emu_platform = _find_platform_from_cmake_file(API_SERVER_DIR / CMAKELIST_FILENAME) if self._is_fvp(zephyr_board, emu_platform == "armfvp"): - transport = ZephyrFvpTransport(options) + arm_fvp_path = options["arm_fvp_path"] + verbose = options.get("verbose") + transport = ZephyrFvpTransport(arm_fvp_path, verbose) elif self._is_qemu(zephyr_board): - transport = ZephyrQemuTransport(options) + gdbserver_port = options.get("gdbserver_port") + transport = ZephyrQemuTransport(gdbserver_port) else: - transport = ZephyrSerialTransport(options) + zephyr_base = options["zephyr_base"] + serial_number = options.get("serial_number") + transport = ZephyrSerialTransport(zephyr_base, serial_number) to_return = transport.open() self._transport = transport @@ -811,14 +788,12 @@ class ZephyrSerialTransport: NRF5340_DK_BOARD_VCOM_BY_PRODUCT_ID = {0x1055: "VCOM2", 0x1051: "VCOM1"} @classmethod - def _lookup_baud_rate(cls, options): + def _lookup_baud_rate(cls, zephyr_base: str): # TODO(mehrdadh): remove this hack once dtlib.py is a standalone project # https://github.com/zephyrproject-rtos/zephyr/blob/v2.7-branch/scripts/dts/README.txt sys.path.insert( 0, - os.path.join( - get_zephyr_base(options), "scripts", "dts", "python-devicetree", "src", "devicetree" - ), + os.path.join(zephyr_base, "scripts", "dts", "python-devicetree", "src", "devicetree"), ) try: import dtlib # pylint: disable=import-outside-toplevel @@ -838,9 +813,9 @@ def _lookup_baud_rate(cls, options): return uart_baud @classmethod - def _find_nrf_serial_port(cls, options): + def _find_nrf_serial_port(cls, serial_number: str = None): com_ports = subprocess.check_output( - ["nrfjprog", "--com"] + _get_device_args(options), encoding="utf-8" + ["nrfjprog", "--com"] + _get_device_args(serial_number), encoding="utf-8" ) ports_by_vcom = {} for line in com_ports.split("\n")[:-1]: @@ -860,43 +835,43 @@ def _find_nrf_serial_port(cls, options): return ports_by_vcom[vcom_port] @classmethod - def _find_openocd_serial_port(cls, options): - serial_number = options.get("serial_number") + def _find_openocd_serial_port(cls, serial_number: str = None): return generic_find_serial_port(serial_number) @classmethod - def _find_jlink_serial_port(cls, options): - return generic_find_serial_port() + def _find_jlink_serial_port(cls, serial_number: str = None): + return generic_find_serial_port(serial_number) @classmethod - def _find_stm32cubeprogrammer_serial_port(cls, options): - return generic_find_serial_port() + def _find_stm32cubeprogrammer_serial_port(cls, serial_number: str = None): + return generic_find_serial_port(serial_number) @classmethod - def _find_serial_port(cls, options): + def _find_serial_port(cls, serial_number: str = None): flash_runner = _get_flash_runner() if flash_runner == "nrfjprog": - return cls._find_nrf_serial_port(options) + return cls._find_nrf_serial_port(serial_number) if flash_runner == "openocd": - return cls._find_openocd_serial_port(options) + return cls._find_openocd_serial_port(serial_number) if flash_runner == "jlink": - return cls._find_jlink_serial_port(options) + return cls._find_jlink_serial_port(serial_number) if flash_runner == "stm32cubeprogrammer": - return cls._find_stm32cubeprogrammer_serial_port(options) + return cls._find_stm32cubeprogrammer_serial_port(serial_number) raise RuntimeError(f"Don't know how to deduce serial port for flash runner {flash_runner}") - def __init__(self, options): - self._options = options + def __init__(self, zephyr_base: str, serial_number: str = None): + self._zephyr_base = zephyr_base + self._serial_number = serial_number self._port = None def open(self): - port_path = self._find_serial_port(self._options) - self._port = serial.Serial(port_path, baudrate=self._lookup_baud_rate(self._options)) + port_path = self._find_serial_port(self._serial_number) + self._port = serial.Serial(port_path, baudrate=self._lookup_baud_rate(self._zephyr_base)) return server.TransportTimeouts( session_start_retry_timeout_sec=2.0, session_start_timeout_sec=5.0, @@ -933,8 +908,8 @@ class ZephyrQemuMakeResult(enum.Enum): class ZephyrQemuTransport: """The user-facing Zephyr QEMU transport class.""" - def __init__(self, options): - self.options = options + def __init__(self, gdbserver_port: int = None): + self._gdbserver_port = gdbserver_port self.proc = None self.pipe_dir = None self.read_fd = None @@ -954,9 +929,9 @@ def open(self): os.mkfifo(self.read_pipe) env = None - if self.options.get("gdbserver_port"): + if self._gdbserver_port: env = os.environ.copy() - env["TVM_QEMU_GDBSERVER_PORT"] = self.options["gdbserver_port"] + env["TVM_QEMU_GDBSERVER_PORT"] = self._gdbserver_port self.proc = subprocess.Popen( ["ninja", "run"], @@ -1102,20 +1077,18 @@ def write(self, data): class ZephyrFvpTransport: """A transport class that communicates with the ARM FVP via Iris server.""" - def __init__(self, options): - self.options = options + def __init__(self, arm_fvp_path: str, verbose: bool = False): + self._arm_fvp_path = arm_fvp_path + self._verbose = verbose self.proc = None self._queue = queue.Queue() self._import_iris() def _import_iris(self): - assert "arm_fvp_path" in self.options, "arm_fvp_path is not defined." + assert self._arm_fvp_path, "arm_fvp_path is not defined." # Location as seen in the FVP_Corstone_SSE-300_11.15_24 tar. iris_lib_path = ( - pathlib.Path(self.options["arm_fvp_path"]).parent.parent.parent - / "Iris" - / "Python" - / "iris" + pathlib.Path(self._arm_fvp_path).parent.parent.parent / "Iris" / "Python" / "iris" ) sys.path.insert(0, str(iris_lib_path.parent)) @@ -1142,7 +1115,7 @@ def _convertStringToU64Array(strValue): def open(self): args = ["ninja"] - if self.options.get("verbose"): + if self._verbose: args.append("-v") args.append("run") env = dict(os.environ) diff --git a/python/tvm/micro/project.py b/python/tvm/micro/project.py index 9dd57123676b..32d2cbf4db71 100644 --- a/python/tvm/micro/project.py +++ b/python/tvm/micro/project.py @@ -28,6 +28,17 @@ from .transport import Transport, TransportTimeouts +def add_unspecified_options(options: dict, server_project_options: list) -> dict: + """Adds default value of project template options that are not specified by user.""" + if not options: + options = dict() + for option in server_project_options: + name = option["name"] + if name not in options.keys(): + options[name] = option["default"] + return options + + class ProjectTransport(Transport): """A Transport implementation that uses the Project API client.""" @@ -69,10 +80,10 @@ def from_directory(cls, project_dir: Union[pathlib.Path, str], options: dict): def __init__(self, api_client, options): self._api_client = api_client - self._options = options self._info = self._api_client.server_info_query(__version__) if self._info["is_template"]: raise TemplateProjectError() + self._options = add_unspecified_options(options, self._info["project_options"]) def build(self): self._api_client.build(self._options) @@ -124,6 +135,8 @@ def _check_project_options(self, options: dict): def generate_project_from_mlf(self, model_library_format_path, project_dir, options: dict): """Generate a project from MLF file.""" self._check_project_options(options) + options = add_unspecified_options(options, self._info["project_options"]) + self._api_client.generate_project( model_library_format_path=str(model_library_format_path), standalone_crt_dir=get_standalone_crt_dir(), diff --git a/python/tvm/micro/project_api/server.py b/python/tvm/micro/project_api/server.py index 5aed3a896241..2d5db09f4bbe 100644 --- a/python/tvm/micro/project_api/server.py +++ b/python/tvm/micro/project_api/server.py @@ -804,7 +804,7 @@ def default_project_options(**kw) -> typing.List[ProjectOption]: "cmsis_path", optional=["generate_project"], type="str", - default=None, + default=os.environ.get("CMSIS_PATH", None), help="Path to the CMSIS directory.", ), ProjectOption( diff --git a/tests/micro/arduino/conftest.py b/tests/micro/arduino/conftest.py index a5ce8127c0bb..ffa1376efe12 100644 --- a/tests/micro/arduino/conftest.py +++ b/tests/micro/arduino/conftest.py @@ -22,20 +22,7 @@ import pytest -def pytest_addoption(parser): - parser.addoption( - "--arduino-cli-cmd", - default="arduino-cli", - help="Path to `arduino-cli` command for flashing device.", - ) - - def pytest_configure(config): config.addinivalue_line( "markers", "requires_hardware: mark test to run only when an Arduino board is connected" ) - - -@pytest.fixture(scope="session") -def arduino_cli_cmd(request): - return request.config.getoption("--arduino-cli-cmd") diff --git a/tests/micro/arduino/test_arduino_error_detection.py b/tests/micro/arduino/test_arduino_error_detection.py index f1278094b484..75b97fa86ca3 100644 --- a/tests/micro/arduino/test_arduino_error_detection.py +++ b/tests/micro/arduino/test_arduino_error_detection.py @@ -24,10 +24,8 @@ @pytest.fixture -def project(board, arduino_cli_cmd, microtvm_debug, workspace_dir, serial_number): - return test_utils.make_kws_project( - board, arduino_cli_cmd, microtvm_debug, workspace_dir, serial_number - ) +def project(board, microtvm_debug, workspace_dir, serial_number): + return test_utils.make_kws_project(board, microtvm_debug, workspace_dir, serial_number) def test_blank_project_compiles(workspace_dir, project): diff --git a/tests/micro/arduino/test_arduino_rpc_server.py b/tests/micro/arduino/test_arduino_rpc_server.py index ae22fb9499b8..38f34de82beb 100644 --- a/tests/micro/arduino/test_arduino_rpc_server.py +++ b/tests/micro/arduino/test_arduino_rpc_server.py @@ -41,7 +41,6 @@ def _make_session( model, arduino_board, - arduino_cli_cmd, workspace_dir, mod, build_config, @@ -53,7 +52,6 @@ def _make_session( workspace_dir / "project", { "board": arduino_board, - "arduino_cli_cmd": arduino_cli_cmd, "project_type": "host_driven", "verbose": bool(build_config.get("debug")), "serial_number": serial_number, @@ -67,7 +65,6 @@ def _make_session( def _make_sess_from_op( model, arduino_board, - arduino_cli_cmd, workspace_dir, op_name, sched, @@ -80,14 +77,10 @@ def _make_sess_from_op( with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): mod = tvm.build(sched, arg_bufs, target=target, runtime=runtime, name=op_name) - return _make_session( - model, arduino_board, arduino_cli_cmd, workspace_dir, mod, build_config, serial_number - ) + return _make_session(model, arduino_board, workspace_dir, mod, build_config, serial_number) -def _make_add_sess( - model, arduino_board, arduino_cli_cmd, workspace_dir, build_config, serial_number: str = None -): +def _make_add_sess(model, arduino_board, workspace_dir, build_config, serial_number: str = None): A = tvm.te.placeholder((2,), dtype="int8") B = tvm.te.placeholder((1,), dtype="int8") C = tvm.te.compute(A.shape, lambda i: A[i] + B[0], name="C") @@ -95,7 +88,6 @@ def _make_add_sess( return _make_sess_from_op( model, arduino_board, - arduino_cli_cmd, workspace_dir, "add", sched, @@ -108,7 +100,7 @@ def _make_add_sess( # The same test code can be executed on both the QEMU simulation and on real hardware. @tvm.testing.requires_micro @pytest.mark.requires_hardware -def test_compile_runtime(board, arduino_cli_cmd, microtvm_debug, workspace_dir, serial_number): +def test_compile_runtime(board, microtvm_debug, workspace_dir, serial_number): """Test compiling the on-device runtime.""" model = test_utils.ARDUINO_BOARDS[board] @@ -127,15 +119,13 @@ def test_basic_add(sess): system_lib.get_function("add")(A_data, B_data, C_data) assert (C_data.numpy() == np.array([6, 7])).all() - with _make_add_sess( - model, board, arduino_cli_cmd, workspace_dir, build_config, serial_number - ) as sess: + with _make_add_sess(model, board, workspace_dir, build_config, serial_number) as sess: test_basic_add(sess) @tvm.testing.requires_micro @pytest.mark.requires_hardware -def test_platform_timer(board, arduino_cli_cmd, microtvm_debug, workspace_dir, serial_number): +def test_platform_timer(board, microtvm_debug, workspace_dir, serial_number): """Test compiling the on-device runtime.""" model = test_utils.ARDUINO_BOARDS[board] @@ -159,15 +149,13 @@ def test_basic_add(sess): assert result.mean > 0 assert len(result.results) == 3 - with _make_add_sess( - model, board, arduino_cli_cmd, workspace_dir, build_config, serial_number - ) as sess: + with _make_add_sess(model, board, workspace_dir, build_config, serial_number) as sess: test_basic_add(sess) @tvm.testing.requires_micro @pytest.mark.requires_hardware -def test_relay(board, arduino_cli_cmd, microtvm_debug, workspace_dir, serial_number): +def test_relay(board, microtvm_debug, workspace_dir, serial_number): """Testing a simple relay graph""" model = test_utils.ARDUINO_BOARDS[board] build_config = {"debug": microtvm_debug} @@ -186,9 +174,7 @@ def test_relay(board, arduino_cli_cmd, microtvm_debug, workspace_dir, serial_num with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): mod = tvm.relay.build(func, target=target, runtime=runtime) - with _make_session( - model, board, arduino_cli_cmd, workspace_dir, mod, build_config, serial_number - ) as session: + with _make_session(model, board, workspace_dir, mod, build_config, serial_number) as session: graph_mod = tvm.micro.create_local_graph_executor( mod.get_graph_json(), session.get_system_lib(), session.device ) @@ -202,7 +188,7 @@ def test_relay(board, arduino_cli_cmd, microtvm_debug, workspace_dir, serial_num @tvm.testing.requires_micro @pytest.mark.requires_hardware -def test_onnx(board, arduino_cli_cmd, microtvm_debug, workspace_dir, serial_number): +def test_onnx(board, microtvm_debug, workspace_dir, serial_number): """Testing a simple ONNX model.""" model = test_utils.ARDUINO_BOARDS[board] build_config = {"debug": microtvm_debug} @@ -232,7 +218,7 @@ def test_onnx(board, arduino_cli_cmd, microtvm_debug, workspace_dir, serial_numb graph = lowered.get_graph_json() with _make_session( - model, board, arduino_cli_cmd, workspace_dir, lowered, build_config, serial_number + model, board, workspace_dir, lowered, build_config, serial_number ) as session: graph_mod = tvm.micro.create_local_graph_executor( graph, session.get_system_lib(), session.device @@ -256,7 +242,6 @@ def check_result( relay_mod, model, arduino_board, - arduino_cli_cmd, workspace_dir, map_inputs, out_shape, @@ -272,7 +257,7 @@ def check_result( mod = tvm.relay.build(relay_mod, target=target, runtime=runtime) with _make_session( - model, arduino_board, arduino_cli_cmd, workspace_dir, mod, build_config, serial_number + model, arduino_board, workspace_dir, mod, build_config, serial_number ) as session: rt_mod = tvm.micro.create_local_graph_executor( mod.get_graph_json(), session.get_system_lib(), session.device @@ -294,7 +279,7 @@ def check_result( @tvm.testing.requires_micro @pytest.mark.requires_hardware -def test_byoc_microtvm(board, arduino_cli_cmd, microtvm_debug, workspace_dir, serial_number): +def test_byoc_microtvm(board, microtvm_debug, workspace_dir, serial_number): """This is a simple test case to check BYOC capabilities of microTVM""" model = test_utils.ARDUINO_BOARDS[board] build_config = {"debug": microtvm_debug} @@ -352,7 +337,6 @@ def test_byoc_microtvm(board, arduino_cli_cmd, microtvm_debug, workspace_dir, se model=model, build_config=build_config, arduino_board=board, - arduino_cli_cmd=arduino_cli_cmd, workspace_dir=workspace_dir, serial_number=serial_number, ) @@ -361,7 +345,6 @@ def test_byoc_microtvm(board, arduino_cli_cmd, microtvm_debug, workspace_dir, se def _make_add_sess_with_shape( model, arduino_board, - arduino_cli_cmd, workspace_dir, shape, build_config, @@ -373,7 +356,6 @@ def _make_add_sess_with_shape( return _make_sess_from_op( model, arduino_board, - arduino_cli_cmd, workspace_dir, "add", sched, @@ -393,9 +375,7 @@ def _make_add_sess_with_shape( ) @tvm.testing.requires_micro @pytest.mark.requires_hardware -def test_rpc_large_array( - board, arduino_cli_cmd, microtvm_debug, workspace_dir, shape, serial_number -): +def test_rpc_large_array(board, microtvm_debug, workspace_dir, shape, serial_number): """Test large RPC array transfer.""" model = test_utils.ARDUINO_BOARDS[board] build_config = {"debug": microtvm_debug} @@ -410,7 +390,7 @@ def test_tensors(sess): assert (C_data.numpy() == np.zeros(shape)).all() with _make_add_sess_with_shape( - model, board, arduino_cli_cmd, workspace_dir, shape, build_config, serial_number + model, board, workspace_dir, shape, build_config, serial_number ) as sess: test_tensors(sess) diff --git a/tests/micro/arduino/test_arduino_workflow.py b/tests/micro/arduino/test_arduino_workflow.py index 51898424aee5..42874ad6c349 100644 --- a/tests/micro/arduino/test_arduino_workflow.py +++ b/tests/micro/arduino/test_arduino_workflow.py @@ -54,12 +54,10 @@ def project_dir(workflow_workspace_dir): # We MUST pass workspace_dir, not project_dir, or the workspace will be dereferenced # too soon. We can't use the board fixture either for the reason mentioned above. @pytest.fixture(scope="module") -def project(request, arduino_cli_cmd, microtvm_debug, workflow_workspace_dir): +def project(request, microtvm_debug, workflow_workspace_dir): board = request.config.getoption("--board") serial_number = request.config.getoption("--serial-number") - return test_utils.make_kws_project( - board, arduino_cli_cmd, microtvm_debug, workflow_workspace_dir, serial_number - ) + return test_utils.make_kws_project(board, microtvm_debug, workflow_workspace_dir, serial_number) def _get_directory_elements(directory): diff --git a/tests/micro/arduino/test_utils.py b/tests/micro/arduino/test_utils.py index d81edc845b98..1456e1f7591e 100644 --- a/tests/micro/arduino/test_utils.py +++ b/tests/micro/arduino/test_utils.py @@ -61,7 +61,7 @@ def make_workspace_dir(test_name, board): return t -def make_kws_project(board, arduino_cli_cmd, microtvm_debug, workspace_dir, serial_number: str): +def make_kws_project(board, microtvm_debug, workspace_dir, serial_number: str): this_dir = pathlib.Path(__file__).parent model = ARDUINO_BOARDS[board] build_config = {"debug": microtvm_debug} @@ -85,7 +85,6 @@ def make_kws_project(board, arduino_cli_cmd, microtvm_debug, workspace_dir, seri workspace_dir / "project", { "board": board, - "arduino_cli_cmd": arduino_cli_cmd, "project_type": "example_project", "verbose": bool(build_config.get("debug")), "serial_number": serial_number, diff --git a/tests/micro/project_api/test_arduino_microtvm_api_server.py b/tests/micro/project_api/test_arduino_microtvm_api_server.py index ad9bd4a56a2d..39e5780af6dc 100644 --- a/tests/micro/project_api/test_arduino_microtvm_api_server.py +++ b/tests/micro/project_api/test_arduino_microtvm_api_server.py @@ -136,14 +136,15 @@ def test_auto_detect_port(self, mock_run): arduino_cli_cmd = self.DEFAULT_OPTIONS.get("arduino_cli_cmd") warning_as_error = self.DEFAULT_OPTIONS.get("warning_as_error") - cli_command = handler._get_arduino_cli_cmd(arduino_cli_cmd) - handler._check_platform_version(cli_command=cli_command, warning_as_error=warning_as_error) + handler._check_platform_version( + cli_command=arduino_cli_cmd, warning_as_error=warning_as_error + ) assert handler._version == version.parse("0.21.1") handler = microtvm_api_server.Handler() mock_run.return_value.stdout = bytes(self.BAD_CLI_VERSION, "utf-8") with pytest.raises(server.ServerError) as error: - handler._check_platform_version(cli_command=cli_command, warning_as_error=True) + handler._check_platform_version(cli_command=arduino_cli_cmd, warning_as_error=True) mock_run.reset_mock() @mock.patch("subprocess.run") diff --git a/tests/micro/zephyr/test_zephyr.py b/tests/micro/zephyr/test_zephyr.py index a8fb26133970..6b49c043cc3d 100644 --- a/tests/micro/zephyr/test_zephyr.py +++ b/tests/micro/zephyr/test_zephyr.py @@ -605,7 +605,6 @@ def test_schedule_build_with_cmsis_dependency(workspace_dir, board, microtvm_deb "project_type": "host_driven", "verbose": bool(build_config.get("debug")), "board": board, - "cmsis_path": os.getenv("CMSIS_PATH"), "use_fvp": bool(use_fvp), }