From ca40ff6713721a5e190bacb30a20ea54b7d89878 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Thu, 26 Aug 2021 09:51:40 -0700 Subject: [PATCH 01/10] change workload keys --- python/tvm/auto_scheduler/compute_dag.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/tvm/auto_scheduler/compute_dag.py b/python/tvm/auto_scheduler/compute_dag.py index f7a5f39c829a..9edd4a6f63bf 100755 --- a/python/tvm/auto_scheduler/compute_dag.py +++ b/python/tvm/auto_scheduler/compute_dag.py @@ -231,12 +231,11 @@ def workload_key(self): """ str_dag = _ffi_api.ComputeDAGPrintDAG(self, True) str_dag = str_dag.encode(encoding="utf-8") - hash_key = hashlib.md5(str_dag).hexdigest() io_shapes = [] for tensor in self.tensors: io_shapes += get_const_tuple(tensor.shape) - return json.dumps([hash_key] + io_shapes) + return json.dumps([str_dag] + io_shapes) def __str__(self): # pretty print From d2821a30c33d6bf54262cc381e15903c0c777d96 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Thu, 26 Aug 2021 10:24:37 -0700 Subject: [PATCH 02/10] remove binary string comparison --- python/tvm/auto_scheduler/compute_dag.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/tvm/auto_scheduler/compute_dag.py b/python/tvm/auto_scheduler/compute_dag.py index 9edd4a6f63bf..7fc1b8eb2385 100755 --- a/python/tvm/auto_scheduler/compute_dag.py +++ b/python/tvm/auto_scheduler/compute_dag.py @@ -230,8 +230,6 @@ def workload_key(self): The workload key of this compute DAG """ str_dag = _ffi_api.ComputeDAGPrintDAG(self, True) - str_dag = str_dag.encode(encoding="utf-8") - io_shapes = [] for tensor in self.tensors: io_shapes += get_const_tuple(tensor.shape) From c4f83fd9f5173d083e92018779da0b08e539731c Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Thu, 26 Aug 2021 10:34:26 -0700 Subject: [PATCH 03/10] append the tuple not every integer --- python/tvm/auto_scheduler/compute_dag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/auto_scheduler/compute_dag.py b/python/tvm/auto_scheduler/compute_dag.py index 7fc1b8eb2385..a9e8d5d843f9 100755 --- a/python/tvm/auto_scheduler/compute_dag.py +++ b/python/tvm/auto_scheduler/compute_dag.py @@ -232,7 +232,7 @@ def workload_key(self): str_dag = _ffi_api.ComputeDAGPrintDAG(self, True) io_shapes = [] for tensor in self.tensors: - io_shapes += get_const_tuple(tensor.shape) + io_shapes.append(get_const_tuple(tensor.shape)) return json.dumps([str_dag] + io_shapes) def __str__(self): From c38baa1179fc077a8d6044c519de7069655bc55d Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Thu, 26 Aug 2021 15:20:11 -0700 Subject: [PATCH 04/10] clean up --- python/tvm/auto_scheduler/compute_dag.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/python/tvm/auto_scheduler/compute_dag.py b/python/tvm/auto_scheduler/compute_dag.py index a9e8d5d843f9..007c71c9cc3f 100755 --- a/python/tvm/auto_scheduler/compute_dag.py +++ b/python/tvm/auto_scheduler/compute_dag.py @@ -222,18 +222,23 @@ def rewrite_layout_from_state(self, state): def workload_key(self): """Return the workload key of this compute DAG. - The workload key is a JSON string from a tuple of (hash-key, tensor shapes...) + The workload key is a JSON string from a tuple of (DAG string, tensor shapes...) Returns ------- key: str The workload key of this compute DAG """ - str_dag = _ffi_api.ComputeDAGPrintDAG(self, True) + hash_key = _ffi_api.ComputeDAGPrintDAG(self, True) + # TODO: forward a "use_hash" flag down here + # if use_hash: + # hash_key = hash_key.encode(encoding="utf-8") + # hash_key = hashlib.md5(hash_key).hexdigest() + io_shapes = [] for tensor in self.tensors: io_shapes.append(get_const_tuple(tensor.shape)) - return json.dumps([str_dag] + io_shapes) + return json.dumps([hash_key] + io_shapes) def __str__(self): # pretty print From 3a09575a9e8cbb32189802959f0adcbc35e4b531 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Thu, 26 Aug 2021 15:49:50 -0700 Subject: [PATCH 05/10] lint --- python/tvm/auto_scheduler/compute_dag.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/auto_scheduler/compute_dag.py b/python/tvm/auto_scheduler/compute_dag.py index 007c71c9cc3f..0bece830a5a1 100755 --- a/python/tvm/auto_scheduler/compute_dag.py +++ b/python/tvm/auto_scheduler/compute_dag.py @@ -18,7 +18,6 @@ """ The auto-scheduler's computational graph and related program analyses. """ -import hashlib import json import tvm._ffi From 0885b1f8b01a4c23d3938f4a0fa08e426cb46a8c Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Fri, 27 Aug 2021 11:06:50 -0700 Subject: [PATCH 06/10] dump workload keys to dags --- python/tvm/auto_scheduler/compute_dag.py | 15 ++++++++++----- python/tvm/auto_scheduler/relay_integration.py | 11 +++++++++-- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/python/tvm/auto_scheduler/compute_dag.py b/python/tvm/auto_scheduler/compute_dag.py index 0bece830a5a1..aabf7e6d10a5 100755 --- a/python/tvm/auto_scheduler/compute_dag.py +++ b/python/tvm/auto_scheduler/compute_dag.py @@ -18,6 +18,7 @@ """ The auto-scheduler's computational graph and related program analyses. """ +import hashlib import json import tvm._ffi @@ -228,11 +229,15 @@ def workload_key(self): key: str The workload key of this compute DAG """ - hash_key = _ffi_api.ComputeDAGPrintDAG(self, True) - # TODO: forward a "use_hash" flag down here - # if use_hash: - # hash_key = hash_key.encode(encoding="utf-8") - # hash_key = hashlib.md5(hash_key).hexdigest() + str_dag = _ffi_api.ComputeDAGPrintDAG(self, True) + hash_func = tvm._ffi.get_global_func( + "auto_scheduler.compute_dag.hash_func", allow_missing=True + ) + + if hash_func is None: + hash_key = hashlib.md5(str_dag).hexdigest() + else: + hash_key = hash_func(str_dag) io_shapes = [] for tensor in self.tensors: diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index 850e50004337..e305dd098cac 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -22,6 +22,7 @@ 2. Provide auto-scheduling for all TOPI compute functions """ +import json import logging import threading from copy import deepcopy @@ -30,11 +31,10 @@ from tvm import autotvm, transform from tvm.ir.transform import PassContext from tvm.runtime import convert_to_object - +from tvm.target import Target from tvm.te.tensor import ComputeOp, PlaceholderOp, Tensor from tvm.tir import Reduce from tvm.tir import expr as _expr -from tvm.target import Target from . import _ffi_api from .compute_dag import ComputeDAG, LayoutRewriteOption @@ -97,6 +97,7 @@ def extract_tasks( target_host=None, hardware_params=None, include_simple_tasks=False, + dump_workload_to_dag_log=None, opt_level=3, ): """Extract tuning tasks from a relay program. @@ -115,6 +116,8 @@ def extract_tasks( Hardware parameters used for the search tasks include_simple_tasks: bool Whether to extract simple tasks that do not include complicated ops. + dump_workloads_extract_tasks: Optional[str] + A file to dump an association between the workload keys and the actual DAG opt_level : Optional[int] The optimization level of the task extractions. @@ -170,6 +173,10 @@ def extract_tasks( ) weights.append(weight) + if dump_workload_to_dag_log is not None: + with open(dump_workload_to_dag_log, "wb") as f: + json.dump({task.workload_key: str(task.compute_dag) for task in tasks}, f) + return tasks, weights From bc4b0c529c36ebdd40c1b5ae58382cf1afd8d266 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Fri, 27 Aug 2021 11:07:35 -0700 Subject: [PATCH 07/10] fix things --- python/tvm/auto_scheduler/compute_dag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/auto_scheduler/compute_dag.py b/python/tvm/auto_scheduler/compute_dag.py index aabf7e6d10a5..8a1cf7e61dc5 100755 --- a/python/tvm/auto_scheduler/compute_dag.py +++ b/python/tvm/auto_scheduler/compute_dag.py @@ -222,7 +222,7 @@ def rewrite_layout_from_state(self, state): def workload_key(self): """Return the workload key of this compute DAG. - The workload key is a JSON string from a tuple of (DAG string, tensor shapes...) + The workload key is a JSON string from a tuple of (hash of DAG, tensor shapes...) Returns ------- From eb92bdf7018a8113a1d6780ea74b24070e229012 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Fri, 27 Aug 2021 11:08:45 -0700 Subject: [PATCH 08/10] change some strings --- python/tvm/auto_scheduler/relay_integration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index e305dd098cac..a8d73c7c99a8 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -116,7 +116,7 @@ def extract_tasks( Hardware parameters used for the search tasks include_simple_tasks: bool Whether to extract simple tasks that do not include complicated ops. - dump_workloads_extract_tasks: Optional[str] + dump_workload_to_dag_log: Optional[str] A file to dump an association between the workload keys and the actual DAG opt_level : Optional[int] The optimization level of the task extractions. From 868423fd2c3bfb1d69e72bc96544ada5ac4a3ee5 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Fri, 27 Aug 2021 11:43:24 -0700 Subject: [PATCH 09/10] misc fixes, add tests --- python/tvm/auto_scheduler/compute_dag.py | 1 + .../tvm/auto_scheduler/relay_integration.py | 2 +- .../test_auto_scheduler_task_extraction.py | 44 ++++++++++++++++++- 3 files changed, 45 insertions(+), 2 deletions(-) diff --git a/python/tvm/auto_scheduler/compute_dag.py b/python/tvm/auto_scheduler/compute_dag.py index 8a1cf7e61dc5..c212d143f987 100755 --- a/python/tvm/auto_scheduler/compute_dag.py +++ b/python/tvm/auto_scheduler/compute_dag.py @@ -235,6 +235,7 @@ def workload_key(self): ) if hash_func is None: + str_dag = str_dag.encode("utf-8") hash_key = hashlib.md5(str_dag).hexdigest() else: hash_key = hash_func(str_dag) diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index a8d73c7c99a8..8b68f4e9002a 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -174,7 +174,7 @@ def extract_tasks( weights.append(weight) if dump_workload_to_dag_log is not None: - with open(dump_workload_to_dag_log, "wb") as f: + with open(dump_workload_to_dag_log, "w") as f: json.dump({task.workload_key: str(task.compute_dag) for task in tasks}, f) return tasks, weights diff --git a/tests/python/relay/test_auto_scheduler_task_extraction.py b/tests/python/relay/test_auto_scheduler_task_extraction.py index 39596186d211..a53b68cca885 100644 --- a/tests/python/relay/test_auto_scheduler_task_extraction.py +++ b/tests/python/relay/test_auto_scheduler_task_extraction.py @@ -15,10 +15,13 @@ # specific language governing permissions and limitations # under the License. """Test task extraction for auto-scheduler""" -import pytest +import json +import tempfile +import pytest import tvm.relay.testing import tvm.testing +from tvm import _ffi as _ffi_api from tvm import auto_scheduler, relay @@ -248,5 +251,44 @@ def verify_task_extraction(func_name, expected_task, include_simple_tasks=False) verify_task_extraction(*params) +def test_dump_workload_to_dag_extract_tasks(): + mod, _ = get_network("mobilenet", layout="NHWC") + with tempfile.NamedTemporaryFile() as f: + tasks, _ = auto_scheduler.extract_tasks( + mod["main"], None, "llvm", include_simple_tasks=True, dump_workload_to_dag_log=f.name + ) + expected = {task.workload_key: str(task.compute_dag) for task in tasks} + actual = json.load(f) + assert expected == actual + + +def test_custom_hash_func_extract_tasks(): + @_ffi_api.register_func("auto_scheduler.compute_dag.hash_func") + def counting_unique_hash(str_dag): + ret = counting_unique_hash.i + counting_unique_hash.i += 1 + return ret + + counting_unique_hash.i = 0 + + mod, _ = get_network("mobilenet", layout="NHWC") + tasks, _ = auto_scheduler.extract_tasks(mod["main"], None, "llvm", include_simple_tasks=True) + + hash_values = [] + for task in tasks: + # task.workload_key should look like + # [43, [3, 3, 1024, 1], [1024], [3, 3, 1024, 1]] where the first int is the result of the hash + # Extract the hash and keep track of every hash + hash_value = int(task.workload_key[1:].split(",")[0]) + hash_values.append(hash_value) + + # All values are unique, and we know the min and max + # This is a sufficient condition to know that hashes in hash_values are an increasing list + # of hashes up to counting_unique_hash.i - 1 + assert len(hash_values) == len(set(hash_values)) + assert min(hash_values) == 0 + assert max(hash_values) == counting_unique_hash.i - 1 + + if __name__ == "__main__": pytest.main([__file__]) From 0b4914c7e5ff23a012ef4f39a01e18f0d2a2b786 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Fri, 27 Aug 2021 13:41:30 -0700 Subject: [PATCH 10/10] jostle ci