diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 4f5b5d146d92..8e25574b7652 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -48,6 +48,18 @@ using PassContext = tvm::transform::PassContext; using PassContextNode = tvm::transform::PassContextNode; using Sequential = tvm::transform::Sequential; +/*! + * \brief RelayToTIR tvm::transform::Pass specific to a TargetKind + * + * Called before the default lowering passes. + * + * \param mod The module that an optimization pass runs on. + * \param pass_ctx The pass context that can provide information for the optimization. + * + * \return The transformed module. + */ +using FTVMRelayToTIR = tvm::transform::Pass; + /* * \brief Create a function pass. * diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index 56d6a596b9b2..d47ac94e067e 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -25,7 +25,6 @@ #define TVM_TARGET_TARGET_H_ #include -#include #include #include #include @@ -284,14 +283,5 @@ class Target : public ObjectRef { */ void CheckAndUpdateHostConsistency(Target* target, Target* host); -/*! - * \brief Check and update host field of the given legacy heterogeneous targets and - * target host.Note that this function is for legacy target api compatibility issue only, - * not recommended for other use. - * \param ir_modules The pointer to a Map objects with keys being Target objects - * \param host The Target typed object for target host to be updated - */ -void CheckAndUpdateHostConsistency(Map* ir_modules, Target* host); - } // namespace tvm #endif // TVM_TARGET_TARGET_H_ diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h index 19bcce3116b2..86f386abb827 100644 --- a/include/tvm/target/target_kind.h +++ b/include/tvm/target/target_kind.h @@ -24,7 +24,6 @@ #ifndef TVM_TARGET_TARGET_KIND_H_ #define TVM_TARGET_TARGET_KIND_H_ -#include #include #include @@ -50,31 +49,7 @@ using TargetFeatures = Map; * \return The transformed Target JSON object. */ using TargetJSON = Map; -using FTVMTargetParser = TypedPackedFunc; - -/*! - * \brief RelayToTIR tvm::transform::Pass specific to a TargetKind - * - * Called before the default lowering passes. - * - * \param mod The module that an optimization pass runs on. - * \param pass_ctx The pass context that can provide information for the optimization. - * - * \return The transformed module. - */ -using FTVMRelayToTIR = transform::Pass; - -/*! - * \brief TIRToRuntime conversion specific to a TargetKind - * - * This function is responsible for scanning an IRModule for appropriate Target-specific functions - and generating a Runtime module representing the compiled output - * - * \param ir_module Unified IRModule - * \param target Target to filter on or retrieve arguments from - * \return Runtime Module containing compiled functions - */ -using FTVMTIRToRuntime = runtime::TypedPackedFunc; +using FTVMTargetParser = runtime::TypedPackedFunc; namespace detail { template diff --git a/include/tvm/tir/usmp/utils.h b/include/tvm/tir/usmp/utils.h index f49e9ceef794..a67350a2bb13 100644 --- a/include/tvm/tir/usmp/utils.h +++ b/include/tvm/tir/usmp/utils.h @@ -27,6 +27,7 @@ #include #include +#include #include #include #include diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index d46fab716814..b7ba0ffe4468 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -431,6 +431,23 @@ std::pair SplitMixedModule(IRModule mod_mixed, const Target& return {host_mod, device_mod}; } +/*! + * \brief Check and update host field of the given legacy heterogeneous targets and + * target host.Note that this function is for legacy target api compatibility issue only, + * not recommended for other use. + * \param ir_modules The pointer to a Map objects with keys being Target objects + * \param host The Target typed object for target host to be updated + */ +void CheckAndUpdateHostConsistency(Map* targets, Target* host) { + Map new_targets; + for (auto& it : *targets) { + auto target = it.first; + CheckAndUpdateHostConsistency(&target, host); + new_targets.Set(target, it.second); + } + *targets = new_targets; +} + runtime::Module TIRToRuntime(const Map& inputs_arg, const Target& target_host_arg) { std::vector device_modules; diff --git a/src/relay/backend/contrib/cmsisnn/target.cc b/src/relay/backend/contrib/cmsisnn/target.cc index f14c106703b3..10125bf814ad 100644 --- a/src/relay/backend/contrib/cmsisnn/target.cc +++ b/src/relay/backend/contrib/cmsisnn/target.cc @@ -32,12 +32,13 @@ namespace cmsisnn { tvm::transform::Pass RelayToTIR(); runtime::Module TIRToRuntime(IRModule mod, Target target); +using FTVMTIRToRuntime = tvm::runtime::TypedPackedFunc; TVM_REGISTER_TARGET_KIND("cmsis-nn", kDLCPU) .add_attr_option>("mattr") .add_attr_option("mcpu") .add_attr_option("debug_last_error") - .set_attr(tvm::attr::kRelayToTIR, RelayToTIR()) + .set_attr(tvm::attr::kRelayToTIR, RelayToTIR()) .set_attr("TIRToRuntime", TIRToRuntime) .set_target_parser(tvm::target::parsers::cpu::ParseTarget); diff --git a/src/relay/backend/contrib/codegen_c/target.cc b/src/relay/backend/contrib/codegen_c/target.cc index 623057ac1762..cd1e0283df28 100644 --- a/src/relay/backend/contrib/codegen_c/target.cc +++ b/src/relay/backend/contrib/codegen_c/target.cc @@ -34,7 +34,7 @@ namespace contrib { */ TVM_REGISTER_TARGET_KIND("ccompiler", kDLCPU) .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)) - .set_attr(tvm::attr::kRelayToTIR, CCompilerPass()) + .set_attr(tvm::attr::kRelayToTIR, CCompilerPass()) // Value is prepended to every output CModule. .add_attr_option("header", String("")); diff --git a/src/relay/backend/contrib/cutlass/target.cc b/src/relay/backend/contrib/cutlass/target.cc index 7b377f340a57..50c8b84a9069 100644 --- a/src/relay/backend/contrib/cutlass/target.cc +++ b/src/relay/backend/contrib/cutlass/target.cc @@ -40,7 +40,7 @@ namespace cutlass { */ TVM_REGISTER_TARGET_KIND("cutlass", kDLCUDA) .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)) - .set_attr("RelayToTIR", CompileForCutlass()) + .set_attr("RelayToTIR", CompileForCutlass()) // An integer specifying the compute capability. For example, 75 for Turing and // 80 or 86 for Ampere. .add_attr_option("sm", Integer(80)) diff --git a/src/relay/backend/contrib/ethosu/codegen.cc b/src/relay/backend/contrib/ethosu/codegen.cc index f35d4c6d48b2..54d0595c4634 100644 --- a/src/relay/backend/contrib/ethosu/codegen.cc +++ b/src/relay/backend/contrib/ethosu/codegen.cc @@ -47,6 +47,8 @@ namespace relay { namespace contrib { namespace ethosu { +using FTVMTIRToRuntime = tvm::runtime::TypedPackedFunc; + /*! * \brief This mutator outlines functions that are marked with a named * "Compiler" attribute. Functions that do not match this condition remain @@ -320,7 +322,7 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) { TVM_REGISTER_TARGET_KIND("ethos-u", kDLCPU) .set_attr("use_device_api", Bool(true)) - .set_attr(tvm::attr::kRelayToTIR, RelayToTIR()) + .set_attr(tvm::attr::kRelayToTIR, RelayToTIR()) .set_attr("TIRToRuntime", TIRToRuntime); } // namespace ethosu diff --git a/src/relay/backend/contrib/example_target_hooks/target.cc b/src/relay/backend/contrib/example_target_hooks/target.cc index b01c23ed806a..b45987f6be33 100644 --- a/src/relay/backend/contrib/example_target_hooks/target.cc +++ b/src/relay/backend/contrib/example_target_hooks/target.cc @@ -22,6 +22,8 @@ namespace tvm { +using FTVMTIRToRuntime = tvm::runtime::TypedPackedFunc; + namespace relay { namespace contrib { namespace example_target_hooks { @@ -33,7 +35,8 @@ runtime::Module TIRToRuntime(IRModule mod, Target target); TVM_REGISTER_TARGET_KIND("example_target_hook", kDLCPU) .set_attr("use_device_api", Bool(true)) - .set_attr(attr::kRelayToTIR, relay::contrib::example_target_hooks::RelayToTIR()) + .set_attr(attr::kRelayToTIR, + relay::contrib::example_target_hooks::RelayToTIR()) .set_attr("TIRToRuntime", relay::contrib::example_target_hooks::TIRToRuntime) .add_attr_option("example_attribute", Integer(0)); diff --git a/src/relay/backend/contrib/tensorrt/target.cc b/src/relay/backend/contrib/tensorrt/target.cc index 2e4581d30a3c..0277787a8c12 100644 --- a/src/relay/backend/contrib/tensorrt/target.cc +++ b/src/relay/backend/contrib/tensorrt/target.cc @@ -39,7 +39,7 @@ namespace tensorrt { */ TVM_REGISTER_TARGET_KIND("tensorrt", kDLCUDA) .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)) - .set_attr("RelayToTIR", CompileForTensorRT()) + .set_attr("RelayToTIR", CompileForTensorRT()) // A array of three integers given the major, minor, and patch numbers for the supported // TensorRT compiler version. If empty will be auto-detected from linked library. Default empty. .add_attr_option>("tensorrt_version", Array()) diff --git a/src/relay/backend/contrib/uma/targets.cc b/src/relay/backend/contrib/uma/targets.cc index e2fe644cb9bf..244f243749c1 100644 --- a/src/relay/backend/contrib/uma/targets.cc +++ b/src/relay/backend/contrib/uma/targets.cc @@ -28,6 +28,8 @@ namespace tvm { +using FTVMTIRToRuntime = tvm::runtime::TypedPackedFunc; + namespace relay { namespace contrib { namespace uma { @@ -57,8 +59,8 @@ TVM_REGISTER_GLOBAL("relay.backend.contrib.uma.RegisterTarget") .add_attr_option>("libs") .add_attr_option("host") .add_attr_option("from_device") - .set_attr(attr::kRelayToTIR, - relay::contrib::uma::RelayToTIR(target_name)) + .set_attr( + attr::kRelayToTIR, relay::contrib::uma::RelayToTIR(target_name)) .set_attr("TIRToRuntime", relay::contrib::uma::TIRToRuntime); // target kind attrs inventory diff --git a/src/target/codegen.cc b/src/target/codegen.cc index bbb2c15a647f..6e31db4f608f 100644 --- a/src/target/codegen.cc +++ b/src/target/codegen.cc @@ -38,8 +38,21 @@ #include namespace tvm { + namespace codegen { +/*! + * \brief TIRToRuntime conversion specific to a TargetKind + * + * This function is responsible for scanning an IRModule for appropriate Target-specific functions + and generating a Runtime module representing the compiled output + * + * \param ir_module Unified IRModule + * \param target Target to filter on or retrieve arguments from + * \return Runtime Module containing compiled functions + */ +using FTVMTIRToRuntime = tvm::runtime::TypedPackedFunc; + runtime::Module Build(IRModule mod, Target target) { if (transform::PassContext::Current() ->GetConfig("tir.disable_assert", Bool(false)) diff --git a/src/target/target.cc b/src/target/target.cc index 2f585188d0d3..cd2e3714e422 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -21,6 +21,7 @@ * \file src/target/target.cc */ #include +#include #include #include #include @@ -91,16 +92,6 @@ void CheckAndUpdateHostConsistency(Target* target, Target* host) { *host = (*target)->GetHost().value_or(Target()); } -void CheckAndUpdateHostConsistency(Map* targets, Target* host) { - Map new_targets; - for (auto& it : *targets) { - auto target = it.first; - CheckAndUpdateHostConsistency(&target, host); - new_targets.Set(target, it.second); - } - *targets = new_targets; -} - static std::vector DeduplicateKeys(const std::vector& keys) { std::vector new_keys; for (size_t i = 0; i < keys.size(); ++i) { @@ -614,8 +605,8 @@ Target::Target(TargetKind kind, Optional host, String tag, Array is_external_codegen_map = TargetKind::GetAttrMap(tvm::attr::kIsExternalCodegen); - TargetKindAttrMap relay_to_tir_map = - TargetKind::GetAttrMap(tvm::attr::kRelayToTIR); + TargetKindAttrMap relay_to_tir_map = + TargetKind::GetAttrMap(tvm::attr::kRelayToTIR); return is_external_codegen_map.get(get()->kind, Bool(false)) || relay_to_tir_map.count(get()->kind); } diff --git a/tests/cpp/target_test.cc b/tests/cpp/target_test.cc index 37a8eeb44840..50a6f2f2ac16 100644 --- a/tests/cpp/target_test.cc +++ b/tests/cpp/target_test.cc @@ -458,7 +458,8 @@ TVM_REGISTER_TARGET_KIND("test_external_codegen_2", kDLMetal) .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)); TVM_REGISTER_TARGET_KIND("test_external_codegen_3", kDLCPU) - .set_attr(tvm::attr::kRelayToTIR, tvm::relay::transform::InferType()); + .set_attr(tvm::attr::kRelayToTIR, + tvm::relay::transform::InferType()); TEST(Target, ExternalCodegen) { Target regular("cuda");