Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 60 additions & 7 deletions python/tvm/relay/op/strategy/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Definition of ARM CPU operator strategy."""
from functools import reduce
import logging

# pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import
Expand Down Expand Up @@ -71,6 +72,32 @@ def schedule_pool_arm_cpu(attrs, outs, target):
return topi.generic.schedule_pool(outs, layout)


def _get_padding_width(padding):
assert isinstance(padding, tuple)
if len(padding) == 2:
_, (pad_left, pad_right) = padding
else:
_, pad_left, _, pad_right = padding
return pad_left + pad_right


def _is_simd_aligned(dtype, dimensions, padding=None):
if padding:
assert len(dimensions) == len(padding)
padded_dims = (sum(x) for x in zip(dimensions, padding))
else:
padded_dims = dimensions

# Multiply all elements of padded_dims together. We can't use math.prod, as it
# does not exist in Python 3.7.
size = reduce(lambda x, y: x * y, padded_dims)
return (
(dtype == "int8" and size % 4 == 0)
or (dtype == "int16" and size % 2 == 0)
or (dtype == "int32")
)


@conv2d_strategy.register("arm_cpu")
def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
"""conv2d arm cpu strategy"""
Expand Down Expand Up @@ -159,7 +186,21 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
name="conv2d_hwcn.generic",
)
elif layout == "NHWC":
if target.features.has_dsp and kernel_layout == "HWOI":
data_width_padding = _get_padding_width(padding)
if (
target.features.has_dsp
and dilation_w == dilation_h == 1
and kernel_layout == "OHWI"
# Check SIMD alignment
and _is_simd_aligned(data.dtype, data.shape[2:], padding=(data_width_padding, 0))
and _is_simd_aligned(kernel.dtype, kernel.shape[2:])
):
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.conv2d_nhwc_ohwi_dsp),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nhwc_ohwi_dsp),
name="conv2d_nhwc_ohwi_dsp.arm_cpu",
)
elif target.features.has_dsp and kernel_layout == "HWOI":
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.conv2d_nhwc_dsp),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nhwc_dsp),
Expand Down Expand Up @@ -199,13 +240,25 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
if layout == "NCHW":
assert kernel_layout == "OIHW" or re.match(r"OIHW\d*o", kernel_layout)
# ARM conv2d depthwise schedule
if kernel_layout == "OIHW":
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.depthwise_conv2d_nchw),
wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nchw),
name="depthwise_conv2d_nchw.arm_cpu",
)
data_width_padding = _get_padding_width(padding)
if (
target.features.has_dsp
and dilation_w == dilation_h == 1
and _is_simd_aligned(data.dtype, data.shape[3:], padding=(data_width_padding,))
and _is_simd_aligned(kernel.dtype, kernel.shape[3:])
):
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.depthwise_conv2d_nchw_oihw_dsp),
wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nchw_oihw_dsp),
name="depthwise_conv2d_nchw_oihw_dsp.arm_cpu",
)
else:
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.depthwise_conv2d_nchw),
wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nchw),
name="depthwise_conv2d_nchw.arm_cpu",
)

# TODO:
# This schedule has incorrect result on some hardware platforms (like NV Jetson TX2)
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/topi/arm_cpu/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@
conv2d_nhwc_dsp_compute,
conv2d_nhwc_dsp_schedule,
)
from .mprofile.dsp.tensordot_conv2ds import (
conv2d_nhwc_ohwi_dsp_compute,
tensordot_conv2ds_schedule,
)


@autotvm.register_topi_compute("conv2d_nchw_spatial_pack.arm_cpu")
Expand Down Expand Up @@ -518,3 +522,15 @@ def conv2d_nhwc_dsp(cfg, data, kernel, strides, padding, dilation, out_dtype):
def schedule_conv2d_nhwc_dsp(cfg, outs):
"""Create schedule for conv2d_nhwc_dsp"""
return conv2d_nhwc_dsp_schedule(cfg, outs)


@autotvm.register_topi_compute("conv2d_nhwc_ohwi_dsp.arm_cpu")
def conv2d_nhwc_ohwi_dsp(cfg, data, kernel, strides, padding, dilation, out_dtype):
"""Compute conv2d_nhwc_ohwi with v7e-m DSP instructions and the tensordot kernel."""
return conv2d_nhwc_ohwi_dsp_compute(cfg, data, kernel, strides, padding, dilation, out_dtype)


@autotvm.register_topi_schedule("conv2d_nhwc_ohwi_dsp.arm_cpu")
def schedule_conv2d_nhwc_ohwi_dsp(cfg, outs):
"""Create schedule for conv2d_nhwc_ohwi."""
return tensordot_conv2ds_schedule(cfg, outs)
19 changes: 18 additions & 1 deletion python/tvm/topi/arm_cpu/depthwise_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,14 @@
from ..nn.utils import get_pad_tuple
from .tensor_intrin import smlal_int16_int32
from .arm_utils import is_aarch64_arm

from .mprofile.dsp.depthwise_conv2d import (
depthwise_conv2d_nhwc_dsp_compute,
depthwise_conv2d_nhwc_dsp_schedule,
)
from .mprofile.dsp.tensordot_conv2ds import (
depthwise_conv2d_nchw_oihw_dsp_compute,
tensordot_conv2ds_schedule,
)


@autotvm.register_topi_compute("depthwise_conv2d_nchw.arm_cpu")
Expand Down Expand Up @@ -718,3 +721,17 @@ def depthwise_conv2d_nhwc_dsp(cfg, data, kernel, strides, padding, dilation, out
def schedule_depthwise_conv2d_nhwc_dsp(cfg, outs):
"""Create schedule for conv2d_nhwc_dsp"""
return depthwise_conv2d_nhwc_dsp_schedule(cfg, outs)


@autotvm.register_topi_compute("depthwise_conv2d_nchw_oihw_dsp.arm_cpu")
def depthwise_conv2d_nchw_oihw_dsp(cfg, data, kernel, strides, padding, dilation, out_dtype):
"""Compute depthwise_conv2d_nchw_oihw with v7e-m DSP instructions and the tensordot kernel."""
return depthwise_conv2d_nchw_oihw_dsp_compute(
cfg, data, kernel, strides, padding, dilation, out_dtype
)


@autotvm.register_topi_schedule("depthwise_conv2d_nchw_oihw_dsp.arm_cpu")
def schedule_depthwise_conv2d_nchw_oihw_dsp(cfg, outs):
"""Create schedule for depthwise_conv2d_nchw_oihw."""
return tensordot_conv2ds_schedule(cfg, outs)
155 changes: 155 additions & 0 deletions python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Computes a "jumpy tensordot" operator, which can be used to tensorize many common operators
including regular conv2d, depthwise conv2d, and grouped conv2d provided the data and kernel layouts
are the optimal ones. When groups=1, the optimal data layout is NHWC and kernel layout is OHWI. When
this is a depthwise convolution, the optimal data layout is NCHW and kernel layout is OIHW."""

