diff --git a/python/tvm/topi/cuda/pooling.py b/python/tvm/topi/cuda/pooling.py index f2a6aadb659f..7ddbea27c174 100644 --- a/python/tvm/topi/cuda/pooling.py +++ b/python/tvm/topi/cuda/pooling.py @@ -20,6 +20,8 @@ from tvm import te from .. import tag from ..utils import traverse_inline +from .reduction import _schedule_reduce +from .injective import schedule_injective_from_existing def schedule_adaptive_pool(outs, layout="NCHW"): @@ -39,12 +41,7 @@ def schedule_adaptive_pool(outs, layout="NCHW"): outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs s = te.create_schedule([x.op for x in outs]) - def _schedule(Pool): - num_thread = 8 - block_x = te.thread_axis("blockIdx.x") - block_y = te.thread_axis("blockIdx.y") - thread_x = te.thread_axis((0, num_thread), "threadIdx.x") - thread_y = te.thread_axis((0, num_thread), "threadIdx.y") + def _schedule_non_global(Pool): if Pool.op in s.outputs: Out = Pool OL = s.cache_write(Pool, "local") @@ -52,16 +49,12 @@ def _schedule(Pool): Out = outs[0].op.output(0) s[Pool].set_scope("local") - by, ty = s[Out].split(s[Out].op.axis[0], factor=num_thread) - if layout == "NHWC": - bx, tx = s[Out].split(s[Out].op.axis[3], factor=num_thread) - else: - bx, tx = s[Out].split(s[Out].op.axis[1], factor=num_thread) - s[Out].reorder(by, bx, ty, tx) - s[Out].bind(ty, thread_y) - s[Out].bind(tx, thread_x) - s[Out].bind(by, block_y) - s[Out].bind(bx, block_x) + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + fused_axis = s[Out].fuse(*s[Out].op.axis) + bx, tx = s[Out].split(fused_axis, factor=max_threads) + s[Out].bind(bx, te.thread_axis("blockIdx.x")) + s[Out].bind(tx, te.thread_axis("threadIdx.x")) + if Pool.op in s.outputs: s[OL].compute_at(s[Out], tx) else: @@ -81,7 +74,16 @@ def traverse(OP): # schedule global_pool elif OP.tag.startswith("adaptive_pool"): Pool = OP.output(0) - _schedule(Pool) + oshape = Pool.shape + if (layout == "NCHW" and oshape[2] == 1 and oshape[3] == 1) or ( + layout == "NHWC" and oshape[1] == 1 and oshape[2] == 1 + ): + _schedule_reduce(OP, s) + if OP != outs[0].op: + # the final division for adaptive pool or fused elemwise ops + schedule_injective_from_existing(s, outs[0]) + else: + _schedule_non_global(Pool) else: raise RuntimeError("Unsupported operator: %s" % OP.tag)