From ab0158320e72c3b60a0d1c923b92bf1703ec34d3 Mon Sep 17 00:00:00 2001 From: masa Date: Sun, 5 Sep 2021 18:41:01 +0900 Subject: [PATCH 1/4] use reduction schedule for adaptive pool --- python/tvm/topi/cuda/pooling.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/cuda/pooling.py b/python/tvm/topi/cuda/pooling.py index f2a6aadb659f..d8caefc76ae2 100644 --- a/python/tvm/topi/cuda/pooling.py +++ b/python/tvm/topi/cuda/pooling.py @@ -70,6 +70,7 @@ def _schedule(Pool): scheduled_ops = [] def traverse(OP): + nonlocal s """Internal traverse function""" # inline all one-to-one-mapping operators except the last stage (output) if tag.is_broadcast(OP.tag): @@ -80,8 +81,12 @@ def traverse(OP): traverse(tensor.op) # schedule global_pool elif OP.tag.startswith("adaptive_pool"): - Pool = OP.output(0) - _schedule(Pool) + # Pool = OP.output(0) + # _schedule(Pool) + from .reduction import _schedule_reduce + from .injective import schedule_injective_from_existing + _schedule_reduce(OP, s) + s = schedule_injective_from_existing(s, outs[0]) else: raise RuntimeError("Unsupported operator: %s" % OP.tag) From f81e5bb5ab46beed26aaf3a3bbd9414f039f71e9 Mon Sep 17 00:00:00 2001 From: masa Date: Mon, 6 Sep 2021 18:27:32 +0900 Subject: [PATCH 2/4] improve adaptive pool schedule --- python/tvm/topi/cuda/pooling.py | 37 ++++++++++++++++++++++++++------- 1 file changed, 30 insertions(+), 7 deletions(-) diff --git a/python/tvm/topi/cuda/pooling.py b/python/tvm/topi/cuda/pooling.py index d8caefc76ae2..e3a832470a6a 100644 --- a/python/tvm/topi/cuda/pooling.py +++ b/python/tvm/topi/cuda/pooling.py @@ -20,6 +20,8 @@ from tvm import te from .. import tag from ..utils import traverse_inline +from .reduction import _schedule_reduce +from .injective import schedule_injective_from_existing def schedule_adaptive_pool(outs, layout="NCHW"): @@ -55,17 +57,33 @@ def _schedule(Pool): by, ty = s[Out].split(s[Out].op.axis[0], factor=num_thread) if layout == "NHWC": bx, tx = s[Out].split(s[Out].op.axis[3], factor=num_thread) + fused_hw_axis = s[Out].fuse(*s[Out].op.axis[1:3]) + fused_hw_size = Out.shape[1] * Out.shape[2] else: bx, tx = s[Out].split(s[Out].op.axis[1], factor=num_thread) + fused_hw_axis = s[Out].fuse(*s[Out].op.axis[2:4]) + fused_hw_size = Out.shape[2] * Out.shape[3] + s[Out].reorder(by, bx, ty, tx) s[Out].bind(ty, thread_y) s[Out].bind(tx, thread_x) s[Out].bind(by, block_y) s[Out].bind(bx, block_x) + + max_threads = tvm.tir.floordiv( + tvm.tir.min( + fused_hw_size, int(tvm.target.Target.current(allow_none=False).max_num_threads) + ), + (num_thread * num_thread), + ) + bz, tz = s[Out].split(fused_hw_axis, factor=max_threads) + s[Out].bind(bz, te.thread_axis("blockIdx.z")) + s[Out].bind(tz, te.thread_axis("threadIdx.z")) + if Pool.op in s.outputs: s[OL].compute_at(s[Out], tx) else: - s[Pool].compute_at(s[Out], tx) + s[Pool].compute_at(s[Out], tz) scheduled_ops = [] @@ -81,12 +99,17 @@ def traverse(OP): traverse(tensor.op) # schedule global_pool elif OP.tag.startswith("adaptive_pool"): - # Pool = OP.output(0) - # _schedule(Pool) - from .reduction import _schedule_reduce - from .injective import schedule_injective_from_existing - _schedule_reduce(OP, s) - s = schedule_injective_from_existing(s, outs[0]) + Pool = OP.output(0) + oshape = Pool.shape + if (layout == "NCHW" and oshape[2] == 1 and oshape[3] == 1) or ( + layout == "NHWC" and oshape[1] == 1 and oshape[2] == 1 + ): + _schedule_reduce(OP, s) + if OP.tag == "adaptive_pool_sum": + schedule_injective_from_existing(s, outs[0]) # the final division + else: + print("foo") + _schedule(Pool) else: raise RuntimeError("Unsupported operator: %s" % OP.tag) From 87a10ec5de6a3bc6fc400fc0811ce2fc116895ac Mon Sep 17 00:00:00 2001 From: masa Date: Mon, 6 Sep 2021 18:37:07 +0900 Subject: [PATCH 3/4] simplify --- python/tvm/topi/cuda/pooling.py | 43 ++++++--------------------------- 1 file changed, 8 insertions(+), 35 deletions(-) diff --git a/python/tvm/topi/cuda/pooling.py b/python/tvm/topi/cuda/pooling.py index e3a832470a6a..1a4f67552a7c 100644 --- a/python/tvm/topi/cuda/pooling.py +++ b/python/tvm/topi/cuda/pooling.py @@ -41,12 +41,7 @@ def schedule_adaptive_pool(outs, layout="NCHW"): outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs s = te.create_schedule([x.op for x in outs]) - def _schedule(Pool): - num_thread = 8 - block_x = te.thread_axis("blockIdx.x") - block_y = te.thread_axis("blockIdx.y") - thread_x = te.thread_axis((0, num_thread), "threadIdx.x") - thread_y = te.thread_axis((0, num_thread), "threadIdx.y") + def _schedule_non_global(Pool): if Pool.op in s.outputs: Out = Pool OL = s.cache_write(Pool, "local") @@ -54,41 +49,20 @@ def _schedule(Pool): Out = outs[0].op.output(0) s[Pool].set_scope("local") - by, ty = s[Out].split(s[Out].op.axis[0], factor=num_thread) - if layout == "NHWC": - bx, tx = s[Out].split(s[Out].op.axis[3], factor=num_thread) - fused_hw_axis = s[Out].fuse(*s[Out].op.axis[1:3]) - fused_hw_size = Out.shape[1] * Out.shape[2] - else: - bx, tx = s[Out].split(s[Out].op.axis[1], factor=num_thread) - fused_hw_axis = s[Out].fuse(*s[Out].op.axis[2:4]) - fused_hw_size = Out.shape[2] * Out.shape[3] - - s[Out].reorder(by, bx, ty, tx) - s[Out].bind(ty, thread_y) - s[Out].bind(tx, thread_x) - s[Out].bind(by, block_y) - s[Out].bind(bx, block_x) - - max_threads = tvm.tir.floordiv( - tvm.tir.min( - fused_hw_size, int(tvm.target.Target.current(allow_none=False).max_num_threads) - ), - (num_thread * num_thread), - ) - bz, tz = s[Out].split(fused_hw_axis, factor=max_threads) - s[Out].bind(bz, te.thread_axis("blockIdx.z")) - s[Out].bind(tz, te.thread_axis("threadIdx.z")) + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + fused_axis = s[Out].fuse(*s[Out].op.axis) + bx, tx = s[Out].split(fused_axis, factor=max_threads) + s[Out].bind(bx, te.thread_axis("blockIdx.x")) + s[Out].bind(tx, te.thread_axis("threadIdx.x")) if Pool.op in s.outputs: s[OL].compute_at(s[Out], tx) else: - s[Pool].compute_at(s[Out], tz) + s[Pool].compute_at(s[Out], tx) scheduled_ops = [] def traverse(OP): - nonlocal s """Internal traverse function""" # inline all one-to-one-mapping operators except the last stage (output) if tag.is_broadcast(OP.tag): @@ -108,8 +82,7 @@ def traverse(OP): if OP.tag == "adaptive_pool_sum": schedule_injective_from_existing(s, outs[0]) # the final division else: - print("foo") - _schedule(Pool) + _schedule_non_global(Pool) else: raise RuntimeError("Unsupported operator: %s" % OP.tag) From 5c68288b839b8a285c34f5c6232e067d50f8b662 Mon Sep 17 00:00:00 2001 From: masa Date: Mon, 6 Sep 2021 18:49:33 +0900 Subject: [PATCH 4/4] fixed global pooling with fused ops --- python/tvm/topi/cuda/pooling.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/cuda/pooling.py b/python/tvm/topi/cuda/pooling.py index 1a4f67552a7c..7ddbea27c174 100644 --- a/python/tvm/topi/cuda/pooling.py +++ b/python/tvm/topi/cuda/pooling.py @@ -79,8 +79,9 @@ def traverse(OP): layout == "NHWC" and oshape[1] == 1 and oshape[2] == 1 ): _schedule_reduce(OP, s) - if OP.tag == "adaptive_pool_sum": - schedule_injective_from_existing(s, outs[0]) # the final division + if OP != outs[0].op: + # the final division for adaptive pool or fused elemwise ops + schedule_injective_from_existing(s, outs[0]) else: _schedule_non_global(Pool) else: