diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index fc20d9a9bd15..215ef43ecee0 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -21,8 +21,8 @@ Im2ColPack = namedtuple('Im2ColPack', ['vp', 'vq', 'ba', 'bc', 'unroll']) -# workloads of resnet18 on imagenet _WORKLOADS = [ + # workloads of resnet18 on imagenet Workload(224, 224, 3, 64, 7, 7, 3, 3, 2, 2), Workload(56, 56, 64, 64, 3, 3, 1, 1, 1, 1), Workload(56, 56, 64, 64, 1, 1, 0, 0, 1, 1), @@ -35,6 +35,17 @@ Workload(14, 14, 256, 512, 3, 3, 1, 1, 2, 2), Workload(14, 14, 256, 512, 1, 1, 0, 0, 2, 2), Workload(7, 7, 512, 512, 3, 3, 1, 1, 1, 1), + # workloads of mobile net on imagenet + Workload(224, 224, 3, 32, 3, 3, 1, 1, 2, 2), + Workload(112, 112, 32, 64, 1, 1, 0, 0, 1, 1), + Workload(56, 56, 64, 128, 1, 1, 0, 0, 1, 1), + Workload(56, 56, 128, 128, 1, 1, 0, 0, 1, 1), + Workload(28, 28, 128, 256, 1, 1, 0, 0, 1, 1), + Workload(28, 28, 256, 256, 1, 1, 0, 0, 1, 1), + Workload(14, 14, 256, 512, 1, 1, 0, 0, 1, 1), + Workload(14, 14, 512, 512, 1, 1, 0, 0, 1, 1), + Workload(7, 7, 512, 1024, 1, 1, 0, 0, 1, 1), + Workload(7, 7, 1024, 1024, 1, 1, 0, 0, 1, 1), ] # platform specific schedule diff --git a/topi/python/topi/rasp/__init__.py b/topi/python/topi/rasp/__init__.py index f0f605eeba9d..2ac059128c60 100644 --- a/topi/python/topi/rasp/__init__.py +++ b/topi/python/topi/rasp/__init__.py @@ -3,3 +3,4 @@ from __future__ import absolute_import as _abs from .conv2d import * +from .depthwise_conv2d import * diff --git a/topi/python/topi/rasp/conv2d.py b/topi/python/topi/rasp/conv2d.py index f48bd4a16ee1..e208a67d8a20 100644 --- a/topi/python/topi/rasp/conv2d.py +++ b/topi/python/topi/rasp/conv2d.py @@ -23,6 +23,17 @@ Im2ColPack(7, 4, 1, 16, True), Im2ColPack(7, 4, 1, 8, False), Im2ColPack(7, 4, 1, 16, False), + + SpatialPack(2, 2, 4, 28, 1, True), + SpatialPack(1, 4, 8, 14, 1, False), + SpatialPack(1, 2, 16, 8, 1, True), + SpatialPack(1, 4, 8, 8, 8, True), + SpatialPack(2, 2, 8, 1, 1, False), + SpatialPack(1, 4, 8, 4, 8, False), + SpatialPack(2, 2, 8, 1, 4, False), + SpatialPack(2, 2, 8, 1, 8, False), + SpatialPack(1, 1, 16, 1, 4, False), + SpatialPack(1, 1, 4, 1, 4, True), ] def _schedule_conv2d(wkl): diff --git a/topi/python/topi/rasp/depthwise_conv2d.py b/topi/python/topi/rasp/depthwise_conv2d.py new file mode 100644 index 000000000000..1446556dc207 --- /dev/null +++ b/topi/python/topi/rasp/depthwise_conv2d.py @@ -0,0 +1,64 @@ +# pylint: disable=invalid-name,unused-variable +"""Schedule for depthwise_conv2d with auto fusion""" +import tvm +from .. import tag + +def _schedule(s, data, data_pad, kernel, output, last): + A, B, C = data, kernel, output + A0 = data_pad + C0 = last + + _, c, h, w = s[C].op.axis + dh, dw = s[C].op.reduce_axis + + oh, ow, ih, iw = s[C].tile(h, w, 2, 4) + s[C].reorder(oh, ow, dh, dw, ih, iw) + s[C].unroll(ih) + s[C].vectorize(iw) + + s[C].parallel(c) + s[C].pragma(c, "parallel_launch_point") + s[C].pragma(c, "parallel_stride_pattern") + s[C].pragma(c, "parallel_barrier_when_finish") + return s + + + +def schedule_depthwise_conv2d(outs): + """Schedule for depthwise_conv2d nchw forward. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of depthwise_conv2d + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for depthwise_conv2d nchw. + """ + outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs + s = tvm.create_schedule([x.op for x in outs]) + + def traverse(op): + # inline all one-to-one-mapping operators except the last stage (output) + if tag.is_broadcast(op.tag): + if op not in s.outputs: + s[op].compute_inline() + for tensor in op.input_tensors: + if tensor.op.input_tensors: + traverse(tensor.op) + # schedule depthwise_conv2d + if op.tag == 'depthwise_conv2d_nchw': + output = op.output(0) + kernel = op.input_tensors[1] + data = op.input_tensors[0] + data_pad = None + if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: + data_pad = data + data = data_pad.op.input_tensors[0] + _schedule(s, data, data_pad, kernel, output, outs[0]) + + traverse(outs[0].op) + return s