From 846f45f693f9698c9f20eea1355e91b8c94c070e Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Fri, 2 Apr 2021 09:17:26 -0700 Subject: [PATCH 1/7] Fix empty target and host for autotvm task --- python/tvm/autotvm/task/task.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py index 0d60ca929d7b..29621da10ff6 100644 --- a/python/tvm/autotvm/task/task.py +++ b/python/tvm/autotvm/task/task.py @@ -175,9 +175,10 @@ def __getstate__(self): # and restore the function by name when unpickling it. import cloudpickle # pylint: disable=import-outside-toplevel - self.target, self.target_host = Target.check_and_update_host_consist( - self.target, self.target_host - ) + if self.target: + self.target, self.target_host = Target.check_and_update_host_consist( + self.target, self.target_host + ) return { "name": self.name, "args": self.args, @@ -185,7 +186,7 @@ def __getstate__(self): "config_space": self.config_space, "flop": self.flop, "target": self.target, - "target_host": self.target.host, + "target_host": self.target_host, "func": cloudpickle.dumps(self.func), } From 0c96f50b6f6b208aae6b3c7a63a4d307dc68bb99 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Fri, 2 Apr 2021 10:33:16 -0700 Subject: [PATCH 2/7] Fix setstate for autotvm task --- python/tvm/autotvm/task/task.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py index 29621da10ff6..f146eac5741d 100644 --- a/python/tvm/autotvm/task/task.py +++ b/python/tvm/autotvm/task/task.py @@ -199,9 +199,10 @@ def __setstate__(self, state): self.config_space = state["config_space"] self.func = cloudpickle.loads(state["func"]) self.flop = state["flop"] - self.target, self.target_host = Target.check_and_update_host_consist( - state["target"], state["target_host"] - ) + if self.target: + self.target, self.target_host = Target.check_and_update_host_consist( + state["target"], state["target_host"] + ) def __repr__(self): return "Task(func_name=%s, args=%s, kwargs=%s, workload=%s)" % ( From 725a3430a0589c87698a7f166eaee65684576861 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Fri, 2 Apr 2021 12:10:36 -0700 Subject: [PATCH 3/7] Check all occurrences and fix --- python/tvm/autotvm/task/task.py | 16 +++++++--------- python/tvm/target/target.py | 4 ++++ tests/python/integration/test_tuning.py | 4 ++-- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py index f146eac5741d..668832b8a86c 100644 --- a/python/tvm/autotvm/task/task.py +++ b/python/tvm/autotvm/task/task.py @@ -175,10 +175,9 @@ def __getstate__(self): # and restore the function by name when unpickling it. import cloudpickle # pylint: disable=import-outside-toplevel - if self.target: - self.target, self.target_host = Target.check_and_update_host_consist( - self.target, self.target_host - ) + self.target, self.target_host = Target.check_and_update_host_consist( + self.target, self.target_host + ) return { "name": self.name, "args": self.args, @@ -199,10 +198,9 @@ def __setstate__(self, state): self.config_space = state["config_space"] self.func = cloudpickle.loads(state["func"]) self.flop = state["flop"] - if self.target: - self.target, self.target_host = Target.check_and_update_host_consist( - state["target"], state["target_host"] - ) + self.target, self.target_host = Target.check_and_update_host_consist( + state["target"], state["target_host"] + ) def __repr__(self): return "Task(func_name=%s, args=%s, kwargs=%s, workload=%s)" % ( @@ -467,7 +465,7 @@ def create(task_name, args, target, target_host=None): ret.flop = ret.config_space.flop or compute_flop(sch) ret.target = target - ret.target_host = target.host + ret.target_host = target_host return ret diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index 6d0a0635221e..d3fa25d9b3c4 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -182,6 +182,10 @@ def check_and_update_host_consist(target, host=None, target_is_dict_key=True): target_is_dict_key : Bool When the type of target is dict, whether Target is the key (Otherwise the value) """ + if target is None: + if host is not None: + warnings.warn("Target host is not empty when target is empty.") + return target, host if isinstance(target, dict) and "kind" not in target: new_target = {} for tgt, mod in target.items(): diff --git a/tests/python/integration/test_tuning.py b/tests/python/integration/test_tuning.py index 45e0958a0240..55c8e5643c71 100644 --- a/tests/python/integration/test_tuning.py +++ b/tests/python/integration/test_tuning.py @@ -30,6 +30,7 @@ from tvm import autotvm from tvm.autotvm.tuner import RandomTuner +from tvm.target import Target import tvm.testing @@ -131,8 +132,7 @@ def teardown_module(): def get_sample_task(target=tvm.target.cuda(), target_host=None): - target = tvm.target.Target(target, target_host) - target_host = target.host + target, target_host = Target.check_and_update_host_consist(target, target_host) """return a sample task for testing""" task = autotvm.task.create( "testing/conv2d_no_batching", args=(1, 7, 7, 512, 512, 3, 3), target=target From f78920cc540e0cbf6432ec3aadc66e4124f66627 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Fri, 2 Apr 2021 12:56:07 -0700 Subject: [PATCH 4/7] Add assertion on target is empty and host is not --- python/tvm/target/target.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index d3fa25d9b3c4..94fcf2eadb2e 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -183,8 +183,7 @@ def check_and_update_host_consist(target, host=None, target_is_dict_key=True): When the type of target is dict, whether Target is the key (Otherwise the value) """ if target is None: - if host is not None: - warnings.warn("Target host is not empty when target is empty.") + assert host is None return target, host if isinstance(target, dict) and "kind" not in target: new_target = {} From b798ed67106cd059620edcd8fdf9ef79f69266a9 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Fri, 2 Apr 2021 13:27:59 -0700 Subject: [PATCH 5/7] Add test for check_and_update_host_consist and add warning msg when assertion fail --- python/tvm/target/target.py | 2 +- tests/python/unittest/test_target_target.py | 32 ++++++++++++++++++++- 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index 94fcf2eadb2e..baf07602bde6 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -183,7 +183,7 @@ def check_and_update_host_consist(target, host=None, target_is_dict_key=True): When the type of target is dict, whether Target is the key (Otherwise the value) """ if target is None: - assert host is None + assert host is None, "Target host is not empty when target is empty." return target, host if isinstance(target, dict) and "kind" not in target: new_target = {} diff --git a/tests/python/unittest/test_target_target.py b/tests/python/unittest/test_target_target.py index 2f885d39335b..3291e78eeb78 100644 --- a/tests/python/unittest/test_target_target.py +++ b/tests/python/unittest/test_target_target.py @@ -18,7 +18,7 @@ import sys import pytest import tvm -from tvm.target import cuda, rocm, mali, intel_graphics, arm_cpu, vta, bifrost +from tvm.target import cuda, rocm, mali, intel_graphics, arm_cpu, vta, bifrost, Target @tvm.target.generic_func @@ -268,5 +268,35 @@ def test_target_with_host(): assert tgt.host.attrs["registers_per_block"] == 32768 +def test_check_and_update_host_consist_0(): + target = None + host = None + target, host = Target.check_and_update_host_consist(target, host) + + +def test_check_and_update_host_consist_1(): + target = None + host = "llvm" + with pytest.raises(AssertionError, match=r"Target host is not empty when target is empty."): + target, host = Target.check_and_update_host_consist(target, host) + + +def test_check_and_update_host_consist_2(): + target = Target("cuda") + host = Target("llvm") + target, host = Target.check_and_update_host_consist(target, host) + assert target.kind.name == "cuda" + assert target.host.kind.name == "llvm" + + +def test_check_and_update_host_consist_3(): + target = Target("cuda --host=llvm") + host = None + target, host = Target.check_and_update_host_consist(target, host) + assert target.kind.name == "cuda" + assert target.host.kind.name == "llvm" + assert host.kind.name == "llvm" + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) From 66b03210e37f94751716a7ec5f32708d93756233 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Fri, 2 Apr 2021 13:30:29 -0700 Subject: [PATCH 6/7] Add minor test patch --- tests/python/unittest/test_target_target.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/unittest/test_target_target.py b/tests/python/unittest/test_target_target.py index 3291e78eeb78..b0c87b1c9f90 100644 --- a/tests/python/unittest/test_target_target.py +++ b/tests/python/unittest/test_target_target.py @@ -296,6 +296,7 @@ def test_check_and_update_host_consist_3(): assert target.kind.name == "cuda" assert target.host.kind.name == "llvm" assert host.kind.name == "llvm" + assert target.host == host if __name__ == "__main__": From a128c1740baff8cd48f0c83e55790b216516ac5c Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Fri, 2 Apr 2021 17:11:02 -0700 Subject: [PATCH 7/7] Try to rerun CI --- tests/python/unittest/test_target_target.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_target_target.py b/tests/python/unittest/test_target_target.py index b0c87b1c9f90..98a9edc7a517 100644 --- a/tests/python/unittest/test_target_target.py +++ b/tests/python/unittest/test_target_target.py @@ -290,7 +290,7 @@ def test_check_and_update_host_consist_2(): def test_check_and_update_host_consist_3(): - target = Target("cuda --host=llvm") + target = Target(target="cuda", host="llvm") host = None target, host = Target.check_and_update_host_consist(target, host) assert target.kind.name == "cuda"