Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion topi/python/topi/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@

from .conv2d_nchw import schedule_conv2d_nchw
from .conv2d_hwcn import schedule_conv2d_hwcn
from .depthwise_conv2d_map import schedule_depthwise_conv2d_map
from .depthwise_conv2d import schedule_depthwise_conv2d
from .reduction import schedule_reduce
from .broadcast import schedule_broadcast_to
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,24 @@
import tvm
from ..util import get_const_tuple

def schedule_depthwise_conv2d_map(op):
"""Schedule for depthwise_conv2d map ops.

This include scale-shift and relu.
def schedule_depthwise_conv2d(outs):
"""Schedule for depthwise_conv2d.

Parameters
----------
op: Operation
The symbolic description of the operation, should be depthwise_conv2d or
depthwise_conv2d followed by a sequence of one-to-one-mapping operators.
outs: Array of Tensor
The computation graph description of depthwise_conv2d
in the format of an array of tensors.

Returns
-------
s: Schedule
The computation schedule for the op.
The computation schedule for depthwise_conv2d.
"""
s = tvm.create_schedule(op)
def schedule_depthwise_conv2d(PaddedInput, Filter, DepthwiseConv2d):
"""Schedule for depthwise_conv2d declared in topi.nn.conv"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
def _schedule(PaddedInput, Filter, DepthwiseConv2d):
out_shape = get_const_tuple(DepthwiseConv2d.shape)
out_height = out_shape[2]
out_width = out_shape[3]
Expand All @@ -35,27 +34,27 @@ def schedule_depthwise_conv2d(PaddedInput, Filter, DepthwiseConv2d):
Output = DepthwiseConv2d
CL = s.cache_write(DepthwiseConv2d, "local")
else:
Output = op.output(0)
Output = outs[0].op.output(0)
s[DepthwiseConv2d].set_scope("local")
# schedule parameters
num_thread = 8
num_thread_x = 8
num_thread_y = 8
num_vthread_x = 1
num_vthread_y = 1
blocking_h = out_height
blocking_w = out_width
if out_height % 48 == 0:
blocking_h = 48
elif out_height % 32 == 0:
if out_height % 32 == 0:
blocking_h = 32
if out_width % 48 == 0:
blocking_w = 48
num_vthread_y = 3
elif out_width % 32 == 0:
num_thread_x = 2
num_vthread_x = 2
if out_width % 32 == 0:
blocking_w = 32
num_thread_y = 16
num_vthread_y = 2
block_x = tvm.thread_axis("blockIdx.x")
block_y = tvm.thread_axis("blockIdx.y")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y")
thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x")
thread_y = tvm.thread_axis((0, num_thread_y), "threadIdx.y")
thread_vx = tvm.thread_axis((0, num_vthread_x), "vthread", name="vx")
thread_vy = tvm.thread_axis((0, num_vthread_y), "vthread", name="vy")
# split and bind
Expand All @@ -65,10 +64,10 @@ def schedule_depthwise_conv2d(PaddedInput, Filter, DepthwiseConv2d):
s[Output].bind(bx, block_x)
by1, y1i = s[Output].split(Output.op.axis[2], factor=blocking_h)
tvx, vxi = s[Output].split(y1i, nparts=num_vthread_x)
tx, xi = s[Output].split(vxi, nparts=num_thread)
tx, xi = s[Output].split(vxi, nparts=num_thread_x)
by2, y2i = s[Output].split(Output.op.axis[3], factor=blocking_w)
tvy, vyi = s[Output].split(y2i, nparts=num_vthread_y)
ty, yi = s[Output].split(vyi, nparts=num_thread)
ty, yi = s[Output].split(vyi, nparts=num_thread_y)
s[Output].reorder(by1, by2, tvx, tvy, tx, ty, xi, yi)
by = s[Output].fuse(by1, by2)
s[Output].bind(tvx, thread_vx)
Expand All @@ -85,21 +84,21 @@ def schedule_depthwise_conv2d(PaddedInput, Filter, DepthwiseConv2d):
s[DepthwiseConv2d].compute_at(s[Output], ty)
# input's shared memory load
s[IS].compute_at(s[Output], by)
tx, xi = s[IS].split(IS.op.axis[2], nparts=num_thread)
ty, yi = s[IS].split(IS.op.axis[3], nparts=num_thread)
tx, xi = s[IS].split(IS.op.axis[2], nparts=num_thread_x)
ty, yi = s[IS].split(IS.op.axis[3], nparts=num_thread_y)
s[IS].bind(tx, thread_x)
s[IS].bind(ty, thread_y)
# filter's shared memory load
s[FS].compute_at(s[Output], by)
s[FS].reorder(FS.op.axis[2], FS.op.axis[3], FS.op.axis[1])
tx, xi = s[FS].split(FS.op.axis[2], nparts=num_thread)
ty, yi = s[FS].split(FS.op.axis[3], nparts=num_thread)
tx, xi = s[FS].split(FS.op.axis[2], nparts=num_thread_x)
ty, yi = s[FS].split(FS.op.axis[3], nparts=num_thread_y)
s[FS].bind(tx, thread_x)
s[FS].bind(ty, thread_y)

