Skip to content
35 changes: 34 additions & 1 deletion topi/include/topi/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ inline tvm::Tensor depthwise_conv2d_nchw(const tvm::Tensor& I,
int stride_h = 1,
int stride_w = 1,
std::string name = "tensor",
std::string tag = kDepthwiseConv2d) {
std::string tag = kDepthwiseConv2dNCHW) {
CHECK_EQ(4, I->shape.size());
CHECK_EQ(4, W->shape.size());
auto pH = I->shape[2];
Expand All @@ -313,6 +313,39 @@ inline tvm::Tensor depthwise_conv2d_nchw(const tvm::Tensor& I,
return tvm::compute(output_shape, l, name, tag);
}

inline tvm::Tensor depthwise_conv2d_nhwc(const tvm::Tensor& I,
const tvm::Tensor& W,
int pad_h = 0,
int pad_w = 0,
int stride_h = 1,
int stride_w = 1,
std::string name = "tensor",
std::string tag = kDepthwiseConv2dNHWC) {
CHECK_EQ(4, I->shape.size());
CHECK_EQ(4, W->shape.size());
auto pH = I->shape[1];
auto pW = I->shape[2];
auto pCM = W->shape[1]; // channel_multiplier
tvm::Array<tvm::Expr> output_shape{
I->shape[0], // B
(I->shape[1] - W->shape[1] + 2 * pad_h) / stride_h + 1, // H
(I->shape[2] - W->shape[2] + 2 * pad_w) / stride_w + 1, // W
W->shape[3], // O
};
auto i = tvm::reduce_axis(tvm::Range{0, I->shape[3]}, "i");
auto kh = tvm::reduce_axis(tvm::Range{0, W->shape[0]}, "kh");
auto kw = tvm::reduce_axis(tvm::Range{0, W->shape[1]}, "kw");
auto T = (pad_h == 0 && pad_w == 0)
? I
: pad(I, {tvm::Expr(0), pad_h, pad_w, tvm::Expr(0)});
auto l = [&](tvm::Var b, tvm::Var h, tvm::Var w, tvm::Var o) {
return tvm::sum(T(b, stride_h * h + kh, stride_w * w + kw, i / pCM) *
W(kh, kw, i / pCM, o % pCM),
{kh, kw, i});
};
return tvm::compute(output_shape, l, name, tag);
}

/*!
* \brief Creates an operation that performs a 2-D group convolution with
* an NGCHW-layout
Expand Down
3 changes: 2 additions & 1 deletion topi/include/topi/tags.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ constexpr auto kBroadcast = "bcast";
constexpr auto kMatMult = "matmult";
constexpr auto kConv2dNCHW = "conv2d_nchw";
constexpr auto kConv2dHWCN = "conv2d_hwcn";
constexpr auto kDepthwiseConv2d = "depthwise_conv2d";
constexpr auto kDepthwiseConv2dNCHW = "depthwise_conv2d_nchw";
constexpr auto kDepthwiseConv2dNHWC = "depthwise_conv2d_nhwc";
constexpr auto kGroupConv2d = "group_conv2d";

} // namespace topi
Expand Down
2 changes: 1 addition & 1 deletion topi/python/topi/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@

from .conv2d_nchw import schedule_conv2d_nchw
from .conv2d_hwcn import schedule_conv2d_hwcn
from .depthwise_conv2d import schedule_depthwise_conv2d
from .depthwise_conv2d import schedule_depthwise_conv2d_nchw, schedule_depthwise_conv2d_nhwc
from .reduction import schedule_reduce
from .broadcast import schedule_broadcast_to
80 changes: 75 additions & 5 deletions topi/python/topi/cuda/depthwise_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
import tvm
from ..util import get_const_tuple


def schedule_depthwise_conv2d(outs):
"""Schedule for depthwise_conv2d.
def schedule_depthwise_conv2d_nchw(outs):
"""Schedule for depthwise_conv2d nchw forward.

Parameters
----------
Expand All @@ -16,7 +15,7 @@ def schedule_depthwise_conv2d(outs):
Returns
-------
s: Schedule
The computation schedule for depthwise_conv2d.
The computation schedule for depthwise_conv2d nchw.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
Expand Down Expand Up @@ -105,7 +104,78 @@ def traverse(OP):
if tensor.op.input_tensors:
traverse(tensor.op)
# schedule depthwise_conv2d
if OP.tag == 'depthwise_conv2d':
if OP.tag == 'depthwise_conv2d_nchw':
PaddedInput = OP.input_tensors[0]
Filter = OP.input_tensors[1]
DepthwiseConv2d = OP.output(0)
_schedule(PaddedInput, Filter, DepthwiseConv2d)

traverse(outs[0].op)
return s

def schedule_depthwise_conv2d_nhwc(outs):
"""Schedule for depthwise_conv2d nhwc forward.

Parameters
----------
outs: Array of Tensor
The computation graph description of depthwise_conv2d
in the format of an array of tensors.

Returns
-------
s: Schedule
The computation schedule for depthwise_conv2d nhwc.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
def _schedule(temp, Filter, DepthwiseConv2d):

s[temp].compute_inline()
FS = s.cache_read(Filter, "shared", [DepthwiseConv2d])
if DepthwiseConv2d.op in s.outputs:
Output = DepthwiseConv2d
CL = s.cache_write(DepthwiseConv2d, "local")
else:
Output = outs[0].op.output(0)
s[DepthwiseConv2d].set_scope("local")

block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis("threadIdx.x")

b, h, w, c = s[Output].op.axis

ic_val = tvm.ir_pass.Simplify(temp.shape[3]).value
xoc, xic = s[Output].split(c, factor=ic_val)
s[Output].reorder(xoc, b, h, w, xic)
xo, yo, _, _ = s[Output].tile(h, w, x_factor=2, y_factor=2)
fused = s[Output].fuse(yo, xo)
fused = s[Output].fuse(fused, b)
fused = s[Output].fuse(fused, xoc)

s[Output].bind(fused, block_x)
s[Output].bind(xic, thread_x)

if DepthwiseConv2d.op in s.outputs:
s[CL].compute_at(s[Output], xic)
else:
s[DepthwiseConv2d].compute_at(s[Output], xic)

_, _, ci, fi = s[FS].op.axis
s[FS].compute_at(s[Output], fused)
fused = s[FS].fuse(fi, ci)
s[FS].bind(fused, thread_x)

def traverse(OP):
# inline all one-to-one-mapping operators except the last stage (output)
if 'ewise' in OP.tag or 'bcast' in OP.tag:
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
# schedule depthwise_conv2d
if OP.tag == 'depthwise_conv2d_nhwc':
PaddedInput = OP.input_tensors[0]
Filter = OP.input_tensors[1]
DepthwiseConv2d = OP.output(0)
Expand Down
55 changes: 51 additions & 4 deletions topi/python/topi/nn/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,8 @@ def conv2d_hwcn(Input, Filter, stride, padding):
name="Conv2dOutput", tag="conv2d_hwcn")
return Output


def depthwise_conv2d(Input, Filter, stride, padding):
"""Depthwise convolution operator.
def depthwise_conv2d_nchw(Input, Filter, stride, padding):
"""Depthwise convolution nchw forward operator.

Parameters
----------
Expand Down Expand Up @@ -153,5 +152,53 @@ def depthwise_conv2d(Input, Filter, stride, padding):
(PaddedInput[b, c/channel_multiplier, i*stride_h + di, j*stride_w + dj] *
Filter[c/channel_multiplier, c%channel_multiplier, di, dj]),
axis=[di, dj]),
name='DepthwiseConv2d', tag="depthwise_conv2d")
name='DepthwiseConv2d', tag="depthwise_conv2d_nchw")
return Output

def depthwise_conv2d_nhwc(Input, Filter, stride, padding):
"""Depthwise convolution nhwc forward operator.

Parameters
----------
Input : tvm.Tensor
4-D with shape [batch, in_height, in_width, in_channel]

Filter : tvm.Tensor
4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]

Stride : tvm.Tensor
1-D of size 2

padding : str
'VALID' or 'SAME'

Returns
-------
Output : tvm.Tensor
4-D with shape [batch, out_height, out_width, out_channel]
"""
batch, in_height, in_width, in_channel = Input.shape
filter_height, filter_width, filter_channel, channel_multiplier = Filter.shape
stride_h, stride_w = stride

pad_top, pad_left, pad_down, pad_right = _spatial2d_pad_option(
padding, (filter_height, filter_width))
out_channel = simplify(in_channel * channel_multiplier)
out_height = simplify((in_height - filter_height + pad_top + pad_down) // stride_h + 1)
out_width = simplify((in_width - filter_width + pad_left + pad_right) // stride_w + 1)

# padding stage
pad_before = [0, pad_top, pad_left, 0]
pad_after = [0, pad_down, pad_right, 0]
PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
# depthconv stage
di = tvm.reduce_axis((0, filter_height), name='di')
dj = tvm.reduce_axis((0, filter_width), name='dj')
Output = tvm.compute(
(batch, out_height, out_width, out_channel),
lambda b, i, j, c: tvm.sum(
(PaddedInput[b, i*stride_h + di, j*stride_w + dj, c/channel_multiplier] *
Filter[di, dj, c/channel_multiplier, c%channel_multiplier]),
axis=[di, dj]),
name='DepthwiseConv2d', tag="depthwise_conv2d_nhwc")
return Output
26 changes: 24 additions & 2 deletions topi/python/topi/nn/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from __future__ import absolute_import as _abs
import tvm

@tvm.tag_scope(tag="bcast_scale_shift")
def scale_shift(Input, Scale, Shift):
@tvm.tag_scope(tag="bcast_scale_shift_nchw")
def scale_shift_nchw(Input, Scale, Shift):
"""Batch normalization operator in inference.

Parameters
Expand All @@ -24,3 +24,25 @@ def scale_shift(Input, Scale, Shift):
Output tensor, layout is NCHW
"""
return tvm.compute(Input.shape, lambda b, c, i, j: Input[b, c, i, j] * Scale[c] + Shift[c], name='ScaleShift')

@tvm.tag_scope(tag="bcast_scale_shift_nhwc")
def scale_shift_nhwc(Input, Scale, Shift):
"""Batch normalization operator in inference.

Parameters
----------
Input : tvm.Tensor
Input tensor, layout is NHWC

Scale : tvm.Tensor
Scale tensor, 1-D of size channel number

Shift : tvm.Tensor
Shift tensor, 1-D of size channel number

Returns
-------
Output : tvm.Tensor
Output tensor, layout is NHWC
"""
return tvm.compute(Input.shape, lambda b, i, j, c: Input[b, i, j, c] * Scale[c] + Shift[c], name='ScaleShift')
2 changes: 1 addition & 1 deletion topi/python/topi/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@

from .conv2d_hwcn_python import conv2d_hwcn_python
from .conv2d_nchw_python import conv2d_nchw_python
from .depthwise_conv2d_python import depthwise_conv2d_python
from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc
from .dilate_python import dilate_python
60 changes: 58 additions & 2 deletions topi/python/topi/testing/depthwise_conv2d_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import numpy as np
from scipy import signal


def depthwise_conv2d_python(input_np, filter_np, stride, padding):
def depthwise_conv2d_python_nchw(input_np, filter_np, stride, padding):
"""Depthwise convolution operator in NCHW layout.

Parameters
Expand Down Expand Up @@ -60,3 +59,60 @@ def depthwise_conv2d_python(input_np, filter_np, stride, padding):
mode='same')[index_h:in_height:stride_h, index_w:in_width:stride_w]

return output_np

def depthwise_conv2d_python_nhwc(input_np, filter_np, stride, padding):
"""Depthwise convolution operator in nchw layout.

Parameters
----------
input_np : numpy.ndarray
4-D with shape [batch, in_height, in_width, in_channel]

filter_np : numpy.ndarray
4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]

stride : list / tuple of 2 ints
[stride_height, stride_width]

padding : str
'VALID' or 'SAME'

Returns
-------
output_np : np.ndarray
4-D with shape [batch, out_height, out_width, out_channel]
"""
batch, in_height, in_width, in_channel = input_np.shape
filter_height, filter_width, _, channel_multiplier = filter_np.shape
stride_h, stride_w = stride
# calculate output shape
if padding == 'VALID':
out_channel = in_channel * channel_multiplier
out_height = (in_height - filter_height) // stride_h + 1
out_width = (in_width - filter_width) // stride_w + 1
output_np = np.zeros((batch, out_height, out_width, out_channel))
for i in range(batch):
for j in range(out_channel):
output_np[i, :, :, j] = signal.convolve2d(input_np[i, :, :, j//channel_multiplier], \
np.rot90(filter_np[:, :, j//channel_multiplier, j%channel_multiplier], 2), \
mode='valid')[0:(in_height - filter_height + 1):stride_h, 0:(in_width - filter_height + 1):stride_w]
if padding == 'SAME':
out_channel = in_channel * channel_multiplier
out_height = np.int(np.ceil(float(in_height) / float(stride_h)))
out_width = np.int(np.ceil(float(in_width) / float(stride_w)))
output_np = np.zeros((batch, out_height, out_width, out_channel))
pad_along_height = np.int(np.max((out_height - 1) * stride_h + filter_height - in_height, 0))
pad_along_width = np.int(np.max((out_width - 1) * stride_w + filter_width - in_width, 0))
pad_top_tvm = np.int(np.ceil(float(pad_along_height) / 2))
pad_left_tvm = np.int(np.ceil(float(pad_along_width) / 2))
pad_top_scipy = np.int(np.ceil(float(filter_height - 1) / 2))
pad_left_scipy = np.int(np.ceil(float(filter_width - 1) / 2))
index_h = pad_top_scipy - pad_top_tvm
index_w = pad_left_scipy - pad_left_tvm
for i in range(batch):
for j in range(out_channel):
output_np[i, :, :, j] = signal.convolve2d(input_np[i, :, :, j//channel_multiplier], \
np.rot90(filter_np[:, :, j//channel_multiplier, j%channel_multiplier], 2), \
mode='same')[index_h:in_height:stride_h, index_w:in_width:stride_w]

return output_np
Loading