diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index be31e43c96b6..24e80686850d 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -19,6 +19,8 @@ """The build utils in python.""" from typing import Union, Optional, List, Mapping +import warnings + import tvm.tir from tvm.runtime import Module @@ -255,8 +257,13 @@ def build( annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host) - # TODO(mbs): CompilationConfig implements the same host target defaulting logic, but - # tir_to_runtime currently bypasses that. + # TODO(mbs): Both CompilationConfig and TIRToRuntime implement the same host target + # defaulting logic, but there's currently no way to get back the decided host. + if target_host is not None: + warnings.warn( + "target_host parameter is going to be deprecated. " + "Please pass in tvm.target.Target(target, host=target_host) instead." + ) if not target_host: for tar, mod in annotated_mods.items(): device_type = ndarray.device(tar.kind.name, 0).device_type diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index ec42984a448d..bcb284839d19 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -247,14 +247,15 @@ def canon_target_and_host(target, target_host=None): Note that this method does not support heterogeneous compilation targets. """ target = Target.canon_target(target) - target_host = Target.canon_target(target_host) if target is None: assert target_host is None, "Target host is not empty when target is empty." - if target_host is not None: + return target, target_host + if target.host is None and target_host is not None: warnings.warn( "target_host parameter is going to be deprecated. " "Please pass in tvm.target.Target(target, host=target_host) instead." ) + target_host = Target.canon_target(target_host) target = target.with_host(target_host) if target is not None: # In case the target already had a host, extract it here. @@ -293,15 +294,15 @@ def canon_multi_target_and_host(target, target_host=None): """ # Convert target to Array, but not yet accounting for any host. raw_targets = Target.canon_multi_target(target) - assert raw_targets is not None + assert raw_targets is not None and len(raw_targets) > 0 # Convert host to Target, if given. - target_host = Target.canon_target(target_host) - if target_host is not None: + if raw_targets[0].host is None and target_host is not None: warnings.warn( "target_host parameter is going to be deprecated. " "Please pass in tvm.target.Target(target, host=target_host) instead." ) # Make sure the (canonical) host is captured in all the (canonical) targets. + target_host = Target.canon_target(target_host) raw_targets = convert([tgt.with_host(target_host) for tgt in raw_targets]) return raw_targets @@ -312,22 +313,22 @@ def canon_target_map_and_host(target_map, target_host=None): Similarly, if given, target_host can be in any form recognized by Target.canon_target. The final target_map keys will capture the target_host in canonical form. Also returns the target_host in canonical form.""" - if target_host is not None: - warnings.warn( - "target_host parameter is going to be deprecated. " - "Please pass in tvm.target.Target(target, host=target_host) instead." - ) - target_host = Target.canon_target(target_host) new_target_map = {} + canonical_target_host = None for tgt, mod in target_map.items(): tgt = Target.canon_target(tgt) assert tgt is not None - if target_host is not None: - tgt = tgt.with_host(target_host) - # In case the first target already has a host, extract it here. - target_host = tgt.host + if canonical_target_host is None: + if tgt.host is not None: + canonical_target_host = tgt.host + elif target_host is not None: + # No deprecation warning in this case since host may have been manufactured + # behind the scenes in build_module.py build. + canonical_target_host = Target.canon_target(target_host) + if tgt.host is None and canonical_target_host is not None: + tgt = tgt.with_host(canonical_target_host) new_target_map[tgt] = mod - return new_target_map, target_host + return new_target_map, canonical_target_host @staticmethod def target_or_current(target):