Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 47 additions & 9 deletions python/tvm/target/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""Target data structure."""
import os
import re
import json
import warnings
import tvm._ffi

Expand Down Expand Up @@ -347,13 +348,43 @@ def create_llvm(llvm_args):
return _ffi_api.TargetCreate('hexagon', *args_list)


def create(target_str):
def create(target):
"""Get a target given target string.

Parameters
----------
target_str : str
The target string.
target : str or dict
Can be one of a literal target string, a json string describing
a configuration, or a dictionary of configuration options.
When using a dictionary or json string to configure target, the
possible values are:

kind : str (required)
Which codegen path to use, for example 'llvm' or 'cuda'.
keys : List of str (optional)
A set of strategies that can be dispatched to. When using
"kind=opencl" for example, one could set keys to ["mali", "opencl", "gpu"].
device : str (optional)
A single key that corresponds to the actual device being run on.
This will be effectively appended to the keys.
libs : List of str (optional)
The set of external libraries to use. For example ['cblas', 'mkl'].
system-lib : bool (optional)
If True, build a module that contains self registered functions.
Useful for environments where dynamic loading like dlopen is banned.
mcpu : str (optional)
The specific cpu being run on. Serves only as an annotation.
model : str (optional)
An annotation indicating what model a workload came from.
runtime : str (optional)
An annotation indicating which runtime to use with a workload.
mtriple : str (optional)
The llvm triplet describing the target, for example "arm64-linux-android".
mattr : List of str (optional)
The llvm features to compile with, for example ["+avx512f", "+mmx"].
mfloat-abi : str (optional)
An llvm setting that is one of 'hard' or 'soft' indicating whether to use
hardware or software floating-point operations.

Returns
-------
Expand All @@ -364,9 +395,16 @@ def create(target_str):
----
See the note on :py:mod:`tvm.target` on target string format.
"""
if isinstance(target_str, Target):
return target_str
if not isinstance(target_str, str):
raise ValueError("target_str has to be string type")

return _ffi_api.TargetFromString(target_str)
if isinstance(target, Target):
return target
if isinstance(target, dict):
return _ffi_api.TargetFromConfig(target)
if isinstance(target, str):
# Check if target is a valid json string by trying to load it.
# If we cant, then assume it is a non-json target string.
try:
return _ffi_api.TargetFromConfig(json.loads(target))
except json.decoder.JSONDecodeError:
return _ffi_api.TargetFromString(target)

raise ValueError("target has to be a string or dictionary.")
4 changes: 3 additions & 1 deletion src/target/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ Target Target::FromConfig(const Map<String, ObjectRef>& config_dict) {
const auto* cfg_keys = config[kKeys].as<ArrayNode>();
CHECK(cfg_keys != nullptr)
<< "AttributeError: Expect type of field 'keys' is an Array, but get: "
<< config[kTag]->GetTypeKey();
<< config[kKeys]->GetTypeKey();
for (const ObjectRef& e : *cfg_keys) {
const auto* key = e.as<StringObj>();
CHECK(key != nullptr) << "AttributeError: Expect 'keys' to be an array of strings, but it "
Expand Down Expand Up @@ -525,6 +525,8 @@ TVM_REGISTER_GLOBAL("target.GetCurrentTarget").set_body_typed(Target::Current);

TVM_REGISTER_GLOBAL("target.TargetFromString").set_body_typed(Target::Create);

TVM_REGISTER_GLOBAL("target.TargetFromConfig").set_body_typed(Target::FromConfig);

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TargetNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const TargetNode*>(node.get());
Expand Down
47 changes: 47 additions & 0 deletions tests/python/unittest/test_target_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
import tvm
from tvm import te
from tvm.target import cuda, rocm, mali, intel_graphics, arm_cpu, vta, bifrost, hexagon
Expand Down Expand Up @@ -80,7 +81,53 @@ def test_target_create():
assert tgt is not None


def test_target_config():
"""
Test that constructing a target from a dictionary works.
"""
target_config = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should also test map-type attributes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you expand on this a little? Do you mean add testing that confirms it fails when invalid attribute types are provided?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just meant we should also test the attribute which type is a map. Since one purpose of testing Python binding is to make sure we can correctly pass supported data structures to the C++ container. Specifically, the Python version of this test: https://github.com/apache/incubator-tvm/blob/master/tests/cpp/target_test.cc#L121

Copy link
Contributor Author

@jwfromm jwfromm Aug 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see. It's a little strange to do this since none of the supported options are maps. But I could add a test case that passes a map and then confirm it fails like the cpp version you linked.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added an analogous test in the latest commit. Let me know if this is what you were looking for.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw the map test you added, but it is intented to fail so I'm not sure if it is sufficient.
@junrushao1994 could you comment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to be clear, there are no valid arguments that have type map. The tests you linked in target_test.cc also are intentionally failing.

'kind': 'llvm',
'keys': ['arm_cpu', 'cpu'],
'device': 'arm_cpu',
'libs': ['cblas'],
'system-lib': True,
'mfloat-abi': 'hard',
'mattr': ['+neon', '-avx512f'],
}
# Convert config dictionary to json string.
target_config_str = json.dumps(target_config)
# Test both dictionary input and json string.
for config in [target_config, target_config_str]:
target = tvm.target.create(config)
assert target.kind.name == 'llvm'
assert all([key in target.keys for key in ['arm_cpu', 'cpu']])
assert target.device_name == 'arm_cpu'
assert target.libs == ['cblas']
assert 'system-lib' in str(target)
assert target.attrs['mfloat-abi'] == 'hard'
assert all([attr in target.attrs['mattr'] for attr in ['+neon', '-avx512f']])


def test_config_map():
"""
Confirm that constructing a target with invalid
attributes fails as expected.
"""
target_config = {
'kind': 'llvm',
'libs': {'a': 'b', 'c': 'd'}
}
failed = False
try:
target = tvm.target.create(target_config)
except AttributeError:
failed = True
assert failed == True


if __name__ == "__main__":
test_target_dispatch()
test_target_string_parse()
test_target_create()
test_target_config()
test_config_map()