Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion topi/python/topi/x86/depthwise_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import tvm
from tvm import te
from tvm import autotvm
from tvm.autotvm.task.space import SplitEntity
from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
from ..nn.pad import pad
from ..util import get_const_tuple
from ..nn.util import get_pad_tuple
Expand Down Expand Up @@ -67,6 +67,7 @@ def _fallback_schedule(cfg, wkl):
cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn])
cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn])
cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n])
cfg["unroll_kw"] = OtherOptionEntity(False)

def depthwise_conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype):
"""Compute depthwise conv2d with NCHW layout."""
Expand Down Expand Up @@ -133,6 +134,7 @@ def depthwise_conv2d_NCHWc(cfg, data, kernel, strides, padding, dilation,
cfg.define_split("tile_ic", in_channel, num_outputs=2)
cfg.define_split("tile_oc", out_channel, num_outputs=2)
cfg.define_split("tile_ow", out_width, num_outputs=2, filter=lambda y: y.size[-1] <= 64)
cfg.define_knob("unroll_kw", [True, False])

# get workload and related schedule config
wkl = _get_workload(
Expand Down Expand Up @@ -199,6 +201,8 @@ def _callback(op):

def _schedule_depthwise_conv2d_NCHWc_impl(s, cfg, data_vec, kernel_vec, conv_out, output):
tile_ow, oc_bn = cfg["tile_ow"].size[-1], cfg["tile_oc"].size[-1]
unroll_kw = cfg["unroll_kw"].val

# schedule pad
if isinstance(s[data_vec].op, tvm.te.ComputeOp) \
and "pad" in data_vec.op.tag:
Expand Down Expand Up @@ -228,6 +232,8 @@ def _schedule_depthwise_conv2d_NCHWc_impl(s, cfg, data_vec, kernel_vec, conv_out
_, ic_chunk, oh, ow, ic_block = s[CC].op.axis
kh, kw = s[CC].op.reduce_axis
s[CC].reorder(ic_chunk, oh, kh, kw, ow, ic_block)
if unroll_kw:
s[CC].unroll(kw)
s[CC].vectorize(ic_block)
s[CC].unroll(ow)

Expand Down