Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions python/tvm/topi/testing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
"""Common utility for topi test"""

import numpy as np
import scipy.signal

import tvm
from tvm import topi
from tvm.testing import assert_allclose
Expand Down Expand Up @@ -108,3 +110,52 @@ def compare_numpy_tvm(inputs, output, target, device, compute, schedule):
arys = [tvm.nd.array(x, device=device) for x in inputs]
func(*(arys + [te_out]))
assert_allclose(te_out.numpy(), output, atol=1e-4, rtol=1e-4)


def _convolve2d(data, weights):
"""2d convolution operator in HW layout.

This is intended to be used as a replacement for
scipy.signals.convolve2d, with wider support for different dtypes.
scipy.signal.convolve2d does not support all TVM-supported
dtypes (e.g. float16). Where possible, this function uses
scipy.signal.convolve2d to take advantage of compiled scipy
routines, falling back to an explicit loop only where needed.

Parameters
----------
data : numpy.ndarray
2-D with shape [in_height, in_width]

weights : numpy.ndarray
2-D with shape [filter_height, filter_width].

Returns
-------
b_np : np.ndarray
2-D with shape [out_height, out_width]

Return value and layout conventions are matched to
``scipy.signal.convolve2d(data, weights, mode="valid")``
"""

try:
return scipy.signal.convolve2d(data, weights, mode="valid")
except ValueError:
pass

weights = np.rot90(weights, k=2)

assert len(data.shape) == len(weights.shape) == 2

dtype = data.dtype
kernel_h, kernel_w = weights.shape

output_shape = [a_dim - w_dim + 1 for a_dim, w_dim in zip(data.shape, weights.shape)]
output = np.zeros(output_shape, dtype=dtype)

for y in range(output_shape[0]):
for x in range(output_shape[1]):
output[y][x] = np.sum(data[y : y + kernel_h, x : x + kernel_w] * weights)

return output
55 changes: 51 additions & 4 deletions python/tvm/topi/testing/conv2d_nchw_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals, too-many-branches
"""Convolution in python"""
import numpy as np
import scipy.signal
import scipy

from tvm.topi.nn.utils import get_pad_tuple


Expand Down Expand Up @@ -58,21 +59,67 @@ def _conv2d_nchw_python(a_np, w_np, stride, padding):
out_channel = num_filter
out_height = (in_height - kernel_h + pad_h) // stride_h + 1
out_width = (in_width - kernel_w + pad_w) // stride_w + 1
b_np = np.zeros((batch, out_channel, out_height, out_width))
b_np = np.zeros((batch, out_channel, out_height, out_width), dtype=a_np.dtype)
# computation
for n in range(batch):
for f in range(out_channel):
for c in range(in_channel):
if pad_h > 0 or pad_w > 0:
apad = np.zeros((in_height + pad_h, in_width + pad_w))
apad = np.zeros((in_height + pad_h, in_width + pad_w), dtype=a_np.dtype)
apad[pad_top : pad_top + in_height, pad_left : pad_left + in_width] = a_np[n, c]
else:
apad = a_np[n, c]
out = scipy.signal.convolve2d(apad, np.rot90(np.rot90(w_np[f, c])), mode="valid")

out = _conv2d_hw(apad, w_np[f, c])
b_np[n, f] += out[::stride_h, ::stride_w]
return b_np


def _conv2d_hw(apad, w_np_fc):
"""2d convolution operator in HW layout.

This is intended to be used as a subroutine from
_conv2d_nchw_python. Using scipy.signal.convolve2d directly does
not work for all dtypes (e.g. float16). Where possible, this
function uses scipy.signal.convolve2d to take advantage of
compiled scipy routines, falling back to an explicit loop only
where needed

Parameters
----------
a_np : numpy.ndarray
2-D with shape [in_height, in_width]

w_np : numpy.ndarray
2-D with shape [filter_height, filter_width].

Returns
-------
b_np : np.ndarray
2-D with shape [out_height, out_width]
"""

try:
return scipy.signal.convolve2d(apad, np.rot90(np.rot90(w_np_fc)), mode="valid")
except ValueError:
pass

assert len(apad.shape) == len(w_np_fc.shape) == 2

dtype = apad.dtype
in_height, in_width = apad.shape
kernel_h, kernel_w = w_np_fc.shape

output_shape = [a_dim - w_dim + 1 for a_dim, w_dim in zip(apad.shape, w_np_fc.shape)]
output = np.zeros(output_shape, dtype=apad.dtype)

for y in range(output_shape[0]):
for x in range(output_shape[1]):
output[y][x] = np.sum(apad[y : y + kernel_h, x : x + kernel_w] * w_np_fc)

return output


def conv2d_nchw_python(a_np, w_np, stride, padding, groups=1):
"""Convolution operator in NCHW layout.

Expand Down
118 changes: 34 additions & 84 deletions python/tvm/topi/testing/depthwise_conv2d_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
# pylint: disable=invalid-name, unused-variable, line-too-long
"""Depthwise convolution in python"""
import numpy as np
from scipy import signal

from tvm.topi.nn.utils import get_pad_tuple
from .common import _convolve2d


def depthwise_conv2d_python_nchw(input_np, filter_np, stride, padding):
Expand Down Expand Up @@ -49,42 +51,29 @@ def depthwise_conv2d_python_nchw(input_np, filter_np, stride, padding):
else:
stride_h, stride_w = stride

# calculate output shape
if padding == "VALID":
out_channel = in_channel * channel_multiplier
out_height = (in_height - filter_height) // stride_h + 1
out_width = (in_width - filter_width) // stride_w + 1
output_np = np.zeros((batch, out_channel, out_height, out_width))
for i in range(batch):
for j in range(out_channel):
output_np[i, j, :, :] = signal.convolve2d(
input_np[i, j // channel_multiplier, :, :],
np.rot90(filter_np[j // channel_multiplier, j % channel_multiplier, :, :], 2),
mode="valid",
)[
0 : (in_height - filter_height + 1) : stride_h,
0 : (in_width - filter_width + 1) : stride_w,
]
elif padding == "SAME":
out_channel = in_channel * channel_multiplier
out_height = int(np.ceil(float(in_height) / float(stride_h)))
out_width = int(np.ceil(float(in_width) / float(stride_w)))
output_np = np.zeros((batch, out_channel, out_height, out_width))
pad_along_height = int(np.max((out_height - 1) * stride_h + filter_height - in_height, 0))
pad_along_width = int(np.max((out_width - 1) * stride_w + filter_width - in_width, 0))
pad_top_tvm = int(np.ceil(float(pad_along_height) / 2))
pad_left_tvm = int(np.ceil(float(pad_along_width) / 2))
pad_top_scipy = int(np.ceil(float(filter_height - 1) / 2))
pad_left_scipy = int(np.ceil(float(filter_width - 1) / 2))
index_h = pad_top_scipy - pad_top_tvm
index_w = pad_left_scipy - pad_left_tvm
for i in range(batch):
for j in range(out_channel):
output_np[i, j, :, :] = signal.convolve2d(
input_np[i, j // channel_multiplier, :, :],
np.rot90(filter_np[j // channel_multiplier, j % channel_multiplier, :, :], 2),
mode="same",
)[index_h:in_height:stride_h, index_w:in_width:stride_w]
pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (filter_height, filter_width))
pad_h = pad_top + pad_bottom
pad_w = pad_left + pad_right

out_channel = in_channel * channel_multiplier
out_height = (in_height - filter_height + pad_h) // stride_h + 1
out_width = (in_width - filter_width + pad_w) // stride_w + 1
output_np = np.zeros((batch, out_channel, out_height, out_width))

for i in range(batch):
for j in range(out_channel):
apad = input_np[i, j // channel_multiplier, :, :]
if pad_h or pad_w:
apad = np.pad(apad, [(pad_top, pad_bottom), (pad_left, pad_right)])

conv = _convolve2d(
apad,
np.rot90(filter_np[j // channel_multiplier, j % channel_multiplier, :, :], k=2),
)
output_np[i, j, :, :] = conv[
::stride_h,
::stride_w,
]

return output_np

Expand Down Expand Up @@ -139,15 +128,17 @@ def depthwise_conv2d_python_nchwc(input_np, filter_np, stride, padding):
# Perform conv2d
output_np = depthwise_conv2d_python_nchw(input_nchw, filter_nchw, stride, padding)

# Transform back
# Transform back to NCHWc

# pylint: disable=unpacking-non-sequence
batch_size, out_channel, out_height, out_width = output_np.shape
return output_np.reshape(
(batch_size, out_channel_chunk, out_channel_block, out_height, out_width)
).transpose(0, 1, 3, 4, 2)


def depthwise_conv2d_python_nhwc(input_np, filter_np, stride, padding):
"""Depthwise convolution operator in nchw layout.
"""Depthwise convolution operator in nhwc layout.

Parameters
----------
Expand All @@ -168,48 +159,7 @@ def depthwise_conv2d_python_nhwc(input_np, filter_np, stride, padding):
output_np : np.ndarray
4-D with shape [batch, out_height, out_width, out_channel]
"""
batch, in_height, in_width, in_channel = input_np.shape
filter_height, filter_width, _, channel_multiplier = filter_np.shape
if isinstance(stride, int):
stride_h = stride_w = stride
else:
stride_h, stride_w = stride

