diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h index 8a2bbcbd0121..e24f8ec876ef 100644 --- a/include/tvm/target/target_kind.h +++ b/include/tvm/target/target_kind.h @@ -201,6 +201,12 @@ class TargetKindRegEntry { * \return The entry names. */ TVM_DLL static Array ListTargetKinds(); + /*! + * \brief Get all supported option names and types for a given Target kind. + * \return Map of option name to type + */ + TVM_DLL static Map ListTargetKindOptions(const TargetKind& kind); + /*! * \brief Register or get a new entry. * \param target_kind_name The name of the TargetKind. diff --git a/python/tvm/driver/tvmc/autotuner.py b/python/tvm/driver/tvmc/autotuner.py index dab855abfb11..92d13a99acd5 100644 --- a/python/tvm/driver/tvmc/autotuner.py +++ b/python/tvm/driver/tvmc/autotuner.py @@ -21,7 +21,7 @@ import logging import time from copy import deepcopy -from typing import Optional, Dict, List, Union +from typing import Any, Optional, Dict, List, Union from urllib.parse import urlparse @@ -38,6 +38,7 @@ from .common import TVMCException from .main import register_parser from .model import TVMCModel +from .target import generate_target_args, reconstruct_target_args # pylint: disable=invalid-name @@ -106,16 +107,14 @@ def add_tune_parser(subparsers): help="hostname (required) and port (optional, defaults to 9090) of the RPC tracker, " "e.g. '192.168.0.100:9999'", ) - parser.add_argument( - "--target", - help="compilation target as plain string, inline JSON or path to a JSON file", - required=True, - ) + + generate_target_args(parser) parser.add_argument( "--target-host", help="the host compilation target, defaults to 'llvm'", default="llvm", ) + parser.add_argument("--timeout", type=int, default=10, help="compilation timeout, in seconds") parser.add_argument( "--trials", @@ -286,6 +285,7 @@ def drive_tune(args): hardware_params=hardware_params, include_simple_tasks=args.include_simple_tasks, log_estimated_latency=args.log_estimated_latency, + additional_target_options=reconstruct_target_args(args), ) @@ -311,6 +311,7 @@ def tune_model( hardware_params: Optional[HardwareParams] = None, include_simple_tasks: bool = False, log_estimated_latency: bool = False, + additional_target_options: Optional[Dict[str, Dict[str, Any]]] = None, ): """Use tuning to automatically optimize the functions in a model. @@ -367,13 +368,15 @@ def tune_model( the autoscheduler. log_estimated_latency : bool, optional If using the autoscheduler, write the estimated latency at each step of tuning to file. + additional_target_options: Optional[Dict[str, Dict[str, Any]]] + Additional target options in a dictionary to combine with initial Target arguments Returns ------- tuning_records : str The path to the produced tuning log file. """ - target, extra_targets = common.target_from_cli(target) + target, extra_targets = common.target_from_cli(target, additional_target_options) target, target_host = Target.check_and_update_host_consist(target, target_host) # TODO(jwfromm) Remove this deepcopy once AlterOpLayout bug that mutates source # model is fixed. For now, creating a clone avoids the issue. diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py index 9ef2f6f1fbfa..f4bc3ec027d7 100644 --- a/python/tvm/driver/tvmc/common.py +++ b/python/tvm/driver/tvmc/common.py @@ -80,7 +80,7 @@ def convert_graph_layout(mod, desired_layout): ) -def validate_targets(parse_targets): +def validate_targets(parse_targets, additional_target_options=None): """ Apply a series of validations in the targets provided via CLI. """ @@ -104,6 +104,15 @@ def validate_targets(parse_targets): f"Found: {verbose_tvm_targets}." ) + if additional_target_options is not None: + for target_name in additional_target_options: + if not any([target for target in parse_targets if target["name"] == target_name]): + first_option = list(additional_target_options[target_name].keys())[0] + raise TVMCException( + f"Passed --target-{target_name}-{first_option}" + f" but did not specify {target_name} target" + ) + def tokenize_target(target): """ @@ -261,7 +270,21 @@ def is_inline_json(target): return False -def target_from_cli(target): +def _combine_target_options(target, additional_target_options=None): + if additional_target_options is None: + return target + if target["name"] in additional_target_options: + target["opts"].update(additional_target_options[target["name"]]) + return target + + +def _recombobulate_target(target): + name = target["name"] + opts = " ".join([f"-{key}={value}" for key, value in target["opts"].items()]) + return f"{name} {opts}" + + +def target_from_cli(target, additional_target_options=None): """ Create a tvm.target.Target instance from a command line interface (CLI) string. @@ -272,6 +295,10 @@ def target_from_cli(target): compilation target as plain string, inline JSON or path to a JSON file + additional_target_options: Optional[Dict[str, Dict[str,str]]] + dictionary of additional target options to be + combined with parsed targets + Returns ------- tvm.target.Target @@ -298,18 +325,22 @@ def target_from_cli(target): except ValueError as ex: raise TVMCException(f"Error parsing target string '{target}'.\nThe error was: {ex}") - validate_targets(parsed_targets) - tvm_targets = [t for t in parsed_targets if t["is_tvm_target"]] + validate_targets(parsed_targets, additional_target_options) + tvm_targets = [ + _combine_target_options(t, additional_target_options) + for t in parsed_targets + if t["is_tvm_target"] + ] # Validated target strings have 1 or 2 tvm targets, otherwise # `validate_targets` above will fail. if len(tvm_targets) == 1: - target = tvm_targets[0]["raw"] + target = _recombobulate_target(tvm_targets[0]) target_host = None else: assert len(tvm_targets) == 2 - target = tvm_targets[0]["raw"] - target_host = tvm_targets[1]["raw"] + target = _recombobulate_target(tvm_targets[0]) + target_host = _recombobulate_target(tvm_targets[1]) extra_targets = [t for t in parsed_targets if not t["is_tvm_target"]] diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index 9eb85a4934cb..7623a141c27a 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -19,7 +19,7 @@ """ import logging import os.path -from typing import Optional, Dict, List, Union, Callable +from typing import Any, Optional, Dict, List, Union, Callable from pathlib import Path import tvm @@ -30,6 +30,7 @@ from . import common, composite_target, frontends from .model import TVMCModel, TVMCPackage from .main import register_parser +from .target import generate_target_args, reconstruct_target_args # pylint: disable=invalid-name @@ -91,11 +92,7 @@ def add_compile_parser(subparsers): "times, each one to set one configuration value, " "e.g. '--pass-config relay.backend.use_auto_scheduler=0'.", ) - parser.add_argument( - "--target", - help="compilation targets as comma separated string, inline JSON or path to a JSON file.", - required=True, - ) + generate_target_args(parser) parser.add_argument( "--tuning-records", metavar="PATH", @@ -154,6 +151,7 @@ def drive_compile(args): desired_layout=args.desired_layout, disabled_pass=args.disabled_pass, pass_context_configs=args.pass_config, + additional_target_options=reconstruct_target_args(args), ) return 0 @@ -172,6 +170,7 @@ def compile_model( desired_layout: Optional[str] = None, disabled_pass: Optional[str] = None, pass_context_configs: Optional[List[str]] = None, + additional_target_options: Optional[Dict[str, Dict[str, Any]]] = None, ): """Compile a model from a supported framework into a TVM module. @@ -215,6 +214,8 @@ def compile_model( pass_context_configs: list[str], optional List of strings containing a set of configurations to be passed to the PassContext. + additional_target_options: Optional[Dict[str, Dict[str, Any]]] + Additional target options in a dictionary to combine with initial Target arguments Returns @@ -230,7 +231,7 @@ def compile_model( if desired_layout: mod = common.convert_graph_layout(mod, desired_layout) - tvm_target, extra_targets = common.target_from_cli(target) + tvm_target, extra_targets = common.target_from_cli(target, additional_target_options) tvm_target, target_host = Target.check_and_update_host_consist(tvm_target, target_host) for codegen_from_cli in extra_targets: diff --git a/python/tvm/driver/tvmc/target.py b/python/tvm/driver/tvmc/target.py new file mode 100644 index 000000000000..7a078b8be087 --- /dev/null +++ b/python/tvm/driver/tvmc/target.py @@ -0,0 +1,74 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This file contains functions for processing target inputs for the TVMC CLI +""" + +from tvm.target import Target + +# We can't tell the type inside an Array but all current options are strings so +# it can default to that. Bool is used alongside Integer but aren't distinguished +# between as both are represented by IntImm +INTERNAL_TO_NATIVE_TYPE = {"runtime.String": str, "IntImm": int, "Array": str} +INTERNAL_TO_HELP = {"runtime.String": " string", "IntImm": "", "Array": " options"} + + +def _generate_target_kind_args(parser, kind): + target_group = parser.add_argument_group(f"target {kind.name}") + for target_option, target_type in kind.options.items(): + if target_type in INTERNAL_TO_NATIVE_TYPE: + target_group.add_argument( + f"--target-{kind.name}-{target_option}", + type=INTERNAL_TO_NATIVE_TYPE[target_type], + help=f"target {kind.name} {target_option}{INTERNAL_TO_HELP[target_type]}", + ) + + +def generate_target_args(parser): + """Walks through the TargetKind registry and generates arguments for each Target's options""" + parser.add_argument( + "--target", + help="compilation target as plain string, inline JSON or path to a JSON file", + required=True, + ) + target_kinds = Target.list_kinds() + for target_kind in target_kinds: + target = Target(target_kind) + _generate_target_kind_args(parser, target.kind) + + +def _reconstruct_target_kind_args(args, kind): + kind_options = {} + for target_option, target_type in kind.options.items(): + if target_type in INTERNAL_TO_NATIVE_TYPE: + var_name = f"target_{kind.name}_{target_option.replace('-', '_')}" + option_value = getattr(args, var_name) + if option_value is not None: + kind_options[target_option] = getattr(args, var_name) + return kind_options + + +def reconstruct_target_args(args): + """Reconstructs the target options from the arguments""" + target_kinds = Target.list_kinds() + reconstructed = {} + for target_kind in target_kinds: + target = Target(target_kind) + kind_options = _reconstruct_target_kind_args(args, target.kind) + if kind_options: + reconstructed[target.kind.name] = kind_options + return reconstructed diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index 4e5826f5b2a2..9af09296e9cc 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -31,6 +31,11 @@ class TargetKind(Object): """Kind of a compilation target""" + @property + def options(self): + """Returns the dict of available option names and types""" + return dict(_ffi_api.ListTargetKindOptions(self)) + @tvm._ffi.register_object class Target(Object): diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index d719386d204b..7cd329f83738 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -49,6 +49,14 @@ Array TargetKindRegEntry::ListTargetKinds() { return TargetKindRegistry::Global()->ListAllNames(); } +Map TargetKindRegEntry::ListTargetKindOptions(const TargetKind& target_kind) { + Map options; + for (const auto& kv : target_kind->key2vtype_) { + options.Set(kv.first, kv.second.type_key); + } + return options; +} + TargetKindRegEntry& TargetKindRegEntry::RegisterOrGet(const String& target_kind_name) { return TargetKindRegistry::Global()->RegisterOrGet(target_kind_name); } @@ -359,5 +367,7 @@ TVM_REGISTER_TARGET_KIND("composite", kDLCPU).add_attr_option>("de /********** Registry **********/ TVM_REGISTER_GLOBAL("target.ListTargetKinds").set_body_typed(TargetKindRegEntry::ListTargetKinds); +TVM_REGISTER_GLOBAL("target.ListTargetKindOptions") + .set_body_typed(TargetKindRegEntry::ListTargetKindOptions); } // namespace tvm diff --git a/tests/cpp/target_test.cc b/tests/cpp/target_test.cc index 2e8ba11c0262..6106eb2225e1 100644 --- a/tests/cpp/target_test.cc +++ b/tests/cpp/target_test.cc @@ -152,8 +152,18 @@ TEST(TargetCreation, DeduplicateKeys) { ICHECK_EQ(target->GetAttr("link-params"), false); } -TEST(TargetKindRegistryListTargetKinds, Basic) { +TEST(TargetKindRegistry, ListTargetKinds) { Array names = TargetKindRegEntry::ListTargetKinds(); ICHECK_EQ(names.empty(), false); ICHECK_EQ(std::count(std::begin(names), std::end(names), "llvm"), 1); } + +TEST(TargetKindRegistry, ListTargetOptions) { + TargetKind llvm = TargetKind::Get("llvm").value(); + Map attrs = TargetKindRegEntry::ListTargetKindOptions(llvm); + ICHECK_EQ(attrs.empty(), false); + + ICHECK_EQ(attrs["mattr"], "Array"); + ICHECK_EQ(attrs["mcpu"], "runtime.String"); + ICHECK_EQ(attrs["system-lib"], "IntImm"); +} diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index 2e4687fb7985..1d93d73256bc 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -397,7 +397,7 @@ def test_compile_tflite_module_with_external_codegen_cmsisnn( tvmc_package = tvmc.compiler.compile_model( tvmc_model, - target=f"cmsis-nn, c -runtime=c --system-lib --link-params -mcpu=cortex-m55 --executor=aot", + target=f"cmsis-nn, c -runtime=c --system-lib --link-params -mcpu=cortex-m55 -executor=aot", output_format="mlf", package_path=output_file_name, pass_context_configs=["tir.disable_vectorize=true"], @@ -455,7 +455,7 @@ def test_compile_tflite_module_with_external_codegen_ethosu( tvmc_package = tvmc.compiler.compile_model( tvmc_model, - target=f"ethos-u -accelerator_config={accel_type}, c -runtime=c --system-lib --link-params -mcpu=cortex-m55 --executor=aot", + target=f"ethos-u -accelerator_config={accel_type}, c -runtime=c --system-lib --link-params -mcpu=cortex-m55 -executor=aot", output_format="mlf", package_path=output_file_name, pass_context_configs=["tir.disable_vectorize=true"], diff --git a/tests/python/driver/tvmc/test_frontends.py b/tests/python/driver/tvmc/test_frontends.py index 569c42020817..4d2fb56c5d4e 100644 --- a/tests/python/driver/tvmc/test_frontends.py +++ b/tests/python/driver/tvmc/test_frontends.py @@ -14,11 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import os -import tarfile import pytest +import tvm from tvm.ir.module import IRModule from tvm.driver import tvmc @@ -229,3 +228,128 @@ def test_load_model___wrong_language__to_pytorch(tflite_mobilenet_v1_1_quant): model_format="pytorch", shape_dict={"input": [1, 3, 224, 224]}, ) + + +def test_compile_tflite_module_nhwc_to_nchw(tflite_mobilenet_v1_1_quant): + # some CI environments wont offer TFLite, so skip in case it is not present + pytest.importorskip("tflite") + + tvmc_model = tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant) + before = tvmc_model.mod + + expected_layout = "NCHW" + after = tvmc.common.convert_graph_layout(before, expected_layout) + + layout_transform_calls = [] + + def _is_layout_transform(node): + if isinstance(node, tvm.relay.expr.Call): + layout_transform_calls.append( + node.op.name == "layout_transform" + and node.attrs.src_layout == "NHWC" + and node.attrs.dst_layout == "NCHW" + ) + + tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) + + assert any(layout_transform_calls), "Expected 'layout_transform NHWC->NCHW' not found" + + +def test_compile_onnx_module_nchw_to_nhwc(onnx_resnet50): + # some CI environments wont offer ONNX, so skip in case it is not present + pytest.importorskip("onnx") + + tvmc_model = tvmc.frontends.load_model(onnx_resnet50) + before = tvmc_model.mod + + expected_layout = "NHWC" + after = tvmc.common.convert_graph_layout(before, expected_layout) + + layout_transform_calls = [] + + def _is_layout_transform(node): + if isinstance(node, tvm.relay.expr.Call): + layout_transform_calls.append( + node.op.name == "layout_transform" + and node.attrs.src_layout == "NCHW" + and node.attrs.dst_layout == "NHWC" + ) + + tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) + + assert any(layout_transform_calls), "Expected 'layout_transform NCWH->NHWC' not found" + + +def test_compile_paddle_module_nchw_to_nhwc(paddle_resnet50): + # some CI environments wont offer Paddle, so skip in case it is not present + pytest.importorskip("paddle") + + tvmc_model = tvmc.frontends.load_model(paddle_resnet50, "paddle") + before = tvmc_model.mod + + expected_layout = "NHWC" + after = tvmc.common.convert_graph_layout(before, expected_layout) + + layout_transform_calls = [] + + def _is_layout_transform(node): + if isinstance(node, tvm.relay.expr.Call): + layout_transform_calls.append( + node.op.name == "layout_transform" + and node.attrs.src_layout == "NCHW" + and node.attrs.dst_layout == "NHWC" + ) + + tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) + + assert any(layout_transform_calls), "Expected 'layout_transform NCWH->NHWC' not found" + + +def test_compile_tflite_module__same_layout__nhwc_to_nhwc(tflite_mobilenet_v1_1_quant): + # some CI environments wont offer TFLite, so skip in case it is not present + pytest.importorskip("tflite") + + tvmc_model = tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant) + before = tvmc_model.mod + + expected_layout = "NHWC" + after = tvmc.common.convert_graph_layout(before, expected_layout) + + layout_transform_calls = [] + + def _is_layout_transform(node): + if isinstance(node, tvm.relay.expr.Call): + layout_transform_calls.append( + node.op.name == "layout_transform" + and node.attrs.src_layout == "NHWC" + and node.attrs.dst_layout == "NHWC" + ) + + tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) + + assert not any(layout_transform_calls), "Unexpected 'layout_transform' call" + + +def test_compile_onnx_module__same_layout__nchw_to_nchw(onnx_resnet50): + # some CI environments wont offer ONNX, so skip in case it is not present + pytest.importorskip("onnx") + + tvmc_model = tvmc.frontends.load_model(onnx_resnet50) + before = tvmc_model.mod + + expected_layout = "NCHW" + after = tvmc.common.convert_graph_layout(before, expected_layout) + + layout_transform_calls = [] + + def _is_layout_transform(node): + if isinstance(node, tvm.relay.expr.Call): + layout_transform_calls.append( + node.op.name == "layout_transform" + and node.attrs.src_layout == "NCHW" + and node.attrs.dst_layout == "NCHW" + ) + + tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) + + assert not any(layout_transform_calls), "Unexpected 'layout_transform' call" diff --git a/tests/python/driver/tvmc/test_mlf.py b/tests/python/driver/tvmc/test_mlf.py index 0426f5678153..11306bd58848 100644 --- a/tests/python/driver/tvmc/test_mlf.py +++ b/tests/python/driver/tvmc/test_mlf.py @@ -27,7 +27,7 @@ @pytest.mark.parametrize( - "target,pass_configs", [["llvm", []], ["c --executor=aot", ["tir.disable_vectorize=1"]]] + "target,pass_configs", [["llvm", []], ["c -executor=aot", ["tir.disable_vectorize=1"]]] ) def test_tvmc_cl_compile_run_mlf(tflite_mobilenet_v1_1_quant, tmpdir_factory, target, pass_configs): pytest.importorskip("tflite") @@ -114,7 +114,7 @@ def test_tvmc_import_package_mlf_aot(tflite_mobilenet_v1_1_quant, tflite_compile tflite_compiled_model_mlf = tflite_compile_model( tflite_mobilenet_v1_1_quant, - target="c --executor=aot", + target="c -executor=aot", output_format="mlf", pass_context_configs=["tir.disable_vectorize=1"], ) diff --git a/tests/python/driver/tvmc/test_pass_config.py b/tests/python/driver/tvmc/test_pass_config.py new file mode 100644 index 000000000000..d8ffd7d4d521 --- /dev/null +++ b/tests/python/driver/tvmc/test_pass_config.py @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +from tvm.contrib.target.vitis_ai import vitis_ai_available +from tvm.driver import tvmc + +from tvm.driver.tvmc.common import TVMCException + + +def test_config_invalid_format(): + with pytest.raises(TVMCException): + _ = tvmc.common.parse_configs(["relay.backend.use_auto_scheduler.missing.value"]) + + +def test_config_missing_from_tvm(): + with pytest.raises(TVMCException): + _ = tvmc.common.parse_configs(["relay.backend.use_auto_scheduler.missing.value=1234"]) + + +def test_config_unsupported_tvmc_config(): + with pytest.raises(TVMCException): + _ = tvmc.common.parse_configs(["tir.LoopPartition=value"]) + + +def test_config_empty(): + with pytest.raises(TVMCException): + _ = tvmc.common.parse_configs([""]) + + +def test_config_valid_config_bool(): + configs = tvmc.common.parse_configs(["relay.backend.use_auto_scheduler=true"]) + + assert len(configs) == 1 + assert "relay.backend.use_auto_scheduler" in configs.keys() + assert configs["relay.backend.use_auto_scheduler"] == True + + +@pytest.mark.skipif( + not vitis_ai_available(), + reason="--target vitis-ai is not available. TVM built with 'USE_VITIS_AI OFF'", +) +def test_config_valid_multiple_configs(): + configs = tvmc.common.parse_configs( + [ + "relay.backend.use_auto_scheduler=false", + "tir.detect_global_barrier=10", + "relay.ext.vitis_ai.options.build_dir=mystring", + ] + ) + + assert len(configs) == 3 + assert "relay.backend.use_auto_scheduler" in configs.keys() + assert configs["relay.backend.use_auto_scheduler"] == False + assert "tir.detect_global_barrier" in configs.keys() + assert configs["tir.detect_global_barrier"] == 10 + assert "relay.ext.vitis_ai.options.build_dir" in configs.keys() + assert configs["relay.ext.vitis_ai.options.build_dir"] == "mystring" diff --git a/tests/python/driver/tvmc/test_common.py b/tests/python/driver/tvmc/test_pass_list.py similarity index 97% rename from tests/python/driver/tvmc/test_common.py rename to tests/python/driver/tvmc/test_pass_list.py index 5cac6a1378a5..de50b04f415a 100644 --- a/tests/python/driver/tvmc/test_common.py +++ b/tests/python/driver/tvmc/test_pass_list.py @@ -14,12 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + import argparse import pytest from tvm.driver import tvmc -def test_common_parse_pass_list_str(): +def test_parse_pass_list_str(): assert [""] == tvmc.common.parse_pass_list_str("") assert ["FoldScaleAxis", "FuseOps"] == tvmc.common.parse_pass_list_str("FoldScaleAxis,FuseOps") diff --git a/tests/python/driver/tvmc/test_shape_parser.py b/tests/python/driver/tvmc/test_shape_parser.py new file mode 100644 index 000000000000..c021078630ed --- /dev/null +++ b/tests/python/driver/tvmc/test_shape_parser.py @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import argparse + +import pytest + +from tvm.driver import tvmc + + +def test_shape_parser(): + # Check that a valid input is parsed correctly + shape_string = "input:[10,10,10]" + shape_dict = tvmc.common.parse_shape_string(shape_string) + assert shape_dict == {"input": [10, 10, 10]} + + +def test_alternate_syntax(): + shape_string = "input:0:[10,10,10] input2:[20,20,20,20]" + shape_dict = tvmc.common.parse_shape_string(shape_string) + assert shape_dict == {"input:0": [10, 10, 10], "input2": [20, 20, 20, 20]} + + +@pytest.mark.parametrize( + "shape_string", + [ + "input:[10,10,10] input2:[20,20,20,20]", + "input: [10, 10, 10] input2: [20, 20, 20, 20]", + "input:[10,10,10],input2:[20,20,20,20]", + ], +) +def test_alternate_syntaxes(shape_string): + shape_dict = tvmc.common.parse_shape_string(shape_string) + assert shape_dict == {"input": [10, 10, 10], "input2": [20, 20, 20, 20]} + + +def test_negative_dimensions(): + # Check that negative dimensions parse to Any correctly. + shape_string = "input:[-1,3,224,224]" + shape_dict = tvmc.common.parse_shape_string(shape_string) + # Convert to strings to allow comparison with Any. + assert str(shape_dict) == "{'input': [?, 3, 224, 224]}" + + +def test_multiple_valid_gpu_inputs(): + # Check that multiple valid gpu inputs are parsed correctly. + shape_string = "gpu_0/data_0:[1, -1,224,224] gpu_1/data_1:[7, 7]" + shape_dict = tvmc.common.parse_shape_string(shape_string) + expected = "{'gpu_0/data_0': [1, ?, 224, 224], 'gpu_1/data_1': [7, 7]}" + assert str(shape_dict) == expected + + +def test_invalid_pattern(): + shape_string = "input:[a,10]" + with pytest.raises(argparse.ArgumentTypeError): + tvmc.common.parse_shape_string(shape_string) + + +def test_invalid_separators(): + shape_string = "input:5,10 input2:10,10" + with pytest.raises(argparse.ArgumentTypeError): + tvmc.common.parse_shape_string(shape_string) + + +def test_invalid_colon(): + shape_string = "gpu_0/data_0:5,10 :test:10,10" + with pytest.raises(argparse.ArgumentTypeError): + tvmc.common.parse_shape_string(shape_string) + + +@pytest.mark.parametrize( + "shape_string", + [ + "gpu_0/data_0:5,10 /:10,10", + "gpu_0/data_0:5,10 data/:10,10", + "gpu_0/data_0:5,10 /data:10,10", + "gpu_0/invalid/data_0:5,10 data_1:10,10", + ], +) +def test_invalid_slashes(shape_string): + with pytest.raises(argparse.ArgumentTypeError): + tvmc.common.parse_shape_string(shape_string) diff --git a/tests/python/driver/tvmc/test_target.py b/tests/python/driver/tvmc/test_target.py new file mode 100644 index 000000000000..afb099f3add6 --- /dev/null +++ b/tests/python/driver/tvmc/test_target.py @@ -0,0 +1,143 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +from tvm.driver import tvmc + +from tvm.driver.tvmc.common import TVMCException + + +def test_target_from_cli__error_duplicate(): + with pytest.raises(TVMCException): + _ = tvmc.common.target_from_cli("llvm, llvm") + + +def test_target_invalid_more_than_two_tvm_targets(): + with pytest.raises(TVMCException): + _ = tvmc.common.target_from_cli("cuda, opencl, llvm") + + +def test_target_from_cli__error_target_not_found(): + with pytest.raises(TVMCException): + _ = tvmc.common.target_from_cli("invalidtarget") + + +def test_target_from_cli__error_no_tvm_target(): + with pytest.raises(TVMCException): + _ = tvmc.common.target_from_cli("ethos-n77") + + +def test_target_two_tvm_targets(): + tvm_target, extra_targets = tvmc.common.target_from_cli( + "opencl -device=mali, llvm -mtriple=aarch64-linux-gnu" + ) + + assert "opencl" in str(tvm_target) + assert "llvm" in str(tvm_target.host) + + # No extra targets + assert 0 == len(extra_targets) + + +def test_tokenize_target_with_opts(): + tokens = tvmc.common.tokenize_target("foo -opt1=value1 --flag, bar -opt2=value2") + expected_tokens = ["foo", "-opt1=value1", "--flag", ",", "bar", "-opt2=value2"] + + assert len(tokens) == len(expected_tokens) + assert tokens == expected_tokens + + +def test_tokenize_target_with_plus_sign(): + tokens = tvmc.common.tokenize_target("foo -opt1=+value1 --flag, bar -opt2=test,+v") + expected_tokens = ["foo", "-opt1=+value1", "--flag", ",", "bar", "-opt2=test,+v"] + + assert len(tokens) == len(expected_tokens) + assert tokens == expected_tokens + + +def test_tokenize_target_with_commas(): + tokens = tvmc.common.tokenize_target("foo -opt1=v,a,l,u,e,1 --flag") + expected_tokens = ["foo", "-opt1=v,a,l,u,e,1", "--flag"] + + assert len(tokens) == len(expected_tokens) + assert tokens == expected_tokens + + +def test_tokenize_target_with_commas_and_single_quotes(): + tokens = tvmc.common.tokenize_target("foo -opt1='v, a, l, u, e', bar") + expected_tokens = ["foo", "-opt1='v, a, l, u, e'", ",", "bar"] + + assert len(tokens) == len(expected_tokens) + assert tokens == expected_tokens + + +def test_tokenize_target_with_commas_and_double_quotes(): + tokens = tvmc.common.tokenize_target('foo -opt1="v, a, l, u, e", bar') + expected_tokens = ["foo", '-opt1="v, a, l, u, e"', ",", "bar"] + + assert len(tokens) == len(expected_tokens) + assert tokens == expected_tokens + + +def test_tokenize_target_with_dashes(): + tokens = tvmc.common.tokenize_target("foo-bar1 -opt-1=t-e-s-t, baz") + expected_tokens = ["foo-bar1", "-opt-1=t-e-s-t", ",", "baz"] + + assert len(tokens) == len(expected_tokens) + assert tokens == expected_tokens + + +def test_parse_single_target_with_opts(): + targets = tvmc.common.parse_target("llvm -device=arm_cpu --system-lib") + + assert len(targets) == 1 + assert "device" in targets[0]["opts"] + assert "system-lib" in targets[0]["opts"] + + +def test_parse_multiple_target(): + targets = tvmc.common.parse_target("compute-library, llvm -device=arm_cpu --system-lib") + + assert len(targets) == 2 + assert "compute-library" == targets[0]["name"] + assert "llvm" == targets[1]["name"] + + +def test_parse_multiple_target_with_opts(): + targets = tvmc.common.parse_target("ethos-n77 -myopt=value, llvm -device=arm_cpu --system-lib") + + assert len(targets) == 2 + assert "ethos-n77" == targets[0]["name"] + assert "myopt" in targets[0]["opts"] + assert "value" == targets[0]["opts"]["myopt"] + assert "llvm" == targets[1]["name"] + + +def test_parse_quotes_and_separators_on_options(): + targets_no_quote = tvmc.common.parse_target("foo -option1=+v1.0x,+value,+bar") + targets_single_quote = tvmc.common.parse_target("foo -option1='+v1.0x,+value'") + targets_double_quote = tvmc.common.parse_target('foo -option1="+v1.0x,+value"') + + assert len(targets_no_quote) == 1 + assert "+v1.0x,+value,+bar" == targets_no_quote[0]["opts"]["option1"] + + assert len(targets_single_quote) == 1 + assert "+v1.0x,+value" == targets_single_quote[0]["opts"]["option1"] + + assert len(targets_double_quote) == 1 + assert "+v1.0x,+value" == targets_double_quote[0]["opts"]["option1"] diff --git a/tests/python/driver/tvmc/test_target_options.py b/tests/python/driver/tvmc/test_target_options.py new file mode 100644 index 000000000000..f6942299b751 --- /dev/null +++ b/tests/python/driver/tvmc/test_target_options.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import argparse + +import pytest + +from tvm.driver import tvmc +from tvm.driver.tvmc.common import TVMCException +from tvm.driver.tvmc.target import generate_target_args, reconstruct_target_args + + +def test_target_to_argparse(): + parser = argparse.ArgumentParser() + generate_target_args(parser) + parsed, _ = parser.parse_known_args( + ["--target=llvm", "--target-llvm-mattr=+fp,+mve", "--target-llvm-mcpu=cortex-m3"] + ) + assert parsed.target == "llvm" + assert parsed.target_llvm_mcpu == "cortex-m3" + assert parsed.target_llvm_mattr == "+fp,+mve" + + +def test_mapping_target_args(): + parser = argparse.ArgumentParser() + generate_target_args(parser) + parsed, _ = parser.parse_known_args(["--target=llvm", "--target-llvm-mcpu=cortex-m3"]) + assert reconstruct_target_args(parsed) == {"llvm": {"mcpu": "cortex-m3"}} + + +def test_target_recombobulation_single(): + tvm_target, _ = tvmc.common.target_from_cli("llvm", {"llvm": {"mcpu": "cortex-m3"}}) + + assert str(tvm_target) == "llvm -keys=cpu -link-params=0 -mcpu=cortex-m3" + + +def test_target_recombobulation_many(): + tvm_target, _ = tvmc.common.target_from_cli( + "opencl -device=mali, llvm -mtriple=aarch64-linux-gnu", + {"llvm": {"mcpu": "cortex-m3"}, "opencl": {"max_num_threads": 404}}, + ) + + assert "-max_num_threads=404" in str(tvm_target) + assert "-device=mali" in str(tvm_target) + assert "-mtriple=aarch64-linux-gnu" in str(tvm_target.host) + assert "-mcpu=cortex-m3" in str(tvm_target.host) + + +def test_error_if_target_missing(): + with pytest.raises( + TVMCException, + match="Passed --target-opencl-max_num_threads but did not specify opencl target", + ): + tvmc.common.target_from_cli( + "llvm", + {"opencl": {"max_num_threads": 404}}, + ) diff --git a/tests/python/driver/tvmc/test_tracker.py b/tests/python/driver/tvmc/test_tracker.py new file mode 100644 index 000000000000..2ca0fae8f45e --- /dev/null +++ b/tests/python/driver/tvmc/test_tracker.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from tvm.driver import tvmc + + +def test_tracker_host_port_from_cli__hostname_port(): + input_str = "1.2.3.4:9090" + expected_host = "1.2.3.4" + expected_port = 9090 + + actual_host, actual_port = tvmc.common.tracker_host_port_from_cli(input_str) + + assert expected_host == actual_host + assert expected_port == actual_port + + +def test_tracker_host_port_from_cli__hostname_port__empty(): + input_str = "" + + actual_host, actual_port = tvmc.common.tracker_host_port_from_cli(input_str) + + assert actual_host is None + assert actual_port is None + + +def test_tracker_host_port_from_cli__only_hostname__default_port_is_9090(): + input_str = "1.2.3.4" + expected_host = "1.2.3.4" + expected_port = 9090 + + actual_host, actual_port = tvmc.common.tracker_host_port_from_cli(input_str) + + assert expected_host == actual_host + assert expected_port == actual_port diff --git a/tests/python/driver/tvmc/test_tvmc_common.py b/tests/python/driver/tvmc/test_tvmc_common.py deleted file mode 100644 index bdfdb48ce6a0..000000000000 --- a/tests/python/driver/tvmc/test_tvmc_common.py +++ /dev/null @@ -1,413 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import argparse - -import pytest - -import tvm -from tvm.contrib.target.vitis_ai import vitis_ai_available -from tvm.driver import tvmc - -from tvm.driver.tvmc.common import TVMCException - - -def test_compile_tflite_module_nhwc_to_nchw(tflite_mobilenet_v1_1_quant): - # some CI environments wont offer TFLite, so skip in case it is not present - pytest.importorskip("tflite") - - tvmc_model = tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant) - before = tvmc_model.mod - - expected_layout = "NCHW" - after = tvmc.common.convert_graph_layout(before, expected_layout) - - layout_transform_calls = [] - - def _is_layout_transform(node): - if isinstance(node, tvm.relay.expr.Call): - layout_transform_calls.append( - node.op.name == "layout_transform" - and node.attrs.src_layout == "NHWC" - and node.attrs.dst_layout == "NCHW" - ) - - tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) - - assert any(layout_transform_calls), "Expected 'layout_transform NHWC->NCHW' not found" - - -def test_compile_onnx_module_nchw_to_nhwc(onnx_resnet50): - # some CI environments wont offer ONNX, so skip in case it is not present - pytest.importorskip("onnx") - - tvmc_model = tvmc.frontends.load_model(onnx_resnet50) - before = tvmc_model.mod - - expected_layout = "NHWC" - after = tvmc.common.convert_graph_layout(before, expected_layout) - - layout_transform_calls = [] - - def _is_layout_transform(node): - if isinstance(node, tvm.relay.expr.Call): - layout_transform_calls.append( - node.op.name == "layout_transform" - and node.attrs.src_layout == "NCHW" - and node.attrs.dst_layout == "NHWC" - ) - - tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) - - assert any(layout_transform_calls), "Expected 'layout_transform NCWH->NHWC' not found" - - -def test_compile_paddle_module_nchw_to_nhwc(paddle_resnet50): - # some CI environments wont offer Paddle, so skip in case it is not present - pytest.importorskip("paddle") - - tvmc_model = tvmc.frontends.load_model(paddle_resnet50, "paddle") - before = tvmc_model.mod - - expected_layout = "NHWC" - after = tvmc.common.convert_graph_layout(before, expected_layout) - - layout_transform_calls = [] - - def _is_layout_transform(node): - if isinstance(node, tvm.relay.expr.Call): - layout_transform_calls.append( - node.op.name == "layout_transform" - and node.attrs.src_layout == "NCHW" - and node.attrs.dst_layout == "NHWC" - ) - - tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) - - assert any(layout_transform_calls), "Expected 'layout_transform NCWH->NHWC' not found" - - -def test_compile_tflite_module__same_layout__nhwc_to_nhwc(tflite_mobilenet_v1_1_quant): - # some CI environments wont offer TFLite, so skip in case it is not present - pytest.importorskip("tflite") - - tvmc_model = tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant) - before = tvmc_model.mod - - expected_layout = "NHWC" - after = tvmc.common.convert_graph_layout(before, expected_layout) - - layout_transform_calls = [] - - def _is_layout_transform(node): - if isinstance(node, tvm.relay.expr.Call): - layout_transform_calls.append( - node.op.name == "layout_transform" - and node.attrs.src_layout == "NHWC" - and node.attrs.dst_layout == "NHWC" - ) - - tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) - - assert not any(layout_transform_calls), "Unexpected 'layout_transform' call" - - -def test_compile_onnx_module__same_layout__nchw_to_nchw(onnx_resnet50): - # some CI environments wont offer ONNX, so skip in case it is not present - pytest.importorskip("onnx") - - tvmc_model = tvmc.frontends.load_model(onnx_resnet50) - before = tvmc_model.mod - - expected_layout = "NCHW" - after = tvmc.common.convert_graph_layout(before, expected_layout) - - layout_transform_calls = [] - - def _is_layout_transform(node): - if isinstance(node, tvm.relay.expr.Call): - layout_transform_calls.append( - node.op.name == "layout_transform" - and node.attrs.src_layout == "NCHW" - and node.attrs.dst_layout == "NCHW" - ) - - tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform) - - assert not any(layout_transform_calls), "Unexpected 'layout_transform' call" - - -def test_tracker_host_port_from_cli__hostname_port(): - input_str = "1.2.3.4:9090" - expected_host = "1.2.3.4" - expected_port = 9090 - - actual_host, actual_port = tvmc.common.tracker_host_port_from_cli(input_str) - - assert expected_host == actual_host - assert expected_port == actual_port - - -def test_tracker_host_port_from_cli__hostname_port__empty(): - input_str = "" - - actual_host, actual_port = tvmc.common.tracker_host_port_from_cli(input_str) - - assert actual_host is None - assert actual_port is None - - -def test_tracker_host_port_from_cli__only_hostname__default_port_is_9090(): - input_str = "1.2.3.4" - expected_host = "1.2.3.4" - expected_port = 9090 - - actual_host, actual_port = tvmc.common.tracker_host_port_from_cli(input_str) - - assert expected_host == actual_host - assert expected_port == actual_port - - -def test_shape_parser(): - # Check that a valid input is parsed correctly - shape_string = "input:[10,10,10]" - shape_dict = tvmc.common.parse_shape_string(shape_string) - assert shape_dict == {"input": [10, 10, 10]} - # Check that multiple valid input shapes are parse correctly - shape_string = "input:[10,10,10] input2:[20,20,20,20]" - shape_dict = tvmc.common.parse_shape_string(shape_string) - assert shape_dict == {"input": [10, 10, 10], "input2": [20, 20, 20, 20]} - # Check that multiple valid input shapes with colons are parse correctly - shape_string = "input:0:[10,10,10] input2:[20,20,20,20]" - shape_dict = tvmc.common.parse_shape_string(shape_string) - assert shape_dict == {"input:0": [10, 10, 10], "input2": [20, 20, 20, 20]} - # Check that alternate syntax parses correctly - shape_string = "input: [10, 10, 10] input2: [20, 20, 20, 20]" - shape_dict = tvmc.common.parse_shape_string(shape_string) - assert shape_dict == {"input": [10, 10, 10], "input2": [20, 20, 20, 20]} - shape_string = "input:[10,10,10],input2:[20,20,20,20]" - shape_dict = tvmc.common.parse_shape_string(shape_string) - assert shape_dict == {"input": [10, 10, 10], "input2": [20, 20, 20, 20]} - # Check that negative dimensions parse to Any correctly. - shape_string = "input:[-1,3,224,224]" - shape_dict = tvmc.common.parse_shape_string(shape_string) - # Convert to strings to allow comparison with Any. - assert str(shape_dict) == "{'input': [?, 3, 224, 224]}" - # Check that multiple valid gpu inputs are parsed correctly. - shape_string = "gpu_0/data_0:[1, -1,224,224] gpu_1/data_1:[7, 7]" - shape_dict = tvmc.common.parse_shape_string(shape_string) - expected = "{'gpu_0/data_0': [1, ?, 224, 224], 'gpu_1/data_1': [7, 7]}" - assert str(shape_dict) == expected - - # Check that invalid pattern raises expected error. - shape_string = "input:[a,10]" - with pytest.raises(argparse.ArgumentTypeError): - tvmc.common.parse_shape_string(shape_string) - # Check that input with invalid separators raises error. - shape_string = "input:5,10 input2:10,10" - with pytest.raises(argparse.ArgumentTypeError): - tvmc.common.parse_shape_string(shape_string) - # Check that input with a invalid slash raises error. - shape_string = "gpu_0/data_0:5,10 /:10,10" - with pytest.raises(argparse.ArgumentTypeError): - tvmc.common.parse_shape_string(shape_string) - # Check that input with a invalid colon raises error. - shape_string = "gpu_0/data_0:5,10 :test:10,10" - with pytest.raises(argparse.ArgumentTypeError): - tvmc.common.parse_shape_string(shape_string) - # Check that input with a invalid slash raises error. - shape_string = "gpu_0/data_0:5,10 data/:10,10" - with pytest.raises(argparse.ArgumentTypeError): - tvmc.common.parse_shape_string(shape_string) - # Check that input with a invalid slash raises error. - shape_string = "gpu_0/data_0:5,10 /data:10,10" - with pytest.raises(argparse.ArgumentTypeError): - tvmc.common.parse_shape_string(shape_string) - # Check that input with invalid slashes raises error. - shape_string = "gpu_0/invalid/data_0:5,10 data_1:10,10" - with pytest.raises(argparse.ArgumentTypeError): - tvmc.common.parse_shape_string(shape_string) - - -def test_target_from_cli__error_duplicate(): - with pytest.raises(TVMCException): - _ = tvmc.common.target_from_cli("llvm, llvm") - - -def test_target_invalid_more_than_two_tvm_targets(): - with pytest.raises(TVMCException): - _ = tvmc.common.target_from_cli("cuda, opencl, llvm") - - -def test_target_from_cli__error_target_not_found(): - with pytest.raises(TVMCException): - _ = tvmc.common.target_from_cli("invalidtarget") - - -def test_target_from_cli__error_no_tvm_target(): - with pytest.raises(TVMCException): - _ = tvmc.common.target_from_cli("ethos-n77") - - -def test_target_two_tvm_targets(): - tvm_target, extra_targets = tvmc.common.target_from_cli( - "opencl -device=mali, llvm -mtriple=aarch64-linux-gnu" - ) - - assert "opencl" in str(tvm_target) - assert "llvm" in str(tvm_target.host) - - # No extra targets - assert 0 == len(extra_targets) - - -def test_tokenize_target_with_opts(): - tokens = tvmc.common.tokenize_target("foo -opt1=value1 --flag, bar -opt2=value2") - expected_tokens = ["foo", "-opt1=value1", "--flag", ",", "bar", "-opt2=value2"] - - assert len(tokens) == len(expected_tokens) - assert tokens == expected_tokens - - -def test_tokenize_target_with_plus_sign(): - tokens = tvmc.common.tokenize_target("foo -opt1=+value1 --flag, bar -opt2=test,+v") - expected_tokens = ["foo", "-opt1=+value1", "--flag", ",", "bar", "-opt2=test,+v"] - - assert len(tokens) == len(expected_tokens) - assert tokens == expected_tokens - - -def test_tokenize_target_with_commas(): - tokens = tvmc.common.tokenize_target("foo -opt1=v,a,l,u,e,1 --flag") - expected_tokens = ["foo", "-opt1=v,a,l,u,e,1", "--flag"] - - assert len(tokens) == len(expected_tokens) - assert tokens == expected_tokens - - -def test_tokenize_target_with_commas_and_single_quotes(): - tokens = tvmc.common.tokenize_target("foo -opt1='v, a, l, u, e', bar") - expected_tokens = ["foo", "-opt1='v, a, l, u, e'", ",", "bar"] - - assert len(tokens) == len(expected_tokens) - assert tokens == expected_tokens - - -def test_tokenize_target_with_commas_and_double_quotes(): - tokens = tvmc.common.tokenize_target('foo -opt1="v, a, l, u, e", bar') - expected_tokens = ["foo", '-opt1="v, a, l, u, e"', ",", "bar"] - - assert len(tokens) == len(expected_tokens) - assert tokens == expected_tokens - - -def test_tokenize_target_with_dashes(): - tokens = tvmc.common.tokenize_target("foo-bar1 -opt-1=t-e-s-t, baz") - expected_tokens = ["foo-bar1", "-opt-1=t-e-s-t", ",", "baz"] - - assert len(tokens) == len(expected_tokens) - assert tokens == expected_tokens - - -def test_parse_single_target_with_opts(): - targets = tvmc.common.parse_target("llvm -device=arm_cpu --system-lib") - - assert len(targets) == 1 - assert "device" in targets[0]["opts"] - assert "system-lib" in targets[0]["opts"] - - -def test_parse_multiple_target(): - targets = tvmc.common.parse_target("compute-library, llvm -device=arm_cpu --system-lib") - - assert len(targets) == 2 - assert "compute-library" == targets[0]["name"] - assert "llvm" == targets[1]["name"] - - -def test_parse_multiple_target_with_opts(): - targets = tvmc.common.parse_target("ethos-n77 -myopt=value, llvm -device=arm_cpu --system-lib") - - assert len(targets) == 2 - assert "ethos-n77" == targets[0]["name"] - assert "myopt" in targets[0]["opts"] - assert "value" == targets[0]["opts"]["myopt"] - assert "llvm" == targets[1]["name"] - - -def test_parse_quotes_and_separators_on_options(): - targets_no_quote = tvmc.common.parse_target("foo -option1=+v1.0x,+value,+bar") - targets_single_quote = tvmc.common.parse_target("foo -option1='+v1.0x,+value'") - targets_double_quote = tvmc.common.parse_target('foo -option1="+v1.0x,+value"') - - assert len(targets_no_quote) == 1 - assert "+v1.0x,+value,+bar" == targets_no_quote[0]["opts"]["option1"] - - assert len(targets_single_quote) == 1 - assert "+v1.0x,+value" == targets_single_quote[0]["opts"]["option1"] - - assert len(targets_double_quote) == 1 - assert "+v1.0x,+value" == targets_double_quote[0]["opts"]["option1"] - - -def test_config_invalid_format(): - with pytest.raises(TVMCException): - _ = tvmc.common.parse_configs(["relay.backend.use_auto_scheduler.missing.value"]) - - -def test_config_missing_from_tvm(): - with pytest.raises(TVMCException): - _ = tvmc.common.parse_configs(["relay.backend.use_auto_scheduler.missing.value=1234"]) - - -def test_config_unsupported_tvmc_config(): - with pytest.raises(TVMCException): - _ = tvmc.common.parse_configs(["tir.LoopPartition=value"]) - - -def test_config_empty(): - with pytest.raises(TVMCException): - _ = tvmc.common.parse_configs([""]) - - -def test_config_valid_config_bool(): - configs = tvmc.common.parse_configs(["relay.backend.use_auto_scheduler=true"]) - - assert len(configs) == 1 - assert "relay.backend.use_auto_scheduler" in configs.keys() - assert configs["relay.backend.use_auto_scheduler"] == True - - -@pytest.mark.skipif( - not vitis_ai_available(), - reason="--target vitis-ai is not available. TVM built with 'USE_VITIS_AI OFF'", -) -def test_config_valid_multiple_configs(): - configs = tvmc.common.parse_configs( - [ - "relay.backend.use_auto_scheduler=false", - "tir.detect_global_barrier=10", - "relay.ext.vitis_ai.options.build_dir=mystring", - ] - ) - - assert len(configs) == 3 - assert "relay.backend.use_auto_scheduler" in configs.keys() - assert configs["relay.backend.use_auto_scheduler"] == False - assert "tir.detect_global_barrier" in configs.keys() - assert configs["tir.detect_global_barrier"] == 10 - assert "relay.ext.vitis_ai.options.build_dir" in configs.keys() - assert configs["relay.ext.vitis_ai.options.build_dir"] == "mystring"