From 26210c9c324cb6a282fddee7c53c9fa9e0a558ec Mon Sep 17 00:00:00 2001 From: Gustavo Romero Date: Fri, 10 Sep 2021 01:11:58 +0000 Subject: [PATCH 01/10] [microTVM] zephyr: Make platform options comply with RFC-0020 Make Zephyr platform options comply with RFC-0020 specification. Project options now need to specify the required metadata for every option, i.e. 'required', 'optional', and 'type'. Signed-off-by: Gustavo Romero --- .../template_project/microtvm_api_server.py | 42 ++++++++++++++++--- python/tvm/micro/project_api/server.py | 13 +++++- 2 files changed, 48 insertions(+), 7 deletions(-) diff --git a/apps/microtvm/zephyr/template_project/microtvm_api_server.py b/apps/microtvm/zephyr/template_project/microtvm_api_server.py index 7e13f928b288..3c96f31dfe22 100644 --- a/apps/microtvm/zephyr/template_project/microtvm_api_server.py +++ b/apps/microtvm/zephyr/template_project/microtvm_api_server.py @@ -234,49 +234,81 @@ def _get_nrf_device_args(options): PROJECT_OPTIONS = [ server.ProjectOption( "extra_files_tar", + optional=["generate_project"], + type="str", help="If given, during generate_project, uncompress the tarball at this path into the project dir.", ), server.ProjectOption( - "gdbserver_port", help=("If given, port number to use when running the local gdbserver.") + "gdbserver_port", + help=("If given, port number to use when running the local gdbserver."), + optional=["open_transport"], + type="int", ), server.ProjectOption( "nrfjprog_snr", + optional=["open_transport"], + type="int", help=("When used with nRF targets, serial # of the attached board to use, from nrfjprog."), ), server.ProjectOption( "openocd_serial", + optional=["open_transport"], + type="int", help=("When used with OpenOCD targets, serial # of the attached board to use."), ), server.ProjectOption( "project_type", - help="Type of project to generate.", choices=tuple(PROJECT_TYPES), + required=["generate_project"], + type="str", + help="Type of project to generate.", + ), + server.ProjectOption( + "verbose", + optional=["build"], + type="bool", + help="Run build with verbose output.", ), - server.ProjectOption("verbose", help="Run build with verbose output.", choices=(True, False)), server.ProjectOption( "west_cmd", + optional=["generate_project"], + default=sys.executable + " -m west" if sys.executable else None, + type="str", help=( "Path to the west tool. If given, supersedes both the zephyr_base " "option and ZEPHYR_BASE environment variable." ), ), - server.ProjectOption("zephyr_base", help="Path to the zephyr base directory."), + server.ProjectOption( + "zephyr_base", + optional=["build", "open_transport"], + default=os.getenv("ZEPHYR_BASE"), + type="str", + help="Path to the zephyr base directory.", + ), server.ProjectOption( "zephyr_board", + required=["generate_project", "build", "flash", "open_transport"], choices=list(BOARD_PROPERTIES), + type="str", help="Name of the Zephyr board to build for.", ), server.ProjectOption( "config_main_stack_size", + optional=["generate_project"], + type="int", help="Sets CONFIG_MAIN_STACK_SIZE for Zephyr board.", ), server.ProjectOption( "warning_as_error", - choices=(True, False), + optional=["generate_project"], + type="bool", help="Treat warnings as errors and raise an Exception.", ), server.ProjectOption( "compile_definitions", + optional=["generate_project"], + type="str", help="Extra definitions added project compile.", ), ] diff --git a/python/tvm/micro/project_api/server.py b/python/tvm/micro/project_api/server.py index cee0205303f0..7dd26008fc7a 100644 --- a/python/tvm/micro/project_api/server.py +++ b/python/tvm/micro/project_api/server.py @@ -42,15 +42,24 @@ _LOG = logging.getLogger(__name__) -_ProjectOption = collections.namedtuple("ProjectOption", ("name", "choices", "help")) +_ProjectOption = collections.namedtuple( + "ProjectOption", ("name", "choices", "default", "type", "required", "optional", "help") +) class ProjectOption(_ProjectOption): def __new__(cls, name, **kw): """Override __new__ to force all options except name to be specified as kwargs.""" assert "name" not in kw + assert ( + "required" in kw or "optional" in kw + ), "at least one of 'required' or 'optional' must be specified." + assert "type" in kw, "'type' field must be specified." + kw["name"] = name - kw.setdefault("choices", None) + for param in ["choices", "default", "required", "optional"]: + kw.setdefault(param, None) + return super().__new__(cls, **kw) From 6e186cf6193d3e73bb5960f1da26afa78aa68f9a Mon Sep 17 00:00:00 2001 From: Gustavo Romero Date: Fri, 24 Sep 2021 21:32:59 +0000 Subject: [PATCH 02/10] [microTVM] arduino: Make platform options comply with RFC-0020 Make Arduino platform options comply with RFC-0020 specification. Project options now need to specify the required metadata for every option, i.e. 'required', 'optional', and 'type'. Signed-off-by: Gustavo Romero --- .../template_project/microtvm_api_server.py | 30 +++++++++++++++---- 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/apps/microtvm/arduino/template_project/microtvm_api_server.py b/apps/microtvm/arduino/template_project/microtvm_api_server.py index 1768c61197a9..3039eb313908 100644 --- a/apps/microtvm/arduino/template_project/microtvm_api_server.py +++ b/apps/microtvm/arduino/template_project/microtvm_api_server.py @@ -71,22 +71,40 @@ class BoardAutodetectFailed(Exception): PROJECT_OPTIONS = [ server.ProjectOption( "arduino_board", + required=["build", "flash", "open_transport"], choices=list(BOARD_PROPERTIES), - help="Name of the Arduino board to build for", + type="str", + help="Name of the Arduino board to build for.", + ), + server.ProjectOption( + "arduino_cli_cmd", + required=["build", "flash", "open_transport"], + type="str", + help="Path to the arduino-cli tool.", + ), + server.ProjectOption( + "port", + optional=["flash", "open_transport"], + type="int", + help="Port to use for connecting to hardware.", ), - server.ProjectOption("arduino_cli_cmd", help="Path to the arduino-cli tool."), - server.ProjectOption("port", help="Port to use for connecting to hardware"), server.ProjectOption( "project_type", - help="Type of project to generate.", + required=["generate_project"], choices=tuple(PROJECT_TYPES), + type="str", + help="Type of project to generate.", ), server.ProjectOption( - "verbose", help="True to pass --verbose flag to arduino-cli compile and upload" + "verbose", + optional=["build", "flash"], + type="bool", + help="Run arduino-cli compile and upload with verbose output.", ), server.ProjectOption( "warning_as_error", - choices=(True, False), + optional=["generate_project"], + type="bool", help="Treat warnings as errors and raise an Exception.", ), ] From c0cd0b96d1581cac6005b512303a315bc9f7fb03 Mon Sep 17 00:00:00 2001 From: Gustavo Romero Date: Thu, 18 Nov 2021 01:17:50 +0000 Subject: [PATCH 03/10] [microTVM] crt: Make crt options comply with RFC-0020 Make crt project options comply with RFC-0020 specification. Project options now need to specify the required metadata for every option, i.e. 'required', 'optional', and 'type'. Signed-off-by: Gustavo Romero --- src/runtime/crt/host/microtvm_api_server.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/runtime/crt/host/microtvm_api_server.py b/src/runtime/crt/host/microtvm_api_server.py index 546ac1448011..925dc4ea7597 100644 --- a/src/runtime/crt/host/microtvm_api_server.py +++ b/src/runtime/crt/host/microtvm_api_server.py @@ -53,7 +53,10 @@ def server_info_query(self, tvm_version): else PROJECT_DIR / MODEL_LIBRARY_FORMAT_RELPATH, project_options=[ server.ProjectOption( - "verbose", help="Run make with verbose output", choices=(True, False) + "verbose", + optional=["build"], + type="bool", + help="Run make with verbose output", ) ], ) From fcd0e51b70f6c489f09a779db32a794bff4094cf Mon Sep 17 00:00:00 2001 From: Gustavo Romero Date: Thu, 18 Nov 2021 03:51:59 +0000 Subject: [PATCH 04/10] [microTVM][Unittest] Adapt test to RFC-0020 Adapt test to new metadata fields accordingly to RFC-0020 specification. Signed-off-by: Gustavo Romero --- .../python/unittest/test_micro_project_api.py | 32 ++++++++++++++++--- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/tests/python/unittest/test_micro_project_api.py b/tests/python/unittest/test_micro_project_api.py index 1e511c41d73e..1dd8940fecec 100644 --- a/tests/python/unittest/test_micro_project_api.py +++ b/tests/python/unittest/test_micro_project_api.py @@ -42,8 +42,16 @@ class BaseTestHandler_Impl(project_api.server.ProjectAPIHandler): is_template=True, model_library_format_path="./model-library-format-path.sh", project_options=[ - project_api.server.ProjectOption(name="foo", help="Option foo"), - project_api.server.ProjectOption(name="bar", choices=["qux"], help="Option bar"), + project_api.server.ProjectOption( + name="foo", optional=["build"], type="bool", help="Option foo" + ), + project_api.server.ProjectOption( + name="bar", + required=["generate_project"], + type="str", + choices=["qux"], + help="Option bar", + ), ], ) @@ -141,8 +149,24 @@ def test_server_info_query(BaseTestHandler): assert reply["is_template"] == True assert reply["model_library_format_path"] == "./model-library-format-path.sh" assert reply["project_options"] == [ - {"name": "foo", "choices": None, "help": "Option foo"}, - {"name": "bar", "choices": ["qux"], "help": "Option bar"}, + { + "name": "foo", + "choices": None, + "default": None, + "type": "bool", + "required": None, + "optional": ["build"], + "help": "Option foo", + }, + { + "name": "bar", + "choices": ["qux"], + "default": None, + "type": "str", + "required": ["generate_project"], + "optional": None, + "help": "Option bar", + }, ] From 6f959c8429e1a1b5e2fe7c29ece88e00313716ea Mon Sep 17 00:00:00 2001 From: Gustavo Romero Date: Tue, 5 Oct 2021 23:34:47 +0000 Subject: [PATCH 05/10] [microTVM] Add info() method to GeneratedProject class Add info() method to GeneratedProject class so one can use the Project API to query options for project dirs instead of only for template projects. This commit also adds for the sake of convenience a setter and a getter for 'options' in case it's necessary to set or get 'options' after a GeneratedProject class is instantiated without initializing 'options'. Signed-off-by: Gustavo Romero --- python/tvm/micro/project.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/python/tvm/micro/project.py b/python/tvm/micro/project.py index a5e54aa816a3..907590dcd2cf 100644 --- a/python/tvm/micro/project.py +++ b/python/tvm/micro/project.py @@ -83,6 +83,17 @@ def flash(self): def transport(self): return ProjectTransport(self._api_client, self._options) + def info(self): + return self._info + + @property + def options(self): + return self._options + + @options.setter + def options(self, options): + self._options = options + class NotATemplateProjectError(Exception): """Raised when the API server given to TemplateProject reports is_template=false.""" From 517a134ab0c121fa288967897500e2ef3894acd5 Mon Sep 17 00:00:00 2001 From: Gustavo Romero Date: Sun, 24 Oct 2021 21:45:44 +0000 Subject: [PATCH 06/10] [microTVM] Fix typo in python/tvm/micro/session.py Fix typo in comment. Signed-off-by: Gustavo Romero --- python/tvm/micro/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/micro/session.py b/python/tvm/micro/session.py index d545f2e7daa4..4f754d9d442c 100644 --- a/python/tvm/micro/session.py +++ b/python/tvm/micro/session.py @@ -73,7 +73,7 @@ def __init__( ---------- transport_context_manager : ContextManager[transport.Transport] If given, `flasher` and `binary` should not be given. On entry, this context manager - should establish a tarnsport between this TVM instance and the device. + should establish a transport between this TVM instance and the device. session_name : str Name of the session, used for debugging. timeout_override : TransportTimeouts From 1dbf8dc7d7cdff280fe1f96a1c64fe3b04cf7af9 Mon Sep 17 00:00:00 2001 From: Gustavo Romero Date: Sun, 15 Aug 2021 23:56:27 +0000 Subject: [PATCH 07/10] Allow multiple runs on micro targets Currently there is a limitation on microTVM / TVM which doesn't allow running a model multiple times in sequence without previously flashing the model to the device. Root cause is that RPCModuleNode class destructor is called once a run finishes. The destructor sends a RPCCode::kFreeHandle packet with type_code = kTVMModuleHandle to the device which wipes entries in crt/src/runtime/crt/common/crt_runtime_api.c:147:static const TVMModule* registered_modules[TVM_CRT_MAX_REGISTERED_MODULES] when TVMFreeMod() is called when the target receives a kFreeHandle packet. Hence when one tries to re-run a model registered_modules[0] == NULL causes a backtrace on the host side. Probably never before a model on microTVM was run without being flashed just before the run, so tvmc run implementation for micro targets exposed the issue. This commit fixes it by not calling TVMFreeMod() for system_lib_handle on the target side when a session terminates so the pointer to the system_lib_handle is not flushed from 'registered_modules', allowing multiple runs on micro targets. Signed-off-by: Gustavo Romero --- src/runtime/crt/common/crt_runtime_api.c | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/runtime/crt/common/crt_runtime_api.c b/src/runtime/crt/common/crt_runtime_api.c index ea986a3bf096..49a699c3ce13 100644 --- a/src/runtime/crt/common/crt_runtime_api.c +++ b/src/runtime/crt/common/crt_runtime_api.c @@ -183,7 +183,16 @@ int TVMModCreateFromCModule(const TVMModule* mod, TVMModuleHandle* out_handle) { return -1; } +static const TVMModuleHandle kTVMModuleHandleUninitialized = (TVMModuleHandle)(~0UL); + +static TVMModuleHandle system_lib_handle; + int TVMModFree(TVMModuleHandle mod) { + /* Never free system_lib_handler */ + if (mod == system_lib_handle && system_lib_handle != kTVMModuleHandleUninitialized) { + return 0; + } + tvm_module_index_t module_index; if (DecodeModuleHandle(mod, &module_index) != 0) { return -1; @@ -193,10 +202,6 @@ int TVMModFree(TVMModuleHandle mod) { return 0; } -static const TVMModuleHandle kTVMModuleHandleUninitialized = (TVMModuleHandle)(~0UL); - -static TVMModuleHandle system_lib_handle; - int SystemLibraryCreate(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_val, int* ret_type_codes) { const TVMModule* system_lib; From 9ac3129783a78e5bcd9e85b0f79e44567366877d Mon Sep 17 00:00:00 2001 From: Gustavo Romero Date: Fri, 1 Oct 2021 23:58:27 +0000 Subject: [PATCH 08/10] [TVMC] Pass main parser when calling add_*_parser functions Currently when a add_*_parser functions are called in main.py to build and add the various subparsers to the main parser only a subparser is passed to the functions. However if one of these functions need to build a dynamic parser it needs also to call the main parser at least once to parse once the command line and get the arguments necessary to finally build the complete parser. This commit fixes that limitation by passing also the main parser when calling the subparser builders so it can be used to build the dynamic subparses. Signed-off-by: Gustavo Romero --- python/tvm/driver/tvmc/autotuner.py | 2 +- python/tvm/driver/tvmc/compiler.py | 2 +- python/tvm/driver/tvmc/main.py | 2 +- python/tvm/driver/tvmc/runner.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tvm/driver/tvmc/autotuner.py b/python/tvm/driver/tvmc/autotuner.py index 92d13a99acd5..60bec38f0d1a 100644 --- a/python/tvm/driver/tvmc/autotuner.py +++ b/python/tvm/driver/tvmc/autotuner.py @@ -46,7 +46,7 @@ @register_parser -def add_tune_parser(subparsers): +def add_tune_parser(subparsers, _): """Include parser for 'tune' subcommand""" parser = subparsers.add_parser("tune", help="auto-tune a model") diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index 7623a141c27a..a51a16f7e017 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -38,7 +38,7 @@ @register_parser -def add_compile_parser(subparsers): +def add_compile_parser(subparsers, _): """Include parser for 'compile' subcommand""" parser = subparsers.add_parser("compile", help="compile a model.") diff --git a/python/tvm/driver/tvmc/main.py b/python/tvm/driver/tvmc/main.py index 2574daab02ac..ab72d59fbea6 100644 --- a/python/tvm/driver/tvmc/main.py +++ b/python/tvm/driver/tvmc/main.py @@ -66,7 +66,7 @@ def _main(argv): subparser = parser.add_subparsers(title="commands") for make_subparser in REGISTERED_PARSER: - make_subparser(subparser) + make_subparser(subparser, parser) args = parser.parse_args(argv) if args.verbose > 4: diff --git a/python/tvm/driver/tvmc/runner.py b/python/tvm/driver/tvmc/runner.py index 659df7ceef33..52af02bc43d8 100644 --- a/python/tvm/driver/tvmc/runner.py +++ b/python/tvm/driver/tvmc/runner.py @@ -41,7 +41,7 @@ @register_parser -def add_run_parser(subparsers): +def add_run_parser(subparsers, main_parser): """Include parser for 'run' subcommand""" parser = subparsers.add_parser("run", help="run a compiled module") From d1ddc7564e8d7ccea2b334f1267f31fc506020c1 Mon Sep 17 00:00:00 2001 From: Gustavo Romero Date: Thu, 24 Jun 2021 18:25:37 +0000 Subject: [PATCH 09/10] [TVMC] micro: Add new micro context This commit introduces support for micro targets (targets supported by microTVM). It creates a new micro context under the new TVMC command 'tvmc micro'. Moreover, three new subcommands are made available in the new context under 'tvmc micro': 'create-project', 'build', and 'flash'. The new support relies on the Project API to query all the options available for a selected platform (like Zephyr and Arduino) and also from any adhoc platform template directory which provides a custom Project API server. Signed-off-by: Gustavo Romero --- python/tvm/driver/tvmc/__init__.py | 3 +- python/tvm/driver/tvmc/common.py | 232 +++++++++++++++++++ python/tvm/driver/tvmc/fmtopt.py | 116 ++++++++++ python/tvm/driver/tvmc/main.py | 6 + python/tvm/driver/tvmc/micro.py | 300 +++++++++++++++++++++++++ python/tvm/micro/project_api/server.py | 2 + 6 files changed, 658 insertions(+), 1 deletion(-) create mode 100644 python/tvm/driver/tvmc/fmtopt.py create mode 100644 python/tvm/driver/tvmc/micro.py diff --git a/python/tvm/driver/tvmc/__init__.py b/python/tvm/driver/tvmc/__init__.py index 42184c34df74..70747cbb2d74 100644 --- a/python/tvm/driver/tvmc/__init__.py +++ b/python/tvm/driver/tvmc/__init__.py @@ -19,9 +19,10 @@ TVMC - TVM driver command-line interface """ +from . import micro +from . import runner from . import autotuner from . import compiler -from . import runner from . import result_utils from .frontends import load_model as load from .compiler import compile_model as compile diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py index 65b0c3dbc0aa..97b7c5206a38 100644 --- a/python/tvm/driver/tvmc/common.py +++ b/python/tvm/driver/tvmc/common.py @@ -23,6 +23,7 @@ import os.path import argparse +from collections import defaultdict from urllib.parse import urlparse import tvm @@ -31,6 +32,7 @@ from tvm import relay from tvm import transform from tvm._ffi import registry +from .fmtopt import format_option # pylint: disable=invalid-name logger = logging.getLogger("TVMC") @@ -40,6 +42,37 @@ class TVMCException(Exception): """TVMC Exception""" +class TVMCSuppressedArgumentParser(argparse.ArgumentParser): + """ + A silent ArgumentParser class. + + This class is meant to be used as a helper for creating dynamic parsers in + TVMC. It will create a "supressed" parser based on an existing one (parent) + which does not include a help message, does not print a usage message (even + when -h or --help is passed) and does not exit on invalid choice parse + errors but rather throws a TVMCException so it can be handled and the + dynamic parser construction is not interrupted prematurely. + + """ + + def __init__(self, parent, **kwargs): + # Don't add '-h' or '--help' options to the newly created parser. Don't print usage message. + # 'add_help=False' won't supress existing '-h' and '--help' options from the parser (and its + # subparsers) present in 'parent'. However that class is meant to be used with the main + # parser, which is created with `add_help=False` - the help is added only later. Hence it + # the newly created parser won't have help options added in its (main) root parser. The + # subparsers in the main parser will eventually have help activated, which is enough for its + # use in TVMC. + super().__init__(parents=[parent], add_help=False, usage=argparse.SUPPRESS, **kwargs) + + def exit(self, status=0, message=None): + # Don't exit on error when parsing the command line. + # This won't catch all the errors generated when parsing tho. For instance, it won't catch + # errors due to missing required arguments. But this will catch "error: invalid choice", + # which is what it's necessary for its use in TVMC. + raise TVMCException() + + def convert_graph_layout(mod, desired_layout): """Alter the layout of the input graph. @@ -554,3 +587,202 @@ def parse_configs(input_configs): pass_context_configs[name] = parsed_value return pass_context_configs + + +def get_project_options(project_info): + """Get all project options as returned by Project API 'server_info_query' + and return them in a dict indexed by the API method they belong to. + + + Parameters + ---------- + project_info: dict of list + a dict of lists as returned by Project API 'server_info_query' among + which there is a list called 'project_options' containing all the + project options available for a given project/platform. + + Returns + ------- + options_by_method: dict of list + a dict indexed by the API method names (e.g. "generate_project", + "build", "flash", or "open_transport") of lists containing all the + options (plus associated metadata and formatted help text) that belong + to a method. + + The metadata associated to the options include the field 'choices' and + 'required' which are convenient for parsers. + + The formatted help text field 'help_text' is a string that contains the + name of the option, the choices for the option, and the option's default + value. + """ + options = project_info["project_options"] + + options_by_method = defaultdict(list) + for opt in options: + # Get list of methods associated with an option based on the + # existance of a 'required' or 'optional' lists. API specification + # guarantees at least one of these lists will exist. If a list does + # not exist it's returned as None by the API. + metadata = ["required", "optional"] + om = [(opt[md], bool(md == "required")) for md in metadata if opt[md]] + for methods, is_opt_required in om: + for method in methods: + name = opt["name"] + + # Only for boolean options set 'choices' accordingly to the + # option type. API returns 'choices' associated to them + # as None but 'choices' can be deduced from 'type' in this case. + if opt["type"] == "bool": + opt["choices"] = ["true", "false"] + + if opt["choices"]: + choices = "{" + ", ".join(opt["choices"]) + "}" + else: + choices = opt["name"].upper() + option_choices_text = f"{name}={choices}" + + help_text = opt["help"][0].lower() + opt["help"][1:] + + if opt["default"]: + default_text = f"Defaults to '{opt['default']}'." + else: + default_text = None + + formatted_help_text = format_option( + option_choices_text, help_text, default_text, is_opt_required + ) + + option = { + "name": opt["name"], + "choices": opt["choices"], + "help_text": formatted_help_text, + "required": is_opt_required, + } + options_by_method[method].append(option) + + return options_by_method + + +def get_options(options): + """Get option and option value from the list options returned by the parser. + + Parameters + ---------- + options: list of str + list of strings of the form "option=value" as returned by the parser. + + Returns + ------- + opts: dict + dict indexed by option names and associated values. + """ + + opts = {} + for option in options: + try: + k, v = option.split("=") + opts[k] = v + except ValueError: + raise TVMCException(f"Invalid option format: {option}. Please use OPTION=VALUE.") + + return opts + + +def check_options(options, valid_options): + """Check if an option (required or optional) is valid. i.e. in the list of valid options. + + Parameters + ---------- + options: dict + dict indexed by option name of options and options values to be checked. + + valid_options: list of dict + list of all valid options and choices for a platform. + + Returns + ------- + None. Raise TVMCException if check fails, i.e. if an option is not in the list of valid options. + + """ + required_options = [opt["name"] for opt in valid_options if opt["required"]] + for required_option in required_options: + if required_option not in options: + raise TVMCException( + f"Option '{required_option}' is required but was not specified. Use --list-options " + "to see all required options." + ) + + remaining_options = set(options) - set(required_options) + optional_options = [opt["name"] for opt in valid_options if not opt["required"]] + for option in remaining_options: + if option not in optional_options: + raise TVMCException( + f"Option '{option}' is invalid. Use --list-options to see all available options." + ) + + +def check_options_choices(options, valid_options): + """Check if an option value is among the option's choices, when choices exist. + + Parameters + ---------- + options: dict + dict indexed by option name of options and options values to be checked. + + valid_options: list of dict + list of all valid options and choices for a platform. + + Returns + ------- + None. Raise TVMCException if check fails, i.e. if an option value is not valid. + + """ + # Dict of all valid options and associated valid choices. + # Options with no choices are excluded from the dict. + valid_options_choices = { + opt["name"]: opt["choices"] for opt in valid_options if opt["choices"] is not None + } + + for option in options: + if option in valid_options_choices: + if options[option] not in valid_options_choices[option]: + raise TVMCException( + f"Choice '{options[option]}' for option '{option}' is invalid. " + "Use --list-options to see all available choices for that option." + ) + + +def get_and_check_options(passed_options, valid_options): + """Get options and check if they are valid. If choices exist for them, check values against it. + + Parameters + ---------- + passed_options: list of str + list of strings in the "key=value" form as captured by argparse. + + valid_option: list + list with all options available for a given API method / project as returned by + get_project_options(). + + Returns + ------- + opts: dict + dict indexed by option names and associated values. + + Or None if passed_options is None. + + """ + + if passed_options is None: + # No options to check + return None + + # From a list of k=v strings, make a dict options[k]=v + opts = get_options(passed_options) + # Check if passed options are valid + check_options(opts, valid_options) + # Check (when a list of choices exists) if the passed values are valid + check_options_choices(opts, valid_options) + + return opts diff --git a/python/tvm/driver/tvmc/fmtopt.py b/python/tvm/driver/tvmc/fmtopt.py new file mode 100644 index 000000000000..7f27826d77bf --- /dev/null +++ b/python/tvm/driver/tvmc/fmtopt.py @@ -0,0 +1,116 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Utils to format help text for project options. +""" +from textwrap import TextWrapper + + +# Maximum column length for accommodating option name and its choices. +# Help text is placed after it in a new line. +MAX_OPTNAME_CHOICES_TEXT_COL_LEN = 80 + + +# Maximum column length for accommodating help text. +# 0 turns off formatting for the help text. +MAX_HELP_TEXT_COL_LEN = 0 + + +# Justification of help text placed below option name + choices text. +HELP_TEXT_JUST = 2 + + +def format_option(option_text, help_text, default_text, required=True): + """Format option name, choices, and default text into a single help text. + + Parameters + ---------- + options_text: str + String containing the option name and option's choices formatted as: + optname={opt0, opt1, ...} + + help_text: str + Help text string. + + default_text: str + Default text string. + + required: bool + Flag that controls if a "(required)" text mark needs to be added to the final help text to + inform if the option is a required one. + + Returns + ------- + help_text_just: str + Single justified help text formatted as: + optname={opt0, opt1, ... } + HELP_TEXT. "(required)" | "Defaults to 'DEFAULT'." + + """ + optname, choices_text = option_text.split("=", 1) + + # Prepare optname + choices text chunck. + + optname_len = len(optname) + wrapper = TextWrapper(width=MAX_OPTNAME_CHOICES_TEXT_COL_LEN - optname_len) + choices_lines = wrapper.wrap(choices_text) + + # Set first choices line which merely appends to optname string. + # No justification is necessary for the first line since first + # line was wrapped based on MAX_OPTNAME_CHOICES_TEXT_COL_LEN - optname_len, + # i.e. considering optname_len, hence only append justified choices_lines[0] line. + choices_just_lines = [optname + "=" + choices_lines[0]] + + # Justify the remaining lines based on first optname + '='. + for line in choices_lines[1:]: + line_len = len(line) + line_just = line.rjust( + optname_len + 1 + line_len + ) # add 1 to align after '{' in the line above + choices_just_lines.append(line_just) + + choices_text_just_chunk = "\n".join(choices_just_lines) + + # Prepare help text chunck. + + help_text = help_text[0].lower() + help_text[1:] + if MAX_HELP_TEXT_COL_LEN > 0: + wrapper = TextWrapper(width=MAX_HELP_TEXT_COL_LEN) + help_text_lines = wrapper.wrap(help_text) + else: + # Don't format help text. + help_text_lines = [help_text] + + help_text_just_lines = [] + for line in help_text_lines: + line_len = len(line) + line_just = line.rjust(HELP_TEXT_JUST + line_len) + help_text_just_lines.append(line_just) + + help_text_just_chunk = "\n".join(help_text_just_lines) + + # An option might be required for one method but optional for another one. + # If the option is required for one method it means there is no default for + # it when used in that method, hence suppress default text in that case. + if default_text and not required: + help_text_just_chunk += " " + default_text + + if required: + help_text_just_chunk += " (required)" + + help_text_just = choices_text_just_chunk + "\n" + help_text_just_chunk + return help_text_just diff --git a/python/tvm/driver/tvmc/main.py b/python/tvm/driver/tvmc/main.py index ab72d59fbea6..0a8df4b1599d 100644 --- a/python/tvm/driver/tvmc/main.py +++ b/python/tvm/driver/tvmc/main.py @@ -60,6 +60,9 @@ def _main(argv): formatter_class=argparse.RawDescriptionHelpFormatter, description="TVM compiler driver", epilog=__doc__, + # Help action will be added later, after all subparsers are created, + # so it doesn't interfere with the creation of the dynamic subparsers. + add_help=False, ) parser.add_argument("-v", "--verbose", action="count", default=0, help="increase verbosity") parser.add_argument("--version", action="store_true", help="print the version and exit") @@ -68,6 +71,9 @@ def _main(argv): for make_subparser in REGISTERED_PARSER: make_subparser(subparser, parser) + # Finally, add help for the main parser. + parser.add_argument("-h", "--help", action="help", help="show this help message and exit.") + args = parser.parse_args(argv) if args.verbose > 4: args.verbose = 4 diff --git a/python/tvm/driver/tvmc/micro.py b/python/tvm/driver/tvmc/micro.py new file mode 100644 index 000000000000..9b31d3278b91 --- /dev/null +++ b/python/tvm/driver/tvmc/micro.py @@ -0,0 +1,300 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Provides support for micro targets (microTVM). +""" +import argparse +import os +from pathlib import Path +import shutil +import sys + +import tvm.micro.project as project +from tvm.micro import get_microtvm_template_projects +from tvm.micro.build import MicroTVMTemplateProjectNotFoundError +from tvm.micro.project_api.server import ServerError +from tvm.micro.project_api.client import ProjectAPIServerNotFoundError +from .main import register_parser +from .common import ( + TVMCException, + TVMCSuppressedArgumentParser, + get_project_options, + get_and_check_options, +) + + +TEMPLATES = {} +for p in ("zephyr", "arduino"): + try: + TEMPLATES[p] = get_microtvm_template_projects(p) + except MicroTVMTemplateProjectNotFoundError: + pass + + +@register_parser +def add_micro_parser(subparsers, main_parser): + """Includes parser for 'micro' context and associated subcommands: + create-project, build, and flash. + """ + + micro = subparsers.add_parser("micro", help="select micro context.") + micro.set_defaults(func=drive_micro) + + micro_parser = micro.add_subparsers(title="subcommands") + # Selecting a subcommand under 'micro' is mandatory + micro_parser.required = True + micro_parser.dest = "subcommand" + + # 'create_project' subcommand + create_project_parser = micro_parser.add_parser( + "create-project", + aliases=["create"], + help="create a project template of a given type or given a template dir.", + ) + create_project_parser.set_defaults(subcommand_handler=create_project_handler) + create_project_parser.add_argument( + "project_dir", + help="project dir where the new project based on the template dir will be created.", + ) + create_project_parser.add_argument("MLF", help="Model Library Format (MLF) .tar archive.") + create_project_parser.add_argument( + "-f", + "--force", + action="store_true", + help="force project creating even if the specified project directory already exists.", + ) + + # 'build' subcommand + build_parser = micro_parser.add_parser( + "build", + help="build a project dir, generally creating an image to be flashed, e.g. zephyr.elf.", + ) + build_parser.set_defaults(subcommand_handler=build_handler) + build_parser.add_argument("project_dir", help="project dir to build.") + build_parser.add_argument("-f", "--force", action="store_true", help="Force rebuild.") + + # 'flash' subcommand + flash_parser = micro_parser.add_parser( + "flash", help="flash the built image on a given micro target." + ) + flash_parser.set_defaults(subcommand_handler=flash_handler) + flash_parser.add_argument("project_dir", help="project dir where the built image is.") + + # For each platform add arguments detected automatically using Project API info query. + + # Create subparsers for the platforms under 'create-project', 'build', and 'flash' subcommands. + help_msg = ( + "you must select a platform from the list. You can pass '-h' for a selected " + "platform to list its options." + ) + create_project_platforms_parser = create_project_parser.add_subparsers( + title="platforms", help=help_msg, dest="platform" + ) + build_platforms_parser = build_parser.add_subparsers( + title="platforms", help=help_msg, dest="platform" + ) + flash_platforms_parser = flash_parser.add_subparsers( + title="platforms", help=help_msg, dest="platform" + ) + + subcmds = { + # API method name Parser associated to method Handler func to call after parsing + "generate_project": [create_project_platforms_parser, create_project_handler], + "build": [build_platforms_parser, build_handler], + "flash": [flash_platforms_parser, flash_handler], + } + + # Helper to add a platform parser to a subcmd parser. + def _add_parser(parser, platform): + platform_name = platform[0].upper() + platform[1:] + " platform" + platform_parser = parser.add_parser( + platform, add_help=False, help=f"select {platform_name}." + ) + platform_parser.set_defaults(platform=platform) + return platform_parser + + parser_by_subcmd = {} + for subcmd, subcmd_parser_handler in subcmds.items(): + subcmd_parser = subcmd_parser_handler[0] + subcmd_parser.required = True # Selecting a platform or template is mandatory + parser_by_platform = {} + for platform in TEMPLATES: + new_parser = _add_parser(subcmd_parser, platform) + parser_by_platform[platform] = new_parser + + # Besides adding the parsers for each default platform (like Zephyr and Arduino), add a + # parser for 'template' to deal with adhoc projects/platforms. + new_parser = subcmd_parser.add_parser( + "template", add_help=False, help="select an adhoc template." + ) + new_parser.add_argument( + "--template-dir", required=True, help="Project API template directory." + ) + new_parser.set_defaults(platform="template") + parser_by_platform["template"] = new_parser + + parser_by_subcmd[subcmd] = parser_by_platform + + disposable_parser = TVMCSuppressedArgumentParser(main_parser) + try: + known_args, _ = disposable_parser.parse_known_args() + except TVMCException: + return + + try: + subcmd = known_args.subcommand + platform = known_args.platform + except AttributeError: + # No subcommand or platform, hence no need to augment the parser for micro targets. + return + + # Augment parser with project options. + + if platform == "template": + # adhoc template + template_dir = str(Path(known_args.template_dir).resolve()) + else: + # default template + template_dir = TEMPLATES[platform] + + try: + template = project.TemplateProject.from_directory(template_dir) + except ProjectAPIServerNotFoundError: + sys.exit(f"Error: Project API server not found in {template_dir}!") + + template_info = template.info() + + options_by_method = get_project_options(template_info) + + # TODO(gromero): refactor to remove this map. + subcmd_to_method = { + "create-project": "generate_project", + "create": "generate_project", + "build": "build", + "flash": "flash", + } + + method = subcmd_to_method[subcmd] + parser_by_subcmd_n_platform = parser_by_subcmd[method][platform] + _, handler = subcmds[method] + + parser_by_subcmd_n_platform.formatter_class = ( + # Set raw help text so help_text format works + argparse.RawTextHelpFormatter + ) + parser_by_subcmd_n_platform.set_defaults( + subcommand_handler=handler, + valid_options=options_by_method[method], + template_dir=template_dir, + ) + + required = any([opt["required"] for opt in options_by_method[method]]) + nargs = "+" if required else "*" + + help_text_by_option = [opt["help_text"] for opt in options_by_method[method]] + help_text = "\n\n".join(help_text_by_option) + "\n\n" + + parser_by_subcmd_n_platform.add_argument( + "--project-option", required=required, metavar="OPTION=VALUE", nargs=nargs, help=help_text + ) + + parser_by_subcmd_n_platform.add_argument( + "-h", + "--help", + "--list-options", + action="help", + help="show this help message which includes platform-specific options and exit.", + ) + + +def drive_micro(args): + # Call proper handler based on subcommand parsed. + args.subcommand_handler(args) + + +def create_project_handler(args): + """Creates a new project dir.""" + + if os.path.exists(args.project_dir): + if args.force: + shutil.rmtree(args.project_dir) + else: + raise TVMCException( + "The specified project dir already exists. " + "To force overwriting it use '-f' or '--force'." + ) + project_dir = args.project_dir + + template_dir = str(Path(args.template_dir).resolve()) + if not os.path.exists(template_dir): + raise TVMCException(f"Template directory {template_dir} does not exist!") + + mlf_path = str(Path(args.MLF).resolve()) + if not os.path.exists(mlf_path): + raise TVMCException(f"MLF file {mlf_path} does not exist!") + + options = get_and_check_options(args.project_option, args.valid_options) + + try: + project.generate_project_from_mlf(template_dir, project_dir, mlf_path, options) + except ServerError as error: + print("The following error occured on the Project API server side: \n", error) + sys.exit(1) + + +def build_handler(args): + """Builds a firmware image given a project dir.""" + + if not os.path.exists(args.project_dir): + raise TVMCException(f"{args.project_dir} doesn't exist.") + + if os.path.exists(args.project_dir + "/build"): + if args.force: + shutil.rmtree(args.project_dir + "/build") + else: + raise TVMCException( + f"There is already a build in {args.project_dir}. " + "To force rebuild it use '-f' or '--force'." + ) + + project_dir = args.project_dir + + options = get_and_check_options(args.project_option, args.valid_options) + + try: + prj = project.GeneratedProject.from_directory(project_dir, options=options) + prj.build() + except ServerError as error: + print("The following error occured on the Project API server side: ", error) + sys.exit(1) + + +def flash_handler(args): + """Flashes a firmware image to a target device given a project dir.""" + if not os.path.exists(args.project_dir + "/build"): + raise TVMCException(f"Could not find a build in {args.project_dir}") + + project_dir = args.project_dir + + options = get_and_check_options(args.project_option, args.valid_options) + + try: + prj = project.GeneratedProject.from_directory(project_dir, options=options) + prj.flash() + except ServerError as error: + print("The following error occured on the Project API server side: ", error) + sys.exit(1) diff --git a/python/tvm/micro/project_api/server.py b/python/tvm/micro/project_api/server.py index 7dd26008fc7a..ed26733cfc3c 100644 --- a/python/tvm/micro/project_api/server.py +++ b/python/tvm/micro/project_api/server.py @@ -48,6 +48,8 @@ class ProjectOption(_ProjectOption): + """Class used to keep the metadata associated to project options.""" + def __new__(cls, name, **kw): """Override __new__ to force all options except name to be specified as kwargs.""" assert "name" not in kw From ad45c677f220a78543849c5525f934a9ee912218 Mon Sep 17 00:00:00 2001 From: Gustavo Romero Date: Mon, 23 Aug 2021 15:00:10 +0000 Subject: [PATCH 10/10] [TVMC] run: Add support for micro devices Add support for micro devices using the Project API to query all options available for a given platform and open a session with an specified micro device. Use of 'tvmc run' with micro device is enabled via the '--device micro' option in addition to the project directory. Once the project directory is specified 'tvmc run' will make all options specific to the platform found in the project dir available as options in 'tvmc run'. They can be listed by '--list-options' and passed via '--options'. Signed-off-by: Gustavo Romero --- python/tvm/driver/tvmc/runner.py | 330 ++++++++++++++++++++++--------- 1 file changed, 241 insertions(+), 89 deletions(-) diff --git a/python/tvm/driver/tvmc/runner.py b/python/tvm/driver/tvmc/runner.py index 52af02bc43d8..eb571143e551 100644 --- a/python/tvm/driver/tvmc/runner.py +++ b/python/tvm/driver/tvmc/runner.py @@ -17,21 +17,33 @@ """ Provides support to run compiled networks both locally and remotely. """ +from contextlib import ExitStack import json import logging +import pathlib from typing import Dict, List, Optional, Union from tarfile import ReadError - +import argparse +import os +import sys import numpy as np + import tvm from tvm import rpc from tvm.autotvm.measure import request_remote from tvm.contrib import graph_executor as runtime from tvm.contrib.debugger import debug_executor from tvm.relay.param_dict import load_param_dict - +import tvm.micro.project as project +from tvm.micro.project import TemplateProjectError +from tvm.micro.project_api.client import ProjectAPIServerNotFoundError from . import common -from .common import TVMCException +from .common import ( + TVMCException, + TVMCSuppressedArgumentParser, + get_project_options, + get_and_check_options, +) from .main import register_parser from .model import TVMCPackage, TVMCResult from .result_utils import get_top_results @@ -44,14 +56,16 @@ def add_run_parser(subparsers, main_parser): """Include parser for 'run' subcommand""" - parser = subparsers.add_parser("run", help="run a compiled module") + # Use conflict_handler='resolve' to allow '--list-options' option to be properly overriden when + # augmenting the parser with the micro device options (i.e. when '--device micro'). + parser = subparsers.add_parser("run", help="run a compiled module", conflict_handler="resolve") parser.set_defaults(func=drive_run) # TODO --device needs to be extended and tested to support other targets, # like 'webgpu', etc (@leandron) parser.add_argument( "--device", - choices=["cpu", "cuda", "cl", "metal", "vulkan", "rocm"], + choices=["cpu", "cuda", "cl", "metal", "vulkan", "rocm", "micro"], default="cpu", help="target device to run the compiled module. Defaults to 'cpu'", ) @@ -66,7 +80,9 @@ def add_run_parser(subparsers, main_parser): parser.add_argument("-i", "--inputs", help="path to the .npz input file") parser.add_argument("-o", "--outputs", help="path to the .npz output file") parser.add_argument( - "--print-time", action="store_true", help="record and print the execution time(s)" + "--print-time", + action="store_true", + help="record and print the execution time(s). (non-micro devices only)", ) parser.add_argument( "--print-top", @@ -80,7 +96,7 @@ def add_run_parser(subparsers, main_parser): help="generate profiling data from the runtime execution. " "Using --profile requires the Graph Executor Debug enabled on TVM. " "Profiling may also have an impact on inference time, " - "making it take longer to be generated.", + "making it take longer to be generated. (non-micro devices only)", ) parser.add_argument( "--repeat", metavar="N", type=int, default=1, help="run the model n times. Defaults to '1'" @@ -90,14 +106,69 @@ def add_run_parser(subparsers, main_parser): ) parser.add_argument( "--rpc-key", - help="the RPC tracker key of the target device", + help="the RPC tracker key of the target device. (non-micro devices only)", ) parser.add_argument( "--rpc-tracker", help="hostname (required) and port (optional, defaults to 9090) of the RPC tracker, " - "e.g. '192.168.0.100:9999'", + "e.g. '192.168.0.100:9999'. (non-micro devices only)", + ) + parser.add_argument( + "PATH", + help="path to the compiled module file or to the project directory if '--device micro' " + "is selected.", + ) + parser.add_argument( + "--list-options", + action="store_true", + help="show all run options and option choices when '--device micro' is selected. " + "(micro devices only)", + ) + + disposable_parser = TVMCSuppressedArgumentParser(main_parser) + try: + known_args, _ = disposable_parser.parse_known_args() + except TVMCException: + return + + if vars(known_args).get("device") != "micro": + # No need to augment the parser for micro targets. + return + + project_dir = known_args.PATH + + try: + project_ = project.GeneratedProject.from_directory(project_dir, None) + except ProjectAPIServerNotFoundError: + sys.exit(f"Error: Project API server not found in {project_dir}!") + except TemplateProjectError: + sys.exit( + f"Error: Project directory error. That usually happens when model.tar is not found." + ) + + project_info = project_.info() + options_by_method = get_project_options(project_info) + + parser.formatter_class = ( + argparse.RawTextHelpFormatter + ) # Set raw help text so customized help_text format works + parser.set_defaults(valid_options=options_by_method["open_transport"]) + + required = any([opt["required"] for opt in options_by_method["open_transport"]]) + nargs = "+" if required else "*" + + help_text_by_option = [opt["help_text"] for opt in options_by_method["open_transport"]] + help_text = "\n\n".join(help_text_by_option) + "\n\n" + + parser.add_argument( + "--project-option", required=required, metavar="OPTION=VALUE", nargs=nargs, help=help_text + ) + + parser.add_argument( + "--list-options", + action="help", + help="show this help message with platform-specific options and exit.", ) - parser.add_argument("FILE", help="path to the compiled module file") def drive_run(args): @@ -109,21 +180,61 @@ def drive_run(args): Arguments from command line parser. """ - rpc_hostname, rpc_port = common.tracker_host_port_from_cli(args.rpc_tracker) + path = pathlib.Path(args.PATH) + options = None + if args.device == "micro": + path = path / "model.tar" + if not path.is_file(): + TVMCException( + f"Could not find model (model.tar) in the specified project dir {path.dirname()}." + ) - try: - inputs = np.load(args.inputs) if args.inputs else {} - except IOError as ex: - raise TVMCException("Error loading inputs file: %s" % ex) + # Check for options unavailable for micro targets. + + if args.rpc_key or args.rpc_tracker: + raise TVMCException( + "--rpc-key and/or --rpc-tracker can't be specified for micro targets." + ) + + if args.device != "micro": + raise TVMCException( + f"Device '{args.device}' not supported. " + "Only device 'micro' is supported to run a model in MLF, " + "i.e. when '--device micro'." + ) + + if args.profile: + raise TVMCException("--profile is not currently supported for micro devices.") + + if args.print_time: + raise TVMCException("--print-time is not currently supported for micro devices.") + + # Get and check options for micro targets. + options = get_and_check_options(args.project_option, args.valid_options) + + else: + # Check for options only availabe for micro targets. + + if args.list_options: + raise TVMCException( + "--list-options is only availabe on micro targets, i.e. when '--device micro'." + ) try: - tvmc_package = TVMCPackage(package_path=args.FILE) + tvmc_package = TVMCPackage(package_path=path) except IsADirectoryError: - raise TVMCException(f"File {args.FILE} must be an archive, not a directory.") + raise TVMCException(f"File {path} must be an archive, not a directory.") except FileNotFoundError: - raise TVMCException(f"File {args.FILE} does not exist.") + raise TVMCException(f"File {path} does not exist.") except ReadError: - raise TVMCException(f"Could not read model from archive {args.FILE}!") + raise TVMCException(f"Could not read model from archive {path}!") + + rpc_hostname, rpc_port = common.tracker_host_port_from_cli(args.rpc_tracker) + + try: + inputs = np.load(args.inputs) if args.inputs else {} + except IOError as ex: + raise TVMCException("Error loading inputs file: %s" % ex) result = run_module( tvmc_package, @@ -136,6 +247,7 @@ def drive_run(args): repeat=args.repeat, number=args.number, profile=args.profile, + options=options, ) if args.print_time: @@ -318,6 +430,7 @@ def run_module( repeat: int = 10, number: int = 10, profile: bool = False, + options: dict = None, ): """Run a compiled graph executor module locally or remotely with optional input values. @@ -366,79 +479,118 @@ def run_module( "Try calling tvmc.compile on the model before running it." ) - # Currently only two package formats are supported: "classic" and - # "mlf". The later can only be used for micro targets, i.e. with microTVM. - if tvmc_package.type == "mlf": - raise TVMCException( - "You're trying to run a model saved using the Model Library Format (MLF)." - "MLF can only be used to run micro targets (microTVM)." - ) + with ExitStack() as stack: + # Currently only two package formats are supported: "classic" and + # "mlf". The later can only be used for micro targets, i.e. with microTVM. + if device == "micro": + if tvmc_package.type != "mlf": + raise TVMCException(f"Model {tvmc_package.package_path} is not a MLF archive.") - if hostname: - if isinstance(port, str): - port = int(port) - # Remote RPC - if rpc_key: - logger.debug("Running on remote RPC tracker with key %s.", rpc_key) - session = request_remote(rpc_key, hostname, port, timeout=1000) - else: - logger.debug("Running on remote RPC with no key.") - session = rpc.connect(hostname, port) - else: - # Local - logger.debug("Running a local session.") - session = rpc.LocalSession() - - session.upload(tvmc_package.lib_path) - lib = session.load_module(tvmc_package.lib_name) - - # TODO expand to other supported devices, as listed in tvm.rpc.client (@leandron) - logger.debug("Device is %s.", device) - if device == "cuda": - dev = session.cuda() - elif device == "cl": - dev = session.cl() - elif device == "metal": - dev = session.metal() - elif device == "vulkan": - dev = session.vulkan() - elif device == "rocm": - dev = session.rocm() - else: - assert device == "cpu" - dev = session.cpu() - - if profile: - logger.debug("Creating runtime with profiling enabled.") - module = debug_executor.create(tvmc_package.graph, lib, dev, dump_root="./prof") - else: - logger.debug("Creating runtime with profiling disabled.") - module = runtime.create(tvmc_package.graph, lib, dev) + project_dir = os.path.dirname(tvmc_package.package_path) - logger.debug("Loading params into the runtime module.") - module.load_params(tvmc_package.params) - - shape_dict, dtype_dict = get_input_info(tvmc_package.graph, tvmc_package.params) - inputs_dict = make_inputs_dict(shape_dict, dtype_dict, inputs, fill_mode) - - logger.debug("Setting inputs to the module.") - module.set_input(**inputs_dict) + # This is guaranteed to work since project_dir was already checked when + # building the dynamic parser to accommodate the project options, so no + # checks are in place when calling GeneratedProject. + project_ = project.GeneratedProject.from_directory(project_dir, options) + else: + if tvmc_package.type == "mlf": + raise TVMCException( + "You're trying to run a model saved using the Model Library Format (MLF). " + "MLF can only be used to run micro device ('--device micro')." + ) - # Run must be called explicitly if profiling - if profile: - logger.info("Running the module with profiling enabled.") - report = module.profile() - # This print is intentional - print(report) + if hostname: + if isinstance(port, str): + port = int(port) + # Remote RPC + if rpc_key: + logger.debug("Running on remote RPC tracker with key %s.", rpc_key) + session = request_remote(rpc_key, hostname, port, timeout=1000) + else: + logger.debug("Running on remote RPC with no key.") + session = rpc.connect(hostname, port) + elif device == "micro": + # Remote RPC (running on a micro target) + logger.debug("Running on remote RPC (micro target).") + try: + session = tvm.micro.Session(project_.transport()) + stack.enter_context(session) + except: + raise TVMCException("Could not open a session with the micro target.") + else: + # Local + logger.debug("Running a local session.") + session = rpc.LocalSession() + + # Micro targets don't support uploading a model. The model to be run + # must be already flashed into the micro target before one tries + # to run it. Hence skip model upload for micro targets. + if device != "micro": + session.upload(tvmc_package.lib_path) + lib = session.load_module(tvmc_package.lib_name) + + # TODO expand to other supported devices, as listed in tvm.rpc.client (@leandron) + logger.debug("Device is %s.", device) + if device == "cuda": + dev = session.cuda() + elif device == "cl": + dev = session.cl() + elif device == "metal": + dev = session.metal() + elif device == "vulkan": + dev = session.vulkan() + elif device == "rocm": + dev = session.rocm() + elif device == "micro": + dev = session.device + lib = session.get_system_lib() + else: + assert device == "cpu" + dev = session.cpu() - # call the benchmarking function of the executor - times = module.benchmark(dev, number=number, repeat=repeat) + # TODO(gromero): Adjust for micro targets. + if profile: + logger.debug("Creating runtime with profiling enabled.") + module = debug_executor.create(tvmc_package.graph, lib, dev, dump_root="./prof") + else: + if device == "micro": + logger.debug("Creating runtime (micro) with profiling disabled.") + module = tvm.micro.create_local_graph_executor(tvmc_package.graph, lib, dev) + else: + logger.debug("Creating runtime with profiling disabled.") + module = runtime.create(tvmc_package.graph, lib, dev) + + logger.debug("Loading params into the runtime module.") + module.load_params(tvmc_package.params) + + shape_dict, dtype_dict = get_input_info(tvmc_package.graph, tvmc_package.params) + inputs_dict = make_inputs_dict(shape_dict, dtype_dict, inputs, fill_mode) + + logger.debug("Setting inputs to the module.") + module.set_input(**inputs_dict) + + # Run must be called explicitly if profiling + if profile: + logger.info("Running the module with profiling enabled.") + report = module.profile() + # This print is intentional + print(report) + + if device == "micro": + # TODO(gromero): Fix time_evaluator() for micro targets. Once it's + # fixed module.benchmark() can be used instead and this if/else can + # be removed. + module.run() + times = [] + else: + # call the benchmarking function of the executor + times = module.benchmark(dev, number=number, repeat=repeat) - logger.debug("Collecting the output tensors.") - num_outputs = module.get_num_outputs() - outputs = {} - for i in range(num_outputs): - output_name = "output_{}".format(i) - outputs[output_name] = module.get_output(i).numpy() + logger.debug("Collecting the output tensors.") + num_outputs = module.get_num_outputs() + outputs = {} + for i in range(num_outputs): + output_name = "output_{}".format(i) + outputs[output_name] = module.get_output(i).numpy() - return TVMCResult(outputs, times) + return TVMCResult(outputs, times)