diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index 83f665b229d2..175c2fa06c39 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -30,7 +30,7 @@ from .compute_dag import ComputeDAG, LayoutRewriteOption from .cost_model import XGBModel from .search_policy import SketchPolicy -from .workload_registry import register_workload_tensors +from .workload_registry import WORKLOAD_FUNC_REGISTRY, register_workload_tensors from . import _ffi_api @@ -335,11 +335,12 @@ def __setstate__(self, state): except Exception: # pylint: disable=broad-except raise RuntimeError("Invalid workload key %s" % state["workload_key"]) - # The workload from a compute DAG does not have arguments and is not registered - # by default so we register it here. If the workload has already been registered, - # the later registration overrides the prvious one. - if len(workload) == 1: - register_workload_tensors(workload[0], state["compute_dag"].tensors) + # workload[0] is either the compute function name or the ComputeDAG hash. + # The compute functions are already registered when importing TVM, so here + # we only register the ComputeDAG workloads. If the same workload has + # already been registered, the later registration overrides the prvious one. + if workload[0] not in WORKLOAD_FUNC_REGISTRY: + register_workload_tensors(state["workload_key"], state["compute_dag"].tensors) self.__init_handle_by_constructor__( _ffi_api.SearchTask, diff --git a/tests/python/unittest/test_auto_scheduler_compute_dag.py b/tests/python/unittest/test_auto_scheduler_compute_dag.py index 60b986ec37b2..b303ef56c1d2 100644 --- a/tests/python/unittest/test_auto_scheduler_compute_dag.py +++ b/tests/python/unittest/test_auto_scheduler_compute_dag.py @@ -121,7 +121,7 @@ def test_stage_order(): ) task2 = pickle.loads(pickle.dumps(task)) - assert "test-key" in auto_scheduler.workload_registry.WORKLOAD_FUNC_REGISTRY + assert '["test-key"]' in auto_scheduler.workload_registry.WORKLOAD_FUNC_REGISTRY assert str(task.compute_dag.get_init_state()) == str(task2.compute_dag.get_init_state()) assert len(task.compute_dag.get_init_state().stage_ops) == len( task2.compute_dag.get_init_state().stage_ops