diff --git a/topi/python/topi/nn/__init__.py b/topi/python/topi/nn/__init__.py index 80366e09f402..884989920c60 100644 --- a/topi/python/topi/nn/__init__.py +++ b/topi/python/topi/nn/__init__.py @@ -5,3 +5,4 @@ from .mapping import * from .ewise import * from .conv import * +from .dilate import * diff --git a/topi/python/topi/nn/dilate.py b/topi/python/topi/nn/dilate.py new file mode 100644 index 000000000000..2b3b2f424aa6 --- /dev/null +++ b/topi/python/topi/nn/dilate.py @@ -0,0 +1,44 @@ +# pylint: disable=invalid-name +"""Dilation operators""" +from __future__ import absolute_import as _abs +import tvm + + +@tvm.tag_scope(tag="dilation") +def dilate(Input, strides): + """Dilate Input with zeros. + + Parameters + ---------- + Input : tvm.Tensor + n-D, can be any layout. + + strides : list / tuple of n ints + Dilation stride on each dimension, 1 means no dilation. + + Returns + ------- + Output : tvm.Tensor + n-D, the same layout as Input. + """ + n = len(Input.shape) + assert len(strides) == n, \ + "Input dimension and strides size dismatch : %d vs %d" %(n, len(strides)) + output_size = () + for i in range(n): + output_size += (tvm.ir_pass.Simplify((Input.shape[i]-1)*strides[i]+1),) + + def _dilate(data, *indices): + not_zero = (indices[0]%strides[0]).equal(0) + index_tuple = () + for i in range(n): + index_tuple += (indices[i]/strides[i],) + not_zero = tvm.all(not_zero, (indices[i]%strides[i]).equal(0)) + return tvm.select(not_zero, data[index_tuple], tvm.const(0.0, data.dtype)) + + Output = tvm.compute( + (output_size), + lambda *indices: _dilate(Input, *indices), + name='DilatedInput') + + return Output diff --git a/topi/python/topi/testing/__init__.py b/topi/python/topi/testing/__init__.py index 63bc8eb7215a..61fe2a60df91 100644 --- a/topi/python/topi/testing/__init__.py +++ b/topi/python/topi/testing/__init__.py @@ -6,3 +6,4 @@ from .conv2d_hwcn_python import conv2d_hwcn_python from .conv2d_nchw_python import conv2d_nchw_python +from .dilate_python import dilate_python diff --git a/topi/python/topi/testing/dilate_python.py b/topi/python/topi/testing/dilate_python.py new file mode 100644 index 000000000000..89ac2c109fb6 --- /dev/null +++ b/topi/python/topi/testing/dilate_python.py @@ -0,0 +1,33 @@ +# pylint: disable=invalid-name +"""Dilate operation in python""" +import numpy as np + + +def dilate_python(input_np, strides): + """Dilate operation. + + Parameters + ---------- + input_np : numpy.ndarray + n-D, can be any layout. + + strides : list / tuple of n ints + Dilation stride on each dimension, 1 means no dilation. + + Returns + ------- + output_np : numpy.ndarray + n-D, the same layout as Input. + """ + n = len(input_np.shape) + assert len(strides) == n, \ + "Input dimension and strides size dismatch : %d vs %d" %(n, len(strides)) + output_size = () + no_zero = () + for i in range(n): + output_size += ((input_np.shape[i]-1)*strides[i]+1,) + no_zero += ((range(0, output_size[i], strides[i])),) + output_np = np.zeros(shape=output_size) + output_np[np.ix_(*no_zero)] = input_np + + return output_np diff --git a/topi/tests/python/test_topi_dilate.py b/topi/tests/python/test_topi_dilate.py new file mode 100644 index 000000000000..0d2014535ca4 --- /dev/null +++ b/topi/tests/python/test_topi_dilate.py @@ -0,0 +1,36 @@ +import tvm +import topi +import numpy as np + + +def test_dilate(): + target = 'llvm' + ctx = tvm.cpu(0) + + def _test_dilate(input_size, strides): + Input = tvm.placeholder((input_size)) + Output = topi.nn.dilate(Input, strides) + schedule = tvm.create_schedule(Output.op) + input_np = np.random.uniform(size=input_size).astype(Input.dtype) + output_np = topi.testing.dilate_python(input_np, strides) + input_tvm = tvm.nd.array(input_np, ctx=ctx) + output_size = () + for i in range(len(input_size)): + output_size += (tvm.ir_pass.Simplify(Output.shape[i]).value,) + output_tvm = tvm.nd.array(np.zeros(shape=output_size).astype(Output.dtype), ctx=ctx) + f = tvm.build(schedule, [Input, Output], target) + f(input_tvm, output_tvm) + np.testing.assert_allclose(output_tvm.asnumpy(), output_np, rtol=1e-5) + + _test_dilate((32,), (2,)) + _test_dilate((32,32), (2,2)) + _test_dilate((1,3,32,32), (1,1,1,1)) + _test_dilate((1,3,32,32), (2,2,2,2)) + _test_dilate((1,32,32,3,3), (1,1,1,1,1)) + _test_dilate((1,32,32,3,3), (2,2,2,2,2)) + _test_dilate((1,32,32,32,3,3), (1,1,1,2,2,2)) + _test_dilate((1,32,32,32,3,3), (2,2,2,1,1,1)) + + +if __name__ == "__main__": + test_dilate()