Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion topi/python/topi/generic/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,6 @@ def schedule_global_pool(outs):
"""
return _default_schedule(outs, False)


@tvm.target.generic_func
def schedule_binarize_pack(outs):
"""Schedule for binarize_pack
Expand Down
1 change: 1 addition & 0 deletions topi/python/topi/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@
from .softmax import *
from .conv2d_transpose import *
from .bnn import *
from .upsampling import *
28 changes: 28 additions & 0 deletions topi/python/topi/nn/upsampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""TVM operator upsampling compute."""
from __future__ import absolute_import
import tvm


def upsampling(data, scale):
"""Perform nearest neighbor upsampling on the data.
Bilinear upsampling is not supported.

Parameters
----------
data : tvm.Tensor
4-D with shape [batch, channel, in_height, in_width]

scale: int
upsampling scaling factor

Returns
-------
output : tvm.Tensor
4-D with shape [batch, channel, in_height*scale, in_width*scale]
"""
batch, channel, height, width = data.shape
out_height = height * scale
out_width = width * scale

return tvm.compute((batch, channel, out_height, out_width), \
lambda n, c, h, w: data[n, c, h/scale, w/scale])
1 change: 1 addition & 0 deletions topi/python/topi/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc
from .dilate_python import dilate_python
from .softmax_python import softmax_python, log_softmax_python
from .upsampling_python import upsampling_python
15 changes: 15 additions & 0 deletions topi/python/topi/testing/upsampling_python.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
"""Upsampling in python"""
import numpy as np

def upsample_nearest(arr, scale):
return arr.repeat(scale, axis=0).repeat(scale, axis=1)

def upsampling_python(data, scale):
ishape = data.shape
oshape = (ishape[0], ishape[1], ishape[2]*scale, ishape[3]*scale)
output_np = np.zeros(oshape, dtype=data.dtype)
for b in range(oshape[0]):
for c in range(oshape[1]):
output_np[b, c, :, :] = upsample_nearest(data[b, c, :, :], scale)
return output_np
39 changes: 39 additions & 0 deletions topi/tests/python/test_topi_upsampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Test code for upsampling"""
import numpy as np
import tvm
import topi
import math

def verify_upsampling(batch, in_channel, in_height, in_width, scale):
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
B = topi.nn.upsampling(A, scale)
out_shape = (batch, in_channel, in_height*scale, in_width*scale)
dtype = A.dtype

a_np = np.random.uniform(size=(batch, in_channel, in_height, in_width)).astype(dtype)
b_np = topi.testing.upsampling_python(a_np, scale)

def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective(B)
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(out_shape, dtype=dtype), ctx)
f = tvm.build(s, [A, B], device)
f(a, b)

np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)

for device in ['llvm', 'cuda']:
check_device(device)

def test_upsampling():
verify_upsampling(8, 16, 32, 32, 2)
verify_upsampling(12, 32, 64, 64, 3)

if __name__ == "__main__":
test_upsampling()