diff --git a/python/tvm/expr.py b/python/tvm/expr.py index a84eeb75ef53..b1d5b51b47a8 100644 --- a/python/tvm/expr.py +++ b/python/tvm/expr.py @@ -50,6 +50,12 @@ def __truediv__(self, other): def __rtruediv__(self, other): return self.__rdiv__(other) + def __floordiv__(self, other): + return self.__div__(other) + + def __rfloordiv__(self, other): + return self.__rdiv__(other) + def __mod__(self, other): return _make.Mod(self, other) diff --git a/python/tvm/make.py b/python/tvm/make.py index 1e8fae678084..49f698f4f663 100644 --- a/python/tvm/make.py +++ b/python/tvm/make.py @@ -52,10 +52,11 @@ def static_cast(dtype, expr): """ target_type = TVMType(dtype) src_type = TVMType(expr.dtype) - if target_type.type_code == src_type.type_code\ - and src_type.lanes == 1\ - and target_type.lanes > 1: - return Broadcast(expr, target_type.lanes) + if target_type.type_code == src_type.type_code and src_type.bits == target_type.bits: + if src_type.lanes == target_type.lanes: + return expr + elif src_type.lanes == 1 and target_type.lanes > 1: + return Broadcast(expr, target_type.lanes) return Cast(dtype, expr) diff --git a/topi/python/topi/__init__.py b/topi/python/topi/__init__.py index b33819c7e471..ef0c7467c051 100644 --- a/topi/python/topi/__init__.py +++ b/topi/python/topi/__init__.py @@ -15,3 +15,4 @@ from . import nn from . import cuda from . import testing +from . import util diff --git a/topi/python/topi/nn/convolution.py b/topi/python/topi/nn/convolution.py index 8d82d4d8ab3a..6293e02e343a 100644 --- a/topi/python/topi/nn/convolution.py +++ b/topi/python/topi/nn/convolution.py @@ -1,12 +1,11 @@ -# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals +# pylint: disable=invalid-name, unused-variable, too-many-locals """Convolution operators""" from __future__ import absolute_import as _abs import tvm -import numpy as np -from ..util import get_const_tuple +from ..util import simplify +from .pad import pad, _spatial2d_pad_option -@tvm.tag_scope(tag="conv2d_nchw") def conv2d_nchw(Input, Filter, stride, padding): """Convolution operator in HWCN layout. @@ -31,45 +30,33 @@ def conv2d_nchw(Input, Filter, stride, padding): """ assert isinstance(stride, int) or len(stride) == 2 assert isinstance(padding, int) or padding in ['VALID', 'SAME'] - batch, in_channel, in_height, in_width = get_const_tuple(Input.shape) - num_filter, channel, kernel_h, kernel_w = get_const_tuple(Filter.shape) + batch, in_channel, in_height, in_width = Input.shape + num_filter, channel, kernel_h, kernel_w = Filter.shape if isinstance(stride, int): stride_h = stride_w = stride else: stride_h, stride_w = stride - # compute the padding size - if isinstance(padding, int): - pad_h = pad_w = padding * 2 - elif padding == 'VALID': - pad_h = 0 - pad_w = 0 - else: # 'SAME' - pad_h = kernel_h - 1 - pad_w = kernel_w - 1 - pad_top = int(np.ceil(float(pad_h) / 2)) - pad_left = int(np.ceil(float(pad_w) / 2)) + pad_top, pad_left, pad_down, pad_right = _spatial2d_pad_option( + padding, (kernel_h, kernel_w)) # compute the output shape out_channel = num_filter - out_height = (in_height - kernel_h + pad_h) // stride_h + 1 - out_width = (in_width - kernel_w + pad_w) // stride_w + 1 + out_height = simplify((in_height - kernel_h + pad_top + pad_down) // stride_h + 1) + out_width = simplify((in_width - kernel_w + pad_left + pad_right) // stride_w + 1) # compute graph - temp = tvm.compute( - (batch, in_channel, in_height + pad_h, in_width + pad_w), - lambda nn, cc, yy, xx: tvm.select( - tvm.all(yy >= pad_top, yy - pad_top < in_height, - xx >= pad_left, xx - pad_left < in_width), - Input[nn, cc, yy - pad_top, xx - pad_left], tvm.const(0.)), - name='temp') + pad_before = [0, 0, pad_top, pad_left] + pad_after = [0, 0, pad_down, pad_right] + temp = pad(Input, pad_before, pad_after, name="pad_temp") rc = tvm.reduce_axis((0, in_channel), name='rc') ry = tvm.reduce_axis((0, kernel_h), name='ry') rx = tvm.reduce_axis((0, kernel_w), name='rx') + return tvm.compute( (batch, out_channel, out_height, out_width), lambda nn, ff, yy, xx: tvm.sum( temp[nn, rc, yy * stride_h + ry, xx * stride_w + rx] * Filter[ff, rc, ry, rx], - axis=[rc, ry, rx])) + axis=[rc, ry, rx]), tag="conv2d_nchw") + -@tvm.tag_scope(tag="conv2d_hwcn") def conv2d_hwcn(Input, Filter, stride, padding): """Convolution operator in HWCN layout. @@ -93,36 +80,22 @@ def conv2d_hwcn(Input, Filter, stride, padding): 4-D with shape [out_height, out_width, out_channel, batch] """ assert isinstance(stride, int) or len(stride) == 2 - assert isinstance(padding, int) or padding in ['VALID', 'SAME'] - in_height, in_width, in_channel, batch = get_const_tuple(Input.shape) - kernel_h, kernel_w, channel, num_filter = get_const_tuple(Filter.shape) + in_height, in_width, in_channel, batch = Input.shape + kernel_h, kernel_w, channel, num_filter = Filter.shape if isinstance(stride, int): stride_h = stride_w = stride else: stride_h, stride_w = stride - # compute the padding size - if isinstance(padding, int): - pad_h = pad_w = padding * 2 - elif padding == 'VALID': - pad_h = 0 - pad_w = 0 - else: # 'SAME' - pad_h = kernel_h - 1 - pad_w = kernel_w - 1 - pad_top = int(np.ceil(float(pad_h) / 2)) - pad_left = int(np.ceil(float(pad_w) / 2)) + + pad_top, pad_left, pad_down, pad_right = _spatial2d_pad_option( + padding, (kernel_h, kernel_w)) # compute the output shape out_channel = num_filter - out_height = (in_height - kernel_h + pad_h) // stride_h + 1 - out_width = (in_width - kernel_w + pad_w) // stride_w + 1 - # compute graph - PaddedInput = tvm.compute( - (in_height + pad_h, in_width + pad_w, in_channel, batch), - lambda yy, xx, cc, nn: tvm.select( - tvm.all(yy >= pad_top, yy - pad_top < in_height, - xx >= pad_left, xx - pad_left < in_width), - Input[yy - pad_top, xx - pad_left, cc, nn], tvm.const(0.)), - name='PaddedInput') + out_height = simplify((in_height - kernel_h + pad_top + pad_down) // stride_h + 1) + out_width = simplify((in_width - kernel_w + pad_left + pad_right) // stride_w + 1) + pad_before = [pad_top, pad_left, 0, 0] + pad_after = [pad_down, pad_right, 0, 0] + PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput") rc = tvm.reduce_axis((0, in_channel), name='rc') ry = tvm.reduce_axis((0, kernel_h), name='ry') rx = tvm.reduce_axis((0, kernel_w), name='rx') @@ -131,12 +104,11 @@ def conv2d_hwcn(Input, Filter, stride, padding): lambda yy, xx, ff, nn: tvm.sum( PaddedInput[yy * stride_h + ry, xx * stride_w + rx, rc, nn] * Filter[ry, rx, rc, ff], axis=[ry, rx, rc]), - name='Conv2dOutput') + name="Conv2dOutput", tag="conv2d_hwcn") return Output -@tvm.tag_scope(tag="depthwise_conv2d") -def depthwise_conv2d(Input, Filter, Stride, padding): +def depthwise_conv2d(Input, Filter, stride, padding): """Depthwise convolution operator. Parameters @@ -147,8 +119,8 @@ def depthwise_conv2d(Input, Filter, Stride, padding): Filter : tvm.Tensor 4-D with shape [in_channel, channel_multiplier, filter_height, filter_width] - Stride : tvm.Tensor - 1-D of size 2 + stride : tuple of two ints + The spatial stride along height and width padding : str 'VALID' or 'SAME' @@ -158,49 +130,28 @@ def depthwise_conv2d(Input, Filter, Stride, padding): Output : tvm.Tensor 4-D with shape [batch, out_channel, out_height, out_width] """ - in_shape = get_const_tuple(Input.shape) - batch = in_shape[0] - in_channel = in_shape[1] - in_height = in_shape[2] - in_width = in_shape[3] - filter_shape = get_const_tuple(Filter.shape) - filter_channel = filter_shape[0] - channel_multiplier = filter_shape[1] - filter_height = filter_shape[2] - filter_width = filter_shape[3] - stride_h = Stride.asnumpy()[0] - stride_w = Stride.asnumpy()[1] - # calculate output shape - if padding == 'VALID': - out_channel = in_channel * channel_multiplier - out_height = (in_height - filter_height) // stride_h + 1 - out_width = (in_width - filter_width) // stride_w + 1 - pad_along_height = 0 - pad_along_width = 0 - if padding == 'SAME': - out_channel = in_channel * channel_multiplier - out_height = np.int(np.ceil(float(in_height) / float(stride_h))) - out_width = np.int(np.ceil(float(in_width) / float(stride_w))) - pad_along_height = np.int(np.max((out_height - 1) * stride_h + filter_height - in_height, 0)) - pad_along_width = np.int(np.max((out_width - 1) * stride_w + filter_width - in_width, 0)) - height_after_pad = in_height + pad_along_height - width_after_pad = in_width + pad_along_width - pad_top = np.int(np.ceil(float(pad_along_height) / 2)) - pad_left = np.int(np.ceil(float(pad_along_width) / 2)) + batch, in_channel, in_height, in_width = Input.shape + filter_channel, channel_multiplier, filter_height, filter_width = Filter.shape + stride_h, stride_w = stride + + pad_top, pad_left, pad_down, pad_right = _spatial2d_pad_option( + padding, (filter_height, filter_width)) + out_channel = simplify(in_channel * channel_multiplier) + out_height = simplify((in_height - filter_height + pad_top + pad_down) // stride_h + 1) + out_width = simplify((in_width - filter_width + pad_left + pad_right) // stride_w + 1) + # padding stage - PaddedInput = tvm.compute( - (batch, in_channel, height_after_pad, width_after_pad), - lambda b, c, i, j: tvm.select( - tvm.all(i >= pad_top, i - pad_top < in_height, j >= pad_left, j - pad_left < in_width), - Input[b, c, i - pad_top, j - pad_left], tvm.const(0.0)), - name="PaddedInput") + pad_before = [0, 0, pad_top, pad_left] + pad_after = [0, 0, pad_down, pad_right] + PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput") # depthconv stage di = tvm.reduce_axis((0, filter_height), name='di') dj = tvm.reduce_axis((0, filter_width), name='dj') Output = tvm.compute( (batch, out_channel, out_height, out_width), lambda b, c, i, j: tvm.sum( - PaddedInput[b, c/channel_multiplier, i*stride_h + di, j*stride_w + dj] * Filter[c/channel_multiplier, c%channel_multiplier, di, dj], + (PaddedInput[b, c/channel_multiplier, i*stride_h + di, j*stride_w + dj] * + Filter[c/channel_multiplier, c%channel_multiplier, di, dj]), axis=[di, dj]), - name='DepthwiseConv2d') + name='DepthwiseConv2d', tag="depthwise_conv2d") return Output diff --git a/topi/python/topi/nn/dilate.py b/topi/python/topi/nn/dilate.py index 6ba52b02ef52..7d65712ec30c 100644 --- a/topi/python/topi/nn/dilate.py +++ b/topi/python/topi/nn/dilate.py @@ -6,35 +6,39 @@ @tvm.tag_scope(tag="dilation") -def dilate(Input, strides): - """Dilate Input with zeros. +def dilate(data, strides, name="DilatedInput"): + """Dilate data with zeros. Parameters ---------- - Input : tvm.Tensor + data : tvm.Tensor n-D, can be any layout. strides : list / tuple of n ints Dilation stride on each dimension, 1 means no dilation. + name : str, optional + The name prefix operators generated + Returns ------- Output : tvm.Tensor - n-D, the same layout as Input. + n-D, the same layout as data. """ - n = len(Input.shape) - assert len(strides) == n, \ - "Input dimension and strides size dismatch : %d vs %d" %(n, len(strides)) - output_size = () - for i in range(n): - output_size += (tvm.ir_pass.Simplify((Input.shape[i]-1)*strides[i]+1),) - - def _dilate(data, *indices): + n = len(data.shape) + if len(strides) != n: + raise ValueError("data dimension and strides size dismatch : %d vs %d" % ( + n, len(strides))) + + out_shape = tuple( + tvm.ir_pass.Simplify((data.shape[i] - 1) * strides[i] + 1) for i in range(n)) + + def _dilate(*indices): not_zero = [] index_tuple = [] for i in range(n): if not util.equal_const_int(strides[i], 1): - index_tuple.append(indices[i]/strides[i]) + index_tuple.append(indices[i] / strides[i]) not_zero.append((indices[i] % strides[i]).equal(0)) else: index_tuple.append(indices[i]) @@ -43,9 +47,4 @@ def _dilate(data, *indices): return tvm.select(not_zero, data(*index_tuple), tvm.const(0.0, data.dtype)) return data(*index_tuple) - Output = tvm.compute( - output_size, - lambda *indices: _dilate(Input, *indices), - name='DilatedInput') - - return Output + return tvm.compute(out_shape, _dilate, name=name) diff --git a/topi/python/topi/nn/pad.py b/topi/python/topi/nn/pad.py new file mode 100644 index 000000000000..978947a8ecaa --- /dev/null +++ b/topi/python/topi/nn/pad.py @@ -0,0 +1,104 @@ +"""Pad the data by constant value """ +from __future__ import absolute_import as _abs +import tvm +from ..util import equal_const_int + + +def _spatial2d_pad_option(padding, kernel): + """Common code to get the pad option + + Parameters + ---------- + padding : int or str + Padding size, or ['VALID', 'SAME'] + + kernel : tuple of int + Conv kernel size + + Returns + ------- + pad_top : int + Padding size on top + + pad_left : int + Padding size on left + + pad_down : int + Padding size on down. + + pad_right : int + Padding size on right. + """ + # compute the padding size + if isinstance(padding, (tuple, list)): + pad_h = padding[0] * 2 + pad_w = padding[1] * 2 + elif isinstance(padding, int): + pad_h = pad_w = padding * 2 + elif padding == "VALID": + pad_h = 0 + pad_w = 0 + elif padding == "SAME": + pad_h = kernel[0] - 1 + pad_w = kernel[1] - 1 + else: + raise ValueError("Unknown padding option %s" % padding) + pad_top = (pad_h + 1) // 2 + pad_left = (pad_w + 1) // 2 + return pad_top, pad_left, pad_h - pad_top, pad_w - pad_left + + +@tvm.tag_scope(tag="pad") +def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput"): + """Dilate Input with zeros. + + Parameters + ---------- + data : tvm.Tensor + n-D input, can be any layout. + + pad_before : list / tuple of n ints + Pad width on each dimension to pad the before the axis begin. + + pad_after : list / tuple of n ints, optional + Pad width each dimension to pad the after the axis end. + + pad_value : float, optional + The value to be padded. + + name : str, optional + The name prefix operators generated + + Returns + ------- + Output : tvm.Tensor + n-D, the same layout as Input. + """ + n = len(data.shape) + pad_after = pad_after if pad_after else pad_before + if len(pad_before) != n: + raise ValueError("Input dimension and pad_before dismatch : %d vs %d" % ( + n, len(pad_before))) + if len(pad_after) != n: + raise ValueError("Input dimension and pad_after dismatch : %d vs %d" % ( + n, len(pad_before))) + out_shape = tuple( + tvm.ir_pass.Simplify( + (data.shape[i] + pad_before[i] + pad_after[i])) for i in range(n)) + pad_value = (pad_value if isinstance(pad_value, tvm.expr.Expr) + else tvm.const(pad_value, data.dtype)) + def _pad(*indices): + not_zero = [] + index_tuple = [] + for i in range(n): + if equal_const_int(pad_before[i], 0) and equal_const_int(pad_after[i], 0): + index_tuple.append(indices[i]) + else: + index_tuple.append(indices[i] - pad_before[i]) + not_zero.append(indices[i] >= pad_before[i]) + not_zero.append(indices[i] < data.shape[i] + pad_before[i]) + if not_zero: + not_zero = tvm.all(*not_zero) + return tvm.select(not_zero, data(*index_tuple), pad_value) + return data(*index_tuple) + return tvm.compute(out_shape, _pad, name=name) diff --git a/topi/python/topi/nn/pooling.py b/topi/python/topi/nn/pooling.py index 9f8322da3b21..b5b44b630e21 100644 --- a/topi/python/topi/nn/pooling.py +++ b/topi/python/topi/nn/pooling.py @@ -1,9 +1,10 @@ """TVM operator pooling compute.""" from __future__ import absolute_import import tvm +from .. import util +from .pad import pad, _spatial2d_pad_option -@tvm.tag_scope(tag='max_pool') -def max_pool(data, kernel, stride, pad): +def max_pool(data, kernel, stride, padding): """Perform max pooling on the data Parameters @@ -17,7 +18,7 @@ def max_pool(data, kernel, stride, pad): stride : list/tuple of two ints Stride size, or [stride_height, stride_width] - pad : list/tuple of two ints + paddding : list/tuple of two ints Pad size, or [pad_height, pad_width] Returns @@ -26,29 +27,27 @@ def max_pool(data, kernel, stride, pad): 4-D with shape [batch, channel, out_height, out_width] """ assert len(data.shape) == 4, "only support 4-dim pooling" - assert len(stride.shape) == 2, "only support 2-dim stride" - assert len(pad.shape) == 2, "only support 2-dim pad" + assert len(stride) == 2, "only support 2-dim stride" kernel_height, kernel_width = kernel stride_height, stride_width = stride - pad_height, pad_width = pad batch, channel, height, width = data.shape - padded_height = height + 2*pad_height - padded_width = width + 2*pad_width - out_height = (height + 2*pad_height - kernl_height) / stride_height + 1 - out_width = (width + 2*pad_width - kernel_width) / stride_width + 1 + pad_top, pad_left, pad_down, pad_right = _spatial2d_pad_option( + padding, (kernel_height, kernel_width)) + pad_before = [0, 0, pad_top, pad_left] + pad_after = [0, 0, pad_down, pad_right] + temp = pad(data, pad_before, pad_after, name="pad_temp", + pad_value=tvm.min_value("float32")) + out_height = util.simplify((height - kernel_height + pad_top + pad_down) // stride_height + 1) + out_width = util.simplify((width - kernel_width + pad_left + pad_right) // stride_width + 1) dheight = tvm.reduce_axis((0, kernel_height)) dwidth = tvm.reduce_axis((0, kernel_width)) - temp = tvm.compute((batch, channel, padded_height, padded_width), lambda i, c, h, w: \ - tvm.select( - tvm.make.Or(tvm.make.Or((h < pad_height), (h >= height + pad_height)), - tvm.make.Or((w < pad_width), (w >= width + pad_width))), - tvm.min_value('float32'), - data[i, c, h - pad_height, w - pad_width]), name='temp') - - return tvm.compute((batch, channel, out_height, out_width), lambda i, c, h, w: \ - tvm.max(temp[i, c, h*stride_height+dheight, w*stride_width+dwidth], axis=[dheight, dwidth])) + return tvm.compute( + (batch, channel, out_height, out_width), + lambda i, c, h, w: + tvm.max(temp[i, c, h*stride_height+dheight, w*stride_width+dwidth], axis=[dheight, dwidth]), + tag="max_pool") @tvm.tag_scope(tag='global_avg_pool') diff --git a/topi/python/topi/nn/softmax.py b/topi/python/topi/nn/softmax.py index 56553441d6bc..8394b1afb0ab 100644 --- a/topi/python/topi/nn/softmax.py +++ b/topi/python/topi/nn/softmax.py @@ -19,9 +19,8 @@ def softmax(x): assert len(x.shape) == 2, "only support 2-dim softmax" m, n = x.shape k = tvm.reduce_axis((0, n), name='k') - max_elem = tvm.compute((m, ), lambda i: \ - tvm.max(x[i, k]), axis=k) - expsum = tvm.compute((m, ), lambda i: \ - tvm.sum(tvm.exp(x[i, k] - max_elem[i]), axis=k)) - return tvm.compute(x.shape, lambda i, j: \ - tvm.exp(x[i, j] - max_elem[i]) / expsum[i]) + max_elem = tvm.compute((m, ), lambda i: tvm.max(x[i, k], axis=k)) + expsum = tvm.compute( + (m, ), lambda i: tvm.sum(tvm.exp(x[i, k] - max_elem[i]), axis=k)) + return tvm.compute( + x.shape, lambda i, j: tvm.exp(x[i, j] - max_elem[i]) / expsum[i]) diff --git a/topi/python/topi/util.py b/topi/python/topi/util.py index ee53c1815be7..80312755013c 100644 --- a/topi/python/topi/util.py +++ b/topi/python/topi/util.py @@ -63,3 +63,19 @@ def get_const_tuple(in_tuple): raise ValueError("Element of input tuple should be const int") out_tuple = out_tuple + (elem.value, ) return out_tuple + + +def simplify(expr): + """Simplify the expression if it is Expr, directly return if it is int. + + Parameters + ---------- + expr : Expr or int + The input. + + Returns + ------- + out : Expr or int + The simplified output + """ + return tvm.ir_pass.Simplify(expr) if isinstance(expr, tvm.expr.Expr) else expr diff --git a/topi/recipe/conv/depthwise_conv2d_test.py b/topi/recipe/conv/depthwise_conv2d_test.py index 55d34ab74e4b..223570593d25 100644 --- a/topi/recipe/conv/depthwise_conv2d_test.py +++ b/topi/recipe/conv/depthwise_conv2d_test.py @@ -49,7 +49,7 @@ def test_depthwise_conv2d(): # Placeholder Input = tvm.placeholder((batch, in_channel, in_height, in_width), name='Input') Filter = tvm.placeholder((filter_channel, channel_multiplier, filter_height, filter_width), name='Filter') - Stride = tvm.nd.array(np.array([stride_h, stride_w])) + Stride = [stride_h, stride_w] Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale') Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift') # Declare diff --git a/topi/tests/python/test_topi_depthwise_conv2d.py b/topi/tests/python/test_topi_depthwise_conv2d.py index 962914b17124..bb41066921fc 100644 --- a/topi/tests/python/test_topi_depthwise_conv2d.py +++ b/topi/tests/python/test_topi_depthwise_conv2d.py @@ -13,7 +13,7 @@ def depthwise_conv2d_with_workload(batch, in_channel, in_height, channel_multipl # placeholder Input = tvm.placeholder((batch, in_channel, in_height, in_width), name='Input') Filter = tvm.placeholder((filter_channel, channel_multiplier, filter_height, filter_width), name='Filter') - Stride = tvm.nd.array(np.array([stride_h, stride_w])) + Stride = [stride_h, stride_w] Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale') Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift') # declare diff --git a/topi/tests/python/test_topi_dilate.py b/topi/tests/python/test_topi_dilate.py index 0d2014535ca4..778c0ba5e9c4 100644 --- a/topi/tests/python/test_topi_dilate.py +++ b/topi/tests/python/test_topi_dilate.py @@ -14,9 +14,7 @@ def _test_dilate(input_size, strides): input_np = np.random.uniform(size=input_size).astype(Input.dtype) output_np = topi.testing.dilate_python(input_np, strides) input_tvm = tvm.nd.array(input_np, ctx=ctx) - output_size = () - for i in range(len(input_size)): - output_size += (tvm.ir_pass.Simplify(Output.shape[i]).value,) + output_size = topi.util.get_const_tuple(Output.shape) output_tvm = tvm.nd.array(np.zeros(shape=output_size).astype(Output.dtype), ctx=ctx) f = tvm.build(schedule, [Input, Output], target) f(input_tvm, output_tvm)