Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions python/tvm/relay/op/strategy/adreno.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,13 @@ def schedule_reduce_adreno(attrs, outs, target):
return topi.adreno.schedule_reduce(outs)


@schedule_adaptive_pool.register(["adreno"])
def schedule_adaptive_pool_adreno(attrs, outs, target):
"""schedule adaptive pooling ops for adreno"""
with target:
return topi.adreno.schedule_adaptive_pool(outs, attrs.layout)


@concatenate_strategy.register(["adreno"])
def concatenate_strategy_adreno(attrs, inputs, out_type, target):
strategy = _op.OpStrategy()
Expand Down
107 changes: 107 additions & 0 deletions python/tvm/topi/adreno/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,113 @@
import tvm
from tvm import te
from .. import tag
from .utils import get_div


def schedule_adaptive_pool(outs, layout="NCHW"):
"""Schedule for adaptive_pool.

Parameters
----------
outs: Array of Tensor
The computation graph description of adaptive_pool
in the format of an array of tensors.

Returns
-------
s: Schedule
The computation schedule for adaptive_pool.
"""
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])

def _schedule_global(Pool, layout):
# examples of latest pool op is global max pool and non latest is global avg pooling
# OL - an Expr will be used for rfactor
# Out - programming of the parallelizm on the global level
# shared is not required, local could be enough but shared scope gives quite significant
# perf boost
if Pool.op in s.outputs:
Out = Pool
OL = s.cache_write(Pool, "shared")
else:
Out = outs[0].op.output(0)
s[Pool].set_scope("shared")
OL = Pool

PaddedInput = Pool.op.input_tensors[0]

# detect axis for later reorder and binding of batch/channel to blocks and
# spatial to threads
if layout in ("NCHW", "NCHW4c"):
channel_index = 1
height_index = 2
width_index = 3
else:
channel_index = 3
height_index = 1
width_index = 2

if isinstance(PaddedInput.op, tvm.te.ComputeOp):
s[PaddedInput].compute_inline()

fused_reduce = s[OL].fuse(*s[OL].op.reduce_axis)

spatial = PaddedInput.shape[height_index].value * PaddedInput.shape[width_index].value
# below values were selected empirically assuming that we should have some work in each
# thread (currently from 25-49) and number of threads not exceeding some threshold that
# was selected as 256 from performance point of view after experiments on Adreno 660
max_threads = spatial // 25 if spatial > 25 else 1
max_threads = 256 if max_threads > 256 else max_threads
num_thread = get_div(spatial, max_threads)

thread_y = te.thread_axis((0, num_thread), "threadIdx.y")

_, ki = s[OL].split(fused_reduce, factor=num_thread)
data_out_rf = s.rfactor(OL, ki)
s[data_out_rf].compute_at(s[OL], s[OL].op.reduce_axis[0])
s[OL].bind(s[OL].op.reduce_axis[0], thread_y)

naxis = s[Out].op.axis[0]
caxis = s[Out].op.axis[channel_index]
haxis = s[Out].op.axis[height_index]
waxis = s[Out].op.axis[width_index]

if layout in ("NHWC4c", "NCHW4c"):
texture_axis = s[Out].op.axis[-1]
s[Out].reorder(naxis, caxis, haxis, waxis, texture_axis)
s[Out].vectorize(texture_axis)
else:
texture_axis = None
s[Out].reorder(naxis, caxis, haxis, waxis)

bx = s[Out].fuse(naxis, caxis, haxis, waxis)
s[Out].bind(bx, te.thread_axis("blockIdx.x"))

s[OL].compute_at(s[Out], bx)

scheduled_ops = []

