Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/tvm/autotvm/task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def __getstate__(self):
"config_space": self.config_space,
"flop": self.flop,
"target": self.target,
"target_host": self.target.host,
"target_host": self.target_host,
"func": cloudpickle.dumps(self.func),
}

Expand Down Expand Up @@ -465,7 +465,7 @@ def create(task_name, args, target, target_host=None):

ret.flop = ret.config_space.flop or compute_flop(sch)
ret.target = target
ret.target_host = target.host
ret.target_host = target_host

return ret

Expand Down
3 changes: 3 additions & 0 deletions python/tvm/target/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ def check_and_update_host_consist(target, host=None, target_is_dict_key=True):
target_is_dict_key : Bool
When the type of target is dict, whether Target is the key (Otherwise the value)
"""
if target is None:
assert host is None, "Target host is not empty when target is empty."
return target, host
if isinstance(target, dict) and "kind" not in target:
new_target = {}
for tgt, mod in target.items():
Expand Down
4 changes: 2 additions & 2 deletions tests/python/integration/test_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

from tvm import autotvm
from tvm.autotvm.tuner import RandomTuner
from tvm.target import Target

import tvm.testing

Expand Down Expand Up @@ -131,8 +132,7 @@ def teardown_module():


def get_sample_task(target=tvm.target.cuda(), target_host=None):
target = tvm.target.Target(target, target_host)
target_host = target.host
target, target_host = Target.check_and_update_host_consist(target, target_host)
"""return a sample task for testing"""
task = autotvm.task.create(
"testing/conv2d_no_batching", args=(1, 7, 7, 512, 512, 3, 3), target=target
Expand Down
33 changes: 32 additions & 1 deletion tests/python/unittest/test_target_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import sys
import pytest
import tvm
from tvm.target import cuda, rocm, mali, intel_graphics, arm_cpu, vta, bifrost
from tvm.target import cuda, rocm, mali, intel_graphics, arm_cpu, vta, bifrost, Target


@tvm.target.generic_func
Expand Down Expand Up @@ -268,5 +268,36 @@ def test_target_with_host():
assert tgt.host.attrs["registers_per_block"] == 32768


def test_check_and_update_host_consist_0():
target = None
host = None
target, host = Target.check_and_update_host_consist(target, host)


def test_check_and_update_host_consist_1():
target = None
host = "llvm"
with pytest.raises(AssertionError, match=r"Target host is not empty when target is empty."):
target, host = Target.check_and_update_host_consist(target, host)


def test_check_and_update_host_consist_2():
target = Target("cuda")
host = Target("llvm")
target, host = Target.check_and_update_host_consist(target, host)
assert target.kind.name == "cuda"
assert target.host.kind.name == "llvm"


def test_check_and_update_host_consist_3():
target = Target(target="cuda", host="llvm")
host = None
target, host = Target.check_and_update_host_consist(target, host)
assert target.kind.name == "cuda"
assert target.host.kind.name == "llvm"
assert host.kind.name == "llvm"
assert target.host == host


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))