diff --git a/topi/python/topi/cuda/__init__.py b/topi/python/topi/cuda/__init__.py index 6456441d6038..65753bc088f3 100644 --- a/topi/python/topi/cuda/__init__.py +++ b/topi/python/topi/cuda/__init__.py @@ -2,5 +2,6 @@ """CUDA specific declaration and schedules.""" from __future__ import absolute_import as _abs -from .conv2d_hwcn_map import schedule_conv2d_hwcn_map +from .conv2d_nchw import schedule_conv2d_nchw +from .conv2d_hwcn import schedule_conv2d_hwcn from .depthwise_conv2d_map import schedule_depthwise_conv2d_map diff --git a/topi/python/topi/cuda/conv2d_hwcn.py b/topi/python/topi/cuda/conv2d_hwcn.py new file mode 100644 index 000000000000..210660a230f5 --- /dev/null +++ b/topi/python/topi/cuda/conv2d_hwcn.py @@ -0,0 +1,119 @@ +# pylint: disable=invalid-name, too-many-locals, too-many-statements +"""Schedule for conv2d_hwcn with auto fusion""" +import tvm + + +def schedule_conv2d_hwcn(outs): + """Schedule for conv2d_hwcn and any element-wise operations. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of conv2d_hwcn in the format + of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for conv2d_hwcn. + """ + outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs + sch = tvm.create_schedule([x.op for x in outs]) + def schedule(Apad, W, B): + """Schedule conv2d_hwcn""" + sch[Apad].compute_inline() + AA = sch.cache_read(Apad, "shared", [B]) + WW = sch.cache_read(W, "shared", [B]) + AL = sch.cache_read(AA, "local", [B]) + WL = sch.cache_read(WW, "local", [B]) + + if B.op in sch.outputs: + Out = B + BL = sch.cache_write(Out, "local") + else: + Out = sch.outputs[0].output(0) + sch[B].set_scope("local") + BL = B + + tile = 8 + num_thread = 8 + block_factor = tile * num_thread + step = 8 + vthread = 2 + + block_x = tvm.thread_axis("blockIdx.x") + block_y = tvm.thread_axis("blockIdx.y") + block_z = tvm.thread_axis("blockIdx.z") + thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") + thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y") + thread_xz = tvm.thread_axis((0, vthread), "vthread", name="vx") + thread_yz = tvm.thread_axis((0, vthread), "vthread", name="vy") + + hi, wi, fi, ni = sch[Out].op.axis + bz = sch[Out].fuse(hi, wi) + by, fi = sch[Out].split(fi, factor=block_factor) + bx, ni = sch[Out].split(ni, factor=block_factor) + tyz, fi = sch[Out].split(fi, nparts=vthread) + txz, ni = sch[Out].split(ni, nparts=vthread) + ty, fi = sch[Out].split(fi, nparts=num_thread) + tx, ni = sch[Out].split(ni, nparts=num_thread) + sch[Out].reorder(bz, by, bx, tyz, txz, ty, tx, fi, ni) + sch[Out].bind(bz, block_z) + sch[Out].bind(by, block_y) + sch[Out].bind(bx, block_x) + sch[Out].bind(tyz, thread_yz) + sch[Out].bind(txz, thread_xz) + sch[Out].bind(ty, thread_y) + sch[Out].bind(tx, thread_x) + + # Schedule BL local write + sch[BL].compute_at(sch[Out], tx) + yi, xi, fi, ni = sch[BL].op.axis + ry, rx, rc = sch[BL].op.reduce_axis + rco, rci = sch[BL].split(rc, factor=step) + sch[BL].reorder(rco, ry, rx, rci, fi, ni) + fuse_index = sch[BL].fuse(ry, rx) + fuse_index = sch[BL].fuse(fuse_index, rco) + rx = fuse_index + + sch[AA].compute_at(sch[BL], rx) + sch[WW].compute_at(sch[BL], rx) + sch[AL].compute_at(sch[BL], rci) + sch[WL].compute_at(sch[BL], rci) + # Schedule for A's shared memory load + yi, xi, ci, ni = sch[AA].op.axis + ty, ci = sch[AA].split(ci, nparts=num_thread) + tx, ni = sch[AA].split(ni, nparts=num_thread) + _, ni = sch[AA].split(ni, factor=4) + sch[AA].reorder(ty, tx, yi, xi, ci, ni) + sch[AA].bind(ty, thread_y) + sch[AA].bind(tx, thread_x) + sch[AA].vectorize(ni) + # Schedule for W's shared memory load + yi, xi, ci, fi = sch[WW].op.axis + ty, ci = sch[WW].split(ci, nparts=num_thread) + tx, fi = sch[WW].split(fi, nparts=num_thread) + _, fi = sch[WW].split(fi, factor=4) + sch[WW].reorder(ty, tx, yi, xi, ci, fi) + sch[WW].bind(ty, thread_y) + sch[WW].bind(tx, thread_x) + sch[WW].vectorize(fi) + + def traverse(operator): + """Traverse operators from computation graph""" + if operator.tag == 'ewise' or operator.tag == 'scale_shift': + if operator not in sch.outputs: + sch[operator].compute_inline() + for tensor in operator.input_tensors: + if tensor.op.input_tensors: + traverse(tensor.op) + elif operator.tag == 'conv2d_hwcn': + Apad = operator.input_tensors[0] + W = operator.input_tensors[1] + B = operator.output(0) + schedule(Apad, W, B) + else: + raise RuntimeError("Unsupported operator: %s" % operator.tag) + + traverse(outs[0].op) + return sch diff --git a/topi/python/topi/cuda/conv2d_hwcn_map.py b/topi/python/topi/cuda/conv2d_hwcn_map.py deleted file mode 100644 index 7b932523b720..000000000000 --- a/topi/python/topi/cuda/conv2d_hwcn_map.py +++ /dev/null @@ -1,121 +0,0 @@ -# pylint: disable=invalid-name -"""Schedule for conv2d_hwcn with auto fusion""" -import tvm - - -def _schedule_conv2d_hwcn(op, sch): - assert len(op.input_tensors) == 2 - Apad = op.input_tensors[0] - W = op.input_tensors[1] - B = op.output(0) - - sch[Apad].compute_inline() - AA = sch.cache_read(Apad, "shared", [B]) - WW = sch.cache_read(W, "shared", [B]) - AL = sch.cache_read(AA, "local", [B]) - WL = sch.cache_read(WW, "local", [B]) - - if op in sch.outputs: - Out = op.output(0) - BL = sch.cache_write(Out, "local") - else: - Out = sch.outputs[0].output(0) - sch[B].set_scope("local") - BL = B - - tile = 8 - num_thread = 8 - block_factor = tile * num_thread - step = 8 - vthread = 2 - - block_x = tvm.thread_axis("blockIdx.x") - block_y = tvm.thread_axis("blockIdx.y") - block_z = tvm.thread_axis("blockIdx.z") - thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") - thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y") - thread_xz = tvm.thread_axis((0, vthread), "vthread", name="vx") - thread_yz = tvm.thread_axis((0, vthread), "vthread", name="vy") - - hi, wi, fi, ni = sch[Out].op.axis - bz = sch[Out].fuse(hi, wi) - by, fi = sch[Out].split(fi, factor=block_factor) - bx, ni = sch[Out].split(ni, factor=block_factor) - tyz, fi = sch[Out].split(fi, nparts=vthread) - txz, ni = sch[Out].split(ni, nparts=vthread) - ty, fi = sch[Out].split(fi, nparts=num_thread) - tx, ni = sch[Out].split(ni, nparts=num_thread) - sch[Out].reorder(bz, by, bx, tyz, txz, ty, tx, fi, ni) - sch[Out].bind(bz, block_z) - sch[Out].bind(by, block_y) - sch[Out].bind(bx, block_x) - sch[Out].bind(tyz, thread_yz) - sch[Out].bind(txz, thread_xz) - sch[Out].bind(ty, thread_y) - sch[Out].bind(tx, thread_x) - - # Schedule BL local write - sch[BL].compute_at(sch[Out], tx) - yi, xi, fi, ni = sch[BL].op.axis - ry, rx, rc = sch[BL].op.reduce_axis - rco, rci = sch[BL].split(rc, factor=step) - sch[BL].reorder(rco, ry, rx, rci, fi, ni) - fuse_index = sch[BL].fuse(ry, rx) - fuse_index = sch[BL].fuse(fuse_index, rco) - rx = fuse_index - - sch[AA].compute_at(sch[BL], rx) - sch[WW].compute_at(sch[BL], rx) - sch[AL].compute_at(sch[BL], rci) - sch[WL].compute_at(sch[BL], rci) - # Schedule for A's shared memory load - yi, xi, ci, ni = sch[AA].op.axis - ty, ci = sch[AA].split(ci, nparts=num_thread) - tx, ni = sch[AA].split(ni, nparts=num_thread) - _, ni = sch[AA].split(ni, factor=4) - sch[AA].reorder(ty, tx, yi, xi, ci, ni) - sch[AA].bind(ty, thread_y) - sch[AA].bind(tx, thread_x) - sch[AA].vectorize(ni) - # Schedule for W's shared memory load - yi, xi, ci, fi = sch[WW].op.axis - ty, ci = sch[WW].split(ci, nparts=num_thread) - tx, fi = sch[WW].split(fi, nparts=num_thread) - _, fi = sch[WW].split(fi, factor=4) - sch[WW].reorder(ty, tx, yi, xi, ci, fi) - sch[WW].bind(ty, thread_y) - sch[WW].bind(tx, thread_x) - sch[WW].vectorize(fi) - - return sch - - -def schedule_conv2d_hwcn_map(op): - """Schedule for conv2d_hwcn map ops. - - Parameters - ---------- - op: tvm.tensor.Operation - The symbolic description of the operation, should be conv2d_hwcn or - conv2d_hwcn followed by a sequence of one-to-one-mapping operators. - - Returns - ------- - sch: Schedule - The computation schedule for the op. - """ - def traverse(operator): - if operator.tag == 'ewise' or operator.tag == 'scale_shift': - if operator not in sch.outputs: - sch[operator].compute_inline() - for tensor in operator.input_tensors: - if tensor.op.input_tensors: - traverse(tensor.op) - elif operator.tag == 'conv2d_hwcn': - _schedule_conv2d_hwcn(operator, sch) - else: - raise RuntimeError("Unsupported operator: %s" % operator.tag) - - sch = tvm.create_schedule(op) - traverse(op) - return sch diff --git a/topi/python/topi/cuda/conv2d_nchw.py b/topi/python/topi/cuda/conv2d_nchw.py new file mode 100644 index 000000000000..83abfde9107e --- /dev/null +++ b/topi/python/topi/cuda/conv2d_nchw.py @@ -0,0 +1,137 @@ +# pylint: disable=invalid-name, no-member, too-many-locals, too-many-statements +"""Schedule for conv2d_nchw with auto fusion""" +import tvm +from .. import util + + +def schedule_conv2d_small_batch(outs): + """Create schedule for tensors or return error if batch size is larager than 1""" + s = tvm.create_schedule([x.op for x in outs]) + + def schedule(temp, Filter, Output): + """Schedule conv2d_nchw""" + block_h = util.get_const_int(Output.shape[3]) + block_w = util.get_const_int(temp.shape[1]) + if block_h % 48 == 0: + block_h = 48 + elif block_h % 32 == 0: + block_h = 32 + if block_w % 48 == 0: + block_w = 48 + elif block_w % 32 == 0: + block_w = 32 + + s[temp].compute_inline() + + temp_S = s.cache_read(temp, "shared", [Output]) + Filter_S = s.cache_read(Filter, "shared", [Output]) + + if Output.op in s.outputs: + Out = Output + Out_L = s.cache_write(Out, "local") + else: + Out = outs[0].op.output(0) + s[Output].set_scope("local") + Out_L = Output + + # sheduler params + num_thread = 8 + vthread = 2 + out_filter = min(64, util.get_const_int(Filter.shape[0])) + in_filter = util.get_const_int(Filter.shape[1]) + opart2 = out_filter//8 + ofactor = out_filter + wfactor = block_h + ifactor = in_filter//4 + sfactor = max(1, ofactor//(opart2*2)) + spart = (wfactor + vthread-1) // vthread + block_x = tvm.thread_axis("blockIdx.x") + block_y = tvm.thread_axis("blockIdx.y") + block_z = tvm.thread_axis("blockIdx.z") + thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") + thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y") + thread_xz = tvm.thread_axis((0, vthread), "vthread", name="vx") + thread_yz = tvm.thread_axis((0, vthread), "vthread", name="vy") + + i, oc, h, w = s[Out].op.axis + ooc, ioc = s[Out].split(oc, factor=ofactor) + ow, iw = s[Out].split(w, factor=wfactor) + ow = s[Out].fuse(ow, h) + oioc, iioc = s[Out].split(ioc, nparts=vthread) + oiw, iiw = s[Out].split(iw, nparts=vthread) + oiioc, iiioc = s[Out].split(iioc, nparts=opart2) + s[Out].reorder(i, ooc, ow, oioc, oiw, oiioc, iiw, iiioc) + s[Out].bind(iiioc, thread_x) + s[Out].bind(iiw, thread_y) + s[Out].bind(oiioc, thread_xz) + s[Out].bind(oiw, thread_yz) + s[Out].bind(oioc, block_x) + s[Out].bind(ow, block_y) + s[Out].bind(ooc, block_z) + + s[Out_L].compute_at(s[Out], iiioc) + + # schedule Out_L local write + i, oc, h, w = s[Out_L].op.axis + ic, dh, dw = s[Out_L].op.reduce_axis + oic, iic = s[Out_L].split(ic, factor=ifactor) + s[Out_L].reorder(oic, dh, dw, iic, h, w) + fuse_index = s[Out_L].fuse(dw, dh) + fuse_index = s[Out_L].fuse(fuse_index, oic) + dw = fuse_index + + s[temp_S].compute_at(s[Out_L], dw) + s[Filter_S].compute_at(s[Out_L], dw) + + #schedule temp_S shared mem load + i, ic, h, w = s[temp_S].op.axis + _, iic = s[temp_S].split(ic, factor=sfactor) + _, iw = s[temp_S].split(w, factor=spart) + s[temp_S].bind(iic, thread_x) + s[temp_S].bind(iw, thread_y) + + #schedule Filter_S shared mem load + i, oc, h, w = s[Filter_S].op.axis + _, ioc = s[Filter_S].split(oc, factor=sfactor) + _, ii = s[Filter_S].split(i, factor=spart) + s[Filter_S].bind(ioc, thread_x) + s[Filter_S].bind(ii, thread_y) + + def traverse(OP): + """Traverse operators from computation graph""" + # inline all one-to-one-mapping operators except the last stage (output) + if 'ewise' in OP.tag or 'bcast' in OP.tag: + if OP not in s.outputs: + s[OP].compute_inline() + for tensor in OP.input_tensors: + if tensor.op.input_tensors: + traverse(tensor.op) + # schedule conv2d + if 'conv2d_nchw' in OP.tag: + temp = OP.input_tensors[0] + Filter = OP.input_tensors[1] + Output = OP.output(0) + schedule(temp, Filter, Output) + + traverse(outs[0].op) + return s + +def schedule_conv2d_nchw(outs): + """Schedule for conv2d_nchw and any element-wise operations. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of conv2d_nchw + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for conv2d_nchw. + """ + outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs + batch_size = util.get_const_int(outs[0].op.output(0).shape[0]) + if batch_size > 1: + raise RuntimeError("Batch size: %d is too large for this schedule" % batch_size) + return schedule_conv2d_small_batch(outs) diff --git a/topi/python/topi/nn/conv.py b/topi/python/topi/nn/conv.py index 768387e54d0d..504af1aa24b5 100644 --- a/topi/python/topi/nn/conv.py +++ b/topi/python/topi/nn/conv.py @@ -1,4 +1,4 @@ -# pylint: disable=invalid-name, line-too-long, unused-variable +# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals """Convolution operators""" from __future__ import absolute_import as _abs import tvm @@ -6,6 +6,69 @@ from ..util import get_const_tuple +@tvm.tag_scope(tag="conv2d_nchw") +def conv2d_nchw(Input, Filter, stride, padding): + """Convolution operator in HWCN layout. + + Parameters + ---------- + Input : tvm.Tensor + 4-D with shape [batch, in_channel, in_height, in_width] + + Filter : tvm.Tensor + 4-D with shape [num_filter, in_channel, filter_height, filter_width] + + stride : int or a list/tuple of two ints + Stride size, or [stride_height, stride_width] + + padding : int or str + Padding size, or ['VALID', 'SAME'] + + Returns + ------- + Output : tvm.Tensor + 4-D with shape [batch, out_channel, out_height, out_width] + """ + assert isinstance(stride, int) or len(stride) == 2 + assert isinstance(padding, int) or padding in ['VALID', 'SAME'] + batch, in_channel, in_height, in_width = get_const_tuple(Input.shape) + num_filter, channel, kernel_h, kernel_w = get_const_tuple(Filter.shape) + if isinstance(stride, int): + stride_h = stride_w = stride + else: + stride_h, stride_w = stride + # compute the padding size + if isinstance(padding, int): + pad_h = pad_w = padding * 2 + elif padding == 'VALID': + pad_h = 0 + pad_w = 0 + else: # 'SAME' + pad_h = kernel_h - 1 + pad_w = kernel_w - 1 + pad_top = int(np.ceil(float(pad_h) / 2)) + pad_left = int(np.ceil(float(pad_w) / 2)) + # compute the output shape + out_channel = num_filter + out_height = (in_height - kernel_h + pad_h) // stride_h + 1 + out_width = (in_width - kernel_w + pad_w) // stride_w + 1 + # compute graph + temp = tvm.compute( + (batch, in_channel, in_height + pad_h, in_width + pad_w), + lambda nn, cc, yy, xx: tvm.select( + tvm.all(yy >= pad_top, yy - pad_top < in_height, + xx >= pad_left, xx - pad_left < in_width), + Input[nn, cc, yy - pad_top, xx - pad_left], tvm.const(0.)), + name='temp') + rc = tvm.reduce_axis((0, in_channel), name='rc') + ry = tvm.reduce_axis((0, kernel_h), name='ry') + rx = tvm.reduce_axis((0, kernel_w), name='rx') + return tvm.compute( + (batch, out_channel, out_height, out_width), + lambda nn, ff, yy, xx: tvm.sum( + temp[nn, rc, yy * stride_h + ry, xx * stride_w + rx] * Filter[ff, rc, ry, rx], + axis=[rc, ry, rx])) + @tvm.tag_scope(tag="conv2d_hwcn") def conv2d_hwcn(Input, Filter, stride, padding): """Convolution operator in HWCN layout. diff --git a/topi/python/topi/testing/__init__.py b/topi/python/topi/testing/__init__.py index 28100658fc05..63bc8eb7215a 100644 --- a/topi/python/topi/testing/__init__.py +++ b/topi/python/topi/testing/__init__.py @@ -5,3 +5,4 @@ from __future__ import absolute_import as _abs from .conv2d_hwcn_python import conv2d_hwcn_python +from .conv2d_nchw_python import conv2d_nchw_python diff --git a/topi/python/topi/testing/conv2d_hwcn_python.py b/topi/python/topi/testing/conv2d_hwcn_python.py index e240cfb722ae..c84efce5e777 100644 --- a/topi/python/topi/testing/conv2d_hwcn_python.py +++ b/topi/python/topi/testing/conv2d_hwcn_python.py @@ -1,4 +1,4 @@ -# pylint: disable=invalid-name, line-too-long, unused-variable +# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals """Convolution in python""" import numpy as np import scipy.signal diff --git a/topi/python/topi/testing/conv2d_nchw_python.py b/topi/python/topi/testing/conv2d_nchw_python.py new file mode 100644 index 000000000000..169605faaf45 --- /dev/null +++ b/topi/python/topi/testing/conv2d_nchw_python.py @@ -0,0 +1,64 @@ +# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals +"""Convolution in python""" +import numpy as np +import scipy.signal + + +def conv2d_nchw_python(a_np, w_np, stride, padding): + """Convolution operator in HWCN layout. + + Parameters + ---------- + a_np : numpy.ndarray + 4-D with shape [batch, in_channel, in_height, in_width] + + w_np : numpy.ndarray + 4-D with shape [num_filter, in_channel, filter_height, filter_width] + + stride : int or a list/tuple of two ints + Stride size, or [stride_height, stride_width] + + padding : int or str + Padding size, or ['VALID', 'SAME'] + + Returns + ------- + b_np : np.ndarray + 4-D with shape [batch, out_channel, out_height, out_width] + """ + batch, in_channel, in_height, in_width = a_np.shape + num_filter, _, kernel_h, kernel_w = w_np.shape + if isinstance(stride, int): + stride_h = stride_w = stride + else: + stride_h, stride_w = stride + if isinstance(padding, int): + pad_h = pad_w = padding * 2 + elif padding == 'VALID': + pad_h = 0 + pad_w = 0 + else: # 'SAME' + pad_h = kernel_h - 1 + pad_w = kernel_w - 1 + pad_top = int(np.ceil(float(pad_h) / 2)) + pad_bottom = pad_h - pad_top + pad_left = int(np.ceil(float(pad_w) / 2)) + pad_right = pad_w - pad_left + # compute the output shape + out_channel = num_filter + out_height = (in_height - kernel_h + pad_h) // stride_h + 1 + out_width = (in_width - kernel_w + pad_w) // stride_w + 1 + b_np = np.zeros((batch, out_channel, out_height, out_width)) + # computation + for n in range(batch): + for f in range(out_channel): + for c in range(in_channel): + if pad_h > 0: + apad = np.zeros((in_height + pad_h, in_width + pad_w)) + apad[pad_top:-pad_bottom, pad_left:-pad_right] = a_np[n, c] + else: + apad = a_np[n, c] + out = scipy.signal.convolve2d( + apad, np.rot90(np.rot90(w_np[f, c])), mode='valid') + b_np[n, f] += out[::stride, ::stride] + return b_np diff --git a/topi/recipe/conv/test_conv2d_hwcn_map.py b/topi/recipe/conv/test_conv2d_hwcn_map.py index 553c93aa74a2..a6b9017a74eb 100644 --- a/topi/recipe/conv/test_conv2d_hwcn_map.py +++ b/topi/recipe/conv/test_conv2d_hwcn_map.py @@ -12,7 +12,7 @@ @tvm.register_func def tvm_callback_cuda_compile(code): - ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_52"]) + ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_37"]) return ptx def write_code(code, fname): @@ -43,8 +43,8 @@ def test_conv2d_hwcn_map(): W = tvm.placeholder((kernel, kernel, in_channel, num_filter), name='W') B = topi.nn.conv2d_hwcn(A, W, stride, padding) C = topi.nn.relu(B) - s1 = topi.cuda.schedule_conv2d_hwcn_map(B.op) - s2 = topi.cuda.schedule_conv2d_hwcn_map(C.op) + s1 = topi.cuda.schedule_conv2d_hwcn([B]) + s2 = topi.cuda.schedule_conv2d_hwcn([C]) a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) w_np = np.random.uniform(size=get_const_tuple(W.shape)).astype(W.dtype) diff --git a/topi/tests/python/test_topi_conv2d_hwcn_map.py b/topi/tests/python/test_topi_conv2d_hwcn_map.py index 820d859847a8..f2a8ea14db71 100644 --- a/topi/tests/python/test_topi_conv2d_hwcn_map.py +++ b/topi/tests/python/test_topi_conv2d_hwcn_map.py @@ -13,8 +13,8 @@ def verify_conv2d_hwcn_map(batch, in_channel, in_size, num_filter, kernel, strid W = tvm.placeholder((kernel, kernel, in_channel, num_filter), name='W') B = topi.nn.conv2d_hwcn(A, W, stride, padding) C = topi.nn.relu(B) - s1 = topi.cuda.schedule_conv2d_hwcn_map(B.op) - s2 = topi.cuda.schedule_conv2d_hwcn_map(C.op) + s1 = topi.cuda.schedule_conv2d_hwcn([B]) + s2 = topi.cuda.schedule_conv2d_hwcn([C]) a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) w_np = np.random.uniform(size=get_const_tuple(W.shape)).astype(W.dtype) diff --git a/topi/tests/python/test_topi_conv2d_nchw.py b/topi/tests/python/test_topi_conv2d_nchw.py new file mode 100644 index 000000000000..a40f10ce3d7e --- /dev/null +++ b/topi/tests/python/test_topi_conv2d_nchw.py @@ -0,0 +1,61 @@ +"""Example code to do convolution.""" +import os +import numpy as np +import tvm +import topi +from topi.util import get_const_tuple + + +def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding): + in_height = in_width = in_size + + A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A') + W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W') + B = topi.nn.conv2d_nchw(A, W, stride, padding) + C = topi.nn.relu(B) + s1 = topi.cuda.schedule_conv2d_nchw([B]) + s2 = topi.cuda.schedule_conv2d_nchw([C]) + + a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) + w_np = np.random.uniform(size=get_const_tuple(W.shape)).astype(W.dtype) + b_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding) + c_np = np.maximum(b_np, 0) + + def check_device(device): + if not tvm.module.enabled(device): + print("Skip because %s is not enabled" % device) + return + ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) + a = tvm.nd.array(a_np, ctx) + w = tvm.nd.array(w_np, ctx) + b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) + c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) + with tvm.build_config(auto_unroll_max_step=32, + auto_unroll_min_depth=0, + unroll_explicit=False): + func1 = tvm.build(s1, [A, W, B], device) + func2 = tvm.build(s2, [A, W, C], device) + func1(a, w, b) + func2(a, w, c) + np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) + np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) + + for device in ['cuda', 'opencl', 'metal']: + check_device(device) + + +def test_conv2d_nchw(): + verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1) + verify_conv2d_nchw(1, 64, 56, 64, 1, 1, 0) + verify_conv2d_nchw(1, 64, 56, 128, 3, 2, 1) + verify_conv2d_nchw(1, 64, 56, 128, 1, 2, 0) + verify_conv2d_nchw(1, 128, 28, 128, 3, 1, 1) + verify_conv2d_nchw(1, 128, 28, 256, 3, 2, 1) + verify_conv2d_nchw(1, 128, 28, 256, 1, 2, 0) + verify_conv2d_nchw(1, 256, 14, 256, 3, 1, 1) + verify_conv2d_nchw(1, 256, 14, 512, 3, 2, 1) + verify_conv2d_nchw(1, 256, 14, 512, 1, 2, 0) + verify_conv2d_nchw(1, 512, 7, 512, 3, 1, 1) + +if __name__ == "__main__": + test_conv2d_nchw()