From 459f77307764139ee67522b89873f8bdc5b5a891 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 20 Jan 2021 11:27:14 +0800 Subject: [PATCH 01/35] Add sparse dense tuning tutorial --- python/tvm/auto_scheduler/measure.py | 123 +++++++- python/tvm/topi/nn/sparse.py | 6 +- tutorials/auto_scheduler/tune_sparse_x86.py | 295 ++++++++++++++++++++ 3 files changed, 419 insertions(+), 5 deletions(-) create mode 100644 tutorials/auto_scheduler/tune_sparse_x86.py diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 47ffde4327c4..348d0ba6074f 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -37,6 +37,8 @@ import tempfile import multiprocessing +import numpy as np + import tvm._ffi from tvm.runtime import Object, module, ndarray from tvm.driver import build_module @@ -719,6 +721,61 @@ def local_builder_build(inputs, timeout, n_parallel, build_func="default", verbo return results +def _process_sparse_input(args): + for arg in args: + if isinstance(arg.op, tvm.te.tensor.ComputeOp) and \ + arg.op.tag == "sparse_dense_sp_rhs_bsrmm": + # Get output shape + output_tensor = arg + M, N = output_tensor.shape + + # Get the input tensors + block_tensor = arg.op.input_tensors[0] + unsure_tensors = list(block_tensor.op.input_tensors) + assert len(unsure_tensors) == 4 + + # Get the input data + dense_data = None + for tensor in unsure_tensors: + if len(tensor.shape) == 2: + assert dense_data is None + dense_data = tensor + assert M == dense_data.shape[0] + K = dense_data.shape[1] + unsure_tensors.remove(dense_data) + + # Get the Sparse data + sparse_data = None + for tensor in unsure_tensors: + if len(tensor.shape) == 3: + assert sparse_data is None + sparse_data = tensor + block_size, BS_R, BS_C = sparse_data.shape + unsure_tensors.remove(sparse_data) + + # Get the Sparse indptr & indices + sparse_indices = None + for tensor in unsure_tensors: + assert len(tensor.shape) == 1 + if tensor.shape[0] == block_size: + assert sparse_indices is None + sparse_indices = tensor + unsure_tensors.remove(sparse_indices) + sparse_indptr = unsure_tensors[0] + + density = 1.0 + for i in sparse_data.shape: + density *= i + density /= (K * N) + density = density.value + sparse_prefix = "sparse_dense_bsr_%d_%d_%d_%d_%d_%.2f_" % ( + M, N, K, BS_R, BS_C, density + ) + + return sparse_prefix, sparse_data, sparse_indices, sparse_indptr + + return None, None, None, None + def _timed_eval_func( inp_serialized, build_res, @@ -758,11 +815,30 @@ def _timed_eval_func( if error_no == 0: try: - args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in build_res.args] + # Check sparse op + sparse_prefix, sparse_data, sparse_indices, sparse_indptr = \ + _process_sparse_input(build_res.args) random_fill = tvm.get_global_func("tvm.contrib.random.random_fill", True) assert random_fill, "Please make sure USE_RANDOM is ON in the config.cmake" - for arg in args: - random_fill(arg) + if sparse_prefix: + args = [] + for arg in build_res.args: + if arg == sparse_data: + args.append(ndarray.array(get_special_buffer(sparse_prefix+"W_data"), ctx)) + elif arg == sparse_indices: + args.append(ndarray.array(get_special_buffer(sparse_prefix+"W_indices"), ctx)) + elif arg == sparse_indptr: + args.append(ndarray.array(get_special_buffer(sparse_prefix+"W_indptr"), ctx)) + else: + empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, ctx) + random_fill(empty_array) + args.append(empty_array) + else: + args = [ + ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in build_res.args + ] + for arg in args: + random_fill(arg) ctx.sync() costs = time_f(*args).results # pylint: disable=broad-except @@ -1132,3 +1208,44 @@ def rpc_runner_run( print("") return results + + +# The map stores special registered buffer for measurement +# This can be used for sparse workloads when we cannot use random tensors for measurment. +global special_buffer_table +special_buffer_table = {} + +def register_special_buffer(tensor_name, data): + """Register special buffer for measurement + This can be used for sparse workloads when we cannot use random tensors for measurment. + """ + if tensor_name in special_buffer_table.keys(): + return True + + if os.path.isfile(tensor_name): + print("Load ", tensor_name) + if tensor_name.startswith("sparse_dense_bsr"): + if tensor_name.endswith("data"): + data = np.fromfile(tensor_name, dtype="float32", sep=" ") + name_split = tensor_name.split("_") + BS_R = int(name_split[6]) + BS_C = int(name_split[7]) + data = data.reshape((data.shape[0] // BS_R // BS_C, BS_R, BS_C)) + else: + data = np.fromfile(tensor_name, dtype="int32", sep=" ") + elif data is None: + return False + + special_buffer_table[tensor_name] = data + + if not os.path.isfile(tensor_name): + data.tofile(tensor_name, " ") + + return True + +def get_special_buffer(tensor_name): + """Get special buffer for measurement. + This can be used for sparse workloads when we cannot use random tensors for measurment. + The buffers are registered by `register_special_buffer`. + """ + return special_buffer_table.get(tensor_name, None) diff --git a/python/tvm/topi/nn/sparse.py b/python/tvm/topi/nn/sparse.py index 8145ed80af47..d790d087d251 100644 --- a/python/tvm/topi/nn/sparse.py +++ b/python/tvm/topi/nn/sparse.py @@ -197,7 +197,7 @@ def _compute_block(nb_j, j, i): def _sparse_dense_sp_rhs_bsrmm(data, weight_data, weight_indices, weight_indptr): - (m, _) = get_const_tuple(data.shape) + (m, k) = get_const_tuple(data.shape) (_, bs_r, bs_c) = get_const_tuple(weight_data.shape) (num_blocks_plus_1,) = get_const_tuple(weight_indptr.shape) num_blocks = num_blocks_plus_1 - 1 @@ -218,7 +218,9 @@ def _compute_block(i, nb_j, j): idxm = tvm.tir.indexmod bsrmm_block = te.compute( - (m, num_blocks, bs_r), _compute_block, tag="sparse_dense_sp_rhs_bsrmm_block" + (m, num_blocks, bs_r), _compute_block, + tag="sparse_dense_sp_rhs_bsrmm_block", + attrs={"FLOP": 2 * m * num_blocks * bs_r * k}, ) return te.compute( (m, num_blocks * bs_r), diff --git a/tutorials/auto_scheduler/tune_sparse_x86.py b/tutorials/auto_scheduler/tune_sparse_x86.py new file mode 100644 index 000000000000..3634e4efc883 --- /dev/null +++ b/tutorials/auto_scheduler/tune_sparse_x86.py @@ -0,0 +1,295 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Auto-scheduling Matrix Multiplication for CPU +============================================= +**Author**: `Chengfan Jia `_ + +This is a tutorial on how to use the auto-scheduler for CPUs. + +Different from the template-based :ref:`autotvm ` which relies on +manual templates to define the search space, the auto-scheduler does not require any templates. +Users only need to write the computation declaration without any schedule commands or templates. +The auto-scheduler can automatically generate a large search space and +find a good schedule in the space. + +We use matrix multiplication as an example in this tutorial. + +Note that this tutorial will not run on Windows or recent versions of macOS. To +get it to run, you will need to wrap the body of this tutorial in a :code:`if +__name__ == "__main__":` block. +""" + +import os +import itertools + +import numpy as np +import tvm +from tvm import te, auto_scheduler, topi +from tvm.topi.utils import get_const_tuple + +import scipy.sparse as sp + +###################################################################### +# Define the computation +# ^^^^^^^^^^^^^^^^^^^^^^ +# To begin with, let us define the computation of a matmul with bias add. +# The function should return the list of input/output tensors. +# From these tensors, the auto-scheduler can get the whole computational graph. + + +def random_bsr_matrix(M, N, BS_R, BS_C, density, dtype): + import itertools + + Y = np.zeros((M, N), dtype=dtype) + assert M % BS_R == 0 + assert N % BS_C == 0 + nnz = int(density * M * N) + num_blocks = int(nnz / (BS_R * BS_C)) + 1 + candidate_blocks = np.asarray(list(itertools.product(range(0, M, BS_R), range(0, N, BS_C)))) + assert candidate_blocks.shape[0] == M // BS_R * N // BS_C + chosen_blocks = candidate_blocks[ + np.random.choice(candidate_blocks.shape[0], size=num_blocks, replace=False) + ] + for i in range(len(chosen_blocks)): + r, c = chosen_blocks[i] + Y[r : r + BS_R, c : c + BS_C] = np.random.randn(BS_R, BS_C) + s = sp.bsr_matrix(Y, blocksize=(BS_R, BS_C)) + assert s.data.shape == (num_blocks, BS_R, BS_C) + assert s.indices.shape == (num_blocks,) + assert s.indptr.shape == (M // BS_R + 1,) + return s + + +###################################################################### +# Create the search task +# ^^^^^^^^^^^^^^^^^^^^^^ +# We then create a search task with N=L=M=1024 and dtype="float32" +# If your machine supports avx instructions, you can +# +# - replace "llvm" below with "llvm -mcpu=core-avx2" to enable AVX2 +# - replace "llvm" below with "llvm -mcpu=skylake-avx512" to enable AVX-512 + +target = tvm.target.Target("llvm -mcpu=core-avx2") + +M = K = N = 512 +BS_R = 16 +BS_C = 1 +density = 0.6 + +X_np = np.random.randn(M, K).astype("float32") +W_sp_np = random_bsr_matrix(N, K, BS_R, BS_C, density=density, dtype="float32") +W_np = W_sp_np.todense() +Y_np = X_np @ W_np.T + +prefix = "sparse_dense_bsr_%d_%d_%d_%d_%d_%.2f_" % (M, N, K, BS_R, BS_C, density) +auto_scheduler.measure.register_special_buffer(prefix + "W_data", W_sp_np.data) +auto_scheduler.measure.register_special_buffer(prefix + "W_indices", W_sp_np.indices) +auto_scheduler.measure.register_special_buffer(prefix + "W_indptr", W_sp_np.indptr) + +@auto_scheduler.register_workload +def sparse_dense(dense_shape, w_data_shape, w_indices_shape, w_indptr_shape, dtype): + X = te.placeholder(shape=dense_shape, dtype=dtype) + W_data = te.placeholder(shape=w_data_shape, dtype=dtype) + W_indices = te.placeholder(shape=w_indices_shape, dtype="int32") + W_indptr = te.placeholder(shape=w_indptr_shape, dtype="int32") + + out = topi.nn.sparse_dense(X, W_data, W_indices, W_indptr) + + return [X, W_data, W_indices, W_indptr, out] + +task = tvm.auto_scheduler.SearchTask( + func=sparse_dense, + args=( + X_np.shape, + W_sp_np.data.shape, + W_sp_np.indices.shape, + W_sp_np.indptr.shape, + "float32" + ), + target=target +) + +# Inspect the computational graph +print("Computational DAG:") +print(task.compute_dag) + +###################################################################### + +def meet_condition_func(search_policy, state, stage_id): + state = auto_scheduler.loop_state.State(state, search_policy.search_task.compute_dag) + if state.stages[stage_id].op.tag in [ + "sparse_dense_sp_rhs_bsrmm", "sparse_dense_sp_rhs_bsrmm_block" + ]: + return auto_scheduler.PreloadCustomSketchRule.APPLY_AND_SKIP_REST + else: + return auto_scheduler.PreloadCustomSketchRule.PASS + +def apply_func(search_policy, state, stage_id): + ret = [] + s0 = auto_scheduler.loop_state.State(state, search_policy.search_task.compute_dag) + if s0.stages[stage_id].op.tag == "sparse_dense_sp_rhs_bsrmm_block": + return [s0.state_object, stage_id - 1] + + sparse_dense = s0.stages[stage_id].op + sparse_dense_block = s0.stages[stage_id - 1].op + assert sparse_dense.tag == "sparse_dense_sp_rhs_bsrmm" + assert sparse_dense_block.tag == "sparse_dense_sp_rhs_bsrmm_block" + + s1 = s0.copy() + i, nb_j, j, row_offset, c = s0[sparse_dense_block].iters + m, n = s0[sparse_dense].iters + i0, i1, i2 = s0.split(sparse_dense_block, i, [None, None]) + m0, m1 = s0.follow_split(sparse_dense, m, len(s0.transform_steps) - 1, 1) + j0, j1 = s0.split(sparse_dense_block, nb_j, [None]) + n0, n1 = s0.follow_split(sparse_dense, n, len(s0.transform_steps) - 1, 1) + s0.reorder(sparse_dense_block, [i0, j0, i1, j1, row_offset, i2, j, c]) + s0.reorder(sparse_dense, [m0, n0, m1, n1]) + s0.compute_at(sparse_dense_block, sparse_dense, n0) + + ret.append([s0.state_object, stage_id - 2]) + + return ret + +###################################################################### +# Next, we set parameters for the auto-scheduler. +# +# * :code:`num_measure_trials` is the number of measurement trials we can use during the search. +# We only make 10 trials in this tutorial for a fast demonstration. In practice, 1000 is a +# good value for the search to converge. You can do more trials according to your time budget. +# * In addition, we use :code:`RecordToFile` to dump measurement records into a file `matmul.json`. +# The measurement records can be used to query the history best, resume the search, +# and do more analyses later. +# * see :any:`auto_scheduler.TuningOptions` for more parameters + +log_file = "sparse_dense.json" +tune_option = auto_scheduler.TuningOptions( + num_measure_trials=10, + measure_callbacks=[auto_scheduler.RecordToFile(log_file)], + verbose=2, +) + +search_policy = auto_scheduler.SketchPolicy( + task, + program_cost_model=auto_scheduler.XGBModel(), + init_search_callbacks=[ + auto_scheduler.PreloadCustomSketchRule(meet_condition_func, apply_func, "SparseDense") + ] +) + +###################################################################### +# Run the search +# ^^^^^^^^^^^^^^ +# Now we get all inputs ready. Pretty simple, isn't it? +# We can kick off the search and let the auto-scheduler do its magic. +# After some measurement trials, we can load the best schedule from the log +# file and apply it. + +# Run auto-tuning (search) +task.tune(tune_option, search_policy) +# Apply the best schedule +sch, args = task.apply_best(log_file) + +# args = sparse_dense( +# X_np.shape, +# W_sp_np.data.shape, +# W_sp_np.indices.shape, +# W_sp_np.indptr.shape, +# "float32") + +# sch = tvm.te.create_schedule([arg.op for arg in args]) + +###################################################################### +# We can lower the schedule to see the IR after auto-scheduling. +# The auto-scheduler correctly performs optimizations including multi-level tiling, +# layout transformation, parallelization, vectorization, unrolling, and operator fusion. + +print("Lowered TIR:") +print(tvm.lower(sch, args, simple_mode=True)) + +###################################################################### +# Check correctness and evaluate performance +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# We build the binary and check its correctness and performance. + +func = tvm.build(sch, args, target) + +ctx = tvm.cpu() + +X_np = np.random.randn(M, K).astype("float32") +W_sp_np = random_bsr_matrix(N, K, BS_R, BS_C, density=density, dtype="float32") +W_np = W_sp_np.todense() +Y_np = X_np @ W_np.T + +X_tvm = tvm.nd.array(X_np, ctx=ctx) +W_data_tvm = tvm.nd.array(W_sp_np.data, ctx=ctx) +W_indices_tvm = tvm.nd.array(W_sp_np.indices, ctx=ctx) +W_indptr_tvm = tvm.nd.array(W_sp_np.indptr, ctx=ctx) +Y_tvm = tvm.nd.empty(Y_np.shape, ctx=ctx) + +func(X_tvm, W_data_tvm, W_indices_tvm, W_indptr_tvm, Y_tvm) + +# Check results +tvm.testing.assert_allclose(Y_np, Y_tvm.asnumpy(), atol=1e-4, rtol=1e-4) + +# Evaluate execution time. +evaluator = func.time_evaluator(func.entry_name, ctx, min_repeat_ms=500) +print( + "Execution time of this operator: %.3f ms" + % (np.median(evaluator(X_tvm, W_data_tvm, W_indices_tvm, W_indptr_tvm, Y_tvm).results) * 1000) +) + + +###################################################################### +# Using the record file +# ^^^^^^^^^^^^^^^^^^^^^ +# During the search, all measurement records are dumped into the record +# file "matmul.json". The measurement records can be used to re-apply search results, +# resume the search, and perform other analyses. + +###################################################################### +# Here is an example where we load the best schedule from a file, +# and print the equivalent python schedule API. This can be used for +# debugging and learning the behavior of the auto-scheduler. + +print("Equivalent python schedule:") +print(task.print_best(log_file)) + +###################################################################### +# A more complicated example is to resume the search. +# In this case, we need to create the search policy and cost model by ourselves +# and resume the status of search policy and cost model with the log file. +# In the example below we resume the status and do more 5 trials. + + +def resume_search(task, log_file): + print("Resume search:") + cost_model = auto_scheduler.XGBModel() + cost_model.update_from_file(log_file) + search_policy = auto_scheduler.SketchPolicy( + task, cost_model, init_search_callbacks=[ + auto_scheduler.PreloadMeasuredStates(log_file), + auto_scheduler.PreloadCustomSketchRule(meet_condition_func, apply_func, "SparseDense") + ] + ) + tune_option = auto_scheduler.TuningOptions( + num_measure_trials=5, measure_callbacks=[auto_scheduler.RecordToFile(log_file)] + ) + task.tune(tune_option, search_policy=search_policy) + + +resume_search(task, log_file) From b577eddb4c789220c24ffc1d0d8f677e896224fe Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 20 Jan 2021 11:40:07 +0800 Subject: [PATCH 02/35] Add sparse input fusion --- tutorials/auto_scheduler/tune_sparse_x86.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tutorials/auto_scheduler/tune_sparse_x86.py b/tutorials/auto_scheduler/tune_sparse_x86.py index 3634e4efc883..52b89ebeb07b 100644 --- a/tutorials/auto_scheduler/tune_sparse_x86.py +++ b/tutorials/auto_scheduler/tune_sparse_x86.py @@ -92,6 +92,7 @@ def random_bsr_matrix(M, N, BS_R, BS_C, density, dtype): density = 0.6 X_np = np.random.randn(M, K).astype("float32") +X_np = np.maximum(np.zeros((M, K), dtype="float32"), X_np) # Relu W_sp_np = random_bsr_matrix(N, K, BS_R, BS_C, density=density, dtype="float32") W_np = W_sp_np.todense() Y_np = X_np @ W_np.T @@ -108,7 +109,7 @@ def sparse_dense(dense_shape, w_data_shape, w_indices_shape, w_indptr_shape, dty W_indices = te.placeholder(shape=w_indices_shape, dtype="int32") W_indptr = te.placeholder(shape=w_indptr_shape, dtype="int32") - out = topi.nn.sparse_dense(X, W_data, W_indices, W_indptr) + out = topi.nn.sparse_dense(topi.nn.relu(X), W_data, W_indices, W_indptr) return [X, W_data, W_indices, W_indptr, out] @@ -230,11 +231,6 @@ def apply_func(search_policy, state, stage_id): ctx = tvm.cpu() -X_np = np.random.randn(M, K).astype("float32") -W_sp_np = random_bsr_matrix(N, K, BS_R, BS_C, density=density, dtype="float32") -W_np = W_sp_np.todense() -Y_np = X_np @ W_np.T - X_tvm = tvm.nd.array(X_np, ctx=ctx) W_data_tvm = tvm.nd.array(W_sp_np.data, ctx=ctx) W_indices_tvm = tvm.nd.array(W_sp_np.indices, ctx=ctx) From 594fdb499d06fdc0ba64a372d760d6de23fb5ded Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 20 Jan 2021 16:05:04 +0800 Subject: [PATCH 03/35] Update the dag to support output fusion --- python/tvm/auto_scheduler/measure.py | 155 +++++++++++++------- src/auto_scheduler/search_policy/utils.cc | 16 ++ tutorials/auto_scheduler/tune_sparse_x86.py | 138 ++++++++++------- 3 files changed, 203 insertions(+), 106 deletions(-) diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 348d0ba6074f..ef3b9be8d644 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -45,6 +45,7 @@ from tvm.ir import transform from tvm.autotvm.measure.measure_methods import set_cuda_target_arch from tvm.contrib import tar, ndk +from tvm.te import PlaceholderOp, ComputeOp from . import _ffi_api from .loop_state import StateObject @@ -722,59 +723,80 @@ def local_builder_build(inputs, timeout, n_parallel, build_func="default", verbo def _process_sparse_input(args): - for arg in args: - if isinstance(arg.op, tvm.te.tensor.ComputeOp) and \ - arg.op.tag == "sparse_dense_sp_rhs_bsrmm": - # Get output shape - output_tensor = arg - M, N = output_tensor.shape - - # Get the input tensors - block_tensor = arg.op.input_tensors[0] - unsure_tensors = list(block_tensor.op.input_tensors) - assert len(unsure_tensors) == 4 - - # Get the input data - dense_data = None - for tensor in unsure_tensors: - if len(tensor.shape) == 2: - assert dense_data is None - dense_data = tensor - assert M == dense_data.shape[0] - K = dense_data.shape[1] - unsure_tensors.remove(dense_data) - - # Get the Sparse data - sparse_data = None - for tensor in unsure_tensors: - if len(tensor.shape) == 3: - assert sparse_data is None - sparse_data = tensor - block_size, BS_R, BS_C = sparse_data.shape - unsure_tensors.remove(sparse_data) - - # Get the Sparse indptr & indices - sparse_indices = None - for tensor in unsure_tensors: - assert len(tensor.shape) == 1 - if tensor.shape[0] == block_size: - assert sparse_indices is None - sparse_indices = tensor - unsure_tensors.remove(sparse_indices) - sparse_indptr = unsure_tensors[0] - - density = 1.0 - for i in sparse_data.shape: - density *= i - density /= (K * N) - density = density.value - sparse_prefix = "sparse_dense_bsr_%d_%d_%d_%d_%d_%.2f_" % ( - M, N, K, BS_R, BS_C, density - ) + sparse_prefix = sparse_data = sparse_indices = sparse_indptr = None + + def _process_inputs(input_tensors, M, N, prefix_init): + nonlocal sparse_prefix + nonlocal sparse_data + nonlocal sparse_indices + nonlocal sparse_indptr + + assert len(input_tensors) == 4 + unsure_tensors = list(input_tensors) + # Get the Dense data + dense_data = None + for tensor in unsure_tensors: + if len(tensor.shape) == 2: + assert dense_data is None + dense_data = tensor + assert M == dense_data.shape[0] + K = dense_data.shape[1] + unsure_tensors.remove(dense_data) + + # Get the Sparse data + sparse_data = None + for tensor in unsure_tensors: + if len(tensor.shape) == 3: + assert sparse_data is None + sparse_data = tensor + block_size, BS_R, BS_C = sparse_data.shape + unsure_tensors.remove(sparse_data) + + # Get the Sparse indptr & indices + sparse_indices = None + for tensor in unsure_tensors: + assert len(tensor.shape) == 1 + if tensor.shape[0] == block_size: + assert sparse_indices is None + sparse_indices = tensor + unsure_tensors.remove(sparse_indices) + assert len(unsure_tensors) == 1 + sparse_indptr = unsure_tensors[0] + + # Generate the sparse_prefix + density = 1.0 + for i in sparse_data.shape: + density *= i + density /= (K * N) + density = density.value + sparse_prefix = "%s_%d_%d_%d_%d_%d_%.2f_" % ( + prefix_init, M, N, K, BS_R, BS_C, density + ) + + visited = set() + def _traverse(t): + # We cannot directly add tensors to the set, because the comparison of + # two tensors with ndim=0 is ambiguous. + assert t.handle is not None + if t.handle.value in visited: + return + if isinstance(t.op, ComputeOp): + # TODO(jcf94): Currently only support to tune one sparse op + if t.op.tag == "sparse_dense_sp_rhs_bsrmm": + M, N = t.shape + assert len(t.op.input_tensors) == 1 + block_tensor = t.op.input_tensors[0] + _process_inputs(block_tensor.op.input_tensors, M, N, "sparse_dense_bsr") + if sparse_prefix is not None: + return + for x in t.op.input_tensors: + _traverse(x) + visited.add(t.handle.value) - return sparse_prefix, sparse_data, sparse_indices, sparse_indptr + for arg in args: + _traverse(arg) - return None, None, None, None + return sparse_prefix, sparse_data, sparse_indices, sparse_indptr def _timed_eval_func( inp_serialized, @@ -815,11 +837,12 @@ def _timed_eval_func( if error_no == 0: try: + random_fill = tvm.get_global_func("tvm.contrib.random.random_fill", True) + assert random_fill, "Please make sure USE_RANDOM is ON in the config.cmake" + # Check sparse op sparse_prefix, sparse_data, sparse_indices, sparse_indptr = \ _process_sparse_input(build_res.args) - random_fill = tvm.get_global_func("tvm.contrib.random.random_fill", True) - assert random_fill, "Please make sure USE_RANDOM is ON in the config.cmake" if sparse_prefix: args = [] for arg in build_res.args: @@ -1019,18 +1042,36 @@ def _timed_rpc_run( if error_no == 0: try: - args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in build_res.args] try: random_fill = remote.get_function("tvm.contrib.random.random_fill") except AttributeError: raise AttributeError( "Please make sure USE_RANDOM is ON in the config.cmake " "on the remote devices" ) - for arg in args: - random_fill(arg) - ctx.sync() + # Check sparse op + sparse_prefix, sparse_data, sparse_indices, sparse_indptr = \ + _process_sparse_input(build_res.args) + if sparse_prefix: + args = [] + for arg in build_res.args: + if arg == sparse_data: + args.append(ndarray.array(get_special_buffer(sparse_prefix+"W_data"), ctx)) + elif arg == sparse_indices: + args.append(ndarray.array(get_special_buffer(sparse_prefix+"W_indices"), ctx)) + elif arg == sparse_indptr: + args.append(ndarray.array(get_special_buffer(sparse_prefix+"W_indptr"), ctx)) + else: + empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, ctx) + random_fill(empty_array) + args.append(empty_array) + else: + args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in build_res.args] + for arg in args: + random_fill(arg) + ctx.sync() costs = time_f(*args).results + # clean up remote files remote.remove(build_res.filename) remote.remove(os.path.splitext(build_res.filename)[0] + ".so") diff --git a/src/auto_scheduler/search_policy/utils.cc b/src/auto_scheduler/search_policy/utils.cc index d59df6965776..ce8dc39922e0 100644 --- a/src/auto_scheduler/search_policy/utils.cc +++ b/src/auto_scheduler/search_policy/utils.cc @@ -465,6 +465,22 @@ const std::vector& SplitFactorizationMemo::GetFactors(int n) { /********** Utils interface API for ffi **********/ +TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsGetConsumers") + .set_body_typed([](const SearchTask& task, const State& state, int stage_id) { + const std::set& consumers = GetConsumers(task, state, stage_id); + tvm::Map ret; + for (const auto& i : consumers) { + ret.Set(Integer(i), Integer(i)); + } + return ret; + }); + +TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsIsElementwiseMatch") + .set_body_typed([](const SearchTask& task, const State& state, int stage_id, + int target_stage_id) { + return ElementwiseMatch(task, state, stage_id, target_stage_id); + }); + TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsIsTiled") .set_body_typed([](const Stage& stage) { return IsTiled(stage); }); diff --git a/tutorials/auto_scheduler/tune_sparse_x86.py b/tutorials/auto_scheduler/tune_sparse_x86.py index 52b89ebeb07b..9c135ea5c4a1 100644 --- a/tutorials/auto_scheduler/tune_sparse_x86.py +++ b/tutorials/auto_scheduler/tune_sparse_x86.py @@ -15,19 +15,19 @@ # specific language governing permissions and limitations # under the License. """ -Auto-scheduling Matrix Multiplication for CPU -============================================= +Auto-scheduling Sparse Matrix Multiplication for CPU by Custom Sketch Rule +========================================================================== **Author**: `Chengfan Jia `_ -This is a tutorial on how to use the auto-scheduler for CPUs. +This is a tutorial on how to use the auto-scheduler to tune a sparse matrix multiplication for +CPUs. -Different from the template-based :ref:`autotvm ` which relies on -manual templates to define the search space, the auto-scheduler does not require any templates. -Users only need to write the computation declaration without any schedule commands or templates. -The auto-scheduler can automatically generate a large search space and -find a good schedule in the space. +Auto-scheduler is designed to explore the schedule with best performance for a given computation +declaration automatically. While sometimes, we may have a demand to try some special ops which may +not been well supported by auto-scheduler's default search policy. Auto-scheduler currently allows +user to provide a CustomSketch to cover these cases. -We use matrix multiplication as an example in this tutorial. +We use sparse matrix multiplication as an example in this tutorial. Note that this tutorial will not run on Windows or recent versions of macOS. To get it to run, you will need to wrap the body of this tutorial in a :code:`if @@ -40,6 +40,7 @@ import numpy as np import tvm from tvm import te, auto_scheduler, topi +from tvm.auto_scheduler import _ffi_api from tvm.topi.utils import get_const_tuple import scipy.sparse as sp @@ -47,11 +48,11 @@ ###################################################################### # Define the computation # ^^^^^^^^^^^^^^^^^^^^^^ -# To begin with, let us define the computation of a matmul with bias add. +# To begin with, let us define the computation of a sparse matmul with several relu and bias add. # The function should return the list of input/output tensors. # From these tensors, the auto-scheduler can get the whole computational graph. - +# We use this function to generate a random bsr matrix def random_bsr_matrix(M, N, BS_R, BS_C, density, dtype): import itertools @@ -74,49 +75,70 @@ def random_bsr_matrix(M, N, BS_R, BS_C, density, dtype): assert s.indptr.shape == (M // BS_R + 1,) return s +@auto_scheduler.register_workload +def sparse_dense(M, N, K, w_data_shape, w_indices_shape, w_indptr_shape, dtype): + X = te.placeholder(shape=(M, K), dtype=dtype) + W_data = te.placeholder(shape=w_data_shape, dtype=dtype) + W_indices = te.placeholder(shape=w_indices_shape, dtype="int32") + W_indptr = te.placeholder(shape=w_indptr_shape, dtype="int32") + B = te.placeholder(shape=(M, N), dtype=dtype) + + out = topi.nn.sparse_dense( + topi.nn.relu(X), W_data, W_indices, W_indptr + ) + out = te.compute((M, N), lambda i, j: out[i, j] + B[i, j], name="BiasAdd") + out = topi.nn.relu(out) + + return [X, W_data, W_indices, W_indptr, B, out] ###################################################################### -# Create the search task -# ^^^^^^^^^^^^^^^^^^^^^^ -# We then create a search task with N=L=M=1024 and dtype="float32" -# If your machine supports avx instructions, you can +# Special step for sparse workload +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# During schedule tuning, auto-scheduler will use random inputs to measure the performance of a +# generated schedule. While we cannot directly use a random array as the input of a sparse op, for +# the "indices" and "indptr" array are meaningful for the computation. # -# - replace "llvm" below with "llvm -mcpu=core-avx2" to enable AVX2 -# - replace "llvm" below with "llvm -mcpu=skylake-avx512" to enable AVX-512 - -target = tvm.target.Target("llvm -mcpu=core-avx2") +# To solve this problem, we register these as special buffers, and load them when process program +# measuring. +# See the :any:`auto_scheduler.measure` code for more details. +# Define the basic shapes of this sparse computation M = K = N = 512 BS_R = 16 BS_C = 1 density = 0.6 +# Generate the test data with numpy X_np = np.random.randn(M, K).astype("float32") X_np = np.maximum(np.zeros((M, K), dtype="float32"), X_np) # Relu W_sp_np = random_bsr_matrix(N, K, BS_R, BS_C, density=density, dtype="float32") W_np = W_sp_np.todense() -Y_np = X_np @ W_np.T +Y_np = X_np @ W_np.T # Process the matrix multiplication +B_np = np.random.randn(M, N).astype("float32") +Y_np = Y_np + B_np # Bias add +Y_np = np.maximum(np.zeros((M, N), dtype="float32"), Y_np) # Relu +# Register the sparse data to special buffer prefix = "sparse_dense_bsr_%d_%d_%d_%d_%d_%.2f_" % (M, N, K, BS_R, BS_C, density) auto_scheduler.measure.register_special_buffer(prefix + "W_data", W_sp_np.data) auto_scheduler.measure.register_special_buffer(prefix + "W_indices", W_sp_np.indices) auto_scheduler.measure.register_special_buffer(prefix + "W_indptr", W_sp_np.indptr) -@auto_scheduler.register_workload -def sparse_dense(dense_shape, w_data_shape, w_indices_shape, w_indptr_shape, dtype): - X = te.placeholder(shape=dense_shape, dtype=dtype) - W_data = te.placeholder(shape=w_data_shape, dtype=dtype) - W_indices = te.placeholder(shape=w_indices_shape, dtype="int32") - W_indptr = te.placeholder(shape=w_indptr_shape, dtype="int32") - - out = topi.nn.sparse_dense(topi.nn.relu(X), W_data, W_indices, W_indptr) +###################################################################### +# Create the search task +# ^^^^^^^^^^^^^^^^^^^^^^ +# We then create a search task with M=N=K=512 and dtype="float32" +# If your machine supports avx instructions, you can +# +# - replace "llvm" below with "llvm -mcpu=core-avx2" to enable AVX2 +# - replace "llvm" below with "llvm -mcpu=skylake-avx512" to enable AVX-512 - return [X, W_data, W_indices, W_indptr, out] +target = tvm.target.Target("llvm") task = tvm.auto_scheduler.SearchTask( func=sparse_dense, args=( - X_np.shape, + M, N, K, W_sp_np.data.shape, W_sp_np.indices.shape, W_sp_np.indptr.shape, @@ -130,6 +152,16 @@ def sparse_dense(dense_shape, w_data_shape, w_indices_shape, w_indptr_shape, dty print(task.compute_dag) ###################################################################### +# Write the custom sketch for sparse dense op +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# Before tuning, we will need to define the CustomSketchRule for the sparse dense op. +# +# CustomSketchRule consists of two parts: the condition function and the apply function. +# +# - condition function: describe when to use this sketch rule. For example, we can match the op +# by their name or tag. +# - apply function: describe how to generate the initial sketch. Auto-scheduler provides a set of +# loop state APIs. def meet_condition_func(search_policy, state, stage_id): state = auto_scheduler.loop_state.State(state, search_policy.search_task.compute_dag) @@ -151,16 +183,31 @@ def apply_func(search_policy, state, stage_id): assert sparse_dense.tag == "sparse_dense_sp_rhs_bsrmm" assert sparse_dense_block.tag == "sparse_dense_sp_rhs_bsrmm_block" - s1 = s0.copy() + # Set the default consumer of compute block + consumer = sparse_dense + + # If sparse dense has a single elementwise consumer + # We can compute inline the sparse_dense output stage + consumers = _ffi_api.SearchPolicyUtilsGetConsumers( + search_policy.search_task, s0.state_object, stage_id + ) + if len(consumers) == 1: + consumer_id = int(consumers.items()[0][0]) + if _ffi_api.SearchPolicyUtilsIsElementwiseMatch( + search_policy.search_task, s0.state_object, stage_id, consumer_id + ): + consumer = s0.stages[consumer_id].op + s0.compute_inline(sparse_dense) + i, nb_j, j, row_offset, c = s0[sparse_dense_block].iters - m, n = s0[sparse_dense].iters + m, n = s0[consumer].iters i0, i1, i2 = s0.split(sparse_dense_block, i, [None, None]) - m0, m1 = s0.follow_split(sparse_dense, m, len(s0.transform_steps) - 1, 1) + m0, m1 = s0.follow_split(consumer, m, len(s0.transform_steps) - 1, 1) j0, j1 = s0.split(sparse_dense_block, nb_j, [None]) - n0, n1 = s0.follow_split(sparse_dense, n, len(s0.transform_steps) - 1, 1) + n0, n1 = s0.follow_split(consumer, n, len(s0.transform_steps) - 1, 1) s0.reorder(sparse_dense_block, [i0, j0, i1, j1, row_offset, i2, j, c]) - s0.reorder(sparse_dense, [m0, n0, m1, n1]) - s0.compute_at(sparse_dense_block, sparse_dense, n0) + s0.reorder(consumer, [m0, n0, m1, n1]) + s0.compute_at(sparse_dense_block, consumer, n0) ret.append([s0.state_object, stage_id - 2]) @@ -176,6 +223,8 @@ def apply_func(search_policy, state, stage_id): # The measurement records can be used to query the history best, resume the search, # and do more analyses later. # * see :any:`auto_scheduler.TuningOptions` for more parameters +# * Here, we need to create a :code:`auto_scheduler.SketchPolicy` object, and add the custom sketch +# rule as a `init_search_callbacks`. log_file = "sparse_dense.json" tune_option = auto_scheduler.TuningOptions( @@ -195,7 +244,7 @@ def apply_func(search_policy, state, stage_id): ###################################################################### # Run the search # ^^^^^^^^^^^^^^ -# Now we get all inputs ready. Pretty simple, isn't it? +# Now we get all inputs ready. # We can kick off the search and let the auto-scheduler do its magic. # After some measurement trials, we can load the best schedule from the log # file and apply it. @@ -205,15 +254,6 @@ def apply_func(search_policy, state, stage_id): # Apply the best schedule sch, args = task.apply_best(log_file) -# args = sparse_dense( -# X_np.shape, -# W_sp_np.data.shape, -# W_sp_np.indices.shape, -# W_sp_np.indptr.shape, -# "float32") - -# sch = tvm.te.create_schedule([arg.op for arg in args]) - ###################################################################### # We can lower the schedule to see the IR after auto-scheduling. # The auto-scheduler correctly performs optimizations including multi-level tiling, @@ -235,9 +275,10 @@ def apply_func(search_policy, state, stage_id): W_data_tvm = tvm.nd.array(W_sp_np.data, ctx=ctx) W_indices_tvm = tvm.nd.array(W_sp_np.indices, ctx=ctx) W_indptr_tvm = tvm.nd.array(W_sp_np.indptr, ctx=ctx) +B_tvm = tvm.nd.array(B_np, ctx=ctx) Y_tvm = tvm.nd.empty(Y_np.shape, ctx=ctx) -func(X_tvm, W_data_tvm, W_indices_tvm, W_indptr_tvm, Y_tvm) +func(X_tvm, W_data_tvm, W_indices_tvm, W_indptr_tvm, B_tvm, Y_tvm) # Check results tvm.testing.assert_allclose(Y_np, Y_tvm.asnumpy(), atol=1e-4, rtol=1e-4) @@ -246,10 +287,9 @@ def apply_func(search_policy, state, stage_id): evaluator = func.time_evaluator(func.entry_name, ctx, min_repeat_ms=500) print( "Execution time of this operator: %.3f ms" - % (np.median(evaluator(X_tvm, W_data_tvm, W_indices_tvm, W_indptr_tvm, Y_tvm).results) * 1000) + % (np.median(evaluator(X_tvm, W_data_tvm, W_indices_tvm, W_indptr_tvm, B_tvm, Y_tvm).results) * 1000) ) - ###################################################################### # Using the record file # ^^^^^^^^^^^^^^^^^^^^^ From a4b025b985880e6bdf9da77a9f5d621c65115a81 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Thu, 28 Jan 2021 14:40:28 +0800 Subject: [PATCH 04/35] Update --- python/tvm/auto_scheduler/measure.py | 5 + .../auto_scheduler/tune_sparse_conv2d_x86.py | 379 ++++++++++++++++++ tutorials/auto_scheduler/tune_sparse_x86.py | 61 +-- 3 files changed, 395 insertions(+), 50 deletions(-) create mode 100644 tutorials/auto_scheduler/tune_sparse_conv2d_x86.py diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index ef3b9be8d644..ca362408640b 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -787,6 +787,11 @@ def _traverse(t): assert len(t.op.input_tensors) == 1 block_tensor = t.op.input_tensors[0] _process_inputs(block_tensor.op.input_tensors, M, N, "sparse_dense_bsr") + if t.op.tag == "sparse_conv2d_bsrmm": + N, OH = t.shape[0], t.shape[1] + assert len(t.op.input_tensors) == 1 + block_tensor = t.op.input_tensors[0] + _process_inputs(block_tensor.op.input_tensors, N, OH, "sparse_dense_bsr") if sparse_prefix is not None: return for x in t.op.input_tensors: diff --git a/tutorials/auto_scheduler/tune_sparse_conv2d_x86.py b/tutorials/auto_scheduler/tune_sparse_conv2d_x86.py new file mode 100644 index 000000000000..53fe32294840 --- /dev/null +++ b/tutorials/auto_scheduler/tune_sparse_conv2d_x86.py @@ -0,0 +1,379 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Auto-scheduling Sparse Matrix Multiplication for CPU by Custom Sketch Rule +========================================================================== +**Author**: `Chengfan Jia `_ + +This is a tutorial on how to use the auto-scheduler to tune a sparse matrix multiplication for +CPUs. + +Auto-scheduler is designed to explore the schedule with best performance for a given computation +declaration automatically. While sometimes, we may have a demand to try some special ops which may +not been well supported by auto-scheduler's default search policy. Auto-scheduler currently allows +user to provide a CustomSketch to cover these cases. + +We use sparse matrix multiplication as an example in this tutorial. + +Note that this tutorial will not run on Windows or recent versions of macOS. To +get it to run, you will need to wrap the body of this tutorial in a :code:`if +__name__ == "__main__":` block. +""" + +import os +import itertools + +import numpy as np +import tvm +from tvm import te, auto_scheduler, topi +from tvm.auto_scheduler import _ffi_api +from tvm.topi.utils import get_const_tuple + +import scipy.sparse as sp + +###################################################################### +# Define the computation +# ^^^^^^^^^^^^^^^^^^^^^^ +# To begin with, let us define the computation of a sparse matmul with several relu and bias add. +# The function should return the list of input/output tensors. +# From these tensors, the auto-scheduler can get the whole computational graph. + +# We use this function to generate a random bsr matrix +def random_bsr_matrix(M, N, BS_R, BS_C, density, dtype): + import itertools + + Y = np.zeros((M, N), dtype=dtype) + assert M % BS_R == 0 + assert N % BS_C == 0 + nnz = int(density * M * N) + num_blocks = int(nnz / (BS_R * BS_C)) + 1 + candidate_blocks = np.asarray(list(itertools.product(range(0, M, BS_R), range(0, N, BS_C)))) + assert candidate_blocks.shape[0] == M // BS_R * N // BS_C + chosen_blocks = candidate_blocks[ + np.random.choice(candidate_blocks.shape[0], size=num_blocks, replace=False) + ] + for i in range(len(chosen_blocks)): + r, c = chosen_blocks[i] + Y[r : r + BS_R, c : c + BS_C] = np.random.randn(BS_R, BS_C) + s = sp.bsr_matrix(Y, blocksize=(BS_R, BS_C)) + assert s.data.shape == (num_blocks, BS_R, BS_C) + assert s.indices.shape == (num_blocks,) + assert s.indptr.shape == (M // BS_R + 1,) + return s + +def sparse_conv2d_bsr_compute(data, weight_data, weight_indices, weight_indptr): + ''' + Y = X * W^T + Y[m, h, w, n] = X[m, h, w, k] * W[1, 1, k, n] + Y[n, oh, ow, oc] = X[n, ih, iw, ic] * W[1, 1, ic, oc] + NHWC + HWIO + ''' + (m, h, w, k) = get_const_tuple(data.shape) + (_, bs_r, bs_c) = get_const_tuple(weight_data.shape) + (num_blocks_plus_1,) = get_const_tuple(weight_indptr.shape) + num_blocks = num_blocks_plus_1 - 1 + + def _compute_block(i, h, w, nb_j, j): + row_start = weight_indptr[nb_j] + row_end = weight_indptr[nb_j + 1] + row_elems = row_end - row_start + elem_idx = te.reduce_axis((0, row_elems), name="elem_idx") + block_offset = row_start + elem_idx + c = te.reduce_axis((0, bs_c), name="c") + block_j = weight_indices[block_offset] + block_ij_val = weight_data[block_offset][j][c] + x_val = data[i, h, w, bs_c * block_j + c] + return te.sum(block_ij_val * x_val, axis=[elem_idx, c]) + + idxd = tvm.tir.indexdiv + idxm = tvm.tir.indexmod + + bsrmm_block = te.compute( + (m, h, w, num_blocks, bs_r), _compute_block, + tag="sparse_conv2d_sp_bsrmm_block", + attrs={"FLOP": 2 * m * num_blocks * bs_r * k}, + ) + return te.compute( + (m, h, w, num_blocks * bs_r), + lambda m, h, w, n: bsrmm_block[m, h, w, idxd(n, bs_r), idxm(n, bs_r)], + tag="sparse_conv2d_sp_bsrmm", + ) + +@auto_scheduler.register_workload +def sparse_conv(N, OH, OW, OC, IH, IW, IC, w_data_shape, w_indices_shape, w_indptr_shape, dtype): + X = te.placeholder(shape=(N, IH, IW, IC), dtype=dtype) + W_data = te.placeholder(shape=w_data_shape, dtype=dtype) + W_indices = te.placeholder(shape=w_indices_shape, dtype="int32") + W_indptr = te.placeholder(shape=w_indptr_shape, dtype="int32") + B = te.placeholder(shape=(N, OH, OW, OC), dtype=dtype) + + out = sparse_conv2d_bsr_compute( + topi.nn.relu(X), W_data, W_indices, W_indptr + ) + out = te.compute((N, OH, OW, OC), lambda i, j, k, l: out[i, j, k, l] + B[i, j, k, l], name="BiasAdd") + out = topi.nn.relu(out) + + return [X, W_data, W_indices, W_indptr, B, out] + +###################################################################### +# Special step for sparse workload +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# During schedule tuning, auto-scheduler will use random inputs to measure the performance of a +# generated schedule. While we cannot directly use a random array as the input of a sparse op, for +# the "indices" and "indptr" array are meaningful for the computation. +# +# To solve this problem, we register these as special buffers, and load them when process program +# measuring. +# See the :any:`auto_scheduler.measure` code for more details. + +# Define the basic shapes of this sparse computation + +#Y[n, oh, ow, oc] = X[n, ih, iw, ic] * W[1, 1, ic, oc] +N = 1 +OH = OW = 16 +OC = 128 +IH = IW = 16 +IC = 128 + +BS_R = 8 +BS_C = 1 +density = 0.6 + +# Generate the test data with numpy +X_np = np.random.randn(N, IH, IW, IC).astype("float32") +X_np = np.maximum(np.zeros((N, IH, IW, IC), dtype="float32"), X_np) # Relu +W_sp_np = random_bsr_matrix(IC, OC, BS_R, BS_C, density=density, dtype="float32") +W_np = W_sp_np.todense() + + +# Register the sparse data to special buffer +prefix = "sparse_dense_bsr_%d_%d_%d_%d_%d_%.2f_" % (N, OH, OW, BS_R, BS_C, density) +auto_scheduler.measure.register_special_buffer(prefix + "W_data", W_sp_np.data) +auto_scheduler.measure.register_special_buffer(prefix + "W_indices", W_sp_np.indices) +auto_scheduler.measure.register_special_buffer(prefix + "W_indptr", W_sp_np.indptr) + +###################################################################### +# Create the search task +# ^^^^^^^^^^^^^^^^^^^^^^ +# We then create a search task with M=N=K=512 and dtype="float32" +# If your machine supports avx instructions, you can +# +# - replace "llvm" below with "llvm -mcpu=core-avx2" to enable AVX2 +# - replace "llvm" below with "llvm -mcpu=skylake-avx512" to enable AVX-512 + +target = tvm.target.Target("llvm -mcpu=core-avx2") + +task = tvm.auto_scheduler.SearchTask( + func=sparse_conv, + args=( + N, OH, OW, OC, IH, IW, IC, + W_sp_np.data.shape, + W_sp_np.indices.shape, + W_sp_np.indptr.shape, + "float32" + ), + target=target +) + +# Inspect the computational graph +print("Computational DAG:") +print(task.compute_dag) + +###################################################################### +# Write the custom sketch for sparse dense op +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# Before tuning, we will need to define the CustomSketchRule for the sparse dense op. +# +# CustomSketchRule consists of two parts: the condition function and the apply function. +# +# - condition function: describe when to use this sketch rule. For example, we can match the op +# by their name or tag. +# - apply function: describe how to generate the initial sketch. Auto-scheduler provides a set of +# loop state APIs. + +def meet_condition_func(search_policy, state, stage_id): + state = auto_scheduler.loop_state.State(state, search_policy.search_task.compute_dag) + if state.stages[stage_id].op.tag in [ + "sparse_conv2d_sp_bsrmm", "sparse_conv2d_sp_bsrmm_block" + ]: + return auto_scheduler.PreloadCustomSketchRule.APPLY_AND_SKIP_REST + else: + return auto_scheduler.PreloadCustomSketchRule.PASS + +def apply_func(search_policy, state, stage_id): + ret = [] + s0 = auto_scheduler.loop_state.State(state, search_policy.search_task.compute_dag) + if s0.stages[stage_id].op.tag == "sparse_conv2d_sp_bsrmm_block": + return [s0.state_object, stage_id - 1] + + sparse_dense = s0.stages[stage_id].op + sparse_dense_block = s0.stages[stage_id - 1].op + assert sparse_dense.tag == "sparse_conv2d_sp_bsrmm" + assert sparse_dense_block.tag == "sparse_conv2d_sp_bsrmm_block" + + # Set the default consumer of compute block + consumer = sparse_dense + + # If sparse dense has a single elementwise consumer + # We can compute inline the sparse_dense output stage + consumers = _ffi_api.SearchPolicyUtilsGetConsumers( + search_policy.search_task, s0.state_object, stage_id + ) + if len(consumers) == 1: + consumer_id = int(consumers.items()[0][0]) + if _ffi_api.SearchPolicyUtilsIsElementwiseMatch( + search_policy.search_task, s0.state_object, stage_id, consumer_id + ): + consumer = s0.stages[consumer_id].op + s0.compute_inline(sparse_dense) + + i, h, w, nb_j, j, row_offset, c = s0[sparse_dense_block].iters + m, x, y, n = s0[consumer].iters + + i0, i1, i2 = s0.split(sparse_dense_block, i, [None, None]) + m0, m1 = s0.follow_split(consumer, m, len(s0.transform_steps) - 1, 1) + h0, h1, h2 = s0.split(sparse_dense_block, h, [None, None]) + x0, x1 = s0.follow_split(consumer, x, len(s0.transform_steps) - 1, 1) + w0, w1, w2 = s0.split(sparse_dense_block, w, [None, None]) + y0, y1 = s0.follow_split(consumer, y, len(s0.transform_steps) - 1, 1) + j0, j1 = s0.split(sparse_dense_block, nb_j, [None]) + n0, n1 = s0.follow_split(consumer, n, len(s0.transform_steps) - 1, 1) + s0.reorder(sparse_dense_block, [i0, h0, w0, j0, i1, h1, w1, j1, row_offset, i2, h2, w2, j, c]) + s0.reorder(consumer, [m0, x0, y0, n0, m1, x1, y1, n1]) + s0.compute_at(sparse_dense_block, consumer, n0) + + ret.append([s0.state_object, stage_id - 2]) + + return ret + +###################################################################### +# Next, we set parameters for the auto-scheduler. +# +# * :code:`num_measure_trials` is the number of measurement trials we can use during the search. +# We only make 10 trials in this tutorial for a fast demonstration. In practice, 1000 is a +# good value for the search to converge. You can do more trials according to your time budget. +# * In addition, we use :code:`RecordToFile` to dump measurement records into a file `matmul.json`. +# The measurement records can be used to query the history best, resume the search, +# and do more analyses later. +# * see :any:`auto_scheduler.TuningOptions` for more parameters +# * Here, we need to create a :code:`auto_scheduler.SketchPolicy` object, and add the custom sketch +# rule as a `init_search_callbacks`. + +log_file = "sparse_conv.json" +tune_option = auto_scheduler.TuningOptions( + num_measure_trials=10, + measure_callbacks=[auto_scheduler.RecordToFile(log_file)], + verbose=2, +) + +search_policy = auto_scheduler.SketchPolicy( + task, + program_cost_model=auto_scheduler.XGBModel(), + init_search_callbacks=[ + auto_scheduler.PreloadCustomSketchRule(meet_condition_func, apply_func, "SparseDense") + ] +) + +###################################################################### +# Run the search +# ^^^^^^^^^^^^^^ +# Now we get all inputs ready. +# We can kick off the search and let the auto-scheduler do its magic. +# After some measurement trials, we can load the best schedule from the log +# file and apply it. + +# Run auto-tuning (search) +task.tune(tune_option, search_policy) +# Apply the best schedule +sch, args = task.apply_best(log_file) + +###################################################################### +# We can lower the schedule to see the IR after auto-scheduling. +# The auto-scheduler correctly performs optimizations including multi-level tiling, +# layout transformation, parallelization, vectorization, unrolling, and operator fusion. + +print("Lowered TIR:") +print(tvm.lower(sch, args, simple_mode=True)) + +###################################################################### +# Check correctness and evaluate performance +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# We build the binary and check its correctness and performance. + +func = tvm.build(sch, args, target) + +ctx = tvm.cpu() + +X_tvm = tvm.nd.array(X_np, ctx=ctx) +W_data_tvm = tvm.nd.array(W_sp_np.data, ctx=ctx) +W_indices_tvm = tvm.nd.array(W_sp_np.indices, ctx=ctx) +W_indptr_tvm = tvm.nd.array(W_sp_np.indptr, ctx=ctx) +B_tvm = tvm.nd.array(B_np, ctx=ctx) +Y_tvm = tvm.nd.empty(Y_np.shape, ctx=ctx) + +func(X_tvm, W_data_tvm, W_indices_tvm, W_indptr_tvm, B_tvm, Y_tvm) + +# Check results +tvm.testing.assert_allclose(Y_np, Y_tvm.asnumpy(), atol=1e-4, rtol=1e-4) + +# Evaluate execution time. +evaluator = func.time_evaluator(func.entry_name, ctx, min_repeat_ms=500) +print( + "Execution time of this operator: %.3f ms" + % (np.median(evaluator(X_tvm, W_data_tvm, W_indices_tvm, W_indptr_tvm, B_tvm, Y_tvm).results) * 1000) +) + +###################################################################### +# Using the record file +# ^^^^^^^^^^^^^^^^^^^^^ +# During the search, all measurement records are dumped into the record +# file "matmul.json". The measurement records can be used to re-apply search results, +# resume the search, and perform other analyses. + +###################################################################### +# Here is an example where we load the best schedule from a file, +# and print the equivalent python schedule API. This can be used for +# debugging and learning the behavior of the auto-scheduler. + +print("Equivalent python schedule:") +print(task.print_best(log_file)) + +###################################################################### +# A more complicated example is to resume the search. +# In this case, we need to create the search policy and cost model by ourselves +# and resume the status of search policy and cost model with the log file. +# In the example below we resume the status and do more 5 trials. + + +def resume_search(task, log_file): + print("Resume search:") + cost_model = auto_scheduler.XGBModel() + cost_model.update_from_file(log_file) + search_policy = auto_scheduler.SketchPolicy( + task, cost_model, init_search_callbacks=[ + auto_scheduler.PreloadMeasuredStates(log_file), + auto_scheduler.PreloadCustomSketchRule(meet_condition_func, apply_func, "SparseDense") + ] + ) + tune_option = auto_scheduler.TuningOptions( + num_measure_trials=5, measure_callbacks=[auto_scheduler.RecordToFile(log_file)] + ) + task.tune(tune_option, search_policy=search_policy) + + +resume_search(task, log_file) diff --git a/tutorials/auto_scheduler/tune_sparse_x86.py b/tutorials/auto_scheduler/tune_sparse_x86.py index 9c135ea5c4a1..ea0422dd01b7 100644 --- a/tutorials/auto_scheduler/tune_sparse_x86.py +++ b/tutorials/auto_scheduler/tune_sparse_x86.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. """ -Auto-scheduling Sparse Matrix Multiplication for CPU by Custom Sketch Rule -========================================================================== +Auto-scheduling Sparse Matrix Multiplication on CPU with Custom Sketch Rule +=========================================================================== **Author**: `Chengfan Jia `_ This is a tutorial on how to use the auto-scheduler to tune a sparse matrix multiplication for @@ -24,10 +24,11 @@ Auto-scheduler is designed to explore the schedule with best performance for a given computation declaration automatically. While sometimes, we may have a demand to try some special ops which may -not been well supported by auto-scheduler's default search policy. Auto-scheduler currently allows -user to provide a CustomSketch to cover these cases. +not been well-supported by auto-scheduler's default sketch rules and result in poor performance. +Fortunately, auto-scheduler currently allows user to provide a CustomSketch to cover these cases. -We use sparse matrix multiplication as an example in this tutorial. +We use sparse matrix multiplication as an example in this tutorial to demonstrate how to implement +and plug a custom sketch rule to the auto-scheduler search policy. Note that this tutorial will not run on Windows or recent versions of macOS. To get it to run, you will need to wrap the body of this tutorial in a :code:`if @@ -158,10 +159,10 @@ def sparse_dense(M, N, K, w_data_shape, w_indices_shape, w_indptr_shape, dtype): # # CustomSketchRule consists of two parts: the condition function and the apply function. # -# - condition function: describe when to use this sketch rule. For example, we can match the op -# by their name or tag. -# - apply function: describe how to generate the initial sketch. Auto-scheduler provides a set of -# loop state APIs. +# - condition function: describe when to apply this sketch rule. For example, we can only apply +# the rule to the sparse ops by matching their name and tag. +# - apply function: describe how to generate the initial sketch. You can implement it using +# auto-scheduler provided loop state APIs. def meet_condition_func(search_policy, state, stage_id): state = auto_scheduler.loop_state.State(state, search_policy.search_task.compute_dag) @@ -214,7 +215,7 @@ def apply_func(search_policy, state, stage_id): return ret ###################################################################### -# Next, we set parameters for the auto-scheduler. +# Next, we set parameters for the auto-scheduler with the custom sketch plugged in. # # * :code:`num_measure_trials` is the number of measurement trials we can use during the search. # We only make 10 trials in this tutorial for a fast demonstration. In practice, 1000 is a @@ -289,43 +290,3 @@ def apply_func(search_policy, state, stage_id): "Execution time of this operator: %.3f ms" % (np.median(evaluator(X_tvm, W_data_tvm, W_indices_tvm, W_indptr_tvm, B_tvm, Y_tvm).results) * 1000) ) - -###################################################################### -# Using the record file -# ^^^^^^^^^^^^^^^^^^^^^ -# During the search, all measurement records are dumped into the record -# file "matmul.json". The measurement records can be used to re-apply search results, -# resume the search, and perform other analyses. - -###################################################################### -# Here is an example where we load the best schedule from a file, -# and print the equivalent python schedule API. This can be used for -# debugging and learning the behavior of the auto-scheduler. - -print("Equivalent python schedule:") -print(task.print_best(log_file)) - -###################################################################### -# A more complicated example is to resume the search. -# In this case, we need to create the search policy and cost model by ourselves -# and resume the status of search policy and cost model with the log file. -# In the example below we resume the status and do more 5 trials. - - -def resume_search(task, log_file): - print("Resume search:") - cost_model = auto_scheduler.XGBModel() - cost_model.update_from_file(log_file) - search_policy = auto_scheduler.SketchPolicy( - task, cost_model, init_search_callbacks=[ - auto_scheduler.PreloadMeasuredStates(log_file), - auto_scheduler.PreloadCustomSketchRule(meet_condition_func, apply_func, "SparseDense") - ] - ) - tune_option = auto_scheduler.TuningOptions( - num_measure_trials=5, measure_callbacks=[auto_scheduler.RecordToFile(log_file)] - ) - task.tune(tune_option, search_policy=search_policy) - - -resume_search(task, log_file) From a6223e0595db5a8034a8e7026d9646d119b26b96 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 3 Feb 2021 20:55:11 +0800 Subject: [PATCH 05/35] Add task input to search_task --- .../auto_scheduler/tune_sparse_conv2d_x86.py | 379 ------------------ 1 file changed, 379 deletions(-) delete mode 100644 tutorials/auto_scheduler/tune_sparse_conv2d_x86.py diff --git a/tutorials/auto_scheduler/tune_sparse_conv2d_x86.py b/tutorials/auto_scheduler/tune_sparse_conv2d_x86.py deleted file mode 100644 index 53fe32294840..000000000000 --- a/tutorials/auto_scheduler/tune_sparse_conv2d_x86.py +++ /dev/null @@ -1,379 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Auto-scheduling Sparse Matrix Multiplication for CPU by Custom Sketch Rule -========================================================================== -**Author**: `Chengfan Jia `_ - -This is a tutorial on how to use the auto-scheduler to tune a sparse matrix multiplication for -CPUs. - -Auto-scheduler is designed to explore the schedule with best performance for a given computation -declaration automatically. While sometimes, we may have a demand to try some special ops which may -not been well supported by auto-scheduler's default search policy. Auto-scheduler currently allows -user to provide a CustomSketch to cover these cases. - -We use sparse matrix multiplication as an example in this tutorial. - -Note that this tutorial will not run on Windows or recent versions of macOS. To -get it to run, you will need to wrap the body of this tutorial in a :code:`if -__name__ == "__main__":` block. -""" - -import os -import itertools - -import numpy as np -import tvm -from tvm import te, auto_scheduler, topi -from tvm.auto_scheduler import _ffi_api -from tvm.topi.utils import get_const_tuple - -import scipy.sparse as sp - -###################################################################### -# Define the computation -# ^^^^^^^^^^^^^^^^^^^^^^ -# To begin with, let us define the computation of a sparse matmul with several relu and bias add. -# The function should return the list of input/output tensors. -# From these tensors, the auto-scheduler can get the whole computational graph. - -# We use this function to generate a random bsr matrix -def random_bsr_matrix(M, N, BS_R, BS_C, density, dtype): - import itertools - - Y = np.zeros((M, N), dtype=dtype) - assert M % BS_R == 0 - assert N % BS_C == 0 - nnz = int(density * M * N) - num_blocks = int(nnz / (BS_R * BS_C)) + 1 - candidate_blocks = np.asarray(list(itertools.product(range(0, M, BS_R), range(0, N, BS_C)))) - assert candidate_blocks.shape[0] == M // BS_R * N // BS_C - chosen_blocks = candidate_blocks[ - np.random.choice(candidate_blocks.shape[0], size=num_blocks, replace=False) - ] - for i in range(len(chosen_blocks)): - r, c = chosen_blocks[i] - Y[r : r + BS_R, c : c + BS_C] = np.random.randn(BS_R, BS_C) - s = sp.bsr_matrix(Y, blocksize=(BS_R, BS_C)) - assert s.data.shape == (num_blocks, BS_R, BS_C) - assert s.indices.shape == (num_blocks,) - assert s.indptr.shape == (M // BS_R + 1,) - return s - -def sparse_conv2d_bsr_compute(data, weight_data, weight_indices, weight_indptr): - ''' - Y = X * W^T - Y[m, h, w, n] = X[m, h, w, k] * W[1, 1, k, n] - Y[n, oh, ow, oc] = X[n, ih, iw, ic] * W[1, 1, ic, oc] - NHWC - HWIO - ''' - (m, h, w, k) = get_const_tuple(data.shape) - (_, bs_r, bs_c) = get_const_tuple(weight_data.shape) - (num_blocks_plus_1,) = get_const_tuple(weight_indptr.shape) - num_blocks = num_blocks_plus_1 - 1 - - def _compute_block(i, h, w, nb_j, j): - row_start = weight_indptr[nb_j] - row_end = weight_indptr[nb_j + 1] - row_elems = row_end - row_start - elem_idx = te.reduce_axis((0, row_elems), name="elem_idx") - block_offset = row_start + elem_idx - c = te.reduce_axis((0, bs_c), name="c") - block_j = weight_indices[block_offset] - block_ij_val = weight_data[block_offset][j][c] - x_val = data[i, h, w, bs_c * block_j + c] - return te.sum(block_ij_val * x_val, axis=[elem_idx, c]) - - idxd = tvm.tir.indexdiv - idxm = tvm.tir.indexmod - - bsrmm_block = te.compute( - (m, h, w, num_blocks, bs_r), _compute_block, - tag="sparse_conv2d_sp_bsrmm_block", - attrs={"FLOP": 2 * m * num_blocks * bs_r * k}, - ) - return te.compute( - (m, h, w, num_blocks * bs_r), - lambda m, h, w, n: bsrmm_block[m, h, w, idxd(n, bs_r), idxm(n, bs_r)], - tag="sparse_conv2d_sp_bsrmm", - ) - -@auto_scheduler.register_workload -def sparse_conv(N, OH, OW, OC, IH, IW, IC, w_data_shape, w_indices_shape, w_indptr_shape, dtype): - X = te.placeholder(shape=(N, IH, IW, IC), dtype=dtype) - W_data = te.placeholder(shape=w_data_shape, dtype=dtype) - W_indices = te.placeholder(shape=w_indices_shape, dtype="int32") - W_indptr = te.placeholder(shape=w_indptr_shape, dtype="int32") - B = te.placeholder(shape=(N, OH, OW, OC), dtype=dtype) - - out = sparse_conv2d_bsr_compute( - topi.nn.relu(X), W_data, W_indices, W_indptr - ) - out = te.compute((N, OH, OW, OC), lambda i, j, k, l: out[i, j, k, l] + B[i, j, k, l], name="BiasAdd") - out = topi.nn.relu(out) - - return [X, W_data, W_indices, W_indptr, B, out] - -###################################################################### -# Special step for sparse workload -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# During schedule tuning, auto-scheduler will use random inputs to measure the performance of a -# generated schedule. While we cannot directly use a random array as the input of a sparse op, for -# the "indices" and "indptr" array are meaningful for the computation. -# -# To solve this problem, we register these as special buffers, and load them when process program -# measuring. -# See the :any:`auto_scheduler.measure` code for more details. - -# Define the basic shapes of this sparse computation - -#Y[n, oh, ow, oc] = X[n, ih, iw, ic] * W[1, 1, ic, oc] -N = 1 -OH = OW = 16 -OC = 128 -IH = IW = 16 -IC = 128 - -BS_R = 8 -BS_C = 1 -density = 0.6 - -# Generate the test data with numpy -X_np = np.random.randn(N, IH, IW, IC).astype("float32") -X_np = np.maximum(np.zeros((N, IH, IW, IC), dtype="float32"), X_np) # Relu -W_sp_np = random_bsr_matrix(IC, OC, BS_R, BS_C, density=density, dtype="float32") -W_np = W_sp_np.todense() - - -# Register the sparse data to special buffer -prefix = "sparse_dense_bsr_%d_%d_%d_%d_%d_%.2f_" % (N, OH, OW, BS_R, BS_C, density) -auto_scheduler.measure.register_special_buffer(prefix + "W_data", W_sp_np.data) -auto_scheduler.measure.register_special_buffer(prefix + "W_indices", W_sp_np.indices) -auto_scheduler.measure.register_special_buffer(prefix + "W_indptr", W_sp_np.indptr) - -###################################################################### -# Create the search task -# ^^^^^^^^^^^^^^^^^^^^^^ -# We then create a search task with M=N=K=512 and dtype="float32" -# If your machine supports avx instructions, you can -# -# - replace "llvm" below with "llvm -mcpu=core-avx2" to enable AVX2 -# - replace "llvm" below with "llvm -mcpu=skylake-avx512" to enable AVX-512 - -target = tvm.target.Target("llvm -mcpu=core-avx2") - -task = tvm.auto_scheduler.SearchTask( - func=sparse_conv, - args=( - N, OH, OW, OC, IH, IW, IC, - W_sp_np.data.shape, - W_sp_np.indices.shape, - W_sp_np.indptr.shape, - "float32" - ), - target=target -) - -# Inspect the computational graph -print("Computational DAG:") -print(task.compute_dag) - -###################################################################### -# Write the custom sketch for sparse dense op -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# Before tuning, we will need to define the CustomSketchRule for the sparse dense op. -# -# CustomSketchRule consists of two parts: the condition function and the apply function. -# -# - condition function: describe when to use this sketch rule. For example, we can match the op -# by their name or tag. -# - apply function: describe how to generate the initial sketch. Auto-scheduler provides a set of -# loop state APIs. - -def meet_condition_func(search_policy, state, stage_id): - state = auto_scheduler.loop_state.State(state, search_policy.search_task.compute_dag) - if state.stages[stage_id].op.tag in [ - "sparse_conv2d_sp_bsrmm", "sparse_conv2d_sp_bsrmm_block" - ]: - return auto_scheduler.PreloadCustomSketchRule.APPLY_AND_SKIP_REST - else: - return auto_scheduler.PreloadCustomSketchRule.PASS - -def apply_func(search_policy, state, stage_id): - ret = [] - s0 = auto_scheduler.loop_state.State(state, search_policy.search_task.compute_dag) - if s0.stages[stage_id].op.tag == "sparse_conv2d_sp_bsrmm_block": - return [s0.state_object, stage_id - 1] - - sparse_dense = s0.stages[stage_id].op - sparse_dense_block = s0.stages[stage_id - 1].op - assert sparse_dense.tag == "sparse_conv2d_sp_bsrmm" - assert sparse_dense_block.tag == "sparse_conv2d_sp_bsrmm_block" - - # Set the default consumer of compute block - consumer = sparse_dense - - # If sparse dense has a single elementwise consumer - # We can compute inline the sparse_dense output stage - consumers = _ffi_api.SearchPolicyUtilsGetConsumers( - search_policy.search_task, s0.state_object, stage_id - ) - if len(consumers) == 1: - consumer_id = int(consumers.items()[0][0]) - if _ffi_api.SearchPolicyUtilsIsElementwiseMatch( - search_policy.search_task, s0.state_object, stage_id, consumer_id - ): - consumer = s0.stages[consumer_id].op - s0.compute_inline(sparse_dense) - - i, h, w, nb_j, j, row_offset, c = s0[sparse_dense_block].iters - m, x, y, n = s0[consumer].iters - - i0, i1, i2 = s0.split(sparse_dense_block, i, [None, None]) - m0, m1 = s0.follow_split(consumer, m, len(s0.transform_steps) - 1, 1) - h0, h1, h2 = s0.split(sparse_dense_block, h, [None, None]) - x0, x1 = s0.follow_split(consumer, x, len(s0.transform_steps) - 1, 1) - w0, w1, w2 = s0.split(sparse_dense_block, w, [None, None]) - y0, y1 = s0.follow_split(consumer, y, len(s0.transform_steps) - 1, 1) - j0, j1 = s0.split(sparse_dense_block, nb_j, [None]) - n0, n1 = s0.follow_split(consumer, n, len(s0.transform_steps) - 1, 1) - s0.reorder(sparse_dense_block, [i0, h0, w0, j0, i1, h1, w1, j1, row_offset, i2, h2, w2, j, c]) - s0.reorder(consumer, [m0, x0, y0, n0, m1, x1, y1, n1]) - s0.compute_at(sparse_dense_block, consumer, n0) - - ret.append([s0.state_object, stage_id - 2]) - - return ret - -###################################################################### -# Next, we set parameters for the auto-scheduler. -# -# * :code:`num_measure_trials` is the number of measurement trials we can use during the search. -# We only make 10 trials in this tutorial for a fast demonstration. In practice, 1000 is a -# good value for the search to converge. You can do more trials according to your time budget. -# * In addition, we use :code:`RecordToFile` to dump measurement records into a file `matmul.json`. -# The measurement records can be used to query the history best, resume the search, -# and do more analyses later. -# * see :any:`auto_scheduler.TuningOptions` for more parameters -# * Here, we need to create a :code:`auto_scheduler.SketchPolicy` object, and add the custom sketch -# rule as a `init_search_callbacks`. - -log_file = "sparse_conv.json" -tune_option = auto_scheduler.TuningOptions( - num_measure_trials=10, - measure_callbacks=[auto_scheduler.RecordToFile(log_file)], - verbose=2, -) - -search_policy = auto_scheduler.SketchPolicy( - task, - program_cost_model=auto_scheduler.XGBModel(), - init_search_callbacks=[ - auto_scheduler.PreloadCustomSketchRule(meet_condition_func, apply_func, "SparseDense") - ] -) - -###################################################################### -# Run the search -# ^^^^^^^^^^^^^^ -# Now we get all inputs ready. -# We can kick off the search and let the auto-scheduler do its magic. -# After some measurement trials, we can load the best schedule from the log -# file and apply it. - -# Run auto-tuning (search) -task.tune(tune_option, search_policy) -# Apply the best schedule -sch, args = task.apply_best(log_file) - -###################################################################### -# We can lower the schedule to see the IR after auto-scheduling. -# The auto-scheduler correctly performs optimizations including multi-level tiling, -# layout transformation, parallelization, vectorization, unrolling, and operator fusion. - -print("Lowered TIR:") -print(tvm.lower(sch, args, simple_mode=True)) - -###################################################################### -# Check correctness and evaluate performance -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# We build the binary and check its correctness and performance. - -func = tvm.build(sch, args, target) - -ctx = tvm.cpu() - -X_tvm = tvm.nd.array(X_np, ctx=ctx) -W_data_tvm = tvm.nd.array(W_sp_np.data, ctx=ctx) -W_indices_tvm = tvm.nd.array(W_sp_np.indices, ctx=ctx) -W_indptr_tvm = tvm.nd.array(W_sp_np.indptr, ctx=ctx) -B_tvm = tvm.nd.array(B_np, ctx=ctx) -Y_tvm = tvm.nd.empty(Y_np.shape, ctx=ctx) - -func(X_tvm, W_data_tvm, W_indices_tvm, W_indptr_tvm, B_tvm, Y_tvm) - -# Check results -tvm.testing.assert_allclose(Y_np, Y_tvm.asnumpy(), atol=1e-4, rtol=1e-4) - -# Evaluate execution time. -evaluator = func.time_evaluator(func.entry_name, ctx, min_repeat_ms=500) -print( - "Execution time of this operator: %.3f ms" - % (np.median(evaluator(X_tvm, W_data_tvm, W_indices_tvm, W_indptr_tvm, B_tvm, Y_tvm).results) * 1000) -) - -###################################################################### -# Using the record file -# ^^^^^^^^^^^^^^^^^^^^^ -# During the search, all measurement records are dumped into the record -# file "matmul.json". The measurement records can be used to re-apply search results, -# resume the search, and perform other analyses. - -###################################################################### -# Here is an example where we load the best schedule from a file, -# and print the equivalent python schedule API. This can be used for -# debugging and learning the behavior of the auto-scheduler. - -print("Equivalent python schedule:") -print(task.print_best(log_file)) - -###################################################################### -# A more complicated example is to resume the search. -# In this case, we need to create the search policy and cost model by ourselves -# and resume the status of search policy and cost model with the log file. -# In the example below we resume the status and do more 5 trials. - - -def resume_search(task, log_file): - print("Resume search:") - cost_model = auto_scheduler.XGBModel() - cost_model.update_from_file(log_file) - search_policy = auto_scheduler.SketchPolicy( - task, cost_model, init_search_callbacks=[ - auto_scheduler.PreloadMeasuredStates(log_file), - auto_scheduler.PreloadCustomSketchRule(meet_condition_func, apply_func, "SparseDense") - ] - ) - tune_option = auto_scheduler.TuningOptions( - num_measure_trials=5, measure_callbacks=[auto_scheduler.RecordToFile(log_file)] - ) - task.tune(tune_option, search_policy=search_policy) - - -resume_search(task, log_file) From f068e792faaf5bc7751bb8d1ec856b1eeec6c324 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 3 Feb 2021 20:55:18 +0800 Subject: [PATCH 06/35] Update --- include/tvm/auto_scheduler/measure_record.h | 2 +- include/tvm/auto_scheduler/search_task.h | 9 +- python/tvm/auto_scheduler/measure.py | 40 ---- python/tvm/auto_scheduler/search_task.py | 86 +++++++++ src/auto_scheduler/feature.cc | 4 +- src/auto_scheduler/measure_record.cc | 37 ++++ src/auto_scheduler/search_task.cc | 20 +- .../test_auto_scheduler_search_task.py | 178 ++++++++++++++++++ tutorials/auto_scheduler/tune_sparse_x86.py | 18 +- 9 files changed, 340 insertions(+), 54 deletions(-) create mode 100644 tests/python/unittest/test_auto_scheduler_search_task.py diff --git a/include/tvm/auto_scheduler/measure_record.h b/include/tvm/auto_scheduler/measure_record.h index ec40611d49b4..c82ed076eca7 100755 --- a/include/tvm/auto_scheduler/measure_record.h +++ b/include/tvm/auto_scheduler/measure_record.h @@ -34,7 +34,7 @@ namespace tvm { namespace auto_scheduler { -const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.5"; // NOLINT(*) +const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.6"; // NOLINT(*) /*! \brief Callback for logging the input and results of measurements to file */ class RecordToFileNode : public MeasureCallbackNode { diff --git a/include/tvm/auto_scheduler/search_task.h b/include/tvm/auto_scheduler/search_task.h index 9e7d3aa2cd32..f74f7ad85930 100755 --- a/include/tvm/auto_scheduler/search_task.h +++ b/include/tvm/auto_scheduler/search_task.h @@ -27,6 +27,7 @@ #include #include +#include namespace tvm { namespace auto_scheduler { @@ -120,6 +121,8 @@ class SearchTaskNode : public Object { HardwareParams hardware_params; /*! \brief The layout rewrite option used for measuring programs. */ LayoutRewriteOption layout_rewrite_option; + /*! \brief A map ... */ + Map task_inputs; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("compute_dag", &compute_dag); @@ -128,6 +131,7 @@ class SearchTaskNode : public Object { v->Visit("target_host", &target_host); v->Visit("hardware_params", &hardware_params); v->Visit("layout_rewrite_option", &layout_rewrite_option); + v->Visit("task_inputs", &task_inputs); } static constexpr const char* _type_key = "auto_scheduler.SearchTask"; @@ -150,7 +154,10 @@ class SearchTask : public ObjectRef { * \param layout_rewrite_option The layout rewrite option used for measuring programs. */ SearchTask(ComputeDAG compute_dag, String workload_key, Target target, Target target_host, - Optional hardware_params, LayoutRewriteOption layout_rewrite_option); + Optional hardware_params, LayoutRewriteOption layout_rewrite_option, + Map task_inputs); + + void AddTaskInput(String data_name, runtime::NDArray data); TVM_DEFINE_OBJECT_REF_METHODS(SearchTask, ObjectRef, SearchTaskNode); }; diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index ca362408640b..a248bd5a34bc 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -1255,43 +1255,3 @@ def rpc_runner_run( return results - -# The map stores special registered buffer for measurement -# This can be used for sparse workloads when we cannot use random tensors for measurment. -global special_buffer_table -special_buffer_table = {} - -def register_special_buffer(tensor_name, data): - """Register special buffer for measurement - This can be used for sparse workloads when we cannot use random tensors for measurment. - """ - if tensor_name in special_buffer_table.keys(): - return True - - if os.path.isfile(tensor_name): - print("Load ", tensor_name) - if tensor_name.startswith("sparse_dense_bsr"): - if tensor_name.endswith("data"): - data = np.fromfile(tensor_name, dtype="float32", sep=" ") - name_split = tensor_name.split("_") - BS_R = int(name_split[6]) - BS_C = int(name_split[7]) - data = data.reshape((data.shape[0] // BS_R // BS_C, BS_R, BS_C)) - else: - data = np.fromfile(tensor_name, dtype="int32", sep=" ") - elif data is None: - return False - - special_buffer_table[tensor_name] = data - - if not os.path.isfile(tensor_name): - data.tofile(tensor_name, " ") - - return True - -def get_special_buffer(tensor_name): - """Get special buffer for measurement. - This can be used for sparse workloads when we cannot use random tensors for measurment. - The buffers are registered by `register_special_buffer`. - """ - return special_buffer_table.get(tensor_name, None) diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index d985ed1341f5..ce413eb5448d 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -18,6 +18,8 @@ """ The definiton of SearchTask """ import json +import os +import numpy as np import tvm._ffi from tvm.runtime import Object @@ -156,6 +158,77 @@ def __init__( measure_callbacks, ) +# The map stores special registered buffer for measurement +# This can be used for sparse workloads when we cannot use random tensors for measurment. +global task_input_buffer_table +# { +# "workload_key_0": { +# "task_input_0": Tensor(...), +# "task_input_1": Tensor(...) +# }, +# "workload_key_1": { +# "task_input_2": Tensor(...), +# "task_input_3": Tensor(...) +# }, +# ... +# } +task_input_buffer_table = {} + + +def register_task_input_buffer(workload_key, input_name, input_data, overwrite=False): + """Register special buffer for measurement + This can be used for sparse workloads when we cannot use random tensors for measurment. + """ + global task_input_buffer_table + + if not workload_key in task_input_buffer_table: + task_input_buffer_table[workload_key] = {} + + input_table = task_input_buffer_table[workload_key] + + if input_name in input_table.keys() and not overwrite: + return input_table[input_name] + + input_table[input_name] = input_data + + return input_data + + # print("reg ", data) + + # if os.path.isfile(tensor_name): + # print("Load ", tensor_name) + # if tensor_name.startswith("sparse_dense_bsr"): + # if tensor_name.endswith("data"): + # data = np.fromfile(tensor_name, dtype="float32", sep=" ") + # name_split = tensor_name.split("_") + # BS_R = int(name_split[6]) + # BS_C = int(name_split[7]) + # data = data.reshape((data.shape[0] // BS_R // BS_C, BS_R, BS_C)) + # else: + # data = np.fromfile(tensor_name, dtype="int32", sep=" ") + # elif data is None: + # return False + + # task_input_buffer_table[tensor_name] = data + + # if not os.path.isfile(tensor_name): + # data.asnumpy().tofile(tensor_name, " ") + + # return True + + +@tvm._ffi.register_func("auto_scheduler.search_task.get_task_input_buffer") +def get_task_input_buffer(workload_key, input_name): + """Get special buffer for measurement. + This can be used for sparse workloads when we cannot use random tensors for measurment. + The buffers are registered by `register_task_input_buffer`. + """ + global task_input_buffer_table + + input_table = task_input_buffer_table.get(workload_key, None) + + return input_table.get(input_name, None) + @tvm._ffi.register_object("auto_scheduler.SearchTask") class SearchTask(Object): @@ -239,6 +312,7 @@ def __init__( target_host, hardware_params, layout_rewrite_option, + {} ) def tune(self, tuning_options, search_policy=None): @@ -314,6 +388,12 @@ def print_best(self, log_file, print_mode="schedule"): return func.imported_modules[0].get_source() raise ValueError("Invalid print_mode: %s" % print_mode) + def add_task_input(self, input_name, input_data): + """ + """ + register_task_input_buffer(self.workload_key, input_name, input_data) + _ffi_api.SearchTaskAddTaskInput(self, input_name, input_data) + def __getstate__(self): return { "compute_dag": self.compute_dag, @@ -322,6 +402,7 @@ def __getstate__(self): "target_host": self.target_host, "hardware_params": self.hardware_params, "layout_rewrite_option": self.layout_rewrite_option, + "task_inputs": [i[0] for i in self.measure_inputs.items()] } def __setstate__(self, state): @@ -337,6 +418,10 @@ def __setstate__(self, state): if len(workload) == 1: register_workload_tensors(workload[0], state["compute_dag"].tensors) + task_inputs = {} + for data_name in state["task_inputs"]: + task_inputs[data_name] = get_task_input_buffer(state["workload_key"], data_name) + self.__init_handle_by_constructor__( _ffi_api.SearchTask, state["compute_dag"], @@ -345,6 +430,7 @@ def __setstate__(self, state): state["target_host"], state["hardware_params"], state["layout_rewrite_option"], + task_inputs, ) diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc index cf516d8452e2..08ee5a8c4238 100755 --- a/src/auto_scheduler/feature.cc +++ b/src/auto_scheduler/feature.cc @@ -1399,7 +1399,7 @@ void GetPerStoreFeaturesFromFile(const std::string& filename, int max_lines, int Array tensors = (*workload_key_to_tensors)(workload_key); task = SearchTask(ComputeDAG(tensors), workload_key, cur_inp->task->target, cur_inp->task->target_host, cur_inp->task->hardware_params, - cur_inp->task->layout_rewrite_option); + cur_inp->task->layout_rewrite_option, cur_inp->task->task_inputs); task_id = task_cache.size(); // compute min cost for each task @@ -1468,7 +1468,7 @@ void GetPerStoreFeaturesFromMeasurePairs(const Array& inputs, Array tensors = (*workload_key_to_tensors)(workload_key); task = SearchTask(ComputeDAG(tensors), workload_key, inputs[i]->task->target, inputs[i]->task->target_host, inputs[i]->task->hardware_params, - inputs[i]->task->layout_rewrite_option); + inputs[i]->task->layout_rewrite_option, {}); } catch (std::exception& e) { // Cannot build ComputeDAG from workload key, the task may have not been registered in // this search round diff --git a/src/auto_scheduler/measure_record.cc b/src/auto_scheduler/measure_record.cc index 1120f437b176..ba5755885594 100644 --- a/src/auto_scheduler/measure_record.cc +++ b/src/auto_scheduler/measure_record.cc @@ -169,6 +169,12 @@ struct Handler<::tvm::auto_scheduler::SearchTaskNode> { writer->WriteArrayItem(std::string("")); } writer->WriteArrayItem(static_cast(data.layout_rewrite_option)); + writer->WriteArraySeperator(); + writer->BeginArray(false); + for (const auto& i : data.task_inputs) { + writer->WriteArrayItem(std::string(i.first)); + } + writer->EndArray(); writer->EndArray(); } inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_scheduler::SearchTaskNode* data) { @@ -200,6 +206,20 @@ struct Handler<::tvm::auto_scheduler::SearchTaskNode> { reader->Read(&int_value); data->layout_rewrite_option = ::tvm::auto_scheduler::LayoutRewriteOption(int_value); s = reader->NextArrayItem(); + if (s) { + std::vector input_data_names; + const auto& func = + tvm::runtime::Registry::Get("auto_scheduler.search_task.get_task_input_buffer"); + reader->BeginArray(); + s = reader->NextArrayItem(); + while (s) { + reader->Read(&str_value); + data->task_inputs.Set(str_value, (*func)(data->workload_key, str_value)); + s = reader->NextArrayItem(); + } + // Process the end of array + s = reader->NextArrayItem(); + } ICHECK(!s); } } @@ -444,5 +464,22 @@ TVM_REGISTER_GLOBAL("auto_scheduler.DeserializeMeasureInput").set_body_typed([]( reader.Read(inp.get()); return ObjectRef(inp); }); + +TVM_REGISTER_GLOBAL("auto_scheduler.SerializeSearchTask") + .set_body_typed([](const SearchTask& search_task) { + std::ostringstream os; + dmlc::JSONWriter writer(&os); + writer.Write(*search_task.get()); + return os.str(); + }); + +TVM_REGISTER_GLOBAL("auto_scheduler.DeserializeSearchTask").set_body_typed([](String json) { + std::istringstream ss(json); + dmlc::JSONReader reader(&ss); + auto search_task = make_object(); + reader.Read(search_task.get()); + return ObjectRef(search_task); +}); + } // namespace auto_scheduler } // namespace tvm diff --git a/src/auto_scheduler/search_task.cc b/src/auto_scheduler/search_task.cc index 0abee16fceab..0c6a5edbdb99 100755 --- a/src/auto_scheduler/search_task.cc +++ b/src/auto_scheduler/search_task.cc @@ -114,7 +114,8 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target SearchTask::SearchTask(ComputeDAG compute_dag, String workload_key, Target target, Target target_host, Optional hardware_params, - LayoutRewriteOption layout_rewrite_option) { + LayoutRewriteOption layout_rewrite_option, + Map task_inputs) { auto node = make_object(); node->compute_dag = std::move(compute_dag); node->workload_key = std::move(workload_key); @@ -127,9 +128,17 @@ SearchTask::SearchTask(ComputeDAG compute_dag, String workload_key, Target targe HardwareParamsNode::GetDefaultHardwareParams(node->target, node->target_host); } node->layout_rewrite_option = layout_rewrite_option; + node->task_inputs = std::move(task_inputs); data_ = std::move(node); } +void SearchTask::AddTaskInput(String data_name, runtime::NDArray data) { + if (operator->()->task_inputs.count(data_name)) { + LOG(WARNING) << data_name << " already in memory"; + } + static_cast(get_mutable())->task_inputs.Set(data_name, data); +} + TVM_REGISTER_GLOBAL("auto_scheduler.HardwareParams") .set_body_typed([](int num_cores, int vector_unit_bytes, int cache_line_bytes, int max_shared_memory_per_block, int max_local_memory_per_block, @@ -142,9 +151,14 @@ TVM_REGISTER_GLOBAL("auto_scheduler.HardwareParams") TVM_REGISTER_GLOBAL("auto_scheduler.SearchTask") .set_body_typed([](ComputeDAG compute_dag, String workload_key, Target target, Target target_host, Optional hardware_params, - int layout_rewrite_option) { + int layout_rewrite_option, Map task_inputs) { return SearchTask(compute_dag, workload_key, target, target_host, hardware_params, - LayoutRewriteOption(layout_rewrite_option)); + LayoutRewriteOption(layout_rewrite_option), task_inputs); + }); + +TVM_REGISTER_GLOBAL("auto_scheduler.SearchTaskAddTaskInput") + .set_body_typed([](SearchTask search_task, String input_name, runtime::NDArray input_data) { + search_task.AddTaskInput(input_name, input_data); }); } // namespace auto_scheduler diff --git a/tests/python/unittest/test_auto_scheduler_search_task.py b/tests/python/unittest/test_auto_scheduler_search_task.py new file mode 100644 index 000000000000..db07cfa59e31 --- /dev/null +++ b/tests/python/unittest/test_auto_scheduler_search_task.py @@ -0,0 +1,178 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Test search policy""" + +import random +import multiprocessing +import numpy as np +import tempfile + +import tvm +import tvm.testing +from tvm import auto_scheduler +from tvm.auto_scheduler.utils import get_const_tuple + +from test_auto_scheduler_common import ( + matmul_auto_scheduler_test, + zero_rank_compute_auto_scheduler_test, + zero_rank_reduce_auto_scheduler_test, +) +import multiprocessing + +def test_search_task_add_task_input(): + auto_scheduler.search_task.task_input_buffer_table.clear() + N = 64 + target = "llvm" + task = auto_scheduler.SearchTask( + func="matmul_auto_scheduler_test", args=(N, N, N), target=target + ) + + test_input_0 = tvm.runtime.ndarray.empty((64, 64)) + test_input_1 = tvm.runtime.ndarray.empty((10, 20)) + test_input_2 = tvm.runtime.ndarray.empty((30, 40, 50)) + task.add_task_input("test_input_0", test_input_0) + task.add_task_input("test_input_1", test_input_1) + task.add_task_input("test_input_2", test_input_2) + + assert len(task.task_inputs) == 3 + assert task.task_inputs["test_input_0"] == test_input_0 + assert task.task_inputs["test_input_1"] == test_input_1 + assert task.task_inputs["test_input_2"] == test_input_2 + + +def test_search_task_record(): + auto_scheduler.search_task.task_input_buffer_table.clear() + N = 64 + target = "llvm" + task = auto_scheduler.SearchTask( + func="matmul_auto_scheduler_test", args=(N, N, N), target=target + ) + task_record = auto_scheduler._ffi_api.SerializeSearchTask(task) + new_task = auto_scheduler._ffi_api.DeserializeSearchTask(task_record) + + # Log with no task input + # TODO(jcf94): Check the compute dag & hardware parameter + assert task.workload_key == new_task.workload_key + assert str(task.target) == str(new_task.target) + assert str(task.target_host) == str(new_task.target_host) + assert task.layout_rewrite_option == new_task.layout_rewrite_option + + # Log with 1 task input + test_input_0 = tvm.runtime.ndarray.empty((64, 64)) + task.add_task_input("test_input_0", test_input_0) + task_record = auto_scheduler._ffi_api.SerializeSearchTask(task) + new_task = auto_scheduler._ffi_api.DeserializeSearchTask(task_record) + assert task.workload_key == new_task.workload_key + assert str(task.target) == str(new_task.target) + assert str(task.target_host) == str(new_task.target_host) + assert task.layout_rewrite_option == new_task.layout_rewrite_option + assert len(new_task.task_inputs) == 1 + assert new_task.task_inputs.items()[0][0] == "test_input_0" + assert new_task.task_inputs.items()[0][1] == test_input_0 + + # Log with multiple task inputs + test_input_1 = tvm.runtime.ndarray.empty((64, 64)) + task.add_task_input("test_input_1", test_input_1) + task_record = auto_scheduler._ffi_api.SerializeSearchTask(task) + new_task = auto_scheduler._ffi_api.DeserializeSearchTask(task_record) + assert task.workload_key == new_task.workload_key + assert str(task.target) == str(new_task.target) + assert str(task.target_host) == str(new_task.target_host) + assert task.layout_rewrite_option == new_task.layout_rewrite_option + assert len(new_task.task_inputs) == 2 + assert new_task.task_inputs.items()[0][0] == "test_input_0" + assert new_task.task_inputs.items()[0][1] == test_input_0 + assert new_task.task_inputs.items()[1][0] == "test_input_1" + assert new_task.task_inputs.items()[1][1] == test_input_1 + + # Log with version 0.5 + v5_log = """["[\\\"matmul_auto_scheduler_test\\\", 64, 64, 64]", "llvm -keys=cpu -link-params=0", [6, 64, 64, 0, 0, 0, 0, 0], "", 1]""" + new_task = auto_scheduler._ffi_api.DeserializeSearchTask(v5_log) + assert task.workload_key == new_task.workload_key + assert str(task.target) == str(new_task.target) + assert str(task.target_host) == str(new_task.target_host) + assert task.layout_rewrite_option == new_task.layout_rewrite_option + assert len(new_task.task_inputs) == 0 + +def test_recover_measure_input_with_task_input(): + auto_scheduler.search_task.task_input_buffer_table.clear() + task = auto_scheduler.SearchTask( + func=matmul_auto_scheduler_test, args=(512, 512, 512), target="llvm" + ) + + # Since this file is tests for search_task, we only check the search_task here + + # Log with no task input + inp = auto_scheduler.measure.MeasureInput(task, task.compute_dag.init_state) + res = auto_scheduler.measure.MeasureResult([0.1], 0, "", 0.2, 1) + measure_record = auto_scheduler.measure_record.dump_record_to_string(inp, res) + measure_log = auto_scheduler.measure_record.load_record_from_string(measure_record) + new_task = measure_log[0].task + assert task.workload_key == new_task.workload_key + assert str(task.target) == str(new_task.target) + assert str(task.target_host) == str(new_task.target_host) + assert task.layout_rewrite_option == new_task.layout_rewrite_option + + # Log with 1 task input + test_input_0 = tvm.runtime.ndarray.empty((64, 64)) + task.add_task_input("test_input_0", test_input_0) + inp = auto_scheduler.measure.MeasureInput(task, task.compute_dag.init_state) + res = auto_scheduler.measure.MeasureResult([0.1], 0, "", 0.2, 1) + measure_record = auto_scheduler.measure_record.dump_record_to_string(inp, res) + measure_log = auto_scheduler.measure_record.load_record_from_string(measure_record) + new_task = measure_log[0].task + assert task.workload_key == new_task.workload_key + assert str(task.target) == str(new_task.target) + assert str(task.target_host) == str(new_task.target_host) + assert task.layout_rewrite_option == new_task.layout_rewrite_option + assert len(new_task.task_inputs) == 1 + assert new_task.task_inputs.items()[0][0] == "test_input_0" + assert new_task.task_inputs.items()[0][1] == test_input_0 + + # Log with multiple task inputs + test_input_1 = tvm.runtime.ndarray.empty((64, 64)) + task.add_task_input("test_input_1", test_input_1) + inp = auto_scheduler.measure.MeasureInput(task, task.compute_dag.init_state) + res = auto_scheduler.measure.MeasureResult([0.1], 0, "", 0.2, 1) + measure_record = auto_scheduler.measure_record.dump_record_to_string(inp, res) + measure_log = auto_scheduler.measure_record.load_record_from_string(measure_record) + new_task = measure_log[0].task + assert task.workload_key == new_task.workload_key + assert str(task.target) == str(new_task.target) + assert str(task.target_host) == str(new_task.target_host) + assert task.layout_rewrite_option == new_task.layout_rewrite_option + assert len(new_task.task_inputs) == 2 + assert new_task.task_inputs.items()[0][0] == "test_input_0" + assert new_task.task_inputs.items()[0][1] == test_input_0 + assert new_task.task_inputs.items()[1][0] == "test_input_1" + assert new_task.task_inputs.items()[1][1] == test_input_1 + + # Log with version 0.5 + v5_log = """{"i": [["[\\\"matmul_auto_scheduler_test\\\", 512, 512, 512]", "llvm -keys=cpu -link-params=0", [6, 64, 64, 0, 0, 0, 0, 0], "", 1], [[], []]], "r": [[0.1], 0, 0.2, 1], "v": "v0.6"}""" + measure_log = auto_scheduler.measure_record.load_record_from_string(v5_log) + new_task = measure_log[0].task + assert task.workload_key == new_task.workload_key + assert str(task.target) == str(new_task.target) + assert str(task.target_host) == str(new_task.target_host) + assert task.layout_rewrite_option == new_task.layout_rewrite_option + assert len(new_task.task_inputs) == 0 + +if __name__ == "__main__": + test_search_task_add_task_input() + test_search_task_record() + test_recover_measure_input_with_task_input() diff --git a/tutorials/auto_scheduler/tune_sparse_x86.py b/tutorials/auto_scheduler/tune_sparse_x86.py index ea0422dd01b7..3f4ca0798a83 100644 --- a/tutorials/auto_scheduler/tune_sparse_x86.py +++ b/tutorials/auto_scheduler/tune_sparse_x86.py @@ -40,7 +40,7 @@ import numpy as np import tvm -from tvm import te, auto_scheduler, topi +from tvm import te, auto_scheduler, runtime, topi from tvm.auto_scheduler import _ffi_api from tvm.topi.utils import get_const_tuple @@ -119,12 +119,6 @@ def sparse_dense(M, N, K, w_data_shape, w_indices_shape, w_indptr_shape, dtype): Y_np = Y_np + B_np # Bias add Y_np = np.maximum(np.zeros((M, N), dtype="float32"), Y_np) # Relu -# Register the sparse data to special buffer -prefix = "sparse_dense_bsr_%d_%d_%d_%d_%d_%.2f_" % (M, N, K, BS_R, BS_C, density) -auto_scheduler.measure.register_special_buffer(prefix + "W_data", W_sp_np.data) -auto_scheduler.measure.register_special_buffer(prefix + "W_indices", W_sp_np.indices) -auto_scheduler.measure.register_special_buffer(prefix + "W_indptr", W_sp_np.indptr) - ###################################################################### # Create the search task # ^^^^^^^^^^^^^^^^^^^^^^ @@ -148,10 +142,20 @@ def sparse_dense(M, N, K, w_data_shape, w_indices_shape, w_indptr_shape, dtype): target=target ) +# Register the sparse data to special buffer +prefix = "sparse_dense_bsr_%d_%d_%d_%d_%d_%.2f_" % (M, N, K, BS_R, BS_C, density) +task.add_task_input(prefix + "W_data", runtime.ndarray.array(W_sp_np.data)) +task.add_task_input(prefix + "W_indices", runtime.ndarray.array(W_sp_np.indices)) +task.add_task_input(prefix + "W_indptr", runtime.ndarray.array(W_sp_np.indptr)) + # Inspect the computational graph print("Computational DAG:") print(task.compute_dag) +print(task.task_inputs) + +exit(0) + ###################################################################### # Write the custom sketch for sparse dense op # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ From 1f1a7b50269b1fa17f59590d0f37b0bf6205b400 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 3 Feb 2021 21:10:54 +0800 Subject: [PATCH 07/35] Add search_inputs to measure --- python/tvm/auto_scheduler/measure.py | 8 +++++--- python/tvm/auto_scheduler/search_task.py | 7 ++++++- tutorials/auto_scheduler/tune_sparse_x86.py | 4 +--- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index a248bd5a34bc..633f11808535 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -226,6 +226,7 @@ def recover_measure_input(inp, rebuild_state=False): target_host=task.target_host, hardware_params=task.hardware_params, layout_rewrite_option=task.layout_rewrite_option, + task_inputs=task.task_inputs, ) if rebuild_state: @@ -814,6 +815,7 @@ def _timed_eval_func( verbose, ): inp = MeasureInput.deserialize(inp_serialized) + task_inputs = inp.task.task_inputs tic = time.time() error_no = 0 error_msg = None @@ -852,11 +854,11 @@ def _timed_eval_func( args = [] for arg in build_res.args: if arg == sparse_data: - args.append(ndarray.array(get_special_buffer(sparse_prefix+"W_data"), ctx)) + args.append(task_inputs[sparse_prefix+"W_data"]) elif arg == sparse_indices: - args.append(ndarray.array(get_special_buffer(sparse_prefix+"W_indices"), ctx)) + args.append(task_inputs[sparse_prefix+"W_indices"]) elif arg == sparse_indptr: - args.append(ndarray.array(get_special_buffer(sparse_prefix+"W_indptr"), ctx)) + args.append(task_inputs[sparse_prefix+"W_indptr"]) else: empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, ctx) random_fill(empty_array) diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index ce413eb5448d..2bf310fd4c3f 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -285,6 +285,7 @@ def __init__( target_host=None, hardware_params=None, layout_rewrite_option=None, + task_inputs={}, ): assert ( func is not None or workload_key is not None @@ -312,7 +313,7 @@ def __init__( target_host, hardware_params, layout_rewrite_option, - {} + task_inputs ) def tune(self, tuning_options, search_policy=None): @@ -395,6 +396,8 @@ def add_task_input(self, input_name, input_data): _ffi_api.SearchTaskAddTaskInput(self, input_name, input_data) def __getstate__(self): + print("get state") + return { "compute_dag": self.compute_dag, "workload_key": self.workload_key, @@ -422,6 +425,8 @@ def __setstate__(self, state): for data_name in state["task_inputs"]: task_inputs[data_name] = get_task_input_buffer(state["workload_key"], data_name) + print("set state") + self.__init_handle_by_constructor__( _ffi_api.SearchTask, state["compute_dag"], diff --git a/tutorials/auto_scheduler/tune_sparse_x86.py b/tutorials/auto_scheduler/tune_sparse_x86.py index 3f4ca0798a83..3111f3c43e64 100644 --- a/tutorials/auto_scheduler/tune_sparse_x86.py +++ b/tutorials/auto_scheduler/tune_sparse_x86.py @@ -154,8 +154,6 @@ def sparse_dense(M, N, K, w_data_shape, w_indices_shape, w_indptr_shape, dtype): print(task.task_inputs) -exit(0) - ###################################################################### # Write the custom sketch for sparse dense op # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -233,7 +231,7 @@ def apply_func(search_policy, state, stage_id): log_file = "sparse_dense.json" tune_option = auto_scheduler.TuningOptions( - num_measure_trials=10, + num_measure_trials=2, measure_callbacks=[auto_scheduler.RecordToFile(log_file)], verbose=2, ) From e83dfe44f7ddc236f82dc37b7f7429498395bb2c Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 3 Feb 2021 22:04:53 +0800 Subject: [PATCH 08/35] Lint fix --- include/tvm/auto_scheduler/search_task.h | 2 +- python/tvm/auto_scheduler/measure.py | 7 ++++--- python/tvm/auto_scheduler/search_task.py | 11 +++++------ tutorials/auto_scheduler/tune_sparse_x86.py | 2 -- 4 files changed, 10 insertions(+), 12 deletions(-) diff --git a/include/tvm/auto_scheduler/search_task.h b/include/tvm/auto_scheduler/search_task.h index f74f7ad85930..40fa14866a5d 100755 --- a/include/tvm/auto_scheduler/search_task.h +++ b/include/tvm/auto_scheduler/search_task.h @@ -26,8 +26,8 @@ #define TVM_AUTO_SCHEDULER_SEARCH_TASK_H_ #include -#include #include +#include namespace tvm { namespace auto_scheduler { diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 633f11808535..c3689c083e9b 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -1018,6 +1018,7 @@ def _timed_rpc_run( verbose, ): inp = MeasureInput.deserialize(inp_serialized) + task_inputs = inp.task.task_inputs tic = time.time() error_no = 0 error_msg = None @@ -1063,11 +1064,11 @@ def _timed_rpc_run( args = [] for arg in build_res.args: if arg == sparse_data: - args.append(ndarray.array(get_special_buffer(sparse_prefix+"W_data"), ctx)) + args.append(task_inputs[sparse_prefix+"W_data"]) elif arg == sparse_indices: - args.append(ndarray.array(get_special_buffer(sparse_prefix+"W_indices"), ctx)) + args.append(task_inputs[sparse_prefix+"W_indices"]) elif arg == sparse_indptr: - args.append(ndarray.array(get_special_buffer(sparse_prefix+"W_indptr"), ctx)) + args.append(task_inputs[sparse_prefix+"W_indptr"]) else: empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, ctx) random_fill(empty_array) diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index 2bf310fd4c3f..e633677d3a3b 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -258,6 +258,10 @@ class SearchTask(Object): The NO_REWRITE and INSERT_TRANSFORM_STAGE are expected to be used when tuning a standalone op, and the REWRITE_FOR_PRE_TRANSFORMED is expected to be used when tuning ops inside a network. + task_inputs : Dict[str, Tensor] + Some special Tensor used as inputs in program measuring. Usually we do not need to care + about it, but for special workloads like Sparse computation the Sparse Tensor input are + meaningful that we cannot use random input directly. Examples -------- @@ -313,7 +317,7 @@ def __init__( target_host, hardware_params, layout_rewrite_option, - task_inputs + task_inputs, ) def tune(self, tuning_options, search_policy=None): @@ -396,8 +400,6 @@ def add_task_input(self, input_name, input_data): _ffi_api.SearchTaskAddTaskInput(self, input_name, input_data) def __getstate__(self): - print("get state") - return { "compute_dag": self.compute_dag, "workload_key": self.workload_key, @@ -424,9 +426,6 @@ def __setstate__(self, state): task_inputs = {} for data_name in state["task_inputs"]: task_inputs[data_name] = get_task_input_buffer(state["workload_key"], data_name) - - print("set state") - self.__init_handle_by_constructor__( _ffi_api.SearchTask, state["compute_dag"], diff --git a/tutorials/auto_scheduler/tune_sparse_x86.py b/tutorials/auto_scheduler/tune_sparse_x86.py index 3111f3c43e64..ac084ad1d385 100644 --- a/tutorials/auto_scheduler/tune_sparse_x86.py +++ b/tutorials/auto_scheduler/tune_sparse_x86.py @@ -152,8 +152,6 @@ def sparse_dense(M, N, K, w_data_shape, w_indices_shape, w_indptr_shape, dtype): print("Computational DAG:") print(task.compute_dag) -print(task.task_inputs) - ###################################################################### # Write the custom sketch for sparse dense op # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ From 7ca2bd4c0a13deb7a2444026c70dd5f3da5eabb2 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 3 Feb 2021 22:12:11 +0800 Subject: [PATCH 09/35] Lint fix --- python/tvm/auto_scheduler/measure.py | 35 ++++++++++--------- python/tvm/auto_scheduler/search_task.py | 6 ++-- python/tvm/topi/nn/sparse.py | 3 +- .../test_auto_scheduler_search_task.py | 3 ++ tutorials/auto_scheduler/tune_sparse_x86.py | 29 +++++++-------- 5 files changed, 42 insertions(+), 34 deletions(-) diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index c3689c083e9b..1c629ef298b6 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -768,13 +768,12 @@ def _process_inputs(input_tensors, M, N, prefix_init): density = 1.0 for i in sparse_data.shape: density *= i - density /= (K * N) + density /= K * N density = density.value - sparse_prefix = "%s_%d_%d_%d_%d_%d_%.2f_" % ( - prefix_init, M, N, K, BS_R, BS_C, density - ) + sparse_prefix = "%s_%d_%d_%d_%d_%d_%.2f_" % (prefix_init, M, N, K, BS_R, BS_C, density) visited = set() + def _traverse(t): # We cannot directly add tensors to the set, because the comparison of # two tensors with ndim=0 is ambiguous. @@ -804,6 +803,7 @@ def _traverse(t): return sparse_prefix, sparse_data, sparse_indices, sparse_indptr + def _timed_eval_func( inp_serialized, build_res, @@ -848,17 +848,18 @@ def _timed_eval_func( assert random_fill, "Please make sure USE_RANDOM is ON in the config.cmake" # Check sparse op - sparse_prefix, sparse_data, sparse_indices, sparse_indptr = \ - _process_sparse_input(build_res.args) + sparse_prefix, sparse_data, sparse_indices, sparse_indptr = _process_sparse_input( + build_res.args + ) if sparse_prefix: args = [] for arg in build_res.args: if arg == sparse_data: - args.append(task_inputs[sparse_prefix+"W_data"]) + args.append(task_inputs[sparse_prefix + "W_data"]) elif arg == sparse_indices: - args.append(task_inputs[sparse_prefix+"W_indices"]) + args.append(task_inputs[sparse_prefix + "W_indices"]) elif arg == sparse_indptr: - args.append(task_inputs[sparse_prefix+"W_indptr"]) + args.append(task_inputs[sparse_prefix + "W_indptr"]) else: empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, ctx) random_fill(empty_array) @@ -1058,23 +1059,26 @@ def _timed_rpc_run( ) # Check sparse op - sparse_prefix, sparse_data, sparse_indices, sparse_indptr = \ - _process_sparse_input(build_res.args) + sparse_prefix, sparse_data, sparse_indices, sparse_indptr = _process_sparse_input( + build_res.args + ) if sparse_prefix: args = [] for arg in build_res.args: if arg == sparse_data: - args.append(task_inputs[sparse_prefix+"W_data"]) + args.append(task_inputs[sparse_prefix + "W_data"]) elif arg == sparse_indices: - args.append(task_inputs[sparse_prefix+"W_indices"]) + args.append(task_inputs[sparse_prefix + "W_indices"]) elif arg == sparse_indptr: - args.append(task_inputs[sparse_prefix+"W_indptr"]) + args.append(task_inputs[sparse_prefix + "W_indptr"]) else: empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, ctx) random_fill(empty_array) args.append(empty_array) else: - args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in build_res.args] + args = [ + ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in build_res.args + ] for arg in args: random_fill(arg) ctx.sync() @@ -1257,4 +1261,3 @@ def rpc_runner_run( print("") return results - diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index e633677d3a3b..b6b009298be3 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -158,6 +158,7 @@ def __init__( measure_callbacks, ) + # The map stores special registered buffer for measurement # This can be used for sparse workloads when we cannot use random tensors for measurment. global task_input_buffer_table @@ -394,8 +395,7 @@ def print_best(self, log_file, print_mode="schedule"): raise ValueError("Invalid print_mode: %s" % print_mode) def add_task_input(self, input_name, input_data): - """ - """ + """""" register_task_input_buffer(self.workload_key, input_name, input_data) _ffi_api.SearchTaskAddTaskInput(self, input_name, input_data) @@ -407,7 +407,7 @@ def __getstate__(self): "target_host": self.target_host, "hardware_params": self.hardware_params, "layout_rewrite_option": self.layout_rewrite_option, - "task_inputs": [i[0] for i in self.measure_inputs.items()] + "task_inputs": [i[0] for i in self.measure_inputs.items()], } def __setstate__(self, state): diff --git a/python/tvm/topi/nn/sparse.py b/python/tvm/topi/nn/sparse.py index d790d087d251..eab26c5c0c69 100644 --- a/python/tvm/topi/nn/sparse.py +++ b/python/tvm/topi/nn/sparse.py @@ -218,7 +218,8 @@ def _compute_block(i, nb_j, j): idxm = tvm.tir.indexmod bsrmm_block = te.compute( - (m, num_blocks, bs_r), _compute_block, + (m, num_blocks, bs_r), + _compute_block, tag="sparse_dense_sp_rhs_bsrmm_block", attrs={"FLOP": 2 * m * num_blocks * bs_r * k}, ) diff --git a/tests/python/unittest/test_auto_scheduler_search_task.py b/tests/python/unittest/test_auto_scheduler_search_task.py index db07cfa59e31..877a1ac1c430 100644 --- a/tests/python/unittest/test_auto_scheduler_search_task.py +++ b/tests/python/unittest/test_auto_scheduler_search_task.py @@ -34,6 +34,7 @@ ) import multiprocessing + def test_search_task_add_task_input(): auto_scheduler.search_task.task_input_buffer_table.clear() N = 64 @@ -109,6 +110,7 @@ def test_search_task_record(): assert task.layout_rewrite_option == new_task.layout_rewrite_option assert len(new_task.task_inputs) == 0 + def test_recover_measure_input_with_task_input(): auto_scheduler.search_task.task_input_buffer_table.clear() task = auto_scheduler.SearchTask( @@ -172,6 +174,7 @@ def test_recover_measure_input_with_task_input(): assert task.layout_rewrite_option == new_task.layout_rewrite_option assert len(new_task.task_inputs) == 0 + if __name__ == "__main__": test_search_task_add_task_input() test_search_task_record() diff --git a/tutorials/auto_scheduler/tune_sparse_x86.py b/tutorials/auto_scheduler/tune_sparse_x86.py index ac084ad1d385..2b943ab2aabd 100644 --- a/tutorials/auto_scheduler/tune_sparse_x86.py +++ b/tutorials/auto_scheduler/tune_sparse_x86.py @@ -76,6 +76,7 @@ def random_bsr_matrix(M, N, BS_R, BS_C, density, dtype): assert s.indptr.shape == (M // BS_R + 1,) return s + @auto_scheduler.register_workload def sparse_dense(M, N, K, w_data_shape, w_indices_shape, w_indptr_shape, dtype): X = te.placeholder(shape=(M, K), dtype=dtype) @@ -84,14 +85,13 @@ def sparse_dense(M, N, K, w_data_shape, w_indices_shape, w_indptr_shape, dtype): W_indptr = te.placeholder(shape=w_indptr_shape, dtype="int32") B = te.placeholder(shape=(M, N), dtype=dtype) - out = topi.nn.sparse_dense( - topi.nn.relu(X), W_data, W_indices, W_indptr - ) + out = topi.nn.sparse_dense(topi.nn.relu(X), W_data, W_indices, W_indptr) out = te.compute((M, N), lambda i, j: out[i, j] + B[i, j], name="BiasAdd") out = topi.nn.relu(out) return [X, W_data, W_indices, W_indptr, B, out] + ###################################################################### # Special step for sparse workload # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -132,14 +132,8 @@ def sparse_dense(M, N, K, w_data_shape, w_indices_shape, w_indptr_shape, dtype): task = tvm.auto_scheduler.SearchTask( func=sparse_dense, - args=( - M, N, K, - W_sp_np.data.shape, - W_sp_np.indices.shape, - W_sp_np.indptr.shape, - "float32" - ), - target=target + args=(M, N, K, W_sp_np.data.shape, W_sp_np.indices.shape, W_sp_np.indptr.shape, "float32"), + target=target, ) # Register the sparse data to special buffer @@ -164,15 +158,18 @@ def sparse_dense(M, N, K, w_data_shape, w_indices_shape, w_indptr_shape, dtype): # - apply function: describe how to generate the initial sketch. You can implement it using # auto-scheduler provided loop state APIs. + def meet_condition_func(search_policy, state, stage_id): state = auto_scheduler.loop_state.State(state, search_policy.search_task.compute_dag) if state.stages[stage_id].op.tag in [ - "sparse_dense_sp_rhs_bsrmm", "sparse_dense_sp_rhs_bsrmm_block" + "sparse_dense_sp_rhs_bsrmm", + "sparse_dense_sp_rhs_bsrmm_block", ]: return auto_scheduler.PreloadCustomSketchRule.APPLY_AND_SKIP_REST else: return auto_scheduler.PreloadCustomSketchRule.PASS + def apply_func(search_policy, state, stage_id): ret = [] s0 = auto_scheduler.loop_state.State(state, search_policy.search_task.compute_dag) @@ -214,6 +211,7 @@ def apply_func(search_policy, state, stage_id): return ret + ###################################################################### # Next, we set parameters for the auto-scheduler with the custom sketch plugged in. # @@ -239,7 +237,7 @@ def apply_func(search_policy, state, stage_id): program_cost_model=auto_scheduler.XGBModel(), init_search_callbacks=[ auto_scheduler.PreloadCustomSketchRule(meet_condition_func, apply_func, "SparseDense") - ] + ], ) ###################################################################### @@ -288,5 +286,8 @@ def apply_func(search_policy, state, stage_id): evaluator = func.time_evaluator(func.entry_name, ctx, min_repeat_ms=500) print( "Execution time of this operator: %.3f ms" - % (np.median(evaluator(X_tvm, W_data_tvm, W_indices_tvm, W_indptr_tvm, B_tvm, Y_tvm).results) * 1000) + % ( + np.median(evaluator(X_tvm, W_data_tvm, W_indices_tvm, W_indptr_tvm, B_tvm, Y_tvm).results) + * 1000 + ) ) From 486567a16b65afe16686972921f218a161f1d965 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Thu, 4 Feb 2021 15:02:41 +0800 Subject: [PATCH 10/35] Update --- include/tvm/auto_scheduler/search_task.h | 9 +- python/tvm/auto_scheduler/measure.py | 182 +++++------------- python/tvm/auto_scheduler/search_task.py | 57 ++---- python/tvm/topi/nn/sparse.py | 91 +++++++++ .../test_auto_scheduler_search_task.py | 6 +- 5 files changed, 171 insertions(+), 174 deletions(-) diff --git a/include/tvm/auto_scheduler/search_task.h b/include/tvm/auto_scheduler/search_task.h index 40fa14866a5d..5bb58ca961ac 100755 --- a/include/tvm/auto_scheduler/search_task.h +++ b/include/tvm/auto_scheduler/search_task.h @@ -121,7 +121,7 @@ class SearchTaskNode : public Object { HardwareParams hardware_params; /*! \brief The layout rewrite option used for measuring programs. */ LayoutRewriteOption layout_rewrite_option; - /*! \brief A map ... */ + /*! \brief A map that stores some user defined input data used in program measuring. */ Map task_inputs; void VisitAttrs(tvm::AttrVisitor* v) { @@ -157,7 +157,12 @@ class SearchTask : public ObjectRef { Optional hardware_params, LayoutRewriteOption layout_rewrite_option, Map task_inputs); - void AddTaskInput(String data_name, runtime::NDArray data); + /*! + * \brief Add a input Tensor to this task. + * \param input_name + * \param input_data + */ + void AddTaskInput(String input_name, runtime::NDArray input_data); TVM_DEFINE_OBJECT_REF_METHODS(SearchTask, ObjectRef, SearchTaskNode); }; diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 1c629ef298b6..7dcd25c4f4e3 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -45,7 +45,6 @@ from tvm.ir import transform from tvm.autotvm.measure.measure_methods import set_cuda_target_arch from tvm.contrib import tar, ndk -from tvm.te import PlaceholderOp, ComputeOp from . import _ffi_api from .loop_state import StateObject @@ -723,85 +722,22 @@ def local_builder_build(inputs, timeout, n_parallel, build_func="default", verbo return results -def _process_sparse_input(args): - sparse_prefix = sparse_data = sparse_indices = sparse_indptr = None - - def _process_inputs(input_tensors, M, N, prefix_init): - nonlocal sparse_prefix - nonlocal sparse_data - nonlocal sparse_indices - nonlocal sparse_indptr - - assert len(input_tensors) == 4 - unsure_tensors = list(input_tensors) - # Get the Dense data - dense_data = None - for tensor in unsure_tensors: - if len(tensor.shape) == 2: - assert dense_data is None - dense_data = tensor - assert M == dense_data.shape[0] - K = dense_data.shape[1] - unsure_tensors.remove(dense_data) - - # Get the Sparse data - sparse_data = None - for tensor in unsure_tensors: - if len(tensor.shape) == 3: - assert sparse_data is None - sparse_data = tensor - block_size, BS_R, BS_C = sparse_data.shape - unsure_tensors.remove(sparse_data) - - # Get the Sparse indptr & indices - sparse_indices = None - for tensor in unsure_tensors: - assert len(tensor.shape) == 1 - if tensor.shape[0] == block_size: - assert sparse_indices is None - sparse_indices = tensor - unsure_tensors.remove(sparse_indices) - assert len(unsure_tensors) == 1 - sparse_indptr = unsure_tensors[0] - - # Generate the sparse_prefix - density = 1.0 - for i in sparse_data.shape: - density *= i - density /= K * N - density = density.value - sparse_prefix = "%s_%d_%d_%d_%d_%d_%.2f_" % (prefix_init, M, N, K, BS_R, BS_C, density) - - visited = set() - - def _traverse(t): - # We cannot directly add tensors to the set, because the comparison of - # two tensors with ndim=0 is ambiguous. - assert t.handle is not None - if t.handle.value in visited: - return - if isinstance(t.op, ComputeOp): - # TODO(jcf94): Currently only support to tune one sparse op - if t.op.tag == "sparse_dense_sp_rhs_bsrmm": - M, N = t.shape - assert len(t.op.input_tensors) == 1 - block_tensor = t.op.input_tensors[0] - _process_inputs(block_tensor.op.input_tensors, M, N, "sparse_dense_bsr") - if t.op.tag == "sparse_conv2d_bsrmm": - N, OH = t.shape[0], t.shape[1] - assert len(t.op.input_tensors) == 1 - block_tensor = t.op.input_tensors[0] - _process_inputs(block_tensor.op.input_tensors, N, OH, "sparse_dense_bsr") - if sparse_prefix is not None: - return - for x in t.op.input_tensors: - _traverse(x) - visited.add(t.handle.value) - - for arg in args: - _traverse(arg) - - return sparse_prefix, sparse_data, sparse_indices, sparse_indptr +def _prepare_input_map(args): + """This function deals with special task inputs.""" + # Lazy load topi + from tvm import topi + + # A dict that maps the input tensor arg to a buffer name + tensor_input_map = {} + + # Case 0: Check sparse op + sparse_input_map = topi.nn.sparse.try_get_sparse_input(args) + tensor_input_map.update(sparse_input_map) + + # Case 1: Check ... + # Process any other special buffers here and update them to tensor_input_map + + return tensor_input_map def _timed_eval_func( @@ -847,29 +783,22 @@ def _timed_eval_func( random_fill = tvm.get_global_func("tvm.contrib.random.random_fill", True) assert random_fill, "Please make sure USE_RANDOM is ON in the config.cmake" - # Check sparse op - sparse_prefix, sparse_data, sparse_indices, sparse_indptr = _process_sparse_input( - build_res.args - ) - if sparse_prefix: - args = [] - for arg in build_res.args: - if arg == sparse_data: - args.append(task_inputs[sparse_prefix + "W_data"]) - elif arg == sparse_indices: - args.append(task_inputs[sparse_prefix + "W_indices"]) - elif arg == sparse_indptr: - args.append(task_inputs[sparse_prefix + "W_indptr"]) + tensor_input_map = _prepare_input_map(build_res.args) if task_inputs else {} + args = [] + for arg in build_res.args: + if arg in tensor_input_map: + tensor_name = tensor_input_map[arg] + if tensor_name in task_inputs: + args.append(task_inputs[tensor_name]) else: - empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, ctx) - random_fill(empty_array) - args.append(empty_array) - else: - args = [ - ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in build_res.args - ] - for arg in args: - random_fill(arg) + raise ValueError( + "%s not found in task_inputs, " % (tensor_name) + + "should provide with SearchTask.AddTaskInput()" + ) + else: + empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, ctx) + random_fill(empty_array) + args.append(empty_array) ctx.sync() costs = time_f(*args).results # pylint: disable=broad-except @@ -1051,36 +980,27 @@ def _timed_rpc_run( if error_no == 0: try: - try: - random_fill = remote.get_function("tvm.contrib.random.random_fill") - except AttributeError: - raise AttributeError( - "Please make sure USE_RANDOM is ON in the config.cmake " "on the remote devices" - ) - - # Check sparse op - sparse_prefix, sparse_data, sparse_indices, sparse_indptr = _process_sparse_input( - build_res.args - ) - if sparse_prefix: - args = [] - for arg in build_res.args: - if arg == sparse_data: - args.append(task_inputs[sparse_prefix + "W_data"]) - elif arg == sparse_indices: - args.append(task_inputs[sparse_prefix + "W_indices"]) - elif arg == sparse_indptr: - args.append(task_inputs[sparse_prefix + "W_indptr"]) + random_fill = remote.get_function("tvm.contrib.random.random_fill") + assert ( + random_fill + ), "Please make sure USE_RANDOM is ON in the config.cmake on the remote devices" + + tensor_input_map = _prepare_input_map(build_res.args) if task_inputs else {} + args = [] + for arg in build_res.args: + if arg in tensor_input_map: + tensor_name = tensor_input_map[arg] + if tensor_name in task_inputs: + args.append(task_inputs[tensor_name]) else: - empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, ctx) - random_fill(empty_array) - args.append(empty_array) - else: - args = [ - ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in build_res.args - ] - for arg in args: - random_fill(arg) + raise ValueError( + "%s not found in task_inputs, " % (tensor_name) + + "should provide with SearchTask.AddTaskInput()" + ) + else: + empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, ctx) + random_fill(empty_array) + args.append(empty_array) ctx.sync() costs = time_f(*args).results diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index b6b009298be3..d4329ae243ea 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -18,8 +18,6 @@ """ The definiton of SearchTask """ import json -import os -import numpy as np import tvm._ffi from tvm.runtime import Object @@ -161,7 +159,6 @@ def __init__( # The map stores special registered buffer for measurement # This can be used for sparse workloads when we cannot use random tensors for measurment. -global task_input_buffer_table # { # "workload_key_0": { # "task_input_0": Tensor(...), @@ -173,19 +170,18 @@ def __init__( # }, # ... # } -task_input_buffer_table = {} +TASK_INPUT_BUFFER_TABLE = {} def register_task_input_buffer(workload_key, input_name, input_data, overwrite=False): """Register special buffer for measurement This can be used for sparse workloads when we cannot use random tensors for measurment. """ - global task_input_buffer_table + global TASK_INPUT_BUFFER_TABLE - if not workload_key in task_input_buffer_table: - task_input_buffer_table[workload_key] = {} - - input_table = task_input_buffer_table[workload_key] + if not workload_key in TASK_INPUT_BUFFER_TABLE: + TASK_INPUT_BUFFER_TABLE[workload_key] = {} + input_table = TASK_INPUT_BUFFER_TABLE[workload_key] if input_name in input_table.keys() and not overwrite: return input_table[input_name] @@ -194,29 +190,6 @@ def register_task_input_buffer(workload_key, input_name, input_data, overwrite=F return input_data - # print("reg ", data) - - # if os.path.isfile(tensor_name): - # print("Load ", tensor_name) - # if tensor_name.startswith("sparse_dense_bsr"): - # if tensor_name.endswith("data"): - # data = np.fromfile(tensor_name, dtype="float32", sep=" ") - # name_split = tensor_name.split("_") - # BS_R = int(name_split[6]) - # BS_C = int(name_split[7]) - # data = data.reshape((data.shape[0] // BS_R // BS_C, BS_R, BS_C)) - # else: - # data = np.fromfile(tensor_name, dtype="int32", sep=" ") - # elif data is None: - # return False - - # task_input_buffer_table[tensor_name] = data - - # if not os.path.isfile(tensor_name): - # data.asnumpy().tofile(tensor_name, " ") - - # return True - @tvm._ffi.register_func("auto_scheduler.search_task.get_task_input_buffer") def get_task_input_buffer(workload_key, input_name): @@ -224,9 +197,9 @@ def get_task_input_buffer(workload_key, input_name): This can be used for sparse workloads when we cannot use random tensors for measurment. The buffers are registered by `register_task_input_buffer`. """ - global task_input_buffer_table + global TASK_INPUT_BUFFER_TABLE - input_table = task_input_buffer_table.get(workload_key, None) + input_table = TASK_INPUT_BUFFER_TABLE.get(workload_key, None) return input_table.get(input_name, None) @@ -259,7 +232,7 @@ class SearchTask(Object): The NO_REWRITE and INSERT_TRANSFORM_STAGE are expected to be used when tuning a standalone op, and the REWRITE_FOR_PRE_TRANSFORMED is expected to be used when tuning ops inside a network. - task_inputs : Dict[str, Tensor] + task_inputs : Optional[Dict[str, Tensor]] Some special Tensor used as inputs in program measuring. Usually we do not need to care about it, but for special workloads like Sparse computation the Sparse Tensor input are meaningful that we cannot use random input directly. @@ -290,7 +263,7 @@ def __init__( target_host=None, hardware_params=None, layout_rewrite_option=None, - task_inputs={}, + task_inputs=None, ): assert ( func is not None or workload_key is not None @@ -318,7 +291,7 @@ def __init__( target_host, hardware_params, layout_rewrite_option, - task_inputs, + task_inputs or {}, ) def tune(self, tuning_options, search_policy=None): @@ -395,7 +368,15 @@ def print_best(self, log_file, print_mode="schedule"): raise ValueError("Invalid print_mode: %s" % print_mode) def add_task_input(self, input_name, input_data): - """""" + """Add a input Tensor to this task. + + Parameters + ---------- + input_name : str + ... + input_data : Tensor + ... + """ register_task_input_buffer(self.workload_key, input_name, input_data) _ffi_api.SearchTaskAddTaskInput(self, input_name, input_data) diff --git a/python/tvm/topi/nn/sparse.py b/python/tvm/topi/nn/sparse.py index eab26c5c0c69..c5460a4ad83f 100644 --- a/python/tvm/topi/nn/sparse.py +++ b/python/tvm/topi/nn/sparse.py @@ -359,3 +359,94 @@ def sparse_dense_alter_layout(_attrs, _inputs, _tinfos, _out_type): Unlike other TOPI functions, this function operates on both graph level and operator level. """ return None + + +def try_get_sparse_input(args): + """Analise the input data from the given args. Return is a Dict[Tensor, str] that maps the + input Tensor to a buffer name. + """ + sparse_prefix = sparse_data = sparse_indices = sparse_indptr = None + + def _process_inputs(input_tensors, m, n, prefix_init): + nonlocal sparse_prefix + nonlocal sparse_data + nonlocal sparse_indices + nonlocal sparse_indptr + + assert len(input_tensors) == 4 + unsure_tensors = list(input_tensors) + # Get the Dense data + dense_data = None + for tensor in unsure_tensors: + if len(tensor.shape) == 2: + assert dense_data is None + dense_data = tensor + assert m == dense_data.shape[0] + k = dense_data.shape[1] + unsure_tensors.remove(dense_data) + + # Get the Sparse data + sparse_data = None + for tensor in unsure_tensors: + if len(tensor.shape) == 3: + assert sparse_data is None + sparse_data = tensor + block_size, bs_r, bs_c = sparse_data.shape + unsure_tensors.remove(sparse_data) + + # Get the Sparse indptr & indices + sparse_indices = None + for tensor in unsure_tensors: + assert len(tensor.shape) == 1 + if tensor.shape[0] == block_size: + assert sparse_indices is None + sparse_indices = tensor + unsure_tensors.remove(sparse_indices) + assert len(unsure_tensors) == 1 + sparse_indptr = unsure_tensors[0] + + # Generate the sparse_prefix + density = 1.0 + for i in sparse_data.shape: + density *= i + density /= k * n + density = density.value + sparse_prefix = "%s_%d_%d_%d_%d_%d_%.2f_" % (prefix_init, m, n, k, bs_r, bs_c, density) + + visited = set() + + def _traverse(t): + # We cannot directly add tensors to the set, because the comparison of + # two tensors with ndim=0 is ambiguous. + assert t.handle is not None + if t.handle.value in visited: + return + + if isinstance(t.op, te.ComputeOp): + # TODO(jcf94): Currently only support to one sparse op, add more support here + if t.op.tag == "sparse_dense_sp_rhs_bsrmm": + m, n = t.shape + assert len(t.op.input_tensors) == 1 + block_tensor = t.op.input_tensors[0] + _process_inputs(block_tensor.op.input_tensors, m, n, "sparse_dense_bsr") + if sparse_prefix is not None: + # Early stop if we find a sparse_prefix + # Notice: If any workload has more than one sparse input, this may get problem + return + for x in t.op.input_tensors: + _traverse(x) + visited.add(t.handle.value) + + try: + for arg in args: + _traverse(arg) + # pylint: disable=broad-except + except Exception: + return {} + + sparse_input_map = {} + sparse_input_map[sparse_data] = sparse_prefix + "W_data" + sparse_input_map[sparse_indices] = sparse_prefix + "W_indices" + sparse_input_map[sparse_indptr] = sparse_prefix + "W_indptr" + + return sparse_input_map diff --git a/tests/python/unittest/test_auto_scheduler_search_task.py b/tests/python/unittest/test_auto_scheduler_search_task.py index 877a1ac1c430..3f1db4d46c46 100644 --- a/tests/python/unittest/test_auto_scheduler_search_task.py +++ b/tests/python/unittest/test_auto_scheduler_search_task.py @@ -36,7 +36,7 @@ def test_search_task_add_task_input(): - auto_scheduler.search_task.task_input_buffer_table.clear() + auto_scheduler.search_task.TASK_INPUT_BUFFER_TABLE.clear() N = 64 target = "llvm" task = auto_scheduler.SearchTask( @@ -57,7 +57,7 @@ def test_search_task_add_task_input(): def test_search_task_record(): - auto_scheduler.search_task.task_input_buffer_table.clear() + auto_scheduler.search_task.TASK_INPUT_BUFFER_TABLE.clear() N = 64 target = "llvm" task = auto_scheduler.SearchTask( @@ -112,7 +112,7 @@ def test_search_task_record(): def test_recover_measure_input_with_task_input(): - auto_scheduler.search_task.task_input_buffer_table.clear() + auto_scheduler.search_task.TASK_INPUT_BUFFER_TABLE.clear() task = auto_scheduler.SearchTask( func=matmul_auto_scheduler_test, args=(512, 512, 512), target="llvm" ) From aa7abd2512bfde141faf635e92275f0616687f1c Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Thu, 4 Feb 2021 15:12:13 +0800 Subject: [PATCH 11/35] Update --- python/tvm/auto_scheduler/measure.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 7dcd25c4f4e3..5a94859339db 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -724,8 +724,8 @@ def local_builder_build(inputs, timeout, n_parallel, build_func="default", verbo def _prepare_input_map(args): """This function deals with special task inputs.""" - # Lazy load topi - from tvm import topi + # pylint: disable=import-outside-toplevel + from tvm import topi # lazily import to avoid recursive dependency # A dict that maps the input tensor arg to a buffer name tensor_input_map = {} From 48a4e617e86d6b17cea132b567eb1e6aeba63263 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Thu, 4 Feb 2021 15:13:09 +0800 Subject: [PATCH 12/35] Update --- python/tvm/auto_scheduler/measure.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 5a94859339db..320c6b4cf5c9 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -37,8 +37,6 @@ import tempfile import multiprocessing -import numpy as np - import tvm._ffi from tvm.runtime import Object, module, ndarray from tvm.driver import build_module From 5b2e25a4e9a29b23762e743ec238d68a3214cce7 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Thu, 4 Feb 2021 15:26:18 +0800 Subject: [PATCH 13/35] Update --- include/tvm/auto_scheduler/search_task.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/include/tvm/auto_scheduler/search_task.h b/include/tvm/auto_scheduler/search_task.h index 5bb58ca961ac..fbc7b419bb80 100755 --- a/include/tvm/auto_scheduler/search_task.h +++ b/include/tvm/auto_scheduler/search_task.h @@ -152,6 +152,7 @@ class SearchTask : public ObjectRef { * \param target_host The target host device of this search task. * \param hardware_params Hardware parameters used in this search task. * \param layout_rewrite_option The layout rewrite option used for measuring programs. + * \param task_inputs .... */ SearchTask(ComputeDAG compute_dag, String workload_key, Target target, Target target_host, Optional hardware_params, LayoutRewriteOption layout_rewrite_option, @@ -159,8 +160,8 @@ class SearchTask : public ObjectRef { /*! * \brief Add a input Tensor to this task. - * \param input_name - * \param input_data + * \param input_name ...... + * \param input_data ...... */ void AddTaskInput(String input_name, runtime::NDArray input_data); From cc111b97339cb221f48531e1fc6cb57b560a7278 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Thu, 4 Feb 2021 17:10:01 +0800 Subject: [PATCH 14/35] Add file save load support --- python/tvm/auto_scheduler/search_task.py | 71 +++++++++++++++++++++--- 1 file changed, 64 insertions(+), 7 deletions(-) diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index d4329ae243ea..0a5a30eb7c00 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -19,8 +19,12 @@ import json +import os +import numpy as np +import logging + import tvm._ffi -from tvm.runtime import Object +from tvm.runtime import Object, ndarray from tvm.driver.build_module import build from tvm.target import Target @@ -33,6 +37,8 @@ from .workload_registry import register_workload_tensors from . import _ffi_api +logger = logging.getLogger("auto_scheduler") + @tvm._ffi.register_object("auto_scheduler.HardwareParams") class HardwareParams(Object): @@ -173,6 +179,32 @@ def __init__( TASK_INPUT_BUFFER_TABLE = {} +def _try_load_buffer_from_file(buffer_name): + filelist = os.listdir() + + for file in filelist: + if file.startswith(buffer_name) and file.count("."): + meta_info = file.split(".")[-1].split("_") + shape = [int(i) for i in meta_info[:-1]] + dtype = meta_info[-1] + buffer_data = np.fromfile(file, dtype=dtype, sep=" ") + buffer_data = buffer_data.reshape(shape) + return ndarray.array(buffer_data) + + return None + + +def _save_buffer_to_file(buffer_name, buffer_data): + np_data = buffer_data.asnumpy() + + buffer_name += "." + for i in np_data.shape: + buffer_name += "%d_" % (i) + buffer_name += "%s" % (np_data.dtype) + + np_data.tofile(buffer_name, " ") + + def register_task_input_buffer(workload_key, input_name, input_data, overwrite=False): """Register special buffer for measurement This can be used for sparse workloads when we cannot use random tensors for measurment. @@ -183,11 +215,22 @@ def register_task_input_buffer(workload_key, input_name, input_data, overwrite=F TASK_INPUT_BUFFER_TABLE[workload_key] = {} input_table = TASK_INPUT_BUFFER_TABLE[workload_key] - if input_name in input_table.keys() and not overwrite: - return input_table[input_name] + if not overwrite: + if not input_name in input_table.keys(): + # Try to load buffer data from local file + tensor_from_file = _try_load_buffer_from_file(input_name) + if tensor_from_file: + input_table[input_name] = tensor_from_file + + if input_name in input_table.keys(): + logger.warning( + "Tensor %s exists in TASK_INPUT_BUFFER_TABLE, " % (input_name) + + "set overwrite to True or this Tensor will not be registered" + ) + return input_table[input_name] input_table[input_name] = input_data - + _save_buffer_to_file(input_name, input_data) return input_data @@ -199,9 +242,23 @@ def get_task_input_buffer(workload_key, input_name): """ global TASK_INPUT_BUFFER_TABLE - input_table = TASK_INPUT_BUFFER_TABLE.get(workload_key, None) + if not workload_key in TASK_INPUT_BUFFER_TABLE: + TASK_INPUT_BUFFER_TABLE[workload_key] = {} + input_table = TASK_INPUT_BUFFER_TABLE[workload_key] + + if not input_name in input_table.keys(): + # Try to load buffer data from local file + tensor_from_file = _try_load_buffer_from_file(input_name) + if tensor_from_file: + input_table[input_name] = tensor_from_file + + if input_name in input_table.keys(): + return input_table[input_name] - return input_table.get(input_name, None) + raise ValueError( + "%s not found in TASK_INPUT_BUFFER_TABLE, " % (input_name) + + "should provide with SearchTask.AddTaskInput()" + ) @tvm._ffi.register_object("auto_scheduler.SearchTask") @@ -388,7 +445,7 @@ def __getstate__(self): "target_host": self.target_host, "hardware_params": self.hardware_params, "layout_rewrite_option": self.layout_rewrite_option, - "task_inputs": [i[0] for i in self.measure_inputs.items()], + "task_inputs": [i[0] for i in self.task_inputs.items()], } def __setstate__(self, state): From 9c9d974edeefe5697806685c91f2977bae4f817c Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Thu, 4 Feb 2021 17:21:18 +0800 Subject: [PATCH 15/35] Update --- python/tvm/auto_scheduler/search_task.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index 0a5a30eb7c00..24dfecae12b1 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -20,8 +20,8 @@ import json import os -import numpy as np import logging +import numpy as np import tvm._ffi from tvm.runtime import Object, ndarray @@ -37,6 +37,7 @@ from .workload_registry import register_workload_tensors from . import _ffi_api +# pylint: disable=invalid-name logger = logging.getLogger("auto_scheduler") @@ -224,8 +225,9 @@ def register_task_input_buffer(workload_key, input_name, input_data, overwrite=F if input_name in input_table.keys(): logger.warning( - "Tensor %s exists in TASK_INPUT_BUFFER_TABLE, " % (input_name) - + "set overwrite to True or this Tensor will not be registered" + "Tensor %s exists in TASK_INPUT_BUFFER_TABLE, %s", + input_name, + "set overwrite to True or this Tensor will not be registered", ) return input_table[input_name] From 2eac0c7300ae8f910151e4af50ebbef83eaf9ac9 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Tue, 2 Mar 2021 20:17:03 +0800 Subject: [PATCH 16/35] Update --- include/tvm/auto_scheduler/search_task.h | 8 ++++---- python/tvm/auto_scheduler/search_task.py | 4 ++-- .../unittest/test_auto_scheduler_search_task.py | 14 +++++++------- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/include/tvm/auto_scheduler/search_task.h b/include/tvm/auto_scheduler/search_task.h index fbc7b419bb80..623604bbd5a1 100755 --- a/include/tvm/auto_scheduler/search_task.h +++ b/include/tvm/auto_scheduler/search_task.h @@ -152,16 +152,16 @@ class SearchTask : public ObjectRef { * \param target_host The target host device of this search task. * \param hardware_params Hardware parameters used in this search task. * \param layout_rewrite_option The layout rewrite option used for measuring programs. - * \param task_inputs .... + * \param task_inputs A map that stores some user defined input data used in program measuring. */ SearchTask(ComputeDAG compute_dag, String workload_key, Target target, Target target_host, Optional hardware_params, LayoutRewriteOption layout_rewrite_option, Map task_inputs); /*! - * \brief Add a input Tensor to this task. - * \param input_name ...... - * \param input_data ...... + * \brief Add a input Tensor to this task, will be used in program measuring. + * \param input_name The name of input Tensor. + * \param input_data The input Tensor. */ void AddTaskInput(String input_name, runtime::NDArray input_data); diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index 8b2d563199bc..255650e16c24 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -430,7 +430,7 @@ def print_best(self, log_file, print_mode="schedule"): return func.imported_modules[0].get_source() raise ValueError("Invalid print_mode: %s" % print_mode) - def add_task_input(self, input_name, input_data): + def add_task_input(self, input_name, input_data, overwrite=False): """Add a input Tensor to this task. Parameters @@ -440,7 +440,7 @@ def add_task_input(self, input_name, input_data): input_data : Tensor ... """ - register_task_input_buffer(self.workload_key, input_name, input_data) + register_task_input_buffer(self.workload_key, input_name, input_data, overwrite) _ffi_api.SearchTaskAddTaskInput(self, input_name, input_data) def __getstate__(self): diff --git a/tests/python/unittest/test_auto_scheduler_search_task.py b/tests/python/unittest/test_auto_scheduler_search_task.py index 3f1db4d46c46..fba82e41df9c 100644 --- a/tests/python/unittest/test_auto_scheduler_search_task.py +++ b/tests/python/unittest/test_auto_scheduler_search_task.py @@ -46,9 +46,9 @@ def test_search_task_add_task_input(): test_input_0 = tvm.runtime.ndarray.empty((64, 64)) test_input_1 = tvm.runtime.ndarray.empty((10, 20)) test_input_2 = tvm.runtime.ndarray.empty((30, 40, 50)) - task.add_task_input("test_input_0", test_input_0) - task.add_task_input("test_input_1", test_input_1) - task.add_task_input("test_input_2", test_input_2) + task.add_task_input("test_input_0", test_input_0, overwrite=True) + task.add_task_input("test_input_1", test_input_1, overwrite=True) + task.add_task_input("test_input_2", test_input_2, overwrite=True) assert len(task.task_inputs) == 3 assert task.task_inputs["test_input_0"] == test_input_0 @@ -75,7 +75,7 @@ def test_search_task_record(): # Log with 1 task input test_input_0 = tvm.runtime.ndarray.empty((64, 64)) - task.add_task_input("test_input_0", test_input_0) + task.add_task_input("test_input_0", test_input_0, overwrite=True) task_record = auto_scheduler._ffi_api.SerializeSearchTask(task) new_task = auto_scheduler._ffi_api.DeserializeSearchTask(task_record) assert task.workload_key == new_task.workload_key @@ -88,7 +88,7 @@ def test_search_task_record(): # Log with multiple task inputs test_input_1 = tvm.runtime.ndarray.empty((64, 64)) - task.add_task_input("test_input_1", test_input_1) + task.add_task_input("test_input_1", test_input_1, overwrite=True) task_record = auto_scheduler._ffi_api.SerializeSearchTask(task) new_task = auto_scheduler._ffi_api.DeserializeSearchTask(task_record) assert task.workload_key == new_task.workload_key @@ -132,7 +132,7 @@ def test_recover_measure_input_with_task_input(): # Log with 1 task input test_input_0 = tvm.runtime.ndarray.empty((64, 64)) - task.add_task_input("test_input_0", test_input_0) + task.add_task_input("test_input_0", test_input_0, overwrite=True) inp = auto_scheduler.measure.MeasureInput(task, task.compute_dag.init_state) res = auto_scheduler.measure.MeasureResult([0.1], 0, "", 0.2, 1) measure_record = auto_scheduler.measure_record.dump_record_to_string(inp, res) @@ -148,7 +148,7 @@ def test_recover_measure_input_with_task_input(): # Log with multiple task inputs test_input_1 = tvm.runtime.ndarray.empty((64, 64)) - task.add_task_input("test_input_1", test_input_1) + task.add_task_input("test_input_1", test_input_1, overwrite=True) inp = auto_scheduler.measure.MeasureInput(task, task.compute_dag.init_state) res = auto_scheduler.measure.MeasureResult([0.1], 0, "", 0.2, 1) measure_record = auto_scheduler.measure_record.dump_record_to_string(inp, res) From 418f42cc314a256206a40d224be3d01e9d32f389 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Tue, 2 Mar 2021 21:00:39 +0800 Subject: [PATCH 17/35] Update --- python/tvm/auto_scheduler/measure.py | 17 +++- python/tvm/auto_scheduler/search_task.py | 96 ++++++++++++++++----- python/tvm/topi/nn/sparse.py | 17 +++- tutorials/auto_scheduler/tune_sparse_x86.py | 2 +- 4 files changed, 107 insertions(+), 25 deletions(-) diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 320c6b4cf5c9..ce047363c218 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -721,7 +721,22 @@ def local_builder_build(inputs, timeout, n_parallel, build_func="default", verbo def _prepare_input_map(args): - """This function deals with special task inputs.""" + """This function deals with special task inputs. + + Parameters + ---------- + args : List[Tensor] + Input/output Tensor of a TVM subgraph. + + Returns + ------- + A Dict[Tensor, str] that maps the input Tensor to a buffer name. + + Note + ---- + The buffer name is specially designed, and these buffer should be provided in + `SearchTask.add_task_input()`. + """ # pylint: disable=import-outside-toplevel from tvm import topi # lazily import to avoid recursive dependency diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index 255650e16c24..bdb52201afdc 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -180,7 +180,26 @@ def __init__( TASK_INPUT_BUFFER_TABLE = {} +def _save_buffer_to_file(buffer_name, buffer_data): + """Save the current Tensor buffer to a numpy file. + + File name will be: {buffer_name}.{buffer_shape}_{buffer_data_type} + """ + np_data = buffer_data.asnumpy() + + buffer_name += "." + for i in np_data.shape: + buffer_name += "%d_" % (i) + buffer_name += "%s" % (np_data.dtype) + + np_data.tofile(buffer_name, " ") + + def _try_load_buffer_from_file(buffer_name): + """Try to load buffer from a numpy file, if not found, return None. + + File name has a same format as `_save_buffer_to_file`. + """ filelist = os.listdir() for file in filelist: @@ -195,29 +214,41 @@ def _try_load_buffer_from_file(buffer_name): return None -def _save_buffer_to_file(buffer_name, buffer_data): - np_data = buffer_data.asnumpy() +def register_task_input_buffer( + workload_key, + input_name, + input_data, + overwrite=False, + save_to_file=False, +): + """Register special buffer for measurement. - buffer_name += "." - for i in np_data.shape: - buffer_name += "%d_" % (i) - buffer_name += "%s" % (np_data.dtype) + Parameters + ---------- + workload_key : str + The workload key of the SearchTask. - np_data.tofile(buffer_name, " ") + input_name : str + The name of input buffer. + input_data : Tensor + The input Tensor data. -def register_task_input_buffer(workload_key, input_name, input_data, overwrite=False): - """Register special buffer for measurement - This can be used for sparse workloads when we cannot use random tensors for measurment. + overwrite : bool = False + Whether overwrite the data if a name has already in the global table. + + save_to_file : bool = False + Whether record this buffer to a local file. This can be reused to continue the last tuning + process. """ global TASK_INPUT_BUFFER_TABLE - if not workload_key in TASK_INPUT_BUFFER_TABLE: + if workload_key not in TASK_INPUT_BUFFER_TABLE: TASK_INPUT_BUFFER_TABLE[workload_key] = {} input_table = TASK_INPUT_BUFFER_TABLE[workload_key] if not overwrite: - if not input_name in input_table.keys(): + if input_name not in input_table.keys(): # Try to load buffer data from local file tensor_from_file = _try_load_buffer_from_file(input_name) if tensor_from_file: @@ -232,23 +263,36 @@ def register_task_input_buffer(workload_key, input_name, input_data, overwrite=F return input_table[input_name] input_table[input_name] = input_data - _save_buffer_to_file(input_name, input_data) + if save_to_file: + _save_buffer_to_file(input_name, input_data) return input_data @tvm._ffi.register_func("auto_scheduler.search_task.get_task_input_buffer") def get_task_input_buffer(workload_key, input_name): """Get special buffer for measurement. - This can be used for sparse workloads when we cannot use random tensors for measurment. + The buffers are registered by `register_task_input_buffer`. + + Parameters + ---------- + workload_key : str + The workload key of the SearchTask. + + input_name : str + The name of input buffer. + + Returns + ------- + The registered input buffer. """ global TASK_INPUT_BUFFER_TABLE - if not workload_key in TASK_INPUT_BUFFER_TABLE: + if workload_key not in TASK_INPUT_BUFFER_TABLE: TASK_INPUT_BUFFER_TABLE[workload_key] = {} input_table = TASK_INPUT_BUFFER_TABLE[workload_key] - if not input_name in input_table.keys(): + if input_name not in input_table.keys(): # Try to load buffer data from local file tensor_from_file = _try_load_buffer_from_file(input_name) if tensor_from_file: @@ -259,7 +303,7 @@ def get_task_input_buffer(workload_key, input_name): raise ValueError( "%s not found in TASK_INPUT_BUFFER_TABLE, " % (input_name) - + "should provide with SearchTask.AddTaskInput()" + + "should provide with SearchTask.add_task_input()" ) @@ -430,17 +474,27 @@ def print_best(self, log_file, print_mode="schedule"): return func.imported_modules[0].get_source() raise ValueError("Invalid print_mode: %s" % print_mode) - def add_task_input(self, input_name, input_data, overwrite=False): + def add_task_input(self, input_name, input_data, overwrite=False, save_to_file=False): """Add a input Tensor to this task. Parameters ---------- input_name : str - ... + The name of input buffer. + input_data : Tensor - ... + The input Tensor data. + + overwrite : bool = False + Whether overwrite the data if a name has already in the global table. + + save_to_file : bool = False + Whether record this buffer to a local file. This can be reused to continue the last + tuning process. """ - register_task_input_buffer(self.workload_key, input_name, input_data, overwrite) + register_task_input_buffer( + self.workload_key, input_name, input_data, overwrite, save_to_file + ) _ffi_api.SearchTaskAddTaskInput(self, input_name, input_data) def __getstate__(self): diff --git a/python/tvm/topi/nn/sparse.py b/python/tvm/topi/nn/sparse.py index c5460a4ad83f..da7d5dee7bd8 100644 --- a/python/tvm/topi/nn/sparse.py +++ b/python/tvm/topi/nn/sparse.py @@ -362,8 +362,21 @@ def sparse_dense_alter_layout(_attrs, _inputs, _tinfos, _out_type): def try_get_sparse_input(args): - """Analise the input data from the given args. Return is a Dict[Tensor, str] that maps the - input Tensor to a buffer name. + """Analise the input data from the given args. + + Parameters + ---------- + args : List[Tensor] + Input/output Tensor of a TVM subgraph. + + Returns + ------- + A Dict[Tensor, str] that maps the input Tensor to a buffer name. + + Note + ---- + The buffer name is specially designed, and these buffer should be provided in + `SearchTask.add_task_input()`. """ sparse_prefix = sparse_data = sparse_indices = sparse_indptr = None diff --git a/tutorials/auto_scheduler/tune_sparse_x86.py b/tutorials/auto_scheduler/tune_sparse_x86.py index 2b943ab2aabd..ce26879d7570 100644 --- a/tutorials/auto_scheduler/tune_sparse_x86.py +++ b/tutorials/auto_scheduler/tune_sparse_x86.py @@ -128,7 +128,7 @@ def sparse_dense(M, N, K, w_data_shape, w_indices_shape, w_indptr_shape, dtype): # - replace "llvm" below with "llvm -mcpu=core-avx2" to enable AVX2 # - replace "llvm" below with "llvm -mcpu=skylake-avx512" to enable AVX-512 -target = tvm.target.Target("llvm") +target = tvm.target.Target("llvm -mcpu=core-avx2") task = tvm.auto_scheduler.SearchTask( func=sparse_dense, From 1d735c8d48b667750135231ce62b2f3f99766c41 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 3 Mar 2021 11:54:07 +0800 Subject: [PATCH 18/35] Remove add_task_inputs API --- include/tvm/auto_scheduler/search_task.h | 11 +-- python/tvm/auto_scheduler/search_task.py | 51 +++++------- python/tvm/topi/nn/sparse.py | 2 +- src/auto_scheduler/feature.cc | 2 +- src/auto_scheduler/measure_record.cc | 7 +- src/auto_scheduler/search_task.cc | 16 +--- .../test_auto_scheduler_search_task.py | 79 +++++++++++-------- tutorials/auto_scheduler/tune_sparse_x86.py | 4 +- 8 files changed, 78 insertions(+), 94 deletions(-) diff --git a/include/tvm/auto_scheduler/search_task.h b/include/tvm/auto_scheduler/search_task.h index 623604bbd5a1..f59262d24e4d 100755 --- a/include/tvm/auto_scheduler/search_task.h +++ b/include/tvm/auto_scheduler/search_task.h @@ -122,7 +122,7 @@ class SearchTaskNode : public Object { /*! \brief The layout rewrite option used for measuring programs. */ LayoutRewriteOption layout_rewrite_option; /*! \brief A map that stores some user defined input data used in program measuring. */ - Map task_inputs; + Array task_inputs; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("compute_dag", &compute_dag); @@ -156,14 +156,7 @@ class SearchTask : public ObjectRef { */ SearchTask(ComputeDAG compute_dag, String workload_key, Target target, Target target_host, Optional hardware_params, LayoutRewriteOption layout_rewrite_option, - Map task_inputs); - - /*! - * \brief Add a input Tensor to this task, will be used in program measuring. - * \param input_name The name of input Tensor. - * \param input_data The input Tensor. - */ - void AddTaskInput(String input_name, runtime::NDArray input_data); + Array task_inputs); TVM_DEFINE_OBJECT_REF_METHODS(SearchTask, ObjectRef, SearchTaskNode); }; diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index bdb52201afdc..45c0b51b9919 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -231,7 +231,7 @@ def register_task_input_buffer( input_name : str The name of input buffer. - input_data : Tensor + input_data : tvm.nd.NDArray The input Tensor data. overwrite : bool = False @@ -339,6 +339,11 @@ class SearchTask(Object): Some special Tensor used as inputs in program measuring. Usually we do not need to care about it, but for special workloads like Sparse computation the Sparse Tensor input are meaningful that we cannot use random input directly. + task_inputs_overwrite : bool = False + Whether overwrite the data if a name has already in the global table. + task_inputs_save_to_file : bool = False + Whether record this buffer to a local file. This can be reused to continue the last + tuning process. Examples -------- @@ -366,7 +371,9 @@ def __init__( target_host=None, hardware_params=None, layout_rewrite_option=None, - task_inputs=None, + task_inputs={}, + task_inputs_overwrite=False, + task_inputs_save_to_file=False, ): assert ( func is not None or workload_key is not None @@ -386,6 +393,14 @@ def __init__( if layout_rewrite_option is None: layout_rewrite_option = LayoutRewriteOption.get_target_default(target) + task_input_names = [] + for input_name in task_inputs: + register_task_input_buffer( + workload_key, input_name, task_inputs[input_name], task_inputs_overwrite, + task_inputs_save_to_file + ) + task_input_names.append(input_name) + self.__init_handle_by_constructor__( _ffi_api.SearchTask, compute_dag, @@ -394,7 +409,7 @@ def __init__( target_host, hardware_params, layout_rewrite_option, - task_inputs or {}, + task_input_names, ) def tune(self, tuning_options, search_policy=None): @@ -474,29 +489,6 @@ def print_best(self, log_file, print_mode="schedule"): return func.imported_modules[0].get_source() raise ValueError("Invalid print_mode: %s" % print_mode) - def add_task_input(self, input_name, input_data, overwrite=False, save_to_file=False): - """Add a input Tensor to this task. - - Parameters - ---------- - input_name : str - The name of input buffer. - - input_data : Tensor - The input Tensor data. - - overwrite : bool = False - Whether overwrite the data if a name has already in the global table. - - save_to_file : bool = False - Whether record this buffer to a local file. This can be reused to continue the last - tuning process. - """ - register_task_input_buffer( - self.workload_key, input_name, input_data, overwrite, save_to_file - ) - _ffi_api.SearchTaskAddTaskInput(self, input_name, input_data) - def __getstate__(self): return { "compute_dag": self.compute_dag, @@ -505,7 +497,7 @@ def __getstate__(self): "target_host": self.target_host, "hardware_params": self.hardware_params, "layout_rewrite_option": self.layout_rewrite_option, - "task_inputs": [i[0] for i in self.task_inputs.items()], + "task_inputs": self.task_inputs, } def __setstate__(self, state): @@ -522,9 +514,6 @@ def __setstate__(self, state): if workload[0] not in WORKLOAD_FUNC_REGISTRY: register_workload_tensors(state["workload_key"], state["compute_dag"].tensors) - task_inputs = {} - for data_name in state["task_inputs"]: - task_inputs[data_name] = get_task_input_buffer(state["workload_key"], data_name) self.__init_handle_by_constructor__( _ffi_api.SearchTask, state["compute_dag"], @@ -533,7 +522,7 @@ def __setstate__(self, state): state["target_host"], state["hardware_params"], state["layout_rewrite_option"], - task_inputs, + state["task_inputs"], ) diff --git a/python/tvm/topi/nn/sparse.py b/python/tvm/topi/nn/sparse.py index da7d5dee7bd8..aacd3775ae72 100644 --- a/python/tvm/topi/nn/sparse.py +++ b/python/tvm/topi/nn/sparse.py @@ -362,7 +362,7 @@ def sparse_dense_alter_layout(_attrs, _inputs, _tinfos, _out_type): def try_get_sparse_input(args): - """Analise the input data from the given args. + """Analyze the input data from the given args. Parameters ---------- diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc index 08ee5a8c4238..dde43a7caacd 100755 --- a/src/auto_scheduler/feature.cc +++ b/src/auto_scheduler/feature.cc @@ -1468,7 +1468,7 @@ void GetPerStoreFeaturesFromMeasurePairs(const Array& inputs, Array tensors = (*workload_key_to_tensors)(workload_key); task = SearchTask(ComputeDAG(tensors), workload_key, inputs[i]->task->target, inputs[i]->task->target_host, inputs[i]->task->hardware_params, - inputs[i]->task->layout_rewrite_option, {}); + inputs[i]->task->layout_rewrite_option, inputs[i]->task->task_inputs); } catch (std::exception& e) { // Cannot build ComputeDAG from workload key, the task may have not been registered in // this search round diff --git a/src/auto_scheduler/measure_record.cc b/src/auto_scheduler/measure_record.cc index ba5755885594..00f717026498 100644 --- a/src/auto_scheduler/measure_record.cc +++ b/src/auto_scheduler/measure_record.cc @@ -172,7 +172,7 @@ struct Handler<::tvm::auto_scheduler::SearchTaskNode> { writer->WriteArraySeperator(); writer->BeginArray(false); for (const auto& i : data.task_inputs) { - writer->WriteArrayItem(std::string(i.first)); + writer->WriteArrayItem(std::string(i)); } writer->EndArray(); writer->EndArray(); @@ -207,14 +207,11 @@ struct Handler<::tvm::auto_scheduler::SearchTaskNode> { data->layout_rewrite_option = ::tvm::auto_scheduler::LayoutRewriteOption(int_value); s = reader->NextArrayItem(); if (s) { - std::vector input_data_names; - const auto& func = - tvm::runtime::Registry::Get("auto_scheduler.search_task.get_task_input_buffer"); reader->BeginArray(); s = reader->NextArrayItem(); while (s) { reader->Read(&str_value); - data->task_inputs.Set(str_value, (*func)(data->workload_key, str_value)); + data->task_inputs.push_back(str_value); s = reader->NextArrayItem(); } // Process the end of array diff --git a/src/auto_scheduler/search_task.cc b/src/auto_scheduler/search_task.cc index 0c6a5edbdb99..9e208ebec362 100755 --- a/src/auto_scheduler/search_task.cc +++ b/src/auto_scheduler/search_task.cc @@ -115,7 +115,7 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target SearchTask::SearchTask(ComputeDAG compute_dag, String workload_key, Target target, Target target_host, Optional hardware_params, LayoutRewriteOption layout_rewrite_option, - Map task_inputs) { + Array task_inputs) { auto node = make_object(); node->compute_dag = std::move(compute_dag); node->workload_key = std::move(workload_key); @@ -132,13 +132,6 @@ SearchTask::SearchTask(ComputeDAG compute_dag, String workload_key, Target targe data_ = std::move(node); } -void SearchTask::AddTaskInput(String data_name, runtime::NDArray data) { - if (operator->()->task_inputs.count(data_name)) { - LOG(WARNING) << data_name << " already in memory"; - } - static_cast(get_mutable())->task_inputs.Set(data_name, data); -} - TVM_REGISTER_GLOBAL("auto_scheduler.HardwareParams") .set_body_typed([](int num_cores, int vector_unit_bytes, int cache_line_bytes, int max_shared_memory_per_block, int max_local_memory_per_block, @@ -151,15 +144,10 @@ TVM_REGISTER_GLOBAL("auto_scheduler.HardwareParams") TVM_REGISTER_GLOBAL("auto_scheduler.SearchTask") .set_body_typed([](ComputeDAG compute_dag, String workload_key, Target target, Target target_host, Optional hardware_params, - int layout_rewrite_option, Map task_inputs) { + int layout_rewrite_option, Array task_inputs) { return SearchTask(compute_dag, workload_key, target, target_host, hardware_params, LayoutRewriteOption(layout_rewrite_option), task_inputs); }); -TVM_REGISTER_GLOBAL("auto_scheduler.SearchTaskAddTaskInput") - .set_body_typed([](SearchTask search_task, String input_name, runtime::NDArray input_data) { - search_task.AddTaskInput(input_name, input_data); - }); - } // namespace auto_scheduler } // namespace tvm diff --git a/tests/python/unittest/test_auto_scheduler_search_task.py b/tests/python/unittest/test_auto_scheduler_search_task.py index fba82e41df9c..da641b0ce100 100644 --- a/tests/python/unittest/test_auto_scheduler_search_task.py +++ b/tests/python/unittest/test_auto_scheduler_search_task.py @@ -39,34 +39,35 @@ def test_search_task_add_task_input(): auto_scheduler.search_task.TASK_INPUT_BUFFER_TABLE.clear() N = 64 target = "llvm" - task = auto_scheduler.SearchTask( - func="matmul_auto_scheduler_test", args=(N, N, N), target=target - ) - test_input_0 = tvm.runtime.ndarray.empty((64, 64)) test_input_1 = tvm.runtime.ndarray.empty((10, 20)) test_input_2 = tvm.runtime.ndarray.empty((30, 40, 50)) - task.add_task_input("test_input_0", test_input_0, overwrite=True) - task.add_task_input("test_input_1", test_input_1, overwrite=True) - task.add_task_input("test_input_2", test_input_2, overwrite=True) + task = auto_scheduler.SearchTask( + func="matmul_auto_scheduler_test", args=(N, N, N), target=target, + task_inputs={ + "test_input_0": test_input_0, + "test_input_1": test_input_1, + "test_input_2": test_input_2, + }, task_inputs_overwrite=True + ) assert len(task.task_inputs) == 3 - assert task.task_inputs["test_input_0"] == test_input_0 - assert task.task_inputs["test_input_1"] == test_input_1 - assert task.task_inputs["test_input_2"] == test_input_2 + assert task.task_inputs[0] == "test_input_0" + assert task.task_inputs[1] == "test_input_1" + assert task.task_inputs[2] == "test_input_2" def test_search_task_record(): auto_scheduler.search_task.TASK_INPUT_BUFFER_TABLE.clear() N = 64 target = "llvm" + + # Log with no task input task = auto_scheduler.SearchTask( func="matmul_auto_scheduler_test", args=(N, N, N), target=target ) task_record = auto_scheduler._ffi_api.SerializeSearchTask(task) new_task = auto_scheduler._ffi_api.DeserializeSearchTask(task_record) - - # Log with no task input # TODO(jcf94): Check the compute dag & hardware parameter assert task.workload_key == new_task.workload_key assert str(task.target) == str(new_task.target) @@ -75,7 +76,12 @@ def test_search_task_record(): # Log with 1 task input test_input_0 = tvm.runtime.ndarray.empty((64, 64)) - task.add_task_input("test_input_0", test_input_0, overwrite=True) + task = auto_scheduler.SearchTask( + func="matmul_auto_scheduler_test", args=(N, N, N), target=target, + task_inputs={ + "test_input_0": test_input_0 + }, task_inputs_overwrite=True + ) task_record = auto_scheduler._ffi_api.SerializeSearchTask(task) new_task = auto_scheduler._ffi_api.DeserializeSearchTask(task_record) assert task.workload_key == new_task.workload_key @@ -83,12 +89,17 @@ def test_search_task_record(): assert str(task.target_host) == str(new_task.target_host) assert task.layout_rewrite_option == new_task.layout_rewrite_option assert len(new_task.task_inputs) == 1 - assert new_task.task_inputs.items()[0][0] == "test_input_0" - assert new_task.task_inputs.items()[0][1] == test_input_0 + assert new_task.task_inputs[0] == "test_input_0" # Log with multiple task inputs test_input_1 = tvm.runtime.ndarray.empty((64, 64)) - task.add_task_input("test_input_1", test_input_1, overwrite=True) + task = auto_scheduler.SearchTask( + func="matmul_auto_scheduler_test", args=(N, N, N), target=target, + task_inputs={ + "test_input_0": test_input_0, + "test_input_1": test_input_1, + }, task_inputs_overwrite=True + ) task_record = auto_scheduler._ffi_api.SerializeSearchTask(task) new_task = auto_scheduler._ffi_api.DeserializeSearchTask(task_record) assert task.workload_key == new_task.workload_key @@ -96,10 +107,8 @@ def test_search_task_record(): assert str(task.target_host) == str(new_task.target_host) assert task.layout_rewrite_option == new_task.layout_rewrite_option assert len(new_task.task_inputs) == 2 - assert new_task.task_inputs.items()[0][0] == "test_input_0" - assert new_task.task_inputs.items()[0][1] == test_input_0 - assert new_task.task_inputs.items()[1][0] == "test_input_1" - assert new_task.task_inputs.items()[1][1] == test_input_1 + assert new_task.task_inputs[0] == "test_input_0" + assert new_task.task_inputs[1] == "test_input_1" # Log with version 0.5 v5_log = """["[\\\"matmul_auto_scheduler_test\\\", 64, 64, 64]", "llvm -keys=cpu -link-params=0", [6, 64, 64, 0, 0, 0, 0, 0], "", 1]""" @@ -113,13 +122,13 @@ def test_search_task_record(): def test_recover_measure_input_with_task_input(): auto_scheduler.search_task.TASK_INPUT_BUFFER_TABLE.clear() - task = auto_scheduler.SearchTask( - func=matmul_auto_scheduler_test, args=(512, 512, 512), target="llvm" - ) # Since this file is tests for search_task, we only check the search_task here # Log with no task input + task = auto_scheduler.SearchTask( + func=matmul_auto_scheduler_test, args=(512, 512, 512), target="llvm" + ) inp = auto_scheduler.measure.MeasureInput(task, task.compute_dag.init_state) res = auto_scheduler.measure.MeasureResult([0.1], 0, "", 0.2, 1) measure_record = auto_scheduler.measure_record.dump_record_to_string(inp, res) @@ -132,7 +141,12 @@ def test_recover_measure_input_with_task_input(): # Log with 1 task input test_input_0 = tvm.runtime.ndarray.empty((64, 64)) - task.add_task_input("test_input_0", test_input_0, overwrite=True) + task = auto_scheduler.SearchTask( + func=matmul_auto_scheduler_test, args=(512, 512, 512), target="llvm", + task_inputs={ + "test_input_0": test_input_0, + }, task_inputs_overwrite=True + ) inp = auto_scheduler.measure.MeasureInput(task, task.compute_dag.init_state) res = auto_scheduler.measure.MeasureResult([0.1], 0, "", 0.2, 1) measure_record = auto_scheduler.measure_record.dump_record_to_string(inp, res) @@ -143,12 +157,17 @@ def test_recover_measure_input_with_task_input(): assert str(task.target_host) == str(new_task.target_host) assert task.layout_rewrite_option == new_task.layout_rewrite_option assert len(new_task.task_inputs) == 1 - assert new_task.task_inputs.items()[0][0] == "test_input_0" - assert new_task.task_inputs.items()[0][1] == test_input_0 + assert new_task.task_inputs[0] == "test_input_0" # Log with multiple task inputs test_input_1 = tvm.runtime.ndarray.empty((64, 64)) - task.add_task_input("test_input_1", test_input_1, overwrite=True) + task = auto_scheduler.SearchTask( + func=matmul_auto_scheduler_test, args=(512, 512, 512), target="llvm", + task_inputs={ + "test_input_0": test_input_0, + "test_input_1": test_input_1, + }, task_inputs_overwrite=True + ) inp = auto_scheduler.measure.MeasureInput(task, task.compute_dag.init_state) res = auto_scheduler.measure.MeasureResult([0.1], 0, "", 0.2, 1) measure_record = auto_scheduler.measure_record.dump_record_to_string(inp, res) @@ -159,10 +178,8 @@ def test_recover_measure_input_with_task_input(): assert str(task.target_host) == str(new_task.target_host) assert task.layout_rewrite_option == new_task.layout_rewrite_option assert len(new_task.task_inputs) == 2 - assert new_task.task_inputs.items()[0][0] == "test_input_0" - assert new_task.task_inputs.items()[0][1] == test_input_0 - assert new_task.task_inputs.items()[1][0] == "test_input_1" - assert new_task.task_inputs.items()[1][1] == test_input_1 + assert new_task.task_inputs[0] == "test_input_0" + assert new_task.task_inputs[1] == "test_input_1" # Log with version 0.5 v5_log = """{"i": [["[\\\"matmul_auto_scheduler_test\\\", 512, 512, 512]", "llvm -keys=cpu -link-params=0", [6, 64, 64, 0, 0, 0, 0, 0], "", 1], [[], []]], "r": [[0.1], 0, 0.2, 1], "v": "v0.6"}""" diff --git a/tutorials/auto_scheduler/tune_sparse_x86.py b/tutorials/auto_scheduler/tune_sparse_x86.py index ce26879d7570..8d7eab2eed33 100644 --- a/tutorials/auto_scheduler/tune_sparse_x86.py +++ b/tutorials/auto_scheduler/tune_sparse_x86.py @@ -28,7 +28,7 @@ Fortunately, auto-scheduler currently allows user to provide a CustomSketch to cover these cases. We use sparse matrix multiplication as an example in this tutorial to demonstrate how to implement -and plug a custom sketch rule to the auto-scheduler search policy. +and plug a custom sketch rule to the auto-scheduler's search policy. Note that this tutorial will not run on Windows or recent versions of macOS. To get it to run, you will need to wrap the body of this tutorial in a :code:`if @@ -101,7 +101,7 @@ def sparse_dense(M, N, K, w_data_shape, w_indices_shape, w_indptr_shape, dtype): # # To solve this problem, we register these as special buffers, and load them when process program # measuring. -# See the :any:`auto_scheduler.measure` code for more details. +# See the `tvm.auto_scheduler.measure.py` for more details. # Define the basic shapes of this sparse computation M = K = N = 512 From 208325463b34a37fda97f970f93f1e80215131a3 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 3 Mar 2021 14:05:53 +0800 Subject: [PATCH 19/35] Update --- python/tvm/auto_scheduler/measure.py | 10 ++++++--- python/tvm/auto_scheduler/search_task.py | 25 ++++++++++++++------- python/tvm/auto_scheduler/utils.py | 2 +- tutorials/auto_scheduler/tune_sparse_x86.py | 13 ++++++----- 4 files changed, 32 insertions(+), 18 deletions(-) diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index ce047363c218..a5b0a93db045 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -223,7 +223,7 @@ def recover_measure_input(inp, rebuild_state=False): target_host=task.target_host, hardware_params=task.hardware_params, layout_rewrite_option=task.layout_rewrite_option, - task_inputs=task.task_inputs, + task_inputs=list(task.task_inputs), ) if rebuild_state: @@ -763,6 +763,8 @@ def _timed_eval_func( enable_cpu_cache_flush, verbose, ): + from .search_task import get_task_input_buffer + inp = MeasureInput.deserialize(inp_serialized) task_inputs = inp.task.task_inputs tic = time.time() @@ -802,7 +804,7 @@ def _timed_eval_func( if arg in tensor_input_map: tensor_name = tensor_input_map[arg] if tensor_name in task_inputs: - args.append(task_inputs[tensor_name]) + args.append(get_task_input_buffer(inp.task.workload_key, tensor_name)) else: raise ValueError( "%s not found in task_inputs, " % (tensor_name) @@ -960,6 +962,8 @@ def _timed_rpc_run( enable_cpu_cache_flush, verbose, ): + from .search_task import get_task_input_buffer + inp = MeasureInput.deserialize(inp_serialized) task_inputs = inp.task.task_inputs tic = time.time() @@ -1004,7 +1008,7 @@ def _timed_rpc_run( if arg in tensor_input_map: tensor_name = tensor_input_map[arg] if tensor_name in task_inputs: - args.append(task_inputs[tensor_name]) + args.append(get_task_input_buffer(inp.task.workload_key, tensor_name)) else: raise ValueError( "%s not found in task_inputs, " % (tensor_name) diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index 45c0b51b9919..32519d1e054e 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -335,7 +335,8 @@ class SearchTask(Object): The NO_REWRITE and INSERT_TRANSFORM_STAGE are expected to be used when tuning a standalone op, and the REWRITE_FOR_PRE_TRANSFORMED is expected to be used when tuning ops inside a network. - task_inputs : Optional[Dict[str, Tensor]] + task_inputs : Union[Dict[str, tvm.nd.NDArray], List[str]] + A dict maps the input names to input tensors or a list of input names. Some special Tensor used as inputs in program measuring. Usually we do not need to care about it, but for special workloads like Sparse computation the Sparse Tensor input are meaningful that we cannot use random input directly. @@ -371,7 +372,7 @@ def __init__( target_host=None, hardware_params=None, layout_rewrite_option=None, - task_inputs={}, + task_inputs=None, task_inputs_overwrite=False, task_inputs_save_to_file=False, ): @@ -394,12 +395,20 @@ def __init__( layout_rewrite_option = LayoutRewriteOption.get_target_default(target) task_input_names = [] - for input_name in task_inputs: - register_task_input_buffer( - workload_key, input_name, task_inputs[input_name], task_inputs_overwrite, - task_inputs_save_to_file - ) - task_input_names.append(input_name) + if isinstance(task_inputs, list): + task_input_names = task_inputs + elif isinstance(task_inputs, dict): + for input_name in task_inputs: + register_task_input_buffer( + workload_key, + input_name, + task_inputs[input_name], + task_inputs_overwrite, + task_inputs_save_to_file, + ) + task_input_names.append(input_name) + elif task_inputs is not None: + raise ValueError("task_inputs should be a dict or a list.") self.__init_handle_by_constructor__( _ffi_api.SearchTask, diff --git a/python/tvm/auto_scheduler/utils.py b/python/tvm/auto_scheduler/utils.py index 8aa33e6775f8..2ec3a107bb72 100644 --- a/python/tvm/auto_scheduler/utils.py +++ b/python/tvm/auto_scheduler/utils.py @@ -243,7 +243,7 @@ def kill_child_processes(parent_pid, sig=signal.SIGTERM): # The maximum length of traceback information -MAX_TRACEBACK_INFO_LEN = 512 +MAX_TRACEBACK_INFO_LEN = 51200 def make_traceback_info(): diff --git a/tutorials/auto_scheduler/tune_sparse_x86.py b/tutorials/auto_scheduler/tune_sparse_x86.py index 8d7eab2eed33..50528df05509 100644 --- a/tutorials/auto_scheduler/tune_sparse_x86.py +++ b/tutorials/auto_scheduler/tune_sparse_x86.py @@ -130,18 +130,19 @@ def sparse_dense(M, N, K, w_data_shape, w_indices_shape, w_indptr_shape, dtype): target = tvm.target.Target("llvm -mcpu=core-avx2") +# Register the sparse data to special buffer +prefix = "sparse_dense_bsr_%d_%d_%d_%d_%d_%.2f_" % (M, N, K, BS_R, BS_C, density) task = tvm.auto_scheduler.SearchTask( func=sparse_dense, args=(M, N, K, W_sp_np.data.shape, W_sp_np.indices.shape, W_sp_np.indptr.shape, "float32"), target=target, + task_inputs={ + prefix + "W_data": runtime.ndarray.array(W_sp_np.data), + prefix + "W_indices": runtime.ndarray.array(W_sp_np.indices), + prefix + "W_indptr": runtime.ndarray.array(W_sp_np.indptr), + } ) -# Register the sparse data to special buffer -prefix = "sparse_dense_bsr_%d_%d_%d_%d_%d_%.2f_" % (M, N, K, BS_R, BS_C, density) -task.add_task_input(prefix + "W_data", runtime.ndarray.array(W_sp_np.data)) -task.add_task_input(prefix + "W_indices", runtime.ndarray.array(W_sp_np.indices)) -task.add_task_input(prefix + "W_indptr", runtime.ndarray.array(W_sp_np.indptr)) - # Inspect the computational graph print("Computational DAG:") print(task.compute_dag) From 1622de06e0082cfc1f27f4a3dea9619d0a0a5ed2 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 3 Mar 2021 14:17:54 +0800 Subject: [PATCH 20/35] Update --- include/tvm/auto_scheduler/search_task.h | 4 ++-- python/tvm/auto_scheduler/measure.py | 6 +++--- python/tvm/auto_scheduler/search_task.py | 2 +- python/tvm/auto_scheduler/utils.py | 2 +- python/tvm/topi/nn/sparse.py | 2 +- tutorials/auto_scheduler/tune_sparse_x86.py | 2 +- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/include/tvm/auto_scheduler/search_task.h b/include/tvm/auto_scheduler/search_task.h index f59262d24e4d..c90ef75c25c7 100755 --- a/include/tvm/auto_scheduler/search_task.h +++ b/include/tvm/auto_scheduler/search_task.h @@ -121,7 +121,7 @@ class SearchTaskNode : public Object { HardwareParams hardware_params; /*! \brief The layout rewrite option used for measuring programs. */ LayoutRewriteOption layout_rewrite_option; - /*! \brief A map that stores some user defined input data used in program measuring. */ + /*! \brief Names of some user defined input data used in program measuring. */ Array task_inputs; void VisitAttrs(tvm::AttrVisitor* v) { @@ -152,7 +152,7 @@ class SearchTask : public ObjectRef { * \param target_host The target host device of this search task. * \param hardware_params Hardware parameters used in this search task. * \param layout_rewrite_option The layout rewrite option used for measuring programs. - * \param task_inputs A map that stores some user defined input data used in program measuring. + * \param task_inputs Names of some user defined input data used in program measuring. */ SearchTask(ComputeDAG compute_dag, String workload_key, Target target, Target target_host, Optional hardware_params, LayoutRewriteOption layout_rewrite_option, diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index a5b0a93db045..09e332efe06e 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -735,7 +735,7 @@ def _prepare_input_map(args): Note ---- The buffer name is specially designed, and these buffer should be provided in - `SearchTask.add_task_input()`. + `SearchTask(..., task_inputs={...})`. """ # pylint: disable=import-outside-toplevel from tvm import topi # lazily import to avoid recursive dependency @@ -763,7 +763,7 @@ def _timed_eval_func( enable_cpu_cache_flush, verbose, ): - from .search_task import get_task_input_buffer + from .search_task import get_task_input_buffer # lazily import to avoid recursive dependency inp = MeasureInput.deserialize(inp_serialized) task_inputs = inp.task.task_inputs @@ -962,7 +962,7 @@ def _timed_rpc_run( enable_cpu_cache_flush, verbose, ): - from .search_task import get_task_input_buffer + from .search_task import get_task_input_buffer # lazily import to avoid recursive dependency inp = MeasureInput.deserialize(inp_serialized) task_inputs = inp.task.task_inputs diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index 32519d1e054e..81369e8d5400 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -303,7 +303,7 @@ def get_task_input_buffer(workload_key, input_name): raise ValueError( "%s not found in TASK_INPUT_BUFFER_TABLE, " % (input_name) - + "should provide with SearchTask.add_task_input()" + + "should provide with `SearchTask(..., task_inputs={...})`" ) diff --git a/python/tvm/auto_scheduler/utils.py b/python/tvm/auto_scheduler/utils.py index 2ec3a107bb72..8aa33e6775f8 100644 --- a/python/tvm/auto_scheduler/utils.py +++ b/python/tvm/auto_scheduler/utils.py @@ -243,7 +243,7 @@ def kill_child_processes(parent_pid, sig=signal.SIGTERM): # The maximum length of traceback information -MAX_TRACEBACK_INFO_LEN = 51200 +MAX_TRACEBACK_INFO_LEN = 512 def make_traceback_info(): diff --git a/python/tvm/topi/nn/sparse.py b/python/tvm/topi/nn/sparse.py index aacd3775ae72..199716f1ee60 100644 --- a/python/tvm/topi/nn/sparse.py +++ b/python/tvm/topi/nn/sparse.py @@ -376,7 +376,7 @@ def try_get_sparse_input(args): Note ---- The buffer name is specially designed, and these buffer should be provided in - `SearchTask.add_task_input()`. + `SearchTask(..., task_inputs={...})`. """ sparse_prefix = sparse_data = sparse_indices = sparse_indptr = None diff --git a/tutorials/auto_scheduler/tune_sparse_x86.py b/tutorials/auto_scheduler/tune_sparse_x86.py index 50528df05509..b283ce6fb931 100644 --- a/tutorials/auto_scheduler/tune_sparse_x86.py +++ b/tutorials/auto_scheduler/tune_sparse_x86.py @@ -130,7 +130,7 @@ def sparse_dense(M, N, K, w_data_shape, w_indices_shape, w_indptr_shape, dtype): target = tvm.target.Target("llvm -mcpu=core-avx2") -# Register the sparse data to special buffer +# Register the sparse data to task inputs prefix = "sparse_dense_bsr_%d_%d_%d_%d_%d_%.2f_" % (M, N, K, BS_R, BS_C, density) task = tvm.auto_scheduler.SearchTask( func=sparse_dense, From b6f02cc2e36e526a35fbfc3590abb3b05d65b202 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 3 Mar 2021 15:03:59 +0800 Subject: [PATCH 21/35] Update --- python/tvm/auto_scheduler/measure.py | 10 +++++-- python/tvm/auto_scheduler/utils.py | 3 ++ python/tvm/topi/nn/sparse.py | 3 ++ .../unittest/test_auto_scheduler_measure.py | 29 +++++++++++++++++++ 4 files changed, 43 insertions(+), 2 deletions(-) diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 09e332efe06e..3649d4a15323 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -743,11 +743,17 @@ def _prepare_input_map(args): # A dict that maps the input tensor arg to a buffer name tensor_input_map = {} - # Case 0: Check sparse op + # Case 0: Check placeholder name + for arg in args: + if isinstance(arg.op, tvm.te.PlaceholderOp): + if arg.op.name != "placeholder": + tensor_input_map[arg] = arg.op.name + + # Case 1: Check sparse op sparse_input_map = topi.nn.sparse.try_get_sparse_input(args) tensor_input_map.update(sparse_input_map) - # Case 1: Check ... + # Case 2: Check ... # Process any other special buffers here and update them to tensor_input_map return tensor_input_map diff --git a/python/tvm/auto_scheduler/utils.py b/python/tvm/auto_scheduler/utils.py index 8aa33e6775f8..14dc5b8984c3 100644 --- a/python/tvm/auto_scheduler/utils.py +++ b/python/tvm/auto_scheduler/utils.py @@ -201,6 +201,9 @@ def serialize_args(args): Currently this is mainly used for tvm.tensor.Tensor """ ret = [] + if args is None: + return tuple(ret) + for t in args: if isinstance(t, Tensor): t = ("TENSOR", get_const_tuple(t.shape), t.dtype) diff --git a/python/tvm/topi/nn/sparse.py b/python/tvm/topi/nn/sparse.py index 199716f1ee60..0793cbcdc3a3 100644 --- a/python/tvm/topi/nn/sparse.py +++ b/python/tvm/topi/nn/sparse.py @@ -457,6 +457,9 @@ def _traverse(t): except Exception: return {} + if sparse_data is None or sparse_indices is None or sparse_indptr is None: + return {} + sparse_input_map = {} sparse_input_map[sparse_data] = sparse_prefix + "W_data" sparse_input_map[sparse_indices] = sparse_prefix + "W_indices" diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index cc9d7a41548d..a8c658ed9442 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -19,6 +19,7 @@ import json import multiprocessing +import numpy as np import tvm from tvm import topi from tvm import te, auto_scheduler @@ -355,6 +356,33 @@ def test_measure_target_host(): assert str(recovered_inp.task.target_host) == str(inp.task.target_host) +@tvm.testing.requires_llvm +def test_measure_special_inputs_map_by_name(): + @auto_scheduler.register_workload + def foo(): + X = te.placeholder(shape=[10], dtype="int32") + Index = te.placeholder(shape=[1], dtype="int32", name="Index") + Y = te.compute((1,), lambda i: X[Index[i]]) + return [X, Index, Y] + + # This workload cannot use random input for the `Index` input + task = auto_scheduler.SearchTask( + func=foo, target="llvm", + task_inputs={ + "Index": tvm.nd.array(np.array([5], dtype="int32")), + } + ) + + minp = auto_scheduler.MeasureInput(task, task.compute_dag.init_state) + local_builder = auto_scheduler.LocalBuilder() + local_runner = auto_scheduler.LocalRunner(timeout=10) + + bress = local_builder.build([minp]) + assert bress[0].error_no == 0 + mress = local_runner.run([minp], bress) + assert mress[0].error_no == 0 + + if __name__ == "__main__": test_record_split_reorder_fuse_annotation() test_record_compute_at_root_inline_cache_read_write() @@ -366,3 +394,4 @@ def test_measure_target_host(): test_dag_measure_local_builder_runner() test_measure_local_builder_rpc_runner() test_measure_target_host() + test_measure_special_inputs_map_by_name() From ca92d644f0d9b32cc7219c1e8046f69954ad9410 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 3 Mar 2021 15:10:55 +0800 Subject: [PATCH 22/35] Lint fix --- src/auto_scheduler/search_task.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/auto_scheduler/search_task.cc b/src/auto_scheduler/search_task.cc index 9e208ebec362..28cafb1d2706 100755 --- a/src/auto_scheduler/search_task.cc +++ b/src/auto_scheduler/search_task.cc @@ -114,8 +114,7 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target SearchTask::SearchTask(ComputeDAG compute_dag, String workload_key, Target target, Target target_host, Optional hardware_params, - LayoutRewriteOption layout_rewrite_option, - Array task_inputs) { + LayoutRewriteOption layout_rewrite_option, Array task_inputs) { auto node = make_object(); node->compute_dag = std::move(compute_dag); node->workload_key = std::move(workload_key); From 2273998728ab3b6aa6a8e739ffaad51a5aacefae Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 3 Mar 2021 15:14:50 +0800 Subject: [PATCH 23/35] Lint fix --- .../unittest/test_auto_scheduler_measure.py | 5 ++- .../test_auto_scheduler_search_task.py | 37 +++++++++++++------ 2 files changed, 28 insertions(+), 14 deletions(-) diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index a8c658ed9442..400b3eddcfe8 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -367,10 +367,11 @@ def foo(): # This workload cannot use random input for the `Index` input task = auto_scheduler.SearchTask( - func=foo, target="llvm", + func=foo, + target="llvm", task_inputs={ "Index": tvm.nd.array(np.array([5], dtype="int32")), - } + }, ) minp = auto_scheduler.MeasureInput(task, task.compute_dag.init_state) diff --git a/tests/python/unittest/test_auto_scheduler_search_task.py b/tests/python/unittest/test_auto_scheduler_search_task.py index da641b0ce100..dc8fdcbc824b 100644 --- a/tests/python/unittest/test_auto_scheduler_search_task.py +++ b/tests/python/unittest/test_auto_scheduler_search_task.py @@ -43,12 +43,15 @@ def test_search_task_add_task_input(): test_input_1 = tvm.runtime.ndarray.empty((10, 20)) test_input_2 = tvm.runtime.ndarray.empty((30, 40, 50)) task = auto_scheduler.SearchTask( - func="matmul_auto_scheduler_test", args=(N, N, N), target=target, + func="matmul_auto_scheduler_test", + args=(N, N, N), + target=target, task_inputs={ "test_input_0": test_input_0, "test_input_1": test_input_1, "test_input_2": test_input_2, - }, task_inputs_overwrite=True + }, + task_inputs_overwrite=True, ) assert len(task.task_inputs) == 3 @@ -77,10 +80,11 @@ def test_search_task_record(): # Log with 1 task input test_input_0 = tvm.runtime.ndarray.empty((64, 64)) task = auto_scheduler.SearchTask( - func="matmul_auto_scheduler_test", args=(N, N, N), target=target, - task_inputs={ - "test_input_0": test_input_0 - }, task_inputs_overwrite=True + func="matmul_auto_scheduler_test", + args=(N, N, N), + target=target, + task_inputs={"test_input_0": test_input_0}, + task_inputs_overwrite=True, ) task_record = auto_scheduler._ffi_api.SerializeSearchTask(task) new_task = auto_scheduler._ffi_api.DeserializeSearchTask(task_record) @@ -94,11 +98,14 @@ def test_search_task_record(): # Log with multiple task inputs test_input_1 = tvm.runtime.ndarray.empty((64, 64)) task = auto_scheduler.SearchTask( - func="matmul_auto_scheduler_test", args=(N, N, N), target=target, + func="matmul_auto_scheduler_test", + args=(N, N, N), + target=target, task_inputs={ "test_input_0": test_input_0, "test_input_1": test_input_1, - }, task_inputs_overwrite=True + }, + task_inputs_overwrite=True, ) task_record = auto_scheduler._ffi_api.SerializeSearchTask(task) new_task = auto_scheduler._ffi_api.DeserializeSearchTask(task_record) @@ -142,10 +149,13 @@ def test_recover_measure_input_with_task_input(): # Log with 1 task input test_input_0 = tvm.runtime.ndarray.empty((64, 64)) task = auto_scheduler.SearchTask( - func=matmul_auto_scheduler_test, args=(512, 512, 512), target="llvm", + func=matmul_auto_scheduler_test, + args=(512, 512, 512), + target="llvm", task_inputs={ "test_input_0": test_input_0, - }, task_inputs_overwrite=True + }, + task_inputs_overwrite=True, ) inp = auto_scheduler.measure.MeasureInput(task, task.compute_dag.init_state) res = auto_scheduler.measure.MeasureResult([0.1], 0, "", 0.2, 1) @@ -162,11 +172,14 @@ def test_recover_measure_input_with_task_input(): # Log with multiple task inputs test_input_1 = tvm.runtime.ndarray.empty((64, 64)) task = auto_scheduler.SearchTask( - func=matmul_auto_scheduler_test, args=(512, 512, 512), target="llvm", + func=matmul_auto_scheduler_test, + args=(512, 512, 512), + target="llvm", task_inputs={ "test_input_0": test_input_0, "test_input_1": test_input_1, - }, task_inputs_overwrite=True + }, + task_inputs_overwrite=True, ) inp = auto_scheduler.measure.MeasureInput(task, task.compute_dag.init_state) res = auto_scheduler.measure.MeasureResult([0.1], 0, "", 0.2, 1) From 034dcab2076e5fe2ac389e50ef29a219590536ac Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 3 Mar 2021 15:20:53 +0800 Subject: [PATCH 24/35] Lint fix --- tutorials/auto_scheduler/tune_sparse_x86.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorials/auto_scheduler/tune_sparse_x86.py b/tutorials/auto_scheduler/tune_sparse_x86.py index b283ce6fb931..1f22ca441dfb 100644 --- a/tutorials/auto_scheduler/tune_sparse_x86.py +++ b/tutorials/auto_scheduler/tune_sparse_x86.py @@ -140,7 +140,7 @@ def sparse_dense(M, N, K, w_data_shape, w_indices_shape, w_indptr_shape, dtype): prefix + "W_data": runtime.ndarray.array(W_sp_np.data), prefix + "W_indices": runtime.ndarray.array(W_sp_np.indices), prefix + "W_indptr": runtime.ndarray.array(W_sp_np.indptr), - } + }, ) # Inspect the computational graph From 35ce552984ba62612654f1297dbe56d44f7469f4 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 3 Mar 2021 15:34:43 +0800 Subject: [PATCH 25/35] Lint fix --- python/tvm/auto_scheduler/measure.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 3649d4a15323..d4607f5a480b 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -769,6 +769,7 @@ def _timed_eval_func( enable_cpu_cache_flush, verbose, ): + # pylint: disable=import-outside-toplevel from .search_task import get_task_input_buffer # lazily import to avoid recursive dependency inp = MeasureInput.deserialize(inp_serialized) @@ -968,6 +969,7 @@ def _timed_rpc_run( enable_cpu_cache_flush, verbose, ): + # pylint: disable=import-outside-toplevel from .search_task import get_task_input_buffer # lazily import to avoid recursive dependency inp = MeasureInput.deserialize(inp_serialized) From 925fd702337398294744d0b50cc28fee5074ec20 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 3 Mar 2021 16:01:15 +0800 Subject: [PATCH 26/35] Update --- tutorials/auto_scheduler/tune_sparse_x86.py | 52 +++++++++++++++++++-- 1 file changed, 49 insertions(+), 3 deletions(-) diff --git a/tutorials/auto_scheduler/tune_sparse_x86.py b/tutorials/auto_scheduler/tune_sparse_x86.py index 1f22ca441dfb..07a6d2ae2f46 100644 --- a/tutorials/auto_scheduler/tune_sparse_x86.py +++ b/tutorials/auto_scheduler/tune_sparse_x86.py @@ -219,7 +219,8 @@ def apply_func(search_policy, state, stage_id): # * :code:`num_measure_trials` is the number of measurement trials we can use during the search. # We only make 10 trials in this tutorial for a fast demonstration. In practice, 1000 is a # good value for the search to converge. You can do more trials according to your time budget. -# * In addition, we use :code:`RecordToFile` to dump measurement records into a file `matmul.json`. +# * In addition, we use :code:`RecordToFile` to dump measurement records into a file +# `sparse_dense.json`. # The measurement records can be used to query the history best, resume the search, # and do more analyses later. # * see :any:`auto_scheduler.TuningOptions` for more parameters @@ -228,7 +229,7 @@ def apply_func(search_policy, state, stage_id): log_file = "sparse_dense.json" tune_option = auto_scheduler.TuningOptions( - num_measure_trials=2, + num_measure_trials=10, measure_callbacks=[auto_scheduler.RecordToFile(log_file)], verbose=2, ) @@ -250,7 +251,10 @@ def apply_func(search_policy, state, stage_id): # file and apply it. # Run auto-tuning (search) -task.tune(tune_option, search_policy) +# Notice: We do not run the tuning in our webpage server since it takes too long. +# Uncomment the following line to run it by yourself. +# task.tune(tune_option, search_policy) + # Apply the best schedule sch, args = task.apply_best(log_file) @@ -292,3 +296,45 @@ def apply_func(search_policy, state, stage_id): * 1000 ) ) + +###################################################################### +# .. note:: Tuning result example +# +# .. code-block:: c +# +# ---------------------------------------------------------------------- +# Lowered TIR: +# primfn(placeholder_5: handle, placeholder_6: handle, placeholder_7: handle, placeholder_8: handle, placeholder_9: handle, compute_1: handle) -> () +# attr = {"global_symbol": "main", "tir.noalias": True} +# buffers = {placeholder_2: Buffer(placeholder_10: Pointer(float32), float32, [9831, 16, 1], []), +# placeholder_4: Buffer(placeholder_11: Pointer(int32), int32, [33], []), +# placeholder_3: Buffer(placeholder_12: Pointer(float32), float32, [512, 512], []), +# compute: Buffer(compute_2: Pointer(float32), float32, [512, 512], []), +# placeholder_1: Buffer(placeholder_13: Pointer(float32), float32, [512, 512], []), +# placeholder: Buffer(placeholder_14: Pointer(int32), int32, [9831], [])} +# buffer_map = {placeholder_7: placeholder, placeholder_9: placeholder_1, placeholder_6: placeholder_2, compute_1: compute, placeholder_5: placeholder_3, placeholder_8: placeholder_4} { +# for (i0.outer.i1.outer.fused: int32, 0, 1024) "parallel" { +# attr [compute_3: Pointer(float32)] "storage_scope" = "global"; +# allocate(compute_3, float32, [256]) { +# for (nb_j.inner: int32, 0, 2) { +# for (i.inner.init: int32, 0, 8) { +# for (j.init: int32, 0, 16) { +# compute_3[(((i.inner.init*32) + (nb_j.inner*16)) + j.init)] = 0f32 +# } +# } +# for (elem_idx: int32, 0, ((int32*)placeholder_11[(((floormod(i0.outer.i1.outer.fused, 16)*2) + nb_j.inner) + 1)] - (int32*)placeholder_11[((floormod(i0.outer.i1.outer.fused, 16)*2) + nb_j.inner)])) { +# for (i.inner: int32, 0, 8) { +# for (j: int32, 0, 16) { +# compute_3[(((i.inner*32) + (nb_j.inner*16)) + j)] = ((float32*)compute_3[(((i.inner*32) + (nb_j.inner*16)) + j)] + ((float32*)placeholder_10[((((int32*)placeholder_11[((floormod(i0.outer.i1.outer.fused, 16)*2) + nb_j.inner)]*16) + (elem_idx*16)) + j)]*max((float32*)placeholder_12[(((floordiv(i0.outer.i1.outer.fused, 16)*4096) + (i.inner*512)) + (int32*)placeholder_14[((int32*)placeholder_11[((floormod(i0.outer.i1.outer.fused, 16)*2) + nb_j.inner)] + elem_idx)])], 0f32))) +# } +# } +# } +# } +# for (i0.inner: int32, 0, 8) { +# compute_2[ramp((((floordiv(i0.outer.i1.outer.fused, 16)*4096) + (i0.inner*512)) + (floormod(i0.outer.i1.outer.fused, 16)*32)), 1, 32)] = max(((float32x32*)compute_3[ramp((i0.inner*32), 1, 32)] + (float32x32*)placeholder_13[ramp((((floordiv(i0.outer.i1.outer.fused, 16)*4096) + (i0.inner*512)) + (floormod(i0.outer.i1.outer.fused, 16)*32)), 1, 32)]), broadcast(0f32, 32)) +# } +# } +# } +# } +# +# Execution time of this operator: 0.990 ms From 56c01d9daff63e3bc9f9d537e48831e01a0b0462 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 3 Mar 2021 16:10:19 +0800 Subject: [PATCH 27/35] Add example ci_log --- tutorials/auto_scheduler/ci_logs/sparse_dense.json | 2 ++ tutorials/auto_scheduler/tune_sparse_x86.py | 4 +--- 2 files changed, 3 insertions(+), 3 deletions(-) create mode 100644 tutorials/auto_scheduler/ci_logs/sparse_dense.json diff --git a/tutorials/auto_scheduler/ci_logs/sparse_dense.json b/tutorials/auto_scheduler/ci_logs/sparse_dense.json new file mode 100644 index 000000000000..7c1c100124dc --- /dev/null +++ b/tutorials/auto_scheduler/ci_logs/sparse_dense.json @@ -0,0 +1,2 @@ +# Keep a valid schedule for demonstraction. This is used to prevent flasky errors in CI. +{"i": [["[\"sparse_dense\", 512, 512, 512, [9831, 16, 1], [9831], [33], \"float32\"]", "llvm -keys=cpu -link-params=0", [6, 64, 64, 0, 0, 0, 0, 0], "", 1, ["sparse_dense_bsr_512_512_512_16_1_0.60_W_data", "sparse_dense_bsr_512_512_512_16_1_0.60_W_indices", "sparse_dense_bsr_512_512_512_16_1_0.60_W_indptr"]], [[], [["CI", 8], ["CI", 6], ["SP", 5, 0, 512, [1, 8], 1], ["FSP", 9, 0, 2, 1], ["SP", 5, 3, 32, [32], 1], ["FSP", 9, 2, 4, 1], ["RE", 5, [0, 3, 1, 4, 6, 2, 5, 7]], ["RE", 9, [0, 2, 1, 3]], ["CA", 5, 9, 1], ["CI", 4], ["FU", 9, [0, 1]], ["AN", 9, 0, 3], ["PR", 5, 0, "auto_unroll_max_step$0"], ["AN", 9, 2, 2]]]], "r": [[0.000957008], 0, 0.605709, 1614689820], "v": "v0.6"} diff --git a/tutorials/auto_scheduler/tune_sparse_x86.py b/tutorials/auto_scheduler/tune_sparse_x86.py index 07a6d2ae2f46..f13446295ffb 100644 --- a/tutorials/auto_scheduler/tune_sparse_x86.py +++ b/tutorials/auto_scheduler/tune_sparse_x86.py @@ -128,7 +128,7 @@ def sparse_dense(M, N, K, w_data_shape, w_indices_shape, w_indptr_shape, dtype): # - replace "llvm" below with "llvm -mcpu=core-avx2" to enable AVX2 # - replace "llvm" below with "llvm -mcpu=skylake-avx512" to enable AVX-512 -target = tvm.target.Target("llvm -mcpu=core-avx2") +target = tvm.target.Target("llvm") # Register the sparse data to task inputs prefix = "sparse_dense_bsr_%d_%d_%d_%d_%d_%.2f_" % (M, N, K, BS_R, BS_C, density) @@ -336,5 +336,3 @@ def apply_func(search_policy, state, stage_id): # } # } # } -# -# Execution time of this operator: 0.990 ms From b5a18323de859e61bb03ef5739cbd7660f33da74 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Thu, 4 Mar 2021 09:55:52 +0800 Subject: [PATCH 28/35] Update --- tests/python/unittest/test_auto_scheduler_measure.py | 2 +- tests/python/unittest/test_auto_scheduler_search_task.py | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index 400b3eddcfe8..116981028cc9 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -27,7 +27,7 @@ import tvm.testing import pickle -from test_auto_scheduler_common import matmul_auto_scheduler_test, get_tiled_matmul +from test_auto_scheduler_common import matmul_auto_scheduler_test from tvm.auto_scheduler import workload_registry diff --git a/tests/python/unittest/test_auto_scheduler_search_task.py b/tests/python/unittest/test_auto_scheduler_search_task.py index dc8fdcbc824b..f8de28c628b6 100644 --- a/tests/python/unittest/test_auto_scheduler_search_task.py +++ b/tests/python/unittest/test_auto_scheduler_search_task.py @@ -17,8 +17,6 @@ """Test search policy""" -import random -import multiprocessing import numpy as np import tempfile @@ -32,7 +30,6 @@ zero_rank_compute_auto_scheduler_test, zero_rank_reduce_auto_scheduler_test, ) -import multiprocessing def test_search_task_add_task_input(): From 84b277dd290ff74a718ae04da7ff4165764a2bba Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Thu, 4 Mar 2021 18:41:55 +0800 Subject: [PATCH 29/35] retrigger ci --- tests/python/unittest/test_auto_scheduler_search_task.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/unittest/test_auto_scheduler_search_task.py b/tests/python/unittest/test_auto_scheduler_search_task.py index f8de28c628b6..391d9a2ec787 100644 --- a/tests/python/unittest/test_auto_scheduler_search_task.py +++ b/tests/python/unittest/test_auto_scheduler_search_task.py @@ -24,7 +24,6 @@ import tvm.testing from tvm import auto_scheduler from tvm.auto_scheduler.utils import get_const_tuple - from test_auto_scheduler_common import ( matmul_auto_scheduler_test, zero_rank_compute_auto_scheduler_test, From 3bd6b6f558f36ca9221e2bdeb7c25e86a2d4e037 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Fri, 5 Mar 2021 10:14:20 +0800 Subject: [PATCH 30/35] Update --- include/tvm/auto_scheduler/search_task.h | 8 ++--- python/tvm/auto_scheduler/measure.py | 28 ++++++++-------- python/tvm/auto_scheduler/search_task.py | 32 ++++++++++++------- src/auto_scheduler/feature.cc | 4 +-- src/auto_scheduler/measure_record.cc | 4 +-- src/auto_scheduler/search_task.cc | 8 ++--- .../test_auto_scheduler_search_task.py | 32 +++++++++---------- 7 files changed, 63 insertions(+), 53 deletions(-) diff --git a/include/tvm/auto_scheduler/search_task.h b/include/tvm/auto_scheduler/search_task.h index c90ef75c25c7..14bf55abb447 100755 --- a/include/tvm/auto_scheduler/search_task.h +++ b/include/tvm/auto_scheduler/search_task.h @@ -122,7 +122,7 @@ class SearchTaskNode : public Object { /*! \brief The layout rewrite option used for measuring programs. */ LayoutRewriteOption layout_rewrite_option; /*! \brief Names of some user defined input data used in program measuring. */ - Array task_inputs; + Array task_input_names; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("compute_dag", &compute_dag); @@ -131,7 +131,7 @@ class SearchTaskNode : public Object { v->Visit("target_host", &target_host); v->Visit("hardware_params", &hardware_params); v->Visit("layout_rewrite_option", &layout_rewrite_option); - v->Visit("task_inputs", &task_inputs); + v->Visit("task_input_names", &task_input_names); } static constexpr const char* _type_key = "auto_scheduler.SearchTask"; @@ -152,11 +152,11 @@ class SearchTask : public ObjectRef { * \param target_host The target host device of this search task. * \param hardware_params Hardware parameters used in this search task. * \param layout_rewrite_option The layout rewrite option used for measuring programs. - * \param task_inputs Names of some user defined input data used in program measuring. + * \param task_input_names Names of some user defined input data used in program measuring. */ SearchTask(ComputeDAG compute_dag, String workload_key, Target target, Target target_host, Optional hardware_params, LayoutRewriteOption layout_rewrite_option, - Array task_inputs); + Array task_input_names); TVM_DEFINE_OBJECT_REF_METHODS(SearchTask, ObjectRef, SearchTaskNode); }; diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index d4607f5a480b..e9c9414fd1bd 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -223,7 +223,7 @@ def recover_measure_input(inp, rebuild_state=False): target_host=task.target_host, hardware_params=task.hardware_params, layout_rewrite_option=task.layout_rewrite_option, - task_inputs=list(task.task_inputs), + task_inputs=list(task.task_input_names), ) if rebuild_state: @@ -721,7 +721,8 @@ def local_builder_build(inputs, timeout, n_parallel, build_func="default", verbo def _prepare_input_map(args): - """This function deals with special task inputs. + """This function deals with special task inputs. Map the input Tensor of a TVM subgraph + to a specific buffer name in the global buffer map. Parameters ---------- @@ -730,10 +731,11 @@ def _prepare_input_map(args): Returns ------- - A Dict[Tensor, str] that maps the input Tensor to a buffer name. + Dict[Tensor, str] : + Map from the input Tensor to its buffer name. - Note - ---- + Notes + ----- The buffer name is specially designed, and these buffer should be provided in `SearchTask(..., task_inputs={...})`. """ @@ -773,7 +775,7 @@ def _timed_eval_func( from .search_task import get_task_input_buffer # lazily import to avoid recursive dependency inp = MeasureInput.deserialize(inp_serialized) - task_inputs = inp.task.task_inputs + task_input_names = inp.task.task_input_names tic = time.time() error_no = 0 error_msg = None @@ -805,17 +807,17 @@ def _timed_eval_func( random_fill = tvm.get_global_func("tvm.contrib.random.random_fill", True) assert random_fill, "Please make sure USE_RANDOM is ON in the config.cmake" - tensor_input_map = _prepare_input_map(build_res.args) if task_inputs else {} + tensor_input_map = _prepare_input_map(build_res.args) if task_input_names else {} args = [] for arg in build_res.args: if arg in tensor_input_map: tensor_name = tensor_input_map[arg] - if tensor_name in task_inputs: + if tensor_name in task_input_names: args.append(get_task_input_buffer(inp.task.workload_key, tensor_name)) else: raise ValueError( "%s not found in task_inputs, " % (tensor_name) - + "should provide with SearchTask.AddTaskInput()" + + "should provide with `SearchTask(..., task_inputs={...})`" ) else: empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, ctx) @@ -973,7 +975,7 @@ def _timed_rpc_run( from .search_task import get_task_input_buffer # lazily import to avoid recursive dependency inp = MeasureInput.deserialize(inp_serialized) - task_inputs = inp.task.task_inputs + task_input_names = inp.task.task_input_names tic = time.time() error_no = 0 error_msg = None @@ -1010,17 +1012,17 @@ def _timed_rpc_run( random_fill ), "Please make sure USE_RANDOM is ON in the config.cmake on the remote devices" - tensor_input_map = _prepare_input_map(build_res.args) if task_inputs else {} + tensor_input_map = _prepare_input_map(build_res.args) if task_input_names else {} args = [] for arg in build_res.args: if arg in tensor_input_map: tensor_name = tensor_input_map[arg] - if tensor_name in task_inputs: + if tensor_name in task_input_names: args.append(get_task_input_buffer(inp.task.workload_key, tensor_name)) else: raise ValueError( "%s not found in task_inputs, " % (tensor_name) - + "should provide with SearchTask.AddTaskInput()" + + "should provide with `SearchTask(..., task_inputs={...})`" ) else: empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, ctx) diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index 81369e8d5400..c6d0d1ac7832 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -164,8 +164,8 @@ def __init__( ) -# The map stores special registered buffer for measurement -# This can be used for sparse workloads when we cannot use random tensors for measurment. +# The map stores special registered buffer for measurement. +# This can be used for sparse workloads when we cannot use random tensors for measurment. # { # "workload_key_0": { # "task_input_0": Tensor(...), @@ -183,7 +183,7 @@ def __init__( def _save_buffer_to_file(buffer_name, buffer_data): """Save the current Tensor buffer to a numpy file. - File name will be: {buffer_name}.{buffer_shape}_{buffer_data_type} + File name will be: {buffer_name}_{buffer_shape}_{buffer_data_type}.npy """ np_data = buffer_data.asnumpy() @@ -191,6 +191,7 @@ def _save_buffer_to_file(buffer_name, buffer_data): for i in np_data.shape: buffer_name += "%d_" % (i) buffer_name += "%s" % (np_data.dtype) + buffer_name += ".npy" np_data.tofile(buffer_name, " ") @@ -204,7 +205,7 @@ def _try_load_buffer_from_file(buffer_name): for file in filelist: if file.startswith(buffer_name) and file.count("."): - meta_info = file.split(".")[-1].split("_") + meta_info = file.split(".")[-2].split("_") shape = [int(i) for i in meta_info[:-1]] dtype = meta_info[-1] buffer_data = np.fromfile(file, dtype=dtype, sep=" ") @@ -235,11 +236,17 @@ def register_task_input_buffer( The input Tensor data. overwrite : bool = False - Whether overwrite the data if a name has already in the global table. + Whether to overwrite the data if a name has already registered. save_to_file : bool = False - Whether record this buffer to a local file. This can be reused to continue the last tuning - process. + Whether to save the data to a local file as well. This can be reused to resume the last + tuning process. + + Returns + ------- + tvm.nd.NDArray + The actual registered Tensor data of this input_name. With `overwrite` set to False, will + return the original one if the name has already registered before. """ global TASK_INPUT_BUFFER_TABLE @@ -284,7 +291,8 @@ def get_task_input_buffer(workload_key, input_name): Returns ------- - The registered input buffer. + tvm.nd.NDArray + The registered input buffer. """ global TASK_INPUT_BUFFER_TABLE @@ -341,9 +349,9 @@ class SearchTask(Object): about it, but for special workloads like Sparse computation the Sparse Tensor input are meaningful that we cannot use random input directly. task_inputs_overwrite : bool = False - Whether overwrite the data if a name has already in the global table. + Whether to overwrite the data if a name has already in the global table. task_inputs_save_to_file : bool = False - Whether record this buffer to a local file. This can be reused to continue the last + Whether to save the data to a local file as well. This can be reused to resume the last tuning process. Examples @@ -506,7 +514,7 @@ def __getstate__(self): "target_host": self.target_host, "hardware_params": self.hardware_params, "layout_rewrite_option": self.layout_rewrite_option, - "task_inputs": self.task_inputs, + "task_input_names": self.task_input_names, } def __setstate__(self, state): @@ -531,7 +539,7 @@ def __setstate__(self, state): state["target_host"], state["hardware_params"], state["layout_rewrite_option"], - state["task_inputs"], + state["task_input_names"], ) diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc index dde43a7caacd..c70c9497c063 100755 --- a/src/auto_scheduler/feature.cc +++ b/src/auto_scheduler/feature.cc @@ -1399,7 +1399,7 @@ void GetPerStoreFeaturesFromFile(const std::string& filename, int max_lines, int Array tensors = (*workload_key_to_tensors)(workload_key); task = SearchTask(ComputeDAG(tensors), workload_key, cur_inp->task->target, cur_inp->task->target_host, cur_inp->task->hardware_params, - cur_inp->task->layout_rewrite_option, cur_inp->task->task_inputs); + cur_inp->task->layout_rewrite_option, cur_inp->task->task_input_names); task_id = task_cache.size(); // compute min cost for each task @@ -1468,7 +1468,7 @@ void GetPerStoreFeaturesFromMeasurePairs(const Array& inputs, Array tensors = (*workload_key_to_tensors)(workload_key); task = SearchTask(ComputeDAG(tensors), workload_key, inputs[i]->task->target, inputs[i]->task->target_host, inputs[i]->task->hardware_params, - inputs[i]->task->layout_rewrite_option, inputs[i]->task->task_inputs); + inputs[i]->task->layout_rewrite_option, inputs[i]->task->task_input_names); } catch (std::exception& e) { // Cannot build ComputeDAG from workload key, the task may have not been registered in // this search round diff --git a/src/auto_scheduler/measure_record.cc b/src/auto_scheduler/measure_record.cc index 00f717026498..5dafa8d98702 100644 --- a/src/auto_scheduler/measure_record.cc +++ b/src/auto_scheduler/measure_record.cc @@ -171,7 +171,7 @@ struct Handler<::tvm::auto_scheduler::SearchTaskNode> { writer->WriteArrayItem(static_cast(data.layout_rewrite_option)); writer->WriteArraySeperator(); writer->BeginArray(false); - for (const auto& i : data.task_inputs) { + for (const auto& i : data.task_input_names) { writer->WriteArrayItem(std::string(i)); } writer->EndArray(); @@ -211,7 +211,7 @@ struct Handler<::tvm::auto_scheduler::SearchTaskNode> { s = reader->NextArrayItem(); while (s) { reader->Read(&str_value); - data->task_inputs.push_back(str_value); + data->task_input_names.push_back(str_value); s = reader->NextArrayItem(); } // Process the end of array diff --git a/src/auto_scheduler/search_task.cc b/src/auto_scheduler/search_task.cc index 28cafb1d2706..22c2893141cf 100755 --- a/src/auto_scheduler/search_task.cc +++ b/src/auto_scheduler/search_task.cc @@ -114,7 +114,7 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target SearchTask::SearchTask(ComputeDAG compute_dag, String workload_key, Target target, Target target_host, Optional hardware_params, - LayoutRewriteOption layout_rewrite_option, Array task_inputs) { + LayoutRewriteOption layout_rewrite_option, Array task_input_names) { auto node = make_object(); node->compute_dag = std::move(compute_dag); node->workload_key = std::move(workload_key); @@ -127,7 +127,7 @@ SearchTask::SearchTask(ComputeDAG compute_dag, String workload_key, Target targe HardwareParamsNode::GetDefaultHardwareParams(node->target, node->target_host); } node->layout_rewrite_option = layout_rewrite_option; - node->task_inputs = std::move(task_inputs); + node->task_input_names = std::move(task_input_names); data_ = std::move(node); } @@ -143,9 +143,9 @@ TVM_REGISTER_GLOBAL("auto_scheduler.HardwareParams") TVM_REGISTER_GLOBAL("auto_scheduler.SearchTask") .set_body_typed([](ComputeDAG compute_dag, String workload_key, Target target, Target target_host, Optional hardware_params, - int layout_rewrite_option, Array task_inputs) { + int layout_rewrite_option, Array task_input_names) { return SearchTask(compute_dag, workload_key, target, target_host, hardware_params, - LayoutRewriteOption(layout_rewrite_option), task_inputs); + LayoutRewriteOption(layout_rewrite_option), task_input_names); }); } // namespace auto_scheduler diff --git a/tests/python/unittest/test_auto_scheduler_search_task.py b/tests/python/unittest/test_auto_scheduler_search_task.py index 391d9a2ec787..78e85dc213e0 100644 --- a/tests/python/unittest/test_auto_scheduler_search_task.py +++ b/tests/python/unittest/test_auto_scheduler_search_task.py @@ -50,10 +50,10 @@ def test_search_task_add_task_input(): task_inputs_overwrite=True, ) - assert len(task.task_inputs) == 3 - assert task.task_inputs[0] == "test_input_0" - assert task.task_inputs[1] == "test_input_1" - assert task.task_inputs[2] == "test_input_2" + assert len(task.task_input_names) == 3 + assert task.task_input_names[0] == "test_input_0" + assert task.task_input_names[1] == "test_input_1" + assert task.task_input_names[2] == "test_input_2" def test_search_task_record(): @@ -88,8 +88,8 @@ def test_search_task_record(): assert str(task.target) == str(new_task.target) assert str(task.target_host) == str(new_task.target_host) assert task.layout_rewrite_option == new_task.layout_rewrite_option - assert len(new_task.task_inputs) == 1 - assert new_task.task_inputs[0] == "test_input_0" + assert len(new_task.task_input_names) == 1 + assert new_task.task_input_names[0] == "test_input_0" # Log with multiple task inputs test_input_1 = tvm.runtime.ndarray.empty((64, 64)) @@ -109,9 +109,9 @@ def test_search_task_record(): assert str(task.target) == str(new_task.target) assert str(task.target_host) == str(new_task.target_host) assert task.layout_rewrite_option == new_task.layout_rewrite_option - assert len(new_task.task_inputs) == 2 - assert new_task.task_inputs[0] == "test_input_0" - assert new_task.task_inputs[1] == "test_input_1" + assert len(new_task.task_input_names) == 2 + assert new_task.task_input_names[0] == "test_input_0" + assert new_task.task_input_names[1] == "test_input_1" # Log with version 0.5 v5_log = """["[\\\"matmul_auto_scheduler_test\\\", 64, 64, 64]", "llvm -keys=cpu -link-params=0", [6, 64, 64, 0, 0, 0, 0, 0], "", 1]""" @@ -120,7 +120,7 @@ def test_search_task_record(): assert str(task.target) == str(new_task.target) assert str(task.target_host) == str(new_task.target_host) assert task.layout_rewrite_option == new_task.layout_rewrite_option - assert len(new_task.task_inputs) == 0 + assert len(new_task.task_input_names) == 0 def test_recover_measure_input_with_task_input(): @@ -162,8 +162,8 @@ def test_recover_measure_input_with_task_input(): assert str(task.target) == str(new_task.target) assert str(task.target_host) == str(new_task.target_host) assert task.layout_rewrite_option == new_task.layout_rewrite_option - assert len(new_task.task_inputs) == 1 - assert new_task.task_inputs[0] == "test_input_0" + assert len(new_task.task_input_names) == 1 + assert new_task.task_input_names[0] == "test_input_0" # Log with multiple task inputs test_input_1 = tvm.runtime.ndarray.empty((64, 64)) @@ -186,9 +186,9 @@ def test_recover_measure_input_with_task_input(): assert str(task.target) == str(new_task.target) assert str(task.target_host) == str(new_task.target_host) assert task.layout_rewrite_option == new_task.layout_rewrite_option - assert len(new_task.task_inputs) == 2 - assert new_task.task_inputs[0] == "test_input_0" - assert new_task.task_inputs[1] == "test_input_1" + assert len(new_task.task_input_names) == 2 + assert new_task.task_input_names[0] == "test_input_0" + assert new_task.task_input_names[1] == "test_input_1" # Log with version 0.5 v5_log = """{"i": [["[\\\"matmul_auto_scheduler_test\\\", 512, 512, 512]", "llvm -keys=cpu -link-params=0", [6, 64, 64, 0, 0, 0, 0, 0], "", 1], [[], []]], "r": [[0.1], 0, 0.2, 1], "v": "v0.6"}""" @@ -198,7 +198,7 @@ def test_recover_measure_input_with_task_input(): assert str(task.target) == str(new_task.target) assert str(task.target_host) == str(new_task.target_host) assert task.layout_rewrite_option == new_task.layout_rewrite_option - assert len(new_task.task_inputs) == 0 + assert len(new_task.task_input_names) == 0 if __name__ == "__main__": From de4170e39dfaa5d3764d6da44430a672568971db Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Fri, 5 Mar 2021 10:39:14 +0800 Subject: [PATCH 31/35] Update --- python/tvm/auto_scheduler/measure.py | 17 ++++++++++++++++- python/tvm/topi/nn/sparse.py | 7 ++++--- tutorials/auto_scheduler/tune_sparse_x86.py | 3 ++- 3 files changed, 22 insertions(+), 5 deletions(-) diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index e9c9414fd1bd..f3012b344f99 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -36,6 +36,7 @@ import shutil import tempfile import multiprocessing +import logging import tvm._ffi from tvm.runtime import Object, module, ndarray @@ -58,6 +59,8 @@ deserialize_workload_registry_entry, ) +# pylint: disable=invalid-name +logger = logging.getLogger("auto_scheduler") # The time cost for measurements with errors # We use 1e10 instead of sys.float_info.max for better readability in log @@ -731,7 +734,7 @@ def _prepare_input_map(args): Returns ------- - Dict[Tensor, str] : + Dict[Tensor, str] : Map from the input Tensor to its buffer name. Notes @@ -809,11 +812,13 @@ def _timed_eval_func( tensor_input_map = _prepare_input_map(build_res.args) if task_input_names else {} args = [] + task_inputs_count = 0 for arg in build_res.args: if arg in tensor_input_map: tensor_name = tensor_input_map[arg] if tensor_name in task_input_names: args.append(get_task_input_buffer(inp.task.workload_key, tensor_name)) + task_inputs_count += 1 else: raise ValueError( "%s not found in task_inputs, " % (tensor_name) @@ -823,6 +828,10 @@ def _timed_eval_func( empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, ctx) random_fill(empty_array) args.append(empty_array) + if task_inputs_count != len(task_input_names): + logger.warning( + "task_inputs not fully matched, check if there's any unexpected error" + ) ctx.sync() costs = time_f(*args).results # pylint: disable=broad-except @@ -1014,11 +1023,13 @@ def _timed_rpc_run( tensor_input_map = _prepare_input_map(build_res.args) if task_input_names else {} args = [] + task_inputs_count = 0 for arg in build_res.args: if arg in tensor_input_map: tensor_name = tensor_input_map[arg] if tensor_name in task_input_names: args.append(get_task_input_buffer(inp.task.workload_key, tensor_name)) + task_inputs_count += 1 else: raise ValueError( "%s not found in task_inputs, " % (tensor_name) @@ -1028,6 +1039,10 @@ def _timed_rpc_run( empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, ctx) random_fill(empty_array) args.append(empty_array) + if task_inputs_count != len(task_input_names): + logger.warning( + "task_inputs not fully matched, check if there's any unexpected error" + ) ctx.sync() costs = time_f(*args).results diff --git a/python/tvm/topi/nn/sparse.py b/python/tvm/topi/nn/sparse.py index 0793cbcdc3a3..1718fcc9b336 100644 --- a/python/tvm/topi/nn/sparse.py +++ b/python/tvm/topi/nn/sparse.py @@ -371,10 +371,11 @@ def try_get_sparse_input(args): Returns ------- - A Dict[Tensor, str] that maps the input Tensor to a buffer name. + Dict[Tensor, str] : + Map from the input Tensor to its buffer name. - Note - ---- + Notes + ----- The buffer name is specially designed, and these buffer should be provided in `SearchTask(..., task_inputs={...})`. """ diff --git a/tutorials/auto_scheduler/tune_sparse_x86.py b/tutorials/auto_scheduler/tune_sparse_x86.py index f13446295ffb..ced416f6c500 100644 --- a/tutorials/auto_scheduler/tune_sparse_x86.py +++ b/tutorials/auto_scheduler/tune_sparse_x86.py @@ -141,6 +141,7 @@ def sparse_dense(M, N, K, w_data_shape, w_indices_shape, w_indptr_shape, dtype): prefix + "W_indices": runtime.ndarray.array(W_sp_np.indices), prefix + "W_indptr": runtime.ndarray.array(W_sp_np.indptr), }, + task_inputs_save_to_file=True, ) # Inspect the computational graph @@ -253,7 +254,7 @@ def apply_func(search_policy, state, stage_id): # Run auto-tuning (search) # Notice: We do not run the tuning in our webpage server since it takes too long. # Uncomment the following line to run it by yourself. -# task.tune(tune_option, search_policy) +task.tune(tune_option, search_policy) # Apply the best schedule sch, args = task.apply_best(log_file) From 7e45641f5eadca79f9d2c515d75981ec1cc46c42 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Fri, 5 Mar 2021 11:11:16 +0800 Subject: [PATCH 32/35] Update --- python/tvm/auto_scheduler/__init__.py | 1 + python/tvm/auto_scheduler/measure.py | 64 +++++++++++++++++++++--- python/tvm/auto_scheduler/search_task.py | 4 +- python/tvm/topi/nn/sparse.py | 3 +- 4 files changed, 63 insertions(+), 9 deletions(-) diff --git a/python/tvm/auto_scheduler/__init__.py b/python/tvm/auto_scheduler/__init__.py index 06ca44d997e5..ff6d82a0242c 100644 --- a/python/tvm/auto_scheduler/__init__.py +++ b/python/tvm/auto_scheduler/__init__.py @@ -41,6 +41,7 @@ LocalRunner, RPCRunner, LocalRPCMeasureContext, + register_task_input_check_func, ) from .measure_record import RecordToFile, RecordReader, load_best_record, load_records, save_records from .relay_integration import ( diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index f3012b344f99..82948f5687a0 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -51,6 +51,7 @@ call_func_with_timeout, check_remote, get_const_tuple, + get_func_name, make_traceback_info, request_remote, ) @@ -723,6 +724,57 @@ def local_builder_build(inputs, timeout, n_parallel, build_func="default", verbo return results +TASK_INPUT_CHECK_FUNC_REGISTRY = {} + + +def register_task_input_check_func(func_name, f=None, override=False): + """Register a function that checks the input buffer map. + + The input function should take a list of Tensor wich indicate the Input/output Tensor of a TVM + subgraph and return a Map from the input Tensor to its buffer name. + + Parameters + ---------- + func_name : Union[Function, str] + The check function that returns the compute declaration Tensors or its function name. + f : Optional[Function] + The check function to be registered. + override : boolean = False + Whether to override existing entry. + + Examples + -------- + .. code-block:: python + + @auto_scheduler.register_task_input_check_func + def check_task_input_by_placeholder_name(args : List[Tensor]): + tensor_input_map = {} + for arg in args: + if isinstance(arg.op, tvm.te.PlaceholderOp): + if arg.op.name != "placeholder": + tensor_input_map[arg] = arg.op.name + return tensor_input_map + """ + global TASK_INPUT_CHECK_FUNC_REGISTRY + + if callable(func_name): + f = func_name + func_name = get_func_name(f) + if not isinstance(func_name, str): + raise ValueError("expect string function name") + + def register(myf): + """internal register function""" + if func_name in TASK_INPUT_CHECK_FUNC_REGISTRY and not override: + raise RuntimeError("%s has been registered already" % func_name) + TASK_INPUT_CHECK_FUNC_REGISTRY[func_name] = myf + return myf + + if f: + return register(f) + return register + + def _prepare_input_map(args): """This function deals with special task inputs. Map the input Tensor of a TVM subgraph to a specific buffer name in the global buffer map. @@ -745,6 +797,8 @@ def _prepare_input_map(args): # pylint: disable=import-outside-toplevel from tvm import topi # lazily import to avoid recursive dependency + global TASK_INPUT_CHECK_FUNC_REGISTRY + # A dict that maps the input tensor arg to a buffer name tensor_input_map = {} @@ -754,12 +808,10 @@ def _prepare_input_map(args): if arg.op.name != "placeholder": tensor_input_map[arg] = arg.op.name - # Case 1: Check sparse op - sparse_input_map = topi.nn.sparse.try_get_sparse_input(args) - tensor_input_map.update(sparse_input_map) - - # Case 2: Check ... - # Process any other special buffers here and update them to tensor_input_map + # Case 1: Check specific tensor inputs + for func_name in TASK_INPUT_CHECK_FUNC_REGISTRY: + func = TASK_INPUT_CHECK_FUNC_REGISTRY[func_name] + tensor_input_map.update(func(args)) return tensor_input_map diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index c6d0d1ac7832..71a13f8f91a9 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -183,7 +183,7 @@ def __init__( def _save_buffer_to_file(buffer_name, buffer_data): """Save the current Tensor buffer to a numpy file. - File name will be: {buffer_name}_{buffer_shape}_{buffer_data_type}.npy + File name will be: {buffer_name}.{buffer_shape}_{buffer_data_type}.npy """ np_data = buffer_data.asnumpy() @@ -204,7 +204,7 @@ def _try_load_buffer_from_file(buffer_name): filelist = os.listdir() for file in filelist: - if file.startswith(buffer_name) and file.count("."): + if file.startswith(buffer_name + "."): meta_info = file.split(".")[-2].split("_") shape = [int(i) for i in meta_info[:-1]] dtype = meta_info[-1] diff --git a/python/tvm/topi/nn/sparse.py b/python/tvm/topi/nn/sparse.py index 1718fcc9b336..1bf18df09da3 100644 --- a/python/tvm/topi/nn/sparse.py +++ b/python/tvm/topi/nn/sparse.py @@ -18,7 +18,7 @@ """Sparse operators""" from __future__ import absolute_import import tvm -from tvm import te +from tvm import te, auto_scheduler from ..utils import get_const_tuple @@ -361,6 +361,7 @@ def sparse_dense_alter_layout(_attrs, _inputs, _tinfos, _out_type): return None +@auto_scheduler.register_task_input_check_func def try_get_sparse_input(args): """Analyze the input data from the given args. From 7b47a0633e20a4e7217d56b4bc3da513972eca1a Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Fri, 5 Mar 2021 11:41:55 +0800 Subject: [PATCH 33/35] Lint fix --- python/tvm/auto_scheduler/search_task.py | 1 - src/auto_scheduler/feature.cc | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index 71a13f8f91a9..57e239cf79e8 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -275,7 +275,6 @@ def register_task_input_buffer( return input_data -@tvm._ffi.register_func("auto_scheduler.search_task.get_task_input_buffer") def get_task_input_buffer(workload_key, input_name): """Get special buffer for measurement. diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc index c70c9497c063..e4b4833e9688 100755 --- a/src/auto_scheduler/feature.cc +++ b/src/auto_scheduler/feature.cc @@ -1468,7 +1468,8 @@ void GetPerStoreFeaturesFromMeasurePairs(const Array& inputs, Array tensors = (*workload_key_to_tensors)(workload_key); task = SearchTask(ComputeDAG(tensors), workload_key, inputs[i]->task->target, inputs[i]->task->target_host, inputs[i]->task->hardware_params, - inputs[i]->task->layout_rewrite_option, inputs[i]->task->task_input_names); + inputs[i]->task->layout_rewrite_option, + inputs[i]->task->task_input_names); } catch (std::exception& e) { // Cannot build ComputeDAG from workload key, the task may have not been registered in // this search round From eeb9b3c9b2adcbd300274d5f0d9d21c7558566c6 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Fri, 5 Mar 2021 13:26:48 +0800 Subject: [PATCH 34/35] Lint fix --- src/auto_scheduler/feature.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc index e4b4833e9688..d93218c0208c 100755 --- a/src/auto_scheduler/feature.cc +++ b/src/auto_scheduler/feature.cc @@ -1466,10 +1466,10 @@ void GetPerStoreFeaturesFromMeasurePairs(const Array& inputs, // The measure input is incomplete, rebuild task for incomplete measure pairs read from file try { Array tensors = (*workload_key_to_tensors)(workload_key); - task = SearchTask(ComputeDAG(tensors), workload_key, inputs[i]->task->target, - inputs[i]->task->target_host, inputs[i]->task->hardware_params, - inputs[i]->task->layout_rewrite_option, - inputs[i]->task->task_input_names); + task = + SearchTask(ComputeDAG(tensors), workload_key, inputs[i]->task->target, + inputs[i]->task->target_host, inputs[i]->task->hardware_params, + inputs[i]->task->layout_rewrite_option, inputs[i]->task->task_input_names); } catch (std::exception& e) { // Cannot build ComputeDAG from workload key, the task may have not been registered in // this search round From cf4cb420114f2f24add2aaa55902826d185d4191 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Fri, 5 Mar 2021 14:10:15 +0800 Subject: [PATCH 35/35] Lint fix --- python/tvm/auto_scheduler/measure.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 82948f5687a0..959a9c5da82a 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -795,7 +795,6 @@ def _prepare_input_map(args): `SearchTask(..., task_inputs={...})`. """ # pylint: disable=import-outside-toplevel - from tvm import topi # lazily import to avoid recursive dependency global TASK_INPUT_CHECK_FUNC_REGISTRY