diff --git a/python/tvm/auto_scheduler/compute_dag.py b/python/tvm/auto_scheduler/compute_dag.py index f7a5f39c829a..c212d143f987 100755 --- a/python/tvm/auto_scheduler/compute_dag.py +++ b/python/tvm/auto_scheduler/compute_dag.py @@ -222,7 +222,7 @@ def rewrite_layout_from_state(self, state): def workload_key(self): """Return the workload key of this compute DAG. - The workload key is a JSON string from a tuple of (hash-key, tensor shapes...) + The workload key is a JSON string from a tuple of (hash of DAG, tensor shapes...) Returns ------- @@ -230,12 +230,19 @@ def workload_key(self): The workload key of this compute DAG """ str_dag = _ffi_api.ComputeDAGPrintDAG(self, True) - str_dag = str_dag.encode(encoding="utf-8") - hash_key = hashlib.md5(str_dag).hexdigest() + hash_func = tvm._ffi.get_global_func( + "auto_scheduler.compute_dag.hash_func", allow_missing=True + ) + + if hash_func is None: + str_dag = str_dag.encode("utf-8") + hash_key = hashlib.md5(str_dag).hexdigest() + else: + hash_key = hash_func(str_dag) io_shapes = [] for tensor in self.tensors: - io_shapes += get_const_tuple(tensor.shape) + io_shapes.append(get_const_tuple(tensor.shape)) return json.dumps([hash_key] + io_shapes) def __str__(self): diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index 850e50004337..8b68f4e9002a 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -22,6 +22,7 @@ 2. Provide auto-scheduling for all TOPI compute functions """ +import json import logging import threading from copy import deepcopy @@ -30,11 +31,10 @@ from tvm import autotvm, transform from tvm.ir.transform import PassContext from tvm.runtime import convert_to_object - +from tvm.target import Target from tvm.te.tensor import ComputeOp, PlaceholderOp, Tensor from tvm.tir import Reduce from tvm.tir import expr as _expr -from tvm.target import Target from . import _ffi_api from .compute_dag import ComputeDAG, LayoutRewriteOption @@ -97,6 +97,7 @@ def extract_tasks( target_host=None, hardware_params=None, include_simple_tasks=False, + dump_workload_to_dag_log=None, opt_level=3, ): """Extract tuning tasks from a relay program. @@ -115,6 +116,8 @@ def extract_tasks( Hardware parameters used for the search tasks include_simple_tasks: bool Whether to extract simple tasks that do not include complicated ops. + dump_workload_to_dag_log: Optional[str] + A file to dump an association between the workload keys and the actual DAG opt_level : Optional[int] The optimization level of the task extractions. @@ -170,6 +173,10 @@ def extract_tasks( ) weights.append(weight) + if dump_workload_to_dag_log is not None: + with open(dump_workload_to_dag_log, "w") as f: + json.dump({task.workload_key: str(task.compute_dag) for task in tasks}, f) + return tasks, weights diff --git a/tests/python/relay/test_auto_scheduler_task_extraction.py b/tests/python/relay/test_auto_scheduler_task_extraction.py index 39596186d211..a53b68cca885 100644 --- a/tests/python/relay/test_auto_scheduler_task_extraction.py +++ b/tests/python/relay/test_auto_scheduler_task_extraction.py @@ -15,10 +15,13 @@ # specific language governing permissions and limitations # under the License. """Test task extraction for auto-scheduler""" -import pytest +import json +import tempfile +import pytest import tvm.relay.testing import tvm.testing +from tvm import _ffi as _ffi_api from tvm import auto_scheduler, relay @@ -248,5 +251,44 @@ def verify_task_extraction(func_name, expected_task, include_simple_tasks=False) verify_task_extraction(*params) +def test_dump_workload_to_dag_extract_tasks(): + mod, _ = get_network("mobilenet", layout="NHWC") + with tempfile.NamedTemporaryFile() as f: + tasks, _ = auto_scheduler.extract_tasks( + mod["main"], None, "llvm", include_simple_tasks=True, dump_workload_to_dag_log=f.name + ) + expected = {task.workload_key: str(task.compute_dag) for task in tasks} + actual = json.load(f) + assert expected == actual + + +def test_custom_hash_func_extract_tasks(): + @_ffi_api.register_func("auto_scheduler.compute_dag.hash_func") + def counting_unique_hash(str_dag): + ret = counting_unique_hash.i + counting_unique_hash.i += 1 + return ret + + counting_unique_hash.i = 0 + + mod, _ = get_network("mobilenet", layout="NHWC") + tasks, _ = auto_scheduler.extract_tasks(mod["main"], None, "llvm", include_simple_tasks=True) + + hash_values = [] + for task in tasks: + # task.workload_key should look like + # [43, [3, 3, 1024, 1], [1024], [3, 3, 1024, 1]] where the first int is the result of the hash + # Extract the hash and keep track of every hash + hash_value = int(task.workload_key[1:].split(",")[0]) + hash_values.append(hash_value) + + # All values are unique, and we know the min and max + # This is a sufficient condition to know that hashes in hash_values are an increasing list + # of hashes up to counting_unique_hash.i - 1 + assert len(hash_values) == len(set(hash_values)) + assert min(hash_values) == 0 + assert max(hash_values) == counting_unique_hash.i - 1 + + if __name__ == "__main__": pytest.main([__file__])