Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion topi/python/topi/cuda/depthwise_conv2d_map.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# pylint: disable=invalid-name
"""Schedule for depthwise_conv2d with auto fusion"""
import tvm
from ..nn.util import get_const_tuple
from ..util import get_const_tuple

def schedule_depthwise_conv2d_map(op):
"""Schedule for depthwise_conv2d map ops.
Expand Down
2 changes: 1 addition & 1 deletion topi/python/topi/nn/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import absolute_import as _abs
import tvm
import numpy as np
from .util import get_const_tuple
from ..util import get_const_tuple


@tvm.tag_scope(tag="conv2d_hwcn")
Expand Down
23 changes: 0 additions & 23 deletions topi/python/topi/nn/util.py

This file was deleted.

43 changes: 43 additions & 0 deletions topi/python/topi/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Common topi utilities"""
from __future__ import absolute_import as _abs
import tvm

def get_const_int(expr):
"""Verifies expr is integer and get the constant value.

Parameters
----------
expr :
The input expression.

Returns
-------
out_tuple : tuple of int
The output.
"""
if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)):
expr = tvm.ir_pass.Simplfy(expr)
if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)):
raise ValueError("Expect value to be constant int")
return expr.value


def get_const_tuple(in_tuple):
"""Verifies input tuple is IntImm, returns tuple of int.

Parameters
----------
in_tuple : tuple of Expr
The input.

Returns
-------
out_tuple : tuple of int
The output.
"""
out_tuple = ()
for elem in in_tuple:
if not isinstance(elem, (tvm.expr.IntImm, tvm.expr.UIntImm)):
raise ValueError("Element of input tuple should be const int")
out_tuple = out_tuple + (elem.value, )
return out_tuple
9 changes: 9 additions & 0 deletions topi/tests/python/test_topi_basic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
import tvm
import topi
from topi import util


def test_util():
x = tvm.const(100)
assert util.get_const_int(x) == 100
assert util.get_const_tuple((x, x)) == (100, 100)


def test_ewise():
m = tvm.var('m')
Expand All @@ -19,4 +27,5 @@ def test_apply(func, name):


if __name__ == "__main__":
test_util()
test_ewise()
2 changes: 1 addition & 1 deletion topi/tests/python/test_topi_conv2d_hwcn_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import tvm
import topi
from topi.nn.util import get_const_tuple
from topi.util import get_const_tuple


def verify_conv2d_hwcn_map(batch, in_channel, in_size, num_filter, kernel, stride, padding):
Expand Down
2 changes: 1 addition & 1 deletion topi/tests/python/test_topi_depthwise_conv2d_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import topi
import numpy as np
from scipy import signal
from topi.nn.util import get_const_tuple
from topi.util import get_const_tuple
from topi.cuda.depthwise_conv2d_map import schedule_depthwise_conv2d_map

def depthwise_conv2d_map_with_workload(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding):
Expand Down