diff --git a/include/tvm/target/compilation_config.h b/include/tvm/target/compilation_config.h index 87b9798b20e8..8946a104dac4 100644 --- a/include/tvm/target/compilation_config.h +++ b/include/tvm/target/compilation_config.h @@ -27,6 +27,8 @@ #include +#include + namespace tvm { /*! @@ -68,14 +70,20 @@ class CompilationConfigNode : public Object { * \p host_target, however the \p host_target should be used for all host computations and data. * Each \p Target will have \p host_target as its 'host'. * + * Primitive targets must be unique by their kind name. In this way the + * \p FindPrimitiveTargetForKind method will find the unique target for the given kind name. + * This method is used when transitioning from an external codegen "Compiler" attribute value + * to the external codegen target representing that compiler. + * * It is possible to have multiple primitive targets for the same device type. However given * primitive targets left and right where: * - left appears before right in the array * - left->kind->device_type == right->kind->device_type * then: * - right.IsExternalCodegenFor(left) must be true - * In this way the FindPrimitiveTargetOrFail method will find the 'most general' target for - * the requested device type. + * In this way the \p FindPrimitiveTargetForDeviceOrFail method will find the 'most general' + * target for the requested device type. This method is used when transitioning from a device + * constraint to the target needed to compile for that device. * * In the homogeneous case primitive_targets will have just one entry, which will be pointer equal * to optional_homogeneous_target. @@ -114,11 +122,16 @@ class CompilationConfigNode : public Object { void VisitAttrs(AttrVisitor* v); /*! - * \brief Return the unique \p Target to use for \p device_type. Fail if no such target exists. + * \brief Returns the unique \p Target to use for \p device_type. Fail if no such target exists. * * This will be the first primitive target with matching device type. */ - Target FindPrimitiveTargetOrFail(DLDeviceType device_type) const; + Target FindPrimitiveTargetForDeviceOrFail(DLDeviceType device_type) const; + + /*! + * \brief Returns the unique \p Target to use for \p kind_name. Returns null if none such. + */ + Optional FindPrimitiveTargetForKind(const std::string& kind_name) const; /*! * \brief Returns a \p VirtualDevice agreeing with \p virtual_device on all its constrained diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 4148cdbd3c94..2a4a03bbe8e7 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -224,9 +224,7 @@ def recover_measure_input(inp, rebuild_state=False): from .search_task import SearchTask # lazily import to avoid recursive dependency task = inp.task - task.target, task.target_host = Target.check_and_update_host_consist( - task.target, task.target_host - ) + task.target, task.target_host = Target.canon_target_and_host(task.target, task.target_host) new_task = SearchTask( workload_key=task.workload_key, target=task.target, @@ -612,9 +610,7 @@ def _local_build_worker(inp_serialized, build_func, verbose): tic = time.time() inp = MeasureInput.deserialize(inp_serialized) task = inp.task - task.target, task.target_host = Target.check_and_update_host_consist( - task.target, task.target_host - ) + task.target, task.target_host = Target.canon_target_and_host(task.target, task.target_host) error_no = MeasureErrorNo.NO_ERROR error_msg = None diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index 7ff1840c9123..e9bf1ccfd7cc 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -26,7 +26,6 @@ import logging import threading import traceback -import warnings import tvm from tvm import autotvm, transform @@ -115,13 +114,7 @@ def extract_tasks( The weight (i.e. the number of appearance) of extracted tasks """ # pylint: disable=import-outside-toplevel - if target_host is not None: - warnings.warn( - "target_host parameter is going to be deprecated. " - "Please pass in tvm.target.Target(target, host=target_host) instead." - ) - - target, target_host = Target.check_and_update_host_consist(target, target_host) + target, target_host = Target.canon_target_and_host(target, target_host) # Run the compiler to collect all TOPI calls during compilation. env = TracingEnvironment( diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index f1156998bdac..56dcb56abc6d 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -380,9 +380,9 @@ class SearchTask(Object): The ComputeDAG for the corresponding compute declaration. workload_key : str The workload key for the corresponding compute declaration. - target : tvm.target.Target + target : any target-like object, see Target.canon_target The target device of this search task. - target_host : Optional[tvm.target.Target] + target_host : None or any target-like object, see Target.canon_target The target host device of this search task. hardware_params : Optional[HardwareParams] Hardware parameters used in this search task. @@ -448,7 +448,7 @@ def __init__( assert target is not None, "Must specify a target." - target, target_host = Target.check_and_update_host_consist(target, target_host) + target, target_host = Target.canon_target_and_host(target, target_host) if layout_rewrite_option is None: layout_rewrite_option = LayoutRewriteOption.get_target_default(target) @@ -559,9 +559,7 @@ def print_best(self, log_file, print_mode="schedule"): raise ValueError("Invalid print_mode: %s" % print_mode) def __getstate__(self): - self.target, self.target_host = Target.check_and_update_host_consist( - self.target, self.target_host - ) + self.target, self.target_host = Target.canon_target_and_host(self.target, self.target_host) return { "compute_dag": self.compute_dag, "workload_key": self.workload_key, @@ -587,7 +585,7 @@ def __setstate__(self, state): if workload[0] not in WORKLOAD_FUNC_REGISTRY: register_workload_tensors(state["workload_key"], state["compute_dag"].tensors) - state["target"], state["target_host"] = Target.check_and_update_host_consist( + state["target"], state["target_host"] = Target.canon_target_and_host( state["target"], state["target_host"] ) self.__init_handle_by_constructor__( diff --git a/python/tvm/autotvm/graph_tuner/base_graph_tuner.py b/python/tvm/autotvm/graph_tuner/base_graph_tuner.py index 25d56cf8cf02..d4054bbd3701 100644 --- a/python/tvm/autotvm/graph_tuner/base_graph_tuner.py +++ b/python/tvm/autotvm/graph_tuner/base_graph_tuner.py @@ -443,7 +443,7 @@ def benchmark_layout_transform( Accept a user-supplied runner """ self._logger.info("Start to benchmark layout transformation...") - self._target, target_host = Target.check_and_update_host_consist(self._target, target_host) + self._target, target_host = Target.canon_target_and_host(self._target, target_host) if layout_records is None and infer_layout: raise RuntimeError("Requires some records to infer layout transformation time.") diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py index 6ebbbb653140..f582bd1974aa 100644 --- a/python/tvm/autotvm/measure/measure_methods.py +++ b/python/tvm/autotvm/measure/measure_methods.py @@ -496,7 +496,7 @@ def set_task(self, task): def _build_func_common(measure_input, runtime=None, check_gpu=None, build_option=None): """Common part for building a configuration""" target, task, config = measure_input - target, task.target_host = Target.check_and_update_host_consist(target, task.target_host) + target, task.target_host = Target.canon_target_and_host(target, task.target_host) with target: s, args = task.instantiate(config) diff --git a/python/tvm/autotvm/task/relay_integration.py b/python/tvm/autotvm/task/relay_integration.py index 2643a01439e6..11f40ed62756 100644 --- a/python/tvm/autotvm/task/relay_integration.py +++ b/python/tvm/autotvm/task/relay_integration.py @@ -22,7 +22,6 @@ """ import threading import logging -import warnings import tvm from tvm.autotvm.task.dispatcher import DispatchContext, FallbackContext @@ -81,12 +80,7 @@ def extract_from_program(mod, params, target, target_host=None, ops=None): task: Array of autotvm.task.Task collected tasks """ - if target_host is not None: - warnings.warn( - "target_host parameter is going to be deprecated. " - "Please pass in tvm.target.Target(target, host=target_host) instead." - ) - target, target_host = Target.check_and_update_host_consist(target, target_host) + target, target_host = Target.canon_target_and_host(target, target_host) return extract_from_multiple_program([mod], [params], target, ops=ops) @@ -121,7 +115,7 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No env = TaskExtractEnv.get() # merge target and target host - target, target_host = Target.check_and_update_host_consist(target, target_host) + target, target_host = Target.canon_target_and_host(target, target_host) # run compiler to collect all TOPI calls during compilation env.reset(ops) diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py index ee1750896fca..18bc0720d514 100644 --- a/python/tvm/autotvm/task/task.py +++ b/python/tvm/autotvm/task/task.py @@ -177,9 +177,7 @@ def __getstate__(self): # and restore the function by name when unpickling it. import cloudpickle # pylint: disable=import-outside-toplevel - self.target, self.target_host = Target.check_and_update_host_consist( - self.target, self.target_host - ) + self.target, self.target_host = Target.canon_target_and_host(self.target, self.target_host) return { "name": self.name, "args": self.args, @@ -200,7 +198,7 @@ def __setstate__(self, state): self.config_space = state["config_space"] self.func = cloudpickle.loads(state["func"]) self.flop = state["flop"] - self.target, self.target_host = Target.check_and_update_host_consist( + self.target, self.target_host = Target.canon_target_and_host( state["target"], state["target_host"] ) @@ -471,10 +469,7 @@ def create(task_name, args, target, target_host=None): args = serialize_args(args) ret = Task(task_name, args) - if isinstance(target, str): - target = Target(target) - - target, target_host = Target.check_and_update_host_consist(target, target_host) + target, target_host = Target.canon_target_and_host(target, target_host) # init config space ret.config_space = ConfigSpace() diff --git a/python/tvm/contrib/hexagon/session.py b/python/tvm/contrib/hexagon/session.py index a69a33e27007..f30fe6e47096 100644 --- a/python/tvm/contrib/hexagon/session.py +++ b/python/tvm/contrib/hexagon/session.py @@ -253,8 +253,8 @@ def _set_device_type(self, module: Union[str, pathlib.Path, GraphExecutorFactory if not hasattr(module, "target"): self._requires_cpu_device = False else: - assert len(module.target.values()) == 1 - for target in module.target.values(): + assert len(module.target) == 1 + for target in module.target: target_type = str(target).split()[0] if target_type == "llvm": @@ -319,13 +319,13 @@ def _aot_executor_from_factory( hexagon_arch = set( target.mcpu.replace("hexagon", "") - for target in module.target.values() + for target in module.target if "hexagon" in target.keys ) self._set_device_type(module) - for target in module.target.values(): + for target in module.target: target_type = str(target).split()[0] assert hexagon_arch, "No hexagon target architecture found" diff --git a/python/tvm/contrib/peak.py b/python/tvm/contrib/peak.py index 4133aa31a50b..48d0d31a45b0 100644 --- a/python/tvm/contrib/peak.py +++ b/python/tvm/contrib/peak.py @@ -87,7 +87,7 @@ def measure_bandwidth_sum( GBPS: float gigabyte per second """ - target, target_host = Target.check_and_update_host_consist(target, target_host) + target, target_host = Target.canon_target_and_host(target, target_host) n, m = total_item, item_per_thread n //= lanes @@ -154,7 +154,7 @@ def measure_bandwidth_all_types( result: list a list of (type_name, GBPS) pairs """ - target, target_host = Target.check_and_update_host_consist(target, target_host) + target, target_host = Target.canon_target_and_host(target, target_host) max_threads = target.max_num_threads result = [] @@ -225,7 +225,7 @@ def measure_compute_mad( GOPS: float giga operation per second """ - target, target_host = Target.check_and_update_host_consist(target, target_host) + target, target_host = Target.canon_target_and_host(target, target_host) n = total_item @@ -318,7 +318,7 @@ def measure_compute_all_types( result: list a list of (type_name, GFLOPS/GIOPS) pairs """ - target, target_host = Target.check_and_update_host_consist(target, target_host) + target, target_host = Target.canon_target_and_host(target, target_host) result = [] for base_type in ["float", "int"]: @@ -364,7 +364,7 @@ def measure_peak_all(target, target_host, host, port): port: int """ - target, target_host = Target.check_and_update_host_consist(target, target_host) + target, target_host = Target.canon_target_and_host(target, target_host) remote = rpc.connect(host, port) n_times = 20 diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index faa246e34f0d..be31e43c96b6 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -17,8 +17,6 @@ # pylint: disable=invalid-name """The build utils in python.""" -import warnings - from typing import Union, Optional, List, Mapping import tvm.tir @@ -238,12 +236,6 @@ def build( f"but got {type(inputs)}." ) - if target_host is not None: - warnings.warn( - "target_host parameter is going to be deprecated. " - "Please pass in tvm.target.Target(target, host=target_host) instead." - ) - if not isinstance(inputs, (dict, container.Map)): target = Target.current() if target is None else target target = target if target else "llvm" @@ -261,11 +253,12 @@ def build( raise ValueError("inputs must be Schedule, IRModule," "or dict of str to IRModule.") annotated_mods[tar] = mod.with_attr("runtime", runtime) - annotated_mods, target_host = Target.check_and_update_host_consist(annotated_mods, target_host) + annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host) + # TODO(mbs): CompilationConfig implements the same host target defaulting logic, but + # tir_to_runtime currently bypasses that. if not target_host: for tar, mod in annotated_mods.items(): - tar = Target(tar) device_type = ndarray.device(tar.kind.name, 0).device_type if device_type == ndarray.cpu(0).device_type: target_host = tar @@ -273,11 +266,11 @@ def build( if not target_host: target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm" - annotated_mods, target_host = Target.check_and_update_host_consist(annotated_mods, target_host) + annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host) rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host) - annotated_mods, target_host = Target.check_and_update_host_consist(annotated_mods, target_host) + annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host) if not isinstance(target_host, Target): target_host = Target(target_host) diff --git a/python/tvm/driver/tvmc/autotuner.py b/python/tvm/driver/tvmc/autotuner.py index c279b04f499d..f9ba427ffaa6 100644 --- a/python/tvm/driver/tvmc/autotuner.py +++ b/python/tvm/driver/tvmc/autotuner.py @@ -384,7 +384,7 @@ def tune_model( The path to the produced tuning log file. """ target, extra_targets = target_from_cli(target, additional_target_options) - target, target_host = Target.check_and_update_host_consist(target, target_host) + target, target_host = Target.canon_target_and_host(target, target_host) # TODO(jwfromm) Remove this deepcopy once AlterOpLayout bug that mutates source # model is fixed. For now, creating a clone avoids the issue. mod = deepcopy(tvmc_model.mod) @@ -524,7 +524,7 @@ def autotvm_get_tuning_tasks( tasks : list of autotvm.Tasks list of tasks to be tuned """ - target, target_host = Target.check_and_update_host_consist(target, target_host) + target, target_host = Target.canon_target_and_host(target, target_host) if alter_layout: mod = convert_graph_layout(mod, alter_layout) @@ -573,7 +573,7 @@ def autoscheduler_get_tuning_tasks( weights : List[int] the weight (i.e. the number of appearance) of extracted tasks """ - target, target_host = Target.check_and_update_host_consist(target, target_host) + target, target_host = Target.canon_target_and_host(target, target_host) if alter_layout: mod = convert_graph_layout(mod, alter_layout) diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index a192b93d8cef..138504470459 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -278,7 +278,7 @@ def compile_model( mod = convert_graph_layout(mod, desired_layout) tvm_target, extra_targets = target_from_cli(target, additional_target_options) - tvm_target, target_host = Target.check_and_update_host_consist(tvm_target, target_host) + tvm_target, target_host = Target.canon_target_and_host(tvm_target, target_host) for codegen_from_cli in extra_targets: codegen = composite_target.get_codegen_by_target(codegen_from_cli["name"]) diff --git a/python/tvm/exec/measure_peak.py b/python/tvm/exec/measure_peak.py index 6db61080eaf7..178e60089245 100644 --- a/python/tvm/exec/measure_peak.py +++ b/python/tvm/exec/measure_peak.py @@ -44,9 +44,7 @@ def main(): args = parser.parse_args() logging.basicConfig(level=logging.INFO) - args.target, args.target_host = Target.check_and_update_host_consist( - args.target, args.target_host - ) + args.target, args.target_host = Target.canon_target_and_host(args.target, args.target_host) measure_peak_all(args.target, args.target_host, args.rpc_host, args.rpc_port) diff --git a/python/tvm/micro/model_library_format.py b/python/tvm/micro/model_library_format.py index 6b95220b6794..1dd63b319dbd 100644 --- a/python/tvm/micro/model_library_format.py +++ b/python/tvm/micro/model_library_format.py @@ -31,7 +31,6 @@ from .._ffi import get_global_func from ..contrib import utils from ..driver import build_module -from ..runtime import ndarray as _nd from ..relay.backend import executor_factory from ..relay.backend.name_transforms import to_c_variable_style, prefix_generated_name from ..relay import param_dict @@ -313,7 +312,7 @@ def reset(tarinfo): tar_f.add(get_standalone_crt_dir(), arcname=STANDALONE_CRT_URL) -_GENERATED_VERSION = 5 +_GENERATED_VERSION = 6 def _export_graph_model_library_format( @@ -336,7 +335,7 @@ def _export_graph_model_library_format( "model_name": mod.libmod_name, "export_datetime": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%SZ"), "memory": _build_memory_map(mod), - "target": {int(k): str(v) for k, v in mod.target.items()}, + "target": [str(t) for t in mod.target], "executors": executor, "style": "full-model", } @@ -423,7 +422,9 @@ def _eval_shape(param_name, buffer_shape): return shape memory_map = {} - for target_device_type, target in targets.items(): + for target in targets: + # TODO(mbs): The device type is not unique, better would be to use target.kind.name + target_device_type = target.kind.device_type ir_mod = ir_module_by_target[target] printer = get_global_func("tir.ModelLibraryFormatPrinter")(False, None, False) with open(src_dir / f"tir-{target_device_type}.txt", "w") as f: @@ -460,7 +461,7 @@ def _export_operator_model_library_format(mod: build_module.OperatorModule, temp file_name : str Path to the .tar archive to generate. """ - targets = {} + targets = [] for target in mod.ir_module_by_target.keys(): if str(target.kind) not in ("llvm", "c"): raise UnsupportedInModelLibraryFormatError( @@ -468,7 +469,7 @@ def _export_operator_model_library_format(mod: build_module.OperatorModule, temp "Model Library Format" ) - targets[int(_nd.device(str(target)).device_type)] = target + targets.append(target) src_dir = tempdir / "src" src_dir.mkdir() @@ -479,7 +480,7 @@ def _export_operator_model_library_format(mod: build_module.OperatorModule, temp "model_name": mod.name, "export_datetime": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%SZ"), "memory": memory_map, - "target": {k: str(v) for k, v in targets.items()}, + "target": [str(t) for t in targets], "executors": [], "style": "operator", } diff --git a/python/tvm/relay/backend/_backend.py b/python/tvm/relay/backend/_backend.py index 7378ed6beb8a..b377eefdb2c5 100644 --- a/python/tvm/relay/backend/_backend.py +++ b/python/tvm/relay/backend/_backend.py @@ -17,7 +17,6 @@ """The interface of expr function exposed from C++.""" import tvm._ffi import tvm.driver -from tvm.target import Target @tvm._ffi.register_func("relay.backend.build") @@ -41,8 +40,7 @@ def build(mod, target, target_host=None): The runtime module. """ target_host = None if target_host == "" else target_host - target, target_host = Target.check_and_update_host_consist(target, target_host) - return tvm.driver.build(mod, target=target) + return tvm.driver.build(mod, target=target, target_host=target_host) @tvm._ffi.register_func("relay._tensor_value_repr") diff --git a/python/tvm/relay/backend/graph_executor_codegen.py b/python/tvm/relay/backend/graph_executor_codegen.py index 531f9f69e0e0..aff41c76f89c 100644 --- a/python/tvm/relay/backend/graph_executor_codegen.py +++ b/python/tvm/relay/backend/graph_executor_codegen.py @@ -53,7 +53,7 @@ def __init__(self, mod, target): self._setup(mod, target) def _setup(self, mod, target): - raw_targets = Target.canonicalize_target_and_host(target) + raw_targets = Target.canon_multi_target_and_host(target) self._init(mod, raw_targets) def codegen(self, ir_module, func): diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index 256293f6538b..d4a82cd8d427 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -39,12 +39,11 @@ def compile(mod, target=None, target_host=None, params=None): mod : tvm.IRModule The Relay module to build. - target : str, :any:`tvm.target.Target`, or dict of str(i.e. - device/context name) to str/tvm.target.Target, optional - For heterogeneous compilation, it is a dictionary indicating context - to target mapping. For homogeneous compilation, it is a build target. + target : any multi-target like object, see Target.canon_multi_target + For homogeneous compilation, the unique build target. + For heterogeneous compilation, a dictionary or list of possible build targets. - target_host : str or :any:`tvm.target.Target`, optional + target_host : None, or any target-like object, see Target.canon_target Host compilation target, if target is device. When TVM compiles device specific program such as CUDA, we also need host(CPU) side code to interact with the driver @@ -114,21 +113,14 @@ def lower(self, mod, target=None, target_host=None): mod : tvm.IRModule The Relay module to build. - target : str, :any:`tvm.target.Target`, or dict of str(i.e. - device/context name) to str/tvm.target.Target, optional - For heterogeneous compilation, it is a dictionary indicating context - to target mapping. For homogeneous compilation, it is a build target. + target : any multi-target like object, see Target.canon_multi_target + For homogeneous compilation, the unique build target. + For heterogeneous compilation, a dictionary or list of possible build targets. - target_host : str or :any:`tvm.target.Target`, optional + target_host : any target-like object, see Target.canon_target Host compilation target, if target is device. - When TVM compiles device specific program such as CUDA, - we also need host(CPU) side code to interact with the driver - to setup the dimensions and parameters correctly. - target_host is used to specify the host side codegen target. - By default, llvm is used if it is enabled, - otherwise a stackvm intepreter is used. """ - raw_targets = Target.canonicalize_target_and_host(target, target_host) + raw_targets = Target.canon_multi_target_and_host(target, target_host) tophub_context = self._tophub_context(raw_targets) with tophub_context: self._lower(mod, raw_targets) @@ -144,13 +136,12 @@ def optimize(self, mod, target=None, target_host=None, params=None): ---------- mod : tvm.IRModule - target : str, :any:`tvm.target.Target`, or dict of str (i.e. - device/context name) to str/tvm.target.Target, optional + target : any multi-target like object, see Target.canon_multi_target + For homogeneous compilation, the unique build target. + For heterogeneous compilation, a dictionary or list of possible build targets. - target_host : str or :any:`tvm.target.Target`, optional - The compilation target for host. - By default, llvm is used if it is enabled, - otherwise a stackvm intepreter is used. + target_host : any target-like object, see Target.canon_target + Host compilation target, if target is device. params : dict of str to NDArray Input parameters to the graph that do not change @@ -164,7 +155,7 @@ def optimize(self, mod, target=None, target_host=None, params=None): params : dict The parameters of the final module. """ - raw_targets = Target.canonicalize_target_and_host(target, target_host) + raw_targets = Target.canon_multi_target_and_host(target, target_host) if params: self.set_params(params) return self._optimize(mod, raw_targets), self.get_params() diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 06fa212ff396..9eeb20f5f1ce 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -24,7 +24,6 @@ from tvm.ir import IRModule from tvm.ir.transform import PassContext from tvm.target import Target -from tvm.tir import expr as tvm_expr from .. import autotvm from .. import nd as _nd @@ -46,44 +45,6 @@ from .transform import InferType -def build_target_by_device_type_map(target): - """Build a map from DLDevice device_type to a Target used with that device. - - At runtime, TVM assigns target code to DLDevices by determining a device_type for each Target. - This function handles this process at compile time and, as a side effect, validates that exactly - one target maps to one device_type. - - Parameters - ---------- - target : Target or str or dict - If a Target or str: assumes that exactly one device type is present in the model. - If a dict: keys are tvm.ndarray.device, values are the targets used for each device. - - Returns - ------- - - """ - target = target if target else Target.current() - if target is None: - raise ValueError("Target is not set in env or passed as argument.") - - tgts = {} - if isinstance(target, (str, Target)): - dev_type = tvm_expr.IntImm("int32", _nd.device(str(target)).device_type) - tgts[dev_type] = Target(target) - elif isinstance(target, dict): - for dev, tgt in target.items(): - dev_type = tvm_expr.IntImm("int32", _nd.device(dev).device_type) - tgts[dev_type] = Target(tgt) - else: - raise TypeError( - "target is expected to be str or " - + "tvm.target.Target, but received " - + "{}".format(type(target)) - ) - return tgts - - def _convert_param_map(params): inputs = {} for name, param in params.items(): @@ -128,12 +89,11 @@ def build( mod : :py:class:`~tvm.IRModule` The IRModule to build. - target : str, :any:`tvm.target.Target`, or dict of str(i.e. - device/context name) to str/tvm.target.Target, optional - For heterogeneous compilation, it is a dictionary indicating context - to target mapping. For homogeneous compilation, it is a build target. + target : any multi-target like object, see Target.canon_multi_target + For homogeneous compilation, the unique build target. + For heterogeneous compilation, a dictionary or list of possible build targets. - target_host : str or :any:`tvm.target.Target`, optional + target_host : None, or any target-like object, see Target.canon_target Host compilation target, if target is device. When TVM compiles device specific program such as CUDA, we also need host(CPU) side code to interact with the driver @@ -173,7 +133,7 @@ def build( params : dict The parameters of the final graph. """ - raw_targets = Target.canonicalize_target_and_host(target, target_host) + raw_targets = Target.canon_multi_target_and_host(target, target_host) # Setup the params. if params: @@ -208,10 +168,12 @@ def optimize(self, mod, target=None, target_host=None, params=None): mod : :py:class:`~tvm.IRModule` The IR module to build. - target : str, :any:`tvm.target.Target`, or dict of str(i.e. - device/context name) to str/tvm.target.Target, optional - For heterogeneous compilation, it is a dictionary indicating context - to target mapping. For homogeneous compilation, it is a build target. + target : any multi-target like object, see Target.canon_multi_target. + For homogeneous compilation, the unique build target. + For heterogeneous compilation, a dictionary or list of possible build targets. + + target_host : None, or any target-like object, see Target.canon_target + Host compilation target, if target is device. params : dict of str to NDArray Input parameters to the graph that do not change @@ -225,7 +187,7 @@ def optimize(self, mod, target=None, target_host=None, params=None): params : dict The parameters of the final graph. """ - raw_targets = Target.canonicalize_target_and_host(target, target_host) + raw_targets = Target.canon_multi_target_and_host(target, target_host) # Setup the params. if params: @@ -272,7 +234,7 @@ def get_params(self): return ret def get_irmodule(self): - """Returns the Target IRModule's post-lowering""" + """Returns the TargetIRModule's post-lowering""" return self._get_irmodule() @@ -283,8 +245,9 @@ def _module_export(module, file_name): # fcompile, addons, kwargs? @register_func("tvm.relay.build") def _build_module_no_factory_impl(mod, target, target_host, params, mod_name): - target, target_host = Target.check_and_update_host_consist(target, target_host) - return build(mod, target, params=params, mod_name=mod_name).module + return build( + mod, target=target, target_host=target_host, params=params, mod_name=mod_name + ).module def _build_module_no_factory(mod, target=None, target_host=None, params=None, mod_name="default"): @@ -377,18 +340,13 @@ def build( ir_mod : :py:class:`~tvm.IRModule` The IR module to build. Using relay.Function is deprecated. - target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context name) to str/tvm.target.Target, optional - For heterogeneous compilation, it is a dictionary indicating context to - target mapping. For homogeneous compilation, it is a build target. + target : None, or any multi-target like object, see Target.canon_multi_target + For homogeneous compilation, the unique build target. + For heterogeneous compilation, a dictionary or list of possible build targets. + Defaults to the current target in the environment if None. - target_host : str or :any:`tvm.target.Target`, optional + target_host : None, or any target like object, see Target.canon_target Host compilation target, if target is device. - When TVM compiles device specific program such as CUDA, - we also need host(CPU) side code to interact with the driver - setup the dimensions and parameters correctly. - target_host is used to specify the host side codegen target. - By default, llvm is used if it is enabled, - otherwise a stackvm interpreter is used. executor : Optional[Executor] The executor configuration with which to build the model. @@ -431,25 +389,13 @@ def build( DeprecationWarning, ) - if target_host is not None: - warnings.warn( - "target_host parameter is going to be deprecated. " - "Please pass in tvm.target.Target(target, host=target_host) instead." - ) - - target, target_host = Target.check_and_update_host_consist( - target, target_host, target_is_dict_key=False - ) - - target = build_target_by_device_type_map(target) - if isinstance(target_host, (str, Target)): - target_host = Target(target_host) - elif target_host: - raise ValueError("target host must be the type of str, " + "tvm.target.Target, or None") + raw_targets = Target.canon_multi_target_and_host(Target.target_or_current(target), target_host) + assert len(raw_targets) > 0 + target_host = raw_targets[0].host # All of this logic is to raise deprecation warnings for various parameters # TODO(Mousius) Remove these after some time - deprecated_params_target = target_host or list(target.values())[0] + deprecated_params_target = target_host or list(raw_targets)[0] deprecated_executor, deprecated_runtime = _reconstruct_from_deprecated_options( deprecated_params_target ) @@ -461,7 +407,7 @@ def build( # If current dispatch context is fallback context (the default root context), # then load pre-tuned parameters from TopHub if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext): - tophub_context = autotvm.tophub.context(list(target.values())) + tophub_context = autotvm.tophub.context(list(raw_targets)) else: tophub_context = autotvm.utils.EmptyContext() @@ -469,7 +415,7 @@ def build( bld_mod = BuildModule() graph_json, runtime_mod, params = bld_mod.build( mod=ir_mod, - target=target, + target=raw_targets, params=params, executor=executor, runtime=runtime, @@ -485,7 +431,7 @@ def build( executor_factory = _executor_factory.AOTExecutorFactoryModule( ir_mod, lowered_ir_mods, - target, + raw_targets, executor, runtime, runtime_mod, @@ -497,7 +443,14 @@ def build( ) elif str(executor) == "graph": executor_factory = _executor_factory.GraphExecutorFactoryModule( - ir_mod, target, executor, graph_json, runtime_mod, mod_name, params, func_metadata + ir_mod, + raw_targets, + executor, + graph_json, + runtime_mod, + mod_name, + params, + func_metadata, ) else: assert False, "Executor " + executor + " not supported" @@ -513,10 +466,10 @@ def optimize(mod, target=None, params=None): mod : :py:class:`~tvm.IRModule` The module to build. Using relay.Function is deprecated. - target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context - name) to str/tvm.target.Target, optional - For heterogeneous compilation, it is a dictionary indicating context to - target mapping. For homogeneous compilation, it is a build target. + target : None, or any multi-target like object, see Target.canon_multi_target + For homogeneous compilation, the unique build target. + For heterogeneous compilation, a dictionary or list of possible build targets. + Defaults to the current target in the environment if None. params : dict of str to NDArray Input parameters to the graph that do not change @@ -543,18 +496,18 @@ def optimize(mod, target=None, params=None): DeprecationWarning, ) - target = build_target_by_device_type_map(target) + raw_targets = Target.canon_multi_target_and_host(Target.target_or_current(target)) # If current dispatch context is fallback context (the default root context), # then load pre-tuned parameters from TopHub if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext): - tophub_context = autotvm.tophub.context(list(target.values())) + tophub_context = autotvm.tophub.context(raw_targets) else: tophub_context = autotvm.utils.EmptyContext() with tophub_context: bld_mod = BuildModule() - mod, params = bld_mod.optimize(mod, target=target, params=params) + mod, params = bld_mod.optimize(mod, target=raw_targets, params=params) return mod, params diff --git a/python/tvm/target/compilation_config.py b/python/tvm/target/compilation_config.py index 8a59a33c1a47..116f1dd8e99a 100644 --- a/python/tvm/target/compilation_config.py +++ b/python/tvm/target/compilation_config.py @@ -23,5 +23,5 @@ def make_compilation_config(ctxt, target, target_host=None): """Returns a CompilationConfig appropriate for target and target_host, using the same representation conventions as for the standard build interfaces. Intended only for unit testing.""" - raw_targets = tvm.target.Target.canonicalize_target_and_host(target, target_host) + raw_targets = tvm.target.Target.canon_multi_target_and_host(target, target_host) return _ffi_api.MakeCompilationConfig(ctxt, raw_targets) diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index 101980941fb0..34033b991ed1 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -218,13 +218,13 @@ def list_kinds(): return list(_ffi_api.ListTargetKinds()) @staticmethod - def canonicalize_target(target): + def canon_target(target): """Given a single target-like object, returns the TVM Target object representing it. Can convert from: - None (to None). - An existing TVM Target object. - - A string. - - A Python dictionary binding the target 'kind' and other attributes. + - A string, eg "cuda" or "cuda -arch=sm_80" + - A Python dictionary, eg {"kind": "cuda", "arch": "sm_80" } """ if target is None: return None @@ -233,86 +233,106 @@ def canonicalize_target(target): return Target(target) @staticmethod - def canonicalize_multi_targets(multi_targets): - """Given a single or collection of target-like objects, returns a TVM Array of Target - objects representing then. Can convert from: + def canon_target_and_host(target, target_host=None): + """Returns a TVM Target capturing target and target_host. Also returns the host in + canonical form. The given target can be in any form recognized by + Target.canon_target. If given, target_host can be in any form recognized by + Target.canon_target. If target_host is given it will be set as the 'host' in the + result Target object (and a warning given). + + Note that this method does not support heterogeneous compilation targets. + """ + target = Target.canon_target(target) + target_host = Target.canon_target(target_host) + if target is None: + assert target_host is None, "Target host is not empty when target is empty." + if target_host is not None: + warnings.warn( + "target_host parameter is going to be deprecated. " + "Please pass in tvm.target.Target(target, host=target_host) instead." + ) + target = target.with_host(target_host) + if target is not None: + # In case the target already had a host, extract it here. + target_host = target.host + return target, target_host + + @staticmethod + def canon_multi_target(multi_targets): + """Given a single target-like object, or a collection-like object of target-like objects, + returns a TVM Array of TVM Target objects representing then. Can convert from: - None (to None). - - A single target-like object in a form recognized by canonicalize_target. + - A single target-like object in a form recognized by canon_target. - A Python list or TVM Array of target-like objects in a form recognized by - canonicalize_target. + canon_target. - A Python dict or TVM Map from TVM IntImm objects representing device types to - a target-like object in a form recognized by canonicalize_target. + a target-like object in a form recognized by canon_target. (This is a legacy + method to represent heterogeneous targets. The keys are ignored.) """ if multi_targets is None: return None if isinstance(multi_targets, (dict, Map)) and "kind" not in multi_targets: # Convert legacy heterogeneous map representation to ordinary list of targets. - return Target.canonicalize_multi_targets([t for _, t in multi_targets.items()]) + return Target.canon_multi_target(list(multi_targets.values())) if isinstance(multi_targets, (list, Array)): # Multiple Target results. - return convert([Target.canonicalize_target(t) for t in multi_targets]) + return convert([Target.canon_target(tgt) for tgt in multi_targets]) # Single Target result. - return convert([Target.canonicalize_target(multi_targets)]) + return convert([Target.canon_target(multi_targets)]) @staticmethod - def canonicalize_target_and_host(target, target_host=None): + def canon_multi_target_and_host(target, target_host=None): """Returns a TVM Array capturing target and target_host. The given target can be in - any form recognized by Target.canonicalize_target or Target.canonicalize_multi_targets. If - given target_host can be in any form recognized by Target.canonicalize_target. If - target_host is given it will be set as the 'host' in each result Target object (and a - warning given). + any form recognized by Target.canon_multi_target. If given, target_host can be in + any form recognized by Target.canon_target. If target_host is given it will be set + as the 'host' in each result Target object (and a warning given). """ # Convert target to Array, but not yet accounting for any host. - raw_targets = Target.canonicalize_multi_targets(target) + raw_targets = Target.canon_multi_target(target) assert raw_targets is not None # Convert host to Target, if given. - target_host = Target.canonicalize_target(target_host) - if target_host is None: - return raw_targets - warnings.warn( - "target_host parameter is going to be deprecated. " - "Please pass in tvm.target.Target(target, host=target_host) instead." - ) - # Make sure the (canonical) host is captured in all the (canonical) targets. - return convert([Target(t, target_host) for t in raw_targets]) + target_host = Target.canon_target(target_host) + if target_host is not None: + warnings.warn( + "target_host parameter is going to be deprecated. " + "Please pass in tvm.target.Target(target, host=target_host) instead." + ) + # Make sure the (canonical) host is captured in all the (canonical) targets. + raw_targets = convert([tgt.with_host(target_host) for tgt in raw_targets]) + return raw_targets @staticmethod - def check_and_update_host_consist(target, host=None, target_is_dict_key=True): - """A helper function that merges a legacy "target, target_host" pair, then returns - the merged target and its host field. The function is for legacy target and target - host pair only, and should not be used in the new target system. + def canon_target_map_and_host(target_map, target_host=None): + """Returns target_map as a map from TVM Target's in canonical form to IRModules. The keys + of the input target_map can be in any form recognized by Target.canon_target. + Similarly, if given, target_host can be in any form recognized by + Target.canon_target. The final target_map keys will capture the target_host in + canonical form. Also returns the target_host in canonical form.""" + if target_host is not None: + warnings.warn( + "target_host parameter is going to be deprecated. " + "Please pass in tvm.target.Target(target, host=target_host) instead." + ) + target_host = Target.canon_target(target_host) + new_target_map = {} + for tgt, mod in target_map.items(): + tgt = Target.canon_target(tgt) + assert tgt is not None + if target_host is not None: + tgt = tgt.with_host(target_host) + # In case the first target already has a host, extract it here. + target_host = tgt.host + new_target_map[tgt] = mod + return new_target_map, target_host - Parameters - ---------- - target : Union[str, Dict[str, Any], Target] - The target or heterogeneous target - host : Union[str, Dict[str, Any], Target, None] - The target host - target_is_dict_key : Bool - When the type of target is dict, whether Target is the key (Otherwise the value) - """ - if isinstance(target, (dict, str)): - target = convert(target) - if isinstance(host, (dict, str)): - host = convert(host) + @staticmethod + def target_or_current(target): + """Returns target, or the current target in the environment if target is None""" if target is None: - assert host is None, "Target host is not empty when target is empty." - return target, host - if isinstance(target, Map) and "kind" not in target: - new_target = {} - for tgt, mod in target.items(): - if not target_is_dict_key: - tgt, mod = mod, tgt - if isinstance(tgt, (Map, String, Target)): - tgt, host = Target.check_and_update_host_consist(tgt, host) - if not target_is_dict_key: - tgt, mod = mod, tgt - new_target[tgt] = mod - target = new_target - else: - target = Target(target, host) - host = target.host - return target, host + target = Target.current() + if target is None: + raise ValueError("Target is not set in env or passed as argument.") + return target # TODO(@tvm-team): Deprecate the helper functions below. Encourage the usage of config dict instead. diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 71b57aed81f6..5a0502d17548 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -926,17 +926,17 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, const Compila } for (const auto& dev_and_size : device_workspace) { - Target target = config->FindPrimitiveTargetOrFail(dev_and_size.first); + Target target = config->FindPrimitiveTargetForDeviceOrFail(dev_and_size.first); workspace_sizes.Set(target, dev_and_size.second); relay_primfuncs.Set(target, func); } for (const auto& dev_and_size : device_io) { - Target target = config->FindPrimitiveTargetOrFail(dev_and_size.first); + Target target = config->FindPrimitiveTargetForDeviceOrFail(dev_and_size.first); io_sizes.Set(target, dev_and_size.second); } for (const auto& dev_and_size : device_consts) { - Target target = config->FindPrimitiveTargetOrFail(dev_and_size.first); + Target target = config->FindPrimitiveTargetForDeviceOrFail(dev_and_size.first); ICHECK_EQ(constant_sizes.count(target), 0); constant_sizes.Set(target, dev_and_size.second); } diff --git a/src/target/compilation_config.cc b/src/target/compilation_config.cc index 7260427bc1a1..cb50615ce6a5 100644 --- a/src/target/compilation_config.cc +++ b/src/target/compilation_config.cc @@ -38,15 +38,15 @@ void CompilationConfigNode::VisitAttrs(AttrVisitor* v) { // NOTE: The virtual_device_cache_ is not accessible via FFI. } -Target CompilationConfigNode::FindPrimitiveTargetOrFail(DLDeviceType device_type) const { +Target CompilationConfigNode::FindPrimitiveTargetForDeviceOrFail(DLDeviceType device_type) const { ICHECK_GT(device_type, 0) << "Invalid device type"; auto itr = std::find_if( primitive_targets.begin(), primitive_targets.end(), [device_type](const Target& target) { return target->kind->device_type == device_type; }); if (itr == primitive_targets.end()) { std::stringstream msg; - msg << "No target is specified for device '" << runtime::DeviceName(device_type) - << "' mapped to device type " << device_type << ". The available targets are:" << std::endl; + msg << "No target is specified for device type " << device_type + << ". The available device types and targets are:" << std::endl; for (const auto& target : primitive_targets) { msg << " " << target->kind->device_type << "-> " << target->ToDebugString() << std::endl; } @@ -55,6 +55,23 @@ Target CompilationConfigNode::FindPrimitiveTargetOrFail(DLDeviceType device_type return *itr; } +Optional CompilationConfigNode::FindPrimitiveTargetForKind( + const std::string& kind_name) const { + Optional opt_kind = TargetKind::Get(kind_name); + if (!opt_kind.defined()) { + VLOG(1) << "No such target kind for '" << kind_name << "'"; + return {}; + } + auto itr = + std::find_if(primitive_targets.begin(), primitive_targets.end(), + [kind_name](const Target& target) { return target->kind->name == kind_name; }); + if (itr == primitive_targets.end()) { + VLOG(1) << "No target available matching kind '" << kind_name << "'"; + return {}; + } + return *itr; +} + VirtualDevice CompilationConfigNode::CanonicalVirtualDevice( const VirtualDevice& virtual_device) const { if (virtual_device->target.defined()) { @@ -64,7 +81,7 @@ VirtualDevice CompilationConfigNode::CanonicalVirtualDevice( // TODO(mbs): Proper diagnostics. CHECK(device_type != kInvalidDeviceType) << "VirtualDevice annotations must include at least a device_type"; - Target target = FindPrimitiveTargetOrFail(virtual_device->device_type()); + Target target = FindPrimitiveTargetForDeviceOrFail(virtual_device->device_type()); return virtual_device_cache_.Unique(VirtualDevice(device_type, virtual_device->virtual_device_id, target, virtual_device->memory_scope)); } @@ -140,15 +157,20 @@ void CompilationConfigNode::Init(const transform::PassContext& pass_ctx, ICHECK_GT(primitive_targets.size(), 0U); // - // Check the primitive_targets are ordered correctly re Target::IsExternalCodegenFor. + // Check the primitive_targets are ordered correctly re Target::IsExternalCodegenFor, + // and make sure no two targets share a kind name. // // TODO(mbs): We could just sort the list, but given all the implicit defaulting for backwards // compat it seems we should avoid making this any more magical than necessary. But revisit // if usability suffers. std::unordered_set primitive_target_device_types; + std::unordered_set kind_names; for (const auto& target : primitive_targets) { primitive_target_device_types.emplace(static_cast(target->kind->device_type)); + CHECK(kind_names.emplace(target->kind->name).second) << "Multiple targets have been given" + "for the same device kind '" + << target->kind->name << "'"; } for (DLDeviceType device_type : primitive_target_device_types) { Target first_primitive_target; @@ -158,10 +180,7 @@ void CompilationConfigNode::Init(const transform::PassContext& pass_ctx, } if (!first_primitive_target.defined()) { first_primitive_target = current_primitive_target; - CHECK(!first_primitive_target.IsExternalCodegen()) - << "The first given target for device type " << device_type - << " must not be for an external codegen, however given " - << first_primitive_target->ToDebugString(); + // Note it is valid to have only one external codegen target. } else { CHECK(current_primitive_target.IsExternalCodegenFor(first_primitive_target)) << "When given multiple targets for the device type " << device_type @@ -205,7 +224,7 @@ void CompilationConfigNode::Init(const transform::PassContext& pass_ctx, // default_primitive_virtual_device = virtual_device_cache_.Unique(VirtualDevice( default_primitive_device_type, - /*virtual_device_id=*/0, FindPrimitiveTargetOrFail(default_primitive_device_type))); + /*virtual_device_id=*/0, FindPrimitiveTargetForDeviceOrFail(default_primitive_device_type))); ICHECK(default_primitive_virtual_device.defined()); ICHECK(default_primitive_virtual_device->target.defined()); diff --git a/tests/cpp/target/compilation_config_test.cc b/tests/cpp/target/compilation_config_test.cc index 4568d11d6232..825cb5baeb8c 100644 --- a/tests/cpp/target/compilation_config_test.cc +++ b/tests/cpp/target/compilation_config_test.cc @@ -196,6 +196,27 @@ TEST(CompilationConfig, Constructor_Heterogeneous_InvalidOrdering) { CompilationConfig(pass_ctx, {ext_codegen1_target, cuda_target, ext_codegen2_target})); } +TEST(CompilationConfig, Constructor_Homogenous_JustExternalCodegen) { + transform::PassContext pass_ctx = transform::PassContext::Create(); + + Target host_target = TestDefaultCpuTarget(); + Target ext_codegen1_target = Target::WithHost(TestExtCodegenTarget1(), host_target); + + CompilationConfig config(pass_ctx, {ext_codegen1_target}); + ASSERT_EQ(config->primitive_targets.size(), 1); + EXPECT_TRUE(StructuralEqual()(config->primitive_targets[0], ext_codegen1_target)); +} + +TEST(CompliationConfig, Constructor_DuplicateKinds) { + transform::PassContext pass_ctx = transform::PassContext::Create(); + + Target host_target = TestDefaultCpuTarget(); + Target cuda_target_1 = Target::WithHost(TestCudaTarget(), host_target); + Target cuda_target_2 = Target::WithHost(TestCudaTarget(), host_target); + + EXPECT_ANY_THROW(CompilationConfig(pass_ctx, {cuda_target_1, cuda_target_2})); +} + TEST(CompilationConfig, Constructor_NoTargets) { transform::PassContext pass_ctx = transform::PassContext::Create(); EXPECT_ANY_THROW(CompilationConfig(pass_ctx, {})); @@ -243,15 +264,26 @@ TEST(CompilationConfig, Constructor_Idempotent) { reconstructed_config->primitive_targets[1])); } -TEST(CompilationConfig, FindPrimitiveTargetOrFail_Valid) { +TEST(CompilationConfig, FindPrimitiveTargetForDeviceOrFail_Valid) { CompilationConfig config = TestCompilationConfig(); Target cpu_target = Target::WithHost(TestCpuTarget(), TestDefaultCpuTarget()); - ASSERT_TRUE(StructuralEqual()(config->FindPrimitiveTargetOrFail(kDLCPU), cpu_target)); + ASSERT_TRUE(StructuralEqual()(config->FindPrimitiveTargetForDeviceOrFail(kDLCPU), cpu_target)); +} + +TEST(CompilationConfig, FindPrimitiveTargetForDeviceOrFail_Invalid) { + CompilationConfig config = TestCompilationConfig(); + EXPECT_ANY_THROW(config->FindPrimitiveTargetForDeviceOrFail(kDLMetal)); +} + +TEST(CompilationConfig, FindPrimitiveTargetForKind_Found) { + CompilationConfig config = TestCompilationConfig(); + Target cuda_target = Target::WithHost(TestCudaTarget(), TestDefaultCpuTarget()); + ASSERT_TRUE(StructuralEqual()(config->FindPrimitiveTargetForKind("cuda").value(), cuda_target)); } -TEST(CompilationConfig, FindPrimitiveTargetOrFail_Invalid) { +TEST(CompilationConfig, FindPrimitiveTargetForKind_NotFound) { CompilationConfig config = TestCompilationConfig(); - EXPECT_ANY_THROW(config->FindPrimitiveTargetOrFail(kDLMetal)); + ASSERT_FALSE(config->FindPrimitiveTargetForKind("cutlass").defined()); } TEST(CompilationConfig, CanonicalVirtualDevice) { diff --git a/tests/python/contrib/test_ethosn/infrastructure.py b/tests/python/contrib/test_ethosn/infrastructure.py index 49aa064edcb0..a4c20908151b 100644 --- a/tests/python/contrib/test_ethosn/infrastructure.py +++ b/tests/python/contrib/test_ethosn/infrastructure.py @@ -270,7 +270,7 @@ def test_error(mod, params, err_msg): with tvm.target.Target("llvm"): try: mod = relay.transform.InferType()(mod) - relay.build(mod, params) + relay.build(mod, params=params) except tvm.error.TVMError as e: caught = e.args[0] finally: diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index 7ea813762796..1e8d307b33ea 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -1003,4 +1003,7 @@ def fully_connected(x): if __name__ == "__main__": - pytest.main([__file__]) + import sys + import pytest + + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/integration/test_tuning.py b/tests/python/integration/test_tuning.py index 8933fc4267cd..5ca0e1ae67e7 100644 --- a/tests/python/integration/test_tuning.py +++ b/tests/python/integration/test_tuning.py @@ -129,8 +129,8 @@ def teardown_module(): def get_sample_task(target=tvm.target.cuda(), target_host=None): - target, target_host = Target.check_and_update_host_consist(target, target_host) """return a sample task for testing""" + target, target_host = Target.canon_target_and_host(target, target_host) task = autotvm.task.create( "testing/conv2d_no_batching", args=(1, 7, 7, 512, 512, 3, 3), target=target ) diff --git a/tests/python/relay/test_build_module.py b/tests/python/relay/test_build_module.py index b03e760a968a..b88115059eaf 100644 --- a/tests/python/relay/test_build_module.py +++ b/tests/python/relay/test_build_module.py @@ -64,7 +64,7 @@ def test_build_relay_graph_(): """Test to build a simple relay graph by using APIs directly""" def build_graph(mod, target): - target, target_host = tvm.target.Target.check_and_update_host_consist(target) + target, target_host = tvm.target.Target.canon_target_and_host(target) mod, _ = relay.optimize(mod, target) grc = graph_executor_codegen.GraphExecutorCodegen(None, target) _, lowered_funcs, _ = grc.codegen(mod, mod["main"]) diff --git a/tests/python/unittest/test_micro_model_library_format.py b/tests/python/unittest/test_micro_model_library_format.py index 76e95c960482..87b8bcb2b99a 100644 --- a/tests/python/unittest/test_micro_model_library_format.py +++ b/tests/python/unittest/test_micro_model_library_format.py @@ -62,13 +62,13 @@ def test_export_operator_model_library_format(): with open(os.path.join(extract_dir, "metadata.json")) as json_f: metadata = json.load(json_f) - assert metadata["version"] == 5 + assert metadata["version"] == 6 assert metadata["model_name"] == "add" export_datetime = datetime.datetime.strptime( metadata["export_datetime"], "%Y-%m-%d %H:%M:%SZ" ) assert (datetime.datetime.now() - export_datetime) < datetime.timedelta(seconds=60 * 5) - assert metadata["target"] == {"1": str(target)} + assert metadata["target"] == [str(target)] assert metadata["memory"]["add"][0]["dtype"] == "int8" assert metadata["memory"]["add"][0]["shape"] == [2] @@ -156,13 +156,13 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[ with open(os.path.join(extract_dir, "metadata.json")) as json_f: metadata = json.load(json_f) - assert metadata["version"] == 5 + assert metadata["version"] == 6 assert metadata["model_name"] == "add" export_datetime = datetime.datetime.strptime( metadata["export_datetime"], "%Y-%m-%d %H:%M:%SZ" ) assert (datetime.datetime.now() - export_datetime) < datetime.timedelta(seconds=60 * 5) - assert metadata["target"] == {"1": str(target)} + assert metadata["target"] == [str(target)] if str(executor) == "graph": assert metadata["memory"]["sids"] == [ {"storage_id": 0, "size_bytes": 2, "input_binding": "a"}, @@ -242,13 +242,13 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[ with open(os.path.join(extract_dir, "metadata.json")) as json_f: metadata = json.load(json_f) - assert metadata["version"] == 5 + assert metadata["version"] == 6 assert metadata["model_name"] == "add" export_datetime = datetime.datetime.strptime( metadata["export_datetime"], "%Y-%m-%d %H:%M:%SZ" ) assert (datetime.datetime.now() - export_datetime) < datetime.timedelta(seconds=60 * 5) - assert metadata["target"] == {"1": str(target)} + assert metadata["target"] == [str(target)] assert metadata["memory"]["sids"] == [ {"storage_id": 0, "size_bytes": 2, "input_binding": "a"}, {"storage_id": 1, "size_bytes": 8, "input_binding": "b"}, @@ -324,13 +324,13 @@ def @main(%p0: Tensor[(1, 56, 56, 128), int16], %p1: Tensor[(3, 3, 128, 1), int1 with open(os.path.join(extract_dir, "metadata.json")) as json_f: metadata = json.load(json_f) - assert metadata["version"] == 5 + assert metadata["version"] == 6 assert metadata["model_name"] == "qnn_conv2d" export_datetime = datetime.datetime.strptime( metadata["export_datetime"], "%Y-%m-%d %H:%M:%SZ" ) assert (datetime.datetime.now() - export_datetime) < datetime.timedelta(seconds=60 * 5) - assert metadata["target"] == {"1": str(target)} + assert metadata["target"] == [str(target)] assert metadata["memory"]["functions"]["main"] == [ { "constants_size_bytes": 0, diff --git a/tests/python/unittest/test_target_target.py b/tests/python/unittest/test_target_target.py index 9f5f62b8b991..d58c20d063e1 100644 --- a/tests/python/unittest/test_target_target.py +++ b/tests/python/unittest/test_target_target.py @@ -353,96 +353,111 @@ def test_target_with_host(): assert tgt.host.attrs["registers_per_block"] == 32768 -def test_check_and_update_host_consist_0(): +def test_canon_target_and_host_0(): target = None host = None - target, host = Target.check_and_update_host_consist(target, host) + target, host = Target.canon_target_and_host(target, host) + assert target is None + assert host is None -def test_check_and_update_host_consist_1(): +def test_canon_target_and_host_1(): target = None host = "llvm" with pytest.raises(AssertionError, match=r"Target host is not empty when target is empty."): - target, host = Target.check_and_update_host_consist(target, host) + target, host = Target.canon_target_and_host(target, host) -def test_check_and_update_host_consist_2(): +def test_canon_target_and_host_2(): target = Target("cuda") host = Target("llvm") - target, host = Target.check_and_update_host_consist(target, host) + target, host = Target.canon_target_and_host(target, host) assert target.kind.name == "cuda" assert target.host.kind.name == "llvm" -def test_check_and_update_host_consist_3(): +def test_canon_target_and_host_3(): target = Target(target="cuda", host="llvm") host = None - target, host = Target.check_and_update_host_consist(target, host) + target, host = Target.canon_target_and_host(target, host) assert target.kind.name == "cuda" assert target.host.kind.name == "llvm" assert host.kind.name == "llvm" assert target.host == host -def test_check_and_update_host_consist_4(): - """Test `check_and_update_host_consist` by using TVM Objects""" - cuda_device_type = tvm.device("cuda").device_type - target = {cuda_device_type: Target(target="cuda", host="llvm")} - host = None - target_1, host_1 = Target.check_and_update_host_consist(target, host) - assert isinstance(target_1, dict) - assert target_1[cuda_device_type].kind.name == "cuda" - assert target_1[cuda_device_type].host.kind.name == "llvm" - assert host_1 is None - - target = {cuda_device_type: Target(tvm.runtime.container.String("cuda"))} - host = Target(tvm.runtime.container.String("llvm")) - target = tvm.runtime.convert(target) - assert isinstance(target, tvm.ir.container.Map) - target_2, host_2 = Target.check_and_update_host_consist(target, host) - assert isinstance(target_2, dict) - assert target_2[cuda_device_type].kind.name == "cuda" - assert host_2.kind.name == "llvm" - - -def test_canonicalize_target_and_host_0(): +def test_canon_multi_target_and_host_0(): with pytest.raises(AssertionError): - Target.canonicalize_target_and_host(None) + Target.canon_multi_target_and_host(None) -def test_canonicalize_target_and_host_1(): - raw_targets = Target.canonicalize_target_and_host({"kind": "llvm"}) +def test_canon_multi_target_and_host_1(): + raw_targets = Target.canon_multi_target_and_host({"kind": "llvm"}) assert len(raw_targets) == 1 assert raw_targets[0].kind.name == "llvm" -def test_canonicalize_target_and_host_2(): - raw_targets = Target.canonicalize_target_and_host({1: "llvm", 2: "cuda"}) +def test_canon_multi_target_and_host_2(): + raw_targets = Target.canon_multi_target_and_host({1: "llvm", 2: "cuda"}) assert len(raw_targets) == 2 assert raw_targets[0].kind.name == "llvm" assert raw_targets[1].kind.name == "cuda" -def test_canonicalize_target_and_host_3(): - raw_targets = Target.canonicalize_target_and_host(["llvm", "cuda"]) +def test_canon_multi_target_and_host_3(): + raw_targets = Target.canon_multi_target_and_host(["llvm", "cuda"]) assert len(raw_targets) == 2 assert raw_targets[0].kind.name == "llvm" assert raw_targets[1].kind.name == "cuda" -def test_canonicalize_target_and_host_4(): - raw_targets = Target.canonicalize_target_and_host("llvm") +def test_canon_multi_target_and_host_4(): + raw_targets = Target.canon_multi_target_and_host("llvm") assert len(raw_targets) == 1 assert raw_targets[0].kind.name == "llvm" -def test_canonicalize_target_and_host_5(): - raw_targets = Target.canonicalize_target_and_host("cuda", "llvm") +def test_canon_multi_target_and_host_5(): + raw_targets = Target.canon_multi_target_and_host("cuda", "llvm") assert len(raw_targets) == 1 assert raw_targets[0].kind.name == "cuda" assert raw_targets[0].host.kind.name == "llvm" +def test_canon_multi_target_and_host_6(): + """Test `canon_target_and_host` by using TVM Objects""" + cuda_device_type = tvm.device("cuda").device_type + target = {cuda_device_type: Target(target="cuda", host="llvm")} + host = None + raw_targets_1 = Target.canon_multi_target_and_host(target, host) + assert len(raw_targets_1) == 1 + assert raw_targets_1[0].kind.name == "cuda" + assert raw_targets_1[0].host.kind.name == "llvm" + + target = {cuda_device_type: Target(tvm.runtime.container.String("cuda"))} + host = Target(tvm.runtime.container.String("llvm")) + target = tvm.runtime.convert(target) + assert isinstance(target, tvm.ir.container.Map) + raw_targets_2 = Target.canon_multi_target_and_host(target, host) + assert len(raw_targets_2) == 1 + assert raw_targets_2[0].kind.name == "cuda" + assert raw_targets_2[0].host.kind.name == "llvm" + + +def test_canon_target_map_and_host(): + target_map = {"cuda": "cuda_module", "llvm": "cpu_module"} + target_map, host = Target.canon_target_map_and_host(target_map, "llvm") + assert host.kind.name == "llvm" + for t, v in target_map.items(): + assert t.host.kind.name == "llvm" + if t.kind.name == "cuda": + assert v == "cuda_module" + elif t.kind.name == "llvm": + assert v == "cpu_module" + else: + assert False + + def test_target_attr_bool_value(): target0 = Target("vulkan --supports_float16=True") assert target0.attrs["supports_float16"] == 1