From 2b586d25c2637eadc59365bd01817278d1651202 Mon Sep 17 00:00:00 2001 From: wetliu Date: Sun, 13 Aug 2017 19:54:20 -0700 Subject: [PATCH 1/8] rename the nchw and pass the unit test; going to do it for nhwc depthwise --- topi/python/topi/cuda/__init__.py | 3 +- ...conv2d_map.py => depthwise_conv2d_nchw.py} | 8 +- .../python/topi/cuda/depthwise_conv2d_nhwc.py | 74 +++++++++++++++++ topi/python/topi/nn/conv.py | 83 ++++++++++++++++++- .../python/test_topi_depthwise_conv2d_map.py | 12 +-- 5 files changed, 166 insertions(+), 14 deletions(-) rename topi/python/topi/cuda/{depthwise_conv2d_map.py => depthwise_conv2d_nchw.py} (95%) create mode 100644 topi/python/topi/cuda/depthwise_conv2d_nhwc.py diff --git a/topi/python/topi/cuda/__init__.py b/topi/python/topi/cuda/__init__.py index 6456441d6038..76ac193004da 100644 --- a/topi/python/topi/cuda/__init__.py +++ b/topi/python/topi/cuda/__init__.py @@ -3,4 +3,5 @@ from __future__ import absolute_import as _abs from .conv2d_hwcn_map import schedule_conv2d_hwcn_map -from .depthwise_conv2d_map import schedule_depthwise_conv2d_map +from .depthwise_conv2d_nhwc import schedule_depthwise_conv2d_nhwc +from .depthwise_conv2d_nchw import schedule_depthwise_conv2d_nchw diff --git a/topi/python/topi/cuda/depthwise_conv2d_map.py b/topi/python/topi/cuda/depthwise_conv2d_nchw.py similarity index 95% rename from topi/python/topi/cuda/depthwise_conv2d_map.py rename to topi/python/topi/cuda/depthwise_conv2d_nchw.py index e1900d678c3d..1e11223881cb 100644 --- a/topi/python/topi/cuda/depthwise_conv2d_map.py +++ b/topi/python/topi/cuda/depthwise_conv2d_nchw.py @@ -1,10 +1,10 @@ # pylint: disable=invalid-name -"""Schedule for depthwise_conv2d with auto fusion""" +"""Schedule for depthwise_conv2d nchw with auto fusion""" import tvm from ..nn.util import get_const_tuple -def schedule_depthwise_conv2d_map(op): - """Schedule for depthwise_conv2d map ops. +def schedule_depthwise_conv2d_nchw(op): + """Schedule for depthwise_conv2d nchw forward ops. This include scale-shift and relu. @@ -106,7 +106,7 @@ def traverse(OP): if tensor.op.input_tensors: traverse(tensor.op) # schedule depthwise_conv2d - if OP.tag == 'depthwise_conv2d': + if OP.tag == 'depthwise_conv2d_nchw': PaddedInput = OP.input_tensors[0] Filter = OP.input_tensors[1] DepthwiseConv2d = OP.output(0) diff --git a/topi/python/topi/cuda/depthwise_conv2d_nhwc.py b/topi/python/topi/cuda/depthwise_conv2d_nhwc.py new file mode 100644 index 000000000000..8bec0cdef244 --- /dev/null +++ b/topi/python/topi/cuda/depthwise_conv2d_nhwc.py @@ -0,0 +1,74 @@ +# pylint: disable=invalid-name +"""Schedule for depthwise_conv2d nhwc with auto fusion""" +import tvm +from ..nn.util import get_const_tuple + +def schedule_depthwise_conv2d_nhwc(op): + """Schedule for depthwise_conv2d nhwc forward ops. + + This include scale-shift and relu. + + Parameters + ---------- + op: Operation + The symbolic description of the operation, should be depthwise_conv2d or + depthwise_conv2d followed by a sequence of one-to-one-mapping operators. + + Returns + ------- + s: Schedule + The computation schedule for the op. + """ + s = tvm.create_schedule(op) + def schedule_depthwise_conv2d(PaddedInput, Filter, DepthwiseConv2d): + + temp = PaddedInput + Filter = Filter + Output = DepthwiseConv2d + + s = tvm.create_schedule(Output.op) + s[temp].compute_inline() + + FS = s.cache_read(Filter, "shared", [Output]) + + block_x = tvm.thread_axis("blockIdx.x") + thread_x = tvm.thread_axis("threadIdx.x") + + b, h, w, c = s[Output].op.axis + + h_val = tvm.ir_pass.Simplify(Output.shape[1]).value + b_val = tvm.ir_pass.Simplify(Output.shape[0]).value + ic_val = tvm.ir_pass.Simplify(temp.shape[3]).value + xoc, xic = s[Output].split(c, factor=ic_val) + s[Output].reorder(xoc, b, h, w, xic) + xo, yo, xi, yi = s[Output].tile(h, w, x_factor=2, y_factor=2) + fused = s[Output].fuse(yo, xo) + fused = s[Output].fuse(fused, b) + fused = s[Output].fuse(fused, xoc) + + s[Output].bind(fused, block_x) + s[Output].bind(xic, thread_x) + + yi, xi, ci, fi = s[FS].op.axis + s[FS].compute_at(s[Output], fused) + fused = s[FS].fuse(fi,ci) + s[FS].bind(fused, thread_x) + return s + + def traverse(OP): + # inline all one-to-one-mapping operators except the last stage (output) + if OP.tag == 'ewise' or OP.tag == 'scale_shift': + if OP not in s.outputs: + s[OP].compute_inline() + for tensor in OP.input_tensors: + if tensor.op.input_tensors: + traverse(tensor.op) + # schedule depthwise_conv2d + if OP.tag == 'depthwise_conv2d_nhwc': + PaddedInput = OP.input_tensors[0] + Filter = OP.input_tensors[1] + DepthwiseConv2d = OP.output(0) + schedule_depthwise_conv2d(PaddedInput, Filter, DepthwiseConv2d) + + traverse(op) + return s diff --git a/topi/python/topi/nn/conv.py b/topi/python/topi/nn/conv.py index 4c233aa1568a..81c4a0c17324 100644 --- a/topi/python/topi/nn/conv.py +++ b/topi/python/topi/nn/conv.py @@ -71,10 +71,9 @@ def conv2d_hwcn(Input, Filter, stride, padding): name='Conv2dOutput') return Output - -@tvm.tag_scope(tag="depthwise_conv2d") -def depthwise_conv2d(Input, Filter, Stride, padding): - """Depthwise convolution operator. +@tvm.tag_scope(tag="depthwise_conv2d_nchw") +def depthwise_conv2d_nchw(Input, Filter, Stride, padding): + """Depthwise convolution nchw forward operator. Parameters ---------- @@ -141,3 +140,79 @@ def depthwise_conv2d(Input, Filter, Stride, padding): axis=[di, dj]), name='DepthwiseConv2d') return Output + +@tvm.tag_scope(tag="depthwise_conv2d_nhwc") +def depthwise_conv2d_nhwc(Input, Filter, Stride, padding): + """Depthwise convolution nhwc forward operator. + + Parameters + ---------- + Input : tvm.Tensor + 4-D with shape [batch, in_height, in_width, in_channel] + + Filter : tvm.Tensor + 4-D with shape [filter_height, filter_width, in_channel, channel_multiplier] + + Stride : tvm.Tensor + 1-D of size 2 + + padding : str + 'VALID' or 'SAME' + + Returns + ------- + Output : tvm.Tensor + 4-D with shape [batch, out_height, out_width, out_channel] + """ + in_shape = get_const_tuple(Input.shape) + batch = in_shape[0] + in_height = in_shape[1] + in_width = in_shape[2] + in_channel = in_shape[3] + + filter_shape = get_const_tuple(Filter.shape) + filter_height = filter_shape[0] + filter_width = filter_shape[1] + filter_channel = filter_shape[2] + channel_multiplier = filter_shape[3] + + stride_h = Stride.asnumpy()[0] + stride_w = Stride.asnumpy()[1] + + # calculate output shape + if padding == 'VALID': + out_channel = in_channel * channel_multiplier + out_height = (in_height - filter_height) // stride_h + 1 + out_width = (in_width - filter_width) // stride_w + 1 + pad_along_height = 0 + pad_along_width = 0 + if padding == 'SAME': + out_channel = in_channel * channel_multiplier + out_height = np.int(np.ceil(float(in_height) / float(stride_h))) + out_width = np.int(np.ceil(float(in_width) / float(stride_w))) + pad_along_height = np.int(np.max((out_height - 1) * stride_h + filter_height - in_height, 0)) + pad_along_width = np.int(np.max((out_width - 1) * stride_w + filter_width - in_width, 0)) + + height_after_pad = in_height + pad_along_height + width_after_pad = in_width + pad_along_width + pad_top = np.int(np.ceil(float(pad_along_height) / 2)) + pad_left = np.int(np.ceil(float(pad_along_width) / 2)) + + # padding stage + PaddedInput = tvm.compute( + (batch, height_after_pad, width_after_pad, in_channel), + lambda b, i, j, c: tvm.select( + tvm.all(i >= pad_top, i - pad_top < in_height, j >= pad_left, j - pad_left < in_width), + Input[b, i - pad_top, j - pad_left, c], tvm.const(0.0)), + name="PaddedInput") + + # depthconv stage + di = tvm.reduce_axis((0, filter_height), name='di') + dj = tvm.reduce_axis((0, filter_width), name='dj') + Output = tvm.compute( + (batch, out_height, out_width, out_channel), + lambda b, i, j, c: tvm.sum( + PaddedInput[b, i*stride_h + di, j*stride_w + dj, c/channel_multiplier] * Filter[di, dj, c/channel_multiplier, c%channel_multiplier], + axis=[di, dj]), + name='DepthwiseConv2d') + return Output diff --git a/topi/tests/python/test_topi_depthwise_conv2d_map.py b/topi/tests/python/test_topi_depthwise_conv2d_map.py index 4069bd43f2da..cb182c5072c3 100644 --- a/topi/tests/python/test_topi_depthwise_conv2d_map.py +++ b/topi/tests/python/test_topi_depthwise_conv2d_map.py @@ -3,7 +3,7 @@ import numpy as np from scipy import signal from topi.nn.util import get_const_tuple -from topi.cuda.depthwise_conv2d_map import schedule_depthwise_conv2d_map +from topi.cuda.depthwise_conv2d_nchw import schedule_depthwise_conv2d_nchw def depthwise_conv2d_map_with_workload(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding): in_width = in_height @@ -17,13 +17,13 @@ def depthwise_conv2d_map_with_workload(batch, in_channel, in_height, channel_mul Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale') Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift') # declare - DepthwiseConv2d = topi.nn.depthwise_conv2d(Input, Filter, Stride, padding) + DepthwiseConv2d = topi.nn.depthwise_conv2d_nchw(Input, Filter, Stride, padding) ScaleShift = topi.nn.scale_shift(DepthwiseConv2d, Scale, Shift) Relu = topi.nn.relu(ScaleShift) # schedule - s1 = schedule_depthwise_conv2d_map(DepthwiseConv2d.op) - s2 = schedule_depthwise_conv2d_map(ScaleShift.op) - s3 = schedule_depthwise_conv2d_map(Relu.op) + s1 = schedule_depthwise_conv2d_nchw(DepthwiseConv2d.op) + s2 = schedule_depthwise_conv2d_nchw(ScaleShift.op) + s3 = schedule_depthwise_conv2d_nchw(Relu.op) def depthwise_conv2d_map_scipy(input_np, filter_np, scale_np, shift_np): out_shape = get_const_tuple(DepthwiseConv2d.shape) @@ -97,6 +97,8 @@ def check_device(device): check_device("metal") + + def test_depthwise_conv2d_map(): depthwise_conv2d_map_with_workload(1, 728, 64, 1, 3, 1, "SAME") depthwise_conv2d_map_with_workload(1, 728, 32, 1, 3, 1, "SAME") From b65d03f01afaa09aa52b0dc3635f4b369faa6183 Mon Sep 17 00:00:00 2001 From: wetliu Date: Sun, 13 Aug 2017 22:35:45 -0700 Subject: [PATCH 2/8] bug with fusion --- .../python/topi/cuda/depthwise_conv2d_nhwc.py | 8 +- .../python/test_topi_depthwise_conv2d_map.py | 119 ++++++++++++++++-- 2 files changed, 111 insertions(+), 16 deletions(-) diff --git a/topi/python/topi/cuda/depthwise_conv2d_nhwc.py b/topi/python/topi/cuda/depthwise_conv2d_nhwc.py index 8bec0cdef244..34eb560cfc9f 100644 --- a/topi/python/topi/cuda/depthwise_conv2d_nhwc.py +++ b/topi/python/topi/cuda/depthwise_conv2d_nhwc.py @@ -20,13 +20,8 @@ def schedule_depthwise_conv2d_nhwc(op): The computation schedule for the op. """ s = tvm.create_schedule(op) - def schedule_depthwise_conv2d(PaddedInput, Filter, DepthwiseConv2d): + def schedule_depthwise_conv2d(temp, Filter, Output): - temp = PaddedInput - Filter = Filter - Output = DepthwiseConv2d - - s = tvm.create_schedule(Output.op) s[temp].compute_inline() FS = s.cache_read(Filter, "shared", [Output]) @@ -53,7 +48,6 @@ def schedule_depthwise_conv2d(PaddedInput, Filter, DepthwiseConv2d): s[FS].compute_at(s[Output], fused) fused = s[FS].fuse(fi,ci) s[FS].bind(fused, thread_x) - return s def traverse(OP): # inline all one-to-one-mapping operators except the last stage (output) diff --git a/topi/tests/python/test_topi_depthwise_conv2d_map.py b/topi/tests/python/test_topi_depthwise_conv2d_map.py index cb182c5072c3..2be1544090af 100644 --- a/topi/tests/python/test_topi_depthwise_conv2d_map.py +++ b/topi/tests/python/test_topi_depthwise_conv2d_map.py @@ -4,8 +4,9 @@ from scipy import signal from topi.nn.util import get_const_tuple from topi.cuda.depthwise_conv2d_nchw import schedule_depthwise_conv2d_nchw +from topi.cuda.depthwise_conv2d_nhwc import schedule_depthwise_conv2d_nhwc -def depthwise_conv2d_map_with_workload(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding): +def depthwise_conv2d_map_with_workload_nchw(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding): in_width = in_height filter_channel = in_channel filter_width = filter_height @@ -96,18 +97,118 @@ def check_device(device): check_device("cuda") check_device("metal") +def depthwise_conv2d_map_with_workload_nhwc(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding): + in_width = in_height + filter_channel = in_channel + filter_width = filter_height + stride_w = stride_h + # placeholder + Input = tvm.placeholder((batch, in_height, in_width, in_channel), name='Input') + Filter = tvm.placeholder((filter_height, filter_width, filter_channel, channel_multiplier), name='Filter') + Stride = tvm.nd.array(np.array([stride_h, stride_w])) + Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale') + Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift') + # declare + DepthwiseConv2d = topi.nn.depthwise_conv2d_nhwc(Input, Filter, Stride, padding) + ScaleShift = topi.nn.scale_shift(DepthwiseConv2d, Scale, Shift) + Relu = topi.nn.relu(ScaleShift) + # schedule + s1 = schedule_depthwise_conv2d_nhwc(DepthwiseConv2d.op) + s2 = schedule_depthwise_conv2d_nhwc(ScaleShift.op) + s3 = schedule_depthwise_conv2d_nhwc(Relu.op) + + def depthwise_conv2d_map_scipy(input_np, filter_np, scale_np, shift_np): + out_shape = get_const_tuple(DepthwiseConv2d.shape) + out_height = out_shape[1] + out_width = out_shape[2] + out_channel = out_shape[3] + depthwise_conv2d_scipy = np.zeros((batch, out_height, out_width, out_channel), dtype=DepthwiseConv2d.dtype) + scale_shift_scipy = np.zeros((batch, out_height, out_width, out_channel), dtype=ScaleShift.dtype) + relu_scipy = np.zeros((batch, out_height, out_width, out_channel), dtype=Relu.dtype) + if padding == 'SAME': + pad_top_tvm = np.int(np.ceil(float(np.max((out_height - 1) * stride_h + filter_height - in_height, 0)) / 2)) + pad_left_tvm = np.int(np.ceil(float(np.max((out_width - 1) * stride_w + filter_width - in_width, 0)) / 2)) + pad_top_scipy = np.int(np.ceil(float(filter_height - 1) / 2)) + pad_left_scipy = np.int(np.ceil(float(filter_width - 1) / 2)) + index_h = pad_top_scipy - pad_top_tvm + index_w = pad_left_scipy - pad_left_tvm + for i in range(batch): + for j in range(out_channel): + depthwise_conv2d_scipy[i,:,:,j] = signal.convolve2d(input_np[i,:,:,j//channel_multiplier], + np.rot90(filter_np[:,:,j//channel_multiplier,j%channel_multiplier], 2), + mode='same')[index_h:in_height:stride_h, index_w:in_width:stride_w] + if padding == 'VALID': + for i in range(batch): + for j in range(out_channel): + depthwise_conv2d_scipy[i,:,:,j] = signal.convolve2d(input_np[i,:,:,j//channel_multiplier], + np.rot90(filter_np[:,:,j//channel_multiplier,j%channel_multiplier], 2), + mode='valid')[0:(in_height - filter_height + 1):stride_h, 0:(in_width - filter_height + 1):stride_w] + for c in range(out_channel): + scale_shift_scipy[:,:,:,c] = depthwise_conv2d_scipy[:,:,:,c] * scale_np[c] + shift_np[c] + relu_scipy[:,:,:,:] = np.maximum(scale_shift_scipy[:,:,:,:], 0) + return depthwise_conv2d_scipy, scale_shift_scipy, relu_scipy + def check_device(device): + if not tvm.module.enabled(device): + print("Skip because %s is not enabled" % device) + return + ctx = tvm.context(device, 0) + # build the kernels + f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device) + #f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device) + #f3 = tvm.build(s3, [Input, Filter, Scale, Shift, Relu], device) + # prepare data + input_np = np.random.uniform(size=get_const_tuple(Input.shape)).astype(Input.dtype) + filter_np = np.random.uniform(size=get_const_tuple(Filter.shape)).astype(Filter.dtype) + input_tvm = tvm.nd.array(input_np, ctx) + filter_tvm = tvm.nd.array(filter_np, ctx) + scale_np = np.random.uniform(size=get_const_tuple(Scale.shape)).astype(Scale.dtype) + shift_np = np.random.uniform(size=get_const_tuple(Shift.shape)).astype(Shift.dtype) + scale_tvm = tvm.nd.array(scale_np, ctx) + shift_tvm = tvm.nd.array(shift_np, ctx) + depthwise_conv2d_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape), dtype=DepthwiseConv2d.dtype), ctx) + scale_shift_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(ScaleShift.shape), dtype=ScaleShift.dtype), ctx) + relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx) + # launch kernel 1 (depthwise_conv2d) + timer_1 = f1.time_evaluator(f1.entry_name, ctx, number=1) + tcost_1 = timer_1(input_tvm, filter_tvm, depthwise_conv2d_tvm).mean + # launch kernel 2 (depthwise_conv2d + scale_shift) + #timer_2 = f2.time_evaluator(f2.entry_name, ctx, number=1) + #tcost_2 = timer_2(input_tvm, filter_tvm, scale_tvm, shift_tvm, scale_shift_tvm).mean + # launch kernel 3 (depthwise_conv2d + scale_shift + relu) + #timer_3 = f3.time_evaluator(f3.entry_name, ctx, number=1) + #tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm).mean + # correctness with scipy + depthwise_conv2d_scipy, scale_shift_scipy, relu_scipy = depthwise_conv2d_map_scipy(input_np, filter_np, scale_np, shift_np) + np.testing.assert_allclose(depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5) + #np.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5) + #np.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5) + + check_device("opencl") + check_device("cuda") + check_device("metal") def test_depthwise_conv2d_map(): - depthwise_conv2d_map_with_workload(1, 728, 64, 1, 3, 1, "SAME") - depthwise_conv2d_map_with_workload(1, 728, 32, 1, 3, 1, "SAME") - depthwise_conv2d_map_with_workload(4, 256, 64, 2, 5, 2, "SAME") - depthwise_conv2d_map_with_workload(4, 256, 32, 2, 5, 2, "SAME") - depthwise_conv2d_map_with_workload(1, 728, 64, 1, 3, 1, "VALID") - depthwise_conv2d_map_with_workload(1, 728, 32, 1, 3, 1, "VALID") - depthwise_conv2d_map_with_workload(4, 256, 64, 2, 5, 2, "VALID") - depthwise_conv2d_map_with_workload(4, 256, 32, 2, 5, 2, "VALID") + + depthwise_conv2d_map_with_workload_nchw(1, 728, 64, 1, 3, 1, "SAME") + """ + depthwise_conv2d_map_with_workload_nchw(1, 728, 32, 1, 3, 1, "SAME") + depthwise_conv2d_map_with_workload_nchw(4, 256, 64, 2, 5, 2, "SAME") + depthwise_conv2d_map_with_workload_nchw(4, 256, 32, 2, 5, 2, "SAME") + depthwise_conv2d_map_with_workload_nchw(1, 728, 64, 1, 3, 1, "VALID") + depthwise_conv2d_map_with_workload_nchw(1, 728, 32, 1, 3, 1, "VALID") + depthwise_conv2d_map_with_workload_nchw(4, 256, 64, 2, 5, 2, "VALID") + depthwise_conv2d_map_with_workload_nchw(4, 256, 32, 2, 5, 2, "VALID") + """ + depthwise_conv2d_map_with_workload_nhwc(1, 728, 64, 1, 3, 1, "SAME") + depthwise_conv2d_map_with_workload_nhwc(1, 728, 32, 1, 3, 1, "SAME") + depthwise_conv2d_map_with_workload_nhwc(4, 256, 64, 2, 5, 2, "SAME") + depthwise_conv2d_map_with_workload_nhwc(4, 256, 32, 2, 5, 2, "SAME") + depthwise_conv2d_map_with_workload_nhwc(1, 728, 64, 1, 3, 1, "VALID") + depthwise_conv2d_map_with_workload_nhwc(1, 728, 32, 1, 3, 1, "VALID") + depthwise_conv2d_map_with_workload_nhwc(4, 256, 64, 2, 5, 2, "VALID") + depthwise_conv2d_map_with_workload_nhwc(4, 256, 32, 2, 5, 2, "VALID") if __name__ == "__main__": From 68f1292443a012308ae1e41374f3c1267c273446 Mon Sep 17 00:00:00 2001 From: wetliu Date: Mon, 14 Aug 2017 14:16:03 -0700 Subject: [PATCH 3/8] nchw works fine; nhwc float32 problem remains --- .../python/topi/cuda/depthwise_conv2d_nchw.py | 2 +- .../python/topi/cuda/depthwise_conv2d_nhwc.py | 2 +- topi/python/topi/nn/mapping.py | 26 +++++++++++++++++-- .../python/test_topi_depthwise_conv2d_map.py | 26 +++++++++---------- 4 files changed, 39 insertions(+), 17 deletions(-) diff --git a/topi/python/topi/cuda/depthwise_conv2d_nchw.py b/topi/python/topi/cuda/depthwise_conv2d_nchw.py index ef2c507e6388..39136bc24ce7 100644 --- a/topi/python/topi/cuda/depthwise_conv2d_nchw.py +++ b/topi/python/topi/cuda/depthwise_conv2d_nchw.py @@ -99,7 +99,7 @@ def schedule_depthwise_conv2d(PaddedInput, Filter, DepthwiseConv2d): def traverse(OP): # inline all one-to-one-mapping operators except the last stage (output) - if OP.tag == 'ewise' or OP.tag == 'scale_shift': + if OP.tag == 'ewise' or OP.tag == 'scale_shift_nchw': if OP not in s.outputs: s[OP].compute_inline() for tensor in OP.input_tensors: diff --git a/topi/python/topi/cuda/depthwise_conv2d_nhwc.py b/topi/python/topi/cuda/depthwise_conv2d_nhwc.py index 34eb560cfc9f..bbe8e94dbae9 100644 --- a/topi/python/topi/cuda/depthwise_conv2d_nhwc.py +++ b/topi/python/topi/cuda/depthwise_conv2d_nhwc.py @@ -51,7 +51,7 @@ def schedule_depthwise_conv2d(temp, Filter, Output): def traverse(OP): # inline all one-to-one-mapping operators except the last stage (output) - if OP.tag == 'ewise' or OP.tag == 'scale_shift': + if OP.tag == 'ewise' or OP.tag == 'scale_shift_nhwc': if OP not in s.outputs: s[OP].compute_inline() for tensor in OP.input_tensors: diff --git a/topi/python/topi/nn/mapping.py b/topi/python/topi/nn/mapping.py index 6affd68f0406..ee5ec421bcd0 100644 --- a/topi/python/topi/nn/mapping.py +++ b/topi/python/topi/nn/mapping.py @@ -3,8 +3,8 @@ from __future__ import absolute_import as _abs import tvm -@tvm.tag_scope(tag="scale_shift") -def scale_shift(Input, Scale, Shift): +@tvm.tag_scope(tag="scale_shift_nchw") +def scale_shift_nchw(Input, Scale, Shift): """Batch normalization operator in inference. Parameters @@ -24,3 +24,25 @@ def scale_shift(Input, Scale, Shift): Output tensor, layout is NCHW """ return tvm.compute(Input.shape, lambda b, c, i, j: Input[b, c, i, j] * Scale[c] + Shift[c], name='ScaleShift') + +@tvm.tag_scope(tag="scale_shift_nhwc") +def scale_shift_nhwc(Input, Scale, Shift): + """Batch normalization operator in inference. + + Parameters + ---------- + Input : tvm.Tensor + Input tensor, layout is NHWC + + Scale : tvm.Tensor + Scale tensor, 1-D of size channel number + + Shift : tvm.Tensor + Shift tensor, 1-D of size channel number + + Returns + ------- + Output : tvm.Tensor + Output tensor, layout is NHWC + """ + return tvm.compute(Input.shape, lambda b, i, j, c: Input[b, i, j, c] * Scale[c] + Shift[c], name='ScaleShift') diff --git a/topi/tests/python/test_topi_depthwise_conv2d_map.py b/topi/tests/python/test_topi_depthwise_conv2d_map.py index 5a19c78cecfa..373e59aa1c3c 100644 --- a/topi/tests/python/test_topi_depthwise_conv2d_map.py +++ b/topi/tests/python/test_topi_depthwise_conv2d_map.py @@ -19,7 +19,7 @@ def depthwise_conv2d_map_with_workload_nchw(batch, in_channel, in_height, channe Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift') # declare DepthwiseConv2d = topi.nn.depthwise_conv2d_nchw(Input, Filter, Stride, padding) - ScaleShift = topi.nn.scale_shift(DepthwiseConv2d, Scale, Shift) + ScaleShift = topi.nn.scale_shift_nchw(DepthwiseConv2d, Scale, Shift) Relu = topi.nn.relu(ScaleShift) # schedule s1 = schedule_depthwise_conv2d_nchw(DepthwiseConv2d.op) @@ -110,7 +110,7 @@ def depthwise_conv2d_map_with_workload_nhwc(batch, in_channel, in_height, channe Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift') # declare DepthwiseConv2d = topi.nn.depthwise_conv2d_nhwc(Input, Filter, Stride, padding) - ScaleShift = topi.nn.scale_shift(DepthwiseConv2d, Scale, Shift) + ScaleShift = topi.nn.scale_shift_nhwc(DepthwiseConv2d, Scale, Shift) Relu = topi.nn.relu(ScaleShift) # schedule s1 = schedule_depthwise_conv2d_nhwc(DepthwiseConv2d.op) @@ -155,8 +155,9 @@ def check_device(device): ctx = tvm.context(device, 0) # build the kernels f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device) - #f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device) - #f3 = tvm.build(s3, [Input, Filter, Scale, Shift, Relu], device) + print(tvm.lower(s2, [Input, Filter, Scale, Shift, ScaleShift], simple_mode=True)) + f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device) + f3 = tvm.build(s3, [Input, Filter, Scale, Shift, Relu], device) # prepare data input_np = np.random.uniform(size=get_const_tuple(Input.shape)).astype(Input.dtype) filter_np = np.random.uniform(size=get_const_tuple(Filter.shape)).astype(Filter.dtype) @@ -173,16 +174,16 @@ def check_device(device): timer_1 = f1.time_evaluator(f1.entry_name, ctx, number=1) tcost_1 = timer_1(input_tvm, filter_tvm, depthwise_conv2d_tvm).mean # launch kernel 2 (depthwise_conv2d + scale_shift) - #timer_2 = f2.time_evaluator(f2.entry_name, ctx, number=1) - #tcost_2 = timer_2(input_tvm, filter_tvm, scale_tvm, shift_tvm, scale_shift_tvm).mean + timer_2 = f2.time_evaluator(f2.entry_name, ctx, number=1) + tcost_2 = timer_2(input_tvm, filter_tvm, scale_tvm, shift_tvm, scale_shift_tvm).mean # launch kernel 3 (depthwise_conv2d + scale_shift + relu) - #timer_3 = f3.time_evaluator(f3.entry_name, ctx, number=1) - #tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm).mean + timer_3 = f3.time_evaluator(f3.entry_name, ctx, number=1) + tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm).mean # correctness with scipy depthwise_conv2d_scipy, scale_shift_scipy, relu_scipy = depthwise_conv2d_map_scipy(input_np, filter_np, scale_np, shift_np) np.testing.assert_allclose(depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5) - #np.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5) - #np.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5) + np.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5) + np.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5) check_device("opencl") check_device("cuda") @@ -190,9 +191,8 @@ def check_device(device): def test_depthwise_conv2d_map(): - + print("testing nchw layout") depthwise_conv2d_map_with_workload_nchw(1, 728, 64, 1, 3, 1, "SAME") - """ depthwise_conv2d_map_with_workload_nchw(1, 728, 32, 1, 3, 1, "SAME") depthwise_conv2d_map_with_workload_nchw(4, 256, 64, 2, 5, 2, "SAME") depthwise_conv2d_map_with_workload_nchw(4, 256, 32, 2, 5, 2, "SAME") @@ -200,7 +200,7 @@ def test_depthwise_conv2d_map(): depthwise_conv2d_map_with_workload_nchw(1, 728, 32, 1, 3, 1, "VALID") depthwise_conv2d_map_with_workload_nchw(4, 256, 64, 2, 5, 2, "VALID") depthwise_conv2d_map_with_workload_nchw(4, 256, 32, 2, 5, 2, "VALID") - """ + print("testing nhwc layout") depthwise_conv2d_map_with_workload_nhwc(1, 728, 64, 1, 3, 1, "SAME") depthwise_conv2d_map_with_workload_nhwc(1, 728, 32, 1, 3, 1, "SAME") depthwise_conv2d_map_with_workload_nhwc(4, 256, 64, 2, 5, 2, "SAME") From a9a297506649a0d71c60a94acf661254ac5bc08d Mon Sep 17 00:00:00 2001 From: wetliu Date: Mon, 14 Aug 2017 16:25:21 -0700 Subject: [PATCH 4/8] still cannot bind them together --- .../python/topi/cuda/depthwise_conv2d_nhwc.py | 20 +++++++++++-------- .../python/test_topi_depthwise_conv2d_map.py | 6 ++++-- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/topi/python/topi/cuda/depthwise_conv2d_nhwc.py b/topi/python/topi/cuda/depthwise_conv2d_nhwc.py index bbe8e94dbae9..e49a5e547322 100644 --- a/topi/python/topi/cuda/depthwise_conv2d_nhwc.py +++ b/topi/python/topi/cuda/depthwise_conv2d_nhwc.py @@ -20,12 +20,16 @@ def schedule_depthwise_conv2d_nhwc(op): The computation schedule for the op. """ s = tvm.create_schedule(op) - def schedule_depthwise_conv2d(temp, Filter, Output): + def schedule_depthwise_conv2d(temp, Filter, DepthwiseConv2d): s[temp].compute_inline() - - FS = s.cache_read(Filter, "shared", [Output]) - + if DepthwiseConv2d.op in s.outputs: + Output = DepthwiseConv2d + else: + Output = op.output(0) + #s[DepthwiseConv2d].set_scope("local") + #FS = s.cache_read(Filter, "shared", [DepthwiseConv2d]) + block_x = tvm.thread_axis("blockIdx.x") thread_x = tvm.thread_axis("threadIdx.x") @@ -44,10 +48,10 @@ def schedule_depthwise_conv2d(temp, Filter, Output): s[Output].bind(fused, block_x) s[Output].bind(xic, thread_x) - yi, xi, ci, fi = s[FS].op.axis - s[FS].compute_at(s[Output], fused) - fused = s[FS].fuse(fi,ci) - s[FS].bind(fused, thread_x) + #yi, xi, ci, fi = s[FS].op.axis + #s[FS].compute_at(s[Output], fused) + #fused = s[FS].fuse(fi,ci) + #s[FS].bind(fused, thread_x) def traverse(OP): # inline all one-to-one-mapping operators except the last stage (output) diff --git a/topi/tests/python/test_topi_depthwise_conv2d_map.py b/topi/tests/python/test_topi_depthwise_conv2d_map.py index 373e59aa1c3c..98880c4fa0bf 100644 --- a/topi/tests/python/test_topi_depthwise_conv2d_map.py +++ b/topi/tests/python/test_topi_depthwise_conv2d_map.py @@ -63,7 +63,7 @@ def check_device(device): return ctx = tvm.context(device, 0) # build the kernels - f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device) + f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device) f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device) f3 = tvm.build(s3, [Input, Filter, Scale, Shift, Relu], device) # prepare data @@ -155,7 +155,7 @@ def check_device(device): ctx = tvm.context(device, 0) # build the kernels f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device) - print(tvm.lower(s2, [Input, Filter, Scale, Shift, ScaleShift], simple_mode=True)) + print(tvm.lower(s2, [Input, Filter, Scale, Shift, ScaleShift],simple_mode=True)) f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device) f3 = tvm.build(s3, [Input, Filter, Scale, Shift, Relu], device) # prepare data @@ -193,6 +193,7 @@ def check_device(device): def test_depthwise_conv2d_map(): print("testing nchw layout") depthwise_conv2d_map_with_workload_nchw(1, 728, 64, 1, 3, 1, "SAME") + """ depthwise_conv2d_map_with_workload_nchw(1, 728, 32, 1, 3, 1, "SAME") depthwise_conv2d_map_with_workload_nchw(4, 256, 64, 2, 5, 2, "SAME") depthwise_conv2d_map_with_workload_nchw(4, 256, 32, 2, 5, 2, "SAME") @@ -200,6 +201,7 @@ def test_depthwise_conv2d_map(): depthwise_conv2d_map_with_workload_nchw(1, 728, 32, 1, 3, 1, "VALID") depthwise_conv2d_map_with_workload_nchw(4, 256, 64, 2, 5, 2, "VALID") depthwise_conv2d_map_with_workload_nchw(4, 256, 32, 2, 5, 2, "VALID") + """ print("testing nhwc layout") depthwise_conv2d_map_with_workload_nhwc(1, 728, 64, 1, 3, 1, "SAME") depthwise_conv2d_map_with_workload_nhwc(1, 728, 32, 1, 3, 1, "SAME") From f6108c44848ccdf9608fb9bbda7429fafbd72e68 Mon Sep 17 00:00:00 2001 From: wetliu Date: Mon, 14 Aug 2017 20:20:38 -0700 Subject: [PATCH 5/8] fusion works --- .../python/topi/cuda/depthwise_conv2d_nhwc.py | 20 ++++++++++++------- .../python/test_topi_depthwise_conv2d_map.py | 6 ++---- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/topi/python/topi/cuda/depthwise_conv2d_nhwc.py b/topi/python/topi/cuda/depthwise_conv2d_nhwc.py index e49a5e547322..b5780fa0f398 100644 --- a/topi/python/topi/cuda/depthwise_conv2d_nhwc.py +++ b/topi/python/topi/cuda/depthwise_conv2d_nhwc.py @@ -23,12 +23,13 @@ def schedule_depthwise_conv2d_nhwc(op): def schedule_depthwise_conv2d(temp, Filter, DepthwiseConv2d): s[temp].compute_inline() + FS = s.cache_read(Filter, "shared", [DepthwiseConv2d]) if DepthwiseConv2d.op in s.outputs: Output = DepthwiseConv2d + CL = s.cache_write(DepthwiseConv2d, "local") else: Output = op.output(0) - #s[DepthwiseConv2d].set_scope("local") - #FS = s.cache_read(Filter, "shared", [DepthwiseConv2d]) + s[DepthwiseConv2d].set_scope("local") block_x = tvm.thread_axis("blockIdx.x") thread_x = tvm.thread_axis("threadIdx.x") @@ -47,11 +48,16 @@ def schedule_depthwise_conv2d(temp, Filter, DepthwiseConv2d): s[Output].bind(fused, block_x) s[Output].bind(xic, thread_x) - - #yi, xi, ci, fi = s[FS].op.axis - #s[FS].compute_at(s[Output], fused) - #fused = s[FS].fuse(fi,ci) - #s[FS].bind(fused, thread_x) + + if DepthwiseConv2d.op in s.outputs: + s[CL].compute_at(s[Output], xic) + else: + s[DepthwiseConv2d].compute_at(s[Output], xic) + + yi, xi, ci, fi = s[FS].op.axis + s[FS].compute_at(s[Output], fused) + fused = s[FS].fuse(fi,ci) + s[FS].bind(fused, thread_x) def traverse(OP): # inline all one-to-one-mapping operators except the last stage (output) diff --git a/topi/tests/python/test_topi_depthwise_conv2d_map.py b/topi/tests/python/test_topi_depthwise_conv2d_map.py index 98880c4fa0bf..cad2a766e57b 100644 --- a/topi/tests/python/test_topi_depthwise_conv2d_map.py +++ b/topi/tests/python/test_topi_depthwise_conv2d_map.py @@ -155,7 +155,6 @@ def check_device(device): ctx = tvm.context(device, 0) # build the kernels f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device) - print(tvm.lower(s2, [Input, Filter, Scale, Shift, ScaleShift],simple_mode=True)) f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device) f3 = tvm.build(s3, [Input, Filter, Scale, Shift, Relu], device) # prepare data @@ -191,9 +190,9 @@ def check_device(device): def test_depthwise_conv2d_map(): + print("testing nchw layout") - depthwise_conv2d_map_with_workload_nchw(1, 728, 64, 1, 3, 1, "SAME") - """ + depthwise_conv2d_map_with_workload_nchw(1, 728, 64, 1, 3, 1, "SAME") depthwise_conv2d_map_with_workload_nchw(1, 728, 32, 1, 3, 1, "SAME") depthwise_conv2d_map_with_workload_nchw(4, 256, 64, 2, 5, 2, "SAME") depthwise_conv2d_map_with_workload_nchw(4, 256, 32, 2, 5, 2, "SAME") @@ -201,7 +200,6 @@ def test_depthwise_conv2d_map(): depthwise_conv2d_map_with_workload_nchw(1, 728, 32, 1, 3, 1, "VALID") depthwise_conv2d_map_with_workload_nchw(4, 256, 64, 2, 5, 2, "VALID") depthwise_conv2d_map_with_workload_nchw(4, 256, 32, 2, 5, 2, "VALID") - """ print("testing nhwc layout") depthwise_conv2d_map_with_workload_nhwc(1, 728, 64, 1, 3, 1, "SAME") depthwise_conv2d_map_with_workload_nhwc(1, 728, 32, 1, 3, 1, "SAME") From 83703c182cd1e00af8d21669ba07324ae30a4e36 Mon Sep 17 00:00:00 2001 From: weitang Date: Tue, 15 Aug 2017 18:18:51 +0000 Subject: [PATCH 6/8] syntax fix --- topi/python/topi/cuda/depthwise_conv2d.py | 2 +- topi/python/topi/testing/depthwise_conv2d_python.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/topi/python/topi/cuda/depthwise_conv2d.py b/topi/python/topi/cuda/depthwise_conv2d.py index a2e10af9bdbb..73bfb989b698 100644 --- a/topi/python/topi/cuda/depthwise_conv2d.py +++ b/topi/python/topi/cuda/depthwise_conv2d.py @@ -165,7 +165,7 @@ def _schedule(temp, Filter, DepthwiseConv2d): yi, xi, ci, fi = s[FS].op.axis s[FS].compute_at(s[Output], fused) - fused = s[FS].fuse(fi,ci) + fused = s[FS].fuse(fi, ci) s[FS].bind(fused, thread_x) def traverse(OP): diff --git a/topi/python/topi/testing/depthwise_conv2d_python.py b/topi/python/topi/testing/depthwise_conv2d_python.py index c6e75b265d9e..84784f97c2b8 100644 --- a/topi/python/topi/testing/depthwise_conv2d_python.py +++ b/topi/python/topi/testing/depthwise_conv2d_python.py @@ -100,7 +100,7 @@ def depthwise_conv2d_python_nhwc(input_np, filter_np, stride, padding): out_channel = in_channel * channel_multiplier out_height = np.int(np.ceil(float(in_height) / float(stride_h))) out_width = np.int(np.ceil(float(in_width) / float(stride_w))) - output_np = np.zeros((batch, out_height, out_width,out_channel)) + output_np = np.zeros((batch, out_height, out_width, out_channel)) pad_along_height = np.int(np.max((out_height - 1) * stride_h + filter_height - in_height, 0)) pad_along_width = np.int(np.max((out_width - 1) * stride_w + filter_width - in_width, 0)) pad_top_tvm = np.int(np.ceil(float(pad_along_height) / 2)) @@ -111,7 +111,7 @@ def depthwise_conv2d_python_nhwc(input_np, filter_np, stride, padding): index_w = pad_left_scipy - pad_left_tvm for i in range(batch): for j in range(out_channel): - output_np[i, :, :, j] = signal.convolve2d(input_np[i, :, :, j//channel_multiplier], \ + output_np[i, :, :, j] = signal.convolve2d(input_np[i, :, :, j//channel_multiplier], \ np.rot90(filter_np[:, :, j//channel_multiplier, j%channel_multiplier], 2), \ mode='same')[index_h:in_height:stride_h, index_w:in_width:stride_w] From c1ae53d352ddcf81b257dc7ed344dc24b3a3b2f3 Mon Sep 17 00:00:00 2001 From: weitang Date: Tue, 15 Aug 2017 20:52:15 +0000 Subject: [PATCH 7/8] all bugs fixed; test cases pass --- topi/include/topi/nn.h | 35 ++++++++++++++++++- topi/python/topi/nn/convolution.py | 10 +++--- topi/recipe/conv/depthwise_conv2d_test.py | 4 +-- .../python/test_topi_depthwise_conv2d.py | 1 - 4 files changed, 41 insertions(+), 9 deletions(-) diff --git a/topi/include/topi/nn.h b/topi/include/topi/nn.h index 50f592933c28..326d96a99bc4 100644 --- a/topi/include/topi/nn.h +++ b/topi/include/topi/nn.h @@ -287,7 +287,7 @@ inline tvm::Tensor depthwise_conv2d_nchw(const tvm::Tensor& I, int stride_h = 1, int stride_w = 1, std::string name = "tensor", - std::string tag = kDepthwiseConv2d_nchw) { + std::string tag = kDepthwiseConv2dNCHW) { CHECK_EQ(4, I->shape.size()); CHECK_EQ(4, W->shape.size()); auto pH = I->shape[2]; @@ -313,6 +313,39 @@ inline tvm::Tensor depthwise_conv2d_nchw(const tvm::Tensor& I, return tvm::compute(output_shape, l, name, tag); } +inline tvm::Tensor depthwise_conv2d_nhwc(const tvm::Tensor& I, + const tvm::Tensor& W, + int pad_h = 0, + int pad_w = 0, + int stride_h = 1, + int stride_w = 1, + std::string name = "tensor", + std::string tag = kDepthwiseConv2d_nhwc) { + CHECK_EQ(4, I->shape.size()); + CHECK_EQ(4, W->shape.size()); + auto pH = I->shape[1]; + auto pW = I->shape[2]; + auto pCM = W->shape[1]; // channel_multiplier + tvm::Array output_shape{ + I->shape[0], // B + (I->shape[1] - W->shape[1] + 2 * pad_h) / stride_h + 1, // H + (I->shape[2] - W->shape[2] + 2 * pad_w) / stride_w + 1, // W + W->shape[3], // O + }; + auto i = tvm::reduce_axis(tvm::Range{0, I->shape[3]}, "i"); + auto kh = tvm::reduce_axis(tvm::Range{0, W->shape[0]}, "kh"); + auto kw = tvm::reduce_axis(tvm::Range{0, W->shape[1]}, "kw"); + auto T = (pad_h == 0 && pad_w == 0) + ? I + : pad(I, {tvm::Expr(0), tvm::Expr(0), pad_h, pad_w}); + auto l = [&](tvm::Var b, tvm::Var h, tvm::Var w, tvm::Var o) { + return tvm::sum(T(b, stride_h * h + kh, stride_w * w + kw, i / pCM) * + W(kh, kw, i / pCM, o % pCM), + {kh, kw, i}); + }; + return tvm::compute(output_shape, l, name, tag); +} + /*! * \brief Creates an operation that performs a 2-D group convolution with * an NGCHW-layout diff --git a/topi/python/topi/nn/convolution.py b/topi/python/topi/nn/convolution.py index e043987a1abf..cd201bdabf8f 100644 --- a/topi/python/topi/nn/convolution.py +++ b/topi/python/topi/nn/convolution.py @@ -152,7 +152,7 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding): (PaddedInput[b, c/channel_multiplier, i*stride_h + di, j*stride_w + dj] * Filter[c/channel_multiplier, c%channel_multiplier, di, dj]), axis=[di, dj]), - name='DepthwiseConv2d', tag="depthwise_conv2d") + name='DepthwiseConv2d', tag="depthwise_conv2d_nchw") return Output def depthwise_conv2d_nhwc(Input, Filter, stride, padding): @@ -188,8 +188,8 @@ def depthwise_conv2d_nhwc(Input, Filter, stride, padding): out_width = simplify((in_width - filter_width + pad_left + pad_right) // stride_w + 1) # padding stage - pad_before = [0, 0, pad_top, pad_left] - pad_after = [0, 0, pad_down, pad_right] + pad_before = [0, pad_top, pad_left, 0] + pad_after = [0, pad_down, pad_right, 0] PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput") # depthconv stage di = tvm.reduce_axis((0, filter_height), name='di') @@ -198,7 +198,7 @@ def depthwise_conv2d_nhwc(Input, Filter, stride, padding): (batch, out_height, out_width, out_channel), lambda b, i, j, c: tvm.sum( (PaddedInput[b, i*stride_h + di, j*stride_w + dj, c/channel_multiplier] * - Filter[di, dj, c/channel_multiplier, c%channel_multiplie]), + Filter[di, dj, c/channel_multiplier, c%channel_multiplier]), axis=[di, dj]), - name='DepthwiseConv2d', tag="depthwise_conv2d") + name='DepthwiseConv2d', tag="depthwise_conv2d_nhwc") return Output diff --git a/topi/recipe/conv/depthwise_conv2d_test.py b/topi/recipe/conv/depthwise_conv2d_test.py index 2eabbf61aece..0f3a14dde2aa 100644 --- a/topi/recipe/conv/depthwise_conv2d_test.py +++ b/topi/recipe/conv/depthwise_conv2d_test.py @@ -13,7 +13,7 @@ @tvm.register_func def tvm_callback_cuda_compile(code): - ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_52"]) + ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_37"]) # 37 for k80(ec2 instance) return ptx def write_code(code, fname): @@ -60,7 +60,6 @@ def test_depthwise_conv2d_nchw(): s1 = schedule_depthwise_conv2d_nchw(DepthwiseConv2d) s2 = schedule_depthwise_conv2d_nchw(ScaleShift) s3 = schedule_depthwise_conv2d_nchw(Relu) - input_np = np.random.uniform(size=get_const_tuple(Input.shape)).astype(Input.dtype) filter_np = np.random.uniform(size=get_const_tuple(Filter.shape)).astype(Filter.dtype) scale_np = np.random.uniform(size=(in_channel * channel_multiplier)).astype(Scale.dtype) @@ -80,6 +79,7 @@ def check_device(device): filter_tvm = tvm.nd.array(filter_np, ctx) scale_tvm = tvm.nd.array(scale_np, ctx) shift_tvm = tvm.nd.array(shift_np, ctx) + depthwise_conv2d_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape),dtype=DepthwiseConv2d.dtype), ctx) scale_shift_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(ScaleShift.shape), dtype=ScaleShift.dtype), ctx) relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx) diff --git a/topi/tests/python/test_topi_depthwise_conv2d.py b/topi/tests/python/test_topi_depthwise_conv2d.py index 38708804b10e..5b2c633c0a26 100644 --- a/topi/tests/python/test_topi_depthwise_conv2d.py +++ b/topi/tests/python/test_topi_depthwise_conv2d.py @@ -25,7 +25,6 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu s2 = schedule_depthwise_conv2d_nchw(ScaleShift) s3 = schedule_depthwise_conv2d_nchw(Relu) - print(tvm.lower(s2, [Input, Filter, Scale, Shift, ScaleShift], simple_mode=True)) input_np = np.random.uniform(size=get_const_tuple(Input.shape)).astype(Input.dtype) filter_np = np.random.uniform(size=get_const_tuple(Filter.shape)).astype(Filter.dtype) scale_np = np.random.uniform(size=get_const_tuple(Scale.shape)).astype(Scale.dtype) From 183e145f045240f9e406dc13ce64bc7018c9959b Mon Sep 17 00:00:00 2001 From: weitang Date: Tue, 15 Aug 2017 21:45:52 +0000 Subject: [PATCH 8/8] minor fix on nn.h --- topi/include/topi/nn.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/topi/include/topi/nn.h b/topi/include/topi/nn.h index 326d96a99bc4..cfca85d1b704 100644 --- a/topi/include/topi/nn.h +++ b/topi/include/topi/nn.h @@ -320,7 +320,7 @@ inline tvm::Tensor depthwise_conv2d_nhwc(const tvm::Tensor& I, int stride_h = 1, int stride_w = 1, std::string name = "tensor", - std::string tag = kDepthwiseConv2d_nhwc) { + std::string tag = kDepthwiseConv2dNHWC) { CHECK_EQ(4, I->shape.size()); CHECK_EQ(4, W->shape.size()); auto pH = I->shape[1]; @@ -337,7 +337,7 @@ inline tvm::Tensor depthwise_conv2d_nhwc(const tvm::Tensor& I, auto kw = tvm::reduce_axis(tvm::Range{0, W->shape[1]}, "kw"); auto T = (pad_h == 0 && pad_w == 0) ? I - : pad(I, {tvm::Expr(0), tvm::Expr(0), pad_h, pad_w}); + : pad(I, {tvm::Expr(0), pad_h, pad_w, tvm::Expr(0)}); auto l = [&](tvm::Var b, tvm::Var h, tvm::Var w, tvm::Var o) { return tvm::sum(T(b, stride_h * h + kh, stride_w * w + kw, i / pCM) * W(kh, kw, i / pCM, o % pCM),