diff --git a/topi/python/topi/x86/depthwise_conv2d.py b/topi/python/topi/x86/depthwise_conv2d.py index fda964eb6639..8fff0025cff1 100644 --- a/topi/python/topi/x86/depthwise_conv2d.py +++ b/topi/python/topi/x86/depthwise_conv2d.py @@ -20,7 +20,7 @@ import tvm from tvm import te from tvm import autotvm -from tvm.autotvm.task.space import SplitEntity +from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity from ..nn.pad import pad from ..util import get_const_tuple from ..nn.util import get_pad_tuple @@ -67,6 +67,7 @@ def _fallback_schedule(cfg, wkl): cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn]) cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn]) cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n]) + cfg["unroll_kw"] = OtherOptionEntity(False) def depthwise_conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype): """Compute depthwise conv2d with NCHW layout.""" @@ -133,6 +134,7 @@ def depthwise_conv2d_NCHWc(cfg, data, kernel, strides, padding, dilation, cfg.define_split("tile_ic", in_channel, num_outputs=2) cfg.define_split("tile_oc", out_channel, num_outputs=2) cfg.define_split("tile_ow", out_width, num_outputs=2, filter=lambda y: y.size[-1] <= 64) + cfg.define_knob("unroll_kw", [True, False]) # get workload and related schedule config wkl = _get_workload( @@ -199,6 +201,8 @@ def _callback(op): def _schedule_depthwise_conv2d_NCHWc_impl(s, cfg, data_vec, kernel_vec, conv_out, output): tile_ow, oc_bn = cfg["tile_ow"].size[-1], cfg["tile_oc"].size[-1] + unroll_kw = cfg["unroll_kw"].val + # schedule pad if isinstance(s[data_vec].op, tvm.te.ComputeOp) \ and "pad" in data_vec.op.tag: @@ -228,6 +232,8 @@ def _schedule_depthwise_conv2d_NCHWc_impl(s, cfg, data_vec, kernel_vec, conv_out _, ic_chunk, oh, ow, ic_block = s[CC].op.axis kh, kw = s[CC].op.reduce_axis s[CC].reorder(ic_chunk, oh, kh, kw, ow, ic_block) + if unroll_kw: + s[CC].unroll(kw) s[CC].vectorize(ic_block) s[CC].unroll(ow)