From 8edb1e518e7bea38491b3d1e71b867977240827f Mon Sep 17 00:00:00 2001 From: Andrey Malyshev Date: Wed, 7 Dec 2022 09:04:03 +0200 Subject: [PATCH 1/4] [Adreno] Add global pooling schedule The parallelizm opportuninties in case of global pooling are limited by number of channels, need to change schedule to have parallelizm by reduction axis/use rfactor --- python/tvm/relay/op/strategy/adreno.py | 7 + python/tvm/topi/adreno/pooling.py | 113 +++++++++++++++ .../relay/opencl_texture/test_pool_texture.py | 135 ++++++++++++++++++ 3 files changed, 255 insertions(+) create mode 100644 tests/python/relay/opencl_texture/test_pool_texture.py diff --git a/python/tvm/relay/op/strategy/adreno.py b/python/tvm/relay/op/strategy/adreno.py index 21252215fc28..3325018a16e2 100644 --- a/python/tvm/relay/op/strategy/adreno.py +++ b/python/tvm/relay/op/strategy/adreno.py @@ -215,6 +215,13 @@ def schedule_reduce_adreno(attrs, outs, target): return topi.adreno.schedule_reduce(outs) +@schedule_adaptive_pool.register(["adreno"]) +def schedule_adaptive_pool_cuda(attrs, outs, target): + """schedule adaptive pooling ops for cuda""" + with target: + return topi.adreno.schedule_adaptive_pool(outs, attrs.layout) + + @concatenate_strategy.register(["adreno"]) def concatenate_strategy_adreno(attrs, inputs, out_type, target): strategy = _op.OpStrategy() diff --git a/python/tvm/topi/adreno/pooling.py b/python/tvm/topi/adreno/pooling.py index 49f103c04a2f..150bfe3dc110 100644 --- a/python/tvm/topi/adreno/pooling.py +++ b/python/tvm/topi/adreno/pooling.py @@ -19,6 +19,119 @@ import tvm from tvm import te from .. import tag +from ..utils import traverse_inline +from .reduction import _schedule_reduce_adreno +from ..cuda.reduction import _schedule_reduce +from .injective import schedule_injective_from_existing +from .utils import get_div + + +def schedule_adaptive_pool(outs, layout="NCHW"): + """Schedule for adaptive_pool. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of adaptive_pool + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for adaptive_pool. + """ + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + + def _schedule_global(Pool, layout): + # examples of latest pool op is global max pool and non latest is global avg pooling + # OL - an Expr will be used for rfactor + # Out - programming of the parallelizm on the global level + # shared is not required, local could be enough but shared scope gives quite significant + # perf boost + if Pool.op in s.outputs: + Out = Pool + OL = s.cache_write(Pool, "shared") + else: + Out = outs[0].op.output(0) + s[Pool].set_scope("shared") + OL = Pool + + PaddedInput = Pool.op.input_tensors[0] + + # detect axis for later reorder and binding of batch/chennel to blocks and + # spatial to threads + if layout == "NCHW" or layout == "NCHW4c": + channel_index = 1 + height_index = 2 + width_index = 3 + else: + channel_index = 3 + height_index = 1 + width_index = 2 + + if isinstance(PaddedInput.op, tvm.te.ComputeOp): + s[PaddedInput].compute_inline() + + fused_reduce = s[OL].fuse( + *[s[OL].op.reduce_axis[i] for i in range(len(s[OL].op.reduce_axis))] + ) + + spatial = PaddedInput.shape[height_index].value * PaddedInput.shape[width_index].value + max_threads = spatial // 25 if spatial > 25 else 1 + max_threads = 256 if max_threads > 256 else max_threads + num_thread = get_div(spatial, max_threads) + + thread_y = te.thread_axis((0, num_thread), "threadIdx.y") + + _, ki = s[OL].split(fused_reduce, factor=num_thread) + data_out_rf = s.rfactor(OL, ki) + s[data_out_rf].compute_at(s[OL], s[OL].op.reduce_axis[0]) + s[OL].bind(s[OL].op.reduce_axis[0], thread_y) + + naxis = s[Out].op.axis[0] + caxis = s[Out].op.axis[channel_index] + haxis = s[Out].op.axis[height_index] + waxis = s[Out].op.axis[width_index] + + if layout == "NCHW4c" or layout == "NHWC4c": + texture_axis = s[Out].op.axis[-1] + s[Out].reorder(naxis, caxis, haxis, waxis, texture_axis) + s[Out].vectorize(texture_axis) + else: + texture_axis = None + s[Out].reorder(naxis, caxis, haxis, waxis) + + bx = s[Out].fuse(naxis, caxis) + tx = s[Out].fuse(haxis, waxis) + + s[Out].bind(bx, te.thread_axis("blockIdx.x")) + s[Out].bind(tx, te.thread_axis("threadIdx.x")) + + s[OL].compute_at(s[Out], tx) + + scheduled_ops = [] + + def traverse(OP): + """Internal traverse function""" + # inline all one-to-one-mapping operators except the last stage (output) + if tag.is_injective(OP.tag): + if OP not in s.outputs: + s[OP].compute_inline() + for tensor in OP.input_tensors: + if isinstance(tensor.op, te.tensor.ComputeOp) and tensor.op not in scheduled_ops: + traverse(tensor.op) + # schedule global_pool + elif OP.tag.startswith("adaptive_pool"): + Pool = OP.output(0) + _schedule_global(Pool, layout) + else: + raise RuntimeError("Unsupported operator: %s" % OP.tag) + + scheduled_ops.append(OP) + + traverse(outs[0].op) + return s def schedule_pool(outs, layout): diff --git a/tests/python/relay/opencl_texture/test_pool_texture.py b/tests/python/relay/opencl_texture/test_pool_texture.py new file mode 100644 index 000000000000..faeb121c800c --- /dev/null +++ b/tests/python/relay/opencl_texture/test_pool_texture.py @@ -0,0 +1,135 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import relay +from utils.adreno_utils import build_run_compare + + +dtype = tvm.testing.parameter("float32") + + +@tvm.testing.requires_opencl +@tvm.testing.parametrize_targets("opencl -device=adreno") +def test_global_pool2d_nchw_wide(remote, target, dtype): + """ + Use case of NCHW global pooling with big spatial valies + """ + input_shape = (1, 32, 160, 160) + A = relay.var("data", shape=input_shape, dtype=dtype) + C = relay.nn.global_avg_pool2d(A) + mod = relay.Function([A], C) + + build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target) + + +@tvm.testing.requires_opencl +@tvm.testing.parametrize_targets("opencl -device=adreno") +def test_global_pool2d_nchw4c_wide(remote, target, dtype): + """ + Use case of blocked NCHW4c global pooling with big spatial valies + """ + input_shape = (1, 8, 160, 160, 4) + A = relay.var("data", shape=input_shape, dtype=dtype) + C = relay.nn.global_avg_pool2d(A, layout="NCHW4c") + mod = relay.Function([A], C) + + build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target) + + +@tvm.testing.requires_opencl +@tvm.testing.parametrize_targets("opencl -device=adreno") +def test_global_pool2d_nchw_deep(remote, target, dtype): + """ + Use case of NCHW deep global pooling + """ + input_shape = (1, 2048, 20, 20) + A = relay.var("data", shape=input_shape, dtype=dtype) + C = relay.nn.global_avg_pool2d(A) + mod = relay.Function([A], C) + + build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target) + + +@tvm.testing.requires_opencl +@tvm.testing.parametrize_targets("opencl -device=adreno") +def test_global_pool2d_nchw4c_deep(remote, target, dtype): + """ + Use case of blocked NCHW4c deep global pooling + """ + input_shape = (1, 512, 20, 20, 4) + A = relay.var("data", shape=input_shape, dtype=dtype) + C = relay.nn.global_avg_pool2d(A, layout="NCHW4c") + mod = relay.Function([A], C) + + build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target) + + +@tvm.testing.requires_opencl +@tvm.testing.parametrize_targets("opencl -device=adreno") +def test_global_pool2d_nhwc(remote, target, dtype): + """ + Use case of NHWC global pooling with big spatial valies + """ + input_shape = (1, 160, 160, 32) + A = relay.var("data", shape=input_shape, dtype=dtype) + C = relay.nn.global_avg_pool2d(A, layout="NHWC") + mod = relay.Function([A], C) + + build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target) + + +@tvm.testing.requires_opencl +@tvm.testing.parametrize_targets("opencl -device=adreno") +def test_global_pool2d_nhwc4c(remote, target, dtype): + """ + Use case of NHWC deep global pooling + """ + input_shape = (1, 160, 160, 8, 4) + A = relay.var("data", shape=input_shape, dtype=dtype) + C = relay.nn.global_avg_pool2d(A, layout="NHWC4c") + mod = relay.Function([A], C) + + build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target) + + +@tvm.testing.requires_opencl +@tvm.testing.parametrize_targets("opencl -device=adreno") +def test_global_max_pool2d_nchw_wide(remote, target, dtype): + """ + Use case of NCHW global pooling with big spatial valies + """ + input_shape = (1, 32, 160, 160) + A = relay.var("data", shape=input_shape, dtype=dtype) + C = relay.nn.global_max_pool2d(A) + mod = relay.Function([A], C) + + build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target) + + +@tvm.testing.requires_opencl +@tvm.testing.parametrize_targets("opencl -device=adreno") +def test_global_max_pool2d_nchw4c_wide(remote, target, dtype): + """ + Use case of blocked NCHW4c global pooling with big spatial valies + """ + input_shape = (1, 8, 160, 160, 4) + A = relay.var("data", shape=input_shape, dtype=dtype) + C = relay.nn.global_max_pool2d(A, layout="NCHW4c") + mod = relay.Function([A], C) + + build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target) From be7bc2a24888f01b901cd36416da9c2da32b1f4e Mon Sep 17 00:00:00 2001 From: Andrey Malyshev Date: Wed, 7 Dec 2022 10:20:09 +0200 Subject: [PATCH 2/4] address pylint hits --- python/tvm/topi/adreno/pooling.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/python/tvm/topi/adreno/pooling.py b/python/tvm/topi/adreno/pooling.py index 150bfe3dc110..af6a6e5d6cf1 100644 --- a/python/tvm/topi/adreno/pooling.py +++ b/python/tvm/topi/adreno/pooling.py @@ -19,10 +19,6 @@ import tvm from tvm import te from .. import tag -from ..utils import traverse_inline -from .reduction import _schedule_reduce_adreno -from ..cuda.reduction import _schedule_reduce -from .injective import schedule_injective_from_existing from .utils import get_div @@ -61,7 +57,7 @@ def _schedule_global(Pool, layout): # detect axis for later reorder and binding of batch/chennel to blocks and # spatial to threads - if layout == "NCHW" or layout == "NCHW4c": + if layout in ("NCHW", "NCHW4c"): channel_index = 1 height_index = 2 width_index = 3 @@ -94,7 +90,7 @@ def _schedule_global(Pool, layout): haxis = s[Out].op.axis[height_index] waxis = s[Out].op.axis[width_index] - if layout == "NCHW4c" or layout == "NHWC4c": + if layout in ("NHWC4c", "NCHW4c"): texture_axis = s[Out].op.axis[-1] s[Out].reorder(naxis, caxis, haxis, waxis, texture_axis) s[Out].vectorize(texture_axis) From 1a195e4d5759f357ed190bce996d483e60759380 Mon Sep 17 00:00:00 2001 From: Andrey Malyshev Date: Wed, 7 Dec 2022 18:25:19 +0200 Subject: [PATCH 3/4] address PR comments --- python/tvm/relay/op/strategy/adreno.py | 4 ++-- python/tvm/topi/adreno/pooling.py | 9 +++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/op/strategy/adreno.py b/python/tvm/relay/op/strategy/adreno.py index 3325018a16e2..b606ab05d701 100644 --- a/python/tvm/relay/op/strategy/adreno.py +++ b/python/tvm/relay/op/strategy/adreno.py @@ -216,8 +216,8 @@ def schedule_reduce_adreno(attrs, outs, target): @schedule_adaptive_pool.register(["adreno"]) -def schedule_adaptive_pool_cuda(attrs, outs, target): - """schedule adaptive pooling ops for cuda""" +def schedule_adaptive_pool_adreno(attrs, outs, target): + """schedule adaptive pooling ops for adreno""" with target: return topi.adreno.schedule_adaptive_pool(outs, attrs.layout) diff --git a/python/tvm/topi/adreno/pooling.py b/python/tvm/topi/adreno/pooling.py index af6a6e5d6cf1..f775b000bd7d 100644 --- a/python/tvm/topi/adreno/pooling.py +++ b/python/tvm/topi/adreno/pooling.py @@ -55,7 +55,7 @@ def _schedule_global(Pool, layout): PaddedInput = Pool.op.input_tensors[0] - # detect axis for later reorder and binding of batch/chennel to blocks and + # detect axis for later reorder and binding of batch/channel to blocks and # spatial to threads if layout in ("NCHW", "NCHW4c"): channel_index = 1 @@ -69,11 +69,12 @@ def _schedule_global(Pool, layout): if isinstance(PaddedInput.op, tvm.te.ComputeOp): s[PaddedInput].compute_inline() - fused_reduce = s[OL].fuse( - *[s[OL].op.reduce_axis[i] for i in range(len(s[OL].op.reduce_axis))] - ) + fused_reduce = s[OL].fuse(*s[OL].op.reduce_axis) spatial = PaddedInput.shape[height_index].value * PaddedInput.shape[width_index].value + # below values were selected empirically assuming that we should have some work in each + # thread (currently from 25-49) and number of threads not exceeding some threshold that + # was selected as 256 from performance point of view after experiments on Adreno 660 max_threads = spatial // 25 if spatial > 25 else 1 max_threads = 256 if max_threads > 256 else max_threads num_thread = get_div(spatial, max_threads) From ea60d4bf2b2fb41c9bf1441ed9f57d55ef3fdb5f Mon Sep 17 00:00:00 2001 From: Andrey Malyshev Date: Thu, 8 Dec 2022 09:02:30 +0200 Subject: [PATCH 4/4] switch spatial axis to blk binding --- python/tvm/topi/adreno/pooling.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/python/tvm/topi/adreno/pooling.py b/python/tvm/topi/adreno/pooling.py index f775b000bd7d..f02af0c01fd2 100644 --- a/python/tvm/topi/adreno/pooling.py +++ b/python/tvm/topi/adreno/pooling.py @@ -99,13 +99,10 @@ def _schedule_global(Pool, layout): texture_axis = None s[Out].reorder(naxis, caxis, haxis, waxis) - bx = s[Out].fuse(naxis, caxis) - tx = s[Out].fuse(haxis, waxis) - + bx = s[Out].fuse(naxis, caxis, haxis, waxis) s[Out].bind(bx, te.thread_axis("blockIdx.x")) - s[Out].bind(tx, te.thread_axis("threadIdx.x")) - s[OL].compute_at(s[Out], tx) + s[OL].compute_at(s[Out], bx) scheduled_ops = []