def traverse(OP):
# inline all one-to-one-mapping operators except the last stage (output)
if OP.tag == 'ewise' or OP.tag == 'scale_shift':
if 'ewise' in OP.tag or 'bcast' in OP.tag:
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
Expand All @@ -110,7 +109,7 @@ def traverse(OP):
PaddedInput = OP.input_tensors[0]
Filter = OP.input_tensors[1]
DepthwiseConv2d = OP.output(0)
schedule_depthwise_conv2d(PaddedInput, Filter, DepthwiseConv2d)
_schedule(PaddedInput, Filter, DepthwiseConv2d)

traverse(op)
traverse(outs[0].op)
return s
2 changes: 1 addition & 1 deletion topi/python/topi/nn/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import absolute_import as _abs
import tvm

@tvm.tag_scope(tag="scale_shift")
@tvm.tag_scope(tag="bcast_scale_shift")
def scale_shift(Input, Scale, Shift):
"""Batch normalization operator in inference.

Expand Down
1 change: 1 addition & 0 deletions topi/python/topi/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@

from .conv2d_hwcn_python import conv2d_hwcn_python
from .conv2d_nchw_python import conv2d_nchw_python
from .depthwise_conv2d_python import depthwise_conv2d_python
from .dilate_python import dilate_python
62 changes: 62 additions & 0 deletions topi/python/topi/testing/depthwise_conv2d_python.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# pylint: disable=invalid-name, unused-variable, line-too-long
"""Depthwise convolution in python"""
import numpy as np
from scipy import signal


def depthwise_conv2d_python(input_np, filter_np, stride, padding):
"""Depthwise convolution operator in NCHW layout.

Parameters
----------
input_np : numpy.ndarray
4-D with shape [batch, in_channel, in_height, in_width]

filter_np : numpy.ndarray
4-D with shape [in_channel, channel_multiplier, filter_height, filter_width]

stride : list / tuple of 2 ints
[stride_height, stride_width]

padding : str
'VALID' or 'SAME'