def traverse(OP):
"""Internal traverse function"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_injective(OP.tag):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if isinstance(tensor.op, te.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule global_pool
elif OP.tag.startswith("adaptive_pool"):
Pool = OP.output(0)
_schedule_global(Pool, layout)
else:
raise RuntimeError("Unsupported operator: %s" % OP.tag)

scheduled_ops.append(OP)

traverse(outs[0].op)
return s


def schedule_pool(outs, layout):
Expand Down
135 changes: 135 additions & 0 deletions tests/python/relay/opencl_texture/test_pool_texture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import tvm
from tvm import relay
from utils.adreno_utils import build_run_compare


dtype = tvm.testing.parameter("float32")


@tvm.testing.requires_opencl
@tvm.testing.parametrize_targets("opencl -device=adreno")
def test_global_pool2d_nchw_wide(remote, target, dtype):
"""
Use case of NCHW global pooling with big spatial valies
"""
input_shape = (1, 32, 160, 160)
A = relay.var("data", shape=input_shape, dtype=dtype)
C = relay.nn.global_avg_pool2d(A)
mod = relay.Function([A], C)

build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target)


@tvm.testing.requires_opencl
@tvm.testing.parametrize_targets("opencl -device=adreno")
def test_global_pool2d_nchw4c_wide(remote, target, dtype):
"""
Use case of blocked NCHW4c global pooling with big spatial valies
"""
input_shape = (1, 8, 160, 160, 4)
A = relay.var("data", shape=input_shape, dtype=dtype)
C = relay.nn.global_avg_pool2d(A, layout="NCHW4c")
mod = relay.Function([A], C)

build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target)


@tvm.testing.requires_opencl
@tvm.testing.parametrize_targets("opencl -device=adreno")
def test_global_pool2d_nchw_deep(remote, target, dtype):
"""
Use case of NCHW deep global pooling
"""
input_shape = (1, 2048, 20, 20)
A = relay.var("data", shape=input_shape, dtype=dtype)
C = relay.nn.global_avg_pool2d(A)
mod = relay.Function([A], C)

build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target)


@tvm.testing.requires_opencl
@tvm.testing.parametrize_targets("opencl -device=adreno")
def test_global_pool2d_nchw4c_deep(remote, target, dtype):
"""
Use case of blocked NCHW4c deep global pooling
"""
input_shape = (1, 512, 20, 20, 4)
A = relay.var("data", shape=input_shape, dtype=dtype)
C = relay.nn.global_avg_pool2d(A, layout="NCHW4c")
mod = relay.Function([A], C)

build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target)


@tvm.testing.requires_opencl
@tvm.testing.parametrize_targets("opencl -device=adreno")
def test_global_pool2d_nhwc(remote, target, dtype):
"""
Use case of NHWC global pooling with big spatial valies
"""
input_shape = (1, 160, 160, 32)
A = relay.var("data", shape=input_shape, dtype=dtype)
C = relay.nn.global_avg_pool2d(A, layout="NHWC")
mod = relay.Function([A], C)

build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target)


@tvm.testing.requires_opencl
@tvm.testing.parametrize_targets("opencl -device=adreno")
def test_global_pool2d_nhwc4c(remote, target, dtype):
"""
Use case of NHWC deep global pooling
"""
input_shape = (1, 160, 160, 8, 4)
A = relay.var("data", shape=input_shape, dtype=dtype)
C = relay.nn.global_avg_pool2d(A, layout="NHWC4c")
mod = relay.Function([A], C)

build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target)


@tvm.testing.requires_opencl
@tvm.testing.parametrize_targets("opencl -device=adreno")
def test_global_max_pool2d_nchw_wide(remote, target, dtype):
"""
Use case of NCHW global pooling with big spatial valies
"""
input_shape = (1, 32, 160, 160)
A = relay.var("data", shape=input_shape, dtype=dtype)
C = relay.nn.global_max_pool2d(A)
mod = relay.Function([A], C)

build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target)


@tvm.testing.requires_opencl
@tvm.testing.parametrize_targets("opencl -device=adreno")
def test_global_max_pool2d_nchw4c_wide(remote, target, dtype):
"""
Use case of blocked NCHW4c global pooling with big spatial valies
"""
input_shape = (1, 8, 160, 160, 4)
A = relay.var("data", shape=input_shape, dtype=dtype)
C = relay.nn.global_max_pool2d(A, layout="NCHW4c")
mod = relay.Function([A], C)

build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target)