Skip to content
22 changes: 13 additions & 9 deletions python/tvm/relay/op/strategy/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,20 +236,24 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
name="depthwise_conv2d_nhwc.arm_cpu",
)

# Optimized special case depthwiseConv2D operation. Requires a 3x3 kernel, a
# NHWC layout, a HWOI kernel layout (which we rearrange), no dilation, int8 inputs,
# int32 output, the same number of input and output channels, and for that channel
# count to be divisible by 4. Additional work could remove these restrictions.
# Optimized special case depthwiseConv2D operation. Requires NHWC layout,
# a HWOI kernel layout (which we rearrange to a custom layout) no dilation,
# int8/16 inputs, int32 output, and the same number of input and output channels.
# The int8 implementation DOES need the DSP unit (for SXTB16), but it is not
# possible to use the DSP unit to speed up a NHWC depthwise convolution (though
# an NCHW convolution would benefit).

elif (
target.features.has_dsp
and kernel.shape[0] == kernel.shape[1] == 3
and dilation_w == dilation_h == 1
dilation_w == dilation_h == 1
and kernel.shape[3] == 1 # channel_multiplier == 1
and data.dtype == "int8"
and out_type.dtype == "int32"
and data.shape[3] % 4 == 0
and (
(data.shape[3] % 4 == 0 and data.dtype == "int8" and target.features.has_dsp)
or (data.shape[3] % 2 == 0 and data.dtype == "int16")
)
and (padding != "SAME" or data.shape[1] % stride_h == data.shape[2] % stride_w == 0)
# Ideally we should check that kernel is a Relay constant, but strategy functions
# don't have access to the data needed to check this.
):
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.depthwise_conv2d_nhwc_dsp),
Expand Down
38 changes: 37 additions & 1 deletion python/tvm/topi/arm_cpu/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import logging

import numpy as np

import tvm
from tvm import te
from tvm import relay
Expand All @@ -31,6 +33,7 @@
from .conv2d_int8 import is_int8_hw_support
from .arm_utils import get_tiling_B_interleaved_t
from ..generic.conv2d import conv2d_alter_int8_common
from .mprofile.dsp.micro_kernel.common import num_simd_lanes_per_word

logger = logging.getLogger("topi")

Expand Down Expand Up @@ -121,7 +124,40 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):

idxd = tvm.tir.indexdiv

# We don't perform layout alteration for NHWC layout with real data types
if topi_tmpl == "depthwise_conv2d_nhwc_dsp.arm_cpu":
assert data_layout == "NHWC" and kernel_layout == "HWOI"

# We are not able to check if inputs[1] (the kernel) is a constant in the
# strategy function, so as a stopgap solution we use an assert here.
assert isinstance(
inputs[1], relay.Constant
), "depthwise_conv2d_nhwc_dsp.arm_cpu requires kernel be a relay Constant"

channels = get_const_tuple(data.shape)[3]
KH, KW, _, _ = get_const_tuple(kernel.shape)
simd_lanes = num_simd_lanes_per_word(data.dtype)

HWOI_kernel_np = inputs[1].data.numpy()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'll need to check that the kernel is a constant and fallback to a different implementation if it is not.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how easy it is to check if the kernel is a constant from python/tvm/relay/op/strategy/arm_cpu.py, but you're right that it is a thing we should check. I've added an assertion, though it is a bit of a stopgap solution.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a message to the assert and a comment about what needs to be done to not make it a stopgap.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, a clean solution is hard, as the strategy function does not have access to the needed information. When conv2d_alter_op is called, inputs[1] (the kernel) has the form:

meta[relay.Constant][0] /* ty=Tensor[(3, 3, 3, 8), int16] */

However when the Relay strategy functions are called, inputs[1] (the kernel) looks like:

Tensor(shape=[3, 3, 8, 1], op.name=placeholder)

Nowhere inside relay/op/strategy do any of the strategy functions check whether the relevant tensors are constant, so there's not much we can do. I've added comments explaining this, but please let me know if you have ideas for how this could be done.