import textwrap

from tvm import te, tir

from .common import num_simd_lanes_per_word


def _get_func_name(in_dtype, tensor_h, jump, tensor_w, suffix):
"""Gets the C function name of the tensordot function."""
return f"tensordot_{in_dtype}_h{tensor_h}_j{jump}_w{tensor_w}_{suffix}"


def make_intrin_tensordot(slices, strides, tensordot_params):
"""Helper function for constructing tensordot intrinsic. We can't construct the whole thing here
(as multiple schedules use tensordot and each must build the intrinstic differently) but we can
build part here to simplify the code."""

# in_dtype, tensor_h, jump, tensor_w, suffix = tensordot_params
data, kernel, output = slices
data_strides, kernel_strides = strides

data_buf = tir.decl_buffer(
data.shape, data.dtype, name="data", offset_factor=1, strides=data_strides
)
kernel_buf = tir.decl_buffer(
kernel.shape,
kernel.dtype,
name="kernel",
offset_factor=1,
strides=kernel_strides,
)
output_buf = tir.decl_buffer(
output.shape, output.dtype, name="output", offset_factor=1, strides=[1]
)

def intrin_func(ins, outs):
builder = tir.ir_builder.create()
builder.emit(
tir.call_extern(
"int32",
_get_func_name(*tensordot_params),
outs[0].access_ptr("w"),
ins[0].access_ptr("r"),
ins[1].access_ptr("r"),
)
)
return builder.get()

return te.decl_tensor_intrin(
output.op,
intrin_func,
binds={data: data_buf, kernel: kernel_buf, output: output_buf},
)


def tensordot_impl(in_dtype: str, tensor_h: int, jump: int, tensor_w: int, suffix: str) -> str:
"""Generates C code for taking the dot products of two `tensor_h` * `tensor_w` tensors. Also has
a `jump` argument that advances the pointer of one tensor by that many words after each row. The
`jump` and `tensor_w` values must be word-aligned for the input data type, as non-word-aligned
memory access is slow on the Cortex-M series. Depending on the input datatype, the code may
contain DSP instructions for Arm v7e-m. C code contains DSP instructions for Arm v7e-m. See
the below pseudocode for reference:

tensordot(out_ptr, dat_ptr, ker_ptr) {
sum = 0;
for (i = 0; i < tensor_h; i++) {
for (j = 0; j < tensor_w; j++) {
sum += (*dat_ptr++) * (*ker_ptr++);
}
dat_ptr += jump;
}
*out_ptr = sum;
}
"""

simd_lanes = num_simd_lanes_per_word(in_dtype)
assert tensor_w % simd_lanes == 0
assert jump % simd_lanes == 0

if in_dtype == "int8":
inner_loop = """
uint32_t tensor_c20 = __SXTB16(tensor_batch);
uint32_t kernel_c20 = __SXTB16(kernel_batch);
sum = __SMLAD(tensor_c20, kernel_c20, sum);

uint32_t tensor_c31 = __SXTB16(__ROR(tensor_batch, 8));
uint32_t kernel_c31 = __SXTB16(__ROR(kernel_batch, 8));
sum = __SMLAD(tensor_c31, kernel_c31, sum);"""

elif in_dtype == "int16":
inner_loop = """
sum = __SMLAD(tensor_batch, kernel_batch, sum);"""

elif in_dtype == "int32":
inner_loop = """
// Compiles to a single MAC instruction
sum += tensor_batch * kernel_batch;"""

else:
raise ValueError(f"No tensordot implementation exists for dtype '{in_dtype}'!")

function_name = _get_func_name(in_dtype, tensor_h, jump, tensor_w, suffix)
return textwrap.dedent(
(
f"""
#include <stdint.h>
#include <arm_nnsupportfunctions.h>

#ifdef __cplusplus
extern "C"
#endif
__STATIC_FORCEINLINE int32_t {function_name}(
uint32_t *out,
uint32_t *tensor,
uint32_t *kernel) {{

uint32_t sum = 0;

#pragma GCC unroll {tensor_h}
for (int i = 0; i < {tensor_h}; i++) {{
#pragma GCC unroll {tensor_w // simd_lanes}
for (int j = 0; j < {tensor_w // simd_lanes}; j++) {{
uint32_t tensor_batch = *tensor++;
uint32_t kernel_batch = *kernel++;
{inner_loop.strip()}
}}
tensor += {jump // simd_lanes};
}}
out[0] = sum;
return 0;
}}
"""
)
)
Loading