# calculate output shape
if padding == "VALID":
out_channel = in_channel * channel_multiplier
out_height = (in_height - filter_height) // stride_h + 1
out_width = (in_width - filter_width) // stride_w + 1
output_np = np.zeros((batch, out_height, out_width, out_channel))
for i in range(batch):
for j in range(out_channel):
output_np[i, :, :, j] = signal.convolve2d(
input_np[i, :, :, j // channel_multiplier],
np.rot90(filter_np[:, :, j // channel_multiplier, j % channel_multiplier], 2),
mode="valid",
)[
0 : (in_height - filter_height + 1) : stride_h,
0 : (in_width - filter_width + 1) : stride_w,
]
if padding == "SAME":
out_channel = in_channel * channel_multiplier
out_height = int(np.ceil(float(in_height) / float(stride_h)))
out_width = int(np.ceil(float(in_width) / float(stride_w)))
output_np = np.zeros((batch, out_height, out_width, out_channel))
pad_along_height = int(np.max((out_height - 1) * stride_h + filter_height - in_height, 0))
pad_along_width = int(np.max((out_width - 1) * stride_w + filter_width - in_width, 0))
pad_top_tvm = int(np.ceil(float(pad_along_height) / 2))
pad_left_tvm = int(np.ceil(float(pad_along_width) / 2))
pad_top_scipy = int(np.ceil(float(filter_height - 1) / 2))
pad_left_scipy = int(np.ceil(float(filter_width - 1) / 2))
index_h = pad_top_scipy - pad_top_tvm
index_w = pad_left_scipy - pad_left_tvm
for i in range(batch):
for j in range(out_channel):
output_np[i, :, :, j] = signal.convolve2d(
input_np[i, :, :, j // channel_multiplier],
np.rot90(filter_np[:, :, j // channel_multiplier, j % channel_multiplier], 2),
mode="same",
)[index_h:in_height:stride_h, index_w:in_width:stride_w]

return output_np
input_nchw = input_np.transpose(0, 3, 1, 2)
filter_nchw = filter_np.transpose(2, 3, 0, 1)
output_nchw = depthwise_conv2d_python_nchw(input_nchw, filter_nchw, stride, padding)
return output_nchw.transpose(0, 2, 3, 1)
35 changes: 23 additions & 12 deletions python/tvm/topi/testing/dilate_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np


def dilate_python(input_np, strides, dilation_value=0.0):
def dilate_python(input_np, strides, dilation_value=0.0, out_dtype=None):
"""Dilate operation.

Parameters
Expand All @@ -33,23 +33,34 @@ def dilate_python(input_np, strides, dilation_value=0.0):
dilation_value : int/float, optional
Value used to dilate the input.

out_dtype : Option[str]
The datatype of the dilated array. If unspecified, will use
the same dtype as the input array.

Returns
-------
output_np : numpy.ndarray
n-D, the same layout as Input.

"""
n = len(input_np.shape)
assert len(strides) == n, "Input dimension and strides size dismatch : %d vs %d" % (
n,
assert len(input_np.shape) == len(
strides
), "Input dimension and strides size dismatch : %d vs %d" % (
len(input_np.shape),
len(strides),
)
output_size = ()
no_zero = ()
for i in range(n):
output_size += ((input_np.shape[i] - 1) * strides[i] + 1,)
no_zero += ((range(0, output_size[i], strides[i])),)
output_np = np.ones(shape=output_size)
output_np = dilation_value * output_np
output_np[np.ix_(*no_zero)] = input_np

if out_dtype is None:
out_dtype = input_np.dtype

output_size = [
(input_dim - 1) * stride + 1 for input_dim, stride in zip(input_np.shape, strides)
]
non_zero_elements = np.ix_(
*[range(0, output_dim, stride) for output_dim, stride in zip(output_size, strides)]
)

output_np = np.full(shape=output_size, fill_value=dilation_value, dtype=out_dtype)
output_np[non_zero_elements] = input_np

return output_np
Loading