From 48d65374c4d0fd90916e3401d15d3a9c4dc82f45 Mon Sep 17 00:00:00 2001 From: Chris Sidebottom Date: Wed, 17 Nov 2021 10:45:32 +0000 Subject: [PATCH] [TVMC] Split common tvmc file into more specific files This follows from #9206 and splits common.py into multiple smaller and more focussed files. --- python/tvm/driver/tvmc/__init__.py | 11 +- python/tvm/driver/tvmc/arguments.py | 52 ++ python/tvm/driver/tvmc/autotuner.py | 21 +- python/tvm/driver/tvmc/common.py | 799 ------------------ python/tvm/driver/tvmc/compiler.py | 19 +- python/tvm/driver/tvmc/composite_target.py | 2 +- python/tvm/driver/tvmc/frontends.py | 3 +- python/tvm/driver/tvmc/main.py | 3 +- python/tvm/driver/tvmc/micro.py | 6 +- python/tvm/driver/tvmc/model.py | 3 +- python/tvm/driver/tvmc/pass_config.py | 122 +++ python/tvm/driver/tvmc/pass_list.py | 54 ++ python/tvm/driver/tvmc/project.py | 233 +++++ python/tvm/driver/tvmc/registry.py | 2 +- python/tvm/driver/tvmc/runner.py | 11 +- python/tvm/driver/tvmc/shape_parser.py | 67 ++ python/tvm/driver/tvmc/target.py | 278 ++++++ python/tvm/driver/tvmc/tracker.py | 57 ++ python/tvm/driver/tvmc/transform.py | 62 ++ tests/python/driver/tvmc/test_autotuner.py | 2 +- tests/python/driver/tvmc/test_compiler.py | 6 +- .../driver/tvmc/test_composite_target.py | 2 +- tests/python/driver/tvmc/test_frontends.py | 13 +- tests/python/driver/tvmc/test_pass_config.py | 16 +- tests/python/driver/tvmc/test_pass_list.py | 8 +- .../driver/tvmc/test_registry_options.py | 2 +- tests/python/driver/tvmc/test_runner.py | 2 +- tests/python/driver/tvmc/test_shape_parser.py | 22 +- tests/python/driver/tvmc/test_target.py | 43 +- .../python/driver/tvmc/test_target_options.py | 11 +- tests/python/driver/tvmc/test_tracker.py | 8 +- 31 files changed, 1036 insertions(+), 904 deletions(-) create mode 100644 python/tvm/driver/tvmc/arguments.py delete mode 100644 python/tvm/driver/tvmc/common.py create mode 100644 python/tvm/driver/tvmc/pass_config.py create mode 100644 python/tvm/driver/tvmc/pass_list.py create mode 100644 python/tvm/driver/tvmc/project.py create mode 100644 python/tvm/driver/tvmc/shape_parser.py create mode 100644 python/tvm/driver/tvmc/tracker.py create mode 100644 python/tvm/driver/tvmc/transform.py diff --git a/python/tvm/driver/tvmc/__init__.py b/python/tvm/driver/tvmc/__init__.py index 70747cbb2d74..24bb2bc22146 100644 --- a/python/tvm/driver/tvmc/__init__.py +++ b/python/tvm/driver/tvmc/__init__.py @@ -14,11 +14,20 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=redefined-builtin +# pylint: disable=redefined-builtin,wrong-import-position """ TVMC - TVM driver command-line interface """ + +class TVMCException(Exception): + """TVMC Exception""" + + +class TVMCImportError(TVMCException): + """TVMC TVMCImportError""" + + from . import micro from . import runner from . import autotuner diff --git a/python/tvm/driver/tvmc/arguments.py b/python/tvm/driver/tvmc/arguments.py new file mode 100644 index 000000000000..57b6ee2f967a --- /dev/null +++ b/python/tvm/driver/tvmc/arguments.py @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +TVMC Argument Parsing +""" + +import argparse + +from tvm.driver.tvmc import TVMCException + + +class TVMCSuppressedArgumentParser(argparse.ArgumentParser): + """ + A silent ArgumentParser class. + This class is meant to be used as a helper for creating dynamic parsers in + TVMC. It will create a "supressed" parser based on an existing one (parent) + which does not include a help message, does not print a usage message (even + when -h or --help is passed) and does not exit on invalid choice parse + errors but rather throws a TVMCException so it can be handled and the + dynamic parser construction is not interrupted prematurely. + """ + + def __init__(self, parent, **kwargs): + # Don't add '-h' or '--help' options to the newly created parser. Don't print usage message. + # 'add_help=False' won't supress existing '-h' and '--help' options from the parser (and its + # subparsers) present in 'parent'. However that class is meant to be used with the main + # parser, which is created with `add_help=False` - the help is added only later. Hence it + # the newly created parser won't have help options added in its (main) root parser. The + # subparsers in the main parser will eventually have help activated, which is enough for its + # use in TVMC. + super().__init__(parents=[parent], add_help=False, usage=argparse.SUPPRESS, **kwargs) + + def exit(self, status=0, message=None): + # Don't exit on error when parsing the command line. + # This won't catch all the errors generated when parsing tho. For instance, it won't catch + # errors due to missing required arguments. But this will catch "error: invalid choice", + # which is what it's necessary for its use in TVMC. + raise TVMCException() diff --git a/python/tvm/driver/tvmc/autotuner.py b/python/tvm/driver/tvmc/autotuner.py index 60bec38f0d1a..8f14c80b9695 100644 --- a/python/tvm/driver/tvmc/autotuner.py +++ b/python/tvm/driver/tvmc/autotuner.py @@ -34,11 +34,12 @@ from tvm.autotvm.tuner import XGBTuner from tvm.target import Target -from . import common, composite_target, frontends -from .common import TVMCException +from . import TVMCException, composite_target, frontends from .main import register_parser from .model import TVMCModel -from .target import generate_target_args, reconstruct_target_args +from .target import target_from_cli, generate_target_args, reconstruct_target_args +from .shape_parser import parse_shape_string +from .transform import convert_graph_layout # pylint: disable=invalid-name @@ -220,7 +221,7 @@ def add_tune_parser(subparsers, _): "--input-shapes", help="specify non-generic shapes for model to run, format is " '"input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]"', - type=common.parse_shape_string, + type=parse_shape_string, ) @@ -256,9 +257,7 @@ def drive_tune(args): logger.info("RPC tracker port: %s", rpc_port) if not args.rpc_key: - raise common.TVMCException( - "need to provide an RPC tracker key (--rpc-key) for remote tuning" - ) + raise TVMCException("need to provide an RPC tracker key (--rpc-key) for remote tuning") else: rpc_hostname = None rpc_port = None @@ -376,7 +375,7 @@ def tune_model( tuning_records : str The path to the produced tuning log file. """ - target, extra_targets = common.target_from_cli(target, additional_target_options) + target, extra_targets = target_from_cli(target, additional_target_options) target, target_host = Target.check_and_update_host_consist(target, target_host) # TODO(jwfromm) Remove this deepcopy once AlterOpLayout bug that mutates source # model is fixed. For now, creating a clone avoids the issue. @@ -399,7 +398,7 @@ def tune_model( if rpc_key: if hostname is None or port is None: - raise common.TVMCException( + raise TVMCException( "You must provide a hostname and port to connect to a remote RPC device." ) if isinstance(port, str): @@ -520,7 +519,7 @@ def autotvm_get_tuning_tasks( target, target_host = Target.check_and_update_host_consist(target, target_host) if alter_layout: - mod = common.convert_graph_layout(mod, alter_layout) + mod = convert_graph_layout(mod, alter_layout) tasks = autotvm.task.extract_from_program( mod["main"], @@ -569,7 +568,7 @@ def autoscheduler_get_tuning_tasks( target, target_host = Target.check_and_update_host_consist(target, target_host) if alter_layout: - mod = common.convert_graph_layout(mod, alter_layout) + mod = convert_graph_layout(mod, alter_layout) # Extract the tasks tasks, task_weights = auto_scheduler.extract_tasks( diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py deleted file mode 100644 index 498da2341853..000000000000 --- a/python/tvm/driver/tvmc/common.py +++ /dev/null @@ -1,799 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Common utility functions shared by TVMC modules. -""" -import re -import json -import logging -import os.path -import argparse -import pathlib -from typing import Union -from collections import defaultdict -from urllib.parse import urlparse - -import tvm -from tvm.driver import tvmc -from tvm import relay -from tvm import transform -from tvm._ffi import registry -from .fmtopt import format_option - -# pylint: disable=invalid-name -logger = logging.getLogger("TVMC") - - -class TVMCException(Exception): - """TVMC Exception""" - - -class TVMCSuppressedArgumentParser(argparse.ArgumentParser): - """ - A silent ArgumentParser class. - - This class is meant to be used as a helper for creating dynamic parsers in - TVMC. It will create a "supressed" parser based on an existing one (parent) - which does not include a help message, does not print a usage message (even - when -h or --help is passed) and does not exit on invalid choice parse - errors but rather throws a TVMCException so it can be handled and the - dynamic parser construction is not interrupted prematurely. - - """ - - def __init__(self, parent, **kwargs): - # Don't add '-h' or '--help' options to the newly created parser. Don't print usage message. - # 'add_help=False' won't supress existing '-h' and '--help' options from the parser (and its - # subparsers) present in 'parent'. However that class is meant to be used with the main - # parser, which is created with `add_help=False` - the help is added only later. Hence it - # the newly created parser won't have help options added in its (main) root parser. The - # subparsers in the main parser will eventually have help activated, which is enough for its - # use in TVMC. - super().__init__(parents=[parent], add_help=False, usage=argparse.SUPPRESS, **kwargs) - - def exit(self, status=0, message=None): - # Don't exit on error when parsing the command line. - # This won't catch all the errors generated when parsing tho. For instance, it won't catch - # errors due to missing required arguments. But this will catch "error: invalid choice", - # which is what it's necessary for its use in TVMC. - raise TVMCException() - - -class TVMCImportError(TVMCException): - """TVMC TVMCImportError""" - - -def convert_graph_layout(mod, desired_layout): - """Alter the layout of the input graph. - - Parameters - ---------- - mod : tvm.IRModule - The relay module to convert. - desired_layout : str - The layout to convert to. - - Returns - ------- - mod : tvm.IRModule - The converted module. - """ - - # Assume for the time being that graphs only have - # conv2d as heavily-sensitive operators. - desired_layouts = { - "nn.conv2d": [desired_layout, "default"], - "nn.conv2d_transpose": [desired_layout, "default"], - "qnn.conv2d": [desired_layout, "default"], - } - - # Convert the layout of the graph where possible. - seq = transform.Sequential( - [ - relay.transform.RemoveUnusedFunctions(), - relay.transform.ConvertLayout(desired_layouts), - ] - ) - - with transform.PassContext(opt_level=3): - try: - return seq(mod) - except Exception as err: - raise TVMCException( - "Error converting layout to {0}: {1}".format(desired_layout, str(err)) - ) - - -def validate_targets(parse_targets, additional_target_options=None): - """ - Apply a series of validations in the targets provided via CLI. - """ - tvm_target_kinds = tvm.target.Target.list_kinds() - targets = [t["name"] for t in parse_targets] - - if len(targets) > len(set(targets)): - raise TVMCException("Duplicate target definitions are not allowed") - - if targets[-1] not in tvm_target_kinds: - tvm_target_names = ", ".join(tvm_target_kinds) - raise TVMCException( - f"The last target needs to be a TVM target. Choices: {tvm_target_names}" - ) - - tvm_targets = [t for t in targets if t in tvm_target_kinds] - if len(tvm_targets) > 2: - verbose_tvm_targets = ", ".join(tvm_targets) - raise TVMCException( - "Only two of the following targets can be used at a time. " - f"Found: {verbose_tvm_targets}." - ) - - if additional_target_options is not None: - for target_name in additional_target_options: - if not any([target for target in parse_targets if target["name"] == target_name]): - first_option = list(additional_target_options[target_name].keys())[0] - raise TVMCException( - f"Passed --target-{target_name}-{first_option}" - f" but did not specify {target_name} target" - ) - - -def tokenize_target(target): - """ - Extract a list of tokens from a target specification text. - - It covers some corner-cases that are not covered by the built-in - module 'shlex', such as the use of "+" as a punctuation character. - - - Example - ------- - - For the input `foo -op1=v1 -op2="v ,2", bar -op3=v-4` we - should obtain: - - ["foo", "-op1=v1", "-op2="v ,2"", ",", "bar", "-op3=v-4"] - - Parameters - ---------- - target : str - Target options sent via CLI arguments - - Returns - ------- - list of str - a list of parsed tokens extracted from the target string - """ - - # Regex to tokenize the "--target" value. It is split into five parts - # to match with: - # 1. target and option names e.g. llvm, -mattr=, -mcpu= - # 2. option values, all together, without quotes e.g. -mattr=+foo,+opt - # 3. option values, when single quotes are used e.g. -mattr='+foo, +opt' - # 4. option values, when double quotes are used e.g. -mattr="+foo ,+opt" - # 5. commas that separate different targets e.g. "my-target, llvm" - target_pattern = ( - r"(\-{0,2}[\w\-]+\=?" - r"(?:[\w\+\-\.]+(?:,[\w\+\-\.])*" - r"|[\'][\w\+\-,\s\.]+[\']" - r"|[\"][\w\+\-,\s\.]+[\"])*" - r"|,)" - ) - - return re.findall(target_pattern, target) - - -def parse_target(target): - """ - Parse a plain string of targets provided via a command-line - argument. - - To send more than one codegen, a comma-separated list - is expected. Options start with -=. - - We use python standard library 'shlex' to parse the argument in - a POSIX compatible way, so that if options are defined as - strings with spaces or commas, for example, this is considered - and parsed accordingly. - - - Example - ------- - - For the input `--target="foo -op1=v1 -op2="v ,2", bar -op3=v-4"` we - should obtain: - - [ - { - name: "foo", - opts: {"op1":"v1", "op2":"v ,2"}, - raw: 'foo -op1=v1 -op2="v ,2"' - }, - { - name: "bar", - opts: {"op3":"v-4"}, - raw: 'bar -op3=v-4' - } - ] - - Parameters - ---------- - target : str - Target options sent via CLI arguments - - Returns - ------- - codegens : list of dict - This list preserves the order in which codegens were - provided via command line. Each Dict contains three keys: - 'name', containing the name of the codegen; 'opts' containing - a key-value for all options passed via CLI; 'raw', - containing the plain string for this codegen - """ - codegen_names = tvmc.composite_target.get_codegen_names() - codegens = [] - - tvm_target_kinds = tvm.target.Target.list_kinds() - parsed_tokens = tokenize_target(target) - - split_codegens = [] - current_codegen = [] - split_codegens.append(current_codegen) - for token in parsed_tokens: - # every time there is a comma separating - # two codegen definitions, prepare for - # a new codegen - if token == ",": - current_codegen = [] - split_codegens.append(current_codegen) - else: - # collect a new token for the current - # codegen being parsed - current_codegen.append(token) - - # at this point we have a list of lists, - # each item on the first list is a codegen definition - # in the comma-separated values - for codegen_def in split_codegens: - # the first is expected to be the name - name = codegen_def[0] - is_tvm_target = name in tvm_target_kinds and name not in codegen_names - raw_target = " ".join(codegen_def) - all_opts = codegen_def[1:] if len(codegen_def) > 1 else [] - opts = {} - for opt in all_opts: - try: - # deal with -- prefixed flags - if opt.startswith("--"): - opt_name = opt[2:] - opt_value = True - else: - opt = opt[1:] if opt.startswith("-") else opt - opt_name, opt_value = opt.split("=", maxsplit=1) - - # remove quotes from the value: quotes are only parsed if they match, - # so it is safe to assume that if the string starts with quote, it ends - # with quote. - opt_value = opt_value[1:-1] if opt_value[0] in ('"', "'") else opt_value - except ValueError: - raise ValueError(f"Error when parsing '{opt}'") - - opts[opt_name] = opt_value - - codegens.append( - {"name": name, "opts": opts, "raw": raw_target, "is_tvm_target": is_tvm_target} - ) - - return codegens - - -def is_inline_json(target): - try: - json.loads(target) - return True - except json.decoder.JSONDecodeError: - return False - - -def _combine_target_options(target, additional_target_options=None): - if additional_target_options is None: - return target - if target["name"] in additional_target_options: - target["opts"].update(additional_target_options[target["name"]]) - return target - - -def _recombobulate_target(target): - name = target["name"] - opts = " ".join([f"-{key}={value}" for key, value in target["opts"].items()]) - return f"{name} {opts}" - - -def target_from_cli(target, additional_target_options=None): - """ - Create a tvm.target.Target instance from a - command line interface (CLI) string. - - Parameters - ---------- - target : str - compilation target as plain string, - inline JSON or path to a JSON file - - additional_target_options: Optional[Dict[str, Dict[str,str]]] - dictionary of additional target options to be - combined with parsed targets - - Returns - ------- - tvm.target.Target - an instance of target device information - extra_targets : list of dict - This list preserves the order in which extra targets were - provided via command line. Each Dict contains three keys: - 'name', containing the name of the codegen; 'opts' containing - a key-value for all options passed via CLI; 'raw', - containing the plain string for this codegen - """ - extra_targets = [] - - if os.path.isfile(target): - with open(target) as target_file: - logger.debug("target input is a path: %s", target) - target = "".join(target_file.readlines()) - elif is_inline_json(target): - logger.debug("target input is inline JSON: %s", target) - else: - logger.debug("target input is plain text: %s", target) - try: - parsed_targets = parse_target(target) - except ValueError as ex: - raise TVMCException(f"Error parsing target string '{target}'.\nThe error was: {ex}") - - validate_targets(parsed_targets, additional_target_options) - tvm_targets = [ - _combine_target_options(t, additional_target_options) - for t in parsed_targets - if t["is_tvm_target"] - ] - - # Validated target strings have 1 or 2 tvm targets, otherwise - # `validate_targets` above will fail. - if len(tvm_targets) == 1: - target = _recombobulate_target(tvm_targets[0]) - target_host = None - else: - assert len(tvm_targets) == 2 - target = _recombobulate_target(tvm_targets[0]) - target_host = _recombobulate_target(tvm_targets[1]) - - extra_targets = [t for t in parsed_targets if not t["is_tvm_target"]] - - return tvm.target.Target(target, host=target_host), extra_targets - - -def tracker_host_port_from_cli(rpc_tracker_str): - """Extract hostname and (optional) port from strings - like "1.2.3.4:9090" or "4.3.2.1". - - Used as a helper function to cover --rpc-tracker - command line argument, in different subcommands. - - Parameters - ---------- - rpc_tracker_str : str - hostname (or IP address) and port of the RPC tracker, - in the format 'hostname[:port]'. - - Returns - ------- - rpc_hostname : str or None - hostname or IP address, extracted from input. - rpc_port : int or None - port number extracted from input (9090 default). - """ - - rpc_hostname = rpc_port = None - - if rpc_tracker_str: - parsed_url = urlparse("//%s" % rpc_tracker_str) - rpc_hostname = parsed_url.hostname - rpc_port = parsed_url.port or 9090 - logger.info("RPC tracker hostname: %s", rpc_hostname) - logger.info("RPC tracker port: %s", rpc_port) - - return rpc_hostname, rpc_port - - -def parse_pass_list_str(input_string): - """Parse an input string for existing passes - - Parameters - ---------- - input_string: str - Possibly comma-separated string with the names of passes - - Returns - ------- - list: a list of existing passes. - """ - _prefix = "relay._transform." - pass_list = input_string.split(",") - missing_list = [ - p.strip() - for p in pass_list - if len(p.strip()) > 0 and tvm.get_global_func(_prefix + p.strip(), True) is None - ] - if len(missing_list) > 0: - available_list = [ - n[len(_prefix) :] for n in registry.list_global_func_names() if n.startswith(_prefix) - ] - raise argparse.ArgumentTypeError( - "Following passes are not registered within tvm: {}. Available: {}.".format( - ", ".join(missing_list), ", ".join(sorted(available_list)) - ) - ) - return pass_list - - -def parse_shape_string(inputs_string): - """Parse an input shape dictionary string to a usable dictionary. - - Parameters - ---------- - inputs_string: str - A string of the form "input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]" that - indicates the desired shape for specific model inputs. Colons, forward slashes and dots - within input_names are supported. Spaces are supported inside of dimension arrays. - - Returns - ------- - shape_dict: dict - A dictionary mapping input names to their shape for use in relay frontend converters. - """ - - # Create a regex pattern that extracts each separate input mapping. - # We want to be able to handle: - # * Spaces inside arrays - # * forward slashes inside names (but not at the beginning or end) - # * colons inside names (but not at the beginning or end) - # * dots inside names - pattern = r"(?:\w+\/)?[:\w.]+\:\s*\[\-?\d+(?:\,\s*\-?\d+)*\]" - input_mappings = re.findall(pattern, inputs_string) - if not input_mappings: - raise argparse.ArgumentTypeError( - "--input-shapes argument must be of the form " - '"input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]"' - ) - shape_dict = {} - for mapping in input_mappings: - # Remove whitespace. - mapping = mapping.replace(" ", "") - # Split mapping into name and shape. - name, shape_string = mapping.rsplit(":", 1) - # Convert shape string into a list of integers or Anys if negative. - shape = [int(x) if int(x) > 0 else relay.Any() for x in shape_string.strip("][").split(",")] - # Add parsed mapping to shape dictionary. - shape_dict[name] = shape - - return shape_dict - - -def get_pass_config_value(name, value, config_type): - """Get a PassContext configuration value, based on its config data type. - - Parameters - ---------- - name: str - config identifier name. - value: str - value assigned to the config, provided via command line. - config_type: str - data type defined to the config, as string. - - Returns - ------- - parsed_value: bool, int or str - a representation of the input value, converted to the type - specified by config_type. - """ - - if config_type == "IntImm": - # "Bool" configurations in the PassContext are recognized as - # IntImm, so deal with this case here - mapping_values = { - "false": False, - "true": True, - } - - if value.isdigit(): - parsed_value = int(value) - else: - # if not an int, accept only values on the mapping table, case insensitive - parsed_value = mapping_values.get(value.lower(), None) - - if parsed_value is None: - raise TVMCException(f"Invalid value '{value}' for configuration '{name}'. ") - - if config_type == "runtime.String": - parsed_value = value - - return parsed_value - - -def parse_configs(input_configs): - """Parse configuration values set via command line. - - Parameters - ---------- - input_configs: list of str - list of configurations provided via command line. - - Returns - ------- - pass_context_configs: dict - a dict containing key-value configs to be used in the PassContext. - """ - if not input_configs: - return {} - - all_configs = tvm.ir.transform.PassContext.list_configs() - supported_config_types = ("IntImm", "runtime.String") - supported_configs = [ - name for name in all_configs.keys() if all_configs[name]["type"] in supported_config_types - ] - - pass_context_configs = {} - - for config in input_configs: - if not config: - raise TVMCException( - f"Invalid format for configuration '{config}', use =" - ) - - # Each config is expected to be provided as "name=value" - try: - name, value = config.split("=") - name = name.strip() - value = value.strip() - except ValueError: - raise TVMCException( - f"Invalid format for configuration '{config}', use =" - ) - - if name not in all_configs: - raise TVMCException( - f"Configuration '{name}' is not defined in TVM. " - f"These are the existing configurations: {', '.join(all_configs)}" - ) - - if name not in supported_configs: - raise TVMCException( - f"Configuration '{name}' uses a data type not supported by TVMC. " - f"The following configurations are supported: {', '.join(supported_configs)}" - ) - - parsed_value = get_pass_config_value(name, value, all_configs[name]["type"]) - pass_context_configs[name] = parsed_value - - return pass_context_configs - - -def get_project_options(project_info): - """Get all project options as returned by Project API 'server_info_query' - and return them in a dict indexed by the API method they belong to. - - - Parameters - ---------- - project_info: dict of list - a dict of lists as returned by Project API 'server_info_query' among - which there is a list called 'project_options' containing all the - project options available for a given project/platform. - - Returns - ------- - options_by_method: dict of list - a dict indexed by the API method names (e.g. "generate_project", - "build", "flash", or "open_transport") of lists containing all the - options (plus associated metadata and formatted help text) that belong - to a method. - - The metadata associated to the options include the field 'choices' and - 'required' which are convenient for parsers. - - The formatted help text field 'help_text' is a string that contains the - name of the option, the choices for the option, and the option's default - value. - """ - options = project_info["project_options"] - - options_by_method = defaultdict(list) - for opt in options: - # Get list of methods associated with an option based on the - # existance of a 'required' or 'optional' lists. API specification - # guarantees at least one of these lists will exist. If a list does - # not exist it's returned as None by the API. - metadata = ["required", "optional"] - om = [(opt[md], bool(md == "required")) for md in metadata if opt[md]] - for methods, is_opt_required in om: - for method in methods: - name = opt["name"] - - # Only for boolean options set 'choices' accordingly to the - # option type. API returns 'choices' associated to them - # as None but 'choices' can be deduced from 'type' in this case. - if opt["type"] == "bool": - opt["choices"] = ["true", "false"] - - if opt["choices"]: - choices = "{" + ", ".join(opt["choices"]) + "}" - else: - choices = opt["name"].upper() - option_choices_text = f"{name}={choices}" - - help_text = opt["help"][0].lower() + opt["help"][1:] - - if opt["default"]: - default_text = f"Defaults to '{opt['default']}'." - else: - default_text = None - - formatted_help_text = format_option( - option_choices_text, help_text, default_text, is_opt_required - ) - - option = { - "name": opt["name"], - "choices": opt["choices"], - "help_text": formatted_help_text, - "required": is_opt_required, - } - options_by_method[method].append(option) - - return options_by_method - - -def get_options(options): - """Get option and option value from the list options returned by the parser. - - Parameters - ---------- - options: list of str - list of strings of the form "option=value" as returned by the parser. - - Returns - ------- - opts: dict - dict indexed by option names and associated values. - """ - - opts = {} - for option in options: - try: - k, v = option.split("=") - opts[k] = v - except ValueError: - raise TVMCException(f"Invalid option format: {option}. Please use OPTION=VALUE.") - - return opts - - -def check_options(options, valid_options): - """Check if an option (required or optional) is valid. i.e. in the list of valid options. - - Parameters - ---------- - options: dict - dict indexed by option name of options and options values to be checked. - - valid_options: list of dict - list of all valid options and choices for a platform. - - Returns - ------- - None. Raise TVMCException if check fails, i.e. if an option is not in the list of valid options. - - """ - required_options = [opt["name"] for opt in valid_options if opt["required"]] - for required_option in required_options: - if required_option not in options: - raise TVMCException( - f"Option '{required_option}' is required but was not specified. Use --list-options " - "to see all required options." - ) - - remaining_options = set(options) - set(required_options) - optional_options = [opt["name"] for opt in valid_options if not opt["required"]] - for option in remaining_options: - if option not in optional_options: - raise TVMCException( - f"Option '{option}' is invalid. Use --list-options to see all available options." - ) - - -def check_options_choices(options, valid_options): - """Check if an option value is among the option's choices, when choices exist. - - Parameters - ---------- - options: dict - dict indexed by option name of options and options values to be checked. - - valid_options: list of dict - list of all valid options and choices for a platform. - - Returns - ------- - None. Raise TVMCException if check fails, i.e. if an option value is not valid. - - """ - # Dict of all valid options and associated valid choices. - # Options with no choices are excluded from the dict. - valid_options_choices = { - opt["name"]: opt["choices"] for opt in valid_options if opt["choices"] is not None - } - - for option in options: - if option in valid_options_choices: - if options[option] not in valid_options_choices[option]: - raise TVMCException( - f"Choice '{options[option]}' for option '{option}' is invalid. " - "Use --list-options to see all available choices for that option." - ) - - -def get_and_check_options(passed_options, valid_options): - """Get options and check if they are valid. If choices exist for them, check values against it. - - Parameters - ---------- - passed_options: list of str - list of strings in the "key=value" form as captured by argparse. - - valid_option: list - list with all options available for a given API method / project as returned by - get_project_options(). - - Returns - ------- - opts: dict - dict indexed by option names and associated values. - - Or None if passed_options is None. - - """ - - if passed_options is None: - # No options to check - return None - - # From a list of k=v strings, make a dict options[k]=v - opts = get_options(passed_options) - # Check if passed options are valid - check_options(opts, valid_options) - # Check (when a list of choices exists) if the passed values are valid - check_options_choices(opts, valid_options) - - return opts - - -def get_project_dir(project_dir: Union[pathlib.Path, str]) -> str: - """Get project directory path""" - if not os.path.isabs(project_dir): - return os.path.abspath(project_dir) - return project_dir diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index dbf7e46ad003..d260c98b6721 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -29,11 +29,14 @@ from tvm.target import Target from tvm.relay.backend import Executor, Runtime -from . import common, composite_target, frontends +from . import composite_target, frontends from .model import TVMCModel, TVMCPackage from .main import register_parser -from .target import generate_target_args, reconstruct_target_args - +from .target import target_from_cli, generate_target_args, reconstruct_target_args +from .pass_config import parse_configs +from .pass_list import parse_pass_list_str +from .transform import convert_graph_layout +from .shape_parser import parse_shape_string # pylint: disable=invalid-name logger = logging.getLogger("TVMC") @@ -124,13 +127,13 @@ def add_compile_parser(subparsers, _): "--input-shapes", help="specify non-generic shapes for model to run, format is " '"input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]".', - type=common.parse_shape_string, + type=parse_shape_string, default=None, ) parser.add_argument( "--disabled-pass", help="disable specific passes, comma-separated list of pass names.", - type=common.parse_pass_list_str, + type=parse_pass_list_str, default="", ) @@ -249,12 +252,12 @@ def compile_model( """ mod, params = tvmc_model.mod, tvmc_model.params - config = common.parse_configs(pass_context_configs) + config = parse_configs(pass_context_configs) if desired_layout: - mod = common.convert_graph_layout(mod, desired_layout) + mod = convert_graph_layout(mod, desired_layout) - tvm_target, extra_targets = common.target_from_cli(target, additional_target_options) + tvm_target, extra_targets = target_from_cli(target, additional_target_options) tvm_target, target_host = Target.check_and_update_host_consist(tvm_target, target_host) for codegen_from_cli in extra_targets: diff --git a/python/tvm/driver/tvmc/composite_target.py b/python/tvm/driver/tvmc/composite_target.py index 848af1e4ee4e..f347158e5e0c 100644 --- a/python/tvm/driver/tvmc/composite_target.py +++ b/python/tvm/driver/tvmc/composite_target.py @@ -31,7 +31,7 @@ from tvm.relay.op.contrib.vitis_ai import partition_for_vitis_ai -from .common import TVMCException +from tvm.driver.tvmc import TVMCException # pylint: disable=invalid-name diff --git a/python/tvm/driver/tvmc/frontends.py b/python/tvm/driver/tvmc/frontends.py index b6773dca9a7b..a3222782c68e 100644 --- a/python/tvm/driver/tvmc/frontends.py +++ b/python/tvm/driver/tvmc/frontends.py @@ -31,8 +31,7 @@ import numpy as np from tvm import relay -from tvm.driver.tvmc.common import TVMCException -from tvm.driver.tvmc.common import TVMCImportError +from tvm.driver.tvmc import TVMCException, TVMCImportError from tvm.driver.tvmc.model import TVMCModel diff --git a/python/tvm/driver/tvmc/main.py b/python/tvm/driver/tvmc/main.py index 3fb8cd7e77ef..b74cc7d6eefb 100644 --- a/python/tvm/driver/tvmc/main.py +++ b/python/tvm/driver/tvmc/main.py @@ -25,8 +25,7 @@ import tvm -from tvm.driver.tvmc.common import TVMCException -from tvm.driver.tvmc.common import TVMCImportError +from tvm.driver.tvmc import TVMCException, TVMCImportError REGISTERED_PARSER = [] diff --git a/python/tvm/driver/tvmc/micro.py b/python/tvm/driver/tvmc/micro.py index a9c17b840ca6..4f478c7c3aa4 100644 --- a/python/tvm/driver/tvmc/micro.py +++ b/python/tvm/driver/tvmc/micro.py @@ -23,10 +23,10 @@ import shutil import sys +from . import TVMCException from .main import register_parser -from .common import ( - TVMCException, - TVMCSuppressedArgumentParser, +from .arguments import TVMCSuppressedArgumentParser +from .project import ( get_project_options, get_and_check_options, get_project_dir, diff --git a/python/tvm/driver/tvmc/model.py b/python/tvm/driver/tvmc/model.py index 5110aed21378..9a2617f3ed53 100644 --- a/python/tvm/driver/tvmc/model.py +++ b/python/tvm/driver/tvmc/model.py @@ -54,6 +54,7 @@ import tvm.contrib.cc from tvm import relay from tvm.contrib import utils +from tvm.driver.tvmc import TVMCException from tvm.relay.backend.executor_factory import GraphExecutorFactoryModule from tvm.runtime.module import BenchmarkResult @@ -62,8 +63,6 @@ except ImportError: export_model_library_format = None -from .common import TVMCException - class TVMCModel(object): """Initialize a TVMC model from a relay model definition or a saved file. diff --git a/python/tvm/driver/tvmc/pass_config.py b/python/tvm/driver/tvmc/pass_config.py new file mode 100644 index 000000000000..7cf0f0143e60 --- /dev/null +++ b/python/tvm/driver/tvmc/pass_config.py @@ -0,0 +1,122 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +TVMC PassContext Interface +""" + +import tvm +from tvm.driver.tvmc import TVMCException + + +def get_pass_config_value(name, value, config_type): + """Get a PassContext configuration value, based on its config data type. + + Parameters + ---------- + name: str + config identifier name. + value: str + value assigned to the config, provided via command line. + config_type: str + data type defined to the config, as string. + + Returns + ------- + parsed_value: bool, int or str + a representation of the input value, converted to the type + specified by config_type. + """ + + if config_type == "IntImm": + # "Bool" configurations in the PassContext are recognized as + # IntImm, so deal with this case here + mapping_values = { + "false": False, + "true": True, + } + + if value.isdigit(): + parsed_value = int(value) + else: + # if not an int, accept only values on the mapping table, case insensitive + parsed_value = mapping_values.get(value.lower(), None) + + if parsed_value is None: + raise TVMCException(f"Invalid value '{value}' for configuration '{name}'. ") + + if config_type == "runtime.String": + parsed_value = value + + return parsed_value + + +def parse_configs(input_configs): + """Parse configuration values set via command line. + + Parameters + ---------- + input_configs: list of str + list of configurations provided via command line. + + Returns + ------- + pass_context_configs: dict + a dict containing key-value configs to be used in the PassContext. + """ + if not input_configs: + return {} + + all_configs = tvm.ir.transform.PassContext.list_configs() + supported_config_types = ("IntImm", "runtime.String") + supported_configs = [ + name for name in all_configs.keys() if all_configs[name]["type"] in supported_config_types + ] + + pass_context_configs = {} + + for config in input_configs: + if not config: + raise TVMCException( + f"Invalid format for configuration '{config}', use =" + ) + + # Each config is expected to be provided as "name=value" + try: + name, value = config.split("=") + name = name.strip() + value = value.strip() + except ValueError: + raise TVMCException( + f"Invalid format for configuration '{config}', use =" + ) + + if name not in all_configs: + raise TVMCException( + f"Configuration '{name}' is not defined in TVM. " + f"These are the existing configurations: {', '.join(all_configs)}" + ) + + if name not in supported_configs: + raise TVMCException( + f"Configuration '{name}' uses a data type not supported by TVMC. " + f"The following configurations are supported: {', '.join(supported_configs)}" + ) + + parsed_value = get_pass_config_value(name, value, all_configs[name]["type"]) + pass_context_configs[name] = parsed_value + + return pass_context_configs diff --git a/python/tvm/driver/tvmc/pass_list.py b/python/tvm/driver/tvmc/pass_list.py new file mode 100644 index 000000000000..09ec6aaf9102 --- /dev/null +++ b/python/tvm/driver/tvmc/pass_list.py @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language +""" +TVMC Pass List Management +""" + +import argparse + +import tvm +from tvm._ffi import registry + + +def parse_pass_list_str(input_string): + """Parse an input string for existing passes + + Parameters + ---------- + input_string: str + Possibly comma-separated string with the names of passes + + Returns + ------- + list: a list of existing passes. + """ + _prefix = "relay._transform." + pass_list = input_string.split(",") + missing_list = [ + p.strip() + for p in pass_list + if len(p.strip()) > 0 and tvm.get_global_func(_prefix + p.strip(), True) is None + ] + if len(missing_list) > 0: + available_list = [ + n[len(_prefix) :] for n in registry.list_global_func_names() if n.startswith(_prefix) + ] + raise argparse.ArgumentTypeError( + "Following passes are not registered within tvm: {}. Available: {}.".format( + ", ".join(missing_list), ", ".join(sorted(available_list)) + ) + ) + return pass_list diff --git a/python/tvm/driver/tvmc/project.py b/python/tvm/driver/tvmc/project.py new file mode 100644 index 000000000000..d9b22a2d6fc3 --- /dev/null +++ b/python/tvm/driver/tvmc/project.py @@ -0,0 +1,233 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +TVMC Project Generation Functions +""" + +import os +import pathlib +from collections import defaultdict +from typing import Union + +from . import TVMCException +from .fmtopt import format_option + + +def get_project_options(project_info): + """Get all project options as returned by Project API 'server_info_query' + and return them in a dict indexed by the API method they belong to. + + + Parameters + ---------- + project_info: dict of list + a dict of lists as returned by Project API 'server_info_query' among + which there is a list called 'project_options' containing all the + project options available for a given project/platform. + + Returns + ------- + options_by_method: dict of list + a dict indexed by the API method names (e.g. "generate_project", + "build", "flash", or "open_transport") of lists containing all the + options (plus associated metadata and formatted help text) that belong + to a method. + + The metadata associated to the options include the field 'choices' and + 'required' which are convenient for parsers. + + The formatted help text field 'help_text' is a string that contains the + name of the option, the choices for the option, and the option's default + value. + """ + options = project_info["project_options"] + + options_by_method = defaultdict(list) + for opt in options: + # Get list of methods associated with an option based on the + # existance of a 'required' or 'optional' lists. API specification + # guarantees at least one of these lists will exist. If a list does + # not exist it's returned as None by the API. + metadata = ["required", "optional"] + option_methods = [(opt[md], bool(md == "required")) for md in metadata if opt[md]] + for methods, is_opt_required in option_methods: + for method in methods: + name = opt["name"] + + # Only for boolean options set 'choices' accordingly to the + # option type. API returns 'choices' associated to them + # as None but 'choices' can be deduced from 'type' in this case. + if opt["type"] == "bool": + opt["choices"] = ["true", "false"] + + if opt["choices"]: + choices = "{" + ", ".join(opt["choices"]) + "}" + else: + choices = opt["name"].upper() + option_choices_text = f"{name}={choices}" + + help_text = opt["help"][0].lower() + opt["help"][1:] + + if opt["default"]: + default_text = f"Defaults to '{opt['default']}'." + else: + default_text = None + + formatted_help_text = format_option( + option_choices_text, help_text, default_text, is_opt_required + ) + + option = { + "name": opt["name"], + "choices": opt["choices"], + "help_text": formatted_help_text, + "required": is_opt_required, + } + options_by_method[method].append(option) + + return options_by_method + + +def get_options(options): + """Get option and option value from the list options returned by the parser. + + Parameters + ---------- + options: list of str + list of strings of the form "option=value" as returned by the parser. + + Returns + ------- + opts: dict + dict indexed by option names and associated values. + """ + + opts = {} + for option in options: + try: + k, v = option.split("=") + opts[k] = v + except ValueError: + raise TVMCException(f"Invalid option format: {option}. Please use OPTION=VALUE.") + + return opts + + +def check_options(options, valid_options): + """Check if an option (required or optional) is valid. i.e. in the list of valid options. + + Parameters + ---------- + options: dict + dict indexed by option name of options and options values to be checked. + + valid_options: list of dict + list of all valid options and choices for a platform. + + Returns + ------- + None. Raise TVMCException if check fails, i.e. if an option is not in the list of valid options. + + """ + required_options = [opt["name"] for opt in valid_options if opt["required"]] + for required_option in required_options: + if required_option not in options: + raise TVMCException( + f"Option '{required_option}' is required but was not specified. Use --list-options " + "to see all required options." + ) + + remaining_options = set(options) - set(required_options) + optional_options = [opt["name"] for opt in valid_options if not opt["required"]] + for option in remaining_options: + if option not in optional_options: + raise TVMCException( + f"Option '{option}' is invalid. Use --list-options to see all available options." + ) + + +def check_options_choices(options, valid_options): + """Check if an option value is among the option's choices, when choices exist. + + Parameters + ---------- + options: dict + dict indexed by option name of options and options values to be checked. + + valid_options: list of dict + list of all valid options and choices for a platform. + + Returns + ------- + None. Raise TVMCException if check fails, i.e. if an option value is not valid. + + """ + # Dict of all valid options and associated valid choices. + # Options with no choices are excluded from the dict. + valid_options_choices = { + opt["name"]: opt["choices"] for opt in valid_options if opt["choices"] is not None + } + + for option in options: + if option in valid_options_choices: + if options[option] not in valid_options_choices[option]: + raise TVMCException( + f"Choice '{options[option]}' for option '{option}' is invalid. " + "Use --list-options to see all available choices for that option." + ) + + +def get_and_check_options(passed_options, valid_options): + """Get options and check if they are valid. If choices exist for them, check values against it. + + Parameters + ---------- + passed_options: list of str + list of strings in the "key=value" form as captured by argparse. + + valid_option: list + list with all options available for a given API method / project as returned by + get_project_options(). + + Returns + ------- + opts: dict + dict indexed by option names and associated values. + + Or None if passed_options is None. + + """ + + if passed_options is None: + # No options to check + return None + + # From a list of k=v strings, make a dict options[k]=v + opts = get_options(passed_options) + # Check if passed options are valid + check_options(opts, valid_options) + # Check (when a list of choices exists) if the passed values are valid + check_options_choices(opts, valid_options) + + return opts + + +def get_project_dir(project_dir: Union[pathlib.Path, str]) -> str: + """Get project directory path""" + if not os.path.isabs(project_dir): + return os.path.abspath(project_dir) + return project_dir diff --git a/python/tvm/driver/tvmc/registry.py b/python/tvm/driver/tvmc/registry.py index 384a3bd1baf6..334aa1b61be8 100644 --- a/python/tvm/driver/tvmc/registry.py +++ b/python/tvm/driver/tvmc/registry.py @@ -18,7 +18,7 @@ This file contains functions for processing registry based inputs for the TVMC CLI """ -from tvm.driver.tvmc.common import TVMCException +from tvm.driver.tvmc import TVMCException # We can't tell the type inside an Array but all current options are strings so # it can default to that. Bool is used alongside Integer but aren't distinguished diff --git a/python/tvm/driver/tvmc/runner.py b/python/tvm/driver/tvmc/runner.py index fd342a569956..a2343962af95 100644 --- a/python/tvm/driver/tvmc/runner.py +++ b/python/tvm/driver/tvmc/runner.py @@ -33,17 +33,18 @@ from tvm.contrib import graph_executor as runtime from tvm.contrib.debugger import debug_executor from tvm.relay.param_dict import load_param_dict -from . import common -from .common import ( - TVMCException, - TVMCSuppressedArgumentParser, +from . import TVMCException +from .arguments import TVMCSuppressedArgumentParser +from .project import ( get_project_options, get_and_check_options, get_project_dir, ) + from .main import register_parser from .model import TVMCPackage, TVMCResult from .result_utils import get_top_results +from .tracker import tracker_host_port_from_cli try: import tvm.micro.project as project @@ -245,7 +246,7 @@ def drive_run(args): except ReadError: raise TVMCException(f"Could not read model from archive {path}!") - rpc_hostname, rpc_port = common.tracker_host_port_from_cli(args.rpc_tracker) + rpc_hostname, rpc_port = tracker_host_port_from_cli(args.rpc_tracker) try: inputs = np.load(args.inputs) if args.inputs else {} diff --git a/python/tvm/driver/tvmc/shape_parser.py b/python/tvm/driver/tvmc/shape_parser.py new file mode 100644 index 000000000000..24b7727703d6 --- /dev/null +++ b/python/tvm/driver/tvmc/shape_parser.py @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +TVMC Shape Parsing +""" + +import argparse +import re + +from tvm import relay + + +def parse_shape_string(inputs_string): + """Parse an input shape dictionary string to a usable dictionary. + + Parameters + ---------- + inputs_string: str + A string of the form "input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]" that + indicates the desired shape for specific model inputs. Colons, forward slashes and dots + within input_names are supported. Spaces are supported inside of dimension arrays. + + Returns + ------- + shape_dict: dict + A dictionary mapping input names to their shape for use in relay frontend converters. + """ + + # Create a regex pattern that extracts each separate input mapping. + # We want to be able to handle: + # * Spaces inside arrays + # * forward slashes inside names (but not at the beginning or end) + # * colons inside names (but not at the beginning or end) + # * dots inside names + pattern = r"(?:\w+\/)?[:\w.]+\:\s*\[\-?\d+(?:\,\s*\-?\d+)*\]" + input_mappings = re.findall(pattern, inputs_string) + if not input_mappings: + raise argparse.ArgumentTypeError( + "--input-shapes argument must be of the form " + '"input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]"' + ) + shape_dict = {} + for mapping in input_mappings: + # Remove whitespace. + mapping = mapping.replace(" ", "") + # Split mapping into name and shape. + name, shape_string = mapping.rsplit(":", 1) + # Convert shape string into a list of integers or Anys if negative. + shape = [int(x) if int(x) > 0 else relay.Any() for x in shape_string.strip("][").split(",")] + # Add parsed mapping to shape dictionary. + shape_dict[name] = shape + + return shape_dict diff --git a/python/tvm/driver/tvmc/target.py b/python/tvm/driver/tvmc/target.py index 067a3610f57e..15ed19d20549 100644 --- a/python/tvm/driver/tvmc/target.py +++ b/python/tvm/driver/tvmc/target.py @@ -18,9 +18,19 @@ This file contains functions for processing target inputs for the TVMC CLI """ +import os +import logging +import json +import re + +import tvm from tvm.driver import tvmc +from tvm.driver.tvmc import TVMCException from tvm.target import Target, TargetKind +# pylint: disable=invalid-name +logger = logging.getLogger("TVMC") + # We can't tell the type inside an Array but all current options are strings so # it can default to that. Bool is used alongside Integer but aren't distinguished # between as both are represented by IntImm @@ -74,3 +84,271 @@ def reconstruct_target_args(args): if kind_options: reconstructed[target_kind] = kind_options return reconstructed + + +def validate_targets(parse_targets, additional_target_options=None): + """ + Apply a series of validations in the targets provided via CLI. + """ + tvm_target_kinds = tvm.target.Target.list_kinds() + targets = [t["name"] for t in parse_targets] + + if len(targets) > len(set(targets)): + raise TVMCException("Duplicate target definitions are not allowed") + + if targets[-1] not in tvm_target_kinds: + tvm_target_names = ", ".join(tvm_target_kinds) + raise TVMCException( + f"The last target needs to be a TVM target. Choices: {tvm_target_names}" + ) + + tvm_targets = [t for t in targets if t in tvm_target_kinds] + if len(tvm_targets) > 2: + verbose_tvm_targets = ", ".join(tvm_targets) + raise TVMCException( + "Only two of the following targets can be used at a time. " + f"Found: {verbose_tvm_targets}." + ) + + if additional_target_options is not None: + for target_name in additional_target_options: + if not any([target for target in parse_targets if target["name"] == target_name]): + first_option = list(additional_target_options[target_name].keys())[0] + raise TVMCException( + f"Passed --target-{target_name}-{first_option}" + f" but did not specify {target_name} target" + ) + + +def tokenize_target(target): + """ + Extract a list of tokens from a target specification text. + + It covers some corner-cases that are not covered by the built-in + module 'shlex', such as the use of "+" as a punctuation character. + + + Example + ------- + + For the input `foo -op1=v1 -op2="v ,2", bar -op3=v-4` we + should obtain: + + ["foo", "-op1=v1", "-op2="v ,2"", ",", "bar", "-op3=v-4"] + + Parameters + ---------- + target : str + Target options sent via CLI arguments + + Returns + ------- + list of str + a list of parsed tokens extracted from the target string + """ + + # Regex to tokenize the "--target" value. It is split into five parts + # to match with: + # 1. target and option names e.g. llvm, -mattr=, -mcpu= + # 2. option values, all together, without quotes e.g. -mattr=+foo,+opt + # 3. option values, when single quotes are used e.g. -mattr='+foo, +opt' + # 4. option values, when double quotes are used e.g. -mattr="+foo ,+opt" + # 5. commas that separate different targets e.g. "my-target, llvm" + target_pattern = ( + r"(\-{0,2}[\w\-]+\=?" + r"(?:[\w\+\-\.]+(?:,[\w\+\-\.])*" + r"|[\'][\w\+\-,\s\.]+[\']" + r"|[\"][\w\+\-,\s\.]+[\"])*" + r"|,)" + ) + + return re.findall(target_pattern, target) + + +def parse_target(target): + """ + Parse a plain string of targets provided via a command-line + argument. + + To send more than one codegen, a comma-separated list + is expected. Options start with -=. + + We use python standard library 'shlex' to parse the argument in + a POSIX compatible way, so that if options are defined as + strings with spaces or commas, for example, this is considered + and parsed accordingly. + + + Example + ------- + + For the input `--target="foo -op1=v1 -op2="v ,2", bar -op3=v-4"` we + should obtain: + + [ + { + name: "foo", + opts: {"op1":"v1", "op2":"v ,2"}, + raw: 'foo -op1=v1 -op2="v ,2"' + }, + { + name: "bar", + opts: {"op3":"v-4"}, + raw: 'bar -op3=v-4' + } + ] + + Parameters + ---------- + target : str + Target options sent via CLI arguments + + Returns + ------- + codegens : list of dict + This list preserves the order in which codegens were + provided via command line. Each Dict contains three keys: + 'name', containing the name of the codegen; 'opts' containing + a key-value for all options passed via CLI; 'raw', + containing the plain string for this codegen + """ + codegen_names = tvmc.composite_target.get_codegen_names() + codegens = [] + + tvm_target_kinds = tvm.target.Target.list_kinds() + parsed_tokens = tokenize_target(target) + + split_codegens = [] + current_codegen = [] + split_codegens.append(current_codegen) + for token in parsed_tokens: + # every time there is a comma separating + # two codegen definitions, prepare for + # a new codegen + if token == ",": + current_codegen = [] + split_codegens.append(current_codegen) + else: + # collect a new token for the current + # codegen being parsed + current_codegen.append(token) + + # at this point we have a list of lists, + # each item on the first list is a codegen definition + # in the comma-separated values + for codegen_def in split_codegens: + # the first is expected to be the name + name = codegen_def[0] + is_tvm_target = name in tvm_target_kinds and name not in codegen_names + raw_target = " ".join(codegen_def) + all_opts = codegen_def[1:] if len(codegen_def) > 1 else [] + opts = {} + for opt in all_opts: + try: + # deal with -- prefixed flags + if opt.startswith("--"): + opt_name = opt[2:] + opt_value = True + else: + opt = opt[1:] if opt.startswith("-") else opt + opt_name, opt_value = opt.split("=", maxsplit=1) + + # remove quotes from the value: quotes are only parsed if they match, + # so it is safe to assume that if the string starts with quote, it ends + # with quote. + opt_value = opt_value[1:-1] if opt_value[0] in ('"', "'") else opt_value + except ValueError: + raise ValueError(f"Error when parsing '{opt}'") + + opts[opt_name] = opt_value + + codegens.append( + {"name": name, "opts": opts, "raw": raw_target, "is_tvm_target": is_tvm_target} + ) + + return codegens + + +def is_inline_json(target): + try: + json.loads(target) + return True + except json.decoder.JSONDecodeError: + return False + + +def _combine_target_options(target, additional_target_options=None): + if additional_target_options is None: + return target + if target["name"] in additional_target_options: + target["opts"].update(additional_target_options[target["name"]]) + return target + + +def _recombobulate_target(target): + name = target["name"] + opts = " ".join([f"-{key}={value}" for key, value in target["opts"].items()]) + return f"{name} {opts}" + + +def target_from_cli(target, additional_target_options=None): + """ + Create a tvm.target.Target instance from a + command line interface (CLI) string. + + Parameters + ---------- + target : str + compilation target as plain string, + inline JSON or path to a JSON file + + additional_target_options: Optional[Dict[str, Dict[str,str]]] + dictionary of additional target options to be + combined with parsed targets + + Returns + ------- + tvm.target.Target + an instance of target device information + extra_targets : list of dict + This list preserves the order in which extra targets were + provided via command line. Each Dict contains three keys: + 'name', containing the name of the codegen; 'opts' containing + a key-value for all options passed via CLI; 'raw', + containing the plain string for this codegen + """ + extra_targets = [] + + if os.path.isfile(target): + with open(target) as target_file: + logger.debug("target input is a path: %s", target) + target = "".join(target_file.readlines()) + elif is_inline_json(target): + logger.debug("target input is inline JSON: %s", target) + else: + logger.debug("target input is plain text: %s", target) + try: + parsed_targets = parse_target(target) + except ValueError as error: + raise TVMCException(f"Error parsing target string '{target}'.\nThe error was: {error}") + + validate_targets(parsed_targets, additional_target_options) + tvm_targets = [ + _combine_target_options(t, additional_target_options) + for t in parsed_targets + if t["is_tvm_target"] + ] + + # Validated target strings have 1 or 2 tvm targets, otherwise + # `validate_targets` above will fail. + if len(tvm_targets) == 1: + target = _recombobulate_target(tvm_targets[0]) + target_host = None + else: + assert len(tvm_targets) == 2 + target = _recombobulate_target(tvm_targets[0]) + target_host = _recombobulate_target(tvm_targets[1]) + + extra_targets = [t for t in parsed_targets if not t["is_tvm_target"]] + + return tvm.target.Target(target, host=target_host), extra_targets diff --git a/python/tvm/driver/tvmc/tracker.py b/python/tvm/driver/tvmc/tracker.py new file mode 100644 index 000000000000..65fda42ac541 --- /dev/null +++ b/python/tvm/driver/tvmc/tracker.py @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language +""" +TVMC Remote Tracker +""" + +import logging +from urllib.parse import urlparse + +# pylint: disable=invalid-name +logger = logging.getLogger("TVMC") + + +def tracker_host_port_from_cli(rpc_tracker_str): + """Extract hostname and (optional) port from strings + like "1.2.3.4:9090" or "4.3.2.1". + + Used as a helper function to cover --rpc-tracker + command line argument, in different subcommands. + + Parameters + ---------- + rpc_tracker_str : str + hostname (or IP address) and port of the RPC tracker, + in the format 'hostname[:port]'. + + Returns + ------- + rpc_hostname : str or None + hostname or IP address, extracted from input. + rpc_port : int or None + port number extracted from input (9090 default). + """ + + rpc_hostname = rpc_port = None + + if rpc_tracker_str: + parsed_url = urlparse("//%s" % rpc_tracker_str) + rpc_hostname = parsed_url.hostname + rpc_port = parsed_url.port or 9090 + logger.info("RPC tracker hostname: %s", rpc_hostname) + logger.info("RPC tracker port: %s", rpc_port) + + return rpc_hostname, rpc_port diff --git a/python/tvm/driver/tvmc/transform.py b/python/tvm/driver/tvmc/transform.py new file mode 100644 index 000000000000..3f7776577876 --- /dev/null +++ b/python/tvm/driver/tvmc/transform.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language +""" +TVMC Graph Transforms +""" + +from tvm import relay, transform +from tvm.driver.tvmc import TVMCException + + +def convert_graph_layout(mod, desired_layout): + """Alter the layout of the input graph. + + Parameters + ---------- + mod : tvm.IRModule + The relay module to convert. + desired_layout : str + The layout to convert to. + + Returns + ------- + mod : tvm.IRModule + The converted module. + """ + + # Assume for the time being that graphs only have + # conv2d as heavily-sensitive operators. + desired_layouts = { + "nn.conv2d": [desired_layout, "default"], + "nn.conv2d_transpose": [desired_layout, "default"], + "qnn.conv2d": [desired_layout, "default"], + } + + # Convert the layout of the graph where possible. + seq = transform.Sequential( + [ + relay.transform.RemoveUnusedFunctions(), + relay.transform.ConvertLayout(desired_layouts), + ] + ) + + with transform.PassContext(opt_level=3): + try: + return seq(mod) + except Exception as err: + raise TVMCException( + "Error converting layout to {0}: {1}".format(desired_layout, str(err)) + ) diff --git a/tests/python/driver/tvmc/test_autotuner.py b/tests/python/driver/tvmc/test_autotuner.py index 52c331737e6b..a1915a0251e9 100644 --- a/tests/python/driver/tvmc/test_autotuner.py +++ b/tests/python/driver/tvmc/test_autotuner.py @@ -153,7 +153,7 @@ def test_tune_tasks__invalid_tuner(onnx_mnist, tmpdir_factory): tasks = _get_tasks(onnx_mnist) log_file = os.path.join(tmpdir_factory.mktemp("data"), "log2.txt") - with pytest.raises(tvmc.common.TVMCException): + with pytest.raises(tvmc.TVMCException): tvmc.autotuner.tune_tasks(tasks, log_file, _get_measure_options(), "invalid_tuner", 1, 1) diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index 73f3a0f27eba..5ebcb8eea27d 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -70,7 +70,7 @@ def test_compile_tflite_module(tflite_mobilenet_v1_1_quant): verify_compile_tflite_module(tflite_mobilenet_v1_1_quant) # Check with manual shape override shape_string = "input:[1,224,224,3]" - shape_dict = tvmc.common.parse_shape_string(shape_string) + shape_dict = tvmc.shape_parser.parse_shape_string(shape_string) verify_compile_tflite_module(tflite_mobilenet_v1_1_quant, shape_dict) @@ -218,7 +218,7 @@ def test_compile_onnx_module(onnx_resnet50): verify_compile_onnx_module(onnx_resnet50) # Test with manual shape dict shape_string = "data:[1,3,200,200]" - shape_dict = tvmc.common.parse_shape_string(shape_string) + shape_dict = tvmc.shape_parser.parse_shape_string(shape_string) verify_compile_onnx_module(onnx_resnet50, shape_dict) @@ -296,7 +296,7 @@ def test_compile_paddle_module(paddle_resnet50): verify_compile_paddle_module(paddle_resnet50) # Check with manual shape override shape_string = "inputs:[1,3,224,224]" - shape_dict = tvmc.common.parse_shape_string(shape_string) + shape_dict = tvmc.shape_parser.parse_shape_string(shape_string) verify_compile_paddle_module(paddle_resnet50, shape_dict) diff --git a/tests/python/driver/tvmc/test_composite_target.py b/tests/python/driver/tvmc/test_composite_target.py index 80b4d1be93d5..dfaf30c9e2b1 100644 --- a/tests/python/driver/tvmc/test_composite_target.py +++ b/tests/python/driver/tvmc/test_composite_target.py @@ -27,7 +27,7 @@ from tvm.driver import tvmc -from tvm.driver.tvmc.common import TVMCException +from tvm.driver.tvmc import TVMCException def test_get_codegen_names(): diff --git a/tests/python/driver/tvmc/test_frontends.py b/tests/python/driver/tvmc/test_frontends.py index e887857093f7..b76066994cb2 100644 --- a/tests/python/driver/tvmc/test_frontends.py +++ b/tests/python/driver/tvmc/test_frontends.py @@ -24,8 +24,7 @@ from tvm.ir.module import IRModule from tvm.driver import tvmc -from tvm.driver.tvmc.common import TVMCException -from tvm.driver.tvmc.common import TVMCImportError +from tvm.driver.tvmc import TVMCException, TVMCImportError from tvm.driver.tvmc.model import TVMCModel @@ -268,7 +267,7 @@ def test_compile_tflite_module_nhwc_to_nchw(tflite_mobilenet_v1_1_quant): before = tvmc_model.mod expected_layout = "NCHW" - after = tvmc.common.convert_graph_layout(before, expected_layout) + after = tvmc.transform.convert_graph_layout(before, expected_layout) layout_transform_calls = [] @@ -293,7 +292,7 @@ def test_compile_onnx_module_nchw_to_nhwc(onnx_resnet50): before = tvmc_model.mod expected_layout = "NHWC" - after = tvmc.common.convert_graph_layout(before, expected_layout) + after = tvmc.transform.convert_graph_layout(before, expected_layout) layout_transform_calls = [] @@ -318,7 +317,7 @@ def test_compile_paddle_module_nchw_to_nhwc(paddle_resnet50): before = tvmc_model.mod expected_layout = "NHWC" - after = tvmc.common.convert_graph_layout(before, expected_layout) + after = tvmc.transform.convert_graph_layout(before, expected_layout) layout_transform_calls = [] @@ -343,7 +342,7 @@ def test_compile_tflite_module__same_layout__nhwc_to_nhwc(tflite_mobilenet_v1_1_ before = tvmc_model.mod expected_layout = "NHWC" - after = tvmc.common.convert_graph_layout(before, expected_layout) + after = tvmc.transform.convert_graph_layout(before, expected_layout) layout_transform_calls = [] @@ -368,7 +367,7 @@ def test_compile_onnx_module__same_layout__nchw_to_nchw(onnx_resnet50): before = tvmc_model.mod expected_layout = "NCHW" - after = tvmc.common.convert_graph_layout(before, expected_layout) + after = tvmc.transform.convert_graph_layout(before, expected_layout) layout_transform_calls = [] diff --git a/tests/python/driver/tvmc/test_pass_config.py b/tests/python/driver/tvmc/test_pass_config.py index d8ffd7d4d521..bb815e1dc8aa 100644 --- a/tests/python/driver/tvmc/test_pass_config.py +++ b/tests/python/driver/tvmc/test_pass_config.py @@ -18,33 +18,33 @@ import pytest from tvm.contrib.target.vitis_ai import vitis_ai_available -from tvm.driver import tvmc -from tvm.driver.tvmc.common import TVMCException +from tvm.driver.tvmc import TVMCException +from tvm.driver.tvmc.pass_config import parse_configs def test_config_invalid_format(): with pytest.raises(TVMCException): - _ = tvmc.common.parse_configs(["relay.backend.use_auto_scheduler.missing.value"]) + _ = parse_configs(["relay.backend.use_auto_scheduler.missing.value"]) def test_config_missing_from_tvm(): with pytest.raises(TVMCException): - _ = tvmc.common.parse_configs(["relay.backend.use_auto_scheduler.missing.value=1234"]) + _ = parse_configs(["relay.backend.use_auto_scheduler.missing.value=1234"]) def test_config_unsupported_tvmc_config(): with pytest.raises(TVMCException): - _ = tvmc.common.parse_configs(["tir.LoopPartition=value"]) + _ = parse_configs(["tir.LoopPartition=value"]) def test_config_empty(): with pytest.raises(TVMCException): - _ = tvmc.common.parse_configs([""]) + _ = parse_configs([""]) def test_config_valid_config_bool(): - configs = tvmc.common.parse_configs(["relay.backend.use_auto_scheduler=true"]) + configs = parse_configs(["relay.backend.use_auto_scheduler=true"]) assert len(configs) == 1 assert "relay.backend.use_auto_scheduler" in configs.keys() @@ -56,7 +56,7 @@ def test_config_valid_config_bool(): reason="--target vitis-ai is not available. TVM built with 'USE_VITIS_AI OFF'", ) def test_config_valid_multiple_configs(): - configs = tvmc.common.parse_configs( + configs = parse_configs( [ "relay.backend.use_auto_scheduler=false", "tir.detect_global_barrier=10", diff --git a/tests/python/driver/tvmc/test_pass_list.py b/tests/python/driver/tvmc/test_pass_list.py index de50b04f415a..f43da6371b9b 100644 --- a/tests/python/driver/tvmc/test_pass_list.py +++ b/tests/python/driver/tvmc/test_pass_list.py @@ -17,15 +17,15 @@ import argparse import pytest -from tvm.driver import tvmc +from tvm.driver.tvmc.pass_list import parse_pass_list_str def test_parse_pass_list_str(): - assert [""] == tvmc.common.parse_pass_list_str("") - assert ["FoldScaleAxis", "FuseOps"] == tvmc.common.parse_pass_list_str("FoldScaleAxis,FuseOps") + assert [""] == parse_pass_list_str("") + assert ["FoldScaleAxis", "FuseOps"] == parse_pass_list_str("FoldScaleAxis,FuseOps") with pytest.raises(argparse.ArgumentTypeError) as ate: - tvmc.common.parse_pass_list_str("MyYobaPass,MySuperYobaPass,FuseOps") + parse_pass_list_str("MyYobaPass,MySuperYobaPass,FuseOps") assert "MyYobaPass" in str(ate.value) assert "MySuperYobaPass" in str(ate.value) diff --git a/tests/python/driver/tvmc/test_registry_options.py b/tests/python/driver/tvmc/test_registry_options.py index 458d0a88d1f7..dbd7cc050091 100644 --- a/tests/python/driver/tvmc/test_registry_options.py +++ b/tests/python/driver/tvmc/test_registry_options.py @@ -19,7 +19,7 @@ import pytest -from tvm.driver.tvmc.common import TVMCException +from tvm.driver.tvmc import TVMCException from tvm.driver.tvmc.registry import generate_registry_args, reconstruct_registry_entity from tvm.relay.backend import Executor diff --git a/tests/python/driver/tvmc/test_runner.py b/tests/python/driver/tvmc/test_runner.py index 2ce363ab5911..30ce2c6f2191 100644 --- a/tests/python/driver/tvmc/test_runner.py +++ b/tests/python/driver/tvmc/test_runner.py @@ -48,7 +48,7 @@ def test_generate_tensor_data_random(): def test_generate_tensor_data__type_unknown(): - with pytest.raises(tvmc.common.TVMCException) as e: + with pytest.raises(tvmc.TVMCException) as e: tvmc.runner.generate_tensor_data((2, 3), "float32", "whatever") diff --git a/tests/python/driver/tvmc/test_shape_parser.py b/tests/python/driver/tvmc/test_shape_parser.py index f49d89ac7c0f..1e3cde12928a 100644 --- a/tests/python/driver/tvmc/test_shape_parser.py +++ b/tests/python/driver/tvmc/test_shape_parser.py @@ -19,19 +19,19 @@ import pytest -from tvm.driver import tvmc +from tvm.driver.tvmc.shape_parser import parse_shape_string def test_shape_parser(): # Check that a valid input is parsed correctly shape_string = "input:[10,10,10]" - shape_dict = tvmc.common.parse_shape_string(shape_string) + shape_dict = parse_shape_string(shape_string) assert shape_dict == {"input": [10, 10, 10]} def test_alternate_syntax(): shape_string = "input:0:[10,10,10] input2:[20,20,20,20]" - shape_dict = tvmc.common.parse_shape_string(shape_string) + shape_dict = parse_shape_string(shape_string) assert shape_dict == {"input:0": [10, 10, 10], "input2": [20, 20, 20, 20]} @@ -44,14 +44,14 @@ def test_alternate_syntax(): ], ) def test_alternate_syntaxes(shape_string): - shape_dict = tvmc.common.parse_shape_string(shape_string) + shape_dict = parse_shape_string(shape_string) assert shape_dict == {"input": [10, 10, 10], "input2": [20, 20, 20, 20]} def test_negative_dimensions(): # Check that negative dimensions parse to Any correctly. shape_string = "input:[-1,3,224,224]" - shape_dict = tvmc.common.parse_shape_string(shape_string) + shape_dict = parse_shape_string(shape_string) # Convert to strings to allow comparison with Any. assert str(shape_dict) == "{'input': [?, 3, 224, 224]}" @@ -59,7 +59,7 @@ def test_negative_dimensions(): def test_multiple_valid_gpu_inputs(): # Check that multiple valid gpu inputs are parsed correctly. shape_string = "gpu_0/data_0:[1, -1,224,224] gpu_1/data_1:[7, 7]" - shape_dict = tvmc.common.parse_shape_string(shape_string) + shape_dict = parse_shape_string(shape_string) expected = "{'gpu_0/data_0': [1, ?, 224, 224], 'gpu_1/data_1': [7, 7]}" assert str(shape_dict) == expected @@ -67,19 +67,19 @@ def test_multiple_valid_gpu_inputs(): def test_invalid_pattern(): shape_string = "input:[a,10]" with pytest.raises(argparse.ArgumentTypeError): - tvmc.common.parse_shape_string(shape_string) + parse_shape_string(shape_string) def test_invalid_separators(): shape_string = "input:5,10 input2:10,10" with pytest.raises(argparse.ArgumentTypeError): - tvmc.common.parse_shape_string(shape_string) + parse_shape_string(shape_string) def test_invalid_colon(): shape_string = "gpu_0/data_0:5,10 :test:10,10" with pytest.raises(argparse.ArgumentTypeError): - tvmc.common.parse_shape_string(shape_string) + parse_shape_string(shape_string) @pytest.mark.parametrize( @@ -93,11 +93,11 @@ def test_invalid_colon(): ) def test_invalid_slashes(shape_string): with pytest.raises(argparse.ArgumentTypeError): - tvmc.common.parse_shape_string(shape_string) + parse_shape_string(shape_string) def test_dot(): # Check dot in input name shape_string = "input.1:[10,10,10]" - shape_dict = tvmc.common.parse_shape_string(shape_string) + shape_dict = parse_shape_string(shape_string) assert shape_dict == {"input.1": [10, 10, 10]} diff --git a/tests/python/driver/tvmc/test_target.py b/tests/python/driver/tvmc/test_target.py index 06db5c47ea7e..532ecbeb0a1a 100644 --- a/tests/python/driver/tvmc/test_target.py +++ b/tests/python/driver/tvmc/test_target.py @@ -17,33 +17,32 @@ import pytest -from tvm.driver import tvmc - -from tvm.driver.tvmc.common import TVMCException +from tvm.driver.tvmc import TVMCException +from tvm.driver.tvmc.target import target_from_cli, tokenize_target, parse_target def test_target_from_cli__error_duplicate(): with pytest.raises(TVMCException): - _ = tvmc.common.target_from_cli("llvm, llvm") + _ = target_from_cli("llvm, llvm") def test_target_invalid_more_than_two_tvm_targets(): with pytest.raises(TVMCException): - _ = tvmc.common.target_from_cli("cuda, opencl, llvm") + _ = target_from_cli("cuda, opencl, llvm") def test_target_from_cli__error_target_not_found(): with pytest.raises(TVMCException): - _ = tvmc.common.target_from_cli("invalidtarget") + _ = target_from_cli("invalidtarget") def test_target_from_cli__error_no_tvm_target(): with pytest.raises(TVMCException): - _ = tvmc.common.target_from_cli("ethos-n77") + _ = target_from_cli("ethos-n77") def test_target_two_tvm_targets(): - tvm_target, extra_targets = tvmc.common.target_from_cli( + tvm_target, extra_targets = target_from_cli( "opencl -device=mali, llvm -mtriple=aarch64-linux-gnu" ) @@ -55,7 +54,7 @@ def test_target_two_tvm_targets(): def test_tokenize_target_with_opts(): - tokens = tvmc.common.tokenize_target("foo -opt1=value1 --flag, bar -opt2=value2") + tokens = tokenize_target("foo -opt1=value1 --flag, bar -opt2=value2") expected_tokens = ["foo", "-opt1=value1", "--flag", ",", "bar", "-opt2=value2"] assert len(tokens) == len(expected_tokens) @@ -63,7 +62,7 @@ def test_tokenize_target_with_opts(): def test_tokenize_target_with_plus_sign(): - tokens = tvmc.common.tokenize_target("foo -opt1=+value1 --flag, bar -opt2=test,+v") + tokens = tokenize_target("foo -opt1=+value1 --flag, bar -opt2=test,+v") expected_tokens = ["foo", "-opt1=+value1", "--flag", ",", "bar", "-opt2=test,+v"] assert len(tokens) == len(expected_tokens) @@ -71,7 +70,7 @@ def test_tokenize_target_with_plus_sign(): def test_tokenize_target_with_commas(): - tokens = tvmc.common.tokenize_target("foo -opt1=v,a,l,u,e,1 --flag") + tokens = tokenize_target("foo -opt1=v,a,l,u,e,1 --flag") expected_tokens = ["foo", "-opt1=v,a,l,u,e,1", "--flag"] assert len(tokens) == len(expected_tokens) @@ -79,7 +78,7 @@ def test_tokenize_target_with_commas(): def test_tokenize_target_with_commas_and_single_quotes(): - tokens = tvmc.common.tokenize_target("foo -opt1='v, a, l, u, e', bar") + tokens = tokenize_target("foo -opt1='v, a, l, u, e', bar") expected_tokens = ["foo", "-opt1='v, a, l, u, e'", ",", "bar"] assert len(tokens) == len(expected_tokens) @@ -87,7 +86,7 @@ def test_tokenize_target_with_commas_and_single_quotes(): def test_tokenize_target_with_commas_and_double_quotes(): - tokens = tvmc.common.tokenize_target('foo -opt1="v, a, l, u, e", bar') + tokens = tokenize_target('foo -opt1="v, a, l, u, e", bar') expected_tokens = ["foo", '-opt1="v, a, l, u, e"', ",", "bar"] assert len(tokens) == len(expected_tokens) @@ -95,7 +94,7 @@ def test_tokenize_target_with_commas_and_double_quotes(): def test_tokenize_target_with_dashes(): - tokens = tvmc.common.tokenize_target("foo-bar1 -opt-1=t-e-s-t, baz") + tokens = tokenize_target("foo-bar1 -opt-1=t-e-s-t, baz") expected_tokens = ["foo-bar1", "-opt-1=t-e-s-t", ",", "baz"] assert len(tokens) == len(expected_tokens) @@ -103,7 +102,7 @@ def test_tokenize_target_with_dashes(): def test_parse_single_target_with_opts(): - targets = tvmc.common.parse_target("llvm -device=arm_cpu -mattr=+fp") + targets = parse_target("llvm -device=arm_cpu -mattr=+fp") assert len(targets) == 1 assert "device" in targets[0]["opts"] @@ -111,7 +110,7 @@ def test_parse_single_target_with_opts(): def test_parse_multiple_target(): - targets = tvmc.common.parse_target("compute-library, llvm -device=arm_cpu") + targets = parse_target("compute-library, llvm -device=arm_cpu") assert len(targets) == 2 assert "compute-library" == targets[0]["name"] @@ -120,7 +119,7 @@ def test_parse_multiple_target(): def test_parse_hybrid_target(): """Hybrid Target and external codegen""" - targets = tvmc.common.parse_target( + targets = parse_target( "cmsis-nn -accelerator_config=ethos-u55-256, llvm -device=arm_cpu --system-lib" ) @@ -132,9 +131,9 @@ def test_parse_hybrid_target(): def test_parse_quotes_and_separators_on_options(): - targets_no_quote = tvmc.common.parse_target("foo -option1=+v1.0x,+value,+bar") - targets_single_quote = tvmc.common.parse_target("foo -option1='+v1.0x,+value'") - targets_double_quote = tvmc.common.parse_target('foo -option1="+v1.0x,+value"') + targets_no_quote = parse_target("foo -option1=+v1.0x,+value,+bar") + targets_single_quote = parse_target("foo -option1='+v1.0x,+value'") + targets_double_quote = parse_target('foo -option1="+v1.0x,+value"') assert len(targets_no_quote) == 1 assert "+v1.0x,+value,+bar" == targets_no_quote[0]["opts"]["option1"] @@ -147,7 +146,7 @@ def test_parse_quotes_and_separators_on_options(): def test_parse_multiple_target_with_opts_ethos_n77(): - targets = tvmc.common.parse_target("ethos-n77 -myopt=value, llvm -device=arm_cpu --system-lib") + targets = parse_target("ethos-n77 -myopt=value, llvm -device=arm_cpu --system-lib") assert len(targets) == 2 assert "ethos-n77" == targets[0]["name"] @@ -157,7 +156,7 @@ def test_parse_multiple_target_with_opts_ethos_n77(): def test_parse_multiple_target_with_opts_ethos_n78(): - targets = tvmc.common.parse_target("ethos-n78 -myopt=value, llvm -device=arm_cpu --system-lib") + targets = parse_target("ethos-n78 -myopt=value, llvm -device=arm_cpu --system-lib") assert len(targets) == 2 assert "ethos-n78" == targets[0]["name"] diff --git a/tests/python/driver/tvmc/test_target_options.py b/tests/python/driver/tvmc/test_target_options.py index b592d504fe7f..1bcad488c9dc 100644 --- a/tests/python/driver/tvmc/test_target_options.py +++ b/tests/python/driver/tvmc/test_target_options.py @@ -19,9 +19,8 @@ import pytest -from tvm.driver import tvmc -from tvm.driver.tvmc.common import TVMCException -from tvm.driver.tvmc.target import generate_target_args, reconstruct_target_args +from tvm.driver.tvmc import TVMCException +from tvm.driver.tvmc.target import generate_target_args, reconstruct_target_args, target_from_cli def test_target_to_argparse(): @@ -53,13 +52,13 @@ def test_skip_target_from_codegen(): def test_target_recombobulation_single(): - tvm_target, _ = tvmc.common.target_from_cli("llvm", {"llvm": {"mcpu": "cortex-m3"}}) + tvm_target, _ = target_from_cli("llvm", {"llvm": {"mcpu": "cortex-m3"}}) assert str(tvm_target) == "llvm -keys=cpu -link-params=0 -mcpu=cortex-m3" def test_target_recombobulation_many(): - tvm_target, _ = tvmc.common.target_from_cli( + tvm_target, _ = target_from_cli( "opencl -device=mali, llvm -mtriple=aarch64-linux-gnu", {"llvm": {"mcpu": "cortex-m3"}, "opencl": {"max_num_threads": 404}}, ) @@ -75,7 +74,7 @@ def test_error_if_target_missing(): TVMCException, match="Passed --target-opencl-max_num_threads but did not specify opencl target", ): - tvmc.common.target_from_cli( + target_from_cli( "llvm", {"opencl": {"max_num_threads": 404}}, ) diff --git a/tests/python/driver/tvmc/test_tracker.py b/tests/python/driver/tvmc/test_tracker.py index 2ca0fae8f45e..8734ad5c421f 100644 --- a/tests/python/driver/tvmc/test_tracker.py +++ b/tests/python/driver/tvmc/test_tracker.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from tvm.driver import tvmc +from tvm.driver.tvmc.tracker import tracker_host_port_from_cli def test_tracker_host_port_from_cli__hostname_port(): @@ -23,7 +23,7 @@ def test_tracker_host_port_from_cli__hostname_port(): expected_host = "1.2.3.4" expected_port = 9090 - actual_host, actual_port = tvmc.common.tracker_host_port_from_cli(input_str) + actual_host, actual_port = tracker_host_port_from_cli(input_str) assert expected_host == actual_host assert expected_port == actual_port @@ -32,7 +32,7 @@ def test_tracker_host_port_from_cli__hostname_port(): def test_tracker_host_port_from_cli__hostname_port__empty(): input_str = "" - actual_host, actual_port = tvmc.common.tracker_host_port_from_cli(input_str) + actual_host, actual_port = tracker_host_port_from_cli(input_str) assert actual_host is None assert actual_port is None @@ -43,7 +43,7 @@ def test_tracker_host_port_from_cli__only_hostname__default_port_is_9090(): expected_host = "1.2.3.4" expected_port = 9090 - actual_host, actual_port = tvmc.common.tracker_host_port_from_cli(input_str) + actual_host, actual_port = tracker_host_port_from_cli(input_str) assert expected_host == actual_host assert expected_port == actual_port