From 9e47c62e31e8df1b5ae75ce896388303c0e18ded Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Thu, 11 Mar 2021 18:25:43 +0800 Subject: [PATCH 01/25] Add sparse dense end to end model tuning support --- python/tvm/auto_scheduler/measure.py | 6 +- .../tvm/auto_scheduler/relay_integration.py | 16 +++++ python/tvm/auto_scheduler/search_task.py | 11 +++- python/tvm/relay/analysis/sparse_dense.py | 22 +++++++ python/tvm/topi/nn/sparse.py | 2 +- tutorials/auto_scheduler/tune_network_x86.py | 65 ++++++++++++++++++- tutorials/auto_scheduler/tune_sparse_x86.py | 2 +- 7 files changed, 114 insertions(+), 10 deletions(-) diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 959a9c5da82a..2a8222664e9a 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -775,7 +775,7 @@ def register(myf): return register -def _prepare_input_map(args): +def prepare_input_map(args): """This function deals with special task inputs. Map the input Tensor of a TVM subgraph to a specific buffer name in the global buffer map. @@ -861,7 +861,7 @@ def _timed_eval_func( random_fill = tvm.get_global_func("tvm.contrib.random.random_fill", True) assert random_fill, "Please make sure USE_RANDOM is ON in the config.cmake" - tensor_input_map = _prepare_input_map(build_res.args) if task_input_names else {} + tensor_input_map = prepare_input_map(build_res.args) if task_input_names else {} args = [] task_inputs_count = 0 for arg in build_res.args: @@ -1072,7 +1072,7 @@ def _timed_rpc_run( random_fill ), "Please make sure USE_RANDOM is ON in the config.cmake on the remote devices" - tensor_input_map = _prepare_input_map(build_res.args) if task_input_names else {} + tensor_input_map = prepare_input_map(build_res.args) if task_input_names else {} args = [] task_inputs_count = 0 for arg in build_res.args: diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index 68f53125c7ae..a6adf139820e 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -137,6 +137,11 @@ def extract_tasks( # When auto scheduler is used in end to end network, try to apply layout rewrite # to improve the overall performance layout_rewrite_option=LayoutRewriteOption.get_target_default(target, True), + task_inputs=( + env.wkl_key_to_input_names[wkl_key] + if wkl_key in env.wkl_key_to_input_names + else None + ), ) ) weights.append(weight) @@ -161,6 +166,7 @@ def __init__(self, tracing_mode): self.tracing_mode = tracing_mode self.relay_disable_build_cache = "false" self.wkl_key_to_weight = {} + self.wkl_key_to_input_names = {} def __enter__(self): TracingEnvironment.current = self @@ -181,6 +187,10 @@ def add_workload_key(self, workload_key): self.wkl_key_to_weight[workload_key] = 0 self.wkl_key_to_weight[workload_key] += 1 + def add_workload_input_names(self, workload_key, input_names): + """""" + self.wkl_key_to_input_names[workload_key] = input_names + @tvm._ffi.register_func("auto_scheduler.enter_layout_rewrite") def enter_layout_rewrite(): @@ -269,6 +279,9 @@ def auto_schedule_topi(outs): None in the tracing mode so that the fallback topi schedule will be used. """ # pylint: disable=import-outside-toplevel + from tvm.auto_scheduler.measure import ( + prepare_input_map, + ) # lazily import to avoid recursive dependency io_tensors, has_layout_free, has_complex_op = traverse_to_get_io_tensors(outs) if not io_tensors: # The compute includes dynamic shapes which are not supported yet. @@ -300,6 +313,9 @@ def auto_schedule_topi(outs): # in the task extraction mode if has_complex_op or env.tracing_mode == TracingMode.EXTRACT_TASK: env.add_workload_key(key) + input_map = prepare_input_map(io_tensors) + if input_map: + env.add_workload_input_names(key, list(input_map.values())) elif env.tracing_mode == TracingMode.PREPARE_LAYOUT_REWRITE: # in prepare_layout_rewrite mode if ( diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index 57e239cf79e8..ad093272d4e9 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -175,6 +175,9 @@ def __init__( # "task_input_2": Tensor(...), # "task_input_3": Tensor(...) # }, +# "default": { +# "task_input_4": Tensor(...), +# }, # ... # } TASK_INPUT_BUFFER_TABLE = {} @@ -299,13 +302,17 @@ def get_task_input_buffer(workload_key, input_name): TASK_INPUT_BUFFER_TABLE[workload_key] = {} input_table = TASK_INPUT_BUFFER_TABLE[workload_key] - if input_name not in input_table.keys(): + if input_name not in input_table: # Try to load buffer data from local file tensor_from_file = _try_load_buffer_from_file(input_name) if tensor_from_file: input_table[input_name] = tensor_from_file - if input_name in input_table.keys(): + # Then check for the default table + if input_name not in input_table: + input_table = TASK_INPUT_BUFFER_TABLE["default"] + + if input_name in input_table: return input_table[input_name] raise ValueError( diff --git a/python/tvm/relay/analysis/sparse_dense.py b/python/tvm/relay/analysis/sparse_dense.py index d521748f2311..caffa6525093 100644 --- a/python/tvm/relay/analysis/sparse_dense.py +++ b/python/tvm/relay/analysis/sparse_dense.py @@ -73,6 +73,9 @@ def process_params(expr, params, block_size, sparsity_threshold): ret : Namedtuple[weight_name: Array[String], weight_shape: Array[Array[IntImm]]] return names of qualified dense weight and the shape in BSR format """ + + from tvm.auto_scheduler.search_task import register_task_input_buffer # layzily load + memo = SparseAnalysisResult(weight_name=[], weight_shape=[]) weight_names = _search_dense_op_weight(expr) for name in weight_names: @@ -89,6 +92,25 @@ def process_params(expr, params, block_size, sparsity_threshold): + list(sparse_weight.indices.shape) + list(sparse_weight.indptr.shape) ) + + prefix = "sparse_dense_bsr_%d_%d_%d_%d_%.2f_" % ( + w_np.shape[0], + w_np.shape[1], + block_size[0], + block_size[1], + 1 - sparsity, + ) + + register_task_input_buffer( + "default", prefix + "W_data", tvm.runtime.ndarray.array(sparse_weight.data) + ) + register_task_input_buffer( + "default", prefix + "W_indices", tvm.runtime.ndarray.array(sparse_weight.indices) + ) + register_task_input_buffer( + "default", prefix + "W_indptr", tvm.runtime.ndarray.array(sparse_weight.indptr) + ) + params[name + ".data"] = tvm.nd.array(sparse_weight.data) params[name + ".indices"] = tvm.nd.array(sparse_weight.indices) params[name + ".indptr"] = tvm.nd.array(sparse_weight.indptr) diff --git a/python/tvm/topi/nn/sparse.py b/python/tvm/topi/nn/sparse.py index 1bf18df09da3..04798f92b112 100644 --- a/python/tvm/topi/nn/sparse.py +++ b/python/tvm/topi/nn/sparse.py @@ -426,7 +426,7 @@ def _process_inputs(input_tensors, m, n, prefix_init): density *= i density /= k * n density = density.value - sparse_prefix = "%s_%d_%d_%d_%d_%d_%.2f_" % (prefix_init, m, n, k, bs_r, bs_c, density) + sparse_prefix = "%s_%d_%d_%d_%d_%.2f_" % (prefix_init, n, k, bs_r, bs_c, density) visited = set() diff --git a/tutorials/auto_scheduler/tune_network_x86.py b/tutorials/auto_scheduler/tune_network_x86.py index 8526abbbe6ca..9faaa1a895c6 100644 --- a/tutorials/auto_scheduler/tune_network_x86.py +++ b/tutorials/auto_scheduler/tune_network_x86.py @@ -17,7 +17,8 @@ """ Auto-scheduling a Neural Network for x86 CPU ============================================ -**Author**: `Lianmin Zheng `_ +**Author**: `Lianmin Zheng `_, \ + `Chengfan Jia `_ Auto-tuning for specific devices and workloads is critical for getting the best performance. This is a tutorial on how to tune a whole neural @@ -44,10 +45,13 @@ __name__ == "__main__":` block. """ +import itertools import numpy as np +import scipy.sparse as sp import tvm from tvm import relay, auto_scheduler +from tvm.relay import data_dep_optimization as ddo import tvm.relay.testing from tvm.contrib import graph_runtime @@ -66,6 +70,46 @@ # You can use :ref:`ConvertLayout ` pass to do the layout conversion in TVM. +def random_bsr_matrix(M, N, BS_R, BS_C, density, dtype="float32"): + Y = np.zeros((M, N), dtype=dtype) + assert M % BS_R == 0 + assert N % BS_C == 0 + nnz = int(density * M * N) + num_blocks = int(nnz / (BS_R * BS_C)) + 1 + candidate_blocks = np.asarray(list(itertools.product(range(0, M, BS_R), range(0, N, BS_C)))) + assert candidate_blocks.shape[0] == M // BS_R * N // BS_C + chosen_blocks = candidate_blocks[ + np.random.choice(candidate_blocks.shape[0], size=num_blocks, replace=False) + ] + for i in range(len(chosen_blocks)): + r, c = chosen_blocks[i] + Y[r : r + BS_R, c : c + BS_C] = np.random.uniform(-0.1, 0.1, (BS_R, BS_C)) + s = sp.bsr_matrix(Y, blocksize=(BS_R, BS_C)) + assert s.data.shape == (num_blocks, BS_R, BS_C) + assert s.data.size >= nnz + assert s.indices.shape == (num_blocks,) + assert s.indptr.shape == (M // BS_R + 1,) + return s.todense() + + +def random_sparse_params(func, params, density, BS_R, BS_C): + def deepcopy(param_dic): + ret = {} + for k, v in param_dic.items(): + ret[k] = tvm.nd.array(v.asnumpy()) + return ret + + new_params = deepcopy(params) + dense_weight_names = relay.analysis.sparse_dense._search_dense_op_weight(func) + for item in dense_weight_names: + name = str(item) + shape = new_params[name].shape + if shape[0] % BS_R == 0 and shape[1] % BS_C == 0: + new_w = random_bsr_matrix(shape[0], shape[1], BS_R, BS_C, density) + new_params[name] = tvm.nd.array(new_w) + return new_params + + def get_network(name, batch_size, layout="NHWC", dtype="float32"): """Get the symbol definition and random weight of a network""" @@ -126,6 +170,21 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32"): net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs ) mod = tvm.IRModule.from_expr(net) + elif name == "mlp": + mod, params = relay.testing.mlp.get_workload( + batch_size=batch_size, dtype=dtype, image_shape=image_shape, num_classes=1000 + ) + elif name == "mlp-sparse": + bs_r = 1 + sparsity = 0.85 + + mod, params = relay.testing.mlp.get_workload( + batch_size=batch_size, dtype=dtype, image_shape=image_shape, num_classes=1000 + ) + mod, params = ddo.simplify_fc_transpose.convert(mod["main"], params) + params = random_sparse_params(mod, params, BS_R=bs_r, BS_C=1, density=1 - sparsity) + mod, params = ddo.bsr_dense.convert(mod, params, (bs_r, 1), sparsity_threshold=0.8) + mod = tvm.IRModule.from_expr(mod) return mod, params, input_shape, output_shape @@ -183,7 +242,7 @@ def run_tuning(): print("Begin tuning...") tuner = auto_scheduler.TaskScheduler(tasks, task_weights) tune_option = auto_scheduler.TuningOptions( - num_measure_trials=200, # change this to 20000 to achieve the best performance + num_measure_trials=20, # change this to 20000 to achieve the best performance runner=auto_scheduler.LocalRunner(repeat=10, enable_cpu_cache_flush=True), measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) @@ -194,7 +253,7 @@ def run_tuning(): # We do not run the tuning in our webpage server since it takes too long. # Uncomment the following line to run it by yourself. -# run_tuning() +run_tuning() ###################################################################### diff --git a/tutorials/auto_scheduler/tune_sparse_x86.py b/tutorials/auto_scheduler/tune_sparse_x86.py index ced416f6c500..735531ea77e4 100644 --- a/tutorials/auto_scheduler/tune_sparse_x86.py +++ b/tutorials/auto_scheduler/tune_sparse_x86.py @@ -131,7 +131,7 @@ def sparse_dense(M, N, K, w_data_shape, w_indices_shape, w_indptr_shape, dtype): target = tvm.target.Target("llvm") # Register the sparse data to task inputs -prefix = "sparse_dense_bsr_%d_%d_%d_%d_%d_%.2f_" % (M, N, K, BS_R, BS_C, density) +prefix = "sparse_dense_bsr_%d_%d_%d_%d_%.2f_" % (K, N, BS_R, BS_C, density) task = tvm.auto_scheduler.SearchTask( func=sparse_dense, args=(M, N, K, W_sp_np.data.shape, W_sp_np.indices.shape, W_sp_np.indptr.shape, "float32"), From e477927901aa082dcc86c294faed201879b54ae4 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Thu, 11 Mar 2021 20:48:44 +0800 Subject: [PATCH 02/25] Update --- .../tvm/auto_scheduler/relay_integration.py | 15 +++- python/tvm/auto_scheduler/search_task.py | 6 +- python/tvm/relay/analysis/sparse_dense.py | 13 ++-- python/tvm/topi/nn/sparse.py | 27 +++++++ tutorials/auto_scheduler/tune_network_x86.py | 73 +++++++------------ tutorials/auto_scheduler/tune_sparse_x86.py | 33 ++------- 6 files changed, 81 insertions(+), 86 deletions(-) diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index a6adf139820e..a95b74b4d0b9 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -176,19 +176,28 @@ def __exit__(self, exc_type, exc_val, exc_tb): TracingEnvironment.current = None def add_workload_key(self, workload_key): - """Add the workload key of a search task + """Add the workload key of a search task. Parameters ---------- workload_key: str - The workload key of a task + The workload key of a task. """ if workload_key not in self.wkl_key_to_weight: self.wkl_key_to_weight[workload_key] = 0 self.wkl_key_to_weight[workload_key] += 1 def add_workload_input_names(self, workload_key, input_names): - """""" + """Add special task inputs to this workload. + + Parameters + ---------- + workload_key : str + The workload key of a task. + + input_names : List[str] + A list of input names. + """ self.wkl_key_to_input_names[workload_key] = input_names diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index ad093272d4e9..c5c2b5b44451 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -175,9 +175,6 @@ def __init__( # "task_input_2": Tensor(...), # "task_input_3": Tensor(...) # }, -# "default": { -# "task_input_4": Tensor(...), -# }, # ... # } TASK_INPUT_BUFFER_TABLE = {} @@ -308,7 +305,8 @@ def get_task_input_buffer(workload_key, input_name): if tensor_from_file: input_table[input_name] = tensor_from_file - # Then check for the default table + # Then check for the default table, the input names extracted from a relay model will be + # stored here for we're not able to get the workload_key at that time if input_name not in input_table: input_table = TASK_INPUT_BUFFER_TABLE["default"] diff --git a/python/tvm/relay/analysis/sparse_dense.py b/python/tvm/relay/analysis/sparse_dense.py index caffa6525093..23929f45917d 100644 --- a/python/tvm/relay/analysis/sparse_dense.py +++ b/python/tvm/relay/analysis/sparse_dense.py @@ -74,7 +74,10 @@ def process_params(expr, params, block_size, sparsity_threshold): return names of qualified dense weight and the shape in BSR format """ - from tvm.auto_scheduler.search_task import register_task_input_buffer # layzily load + # pylint: disable=import-outside-toplevel + from tvm.auto_scheduler.search_task import ( + register_task_input_buffer, + ) # lazily import to avoid recursive dependency memo = SparseAnalysisResult(weight_name=[], weight_shape=[]) weight_names = _search_dense_op_weight(expr) @@ -92,6 +95,9 @@ def process_params(expr, params, block_size, sparsity_threshold): + list(sparse_weight.indices.shape) + list(sparse_weight.indptr.shape) ) + params[name + ".data"] = tvm.nd.array(sparse_weight.data) + params[name + ".indices"] = tvm.nd.array(sparse_weight.indices) + params[name + ".indptr"] = tvm.nd.array(sparse_weight.indptr) prefix = "sparse_dense_bsr_%d_%d_%d_%d_%.2f_" % ( w_np.shape[0], @@ -100,7 +106,6 @@ def process_params(expr, params, block_size, sparsity_threshold): block_size[1], 1 - sparsity, ) - register_task_input_buffer( "default", prefix + "W_data", tvm.runtime.ndarray.array(sparse_weight.data) ) @@ -110,10 +115,6 @@ def process_params(expr, params, block_size, sparsity_threshold): register_task_input_buffer( "default", prefix + "W_indptr", tvm.runtime.ndarray.array(sparse_weight.indptr) ) - - params[name + ".data"] = tvm.nd.array(sparse_weight.data) - params[name + ".indices"] = tvm.nd.array(sparse_weight.indices) - params[name + ".indptr"] = tvm.nd.array(sparse_weight.indptr) ret = SparseAnalysisResult( weight_name=tvm.runtime.convert(memo.weight_name), weight_shape=tvm.runtime.convert(memo.weight_shape), diff --git a/python/tvm/topi/nn/sparse.py b/python/tvm/topi/nn/sparse.py index 04798f92b112..029a16a6b836 100644 --- a/python/tvm/topi/nn/sparse.py +++ b/python/tvm/topi/nn/sparse.py @@ -468,3 +468,30 @@ def _traverse(t): sparse_input_map[sparse_indptr] = sparse_prefix + "W_indptr" return sparse_input_map + + +def random_bsr_matrix(M, N, BS_R, BS_C, density, dtype): + """ + """ + import numpy as np + import itertools + import scipy.sparse as sp + + Y = np.zeros((M, N), dtype=dtype) + assert M % BS_R == 0 + assert N % BS_C == 0 + nnz = int(density * M * N) + num_blocks = int(nnz / (BS_R * BS_C)) + 1 + candidate_blocks = np.asarray(list(itertools.product(range(0, M, BS_R), range(0, N, BS_C)))) + assert candidate_blocks.shape[0] == M // BS_R * N // BS_C + chosen_blocks = candidate_blocks[ + np.random.choice(candidate_blocks.shape[0], size=num_blocks, replace=False) + ] + for i in range(len(chosen_blocks)): + r, c = chosen_blocks[i] + Y[r : r + BS_R, c : c + BS_C] = np.random.randn(BS_R, BS_C) + s = sp.bsr_matrix(Y, blocksize=(BS_R, BS_C)) + assert s.data.shape == (num_blocks, BS_R, BS_C) + assert s.indices.shape == (num_blocks,) + assert s.indptr.shape == (M // BS_R + 1,) + return s diff --git a/tutorials/auto_scheduler/tune_network_x86.py b/tutorials/auto_scheduler/tune_network_x86.py index 9faaa1a895c6..f2e5490a5ce5 100644 --- a/tutorials/auto_scheduler/tune_network_x86.py +++ b/tutorials/auto_scheduler/tune_network_x86.py @@ -45,13 +45,12 @@ __name__ == "__main__":` block. """ -import itertools import numpy as np -import scipy.sparse as sp import tvm from tvm import relay, auto_scheduler from tvm.relay import data_dep_optimization as ddo +from tvm.topi.nn.sparse import random_bsr_matrix import tvm.relay.testing from tvm.contrib import graph_runtime @@ -70,46 +69,6 @@ # You can use :ref:`ConvertLayout ` pass to do the layout conversion in TVM. -def random_bsr_matrix(M, N, BS_R, BS_C, density, dtype="float32"): - Y = np.zeros((M, N), dtype=dtype) - assert M % BS_R == 0 - assert N % BS_C == 0 - nnz = int(density * M * N) - num_blocks = int(nnz / (BS_R * BS_C)) + 1 - candidate_blocks = np.asarray(list(itertools.product(range(0, M, BS_R), range(0, N, BS_C)))) - assert candidate_blocks.shape[0] == M // BS_R * N // BS_C - chosen_blocks = candidate_blocks[ - np.random.choice(candidate_blocks.shape[0], size=num_blocks, replace=False) - ] - for i in range(len(chosen_blocks)): - r, c = chosen_blocks[i] - Y[r : r + BS_R, c : c + BS_C] = np.random.uniform(-0.1, 0.1, (BS_R, BS_C)) - s = sp.bsr_matrix(Y, blocksize=(BS_R, BS_C)) - assert s.data.shape == (num_blocks, BS_R, BS_C) - assert s.data.size >= nnz - assert s.indices.shape == (num_blocks,) - assert s.indptr.shape == (M // BS_R + 1,) - return s.todense() - - -def random_sparse_params(func, params, density, BS_R, BS_C): - def deepcopy(param_dic): - ret = {} - for k, v in param_dic.items(): - ret[k] = tvm.nd.array(v.asnumpy()) - return ret - - new_params = deepcopy(params) - dense_weight_names = relay.analysis.sparse_dense._search_dense_op_weight(func) - for item in dense_weight_names: - name = str(item) - shape = new_params[name].shape - if shape[0] % BS_R == 0 and shape[1] % BS_C == 0: - new_w = random_bsr_matrix(shape[0], shape[1], BS_R, BS_C, density) - new_params[name] = tvm.nd.array(new_w) - return new_params - - def get_network(name, batch_size, layout="NHWC", dtype="float32"): """Get the symbol definition and random weight of a network""" @@ -175,6 +134,29 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32"): batch_size=batch_size, dtype=dtype, image_shape=image_shape, num_classes=1000 ) elif name == "mlp-sparse": + # This is a test workload that manually transforms a dense model to sparse + # Check `tutorials/frontend/deploy_sparse.py` for more examples on how to import a + # pretrained model + + def random_sparse_params(func, params, density, BS_R, BS_C): + def deepcopy(param_dic): + ret = {} + for k, v in param_dic.items(): + ret[k] = tvm.nd.array(v.asnumpy()) + return ret + + new_params = deepcopy(params) + dense_weight_names = relay.analysis.sparse_dense._search_dense_op_weight(func) + for item in dense_weight_names: + name = str(item) + shape = new_params[name].shape + if shape[0] % BS_R == 0 and shape[1] % BS_C == 0: + new_w = random_bsr_matrix( + shape[0], shape[1], BS_R, BS_C, density, "float32" + ).todense() + new_params[name] = tvm.nd.array(new_w) + return new_params + bs_r = 1 sparsity = 0.85 @@ -192,7 +174,8 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32"): # Define the neural network and compilation target. # If the target machine supports avx512 instructions, replace the # "llvm -mcpu=core-avx2" with "llvm -mcpu=skylake-avx512" -network = "resnet-50" +# network = "resnet-50" +network = "mlp-sparse" batch_size = 1 layout = "NHWC" target = tvm.target.Target("llvm -mcpu=core-avx2") @@ -242,7 +225,7 @@ def run_tuning(): print("Begin tuning...") tuner = auto_scheduler.TaskScheduler(tasks, task_weights) tune_option = auto_scheduler.TuningOptions( - num_measure_trials=20, # change this to 20000 to achieve the best performance + num_measure_trials=200, # change this to 20000 to achieve the best performance runner=auto_scheduler.LocalRunner(repeat=10, enable_cpu_cache_flush=True), measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) @@ -253,7 +236,7 @@ def run_tuning(): # We do not run the tuning in our webpage server since it takes too long. # Uncomment the following line to run it by yourself. -run_tuning() +# run_tuning() ###################################################################### diff --git a/tutorials/auto_scheduler/tune_sparse_x86.py b/tutorials/auto_scheduler/tune_sparse_x86.py index 735531ea77e4..bafd13d41a4a 100644 --- a/tutorials/auto_scheduler/tune_sparse_x86.py +++ b/tutorials/auto_scheduler/tune_sparse_x86.py @@ -36,15 +36,13 @@ """ import os -import itertools import numpy as np import tvm from tvm import te, auto_scheduler, runtime, topi from tvm.auto_scheduler import _ffi_api from tvm.topi.utils import get_const_tuple - -import scipy.sparse as sp +from tvm.topi.nn.sparse import random_bsr_matrix ###################################################################### # Define the computation @@ -53,29 +51,6 @@ # The function should return the list of input/output tensors. # From these tensors, the auto-scheduler can get the whole computational graph. -# We use this function to generate a random bsr matrix -def random_bsr_matrix(M, N, BS_R, BS_C, density, dtype): - import itertools - - Y = np.zeros((M, N), dtype=dtype) - assert M % BS_R == 0 - assert N % BS_C == 0 - nnz = int(density * M * N) - num_blocks = int(nnz / (BS_R * BS_C)) + 1 - candidate_blocks = np.asarray(list(itertools.product(range(0, M, BS_R), range(0, N, BS_C)))) - assert candidate_blocks.shape[0] == M // BS_R * N // BS_C - chosen_blocks = candidate_blocks[ - np.random.choice(candidate_blocks.shape[0], size=num_blocks, replace=False) - ] - for i in range(len(chosen_blocks)): - r, c = chosen_blocks[i] - Y[r : r + BS_R, c : c + BS_C] = np.random.randn(BS_R, BS_C) - s = sp.bsr_matrix(Y, blocksize=(BS_R, BS_C)) - assert s.data.shape == (num_blocks, BS_R, BS_C) - assert s.indices.shape == (num_blocks,) - assert s.indptr.shape == (M // BS_R + 1,) - return s - @auto_scheduler.register_workload def sparse_dense(M, N, K, w_data_shape, w_indices_shape, w_indptr_shape, dtype): @@ -104,7 +79,9 @@ def sparse_dense(M, N, K, w_data_shape, w_indices_shape, w_indptr_shape, dtype): # See the `tvm.auto_scheduler.measure.py` for more details. # Define the basic shapes of this sparse computation -M = K = N = 512 +M = 128 +K = 256 +N = 512 BS_R = 16 BS_C = 1 density = 0.6 @@ -131,7 +108,7 @@ def sparse_dense(M, N, K, w_data_shape, w_indices_shape, w_indptr_shape, dtype): target = tvm.target.Target("llvm") # Register the sparse data to task inputs -prefix = "sparse_dense_bsr_%d_%d_%d_%d_%.2f_" % (K, N, BS_R, BS_C, density) +prefix = "sparse_dense_bsr_%d_%d_%d_%d_%.2f_" % (N, K, BS_R, BS_C, density) task = tvm.auto_scheduler.SearchTask( func=sparse_dense, args=(M, N, K, W_sp_np.data.shape, W_sp_np.indices.shape, W_sp_np.indptr.shape, "float32"), From 73b0346f92c64e3262aa1e7acd4b9ad50f4bd09b Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Thu, 11 Mar 2021 20:53:51 +0800 Subject: [PATCH 03/25] Update --- python/tvm/topi/nn/sparse.py | 2 -- tutorials/auto_scheduler/tune_network_x86.py | 3 +-- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/python/tvm/topi/nn/sparse.py b/python/tvm/topi/nn/sparse.py index 029a16a6b836..dd84e97331de 100644 --- a/python/tvm/topi/nn/sparse.py +++ b/python/tvm/topi/nn/sparse.py @@ -471,8 +471,6 @@ def _traverse(t): def random_bsr_matrix(M, N, BS_R, BS_C, density, dtype): - """ - """ import numpy as np import itertools import scipy.sparse as sp diff --git a/tutorials/auto_scheduler/tune_network_x86.py b/tutorials/auto_scheduler/tune_network_x86.py index f2e5490a5ce5..8ea042aa338e 100644 --- a/tutorials/auto_scheduler/tune_network_x86.py +++ b/tutorials/auto_scheduler/tune_network_x86.py @@ -174,8 +174,7 @@ def deepcopy(param_dic): # Define the neural network and compilation target. # If the target machine supports avx512 instructions, replace the # "llvm -mcpu=core-avx2" with "llvm -mcpu=skylake-avx512" -# network = "resnet-50" -network = "mlp-sparse" +network = "resnet-50" batch_size = 1 layout = "NHWC" target = tvm.target.Target("llvm -mcpu=core-avx2") From da3fb507cedfcc6c9eb500caafd776ed5a0238cb Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Thu, 11 Mar 2021 21:01:22 +0800 Subject: [PATCH 04/25] Lint fix --- python/tvm/topi/nn/sparse.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/python/tvm/topi/nn/sparse.py b/python/tvm/topi/nn/sparse.py index dd84e97331de..8e3ed627c3c7 100644 --- a/python/tvm/topi/nn/sparse.py +++ b/python/tvm/topi/nn/sparse.py @@ -470,26 +470,33 @@ def _traverse(t): return sparse_input_map -def random_bsr_matrix(M, N, BS_R, BS_C, density, dtype): +def random_bsr_matrix(m, n, bs_r, bs_c, density, dtype): + """Generate a random sparse matrix in bsr format. + + Returns + ------- + scipy.sparse.bsr_matrix + """ + # pylint: disable=import-outside-toplevel import numpy as np import itertools import scipy.sparse as sp - Y = np.zeros((M, N), dtype=dtype) - assert M % BS_R == 0 - assert N % BS_C == 0 - nnz = int(density * M * N) - num_blocks = int(nnz / (BS_R * BS_C)) + 1 - candidate_blocks = np.asarray(list(itertools.product(range(0, M, BS_R), range(0, N, BS_C)))) - assert candidate_blocks.shape[0] == M // BS_R * N // BS_C + y = np.zeros((m, n), dtype=dtype) + assert m % bs_r == 0 + assert n % bs_c == 0 + nnz = int(density * m * n) + num_blocks = int(nnz / (bs_r * bs_c)) + 1 + candidate_blocks = np.asarray(list(itertools.product(range(0, m, bs_r), range(0, n, bs_c)))) + assert candidate_blocks.shape[0] == m // bs_r * n // bs_c chosen_blocks = candidate_blocks[ np.random.choice(candidate_blocks.shape[0], size=num_blocks, replace=False) ] for i in range(len(chosen_blocks)): r, c = chosen_blocks[i] - Y[r : r + BS_R, c : c + BS_C] = np.random.randn(BS_R, BS_C) - s = sp.bsr_matrix(Y, blocksize=(BS_R, BS_C)) - assert s.data.shape == (num_blocks, BS_R, BS_C) + y[r : r + bs_r, c : c + bs_c] = np.random.randn(bs_r, bs_c) + s = sp.bsr_matrix(y, blocksize=(bs_r, bs_c)) + assert s.data.shape == (num_blocks, bs_r, bs_c) assert s.indices.shape == (num_blocks,) - assert s.indptr.shape == (M // BS_R + 1,) + assert s.indptr.shape == (m // bs_r + 1,) return s From b319eeb33b7ea44c56c947fba425c0652a70cdae Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Thu, 11 Mar 2021 21:08:20 +0800 Subject: [PATCH 05/25] Lint fix --- python/tvm/topi/nn/sparse.py | 3 +-- tutorials/auto_scheduler/tune_network_x86.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/python/tvm/topi/nn/sparse.py b/python/tvm/topi/nn/sparse.py index 8e3ed627c3c7..272c3a0f86f0 100644 --- a/python/tvm/topi/nn/sparse.py +++ b/python/tvm/topi/nn/sparse.py @@ -492,8 +492,7 @@ def random_bsr_matrix(m, n, bs_r, bs_c, density, dtype): chosen_blocks = candidate_blocks[ np.random.choice(candidate_blocks.shape[0], size=num_blocks, replace=False) ] - for i in range(len(chosen_blocks)): - r, c = chosen_blocks[i] + for (r, c) in chosen_blocks: y[r : r + bs_r, c : c + bs_c] = np.random.randn(bs_r, bs_c) s = sp.bsr_matrix(y, blocksize=(bs_r, bs_c)) assert s.data.shape == (num_blocks, bs_r, bs_c) diff --git a/tutorials/auto_scheduler/tune_network_x86.py b/tutorials/auto_scheduler/tune_network_x86.py index 8ea042aa338e..a1280c0dcdc6 100644 --- a/tutorials/auto_scheduler/tune_network_x86.py +++ b/tutorials/auto_scheduler/tune_network_x86.py @@ -136,7 +136,7 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32"): elif name == "mlp-sparse": # This is a test workload that manually transforms a dense model to sparse # Check `tutorials/frontend/deploy_sparse.py` for more examples on how to import a - # pretrained model + # pretrained model. def random_sparse_params(func, params, density, BS_R, BS_C): def deepcopy(param_dic): From c62079c8fee1010148003074a53a7c8e195625eb Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Thu, 11 Mar 2021 21:11:40 +0800 Subject: [PATCH 06/25] Lint fix --- python/tvm/topi/nn/sparse.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/topi/nn/sparse.py b/python/tvm/topi/nn/sparse.py index 272c3a0f86f0..d1e40c6fbe67 100644 --- a/python/tvm/topi/nn/sparse.py +++ b/python/tvm/topi/nn/sparse.py @@ -492,6 +492,7 @@ def random_bsr_matrix(m, n, bs_r, bs_c, density, dtype): chosen_blocks = candidate_blocks[ np.random.choice(candidate_blocks.shape[0], size=num_blocks, replace=False) ] + # pylint: disable=invalid-name for (r, c) in chosen_blocks: y[r : r + bs_r, c : c + bs_c] = np.random.randn(bs_r, bs_c) s = sp.bsr_matrix(y, blocksize=(bs_r, bs_c)) From 0206afe7e767f8aed0a59593e57564be3df7ee41 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Fri, 12 Mar 2021 11:29:06 +0800 Subject: [PATCH 07/25] Add sparse tuning for arm network --- python/tvm/relay/op/strategy/arm_cpu.py | 9 ++++++ python/tvm/topi/nn/conv2d.py | 2 ++ python/tvm/topi/nn/dense.py | 2 ++ tutorials/auto_scheduler/tune_network_arm.py | 30 +++++++++++++------- 4 files changed, 32 insertions(+), 11 deletions(-) diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index 985124e305ee..9bcde9840876 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -20,6 +20,7 @@ import logging from tvm import topi +from tvm.auto_scheduler import is_auto_scheduler_enabled from ....target import arm_isa from .generic import * from .. import op as _op @@ -127,6 +128,14 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): name="conv2d_hwcn.generic", ) elif layout == "NHWC": + if is_auto_scheduler_enabled(): + strategy.add_implementation( + wrap_compute_conv2d(topi.nn.conv2d_nhwc, need_auto_scheduler_layout=True), + naive_schedule, + name="conv2d_nhwc.arm_cpu", + plevel=100, + ) + channels = data.shape[3] if "SMLAD" in isa and (channels % 4) == 0 and kernel_layout == "HWOI": strategy.add_implementation( diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index 80f87f86736c..4be8b9eb310f 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -410,6 +410,8 @@ def conv2d_nhwc( else: dilation_h, dilation_w = dilation + print("Conv2d", Filter.shape) + if auto_scheduler_rewritten_layout: # Infer shape for the rewritten layout kernel_h, kernel_w, channel, num_filter = auto_scheduler.get_shape_from_rewritten_layout( diff --git a/python/tvm/topi/nn/dense.py b/python/tvm/topi/nn/dense.py index e8ec476b86a5..af2042f5ab59 100644 --- a/python/tvm/topi/nn/dense.py +++ b/python/tvm/topi/nn/dense.py @@ -53,6 +53,8 @@ def dense(data, weight, bias=None, out_dtype=None, auto_scheduler_rewritten_layo out_dtype = data.dtype batch, in_dim = data.shape + print("Dense", weight.shape) + if auto_scheduler_rewritten_layout: # Infer shape for the rewritten layout out_dim, red_dim = auto_scheduler.get_shape_from_rewritten_layout( diff --git a/tutorials/auto_scheduler/tune_network_arm.py b/tutorials/auto_scheduler/tune_network_arm.py index c4add79450e9..78b7954a4b7b 100644 --- a/tutorials/auto_scheduler/tune_network_arm.py +++ b/tutorials/auto_scheduler/tune_network_arm.py @@ -45,6 +45,7 @@ """ import numpy as np +import os import tvm from tvm import relay, auto_scheduler @@ -215,15 +216,19 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32"): # This target is used for cross compilation. You can query it by :code:`gcc -v` on your device. # FIXME(tmoreau89, merrymercy): We leave '-device=arm_cpu' out of the target string # because we're sharing x86 op strategy. -target = tvm.target.Target("llvm -mtriple=aarch64-linux-gnu -mattr=+neon") +# target = tvm.target.Target("llvm -mtriple=aarch64-linux-gnu -mattr=+neon") -# Also replace this with the device key in your tracker -device_key = "rasp4b-64" +target = tvm.target.arm_cpu("pixel2") + +# Also replace this with the device key, rpc host and rpc port in your tracker +device_key = "pixel2" +rpc_host = "11.164.101.214" +rpc_port = 9190 # Set this to True if you use ndk tools for cross compiling # And also set the environment variable below to point to the cross compiler -use_ndk = False -# os.environ["TVM_NDK_CC"] = "/usr/bin/aarch64-linux-gnu-g++" +use_ndk = True +os.environ["TVM_NDK_CC"] = "/Users/jcf/Workspace/tvm_workspace/arm/android-ndk-r21d/build/tools/android-toolchain-arm64/bin/aarch64-linux-android-g++" #### TUNING OPTION #### network = "mobilenet" @@ -279,11 +284,14 @@ def tune_and_evaluate(): print("Begin tuning...") tuner = auto_scheduler.TaskScheduler(tasks, task_weights) tune_option = auto_scheduler.TuningOptions( - num_measure_trials=200, # change this to 20000 to achieve the best performance + num_measure_trials=23, # change this to 20000 to achieve the best performance + builder=auto_scheduler.LocalBuilder( + build_func="ndk" if use_ndk else "default" + ), runner=auto_scheduler.RPCRunner( device_key, - host="0.0.0.0", - port=9191, + host=rpc_host, + port=rpc_port, timeout=30, repeat=1, min_repeat_ms=200, @@ -292,7 +300,7 @@ def tune_and_evaluate(): measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) - tuner.tune(tune_option) + # tuner.tune(tune_option) # Compile with the history best print("Compile...") @@ -315,7 +323,7 @@ def tune_and_evaluate(): # Upload module to device print("Upload...") - remote = auto_scheduler.utils.request_remote(device_key, "0.0.0.0", 9191, timeout=10000) + remote = auto_scheduler.utils.request_remote(device_key, rpc_host, rpc_port, timeout=10000) remote.upload(tmp.relpath(filename)) rlib = remote.load_module(filename) @@ -338,7 +346,7 @@ def tune_and_evaluate(): # or device tracker running. # Uncomment the following line to run it by yourself. -# tune_and_evaluate() +tune_and_evaluate() ###################################################################### From 7a708dc63f957990dd9eb106c7aecdda60e6f30c Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Fri, 12 Mar 2021 11:55:46 +0800 Subject: [PATCH 08/25] Add sparse support for arm network --- tutorials/auto_scheduler/tune_network_arm.py | 66 +++++++++++++++++--- tutorials/auto_scheduler/tune_network_x86.py | 28 +++++---- 2 files changed, 74 insertions(+), 20 deletions(-) diff --git a/tutorials/auto_scheduler/tune_network_arm.py b/tutorials/auto_scheduler/tune_network_arm.py index 78b7954a4b7b..0db6b26ef489 100644 --- a/tutorials/auto_scheduler/tune_network_arm.py +++ b/tutorials/auto_scheduler/tune_network_arm.py @@ -17,7 +17,9 @@ """ Auto-scheduling a Neural Network for ARM CPU ============================================= -**Author**: `Thierry Moreau >`_ +**Author**: `Thierry Moreau _`, \ + `Lianmin Zheng _`, \ + `Chengfan Jia `_ Auto-tuning for specific devices and workloads is critical for getting the best performance. This is a tutorial on how to tune a whole neural @@ -49,6 +51,8 @@ import tvm from tvm import relay, auto_scheduler +from tvm.relay import data_dep_optimization as ddo +from tvm.topi.nn.sparse import random_bsr_matrix import tvm.relay.testing from tvm.contrib import graph_runtime from tvm.contrib.utils import tempdir @@ -68,7 +72,7 @@ # You can use :ref:`ConvertLayout ` pass to do the layout conversion in TVM. -def get_network(name, batch_size, layout="NHWC", dtype="float32"): +def get_network(name, batch_size, layout="NHWC", dtype="float32", use_sparse=False): """Get the symbol definition and random weight of a network""" # auto-scheduler prefers NHWC layout @@ -128,6 +132,46 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32"): net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs ) mod = tvm.IRModule.from_expr(net) + elif name == "mlp": + mod, params = relay.testing.mlp.get_workload( + batch_size=batch_size, dtype=dtype, image_shape=image_shape, num_classes=1000 + ) + else: + raise ValueError("Network not found.") + + if use_sparse: + # This is a test workload that manually transforms a dense model to sparse + # Check `tutorials/frontend/deploy_sparse.py` for more examples on how to import a + # pretrained model. + + def random_sparse_dense_params(func, params, density, BS_R, BS_C): + def deepcopy(param_dic): + ret = {} + for k, v in param_dic.items(): + ret[k] = tvm.nd.array(v.asnumpy()) + return ret + + new_params = deepcopy(params) + dense_weight_names = relay.analysis.sparse_dense._search_dense_op_weight(func) + for item in dense_weight_names: + name = str(item) + shape = new_params[name].shape + if shape[0] % BS_R == 0 and shape[1] % BS_C == 0: + new_w = random_bsr_matrix( + shape[0], shape[1], BS_R, BS_C, density, "float32" + ).todense() + new_params[name] = tvm.nd.array(new_w) + return new_params + + bs_r = 1 + sparsity = 0.85 + + # Currently we only support to conver dense matmul to sparse dense matmul + mod, params = ddo.simplify_fc_transpose.convert(mod["main"], params) + params = random_sparse_dense_params(mod, params, BS_R=bs_r, BS_C=1, density=1 - sparsity) + mod, params = ddo.bsr_dense.convert(mod, params, (bs_r, 1), sparsity_threshold=0.8) + + mod = tvm.IRModule.from_expr(mod) return mod, params, input_shape, output_shape @@ -228,10 +272,13 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32"): # Set this to True if you use ndk tools for cross compiling # And also set the environment variable below to point to the cross compiler use_ndk = True -os.environ["TVM_NDK_CC"] = "/Users/jcf/Workspace/tvm_workspace/arm/android-ndk-r21d/build/tools/android-toolchain-arm64/bin/aarch64-linux-android-g++" +os.environ[ + "TVM_NDK_CC" +] = "/Users/jcf/Workspace/tvm_workspace/arm/android-ndk-r21d/build/tools/android-toolchain-arm64/bin/aarch64-linux-android-g++" #### TUNING OPTION #### network = "mobilenet" +use_sparse = False batch_size = 1 layout = "NHWC" dtype = "float32" @@ -249,8 +296,11 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32"): # The task scheduler will just optimize this objective. # Extract tasks from the network +print("Get model...") +mod, params, input_shape, output_shape = get_network( + network, batch_size, layout, dtype=dtype, use_sparse=use_sparse +) print("Extract tasks...") -mod, params, input_shape, output_shape = get_network(network, batch_size, layout, dtype=dtype) tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target) for idx, task in enumerate(tasks): @@ -284,10 +334,8 @@ def tune_and_evaluate(): print("Begin tuning...") tuner = auto_scheduler.TaskScheduler(tasks, task_weights) tune_option = auto_scheduler.TuningOptions( - num_measure_trials=23, # change this to 20000 to achieve the best performance - builder=auto_scheduler.LocalBuilder( - build_func="ndk" if use_ndk else "default" - ), + num_measure_trials=len(tasks), # change this to 20000 to achieve the best performance + builder=auto_scheduler.LocalBuilder(build_func="ndk" if use_ndk else "default"), runner=auto_scheduler.RPCRunner( device_key, host=rpc_host, @@ -300,7 +348,7 @@ def tune_and_evaluate(): measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) - # tuner.tune(tune_option) + tuner.tune(tune_option) # Compile with the history best print("Compile...") diff --git a/tutorials/auto_scheduler/tune_network_x86.py b/tutorials/auto_scheduler/tune_network_x86.py index a1280c0dcdc6..9a2de695ccd1 100644 --- a/tutorials/auto_scheduler/tune_network_x86.py +++ b/tutorials/auto_scheduler/tune_network_x86.py @@ -69,7 +69,7 @@ # You can use :ref:`ConvertLayout ` pass to do the layout conversion in TVM. -def get_network(name, batch_size, layout="NHWC", dtype="float32"): +def get_network(name, batch_size, layout="NHWC", dtype="float32", use_sparse=False): """Get the symbol definition and random weight of a network""" # auto-scheduler prefers NHWC layout @@ -133,12 +133,15 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32"): mod, params = relay.testing.mlp.get_workload( batch_size=batch_size, dtype=dtype, image_shape=image_shape, num_classes=1000 ) - elif name == "mlp-sparse": + else: + raise ValueError("Network not found.") + + if use_sparse: # This is a test workload that manually transforms a dense model to sparse # Check `tutorials/frontend/deploy_sparse.py` for more examples on how to import a # pretrained model. - def random_sparse_params(func, params, density, BS_R, BS_C): + def random_sparse_dense_params(func, params, density, BS_R, BS_C): def deepcopy(param_dic): ret = {} for k, v in param_dic.items(): @@ -160,12 +163,11 @@ def deepcopy(param_dic): bs_r = 1 sparsity = 0.85 - mod, params = relay.testing.mlp.get_workload( - batch_size=batch_size, dtype=dtype, image_shape=image_shape, num_classes=1000 - ) + # Currently we only support to conver dense matmul to sparse dense matmul mod, params = ddo.simplify_fc_transpose.convert(mod["main"], params) - params = random_sparse_params(mod, params, BS_R=bs_r, BS_C=1, density=1 - sparsity) + params = random_sparse_dense_params(mod, params, BS_R=bs_r, BS_C=1, density=1 - sparsity) mod, params = ddo.bsr_dense.convert(mod, params, (bs_r, 1), sparsity_threshold=0.8) + mod = tvm.IRModule.from_expr(mod) return mod, params, input_shape, output_shape @@ -174,7 +176,8 @@ def deepcopy(param_dic): # Define the neural network and compilation target. # If the target machine supports avx512 instructions, replace the # "llvm -mcpu=core-avx2" with "llvm -mcpu=skylake-avx512" -network = "resnet-50" +network = "mobilenet" +use_sparse = True batch_size = 1 layout = "NHWC" target = tvm.target.Target("llvm -mcpu=core-avx2") @@ -193,8 +196,11 @@ def deepcopy(param_dic): # The task scheduler will just optimize this objective. # Extract tasks from the network +print("Get model...") +mod, params, input_shape, output_shape = get_network( + network, batch_size, layout, dtype=dtype, use_sparse=use_sparse +) print("Extract tasks...") -mod, params, input_shape, output_shape = get_network(network, batch_size, layout, dtype=dtype) tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target) for idx, task in enumerate(tasks): @@ -224,7 +230,7 @@ def run_tuning(): print("Begin tuning...") tuner = auto_scheduler.TaskScheduler(tasks, task_weights) tune_option = auto_scheduler.TuningOptions( - num_measure_trials=200, # change this to 20000 to achieve the best performance + num_measure_trials=len(tasks), # change this to 20000 to achieve the best performance runner=auto_scheduler.LocalRunner(repeat=10, enable_cpu_cache_flush=True), measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) @@ -235,7 +241,7 @@ def run_tuning(): # We do not run the tuning in our webpage server since it takes too long. # Uncomment the following line to run it by yourself. -# run_tuning() +run_tuning() ###################################################################### From 5d0cc86a65c7ae773a4d4f8142d0ab0b23634d68 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Fri, 12 Mar 2021 12:57:03 +0800 Subject: [PATCH 09/25] Update --- .../tvm/auto_scheduler/relay_integration.py | 1 + python/tvm/topi/nn/conv2d.py | 2 -- python/tvm/topi/nn/dense.py | 2 -- tutorials/auto_scheduler/tune_network_arm.py | 20 ++++++++----------- tutorials/auto_scheduler/tune_network_x86.py | 6 +++--- 5 files changed, 12 insertions(+), 19 deletions(-) diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index 80399841179b..e931fc6e298d 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -147,6 +147,7 @@ def extract_tasks( if wkl_key in env.wkl_key_to_input_names else None ), + task_inputs_save_to_file=True, ) ) weights.append(weight) diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index 4be8b9eb310f..80f87f86736c 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -410,8 +410,6 @@ def conv2d_nhwc( else: dilation_h, dilation_w = dilation - print("Conv2d", Filter.shape) - if auto_scheduler_rewritten_layout: # Infer shape for the rewritten layout kernel_h, kernel_w, channel, num_filter = auto_scheduler.get_shape_from_rewritten_layout( diff --git a/python/tvm/topi/nn/dense.py b/python/tvm/topi/nn/dense.py index af2042f5ab59..e8ec476b86a5 100644 --- a/python/tvm/topi/nn/dense.py +++ b/python/tvm/topi/nn/dense.py @@ -53,8 +53,6 @@ def dense(data, weight, bias=None, out_dtype=None, auto_scheduler_rewritten_layo out_dtype = data.dtype batch, in_dim = data.shape - print("Dense", weight.shape) - if auto_scheduler_rewritten_layout: # Infer shape for the rewritten layout out_dim, red_dim = auto_scheduler.get_shape_from_rewritten_layout( diff --git a/tutorials/auto_scheduler/tune_network_arm.py b/tutorials/auto_scheduler/tune_network_arm.py index 0db6b26ef489..e70406682521 100644 --- a/tutorials/auto_scheduler/tune_network_arm.py +++ b/tutorials/auto_scheduler/tune_network_arm.py @@ -260,21 +260,17 @@ def deepcopy(param_dic): # This target is used for cross compilation. You can query it by :code:`gcc -v` on your device. # FIXME(tmoreau89, merrymercy): We leave '-device=arm_cpu' out of the target string # because we're sharing x86 op strategy. -# target = tvm.target.Target("llvm -mtriple=aarch64-linux-gnu -mattr=+neon") - -target = tvm.target.arm_cpu("pixel2") +target = tvm.target.Target("llvm -mtriple=aarch64-linux-gnu -mattr=+neon") # Also replace this with the device key, rpc host and rpc port in your tracker -device_key = "pixel2" -rpc_host = "11.164.101.214" -rpc_port = 9190 +device_key = "rasp4b-64" +rpc_host = "0.0.0.0" +rpc_port = 9191 # Set this to True if you use ndk tools for cross compiling # And also set the environment variable below to point to the cross compiler -use_ndk = True -os.environ[ - "TVM_NDK_CC" -] = "/Users/jcf/Workspace/tvm_workspace/arm/android-ndk-r21d/build/tools/android-toolchain-arm64/bin/aarch64-linux-android-g++" +use_ndk = False +# os.environ["TVM_NDK_CC"] = "/usr/bin/aarch64-linux-gnu-g++" #### TUNING OPTION #### network = "mobilenet" @@ -334,7 +330,7 @@ def tune_and_evaluate(): print("Begin tuning...") tuner = auto_scheduler.TaskScheduler(tasks, task_weights) tune_option = auto_scheduler.TuningOptions( - num_measure_trials=len(tasks), # change this to 20000 to achieve the best performance + num_measure_trials=200, # change this to 20000 to achieve the best performance builder=auto_scheduler.LocalBuilder(build_func="ndk" if use_ndk else "default"), runner=auto_scheduler.RPCRunner( device_key, @@ -394,7 +390,7 @@ def tune_and_evaluate(): # or device tracker running. # Uncomment the following line to run it by yourself. -tune_and_evaluate() +# tune_and_evaluate() ###################################################################### diff --git a/tutorials/auto_scheduler/tune_network_x86.py b/tutorials/auto_scheduler/tune_network_x86.py index 9a2de695ccd1..88a9f267c3b5 100644 --- a/tutorials/auto_scheduler/tune_network_x86.py +++ b/tutorials/auto_scheduler/tune_network_x86.py @@ -176,8 +176,8 @@ def deepcopy(param_dic): # Define the neural network and compilation target. # If the target machine supports avx512 instructions, replace the # "llvm -mcpu=core-avx2" with "llvm -mcpu=skylake-avx512" -network = "mobilenet" -use_sparse = True +network = "resnet-50" +use_sparse = False batch_size = 1 layout = "NHWC" target = tvm.target.Target("llvm -mcpu=core-avx2") @@ -241,7 +241,7 @@ def run_tuning(): # We do not run the tuning in our webpage server since it takes too long. # Uncomment the following line to run it by yourself. -run_tuning() +# run_tuning() ###################################################################### From 77941d10c6a5f535e50d23499ddcad11d9cee77c Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Fri, 12 Mar 2021 12:59:29 +0800 Subject: [PATCH 10/25] Update --- tutorials/auto_scheduler/tune_network_x86.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorials/auto_scheduler/tune_network_x86.py b/tutorials/auto_scheduler/tune_network_x86.py index 88a9f267c3b5..4d2e6182c32f 100644 --- a/tutorials/auto_scheduler/tune_network_x86.py +++ b/tutorials/auto_scheduler/tune_network_x86.py @@ -230,7 +230,7 @@ def run_tuning(): print("Begin tuning...") tuner = auto_scheduler.TaskScheduler(tasks, task_weights) tune_option = auto_scheduler.TuningOptions( - num_measure_trials=len(tasks), # change this to 20000 to achieve the best performance + num_measure_trials=200, # change this to 20000 to achieve the best performance runner=auto_scheduler.LocalRunner(repeat=10, enable_cpu_cache_flush=True), measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) From cb47cda8b11a01d385e53ea76b37a061c067fa53 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Fri, 12 Mar 2021 13:13:23 +0800 Subject: [PATCH 11/25] Update --- python/tvm/relay/op/strategy/arm_cpu.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index 9bcde9840876..b32bdfb66207 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -128,14 +128,6 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): name="conv2d_hwcn.generic", ) elif layout == "NHWC": - if is_auto_scheduler_enabled(): - strategy.add_implementation( - wrap_compute_conv2d(topi.nn.conv2d_nhwc, need_auto_scheduler_layout=True), - naive_schedule, - name="conv2d_nhwc.arm_cpu", - plevel=100, - ) - channels = data.shape[3] if "SMLAD" in isa and (channels % 4) == 0 and kernel_layout == "HWOI": strategy.add_implementation( @@ -144,6 +136,14 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): name="conv2d_direct_simd.micro_dev", ) elif kernel_layout == "HWIO": + if is_auto_scheduler_enabled(): + strategy.add_implementation( + wrap_compute_conv2d(topi.nn.conv2d_nhwc, need_auto_scheduler_layout=True), + naive_schedule, + name="conv2d_nhwc.arm_cpu", + plevel=100, + ) + is_aarch64 = topi.arm_cpu.arm_utils.is_aarch64_arm() has_dot_prod = topi.arm_cpu.arm_utils.is_dotprod_available() if has_dot_prod and data.dtype in ["int8", "uint8"]: From 5379674bdaabba947cfcd019b3e642c9c6ca5f83 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Fri, 12 Mar 2021 17:03:26 +0800 Subject: [PATCH 12/25] Bug fix for tflite frontend dense with layout rewrite --- python/tvm/relay/frontend/tflite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 1b593ad8dea3..9deb8cd35b33 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -1872,7 +1872,7 @@ def convert_fully_connected(self, op): out_dtype="int32", ) else: - out = _op.nn.dense(in_expr, weight_expr) + out = _op.nn.dense(in_expr, weight_expr, units=weight_shape[0]) # if we have bias if len(input_tensors) == 3: From 8f4fc1d7d04ac4fab693eff231e5359d9175ad18 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Fri, 12 Mar 2021 17:41:48 +0800 Subject: [PATCH 13/25] Move the random_bsr_matrix to sparse.utils --- python/tvm/topi/nn/sparse.py | 32 ------------- python/tvm/topi/sparse/utils.py | 49 ++++++++++++++++++++ tutorials/auto_scheduler/tune_network_arm.py | 2 +- tutorials/auto_scheduler/tune_network_x86.py | 2 +- tutorials/auto_scheduler/tune_sparse_x86.py | 2 +- 5 files changed, 52 insertions(+), 35 deletions(-) create mode 100644 python/tvm/topi/sparse/utils.py diff --git a/python/tvm/topi/nn/sparse.py b/python/tvm/topi/nn/sparse.py index 3dc006620dc0..f5737d087fc7 100644 --- a/python/tvm/topi/nn/sparse.py +++ b/python/tvm/topi/nn/sparse.py @@ -470,38 +470,6 @@ def _traverse(t): return sparse_input_map -def random_bsr_matrix(m, n, bs_r, bs_c, density, dtype): - """Generate a random sparse matrix in bsr format. - - Returns - ------- - scipy.sparse.bsr_matrix - """ - # pylint: disable=import-outside-toplevel - import numpy as np - import itertools - import scipy.sparse as sp - - y = np.zeros((m, n), dtype=dtype) - assert m % bs_r == 0 - assert n % bs_c == 0 - nnz = int(density * m * n) - num_blocks = int(nnz / (bs_r * bs_c)) + 1 - candidate_blocks = np.asarray(list(itertools.product(range(0, m, bs_r), range(0, n, bs_c)))) - assert candidate_blocks.shape[0] == m // bs_r * n // bs_c - chosen_blocks = candidate_blocks[ - np.random.choice(candidate_blocks.shape[0], size=num_blocks, replace=False) - ] - # pylint: disable=invalid-name - for (r, c) in chosen_blocks: - y[r : r + bs_r, c : c + bs_c] = np.random.randn(bs_r, bs_c) - s = sp.bsr_matrix(y, blocksize=(bs_r, bs_c)) - assert s.data.shape == (num_blocks, bs_r, bs_c) - assert s.indices.shape == (num_blocks,) - assert s.indptr.shape == (m // bs_r + 1,) - return s - - def sparse_add(dense_data, sparse_data, sparse_indices, sparse_indptr): """ Computes sparse-dense addition diff --git a/python/tvm/topi/sparse/utils.py b/python/tvm/topi/sparse/utils.py new file mode 100644 index 000000000000..a1db6fc12623 --- /dev/null +++ b/python/tvm/topi/sparse/utils.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Some utils for Sparse operation.""" + + +def random_bsr_matrix(m, n, bs_r, bs_c, density, dtype): + """Generate a random sparse matrix in bsr format. + + Returns + ------- + scipy.sparse.bsr_matrix + """ + # pylint: disable=import-outside-toplevel + import numpy as np + import itertools + import scipy.sparse as sp + + y = np.zeros((m, n), dtype=dtype) + assert m % bs_r == 0 + assert n % bs_c == 0 + nnz = int(density * m * n) + num_blocks = int(nnz / (bs_r * bs_c)) + 1 + candidate_blocks = np.asarray(list(itertools.product(range(0, m, bs_r), range(0, n, bs_c)))) + assert candidate_blocks.shape[0] == m // bs_r * n // bs_c + chosen_blocks = candidate_blocks[ + np.random.choice(candidate_blocks.shape[0], size=num_blocks, replace=False) + ] + # pylint: disable=invalid-name + for (r, c) in chosen_blocks: + y[r : r + bs_r, c : c + bs_c] = np.random.randn(bs_r, bs_c) + s = sp.bsr_matrix(y, blocksize=(bs_r, bs_c)) + assert s.data.shape == (num_blocks, bs_r, bs_c) + assert s.indices.shape == (num_blocks,) + assert s.indptr.shape == (m // bs_r + 1,) + return s diff --git a/tutorials/auto_scheduler/tune_network_arm.py b/tutorials/auto_scheduler/tune_network_arm.py index e70406682521..418015b93475 100644 --- a/tutorials/auto_scheduler/tune_network_arm.py +++ b/tutorials/auto_scheduler/tune_network_arm.py @@ -52,7 +52,7 @@ import tvm from tvm import relay, auto_scheduler from tvm.relay import data_dep_optimization as ddo -from tvm.topi.nn.sparse import random_bsr_matrix +from tvm.topi.sparse.utils import random_bsr_matrix import tvm.relay.testing from tvm.contrib import graph_runtime from tvm.contrib.utils import tempdir diff --git a/tutorials/auto_scheduler/tune_network_x86.py b/tutorials/auto_scheduler/tune_network_x86.py index 4d2e6182c32f..04ae9cc28dc9 100644 --- a/tutorials/auto_scheduler/tune_network_x86.py +++ b/tutorials/auto_scheduler/tune_network_x86.py @@ -50,7 +50,7 @@ import tvm from tvm import relay, auto_scheduler from tvm.relay import data_dep_optimization as ddo -from tvm.topi.nn.sparse import random_bsr_matrix +from tvm.topi.sparse.utils import random_bsr_matrix import tvm.relay.testing from tvm.contrib import graph_runtime diff --git a/tutorials/auto_scheduler/tune_sparse_x86.py b/tutorials/auto_scheduler/tune_sparse_x86.py index bafd13d41a4a..ad3646dfc19d 100644 --- a/tutorials/auto_scheduler/tune_sparse_x86.py +++ b/tutorials/auto_scheduler/tune_sparse_x86.py @@ -42,7 +42,7 @@ from tvm import te, auto_scheduler, runtime, topi from tvm.auto_scheduler import _ffi_api from tvm.topi.utils import get_const_tuple -from tvm.topi.nn.sparse import random_bsr_matrix +from tvm.topi.sparse.utils import random_bsr_matrix ###################################################################### # Define the computation From 46498fb9c79e7c2ddcaf6322c27275444e168ab1 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Fri, 12 Mar 2021 18:28:48 +0800 Subject: [PATCH 14/25] Update --- python/tvm/relay/op/strategy/arm_cpu.py | 28 ++++++++++++++++++++++++- python/tvm/relay/op/strategy/x86.py | 2 +- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index b32bdfb66207..468a837ce479 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -143,6 +143,30 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): name="conv2d_nhwc.arm_cpu", plevel=100, ) + judge_winograd_auto_scheduler = False + if len(kernel.shape) == 4: + kernel_h, kernel_w, _, co = get_const_tuple(kernel.shape) + judge_winograd_auto_scheduler = ( + "float" in data.dtype + and "float" in kernel.dtype + and kernel_h == 3 + and kernel_w == 3 + and stride_h == 1 + and stride_w == 1 + and dilation_h == 1 + and dilation_w == 1 + and 64 < co < 512 + ) + # register auto-scheduler implementations + if judge_winograd_auto_scheduler: + strategy.add_implementation( + wrap_compute_conv2d( + topi.nn.conv2d_winograd_nhwc, need_auto_scheduler_layout=True + ), + naive_schedule, # this implementation should never be picked by autotvm + name="conv2d_nhwc.winograd.arm_cpu", + plevel=101, + ) is_aarch64 = topi.arm_cpu.arm_utils.is_aarch64_arm() has_dot_prod = topi.arm_cpu.arm_utils.is_dotprod_available() @@ -207,7 +231,9 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): elif layout == "NHWC": assert kernel_layout == "HWOI" strategy.add_implementation( - wrap_compute_conv2d(topi.arm_cpu.compute_depthwise_conv2d_nhwc), + wrap_compute_conv2d( + topi.arm_cpu.compute_depthwise_conv2d_nhwc, need_auto_scheduler_layout=True + ), wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nhwc), name="depthwise_conv2d_nhwc.arm_cpu", ) diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 1f37a4f8e98c..6b65dca495d9 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -196,7 +196,7 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): "depthwise_conv2d NHWC layout is not optimized for x86 with autotvm." ) strategy.add_implementation( - wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc), + wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc, need_auto_scheduler_layout=True), wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nhwc), name="depthwise_conv2d_nhwc.generic", ) From 03c455c4e3c499026d2f2bdf984e06d80472a54d Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Fri, 12 Mar 2021 18:29:18 +0800 Subject: [PATCH 15/25] Update --- python/tvm/relay/op/strategy/arm_cpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index 468a837ce479..e3a0c1977536 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -155,7 +155,7 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): and stride_w == 1 and dilation_h == 1 and dilation_w == 1 - and 64 < co < 512 + and 64 <= co < 512 ) # register auto-scheduler implementations if judge_winograd_auto_scheduler: From 373020b778eed2ee54cdd1cab3cc8ab9cb393af1 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Fri, 12 Mar 2021 22:18:06 +0800 Subject: [PATCH 16/25] Bug fix --- python/tvm/relay/op/strategy/arm_cpu.py | 4 +--- python/tvm/relay/op/strategy/x86.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index e3a0c1977536..cd83d7918a18 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -231,9 +231,7 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): elif layout == "NHWC": assert kernel_layout == "HWOI" strategy.add_implementation( - wrap_compute_conv2d( - topi.arm_cpu.compute_depthwise_conv2d_nhwc, need_auto_scheduler_layout=True - ), + wrap_compute_conv2d(topi.arm_cpu.compute_depthwise_conv2d_nhwc), wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nhwc), name="depthwise_conv2d_nhwc.arm_cpu", ) diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 6b65dca495d9..1f37a4f8e98c 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -196,7 +196,7 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): "depthwise_conv2d NHWC layout is not optimized for x86 with autotvm." ) strategy.add_implementation( - wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc, need_auto_scheduler_layout=True), + wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc), wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nhwc), name="depthwise_conv2d_nhwc.generic", ) From 5325b521b16f35fcfcba18c3aab75cba4780cd9c Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 24 Mar 2021 10:12:47 +0800 Subject: [PATCH 17/25] Remove the modification of ARM CPU --- python/tvm/relay/op/strategy/arm_cpu.py | 33 ------------------------- 1 file changed, 33 deletions(-) diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index cd83d7918a18..985124e305ee 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -20,7 +20,6 @@ import logging from tvm import topi -from tvm.auto_scheduler import is_auto_scheduler_enabled from ....target import arm_isa from .generic import * from .. import op as _op @@ -136,38 +135,6 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): name="conv2d_direct_simd.micro_dev", ) elif kernel_layout == "HWIO": - if is_auto_scheduler_enabled(): - strategy.add_implementation( - wrap_compute_conv2d(topi.nn.conv2d_nhwc, need_auto_scheduler_layout=True), - naive_schedule, - name="conv2d_nhwc.arm_cpu", - plevel=100, - ) - judge_winograd_auto_scheduler = False - if len(kernel.shape) == 4: - kernel_h, kernel_w, _, co = get_const_tuple(kernel.shape) - judge_winograd_auto_scheduler = ( - "float" in data.dtype - and "float" in kernel.dtype - and kernel_h == 3 - and kernel_w == 3 - and stride_h == 1 - and stride_w == 1 - and dilation_h == 1 - and dilation_w == 1 - and 64 <= co < 512 - ) - # register auto-scheduler implementations - if judge_winograd_auto_scheduler: - strategy.add_implementation( - wrap_compute_conv2d( - topi.nn.conv2d_winograd_nhwc, need_auto_scheduler_layout=True - ), - naive_schedule, # this implementation should never be picked by autotvm - name="conv2d_nhwc.winograd.arm_cpu", - plevel=101, - ) - is_aarch64 = topi.arm_cpu.arm_utils.is_aarch64_arm() has_dot_prod = topi.arm_cpu.arm_utils.is_dotprod_available() if has_dot_prod and data.dtype in ["int8", "uint8"]: From c03231955471be7508fcbb7b4a54a518f6d28e7a Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 24 Mar 2021 10:39:09 +0800 Subject: [PATCH 18/25] Update --- python/tvm/topi/sparse/utils.py | 21 ++++++++++++ tutorials/auto_scheduler/tune_network_arm.py | 35 ++++---------------- tutorials/auto_scheduler/tune_network_x86.py | 35 ++++---------------- 3 files changed, 33 insertions(+), 58 deletions(-) diff --git a/python/tvm/topi/sparse/utils.py b/python/tvm/topi/sparse/utils.py index a1db6fc12623..39edeb01d70f 100644 --- a/python/tvm/topi/sparse/utils.py +++ b/python/tvm/topi/sparse/utils.py @@ -15,6 +15,9 @@ # specific language governing permissions and limitations # under the License. """Some utils for Sparse operation.""" +import tvm +from tvm import relay +from tvm.relay import data_dep_optimization as ddo def random_bsr_matrix(m, n, bs_r, bs_c, density, dtype): @@ -47,3 +50,21 @@ def random_bsr_matrix(m, n, bs_r, bs_c, density, dtype): assert s.indices.shape == (num_blocks,) assert s.indptr.shape == (m // bs_r + 1,) return s + + +def random_sparse_dense_params(func, params, density, BS_R, BS_C): + def deepcopy(param_dic): + ret = {} + for k, v in param_dic.items(): + ret[k] = tvm.nd.array(v.asnumpy()) + return ret + + new_params = deepcopy(params) + dense_weight_names = relay.analysis.sparse_dense._search_dense_op_weight(func) + for item in dense_weight_names: + name = str(item) + shape = new_params[name].shape + if shape[0] % BS_R == 0 and shape[1] % BS_C == 0: + new_w = random_bsr_matrix(shape[0], shape[1], BS_R, BS_C, density, "float32").todense() + new_params[name] = tvm.nd.array(new_w) + return new_params diff --git a/tutorials/auto_scheduler/tune_network_arm.py b/tutorials/auto_scheduler/tune_network_arm.py index 418015b93475..74976813921a 100644 --- a/tutorials/auto_scheduler/tune_network_arm.py +++ b/tutorials/auto_scheduler/tune_network_arm.py @@ -52,7 +52,7 @@ import tvm from tvm import relay, auto_scheduler from tvm.relay import data_dep_optimization as ddo -from tvm.topi.sparse.utils import random_bsr_matrix +from tvm.topi.sparse.utils import random_sparse_dense_params import tvm.relay.testing from tvm.contrib import graph_runtime from tvm.contrib.utils import tempdir @@ -140,37 +140,14 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32", use_sparse=Fal raise ValueError("Network not found.") if use_sparse: - # This is a test workload that manually transforms a dense model to sparse - # Check `tutorials/frontend/deploy_sparse.py` for more examples on how to import a - # pretrained model. - - def random_sparse_dense_params(func, params, density, BS_R, BS_C): - def deepcopy(param_dic): - ret = {} - for k, v in param_dic.items(): - ret[k] = tvm.nd.array(v.asnumpy()) - return ret - - new_params = deepcopy(params) - dense_weight_names = relay.analysis.sparse_dense._search_dense_op_weight(func) - for item in dense_weight_names: - name = str(item) - shape = new_params[name].shape - if shape[0] % BS_R == 0 and shape[1] % BS_C == 0: - new_w = random_bsr_matrix( - shape[0], shape[1], BS_R, BS_C, density, "float32" - ).todense() - new_params[name] = tvm.nd.array(new_w) - return new_params - bs_r = 1 + bs_c = 1 sparsity = 0.85 - - # Currently we only support to conver dense matmul to sparse dense matmul mod, params = ddo.simplify_fc_transpose.convert(mod["main"], params) - params = random_sparse_dense_params(mod, params, BS_R=bs_r, BS_C=1, density=1 - sparsity) - mod, params = ddo.bsr_dense.convert(mod, params, (bs_r, 1), sparsity_threshold=0.8) - + # This is a test workload that manually transforms a dense model to sparse + params = random_sparse_dense_params(mod, params, BS_R=bs_r, BS_C=bs_c, density=1 - sparsity) + # Currently we only support to conver dense matmul to sparse dense matmul + mod, params = ddo.bsr_dense.convert(mod, params, (bs_r, bs_c), sparsity_threshold=0.8) mod = tvm.IRModule.from_expr(mod) return mod, params, input_shape, output_shape diff --git a/tutorials/auto_scheduler/tune_network_x86.py b/tutorials/auto_scheduler/tune_network_x86.py index 04ae9cc28dc9..9ca6193b8475 100644 --- a/tutorials/auto_scheduler/tune_network_x86.py +++ b/tutorials/auto_scheduler/tune_network_x86.py @@ -50,7 +50,7 @@ import tvm from tvm import relay, auto_scheduler from tvm.relay import data_dep_optimization as ddo -from tvm.topi.sparse.utils import random_bsr_matrix +from tvm.topi.sparse.utils import random_sparse_dense_params import tvm.relay.testing from tvm.contrib import graph_runtime @@ -137,37 +137,14 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32", use_sparse=Fal raise ValueError("Network not found.") if use_sparse: - # This is a test workload that manually transforms a dense model to sparse - # Check `tutorials/frontend/deploy_sparse.py` for more examples on how to import a - # pretrained model. - - def random_sparse_dense_params(func, params, density, BS_R, BS_C): - def deepcopy(param_dic): - ret = {} - for k, v in param_dic.items(): - ret[k] = tvm.nd.array(v.asnumpy()) - return ret - - new_params = deepcopy(params) - dense_weight_names = relay.analysis.sparse_dense._search_dense_op_weight(func) - for item in dense_weight_names: - name = str(item) - shape = new_params[name].shape - if shape[0] % BS_R == 0 and shape[1] % BS_C == 0: - new_w = random_bsr_matrix( - shape[0], shape[1], BS_R, BS_C, density, "float32" - ).todense() - new_params[name] = tvm.nd.array(new_w) - return new_params - bs_r = 1 + bs_c = 1 sparsity = 0.85 - - # Currently we only support to conver dense matmul to sparse dense matmul mod, params = ddo.simplify_fc_transpose.convert(mod["main"], params) - params = random_sparse_dense_params(mod, params, BS_R=bs_r, BS_C=1, density=1 - sparsity) - mod, params = ddo.bsr_dense.convert(mod, params, (bs_r, 1), sparsity_threshold=0.8) - + # This is a test workload that manually transforms a dense model to sparse + params = random_sparse_dense_params(mod, params, BS_R=bs_r, BS_C=bs_c, density=1 - sparsity) + # Currently we only support to conver dense matmul to sparse dense matmul + mod, params = ddo.bsr_dense.convert(mod, params, (bs_r, bs_c), sparsity_threshold=0.8) mod = tvm.IRModule.from_expr(mod) return mod, params, input_shape, output_shape From 98de8d82a668c92e55c129b50c6faecf464052cb Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 24 Mar 2021 11:01:14 +0800 Subject: [PATCH 19/25] Update --- python/tvm/topi/sparse/utils.py | 26 +++++++++++++++++--- tutorials/auto_scheduler/tune_network_arm.py | 2 +- tutorials/auto_scheduler/tune_network_x86.py | 2 +- 3 files changed, 25 insertions(+), 5 deletions(-) diff --git a/python/tvm/topi/sparse/utils.py b/python/tvm/topi/sparse/utils.py index 39edeb01d70f..8c08026e0907 100644 --- a/python/tvm/topi/sparse/utils.py +++ b/python/tvm/topi/sparse/utils.py @@ -52,7 +52,27 @@ def random_bsr_matrix(m, n, bs_r, bs_c, density, dtype): return s -def random_sparse_dense_params(func, params, density, BS_R, BS_C): +def random_sparse_dense_params(func, params, bs_r, bs_c, density): + """Replace the dense parameters with random sparse parameters. Mainly used for testing. + + Parameters + ---------- + func : tvm.relay.Expr + Expr will be optimized to sparse operation. + params : Dict[Srting, tvm.nd.array] + Parameters of the Expr. + bs_r : int + The row of BSR matrix block. + bs_c : int + The column of BSR matrix block. + density : float + The density of the random sparse parameters. + + Returns + ------- + Dict[Srting, tvm.nd.array] + The generated random parameters. + """ def deepcopy(param_dic): ret = {} for k, v in param_dic.items(): @@ -64,7 +84,7 @@ def deepcopy(param_dic): for item in dense_weight_names: name = str(item) shape = new_params[name].shape - if shape[0] % BS_R == 0 and shape[1] % BS_C == 0: - new_w = random_bsr_matrix(shape[0], shape[1], BS_R, BS_C, density, "float32").todense() + if shape[0] % bs_r == 0 and shape[1] % bs_c == 0: + new_w = random_bsr_matrix(shape[0], shape[1], bs_r, bs_c, density, "float32").todense() new_params[name] = tvm.nd.array(new_w) return new_params diff --git a/tutorials/auto_scheduler/tune_network_arm.py b/tutorials/auto_scheduler/tune_network_arm.py index 74976813921a..a23314ed2822 100644 --- a/tutorials/auto_scheduler/tune_network_arm.py +++ b/tutorials/auto_scheduler/tune_network_arm.py @@ -145,7 +145,7 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32", use_sparse=Fal sparsity = 0.85 mod, params = ddo.simplify_fc_transpose.convert(mod["main"], params) # This is a test workload that manually transforms a dense model to sparse - params = random_sparse_dense_params(mod, params, BS_R=bs_r, BS_C=bs_c, density=1 - sparsity) + params = random_sparse_dense_params(mod, params, bs_r=bs_r, bs_c=bs_c, density=1 - sparsity) # Currently we only support to conver dense matmul to sparse dense matmul mod, params = ddo.bsr_dense.convert(mod, params, (bs_r, bs_c), sparsity_threshold=0.8) mod = tvm.IRModule.from_expr(mod) diff --git a/tutorials/auto_scheduler/tune_network_x86.py b/tutorials/auto_scheduler/tune_network_x86.py index 9ca6193b8475..1783b10710ad 100644 --- a/tutorials/auto_scheduler/tune_network_x86.py +++ b/tutorials/auto_scheduler/tune_network_x86.py @@ -142,7 +142,7 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32", use_sparse=Fal sparsity = 0.85 mod, params = ddo.simplify_fc_transpose.convert(mod["main"], params) # This is a test workload that manually transforms a dense model to sparse - params = random_sparse_dense_params(mod, params, BS_R=bs_r, BS_C=bs_c, density=1 - sparsity) + params = random_sparse_dense_params(mod, params, bs_r=bs_r, bs_c=bs_c, density=1 - sparsity) # Currently we only support to conver dense matmul to sparse dense matmul mod, params = ddo.bsr_dense.convert(mod, params, (bs_r, bs_c), sparsity_threshold=0.8) mod = tvm.IRModule.from_expr(mod) From 2f030eb8b4bce1beeb9901f6aeb28574a6c59e9f Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 24 Mar 2021 11:47:36 +0800 Subject: [PATCH 20/25] Update --- python/tvm/topi/sparse/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/topi/sparse/utils.py b/python/tvm/topi/sparse/utils.py index 8c08026e0907..59bdd406e00b 100644 --- a/python/tvm/topi/sparse/utils.py +++ b/python/tvm/topi/sparse/utils.py @@ -73,6 +73,7 @@ def random_sparse_dense_params(func, params, bs_r, bs_c, density): Dict[Srting, tvm.nd.array] The generated random parameters. """ + def deepcopy(param_dic): ret = {} for k, v in param_dic.items(): From a02e98d8d0a3a919df0229d17470ad15c80bfaca Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 24 Mar 2021 13:13:00 +0800 Subject: [PATCH 21/25] Update --- python/tvm/topi/sparse/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/topi/sparse/utils.py b/python/tvm/topi/sparse/utils.py index 59bdd406e00b..07a99aa3a736 100644 --- a/python/tvm/topi/sparse/utils.py +++ b/python/tvm/topi/sparse/utils.py @@ -17,7 +17,6 @@ """Some utils for Sparse operation.""" import tvm from tvm import relay -from tvm.relay import data_dep_optimization as ddo def random_bsr_matrix(m, n, bs_r, bs_c, density, dtype): From 6c09f8d7b5e84a9338efc2a6e31d67419d097f29 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Fri, 26 Mar 2021 16:56:59 +0800 Subject: [PATCH 22/25] Update --- python/tvm/topi/sparse/utils.py | 29 ++++++++++++++++++++ tutorials/auto_scheduler/tune_network_arm.py | 12 ++------ tutorials/auto_scheduler/tune_network_x86.py | 12 ++------ 3 files changed, 33 insertions(+), 20 deletions(-) diff --git a/python/tvm/topi/sparse/utils.py b/python/tvm/topi/sparse/utils.py index 07a99aa3a736..e8e69f73f8cb 100644 --- a/python/tvm/topi/sparse/utils.py +++ b/python/tvm/topi/sparse/utils.py @@ -88,3 +88,32 @@ def deepcopy(param_dic): new_w = random_bsr_matrix(shape[0], shape[1], bs_r, bs_c, density, "float32").todense() new_params[name] = tvm.nd.array(new_w) return new_params + + +def convert_model_dense_to_sparse(mod, params, random_params=False, bs_r=1, bs_c=1, sparsity=0.85): + """Convert a dense model to sparse model. + + Parameters + ---------- + mod : tvm.Module + The dense model. + params : Dict[Srting, tvm.nd.array] + Parameters of the dense model. + random_params : Bool = False + True to replace the parameters of the dense model with some random sparse tensors. + This is used for testing. + bs_r : int + The row of BSR matrix block. + bs_c : int + The column of BSR matrix block. + sparsity : float + The sparsity of the random sparse parameters. + """ + mod, params = ddo.simplify_fc_transpose.convert(mod["main"], params) + if random_params: + # Manually replace the parameters of dense model to sparse tensors + params = random_sparse_dense_params(mod, params, bs_r=bs_r, bs_c=bs_c, density=1 - sparsity) + # Currently we only support to conver dense matmul to sparse dense matmul + mod, params = ddo.bsr_dense.convert(mod, params, (bs_r, bs_c), sparsity_threshold=0.8) + + return tvm.IRModule.from_expr(mod), params diff --git a/tutorials/auto_scheduler/tune_network_arm.py b/tutorials/auto_scheduler/tune_network_arm.py index a23314ed2822..1cd74307561a 100644 --- a/tutorials/auto_scheduler/tune_network_arm.py +++ b/tutorials/auto_scheduler/tune_network_arm.py @@ -52,7 +52,6 @@ import tvm from tvm import relay, auto_scheduler from tvm.relay import data_dep_optimization as ddo -from tvm.topi.sparse.utils import random_sparse_dense_params import tvm.relay.testing from tvm.contrib import graph_runtime from tvm.contrib.utils import tempdir @@ -140,15 +139,8 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32", use_sparse=Fal raise ValueError("Network not found.") if use_sparse: - bs_r = 1 - bs_c = 1 - sparsity = 0.85 - mod, params = ddo.simplify_fc_transpose.convert(mod["main"], params) - # This is a test workload that manually transforms a dense model to sparse - params = random_sparse_dense_params(mod, params, bs_r=bs_r, bs_c=bs_c, density=1 - sparsity) - # Currently we only support to conver dense matmul to sparse dense matmul - mod, params = ddo.bsr_dense.convert(mod, params, (bs_r, bs_c), sparsity_threshold=0.8) - mod = tvm.IRModule.from_expr(mod) + from tvm.topi.sparse.utils import convert_model_dense_to_sparse + mod, params = convert_model_dense_to_sparse(mod, params, True) return mod, params, input_shape, output_shape diff --git a/tutorials/auto_scheduler/tune_network_x86.py b/tutorials/auto_scheduler/tune_network_x86.py index 1783b10710ad..a4c413e1d851 100644 --- a/tutorials/auto_scheduler/tune_network_x86.py +++ b/tutorials/auto_scheduler/tune_network_x86.py @@ -50,7 +50,6 @@ import tvm from tvm import relay, auto_scheduler from tvm.relay import data_dep_optimization as ddo -from tvm.topi.sparse.utils import random_sparse_dense_params import tvm.relay.testing from tvm.contrib import graph_runtime @@ -137,15 +136,8 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32", use_sparse=Fal raise ValueError("Network not found.") if use_sparse: - bs_r = 1 - bs_c = 1 - sparsity = 0.85 - mod, params = ddo.simplify_fc_transpose.convert(mod["main"], params) - # This is a test workload that manually transforms a dense model to sparse - params = random_sparse_dense_params(mod, params, bs_r=bs_r, bs_c=bs_c, density=1 - sparsity) - # Currently we only support to conver dense matmul to sparse dense matmul - mod, params = ddo.bsr_dense.convert(mod, params, (bs_r, bs_c), sparsity_threshold=0.8) - mod = tvm.IRModule.from_expr(mod) + from tvm.topi.sparse.utils import convert_model_dense_to_sparse + mod, params = convert_model_dense_to_sparse(mod, params, True) return mod, params, input_shape, output_shape From 2b404c85fcbb5ba534ebaaf39c407d07b873cd39 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Fri, 26 Mar 2021 22:00:29 +0800 Subject: [PATCH 23/25] Update --- tutorials/auto_scheduler/tune_network_arm.py | 1 + tutorials/auto_scheduler/tune_network_x86.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tutorials/auto_scheduler/tune_network_arm.py b/tutorials/auto_scheduler/tune_network_arm.py index 1cd74307561a..3e6035473304 100644 --- a/tutorials/auto_scheduler/tune_network_arm.py +++ b/tutorials/auto_scheduler/tune_network_arm.py @@ -140,6 +140,7 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32", use_sparse=Fal if use_sparse: from tvm.topi.sparse.utils import convert_model_dense_to_sparse + mod, params = convert_model_dense_to_sparse(mod, params, True) return mod, params, input_shape, output_shape diff --git a/tutorials/auto_scheduler/tune_network_x86.py b/tutorials/auto_scheduler/tune_network_x86.py index a4c413e1d851..88ababfc458a 100644 --- a/tutorials/auto_scheduler/tune_network_x86.py +++ b/tutorials/auto_scheduler/tune_network_x86.py @@ -137,6 +137,7 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32", use_sparse=Fal if use_sparse: from tvm.topi.sparse.utils import convert_model_dense_to_sparse + mod, params = convert_model_dense_to_sparse(mod, params, True) return mod, params, input_shape, output_shape From 52a18556999f586e31bd8b125bc629baf82d5ced Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Fri, 26 Mar 2021 22:03:57 +0800 Subject: [PATCH 24/25] Update --- python/tvm/topi/sparse/utils.py | 9 ++++++++- tutorials/auto_scheduler/tune_network_arm.py | 2 +- tutorials/auto_scheduler/tune_network_x86.py | 2 +- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/python/tvm/topi/sparse/utils.py b/python/tvm/topi/sparse/utils.py index e8e69f73f8cb..e91f1c02e8ce 100644 --- a/python/tvm/topi/sparse/utils.py +++ b/python/tvm/topi/sparse/utils.py @@ -101,13 +101,20 @@ def convert_model_dense_to_sparse(mod, params, random_params=False, bs_r=1, bs_c Parameters of the dense model. random_params : Bool = False True to replace the parameters of the dense model with some random sparse tensors. - This is used for testing. + This is mainly used for testing. bs_r : int The row of BSR matrix block. bs_c : int The column of BSR matrix block. sparsity : float The sparsity of the random sparse parameters. + + Returns + ------- + tvm.Module + The updated sparse model. + Dict[Srting, tvm.nd.array] + The updated parameters. """ mod, params = ddo.simplify_fc_transpose.convert(mod["main"], params) if random_params: diff --git a/tutorials/auto_scheduler/tune_network_arm.py b/tutorials/auto_scheduler/tune_network_arm.py index 3e6035473304..5a8407eb8675 100644 --- a/tutorials/auto_scheduler/tune_network_arm.py +++ b/tutorials/auto_scheduler/tune_network_arm.py @@ -141,7 +141,7 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32", use_sparse=Fal if use_sparse: from tvm.topi.sparse.utils import convert_model_dense_to_sparse - mod, params = convert_model_dense_to_sparse(mod, params, True) + mod, params = convert_model_dense_to_sparse(mod, params, random_params=True) return mod, params, input_shape, output_shape diff --git a/tutorials/auto_scheduler/tune_network_x86.py b/tutorials/auto_scheduler/tune_network_x86.py index 88ababfc458a..2839db8646d0 100644 --- a/tutorials/auto_scheduler/tune_network_x86.py +++ b/tutorials/auto_scheduler/tune_network_x86.py @@ -138,7 +138,7 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32", use_sparse=Fal if use_sparse: from tvm.topi.sparse.utils import convert_model_dense_to_sparse - mod, params = convert_model_dense_to_sparse(mod, params, True) + mod, params = convert_model_dense_to_sparse(mod, params, random_params=True) return mod, params, input_shape, output_shape From 2a92eeb2fe36a42b993d52c37f78d17807e7f2fa Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Tue, 30 Mar 2021 08:49:47 +0800 Subject: [PATCH 25/25] Lintfix --- python/tvm/topi/sparse/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/topi/sparse/utils.py b/python/tvm/topi/sparse/utils.py index e91f1c02e8ce..43bc6e021429 100644 --- a/python/tvm/topi/sparse/utils.py +++ b/python/tvm/topi/sparse/utils.py @@ -108,7 +108,7 @@ def convert_model_dense_to_sparse(mod, params, random_params=False, bs_r=1, bs_c The column of BSR matrix block. sparsity : float The sparsity of the random sparse parameters. - + Returns ------- tvm.Module