diff --git a/topi/python/topi/cuda/conv2d_nchw.py b/topi/python/topi/cuda/conv2d_nchw.py index 86b0cb6236e4..f8bcab52cb9a 100644 --- a/topi/python/topi/cuda/conv2d_nchw.py +++ b/topi/python/topi/cuda/conv2d_nchw.py @@ -4,7 +4,7 @@ from .. import util from .. import tag -def conv2d_224_3_64(s, temp_S, Filter_S, Out, Out_L): +def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L): """Schedule conv2d for specific feature_in_out_filter pattern""" # scheduler params ofactor = 16 @@ -36,10 +36,26 @@ def conv2d_224_3_64(s, temp_S, Filter_S, Out, Out_L): s[temp_S].compute_at(s[Out_L], ic) s[Filter_S].compute_at(s[Out_L], w) + num_thread1 = 512 + thread_xx = tvm.thread_axis((0, num_thread1), "threadIdx.x") + block_xx = tvm.thread_axis("blockIdx.x") + + i = s[temp].fuse(*s[temp].op.axis) + bx, tx = s[temp].split(i, factor=num_thread1) + s[temp].bind(tx, thread_xx) + s[temp].bind(bx, block_xx) + + i = s[temp_R].fuse(*s[temp_R].op.axis) + bx, tx = s[temp_R].split(i, factor=num_thread1) + s[temp_R].bind(tx, thread_xx) + s[temp_R].bind(bx, block_xx) + #schedule temp_S shared mem load - i, ic, h, w = s[temp_S].op.axis - tx, ih = s[temp_S].split(w, nparts=num_thread) + i, ic, h, ow, iw = s[temp_S].op.axis + h = s[temp_S].fuse(h, ow) + _, tx = s[temp_S].split(h, factor=num_thread) s[temp_S].bind(tx, thread_x) + s[temp_S].vectorize(iw) #schedule Filter_S shared mem load i, oc, h, w = s[Filter_S].op.axis @@ -48,7 +64,7 @@ def conv2d_224_3_64(s, temp_S, Filter_S, Out, Out_L): tx, _ = s[Filter_S].split(w, nparts=num_thread) s[Filter_S].bind(tx, thread_x) -def conv2d_56_64_128(s, temp_S, Filter_S, Out, Out_L, flag): +def conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag): """Schedule conv2d for specific feature_in_out_filter pattern""" if util.get_const_int(Filter_S.shape[0]) == util.get_const_int(Filter_S.shape[1]): num_thread_x = 8 @@ -89,13 +105,28 @@ def conv2d_56_64_128(s, temp_S, Filter_S, Out, Out_L, flag): s[temp_S].compute_at(s[Out_L], oic) s[Filter_S].compute_at(s[Out_L], dw) + num_thread = 512 + thread_xx = tvm.thread_axis((0, num_thread), "threadIdx.x") + block_xx = tvm.thread_axis("blockIdx.x") + + i = s[temp].fuse(*s[temp].op.axis) + bx, tx = s[temp].split(i, factor=num_thread) + s[temp].bind(tx, thread_xx) + s[temp].bind(bx, block_xx) + + i = s[temp_R].fuse(*s[temp_R].op.axis) + bx, tx = s[temp_R].split(i, factor=num_thread) + s[temp_R].bind(tx, thread_xx) + s[temp_R].bind(bx, block_xx) + #schedule temp_S shared mem load - i, ic, h, w = s[temp_S].op.axis - _, iic = s[temp_S].split(ic, factor=num_thread_y) - w = s[temp_S].fuse(h, w) - _, iw = s[temp_S].split(w, factor=num_thread_x) - s[temp_S].bind(iic, thread_y) - s[temp_S].bind(iw, thread_x) + i, oic, h, w, iic = s[temp_S].op.axis + oic = s[temp_S].fuse(oic, h, w) + ooic, ioic = s[temp_S].split(oic, factor=num_thread_x) + _, iooic = s[temp_S].split(ooic, factor=num_thread_y) + s[temp_S].bind(ioic, thread_x) + s[temp_S].bind(iooic, thread_y) + s[temp_S].vectorize(iic) i, oc, h, w = s[Filter_S].op.axis _, ioc = s[Filter_S].split(oc, factor=num_thread_y) @@ -104,7 +135,6 @@ def conv2d_56_64_128(s, temp_S, Filter_S, Out, Out_L, flag): s[Filter_S].bind(ii, thread_x) else: # scheduler params - num_thread = 8 vthread = 2 opart2 = 4 ofactor = 64 @@ -112,13 +142,13 @@ def conv2d_56_64_128(s, temp_S, Filter_S, Out, Out_L, flag): ifactor = 8 if flag > 256: wfactor = 14 - sfactor = max(1, ofactor//(opart2*2)) - spart = max(1, (wfactor + vthread-1) // vthread) + num_thread_x = max(1, ofactor//(opart2*2)) + num_thread_y = max(1, (wfactor + vthread-1) // vthread) block_x = tvm.thread_axis("blockIdx.x") block_y = tvm.thread_axis("blockIdx.y") block_z = tvm.thread_axis("blockIdx.z") - thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") - thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y") + thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x") + thread_y = tvm.thread_axis((0, num_thread_y), "threadIdx.y") thread_xz = tvm.thread_axis((0, vthread), "vthread", name="vx") thread_yz = tvm.thread_axis((0, vthread), "vthread", name="vy") @@ -140,54 +170,51 @@ def conv2d_56_64_128(s, temp_S, Filter_S, Out, Out_L, flag): s[Out_L].compute_at(s[Out], iiioc) + # schedule Out_L local write + i, oc, h, w = s[Out_L].op.axis + ic, dh, dw = s[Out_L].op.reduce_axis + oic, iic = s[Out_L].split(ic, factor=ifactor) + s[Out_L].reorder(oic, dh, dw, iic, h, w) if util.get_const_int(Filter_S.shape[1]) == 128: - # schedule Out_L local write - i, oc, h, w = s[Out_L].op.axis - ic, dh, dw = s[Out_L].op.reduce_axis - oic, iic = s[Out_L].split(ic, factor=ifactor) - s[Out_L].reorder(oic, dh, dw, iic, h, w) oic = s[Out_L].fuse(dh, oic) s[temp_S].compute_at(s[Out_L], oic) s[Filter_S].compute_at(s[Out_L], oic) - - #schedule temp_S shared mem load - i, ic, h, w = s[temp_S].op.axis - _, iic = s[temp_S].split(ic, factor=sfactor) - _, iw = s[temp_S].split(w, factor=spart) - s[temp_S].bind(iic, thread_x) - s[temp_S].bind(iw, thread_y) - - #schedule Filter_S shared mem load - i, oc, h, w = s[Filter_S].op.axis - _, ioc = s[Filter_S].split(oc, factor=sfactor) - _, ii = s[Filter_S].split(i, factor=spart) - s[Filter_S].bind(ioc, thread_x) - s[Filter_S].bind(ii, thread_y) + num_thread = 512 else: - # schedule Out_L local write - i, oc, h, w = s[Out_L].op.axis - ic, dh, dw = s[Out_L].op.reduce_axis - oic, iic = s[Out_L].split(ic, factor=ifactor) - s[Out_L].reorder(oic, dh, dw, iic, h, w) - s[temp_S].compute_at(s[Out_L], oic) s[Filter_S].compute_at(s[Out_L], dw) + num_thread = 456 + + thread_xx = tvm.thread_axis((0, num_thread), "threadIdx.x") + block_xx = tvm.thread_axis("blockIdx.x") - #schedule temp_S shared mem load - i, ic, h, w = s[temp_S].op.axis - _, iic = s[temp_S].split(ic, factor=sfactor) - _, iw = s[temp_S].split(w, factor=spart) - s[temp_S].bind(iic, thread_x) - s[temp_S].bind(iw, thread_y) - - #schedule Filter_S shared mem load - i, oc, h, w = s[Filter_S].op.axis - _, ioc = s[Filter_S].split(oc, factor=sfactor) - _, ii = s[Filter_S].split(i, factor=spart) - s[Filter_S].bind(ioc, thread_x) - s[Filter_S].bind(ii, thread_y) - -def conv2d_14_256_256(s, Filter, temp_S, Filter_S, Out, Out_L): + i = s[temp].fuse(*s[temp].op.axis) + bx, tx = s[temp].split(i, factor=num_thread) + s[temp].bind(tx, thread_xx) + s[temp].bind(bx, block_xx) + + i = s[temp_R].fuse(*s[temp_R].op.axis) + bx, tx = s[temp_R].split(i, factor=num_thread) + s[temp_R].bind(tx, thread_xx) + s[temp_R].bind(bx, block_xx) + + #schedule temp_S shared mem load + i, oic, h, w, iic = s[temp_S].op.axis + oic = s[temp_S].fuse(oic, h, w) + ooic, ioic = s[temp_S].split(oic, factor=num_thread_x) + _, iooic = s[temp_S].split(ooic, factor=num_thread_y) + s[temp_S].bind(ioic, thread_x) + s[temp_S].bind(iooic, thread_y) + s[temp_S].vectorize(iic) + + #schedule Filter_S shared mem load + i, oc, h, w = s[Filter_S].op.axis + _, ioc = s[Filter_S].split(oc, factor=num_thread_x) + _, ii = s[Filter_S].split(i, factor=num_thread_y) + s[Filter_S].bind(ioc, thread_x) + s[Filter_S].bind(ii, thread_y) + +def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L): """Schedule conv2d for specific feature_in_out_filter pattern""" if util.get_const_int(Filter.shape[1]) == 256: # scheduler params @@ -262,13 +289,30 @@ def conv2d_14_256_256(s, Filter, temp_S, Filter_S, Out, Out_L): s[temp_S].compute_at(s[Out_L], oic) s[Filter_S].compute_at(s[Out_L], oic) + rfactor = util.get_const_int(Filter.shape[1]) + thread_xx = tvm.thread_axis((0, rfactor), "threadIdx.x") + block_xx = tvm.thread_axis("blockIdx.x") + + i, ic, h, w = s[temp].op.axis + ic = s[temp].fuse(ic, h, w) + oic, iic = s[temp].split(ic, factor=rfactor) + s[temp].bind(iic, thread_xx) + s[temp].bind(oic, block_xx) + + i, h, w, oic, iic = s[temp_R].op.axis + ic = s[temp_R].fuse(oic, iic) + s[temp_R].bind(ic, thread_xx) + h = s[temp_R].fuse(h, w) + s[temp_R].bind(h, block_xx) + #schedule temp_S shared mem load - i, ic, h, w = s[temp_S].op.axis - ic = s[temp_S].fuse(w, h, ic) - oic, iic = s[temp_S].split(ic, factor=num_thread_x) + i, h, w, oc, ic = s[temp_S].op.axis + icc = s[temp_S].fuse(oc, w, h) + oic, iic = s[temp_S].split(icc, factor=num_thread_x) _, ioic = s[temp_S].split(oic, factor=num_thread_y) s[temp_S].bind(iic, thread_x) s[temp_S].bind(ioic, thread_y) + s[temp_S].vectorize(ic) #schedule Filter_S shared mem load i, oc, h, w = s[Filter_S].op.axis @@ -363,9 +407,36 @@ def schedule(temp, Filter, Output): elif block_w % 32 == 0: block_w = 32 - s[temp].compute_inline() + flag = util.get_const_int(Filter.shape[0])+util.get_const_int(Filter.shape[1]) + + if flag > 768: + temp_G = s.cache_read(temp, "global", [Output]) + s[temp_G].compute_inline() + i, ic, h, w = s[temp_G].op.axis + oic, iic = s[temp_G].split(ic, factor=4) + s[temp_G].reorder(i, h, w, oic, iic) + temp_R = s.cache_write(temp_G, "global") + temp_S = s.cache_read(temp_R, "shared", [temp_G]) + elif 128 < flag < 512: + temp_G = s.cache_read(temp, "global", [Output]) + s[temp_G].compute_inline() + i, ic, h, w = s[temp_G].op.axis + oic, iic = s[temp_G].split(ic, factor=4) + s[temp_G].reorder(i, oic, h, w, iic) + temp_R = s.cache_write(temp_G, "global") + temp_S = s.cache_read(temp_R, "shared", [temp_G]) + elif util.get_const_int(Filter.shape[3]) == 7: + temp_G = s.cache_read(temp, "global", [Output]) + s[temp_G].compute_inline() + i, ic, h, w = s[temp_G].op.axis + s[temp_G].split(w, factor=4) + temp_R = s.cache_write(temp_G, "global") + temp_S = s.cache_read(temp_R, "shared", [temp_G]) + else: + s[temp].compute_inline() + temp_S = s.cache_read(temp, "shared", [Output]) + temp_R = temp_S - temp_S = s.cache_read(temp, "shared", [Output]) Filter_S = s.cache_read(Filter, "shared", [Output]) if Output.op in s.outputs: @@ -376,14 +447,12 @@ def schedule(temp, Filter, Output): s[Output].set_scope("local") Out_L = Output - flag = util.get_const_int(Filter.shape[0])+util.get_const_int(Filter.shape[1]) - if util.get_const_int(Filter.shape[3]) == 7: - conv2d_224_3_64(s, temp_S, Filter_S, Out, Out_L) + conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L) elif 128 < flag < 512: - conv2d_56_64_128(s, temp_S, Filter_S, Out, Out_L, flag) + conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag) elif flag >= 512: - conv2d_14_256_256(s, Filter, temp_S, Filter_S, Out, Out_L) + conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L) else: conv2d_56_64_64(s, Filter, temp_S, Filter_S, Out, Out_L)