diff --git a/topi/include/topi/nn.h b/topi/include/topi/nn.h index 682642745e4d..cfca85d1b704 100644 --- a/topi/include/topi/nn.h +++ b/topi/include/topi/nn.h @@ -287,7 +287,7 @@ inline tvm::Tensor depthwise_conv2d_nchw(const tvm::Tensor& I, int stride_h = 1, int stride_w = 1, std::string name = "tensor", - std::string tag = kDepthwiseConv2d) { + std::string tag = kDepthwiseConv2dNCHW) { CHECK_EQ(4, I->shape.size()); CHECK_EQ(4, W->shape.size()); auto pH = I->shape[2]; @@ -313,6 +313,39 @@ inline tvm::Tensor depthwise_conv2d_nchw(const tvm::Tensor& I, return tvm::compute(output_shape, l, name, tag); } +inline tvm::Tensor depthwise_conv2d_nhwc(const tvm::Tensor& I, + const tvm::Tensor& W, + int pad_h = 0, + int pad_w = 0, + int stride_h = 1, + int stride_w = 1, + std::string name = "tensor", + std::string tag = kDepthwiseConv2dNHWC) { + CHECK_EQ(4, I->shape.size()); + CHECK_EQ(4, W->shape.size()); + auto pH = I->shape[1]; + auto pW = I->shape[2]; + auto pCM = W->shape[1]; // channel_multiplier + tvm::Array output_shape{ + I->shape[0], // B + (I->shape[1] - W->shape[1] + 2 * pad_h) / stride_h + 1, // H + (I->shape[2] - W->shape[2] + 2 * pad_w) / stride_w + 1, // W + W->shape[3], // O + }; + auto i = tvm::reduce_axis(tvm::Range{0, I->shape[3]}, "i"); + auto kh = tvm::reduce_axis(tvm::Range{0, W->shape[0]}, "kh"); + auto kw = tvm::reduce_axis(tvm::Range{0, W->shape[1]}, "kw"); + auto T = (pad_h == 0 && pad_w == 0) + ? I + : pad(I, {tvm::Expr(0), pad_h, pad_w, tvm::Expr(0)}); + auto l = [&](tvm::Var b, tvm::Var h, tvm::Var w, tvm::Var o) { + return tvm::sum(T(b, stride_h * h + kh, stride_w * w + kw, i / pCM) * + W(kh, kw, i / pCM, o % pCM), + {kh, kw, i}); + }; + return tvm::compute(output_shape, l, name, tag); +} + /*! * \brief Creates an operation that performs a 2-D group convolution with * an NGCHW-layout diff --git a/topi/include/topi/tags.h b/topi/include/topi/tags.h index 9a8081b9b72d..c0738f0117dc 100644 --- a/topi/include/topi/tags.h +++ b/topi/include/topi/tags.h @@ -13,7 +13,8 @@ constexpr auto kBroadcast = "bcast"; constexpr auto kMatMult = "matmult"; constexpr auto kConv2dNCHW = "conv2d_nchw"; constexpr auto kConv2dHWCN = "conv2d_hwcn"; -constexpr auto kDepthwiseConv2d = "depthwise_conv2d"; +constexpr auto kDepthwiseConv2dNCHW = "depthwise_conv2d_nchw"; +constexpr auto kDepthwiseConv2dNHWC = "depthwise_conv2d_nhwc"; constexpr auto kGroupConv2d = "group_conv2d"; } // namespace topi diff --git a/topi/python/topi/cuda/__init__.py b/topi/python/topi/cuda/__init__.py index 252d92dfe16b..39bd50686fa3 100644 --- a/topi/python/topi/cuda/__init__.py +++ b/topi/python/topi/cuda/__init__.py @@ -4,6 +4,6 @@ from .conv2d_nchw import schedule_conv2d_nchw from .conv2d_hwcn import schedule_conv2d_hwcn -from .depthwise_conv2d import schedule_depthwise_conv2d +from .depthwise_conv2d import schedule_depthwise_conv2d_nchw, schedule_depthwise_conv2d_nhwc from .reduction import schedule_reduce from .broadcast import schedule_broadcast_to diff --git a/topi/python/topi/cuda/depthwise_conv2d.py b/topi/python/topi/cuda/depthwise_conv2d.py index b7e186be555f..352c04c5e8a8 100644 --- a/topi/python/topi/cuda/depthwise_conv2d.py +++ b/topi/python/topi/cuda/depthwise_conv2d.py @@ -3,9 +3,8 @@ import tvm from ..util import get_const_tuple - -def schedule_depthwise_conv2d(outs): - """Schedule for depthwise_conv2d. +def schedule_depthwise_conv2d_nchw(outs): + """Schedule for depthwise_conv2d nchw forward. Parameters ---------- @@ -16,7 +15,7 @@ def schedule_depthwise_conv2d(outs): Returns ------- s: Schedule - The computation schedule for depthwise_conv2d. + The computation schedule for depthwise_conv2d nchw. """ outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs s = tvm.create_schedule([x.op for x in outs]) @@ -105,7 +104,78 @@ def traverse(OP): if tensor.op.input_tensors: traverse(tensor.op) # schedule depthwise_conv2d - if OP.tag == 'depthwise_conv2d': + if OP.tag == 'depthwise_conv2d_nchw': + PaddedInput = OP.input_tensors[0] + Filter = OP.input_tensors[1] + DepthwiseConv2d = OP.output(0) + _schedule(PaddedInput, Filter, DepthwiseConv2d) + + traverse(outs[0].op) + return s + +def schedule_depthwise_conv2d_nhwc(outs): + """Schedule for depthwise_conv2d nhwc forward. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of depthwise_conv2d + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for depthwise_conv2d nhwc. + """ + outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs + s = tvm.create_schedule([x.op for x in outs]) + def _schedule(temp, Filter, DepthwiseConv2d): + + s[temp].compute_inline() + FS = s.cache_read(Filter, "shared", [DepthwiseConv2d]) + if DepthwiseConv2d.op in s.outputs: + Output = DepthwiseConv2d + CL = s.cache_write(DepthwiseConv2d, "local") + else: + Output = outs[0].op.output(0) + s[DepthwiseConv2d].set_scope("local") + + block_x = tvm.thread_axis("blockIdx.x") + thread_x = tvm.thread_axis("threadIdx.x") + + b, h, w, c = s[Output].op.axis + + ic_val = tvm.ir_pass.Simplify(temp.shape[3]).value + xoc, xic = s[Output].split(c, factor=ic_val) + s[Output].reorder(xoc, b, h, w, xic) + xo, yo, _, _ = s[Output].tile(h, w, x_factor=2, y_factor=2) + fused = s[Output].fuse(yo, xo) + fused = s[Output].fuse(fused, b) + fused = s[Output].fuse(fused, xoc) + + s[Output].bind(fused, block_x) + s[Output].bind(xic, thread_x) + + if DepthwiseConv2d.op in s.outputs: + s[CL].compute_at(s[Output], xic) + else: + s[DepthwiseConv2d].compute_at(s[Output], xic) + + _, _, ci, fi = s[FS].op.axis + s[FS].compute_at(s[Output], fused) + fused = s[FS].fuse(fi, ci) + s[FS].bind(fused, thread_x) + + def traverse(OP): + # inline all one-to-one-mapping operators except the last stage (output) + if 'ewise' in OP.tag or 'bcast' in OP.tag: + if OP not in s.outputs: + s[OP].compute_inline() + for tensor in OP.input_tensors: + if tensor.op.input_tensors: + traverse(tensor.op) + # schedule depthwise_conv2d + if OP.tag == 'depthwise_conv2d_nhwc': PaddedInput = OP.input_tensors[0] Filter = OP.input_tensors[1] DepthwiseConv2d = OP.output(0) diff --git a/topi/python/topi/nn/convolution.py b/topi/python/topi/nn/convolution.py index 6293e02e343a..cd201bdabf8f 100644 --- a/topi/python/topi/nn/convolution.py +++ b/topi/python/topi/nn/convolution.py @@ -107,9 +107,8 @@ def conv2d_hwcn(Input, Filter, stride, padding): name="Conv2dOutput", tag="conv2d_hwcn") return Output - -def depthwise_conv2d(Input, Filter, stride, padding): - """Depthwise convolution operator. +def depthwise_conv2d_nchw(Input, Filter, stride, padding): + """Depthwise convolution nchw forward operator. Parameters ---------- @@ -153,5 +152,53 @@ def depthwise_conv2d(Input, Filter, stride, padding): (PaddedInput[b, c/channel_multiplier, i*stride_h + di, j*stride_w + dj] * Filter[c/channel_multiplier, c%channel_multiplier, di, dj]), axis=[di, dj]), - name='DepthwiseConv2d', tag="depthwise_conv2d") + name='DepthwiseConv2d', tag="depthwise_conv2d_nchw") + return Output + +def depthwise_conv2d_nhwc(Input, Filter, stride, padding): + """Depthwise convolution nhwc forward operator. + + Parameters + ---------- + Input : tvm.Tensor + 4-D with shape [batch, in_height, in_width, in_channel] + + Filter : tvm.Tensor + 4-D with shape [filter_height, filter_width, in_channel, channel_multiplier] + + Stride : tvm.Tensor + 1-D of size 2 + + padding : str + 'VALID' or 'SAME' + + Returns + ------- + Output : tvm.Tensor + 4-D with shape [batch, out_height, out_width, out_channel] + """ + batch, in_height, in_width, in_channel = Input.shape + filter_height, filter_width, filter_channel, channel_multiplier = Filter.shape + stride_h, stride_w = stride + + pad_top, pad_left, pad_down, pad_right = _spatial2d_pad_option( + padding, (filter_height, filter_width)) + out_channel = simplify(in_channel * channel_multiplier) + out_height = simplify((in_height - filter_height + pad_top + pad_down) // stride_h + 1) + out_width = simplify((in_width - filter_width + pad_left + pad_right) // stride_w + 1) + + # padding stage + pad_before = [0, pad_top, pad_left, 0] + pad_after = [0, pad_down, pad_right, 0] + PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput") + # depthconv stage + di = tvm.reduce_axis((0, filter_height), name='di') + dj = tvm.reduce_axis((0, filter_width), name='dj') + Output = tvm.compute( + (batch, out_height, out_width, out_channel), + lambda b, i, j, c: tvm.sum( + (PaddedInput[b, i*stride_h + di, j*stride_w + dj, c/channel_multiplier] * + Filter[di, dj, c/channel_multiplier, c%channel_multiplier]), + axis=[di, dj]), + name='DepthwiseConv2d', tag="depthwise_conv2d_nhwc") return Output diff --git a/topi/python/topi/nn/mapping.py b/topi/python/topi/nn/mapping.py index 5e2347b8c797..5c4599d91241 100644 --- a/topi/python/topi/nn/mapping.py +++ b/topi/python/topi/nn/mapping.py @@ -3,8 +3,8 @@ from __future__ import absolute_import as _abs import tvm -@tvm.tag_scope(tag="bcast_scale_shift") -def scale_shift(Input, Scale, Shift): +@tvm.tag_scope(tag="bcast_scale_shift_nchw") +def scale_shift_nchw(Input, Scale, Shift): """Batch normalization operator in inference. Parameters @@ -24,3 +24,25 @@ def scale_shift(Input, Scale, Shift): Output tensor, layout is NCHW """ return tvm.compute(Input.shape, lambda b, c, i, j: Input[b, c, i, j] * Scale[c] + Shift[c], name='ScaleShift') + +@tvm.tag_scope(tag="bcast_scale_shift_nhwc") +def scale_shift_nhwc(Input, Scale, Shift): + """Batch normalization operator in inference. + + Parameters + ---------- + Input : tvm.Tensor + Input tensor, layout is NHWC + + Scale : tvm.Tensor + Scale tensor, 1-D of size channel number + + Shift : tvm.Tensor + Shift tensor, 1-D of size channel number + + Returns + ------- + Output : tvm.Tensor + Output tensor, layout is NHWC + """ + return tvm.compute(Input.shape, lambda b, i, j, c: Input[b, i, j, c] * Scale[c] + Shift[c], name='ScaleShift') diff --git a/topi/python/topi/testing/__init__.py b/topi/python/topi/testing/__init__.py index 55b4af8961dc..1cd4c54dfece 100644 --- a/topi/python/topi/testing/__init__.py +++ b/topi/python/topi/testing/__init__.py @@ -6,5 +6,5 @@ from .conv2d_hwcn_python import conv2d_hwcn_python from .conv2d_nchw_python import conv2d_nchw_python -from .depthwise_conv2d_python import depthwise_conv2d_python +from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc from .dilate_python import dilate_python diff --git a/topi/python/topi/testing/depthwise_conv2d_python.py b/topi/python/topi/testing/depthwise_conv2d_python.py index 8aecaf2abbd5..84784f97c2b8 100644 --- a/topi/python/topi/testing/depthwise_conv2d_python.py +++ b/topi/python/topi/testing/depthwise_conv2d_python.py @@ -3,8 +3,7 @@ import numpy as np from scipy import signal - -def depthwise_conv2d_python(input_np, filter_np, stride, padding): +def depthwise_conv2d_python_nchw(input_np, filter_np, stride, padding): """Depthwise convolution operator in NCHW layout. Parameters @@ -60,3 +59,60 @@ def depthwise_conv2d_python(input_np, filter_np, stride, padding): mode='same')[index_h:in_height:stride_h, index_w:in_width:stride_w] return output_np + +def depthwise_conv2d_python_nhwc(input_np, filter_np, stride, padding): + """Depthwise convolution operator in nchw layout. + + Parameters + ---------- + input_np : numpy.ndarray + 4-D with shape [batch, in_height, in_width, in_channel] + + filter_np : numpy.ndarray + 4-D with shape [filter_height, filter_width, in_channel, channel_multiplier] + + stride : list / tuple of 2 ints + [stride_height, stride_width] + + padding : str + 'VALID' or 'SAME' + + Returns + ------- + output_np : np.ndarray + 4-D with shape [batch, out_height, out_width, out_channel] + """ + batch, in_height, in_width, in_channel = input_np.shape + filter_height, filter_width, _, channel_multiplier = filter_np.shape + stride_h, stride_w = stride + # calculate output shape + if padding == 'VALID': + out_channel = in_channel * channel_multiplier + out_height = (in_height - filter_height) // stride_h + 1 + out_width = (in_width - filter_width) // stride_w + 1 + output_np = np.zeros((batch, out_height, out_width, out_channel)) + for i in range(batch): + for j in range(out_channel): + output_np[i, :, :, j] = signal.convolve2d(input_np[i, :, :, j//channel_multiplier], \ + np.rot90(filter_np[:, :, j//channel_multiplier, j%channel_multiplier], 2), \ + mode='valid')[0:(in_height - filter_height + 1):stride_h, 0:(in_width - filter_height + 1):stride_w] + if padding == 'SAME': + out_channel = in_channel * channel_multiplier + out_height = np.int(np.ceil(float(in_height) / float(stride_h))) + out_width = np.int(np.ceil(float(in_width) / float(stride_w))) + output_np = np.zeros((batch, out_height, out_width, out_channel)) + pad_along_height = np.int(np.max((out_height - 1) * stride_h + filter_height - in_height, 0)) + pad_along_width = np.int(np.max((out_width - 1) * stride_w + filter_width - in_width, 0)) + pad_top_tvm = np.int(np.ceil(float(pad_along_height) / 2)) + pad_left_tvm = np.int(np.ceil(float(pad_along_width) / 2)) + pad_top_scipy = np.int(np.ceil(float(filter_height - 1) / 2)) + pad_left_scipy = np.int(np.ceil(float(filter_width - 1) / 2)) + index_h = pad_top_scipy - pad_top_tvm + index_w = pad_left_scipy - pad_left_tvm + for i in range(batch): + for j in range(out_channel): + output_np[i, :, :, j] = signal.convolve2d(input_np[i, :, :, j//channel_multiplier], \ + np.rot90(filter_np[:, :, j//channel_multiplier, j%channel_multiplier], 2), \ + mode='same')[index_h:in_height:stride_h, index_w:in_width:stride_w] + + return output_np diff --git a/topi/recipe/conv/depthwise_conv2d_test.py b/topi/recipe/conv/depthwise_conv2d_test.py index 223570593d25..0f3a14dde2aa 100644 --- a/topi/recipe/conv/depthwise_conv2d_test.py +++ b/topi/recipe/conv/depthwise_conv2d_test.py @@ -6,14 +6,14 @@ import topi from topi.util import get_const_tuple -from topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d +from topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d_nchw, schedule_depthwise_conv2d_nhwc TASK = "depthwise_conv2d" USE_MANUAL_CODE = False @tvm.register_func def tvm_callback_cuda_compile(code): - ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_52"]) + ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_37"]) # 37 for k80(ec2 instance) return ptx def write_code(code, fname): @@ -29,7 +29,7 @@ def tvm_callback_cuda_postproc(code): code = open("perf/%s_manual.cu" % TASK).read() return code -def test_depthwise_conv2d(): +def test_depthwise_conv2d_nchw(): """You may test different settings.""" batch = 1 in_channel = 256 @@ -53,14 +53,13 @@ def test_depthwise_conv2d(): Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale') Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift') # Declare - DepthwiseConv2d = topi.nn.depthwise_conv2d(Input, Filter, Stride, padding) - ScaleShift = topi.nn.scale_shift(DepthwiseConv2d, Scale, Shift) + DepthwiseConv2d = topi.nn.depthwise_conv2d_nchw(Input, Filter, Stride, padding) + ScaleShift = topi.nn.scale_shift_nchw(DepthwiseConv2d, Scale, Shift) Relu = topi.nn.relu(ScaleShift) # Schedule - s1 = schedule_depthwise_conv2d(DepthwiseConv2d) - s2 = schedule_depthwise_conv2d(ScaleShift) - s3 = schedule_depthwise_conv2d(Relu) - + s1 = schedule_depthwise_conv2d_nchw(DepthwiseConv2d) + s2 = schedule_depthwise_conv2d_nchw(ScaleShift) + s3 = schedule_depthwise_conv2d_nchw(Relu) input_np = np.random.uniform(size=get_const_tuple(Input.shape)).astype(Input.dtype) filter_np = np.random.uniform(size=get_const_tuple(Filter.shape)).astype(Filter.dtype) scale_np = np.random.uniform(size=(in_channel * channel_multiplier)).astype(Scale.dtype) @@ -80,6 +79,7 @@ def check_device(device): filter_tvm = tvm.nd.array(filter_np, ctx) scale_tvm = tvm.nd.array(scale_np, ctx) shift_tvm = tvm.nd.array(shift_np, ctx) + depthwise_conv2d_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape),dtype=DepthwiseConv2d.dtype), ctx) scale_shift_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(ScaleShift.shape), dtype=ScaleShift.dtype), ctx) relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx) @@ -101,7 +101,7 @@ def check_device(device): print("average time cost of 1000 runs (depthwise_conv2d + scale_shift) = %g sec" % tcost_2) print("average time cost of 1000 runs (depthwise_conv2d + scale_shift + relu) = %g sec" % tcost_3) # correctness - depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python(input_np, filter_np, stride=[stride_h, stride_w], padding=padding) + depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nchw(input_np, filter_np, stride=[stride_h, stride_w], padding=padding) scale_shift_scipy = np.zeros(shape=get_const_tuple(ScaleShift.shape)) for c in range(in_channel * channel_multiplier): scale_shift_scipy[:,c,:,:] = depthwise_conv2d_scipy[:,c,:,:] * scale_np[c] + shift_np[c] @@ -118,5 +118,95 @@ def check_device(device): restricted_func=True): check_device("cuda") +def test_depthwise_conv2d_nhwc(): + """You may test different settings.""" + batch = 1 + in_channel = 256 + in_height = 96 + in_width = 96 + + filter_channel = in_channel + channel_multiplier = 1 + filter_height = 3 + filter_width = 3 + + stride_h = 1 + stride_w = 1 + + padding = 'SAME' # or 'VALID' + + # Placeholder + Input = tvm.placeholder((batch, in_height, in_width, in_channel), name='Input') + Filter = tvm.placeholder((filter_height, filter_width,filter_channel, channel_multiplier), name='Filter') + Stride = [stride_h, stride_w] + Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale') + Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift') + # Declare + DepthwiseConv2d = topi.nn.depthwise_conv2d_nhwc(Input, Filter, Stride, padding) + ScaleShift = topi.nn.scale_shift_nhwc(DepthwiseConv2d, Scale, Shift) + Relu = topi.nn.relu(ScaleShift) + # Schedule + s1 = schedule_depthwise_conv2d_nhwc(DepthwiseConv2d) + s2 = schedule_depthwise_conv2d_nhwc(ScaleShift) + s3 = schedule_depthwise_conv2d_nhwc(Relu) + + input_np = np.random.uniform(size=get_const_tuple(Input.shape)).astype(Input.dtype) + filter_np = np.random.uniform(size=get_const_tuple(Filter.shape)).astype(Filter.dtype) + scale_np = np.random.uniform(size=(in_channel * channel_multiplier)).astype(Scale.dtype) + shift_np = np.random.uniform(size=(in_channel * channel_multiplier)).astype(Shift.dtype) + + def check_device(device): + if not tvm.module.enabled(device): + print("Skip because %s is not enabled" % device) + return + ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) + # Build the kernel + f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device) + f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device) + f3 = tvm.build(s3, [Input, Filter, Scale, Shift, Relu], device) + # Prepare data + input_tvm = tvm.nd.array(input_np, ctx) + filter_tvm = tvm.nd.array(filter_np, ctx) + scale_tvm = tvm.nd.array(scale_np, ctx) + shift_tvm = tvm.nd.array(shift_np, ctx) + depthwise_conv2d_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape),dtype=DepthwiseConv2d.dtype), ctx) + scale_shift_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(ScaleShift.shape), dtype=ScaleShift.dtype), ctx) + relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx) + # Measure time cost of kernel 1 (depthwise_conv2d) + timer_1 = f1.time_evaluator(f1.entry_name, ctx, number=1000) + tcost_1 = timer_1(input_tvm, filter_tvm, depthwise_conv2d_tvm).mean + # Measure time cost of kernel 2 (depthwise_conv2d + scale_shift) + timer_2 = f2.time_evaluator(f2.entry_name, ctx, number=1000) + tcost_2 = timer_2(input_tvm, filter_tvm, scale_tvm, shift_tvm, scale_shift_tvm).mean + # Measure time cost of kernel 3 (depthwise_conv2d + scale_shift + relu) + timer_3 = f3.time_evaluator(f3.entry_name, ctx, number=1000) + tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm).mean + print("Input shape = " + str(get_const_tuple(Input.shape))) + print("Filter shape = " + str(get_const_tuple(Filter.shape))) + print("Stride = (%d, %d)" % (stride_h, stride_w)) + print("padding = %s\n" % padding) + print("Output shape = " + str(get_const_tuple(DepthwiseConv2d.shape))) + print("average time cost of 1000 runs (depthwise_conv2d) = %g sec" % tcost_1) + print("average time cost of 1000 runs (depthwise_conv2d + scale_shift) = %g sec" % tcost_2) + print("average time cost of 1000 runs (depthwise_conv2d + scale_shift + relu) = %g sec" % tcost_3) + # correctness + depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nhwc(input_np, filter_np, stride=[stride_h, stride_w], padding=padding) + scale_shift_scipy = np.zeros(shape=get_const_tuple(ScaleShift.shape)) + for c in range(in_channel * channel_multiplier): + scale_shift_scipy[:,:,:,c] = depthwise_conv2d_scipy[:,:,:,c] * scale_np[c] + shift_np[c] + relu_scipy = np.maximum(scale_shift_scipy, 0) + np.testing.assert_allclose(depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5) + np.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5) + np.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5) + print("success") + + with tvm.build_config(auto_unroll_max_step=32, + auto_unroll_min_depth=0, + unroll_explicit=False, + detect_global_barrier=False, + restricted_func=True): + check_device("cuda") + if __name__ == "__main__": - test_depthwise_conv2d() + test_depthwise_conv2d_nchw() + test_depthwise_conv2d_nhwc() diff --git a/topi/tests/python/test_topi_depthwise_conv2d.py b/topi/tests/python/test_topi_depthwise_conv2d.py index bb41066921fc..5b2c633c0a26 100644 --- a/topi/tests/python/test_topi_depthwise_conv2d.py +++ b/topi/tests/python/test_topi_depthwise_conv2d.py @@ -3,9 +3,9 @@ import numpy as np from scipy import signal from topi.util import get_const_tuple -from topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d +from topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d_nchw, schedule_depthwise_conv2d_nhwc -def depthwise_conv2d_with_workload(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding): +def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding): in_width = in_height filter_channel = in_channel filter_width = filter_height @@ -17,13 +17,13 @@ def depthwise_conv2d_with_workload(batch, in_channel, in_height, channel_multipl Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale') Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift') # declare - DepthwiseConv2d = topi.nn.depthwise_conv2d(Input, Filter, Stride, padding) - ScaleShift = topi.nn.scale_shift(DepthwiseConv2d, Scale, Shift) + DepthwiseConv2d = topi.nn.depthwise_conv2d_nchw(Input, Filter, Stride, padding) + ScaleShift = topi.nn.scale_shift_nchw(DepthwiseConv2d, Scale, Shift) Relu = topi.nn.relu(ScaleShift) # schedule - s1 = schedule_depthwise_conv2d(DepthwiseConv2d) - s2 = schedule_depthwise_conv2d(ScaleShift) - s3 = schedule_depthwise_conv2d(Relu) + s1 = schedule_depthwise_conv2d_nchw(DepthwiseConv2d) + s2 = schedule_depthwise_conv2d_nchw(ScaleShift) + s3 = schedule_depthwise_conv2d_nchw(Relu) input_np = np.random.uniform(size=get_const_tuple(Input.shape)).astype(Input.dtype) filter_np = np.random.uniform(size=get_const_tuple(Filter.shape)).astype(Filter.dtype) @@ -57,7 +57,7 @@ def check_device(device): timer_3 = f3.time_evaluator(f3.entry_name, ctx, number=1) tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm).mean # correctness with scipy - depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python(input_np, filter_np, stride=[stride_h, stride_w], padding=padding) + depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nchw(input_np, filter_np, stride=[stride_h, stride_w], padding=padding) scale_shift_scipy = np.zeros(shape=get_const_tuple(ScaleShift.shape)) for c in range(in_channel * channel_multiplier): scale_shift_scipy[:,c,:,:] = depthwise_conv2d_scipy[:,c,:,:] * scale_np[c] + shift_np[c] @@ -70,17 +70,91 @@ def check_device(device): check_device("cuda") check_device("metal") +def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding): + in_width = in_height + filter_channel = in_channel + filter_width = filter_height + stride_w = stride_h + # placeholder + Input = tvm.placeholder((batch, in_height, in_width, in_channel), name='Input') + Filter = tvm.placeholder((filter_height, filter_width,filter_channel, channel_multiplier), name='Filter') + Stride = [stride_h, stride_w] + Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale') + Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift') + # declare + DepthwiseConv2d = topi.nn.depthwise_conv2d_nhwc(Input, Filter, Stride, padding) + ScaleShift = topi.nn.scale_shift_nhwc(DepthwiseConv2d, Scale, Shift) + Relu = topi.nn.relu(ScaleShift) + # schedule + s1 = schedule_depthwise_conv2d_nhwc(DepthwiseConv2d) + s2 = schedule_depthwise_conv2d_nhwc(ScaleShift) + s3 = schedule_depthwise_conv2d_nhwc(Relu) -def test_depthwise_conv2d(): - depthwise_conv2d_with_workload(1, 728, 64, 1, 3, 1, "SAME") - depthwise_conv2d_with_workload(1, 728, 32, 1, 3, 1, "SAME") - depthwise_conv2d_with_workload(4, 256, 64, 2, 5, 2, "SAME") - depthwise_conv2d_with_workload(4, 256, 32, 2, 5, 2, "SAME") - depthwise_conv2d_with_workload(1, 728, 64, 1, 3, 1, "VALID") - depthwise_conv2d_with_workload(1, 728, 32, 1, 3, 1, "VALID") - depthwise_conv2d_with_workload(4, 256, 64, 2, 5, 2, "VALID") - depthwise_conv2d_with_workload(4, 256, 32, 2, 5, 2, "VALID") + input_np = np.random.uniform(size=get_const_tuple(Input.shape)).astype(Input.dtype) + filter_np = np.random.uniform(size=get_const_tuple(Filter.shape)).astype(Filter.dtype) + scale_np = np.random.uniform(size=get_const_tuple(Scale.shape)).astype(Scale.dtype) + shift_np = np.random.uniform(size=get_const_tuple(Shift.shape)).astype(Shift.dtype) + + def check_device(device): + if not tvm.module.enabled(device): + print("Skip because %s is not enabled" % device) + return + ctx = tvm.context(device, 0) + # build the kernels + f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device) + f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device) + f3 = tvm.build(s3, [Input, Filter, Scale, Shift, Relu], device) + # prepare data + input_tvm = tvm.nd.array(input_np, ctx) + filter_tvm = tvm.nd.array(filter_np, ctx) + scale_tvm = tvm.nd.array(scale_np, ctx) + shift_tvm = tvm.nd.array(shift_np, ctx) + depthwise_conv2d_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape), dtype=DepthwiseConv2d.dtype), ctx) + scale_shift_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(ScaleShift.shape), dtype=ScaleShift.dtype), ctx) + relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx) + # launch kernel 1 (depthwise_conv2d) + timer_1 = f1.time_evaluator(f1.entry_name, ctx, number=1) + tcost_1 = timer_1(input_tvm, filter_tvm, depthwise_conv2d_tvm).mean + # launch kernel 2 (depthwise_conv2d + scale_shift) + timer_2 = f2.time_evaluator(f2.entry_name, ctx, number=1) + tcost_2 = timer_2(input_tvm, filter_tvm, scale_tvm, shift_tvm, scale_shift_tvm).mean + # launch kernel 3 (depthwise_conv2d + scale_shift + relu) + timer_3 = f3.time_evaluator(f3.entry_name, ctx, number=1) + tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm).mean + # correctness with scipy + depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nhwc(input_np, filter_np, stride=[stride_h, stride_w], padding=padding) + scale_shift_scipy = np.zeros(shape=get_const_tuple(ScaleShift.shape)) + for c in range(in_channel * channel_multiplier): + scale_shift_scipy[:,:,:,c] = depthwise_conv2d_scipy[:,:,:,c] * scale_np[c] + shift_np[c] + relu_scipy = np.maximum(scale_shift_scipy, 0) + np.testing.assert_allclose(depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5) + np.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5) + np.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5) + check_device("opencl") + check_device("cuda") + check_device("metal") + + +def test_depthwise_conv2d(): + print("testing nchw") + depthwise_conv2d_with_workload_nchw(1, 728, 64, 1, 3, 1, "SAME") + depthwise_conv2d_with_workload_nchw(1, 728, 32, 1, 3, 1, "SAME") + depthwise_conv2d_with_workload_nchw(4, 256, 64, 2, 5, 2, "SAME") + depthwise_conv2d_with_workload_nchw(4, 256, 32, 2, 5, 2, "SAME") + depthwise_conv2d_with_workload_nchw(1, 728, 64, 1, 3, 1, "VALID") + depthwise_conv2d_with_workload_nchw(1, 728, 32, 1, 3, 1, "VALID") + depthwise_conv2d_with_workload_nchw(4, 256, 64, 2, 5, 2, "VALID") + depthwise_conv2d_with_workload_nchw(4, 256, 32, 2, 5, 2, "VALID") + print("testing nhwc") + depthwise_conv2d_with_workload_nhwc(1, 728, 64, 1, 3, 1, "SAME") + depthwise_conv2d_with_workload_nhwc(1, 728, 32, 1, 3, 1, "SAME") + depthwise_conv2d_with_workload_nhwc(4, 256, 64, 2, 5, 2, "SAME") + depthwise_conv2d_with_workload_nhwc(4, 256, 32, 2, 5, 2, "SAME") + depthwise_conv2d_with_workload_nhwc(1, 728, 64, 1, 3, 1, "VALID") + depthwise_conv2d_with_workload_nhwc(1, 728, 32, 1, 3, 1, "VALID") + depthwise_conv2d_with_workload_nhwc(4, 256, 64, 2, 5, 2, "VALID") + depthwise_conv2d_with_workload_nhwc(4, 256, 32, 2, 5, 2, "VALID") if __name__ == "__main__": test_depthwise_conv2d()