diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 77cba5fa2ff1..d28044c3845d 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -1066,12 +1066,16 @@ struct SparseTransposeAttrs : public tvm::AttrsNode { /*! \brief Attributes for sparse_dense operator */ struct SparseConv2DAttrs : public tvm::AttrsNode { std::string layout; + Array kernel_size; TVM_DECLARE_ATTRS(SparseConv2DAttrs, "relay.attrs.SparseConv2DAttrs") { TVM_ATTR_FIELD(layout).set_default("NHWC").describe( "Dimension ordering of input data. Can be 'NCHW', 'NHWC'" "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "dimensions respectively."); + TVM_ATTR_FIELD(kernel_size) + .set_default(Array{1, 1}) + .describe("Kernel size for SparseConv2D, 1x1 or 3x3. "); } }; diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py index db4ff26857bd..eab6822b63b8 100644 --- a/python/tvm/autotvm/measure/measure_methods.py +++ b/python/tvm/autotvm/measure/measure_methods.py @@ -254,13 +254,14 @@ def ref_input(self): @ref_input.setter def ref_input(self, val): - warnings.warn( - "You are specifying fixed input for tuning the operator. " - "Be sure your input always fits the operator. Some " - "operators may conduct layout transformation during tuning, " - "thus can lead to unexpected behaviors. ", - RuntimeWarning, - ) + if val is not None: + warnings.warn( + "You are specifying fixed input for tuning the operator. " + "Be sure your input always fits the operator. Some " + "operators may conduct layout transformation during tuning, " + "thus can lead to unexpected behaviors. ", + RuntimeWarning, + ) self._ref_input = val def set_task(self, task): diff --git a/python/tvm/relay/analysis/sparse_conv2d.py b/python/tvm/relay/analysis/sparse_conv2d.py index 11278bddca33..1862ded831f6 100644 --- a/python/tvm/relay/analysis/sparse_conv2d.py +++ b/python/tvm/relay/analysis/sparse_conv2d.py @@ -54,7 +54,9 @@ def _search_conv2d_op_weight(expr): return _ffi_api.search_conv2d_op_weight(expr) -def process_params(expr, params, block_size, sparsity_threshold, layout): +def process_params( + expr, params, block_size, sparsity_threshold, layout, kernel_size, reg_task_input=True +): """Process parameters of conv2d from dense to sparse. Parameters @@ -86,14 +88,18 @@ def process_params(expr, params, block_size, sparsity_threshold, layout): for name in weight_names: name = str(name) w_np = params[name].numpy() - # currently only support conv2d_1*1 - if not ( - (w_np.shape[0] == 1 and w_np.shape[1] == 1) - or (w_np.shape[2] == 1 and w_np.shape[3] == 1) - ): + + if layout == "NHWC": # HWIO + weight_kernel = (w_np.shape[0], w_np.shape[1]) + elif layout == "NCHW": # OIHW + weight_kernel = (w_np.shape[2], w_np.shape[3]) + if weight_kernel[0] != weight_kernel[1]: continue - sparsity = 1.0 - (np.count_nonzero(w_np) / w_np.size) - if sparsity >= sparsity_threshold: + + if weight_kernel[0] == kernel_size == 1: + sparsity = 1.0 - (np.count_nonzero(w_np) / w_np.size) + if sparsity < sparsity_threshold: + continue if layout == "NHWC": w_np = w_np.squeeze().T elif layout == "NCHW": @@ -108,19 +114,31 @@ def process_params(expr, params, block_size, sparsity_threshold, layout): ) else: sparse_weight_data = sparse_weight.data + elif weight_kernel[0] == kernel_size == 3: + if layout == "NHWC": # HWIO + w_np = w_np.reshape((-1, w_np.shape[-1])).T + elif layout == "NCHW": # OIHW + w_np = w_np.reshape((w_np.shape[0], -1)) + sparse_weight = sp.bsr_matrix(w_np, blocksize=block_size) + if 1 - (sparse_weight.nnz / w_np.size) < sparsity_threshold: + continue + sparse_weight_data = sparse_weight.data + else: + continue - # remove dense weight - del params[name] - memo.weight_name.append(name) - memo.weight_shape.append( - list(sparse_weight_data.shape) - + list(sparse_weight.indices.shape) - + list(sparse_weight.indptr.shape) - ) - params[name + ".data"] = tvm.nd.array(sparse_weight_data) - params[name + ".indices"] = tvm.nd.array(sparse_weight.indices) - params[name + ".indptr"] = tvm.nd.array(sparse_weight.indptr) - + # remove dense weight + del params[name] + memo.weight_name.append(name) + memo.weight_shape.append( + list(sparse_weight_data.shape) + + list(sparse_weight.indices.shape) + + list(sparse_weight.indptr.shape) + ) + params[name + ".data"] = tvm.nd.array(sparse_weight_data) + params[name + ".indices"] = tvm.nd.array(sparse_weight.indices) + params[name + ".indptr"] = tvm.nd.array(sparse_weight.indptr) + + if reg_task_input: prefix = "sparse_conv2d_bsr_%d_%d_%d_%d_%d_%d_" % ( w_np.shape[0], w_np.shape[1], diff --git a/python/tvm/relay/data_dep_optimization/bsr_conv2d.py b/python/tvm/relay/data_dep_optimization/bsr_conv2d.py index 6913a428b2ac..20e01da1493e 100644 --- a/python/tvm/relay/data_dep_optimization/bsr_conv2d.py +++ b/python/tvm/relay/data_dep_optimization/bsr_conv2d.py @@ -23,8 +23,8 @@ from .utils import _run_opt_pass -def convert(func, params, blocksize, sparsity_threshold, layout="NHWC"): - """Convert a dense func and according parameters to block sparse +def convert(func, params, blocksize, sparsity_threshold, layout="NHWC", kernel_size=1): + """Convert a conv2d func and according parameters to block sparse Parameters ---------- @@ -49,10 +49,46 @@ def convert(func, params, blocksize, sparsity_threshold, layout="NHWC"): params: Dict[Srting, tvm.nd.array] New params with BSR matrix for mutated Expr """ - weight_info = process_params(func, params, blocksize, sparsity_threshold, layout) + weight_info = process_params(func, params, blocksize, sparsity_threshold, layout, kernel_size) new_func = _run_opt_pass( func, - relay.transform.Conv2dToSparse(weight_info.weight_name, weight_info.weight_shape, layout), + relay.transform.Conv2dToSparse( + weight_info.weight_name, weight_info.weight_shape, layout, kernel_size + ), ) return new_func, params + + +def convert2(func, params, blocksize, sparsity_threshold, layout, kernel_size): + """Convert a freezed conv2d func to block sparse + + Parameters + ---------- + func : relay.Expr + Expr will be optimized to sparse operation, with params freezed + params : Dict[Srting, tvm.nd.array] + Parameters of the Expr (not used in this pass) + blocksize : Tuple(int, int) + Blocksize for BSR matrix + sparsity_threshold : float + Minimal sparsity requirement for converting. + If weight sparsity is lower than this threshold, + the dense operation will be kept. + layout : str + layout of network + kernel_size : int + kernel size of the conv2d, for filtering + + Returns + ------- + new_func: relay.Expr + Mutated Expr with sparse operations + + params: Dict[Srting, tvm.nd.array] + New params with BSR matrix for mutated Expr (not modified) + """ + new_func = _run_opt_pass( + func, relay.transform.Conv2dToSparse2(layout, kernel_size, blocksize, sparsity_threshold) + ) + return new_func, params diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index a9ccc5aa2d24..a9e485866381 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -198,7 +198,11 @@ def compute_sparse_transpose(attrs, inputs, out_type): @reg.register_compute("nn.sparse_conv2d") def compute_sparse_conv2d(attrs, inputs, out_type): """Compute definition of sparse_conv2d""" - return [topi.nn.sparse_conv2d(inputs[0], inputs[1], inputs[2], inputs[3], attrs["layout"])] + return [ + topi.nn.sparse_conv2d( + inputs[0], inputs[1], inputs[2], inputs[3], attrs["layout"], attrs["kernel_size"] + ) + ] reg.register_strategy("nn.sparse_conv2d", strategy.sparse_conv2d_strategy) diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index a6e141f2753b..1c8d1b478cb1 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -565,6 +565,31 @@ def sparse_dense_strategy_cpu(attrs, inputs, out_type, target): return strategy +@sparse_conv2d_strategy.register("cpu") +def sparse_conv2d_strategy_cpu(attrs, inputs, out_type, target): + """sparse conv2d x86 strategy""" + strategy = _op.OpStrategy() + if attrs["kernel_size"][0] == 1: + strategy.add_implementation( + wrap_compute_sparse_conv2d(topi.nn.sparse_conv2d), + wrap_topi_schedule(topi.generic.schedule_sparse_conv2d), + name="sparse_conv2d.generic", + ) + elif attrs["kernel_size"][0] == 3: + if attrs["layout"] == "NHWC": + strategy.add_implementation( + wrap_compute_sparse_conv2d(topi.x86.spconv2d_3x3_nhwc), + wrap_topi_schedule(topi.x86.schedule_spconv2d_3x3_nhwc), + name="conv3x3_spNHWC.x86", + ) + elif attrs["layout"] == "NCHW": + strategy.add_implementation( + wrap_compute_sparse_conv2d(topi.x86.spconv2d_3x3_nchw), + wrap_topi_schedule(topi.x86.schedule_spconv2d_3x3_nchw), + ) + return strategy + + @roi_align_strategy.register("cpu") def roi_align_strategy_cpu(attrs, inputs, out_type, target): """roi_align x86 strategy""" diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 6294e7acea15..9a7857a01fe6 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1093,7 +1093,7 @@ def DenseToSparse(weight_name, weight_shape): return _ffi_api.DenseToSparse(weight_name, weight_shape) -def Conv2dToSparse(weight_name, weight_shape, layout): +def Conv2dToSparse(weight_name, weight_shape, layout, kernel_size): """ Rewrite qualified ```nn.conv2d operation``` to ```nn.sparse_conv2d``` @@ -1113,7 +1113,27 @@ def Conv2dToSparse(weight_name, weight_shape, layout): ret : tvm.transform.Pass The registered DenseToSparse pass. """ - return _ffi_api.Conv2dToSparse(weight_name, weight_shape, layout) + return _ffi_api.Conv2dToSparse(weight_name, weight_shape, layout, kernel_size) + + +def Conv2dToSparse2(layout, kernel_size, blocksize, sparsity_threshold): + """ + Rewrite freezed ```nn.conv2d``` operation to ```nn.sparse_conv2d``` + + Parameters + ---------- + layout : str + layout of data + + kernel_size : int + kernel size of conv2d + + Returns + ------- + ret : tvm.transform.Pass + The registered DenseToSparse pass. + """ + return _ffi_api.Conv2dToSparse2(layout, kernel_size, *blocksize, sparsity_threshold) def SimplifyFCTranspose(target_weight_name): diff --git a/python/tvm/topi/nn/sparse.py b/python/tvm/topi/nn/sparse.py index 948847e60d92..e577104c3ddc 100644 --- a/python/tvm/topi/nn/sparse.py +++ b/python/tvm/topi/nn/sparse.py @@ -566,7 +566,9 @@ def _compute_block(i, nb_j, j, h, w): # pylint: disable=C0103 ) -def sparse_conv2d(dense_data, sparse_data, sparse_indices, sparse_indptr, layout="NHWC"): +def sparse_conv2d( + dense_data, sparse_data, sparse_indices, sparse_indptr, layout="NHWC", kernel_size=1 +): """ Computes sparse-conv2d(1*1) of ``data`` and ``(weight_data, weight_indices, weight_indptr)`` @@ -598,14 +600,15 @@ def sparse_conv2d(dense_data, sparse_data, sparse_indices, sparse_indptr, layout 4-D with shape [M, H, W, N] (layout=NHWC) 4-D with shape [M, N, H ,W] (layout=NCHW) """ - if layout == "NHWC": - return _sparse_conv2d_bsr_compute_nhwc( - dense_data, sparse_data, sparse_indices, sparse_indptr - ) - elif layout == "NCHW": - return _sparse_conv2d_bsr_compute_nchw( - dense_data, sparse_data, sparse_indices, sparse_indptr - ) + if kernel_size == 1: + if layout == "NHWC": + return _sparse_conv2d_bsr_compute_nhwc( + dense_data, sparse_data, sparse_indices, sparse_indptr + ) + elif layout == "NCHW": + return _sparse_conv2d_bsr_compute_nchw( + dense_data, sparse_data, sparse_indices, sparse_indptr + ) else: raise ValueError("Unsupport Layout %s" % layout) diff --git a/python/tvm/topi/x86/sparse.py b/python/tvm/topi/x86/sparse.py index c6300f6701e0..48ec233fa4bb 100644 --- a/python/tvm/topi/x86/sparse.py +++ b/python/tvm/topi/x86/sparse.py @@ -16,8 +16,10 @@ # under the License. """sparse_dense schedule on x86""" -from tvm import te +from functools import partial, reduce +from tvm import te, tir, autotvm +from ..transform import reshape from ..utils import traverse_inline, get_const_int from .utils import get_fp32_len @@ -60,3 +62,161 @@ def _callback(op): traverse_inline(s, outs[0].op, _callback) return s + + +@autotvm.register_topi_compute("conv3x3_spNHWC.x86") +def spconv2d_3x3_nhwc(cfg, data, wdat, wind, wptr, layout="NHWC"): + """Sparse Conv2d 3x3 compute (NHWC).""" + assert layout == "NHWC" + nsamples, imh, imw, chanin = [i.value for i in data.shape] + nelems, bsrr, bsrc = [i.value for i in wdat.shape] + chanout = (wptr.shape[0].value - 1) * bsrr + + imglen, chanlen = nsamples * imh * imw, 9 * chanin + cfg.define_split("tile_y", imglen, num_outputs=3) + cfg.define_split("tile_x", chanout // bsrr, num_outputs=2) + cfg.add_flop(imglen * (nelems * bsrc * bsrr * 2 - chanout)) + if cfg.is_fallback: + cfg["tile_y"] = autotvm.task.space.SplitEntity([-1, 160, 8]) + cfg["tile_x"] = autotvm.task.space.SplitEntity([-1, 4]) + + idxsplit = lambda x, y: reduce(lambda a, b: a[:-1] + [a[-1] % b, a[-1] // b], y, [x]) + + @partial(te.compute, (imglen, chanlen), name="Im2Col") + def im2col(row, col): + j_w, j_h, j_n = idxsplit(row, [imw, imh]) + j_c, k_w, k_h = idxsplit(col, [chanin, 3]) + i_h, i_w = j_h + k_h - 1, j_w + k_w - 1 + return tir.if_then_else( + tir.all(i_h >= 0, i_h < imh, i_w >= 0, i_w < imw), data[j_n, i_h, i_w, j_c], 0 + ) + + @partial(te.compute, (imglen, chanout // bsrr, bsrr, bsrc), name="CC") + def matmul(drow, wrow, brow, bcol): + row_start, row_end = wptr[wrow], wptr[wrow + 1] + elem_idx = te.reduce_axis((0, row_end - row_start), name="elem_idx") + elem = row_start + elem_idx + return te.sum( + im2col[drow, wind[elem] * bsrc + bcol] * wdat[elem, brow, bcol], axis=elem_idx + ) + + sum_bsrc = te.reduce_axis((0, bsrc), name="k") + ret = te.compute( + (imglen, chanout), + lambda y, x: te.sum(matmul[y, x // bsrr, x % bsrr, sum_bsrc], axis=sum_bsrc), + name="C", + tag="conv3x3_spNHWC", + ) + return reshape(ret, (nsamples, imh, imw, chanout)) + + +@autotvm.register_topi_schedule("conv3x3_spNHWC.x86") +def schedule_spconv2d_3x3_nhwc(cfg, outs): + """Sparse Conv2d 3x3 schedule (NHWC).""" + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if op.tag == "conv3x3_spNHWC": + (matmul,) = op.input_tensors + # wptr, wind, im2col, wdat + _, _, im2col, _ = matmul.op.input_tensors + (data,) = im2col.op.input_tensors + bsrr = matmul.shape[-2].value + chanin = data.shape[-1].value + + mm_y, mm_x = s[op].op.axis + y_t, y_o, y_i = cfg["tile_y"].apply(s, op, mm_y) + x_o, x_i = s[op].split(mm_x, factor=bsrr) + x_t, x_o = cfg["tile_x"].apply(s, op, x_o) + (sum_ax,) = s[op].op.reduce_axis + s[op].reorder(y_t, x_t, y_o, x_o, y_i, x_i, sum_ax) + s[op].unroll(sum_ax) + s[op].vectorize(x_i) + s[op].unroll(y_i) + + s[matmul].compute_at(s[op], x_o) + y_i, x_i, bsrr, bsrc = s[matmul].op.axis + (sum_ax,) = s[matmul].op.reduce_axis + s[matmul].reorder(x_i, sum_ax, y_i, bsrr, bsrc) + s[matmul].unroll(bsrc) + s[matmul].vectorize(bsrr) + s[matmul].unroll(y_i) + + s[im2col].compute_at(s[op], y_o) + y_i, sum_ax = s[im2col].op.axis + _, k_i = s[im2col].split(sum_ax, factor=chanin) + s[im2col].vectorize(k_i) + + traverse_inline(s, outs[0].op, _callback) + return s + + +@autotvm.register_topi_compute("conv3x3_spNCHW.x86") +def spconv2d_3x3_nchw(cfg, data, wdat, wind, wptr, layout="NCHW"): + """Sparse Conv2d 3x3 compute (NCHW).""" + nsamples, chanin, imgh, imgw = [i.value for i in data.shape] + nelems, veclen, bsrc = [i.value for i in wdat.shape] + chanout = (wptr.shape[0].value - 1) * veclen + assert bsrc == 1 and layout == "NCHW" + + cfg.add_flop(nsamples * imgh * imgw * (nelems * veclen * bsrc * 2 - chanout)) + cfg.define_split("tile_hw", imgh * imgw, num_outputs=3) + cfg.define_split("tile_ckk", chanin * 9, num_outputs=3) + + @partial(te.compute, (nsamples, chanin * 3 * 3, imgh * imgw), name="im2col") + def im2col(nsamples, ckk, imglen): + j_h, j_w = imglen // imgw, imglen % imgw + i_c, k_h, k_w = ckk // 9, ckk // 3 % 3, ckk % 3 + i_h, i_w = j_h + k_h - 1, j_w + k_w - 1 + return tir.if_then_else( + tir.all(i_h >= 0, i_h < imgh, i_w >= 0, i_w < imgw), data[nsamples, i_c, i_h, i_w], 0 + ) + + @partial( + te.compute, + (nsamples, chanout // veclen, veclen, bsrc, imgh * imgw), + name="CC", + tag="conv3x3_spNCHW", + ) + def matmul(nsamples, f_o, f_i, bsrk, imglen): + row_start, row_end = wptr[f_o], wptr[f_o + 1] + elem_idx = te.reduce_axis((0, row_end - row_start), name="elem_idx") + elem = row_start + elem_idx + return te.sum( + im2col[nsamples, wind[elem] * bsrc + bsrk, imglen] * wdat[elem, f_i, bsrk], + axis=elem_idx, + ) + + return reshape(matmul, [nsamples, chanout, imgh, imgw]) + + +@autotvm.register_topi_schedule("conv3x3_spNCHW.x86") +def schedule_spconv2d_3x3_nchw(cfg, outs): + """Sparse Conv2d 3x3 schedule (NCHW).""" + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if op.tag == "conv3x3_spNCHW": + # wptr, wind, im2col, wdat + _, _, im2col, _ = op.input_tensors + + n_samples, f_o, f_i, b_c, imglen = s[op].op.axis + (sum_ax,) = s[op].op.reduce_axis + hw1, hw2, hw3 = cfg["tile_hw"].apply(s, op, imglen) + s[op].reorder(n_samples, hw1, f_o, hw2, sum_ax, f_i, b_c, hw3) + s[op].unroll(f_i) + s[op].unroll(b_c) + s[op].vectorize(hw3) + + s[im2col].compute_at(s[op], hw1) + n_samples, ckk, imglen = s[im2col].op.axis + ckk1, ckk2, ckk3 = cfg["tile_ckk"].apply(s, im2col, ckk) + hw2, hw3 = s[im2col].split(imglen, factor=cfg["tile_hw"].size[-1]) + s[im2col].reorder(n_samples, ckk1, ckk2, hw2, ckk3, hw3) + s[im2col].unroll(ckk3) + s[im2col].vectorize(hw3) + + traverse_inline(s, outs[0].op, _callback) + return s diff --git a/src/relay/op/nn/sparse.cc b/src/relay/op/nn/sparse.cc index 32b0811b48ac..7d21005cb4db 100644 --- a/src/relay/op/nn/sparse.cc +++ b/src/relay/op/nn/sparse.cc @@ -274,10 +274,11 @@ bool SparseConv2dRel(const Array& types, int num_inputs, const Attrs& attr } Expr MakeSparseConv2d(Expr data, Expr weight_data, Expr weight_indices, Expr weight_indptr, - std::string layout) { + std::string layout, Array kernel_size) { static const Op& op = Op::Get("nn.sparse_conv2d"); auto attrs = make_object(); attrs->layout = std::move(layout); + attrs->kernel_size = std::move(kernel_size); return Call(op, {data, weight_data, weight_indices, weight_indptr}, Attrs(attrs), {}); } diff --git a/src/relay/transforms/convert_sparse_conv2d.cc b/src/relay/transforms/convert_sparse_conv2d.cc index 6e4c03b0fcbc..3f2c25e988f9 100644 --- a/src/relay/transforms/convert_sparse_conv2d.cc +++ b/src/relay/transforms/convert_sparse_conv2d.cc @@ -73,10 +73,12 @@ TVM_REGISTER_GLOBAL("relay.analysis.search_conv2d_op_weight").set_body_typed(Sea class Conv2dToSparseConv2dMutator : public ExprRewriter { public: Conv2dToSparseConv2dMutator(const Array& weight_name, - const Array>& weight_shape, const String& layout) + const Array>& weight_shape, const String& layout, + int kernel_size) : conv2d_op_(Op::Get("nn.conv2d")), sparse_conv2d_op_(Op::Get("nn.sparse_conv2d")) { ICHECK_EQ(weight_name.size(), weight_shape.size()); layout_ = layout; + kernel_size_ = kernel_size; for (size_t i = 0; i < weight_name.size(); ++i) { ICHECK(weight_name[i]->IsInstance()); std::string k = weight_name[i].as()->data; @@ -112,6 +114,7 @@ class Conv2dToSparseConv2dMutator : public ExprRewriter { Var weight_indptr(prefix + ".indptr", ws_indptr_type); auto attrs = make_object(); attrs->layout = std::move(layout_); + attrs->kernel_size = Array{kernel_size_, kernel_size_}; return Call(sparse_conv2d_op_, {data, weight_data, weight_indices, weight_indptr}, Attrs(attrs)); } @@ -126,22 +129,168 @@ class Conv2dToSparseConv2dMutator : public ExprRewriter { const Op& sparse_conv2d_op_; std::unordered_map> target_weights_; String layout_; + int kernel_size_; }; // class Conv2dToSparseConv2dAlter Expr Conv2dToSparse(const Expr& e, const Array& weight_name, - const Array>& weight_shape, const String& layout) { - auto rewriter = Conv2dToSparseConv2dMutator(weight_name, weight_shape, layout); + const Array>& weight_shape, const String& layout, + int kernel_size) { + auto rewriter = Conv2dToSparseConv2dMutator(weight_name, weight_shape, layout, kernel_size); + return PostOrderRewrite(e, &rewriter); +} + +template +auto unpack_to_tuple_internal(elemTy* arr, std::index_sequence) { + return std::make_tuple(arr[Is]...); +} + +template +auto unpack_to_tuple(elemTy* arr) { + return unpack_to_tuple_internal(arr, std::make_index_sequence{}); +} + +struct Range { + size_t dim; + explicit Range(size_t d) : dim(d) {} + + struct iterpoint { + size_t val, lim; + iterpoint(size_t v1, size_t v2) : val(v1), lim(v2) {} + + size_t operator*() const { return val; } + + iterpoint operator/(const iterpoint& rhs) const { + return iterpoint(val * rhs.lim + rhs.val, lim * rhs.lim); + } + }; + + struct iterator { + size_t val, lim; + iterator(size_t v1, size_t v2) : val(v1), lim(v2) {} + + bool operator!=(const iterator& rhs) const { return val != rhs.val; } + + void operator++() { ++val; } + + iterpoint operator*() const { return iterpoint(val, lim); } + }; + + iterator begin() { return iterator(0, dim); } + + iterator end() { return iterator(dim, dim); } +}; + +// Mutate ```nn.conv2d``` to ```nn.sparse_conv2d``` +class Conv2dToSparseConv2dMutator2 : public ExprRewriter { + public: + Conv2dToSparseConv2dMutator2(const String& layout, int kernel_size, int blockH, int blockW, + double sparse_thresh) + : sparse_conv2d_op_(Op::Get("nn.sparse_conv2d")), + dev_cpu0_{DLDeviceType::kDLCPU, 0}, + layout_(layout), + kernel_size_(kernel_size), + blockH_(blockH), + blockW_(blockW), + sparse_thresh_(sparse_thresh) {} + + Expr Rewrite_(const CallNode* pre, const Expr& post) override { + // check op type & attrs + const auto pre_attrs = pre->attrs.as(); + if (!pre_attrs || pre_attrs->data_layout != layout_ || + pre_attrs->strides[0].as()->value != 1 || + pre_attrs->kernel_size[0].as()->value != kernel_size_) + return post; + // check constant weight + const auto pre_weight_node = pre->args[1].as(); + if (!pre_weight_node) return post; + + // check weight dtype & shape + auto&& pre_weight = pre_weight_node->data; + auto dtype = pre_weight.DataType(), itype = runtime::DataType::Int(32); + ICHECK(dtype.code() == DataType::kFloat && dtype.bits() == 32); // float32 only + auto pre_weight_shape = unpack_to_tuple<4>(pre_weight.Shape().data()); + int O, I, H, W; + if (layout_ == "NCHW") { + std::tie(O, I, H, W) = pre_weight_shape; + } else { // NHWC + std::tie(H, W, I, O) = pre_weight_shape; + } + int CO = O, CI = H * W * I; + + // copy to vector + std::vector pre_weight_data(CO * CI); + pre_weight.CopyToBytes(pre_weight_data.data(), pre_weight_data.size() * sizeof(float)); + if (layout_ == "NHWC") { + std::vector tmp(pre_weight_data.size()); + for (auto i : Range(CO)) + for (auto j : Range(CI)) tmp[*(i / j)] = pre_weight_data[*(j / i)]; + std::swap(tmp, pre_weight_data); + } + // convert to BSR + std::vector wdata, block(blockH_ * blockW_); + std::vector windices, windptr; + for (auto bh : Range(CO / blockH_)) { + windptr.push_back(windices.size()); + for (auto bw : Range(CI / blockW_)) { + int cntnnz = 0; + for (auto i : Range(blockH_)) + for (auto j : Range(blockW_)) { + auto tmp = pre_weight_data[*(bh / i / bw / j)]; + if (tmp) cntnnz++; + block[*(i / j)] = tmp; + } + if (cntnnz) { + wdata.insert(wdata.end(), block.begin(), block.end()); + windices.push_back(*bw); + } + } + } + windptr.push_back(windices.size()); + double sprate = 1 - 1.0 * wdata.size() / pre_weight_data.size(); + if (sprate < sparse_thresh_) return post; + + // constrct return data + int nnz = windices.size(); + auto weight_data = runtime::NDArray::Empty({nnz, blockH_, blockW_}, dtype, dev_cpu0_); + auto weight_indices = runtime::NDArray::Empty({nnz}, itype, dev_cpu0_); + auto weight_indptr = runtime::NDArray::Empty({CO / blockH_ + 1}, itype, dev_cpu0_); + weight_data.CopyFromBytes(wdata.data(), wdata.size() * sizeof(float)); + weight_indices.CopyFromBytes(windices.data(), windices.size() * sizeof(int32_t)); + weight_indptr.CopyFromBytes(windptr.data(), windptr.size() * sizeof(int32_t)); + + // construct return call + auto args = runtime::Array{post.as()->args[0], Constant(weight_data), + Constant(weight_indices), Constant(weight_indptr)}; + auto attrs = make_object(); + attrs->layout = layout_; + attrs->kernel_size = Array{kernel_size_, kernel_size_}; + return Call(sparse_conv2d_op_, args, Attrs(attrs)); + } + + private: + const Op& sparse_conv2d_op_; + DLDevice dev_cpu0_; + String layout_; + int kernel_size_, blockH_, blockW_; + double sparse_thresh_; +}; // class Conv2dToSparseConv2dMutator2 + +Expr Conv2dToSparse2(const Expr& e, const String& layout, int kernel_size, int blockH, int blockW, + double sparse_thresh) { + auto rewriter = Conv2dToSparseConv2dMutator2(layout, kernel_size, blockH, blockW, sparse_thresh); return PostOrderRewrite(e, &rewriter); } namespace transform { +// Convert a model with seperate weight info (already sparsified). Pass Conv2dToSparse(const Array& weight_name, const Array>& weight_shape, - const String& layout) { + const String& layout, int kernel_size) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { // Remove FreeVar warnings - auto f0 = Downcast(Conv2dToSparse(f, weight_name, weight_shape, layout)); + auto f0 = + Downcast(Conv2dToSparse(f, weight_name, weight_shape, layout, kernel_size)); Array sparse_params = FreeVars(f0); auto f1 = Function(sparse_params, f0->body, f0->ret_type, f0->type_params, f0->attrs); Array params = FreeVars(f1); @@ -155,6 +304,20 @@ Pass Conv2dToSparse(const Array& weight_name, const Array pass_func = + [=](Function f, IRModule m, PassContext pc) { + auto f0 = Downcast( + Conv2dToSparse2(f, layout, kernel_size, blockH, blockW, sparse_thresh)); + return f0; + }; + return CreateFunctionPass(pass_func, 5, "Conv2dToSparse2", {"DeadCodeElimination"}); +} + +TVM_REGISTER_GLOBAL("relay._transform.Conv2dToSparse2").set_body_typed(Conv2dToSparse2); + } // namespace transform } // namespace relay diff --git a/tests/python/relay/test_sparse_conv2d_convert.py b/tests/python/relay/test_sparse_conv2d_convert.py index 0af78fc033ac..045462475ee1 100644 --- a/tests/python/relay/test_sparse_conv2d_convert.py +++ b/tests/python/relay/test_sparse_conv2d_convert.py @@ -25,6 +25,7 @@ from tvm.ir import IRModule from tvm import relay from tvm.topi.sparse.utils import random_bsr_matrix +from tvm.relay.build_module import bind_params_by_name def run_func(func, params, x): @@ -100,6 +101,68 @@ def test_bsr_sparse_conv2d_nhwc(): np.testing.assert_allclose(sparse_output, dense_output, atol=1e-5, rtol=1e-5) +def test_bsr_sparse_conv2d_3x3_nchw(): + data = relay.var("data", shape=(1, 64, 32, 32), dtype="float32") + x = relay.nn.relu(data) + w = relay.var("weight", shape=(128, 64, 3, 3), dtype="float32") + y = relay.nn.conv2d( + x, w, channels=128, kernel_size=3, padding=1, data_layout="NCHW", kernel_layout="OIHW" + ) + z = relay.nn.relu(y) + func = relay.Function(relay.analysis.free_vars(z), z) + + params = { + "weight": tvm.nd.array( + np.array(random_bsr_matrix(128, 64 * 9, 16, 1, 0.1, "float32").todense()).reshape( + 128, 64, 3, 3 + ) + ) + } + + x_np = np.random.randn(1, 64, 32, 32).astype("float32") + # dense output + dense_output = run_func(func, params, x_np) + # sparse + func = bind_params_by_name(func, params) + sparse_func, params = relay.data_dep_optimization.bsr_conv2d.convert2( + func, {}, (16, 1), 0.2, "NCHW", 3 + ) + sparse_output = run_func(sparse_func, params, x_np) + np.testing.assert_allclose(sparse_output, dense_output, atol=1e-5, rtol=1e-5) + + +def test_bsr_sparse_conv2d_3x3_nhwc(): + data = relay.var("data", shape=(1, 32, 32, 64), dtype="float32") + x = relay.nn.relu(data) + w = relay.var("weight", shape=(3, 3, 64, 128), dtype="float32") + y = relay.nn.conv2d( + x, w, channels=128, kernel_size=3, padding=1, data_layout="NHWC", kernel_layout="HWIO" + ) + z = relay.nn.relu(y) + func = relay.Function(relay.analysis.free_vars(z), z) + + params = { + "weight": tvm.nd.array( + np.array(random_bsr_matrix(128, 64 * 9, 16, 1, 0.1, "float32").todense()).T.reshape( + 3, 3, 64, 128 + ) + ) + } + + x_np = np.random.randn(1, 32, 32, 64).astype("float32") + # dense output + dense_output = run_func(func, params, x_np) + # sparse + func = bind_params_by_name(func, params) + sparse_func, params = relay.data_dep_optimization.bsr_conv2d.convert2( + func, {}, (16, 1), 0.2, "NHWC", 3 + ) + sparse_output = run_func(sparse_func, params, x_np) + np.testing.assert_allclose(sparse_output, dense_output, atol=1e-5, rtol=1e-5) + + if __name__ == "__main__": test_bsr_sparse_conv2d_nhwc() test_bsr_sparse_conv2d_nchw() + test_bsr_sparse_conv2d_3x3_nhwc() + test_bsr_sparse_conv2d_3x3_nchw()