From 53531ecdc4363ead3ca34010a96e0bef9d8db668 Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 22 Sep 2017 00:26:04 +0000 Subject: [PATCH 1/4] conv2d layout change and packing added for the last workload --- topi/python/topi/cuda/conv2d_nchw.py | 45 ++++++++++++++++++++++------ 1 file changed, 36 insertions(+), 9 deletions(-) diff --git a/topi/python/topi/cuda/conv2d_nchw.py b/topi/python/topi/cuda/conv2d_nchw.py index 86b0cb6236e4..729e753d365e 100644 --- a/topi/python/topi/cuda/conv2d_nchw.py +++ b/topi/python/topi/cuda/conv2d_nchw.py @@ -187,7 +187,7 @@ def conv2d_56_64_128(s, temp_S, Filter_S, Out, Out_L, flag): s[Filter_S].bind(ioc, thread_x) s[Filter_S].bind(ii, thread_y) -def conv2d_14_256_256(s, Filter, temp_S, Filter_S, Out, Out_L): +def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L): """Schedule conv2d for specific feature_in_out_filter pattern""" if util.get_const_int(Filter.shape[1]) == 256: # scheduler params @@ -262,13 +262,30 @@ def conv2d_14_256_256(s, Filter, temp_S, Filter_S, Out, Out_L): s[temp_S].compute_at(s[Out_L], oic) s[Filter_S].compute_at(s[Out_L], oic) + rfactor = util.get_const_int(Filter.shape[1]) + thread_xx = tvm.thread_axis((0, rfactor), "threadIdx.x") + block_xx = tvm.thread_axis("blockIdx.x") + + i, ic, h, w = s[temp].op.axis + ic = s[temp].fuse(ic, h, w) + oic, iic = s[temp].split(ic, factor = rfactor) + s[temp].bind(iic, thread_xx) + s[temp].bind(oic, block_xx) + + i, h, w, oic, iic = s[temp_R].op.axis + ic = s[temp_R].fuse(oic, iic) + s[temp_R].bind(ic, thread_xx) + h = s[temp_R].fuse(h, w) + s[temp_R].bind(h, block_xx) + #schedule temp_S shared mem load - i, ic, h, w = s[temp_S].op.axis - ic = s[temp_S].fuse(w, h, ic) - oic, iic = s[temp_S].split(ic, factor=num_thread_x) + i, h, w, oc, ic = s[temp_S].op.axis + icc = s[temp_S].fuse(oc, w, h) + oic, iic = s[temp_S].split(icc, factor=num_thread_x) _, ioic = s[temp_S].split(oic, factor=num_thread_y) s[temp_S].bind(iic, thread_x) s[temp_S].bind(ioic, thread_y) + s[temp_S].vectorize(ic) #schedule Filter_S shared mem load i, oc, h, w = s[Filter_S].op.axis @@ -363,9 +380,21 @@ def schedule(temp, Filter, Output): elif block_w % 32 == 0: block_w = 32 - s[temp].compute_inline() + flag = util.get_const_int(Filter.shape[0])+util.get_const_int(Filter.shape[1]) + + if flag > 768: + temp_G = s.cache_read(temp, "global", [Output]) + s[temp_G].compute_inline() + i, ic, h, w = s[temp_G].op.axis + oic, iic = s[temp_G].split(ic, factor=4) + s[temp_G].reorder(i, h, w, oic, iic) + temp_R = s.cache_write(temp_G, "global") + temp_S = s.cache_read(temp_R, "shared", [temp_G]) + else: + s[temp].compute_inline() + temp_S = s.cache_read(temp, "shared", [Output]) + temp_R = temp_S - temp_S = s.cache_read(temp, "shared", [Output]) Filter_S = s.cache_read(Filter, "shared", [Output]) if Output.op in s.outputs: @@ -376,14 +405,12 @@ def schedule(temp, Filter, Output): s[Output].set_scope("local") Out_L = Output - flag = util.get_const_int(Filter.shape[0])+util.get_const_int(Filter.shape[1]) - if util.get_const_int(Filter.shape[3]) == 7: conv2d_224_3_64(s, temp_S, Filter_S, Out, Out_L) elif 128 < flag < 512: conv2d_56_64_128(s, temp_S, Filter_S, Out, Out_L, flag) elif flag >= 512: - conv2d_14_256_256(s, Filter, temp_S, Filter_S, Out, Out_L) + conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L) else: conv2d_56_64_64(s, Filter, temp_S, Filter_S, Out, Out_L) From 98c64107b4c3b0ca751716fc6160984c9ca46398 Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 22 Sep 2017 22:41:58 +0000 Subject: [PATCH 2/4] packing added for other workloads --- topi/python/topi/cuda/conv2d_nchw.py | 120 ++++++++++++++++----------- 1 file changed, 70 insertions(+), 50 deletions(-) diff --git a/topi/python/topi/cuda/conv2d_nchw.py b/topi/python/topi/cuda/conv2d_nchw.py index 729e753d365e..6dbd53ea31d8 100644 --- a/topi/python/topi/cuda/conv2d_nchw.py +++ b/topi/python/topi/cuda/conv2d_nchw.py @@ -48,7 +48,7 @@ def conv2d_224_3_64(s, temp_S, Filter_S, Out, Out_L): tx, _ = s[Filter_S].split(w, nparts=num_thread) s[Filter_S].bind(tx, thread_x) -def conv2d_56_64_128(s, temp_S, Filter_S, Out, Out_L, flag): +def conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag): """Schedule conv2d for specific feature_in_out_filter pattern""" if util.get_const_int(Filter_S.shape[0]) == util.get_const_int(Filter_S.shape[1]): num_thread_x = 8 @@ -89,13 +89,28 @@ def conv2d_56_64_128(s, temp_S, Filter_S, Out, Out_L, flag): s[temp_S].compute_at(s[Out_L], oic) s[Filter_S].compute_at(s[Out_L], dw) + num_thread = 512 + thread_xx = tvm.thread_axis((0, num_thread), "threadIdx.x") + block_xx = tvm.thread_axis("blockIdx.x") + + i = s[temp].fuse(*s[temp].op.axis) + bx, tx = s[temp].split(i, factor=num_thread) + s[temp].bind(tx, thread_xx) + s[temp].bind(bx, block_xx) + + i = s[temp_R].fuse(*s[temp_R].op.axis) + bx, tx = s[temp_R].split(i, factor=num_thread) + s[temp_R].bind(tx, thread_xx) + s[temp_R].bind(bx, block_xx) + #schedule temp_S shared mem load - i, ic, h, w = s[temp_S].op.axis - _, iic = s[temp_S].split(ic, factor=num_thread_y) - w = s[temp_S].fuse(h, w) - _, iw = s[temp_S].split(w, factor=num_thread_x) - s[temp_S].bind(iic, thread_y) - s[temp_S].bind(iw, thread_x) + i, oic, h, w, iic = s[temp_S].op.axis + oic = s[temp_S].fuse(oic, h, w) + ooic, ioic = s[temp_S].split(oic, factor=num_thread_x) + _, iooic = s[temp_S].split(ooic, factor=num_thread_y) + s[temp_S].bind(ioic, thread_x) + s[temp_S].bind(iooic, thread_y) + s[temp_S].vectorize(iic) i, oc, h, w = s[Filter_S].op.axis _, ioc = s[Filter_S].split(oc, factor=num_thread_y) @@ -104,7 +119,6 @@ def conv2d_56_64_128(s, temp_S, Filter_S, Out, Out_L, flag): s[Filter_S].bind(ii, thread_x) else: # scheduler params - num_thread = 8 vthread = 2 opart2 = 4 ofactor = 64 @@ -112,13 +126,13 @@ def conv2d_56_64_128(s, temp_S, Filter_S, Out, Out_L, flag): ifactor = 8 if flag > 256: wfactor = 14 - sfactor = max(1, ofactor//(opart2*2)) - spart = max(1, (wfactor + vthread-1) // vthread) + num_thread_x = max(1, ofactor//(opart2*2)) + num_thread_y = max(1, (wfactor + vthread-1) // vthread) block_x = tvm.thread_axis("blockIdx.x") block_y = tvm.thread_axis("blockIdx.y") block_z = tvm.thread_axis("blockIdx.z") - thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") - thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y") + thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x") + thread_y = tvm.thread_axis((0, num_thread_y), "threadIdx.y") thread_xz = tvm.thread_axis((0, vthread), "vthread", name="vx") thread_yz = tvm.thread_axis((0, vthread), "vthread", name="vy") @@ -140,52 +154,49 @@ def conv2d_56_64_128(s, temp_S, Filter_S, Out, Out_L, flag): s[Out_L].compute_at(s[Out], iiioc) + # schedule Out_L local write + i, oc, h, w = s[Out_L].op.axis + ic, dh, dw = s[Out_L].op.reduce_axis + oic, iic = s[Out_L].split(ic, factor=ifactor) + s[Out_L].reorder(oic, dh, dw, iic, h, w) if util.get_const_int(Filter_S.shape[1]) == 128: - # schedule Out_L local write - i, oc, h, w = s[Out_L].op.axis - ic, dh, dw = s[Out_L].op.reduce_axis - oic, iic = s[Out_L].split(ic, factor=ifactor) - s[Out_L].reorder(oic, dh, dw, iic, h, w) oic = s[Out_L].fuse(dh, oic) s[temp_S].compute_at(s[Out_L], oic) s[Filter_S].compute_at(s[Out_L], oic) - - #schedule temp_S shared mem load - i, ic, h, w = s[temp_S].op.axis - _, iic = s[temp_S].split(ic, factor=sfactor) - _, iw = s[temp_S].split(w, factor=spart) - s[temp_S].bind(iic, thread_x) - s[temp_S].bind(iw, thread_y) - - #schedule Filter_S shared mem load - i, oc, h, w = s[Filter_S].op.axis - _, ioc = s[Filter_S].split(oc, factor=sfactor) - _, ii = s[Filter_S].split(i, factor=spart) - s[Filter_S].bind(ioc, thread_x) - s[Filter_S].bind(ii, thread_y) + num_thread = 512 else: - # schedule Out_L local write - i, oc, h, w = s[Out_L].op.axis - ic, dh, dw = s[Out_L].op.reduce_axis - oic, iic = s[Out_L].split(ic, factor=ifactor) - s[Out_L].reorder(oic, dh, dw, iic, h, w) - s[temp_S].compute_at(s[Out_L], oic) s[Filter_S].compute_at(s[Out_L], dw) + num_thread = 456 - #schedule temp_S shared mem load - i, ic, h, w = s[temp_S].op.axis - _, iic = s[temp_S].split(ic, factor=sfactor) - _, iw = s[temp_S].split(w, factor=spart) - s[temp_S].bind(iic, thread_x) - s[temp_S].bind(iw, thread_y) + thread_xx = tvm.thread_axis((0, num_thread), "threadIdx.x") + block_xx = tvm.thread_axis("blockIdx.x") + + i = s[temp].fuse(*s[temp].op.axis) + bx, tx = s[temp].split(i, factor=num_thread) + s[temp].bind(tx, thread_xx) + s[temp].bind(bx, block_xx) - #schedule Filter_S shared mem load - i, oc, h, w = s[Filter_S].op.axis - _, ioc = s[Filter_S].split(oc, factor=sfactor) - _, ii = s[Filter_S].split(i, factor=spart) - s[Filter_S].bind(ioc, thread_x) - s[Filter_S].bind(ii, thread_y) + i = s[temp_R].fuse(*s[temp_R].op.axis) + bx, tx = s[temp_R].split(i, factor=num_thread) + s[temp_R].bind(tx, thread_xx) + s[temp_R].bind(bx, block_xx) + + #schedule temp_S shared mem load + i, oic, h, w, iic = s[temp_S].op.axis + oic = s[temp_S].fuse(oic, h, w) + ooic, ioic = s[temp_S].split(oic, factor=num_thread_x) + _, iooic = s[temp_S].split(ooic, factor=num_thread_y) + s[temp_S].bind(ioic, thread_x) + s[temp_S].bind(iooic, thread_y) + s[temp_S].vectorize(iic) + + #schedule Filter_S shared mem load + i, oc, h, w = s[Filter_S].op.axis + _, ioc = s[Filter_S].split(oc, factor=num_thread_x) + _, ii = s[Filter_S].split(i, factor=num_thread_y) + s[Filter_S].bind(ioc, thread_x) + s[Filter_S].bind(ii, thread_y) def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L): """Schedule conv2d for specific feature_in_out_filter pattern""" @@ -390,6 +401,15 @@ def schedule(temp, Filter, Output): s[temp_G].reorder(i, h, w, oic, iic) temp_R = s.cache_write(temp_G, "global") temp_S = s.cache_read(temp_R, "shared", [temp_G]) + elif 128 < flag < 512: + temp_G = s.cache_read(temp, "global", [Output]) + s[temp_G].compute_inline() + i, ic, h, w = s[temp_G].op.axis + oic, iic = s[temp_G].split(ic, factor=4) + s[temp_G].reorder(i, oic, h, w, iic) + temp_R = s.cache_write(temp_G, "global") + temp_S = s.cache_read(temp_R, "shared", [temp_G]) + else: s[temp].compute_inline() temp_S = s.cache_read(temp, "shared", [Output]) @@ -408,7 +428,7 @@ def schedule(temp, Filter, Output): if util.get_const_int(Filter.shape[3]) == 7: conv2d_224_3_64(s, temp_S, Filter_S, Out, Out_L) elif 128 < flag < 512: - conv2d_56_64_128(s, temp_S, Filter_S, Out, Out_L, flag) + conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag) elif flag >= 512: conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L) else: From 79940f39c144e1a2f9a5ed75cc85177464081c90 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sat, 23 Sep 2017 06:41:31 +0000 Subject: [PATCH 3/4] conv2d added packing for first workload --- topi/python/topi/cuda/conv2d_nchw.py | 32 +++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/topi/python/topi/cuda/conv2d_nchw.py b/topi/python/topi/cuda/conv2d_nchw.py index 6dbd53ea31d8..0024859fd892 100644 --- a/topi/python/topi/cuda/conv2d_nchw.py +++ b/topi/python/topi/cuda/conv2d_nchw.py @@ -4,7 +4,7 @@ from .. import util from .. import tag -def conv2d_224_3_64(s, temp_S, Filter_S, Out, Out_L): +def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L): """Schedule conv2d for specific feature_in_out_filter pattern""" # scheduler params ofactor = 16 @@ -36,10 +36,26 @@ def conv2d_224_3_64(s, temp_S, Filter_S, Out, Out_L): s[temp_S].compute_at(s[Out_L], ic) s[Filter_S].compute_at(s[Out_L], w) + num_thread1 = 512 + thread_xx = tvm.thread_axis((0, num_thread1), "threadIdx.x") + block_xx = tvm.thread_axis("blockIdx.x") + + i = s[temp].fuse(*s[temp].op.axis) + bx, tx = s[temp].split(i, factor=num_thread1) + s[temp].bind(tx, thread_xx) + s[temp].bind(bx, block_xx) + + i = s[temp_R].fuse(*s[temp_R].op.axis) + bx, tx = s[temp_R].split(i, factor=num_thread1) + s[temp_R].bind(tx, thread_xx) + s[temp_R].bind(bx, block_xx) + #schedule temp_S shared mem load - i, ic, h, w = s[temp_S].op.axis - tx, ih = s[temp_S].split(w, nparts=num_thread) + i, ic, h, ow, iw = s[temp_S].op.axis + h = s[temp_S].fuse(h, ow) + _, tx = s[temp_S].split(h, factor=num_thread) s[temp_S].bind(tx, thread_x) + s[temp_S].vectorize(iw) #schedule Filter_S shared mem load i, oc, h, w = s[Filter_S].op.axis @@ -409,7 +425,13 @@ def schedule(temp, Filter, Output): s[temp_G].reorder(i, oic, h, w, iic) temp_R = s.cache_write(temp_G, "global") temp_S = s.cache_read(temp_R, "shared", [temp_G]) - + elif util.get_const_int(Filter.shape[3]) == 7: + temp_G = s.cache_read(temp, "global", [Output]) + s[temp_G].compute_inline() + i, ic, h, w = s[temp_G].op.axis + ow, iw = s[temp_G].split(w, factor=4) + temp_R = s.cache_write(temp_G, "global") + temp_S = s.cache_read(temp_R, "shared", [temp_G]) else: s[temp].compute_inline() temp_S = s.cache_read(temp, "shared", [Output]) @@ -426,7 +448,7 @@ def schedule(temp, Filter, Output): Out_L = Output if util.get_const_int(Filter.shape[3]) == 7: - conv2d_224_3_64(s, temp_S, Filter_S, Out, Out_L) + conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L) elif 128 < flag < 512: conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag) elif flag >= 512: From 1661d70ddc84dba0cd31c0adf0276e9fe7987b46 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sat, 23 Sep 2017 22:53:27 +0000 Subject: [PATCH 4/4] fix pylint error --- topi/python/topi/cuda/conv2d_nchw.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/topi/python/topi/cuda/conv2d_nchw.py b/topi/python/topi/cuda/conv2d_nchw.py index 0024859fd892..f8bcab52cb9a 100644 --- a/topi/python/topi/cuda/conv2d_nchw.py +++ b/topi/python/topi/cuda/conv2d_nchw.py @@ -295,7 +295,7 @@ def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L): i, ic, h, w = s[temp].op.axis ic = s[temp].fuse(ic, h, w) - oic, iic = s[temp].split(ic, factor = rfactor) + oic, iic = s[temp].split(ic, factor=rfactor) s[temp].bind(iic, thread_xx) s[temp].bind(oic, block_xx) @@ -429,10 +429,10 @@ def schedule(temp, Filter, Output): temp_G = s.cache_read(temp, "global", [Output]) s[temp_G].compute_inline() i, ic, h, w = s[temp_G].op.axis - ow, iw = s[temp_G].split(w, factor=4) + s[temp_G].split(w, factor=4) temp_R = s.cache_write(temp_G, "global") temp_S = s.cache_read(temp_R, "shared", [temp_G]) - else: + else: s[temp].compute_inline() temp_S = s.cache_read(temp, "shared", [Output]) temp_R = temp_S