From 656e5aa72161dfcd3b19c62c78bd4bf098ab4524 Mon Sep 17 00:00:00 2001 From: Leyuan Wang Date: Fri, 6 Oct 2017 09:31:13 +0000 Subject: [PATCH 1/2] conv2d tweaked for better end-to-end performance --- topi/python/topi/cuda/conv2d_nchw.py | 50 +++++++++++++++++----------- 1 file changed, 30 insertions(+), 20 deletions(-) diff --git a/topi/python/topi/cuda/conv2d_nchw.py b/topi/python/topi/cuda/conv2d_nchw.py index 4987f8d6fef2..e62f1e27d4cb 100644 --- a/topi/python/topi/cuda/conv2d_nchw.py +++ b/topi/python/topi/cuda/conv2d_nchw.py @@ -1,4 +1,4 @@ -#pylint: disable=invalid-name, no-member, too-many-locals, too-many-statements, too-many-arguments, too-many-branches +#pylint: disable=invalid-name, no-member, too-many-locals, too-many-statements, too-many-argument, too-many-branches """Schedule for conv2d_nchw with auto fusion""" import tvm from .. import util @@ -10,7 +10,7 @@ def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L): ofactor = 16 hfactor = 2 ow_size = util.get_const_int(Out.shape[3]) - num_thread = ow_size * hfactor + num_thread = ow_size*hfactor vthread = ofactor block_x = tvm.thread_axis("blockIdx.x") thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") @@ -22,6 +22,7 @@ def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L): s[Out].reorder(ooc, oh, ioc, ih, w) oc = s[Out].fuse(ooc, oh) w = s[Out].fuse(w, ih) + s[Out].bind(w, thread_x) s[Out].bind(ioc, thread_xz) s[Out].bind(oc, block_x) @@ -66,9 +67,20 @@ def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L): def conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag): """Schedule conv2d for specific feature_in_out_filter pattern""" if util.get_const_int(Filter_S.shape[0]) == util.get_const_int(Filter_S.shape[1]): - num_thread_x = 8 + mark = util.get_const_int(Out.shape[2]) * util.get_const_int(Out.shape[3]) + num_thread_x = 0 + if mark % 8 == 0 and mark % 7 == 0: + num_thread_x = 8 + vthread_x = 7 + else: + for i in range(5, mark): + if mark % i == 0 and num_thread_x == 0: + vthread_x = i + mark = mark // i + if mark % i == 0 and vthread_x > 0: + num_thread_x = i + break num_thread_y = 8 - vthread_x = 7 vthread_y = 2 ifactor = 8 @@ -80,20 +92,20 @@ def conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag): thread_yz = tvm.thread_axis((0, vthread_y), "vthread", name="vy") i, oc, h, w = s[Out].op.axis - oh, ih = s[Out].split(h, nparts=vthread_x) - w = s[Out].fuse(ih, w) + w = s[Out].fuse(h, w) + ow, iw = s[Out].split(w, factor=num_thread_x*vthread_x) ooc, ioc = s[Out].split(oc, factor=num_thread_y*vthread_y) - ow, iw = s[Out].split(w, factor=num_thread_x) + oiw, iiw = s[Out].split(iw, nparts=vthread_x) oioc, iioc = s[Out].split(ioc, nparts=vthread_y) - s[Out].reorder(i, ooc, oh, oioc, ow, iioc, iw) - s[Out].bind(iw, thread_x) + s[Out].reorder(i, ooc, ow, oioc, oiw, iioc, iiw) + s[Out].bind(iiw, thread_x) s[Out].bind(iioc, thread_y) - s[Out].bind(ow, thread_xz) + s[Out].bind(oiw, thread_xz) s[Out].bind(oioc, thread_yz) - s[Out].bind(oh, block_x) + s[Out].bind(ow, block_x) s[Out].bind(ooc, block_y) - s[Out_L].compute_at(s[Out], iw) + s[Out_L].compute_at(s[Out], iiw) # schedule Out_L local write i, oc, h, w = s[Out_L].op.axis @@ -260,9 +272,9 @@ def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L): else: # scheduler params - vthread_x = min(8, util.get_const_int(Out.shape[2])) + vthread_x = util.get_const_int(Out.shape[2]) num_thread_x = 16 - num_thread_y = min(8, util.get_const_int(Out.shape[3])) + num_thread_y = util.get_const_int(Out.shape[3]) ofactor = 8 block_x = tvm.thread_axis("blockIdx.x") thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x") @@ -271,12 +283,10 @@ def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L): i, oc, h, w = s[Out].op.axis ooc, ioc = s[Out].split(oc, factor=num_thread_x) - oh, ih = s[Out].split(h, factor=vthread_x) - ow, iw = s[Out].split(w, factor=num_thread_y) - s[Out].reorder(i, ooc, oh, ih, ow, iw, ioc) + s[Out].reorder(i, ooc, h, w, ioc) s[Out].bind(ioc, thread_x) - s[Out].bind(iw, thread_y) - s[Out].bind(ih, thread_xz) + s[Out].bind(w, thread_y) + s[Out].bind(h, thread_xz) s[Out].bind(ooc, block_x) s[Out_L].compute_at(s[Out], ioc) @@ -390,7 +400,7 @@ def conv2d_56_64_64(s, Filter, temp_S, Filter_S, Out, Out_L): s[Filter_S].bind(ii, thread_y) def schedule_conv2d_small_batch(outs): - """Create schedule for tensors or return error if batch size is larger than 1""" + """Create schedule for tensors or return error if batch size is larager than 1""" s = tvm.create_schedule([x.op for x in outs]) def schedule(temp, Filter, Output): From 96844ee9516d3d1cb7abaefc8b4632f86b939187 Mon Sep 17 00:00:00 2001 From: Leyuan Wang Date: Fri, 6 Oct 2017 09:33:35 +0000 Subject: [PATCH 2/2] syntax changed --- topi/python/topi/cuda/conv2d_nchw.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/topi/python/topi/cuda/conv2d_nchw.py b/topi/python/topi/cuda/conv2d_nchw.py index e62f1e27d4cb..8e0f22781c1d 100644 --- a/topi/python/topi/cuda/conv2d_nchw.py +++ b/topi/python/topi/cuda/conv2d_nchw.py @@ -1,4 +1,4 @@ -#pylint: disable=invalid-name, no-member, too-many-locals, too-many-statements, too-many-argument, too-many-branches +#pylint: disable=invalid-name, no-member, too-many-locals, too-many-statements, too-many-arguments, too-many-branches """Schedule for conv2d_nchw with auto fusion""" import tvm from .. import util @@ -10,7 +10,7 @@ def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L): ofactor = 16 hfactor = 2 ow_size = util.get_const_int(Out.shape[3]) - num_thread = ow_size*hfactor + num_thread = ow_size * hfactor vthread = ofactor block_x = tvm.thread_axis("blockIdx.x") thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") @@ -22,7 +22,6 @@ def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L): s[Out].reorder(ooc, oh, ioc, ih, w) oc = s[Out].fuse(ooc, oh) w = s[Out].fuse(w, ih) - s[Out].bind(w, thread_x) s[Out].bind(ioc, thread_xz) s[Out].bind(oc, block_x) @@ -400,7 +399,7 @@ def conv2d_56_64_64(s, Filter, temp_S, Filter_S, Out, Out_L): s[Filter_S].bind(ii, thread_y) def schedule_conv2d_small_batch(outs): - """Create schedule for tensors or return error if batch size is larager than 1""" + """Create schedule for tensors or return error if batch size is larger than 1""" s = tvm.create_schedule([x.op for x in outs]) def schedule(temp, Filter, Output):