Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/tvm/topi/hexagon/qnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@

from .quantize import quantize_compute, tir_quantize_schedule
from .nn import *
from .qdepthwise_conv2d_slice import qdepthwise_conv2d_compute, qdepthwise_conv2d_schedule
217 changes: 217 additions & 0 deletions python/tvm/topi/hexagon/qnn/qdepthwise_conv2d_slice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-variable, unused-argument, too-many-locals
"""
Please note the following assumptions made by the implementation:
1) The input must be padded in advance to account for 'padding'. In addition,
both input and output must be padded as per the physical buffer layout.
2) 'padding' is ignored. It must be handled outside of the sliced op.
3) The weights are expected to be as per physical layout

The initial compute for quantized depthwise conv2d is as follows
where cm = channel_multiplier; assumed to be 1,
zp_a = Activation_zero_point,
zp_w = Weight_zero_point,
Qa = Quantized Activation,
Qw = Quantized Weights.

a) Qc(n, oh, ow, oc) = (Sigma(r, s) (Qw(r, s, oc%cm, oc/cm) - zp_w)
* (Qa(n, oh + r, ow + s, oc/cm) - zp_a))
* scale_value
where scale_value = (activation_scale * weight_scale) / output_scale

This can be written as

b) Qc(n, oh, ow, oc) = (t1 - t2 - t3 + t4) * scale_value

where t1 = Sigma(r, s) Qw(r, s, oc%cm, oc/cm) * Qa(n, oh + r, ow + s, oc/cm)
t2 = Sigma(r, s) zp_w * Qa(n, oh + r, ow + s, oc/cm)
t3 = Sigma(r, s) zp_a * Qw(r, s, oc%cm, oc/cm)
t4 = Sigma(r, s) zp_a * zp_w

c) Qc(n, oh, ow, oc) = saturate(((t1 - t2 - t3 + t4) * fixed_scale_value)) >> rsh)

where fixed_scale_value, rsh are fixed point values for scale_value.


Compute and schedule for quantized depthwise conv2d slice op"""

import typing
import tvm
from tvm import te
from ..utils import get_layout_transform_fn, get_fixed_point_value, saturate


def qdepthwise_conv2d_compute(
activations: te.Tensor,
weights: te.Tensor,
out_shape: typing.Tuple,
stride: typing.Tuple,
dilation: typing.Tuple,
dtype: str,
# quantization params:
activation_zero_point,
activation_scale,
weight_zero_point,
weight_scale,
output_zero_point,
output_scale,
):
"""Compute for quantized depthwise conv2d"""
filt_shape = weights.shape
ob, oh, ow, oc = out_shape

if dtype == "uint8":
temp_dtype = "int32"
big_dtype = "int64"
elif dtype == "int8":
temp_dtype = "int32"
big_dtype = "int64"
else:
raise RuntimeError(f"Unsupported output dtype, {odtype}'")

reduce_height = tvm.te.reduce_axis((0, filt_shape[0]), name="reduce_height")
reduce_width = tvm.te.reduce_axis((0, filt_shape[1]), name="reduce_width")
stride_height, stride_width = stride
dilation_height, dilation_width = dilation

scale_value = (activation_scale * weight_scale) / output_scale
fixed_scale_value, rsh = get_fixed_point_value(scale_value, "int16")

t1 = tvm.te.compute(
out_shape,
lambda n, h, w, c: tvm.te.sum(
(
(
activations[
n,
h * stride_height + reduce_height * dilation_height,
w * stride_width + reduce_width * dilation_width,
c,
].astype(temp_dtype)
)
* (weights[reduce_height, reduce_width, 0, c].astype(temp_dtype))
).astype(temp_dtype),
axis=[reduce_height, reduce_width],
),
name="t1",
)

t2 = tvm.te.compute(
out_shape,
lambda n, h, w, c: tvm.te.sum(
(
(
activations[
n,
h * stride_height + reduce_height * dilation_height,
w * stride_width + reduce_width * dilation_width,
c,
].astype(temp_dtype)
)
* weight_zero_point
).astype(temp_dtype),
axis=[reduce_height, reduce_width],
),
name="t2",
)

t3 = tvm.te.compute(
(oc,),
lambda c: tvm.te.sum(
(
((weights[reduce_height, reduce_width, 0, c].astype(temp_dtype)))
* activation_zero_point
).astype(temp_dtype),
axis=[reduce_height, reduce_width],
),
name="t3",
)

t4 = activation_zero_point * weight_zero_point * reduce_height * reduce_width

output = tvm.te.compute(
out_shape,
lambda n, h, w, c: saturate(
(
(
(
((t1[n, h, w, c]).astype(big_dtype) - t2[n, h, w, c] - t3[c] + t4)
* fixed_scale_value
)
>> rsh
)
+ (output_zero_point).astype(big_dtype)
),
dtype,
).astype(dtype),
name="output",
)

return output


