Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 19 additions & 17 deletions python/tvm/topi/cuda/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from tvm import te
from .. import tag
from ..utils import traverse_inline
from .reduction import _schedule_reduce
from .injective import schedule_injective_from_existing


def schedule_adaptive_pool(outs, layout="NCHW"):
Expand All @@ -39,29 +41,20 @@ def schedule_adaptive_pool(outs, layout="NCHW"):
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])

def _schedule(Pool):
num_thread = 8
block_x = te.thread_axis("blockIdx.x")
block_y = te.thread_axis("blockIdx.y")
thread_x = te.thread_axis((0, num_thread), "threadIdx.x")
thread_y = te.thread_axis((0, num_thread), "threadIdx.y")
def _schedule_non_global(Pool):
if Pool.op in s.outputs:
Out = Pool
OL = s.cache_write(Pool, "local")
else:
Out = outs[0].op.output(0)
s[Pool].set_scope("local")

by, ty = s[Out].split(s[Out].op.axis[0], factor=num_thread)
if layout == "NHWC":
bx, tx = s[Out].split(s[Out].op.axis[3], factor=num_thread)
else:
bx, tx = s[Out].split(s[Out].op.axis[1], factor=num_thread)
s[Out].reorder(by, bx, ty, tx)
s[Out].bind(ty, thread_y)
s[Out].bind(tx, thread_x)
s[Out].bind(by, block_y)
s[Out].bind(bx, block_x)
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
fused_axis = s[Out].fuse(*s[Out].op.axis)
bx, tx = s[Out].split(fused_axis, factor=max_threads)
s[Out].bind(bx, te.thread_axis("blockIdx.x"))
s[Out].bind(tx, te.thread_axis("threadIdx.x"))

if Pool.op in s.outputs:
s[OL].compute_at(s[Out], tx)
else:
Expand All @@ -81,7 +74,16 @@ def traverse(OP):
# schedule global_pool
elif OP.tag.startswith("adaptive_pool"):
Pool = OP.output(0)
_schedule(Pool)
oshape = Pool.shape
if (layout == "NCHW" and oshape[2] == 1 and oshape[3] == 1) or (
layout == "NHWC" and oshape[1] == 1 and oshape[2] == 1
):
_schedule_reduce(OP, s)
if OP != outs[0].op:
# the final division for adaptive pool or fused elemwise ops
schedule_injective_from_existing(s, outs[0])
else:
_schedule_non_global(Pool)
else:
raise RuntimeError("Unsupported operator: %s" % OP.tag)

Expand Down