CHWc_kernel_np = np.zeros((channels // simd_lanes, KH, KW, simd_lanes), dtype=kernel.dtype)
for i in range(channels // simd_lanes):
CHWc_kernel_np[i] = HWOI_kernel_np[:, :, simd_lanes * i : simd_lanes * (i + 1), 0]
reshaped_new_kernel = CHWc_kernel_np.reshape((KH, KW, channels, 1))

# Store the same config for the altered operator (workload)
new_data = data
new_kernel = te.placeholder((KH, KW, channels, 1), dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, dilation, out_dtype],
"depthwise_conv2d_nhwc_dsp.arm_cpu",
)
dispatch_ctx.update(target, new_workload, cfg)
return relay.nn.conv2d(
inputs[0],
relay.Constant(tvm.nd.array(reshaped_new_kernel)),
**new_attrs,
)

# Only microTVM does layout alteration for NHWC layout with real data types
if data_layout == "NHWC" and data_dtype not in ["uint8", "int8"]:
return None

Expand Down
126 changes: 24 additions & 102 deletions python/tvm/topi/arm_cpu/mprofile/dsp/depthwise_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,84 +19,15 @@
import random
import string

from tvm import te
from tvm.topi.utils import traverse_inline, get_const_tuple
from tvm import te, topi
from tvm.topi.utils import traverse_inline
from tvm.topi.nn.pad import pad
from tvm import tir

from .micro_kernel.quad_channel_convolve import (
intrin_quad_channel_convolve,
quad_channel_convolve_impl,
from .micro_kernel.multi_channel_convolve import (
intrin_multi_channel_convolve,
multi_channel_convolve_impl,
)

# For depthwise_conv2d, kernels are normally given in HWOI format,
# which when input_channels = output channels, we will call HWC.
# This is bad, as we want "related" parts of the kernel to be next
# to each other, so we can use __SMLAD later.
#
# Consider a 3x3 int8 kernel with no bias vector, with eight
# channels. Let us specify entries in the kernel as H_W_C - i.e.
# where 0_2_3 represents the rightmost position in the first row
# of channel 4/8 (4 because of zero indexing). Each [ ] represents
# a 32-bit integer. We currently store the kernel as:
#
# 0 ................................31
# [ 0_0_0 || 0_0_1 || 0_0_2 || 0_0_3 ] [ 0_0_4 || 0_0_5 || 0_0_6 || 0_0_7 ]
# [ 0_1_0 || 0_1_1 || 0_1_2 || 0_1_3 ] [ 0_1_4 || 0_1_5 || 0_1_6 || 0_1_7 ]
# [ 0_2_0 || 0_2_1 || 0_2_2 || 0_2_3 ] [ 0_2_4 || 0_2_5 || 0_2_6 || 0_2_7 ]
# [ 1_0_0 || 1_0_1 || 1_0_2 || 1_0_3 ] [ 1_0_4 || 1_0_5 || 1_0_6 || 1_0_7 ]
# [ 1_1_0 || 1_1_1 || 1_1_2 || 1_1_3 ] [ 1_1_4 || 1_1_5 || 1_1_6 || 1_1_7 ]
# [ 1_2_0 || 1_2_1 || 1_2_2 || 1_2_3 ] [ 1_2_4 || 1_2_5 || 1_2_6 || 1_2_7 ]
# [ 2_0_0 || 2_0_1 || 2_0_2 || 2_0_3 ] [ 2_0_4 || 2_0_5 || 2_0_6 || 2_0_7 ]
# [ 2_1_0 || 2_1_1 || 2_1_2 || 2_1_3 ] [ 2_1_4 || 2_1_5 || 2_1_6 || 2_1_7 ]
# [ 2_2_0 || 2_2_1 || 2_2_2 || 2_2_3 ] [ 2_2_4 || 2_2_5 || 2_2_6 || 2_2_7 ]
#
# Let 0x00 be all zeros. We rearrange into:
#
# 0 ................................31
# [ 0_0_0 || 0_0_1 || 0_1_0 || 0_1_1 ] [ 0_0_2 || 0_0_3 || 0_1_2 || 0_1_3 ]
# [ 0_2_0 || 0_2_1 || 1_0_0 || 1_0_1 ] [ 0_2_2 || 0_2_3 || 1_0_2 || 1_0_3 ]
# [ 1_1_0 || 1_1_1 || 1_2_0 || 1_2_1 ] [ 1_1_2 || 1_1_3 || 1_2_2 || 1_2_3 ]
# [ 2_0_0 || 2_0_1 || 2_1_0 || 2_1_1 ] [ 2_0_2 || 2_0_3 || 2_1_2 || 2_1_3 ]
# [ 2_2_0 || 2_2_1 || 0x000 || 0x000 ] [ 2_2_2 || 2_2_3 || 0x000 || 0x000 ]
# [ 0_0_4 || 0_0_5 || 0_1_4 || 0_1_5 ] [ 0_0_6 || 0_0_7 || 0_1_6 || 0_1_7 ]
# [ 0_2_4 || 0_2_5 || 1_0_4 || 1_0_5 ] [ 0_2_6 || 0_2_7 || 1_0_6 || 1_0_7 ]
# [ 1_1_4 || 1_1_5 || 1_2_4 || 1_2_5 ] [ 1_1_6 || 1_1_7 || 1_2_6 || 1_2_7 ]
# [ 2_0_4 || 2_0_5 || 2_1_4 || 2_1_5 ] [ 2_0_6 || 2_0_7 || 2_1_6 || 2_1_7 ]
# [ 2_2_4 || 2_2_5 || 0x000 || 0x000 ] [ 2_2_6 || 2_2_7 || 0x000 || 0x000 ]
#
# This saves us six operations comapred to the original ordering, as we
# do not need halfword packing instructions.
#
# This kernel re-arranging function will be used for 3x3 kernels (as that
# is all this DSP implementation currently supports) but would work with
# any M*N kernel such that M*N is odd.


def _rearrange_kernel(kernel):
# Kernel must be HWC format.
kernel_h, kernel_w, channels, _ = get_const_tuple(kernel.shape)
assert channels % 4 == 0

# This restriction could be removed by only using tir.if_then_else to add padding
# zeros if (kernel_w * kernel_h) % 2 == 1, and filling completely otherwise.
assert (kernel_w * kernel_h) % 2 == 1

def fcompute(c_o, pos, c_i):
channel = (2 * (pos % 2)) + (c_i % 2) + (4 * c_o)
true_pos_index = 2 * (pos // 2) + (c_i // 2)

return tir.if_then_else(
true_pos_index < (kernel_h * kernel_w),
kernel[true_pos_index // kernel_w, true_pos_index % kernel_w, channel, 0],
tir.const(0, "int8"),
)

return te.compute(
(channels // 4, kernel_h * kernel_w + 1, 4),
fcompute,
name="packed_kernel",
)
from .micro_kernel.common import num_simd_lanes_per_word


def depthwise_conv2d_nhwc_dsp_compute(_cfg, data, kernel, strides, padding, dilation, out_dtype):
Expand All @@ -120,10 +51,7 @@ def depthwise_conv2d_nhwc_dsp_compute(_cfg, data, kernel, strides, padding, dila

batch_size, height, width, channels = data.shape
kernel_h, kernel_w, _, _ = kernel.shape

# We require that the number of channels be divisible by 4. This restriction could
# be removed with strip mining if people cared.
assert channels % 4 == 0
simd_lanes = num_simd_lanes_per_word(data.dtype)

# We don't support different numbers of input and output channels.
assert channels == kernel.shape[2]
Expand All @@ -133,11 +61,6 @@ def depthwise_conv2d_nhwc_dsp_compute(_cfg, data, kernel, strides, padding, dila
# round until we compute activations.
assert out_dtype == "int32"

# This can pretty easily be generalized in the future. Likely worth doing, and this
# function was written to make doing so easy. Should only require adding more calls
# to QUAD_CHANNEL_REARRANGE_SUM.
assert kernel_w == kernel_h == 3

# Padding the data requires COPYING THE ENTIRE INPUT TENSOR, which
# is slow and bad. We should really implement a strip mining
# routine to avoid this, but TVM has terrible support for that.
Expand Down Expand Up @@ -188,18 +111,14 @@ def depthwise_conv2d_nhwc_dsp_compute(_cfg, data, kernel, strides, padding, dila
raise RuntimeError()
_, padded_h, padded_w, _ = padded_data.shape

packed_kernel = _rearrange_kernel(kernel)
kh_i = te.reduce_axis((0, kernel_h), name="kh_i")
kw_i = te.reduce_axis((0, kernel_w), name="kw_i")
reshaped_kernel = topi.reshape(kernel, (channels // simd_lanes, kernel_h, kernel_w, simd_lanes))
return te.compute(
(batch_size, output_h, output_w, channels),
lambda h, i, j, k: te.sum(
padded_data[h, (i * stride_h) + kh_i, (j * stride_w) + kw_i, k].astype("int32")
* packed_kernel[
k // 4,
(2 * ((3 * kh_i + kw_i) // 2)) + ((k % 4) // 2),
(2 * ((kh_i + kw_i) % 2)) + (k % 2),
].astype("int32"),
* reshaped_kernel[k // simd_lanes, kh_i, kw_i, k % simd_lanes].astype("int32"),
axis=(kh_i, kw_i),
),
name="depthwise_conv2d",
Expand All @@ -212,33 +131,36 @@ def depthwise_conv2d_nhwc_dsp_schedule(_cfg, outs):
"""Schedule function for v7e-m DSP instructions of conv2d."""
schedule = te.create_schedule([x.op for x in outs])

def _callback(op):
if "depthwise_conv2d_nhwc" not in op.tag:
def _callback(operator):
if "depthwise_conv2d_nhwc" not in operator.tag:
return

# extract tensors
output = op.output(0)
output = operator.output(0)
padded_data = output.op.input_tensors[0]
packed_kernel = output.op.input_tensors[1]
kernel = packed_kernel.op.input_tensors[0]
reshaped_kernel = output.op.input_tensors[1]
in_dtype = padded_data.dtype

_, _, padded_w, channels = padded_data.shape
kernel_h, kernel_w, _, _ = kernel.shape
_, padded_h, padded_w, channels = padded_data.shape
_, kernel_h, kernel_w, _ = reshaped_kernel.shape
suffix = "".join(random.choices(string.ascii_uppercase, k=8))

b_ax, y_ax, x_ax, c_ax = schedule[output].op.axis
ky_ax, kx_ax = schedule[output].op.reduce_axis
c_ax_o, c_ax_i = schedule[output].split(c_ax, factor=4)
simd_lanes = num_simd_lanes_per_word(in_dtype)
c_ax_o, c_ax_i = schedule[output].split(c_ax, factor=simd_lanes)
schedule[output].reorder(b_ax, c_ax_o, y_ax, x_ax, ky_ax, kx_ax, c_ax_i)

quad_channel_convolve = intrin_quad_channel_convolve(
padded_w, channels, kernel_h, kernel_w, suffix
multi_channel_convolve = intrin_multi_channel_convolve(
in_dtype, padded_h, padded_w, channels, kernel_h, kernel_w, suffix
)
schedule[output].tensorize(ky_ax, quad_channel_convolve)
schedule[output].tensorize(ky_ax, multi_channel_convolve)
schedule[output].pragma(
b_ax,
"import_c",
quad_channel_convolve_impl(padded_w, channels, kernel_h, kernel_w, suffix),
multi_channel_convolve_impl(
in_dtype, padded_h, padded_w, channels, kernel_h, kernel_w, suffix
),
)

traverse_inline(schedule, outs[-1].op, _callback)
Expand Down
15 changes: 15 additions & 0 deletions python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,18 @@
#include <tvm/runtime/crt/error_codes.h>

"""

MICRO_WORD_LENGTH_BITS = 32


def num_simd_lanes_per_word(dtype: str) -> int:
"""Takes a dtype, and returns how many of that dtype fit into a single microcontroller word.

>>> num_simd_lanes_per_word("int8")
4
>>> num_simd_lanes_per_word("int16")
2
"""
assert dtype.startswith("int")
dtype_width = int(dtype[3:])
return MICRO_WORD_LENGTH_BITS // dtype_width
Loading