def qdepthwise_conv2d_schedule(
outs: te.Tensor,
ins: typing.List[te.Tensor],
transform_activation_layout: str,
transform_weights: str,
):
"""
Schedule for quantized depthwise conv2d for input layout nhwc-8h8w32c
assert len(ins) == 2, "This schedule expects only 2 inputs - Activations and Weights
"""
source_expr = ins + [outs]
prim_func = tvm.te.create_prim_func(source_expr)
sch = tvm.tir.Schedule(prim_func)

compute = sch.get_block("output")
compute1 = sch.get_block("t1")

transform_layout_fn = get_layout_transform_fn(transform_activation_layout)
transform_layout_weights = get_layout_transform_fn(transform_weights)

# Apply layout_transform for activation
sch.transform_layout(compute1, ins[0].name, transform_layout_fn)

# Apply layout_transform for weights
sch.transform_layout(compute1, ins[1].name, transform_layout_weights)

# Apply layout_transform for output
sch.transform_layout(compute, outs.name, transform_layout_fn)

# This returns the original 6d loop
batch, height, width, channel, reduce_height, reduce_width = sch.get_loops(compute1)
h_outer, h_inner = sch.split(height, [None, 8])
w_outer, w_inner = sch.split(width, [None, 8])
c_outer, c_inner = sch.split(channel, [None, 32])
sch.reorder(
batch,
h_outer,
w_outer,
c_outer,
h_inner,
reduce_height,
reduce_width,
w_inner,
c_inner,
)

sch.decompose_reduction(compute1, reduce_height)
# wi_ci = sch.fuse(w_inner,c_inner)
# sch.vectorize(wi_ci)
return sch
5 changes: 3 additions & 2 deletions python/tvm/topi/hexagon/slice_ops/dwconv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def dwconv2d_schedule(
outs: te.Tensor,
ins: typing.List[te.Tensor],
transform_activation_layout: str,
transform_weights: typing.Callable,
transform_weights: str,
) -> tvm.tir.Schedule:
"""STIR schedule definition for the compute defined above by dwconv2d_compute.
- Auto-generated prim_func before applying schedule primitives for reference
Expand Down Expand Up @@ -128,11 +128,12 @@ def main(InputTensor: T.Buffer[(1, 16, 8, 32), "float16"], Weights: T.Buffer[(3,
sch = tvm.tir.Schedule(prim_func)
compute = sch.get_block("Output")
transform_layout_fn = get_layout_transform_fn(transform_activation_layout)
transform_layout_weights = get_layout_transform_fn(transform_weights)
# Apply layout_transform for activation
sch.transform_layout(compute, ins[0].name, transform_layout_fn)

# Apply layout_transform for weights
sch.transform_layout(compute, ins[1].name, transform_weights)
sch.transform_layout(compute, ins[1].name, transform_layout_weights)

# Apply layout_transform for output
sch.transform_layout(compute, outs.name, transform_layout_fn)
Expand Down
19 changes: 19 additions & 0 deletions python/tvm/topi/hexagon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ def iohw_16i32o2i_1d(height, width, in_channel, out_channel):
]


def ohwi32o_1d(height, width, in_channel, out_channel):
return [out_channel // 32, height, width, in_channel, out_channel % 32]


def get_layout_transform_fn(layout):
"""Return index map function as per the layout string"""
if layout == "nhwc-8h2w32c2w-2d":
Expand Down Expand Up @@ -167,6 +171,8 @@ def get_layout_transform_fn(layout):
return nhwc_8h8w32c_2d
if layout == "n11c-2048c-2d":
return n11c_2048c_2d
if layout == "ohwi32o-1d":
return ohwi32o_1d
raise RuntimeError(f"Unexpected layout '{layout}'")


Expand Down Expand Up @@ -235,6 +241,19 @@ def get_fixed_point_value(flp: float, dtype: str = "int16") -> Tuple[int, int]:
best scaling factor for 'int16' type that can be used to convert the floating-point value to
fixed-point with the least amount of precision loss.


Here is a more rigorous explanation of the above, for non-negative scale values, which are of
interest. M < 2, so M * 2^(E-Bias+x) < 2 ^ (E-Bias+x+1) [Note: LHS is a fraction, RHS int]
=> round(M * 2^(E-Bias+x)) <= 2 ^ (E-Bias+x+1) [Note the "<=", not "<"]
We want x s.t. round(M * 2^(E-Bias+x)) <= 2^15 - 1
We know round(M * 2^(E-Bias+x)) <= 2^(E-Bias+x+1)
It will be sufficient to choose x s.t. 2^(E-Bias+x+1) <= 2^15 - 1
That is, max x. s.t. 2^(E-Bias+x+1) < 2^15
E-Bias+x+1 < 15
E-Bias+x+1 <= 14
Max x will make E-Bias+x+1 = 14
x = 13 - E + Bias
Copy link
Member

@masahi masahi Oct 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @ibsidorenko - I'm curious how the requantize operation done in QC "slice ops" (such as this PR) compares to the one done by QNN canonicalization.


Additonal notes on various floating-point values:
------------------------------------------------
1) Denormalized values: causes assertion failure. The problem with the denormalized values
Expand Down
Loading