diff --git a/topi/python/topi/cuda/__init__.py b/topi/python/topi/cuda/__init__.py index 7ec40666347d..252d92dfe16b 100644 --- a/topi/python/topi/cuda/__init__.py +++ b/topi/python/topi/cuda/__init__.py @@ -4,6 +4,6 @@ from .conv2d_nchw import schedule_conv2d_nchw from .conv2d_hwcn import schedule_conv2d_hwcn -from .depthwise_conv2d_map import schedule_depthwise_conv2d_map +from .depthwise_conv2d import schedule_depthwise_conv2d from .reduction import schedule_reduce from .broadcast import schedule_broadcast_to diff --git a/topi/python/topi/cuda/depthwise_conv2d_map.py b/topi/python/topi/cuda/depthwise_conv2d.py similarity index 69% rename from topi/python/topi/cuda/depthwise_conv2d_map.py rename to topi/python/topi/cuda/depthwise_conv2d.py index 95a5ee827e11..b7e186be555f 100644 --- a/topi/python/topi/cuda/depthwise_conv2d_map.py +++ b/topi/python/topi/cuda/depthwise_conv2d.py @@ -3,25 +3,24 @@ import tvm from ..util import get_const_tuple -def schedule_depthwise_conv2d_map(op): - """Schedule for depthwise_conv2d map ops. - This include scale-shift and relu. +def schedule_depthwise_conv2d(outs): + """Schedule for depthwise_conv2d. Parameters ---------- - op: Operation - The symbolic description of the operation, should be depthwise_conv2d or - depthwise_conv2d followed by a sequence of one-to-one-mapping operators. + outs: Array of Tensor + The computation graph description of depthwise_conv2d + in the format of an array of tensors. Returns ------- s: Schedule - The computation schedule for the op. + The computation schedule for depthwise_conv2d. """ - s = tvm.create_schedule(op) - def schedule_depthwise_conv2d(PaddedInput, Filter, DepthwiseConv2d): - """Schedule for depthwise_conv2d declared in topi.nn.conv""" + outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs + s = tvm.create_schedule([x.op for x in outs]) + def _schedule(PaddedInput, Filter, DepthwiseConv2d): out_shape = get_const_tuple(DepthwiseConv2d.shape) out_height = out_shape[2] out_width = out_shape[3] @@ -35,27 +34,27 @@ def schedule_depthwise_conv2d(PaddedInput, Filter, DepthwiseConv2d): Output = DepthwiseConv2d CL = s.cache_write(DepthwiseConv2d, "local") else: - Output = op.output(0) + Output = outs[0].op.output(0) s[DepthwiseConv2d].set_scope("local") # schedule parameters - num_thread = 8 + num_thread_x = 8 + num_thread_y = 8 num_vthread_x = 1 num_vthread_y = 1 blocking_h = out_height blocking_w = out_width - if out_height % 48 == 0: - blocking_h = 48 - elif out_height % 32 == 0: + if out_height % 32 == 0: blocking_h = 32 - if out_width % 48 == 0: - blocking_w = 48 - num_vthread_y = 3 - elif out_width % 32 == 0: + num_thread_x = 2 + num_vthread_x = 2 + if out_width % 32 == 0: blocking_w = 32 + num_thread_y = 16 + num_vthread_y = 2 block_x = tvm.thread_axis("blockIdx.x") block_y = tvm.thread_axis("blockIdx.y") - thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") - thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y") + thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x") + thread_y = tvm.thread_axis((0, num_thread_y), "threadIdx.y") thread_vx = tvm.thread_axis((0, num_vthread_x), "vthread", name="vx") thread_vy = tvm.thread_axis((0, num_vthread_y), "vthread", name="vy") # split and bind @@ -65,10 +64,10 @@ def schedule_depthwise_conv2d(PaddedInput, Filter, DepthwiseConv2d): s[Output].bind(bx, block_x) by1, y1i = s[Output].split(Output.op.axis[2], factor=blocking_h) tvx, vxi = s[Output].split(y1i, nparts=num_vthread_x) - tx, xi = s[Output].split(vxi, nparts=num_thread) + tx, xi = s[Output].split(vxi, nparts=num_thread_x) by2, y2i = s[Output].split(Output.op.axis[3], factor=blocking_w) tvy, vyi = s[Output].split(y2i, nparts=num_vthread_y) - ty, yi = s[Output].split(vyi, nparts=num_thread) + ty, yi = s[Output].split(vyi, nparts=num_thread_y) s[Output].reorder(by1, by2, tvx, tvy, tx, ty, xi, yi) by = s[Output].fuse(by1, by2) s[Output].bind(tvx, thread_vx) @@ -85,21 +84,21 @@ def schedule_depthwise_conv2d(PaddedInput, Filter, DepthwiseConv2d): s[DepthwiseConv2d].compute_at(s[Output], ty) # input's shared memory load s[IS].compute_at(s[Output], by) - tx, xi = s[IS].split(IS.op.axis[2], nparts=num_thread) - ty, yi = s[IS].split(IS.op.axis[3], nparts=num_thread) + tx, xi = s[IS].split(IS.op.axis[2], nparts=num_thread_x) + ty, yi = s[IS].split(IS.op.axis[3], nparts=num_thread_y) s[IS].bind(tx, thread_x) s[IS].bind(ty, thread_y) # filter's shared memory load s[FS].compute_at(s[Output], by) s[FS].reorder(FS.op.axis[2], FS.op.axis[3], FS.op.axis[1]) - tx, xi = s[FS].split(FS.op.axis[2], nparts=num_thread) - ty, yi = s[FS].split(FS.op.axis[3], nparts=num_thread) + tx, xi = s[FS].split(FS.op.axis[2], nparts=num_thread_x) + ty, yi = s[FS].split(FS.op.axis[3], nparts=num_thread_y) s[FS].bind(tx, thread_x) s[FS].bind(ty, thread_y) def traverse(OP): # inline all one-to-one-mapping operators except the last stage (output) - if OP.tag == 'ewise' or OP.tag == 'scale_shift': + if 'ewise' in OP.tag or 'bcast' in OP.tag: if OP not in s.outputs: s[OP].compute_inline() for tensor in OP.input_tensors: @@ -110,7 +109,7 @@ def traverse(OP): PaddedInput = OP.input_tensors[0] Filter = OP.input_tensors[1] DepthwiseConv2d = OP.output(0) - schedule_depthwise_conv2d(PaddedInput, Filter, DepthwiseConv2d) + _schedule(PaddedInput, Filter, DepthwiseConv2d) - traverse(op) + traverse(outs[0].op) return s diff --git a/topi/python/topi/nn/mapping.py b/topi/python/topi/nn/mapping.py index 6affd68f0406..5e2347b8c797 100644 --- a/topi/python/topi/nn/mapping.py +++ b/topi/python/topi/nn/mapping.py @@ -3,7 +3,7 @@ from __future__ import absolute_import as _abs import tvm -@tvm.tag_scope(tag="scale_shift") +@tvm.tag_scope(tag="bcast_scale_shift") def scale_shift(Input, Scale, Shift): """Batch normalization operator in inference. diff --git a/topi/python/topi/testing/__init__.py b/topi/python/topi/testing/__init__.py index 61fe2a60df91..55b4af8961dc 100644 --- a/topi/python/topi/testing/__init__.py +++ b/topi/python/topi/testing/__init__.py @@ -6,4 +6,5 @@ from .conv2d_hwcn_python import conv2d_hwcn_python from .conv2d_nchw_python import conv2d_nchw_python +from .depthwise_conv2d_python import depthwise_conv2d_python from .dilate_python import dilate_python diff --git a/topi/python/topi/testing/depthwise_conv2d_python.py b/topi/python/topi/testing/depthwise_conv2d_python.py new file mode 100644 index 000000000000..8aecaf2abbd5 --- /dev/null +++ b/topi/python/topi/testing/depthwise_conv2d_python.py @@ -0,0 +1,62 @@ +# pylint: disable=invalid-name, unused-variable, line-too-long +"""Depthwise convolution in python""" +import numpy as np +from scipy import signal + + +def depthwise_conv2d_python(input_np, filter_np, stride, padding): + """Depthwise convolution operator in NCHW layout. + + Parameters + ---------- + input_np : numpy.ndarray + 4-D with shape [batch, in_channel, in_height, in_width] + + filter_np : numpy.ndarray + 4-D with shape [in_channel, channel_multiplier, filter_height, filter_width] + + stride : list / tuple of 2 ints + [stride_height, stride_width] + + padding : str + 'VALID' or 'SAME' + + Returns + ------- + output_np : np.ndarray + 4-D with shape [batch, out_channel, out_height, out_width] + """ + batch, in_channel, in_height, in_width = input_np.shape + _, channel_multiplier, filter_height, filter_width = filter_np.shape + stride_h, stride_w = stride + # calculate output shape + if padding == 'VALID': + out_channel = in_channel * channel_multiplier + out_height = (in_height - filter_height) // stride_h + 1 + out_width = (in_width - filter_width) // stride_w + 1 + output_np = np.zeros((batch, out_channel, out_height, out_width)) + for i in range(batch): + for j in range(out_channel): + output_np[i, j, :, :] = signal.convolve2d(input_np[i, j//channel_multiplier, :, :], \ + np.rot90(filter_np[j//channel_multiplier, j%channel_multiplier, :, :], 2), \ + mode='valid')[0:(in_height - filter_height + 1):stride_h, 0:(in_width - filter_height + 1):stride_w] + if padding == 'SAME': + out_channel = in_channel * channel_multiplier + out_height = np.int(np.ceil(float(in_height) / float(stride_h))) + out_width = np.int(np.ceil(float(in_width) / float(stride_w))) + output_np = np.zeros((batch, out_channel, out_height, out_width)) + pad_along_height = np.int(np.max((out_height - 1) * stride_h + filter_height - in_height, 0)) + pad_along_width = np.int(np.max((out_width - 1) * stride_w + filter_width - in_width, 0)) + pad_top_tvm = np.int(np.ceil(float(pad_along_height) / 2)) + pad_left_tvm = np.int(np.ceil(float(pad_along_width) / 2)) + pad_top_scipy = np.int(np.ceil(float(filter_height - 1) / 2)) + pad_left_scipy = np.int(np.ceil(float(filter_width - 1) / 2)) + index_h = pad_top_scipy - pad_top_tvm + index_w = pad_left_scipy - pad_left_tvm + for i in range(batch): + for j in range(out_channel): + output_np[i, j, :, :] = signal.convolve2d(input_np[i, j//channel_multiplier, :, :], \ + np.rot90(filter_np[j//channel_multiplier, j%channel_multiplier, :, :], 2), \ + mode='same')[index_h:in_height:stride_h, index_w:in_width:stride_w] + + return output_np diff --git a/topi/recipe/conv/depthwise_conv2d_map_test.py b/topi/recipe/conv/depthwise_conv2d_test.py similarity index 51% rename from topi/recipe/conv/depthwise_conv2d_map_test.py rename to topi/recipe/conv/depthwise_conv2d_test.py index b41244d25fe7..55d34ab74e4b 100644 --- a/topi/recipe/conv/depthwise_conv2d_map_test.py +++ b/topi/recipe/conv/depthwise_conv2d_test.py @@ -5,10 +5,10 @@ from tvm.contrib import nvcc import topi -from topi.nn.util import get_const_tuple -from topi.cuda.depthwise_conv2d_map import schedule_depthwise_conv2d_map +from topi.util import get_const_tuple +from topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d -TASK = "depthwise_conv2d_map" +TASK = "depthwise_conv2d" USE_MANUAL_CODE = False @tvm.register_func @@ -29,20 +29,20 @@ def tvm_callback_cuda_postproc(code): code = open("perf/%s_manual.cu" % TASK).read() return code -def test_depthwise_conv2d_map(): +def test_depthwise_conv2d(): """You may test different settings.""" - batch = 2 + batch = 1 in_channel = 256 - in_height = 32 - in_width = 32 + in_height = 96 + in_width = 96 filter_channel = in_channel - channel_multiplier = 2 - filter_height = 5 - filter_width = 5 + channel_multiplier = 1 + filter_height = 3 + filter_width = 3 - stride_h = 2 - stride_w = 2 + stride_h = 1 + stride_w = 1 padding = 'SAME' # or 'VALID' @@ -57,40 +57,14 @@ def test_depthwise_conv2d_map(): ScaleShift = topi.nn.scale_shift(DepthwiseConv2d, Scale, Shift) Relu = topi.nn.relu(ScaleShift) # Schedule - s1 = schedule_depthwise_conv2d_map(DepthwiseConv2d.op) - s2 = schedule_depthwise_conv2d_map(ScaleShift.op) - s3 = schedule_depthwise_conv2d_map(Relu.op) + s1 = schedule_depthwise_conv2d(DepthwiseConv2d) + s2 = schedule_depthwise_conv2d(ScaleShift) + s3 = schedule_depthwise_conv2d(Relu) - def depthwise_conv2d_map_scipy(input_np, filter_np, scale_np, shift_np): - out_shape = get_const_tuple(DepthwiseConv2d.shape) - out_channel = out_shape[1] - out_height = out_shape[2] - out_width = out_shape[3] - depthwise_conv2d_scipy = np.zeros((batch, out_channel, out_height, out_width), dtype=DepthwiseConv2d.dtype) - scale_shift_scipy = np.zeros((batch, out_channel, out_height, out_width), dtype=ScaleShift.dtype) - relu_scipy = np.zeros((batch, out_channel, out_height, out_width), dtype=Relu.dtype) - if padding == 'SAME': - pad_top_tvm = np.int(np.ceil(float(np.max((out_height - 1) * stride_h + filter_height - in_height, 0)) / 2)) - pad_left_tvm = np.int(np.ceil(float(np.max((out_width - 1) * stride_w + filter_width - in_width, 0)) / 2)) - pad_top_scipy = np.int(np.ceil(float(filter_height - 1) / 2)) - pad_left_scipy = np.int(np.ceil(float(filter_width - 1) / 2)) - index_h = pad_top_scipy - pad_top_tvm - index_w = pad_left_scipy - pad_left_tvm - for i in range(batch): - for j in range(out_channel): - depthwise_conv2d_scipy[i,j,:,:] = signal.convolve2d(input_np[i,j//channel_multiplier,:,:], - np.rot90(filter_np[j//channel_multiplier,j%channel_multiplier,:,:], 2), - mode='same')[index_h:in_height:stride_h, index_w:in_width:stride_w] - if padding == 'VALID': - for i in range(batch): - for j in range(out_channel): - depthwise_conv2d_scipy[i,j,:,:] = signal.convolve2d(input_np[i,j//channel_multiplier,:,:], - np.rot90(filter_np[j//channel_multiplier,j%channel_multiplier,:,:], 2), - mode='valid')[0:(in_height - filter_height + 1):stride_h, 0:(in_width - filter_height + 1):stride_w] - for c in range(out_channel): - scale_shift_scipy[:,c,:,:] = depthwise_conv2d_scipy[:,c,:,:] * scale_np[c] + shift_np[c] - relu_scipy[:,:,:,:] = np.maximum(scale_shift_scipy[:,:,:,:], 0) - return depthwise_conv2d_scipy, scale_shift_scipy, relu_scipy + input_np = np.random.uniform(size=get_const_tuple(Input.shape)).astype(Input.dtype) + filter_np = np.random.uniform(size=get_const_tuple(Filter.shape)).astype(Filter.dtype) + scale_np = np.random.uniform(size=(in_channel * channel_multiplier)).astype(Scale.dtype) + shift_np = np.random.uniform(size=(in_channel * channel_multiplier)).astype(Shift.dtype) def check_device(device): if not tvm.module.enabled(device): @@ -102,35 +76,36 @@ def check_device(device): f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device) f3 = tvm.build(s3, [Input, Filter, Scale, Shift, Relu], device) # Prepare data - input_np = np.random.uniform(size=get_const_tuple(Input.shape)).astype(Input.dtype) - filter_np = np.random.uniform(size=get_const_tuple(Filter.shape)).astype(Filter.dtype) input_tvm = tvm.nd.array(input_np, ctx) filter_tvm = tvm.nd.array(filter_np, ctx) - scale_np = np.random.uniform(size=(in_channel * channel_multiplier)).astype(Scale.dtype) - shift_np = np.random.uniform(size=(in_channel * channel_multiplier)).astype(Shift.dtype) scale_tvm = tvm.nd.array(scale_np, ctx) shift_tvm = tvm.nd.array(shift_np, ctx) - depthwise_conv2d_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape), dtype=DepthwiseConv2d.dtype), ctx) + depthwise_conv2d_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape),dtype=DepthwiseConv2d.dtype), ctx) scale_shift_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(ScaleShift.shape), dtype=ScaleShift.dtype), ctx) relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx) # Measure time cost of kernel 1 (depthwise_conv2d) - timer_1 = f1.time_evaluator(f1.entry_name, ctx, number=10000) + timer_1 = f1.time_evaluator(f1.entry_name, ctx, number=1000) tcost_1 = timer_1(input_tvm, filter_tvm, depthwise_conv2d_tvm).mean # Measure time cost of kernel 2 (depthwise_conv2d + scale_shift) - timer_2 = f2.time_evaluator(f2.entry_name, ctx, number=10000) + timer_2 = f2.time_evaluator(f2.entry_name, ctx, number=1000) tcost_2 = timer_2(input_tvm, filter_tvm, scale_tvm, shift_tvm, scale_shift_tvm).mean # Measure time cost of kernel 3 (depthwise_conv2d + scale_shift + relu) - timer_3 = f3.time_evaluator(f3.entry_name, ctx, number=10000) + timer_3 = f3.time_evaluator(f3.entry_name, ctx, number=1000) tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm).mean print("Input shape = " + str(get_const_tuple(Input.shape))) print("Filter shape = " + str(get_const_tuple(Filter.shape))) print("Stride = (%d, %d)" % (stride_h, stride_w)) print("padding = %s\n" % padding) print("Output shape = " + str(get_const_tuple(DepthwiseConv2d.shape))) - print("average time cost of 10000 runs (depthwise_conv2d) = %g sec" % tcost_1) - print("average time cost of 10000 runs (depthwise_conv2d + scale_shift) = %g sec" % tcost_2) - print("average time cost of 10000 runs (depthwise_conv2d + scale_shift + relu) = %g sec" % tcost_3) - depthwise_conv2d_scipy, scale_shift_scipy, relu_scipy = depthwise_conv2d_map_scipy(input_np, filter_np, scale_np, shift_np) + print("average time cost of 1000 runs (depthwise_conv2d) = %g sec" % tcost_1) + print("average time cost of 1000 runs (depthwise_conv2d + scale_shift) = %g sec" % tcost_2) + print("average time cost of 1000 runs (depthwise_conv2d + scale_shift + relu) = %g sec" % tcost_3) + # correctness + depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python(input_np, filter_np, stride=[stride_h, stride_w], padding=padding) + scale_shift_scipy = np.zeros(shape=get_const_tuple(ScaleShift.shape)) + for c in range(in_channel * channel_multiplier): + scale_shift_scipy[:,c,:,:] = depthwise_conv2d_scipy[:,c,:,:] * scale_np[c] + shift_np[c] + relu_scipy = np.maximum(scale_shift_scipy, 0) np.testing.assert_allclose(depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5) np.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5) np.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5) @@ -138,10 +113,10 @@ def check_device(device): with tvm.build_config(auto_unroll_max_step=32, auto_unroll_min_depth=0, - unroll_explicit=True, + unroll_explicit=False, detect_global_barrier=False, restricted_func=True): check_device("cuda") if __name__ == "__main__": - test_depthwise_conv2d_map() + test_depthwise_conv2d() diff --git a/topi/tests/python/test_topi_depthwise_conv2d.py b/topi/tests/python/test_topi_depthwise_conv2d.py new file mode 100644 index 000000000000..962914b17124 --- /dev/null +++ b/topi/tests/python/test_topi_depthwise_conv2d.py @@ -0,0 +1,86 @@ +import tvm +import topi +import numpy as np +from scipy import signal +from topi.util import get_const_tuple +from topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d + +def depthwise_conv2d_with_workload(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding): + in_width = in_height + filter_channel = in_channel + filter_width = filter_height + stride_w = stride_h + # placeholder + Input = tvm.placeholder((batch, in_channel, in_height, in_width), name='Input') + Filter = tvm.placeholder((filter_channel, channel_multiplier, filter_height, filter_width), name='Filter') + Stride = tvm.nd.array(np.array([stride_h, stride_w])) + Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale') + Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift') + # declare + DepthwiseConv2d = topi.nn.depthwise_conv2d(Input, Filter, Stride, padding) + ScaleShift = topi.nn.scale_shift(DepthwiseConv2d, Scale, Shift) + Relu = topi.nn.relu(ScaleShift) + # schedule + s1 = schedule_depthwise_conv2d(DepthwiseConv2d) + s2 = schedule_depthwise_conv2d(ScaleShift) + s3 = schedule_depthwise_conv2d(Relu) + + input_np = np.random.uniform(size=get_const_tuple(Input.shape)).astype(Input.dtype) + filter_np = np.random.uniform(size=get_const_tuple(Filter.shape)).astype(Filter.dtype) + scale_np = np.random.uniform(size=get_const_tuple(Scale.shape)).astype(Scale.dtype) + shift_np = np.random.uniform(size=get_const_tuple(Shift.shape)).astype(Shift.dtype) + + def check_device(device): + if not tvm.module.enabled(device): + print("Skip because %s is not enabled" % device) + return + ctx = tvm.context(device, 0) + # build the kernels + f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device) + f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device) + f3 = tvm.build(s3, [Input, Filter, Scale, Shift, Relu], device) + # prepare data + input_tvm = tvm.nd.array(input_np, ctx) + filter_tvm = tvm.nd.array(filter_np, ctx) + scale_tvm = tvm.nd.array(scale_np, ctx) + shift_tvm = tvm.nd.array(shift_np, ctx) + depthwise_conv2d_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape), dtype=DepthwiseConv2d.dtype), ctx) + scale_shift_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(ScaleShift.shape), dtype=ScaleShift.dtype), ctx) + relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx) + # launch kernel 1 (depthwise_conv2d) + timer_1 = f1.time_evaluator(f1.entry_name, ctx, number=1) + tcost_1 = timer_1(input_tvm, filter_tvm, depthwise_conv2d_tvm).mean + # launch kernel 2 (depthwise_conv2d + scale_shift) + timer_2 = f2.time_evaluator(f2.entry_name, ctx, number=1) + tcost_2 = timer_2(input_tvm, filter_tvm, scale_tvm, shift_tvm, scale_shift_tvm).mean + # launch kernel 3 (depthwise_conv2d + scale_shift + relu) + timer_3 = f3.time_evaluator(f3.entry_name, ctx, number=1) + tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm).mean + # correctness with scipy + depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python(input_np, filter_np, stride=[stride_h, stride_w], padding=padding) + scale_shift_scipy = np.zeros(shape=get_const_tuple(ScaleShift.shape)) + for c in range(in_channel * channel_multiplier): + scale_shift_scipy[:,c,:,:] = depthwise_conv2d_scipy[:,c,:,:] * scale_np[c] + shift_np[c] + relu_scipy = np.maximum(scale_shift_scipy, 0) + np.testing.assert_allclose(depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5) + np.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5) + np.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5) + + check_device("opencl") + check_device("cuda") + check_device("metal") + + +def test_depthwise_conv2d(): + depthwise_conv2d_with_workload(1, 728, 64, 1, 3, 1, "SAME") + depthwise_conv2d_with_workload(1, 728, 32, 1, 3, 1, "SAME") + depthwise_conv2d_with_workload(4, 256, 64, 2, 5, 2, "SAME") + depthwise_conv2d_with_workload(4, 256, 32, 2, 5, 2, "SAME") + depthwise_conv2d_with_workload(1, 728, 64, 1, 3, 1, "VALID") + depthwise_conv2d_with_workload(1, 728, 32, 1, 3, 1, "VALID") + depthwise_conv2d_with_workload(4, 256, 64, 2, 5, 2, "VALID") + depthwise_conv2d_with_workload(4, 256, 32, 2, 5, 2, "VALID") + + +if __name__ == "__main__": + test_depthwise_conv2d() diff --git a/topi/tests/python/test_topi_depthwise_conv2d_map.py b/topi/tests/python/test_topi_depthwise_conv2d_map.py deleted file mode 100644 index 22cc0654b0e6..000000000000 --- a/topi/tests/python/test_topi_depthwise_conv2d_map.py +++ /dev/null @@ -1,112 +0,0 @@ -import tvm -import topi -import numpy as np -from scipy import signal -from topi.util import get_const_tuple -from topi.cuda.depthwise_conv2d_map import schedule_depthwise_conv2d_map - -def depthwise_conv2d_map_with_workload(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding): - in_width = in_height - filter_channel = in_channel - filter_width = filter_height - stride_w = stride_h - # placeholder - Input = tvm.placeholder((batch, in_channel, in_height, in_width), name='Input') - Filter = tvm.placeholder((filter_channel, channel_multiplier, filter_height, filter_width), name='Filter') - Stride = tvm.nd.array(np.array([stride_h, stride_w])) - Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale') - Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift') - # declare - DepthwiseConv2d = topi.nn.depthwise_conv2d(Input, Filter, Stride, padding) - ScaleShift = topi.nn.scale_shift(DepthwiseConv2d, Scale, Shift) - Relu = topi.nn.relu(ScaleShift) - # schedule - s1 = schedule_depthwise_conv2d_map(DepthwiseConv2d.op) - s2 = schedule_depthwise_conv2d_map(ScaleShift.op) - s3 = schedule_depthwise_conv2d_map(Relu.op) - - def depthwise_conv2d_map_scipy(input_np, filter_np, scale_np, shift_np): - out_shape = get_const_tuple(DepthwiseConv2d.shape) - out_channel = out_shape[1] - out_height = out_shape[2] - out_width = out_shape[3] - depthwise_conv2d_scipy = np.zeros((batch, out_channel, out_height, out_width), dtype=DepthwiseConv2d.dtype) - scale_shift_scipy = np.zeros((batch, out_channel, out_height, out_width), dtype=ScaleShift.dtype) - relu_scipy = np.zeros((batch, out_channel, out_height, out_width), dtype=Relu.dtype) - if padding == 'SAME': - pad_top_tvm = np.int(np.ceil(float(np.max((out_height - 1) * stride_h + filter_height - in_height, 0)) / 2)) - pad_left_tvm = np.int(np.ceil(float(np.max((out_width - 1) * stride_w + filter_width - in_width, 0)) / 2)) - pad_top_scipy = np.int(np.ceil(float(filter_height - 1) / 2)) - pad_left_scipy = np.int(np.ceil(float(filter_width - 1) / 2)) - index_h = pad_top_scipy - pad_top_tvm - index_w = pad_left_scipy - pad_left_tvm - for i in range(batch): - for j in range(out_channel): - depthwise_conv2d_scipy[i,j,:,:] = signal.convolve2d(input_np[i,j//channel_multiplier,:,:], - np.rot90(filter_np[j//channel_multiplier,j%channel_multiplier,:,:], 2), - mode='same')[index_h:in_height:stride_h, index_w:in_width:stride_w] - if padding == 'VALID': - for i in range(batch): - for j in range(out_channel): - depthwise_conv2d_scipy[i,j,:,:] = signal.convolve2d(input_np[i,j//channel_multiplier,:,:], - np.rot90(filter_np[j//channel_multiplier,j%channel_multiplier,:,:], 2), - mode='valid')[0:(in_height - filter_height + 1):stride_h, 0:(in_width - filter_height + 1):stride_w] - for c in range(out_channel): - scale_shift_scipy[:,c,:,:] = depthwise_conv2d_scipy[:,c,:,:] * scale_np[c] + shift_np[c] - relu_scipy[:,:,:,:] = np.maximum(scale_shift_scipy[:,:,:,:], 0) - return depthwise_conv2d_scipy, scale_shift_scipy, relu_scipy - - def check_device(device): - if not tvm.module.enabled(device): - print("Skip because %s is not enabled" % device) - return - ctx = tvm.context(device, 0) - # build the kernels - f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device) - f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device) - f3 = tvm.build(s3, [Input, Filter, Scale, Shift, Relu], device) - # prepare data - input_np = np.random.uniform(size=get_const_tuple(Input.shape)).astype(Input.dtype) - filter_np = np.random.uniform(size=get_const_tuple(Filter.shape)).astype(Filter.dtype) - input_tvm = tvm.nd.array(input_np, ctx) - filter_tvm = tvm.nd.array(filter_np, ctx) - scale_np = np.random.uniform(size=get_const_tuple(Scale.shape)).astype(Scale.dtype) - shift_np = np.random.uniform(size=get_const_tuple(Shift.shape)).astype(Shift.dtype) - scale_tvm = tvm.nd.array(scale_np, ctx) - shift_tvm = tvm.nd.array(shift_np, ctx) - depthwise_conv2d_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape), dtype=DepthwiseConv2d.dtype), ctx) - scale_shift_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(ScaleShift.shape), dtype=ScaleShift.dtype), ctx) - relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx) - # launch kernel 1 (depthwise_conv2d) - timer_1 = f1.time_evaluator(f1.entry_name, ctx, number=1) - tcost_1 = timer_1(input_tvm, filter_tvm, depthwise_conv2d_tvm).mean - # launch kernel 2 (depthwise_conv2d + scale_shift) - timer_2 = f2.time_evaluator(f2.entry_name, ctx, number=1) - tcost_2 = timer_2(input_tvm, filter_tvm, scale_tvm, shift_tvm, scale_shift_tvm).mean - # launch kernel 3 (depthwise_conv2d + scale_shift + relu) - timer_3 = f3.time_evaluator(f3.entry_name, ctx, number=1) - tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm).mean - # correctness with scipy - depthwise_conv2d_scipy, scale_shift_scipy, relu_scipy = depthwise_conv2d_map_scipy(input_np, filter_np, scale_np, shift_np) - np.testing.assert_allclose(depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5) - np.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5) - np.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5) - - check_device("opencl") - check_device("cuda") - check_device("metal") - - -def test_depthwise_conv2d_map(): - depthwise_conv2d_map_with_workload(1, 728, 64, 1, 3, 1, "SAME") - depthwise_conv2d_map_with_workload(1, 728, 32, 1, 3, 1, "SAME") - depthwise_conv2d_map_with_workload(4, 256, 64, 2, 5, 2, "SAME") - depthwise_conv2d_map_with_workload(4, 256, 32, 2, 5, 2, "SAME") - depthwise_conv2d_map_with_workload(1, 728, 64, 1, 3, 1, "VALID") - depthwise_conv2d_map_with_workload(1, 728, 32, 1, 3, 1, "VALID") - depthwise_conv2d_map_with_workload(4, 256, 64, 2, 5, 2, "VALID") - depthwise_conv2d_map_with_workload(4, 256, 32, 2, 5, 2, "VALID") - - -if __name__ == "__main__": - test_depthwise_conv2d_map()