diff --git a/apps/microtvm/arduino/template_project/microtvm_api_server.py b/apps/microtvm/arduino/template_project/microtvm_api_server.py index 1768c61197a9..3039eb313908 100644 --- a/apps/microtvm/arduino/template_project/microtvm_api_server.py +++ b/apps/microtvm/arduino/template_project/microtvm_api_server.py @@ -71,22 +71,40 @@ class BoardAutodetectFailed(Exception): PROJECT_OPTIONS = [ server.ProjectOption( "arduino_board", + required=["build", "flash", "open_transport"], choices=list(BOARD_PROPERTIES), - help="Name of the Arduino board to build for", + type="str", + help="Name of the Arduino board to build for.", + ), + server.ProjectOption( + "arduino_cli_cmd", + required=["build", "flash", "open_transport"], + type="str", + help="Path to the arduino-cli tool.", + ), + server.ProjectOption( + "port", + optional=["flash", "open_transport"], + type="int", + help="Port to use for connecting to hardware.", ), - server.ProjectOption("arduino_cli_cmd", help="Path to the arduino-cli tool."), - server.ProjectOption("port", help="Port to use for connecting to hardware"), server.ProjectOption( "project_type", - help="Type of project to generate.", + required=["generate_project"], choices=tuple(PROJECT_TYPES), + type="str", + help="Type of project to generate.", ), server.ProjectOption( - "verbose", help="True to pass --verbose flag to arduino-cli compile and upload" + "verbose", + optional=["build", "flash"], + type="bool", + help="Run arduino-cli compile and upload with verbose output.", ), server.ProjectOption( "warning_as_error", - choices=(True, False), + optional=["generate_project"], + type="bool", help="Treat warnings as errors and raise an Exception.", ), ] diff --git a/apps/microtvm/zephyr/template_project/microtvm_api_server.py b/apps/microtvm/zephyr/template_project/microtvm_api_server.py index 7e13f928b288..3c96f31dfe22 100644 --- a/apps/microtvm/zephyr/template_project/microtvm_api_server.py +++ b/apps/microtvm/zephyr/template_project/microtvm_api_server.py @@ -234,49 +234,81 @@ def _get_nrf_device_args(options): PROJECT_OPTIONS = [ server.ProjectOption( "extra_files_tar", + optional=["generate_project"], + type="str", help="If given, during generate_project, uncompress the tarball at this path into the project dir.", ), server.ProjectOption( - "gdbserver_port", help=("If given, port number to use when running the local gdbserver.") + "gdbserver_port", + help=("If given, port number to use when running the local gdbserver."), + optional=["open_transport"], + type="int", ), server.ProjectOption( "nrfjprog_snr", + optional=["open_transport"], + type="int", help=("When used with nRF targets, serial # of the attached board to use, from nrfjprog."), ), server.ProjectOption( "openocd_serial", + optional=["open_transport"], + type="int", help=("When used with OpenOCD targets, serial # of the attached board to use."), ), server.ProjectOption( "project_type", - help="Type of project to generate.", choices=tuple(PROJECT_TYPES), + required=["generate_project"], + type="str", + help="Type of project to generate.", + ), + server.ProjectOption( + "verbose", + optional=["build"], + type="bool", + help="Run build with verbose output.", ), - server.ProjectOption("verbose", help="Run build with verbose output.", choices=(True, False)), server.ProjectOption( "west_cmd", + optional=["generate_project"], + default=sys.executable + " -m west" if sys.executable else None, + type="str", help=( "Path to the west tool. If given, supersedes both the zephyr_base " "option and ZEPHYR_BASE environment variable." ), ), - server.ProjectOption("zephyr_base", help="Path to the zephyr base directory."), + server.ProjectOption( + "zephyr_base", + optional=["build", "open_transport"], + default=os.getenv("ZEPHYR_BASE"), + type="str", + help="Path to the zephyr base directory.", + ), server.ProjectOption( "zephyr_board", + required=["generate_project", "build", "flash", "open_transport"], choices=list(BOARD_PROPERTIES), + type="str", help="Name of the Zephyr board to build for.", ), server.ProjectOption( "config_main_stack_size", + optional=["generate_project"], + type="int", help="Sets CONFIG_MAIN_STACK_SIZE for Zephyr board.", ), server.ProjectOption( "warning_as_error", - choices=(True, False), + optional=["generate_project"], + type="bool", help="Treat warnings as errors and raise an Exception.", ), server.ProjectOption( "compile_definitions", + optional=["generate_project"], + type="str", help="Extra definitions added project compile.", ), ] diff --git a/python/tvm/driver/tvmc/__init__.py b/python/tvm/driver/tvmc/__init__.py index 42184c34df74..70747cbb2d74 100644 --- a/python/tvm/driver/tvmc/__init__.py +++ b/python/tvm/driver/tvmc/__init__.py @@ -19,9 +19,10 @@ TVMC - TVM driver command-line interface """ +from . import micro +from . import runner from . import autotuner from . import compiler -from . import runner from . import result_utils from .frontends import load_model as load from .compiler import compile_model as compile diff --git a/python/tvm/driver/tvmc/autotuner.py b/python/tvm/driver/tvmc/autotuner.py index 92d13a99acd5..60bec38f0d1a 100644 --- a/python/tvm/driver/tvmc/autotuner.py +++ b/python/tvm/driver/tvmc/autotuner.py @@ -46,7 +46,7 @@ @register_parser -def add_tune_parser(subparsers): +def add_tune_parser(subparsers, _): """Include parser for 'tune' subcommand""" parser = subparsers.add_parser("tune", help="auto-tune a model") diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py index 65b0c3dbc0aa..97b7c5206a38 100644 --- a/python/tvm/driver/tvmc/common.py +++ b/python/tvm/driver/tvmc/common.py @@ -23,6 +23,7 @@ import os.path import argparse +from collections import defaultdict from urllib.parse import urlparse import tvm @@ -31,6 +32,7 @@ from tvm import relay from tvm import transform from tvm._ffi import registry +from .fmtopt import format_option # pylint: disable=invalid-name logger = logging.getLogger("TVMC") @@ -40,6 +42,37 @@ class TVMCException(Exception): """TVMC Exception""" +class TVMCSuppressedArgumentParser(argparse.ArgumentParser): + """ + A silent ArgumentParser class. + + This class is meant to be used as a helper for creating dynamic parsers in + TVMC. It will create a "supressed" parser based on an existing one (parent) + which does not include a help message, does not print a usage message (even + when -h or --help is passed) and does not exit on invalid choice parse + errors but rather throws a TVMCException so it can be handled and the + dynamic parser construction is not interrupted prematurely. + + """ + + def __init__(self, parent, **kwargs): + # Don't add '-h' or '--help' options to the newly created parser. Don't print usage message. + # 'add_help=False' won't supress existing '-h' and '--help' options from the parser (and its + # subparsers) present in 'parent'. However that class is meant to be used with the main + # parser, which is created with `add_help=False` - the help is added only later. Hence it + # the newly created parser won't have help options added in its (main) root parser. The + # subparsers in the main parser will eventually have help activated, which is enough for its + # use in TVMC. + super().__init__(parents=[parent], add_help=False, usage=argparse.SUPPRESS, **kwargs) + + def exit(self, status=0, message=None): + # Don't exit on error when parsing the command line. + # This won't catch all the errors generated when parsing tho. For instance, it won't catch + # errors due to missing required arguments. But this will catch "error: invalid choice", + # which is what it's necessary for its use in TVMC. + raise TVMCException() + + def convert_graph_layout(mod, desired_layout): """Alter the layout of the input graph. @@ -554,3 +587,202 @@ def parse_configs(input_configs): pass_context_configs[name] = parsed_value return pass_context_configs + + +def get_project_options(project_info): + """Get all project options as returned by Project API 'server_info_query' + and return them in a dict indexed by the API method they belong to. + + + Parameters + ---------- + project_info: dict of list + a dict of lists as returned by Project API 'server_info_query' among + which there is a list called 'project_options' containing all the + project options available for a given project/platform. + + Returns + ------- + options_by_method: dict of list + a dict indexed by the API method names (e.g. "generate_project", + "build", "flash", or "open_transport") of lists containing all the + options (plus associated metadata and formatted help text) that belong + to a method. + + The metadata associated to the options include the field 'choices' and + 'required' which are convenient for parsers. + + The formatted help text field 'help_text' is a string that contains the + name of the option, the choices for the option, and the option's default + value. + """ + options = project_info["project_options"] + + options_by_method = defaultdict(list) + for opt in options: + # Get list of methods associated with an option based on the + # existance of a 'required' or 'optional' lists. API specification + # guarantees at least one of these lists will exist. If a list does + # not exist it's returned as None by the API. + metadata = ["required", "optional"] + om = [(opt[md], bool(md == "required")) for md in metadata if opt[md]] + for methods, is_opt_required in om: + for method in methods: + name = opt["name"] + + # Only for boolean options set 'choices' accordingly to the + # option type. API returns 'choices' associated to them + # as None but 'choices' can be deduced from 'type' in this case. + if opt["type"] == "bool": + opt["choices"] = ["true", "false"] + + if opt["choices"]: + choices = "{" + ", ".join(opt["choices"]) + "}" + else: + choices = opt["name"].upper() + option_choices_text = f"{name}={choices}" + + help_text = opt["help"][0].lower() + opt["help"][1:] + + if opt["default"]: + default_text = f"Defaults to '{opt['default']}'." + else: + default_text = None + + formatted_help_text = format_option( + option_choices_text, help_text, default_text, is_opt_required + ) + + option = { + "name": opt["name"], + "choices": opt["choices"], + "help_text": formatted_help_text, + "required": is_opt_required, + } + options_by_method[method].append(option) + + return options_by_method + + +def get_options(options): + """Get option and option value from the list options returned by the parser. + + Parameters + ---------- + options: list of str + list of strings of the form "option=value" as returned by the parser. + + Returns + ------- + opts: dict + dict indexed by option names and associated values. + """ + + opts = {} + for option in options: + try: + k, v = option.split("=") + opts[k] = v + except ValueError: + raise TVMCException(f"Invalid option format: {option}. Please use OPTION=VALUE.") + + return opts + + +def check_options(options, valid_options): + """Check if an option (required or optional) is valid. i.e. in the list of valid options. + + Parameters + ---------- + options: dict + dict indexed by option name of options and options values to be checked. + + valid_options: list of dict + list of all valid options and choices for a platform. + + Returns + ------- + None. Raise TVMCException if check fails, i.e. if an option is not in the list of valid options. + + """ + required_options = [opt["name"] for opt in valid_options if opt["required"]] + for required_option in required_options: + if required_option not in options: + raise TVMCException( + f"Option '{required_option}' is required but was not specified. Use --list-options " + "to see all required options." + ) + + remaining_options = set(options) - set(required_options) + optional_options = [opt["name"] for opt in valid_options if not opt["required"]] + for option in remaining_options: + if option not in optional_options: + raise TVMCException( + f"Option '{option}' is invalid. Use --list-options to see all available options." + ) + + +def check_options_choices(options, valid_options): + """Check if an option value is among the option's choices, when choices exist. + + Parameters + ---------- + options: dict + dict indexed by option name of options and options values to be checked. + + valid_options: list of dict + list of all valid options and choices for a platform. + + Returns + ------- + None. Raise TVMCException if check fails, i.e. if an option value is not valid. + + """ + # Dict of all valid options and associated valid choices. + # Options with no choices are excluded from the dict. + valid_options_choices = { + opt["name"]: opt["choices"] for opt in valid_options if opt["choices"] is not None + } + + for option in options: + if option in valid_options_choices: + if options[option] not in valid_options_choices[option]: + raise TVMCException( + f"Choice '{options[option]}' for option '{option}' is invalid. " + "Use --list-options to see all available choices for that option." + ) + + +def get_and_check_options(passed_options, valid_options): + """Get options and check if they are valid. If choices exist for them, check values against it. + + Parameters + ---------- + passed_options: list of str + list of strings in the "key=value" form as captured by argparse. + + valid_option: list + list with all options available for a given API method / project as returned by + get_project_options(). + + Returns + ------- + opts: dict + dict indexed by option names and associated values. + + Or None if passed_options is None. + + """ + + if passed_options is None: + # No options to check + return None + + # From a list of k=v strings, make a dict options[k]=v + opts = get_options(passed_options) + # Check if passed options are valid + check_options(opts, valid_options) + # Check (when a list of choices exists) if the passed values are valid + check_options_choices(opts, valid_options) + + return opts diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index 7623a141c27a..a51a16f7e017 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -38,7 +38,7 @@ @register_parser -def add_compile_parser(subparsers): +def add_compile_parser(subparsers, _): """Include parser for 'compile' subcommand""" parser = subparsers.add_parser("compile", help="compile a model.") diff --git a/python/tvm/driver/tvmc/fmtopt.py b/python/tvm/driver/tvmc/fmtopt.py new file mode 100644 index 000000000000..7f27826d77bf --- /dev/null +++ b/python/tvm/driver/tvmc/fmtopt.py @@ -0,0 +1,116 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Utils to format help text for project options. +""" +from textwrap import TextWrapper + + +# Maximum column length for accommodating option name and its choices. +# Help text is placed after it in a new line. +MAX_OPTNAME_CHOICES_TEXT_COL_LEN = 80 + + +# Maximum column length for accommodating help text. +# 0 turns off formatting for the help text. +MAX_HELP_TEXT_COL_LEN = 0 + + +# Justification of help text placed below option name + choices text. +HELP_TEXT_JUST = 2 + + +def format_option(option_text, help_text, default_text, required=True): + """Format option name, choices, and default text into a single help text. + + Parameters + ---------- + options_text: str + String containing the option name and option's choices formatted as: + optname={opt0, opt1, ...} + + help_text: str + Help text string. + + default_text: str + Default text string. + + required: bool + Flag that controls if a "(required)" text mark needs to be added to the final help text to + inform if the option is a required one. + + Returns + ------- + help_text_just: str + Single justified help text formatted as: + optname={opt0, opt1, ... } + HELP_TEXT. "(required)" | "Defaults to 'DEFAULT'." + + """ + optname, choices_text = option_text.split("=", 1) + + # Prepare optname + choices text chunck. + + optname_len = len(optname) + wrapper = TextWrapper(width=MAX_OPTNAME_CHOICES_TEXT_COL_LEN - optname_len) + choices_lines = wrapper.wrap(choices_text) + + # Set first choices line which merely appends to optname string. + # No justification is necessary for the first line since first + # line was wrapped based on MAX_OPTNAME_CHOICES_TEXT_COL_LEN - optname_len, + # i.e. considering optname_len, hence only append justified choices_lines[0] line. + choices_just_lines = [optname + "=" + choices_lines[0]] + + # Justify the remaining lines based on first optname + '='. + for line in choices_lines[1:]: + line_len = len(line) + line_just = line.rjust( + optname_len + 1 + line_len + ) # add 1 to align after '{' in the line above + choices_just_lines.append(line_just) + + choices_text_just_chunk = "\n".join(choices_just_lines) + + # Prepare help text chunck. + + help_text = help_text[0].lower() + help_text[1:] + if MAX_HELP_TEXT_COL_LEN > 0: + wrapper = TextWrapper(width=MAX_HELP_TEXT_COL_LEN) + help_text_lines = wrapper.wrap(help_text) + else: + # Don't format help text. + help_text_lines = [help_text] + + help_text_just_lines = [] + for line in help_text_lines: + line_len = len(line) + line_just = line.rjust(HELP_TEXT_JUST + line_len) + help_text_just_lines.append(line_just) + + help_text_just_chunk = "\n".join(help_text_just_lines) + + # An option might be required for one method but optional for another one. + # If the option is required for one method it means there is no default for + # it when used in that method, hence suppress default text in that case. + if default_text and not required: + help_text_just_chunk += " " + default_text + + if required: + help_text_just_chunk += " (required)" + + help_text_just = choices_text_just_chunk + "\n" + help_text_just_chunk + return help_text_just diff --git a/python/tvm/driver/tvmc/main.py b/python/tvm/driver/tvmc/main.py index 2574daab02ac..0a8df4b1599d 100644 --- a/python/tvm/driver/tvmc/main.py +++ b/python/tvm/driver/tvmc/main.py @@ -60,13 +60,19 @@ def _main(argv): formatter_class=argparse.RawDescriptionHelpFormatter, description="TVM compiler driver", epilog=__doc__, + # Help action will be added later, after all subparsers are created, + # so it doesn't interfere with the creation of the dynamic subparsers. + add_help=False, ) parser.add_argument("-v", "--verbose", action="count", default=0, help="increase verbosity") parser.add_argument("--version", action="store_true", help="print the version and exit") subparser = parser.add_subparsers(title="commands") for make_subparser in REGISTERED_PARSER: - make_subparser(subparser) + make_subparser(subparser, parser) + + # Finally, add help for the main parser. + parser.add_argument("-h", "--help", action="help", help="show this help message and exit.") args = parser.parse_args(argv) if args.verbose > 4: diff --git a/python/tvm/driver/tvmc/micro.py b/python/tvm/driver/tvmc/micro.py new file mode 100644 index 000000000000..9b31d3278b91 --- /dev/null +++ b/python/tvm/driver/tvmc/micro.py @@ -0,0 +1,300 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Provides support for micro targets (microTVM). +""" +import argparse +import os +from pathlib import Path +import shutil +import sys + +import tvm.micro.project as project +from tvm.micro import get_microtvm_template_projects +from tvm.micro.build import MicroTVMTemplateProjectNotFoundError +from tvm.micro.project_api.server import ServerError +from tvm.micro.project_api.client import ProjectAPIServerNotFoundError +from .main import register_parser +from .common import ( + TVMCException, + TVMCSuppressedArgumentParser, + get_project_options, + get_and_check_options, +) + + +TEMPLATES = {} +for p in ("zephyr", "arduino"): + try: + TEMPLATES[p] = get_microtvm_template_projects(p) + except MicroTVMTemplateProjectNotFoundError: + pass + + +@register_parser +def add_micro_parser(subparsers, main_parser): + """Includes parser for 'micro' context and associated subcommands: + create-project, build, and flash. + """ + + micro = subparsers.add_parser("micro", help="select micro context.") + micro.set_defaults(func=drive_micro) + + micro_parser = micro.add_subparsers(title="subcommands") + # Selecting a subcommand under 'micro' is mandatory + micro_parser.required = True + micro_parser.dest = "subcommand" + + # 'create_project' subcommand + create_project_parser = micro_parser.add_parser( + "create-project", + aliases=["create"], + help="create a project template of a given type or given a template dir.", + ) + create_project_parser.set_defaults(subcommand_handler=create_project_handler) + create_project_parser.add_argument( + "project_dir", + help="project dir where the new project based on the template dir will be created.", + ) + create_project_parser.add_argument("MLF", help="Model Library Format (MLF) .tar archive.") + create_project_parser.add_argument( + "-f", + "--force", + action="store_true", + help="force project creating even if the specified project directory already exists.", + ) + + # 'build' subcommand + build_parser = micro_parser.add_parser( + "build", + help="build a project dir, generally creating an image to be flashed, e.g. zephyr.elf.", + ) + build_parser.set_defaults(subcommand_handler=build_handler) + build_parser.add_argument("project_dir", help="project dir to build.") + build_parser.add_argument("-f", "--force", action="store_true", help="Force rebuild.") + + # 'flash' subcommand + flash_parser = micro_parser.add_parser( + "flash", help="flash the built image on a given micro target." + ) + flash_parser.set_defaults(subcommand_handler=flash_handler) + flash_parser.add_argument("project_dir", help="project dir where the built image is.") + + # For each platform add arguments detected automatically using Project API info query. + + # Create subparsers for the platforms under 'create-project', 'build', and 'flash' subcommands. + help_msg = ( + "you must select a platform from the list. You can pass '-h' for a selected " + "platform to list its options." + ) + create_project_platforms_parser = create_project_parser.add_subparsers( + title="platforms", help=help_msg, dest="platform" + ) + build_platforms_parser = build_parser.add_subparsers( + title="platforms", help=help_msg, dest="platform" + ) + flash_platforms_parser = flash_parser.add_subparsers( + title="platforms", help=help_msg, dest="platform" + ) + + subcmds = { + # API method name Parser associated to method Handler func to call after parsing + "generate_project": [create_project_platforms_parser, create_project_handler], + "build": [build_platforms_parser, build_handler], + "flash": [flash_platforms_parser, flash_handler], + } + + # Helper to add a platform parser to a subcmd parser. + def _add_parser(parser, platform): + platform_name = platform[0].upper() + platform[1:] + " platform" + platform_parser = parser.add_parser( + platform, add_help=False, help=f"select {platform_name}." + ) + platform_parser.set_defaults(platform=platform) + return platform_parser + + parser_by_subcmd = {} + for subcmd, subcmd_parser_handler in subcmds.items(): + subcmd_parser = subcmd_parser_handler[0] + subcmd_parser.required = True # Selecting a platform or template is mandatory + parser_by_platform = {} + for platform in TEMPLATES: + new_parser = _add_parser(subcmd_parser, platform) + parser_by_platform[platform] = new_parser + + # Besides adding the parsers for each default platform (like Zephyr and Arduino), add a + # parser for 'template' to deal with adhoc projects/platforms. + new_parser = subcmd_parser.add_parser( + "template", add_help=False, help="select an adhoc template." + ) + new_parser.add_argument( + "--template-dir", required=True, help="Project API template directory." + ) + new_parser.set_defaults(platform="template") + parser_by_platform["template"] = new_parser + + parser_by_subcmd[subcmd] = parser_by_platform + + disposable_parser = TVMCSuppressedArgumentParser(main_parser) + try: + known_args, _ = disposable_parser.parse_known_args() + except TVMCException: + return + + try: + subcmd = known_args.subcommand + platform = known_args.platform + except AttributeError: + # No subcommand or platform, hence no need to augment the parser for micro targets. + return + + # Augment parser with project options. + + if platform == "template": + # adhoc template + template_dir = str(Path(known_args.template_dir).resolve()) + else: + # default template + template_dir = TEMPLATES[platform] + + try: + template = project.TemplateProject.from_directory(template_dir) + except ProjectAPIServerNotFoundError: + sys.exit(f"Error: Project API server not found in {template_dir}!") + + template_info = template.info() + + options_by_method = get_project_options(template_info) + + # TODO(gromero): refactor to remove this map. + subcmd_to_method = { + "create-project": "generate_project", + "create": "generate_project", + "build": "build", + "flash": "flash", + } + + method = subcmd_to_method[subcmd] + parser_by_subcmd_n_platform = parser_by_subcmd[method][platform] + _, handler = subcmds[method] + + parser_by_subcmd_n_platform.formatter_class = ( + # Set raw help text so help_text format works + argparse.RawTextHelpFormatter + ) + parser_by_subcmd_n_platform.set_defaults( + subcommand_handler=handler, + valid_options=options_by_method[method], + template_dir=template_dir, + ) + + required = any([opt["required"] for opt in options_by_method[method]]) + nargs = "+" if required else "*" + + help_text_by_option = [opt["help_text"] for opt in options_by_method[method]] + help_text = "\n\n".join(help_text_by_option) + "\n\n" + + parser_by_subcmd_n_platform.add_argument( + "--project-option", required=required, metavar="OPTION=VALUE", nargs=nargs, help=help_text + ) + + parser_by_subcmd_n_platform.add_argument( + "-h", + "--help", + "--list-options", + action="help", + help="show this help message which includes platform-specific options and exit.", + ) + + +def drive_micro(args): + # Call proper handler based on subcommand parsed. + args.subcommand_handler(args) + + +def create_project_handler(args): + """Creates a new project dir.""" + + if os.path.exists(args.project_dir): + if args.force: + shutil.rmtree(args.project_dir) + else: + raise TVMCException( + "The specified project dir already exists. " + "To force overwriting it use '-f' or '--force'." + ) + project_dir = args.project_dir + + template_dir = str(Path(args.template_dir).resolve()) + if not os.path.exists(template_dir): + raise TVMCException(f"Template directory {template_dir} does not exist!") + + mlf_path = str(Path(args.MLF).resolve()) + if not os.path.exists(mlf_path): + raise TVMCException(f"MLF file {mlf_path} does not exist!") + + options = get_and_check_options(args.project_option, args.valid_options) + + try: + project.generate_project_from_mlf(template_dir, project_dir, mlf_path, options) + except ServerError as error: + print("The following error occured on the Project API server side: \n", error) + sys.exit(1) + + +def build_handler(args): + """Builds a firmware image given a project dir.""" + + if not os.path.exists(args.project_dir): + raise TVMCException(f"{args.project_dir} doesn't exist.") + + if os.path.exists(args.project_dir + "/build"): + if args.force: + shutil.rmtree(args.project_dir + "/build") + else: + raise TVMCException( + f"There is already a build in {args.project_dir}. " + "To force rebuild it use '-f' or '--force'." + ) + + project_dir = args.project_dir + + options = get_and_check_options(args.project_option, args.valid_options) + + try: + prj = project.GeneratedProject.from_directory(project_dir, options=options) + prj.build() + except ServerError as error: + print("The following error occured on the Project API server side: ", error) + sys.exit(1) + + +def flash_handler(args): + """Flashes a firmware image to a target device given a project dir.""" + if not os.path.exists(args.project_dir + "/build"): + raise TVMCException(f"Could not find a build in {args.project_dir}") + + project_dir = args.project_dir + + options = get_and_check_options(args.project_option, args.valid_options) + + try: + prj = project.GeneratedProject.from_directory(project_dir, options=options) + prj.flash() + except ServerError as error: + print("The following error occured on the Project API server side: ", error) + sys.exit(1) diff --git a/python/tvm/driver/tvmc/runner.py b/python/tvm/driver/tvmc/runner.py index 659df7ceef33..eb571143e551 100644 --- a/python/tvm/driver/tvmc/runner.py +++ b/python/tvm/driver/tvmc/runner.py @@ -17,21 +17,33 @@ """ Provides support to run compiled networks both locally and remotely. """ +from contextlib import ExitStack import json import logging +import pathlib from typing import Dict, List, Optional, Union from tarfile import ReadError - +import argparse +import os +import sys import numpy as np + import tvm from tvm import rpc from tvm.autotvm.measure import request_remote from tvm.contrib import graph_executor as runtime from tvm.contrib.debugger import debug_executor from tvm.relay.param_dict import load_param_dict - +import tvm.micro.project as project +from tvm.micro.project import TemplateProjectError +from tvm.micro.project_api.client import ProjectAPIServerNotFoundError from . import common -from .common import TVMCException +from .common import ( + TVMCException, + TVMCSuppressedArgumentParser, + get_project_options, + get_and_check_options, +) from .main import register_parser from .model import TVMCPackage, TVMCResult from .result_utils import get_top_results @@ -41,17 +53,19 @@ @register_parser -def add_run_parser(subparsers): +def add_run_parser(subparsers, main_parser): """Include parser for 'run' subcommand""" - parser = subparsers.add_parser("run", help="run a compiled module") + # Use conflict_handler='resolve' to allow '--list-options' option to be properly overriden when + # augmenting the parser with the micro device options (i.e. when '--device micro'). + parser = subparsers.add_parser("run", help="run a compiled module", conflict_handler="resolve") parser.set_defaults(func=drive_run) # TODO --device needs to be extended and tested to support other targets, # like 'webgpu', etc (@leandron) parser.add_argument( "--device", - choices=["cpu", "cuda", "cl", "metal", "vulkan", "rocm"], + choices=["cpu", "cuda", "cl", "metal", "vulkan", "rocm", "micro"], default="cpu", help="target device to run the compiled module. Defaults to 'cpu'", ) @@ -66,7 +80,9 @@ def add_run_parser(subparsers): parser.add_argument("-i", "--inputs", help="path to the .npz input file") parser.add_argument("-o", "--outputs", help="path to the .npz output file") parser.add_argument( - "--print-time", action="store_true", help="record and print the execution time(s)" + "--print-time", + action="store_true", + help="record and print the execution time(s). (non-micro devices only)", ) parser.add_argument( "--print-top", @@ -80,7 +96,7 @@ def add_run_parser(subparsers): help="generate profiling data from the runtime execution. " "Using --profile requires the Graph Executor Debug enabled on TVM. " "Profiling may also have an impact on inference time, " - "making it take longer to be generated.", + "making it take longer to be generated. (non-micro devices only)", ) parser.add_argument( "--repeat", metavar="N", type=int, default=1, help="run the model n times. Defaults to '1'" @@ -90,14 +106,69 @@ def add_run_parser(subparsers): ) parser.add_argument( "--rpc-key", - help="the RPC tracker key of the target device", + help="the RPC tracker key of the target device. (non-micro devices only)", ) parser.add_argument( "--rpc-tracker", help="hostname (required) and port (optional, defaults to 9090) of the RPC tracker, " - "e.g. '192.168.0.100:9999'", + "e.g. '192.168.0.100:9999'. (non-micro devices only)", + ) + parser.add_argument( + "PATH", + help="path to the compiled module file or to the project directory if '--device micro' " + "is selected.", + ) + parser.add_argument( + "--list-options", + action="store_true", + help="show all run options and option choices when '--device micro' is selected. " + "(micro devices only)", + ) + + disposable_parser = TVMCSuppressedArgumentParser(main_parser) + try: + known_args, _ = disposable_parser.parse_known_args() + except TVMCException: + return + + if vars(known_args).get("device") != "micro": + # No need to augment the parser for micro targets. + return + + project_dir = known_args.PATH + + try: + project_ = project.GeneratedProject.from_directory(project_dir, None) + except ProjectAPIServerNotFoundError: + sys.exit(f"Error: Project API server not found in {project_dir}!") + except TemplateProjectError: + sys.exit( + f"Error: Project directory error. That usually happens when model.tar is not found." + ) + + project_info = project_.info() + options_by_method = get_project_options(project_info) + + parser.formatter_class = ( + argparse.RawTextHelpFormatter + ) # Set raw help text so customized help_text format works + parser.set_defaults(valid_options=options_by_method["open_transport"]) + + required = any([opt["required"] for opt in options_by_method["open_transport"]]) + nargs = "+" if required else "*" + + help_text_by_option = [opt["help_text"] for opt in options_by_method["open_transport"]] + help_text = "\n\n".join(help_text_by_option) + "\n\n" + + parser.add_argument( + "--project-option", required=required, metavar="OPTION=VALUE", nargs=nargs, help=help_text + ) + + parser.add_argument( + "--list-options", + action="help", + help="show this help message with platform-specific options and exit.", ) - parser.add_argument("FILE", help="path to the compiled module file") def drive_run(args): @@ -109,21 +180,61 @@ def drive_run(args): Arguments from command line parser. """ - rpc_hostname, rpc_port = common.tracker_host_port_from_cli(args.rpc_tracker) + path = pathlib.Path(args.PATH) + options = None + if args.device == "micro": + path = path / "model.tar" + if not path.is_file(): + TVMCException( + f"Could not find model (model.tar) in the specified project dir {path.dirname()}." + ) - try: - inputs = np.load(args.inputs) if args.inputs else {} - except IOError as ex: - raise TVMCException("Error loading inputs file: %s" % ex) + # Check for options unavailable for micro targets. + + if args.rpc_key or args.rpc_tracker: + raise TVMCException( + "--rpc-key and/or --rpc-tracker can't be specified for micro targets." + ) + + if args.device != "micro": + raise TVMCException( + f"Device '{args.device}' not supported. " + "Only device 'micro' is supported to run a model in MLF, " + "i.e. when '--device micro'." + ) + + if args.profile: + raise TVMCException("--profile is not currently supported for micro devices.") + + if args.print_time: + raise TVMCException("--print-time is not currently supported for micro devices.") + + # Get and check options for micro targets. + options = get_and_check_options(args.project_option, args.valid_options) + + else: + # Check for options only availabe for micro targets. + + if args.list_options: + raise TVMCException( + "--list-options is only availabe on micro targets, i.e. when '--device micro'." + ) try: - tvmc_package = TVMCPackage(package_path=args.FILE) + tvmc_package = TVMCPackage(package_path=path) except IsADirectoryError: - raise TVMCException(f"File {args.FILE} must be an archive, not a directory.") + raise TVMCException(f"File {path} must be an archive, not a directory.") except FileNotFoundError: - raise TVMCException(f"File {args.FILE} does not exist.") + raise TVMCException(f"File {path} does not exist.") except ReadError: - raise TVMCException(f"Could not read model from archive {args.FILE}!") + raise TVMCException(f"Could not read model from archive {path}!") + + rpc_hostname, rpc_port = common.tracker_host_port_from_cli(args.rpc_tracker) + + try: + inputs = np.load(args.inputs) if args.inputs else {} + except IOError as ex: + raise TVMCException("Error loading inputs file: %s" % ex) result = run_module( tvmc_package, @@ -136,6 +247,7 @@ def drive_run(args): repeat=args.repeat, number=args.number, profile=args.profile, + options=options, ) if args.print_time: @@ -318,6 +430,7 @@ def run_module( repeat: int = 10, number: int = 10, profile: bool = False, + options: dict = None, ): """Run a compiled graph executor module locally or remotely with optional input values. @@ -366,79 +479,118 @@ def run_module( "Try calling tvmc.compile on the model before running it." ) - # Currently only two package formats are supported: "classic" and - # "mlf". The later can only be used for micro targets, i.e. with microTVM. - if tvmc_package.type == "mlf": - raise TVMCException( - "You're trying to run a model saved using the Model Library Format (MLF)." - "MLF can only be used to run micro targets (microTVM)." - ) + with ExitStack() as stack: + # Currently only two package formats are supported: "classic" and + # "mlf". The later can only be used for micro targets, i.e. with microTVM. + if device == "micro": + if tvmc_package.type != "mlf": + raise TVMCException(f"Model {tvmc_package.package_path} is not a MLF archive.") - if hostname: - if isinstance(port, str): - port = int(port) - # Remote RPC - if rpc_key: - logger.debug("Running on remote RPC tracker with key %s.", rpc_key) - session = request_remote(rpc_key, hostname, port, timeout=1000) - else: - logger.debug("Running on remote RPC with no key.") - session = rpc.connect(hostname, port) - else: - # Local - logger.debug("Running a local session.") - session = rpc.LocalSession() - - session.upload(tvmc_package.lib_path) - lib = session.load_module(tvmc_package.lib_name) - - # TODO expand to other supported devices, as listed in tvm.rpc.client (@leandron) - logger.debug("Device is %s.", device) - if device == "cuda": - dev = session.cuda() - elif device == "cl": - dev = session.cl() - elif device == "metal": - dev = session.metal() - elif device == "vulkan": - dev = session.vulkan() - elif device == "rocm": - dev = session.rocm() - else: - assert device == "cpu" - dev = session.cpu() - - if profile: - logger.debug("Creating runtime with profiling enabled.") - module = debug_executor.create(tvmc_package.graph, lib, dev, dump_root="./prof") - else: - logger.debug("Creating runtime with profiling disabled.") - module = runtime.create(tvmc_package.graph, lib, dev) + project_dir = os.path.dirname(tvmc_package.package_path) - logger.debug("Loading params into the runtime module.") - module.load_params(tvmc_package.params) - - shape_dict, dtype_dict = get_input_info(tvmc_package.graph, tvmc_package.params) - inputs_dict = make_inputs_dict(shape_dict, dtype_dict, inputs, fill_mode) - - logger.debug("Setting inputs to the module.") - module.set_input(**inputs_dict) + # This is guaranteed to work since project_dir was already checked when + # building the dynamic parser to accommodate the project options, so no + # checks are in place when calling GeneratedProject. + project_ = project.GeneratedProject.from_directory(project_dir, options) + else: + if tvmc_package.type == "mlf": + raise TVMCException( + "You're trying to run a model saved using the Model Library Format (MLF). " + "MLF can only be used to run micro device ('--device micro')." + ) - # Run must be called explicitly if profiling - if profile: - logger.info("Running the module with profiling enabled.") - report = module.profile() - # This print is intentional - print(report) + if hostname: + if isinstance(port, str): + port = int(port) + # Remote RPC + if rpc_key: + logger.debug("Running on remote RPC tracker with key %s.", rpc_key) + session = request_remote(rpc_key, hostname, port, timeout=1000) + else: + logger.debug("Running on remote RPC with no key.") + session = rpc.connect(hostname, port) + elif device == "micro": + # Remote RPC (running on a micro target) + logger.debug("Running on remote RPC (micro target).") + try: + session = tvm.micro.Session(project_.transport()) + stack.enter_context(session) + except: + raise TVMCException("Could not open a session with the micro target.") + else: + # Local + logger.debug("Running a local session.") + session = rpc.LocalSession() + + # Micro targets don't support uploading a model. The model to be run + # must be already flashed into the micro target before one tries + # to run it. Hence skip model upload for micro targets. + if device != "micro": + session.upload(tvmc_package.lib_path) + lib = session.load_module(tvmc_package.lib_name) + + # TODO expand to other supported devices, as listed in tvm.rpc.client (@leandron) + logger.debug("Device is %s.", device) + if device == "cuda": + dev = session.cuda() + elif device == "cl": + dev = session.cl() + elif device == "metal": + dev = session.metal() + elif device == "vulkan": + dev = session.vulkan() + elif device == "rocm": + dev = session.rocm() + elif device == "micro": + dev = session.device + lib = session.get_system_lib() + else: + assert device == "cpu" + dev = session.cpu() - # call the benchmarking function of the executor - times = module.benchmark(dev, number=number, repeat=repeat) + # TODO(gromero): Adjust for micro targets. + if profile: + logger.debug("Creating runtime with profiling enabled.") + module = debug_executor.create(tvmc_package.graph, lib, dev, dump_root="./prof") + else: + if device == "micro": + logger.debug("Creating runtime (micro) with profiling disabled.") + module = tvm.micro.create_local_graph_executor(tvmc_package.graph, lib, dev) + else: + logger.debug("Creating runtime with profiling disabled.") + module = runtime.create(tvmc_package.graph, lib, dev) + + logger.debug("Loading params into the runtime module.") + module.load_params(tvmc_package.params) + + shape_dict, dtype_dict = get_input_info(tvmc_package.graph, tvmc_package.params) + inputs_dict = make_inputs_dict(shape_dict, dtype_dict, inputs, fill_mode) + + logger.debug("Setting inputs to the module.") + module.set_input(**inputs_dict) + + # Run must be called explicitly if profiling + if profile: + logger.info("Running the module with profiling enabled.") + report = module.profile() + # This print is intentional + print(report) + + if device == "micro": + # TODO(gromero): Fix time_evaluator() for micro targets. Once it's + # fixed module.benchmark() can be used instead and this if/else can + # be removed. + module.run() + times = [] + else: + # call the benchmarking function of the executor + times = module.benchmark(dev, number=number, repeat=repeat) - logger.debug("Collecting the output tensors.") - num_outputs = module.get_num_outputs() - outputs = {} - for i in range(num_outputs): - output_name = "output_{}".format(i) - outputs[output_name] = module.get_output(i).numpy() + logger.debug("Collecting the output tensors.") + num_outputs = module.get_num_outputs() + outputs = {} + for i in range(num_outputs): + output_name = "output_{}".format(i) + outputs[output_name] = module.get_output(i).numpy() - return TVMCResult(outputs, times) + return TVMCResult(outputs, times) diff --git a/python/tvm/micro/project.py b/python/tvm/micro/project.py index a5e54aa816a3..907590dcd2cf 100644 --- a/python/tvm/micro/project.py +++ b/python/tvm/micro/project.py @@ -83,6 +83,17 @@ def flash(self): def transport(self): return ProjectTransport(self._api_client, self._options) + def info(self): + return self._info + + @property + def options(self): + return self._options + + @options.setter + def options(self, options): + self._options = options + class NotATemplateProjectError(Exception): """Raised when the API server given to TemplateProject reports is_template=false.""" diff --git a/python/tvm/micro/project_api/server.py b/python/tvm/micro/project_api/server.py index cee0205303f0..ed26733cfc3c 100644 --- a/python/tvm/micro/project_api/server.py +++ b/python/tvm/micro/project_api/server.py @@ -42,15 +42,26 @@ _LOG = logging.getLogger(__name__) -_ProjectOption = collections.namedtuple("ProjectOption", ("name", "choices", "help")) +_ProjectOption = collections.namedtuple( + "ProjectOption", ("name", "choices", "default", "type", "required", "optional", "help") +) class ProjectOption(_ProjectOption): + """Class used to keep the metadata associated to project options.""" + def __new__(cls, name, **kw): """Override __new__ to force all options except name to be specified as kwargs.""" assert "name" not in kw + assert ( + "required" in kw or "optional" in kw + ), "at least one of 'required' or 'optional' must be specified." + assert "type" in kw, "'type' field must be specified." + kw["name"] = name - kw.setdefault("choices", None) + for param in ["choices", "default", "required", "optional"]: + kw.setdefault(param, None) + return super().__new__(cls, **kw) diff --git a/python/tvm/micro/session.py b/python/tvm/micro/session.py index d545f2e7daa4..4f754d9d442c 100644 --- a/python/tvm/micro/session.py +++ b/python/tvm/micro/session.py @@ -73,7 +73,7 @@ def __init__( ---------- transport_context_manager : ContextManager[transport.Transport] If given, `flasher` and `binary` should not be given. On entry, this context manager - should establish a tarnsport between this TVM instance and the device. + should establish a transport between this TVM instance and the device. session_name : str Name of the session, used for debugging. timeout_override : TransportTimeouts diff --git a/src/runtime/crt/common/crt_runtime_api.c b/src/runtime/crt/common/crt_runtime_api.c index ea986a3bf096..49a699c3ce13 100644 --- a/src/runtime/crt/common/crt_runtime_api.c +++ b/src/runtime/crt/common/crt_runtime_api.c @@ -183,7 +183,16 @@ int TVMModCreateFromCModule(const TVMModule* mod, TVMModuleHandle* out_handle) { return -1; } +static const TVMModuleHandle kTVMModuleHandleUninitialized = (TVMModuleHandle)(~0UL); + +static TVMModuleHandle system_lib_handle; + int TVMModFree(TVMModuleHandle mod) { + /* Never free system_lib_handler */ + if (mod == system_lib_handle && system_lib_handle != kTVMModuleHandleUninitialized) { + return 0; + } + tvm_module_index_t module_index; if (DecodeModuleHandle(mod, &module_index) != 0) { return -1; @@ -193,10 +202,6 @@ int TVMModFree(TVMModuleHandle mod) { return 0; } -static const TVMModuleHandle kTVMModuleHandleUninitialized = (TVMModuleHandle)(~0UL); - -static TVMModuleHandle system_lib_handle; - int SystemLibraryCreate(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_val, int* ret_type_codes) { const TVMModule* system_lib; diff --git a/src/runtime/crt/host/microtvm_api_server.py b/src/runtime/crt/host/microtvm_api_server.py index 546ac1448011..925dc4ea7597 100644 --- a/src/runtime/crt/host/microtvm_api_server.py +++ b/src/runtime/crt/host/microtvm_api_server.py @@ -53,7 +53,10 @@ def server_info_query(self, tvm_version): else PROJECT_DIR / MODEL_LIBRARY_FORMAT_RELPATH, project_options=[ server.ProjectOption( - "verbose", help="Run make with verbose output", choices=(True, False) + "verbose", + optional=["build"], + type="bool", + help="Run make with verbose output", ) ], ) diff --git a/tests/python/unittest/test_micro_project_api.py b/tests/python/unittest/test_micro_project_api.py index 1e511c41d73e..1dd8940fecec 100644 --- a/tests/python/unittest/test_micro_project_api.py +++ b/tests/python/unittest/test_micro_project_api.py @@ -42,8 +42,16 @@ class BaseTestHandler_Impl(project_api.server.ProjectAPIHandler): is_template=True, model_library_format_path="./model-library-format-path.sh", project_options=[ - project_api.server.ProjectOption(name="foo", help="Option foo"), - project_api.server.ProjectOption(name="bar", choices=["qux"], help="Option bar"), + project_api.server.ProjectOption( + name="foo", optional=["build"], type="bool", help="Option foo" + ), + project_api.server.ProjectOption( + name="bar", + required=["generate_project"], + type="str", + choices=["qux"], + help="Option bar", + ), ], ) @@ -141,8 +149,24 @@ def test_server_info_query(BaseTestHandler): assert reply["is_template"] == True assert reply["model_library_format_path"] == "./model-library-format-path.sh" assert reply["project_options"] == [ - {"name": "foo", "choices": None, "help": "Option foo"}, - {"name": "bar", "choices": ["qux"], "help": "Option bar"}, + { + "name": "foo", + "choices": None, + "default": None, + "type": "bool", + "required": None, + "optional": ["build"], + "help": "Option foo", + }, + { + "name": "bar", + "choices": ["qux"], + "default": None, + "type": "str", + "required": ["generate_project"], + "optional": None, + "help": "Option bar", + }, ]