diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index a9d893ff5402..fca2839cb363 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -55,6 +55,9 @@ class TargetNode : public Object { Array keys; /*! \brief Collection of attributes */ Map attrs; + /*! \brief Target features */ + Map features; + /*! * \brief The raw string representation of the target * \return the full device string to pass to codegen::Build @@ -80,6 +83,7 @@ class TargetNode : public Object { v->Visit("tag", &tag); v->Visit("keys", &keys); v->Visit("attrs", &attrs); + v->Visit("features", &features); v->Visit("host", &host); } @@ -114,6 +118,42 @@ class TargetNode : public Object { Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { return GetAttr(attr_key, Optional(default_value)); } + + /*! + * \brief Get a Target feature + * + * \param feature_key The feature key. + * \param default_value The default value if the key does not exist, defaults to nullptr. + * + * \return The result + * + * \tparam TOBjectRef the expected object type. + * \throw Error if the key exists but the value does not match TObjectRef + * + * \code + * + * void GetTargetFeature(const Target& target) { + * Bool has_feature = target->GetFeature("has_feature", false).value(); + * } + * + * \endcode + */ + template + Optional GetFeature( + const std::string& feature_key, + Optional default_value = Optional(nullptr)) const { + Optional feature = Downcast>(features.Get(feature_key)); + if (!feature) { + return default_value; + } + return feature; + } + // variant that uses TObjectRef to enable implicit conversion to default value. + template + Optional GetFeature(const std::string& attr_key, TObjectRef default_value) const { + return GetFeature(attr_key, Optional(default_value)); + } + /*! \brief Get the keys for this target as a vector of string */ TVM_DLL std::vector GetKeys() const; /*! \brief Get the keys for this target as an unordered_set of string */ diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 5dc3fe093858..0b14c80bdb6f 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -216,6 +216,7 @@ class Device(ctypes.Structure): "stackvm": 1, "cpu": 1, "c": 1, + "test": 1, "hybrid": 1, "composite": 1, "cuda": 2, diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index 2518527083aa..ab646ab83c63 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -43,6 +43,14 @@ def options_from_name(kind_name: str): return dict(_ffi_api.ListTargetKindOptionsFromName(kind_name)) +class TargetFeatures: + def __init__(self, target): + self.target = target + + def __getattr__(self, name: str): + return _ffi_api.TargetGetFeature(self.target, name) + + @tvm._ffi.register_object class Target(Object): """Target device information, use through TVM API. @@ -207,6 +215,10 @@ def supports_integer_dot_product(self): def libs(self): return list(self.attrs.get("libs", [])) + @property + def features(self): + return TargetFeatures(self) + def get_kind_attr(self, attr_name): """Get additional attribute about the target kind. diff --git a/src/target/target.cc b/src/target/target.cc index 207a399a77ee..9ccd755540ca 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -722,8 +722,11 @@ ObjectPtr TargetInternal::FromConfig(Map config) { const String kKeys = "keys"; const String kDeviceName = "device"; const String kHost = "host"; + const String kFeatures = "features"; ObjectPtr target = make_object(); + ICHECK(!config.count(kFeatures)) << "Target Features should be generated by Target parser"; + // parse 'kind' if (config.count(kKind)) { if (const auto* kind = config[kKind].as()) { @@ -735,6 +738,10 @@ ObjectPtr TargetInternal::FromConfig(Map config) { if (target->kind->target_parser != nullptr) { VLOG(9) << "TargetInternal::FromConfig - Running target_parser"; config = target->kind->target_parser(config); + if (config.count(kFeatures)) { + target->features = Downcast>(config[kFeatures]); + config.erase(kFeatures); + } } config.erase(kKind); @@ -914,6 +921,10 @@ TVM_REGISTER_GLOBAL("target.TargetExitScope").set_body_typed(TargetInternal::Exi TVM_REGISTER_GLOBAL("target.TargetCurrent").set_body_typed(Target::Current); TVM_REGISTER_GLOBAL("target.TargetExport").set_body_typed(TargetInternal::Export); TVM_REGISTER_GLOBAL("target.WithHost").set_body_typed(TargetInternal::WithHost); +TVM_REGISTER_GLOBAL("target.TargetGetFeature") + .set_body_typed([](const Target& target, const String& feature_key) { + return target->GetFeature(feature_key); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& obj, ReprPrinter* p) { diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 7620c6fc2e53..0d3e7b0a424c 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -244,6 +244,17 @@ TargetJSON UpdateROCmAttrs(TargetJSON target) { return target; } +/*! + * \brief Test Target Parser + * \param target The Target to update + * \return The updated attributes + */ +TargetJSON TestTargetParser(TargetJSON target) { + Map features = {{"is_test", Bool(true)}}; + target.Set("features", features); + return target; +} + /********** Register Target kinds and attributes **********/ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) @@ -416,6 +427,9 @@ TVM_REGISTER_TARGET_KIND("hybrid", kDLCPU) // line break TVM_REGISTER_TARGET_KIND("composite", kDLCPU) // line break .add_attr_option>("devices"); +TVM_REGISTER_TARGET_KIND("test", kDLCPU) // line break + .set_target_parser(TestTargetParser); + /********** Registry **********/ TVM_REGISTER_GLOBAL("target.TargetKindGetAttr") diff --git a/tests/cpp/target_test.cc b/tests/cpp/target_test.cc index 6854fc661d0b..cb5eaa18b576 100644 --- a/tests/cpp/target_test.cc +++ b/tests/cpp/target_test.cc @@ -38,6 +38,7 @@ TargetJSON TestTargetParser(TargetJSON target) { String mcpu = Downcast(target.at("mcpu")); target.Set("mcpu", String("super_") + mcpu); target.Set("keys", Array({"super"})); + target.Set("features", Map{{"test", Bool(true)}}); return target; } @@ -174,13 +175,32 @@ TEST(TargetCreation, TargetParser) { ASSERT_EQ(test_target->keys[1], "cpu"); } +TEST(TargetCreation, TargetFeatures) { + Target test_target_with_parser("TestTargetParser -mcpu=woof"); + ASSERT_EQ(test_target_with_parser->GetFeature("test").value(), true); + + Target test_target_no_parser("TestTargetKind"); + ASSERT_EQ(test_target_no_parser->GetFeature("test"), nullptr); + ASSERT_EQ(test_target_no_parser->GetFeature("test", Bool(true)).value(), true); +} + +TEST(TargetCreation, TargetFeaturesBeforeParser) { + Map features = {{"test", Bool(true)}}; + Map config = { + {"kind", String("TestTargetParser")}, + {"mcpu", String("woof")}, + {"features", features}, + }; + EXPECT_THROW(Target test(config), InternalError); +} + TEST(TargetCreation, TargetAttrsPreProcessor) { Target test_target("TestAttrsPreprocessor -mattr=cake"); ASSERT_EQ(test_target->GetAttr("mattr").value(), "woof"); } TEST(TargetCreation, ClashingTargetProcessing) { - EXPECT_THROW(Target("TestClashingPreprocessor -mcpu=woof -mattr=cake"), InternalError); + EXPECT_THROW(Target test("TestClashingPreprocessor -mcpu=woof -mattr=cake"), InternalError); } TVM_REGISTER_TARGET_KIND("test_external_codegen_0", kDLCUDA) diff --git a/tests/python/unittest/test_target_target.py b/tests/python/unittest/test_target_target.py index 5a5c17e196dc..ef55abfa4dcd 100644 --- a/tests/python/unittest/test_target_target.py +++ b/tests/python/unittest/test_target_target.py @@ -470,5 +470,15 @@ def test_target_attr_bool_value(): assert target3.attrs["supports_float16"] == 0 +def test_target_features(): + target_no_features = Target("cuda") + assert target_no_features.features + assert not target_no_features.features.is_test + + target_with_features = Target("test") + assert target_with_features.features.is_test + assert not target_with_features.features.is_missing + + if __name__ == "__main__": tvm.testing.main()