Skip to content
1 change: 1 addition & 0 deletions python/tvm/topi/hexagon/qnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@
from .quantize import quantize_compute, tir_quantize_schedule
from .nn import *
from .qdepthwise_conv2d_slice import qdepthwise_conv2d_compute, qdepthwise_conv2d_schedule
from .adaptive_avg_pool1d import *
120 changes: 120 additions & 0 deletions python/tvm/topi/hexagon/qnn/adaptive_avg_pool1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

""" Compute and schedule for adaptive_avg_pool1d slice op

Following are few notes and assumptions made by the implementation:

Assumptions:
1) The input is in NCW layout. Distilbert is the only model that calls
nn.adaptive_avg_pool1d and the only layout it uses is 'NCW'.
2) The op takes output_size as an argument and
only handles the specialized case where output_size is 1.
The argument output_size is used as the value of output_width.
3) Both input and output dtype is uint8/int8 and
quantization parameter is provided to the op.
4) Input is assumed to always be multiple of fixed chunk 32c64w.

Notes:
1) If input width is used as output width, there can be two cases:
a. If the quantization parameters of input and output are same,
it can return the input as output so the op will be a no-op.
b. If the quantization parameters of input and output are different,
it will essentially be a requantize op.
2) If output_size is a value besides 1 or input_width,
adaptive_avg_pool1d may use dynamic stride and kernel for each output element.
When this case occurs, kernel won't be known at compile time. We want to use
the generic implementation nn.adaptive_avg_pool1d() for this case.
"""

from tvm import te
from tvm import tir
from ..utils import get_layout_transform_fn, get_fixed_point_value, saturate


def adaptive_avg_pool1d(
data: te.Tensor,
output_size: list,
odtype: str,
input_zero_point: int,
input_scale: float,
output_zero_point: int,
output_scale: float,
):
"""adaptive_avg_pool1d compute"""
_, _, inw = data.shape

out_width = output_size[0]

n, c = data.shape[:2]
oshape = (n, c) + (out_width,)

# Kernel is same as input_width since output_width is assumed to be 1
if out_width == 1:
kw_r = inw
else:
raise RuntimeError(f"Unsupported output_size, {out_width}'")

if odtype == "uint8":
temp_dtype = "uint32"
elif odtype == "int8":
temp_dtype = "int32"
else:
raise RuntimeError(f"Unsupported output dtype, {odtype}'")

scale_with_area = input_scale / (output_scale * int(kw_r))
scale_fixed_point, rsh = get_fixed_point_value(scale_with_area, "int16")
corr = (output_zero_point << rsh) - input_zero_point * kw_r * scale_fixed_point

rw_r = te.reduce_axis((0, kw_r), name="rw_r")

sum_compute = te.compute(
oshape,
lambda n, c, w: te.sum(data[n, c, w + rw_r].astype(temp_dtype), axis=[rw_r]),
name="sum",
)

avg_compute = te.compute(
oshape,
lambda n, c, w: saturate(
((sum_compute[n, c, w] * scale_fixed_point) + corr) >> rsh, odtype
).astype(odtype),
name="adaptive_avg_1d",
)
return avg_compute


def stir_schedule_ncw_32c64w(outs, ins, input_layout: str):
"""Schedule for input layout ncw-32c64w and output layout ncw"""
func = te.create_prim_func([ins, outs])
s = tir.Schedule(func)

sum_block = s.get_block("sum")

# Input is multiple of fixed chunk but output is NxCx1
# Hence transform_layout is only applied on input
input_transformed_layout = get_layout_transform_fn(input_layout)
s.transform_layout(sum_block, buffer=("read", 0), index_map=input_transformed_layout)

return s


