diff --git a/topi/python/topi/generic/nn.py b/topi/python/topi/generic/nn.py index d606213a5270..6f641e99f7dd 100644 --- a/topi/python/topi/generic/nn.py +++ b/topi/python/topi/generic/nn.py @@ -177,7 +177,6 @@ def schedule_global_pool(outs): """ return _default_schedule(outs, False) - @tvm.target.generic_func def schedule_binarize_pack(outs): """Schedule for binarize_pack diff --git a/topi/python/topi/nn/__init__.py b/topi/python/topi/nn/__init__.py index 3cdf3122e78e..918f399f503c 100644 --- a/topi/python/topi/nn/__init__.py +++ b/topi/python/topi/nn/__init__.py @@ -14,3 +14,4 @@ from .softmax import * from .conv2d_transpose import * from .bnn import * +from .upsampling import * diff --git a/topi/python/topi/nn/upsampling.py b/topi/python/topi/nn/upsampling.py new file mode 100644 index 000000000000..e1234741e286 --- /dev/null +++ b/topi/python/topi/nn/upsampling.py @@ -0,0 +1,28 @@ +"""TVM operator upsampling compute.""" +from __future__ import absolute_import +import tvm + + +def upsampling(data, scale): + """Perform nearest neighbor upsampling on the data. + Bilinear upsampling is not supported. + + Parameters + ---------- + data : tvm.Tensor + 4-D with shape [batch, channel, in_height, in_width] + + scale: int + upsampling scaling factor + + Returns + ------- + output : tvm.Tensor + 4-D with shape [batch, channel, in_height*scale, in_width*scale] + """ + batch, channel, height, width = data.shape + out_height = height * scale + out_width = width * scale + + return tvm.compute((batch, channel, out_height, out_width), \ + lambda n, c, h, w: data[n, c, h/scale, w/scale]) diff --git a/topi/python/topi/testing/__init__.py b/topi/python/topi/testing/__init__.py index 3a43a04437a1..6a1b361e3097 100644 --- a/topi/python/topi/testing/__init__.py +++ b/topi/python/topi/testing/__init__.py @@ -10,3 +10,4 @@ from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc from .dilate_python import dilate_python from .softmax_python import softmax_python, log_softmax_python +from .upsampling_python import upsampling_python diff --git a/topi/python/topi/testing/upsampling_python.py b/topi/python/topi/testing/upsampling_python.py new file mode 100644 index 000000000000..328c7a5a0bc1 --- /dev/null +++ b/topi/python/topi/testing/upsampling_python.py @@ -0,0 +1,15 @@ +# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals +"""Upsampling in python""" +import numpy as np + +def upsample_nearest(arr, scale): + return arr.repeat(scale, axis=0).repeat(scale, axis=1) + +def upsampling_python(data, scale): + ishape = data.shape + oshape = (ishape[0], ishape[1], ishape[2]*scale, ishape[3]*scale) + output_np = np.zeros(oshape, dtype=data.dtype) + for b in range(oshape[0]): + for c in range(oshape[1]): + output_np[b, c, :, :] = upsample_nearest(data[b, c, :, :], scale) + return output_np diff --git a/topi/tests/python/test_topi_upsampling.py b/topi/tests/python/test_topi_upsampling.py new file mode 100644 index 000000000000..08b8f987694d --- /dev/null +++ b/topi/tests/python/test_topi_upsampling.py @@ -0,0 +1,39 @@ +"""Test code for upsampling""" +import numpy as np +import tvm +import topi +import math + +def verify_upsampling(batch, in_channel, in_height, in_width, scale): + A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A') + B = topi.nn.upsampling(A, scale) + out_shape = (batch, in_channel, in_height*scale, in_width*scale) + dtype = A.dtype + + a_np = np.random.uniform(size=(batch, in_channel, in_height, in_width)).astype(dtype) + b_np = topi.testing.upsampling_python(a_np, scale) + + def check_device(device): + if not tvm.module.enabled(device): + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + s = topi.generic.schedule_injective(B) + ctx = tvm.context(device, 0) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(np.zeros(out_shape, dtype=dtype), ctx) + f = tvm.build(s, [A, B], device) + f(a, b) + + np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) + + for device in ['llvm', 'cuda']: + check_device(device) + +def test_upsampling(): + verify_upsampling(8, 16, 32, 32, 2) + verify_upsampling(12, 32, 64, 64, 3) + +if __name__ == "__main__": + test_upsampling()