Returns
-------
output_np : np.ndarray
4-D with shape [batch, out_channel, out_height, out_width]
"""
batch, in_channel, in_height, in_width = input_np.shape
_, channel_multiplier, filter_height, filter_width = filter_np.shape
stride_h, stride_w = stride
# calculate output shape
if padding == 'VALID':
out_channel = in_channel * channel_multiplier
out_height = (in_height - filter_height) // stride_h + 1
out_width = (in_width - filter_width) // stride_w + 1
output_np = np.zeros((batch, out_channel, out_height, out_width))
for i in range(batch):
for j in range(out_channel):
output_np[i, j, :, :] = signal.convolve2d(input_np[i, j//channel_multiplier, :, :], \
np.rot90(filter_np[j//channel_multiplier, j%channel_multiplier, :, :], 2), \
mode='valid')[0:(in_height - filter_height + 1):stride_h, 0:(in_width - filter_height + 1):stride_w]
if padding == 'SAME':
out_channel = in_channel * channel_multiplier
out_height = np.int(np.ceil(float(in_height) / float(stride_h)))
out_width = np.int(np.ceil(float(in_width) / float(stride_w)))
output_np = np.zeros((batch, out_channel, out_height, out_width))
pad_along_height = np.int(np.max((out_height - 1) * stride_h + filter_height - in_height, 0))
pad_along_width = np.int(np.max((out_width - 1) * stride_w + filter_width - in_width, 0))
pad_top_tvm = np.int(np.ceil(float(pad_along_height) / 2))
pad_left_tvm = np.int(np.ceil(float(pad_along_width) / 2))
pad_top_scipy = np.int(np.ceil(float(filter_height - 1) / 2))
pad_left_scipy = np.int(np.ceil(float(filter_width - 1) / 2))
index_h = pad_top_scipy - pad_top_tvm
index_w = pad_left_scipy - pad_left_tvm
for i in range(batch):
for j in range(out_channel):
output_np[i, j, :, :] = signal.convolve2d(input_np[i, j//channel_multiplier, :, :], \
np.rot90(filter_np[j//channel_multiplier, j%channel_multiplier, :, :], 2), \
mode='same')[index_h:in_height:stride_h, index_w:in_width:stride_w]

return output_np
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from tvm.contrib import nvcc

import topi
from topi.nn.util import get_const_tuple
from topi.cuda.depthwise_conv2d_map import schedule_depthwise_conv2d_map
from topi.util import get_const_tuple
from topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d

TASK = "depthwise_conv2d_map"
TASK = "depthwise_conv2d"
USE_MANUAL_CODE = False

@tvm.register_func
Expand All @@ -29,20 +29,20 @@ def tvm_callback_cuda_postproc(code):
code = open("perf/%s_manual.cu" % TASK).read()
return code

def test_depthwise_conv2d_map():
def test_depthwise_conv2d():
"""You may test different settings."""
batch = 2
batch = 1
in_channel = 256
in_height = 32
in_width = 32
in_height = 96
in_width = 96

filter_channel = in_channel
channel_multiplier = 2
filter_height = 5
filter_width = 5
channel_multiplier = 1
filter_height = 3
filter_width = 3

stride_h = 2
stride_w = 2
stride_h = 1
stride_w = 1

padding = 'SAME' # or 'VALID'

Expand All @@ -57,40 +57,14 @@ def test_depthwise_conv2d_map():
ScaleShift = topi.nn.scale_shift(DepthwiseConv2d, Scale, Shift)
Relu = topi.nn.relu(ScaleShift)
# Schedule
s1 = schedule_depthwise_conv2d_map(DepthwiseConv2d.op)
s2 = schedule_depthwise_conv2d_map(ScaleShift.op)
s3 = schedule_depthwise_conv2d_map(Relu.op)
s1 = schedule_depthwise_conv2d(DepthwiseConv2d)
s2 = schedule_depthwise_conv2d(ScaleShift)
s3 = schedule_depthwise_conv2d(Relu)

def depthwise_conv2d_map_scipy(input_np, filter_np, scale_np, shift_np):
out_shape = get_const_tuple(DepthwiseConv2d.shape)
out_channel = out_shape[1]
out_height = out_shape[2]
out_width = out_shape[3]
depthwise_conv2d_scipy = np.zeros((batch, out_channel, out_height, out_width), dtype=DepthwiseConv2d.dtype)
scale_shift_scipy = np.zeros((batch, out_channel, out_height, out_width), dtype=ScaleShift.dtype)
relu_scipy = np.zeros((batch, out_channel, out_height, out_width), dtype=Relu.dtype)
if padding == 'SAME':
pad_top_tvm = np.int(np.ceil(float(np.max((out_height - 1) * stride_h + filter_height - in_height, 0)) / 2))
pad_left_tvm = np.int(np.ceil(float(np.max((out_width - 1) * stride_w + filter_width - in_width, 0)) / 2))
pad_top_scipy = np.int(np.ceil(float(filter_height - 1) / 2))
pad_left_scipy = np.int(np.ceil(float(filter_width - 1) / 2))
index_h = pad_top_scipy - pad_top_tvm
index_w = pad_left_scipy - pad_left_tvm
for i in range(batch):
for j in range(out_channel):
depthwise_conv2d_scipy[i,j,:,:] = signal.convolve2d(input_np[i,j//channel_multiplier,:,:],
np.rot90(filter_np[j//channel_multiplier,j%channel_multiplier,:,:], 2),
mode='same')[index_h:in_height:stride_h, index_w:in_width:stride_w]
if padding == 'VALID':
for i in range(batch):
for j in range(out_channel):
depthwise_conv2d_scipy[i,j,:,:] = signal.convolve2d(input_np[i,j//channel_multiplier,:,:],
np.rot90(filter_np[j//channel_multiplier,j%channel_multiplier,:,:], 2),
mode='valid')[0:(in_height - filter_height + 1):stride_h, 0:(in_width - filter_height + 1):stride_w]
for c in range(out_channel):
scale_shift_scipy[:,c,:,:] = depthwise_conv2d_scipy[:,c,:,:] * scale_np[c] + shift_np[c]
relu_scipy[:,:,:,:] = np.maximum(scale_shift_scipy[:,:,:,:], 0)
return depthwise_conv2d_scipy, scale_shift_scipy, relu_scipy
input_np = np.random.uniform(size=get_const_tuple(Input.shape)).astype(Input.dtype)
filter_np = np.random.uniform(size=get_const_tuple(Filter.shape)).astype(Filter.dtype)
scale_np = np.random.uniform(size=(in_channel * channel_multiplier)).astype(Scale.dtype)
shift_np = np.random.uniform(size=(in_channel * channel_multiplier)).astype(Shift.dtype)

def check_device(device):
if not tvm.module.enabled(device):
Expand All @@ -102,46 +76,47 @@ def check_device(device):
f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device)
f3 = tvm.build(s3, [Input, Filter, Scale, Shift, Relu], device)
# Prepare data
input_np = np.random.uniform(size=get_const_tuple(Input.shape)).astype(Input.dtype)
filter_np = np.random.uniform(size=get_const_tuple(Filter.shape)).astype(Filter.dtype)
input_tvm = tvm.nd.array(input_np, ctx)
filter_tvm = tvm.nd.array(filter_np, ctx)
scale_np = np.random.uniform(size=(in_channel * channel_multiplier)).astype(Scale.dtype)
shift_np = np.random.uniform(size=(in_channel * channel_multiplier)).astype(Shift.dtype)
scale_tvm = tvm.nd.array(scale_np, ctx)
shift_tvm = tvm.nd.array(shift_np, ctx)
depthwise_conv2d_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape), dtype=DepthwiseConv2d.dtype), ctx)
depthwise_conv2d_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape),dtype=DepthwiseConv2d.dtype), ctx)
scale_shift_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(ScaleShift.shape), dtype=ScaleShift.dtype), ctx)
relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx)
# Measure time cost of kernel 1 (depthwise_conv2d)
timer_1 = f1.time_evaluator(f1.entry_name, ctx, number=10000)
timer_1 = f1.time_evaluator(f1.entry_name, ctx, number=1000)
tcost_1 = timer_1(input_tvm, filter_tvm, depthwise_conv2d_tvm).mean
# Measure time cost of kernel 2 (depthwise_conv2d + scale_shift)
timer_2 = f2.time_evaluator(f2.entry_name, ctx, number=10000)
timer_2 = f2.time_evaluator(f2.entry_name, ctx, number=1000)
tcost_2 = timer_2(input_tvm, filter_tvm, scale_tvm, shift_tvm, scale_shift_tvm).mean
# Measure time cost of kernel 3 (depthwise_conv2d + scale_shift + relu)
timer_3 = f3.time_evaluator(f3.entry_name, ctx, number=10000)
timer_3 = f3.time_evaluator(f3.entry_name, ctx, number=1000)
tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm).mean
print("Input shape = " + str(get_const_tuple(Input.shape)))
print("Filter shape = " + str(get_const_tuple(Filter.shape)))
print("Stride = (%d, %d)" % (stride_h, stride_w))
print("padding = %s\n" % padding)
print("Output shape = " + str(get_const_tuple(DepthwiseConv2d.shape)))
print("average time cost of 10000 runs (depthwise_conv2d) = %g sec" % tcost_1)
print("average time cost of 10000 runs (depthwise_conv2d + scale_shift) = %g sec" % tcost_2)
print("average time cost of 10000 runs (depthwise_conv2d + scale_shift + relu) = %g sec" % tcost_3)
depthwise_conv2d_scipy, scale_shift_scipy, relu_scipy = depthwise_conv2d_map_scipy(input_np, filter_np, scale_np, shift_np)
print("average time cost of 1000 runs (depthwise_conv2d) = %g sec" % tcost_1)
print("average time cost of 1000 runs (depthwise_conv2d + scale_shift) = %g sec" % tcost_2)
print("average time cost of 1000 runs (depthwise_conv2d + scale_shift + relu) = %g sec" % tcost_3)
# correctness
depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python(input_np, filter_np, stride=[stride_h, stride_w], padding=padding)
scale_shift_scipy = np.zeros(shape=get_const_tuple(ScaleShift.shape))
for c in range(in_channel * channel_multiplier):
scale_shift_scipy[:,c,:,:] = depthwise_conv2d_scipy[:,c,:,:] * scale_np[c] + shift_np[c]
relu_scipy = np.maximum(scale_shift_scipy, 0)
np.testing.assert_allclose(depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5)
np.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5)
np.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5)
print("success")

with tvm.build_config(auto_unroll_max_step=32,
auto_unroll_min_depth=0,
unroll_explicit=True,
unroll_explicit=False,
detect_global_barrier=False,
restricted_func=True):
check_device("cuda")

if __name__ == "__main__":
test_depthwise_conv2d_map()
test_depthwise_conv2d()
Loading