def tir_adaptive_avg_pool1d_schedule(outs, ins, output_layout: str, input_layout: str):
"""STIR based schedule"""
if output_layout == "ncw":
return stir_schedule_ncw_32c64w(outs, ins, input_layout)
raise RuntimeError(f"Unexpected layout '{output_layout}'")
7 changes: 7 additions & 0 deletions python/tvm/topi/hexagon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ def ohwi32o_1d(height, width, in_channel, out_channel):
return [out_channel // 32, height, width, in_channel, out_channel % 32]


def ncw_32c64w_2d(n, c, w):
"""Return index map for ncw_32c64w 2d layout"""
return [n, c // 32, w // 64, te.AXIS_SEPARATOR, c % 32, w % 64]


def get_layout_transform_fn(layout):
"""Return index map function as per the layout string"""
if layout == "nhwc-8h2w32c2w-2d":
Expand Down Expand Up @@ -173,6 +178,8 @@ def get_layout_transform_fn(layout):
return n11c_2048c_2d
if layout == "ohwi32o-1d":
return ohwi32o_1d
if layout == "ncw-32c64w-2d":
return ncw_32c64w_2d
raise RuntimeError(f"Unexpected layout '{layout}'")


Expand Down
9 changes: 9 additions & 0 deletions tests/python/contrib/test_hexagon/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,15 @@ def transform_numpy(arr_np, current_layout: str, new_layout: str):

raise RuntimeError(f"Unexpected new_layout '{new_layout}'")

if current_layout == "ncw":
if new_layout == "ncw":
return arr_np
if new_layout in ["ncw-32c64w-2d"]:
n, c, w = arr_np.shape
return arr_np.reshape([n, c // 32, 32, w // 64, 64]).transpose(0, 1, 3, 2, 4)

raise RuntimeError(f"Unexpected new_layout '{new_layout}'")

raise RuntimeError(f"Unexpected current_layout '{current_layout}'")


Expand Down
185 changes: 185 additions & 0 deletions tests/python/contrib/test_hexagon/topi/test_adaptive_avg_pool1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# Licensed to the Apache Software Foundation (ASF) under one
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added lint for tests in this sub-directory. Let's wait for that PR to merge to avoid conflict. After that merges, you will see lint issues with this test, so I recommend to fix it in the meantime
#13271

# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Test code for specialized case of adaptive_avg_pool1d."""

import numpy as np

import tvm
from tvm import te
from tvm.topi.testing import adaptive_pool
import tvm.topi.hexagon.qnn as s1
from tvm.contrib.hexagon import allocate_hexagon_array
from ..infrastructure import transform_numpy, quantize_np


SCALE_M_VAL = None
ZERO_POINT_M_VAL = None
SCALE_VAL = None
ZERO_POINT_VAL = None


class TestAdaptivePool1D:
"""Test specialized case of adaptive_avg_pool1d."""

(input_shape,) = tvm.testing.parameters(
([1, 128, 128],),
([1, 64, 64],),
([1, 64, 128],),
([1, 32, 64],),
([1, 128, 768],),
)

# Fixed chunk layout is set as ncw-32c64w-2d for now.
# The adaptive_avg_pool1d implementation only handles specialized case
# where output_size is 1 as it appears on quantized distilbert model.
# Since output size won't be a multiple of fixed-chunk,
# output_layout is ncw.
# For optimization, it might get changed later.
input_layout, output_layout, pool_type, layout, output_size, dtype, = tvm.testing.parameters(
(
"ncw-32c64w-2d",
"ncw",
"avg",
"NCW",
[1],
"uint8",
)
)

@tvm.testing.fixture
def expected_output_np(
self,
input_np,
output_size,
pool_type,
layout,
):
"""Generate expected output."""
out_width = output_size[0]

ref_np = adaptive_pool(
input_np,
out_width,
pool_type,
layout,
)
return ref_np

@tvm.testing.fixture
def input_np(self, input_shape, dtype):
if dtype in ("uint8", "int8"):
dtype = "float32"
return np.random.random(input_shape).astype(dtype)

@tvm.testing.fixture
def quantize_input_np(self, input_np, dtype):
if dtype in ("uint8", "int8"):
global ZERO_POINT_VAL, SCALE_VAL
input_np_quantized, SCALE_VAL, ZERO_POINT_VAL = quantize_np(input_np, dtype)
return input_np_quantized

raise RuntimeError(f"Unsupported data type '{dtype}'")

@tvm.testing.fixture
def transformed_input_np(self, quantize_input_np, input_layout, layout, dtype):
if dtype in ("uint8", "int8"):
return transform_numpy(quantize_input_np, layout.lower(), input_layout)

raise RuntimeError(f"Unsupported data type '{dtype}'")

@tvm.testing.fixture
def quantize_expected_output_np(self, expected_output_np, dtype):
"""Generate expected output."""
if dtype in ("uint8", "int8"):
global ZERO_POINT_M_VAL, SCALE_M_VAL
out_ref_quantized, SCALE_M_VAL, ZERO_POINT_M_VAL = quantize_np(
expected_output_np, dtype
)

# Since output_layout is ncw, no transformation is needed.
return out_ref_quantized

raise RuntimeError(f"Unsupported data type '{dtype}'")

@tvm.testing.requires_hexagon
def test_pool1d(
self,
dtype,
output_size,
input_layout,
output_layout,
input_shape,
transformed_input_np,
quantize_expected_output_np,
hexagon_session,
):
"""Test adaptive_avg_pool1d."""
target_hexagon = tvm.target.hexagon("v69")
a_tensor = te.placeholder(input_shape, name="a_tensor", dtype=dtype)

m_tensor = s1.adaptive_avg_pool1d(
a_tensor,
output_size,
dtype,
ZERO_POINT_VAL,
SCALE_VAL,
ZERO_POINT_M_VAL,
SCALE_M_VAL,
)

tir_schedule = s1.tir_adaptive_avg_pool1d_schedule(
m_tensor, a_tensor, output_layout, input_layout
)

sch = tir_schedule.mod

with tvm.transform.PassContext(opt_level=3):
func = tvm.build(
sch,
[a_tensor, m_tensor],
tvm.target.Target(target_hexagon, host=target_hexagon),
name="adaptive_pool1d",
)

input_axis_separator = [3]

a_data_nd = allocate_hexagon_array(
hexagon_session.device,
data=transformed_input_np,
dtype=dtype,
axis_separators=input_axis_separator,
mem_scope="global.vtcm",
)

m_data_nd = allocate_hexagon_array(
hexagon_session.device,
quantize_expected_output_np.shape,
dtype=dtype,
)

mod = hexagon_session.load_module(func)
mod(a_data_nd, m_data_nd)

# Convert nd to np
m_data_np = m_data_nd.numpy()

np.testing.assert_allclose(quantize_expected_output_np, m_data_np, atol=2)


if __name__ == "__main__":
tvm.testing.main()