From 799e540229a3e69cd02c1d2d382931e360e9778f Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 2 Feb 2021 21:41:15 +0000 Subject: [PATCH 1/3] [Bugfix][AutoScheduler] Fail to register ComputeDAG when deserialize tasks --- python/tvm/auto_scheduler/search_task.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index 83f665b229d2..175c2fa06c39 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -30,7 +30,7 @@ from .compute_dag import ComputeDAG, LayoutRewriteOption from .cost_model import XGBModel from .search_policy import SketchPolicy -from .workload_registry import register_workload_tensors +from .workload_registry import WORKLOAD_FUNC_REGISTRY, register_workload_tensors from . import _ffi_api @@ -335,11 +335,12 @@ def __setstate__(self, state): except Exception: # pylint: disable=broad-except raise RuntimeError("Invalid workload key %s" % state["workload_key"]) - # The workload from a compute DAG does not have arguments and is not registered - # by default so we register it here. If the workload has already been registered, - # the later registration overrides the prvious one. - if len(workload) == 1: - register_workload_tensors(workload[0], state["compute_dag"].tensors) + # workload[0] is either the compute function name or the ComputeDAG hash. + # The compute functions are already registered when importing TVM, so here + # we only register the ComputeDAG workloads. If the same workload has + # already been registered, the later registration overrides the prvious one. + if workload[0] not in WORKLOAD_FUNC_REGISTRY: + register_workload_tensors(state["workload_key"], state["compute_dag"].tensors) self.__init_handle_by_constructor__( _ffi_api.SearchTask, From 47af519d98fdee7f1fad9131fb6e107cba3a624d Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Wed, 3 Feb 2021 00:20:35 +0000 Subject: [PATCH 2/3] fix test --- tests/python/unittest/test_auto_scheduler_compute_dag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_auto_scheduler_compute_dag.py b/tests/python/unittest/test_auto_scheduler_compute_dag.py index 60b986ec37b2..b303ef56c1d2 100644 --- a/tests/python/unittest/test_auto_scheduler_compute_dag.py +++ b/tests/python/unittest/test_auto_scheduler_compute_dag.py @@ -121,7 +121,7 @@ def test_stage_order(): ) task2 = pickle.loads(pickle.dumps(task)) - assert "test-key" in auto_scheduler.workload_registry.WORKLOAD_FUNC_REGISTRY + assert '["test-key"]' in auto_scheduler.workload_registry.WORKLOAD_FUNC_REGISTRY assert str(task.compute_dag.get_init_state()) == str(task2.compute_dag.get_init_state()) assert len(task.compute_dag.get_init_state().stage_ops) == len( task2.compute_dag.get_init_state().stage_ops From 3bdf174ed3629bca7def51e39b8542054187ef2d Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Thu, 4 Feb 2021 18:09:44 +0000 Subject: [PATCH 3/3] trigger ci