Skip to content
37 changes: 36 additions & 1 deletion python/tvm/relay/op/strategy/hexagon.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def batch_matmul_strategy_hexagon(attrs, inputs, out_type, target):
"""batch_matmul strategy for Hexagon"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_batch_matmul(topi.nn.batch_matmul),
wrap_compute_batch_matmul(topi.nn.batch_matmul, need_out_dtype=True),
wrap_topi_schedule(topi.hexagon.schedule_batch_matmul),
name="batch_matmul.hexagon",
)
Expand Down Expand Up @@ -187,3 +187,38 @@ def schedule_reduce_hexagon(attrs, outs, target):
"""Schedule reduction ops for Hexagon"""
with target:
return topi.hexagon.schedule_reduce(outs)


@conv2d_NCHWc_strategy.register("hexagon")
def conv2d_NCHWc_strategy_hexagon(attrs, inputs, out_type, target):
"""conv2d_NCHWc_ hexagon strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv2d(
topi.hexagon.conv2d_NCHWc_int8, need_data_layout=True, need_out_layout=True
),
wrap_topi_schedule(topi.hexagon.schedule_conv2d_NCHWc_int8),
name="conv2d_NCHWc_int8.hexagon",
)
return strategy


@dense_pack_strategy.register("hexagon")
def dense_pack_strategy_hexagon(attrs, inputs, out_type, target):
"""dense_pack hexagon strategy"""
strategy = _op.OpStrategy()

if (
inputs[0].dtype == "uint8"
and inputs[1].dtype == "uint8"
and out_type.dtype == "int32"
and attrs["weight_layout"] == "NC32n4c"
):
strategy.add_implementation(
wrap_compute_dense(topi.hexagon.dense.dense_u8u8i32_vrmpy_compute),
wrap_topi_schedule(topi.hexagon.dense.dense_u8u8i32_vrmpy_schedule),
name="dense_uint8.hexagon",
plevel=12,
)

return strategy
11 changes: 10 additions & 1 deletion python/tvm/topi/generic/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,16 @@ def schedule_conv_NCHWc_cpu_common_int8(
More details - https://software.intel.com/en-us/articles/
lower-numerical-precision-deep-learning-inference-and-training
"""
reg_n, unroll_kw = cfg["tile_ow"].size[-1], cfg["unroll_kw"].val
if isinstance(cfg["tile_ow"], int):
reg_n = cfg["tile_ow"]
else:
reg_n = cfg["tile_ow"].size[-1]

if isinstance(cfg["unroll_kw"], (int, bool)):
unroll_kw = cfg["unroll_kw"]
else:
unroll_kw = cfg["unroll_kw"].val

_, _, _, _, ic_bn = get_const_tuple(data_vec.shape)
_, _, _, _, oc_bn = get_const_tuple(conv_out.shape)

Expand Down
2 changes: 2 additions & 0 deletions python/tvm/topi/hexagon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,5 @@
from .resize2d import *
from .tensor_intrin import *
from .qnn import *
from .dense_alter_op import *
from .conv2d_alter_op import *
49 changes: 48 additions & 1 deletion python/tvm/topi/hexagon/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# pylint: disable=invalid-name
"""Schedule for conv2d"""

import tvm
from tvm import te
from .. import nn
from ..utils import traverse_inline
from .tensor_intrin import dot_vrmpy
from ..generic import conv2d as conv2d_generic


def schedule_conv2d_nhwc(outs):
Expand Down Expand Up @@ -86,3 +90,46 @@ def _callback(op):

traverse_inline(s, outs[0].op, _callback)
return s


def conv2d_NCHWc_int8(
data, kernel, stride, padding, dilation, layout, out_layout, out_dtype="int32"
):
"""Compute definition for int8 conv2d in NCHWc layout"""
n_elems = int(kernel.shape[-1])
return nn.conv2d_NCHWc_int8(
data, kernel, stride, padding, dilation, layout, out_layout, out_dtype, n_elems=n_elems
)


def schedule_conv2d_NCHWc_int8(outs):
"""Schedule for int8 conv2d in NCHWc layout using vrmpy tensorization"""
s = te.create_schedule([x.op for x in outs])

def _callback(op):
if "conv2d_NCHWc_int8" in op.tag:
conv_out = op.output(0)
kernel_vec = conv_out.op.input_tensors[1]
data_vec = conv_out.op.input_tensors[0]
out_width = conv_out.shape[3]

reg_n = 1
for n in range(31, 0, -1):
if out_width % n == 0:
reg_n = n
break

cfg = {"tile_ow": reg_n, "unroll_kw": False}
args = [s, cfg, data_vec, kernel_vec, conv_out, outs[0]]
intrin = dot_vrmpy(data_vec.dtype, kernel_vec.dtype)

conv2d_generic.schedule_conv_NCHWc_cpu_common_int8(
*args,
int32_lanes=32,
int8_elems=4,
intrin=intrin,
inline_fused=True,
)

traverse_inline(s, outs[0].op, _callback)
return s
111 changes: 111 additions & 0 deletions python/tvm/topi/hexagon/conv2d_alter_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,unused-variable,unused-argument,no-member
"""Conv2d alter op functions for Hexagon"""

from tvm import relay
from ..utils import get_const_tuple
from .. import nn
from ..nn import conv2d_alter_layout
from ..generic.conv2d import conv2d_alter_int8_common


@conv2d_alter_layout.register("hexagon")
def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
"""Convert nn.conv2d into nn.contrib_conv2d_nchwc if vrmpy is applicable."""
new_attrs = {k: attrs[k] for k in attrs.keys()}

data_layout = attrs["data_layout"]
kernel_layout = attrs["kernel_layout"]
data_tensor, kernel_tensor = tinfos
out_channel, in_channel, _, _ = get_const_tuple(kernel_tensor.shape)

if (
"int8" in data_tensor.dtype
and "int8" in kernel_tensor.dtype
and out_channel % 32 == 0
and in_channel % 4 == 0
and data_layout == "NCHW"
and kernel_layout == "OIHW"
):
out_channel, in_channel, _, _ = get_const_tuple(kernel_tensor.shape)

n_elems = 4
oc_bn = 32
ic_bn = min(in_channel, 32)

new_attrs = {k: attrs[k] for k in attrs.keys()}

new_attrs["channels"] = out_channel
new_attrs["data_layout"] = "NCHW%dc" % ic_bn
new_attrs["kernel_layout"] = "OIHW{:n}i{:n}o{:n}i".format(ic_bn // n_elems, oc_bn, n_elems)
new_attrs["out_layout"] = "NCHW%dc" % oc_bn

return relay.nn.contrib_conv2d_nchwc(*inputs, **new_attrs)

return None


@nn.conv2d_legalize.register("hexagon")
def _conv2d_legalize(attrs, inputs, arg_types):
"""Legalize conv2d op for vrmpy tensorization.

If the inputs are signed or unsigned int8, the input and output channels are padded to be
a multiple of 4 and 32 respectively.

If the input data types are (int8, int8), they are converted to (uint8, int8) and
the vector-by-vector variant of vrmpy is applied.
If the input data types are (uint8, uint8), the more efficient vector-by-scalar variant of vrmpy
is applied.

Unlike the nn.dense case (see dense_alter_op.py), we do not convert (uint8, int8) to
(uint8, uint8). That would introduce another convolution by a constant (128 or 1) filter,
to compensate for the dtype legalization. In the nn.dense case, such compensation factor is
just a sum over the K axis.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @ibsidorenko @tkonolige @nverke on this. We can convert u8 * s8 convolution to u8 * u8 like below

W'_u8 = W_s8 + 128
X_u8 * W_s8 = X_u8 * (W'_u8 - 128)
                = X'_u8 * W'_u8 - X_u8 * 128

Here, X_u8 * 128 is a convolution of X_u8 by a constant filter. We can factor out 128 to end up with a filter where all elements are 1. So what we need is a windowed sum, or "sum pooling" op - without it, I think we need to do a full blown convolution. This is why I don't use legalization for conv2d. Let me know if you have better idea.

"""
data_layout = attrs["data_layout"]
kernel_layout = attrs["kernel_layout"]

output_tensor = arg_types[2]

data, kernel = inputs

if data_layout != "NCHW" or kernel_layout != "OIHW":
return None

data_tensor, kernel_tensor = arg_types[0], arg_types[1]

if "int8" in data_tensor.dtype and "int8" in data_tensor.dtype:
output_tensor = arg_types[2]
data, kernel = inputs
desired_data_dtype = "uint8"
in_channel_vector_length = 4
out_channel_vector_length = 32

return conv2d_alter_int8_common(
data,
data_tensor,
kernel,
kernel_tensor,
output_tensor,
attrs,
desired_data_dtype,
in_channel_vector_length,
out_channel_vector_length,
)

return None
73 changes: 72 additions & 1 deletion python/tvm/topi/hexagon/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# pylint: disable=invalid-name
"""Schedule for dense operator"""

import tvm
from tvm.topi.utils import traverse_inline
from tvm import te
from .. import tag
from .tensor_intrin import dot_vrmpy


def schedule_dense(outs):
Expand All @@ -38,3 +42,70 @@ def schedule_dense(outs):
s = tvm.te.create_schedule([x.op for x in outs])
tvm.te.schedule.AutoInlineInjective(s)
return s


def dense_u8u8i32_vrmpy_compute(X, packed_w, bias, out_dtype):
"""Compute for uint8 x uint8 -> int32 dense using vrmpy"""
assert X.dtype == "uint8" and packed_w.dtype == "uint8" and out_dtype == "int32"
m, k = X.shape
n_o, _, n_i, _ = packed_w.shape
assert n_i == 32
ak = te.reduce_axis((0, k), name="k")

C = te.compute(
(m, n_o * n_i),
lambda i, j: te.sum(
X[i, ak].astype("int32")
* packed_w[tvm.tir.indexdiv(j, 32), tvm.tir.indexdiv(ak, 4), j % 32, ak % 4].astype(
"int32"
),
axis=ak,
),
tag="dense_u8u8i32_vrmpy",
name="compute",
)

if bias is not None:
C = te.compute(C.shape, lambda i, j: C[i, j] + bias[j], tag=tag.BROADCAST)

return C


def dense_u8u8i32_vrmpy_schedule(outs):
"""Schedule for vrmpy dense"""
s = te.create_schedule([x.op for x in outs])
# O: The output of the fused op
O = outs[0]

def _schedule_dense(s, C, O):
(a_k,) = C.op.reduce_axis
a_y = C.op.axis[-2]
a_yo, a_yi = s[C].split(a_y, factor=32)
a_xo, a_xi = s[C].split(C.op.axis[-1], factor=32)
a_ko, a_ki = s[C].split(a_k, factor=4)

s[C].reorder(a_yo, a_xo, a_yi, a_ko, a_xi, a_ki)

pc = dot_vrmpy("uint8", "uint8")
s[C].tensorize(a_xi, pc)
s[C].parallel(s[C].fuse(a_yo, a_xo))

if C != O:
a_y = O.op.axis[-2]
a_yo, a_yi = s[O].split(a_y, factor=32)
a_xo, a_xi = s[O].split(O.op.axis[-1], factor=32)

s[O].reorder(a_yo, a_xo, a_yi, a_xi)
s[O].vectorize(a_xi)
s[C].compute_at(s[O], a_yi)
s[O].parallel(s[O].fuse(a_yo, a_xo))

def _callback(op):
if "u8u8i32_vrmpy" in op.tag:
# C: The output of GEMM
C = op.output(0)
_schedule_dense(s, C, O)

traverse_inline(s, outs[0].op, _callback)

return s
Loading