diff --git a/apps/bundle_deploy/build_model.py b/apps/bundle_deploy/build_model.py index 8fbc01bcf4a6..096fea2d27ad 100644 --- a/apps/bundle_deploy/build_model.py +++ b/apps/bundle_deploy/build_model.py @@ -20,15 +20,15 @@ import os from tvm import relay import tvm -from tvm import te, runtime +from tvm import runtime as tvm_runtime import logging -import json +from tvm.relay.backend import Runtime from tvm.contrib import cc as _cc -RUNTIMES = { - "c": "{name}_c.{ext}", - "c++": "{name}_cpp.{ext}", -} +RUNTIMES = [ + (Runtime("crt", {"system-lib": True}), "{name}_c.{ext}"), + (Runtime("cpp", {"system-lib": True}), "{name}_cpp.{ext}"), +] def build_module(opts): @@ -43,18 +43,16 @@ def build_module(opts): func.params, relay.nn.softmax(func.body), None, func.type_params, func.attrs ) - for runtime_name, file_format_str in RUNTIMES.items(): + for runtime, file_format_str in RUNTIMES: with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): - graph, lib, params = relay.build( - func, f"llvm --runtime={runtime_name} --system-lib", params=params - ) + graph, lib, params = relay.build(func, "llvm", runtime=runtime, params=params) build_dir = os.path.abspath(opts.out_dir) if not os.path.isdir(build_dir): os.makedirs(build_dir) - ext = "tar" if runtime_name == "c" else "o" + ext = "tar" if str(runtime) == "crt" else "o" lib_file_name = os.path.join(build_dir, file_format_str.format(name="model", ext=ext)) - if runtime_name == "c": + if str(runtime) == "crt": lib.export_library(lib_file_name) else: # NOTE: at present, export_libarary will always create _another_ shared object, and you @@ -70,7 +68,7 @@ def build_module(opts): with open( os.path.join(build_dir, file_format_str.format(name="params", ext="bin")), "wb" ) as f_params: - f_params.write(runtime.save_param_dict(params)) + f_params.write(tvm_runtime.save_param_dict(params)) def build_test_module(opts): @@ -84,20 +82,21 @@ def build_test_module(opts): y_data = np.random.rand(1, 5).astype("float32") params = {"y": y_data} - for runtime_name, file_format_str in RUNTIMES.items(): + for runtime, file_format_str in RUNTIMES: with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): graph, lib, lowered_params = relay.build( tvm.IRModule.from_expr(func), - f"llvm --runtime={runtime_name} --system-lib", + "llvm", + runtime=runtime, params=params, ) build_dir = os.path.abspath(opts.out_dir) if not os.path.isdir(build_dir): os.makedirs(build_dir) - ext = "tar" if runtime_name == "c" else "o" + ext = "tar" if str(runtime) == "crt" else "o" lib_file_name = os.path.join(build_dir, file_format_str.format(name="test_model", ext=ext)) - if runtime_name == "c": + if str(runtime) == "crt": lib.export_library(lib_file_name) else: # NOTE: at present, export_libarary will always create _another_ shared object, and you @@ -113,7 +112,7 @@ def build_test_module(opts): with open( os.path.join(build_dir, file_format_str.format(name="test_params", ext="bin")), "wb" ) as f_params: - f_params.write(runtime.save_param_dict(lowered_params)) + f_params.write(tvm_runtime.save_param_dict(lowered_params)) with open( os.path.join(build_dir, file_format_str.format(name="test_data", ext="bin")), "wb" ) as fp: diff --git a/apps/microtvm/ethosu/run_demo.sh b/apps/microtvm/ethosu/run_demo.sh index 5d9efb359b24..5dd23bc12df8 100755 --- a/apps/microtvm/ethosu/run_demo.sh +++ b/apps/microtvm/ethosu/run_demo.sh @@ -130,8 +130,12 @@ curl --retry 64 -sSL ${mobilenet_url} | gunzip | tar -xvf - ./mobilenet_v1_1.0_2 # Compile model for Arm(R) Cortex(R)-M55 CPU and Ethos(TM)-U55 NPU # An alternative to using "python3 -m tvm.driver.tvmc" is to call # "tvmc" directly once TVM has been pip installed. -python3 -m tvm.driver.tvmc compile --target="ethos-u -accelerator_config=ethos-u55-256, \ - c -runtime=c --link-params -mcpu=cortex-m55 -executor=aot -interface-api=c -unpacked-api=1" \ +python3 -m tvm.driver.tvmc compile --target="ethos-u -accelerator_config=ethos-u55-256, c" \ + --target-c-mcpu=cortex-m55 \ + --runtime=crt \ + --executor=aot \ + --executor-aot-interface-api=c \ + --executor-aot-unpacked-api=1 \ --pass-config tir.disable_vectorize=1 ./mobilenet_v1_1.0_224_quant.tflite --output-format=mlf tar -xvf module.tar diff --git a/docs/arch/microtvm_design.rst b/docs/arch/microtvm_design.rst index 087b8166c226..f9c06c10b677 100644 --- a/docs/arch/microtvm_design.rst +++ b/docs/arch/microtvm_design.rst @@ -127,7 +127,7 @@ logs use it to rank measured performance (but see Future Work). Targets are currently represented as strings structured similarly to command-line arguments. An example target is shown below: - ``c -keys=arm_cpu -mcpu=cortex-m7 -link-params -model=stm32f746xx -runtime=c -system-lib=1`` + ``c -keys=arm_cpu -mcpu=cortex-m7 -model=stm32f746xx`` The relevant parts to microTVM are: @@ -135,10 +135,16 @@ The relevant parts to microTVM are: * ``-mcpu=cortex-m7``: used by TOPI to enable Cortex-M schedules, and, when the C source code generator is selected, included in the output as a comment to help identify the code and configure the downstream C compiler. - * ``-link-params``: include parameters as global constants to load from flash. - * ``-runtime=c``: build glue code to allow operators to work with the C runtime - * ``-system-lib=1``: emit a system library (i.e. which can be loaded by calling the PackedFunc - ``runtime.SystemLib``. + +Runtime and Executor configuration for microTVM +----------------------------------------------- + +When using microTVM, it's important to use the C Runtime (``Runtime('crt')``), which is the runtime that works best on micro devices rather than the more dynamic C++ Runtime. Alongside this, there are two executors which you could use in combination with the C runtime: + +* ``Executor("aot")`` - The Ahead of Time (AOT) executor precompiles the network into a runnable function which you can add directly into your micro application +* ``Executor("graph", {"link-params": True})`` - The Graph executor provides a JSON representation of your network and requires the C Runtime's system library to be generated to find functions in the function registry (``Runtime("crt", {"system-lib": True})``). ``{"link-params":True}`` enables parameters to be linked into the generated files rather than provided externally. + +These are specified when building a runtime module: ``relay.build(..., runtime=..., executor=...)``. Writing Schedules for microTVM ------------------------------ diff --git a/gallery/how_to/work_with_microtvm/micro_autotune.py b/gallery/how_to/work_with_microtvm/micro_autotune.py index d3106712aa99..394a946cf3d5 100644 --- a/gallery/how_to/work_with_microtvm/micro_autotune.py +++ b/gallery/how_to/work_with_microtvm/micro_autotune.py @@ -32,6 +32,7 @@ import pathlib import tvm +from tvm.relay.backend import Executor, Runtime #################### # Defining the model @@ -69,13 +70,15 @@ # Defining the target # ####################### # Now we define the TVM target that describes the execution environment. This looks very similar -# to target definitions from other microTVM tutorials. +# to target definitions from other microTVM tutorials. Alongside this we pick the C Runtime to code +# generate our model against. # # When running on physical hardware, choose a target and a board that # describe the hardware. There are multiple hardware targets that could be selected from # PLATFORM list in this tutorial. You can chose the platform by passing --platform argument when running # this tutorial. # +RUNTIME = Runtime("crt", {"system-lib": True}) TARGET = tvm.target.target.micro("host") # Compiling for physical hardware @@ -123,6 +126,7 @@ build_kwargs={"build_option": {"tir.disable_vectorize": True}}, do_fork=True, build_func=tvm.micro.autotvm_build_func, + runtime=RUNTIME, ) runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=100, module_loader=module_loader) @@ -175,7 +179,7 @@ # the tuned operator. with pass_context: - lowered = tvm.relay.build(relay_mod, target=TARGET, params=params) + lowered = tvm.relay.build(relay_mod, target=TARGET, runtime=RUNTIME, params=params) temp_dir = tvm.contrib.utils.tempdir() @@ -218,7 +222,7 @@ with tvm.autotvm.apply_history_best("microtvm_autotune.log.txt"): with pass_context: - lowered_tuned = tvm.relay.build(relay_mod, target=TARGET, params=params) + lowered_tuned = tvm.relay.build(relay_mod, target=TARGET, runtime=RUNTIME, params=params) temp_dir = tvm.contrib.utils.tempdir() diff --git a/gallery/how_to/work_with_microtvm/micro_tflite.py b/gallery/how_to/work_with_microtvm/micro_tflite.py index 35b08d87b9ee..bd70fc581c5c 100644 --- a/gallery/how_to/work_with_microtvm/micro_tflite.py +++ b/gallery/how_to/work_with_microtvm/micro_tflite.py @@ -124,12 +124,9 @@ import os import numpy as np -import logging import tvm -import tvm.micro as micro from tvm.contrib.download import download_testdata -from tvm.contrib import graph_executor, utils from tvm import relay model_url = "https://people.linaro.org/~tom.gall/sine_model.tflite" @@ -179,9 +176,10 @@ # Now we create a build config for relay, turning off two options and then calling relay.build which # will result in a C source file for the selected TARGET. When running on a simulated target of the # same architecture as the host (where this Python script is executed) choose "host" below for the -# TARGET and a proper board/VM to run it (Zephyr will create the right QEMU VM based on BOARD. In -# the example below the x86 arch is selected and a x86 VM is picked up accordingly: +# TARGET, the C Runtime as the RUNTIME and a proper board/VM to run it (Zephyr will create the right +# QEMU VM based on BOARD. In the example below the x86 arch is selected and a x86 VM is picked up accordingly: # +RUNTIME = tvm.relay.backend.Runtime("crt", {"system-lib": True}) TARGET = tvm.target.target.micro("host") BOARD = "qemu_x86" # @@ -210,7 +208,7 @@ with tvm.transform.PassContext( opt_level=3, config={"tir.disable_vectorize": True}, disabled_pass=["AlterOpLayout"] ): - module = relay.build(mod, target=TARGET, params=params) + module = relay.build(mod, target=TARGET, runtime=RUNTIME, params=params) # Inspecting the compilation output diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 852c7d0d8a98..ec8c9b6c4b2c 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -301,6 +301,12 @@ class IRModuleNode : public Object { */ TVM_DLL void ImportFromStd(const String& path); + /*! + * \brief Should Link Parameters into the module + * \return Whether the Executor is configured to execute with linked parameters (Default: false) + */ + TVM_DLL Bool ShouldLinkParameters() const; + /*! * \brief The set of imported files. */ @@ -468,5 +474,27 @@ TVM_DLL String PrettyPrint(const ObjectRef& node); */ TVM_DLL String AsText(const ObjectRef& node, bool show_meta_data = true, runtime::TypedPackedFunc annotate = nullptr); + +namespace attr { + +/*! + * \brief Executor targetted by the module + * + * Type: Executor + * + * \sa tvm::relay::Executor + */ +constexpr const char* kExecutor = "executor"; + +/*! + * \brief Runtime target of the module + * + * Type: Runtime + * + * \sa tvm::relay::Runtime + */ +constexpr const char* kRuntime = "runtime"; + +} // namespace attr } // namespace tvm #endif // TVM_IR_MODULE_H_ diff --git a/include/tvm/relay/executor.h b/include/tvm/relay/executor.h index 4f779e1dc0a4..4c3b1d3d8a9f 100644 --- a/include/tvm/relay/executor.h +++ b/include/tvm/relay/executor.h @@ -59,6 +59,14 @@ class ExecutorNode : public Object { /* \brief Additional attributes storing meta-data about the Executor. */ DictAttrs attrs; + /*! + * \brief Should Link Parameters into the module + * \return Whether the Executor is configured to execute modules with linked parameters + */ + Bool ShouldLinkParameters() const { + return name == "aot" || GetAttr("link-params").value_or(Bool(false)); + } + /*! * \brief Get an attribute. * @@ -114,6 +122,8 @@ class ExecutorNode : public Object { */ class Executor : public ObjectRef { public: + Executor() = default; + /*! * \brief Create a new Executor object using the registry * \throws Error if name is not registered @@ -121,7 +131,7 @@ class Executor : public ObjectRef { * \param attrs Attributes for the executor. * \return the new Executor object. */ - TVM_DLL static Executor Create(String name, Map attrs); + TVM_DLL static Executor Create(String name, Map attrs = {}); /*! * \brief List all registered Executors diff --git a/include/tvm/relay/runtime.h b/include/tvm/relay/runtime.h index cc2ea4193ff2..38b87c5c9c99 100644 --- a/include/tvm/relay/runtime.h +++ b/include/tvm/relay/runtime.h @@ -114,6 +114,8 @@ class RuntimeNode : public Object { */ class Runtime : public ObjectRef { public: + Runtime() = default; + /*! * \brief Create a new Runtime object using the registry * \throws Error if name is not registered @@ -121,7 +123,7 @@ class Runtime : public ObjectRef { * \param attrs Attributes for the Runtime. * \return the new Runtime object. */ - TVM_DLL static Runtime Create(String name, Map attrs); + TVM_DLL static Runtime Create(String name, Map attrs = {}); /*! * \brief List all registered Runtimes diff --git a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py index 7299875bf28d..ed78cd689ec7 100644 --- a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py +++ b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py @@ -144,9 +144,7 @@ def _traverse_expr(node): mod = tvm.IRModule.from_expr(relay.Function(params, call)) relay.backend.te_compiler.get().clear() tracing_target = _replace_device_with_tracing(tvm_target) - build_thread = threading.Thread( - target=relay.build, args=(mod, tracing_target, None, None) - ) + build_thread = threading.Thread(target=relay.build, args=(mod, tracing_target)) build_thread.start() build_thread.join() elif isinstance(node, Var): diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py index efe45daa1464..360e1a0d90bc 100644 --- a/python/tvm/autotvm/measure/measure_methods.py +++ b/python/tvm/autotvm/measure/measure_methods.py @@ -89,10 +89,18 @@ class LocalBuilder(Builder): If is callable, use it as custom build function, expect lib_format field. do_fork: bool If False, do not fork when building. Requires n_parallel=1. + runtime: Optional[Runtime] + Specify the runtime to generate artifacts for """ def __init__( - self, timeout=10, n_parallel=None, build_kwargs=None, build_func="default", do_fork=False + self, + timeout=10, + n_parallel=None, + build_kwargs=None, + build_func="default", + do_fork=False, + runtime=None, ): super(LocalBuilder, self).__init__(timeout, n_parallel, build_kwargs) @@ -105,7 +113,7 @@ def __init__( build_func = stackvm.build else: raise ValueError("Invalid build_func" + build_func) - self.build_func = _WrappedBuildFunc(build_func) + self.build_func = _WrappedBuildFunc(build_func, runtime) if not do_fork: assert n_parallel in ( None, @@ -455,7 +463,9 @@ def set_task(self, task): return server, tracker -def _build_func_common(measure_input, check_gpu=None, cuda_arch=None, build_option=None): +def _build_func_common( + measure_input, runtime=None, check_gpu=None, cuda_arch=None, build_option=None +): """Common part for building a configuration""" target, task, config = measure_input target, task.target_host = Target.check_and_update_host_consist(target, task.target_host) @@ -484,7 +494,7 @@ def _build_func_common(measure_input, check_gpu=None, cuda_arch=None, build_opti func = vta.build(s, args, target_host=task.target_host) else: with tvm.ir.transform.PassContext(config=opts): - func = build(s, args, target_host=task.target_host) + func = build(s, args, target_host=task.target_host, runtime=runtime) return func, tuple((get_const_tuple(x.shape), x.dtype) for x in args) @@ -499,6 +509,8 @@ class _WrappedBuildFunc: ---------- build_func : The compilation function We expect fcompile to contain an attr "output_format". + runtime : Optional[Runtime] + The runtime to generate artifacts for Returns ------- @@ -506,10 +518,11 @@ class _WrappedBuildFunc: The wrapped build function """ - def __init__(self, build_func): + def __init__(self, build_func, runtime=None): if not hasattr(build_func, "output_format"): raise AttributeError("Expect build_func to have the attribute output_format.") self.build_func = build_func + self.runtime = runtime def __call__(self, measure_input, tmp_dir, **kwargs): """ @@ -529,14 +542,13 @@ def __call__(self, measure_input, tmp_dir, **kwargs): tmp_dir, "tmp_func_%0x.%s" % (getrandbits(64), self.build_func.output_format) ) # TODO(tvm-team) consider linline _build_func_common - func, arg_info = _build_func_common(measure_input, **kwargs) + func, arg_info = _build_func_common(measure_input, self.runtime, **kwargs) if self.build_func.output_format == ".model-library-format": # Late import to preserve autoTVM with USE_MICRO OFF try: from tvm import micro # pylint: disable=import-outside-toplevel except ImportError: raise ImportError("Requires USE_MICRO") - micro.export_model_library_format(func, filename) else: func.export_library(filename, self.build_func) diff --git a/python/tvm/autotvm/task/relay_integration.py b/python/tvm/autotvm/task/relay_integration.py index a7d998c33bc6..d2861cfe8ad6 100644 --- a/python/tvm/autotvm/task/relay_integration.py +++ b/python/tvm/autotvm/task/relay_integration.py @@ -46,7 +46,7 @@ def _lower(mod, target, params): with vta.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): mod, _ = relay.optimize(mod, target, params) grc = graph_executor_codegen.GraphExecutorCodegen(None, target) - grc.codegen(mod["main"]) + grc.codegen(mod, mod["main"]) return compiler = relay.vm.VMCompiler() diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 58e2866668c5..6c19ca065fb4 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -139,6 +139,9 @@ def build( args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None, target: Optional[Union[str, Target]] = None, target_host: Optional[Union[str, Target]] = None, + runtime: Optional[ + "tvm.relay.backend.Runtime" + ] = None, # Type is annotated this way to avoid cyclic dependency name: Optional[str] = "default_function", binds: Optional[Mapping[tensor.Tensor, Buffer]] = None, ): @@ -166,6 +169,9 @@ def build( By default, llvm is used if it is enabled, otherwise a stackvm interpreter is used. + runtime : Optional[Runtime] + Runtime to generate artifacts for + name : Optional[str] The name of result function. @@ -243,18 +249,20 @@ def build( else: target_input_mod = inputs + # Because modules can be created from a variety of sources, we annotate them + # with the relevant attributes here to ensure they propagate + annotated_mods = {} for tar, mod in target_input_mod.items(): if not isinstance(tar, (str, Target)): raise ValueError("The key of inputs must be str or " "Target when inputs is dict.") if not isinstance(mod, tvm.IRModule): raise ValueError("inputs must be Schedule, IRModule," "or dict of str to IRModule.") + annotated_mods[tar] = mod.with_attr("runtime", runtime) - target_input_mod, target_host = Target.check_and_update_host_consist( - target_input_mod, target_host - ) + annotated_mods, target_host = Target.check_and_update_host_consist(annotated_mods, target_host) if not target_host: - for tar, mod in target_input_mod.items(): + for tar, mod in annotated_mods.items(): tar = Target(tar) device_type = ndarray.device(tar.kind.name, 0).device_type if device_type == ndarray.cpu(0).device_type: @@ -263,37 +271,30 @@ def build( if not target_host: target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm" - target_input_mod, target_host = Target.check_and_update_host_consist( - target_input_mod, target_host - ) + annotated_mods, target_host = Target.check_and_update_host_consist(annotated_mods, target_host) - rt_mod_host = _driver_ffi.preprocess_module(target_input_mod, target_host) + rt_mod_host = _driver_ffi.preprocess_module(annotated_mods, target_host) - target_input_mod, target_host = Target.check_and_update_host_consist( - target_input_mod, target_host - ) + annotated_mods, target_host = Target.check_and_update_host_consist(annotated_mods, target_host) if not isinstance(target_host, Target): target_host = Target(target_host) - if ( - target_host.attrs.get("runtime", tvm.runtime.String("c++")) == "c" - and target_host.attrs.get("system-lib", 0) == 1 - ): + + if str(runtime) == "crt" and runtime["system-lib"]: if target_host.kind.name == "c": create_csource_crt_metadata_module = tvm._ffi.get_global_func( "runtime.CreateCSourceCrtMetadataModule" ) - to_return = create_csource_crt_metadata_module([rt_mod_host], target_host) - + to_return = create_csource_crt_metadata_module([rt_mod_host], target_host, runtime) elif target_host.kind.name == "llvm": create_llvm_crt_metadata_module = tvm._ffi.get_global_func( "runtime.CreateLLVMCrtMetadataModule" ) - to_return = create_llvm_crt_metadata_module([rt_mod_host], target_host) + to_return = create_llvm_crt_metadata_module([rt_mod_host], target_host, runtime) else: to_return = rt_mod_host - return OperatorModule.from_module(to_return, ir_module_by_target=target_input_mod, name=name) + return OperatorModule.from_module(to_return, ir_module_by_target=annotated_mods, name=name) class OperatorModule(Module): diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index 7623a141c27a..0c693307b716 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -25,7 +25,9 @@ import tvm from tvm import autotvm, auto_scheduler from tvm import relay +from tvm.driver.tvmc.registry import generate_registry_args, reconstruct_registry_entity from tvm.target import Target +from tvm.relay.backend import Executor, Runtime from . import common, composite_target, frontends from .model import TVMCModel, TVMCPackage @@ -92,6 +94,7 @@ def add_compile_parser(subparsers): "times, each one to set one configuration value, " "e.g. '--pass-config relay.backend.use_auto_scheduler=0'.", ) + generate_target_args(parser) parser.add_argument( "--tuning-records", @@ -100,6 +103,9 @@ def add_compile_parser(subparsers): help="path to an auto-tuning log file by AutoTVM. If not presented, " "the fallback/tophub configs will be used.", ) + generate_registry_args(parser, Executor, "graph") + generate_registry_args(parser, Runtime, "cpp") + parser.add_argument("-v", "--verbose", action="count", default=0, help="increase verbosity.") # TODO (@leandron) This is a path to a physical file, but # can be improved in future to add integration with a modelzoo @@ -141,6 +147,8 @@ def drive_compile(args): compile_model( tvmc_model, args.target, + executor=reconstruct_registry_entity(args, Executor), + runtime=reconstruct_registry_entity(args, Runtime), tuning_records=args.tuning_records, package_path=args.output, cross=args.cross_compiler, @@ -160,6 +168,8 @@ def drive_compile(args): def compile_model( tvmc_model: TVMCModel, target: str, + executor: Optional[Executor] = Executor("graph"), + runtime: Optional[Runtime] = Runtime("cpp"), tuning_records: Optional[str] = None, package_path: Optional[str] = None, cross: Optional[Union[str, Callable]] = None, @@ -257,18 +267,24 @@ def compile_model( opt_level=3, config=config, disabled_pass=disabled_pass ): logger.debug("building relay graph with autoscheduler") - graph_module = relay.build(mod, target=tvm_target, params=params) + graph_module = relay.build( + mod, target=tvm_target, executor=executor, runtime=runtime, params=params + ) else: with autotvm.apply_history_best(tuning_records): with tvm.transform.PassContext( opt_level=3, config=config, disabled_pass=disabled_pass ): logger.debug("building relay graph with tuning records") - graph_module = relay.build(mod, target=tvm_target, params=params) + graph_module = relay.build( + mod, target=tvm_target, executor=executor, runtime=runtime, params=params + ) else: with tvm.transform.PassContext(opt_level=3, config=config, disabled_pass=disabled_pass): logger.debug("building relay graph (no tuning records provided)") - graph_module = relay.build(mod, target=tvm_target, params=params) + graph_module = relay.build( + mod, target=tvm_target, executor=executor, runtime=runtime, params=params + ) # Generate output dump files with sources if dump_code is None: diff --git a/python/tvm/driver/tvmc/registry.py b/python/tvm/driver/tvmc/registry.py new file mode 100644 index 000000000000..384a3bd1baf6 --- /dev/null +++ b/python/tvm/driver/tvmc/registry.py @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This file contains functions for processing registry based inputs for the TVMC CLI +""" + +from tvm.driver.tvmc.common import TVMCException + +# We can't tell the type inside an Array but all current options are strings so +# it can default to that. Bool is used alongside Integer but aren't distinguished +# between as both are represented by IntImm +INTERNAL_TO_NATIVE_TYPE = {"runtime.String": str, "IntImm": int, "Array": str} +INTERNAL_TO_HELP = {"runtime.String": " string", "IntImm": "", "Array": " options"} + + +def _generate_registry_option_args(parser, registry, name): + target_group = parser.add_argument_group(f"{registry.name} {name}") + for option_name, option_type in registry.list_registered_options(name).items(): + if option_type in INTERNAL_TO_NATIVE_TYPE: + target_group.add_argument( + f"--{registry.name}-{name}-{option_name}", + type=INTERNAL_TO_NATIVE_TYPE[option_type], + help=f"{registry.name.title()} {name} {option_name}{INTERNAL_TO_HELP[option_type]}", + ) + + +def generate_registry_args(parser, registry, default=None): + """Walks through the given registry and generates arguments for each of the available options""" + parser.add_argument( + f"--{registry.name}", + help=f"{registry.name.title()} to compile the model with", + required=False, + default=default, + ) + names = registry.list_registered() + for name in names: + _generate_registry_option_args(parser, registry, name) + + +def _reconstruct_registry_options(args, registry, name): + options = {} + for option, option_type in registry.list_registered_options(name).items(): + if option_type in INTERNAL_TO_NATIVE_TYPE: + var_name = f"{registry.name}_{name}_{option.replace('-', '_')}" + option_value = getattr(args, var_name) + if option_value is not None: + options[option] = option_value + return options + + +def reconstruct_registry_entity(args, registry): + """Reconstructs an entity from arguments generated from a registry""" + possible_names = registry.list_registered() + name = getattr(args, registry.name) + if name is None: + return None + + if name not in possible_names: + raise TVMCException(f'{registry.name.title()} "{name}" is not defined') + + reconstructed = { + possible_name: _reconstruct_registry_options(args, registry, possible_name) + for possible_name in possible_names + } + + for possible_name in possible_names: + if possible_name != name and reconstructed[possible_name]: + first_option = list(reconstructed[possible_name])[0] + raise TVMCException( + f"Passed --{registry.name}-{possible_name}-{first_option} " + f"but did not specify {possible_name} executor" + ) + + return registry(name, reconstructed[name]) diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py index 1a705b999b74..b94d50dbf20d 100644 --- a/python/tvm/ir/module.py +++ b/python/tvm/ir/module.py @@ -275,3 +275,38 @@ def script(self, tir_prefix: str = "T", show_meta: bool = False) -> str: return tvm._ffi.get_global_func("script.AsTVMScript")( self, tir_prefix, show_meta ) # type: ignore + + def get_attr(self, attr_key): + """Get the IRModule attribute. + + Parameters + ---------- + attr_key : str + The attribute key. + + Returns + ------- + attr_value : Any + Attribute value + """ + + return _ffi_api.Module_GetAttr(self, attr_key) + + def with_attr(self, attr_key, attr_value): + """Copy the IRModule and add an attribute to it. + + Parameters + ---------- + attr_key : str + The attribute key. + + attr_value : Object + The new attribute value. + + Returns + ------- + mod : IRModule + A new copy of the IRModule with the attribute + """ + + return _ffi_api.Module_WithAttr(self, attr_key, attr_value) diff --git a/python/tvm/micro/model_library_format.py b/python/tvm/micro/model_library_format.py index b69fc05ed942..387248a38b68 100644 --- a/python/tvm/micro/model_library_format.py +++ b/python/tvm/micro/model_library_format.py @@ -270,7 +270,7 @@ def _get_inputs_and_outputs_from_module(mod): def _should_generate_interface_header(mod): - return any(target.attrs.get("interface-api") == "c" for target in mod.target.values()) + return "interface-api" in mod.executor and mod.executor["interface-api"] == "c" def _make_tar(source_dir, tar_file_path): diff --git a/python/tvm/relay/backend/executor.py b/python/tvm/relay/backend/executor.py index b3af565fe69e..9164d6a75ea3 100644 --- a/python/tvm/relay/backend/executor.py +++ b/python/tvm/relay/backend/executor.py @@ -27,6 +27,8 @@ class Executor(Object): """Executor configuration""" + name = "executor" + def __init__(self, name, options=None) -> None: if options is None: options = {} @@ -39,12 +41,15 @@ def __contains__(self, name): def __getitem__(self, name): return self._attrs[name] + def __eq__(self, other): + return str(other) == str(self) and dict(other._attrs) == dict(self._attrs) + @staticmethod - def list_executors(): + def list_registered(): """Returns a list of possible executors""" return list(_backend.ListExecutors()) @staticmethod - def list_executor_options(executor): + def list_registered_options(executor): """Returns the dict of available option names and types""" return dict(_backend.ListExecutorOptions(str(executor))) diff --git a/python/tvm/relay/backend/executor_factory.py b/python/tvm/relay/backend/executor_factory.py index 7b96dd87604e..5f4a134270ac 100644 --- a/python/tvm/relay/backend/executor_factory.py +++ b/python/tvm/relay/backend/executor_factory.py @@ -79,6 +79,8 @@ class AOTExecutorFactoryModule(ExecutorFactoryModule): The IR modules lowered per Target. target : tvm.Target The Target used to build this module. + executor : tvm.relay.backend.Executor + Internal representation of the Executor libmod : tvm.Module The module of the corresponding function libmod_name: str @@ -96,6 +98,7 @@ def __init__( ir_mod, lowered_ir_mods, target, + executor, libmod, libmod_name, params, @@ -105,6 +108,7 @@ def __init__( self.ir_mod = ir_mod self.lowered_ir_mods = lowered_ir_mods self.target = target + self.executor = executor self.lib = libmod self.libmod_name = libmod_name self.params = params @@ -135,6 +139,8 @@ class GraphExecutorFactoryModule(ExecutorFactoryModule): The IR module to build. target : tvm.Target The Target used to build this module. + executor : tvm.relay.backend.Executor + Internal representation of the Executor graph_json_str : the json graph to be deployed in json format output by graph compiler. The graph can contain operator(tvm_op) that points to the name of PackedFunc in the libmod. @@ -152,6 +158,7 @@ def __init__( self, ir_mod, target, + executor, graph_json_str, libmod, libmod_name, @@ -167,6 +174,7 @@ def __init__( self.ir_mod = ir_mod self.target = target + self.executor = executor self.module = fcreate(graph_json_str, libmod, libmod_name, *args) self.graph_json = graph_json_str self.lib = libmod diff --git a/python/tvm/relay/backend/graph_executor_codegen.py b/python/tvm/relay/backend/graph_executor_codegen.py index 58717a0ab482..02e6c4f61af2 100644 --- a/python/tvm/relay/backend/graph_executor_codegen.py +++ b/python/tvm/relay/backend/graph_executor_codegen.py @@ -64,11 +64,13 @@ def _setup(self, mod, target): tgts[_expr.IntImm("int32", 0)] = Target(target) self._init(mod, tgts) - def codegen(self, func): + def codegen(self, ir_module, func): """Compile a single function into a graph. Parameters ---------- + ir_module: tvm.ir.Module + The module to compile func: tvm.relay.Expr The function to compile. @@ -82,7 +84,7 @@ def codegen(self, func): Additional constant parameters. """ default_mod_name = mangle_module_name("default") - self._codegen(func, default_mod_name) + self._codegen(ir_module, func, default_mod_name) graph_json = self._get_graph_json() lowered_func = self._get_irmodule() param_names = self._list_params_name() diff --git a/python/tvm/relay/backend/runtime.py b/python/tvm/relay/backend/runtime.py index 81779a245dde..f2fd69a0f547 100644 --- a/python/tvm/relay/backend/runtime.py +++ b/python/tvm/relay/backend/runtime.py @@ -27,6 +27,8 @@ class Runtime(Object): """Runtime configuration""" + name = "runtime" + def __init__(self, name, options=None) -> None: if options is None: options = {} @@ -37,14 +39,18 @@ def __contains__(self, name): return name in self._attrs def __getitem__(self, name): + self._attrs = _backend.GetRuntimeAttrs(self) return self._attrs[name] + def __eq__(self, other): + return str(other) == str(self) and dict(other._attrs) == dict(self._attrs) + @staticmethod - def list_runtimes(): + def list_registered(): """Returns a list of possible runtimes""" return list(_backend.ListRuntimes()) @staticmethod - def list_runtime_options(runtime): + def list_registered_options(runtime): """Returns the dict of available option names and types""" return dict(_backend.ListRuntimeOptions(str(runtime))) diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index b66d5fbec8c2..09b847a3ba91 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -35,7 +35,7 @@ from . import function as _function from .transform import InferType from .backend.utils import mangle_module_name -from .backend import executor_factory as _executor_factory +from .backend import executor_factory as _executor_factory, Executor, Runtime from .backend import interpreter as _interpreter from .backend.vm import VMExecutor @@ -105,7 +105,14 @@ def __init__(self): self._get_irmodule = self.mod["get_irmodule"] def build( - self, mod, target=None, target_host=None, params=None, executor="graph", mod_name=None + self, + mod, + target=None, + target_host=None, + executor=Executor("graph"), + runtime=Runtime("cpp"), + params=None, + mod_name=None, ): """ Parameters @@ -127,15 +134,18 @@ def build( By default, llvm is used if it is enabled, otherwise a stackvm interpreter is used. + executor : Optional[Executor] + The executor configuration with which to build the model. + Defaults to "graph" if no executor specified. + + runtime : Optional[Runtime] + Runtime configuration to use when building the model. + Defaults to "cpp" if no runtime specified. + params : dict of str to NDArray Input parameters to the graph that do not change during inference time. Used for constant folding. - executor: str[Optional] - The type of executor to be used in order to run the model: - - If "graph" is specified, then the graph_executor will be used - - If "aot" is specified, then the aot_executor will be used - mod_name: Optional[str] The module name we will build @@ -176,13 +186,13 @@ def build( mod_name = mangle_module_name(mod_name) - self._build(mod, target, target_host, executor, mod_name) + self._build(mod, target, target_host, executor, runtime, mod_name) autotvm.GLOBAL_SCOPE.silent = old_autotvm_silent # Get artifacts mod = self.get_module() params = self.get_params() - executor_config = self.get_graph_json() if executor == "graph" else None + executor_config = self.get_graph_json() if str(executor) == "graph" else None return executor_config, mod, params @@ -270,34 +280,78 @@ def _build_module_no_factory(mod, target=None, target_host=None, params=None, mo return build(mod, target, params=params, mod_name=mod_name).module -def get_executor_from_target(target, target_host): - """Helper function to extract the executor parameter from the target +def _reconstruct_from_deprecated_options(deprecated_params_target): + executor = None + runtime = None - Parameters - ---------- - target : Dict of targets for heterogeneous compilation + deprecated_executor = None + deprecated_executor_args = {} + if "executor" in deprecated_params_target.attrs: + _deprecated_target_param_warning("Executor", "executor") + deprecated_executor = deprecated_params_target.attrs.get("executor", "graph") + if "interface-api" in deprecated_params_target.attrs: + _deprecated_target_sub_param_warning("Executor", "interface-api") + deprecated_executor_args.update( + {"interface-api": deprecated_params_target.attrs["interface-api"]} + ) + if "unpacked-api" in deprecated_params_target.attrs: + _deprecated_target_sub_param_warning("Executor", "unpacked-api") + deprecated_executor_args.update( + {"unpacked-api": deprecated_params_target.attrs["unpacked-api"]} + ) + if ( + "link-params" in deprecated_params_target.attrs + and deprecated_params_target.attrs["link-params"] + ): + _deprecated_target_sub_param_warning("Executor", "link-params") + if deprecated_executor != "aot": + deprecated_executor_args.update( + {"link-params": deprecated_params_target.attrs["link-params"]} + ) + if deprecated_executor or deprecated_executor_args: + executor = Executor(deprecated_executor or "graph", deprecated_executor_args) + + deprecated_runtime = None + deprecated_runtime_args = {} + if "runtime" in deprecated_params_target.attrs: + _deprecated_target_param_warning("Runtime", "runtime") + deprecated_runtime = deprecated_params_target.attrs.get("runtime", "cpp") + if deprecated_runtime == "c": + deprecated_runtime = "crt" + if "system-lib" in deprecated_params_target.attrs: + _deprecated_target_sub_param_warning("Runtime", "system-lib") + deprecated_runtime_args.update({"system-lib": deprecated_params_target.attrs["system-lib"]}) + if deprecated_runtime or deprecated_runtime_args: + runtime = Runtime(deprecated_runtime or "cpp", deprecated_runtime_args) + + return executor, runtime + + +def _deprecated_target_param_warning(registry, param): + warnings.warn( + f"Please use {registry} (tvm.relay.backend.{registry}) " + f"instead of deprecated Target parameter -{param}", + DeprecationWarning, + ) - target_host : Host compilation target - Returns - ------- - executor : str - A string representing the executor type - """ - - # Default executor is graph - executor = "graph" - cpu_device_type = 1 - if target_host: - executor = target_host.attrs.get("executor", "graph") - else: - for device_type in target: - if device_type == cpu_device_type: - executor = target[device_type].attrs.get("executor", "graph") - return executor +def _deprecated_target_sub_param_warning(registry, param): + warnings.warn( + f"Please use {registry} (tvm.relay.backend.{registry}) parameter {param} " + f"instead of deprecated Target parameter -{param}", + DeprecationWarning, + ) -def build(ir_mod, target=None, target_host=None, params=None, mod_name="default"): +def build( + ir_mod, + target=None, + target_host=None, + executor=Executor("graph"), + runtime=Runtime("cpp"), + params=None, + mod_name="default", +): # fmt: off # pylint: disable=line-too-long """Helper function that builds a Relay function to run on TVM graph executor. @@ -320,6 +374,14 @@ def build(ir_mod, target=None, target_host=None, params=None, mod_name="default" By default, llvm is used if it is enabled, otherwise a stackvm interpreter is used. + executor : Optional[Executor] + The executor configuration with which to build the model. + Defaults to "graph" if no executor specified. + + runtime : Optional[Runtime] + Runtime configuration to use when building the model. + Defaults to "cpp" if no runtime specified. + params : dict of str to NDArray Input parameters to the graph that do not change during inference time. Used for constant folding. @@ -364,8 +426,16 @@ def build(ir_mod, target=None, target_host=None, params=None, mod_name="default" elif target_host: raise ValueError("target host must be the type of str, " + "tvm.target.Target, or None") - # Retrieve the executor from the target - executor = get_executor_from_target(target, target_host) + # All of this logic is to raise deprecation warnings for various parameters + # TODO(Mousius) Remove these after some time + deprecated_params_target = target_host or list(target.values())[0] + deprecated_executor, deprecated_runtime = _reconstruct_from_deprecated_options( + deprecated_params_target + ) + if deprecated_executor: + executor = deprecated_executor + if deprecated_runtime: + runtime = deprecated_runtime # If current dispatch context is fallback context (the default root context), # then load pre-tuned parameters from TopHub @@ -376,33 +446,33 @@ def build(ir_mod, target=None, target_host=None, params=None, mod_name="default" with tophub_context: bld_mod = BuildModule() - executor_config, runtime_mod, params = bld_mod.build( - mod=ir_mod, target=target, params=params, executor=executor, mod_name=mod_name + graph_json, runtime_mod, params = bld_mod.build( + mod=ir_mod, + target=target, + params=params, + executor=executor, + runtime=runtime, + mod_name=mod_name, ) func_metadata = bld_mod.get_function_metadata() devices = bld_mod.get_devices() lowered_ir_mods = bld_mod.get_irmodule() - if executor == "aot": + if str(executor) == "aot": executor_factory = _executor_factory.AOTExecutorFactoryModule( ir_mod, lowered_ir_mods, target, + executor, runtime_mod, mod_name, params, func_metadata, devices, ) - elif executor == "graph": + elif str(executor) == "graph": executor_factory = _executor_factory.GraphExecutorFactoryModule( - ir_mod, - target, - executor_config, - runtime_mod, - mod_name, - params, - func_metadata, + ir_mod, target, executor, graph_json, runtime_mod, mod_name, params, func_metadata ) else: assert False, "Executor " + executor + " not supported" diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index 9af09296e9cc..378c4e63a8bd 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -316,13 +316,10 @@ def micro(model="unknown", options=None): if model not in MICRO_SUPPORTED_MODELS: raise ValueError(f"Model {model} not supported by tvm.target.micro.") opts = _merge_opts( - MICRO_SUPPORTED_MODELS[model] + ["-runtime=c", f"-model={model}"], + MICRO_SUPPORTED_MODELS[model] + [f"-model={model}"], options, ) - if (not options) or (options and not any("-executor=aot" in o for o in options)): - opts = _merge_opts(opts, "--system-lib") - # NOTE: in the future, the default micro target will be LLVM except when # external dependencies are present. return Target(" ".join(["c"] + opts)) @@ -623,27 +620,13 @@ def create_llvm_options(cpu_ver, config): # pylint: disable=unused-argument args = [s.replace("=", "@") for s in llvm_options.split()] return "--llvm-options=" + ",".join(args) - # TVM target attributes string - def create_tvm_options(cpu_ver, config): # pylint: disable=unused-argument - """Create TVM target features string.""" - - features = { - "link_params": "link-params", - } - opts = "" - for k in config: - if k in features: - opts += " --" + features[k] + "=" + str(config[k]) - return opts - # Sim args os.environ["HEXAGON_SIM_ARGS"] = create_sim_options(cpu_ver, config) target_str = create_llvm_target(cpu_ver, config) llvm_str = create_llvm_options(cpu_ver, config) - tvm_str = create_tvm_options(cpu_ver, config) - args_list = target_str.split() + llvm_str.split() + tvm_str.split() + args_list = target_str.split() + llvm_str.split() return Target(" ".join(["hexagon"] + args_list)) diff --git a/rust/tvm-graph-rt/tests/test_nn/src/build_test_graph.py b/rust/tvm-graph-rt/tests/test_nn/src/build_test_graph.py index 0045b3b0557d..a6e6958e4637 100755 --- a/rust/tvm-graph-rt/tests/test_nn/src/build_test_graph.py +++ b/rust/tvm-graph-rt/tests/test_nn/src/build_test_graph.py @@ -21,9 +21,7 @@ from os import path as osp import sys -import numpy as np -import tvm -from tvm import te, runtime +from tvm import runtime as tvm_runtime from tvm import relay from tvm.relay import testing @@ -41,7 +39,8 @@ def main(): dshape = (4, 8) net = _get_model(dshape) mod, params = testing.create_workload(net) - graph, lib, params = relay.build(mod, "llvm --system-lib", params=params) + runtime = relay.backend.Runtime("cpp", {"system-lib": True}) + graph, lib, params = relay.build(mod, "llvm", runtime=runtime, params=params) out_dir = sys.argv[1] lib.save(osp.join(sys.argv[1], "graph.o")) @@ -49,7 +48,7 @@ def main(): f_resnet.write(graph) with open(osp.join(out_dir, "graph.params"), "wb") as f_params: - f_params.write(runtime.save_param_dict(params)) + f_params.write(tvm_runtime.save_param_dict(params)) if __name__ == "__main__": diff --git a/rust/tvm-graph-rt/tests/test_tvm_basic/src/build_test_lib.py b/rust/tvm-graph-rt/tests/test_tvm_basic/src/build_test_lib.py index 2a9ca23a586a..d6e1922efa85 100755 --- a/rust/tvm-graph-rt/tests/test_tvm_basic/src/build_test_lib.py +++ b/rust/tvm-graph-rt/tests/test_tvm_basic/src/build_test_lib.py @@ -22,6 +22,7 @@ import sys import tvm +from tvm.relay.backend import Runtime from tvm import te @@ -32,8 +33,9 @@ def main(): C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name="C") s = tvm.te.create_schedule(C.op) s[C].parallel(s[C].op.axis[0]) + runtime = Runtime("cpp", {"system-lib": True}) print(tvm.lower(s, [A, B, C], simple_mode=True)) - tvm.build(s, [A, B, C], "llvm --system-lib").save(osp.join(sys.argv[1], "test.o")) + tvm.build(s, [A, B, C], "llvm", runtime=runtime).save(osp.join(sys.argv[1], "test.o")) if __name__ == "__main__": diff --git a/rust/tvm-graph-rt/tests/test_wasm32/src/build_test_lib.py b/rust/tvm-graph-rt/tests/test_wasm32/src/build_test_lib.py index 42da22df53a3..2bf327a31b1b 100755 --- a/rust/tvm-graph-rt/tests/test_wasm32/src/build_test_lib.py +++ b/rust/tvm-graph-rt/tests/test_wasm32/src/build_test_lib.py @@ -23,6 +23,7 @@ import tvm from tvm import te +from tvm.relay.backend import Runtime def main(): @@ -33,7 +34,8 @@ def main(): s = tvm.te.create_schedule(C.op) s[C].parallel(s[C].op.axis[0]) print(tvm.lower(s, [A, B, C], simple_mode=True)) - tvm.build(s, [A, B, C], "llvm -mtriple=wasm32-unknown-unknown --system-lib").save( + runtime = Runtime("cpp", {"system-lib": True}) + tvm.build(s, [A, B, C], "llvm -mtriple=wasm32-unknown-unknown", runtime=runtime).save( osp.join(sys.argv[1], "test.o") ) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index f49409c2baee..7dc7b28b968b 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -24,6 +24,8 @@ #include #include #include +#include +#include #include #include #include @@ -57,8 +59,9 @@ bool LLVMEnabled() { return pf != nullptr; } -bool ShouldAnnotateEntryFunc(const Target target, const IRModule mod) { - const bool aot_executor = (target->GetAttr("executor").value_or("") == "aot"); +bool ShouldAnnotateEntryFunc(const IRModule mod) { + Optional executor = mod->GetAttr("executor"); + const bool aot_executor = executor.defined() && executor.value()->name == "aot"; const bool single_entry_func = (mod->functions.size() == 1); return single_entry_func && !aot_executor; } @@ -451,8 +454,10 @@ runtime::Module PreProcessModuleForBuild(const Map& inputs_arg // Update target host for all targets CheckAndUpdateHostConsistency(&inputs, &target_host); - IRModule mhost_all = IRModule(Map()); - + // Take the attrs from the first module so the eventual modules have them. + // Ideally this would just be one unified module all the way through; + IRModule first_module = (*inputs.begin()).second; + IRModule mhost_all = IRModule(Map(), {}, {}, {}, first_module->attrs); ICHECK(mhost_all.defined()) << "The host module must be defined"; for (const auto& it : inputs) { @@ -513,7 +518,10 @@ runtime::Module build(const Map& inputs_arg, const Target& tar // Update target host for all targets CheckAndUpdateHostConsistency(&inputs, &target_host); - IRModule mhost_all = IRModule(Map()); + // Take the attrs from the first module so the eventual modules have them. + // Ideally this would just be one unified module all the way through; + IRModule first_module = (*inputs.begin()).second; + IRModule mhost_all = IRModule(Map(), {}, {}, {}, first_module->attrs); ICHECK(mhost_all.defined()) << "The host module must be defined"; @@ -592,7 +600,7 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) mixed_pass_list.push_back(tir::transform::VerifyMemory()); - if (ShouldAnnotateEntryFunc(target, mixed_mod)) { + if (ShouldAnnotateEntryFunc(mixed_mod)) { mixed_pass_list.push_back(AnnotateEntryFunc(true)); } @@ -609,10 +617,11 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) mixed_pass_list.push_back(tir::transform::InferFragment()); mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce()); - // The host Target contains these parameters at the moment rather than - // the specific Target - // TODO(Mousius) - Move these to the Executor object rather than Target - if (target->GetHost().value()->GetAttr("unpacked-api").value_or(Bool(false))) { + bool unpacked_api = mixed_mod->GetAttr(tvm::attr::kExecutor) + .value_or(relay::Executor::Create("graph", {})) + ->GetAttr("unpacked-api") + .value_or(Bool(false)); + if (unpacked_api) { mixed_pass_list.push_back(tir::transform::MakeUnpackedAPI()); } else { mixed_pass_list.push_back(tir::transform::MakePackedAPI(-1)); diff --git a/src/ir/module.cc b/src/ir/module.cc index 8ea83cfb40f0..c63e1df79f2e 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -32,6 +32,7 @@ #include #include #include +#include #include #include @@ -426,6 +427,14 @@ void IRModuleNode::ImportFromStd(const String& path) { this->Import(std_path + "/" + path); } +Bool IRModuleNode::ShouldLinkParameters() const { + Optional executor = GetAttr(tvm::attr::kExecutor); + if (!executor.defined()) { + return Bool(false); + } + return executor.value()->ShouldLinkParameters(); +} + std::unordered_set IRModuleNode::Imports() const { return this->import_set_; } IRModule IRModule::FromText(const String& text, const String& source_path) { @@ -521,6 +530,15 @@ TVM_REGISTER_GLOBAL("ir.Module_ImportFromStd").set_body_typed([](IRModule mod, S mod->ImportFromStd(path); }); +TVM_REGISTER_GLOBAL("ir.Module_WithAttr") + .set_body_typed([](IRModule mod, String key, ObjectRef value) -> IRModule { + return WithAttr(mod, key, value); + }); + +TVM_REGISTER_GLOBAL("ir.Module_GetAttr").set_body_typed([](IRModule mod, String key) -> ObjectRef { + return mod->GetAttr(key); +}); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 9e2eb8dd527d..a1bd026958ca 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -692,26 +693,29 @@ class AOTExecutorCodegen : public MixedModeVisitor { public: AOTExecutorCodegen(runtime::Module* mod, const tec::TargetMap& targets, Target target_host) - : mod_(mod), - targets_(targets), - target_host_(target_host), - use_unpacked_api_(target_host->GetAttr("unpacked-api").value_or(Bool(false))) {} - - LoweredOutput Codegen(relay::Function func, String mod_name) { - IRModule mod = IRModule::FromExpr(func); - IRModule lowered_mod = tec::LowerTEPass(mod_name, [this](BaseFunc func) { - // We need to maintain the constant map for external - // functions so we pass this processing function which - // allows us to process each function as we lower it. - if (func->GetAttr(attr::kCompiler).defined()) { - UpdateConstants(func, ¶ms_); - } - - // TODO(@areusch, @jroesch): We should refactor this to - // execute as a further pass, instead writing data to the - // lowering process directly. - tec::UpdateFunctionMetadata(func, this->function_metadata_); - })(mod); + : mod_(mod), targets_(targets), target_host_(target_host), use_unpacked_api_(Bool(false)) {} + + LoweredOutput Codegen(IRModule mod, relay::Function func, String mod_name) { + Executor executor_config = mod->GetAttr(tvm::attr::kExecutor).value(); + String interface_api = executor_config->GetAttr("interface-api").value_or("packed"); + Integer workspace_byte_alignment = + executor_config->GetAttr("workspace-byte-alignment").value_or(16); + use_unpacked_api_ = executor_config->GetAttr("unpacked-api").value_or(Bool(false)); + + IRModule lowered_mod = + tec::LowerTEPass(mod_name, [this, workspace_byte_alignment](BaseFunc func) { + // We need to maintain the constant map for external + // functions so we pass this processing function which + // allows us to process each function as we lower it. + if (func->GetAttr(attr::kCompiler).defined()) { + UpdateConstants(func, ¶ms_); + } + + // TODO(@areusch, @jroesch): We should refactor this to + // execute as a further pass, instead writing data to the + // lowering process directly. + tec::UpdateFunctionMetadata(func, this->function_metadata_, workspace_byte_alignment); + })(mod); auto lowered_main = lowered_mod->Lookup("main"); auto lowered_main_func = GetRef(lowered_main.as()); @@ -771,15 +775,13 @@ class AOTExecutorCodegen : public MixedModeVisitor { // Build the TIR IRModule for the AOT function Map symbol_map; symbol_map.Set(GlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix), prim_func); - IRModule mod_run(symbol_map); + IRModule mod_run(symbol_map, {}, {}, {}, mod->attrs); // Apply storage rewrite pass to the runner function to do memory planning auto storage_rewrite = tir::transform::StorageRewrite(); mod_run = storage_rewrite(mod_run); // The workspace for main function should be calculated after performing storage_rewrite for // the top level TIR function. - auto workspace_byte_alignment = - target_host_->GetAttr("workspace-byte-alignment").value_or(16); Integer main_workspace_size = CalculateWorkspaceBytes( Downcast(mod_run->Lookup(::tvm::runtime::symbol::tvm_run_func_suffix)), workspace_byte_alignment); @@ -816,9 +818,9 @@ class AOTExecutorCodegen : public MixedModeVisitor { std::vector input_var_names(input_vars_.size()); std::transform(input_vars_.begin(), input_vars_.end(), input_var_names.begin(), [](Var input_var) -> String { return input_var->name_hint(); }); - - ret.metadata = runtime::Metadata(input_var_names, ListDevices(), return_sid_.size(), - runtime::kTvmExecutorAot, mod_name); + ret.metadata = + runtime::Metadata(input_var_names, ListDevices(), return_sid_.size(), + runtime::kTvmExecutorAot, mod_name, interface_api, use_unpacked_api_); return ret; } @@ -848,9 +850,10 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode { }); } else if (name == "codegen") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - Function func = args[0]; - String mod_name = args[1]; - this->output_ = codegen(func, mod_name); + IRModule mod = args[0]; + Function func = args[1]; + String mod_name = args[2]; + this->output_ = this->codegen_->Codegen(mod, func, mod_name); }); } else if (name == "list_params_name") { return PackedFunc( @@ -905,10 +908,6 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode { targets, target_host); } - LoweredOutput codegen(Function func, String mod_name) { - return this->codegen_->Codegen(func, mod_name); - } - Array list_params_name() { Array ret; for (const auto& kv : this->output_.params) { diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 24706fbfb8d8..ab86dbf41b3f 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -24,8 +24,10 @@ #include #include #include +#include #include #include +#include #include #include #include @@ -33,6 +35,7 @@ #include #include "../../target/func_registry_generator.h" +#include "../../target/metadata_module.h" #include "../../target/source/codegen_source_base.h" #include "te_compiler.h" #include "utils.h" @@ -58,7 +61,9 @@ struct BuildOutput { struct ExecutorCodegen { void Init(runtime::Module* m, TargetMap targets) { CallFunc("init", m, targets); } - void Codegen(const Function& func, String mod_name) { CallFunc("codegen", func, mod_name); } + void Codegen(IRModule mod, const Function& func, String mod_name) { + CallFunc("codegen", mod, func, mod_name); + } virtual void UpdateOutput(BuildOutput* ret) = 0; @@ -181,8 +186,8 @@ class RelayBuildModule : public runtime::ModuleNode { [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetModule(); }); } else if (name == "build") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - ICHECK_EQ(args.num_args, 5); - this->Build(args[0], args[1], args[2], args[3], args[4]); + ICHECK_EQ(args.num_args, 6); + this->Build(args[0], args[1], args[2], args[3], args[4], args[5]); }); } else if (name == "list_params") { return PackedFunc( @@ -285,13 +290,16 @@ class RelayBuildModule : public runtime::ModuleNode { * \param mod Relay IRModule * \param targets Target devices * \param target_host Host target device + * \param executor Executor to target + * \param runtime Runtime to codegen for + * \param mod_name Name of the module */ void Build(IRModule mod, const TargetMap& targets, const tvm::Target& target_host, - const String executor, const String mod_name) { + const Executor& executor, const Runtime& runtime, const String mod_name) { VLOG_CONTEXT << "Build"; executor_ = executor; + runtime_ = runtime; config_ = CompilationConfig(PassContext::Current(), targets, target_host); - BuildRelay(std::move(mod), mod_name); } @@ -394,13 +402,17 @@ class RelayBuildModule : public runtime::ModuleNode { // Relay IRModule -> IRModule optimizations. relay_module = OptimizeImpl(std::move(relay_module)); - // Get the updated function. - auto func = Downcast(relay_module->Lookup("main")); + // Get the updated function and new IRModule to build. + // Instead of recreating the IRModule, we should look at the differences between this and the + // incoming IRModule to see if we can just pass (IRModule, Function) to the code generator. + Function func = Downcast(relay_module->Lookup("main")); + IRModule func_module = WithAttrs(IRModule::FromExpr(func), {{tvm::attr::kExecutor, executor_}, + {tvm::attr::kRuntime, runtime_}}); // Generate code for the updated function. - executor_codegen_ = MakeExecutorCodegen(executor_); + executor_codegen_ = MakeExecutorCodegen(executor_->name); executor_codegen_->Init(nullptr, config_->legacy_target_map); - executor_codegen_->Codegen(func, mod_name); + executor_codegen_->Codegen(func_module, func, mod_name); executor_codegen_->UpdateOutput(&ret_); ret_.params = executor_codegen_->GetParams(); @@ -412,12 +424,13 @@ class RelayBuildModule : public runtime::ModuleNode { lowered_funcs.Set(ext_dev, IRModule()); } + const Target& host_target = config_->host_se_scope->target; const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.LLVMModuleCreate"); // Generate a placeholder function that attaches linked params as its arguments. - const Target& host_target = config_->host_se_scope->target; - if (host_target->GetAttr("link-params").value_or(Bool(false))) { - CHECK(pf != nullptr) << "Unable to link-params without llvm codegen."; + Bool should_link_params = func_module->ShouldLinkParameters(); + if (should_link_params) { + CHECK(pf != nullptr) << "Unable to link-params with no target_host and no llvm codegen."; auto param_ids = executor_codegen_->GetParamIds(); auto link_params = Map(); for (auto param : ret_.params) { @@ -431,7 +444,8 @@ class RelayBuildModule : public runtime::ModuleNode { auto prim = tir::PrimFunc(Array(), tir::SeqStmt(Array()), VoidType(), Map(), attrs); if (lowered_funcs.find(host_target) == lowered_funcs.end()) { - lowered_funcs.Set(host_target, IRModule(Map({}))); + lowered_funcs.Set(host_target, + IRModule(Map({}), {}, {}, {}, func_module->attrs)); } lowered_funcs[host_target]->Add(GlobalVar(::tvm::runtime::symbol::tvm_lookup_linked_param), prim); @@ -455,7 +469,7 @@ class RelayBuildModule : public runtime::ModuleNode { auto ext_mods = executor_codegen_->GetExternalModules(); ret_.mod = tvm::codegen::CreateMetadataModule(ret_.params, ret_.mod, ext_mods, host_target, - executor_codegen_->GetMetadata()); + runtime_, executor_codegen_->GetMetadata()); // Remove external params which were stored in metadata module. for (tvm::runtime::Module mod : ext_mods) { auto pf_var = mod.GetFunction("get_const_vars"); @@ -473,16 +487,14 @@ class RelayBuildModule : public runtime::ModuleNode { protected: std::unique_ptr executor_codegen_; + /*! \brief Executor to build for */ + Executor executor_; + /*! \brief Runtime to codegen for */ + Runtime runtime_; /*! \brief parameters */ std::unordered_map params_; /*! \brief building output */ BuildOutput ret_; - /*! - * \brief Executor used to execute the model: - * - graph: use the json graph executor - * - aot: use the aot executor - */ - String executor_; /*! \brief Collects all the targets and scopes we need during compilation. */ CompilationConfig config_; }; diff --git a/src/relay/backend/executor.cc b/src/relay/backend/executor.cc index 3f5c2f4cb00f..7c0c690c07aa 100644 --- a/src/relay/backend/executor.cc +++ b/src/relay/backend/executor.cc @@ -89,12 +89,11 @@ ExecutorRegEntry& ExecutorRegEntry::RegisterOrGet(const String& name) { TVM_REGISTER_EXECUTOR("aot") .add_attr_option("unpacked-api") - .add_attr_option("interface-api"); + .add_attr_option("interface-api") + .add_attr_option("workspace-byte-alignment"); TVM_REGISTER_EXECUTOR("graph").add_attr_option("link-params", Bool(false)); -TVM_REGISTER_EXECUTOR("vm"); - /********** Registry **********/ TVM_REGISTER_GLOBAL("relay.backend.CreateExecutor").set_body_typed(Executor::Create); diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index 051b325e3db3..1e647b5ba1d3 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -200,7 +200,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatoroutput_ = this->codegen_->Codegen(func, mod_name); + IRModule mod = args[0]; + Function func = args[1]; + String mod_name = args[2]; + this->output_ = this->codegen_->Codegen(mod, func, mod_name); }); } else if (name == "get_graph_json") { return PackedFunc( diff --git a/src/relay/backend/runtime.cc b/src/relay/backend/runtime.cc index 1c08cbd29d1e..786d6f937f14 100644 --- a/src/relay/backend/runtime.cc +++ b/src/relay/backend/runtime.cc @@ -88,9 +88,9 @@ RuntimeRegEntry& RuntimeRegEntry::RegisterOrGet(const String& name) { /********** Register Runtimes and options **********/ -TVM_REGISTER_RUNTIME("c").add_attr_option("system-lib"); +TVM_REGISTER_RUNTIME("crt").add_attr_option("system-lib"); -TVM_REGISTER_RUNTIME("cpp"); +TVM_REGISTER_RUNTIME("cpp").add_attr_option("system-lib"); /********** Registry **********/ diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 85a03cb5bd16..cfb4c7923a49 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -848,7 +848,8 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, tec::TargetMa * \param function_metadata The map that stores all the function metadatas */ void UpdateFunctionMetadata(BaseFunc func, - Map& function_metadata) { // NOLINT(*) + Map& function_metadata, // NOLINT(*) + Integer workspace_byte_alignment) { VLOG_CONTEXT << "UpdateFunctionMetadata"; VLOG(1) << "updating function metadata for:" << std::endl << PrettyPrint(func); // Originally UpdateFunctionMetadata took in CCachedFunc and looped through all the funcs stored @@ -878,9 +879,6 @@ void UpdateFunctionMetadata(BaseFunc func, auto prim_fn = Downcast(kv.second); CHECK(prim_fn.defined()) << "the primitive function must be defined"; - auto workspace_byte_alignment = - relay_target.value()->GetAttr("workspace-byte-alignment").value_or(16); - Integer workspace_size = CalculateWorkspaceBytes(prim_fn, workspace_byte_alignment); // Workspace sizes diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index 268d1a65a31b..cb36718df120 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -142,9 +142,11 @@ class TECompiler : public ObjectRef { * input/output sizes) * \param func The function to calculate function metadata for * \param function_metadata The map that stores all the function metadatas + * \param workspace_byte_alignment Byte alignment for allocations */ -void UpdateFunctionMetadata(BaseFunc func, - Map& function_metadata); // NOLINT(*) +void UpdateFunctionMetadata(BaseFunc relay_func, + Map& function_metadata, // NOLINT(*) + Integer workspace_byte_alignment = 16); /*! * \brief Obtain the Target from the device type. diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 490d6893964d..9bdd63a4b126 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -32,6 +32,7 @@ #include #include #include +#include #include #include #include @@ -44,6 +45,7 @@ #include #include +#include "../../../target/metadata_module.h" #include "../../../target/source/codegen_source_base.h" #include "../../op/annotation/annotation.h" #include "../../op/op_common.h" @@ -1170,7 +1172,7 @@ void VMCompiler::Codegen() { lib = codegen::CSourceModuleCreate(";", "", Array{}); } lib = codegen::CreateMetadataModule(params_, lib, ext_mods, config_->host_target, - runtime::Metadata()); + Runtime::Create("cpp"), runtime::Metadata()); exec_->SetLib(lib); } diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h index fd612b08ab0e..8996d1b76e1f 100644 --- a/src/runtime/meta_data.h +++ b/src/runtime/meta_data.h @@ -62,6 +62,10 @@ class MetadataNode : public Object { Array devices; /*! \brief the executor to be used to run the model */ String executor = kTvmExecutorGraph; + /*! \brief The external API (packed or c) in use */ + String interface_api; + /*! \brief The internal API (packed or unpacked) in use */ + bool unpacked_api; String mod_name = ""; @@ -76,12 +80,14 @@ class MetadataNode : public Object { class Metadata : public ObjectRef { public: TVM_DLL Metadata(Array inputs, Array devices, int num_outputs, String executor, - String mod_name) { + String mod_name, String interface_api = "packed", bool unpacked_api = false) { auto n = make_object(); n->inputs = inputs; n->devices = devices; n->num_outputs = num_outputs; n->executor = executor; + n->interface_api = interface_api; + n->unpacked_api = unpacked_api; n->mod_name = mod_name; data_ = std::move(n); } diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 86079b25aa90..dc10d7885c25 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -24,6 +24,7 @@ #ifdef TVM_LLVM_VERSION #include +#include #include #include #include @@ -215,8 +216,6 @@ class LLVMModuleNode final : public runtime::ModuleNode { void Init(const IRModule& mod, const Target& target) { InitializeLLVM(); tm_ = GetLLVMTargetMachine(target); - bool system_lib = target->GetAttr("system-lib").value_or(Bool(false)); - bool target_c_runtime = (target->GetAttr("runtime").value_or("") == kTvmRuntimeCrt); ctx_ = std::make_shared(); std::unique_ptr cg = CodeGenLLVM::Create(tm_.get()); @@ -224,7 +223,12 @@ class LLVMModuleNode final : public runtime::ModuleNode { std::string entry_func; Map linked_params; bool found_linked_params = false; - bool could_have_linked_params = target->GetAttr("link-params").value_or(Bool(false)); + bool could_have_linked_params = mod->ShouldLinkParameters(); + relay::Runtime runtime = + mod->GetAttr(tvm::attr::kRuntime).value_or(relay::Runtime::Create("cpp")); + bool system_lib = runtime->GetAttr("system-lib").value_or(Bool(false)); + bool target_c_runtime = runtime->name == "crt"; + for (auto kv : mod->functions) { if (could_have_linked_params && kv.first->name_hint == ::tvm::runtime::symbol::tvm_lookup_linked_param) { @@ -508,7 +512,8 @@ TVM_REGISTER_GLOBAL("codegen.codegen_blob") return runtime::Module(n); }); -runtime::Module CreateLLVMCrtMetadataModule(const Array& modules, Target target) { +runtime::Module CreateLLVMCrtMetadataModule(const Array& modules, Target target, + tvm::relay::Runtime runtime) { Array func_names; for (runtime::Module mod : modules) { auto pf_funcs = mod.GetFunction("get_func_names"); @@ -522,8 +527,8 @@ runtime::Module CreateLLVMCrtMetadataModule(const Array& module InitializeLLVM(); auto tm = GetLLVMTargetMachine(target); - bool system_lib = target->GetAttr("system-lib").value_or(Bool(false)); - bool target_c_runtime = (target->GetAttr("runtime").value_or("") == kTvmRuntimeCrt); + bool system_lib = runtime->GetAttr("system-lib").value_or(Bool(false)); + bool target_c_runtime = runtime->name == "crt"; ICHECK(system_lib && target_c_runtime) << "For LLVM C-runtime metadata module, must include --system-lib and --runtime=c; " << "got target: " << target->str(); @@ -556,9 +561,7 @@ runtime::Module CreateLLVMCrtMetadataModule(const Array& module } TVM_REGISTER_GLOBAL("runtime.CreateLLVMCrtMetadataModule") - .set_body_typed([](const Array& modules, Target target) { - return CreateLLVMCrtMetadataModule(modules, target); - }); + .set_body_typed(CreateLLVMCrtMetadataModule); } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/llvm_module.h b/src/target/llvm/llvm_module.h index 6b05d4bdf2d5..933030e213d2 100644 --- a/src/target/llvm/llvm_module.h +++ b/src/target/llvm/llvm_module.h @@ -33,7 +33,8 @@ namespace tvm { namespace codegen { -runtime::Module CreateLLVMCrtMetadataModule(const Array& modules, Target target); +runtime::Module CreateLLVMCrtMetadataModule(const Array& modules, Target target, + tvm::relay::Runtime runtime); } // namespace codegen } // namespace tvm diff --git a/src/target/metadata_module.cc b/src/target/metadata_module.cc index db4051e00fd2..2b190e5d66ed 100644 --- a/src/target/metadata_module.cc +++ b/src/target/metadata_module.cc @@ -23,6 +23,8 @@ */ #include "metadata_module.h" +#include + #include #include "../runtime/meta_data.h" @@ -32,21 +34,10 @@ namespace tvm { namespace codegen { -/*! - * \brief Create a metadata module wrapper. The helper is used by different - * codegens, such as graph executor codegen and the vm compiler. - * - * \param params The metadata for initialization of all modules. - * \param target_module the internal module that is compiled by tvm. - * \param ext_modules The external modules that needs to be imported inside the metadata - * module(s). - * \param target The target that all the modules are compiled for - * \return The created metadata module that manages initialization of metadata. - */ runtime::Module CreateMetadataModule( const std::unordered_map& params, tvm::runtime::Module target_module, const Array& ext_modules, Target target, - runtime::Metadata metadata) { + tvm::relay::Runtime runtime, runtime::Metadata metadata) { // Here we split modules into two groups: // 1. Those modules which can be exported to C-runtime. These are DSO-exportable // (i.e. llvm or c) modules which return nothing from get_const_vars(). @@ -58,8 +49,7 @@ runtime::Module CreateMetadataModule( return !std::strcmp(mod->type_key(), "llvm") || !std::strcmp(mod->type_key(), "c"); }; - bool is_targeting_crt = - target.defined() && target->GetAttr("runtime").value_or(String("")) == kTvmRuntimeCrt; + bool is_targeting_crt = runtime->name == "crt"; // Wrap all submodules in the initialization wrapper. std::unordered_map> sym_metadata; @@ -114,11 +104,12 @@ runtime::Module CreateMetadataModule( if (target->kind->name == "c") { crt_exportable_modules.push_back(target_module); - target_module = CreateCSourceCrtMetadataModule(crt_exportable_modules, target, metadata); + target_module = + CreateCSourceCrtMetadataModule(crt_exportable_modules, target, runtime, metadata); } else if (target->kind->name == "llvm") { #ifdef TVM_LLVM_VERSION crt_exportable_modules.push_back(target_module); - target_module = CreateLLVMCrtMetadataModule(crt_exportable_modules, target); + target_module = CreateLLVMCrtMetadataModule(crt_exportable_modules, target, runtime); #else // TVM_LLVM_VERSION LOG(FATAL) << "TVM was not built with LLVM enabled."; #endif // TVM_LLVM_VERSION diff --git a/src/target/metadata_module.h b/src/target/metadata_module.h index 9311ee78ca6a..ee6f7231b3a1 100644 --- a/src/target/metadata_module.h +++ b/src/target/metadata_module.h @@ -25,6 +25,7 @@ #ifndef TVM_TARGET_METADATA_MODULE_H_ #define TVM_TARGET_METADATA_MODULE_H_ +#include #include #include #include @@ -37,9 +38,22 @@ namespace tvm { namespace codegen { +/*! + * \brief Create a metadata module wrapper. The helper is used by different + * codegens, such as graph executor codegen and the vm compiler. + * + * \param params The metadata for initialization of all modules. + * \param target_module the internal module that is compiled by tvm. + * \param ext_modules The external modules that needs to be imported inside the metadata + * module(s). + * \param target The target that all the modules are compiled for + * \param runtime The runtime to codegen for + * \param metadata Module metadata + * \return The created metadata module that manages initialization of metadata. + */ runtime::Module CreateMetadataModule( - const std::unordered_map& params, - tvm::runtime::Module target_module, const Array& ext_modules, Target target, + const std::unordered_map& params, runtime::Module target_module, + const Array& ext_modules, Target target, tvm::relay::Runtime runtime, runtime::Metadata metadata); } // namespace codegen diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index 80ace929b881..37d54571859e 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -22,6 +22,7 @@ */ #include "codegen_c_host.h" +#include #include #include #include @@ -384,7 +385,7 @@ runtime::Module BuildCHost(IRModule mod, Target target) { Map linked_params; bool found_linked_params = false; - bool could_have_linked_params = target->GetAttr("link-params").value_or(Bool(false)); + bool could_have_linked_params = mod->ShouldLinkParameters(); PrimFunc aot_executor_fn; for (auto kv : mod->functions) { diff --git a/src/target/source/codegen_source_base.h b/src/target/source/codegen_source_base.h index ff0d079f5425..d938469b8969 100644 --- a/src/target/source/codegen_source_base.h +++ b/src/target/source/codegen_source_base.h @@ -145,19 +145,6 @@ runtime::Module CSourceModuleCreate(const String& code, const String& fmt, const Array& func_names, const Array& const_vars = {}); -/*! - * \brief Wrap the submodules in a metadata module. - * \param params The variable to constant mapping that is collected by the host - * module. - * \param target_module The main TIR-lowered internal runtime module - * \param modules All the external modules that needs to be imported inside the metadata module(s). - * \param target The target that all the modules are compiled for - * \return The wrapped module. - */ -runtime::Module CreateMetadataModule( - const std::unordered_map& params, runtime::Module target_module, - const Array& ext_modules, Target target, runtime::Metadata metadata); - /*! * \brief Create a source module for viewing and limited saving for device. * \param data The code data to be viewed. @@ -170,16 +157,6 @@ runtime::Module DeviceSourceModuleCreate( std::string data, std::string fmt, std::unordered_map fmap, std::string type_key, std::function fget_source = nullptr); -/*! - * \brief Wrap the submodules that are to be wrapped in a c-source metadata module for C runtime. - * \param modules The modules to be wrapped. - * \param target the target the modules are compiled for. - * \param metadata the metadata needed for code generation. - * \return The wrapped module. - */ -runtime::Module CreateCSourceCrtMetadataModule(const Array& modules, Target target, - runtime::Metadata metadata); - } // namespace codegen } // namespace tvm #endif // TVM_TARGET_SOURCE_CODEGEN_SOURCE_BASE_H_ diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index 21f82c3a99f1..e01a3d93d087 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -130,8 +130,12 @@ runtime::Module CSourceModuleCreate(const String& code, const String& fmt, class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { public: CSourceCrtMetadataModuleNode(const Array& func_names, const std::string& fmt, - Target target, runtime::Metadata metadata) - : fmt_(fmt), func_names_(func_names), target_(target), metadata_(metadata) { + Target target, relay::Runtime runtime, runtime::Metadata metadata) + : fmt_(fmt), + func_names_(func_names), + target_(target), + runtime_(runtime), + metadata_(metadata) { CreateSource(); } const char* type_key() const { return "c"; } @@ -159,6 +163,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { std::string fmt_; Array func_names_; Target target_; + relay::Runtime runtime_; runtime::Metadata metadata_; void CreateFuncRegistry() { @@ -298,22 +303,21 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { const std::string entrypoint_mangled = runtime::get_name_mangled(metadata_->mod_name, tvm_entrypoint_suffix); const std::string network_mangled = runtime::get_name_mangled(metadata_->mod_name, "network"); - auto unpacked_api = target_->GetAttr("unpacked-api").value_or(Bool(false)); - auto interface_api = target_->GetAttr("interface-api").value_or(String("packed")); code_ << "#include \"tvm/runtime/c_runtime_api.h\"\n"; code_ << "#ifdef __cplusplus\n"; code_ << "extern \"C\" {\n"; code_ << "#endif\n"; - if (unpacked_api) { - if (interface_api == "c") { + if (metadata_->unpacked_api) { + if (metadata_->interface_api == "c") { GenerateCInterfaceEntrypoint(entrypoint_mangled, run_func_mangled, metadata_->mod_name); } else { GenerateEntrypointForUnpackedAPI(entrypoint_mangled, run_func_mangled); } } else { - ICHECK_EQ(interface_api, "packed") << "Packed interface required for packed operators"; + ICHECK_EQ(metadata_->interface_api, "packed") + << "Packed interface required for packed operators"; GenerateEntrypointForPackedAPI(entrypoint_mangled, run_func_mangled); } @@ -323,7 +327,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { } void CreateSource() { - if (target_->GetAttr("system-lib").value_or(Bool(false)) && !func_names_.empty()) { + if (runtime_->GetAttr("system-lib").value_or(Bool(false)) && !func_names_.empty()) { CreateFuncRegistry(); GenerateCrtSystemLib(); } @@ -335,7 +339,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { }; runtime::Module CreateCSourceCrtMetadataModule(const Array& modules, Target target, - runtime::Metadata metadata) { + relay::Runtime runtime, runtime::Metadata metadata) { Array func_names; for (runtime::Module mod : modules) { auto pf_funcs = mod.GetFunction("get_func_names"); @@ -346,7 +350,7 @@ runtime::Module CreateCSourceCrtMetadataModule(const Array& mod } } } - auto n = make_object(func_names, "cc", target, metadata); + auto n = make_object(func_names, "cc", target, runtime, metadata); auto csrc_metadata_module = runtime::Module(n); for (const auto& mod : modules) { csrc_metadata_module.Import(mod); @@ -416,9 +420,10 @@ TVM_REGISTER_GLOBAL("runtime.CSourceModuleCreate") }); TVM_REGISTER_GLOBAL("runtime.CreateCSourceCrtMetadataModule") - .set_body_typed([](const Array& modules, Target target) { + .set_body_typed([](const Array& modules, Target target, + relay::Runtime runtime) { // Note that we don't need metadata when we compile a single operator - return CreateCSourceCrtMetadataModule(modules, target, runtime::Metadata()); + return CreateCSourceCrtMetadataModule(modules, target, runtime, runtime::Metadata()); }); } // namespace codegen diff --git a/src/target/source/source_module.h b/src/target/source/source_module.h index 8ed08048cf2f..fde363c1198a 100644 --- a/src/target/source/source_module.h +++ b/src/target/source/source_module.h @@ -25,6 +25,7 @@ #ifndef TVM_TARGET_SOURCE_SOURCE_MODULE_H_ #define TVM_TARGET_SOURCE_SOURCE_MODULE_H_ +#include #include #include @@ -34,12 +35,15 @@ namespace tvm { namespace codegen { /*! - * \brief Create C-runtime targeted metadata module for "c" backend. - * \param modules Array of modules included in the compilation output. - * \param target TVM target. + * \brief Wrap the submodules that are to be wrapped in a c-source metadata module for C runtime. + * \param modules The modules to be wrapped. + * \param target the target the modules are compiled for. + * \param runtime the runtime to code generate against + * \param metadata the metadata needed for code generation. + * \return The wrapped module. */ -runtime::Module CreateCSourceCrtMetadataModule(const Array& modules, - tvm::Target target, runtime::Metadata metadata); +runtime::Module CreateCSourceCrtMetadataModule(const Array& modules, Target target, + relay::Runtime runtime, runtime::Metadata metadata); } // namespace codegen } // namespace tvm diff --git a/tests/cpp/relay_build_module_test.cc b/tests/cpp/relay_build_module_test.cc index 5ef6ec0ad6f1..568338bcb868 100644 --- a/tests/cpp/relay_build_module_test.cc +++ b/tests/cpp/relay_build_module_test.cc @@ -21,9 +21,11 @@ #include #include #include +#include #include #include #include +#include #include #include #include @@ -126,7 +128,7 @@ TEST(Relay, BuildModule) { targets.Set(0, llvm_tgt); auto relay_mod = tvm::IRModule::FromExpr(func); ICHECK(relay_mod.defined()) << "Module must be defined"; - build_f(relay_mod, targets, llvm_tgt, runtime::kTvmExecutorGraph, ""); + build_f(relay_mod, targets, llvm_tgt, Executor::Create("graph"), Runtime::Create("cpp"), ""); std::string json = json_f(); tvm::runtime::Module mod = mod_f(); // run diff --git a/tests/cpp/runtime_test.cc b/tests/cpp/runtime_test.cc index 46fc2f74af60..c87639fffd2c 100644 --- a/tests/cpp/runtime_test.cc +++ b/tests/cpp/runtime_test.cc @@ -21,9 +21,11 @@ #include #include #include +#include #include #include #include +#include #include #include #include @@ -112,7 +114,7 @@ TEST(Runtime, ZeroCopy) { targets.Set(0, llvm_tgt); auto relay_mod = tvm::IRModule::FromExpr(func); ICHECK(relay_mod.defined()) << "Module must be defined"; - build_f(relay_mod, targets, llvm_tgt, runtime::kTvmExecutorGraph, ""); + build_f(relay_mod, targets, llvm_tgt, Executor::Create("graph"), Runtime::Create("cpp"), ""); // create graph executor std::string json = json_f(); tvm::runtime::Module mod = mod_f(); diff --git a/tests/micro/arduino/conftest.py b/tests/micro/arduino/conftest.py index 8625b4a45364..71cc810affe3 100644 --- a/tests/micro/arduino/conftest.py +++ b/tests/micro/arduino/conftest.py @@ -22,7 +22,8 @@ import tvm.target.target from tvm.micro import project -from tvm import micro, relay +from tvm import relay +from tvm.relay.backend import Executor, Runtime TEMPLATE_PROJECT_DIR = pathlib.Path(tvm.micro.get_microtvm_template_projects("arduino")) @@ -139,12 +140,12 @@ def make_kws_project(board, arduino_cli_cmd, tvm_debug, workspace_dir): tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0) mod, params = relay.frontend.from_tflite(tflite_model) - target = tvm.target.target.micro( - model, options=["--link-params=1", "--unpacked-api=1", "--executor=aot"] - ) + target = tvm.target.target.micro(model) + runtime = Runtime("crt") + executor = Executor("aot", {"unpacked-api": True}) with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): - mod = relay.build(mod, target, params=params) + mod = relay.build(mod, target, runtime=runtime, executor=executor, params=params) return tvm.micro.generate_project( str(TEMPLATE_PROJECT_DIR), diff --git a/tests/micro/arduino/test_arduino_rpc_server.py b/tests/micro/arduino/test_arduino_rpc_server.py index f157214241c9..a0dcb923a197 100644 --- a/tests/micro/arduino/test_arduino_rpc_server.py +++ b/tests/micro/arduino/test_arduino_rpc_server.py @@ -22,7 +22,6 @@ """ -import datetime import pathlib import sys @@ -31,8 +30,9 @@ import pytest import tvm from PIL import Image -from tvm import micro, relay +from tvm import relay from tvm.relay.testing import byoc +from tvm.relay.backend import Executor, Runtime import conftest @@ -191,9 +191,11 @@ def test_onnx(board, arduino_cli_cmd, tvm_debug, workspace_dir): relay_mod, params = relay.frontend.from_onnx(onnx_model, shape=shape, freeze_params=True) relay_mod = relay.transform.DynamicToStatic()(relay_mod) - target = tvm.target.target.micro(model, options=["-link-params=1"]) + target = tvm.target.target.micro(model) with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): - lowered = relay.build(relay_mod, target, params=params) + executor = Executor("graph", {"link-params": True}) + runtime = Runtime("crt", {"system-lib": True}) + lowered = relay.build(relay_mod, target, params=params, executor=executor, runtime=runtime) graph = lowered.get_graph_json() with _make_session( diff --git a/tests/micro/zephyr/test_zephyr.py b/tests/micro/zephyr/test_zephyr.py index 10759c3790db..4034008c6d45 100644 --- a/tests/micro/zephyr/test_zephyr.py +++ b/tests/micro/zephyr/test_zephyr.py @@ -17,7 +17,6 @@ import logging import os import pathlib -import subprocess import sys import logging @@ -28,6 +27,7 @@ import tvm import tvm.relay as relay +from tvm.relay.backend import Executor, Runtime from tvm.relay.testing import byoc from tvm.contrib import utils from tvm.micro.testing import check_tune_log @@ -40,10 +40,11 @@ def _make_sess_from_op( temp_dir, model, zephyr_board, west_cmd, op_name, sched, arg_bufs, build_config ): + runtime = Runtime("crt", {"system-lib": True}) target = tvm.target.target.micro(model) target = tvm.target.Target(target=target, host=target) with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): - mod = tvm.build(sched, arg_bufs, target=target, name=op_name) + mod = tvm.build(sched, arg_bufs, target=target, runtime=runtime, name=op_name) return _make_session(temp_dir, zephyr_board, west_cmd, mod, build_config) @@ -179,9 +180,10 @@ def test_relay(temp_dir, board, west_cmd, tvm_debug): func = relay.Function([x], z) ir_mod = tvm.IRModule.from_expr(func) + runtime = Runtime("crt", {"system-lib": True}) target = tvm.target.target.micro(model) with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): - mod = tvm.relay.build(ir_mod, target=target) + mod = tvm.relay.build(ir_mod, target=target, runtime=runtime) with _make_session(temp_dir, board, west_cmd, mod, build_config) as session: graph_mod = tvm.micro.create_local_graph_executor( @@ -217,13 +219,15 @@ def test_onnx(temp_dir, board, west_cmd, tvm_debug): relay_mod, params = relay.frontend.from_onnx(onnx_model, shape=shape, freeze_params=True) relay_mod = relay.transform.DynamicToStatic()(relay_mod) - # We add the -link-params=1 option to ensure the model parameters are compiled in. + # We add the link-params=True option to ensure the model parameters are compiled in. # There is currently a bug preventing the host_driven environment from receiving # the model weights when set using graph_mod.set_input(). # See: https://github.com/apache/tvm/issues/7567 - target = tvm.target.target.micro(model, options=["-link-params=1"]) + target = tvm.target.target.micro(model) with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): - lowered = relay.build(relay_mod, target, params=params) + executor = Executor("graph", {"link-params": True}) + runtime = Runtime("crt", {"system-lib": True}) + lowered = relay.build(relay_mod, target, params=params, executor=executor, runtime=runtime) graph = lowered.get_graph_json() with _make_session(temp_dir, board, west_cmd, lowered, build_config) as session: @@ -249,9 +253,10 @@ def check_result( ): """Helper function to verify results""" TOL = 1e-5 + runtime = Runtime("crt", {"system-lib": True}) target = tvm.target.target.micro(model) with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): - mod = tvm.relay.build(relay_mod, target=target) + mod = tvm.relay.build(relay_mod, target=target, runtime=runtime) with _make_session(temp_dir, zephyr_board, west_cmd, mod, build_config) as session: rt_mod = tvm.micro.create_local_graph_executor( @@ -377,6 +382,7 @@ def test_autotune_conv2d(temp_dir, board, west_cmd, tvm_debug): if board != "qemu_x86": pytest.xfail(f"Autotune fails on {board}.") + runtime = Runtime("crt", {"system-lib": True}) model = test_utils.ZEPHYR_BOARDS[board] build_config = {"debug": tvm_debug} @@ -436,6 +442,7 @@ def test_autotune_conv2d(temp_dir, board, west_cmd, tvm_debug): build_kwargs={"build_option": {"tir.disable_vectorize": True}}, do_fork=True, build_func=tvm.micro.autotvm_build_func, + runtime=runtime, ) runner = tvm.autotvm.LocalRunner( number=1, repeat=1, timeout=timeout, module_loader=module_loader @@ -465,7 +472,7 @@ def test_autotune_conv2d(temp_dir, board, west_cmd, tvm_debug): # Build without tuning with pass_context: - lowered = tvm.relay.build(mod, target=target, params=params) + lowered = tvm.relay.build(mod, target=target, runtime=runtime, params=params) temp_dir = utils.tempdir() with _make_session(temp_dir, board, west_cmd, lowered, build_config) as session: @@ -480,7 +487,7 @@ def test_autotune_conv2d(temp_dir, board, west_cmd, tvm_debug): # Build using autotune logs with tvm.autotvm.apply_history_best(str(log_path)): with pass_context: - lowered_tuned = tvm.relay.build(mod, target=target, params=params) + lowered_tuned = tvm.relay.build(mod, target=target, runtime=runtime, params=params) temp_dir = utils.tempdir() with _make_session(temp_dir, board, west_cmd, lowered_tuned, build_config) as session: diff --git a/tests/micro/zephyr/test_zephyr_aot.py b/tests/micro/zephyr/test_zephyr_aot.py index 4324570e1930..dbb5883f01bd 100644 --- a/tests/micro/zephyr/test_zephyr_aot.py +++ b/tests/micro/zephyr/test_zephyr_aot.py @@ -30,6 +30,7 @@ import tvm.testing from tvm.micro.project_api import server import tvm.relay as relay +from tvm.relay.backend import Executor, Runtime from tvm.contrib.download import download_testdata from tvm.micro.model_library_format import generate_c_interface_header @@ -65,20 +66,14 @@ def test_tflite(temp_dir, board, west_cmd, tvm_debug): tflite_model, shape_dict={"input_1": input_shape}, dtype_dict={"input_1 ": "int8"} ) - target = tvm.target.target.micro( - model, - options=[ - "-link-params=1", - "--executor=aot", - "--unpacked-api=1", - "--interface-api=c", - "--workspace-byte-alignment=4", - ], + target = tvm.target.target.micro(model) + executor = Executor( + "aot", {"unpacked-api": True, "interface-api": "c", "workspace-byte-alignment": 4} ) + runtime = Runtime("crt") with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): - lowered = relay.build(relay_mod, target, params=params) + lowered = relay.build(relay_mod, target, params=params, runtime=runtime, executor=executor) - # Load sample and generate input/output header files sample_url = "https://github.com/tlc-pack/web-data/raw/967fc387dadb272c5a7f8c3461d34c060100dbf1/testdata/microTVM/data/keyword_spotting_int8_6.pyc.npy" sample_path = download_testdata(sample_url, "keyword_spotting_int8_6.pyc.npy", module="data") sample = np.load(sample_path) @@ -139,9 +134,11 @@ def test_qemu_make_fail(temp_dir, board, west_cmd, tvm_debug): func = relay.Function([x], z) ir_mod = tvm.IRModule.from_expr(func) - target = tvm.target.target.micro(model, options=["-link-params=1", "--executor=aot"]) + target = tvm.target.target.micro(model) + executor = Executor("aot") + runtime = Runtime("crt") with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): - lowered = relay.build(ir_mod, target) + lowered = relay.build(ir_mod, target, executor=executor, runtime=runtime) # Generate input/output header files with tempfile.NamedTemporaryFile() as tar_temp_file: diff --git a/tests/micro/zephyr/test_zephyr_armv7m.py b/tests/micro/zephyr/test_zephyr_armv7m.py index 9364b54c153f..53ffb1e43961 100644 --- a/tests/micro/zephyr/test_zephyr_armv7m.py +++ b/tests/micro/zephyr/test_zephyr_armv7m.py @@ -34,6 +34,7 @@ from tvm.contrib.download import download_testdata from tvm.micro.model_library_format import generate_c_interface_header from tvm.micro.testing import aot_transport_init_wait, aot_transport_find_message +from tvm.relay.backend import Executor, Runtime import test_utils @@ -182,27 +183,11 @@ def test_armv7m_intrinsic(temp_dir, board, west_cmd, tvm_debug): # kernel layout "HWIO" is not supported by arm_cpu SIMD extension (see tvm\python\relay\op\strategy\arm_cpu.py) relay_mod_no_simd = _apply_desired_layout_no_simd(relay_mod) - target = tvm.target.target.micro( - model, - options=[ - "-keys=cpu", - "-link-params=1", - "--executor=aot", - "--unpacked-api=1", - "--interface-api=c", - ], - ) + target = tvm.target.target.micro(model, options=["-keys=cpu"]) + target_simd = tvm.target.target.micro(model, options=["-keys=arm_cpu,cpu"]) - target_simd = tvm.target.target.micro( - model, - options=[ - "-keys=arm_cpu,cpu", - "-link-params=1", - "--executor=aot", - "--unpacked-api=1", - "--interface-api=c", - ], - ) + executor = Executor("aot", {"unpacked-api": True, "interface-api": "c"}) + runtime = Runtime("crt") temp_dir_simd = temp_dir / "simd" temp_dir_no_simd = temp_dir / "nosimd" @@ -212,7 +197,9 @@ def test_armv7m_intrinsic(temp_dir, board, west_cmd, tvm_debug): with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): lowered_simd = relay.build(relay_mod_simd, target_simd, params=params) - lowered_no_simd = relay.build(relay_mod_no_simd, target, params=params) + lowered_no_simd = relay.build( + relay_mod_no_simd, target, params=params, runtime=runtime, executor=executor + ) result_simd, time_simd = _run_model( temp_dir_simd, board, west_cmd, lowered_simd, build_config, sample, output_shape ) diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index 2ef84d7f1a6f..1bb854c1cf0a 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -24,7 +24,9 @@ import pytest import tvm +import tvm.testing from tvm.testing.utils import ethosn_available +from tvm.relay.backend import Runtime, Executor from tvm.contrib.target.vitis_ai import vitis_ai_available @@ -398,9 +400,11 @@ def test_compile_tflite_module_with_external_codegen_cmsisnn( output_file_name = f"{output_dir}/file.tar" - tvmc_package = tvmc.compiler.compile_model( + tvmc.compiler.compile_model( tvmc_model, - target=f"cmsis-nn, c -runtime=c --system-lib --link-params -mcpu=cortex-m55 -executor=aot", + target=f"cmsis-nn, c -mcpu=cortex-m55", + runtime=Runtime("crt", {"system-lib": True}), + executor=Executor("aot"), output_format="mlf", package_path=output_file_name, pass_context_configs=["tir.disable_vectorize=true"], @@ -476,9 +480,11 @@ def test_compile_tflite_module_with_external_codegen_ethosu( for accel_type in ACCEL_TYPES: output_file_name = f"{output_dir}/file_{accel_type}.tar" - tvmc_package = tvmc.compiler.compile_model( + tvmc.compiler.compile_model( tvmc_model, - target=f"ethos-u -accelerator_config={accel_type}, c -runtime=c --system-lib --link-params -mcpu=cortex-m55 -executor=aot", + target=f"ethos-u -accelerator_config={accel_type}, c -mcpu=cortex-m55", + runtime=Runtime("crt", {"system-lib": True}), + executor=Executor("aot"), output_format="mlf", package_path=output_file_name, pass_context_configs=["tir.disable_vectorize=true"], diff --git a/tests/python/driver/tvmc/test_mlf.py b/tests/python/driver/tvmc/test_mlf.py index 11306bd58848..4f61aec946d7 100644 --- a/tests/python/driver/tvmc/test_mlf.py +++ b/tests/python/driver/tvmc/test_mlf.py @@ -21,15 +21,17 @@ import sys import tvm +from tvm.autotvm.measure.executor import Executor from tvm.driver import tvmc from tvm.driver.tvmc.main import _main from tvm.driver.tvmc.model import TVMCPackage, TVMCException +from tvm.relay import backend -@pytest.mark.parametrize( - "target,pass_configs", [["llvm", []], ["c -executor=aot", ["tir.disable_vectorize=1"]]] -) -def test_tvmc_cl_compile_run_mlf(tflite_mobilenet_v1_1_quant, tmpdir_factory, target, pass_configs): +def test_tvmc_cl_compile_run_mlf(tflite_mobilenet_v1_1_quant, tmpdir_factory): + target = "c" + executor = "aot" + pass_configs = ["tir.disable_vectorize=1"] pytest.importorskip("tflite") output_dir = tmpdir_factory.mktemp("mlf") @@ -38,7 +40,7 @@ def test_tvmc_cl_compile_run_mlf(tflite_mobilenet_v1_1_quant, tmpdir_factory, ta # Compile the input model and generate a Model Library Format (MLF) archive. pass_config_args = " ".join([f"--pass-config {pass_config}" for pass_config in pass_configs]) - tvmc_cmd = f"tvmc compile {input_model} --target='{target}' {pass_config_args} --output {output_file} --output-format mlf" + tvmc_cmd = f"tvmc compile {input_model} --target={target} --executor={executor} {pass_config_args} --output {output_file} --output-format mlf" tvmc_args = shlex.split(tvmc_cmd)[1:] _main(tvmc_args) assert os.path.exists(output_file), "Could not find the exported MLF archive." @@ -114,7 +116,8 @@ def test_tvmc_import_package_mlf_aot(tflite_mobilenet_v1_1_quant, tflite_compile tflite_compiled_model_mlf = tflite_compile_model( tflite_mobilenet_v1_1_quant, - target="c -executor=aot", + target="c", + executor=backend.Executor("aot"), output_format="mlf", pass_context_configs=["tir.disable_vectorize=1"], ) diff --git a/tests/python/driver/tvmc/test_registry_options.py b/tests/python/driver/tvmc/test_registry_options.py new file mode 100644 index 000000000000..458d0a88d1f7 --- /dev/null +++ b/tests/python/driver/tvmc/test_registry_options.py @@ -0,0 +1,81 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import argparse + +import pytest + +from tvm.driver.tvmc.common import TVMCException +from tvm.driver.tvmc.registry import generate_registry_args, reconstruct_registry_entity +from tvm.relay.backend import Executor + + +def test_registry_to_argparse(): + parser = argparse.ArgumentParser() + generate_registry_args(parser, Executor) + parsed, _ = parser.parse_known_args(["--executor=aot", "--executor-aot-interface-api=c"]) + + assert parsed.executor == "aot" + assert parsed.executor_aot_interface_api == "c" + + +def test_registry_to_argparse_default(): + parser = argparse.ArgumentParser() + generate_registry_args(parser, Executor, "aot") + parsed, _ = parser.parse_known_args([]) + + assert parsed.executor == "aot" + + +def test_mapping_registered_args(): + parser = argparse.ArgumentParser() + generate_registry_args(parser, Executor) + parsed, _ = parser.parse_known_args(["--executor=aot", "--executor-aot-interface-api=c"]) + entity = reconstruct_registry_entity(parsed, Executor) + + assert isinstance(entity, Executor) + assert "interface-api" in entity + assert entity["interface-api"] == "c" + + +def test_mapping_registered_args_no_match_for_name(): + parser = argparse.ArgumentParser() + generate_registry_args(parser, Executor) + parsed, _ = parser.parse_known_args(["--executor=woof"]) + + with pytest.raises(TVMCException, match='Executor "woof" is not defined'): + reconstruct_registry_entity(parsed, Executor) + + +def test_mapping_registered_args_no_arg(): + parser = argparse.ArgumentParser() + generate_registry_args(parser, Executor) + parsed, _ = parser.parse_known_args([]) + + assert reconstruct_registry_entity(parsed, Executor) == None + + +def test_mapping_registered_args_mismatch_for_arg(): + parser = argparse.ArgumentParser() + generate_registry_args(parser, Executor) + parsed, _ = parser.parse_known_args(["--executor=aot", "--executor-graph-link-params=1"]) + + with pytest.raises( + TVMCException, + match="Passed --executor-graph-link-params but did not specify graph executor", + ): + reconstruct_registry_entity(parsed, Executor) diff --git a/tests/python/driver/tvmc/test_target.py b/tests/python/driver/tvmc/test_target.py index 865542ee25c1..06db5c47ea7e 100644 --- a/tests/python/driver/tvmc/test_target.py +++ b/tests/python/driver/tvmc/test_target.py @@ -103,15 +103,15 @@ def test_tokenize_target_with_dashes(): def test_parse_single_target_with_opts(): - targets = tvmc.common.parse_target("llvm -device=arm_cpu --system-lib") + targets = tvmc.common.parse_target("llvm -device=arm_cpu -mattr=+fp") assert len(targets) == 1 assert "device" in targets[0]["opts"] - assert "system-lib" in targets[0]["opts"] + assert "mattr" in targets[0]["opts"] def test_parse_multiple_target(): - targets = tvmc.common.parse_target("compute-library, llvm -device=arm_cpu --system-lib") + targets = tvmc.common.parse_target("compute-library, llvm -device=arm_cpu") assert len(targets) == 2 assert "compute-library" == targets[0]["name"] diff --git a/tests/python/relay/aot/aot_test_utils.py b/tests/python/relay/aot/aot_test_utils.py index c73af1948b57..acb9ffc5c4f5 100644 --- a/tests/python/relay/aot/aot_test_utils.py +++ b/tests/python/relay/aot/aot_test_utils.py @@ -35,7 +35,7 @@ from tvm import relay from tvm import te from tvm.contrib import utils, graph_executor -from tvm.relay.backend import te_compiler +from tvm.relay.backend import te_compiler, Executor, Runtime from tvm.relay.backend.te_compiler import TECompiler from tvm.relay.backend.utils import mangle_module_name from tvm.micro import export_model_library_format @@ -579,6 +579,8 @@ def compile_models( workspace_byte_alignment: int = 8, enable_op_fusion: bool = True, pass_config: Dict[str, Any] = None, + use_runtime_executor: bool = True, + target: str = "c", target_opts: Dict = None, ) -> List[AOTCompiledTestModel]: """ @@ -587,12 +589,18 @@ def compile_models( if not isinstance(models, list): models = [models] - base_target = "c -runtime=c --link-params --executor=aot" - extra_target = f"--workspace-byte-alignment={workspace_byte_alignment} --interface-api={interface_api} --unpacked-api={int(use_unpacked_api)}" + runtime = Runtime("crt") + executor = Executor( + "aot", + { + "workspace-byte-alignment": workspace_byte_alignment, + "interface-api": interface_api, + "unpacked-api": use_unpacked_api, + }, + ) if target_opts: for key, val in target_opts.items(): - extra_target += f" {key}={val}" - target = f"{base_target} {extra_target}" + target += f" {key}={val}" config = {"tir.disable_vectorize": True} if pass_config: @@ -603,15 +611,29 @@ def compile_models( compiled_mods = list() for model in models: with tvm.transform.PassContext(opt_level=3, config=config): - executor_factory = tvm.relay.build( - model.module, - tvm.target.Target(target, host=target), - params=model.params, - mod_name=model.name, - ) - compiled_mods.append( - AOTCompiledTestModel(model=model, executor_factory=executor_factory) - ) + # TODO(Mousius) - Remove once executor/runtime are fully removed from Target + if use_runtime_executor: + executor_factory = tvm.relay.build( + model.module, + tvm.target.Target(target, host=target), + executor=executor, + runtime=runtime, + params=model.params, + mod_name=model.name, + ) + compiled_mods.append( + AOTCompiledTestModel(model=model, executor_factory=executor_factory) + ) + else: + executor_factory = tvm.relay.build( + model.module, + tvm.target.Target(target, host=target), + params=model.params, + mod_name=model.name, + ) + compiled_mods.append( + AOTCompiledTestModel(model=model, executor_factory=executor_factory) + ) return compiled_mods @@ -733,6 +755,8 @@ def compile_and_run( workspace_byte_alignment: int = 8, enable_op_fusion: bool = True, data_linkage: AOTDataLinkage = None, + use_runtime_executor: bool = True, + target: str = "c", target_opts: Dict = None, ): """This is a wrapper API to compile and run models as test for AoT""" @@ -743,6 +767,8 @@ def compile_and_run( workspace_byte_alignment=workspace_byte_alignment, enable_op_fusion=enable_op_fusion, pass_config=runner.pass_config, + use_runtime_executor=use_runtime_executor, + target=target, target_opts=target_opts, ) run_and_check( diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py index e2bbb24c55d3..605d061918bd 100644 --- a/tests/python/relay/aot/test_crt_aot.py +++ b/tests/python/relay/aot/test_crt_aot.py @@ -28,6 +28,7 @@ from tvm.relay import testing, transform from tvm.relay.testing import byoc from tvm.relay.op.annotation import compiler_begin, compiler_end +from tvm.relay.backend import Executor, Runtime from aot_test_utils import ( AOTTestModel, AOT_DEFAULT_RUNNER, @@ -624,6 +625,39 @@ def test_name_sanitiser_name_clash(): ) +# This tests for deprecated AOT executor arguments +# TODO(Mousius) Remove deprecated arguments later +def test_deprecated_target_arguments(capsys): + """Tests we can still use relay.build with -executor, -runtime and -link-params""" + + interface_api = "c" + use_unpacked_api = True + test_runner = AOT_DEFAULT_RUNNER + + x = relay.var("x", shape=(1, 10)) + y = relay.var("y", shape=(1, 10)) + z = relay.add(x, y) + func = relay.Function([x, y], z) + + x_in = np.ones((1, 10)).astype("float32") + y_in = np.random.uniform(size=(1, 10)).astype("float32") + + params = {"x": x_in} + inputs = {"y": y_in} + output_list = generate_ref_data(func, inputs, params) + + compile_and_run( + AOTTestModel( + module=IRModule.from_expr(func), inputs=inputs, outputs=output_list, params=params + ), + test_runner, + interface_api, + use_unpacked_api, + use_runtime_executor=False, + target="c -executor=aot --link-params -runtime=c -interface-api=c --unpacked-api", + ) + + @pytest.mark.parametrize( "workspace_byte_alignment,main_workspace_size,sum_workspace_size", [ @@ -634,10 +668,16 @@ def test_name_sanitiser_name_clash(): ) def test_memory_planning(workspace_byte_alignment, main_workspace_size, sum_workspace_size): mod, params = tvm.relay.testing.synthetic.get_workload() - - target = f"c -runtime=c --link-params --executor=aot --workspace-byte-alignment={workspace_byte_alignment}" + target = "c" + runtime = Runtime("crt") + executor = Executor( + "aot", + { + "workspace-byte-alignment": workspace_byte_alignment, + }, + ) with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): - lib = tvm.relay.build(mod, target, params=params) + lib = tvm.relay.build(mod, target, executor=executor, runtime=runtime, params=params) assert ( sum(lib.function_metadata["__tvm_main__"].workspace_sizes.values()) == main_workspace_size diff --git a/tests/python/relay/test_build_module.py b/tests/python/relay/test_build_module.py new file mode 100644 index 000000000000..d812ad8f92a4 --- /dev/null +++ b/tests/python/relay/test_build_module.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +from tvm.target.target import Target +from tvm.relay.backend import Runtime, Executor +from tvm.relay.build_module import _reconstruct_from_deprecated_options + + +@pytest.mark.parametrize( + "target,executor,runtime", + [ + [Target("c"), None, None], + [Target("c -runtime=c"), None, Runtime("crt")], + [Target("c -system-lib"), None, Runtime("cpp", {"system-lib": True})], + [Target("c -runtime=c -system-lib"), None, Runtime("crt", {"system-lib": True})], + [Target("c -executor=aot"), Executor("aot"), None], + [ + Target("c -executor=aot -interface-api=c"), + Executor("aot", {"interface-api": "c"}), + None, + ], + [ + Target("c -executor=aot -unpacked-api=1"), + Executor("aot", {"unpacked-api": True}), + None, + ], + [Target("c -executor=aot -link-params=1"), Executor("aot"), None], + [Target("c -link-params=1"), Executor("graph", {"link-params": True}), None], + [ + Target( + "c -executor=aot -link-params=1 -interface-api=c" + " -unpacked-api=1 -runtime=c -system-lib" + ), + Executor("aot", {"unpacked-api": True, "interface-api": "c"}), + Runtime("crt", {"system-lib": True}), + ], + ], +) +def test_deprecated_target_parameters(target, executor, runtime): + actual_executor, actual_runtime = _reconstruct_from_deprecated_options(target) + assert executor == actual_executor + assert runtime == actual_runtime + + +if __name__ == "__main__": + pytest.main() diff --git a/tests/python/relay/test_executor.py b/tests/python/relay/test_executor.py index ebda4ff47cac..866339cb89fe 100644 --- a/tests/python/relay/test_executor.py +++ b/tests/python/relay/test_executor.py @@ -63,16 +63,16 @@ def test_create_executor_attr_type_incorrect(): def test_list_executors(): - assert "aot" in Executor.list_executors() + assert "aot" in Executor.list_registered() @pytest.mark.parametrize("executor", [Executor("aot"), "aot"]) def test_list_executor_options(executor): - aot_options = Executor.list_executor_options(executor) + aot_options = Executor.list_registered_options(executor) assert "interface-api" in aot_options assert aot_options["interface-api"] == "runtime.String" def test_list_executor_options_not_found(): with pytest.raises(TVMError, match='Executor "woof" is not defined'): - Executor.list_executor_options("woof") + Executor.list_registered_options("woof") diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 90d88169225c..7d79698acf12 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -23,10 +23,11 @@ import tvm from tvm.relay.backend import te_compiler +from tvm.relay.backend.runtime import Runtime import tvm.relay.testing import tvm.relay.op as reg from tvm import relay -from tvm import runtime +from tvm import runtime as tvm_runtime from tvm.relay import transform from tvm.relay.testing import byoc from tvm.contrib import utils @@ -121,7 +122,15 @@ def visit_call(self, call): def check_result( - mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", device=tvm.cpu(), params=None + mod, + map_inputs, + out_shape, + result, + tol=1e-5, + target="llvm", + device=tvm.cpu(), + params=None, + runtime=Runtime("cpp"), ): if sys.platform == "win32": print("Skip test on Windows for now") @@ -138,7 +147,7 @@ def update_lib(lib): lib_name = "lib.so" lib_path = tmp_path.relpath(lib_name) lib.export_library(lib_path, fcompile=False, **kwargs) - lib = runtime.load_module(lib_path) + lib = tvm_runtime.load_module(lib_path) return lib @@ -148,10 +157,10 @@ def check_vm_result(): exe = relay.vm.compile(mod, target=target, params=params) code, lib = exe.save() lib = update_lib(lib) - exe = runtime.vm.Executable.load_exec(code, lib) - vm = runtime.vm.VirtualMachine(exe, device) + exe = tvm_runtime.vm.Executable.load_exec(code, lib) + vm = tvm_runtime.vm.VirtualMachine(exe, device) outs = vm.run(**map_inputs) - outs = outs if isinstance(outs, runtime.container.ADT) else [outs] + outs = outs if isinstance(outs, tvm_runtime.container.ADT) else [outs] results = result if isinstance(result, list) else [result] for out, ref in zip(outs, results): tvm.testing.assert_allclose(out.numpy(), ref, rtol=tol, atol=tol) @@ -159,7 +168,7 @@ def check_vm_result(): def check_graph_executor_result(): te_compiler.get().clear() with tvm.transform.PassContext(opt_level=3): - json, lib, param = relay.build(mod, target=target, params=params) + json, lib, param = relay.build(mod, target=target, params=params, runtime=runtime) lib = update_lib(lib) rt_mod = tvm.contrib.graph_executor.create(json, lib, device) @@ -222,8 +231,8 @@ def test_multi_node_compiler(): map_inputs = {"w{}".format(i): w_data[i] for i in range(8)} map_inputs["x"] = x_data - targets = ["llvm", "c -runtime=c --system-lib"] - for tgt in targets: + targets = [("llvm", Runtime("cpp")), ("c", Runtime("crt", {"system-lib": True}))] + for tgt, rt in targets: check_result( mod, map_inputs, @@ -237,6 +246,7 @@ def test_multi_node_compiler(): axis=0, ), target=tgt, + runtime=rt, ) diff --git a/tests/python/relay/test_runtime.py b/tests/python/relay/test_runtime.py index d78b822411bc..ea15dd0d3c88 100644 --- a/tests/python/relay/test_runtime.py +++ b/tests/python/relay/test_runtime.py @@ -27,13 +27,13 @@ def test_create(): def test_create_runtime_with_options(): - runtime = Runtime("c", {"system-lib": True}) - assert str(runtime) == "c" + runtime = Runtime("crt", {"system-lib": True}) + assert str(runtime) == "crt" assert runtime["system-lib"] def test_attr_check(): - runtime = Runtime("c", {"system-lib": True}) + runtime = Runtime("crt", {"system-lib": True}) assert "woof" not in runtime assert "system-lib" in runtime @@ -45,7 +45,7 @@ def test_create_runtime_not_found(): def test_create_runtime_attr_not_found(): with pytest.raises(TVMError, match='Attribute "woof" is not available on this Runtime'): - Runtime("c", {"woof": "bark"}) + Runtime("crt", {"woof": "bark"}) def test_create_runtime_attr_type_incorrect(): @@ -54,20 +54,20 @@ def test_create_runtime_attr_type_incorrect(): match='Attribute "system-lib" should have type "IntImm"' ' but instead found "runtime.String"', ): - Runtime("c", {"system-lib": "woof"}) + Runtime("crt", {"system-lib": "woof"}) def test_list_runtimes(): - assert "c" in Runtime.list_runtimes() + assert "crt" in Runtime.list_registered() -@pytest.mark.parametrize("runtime", [Runtime("c"), "c"]) +@pytest.mark.parametrize("runtime", [Runtime("crt"), "crt"]) def test_list_runtime_options(runtime): - aot_options = Runtime.list_runtime_options(runtime) + aot_options = Runtime.list_registered_options(runtime) assert "system-lib" in aot_options assert aot_options["system-lib"] == "IntImm" def test_list_runtime_options_not_found(): with pytest.raises(TVMError, match='Runtime "woof" is not defined'): - Runtime.list_runtime_options("woof") + Runtime.list_registered_options("woof") diff --git a/tests/python/unittest/test_auto_scheduler_search_task.py b/tests/python/unittest/test_auto_scheduler_search_task.py index f23b02c24298..9197a2097ebc 100644 --- a/tests/python/unittest/test_auto_scheduler_search_task.py +++ b/tests/python/unittest/test_auto_scheduler_search_task.py @@ -114,7 +114,7 @@ def test_search_task_record(): assert new_task.task_input_names[1] == "test_input_1" # Log with version 0.5 - v5_log = """["[\\\"matmul_auto_scheduler_test\\\", 64, 64, 64]", "llvm -keys=cpu -link-params=0", [6, 64, 64, 0, 0, 0, 0, 0], "", 1]""" + v5_log = """["[\\\"matmul_auto_scheduler_test\\\", 64, 64, 64]", "llvm -keys=cpu", [6, 64, 64, 0, 0, 0, 0, 0], "", 1]""" new_task = auto_scheduler._ffi_api.DeserializeSearchTask(v5_log) assert task.workload_key == new_task.workload_key assert str(task.target) == str(new_task.target) @@ -191,7 +191,7 @@ def test_recover_measure_input_with_task_input(): assert new_task.task_input_names[1] == "test_input_1" # Log with version 0.5 - v5_log = """{"i": [["[\\\"matmul_auto_scheduler_test\\\", 512, 512, 512]", "llvm -keys=cpu -link-params=0", [6, 64, 64, 0, 0, 0, 0, 0], "", 1], [[], []]], "r": [[0.1], 0, 0.2, 1], "v": "v0.6"}""" + v5_log = """{"i": [["[\\\"matmul_auto_scheduler_test\\\", 512, 512, 512]", "llvm -keys=cpu", [6, 64, 64, 0, 0, 0, 0, 0], "", 1], [[], []]], "r": [[0.1], 0, 0.2, 1], "v": "v0.6"}""" measure_log = auto_scheduler.measure_record.load_record_from_string(v5_log) new_task = measure_log[0].task assert task.workload_key == new_task.workload_key diff --git a/tests/python/unittest/test_crt.py b/tests/python/unittest/test_crt.py index fbf908170938..426484343679 100644 --- a/tests/python/unittest/test_crt.py +++ b/tests/python/unittest/test_crt.py @@ -31,6 +31,7 @@ import tvm.relay import tvm.testing from tvm.target import Target +from tvm.relay.backend import Runtime from tvm.topi.utils import get_const_tuple from tvm.topi.testing import conv2d_nchw_python @@ -42,8 +43,9 @@ def _make_sess_from_op(temp_dir, op_name, sched, arg_bufs): + runtime = Runtime("crt", {"system-lib": True}) with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): - mod = tvm.build(sched, arg_bufs, Target(TARGET, TARGET), name=op_name) + mod = tvm.build(sched, arg_bufs, Target(TARGET, TARGET), runtime=runtime, name=op_name) return _make_session(temp_dir, mod) @@ -143,8 +145,9 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), uint8]) { }""" ) + runtime = Runtime("crt", {"system-lib": True}) with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): - factory = tvm.relay.build(relay_mod, target=TARGET) + factory = tvm.relay.build(relay_mod, target=TARGET, runtime=runtime) with _make_session(temp_dir, factory) as sess: graph_mod = tvm.micro.create_local_graph_executor( @@ -221,6 +224,8 @@ def test_autotune(): import tvm.relay as relay from tvm.micro.testing import check_tune_log + runtime = Runtime("crt", {"system-lib": True}) + data = relay.var("data", relay.TensorType((1, 3, 64, 64), "float32")) weight = relay.var("weight", relay.TensorType((8, 3, 5, 5), "float32")) y = relay.nn.conv2d( @@ -261,6 +266,7 @@ def test_autotune(): build_kwargs={"build_option": {"tir.disable_vectorize": True}}, do_fork=True, build_func=tvm.micro.autotvm_build_func, + runtime=runtime, ) runner = tvm.autotvm.LocalRunner(number=1, repeat=1, module_loader=module_loader) @@ -288,7 +294,7 @@ def test_autotune(): # Build without tuning with pass_context: - lowered = tvm.relay.build(mod, target=TARGET, params=params) + lowered = tvm.relay.build(mod, target=TARGET, runtime=runtime, params=params) temp_dir = tvm.contrib.utils.tempdir() project = tvm.micro.generate_project(template_project_dir, lowered, temp_dir / "project") @@ -305,7 +311,7 @@ def test_autotune(): # Build using autotune logs with tvm.autotvm.apply_history_best(str(tune_log_file)): with pass_context: - lowered_tuned = tvm.relay.build(mod, target=target, params=params) + lowered_tuned = tvm.relay.build(mod, target=target, runtime=runtime, params=params) temp_dir = tvm.contrib.utils.tempdir() project = tvm.micro.generate_project(template_project_dir, lowered_tuned, temp_dir / "project") diff --git a/tests/python/unittest/test_link_params.py b/tests/python/unittest/test_link_params.py index 4f9cfffd0640..57ec70d879f6 100644 --- a/tests/python/unittest/test_link_params.py +++ b/tests/python/unittest/test_link_params.py @@ -19,7 +19,6 @@ import json import os import re -import struct import sys import tempfile @@ -29,6 +28,7 @@ import tvm import tvm.relay import tvm.testing +from tvm.relay.backend import Executor, Runtime from tvm.contrib import utils @@ -185,10 +185,13 @@ def test_llvm_link_params(): for dtype in LINKABLE_DTYPES: ir_mod, param_init = _make_mod_and_params(dtype) rand_input = _make_random_tensor(dtype, INPUT_SHAPE) - main_func = ir_mod["main"] - target = "llvm --runtime=c --system-lib --link-params" + target = "llvm" + runtime = Runtime("crt", {"system-lib": True}) + executor = Executor("graph", {"link-params": True}) with tvm.transform.PassContext(opt_level=3): - lib = tvm.relay.build(ir_mod, target, params=param_init) + lib = tvm.relay.build( + ir_mod, target, runtime=runtime, executor=executor, params=param_init + ) # NOTE: Need to export_library() and load_library() to link all the Module(llvm, ...) # against one another. @@ -213,8 +216,9 @@ def _run_linked(lib, mod): linked_output = _run_linked(lib, mod) + runtime = Runtime("cpp", {"system-lib": True}) with tvm.transform.PassContext(opt_level=3): - lib = tvm.relay.build(ir_mod, "llvm --system-lib", params=param_init) + lib = tvm.relay.build(ir_mod, "llvm", runtime=runtime, params=param_init) def _run_unlinked(lib): graph_json, mod, lowered_params = lib @@ -268,10 +272,10 @@ def test_c_link_params(): for dtype in LINKABLE_DTYPES: mod, param_init = _make_mod_and_params(dtype) rand_input = _make_random_tensor(dtype, INPUT_SHAPE) - main_func = mod["main"] - target = "c --link-params" + target = "c" + executor = Executor("graph", {"link-params": True}) with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): - lib = tvm.relay.build(mod, target, params=param_init) + lib = tvm.relay.build(mod, target, executor=executor, params=param_init) assert set(lib.params.keys()) == {"p0", "p1"} # NOTE: op folded src = lib.lib.get_source() @@ -353,10 +357,13 @@ def test_crt_link_params(): for dtype in LINKABLE_DTYPES: mod, param_init = _make_mod_and_params(dtype) rand_input = _make_random_tensor(dtype, INPUT_SHAPE) - main_func = mod["main"] - target = "c --system-lib --runtime=c --link-params" + target = "c" + runtime = Runtime("crt", {"system-lib": True}) + executor = Executor("graph", {"link-params": True}) with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): - factory = tvm.relay.build(mod, target, params=param_init) + factory = tvm.relay.build( + mod, target, runtime=runtime, executor=executor, params=param_init + ) assert set(factory.get_params().keys()) == {"p0", "p1"} # NOTE: op folded temp_dir = tvm.contrib.utils.tempdir() @@ -378,8 +385,9 @@ def test_crt_link_params(): graph_rt.run() linked_output = graph_rt.get_output(0).numpy() + runtime = Runtime("cpp", {"system-lib": True}) with tvm.transform.PassContext(opt_level=3): - lib = tvm.relay.build(mod, "llvm --system-lib", params=param_init) + lib = tvm.relay.build(mod, "llvm", runtime=runtime, params=param_init) def _run_unlinked(lib): graph_json, mod, lowered_params = lib diff --git a/tests/python/unittest/test_micro_model_library_format.py b/tests/python/unittest/test_micro_model_library_format.py index 2baec5698e5e..48437eaf58d7 100644 --- a/tests/python/unittest/test_micro_model_library_format.py +++ b/tests/python/unittest/test_micro_model_library_format.py @@ -25,7 +25,7 @@ import tvm import tvm.relay -from tvm.relay.backend import executor_factory +from tvm.relay.backend import Executor, Runtime from tvm.relay.testing import byoc import tvm.runtime.module import tvm.testing @@ -42,7 +42,13 @@ def test_export_operator_model_library_format(): B = tvm.te.placeholder((1,), dtype="int8") C = tvm.te.compute(A.shape, lambda i: A[i] + B[0], name="C") sched = tvm.te.create_schedule(C.op) - mod = tvm.build(sched, [A, B, C], tvm.target.Target(target, target), name="add") + mod = tvm.build( + sched, + [A, B, C], + tvm.target.Target(target, target), + runtime=Runtime("crt", {"system-lib": True}), + name="add", + ) temp_dir = utils.tempdir() mlf_tar_path = temp_dir.relpath("lib.tar") @@ -81,7 +87,7 @@ def test_export_operator_model_library_format(): assert ( len(mod.ir_module_by_target) == 1 - ), f"expect 1 ir_model_by_target: {ir_module_by_target!r}" + ), f"expect 1 ir_model_by_target: {mod.ir_module_by_target!r}" for target, ir_mod in mod.ir_module_by_target.items(): assert int(tvm.runtime.ndarray.device(str(target)).device_type) == 1 with open(os.path.join(extract_dir, "src", "tir-1.txt")) as tir_f: @@ -102,20 +108,15 @@ def validate_graph_json(extract_dir, factory): @tvm.testing.requires_micro @pytest.mark.parametrize( - "executor,target,should_generate_interface", + "executor,runtime,should_generate_interface", [ - ("graph", tvm.target.target.micro("host"), False), - ("aot", tvm.target.target.micro("host", options="-executor=aot"), False), - ( - "aot", - tvm.target.target.micro( - "host", options="-executor=aot --unpacked-api=1 --interface-api=c" - ), - True, - ), + (Executor("graph"), Runtime("crt", {"system-lib": True}), False), + (Executor("aot"), Runtime("crt"), False), + (Executor("aot", {"unpacked-api": True, "interface-api": "c"}), Runtime("crt"), True), ], ) -def test_export_model_library_format_c(executor, target, should_generate_interface): +def test_export_model_library_format_c(executor, runtime, should_generate_interface): + target = tvm.target.target.micro("host") with utils.TempDirectory.set_keep_for_debug(True): with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): relay_mod = tvm.parser.fromtext( @@ -129,6 +130,8 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[ factory = tvm.relay.build( relay_mod, target, + executor=executor, + runtime=runtime, mod_name="add", params={"c": numpy.array([[2.0, 4.0]], dtype="float32")}, ) @@ -153,7 +156,7 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[ ) assert (datetime.datetime.now() - export_datetime) < datetime.timedelta(seconds=60 * 5) assert metadata["target"] == {"1": str(target)} - if executor == "graph": + if str(executor) == "graph": assert metadata["memory"]["sids"] == [ {"storage_id": 0, "size_bytes": 2, "input_binding": "a"}, {"storage_id": 1, "size_bytes": 8, "input_binding": "b"}, @@ -182,7 +185,7 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[ os.path.join(extract_dir, "codegen", "host", "include", "tvmgen_add.h") ) - if executor == "graph": + if str(executor) == "graph": validate_graph_json(extract_dir, factory) with open(os.path.join(extract_dir, "src", "relay.txt")) as relay_f: @@ -211,6 +214,7 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[ factory = tvm.relay.build( relay_mod, target, + runtime=Runtime("crt", {"system-lib": True}), mod_name="add", params={"c": numpy.array([[2.0, 4.0]], dtype="float32")}, ) @@ -271,10 +275,11 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[ @tvm.testing.requires_micro @pytest.mark.parametrize( - "target", - [tvm.target.target.micro("host"), tvm.target.target.micro("host", options="-executor=aot")], + "executor,runtime", + [(Executor("graph"), Runtime("crt", {"system-lib": True})), (Executor("aot"), Runtime("crt"))], ) -def test_export_model_library_format_workspace(target): +def test_export_model_library_format_workspace(executor, runtime): + target = tvm.target.target.micro("host") with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): relay_mod = tvm.parser.fromtext( """ @@ -288,7 +293,13 @@ def @main(%p0: Tensor[(1, 56, 56, 128), int16], %p1: Tensor[(3, 3, 128, 1), int1 } """ ) - factory = tvm.relay.build(relay_mod, target, mod_name="qnn_conv2d") + factory = tvm.relay.build( + relay_mod, + target, + executor=executor, + runtime=runtime, + mod_name="qnn_conv2d", + ) temp_dir = utils.tempdir() mlf_tar_path = temp_dir.relpath("lib.tar") @@ -381,7 +392,7 @@ def test_export_byoc_c_module(): mod = tvm.relay.transform.InferType()(mod) with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): - factory = tvm.relay.build(mod, tvm.target.target.micro("host")) + factory = tvm.relay.build(mod, tvm.target.target.micro("host"), runtime=Runtime("crt")) temp_dir = utils.tempdir() mlf_tar_path = temp_dir.relpath("lib.tar") diff --git a/tests/python/unittest/test_runtime_module_load.py b/tests/python/unittest/test_runtime_module_load.py index 7bf4d72b047e..f17a615ce2c1 100644 --- a/tests/python/unittest/test_runtime_module_load.py +++ b/tests/python/unittest/test_runtime_module_load.py @@ -18,11 +18,11 @@ from tvm import te from tvm.contrib import cc, utils import ctypes -import os import sys import numpy as np import subprocess import tvm.testing +from tvm.relay.backend import Runtime runtime_py = """ import os @@ -117,7 +117,8 @@ def check_device(device): temp = utils.tempdir() name = "myadd_%s" % device if sys.platform == "darwin" or sys.platform.startswith("linux"): - f = tvm.build(s, [A, B], device, "llvm -system-lib", name=name) + runtime = Runtime("cpp", {"system-lib": True}) + f = tvm.build(s, [A, B], device, "llvm", runtime=runtime, name=name) elif sys.platform == "win32": f = tvm.build(s, [A, B], device, "llvm", name=name) else: @@ -198,8 +199,9 @@ def check_system_lib(): print("Skip because llvm is not enabled") return temp = utils.tempdir() - fadd1 = tvm.build(s, [A, B], "llvm -system-lib", name="myadd1") - fadd2 = tvm.build(s, [A, B], "llvm -system-lib", name="myadd2") + runtime = Runtime("cpp", {"system-lib": True}) + fadd1 = tvm.build(s, [A, B], "llvm", runtime=runtime, name="myadd1") + fadd2 = tvm.build(s, [A, B], "llvm", runtime=runtime, name="myadd2") path1 = temp.relpath("myadd1.o") path2 = temp.relpath("myadd2.o") path_dso = temp.relpath("mylib.so") @@ -207,7 +209,7 @@ def check_system_lib(): fadd2.save(path2) cc.create_shared(path_dso, [path1, path2]) # Load dll, will trigger system library registration - dll = ctypes.CDLL(path_dso) + ctypes.CDLL(path_dso) # Load the system wide library mm = tvm.runtime.system_lib() a = tvm.nd.array(np.random.uniform(size=nn).astype(A.dtype), dev) diff --git a/tests/python/unittest/test_runtime_rpc.py b/tests/python/unittest/test_runtime_rpc.py index 6e1fc815d66d..f0ddcb60a1fd 100644 --- a/tests/python/unittest/test_runtime_rpc.py +++ b/tests/python/unittest/test_runtime_rpc.py @@ -17,7 +17,6 @@ import tvm from tvm import te import tvm.testing -import logging import multiprocessing import os import stat @@ -27,6 +26,7 @@ import pytest import numpy as np from tvm import rpc +from tvm.relay.backend import Runtime from tvm.contrib import utils, cc from tvm.rpc.tracker import Tracker from tvm.rpc.proxy import Proxy @@ -267,7 +267,8 @@ def check_minrpc(): return # export to minrpc temp = utils.tempdir() - f = tvm.build(s, [A, B], "llvm --system-lib", name="myadd") + runtime = Runtime("cpp", {"system-lib": True}) + f = tvm.build(s, [A, B], "llvm", name="myadd", runtime=runtime) path_minrpc = temp.relpath("dev_lib.minrpc") f.export_library(path_minrpc, rpc.with_minrpc(cc.create_executable)) diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index 5a1b33ae10b1..ed097de08699 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -22,10 +22,10 @@ import tvm import tvm.testing from tvm import te -from tvm import topi +from tvm.relay.backend import Runtime from tvm.contrib import utils, clang import numpy as np -import ctypes + import math import re import pytest @@ -747,7 +747,12 @@ def test_llvm_crt_static_lib(): B = te.placeholder((32,), dtype="bfloat16") d = te.compute((32,), lambda x: A[x] + B[x]) sch = te.create_schedule(d.op) - module = tvm.build(sch, [A, B, d], target=tvm.target.Target("llvm --system-lib --runtime=c")) + module = tvm.build( + sch, + [A, B, d], + target=tvm.target.Target("llvm"), + runtime=Runtime("crt", {"system-lib": True}), + ) print(module.get_source()) module.save("test.o") diff --git a/tests/python/unittest/test_target_target.py b/tests/python/unittest/test_target_target.py index 0a0ad49a7767..46ceeae57e34 100644 --- a/tests/python/unittest/test_target_target.py +++ b/tests/python/unittest/test_target_target.py @@ -104,7 +104,6 @@ def test_target_config(): "keys": ["arm_cpu", "cpu"], "device": "arm_cpu", "libs": ["cblas"], - "system-lib": True, "mfloat-abi": "hard", "mattr": ["+neon", "-avx512f"], } @@ -117,7 +116,6 @@ def test_target_config(): assert all([key in target.keys for key in ["arm_cpu", "cpu"]]) assert target.device_name == "arm_cpu" assert target.libs == ["cblas"] - assert "system-lib" in str(target) assert target.attrs["mfloat-abi"] == "hard" assert all([attr in target.attrs["mattr"] for attr in ["+neon", "-avx512f"]]) @@ -303,14 +301,14 @@ def test_check_and_update_host_consist_3(): def test_target_attr_bool_value(): - target0 = Target("llvm --link-params=True") - assert target0.attrs["link-params"] == 1 - target1 = Target("llvm --link-params=true") - assert target1.attrs["link-params"] == 1 - target2 = Target("llvm --link-params=False") - assert target2.attrs["link-params"] == 0 - target3 = Target("llvm --link-params=false") - assert target3.attrs["link-params"] == 0 + target0 = Target("vulkan --supports_float16=True") + assert target0.attrs["supports_float16"] == 1 + target1 = Target("vulkan --supports_float16=true") + assert target1.attrs["supports_float16"] == 1 + target2 = Target("vulkan --supports_float16=False") + assert target2.attrs["supports_float16"] == 0 + target3 = Target("vulkan --supports_float16=false") + assert target3.attrs["supports_float16"] == 0 if __name__ == "__main__": diff --git a/web/tests/python/prepare_test_libs.py b/web/tests/python/prepare_test_libs.py index fa086e6ff2c8..5c1f7c68c421 100644 --- a/web/tests/python/prepare_test_libs.py +++ b/web/tests/python/prepare_test_libs.py @@ -19,18 +19,20 @@ import tvm from tvm import te from tvm.contrib import emcc +from tvm.relay.backend import Runtime import os def prepare_test_libs(base_path): - target = "llvm -mtriple=wasm32-unknown-unknown-wasm -system-lib" + runtime = Runtime("cpp", {"system-lib": True}) + target = "llvm -mtriple=wasm32-unknown-unknown-wasm" if not tvm.runtime.enabled(target): raise RuntimeError("Target %s is not enbaled" % target) n = te.var("n") A = te.placeholder((n,), name="A") B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name="B") s = te.create_schedule(B.op) - fadd = tvm.build(s, [A, B], target, name="add_one") + fadd = tvm.build(s, [A, B], target, runtime=runtime, name="add_one") wasm_path = os.path.join(base_path, "test_addone.wasm") fadd.export_library(wasm_path, emcc.create_tvmjs_wasm) diff --git a/web/tests/python/webgpu_rpc_test.py b/web/tests/python/webgpu_rpc_test.py index ad27735c04ca..ac1a241a9baa 100644 --- a/web/tests/python/webgpu_rpc_test.py +++ b/web/tests/python/webgpu_rpc_test.py @@ -24,6 +24,7 @@ from tvm import te from tvm import rpc from tvm.contrib import utils, emcc +from tvm.relay.backend import Runtime import numpy as np proxy_host = "127.0.0.1" @@ -34,9 +35,8 @@ def test_rpc(): if not tvm.runtime.enabled("rpc"): return # generate the wasm library - target = tvm.target.Target( - "webgpu", host="llvm -mtriple=wasm32-unknown-unknown-wasm -system-lib" - ) + target = tvm.target.Target("webgpu", host="llvm -mtriple=wasm32-unknown-unknown-wasm") + runtime = Runtime("cpp", {"system-lib": True}) if not tvm.runtime.enabled(target_host): raise RuntimeError("Target %s is not enbaled" % target_host) @@ -50,7 +50,7 @@ def test_rpc(): s[B].bind(xi, te.thread_axis("threadIdx.x")) s[B].bind(xo, te.thread_axis("blockIdx.x")) - fadd = tvm.build(s, [A, B], target, name="addone") + fadd = tvm.build(s, [A, B], target, runtime=runtime, name="addone") temp = utils.tempdir() wasm_path = temp.relpath("addone_gpu.wasm") diff --git a/web/tests/python/websock_rpc_test.py b/web/tests/python/websock_rpc_test.py index ee94e40a678c..9aab1759f8dd 100644 --- a/web/tests/python/websock_rpc_test.py +++ b/web/tests/python/websock_rpc_test.py @@ -24,6 +24,7 @@ from tvm import te from tvm import rpc from tvm.contrib import utils, emcc +from tvm.relay.backend import Runtime import numpy as np proxy_host = "127.0.0.1" @@ -34,7 +35,8 @@ def test_rpc(): if not tvm.runtime.enabled("rpc"): return # generate the wasm library - target = "llvm -mtriple=wasm32-unknown-unknown-wasm -system-lib" + runtime = Runtime("cpp", {"system-lib": True}) + target = "llvm -mtriple=wasm32-unknown-unknown-wasm" if not tvm.runtime.enabled(target): raise RuntimeError("Target %s is not enbaled" % target) n = te.var("n") @@ -42,7 +44,7 @@ def test_rpc(): B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name="B") s = te.create_schedule(B.op) - fadd = tvm.build(s, [A, B], target, name="addone") + fadd = tvm.build(s, [A, B], target, runtime=runtime, name="addone") temp = utils.tempdir() wasm_path = temp.relpath("addone.wasm")