Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 134 additions & 65 deletions topi/python/topi/cuda/conv2d_nchw.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .. import util
from .. import tag

def conv2d_224_3_64(s, temp_S, Filter_S, Out, Out_L):
def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L):
"""Schedule conv2d for specific feature_in_out_filter pattern"""
# scheduler params
ofactor = 16
Expand Down Expand Up @@ -36,10 +36,26 @@ def conv2d_224_3_64(s, temp_S, Filter_S, Out, Out_L):
s[temp_S].compute_at(s[Out_L], ic)
s[Filter_S].compute_at(s[Out_L], w)

num_thread1 = 512
thread_xx = tvm.thread_axis((0, num_thread1), "threadIdx.x")
block_xx = tvm.thread_axis("blockIdx.x")

i = s[temp].fuse(*s[temp].op.axis)
bx, tx = s[temp].split(i, factor=num_thread1)
s[temp].bind(tx, thread_xx)
s[temp].bind(bx, block_xx)

i = s[temp_R].fuse(*s[temp_R].op.axis)
bx, tx = s[temp_R].split(i, factor=num_thread1)
s[temp_R].bind(tx, thread_xx)
s[temp_R].bind(bx, block_xx)

#schedule temp_S shared mem load
i, ic, h, w = s[temp_S].op.axis
tx, ih = s[temp_S].split(w, nparts=num_thread)
i, ic, h, ow, iw = s[temp_S].op.axis
h = s[temp_S].fuse(h, ow)
_, tx = s[temp_S].split(h, factor=num_thread)
s[temp_S].bind(tx, thread_x)
s[temp_S].vectorize(iw)

#schedule Filter_S shared mem load
i, oc, h, w = s[Filter_S].op.axis
Expand All @@ -48,7 +64,7 @@ def conv2d_224_3_64(s, temp_S, Filter_S, Out, Out_L):
tx, _ = s[Filter_S].split(w, nparts=num_thread)
s[Filter_S].bind(tx, thread_x)

def conv2d_56_64_128(s, temp_S, Filter_S, Out, Out_L, flag):
def conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag):
"""Schedule conv2d for specific feature_in_out_filter pattern"""
if util.get_const_int(Filter_S.shape[0]) == util.get_const_int(Filter_S.shape[1]):
num_thread_x = 8
Expand Down Expand Up @@ -89,13 +105,28 @@ def conv2d_56_64_128(s, temp_S, Filter_S, Out, Out_L, flag):
s[temp_S].compute_at(s[Out_L], oic)
s[Filter_S].compute_at(s[Out_L], dw)

num_thread = 512
thread_xx = tvm.thread_axis((0, num_thread), "threadIdx.x")
block_xx = tvm.thread_axis("blockIdx.x")

i = s[temp].fuse(*s[temp].op.axis)
bx, tx = s[temp].split(i, factor=num_thread)
s[temp].bind(tx, thread_xx)
s[temp].bind(bx, block_xx)

i = s[temp_R].fuse(*s[temp_R].op.axis)
bx, tx = s[temp_R].split(i, factor=num_thread)
s[temp_R].bind(tx, thread_xx)
s[temp_R].bind(bx, block_xx)

#schedule temp_S shared mem load
i, ic, h, w = s[temp_S].op.axis
_, iic = s[temp_S].split(ic, factor=num_thread_y)
w = s[temp_S].fuse(h, w)
_, iw = s[temp_S].split(w, factor=num_thread_x)
s[temp_S].bind(iic, thread_y)
s[temp_S].bind(iw, thread_x)
i, oic, h, w, iic = s[temp_S].op.axis
oic = s[temp_S].fuse(oic, h, w)
ooic, ioic = s[temp_S].split(oic, factor=num_thread_x)
_, iooic = s[temp_S].split(ooic, factor=num_thread_y)
s[temp_S].bind(ioic, thread_x)
s[temp_S].bind(iooic, thread_y)
s[temp_S].vectorize(iic)

