diff --git a/topi/python/topi/cuda/depthwise_conv2d_map.py b/topi/python/topi/cuda/depthwise_conv2d_map.py index e1900d678c3d..95a5ee827e11 100644 --- a/topi/python/topi/cuda/depthwise_conv2d_map.py +++ b/topi/python/topi/cuda/depthwise_conv2d_map.py @@ -1,7 +1,7 @@ # pylint: disable=invalid-name """Schedule for depthwise_conv2d with auto fusion""" import tvm -from ..nn.util import get_const_tuple +from ..util import get_const_tuple def schedule_depthwise_conv2d_map(op): """Schedule for depthwise_conv2d map ops. diff --git a/topi/python/topi/nn/conv.py b/topi/python/topi/nn/conv.py index 4c233aa1568a..768387e54d0d 100644 --- a/topi/python/topi/nn/conv.py +++ b/topi/python/topi/nn/conv.py @@ -3,7 +3,7 @@ from __future__ import absolute_import as _abs import tvm import numpy as np -from .util import get_const_tuple +from ..util import get_const_tuple @tvm.tag_scope(tag="conv2d_hwcn") diff --git a/topi/python/topi/nn/util.py b/topi/python/topi/nn/util.py deleted file mode 100644 index 207a245109ea..000000000000 --- a/topi/python/topi/nn/util.py +++ /dev/null @@ -1,23 +0,0 @@ -"""Common topi utilities""" -from __future__ import absolute_import as _abs -import tvm - -def get_const_tuple(in_tuple): - """Verifies input tuple is IntImm, returns tuple of int. - - Parameters - ---------- - in_tuple : tuple of tvm.expr.IntImm - The input. - - Returns - ------- - out_tuple : tuple of int - The output. - """ - out_tuple = () - for elem in in_tuple: - if not isinstance(elem, tvm.expr.IntImm): - raise ValueError("Element of input tuple should be IntImm") - out_tuple = out_tuple + (elem.value, ) - return out_tuple diff --git a/topi/python/topi/util.py b/topi/python/topi/util.py new file mode 100644 index 000000000000..859e3d6caa2f --- /dev/null +++ b/topi/python/topi/util.py @@ -0,0 +1,43 @@ +"""Common topi utilities""" +from __future__ import absolute_import as _abs +import tvm + +def get_const_int(expr): + """Verifies expr is integer and get the constant value. + + Parameters + ---------- + expr : + The input expression. + + Returns + ------- + out_tuple : tuple of int + The output. + """ + if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)): + expr = tvm.ir_pass.Simplfy(expr) + if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)): + raise ValueError("Expect value to be constant int") + return expr.value + + +def get_const_tuple(in_tuple): + """Verifies input tuple is IntImm, returns tuple of int. + + Parameters + ---------- + in_tuple : tuple of Expr + The input. + + Returns + ------- + out_tuple : tuple of int + The output. + """ + out_tuple = () + for elem in in_tuple: + if not isinstance(elem, (tvm.expr.IntImm, tvm.expr.UIntImm)): + raise ValueError("Element of input tuple should be const int") + out_tuple = out_tuple + (elem.value, ) + return out_tuple diff --git a/topi/tests/python/test_topi_basic.py b/topi/tests/python/test_topi_basic.py index 748231aad979..067bc5adab41 100644 --- a/topi/tests/python/test_topi_basic.py +++ b/topi/tests/python/test_topi_basic.py @@ -1,5 +1,13 @@ import tvm import topi +from topi import util + + +def test_util(): + x = tvm.const(100) + assert util.get_const_int(x) == 100 + assert util.get_const_tuple((x, x)) == (100, 100) + def test_ewise(): m = tvm.var('m') @@ -19,4 +27,5 @@ def test_apply(func, name): if __name__ == "__main__": + test_util() test_ewise() diff --git a/topi/tests/python/test_topi_conv2d_hwcn_map.py b/topi/tests/python/test_topi_conv2d_hwcn_map.py index 993e5713cfe4..820d859847a8 100644 --- a/topi/tests/python/test_topi_conv2d_hwcn_map.py +++ b/topi/tests/python/test_topi_conv2d_hwcn_map.py @@ -3,7 +3,7 @@ import numpy as np import tvm import topi -from topi.nn.util import get_const_tuple +from topi.util import get_const_tuple def verify_conv2d_hwcn_map(batch, in_channel, in_size, num_filter, kernel, stride, padding): diff --git a/topi/tests/python/test_topi_depthwise_conv2d_map.py b/topi/tests/python/test_topi_depthwise_conv2d_map.py index 4069bd43f2da..22cc0654b0e6 100644 --- a/topi/tests/python/test_topi_depthwise_conv2d_map.py +++ b/topi/tests/python/test_topi_depthwise_conv2d_map.py @@ -2,7 +2,7 @@ import topi import numpy as np from scipy import signal -from topi.nn.util import get_const_tuple +from topi.util import get_const_tuple from topi.cuda.depthwise_conv2d_map import schedule_depthwise_conv2d_map def depthwise_conv2d_map_with_workload(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding):