diff --git a/topi/python/topi/nn/upsampling.py b/topi/python/topi/nn/upsampling.py index 55f7844319f3..757d8fe674c2 100644 --- a/topi/python/topi/nn/upsampling.py +++ b/topi/python/topi/nn/upsampling.py @@ -1,6 +1,7 @@ """TVM operator upsampling compute.""" from __future__ import absolute_import import topi +from ..util import simplify def upsampling(data, scale, layout="NCHW", method='NEAREST_NEIGHBOR'): @@ -31,9 +32,9 @@ def upsampling(data, scale, layout="NCHW", method='NEAREST_NEIGHBOR'): """ if layout == "NCHW": - out_shape = (data.shape[2] * scale, data.shape[3] * scale) + out_shape = (simplify(data.shape[2] * scale), simplify(data.shape[3] * scale)) elif layout == "NHWC": - out_shape = (data.shape[1] * scale, data.shape[2] * scale) + out_shape = (simplify(data.shape[1] * scale), simplify(data.shape[2] * scale)) else: raise ValueError("not support this layout {} yet".format(layout)) diff --git a/topi/tests/python/test_topi_upsampling.py b/topi/tests/python/test_topi_upsampling.py index 3affc30a0722..ec657d490fb6 100644 --- a/topi/tests/python/test_topi_upsampling.py +++ b/topi/tests/python/test_topi_upsampling.py @@ -5,7 +5,7 @@ import topi.testing import math -def verify_upsampling(batch, in_channel, in_height, in_width, scale, layout='NCHW'): +def verify_upsampling(batch, in_channel, in_height, in_width, scale, layout='NCHW', method="NEAREST_NEIGHBOR"): if layout == 'NCHW': @@ -22,9 +22,13 @@ def verify_upsampling(batch, in_channel, in_height, in_width, scale, layout='NCH raise NotImplementedError( 'Layout not supported {} '.format(layout)) - B = topi.nn.upsampling(A, scale, layout=layout) + B = topi.nn.upsampling(A, scale, layout=layout, method=method) - b_np = topi.testing.upsampling_python(a_np, scale, layout) + if method == "BILINEAR": + out_size = (in_height*scale, in_width*scale) + b_np = topi.testing.bilinear_resize_python(a_np, out_size, layout) + else: + b_np = topi.testing.upsampling_python(a_np, scale, layout) def check_device(device): ctx = tvm.context(device, 0) @@ -39,18 +43,27 @@ def check_device(device): f = tvm.build(s, [A, B], device) f(a, b) - np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) + np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5) for device in ['llvm', 'cuda', 'vulkan', 'nvptx']: check_device(device) def test_upsampling(): - # NCHW + # NEAREST_NEIGHBOR - NCHW verify_upsampling(8, 16, 32, 32, 2) verify_upsampling(12, 32, 64, 64, 3) - # NHWC - verify_upsampling(8, 16, 32, 32, 2, "NHWC") - verify_upsampling(12, 32, 64, 64, 3, "NHWC") + + # NEAREST_NEIGHBOR - NHWC + verify_upsampling(8, 16, 32, 32, 2, layout="NHWC") + verify_upsampling(12, 32, 64, 64, 3, layout="NHWC") + + # BILINEAR - NCHW + verify_upsampling(2, 2, 32, 32, 2, method="BILINEAR") + verify_upsampling(2, 2, 32, 32, 3, method="BILINEAR") + + # BILINEAR - NHWC + verify_upsampling(2, 2, 32, 32, 2, layout="NHWC", method="BILINEAR") + verify_upsampling(2, 2, 32, 32, 3, layout="NHWC", method="BILINEAR") if __name__ == "__main__": test_upsampling()