Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions python/tvm/auto_scheduler/compute_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,20 +222,27 @@ def rewrite_layout_from_state(self, state):

def workload_key(self):
"""Return the workload key of this compute DAG.
The workload key is a JSON string from a tuple of (hash-key, tensor shapes...)
The workload key is a JSON string from a tuple of (hash of DAG, tensor shapes...)

Returns
-------
key: str
The workload key of this compute DAG
"""
str_dag = _ffi_api.ComputeDAGPrintDAG(self, True)
str_dag = str_dag.encode(encoding="utf-8")
hash_key = hashlib.md5(str_dag).hexdigest()
hash_func = tvm._ffi.get_global_func(
"auto_scheduler.compute_dag.hash_func", allow_missing=True
)

if hash_func is None:
str_dag = str_dag.encode("utf-8")
hash_key = hashlib.md5(str_dag).hexdigest()
else:
hash_key = hash_func(str_dag)

io_shapes = []
for tensor in self.tensors:
io_shapes += get_const_tuple(tensor.shape)
io_shapes.append(get_const_tuple(tensor.shape))
return json.dumps([hash_key] + io_shapes)

def __str__(self):
Expand Down
11 changes: 9 additions & 2 deletions python/tvm/auto_scheduler/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
2. Provide auto-scheduling for all TOPI compute functions
"""

import json
import logging
import threading
from copy import deepcopy
Expand All @@ -30,11 +31,10 @@
from tvm import autotvm, transform
from tvm.ir.transform import PassContext
from tvm.runtime import convert_to_object

from tvm.target import Target
from tvm.te.tensor import ComputeOp, PlaceholderOp, Tensor
from tvm.tir import Reduce
from tvm.tir import expr as _expr
from tvm.target import Target

from . import _ffi_api
from .compute_dag import ComputeDAG, LayoutRewriteOption
Expand Down Expand Up @@ -97,6 +97,7 @@ def extract_tasks(
target_host=None,
hardware_params=None,
include_simple_tasks=False,
dump_workload_to_dag_log=None,
opt_level=3,
):
"""Extract tuning tasks from a relay program.
Expand All @@ -115,6 +116,8 @@ def extract_tasks(
Hardware parameters used for the search tasks
include_simple_tasks: bool
Whether to extract simple tasks that do not include complicated ops.
dump_workload_to_dag_log: Optional[str]
A file to dump an association between the workload keys and the actual DAG
opt_level : Optional[int]
The optimization level of the task extractions.

Expand Down Expand Up @@ -170,6 +173,10 @@ def extract_tasks(
)
weights.append(weight)

if dump_workload_to_dag_log is not None:
with open(dump_workload_to_dag_log, "w") as f:
json.dump({task.workload_key: str(task.compute_dag) for task in tasks}, f)

return tasks, weights


Expand Down
44 changes: 43 additions & 1 deletion tests/python/relay/test_auto_scheduler_task_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@
# specific language governing permissions and limitations
# under the License.
"""Test task extraction for auto-scheduler"""
import pytest
import json
import tempfile

import pytest
import tvm.relay.testing
import tvm.testing
from tvm import _ffi as _ffi_api
from tvm import auto_scheduler, relay


Expand Down Expand Up @@ -248,5 +251,44 @@ def verify_task_extraction(func_name, expected_task, include_simple_tasks=False)
verify_task_extraction(*params)


def test_dump_workload_to_dag_extract_tasks():
mod, _ = get_network("mobilenet", layout="NHWC")
with tempfile.NamedTemporaryFile() as f:
tasks, _ = auto_scheduler.extract_tasks(
mod["main"], None, "llvm", include_simple_tasks=True, dump_workload_to_dag_log=f.name
)
expected = {task.workload_key: str(task.compute_dag) for task in tasks}
actual = json.load(f)
assert expected == actual


def test_custom_hash_func_extract_tasks():
@_ffi_api.register_func("auto_scheduler.compute_dag.hash_func")
def counting_unique_hash(str_dag):
ret = counting_unique_hash.i
counting_unique_hash.i += 1
return ret

counting_unique_hash.i = 0

mod, _ = get_network("mobilenet", layout="NHWC")
tasks, _ = auto_scheduler.extract_tasks(mod["main"], None, "llvm", include_simple_tasks=True)

hash_values = []
for task in tasks:
# task.workload_key should look like
# [43, [3, 3, 1024, 1], [1024], [3, 3, 1024, 1]] where the first int is the result of the hash
# Extract the hash and keep track of every hash
hash_value = int(task.workload_key[1:].split(",")[0])
hash_values.append(hash_value)

# All values are unique, and we know the min and max
# This is a sufficient condition to know that hashes in hash_values are an increasing list
# of hashes up to counting_unique_hash.i - 1
assert len(hash_values) == len(set(hash_values))
assert min(hash_values) == 0
assert max(hash_values) == counting_unique_hash.i - 1


if __name__ == "__main__":
pytest.main([__file__])