Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions include/tvm/target/target.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ class TargetNode : public Object {
Array<String> keys;
/*! \brief Collection of attributes */
Map<String, ObjectRef> attrs;
/*! \brief Target features */
Map<String, ObjectRef> features;

/*!
* \brief The raw string representation of the target
* \return the full device string to pass to codegen::Build
Expand All @@ -80,6 +83,7 @@ class TargetNode : public Object {
v->Visit("tag", &tag);
v->Visit("keys", &keys);
v->Visit("attrs", &attrs);
v->Visit("features", &features);
v->Visit("host", &host);
}

Expand Down Expand Up @@ -114,6 +118,42 @@ class TargetNode : public Object {
Optional<TObjectRef> GetAttr(const std::string& attr_key, TObjectRef default_value) const {
return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value));
}

/*!
* \brief Get a Target feature
*
* \param feature_key The feature key.
* \param default_value The default value if the key does not exist, defaults to nullptr.
*
* \return The result
*
* \tparam TOBjectRef the expected object type.
* \throw Error if the key exists but the value does not match TObjectRef
*
* \code
*
* void GetTargetFeature(const Target& target) {
* Bool has_feature = target->GetFeature<Bool>("has_feature", false).value();
* }
*
* \endcode
*/
template <typename TObjectRef>
Optional<TObjectRef> GetFeature(
const std::string& feature_key,
Optional<TObjectRef> default_value = Optional<TObjectRef>(nullptr)) const {
Optional<TObjectRef> feature = Downcast<Optional<TObjectRef>>(features.Get(feature_key));
if (!feature) {
return default_value;
}
return feature;
}
// variant that uses TObjectRef to enable implicit conversion to default value.
template <typename TObjectRef>
Optional<TObjectRef> GetFeature(const std::string& attr_key, TObjectRef default_value) const {
return GetFeature<TObjectRef>(attr_key, Optional<TObjectRef>(default_value));
}

/*! \brief Get the keys for this target as a vector of string */
TVM_DLL std::vector<std::string> GetKeys() const;
/*! \brief Get the keys for this target as an unordered_set of string */
Expand Down
1 change: 1 addition & 0 deletions python/tvm/_ffi/runtime_ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ class Device(ctypes.Structure):
"stackvm": 1,
"cpu": 1,
"c": 1,
"test": 1,
"hybrid": 1,
"composite": 1,
"cuda": 2,
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/target/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ def options_from_name(kind_name: str):
return dict(_ffi_api.ListTargetKindOptionsFromName(kind_name))


class TargetFeatures:
def __init__(self, target):
self.target = target

def __getattr__(self, name: str):
return _ffi_api.TargetGetFeature(self.target, name)


@tvm._ffi.register_object
class Target(Object):
"""Target device information, use through TVM API.
Expand Down Expand Up @@ -207,6 +215,10 @@ def supports_integer_dot_product(self):
def libs(self):
return list(self.attrs.get("libs", []))

@property
def features(self):
return TargetFeatures(self)

def get_kind_attr(self, attr_name):
"""Get additional attribute about the target kind.

Expand Down
11 changes: 11 additions & 0 deletions src/target/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -722,8 +722,11 @@ ObjectPtr<Object> TargetInternal::FromConfig(Map<String, ObjectRef> config) {
const String kKeys = "keys";
const String kDeviceName = "device";
const String kHost = "host";
const String kFeatures = "features";
ObjectPtr<TargetNode> target = make_object<TargetNode>();

ICHECK(!config.count(kFeatures)) << "Target Features should be generated by Target parser";

// parse 'kind'
if (config.count(kKind)) {
if (const auto* kind = config[kKind].as<StringObj>()) {
Expand All @@ -735,6 +738,10 @@ ObjectPtr<Object> TargetInternal::FromConfig(Map<String, ObjectRef> config) {
if (target->kind->target_parser != nullptr) {
VLOG(9) << "TargetInternal::FromConfig - Running target_parser";
config = target->kind->target_parser(config);
if (config.count(kFeatures)) {
target->features = Downcast<Map<String, ObjectRef>>(config[kFeatures]);
config.erase(kFeatures);
}
}

config.erase(kKind);
Expand Down Expand Up @@ -914,6 +921,10 @@ TVM_REGISTER_GLOBAL("target.TargetExitScope").set_body_typed(TargetInternal::Exi
TVM_REGISTER_GLOBAL("target.TargetCurrent").set_body_typed(Target::Current);
TVM_REGISTER_GLOBAL("target.TargetExport").set_body_typed(TargetInternal::Export);
TVM_REGISTER_GLOBAL("target.WithHost").set_body_typed(TargetInternal::WithHost);
TVM_REGISTER_GLOBAL("target.TargetGetFeature")
.set_body_typed([](const Target& target, const String& feature_key) {
return target->GetFeature<ObjectRef>(feature_key);
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TargetNode>([](const ObjectRef& obj, ReprPrinter* p) {
Expand Down
14 changes: 14 additions & 0 deletions src/target/target_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,17 @@ TargetJSON UpdateROCmAttrs(TargetJSON target) {
return target;
}

/*!
* \brief Test Target Parser
* \param target The Target to update
* \return The updated attributes
*/
TargetJSON TestTargetParser(TargetJSON target) {
Map<String, ObjectRef> features = {{"is_test", Bool(true)}};
target.Set("features", features);
return target;
}

/********** Register Target kinds and attributes **********/

TVM_REGISTER_TARGET_KIND("llvm", kDLCPU)
Expand Down Expand Up @@ -416,6 +427,9 @@ TVM_REGISTER_TARGET_KIND("hybrid", kDLCPU) // line break
TVM_REGISTER_TARGET_KIND("composite", kDLCPU) // line break
.add_attr_option<Array<Target>>("devices");

TVM_REGISTER_TARGET_KIND("test", kDLCPU) // line break
.set_target_parser(TestTargetParser);

/********** Registry **********/

TVM_REGISTER_GLOBAL("target.TargetKindGetAttr")
Expand Down
22 changes: 21 additions & 1 deletion tests/cpp/target_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ TargetJSON TestTargetParser(TargetJSON target) {
String mcpu = Downcast<String>(target.at("mcpu"));
target.Set("mcpu", String("super_") + mcpu);
target.Set("keys", Array<String>({"super"}));
target.Set("features", Map<String, ObjectRef>{{"test", Bool(true)}});
return target;
}

Expand Down Expand Up @@ -174,13 +175,32 @@ TEST(TargetCreation, TargetParser) {
ASSERT_EQ(test_target->keys[1], "cpu");
}

TEST(TargetCreation, TargetFeatures) {
Target test_target_with_parser("TestTargetParser -mcpu=woof");
ASSERT_EQ(test_target_with_parser->GetFeature<Bool>("test").value(), true);

Target test_target_no_parser("TestTargetKind");
ASSERT_EQ(test_target_no_parser->GetFeature<Bool>("test"), nullptr);
ASSERT_EQ(test_target_no_parser->GetFeature<Bool>("test", Bool(true)).value(), true);
}

TEST(TargetCreation, TargetFeaturesBeforeParser) {
Map<String, ObjectRef> features = {{"test", Bool(true)}};
Map<String, ObjectRef> config = {
{"kind", String("TestTargetParser")},
{"mcpu", String("woof")},
{"features", features},
};
EXPECT_THROW(Target test(config), InternalError);
}

TEST(TargetCreation, TargetAttrsPreProcessor) {
Target test_target("TestAttrsPreprocessor -mattr=cake");
ASSERT_EQ(test_target->GetAttr<String>("mattr").value(), "woof");
}

TEST(TargetCreation, ClashingTargetProcessing) {
EXPECT_THROW(Target("TestClashingPreprocessor -mcpu=woof -mattr=cake"), InternalError);
EXPECT_THROW(Target test("TestClashingPreprocessor -mcpu=woof -mattr=cake"), InternalError);
}

TVM_REGISTER_TARGET_KIND("test_external_codegen_0", kDLCUDA)
Expand Down
10 changes: 10 additions & 0 deletions tests/python/unittest/test_target_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,5 +470,15 @@ def test_target_attr_bool_value():
assert target3.attrs["supports_float16"] == 0


def test_target_features():
target_no_features = Target("cuda")
assert target_no_features.features
assert not target_no_features.features.is_test

target_with_features = Target("test")
assert target_with_features.features.is_test
assert not target_with_features.features.is_missing


if __name__ == "__main__":
tvm.testing.main()