diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py index 0d60ca929d7b..668832b8a86c 100644 --- a/python/tvm/autotvm/task/task.py +++ b/python/tvm/autotvm/task/task.py @@ -185,7 +185,7 @@ def __getstate__(self): "config_space": self.config_space, "flop": self.flop, "target": self.target, - "target_host": self.target.host, + "target_host": self.target_host, "func": cloudpickle.dumps(self.func), } @@ -465,7 +465,7 @@ def create(task_name, args, target, target_host=None): ret.flop = ret.config_space.flop or compute_flop(sch) ret.target = target - ret.target_host = target.host + ret.target_host = target_host return ret diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index 6d0a0635221e..baf07602bde6 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -182,6 +182,9 @@ def check_and_update_host_consist(target, host=None, target_is_dict_key=True): target_is_dict_key : Bool When the type of target is dict, whether Target is the key (Otherwise the value) """ + if target is None: + assert host is None, "Target host is not empty when target is empty." + return target, host if isinstance(target, dict) and "kind" not in target: new_target = {} for tgt, mod in target.items(): diff --git a/tests/python/integration/test_tuning.py b/tests/python/integration/test_tuning.py index 45e0958a0240..55c8e5643c71 100644 --- a/tests/python/integration/test_tuning.py +++ b/tests/python/integration/test_tuning.py @@ -30,6 +30,7 @@ from tvm import autotvm from tvm.autotvm.tuner import RandomTuner +from tvm.target import Target import tvm.testing @@ -131,8 +132,7 @@ def teardown_module(): def get_sample_task(target=tvm.target.cuda(), target_host=None): - target = tvm.target.Target(target, target_host) - target_host = target.host + target, target_host = Target.check_and_update_host_consist(target, target_host) """return a sample task for testing""" task = autotvm.task.create( "testing/conv2d_no_batching", args=(1, 7, 7, 512, 512, 3, 3), target=target diff --git a/tests/python/unittest/test_target_target.py b/tests/python/unittest/test_target_target.py index 2f885d39335b..98a9edc7a517 100644 --- a/tests/python/unittest/test_target_target.py +++ b/tests/python/unittest/test_target_target.py @@ -18,7 +18,7 @@ import sys import pytest import tvm -from tvm.target import cuda, rocm, mali, intel_graphics, arm_cpu, vta, bifrost +from tvm.target import cuda, rocm, mali, intel_graphics, arm_cpu, vta, bifrost, Target @tvm.target.generic_func @@ -268,5 +268,36 @@ def test_target_with_host(): assert tgt.host.attrs["registers_per_block"] == 32768 +def test_check_and_update_host_consist_0(): + target = None + host = None + target, host = Target.check_and_update_host_consist(target, host) + + +def test_check_and_update_host_consist_1(): + target = None + host = "llvm" + with pytest.raises(AssertionError, match=r"Target host is not empty when target is empty."): + target, host = Target.check_and_update_host_consist(target, host) + + +def test_check_and_update_host_consist_2(): + target = Target("cuda") + host = Target("llvm") + target, host = Target.check_and_update_host_consist(target, host) + assert target.kind.name == "cuda" + assert target.host.kind.name == "llvm" + + +def test_check_and_update_host_consist_3(): + target = Target(target="cuda", host="llvm") + host = None + target, host = Target.check_and_update_host_consist(target, host) + assert target.kind.name == "cuda" + assert target.host.kind.name == "llvm" + assert host.kind.name == "llvm" + assert target.host == host + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))