i, oc, h, w = s[Filter_S].op.axis
_, ioc = s[Filter_S].split(oc, factor=num_thread_y)
Expand All @@ -104,21 +135,20 @@ def conv2d_56_64_128(s, temp_S, Filter_S, Out, Out_L, flag):
s[Filter_S].bind(ii, thread_x)
else:
# scheduler params
num_thread = 8
vthread = 2
opart2 = 4
ofactor = 64
wfactor = 28
ifactor = 8
if flag > 256:
wfactor = 14
sfactor = max(1, ofactor//(opart2*2))
spart = max(1, (wfactor + vthread-1) // vthread)
num_thread_x = max(1, ofactor//(opart2*2))
num_thread_y = max(1, (wfactor + vthread-1) // vthread)
block_x = tvm.thread_axis("blockIdx.x")
block_y = tvm.thread_axis("blockIdx.y")
block_z = tvm.thread_axis("blockIdx.z")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y")
thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x")
thread_y = tvm.thread_axis((0, num_thread_y), "threadIdx.y")
thread_xz = tvm.thread_axis((0, vthread), "vthread", name="vx")
thread_yz = tvm.thread_axis((0, vthread), "vthread", name="vy")

Expand All @@ -140,54 +170,51 @@ def conv2d_56_64_128(s, temp_S, Filter_S, Out, Out_L, flag):

s[Out_L].compute_at(s[Out], iiioc)

# schedule Out_L local write
i, oc, h, w = s[Out_L].op.axis
ic, dh, dw = s[Out_L].op.reduce_axis
oic, iic = s[Out_L].split(ic, factor=ifactor)
s[Out_L].reorder(oic, dh, dw, iic, h, w)
if util.get_const_int(Filter_S.shape[1]) == 128:
# schedule Out_L local write
i, oc, h, w = s[Out_L].op.axis
ic, dh, dw = s[Out_L].op.reduce_axis
oic, iic = s[Out_L].split(ic, factor=ifactor)
s[Out_L].reorder(oic, dh, dw, iic, h, w)
oic = s[Out_L].fuse(dh, oic)
s[temp_S].compute_at(s[Out_L], oic)
s[Filter_S].compute_at(s[Out_L], oic)

#schedule temp_S shared mem load
i, ic, h, w = s[temp_S].op.axis
_, iic = s[temp_S].split(ic, factor=sfactor)
_, iw = s[temp_S].split(w, factor=spart)
s[temp_S].bind(iic, thread_x)
s[temp_S].bind(iw, thread_y)

#schedule Filter_S shared mem load
i, oc, h, w = s[Filter_S].op.axis
_, ioc = s[Filter_S].split(oc, factor=sfactor)
_, ii = s[Filter_S].split(i, factor=spart)
s[Filter_S].bind(ioc, thread_x)
s[Filter_S].bind(ii, thread_y)
num_thread = 512
else:
# schedule Out_L local write
i, oc, h, w = s[Out_L].op.axis
ic, dh, dw = s[Out_L].op.reduce_axis
oic, iic = s[Out_L].split(ic, factor=ifactor)
s[Out_L].reorder(oic, dh, dw, iic, h, w)

s[temp_S].compute_at(s[Out_L], oic)
s[Filter_S].compute_at(s[Out_L], dw)
num_thread = 456

thread_xx = tvm.thread_axis((0, num_thread), "threadIdx.x")
block_xx = tvm.thread_axis("blockIdx.x")

#schedule temp_S shared mem load
i, ic, h, w = s[temp_S].op.axis
_, iic = s[temp_S].split(ic, factor=sfactor)
_, iw = s[temp_S].split(w, factor=spart)
s[temp_S].bind(iic, thread_x)
s[temp_S].bind(iw, thread_y)

#schedule Filter_S shared mem load
i, oc, h, w = s[Filter_S].op.axis
_, ioc = s[Filter_S].split(oc, factor=sfactor)
_, ii = s[Filter_S].split(i, factor=spart)
s[Filter_S].bind(ioc, thread_x)
s[Filter_S].bind(ii, thread_y)

def conv2d_14_256_256(s, Filter, temp_S, Filter_S, Out, Out_L):
i = s[temp].fuse(*s[temp].op.axis)
bx, tx = s[temp].split(i, factor=num_thread)
s[temp].bind(tx, thread_xx)
s[temp].bind(bx, block_xx)

i = s[temp_R].fuse(*s[temp_R].op.axis)
bx, tx = s[temp_R].split(i, factor=num_thread)
s[temp_R].bind(tx, thread_xx)
s[temp_R].bind(bx, block_xx)

#schedule temp_S shared mem load
i, oic, h, w, iic = s[temp_S].op.axis
oic = s[temp_S].fuse(oic, h, w)
ooic, ioic = s[temp_S].split(oic, factor=num_thread_x)
_, iooic = s[temp_S].split(ooic, factor=num_thread_y)
s[temp_S].bind(ioic, thread_x)
s[temp_S].bind(iooic, thread_y)
s[temp_S].vectorize(iic)

#schedule Filter_S shared mem load
i, oc, h, w = s[Filter_S].op.axis
_, ioc = s[Filter_S].split(oc, factor=num_thread_x)
_, ii = s[Filter_S].split(i, factor=num_thread_y)
s[Filter_S].bind(ioc, thread_x)
s[Filter_S].bind(ii, thread_y)

def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L):
"""Schedule conv2d for specific feature_in_out_filter pattern"""
if util.get_const_int(Filter.shape[1]) == 256:
# scheduler params
Expand Down Expand Up @@ -262,13 +289,30 @@ def conv2d_14_256_256(s, Filter, temp_S, Filter_S, Out, Out_L):
s[temp_S].compute_at(s[Out_L], oic)
s[Filter_S].compute_at(s[Out_L], oic)

rfactor = util.get_const_int(Filter.shape[1])
thread_xx = tvm.thread_axis((0, rfactor), "threadIdx.x")
block_xx = tvm.thread_axis("blockIdx.x")

i, ic, h, w = s[temp].op.axis
ic = s[temp].fuse(ic, h, w)
oic, iic = s[temp].split(ic, factor=rfactor)
s[temp].bind(iic, thread_xx)
s[temp].bind(oic, block_xx)

i, h, w, oic, iic = s[temp_R].op.axis
ic = s[temp_R].fuse(oic, iic)
s[temp_R].bind(ic, thread_xx)
h = s[temp_R].fuse(h, w)
s[temp_R].bind(h, block_xx)

#schedule temp_S shared mem load
i, ic, h, w = s[temp_S].op.axis
ic = s[temp_S].fuse(w, h, ic)
oic, iic = s[temp_S].split(ic, factor=num_thread_x)
i, h, w, oc, ic = s[temp_S].op.axis
icc = s[temp_S].fuse(oc, w, h)
oic, iic = s[temp_S].split(icc, factor=num_thread_x)
_, ioic = s[temp_S].split(oic, factor=num_thread_y)
s[temp_S].bind(iic, thread_x)
s[temp_S].bind(ioic, thread_y)
s[temp_S].vectorize(ic)

#schedule Filter_S shared mem load
i, oc, h, w = s[Filter_S].op.axis
Expand Down Expand Up @@ -363,9 +407,36 @@ def schedule(temp, Filter, Output):
elif block_w % 32 == 0:
block_w = 32

s[temp].compute_inline()
flag = util.get_const_int(Filter.shape[0])+util.get_const_int(Filter.shape[1])

if flag > 768:
temp_G = s.cache_read(temp, "global", [Output])
s[temp_G].compute_inline()
i, ic, h, w = s[temp_G].op.axis
oic, iic = s[temp_G].split(ic, factor=4)
s[temp_G].reorder(i, h, w, oic, iic)
temp_R = s.cache_write(temp_G, "global")
temp_S = s.cache_read(temp_R, "shared", [temp_G])
elif 128 < flag < 512:
temp_G = s.cache_read(temp, "global", [Output])
s[temp_G].compute_inline()
i, ic, h, w = s[temp_G].op.axis
oic, iic = s[temp_G].split(ic, factor=4)
s[temp_G].reorder(i, oic, h, w, iic)
temp_R = s.cache_write(temp_G, "global")
temp_S = s.cache_read(temp_R, "shared", [temp_G])
elif util.get_const_int(Filter.shape[3]) == 7:
temp_G = s.cache_read(temp, "global", [Output])
s[temp_G].compute_inline()
i, ic, h, w = s[temp_G].op.axis
s[temp_G].split(w, factor=4)
temp_R = s.cache_write(temp_G, "global")
temp_S = s.cache_read(temp_R, "shared", [temp_G])
else:
s[temp].compute_inline()
temp_S = s.cache_read(temp, "shared", [Output])
temp_R = temp_S

temp_S = s.cache_read(temp, "shared", [Output])
Filter_S = s.cache_read(Filter, "shared", [Output])

if Output.op in s.outputs:
Expand All @@ -376,14 +447,12 @@ def schedule(temp, Filter, Output):
s[Output].set_scope("local")
Out_L = Output

flag = util.get_const_int(Filter.shape[0])+util.get_const_int(Filter.shape[1])

if util.get_const_int(Filter.shape[3]) == 7:
conv2d_224_3_64(s, temp_S, Filter_S, Out, Out_L)
conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L)
elif 128 < flag < 512:
conv2d_56_64_128(s, temp_S, Filter_S, Out, Out_L, flag)
conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag)
elif flag >= 512:
conv2d_14_256_256(s, Filter, temp_S, Filter_S, Out, Out_L)
conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L)
else:
conv2d_56_64_64(s, Filter, temp_S, Filter_S, Out, Out_L)

Expand Down