From 278c20203486cf798157b24496e401ed226e5a68 Mon Sep 17 00:00:00 2001 From: Gavin Uberti Date: Mon, 3 Oct 2022 02:05:25 -0700 Subject: [PATCH 01/18] Rewrite conv2D to tensorize with tensordot --- .../tvm/topi/arm_cpu/mprofile/dsp/conv2d.py | 162 +++++++---------- .../mprofile/dsp/micro_kernel/tensordot.py | 163 ++++++++++++++++++ 2 files changed, 224 insertions(+), 101 deletions(-) create mode 100644 python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/conv2d.py b/python/tvm/topi/arm_cpu/mprofile/dsp/conv2d.py index 470d46b92a7a..2664c52faff5 100644 --- a/python/tvm/topi/arm_cpu/mprofile/dsp/conv2d.py +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/conv2d.py @@ -25,32 +25,14 @@ from tvm.topi.nn.utils import get_pad_tuple from tvm.tir.expr import Mul -from .micro_kernel.gemm import ( - intrin_gemm_MxKxN, - gemm_MxKxN_impl, +from .micro_kernel.tensordot import ( + make_intrin_tensordot, + tensordot_impl, ) -def conv2d_nhwc_dsp(*args, **kwargs): - """Defines the v7e-m DSP instructions of conv2d.""" - assert not kwargs, "Do not support kwargs in template function call" - args = deserialize_args(args) - data, kernel = args[:2] - layout = args[-2] - cfg = autotvm.get_config() - args = [cfg] + args - assert layout == "NHWC" - conv = conv2d_nhwc_dsp_compute(*args) - sched = conv2d_nhwc_dsp_schedule(cfg, [data, kernel, conv]) - return sched, [data, kernel, conv] - - -conv2d_nhwc_dsp.template_key = "dsp" -conv2d_nhwc_dsp.default_data_layout = "NHWC" -conv2d_nhwc_dsp.default_kernel_layout = "HWOI" - - def conv2d_nhwc_dsp_compute(cfg, data, kernel, strides, padding, dilation, out_dtype): + print("Activating Conv2D NHWC schedule") """Compute function for v7e-m DSP instructions of conv2d.""" assert isinstance(strides, int) or len(strides) == 2 assert isinstance(dilation, int) or len(dilation) == 2 @@ -66,7 +48,8 @@ def conv2d_nhwc_dsp_compute(cfg, data, kernel, strides, padding, dilation, out_d dilation_h, dilation_w = dilation batch_size, in_height, in_width, in_channels = data.shape - kernel_h, kernel_w, out_channels, _ = kernel.shape + out_channels, kernel_h, kernel_w, _ = kernel.shape + assert kernel.shape[3] == in_channels # compute the output shape dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 @@ -81,9 +64,9 @@ def conv2d_nhwc_dsp_compute(cfg, data, kernel, strides, padding, dilation, out_d pad_after = [0, pad_down, pad_right, 0] padded_data = pad(data, pad_before, pad_after, name="padded_data") - rc = te.reduce_axis((0, in_channels), name="rc") ry = te.reduce_axis((0, kernel_h), name="ry") rx = te.reduce_axis((0, kernel_w), name="rx") + rc = te.reduce_axis((0, in_channels), name="rc") conv = te.compute( (batch_size, out_height, out_width, out_channels), @@ -91,104 +74,81 @@ def conv2d_nhwc_dsp_compute(cfg, data, kernel, strides, padding, dilation, out_d padded_data[ nn, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rc ].astype(out_dtype) - * kernel[ry, rx, ff, rc].astype(out_dtype), + * kernel[ff, ry, rx, rc].astype(out_dtype), axis=[ry, rx, rc], ), name="conv2d", tag="conv2d_nhwc", ) - ########################### - # Config Space Definition # - ########################### - n, oh, ow, co = ( - cfg.axis(batch_size.value), - cfg.axis(out_height.value), - cfg.axis(out_width.value), - cfg.axis(out_channels.value), - ) - kh, kw, ci = ( - cfg.reduce_axis(kernel_h.value), - cfg.reduce_axis(kernel_w.value), - cfg.reduce_axis(in_channels.value), + return conv + + +def _make_tensorization(padded_data, kernel, suffix): + _, padded_h, padded_w, in_channels = padded_data.shape + _, kernel_h, kernel_w, _ = kernel.shape + + data_slice = te.placeholder((kernel_h, kernel_w, in_channels), name="a", dtype=in_dtype) + kernel_slice = te.placeholder((kernel_h, kernel_w, in_channels), name="b", dtype=in_dtype) + + kh_i = te.reduce_axis((0, kernel_h), name="kh_i") + kw_i = te.reduce_axis((0, kernel_w), name="kw_i") + kc_i = te.reduce_axis((0, in_channels), name="kc_i") + + output_slice = te.compute((1,), + lambda k: te.sum( + data_slice[kh_i, kw_i, kc_i].astype("int32") + * kernel_slice[kh_i, kw_i, kc_i].astype("int32"), + axis=[kh_i, kw_i, kc_i], + ), + name="c", ) - owo, owi = cfg.define_split("tile_ow", ow, policy="factors", num_outputs=2) - cio, cii = cfg.define_split( - "tile_ci", - ci, - policy="factors", - num_outputs=2, - # TODO: check case with in_channels.value % 4 != 0 with AutoTVM - filter=None if cfg.is_fallback else lambda x: x.size[-1] % 4 == 0, + data_buf = tir.decl_buffer( + data_slice.shape, data_slice.dtype, name="data", offset_factor=1, + strides=[tensor_w * in_channels, in_channels, 1], + ) + kernel_buf = tir.decl_buffer( + kernel_slice.shape, kernel_slice.dtype, name="kernel", offset_factor=1, + strides=[kernel_w * in_channels, in_channels, 1] ) - coo, coi = cfg.define_split("tile_co", co, policy="factors", num_outputs=2) - - cfg.define_reorder( - "reorder_0_simd", - [n, oh, owo, owi, coo, coi, kh, kw, cio, cii], - policy="candidate", - candidate=[ - [n, oh, kh, kw, owo, coo, cio, owi, coi, cii], - [n, oh, kh, kw, coo, owo, cio, owi, coi, cii], - [n, kh, kw, oh, owo, coo, cio, owi, coi, cii], - [n, kh, kw, oh, coo, owo, cio, owi, coi, cii], - ], + output_buf = tir.decl_buffer( + output_slice.shape, output_slice.dtype, name="output", offset_factor=1, strides=[1] ) - cfg.define_knob("auto_unroll_max_step", [0, 2, 4, 8, 16, 32]) - cfg.define_knob("unroll_explicit", [0, 1]) + jump = (tensor_w - kernel_w) * in_channels + tensordot_params = (kernel_h, jump, kernel_w * in_channels, suffix) - if cfg.is_fallback: - cfg.fallback_split("tile_ow", [-1, out_width.value]) - cfg.fallback_split("tile_ci", [-1, in_channels.value]) - cfg.fallback_split("tile_co", [-1, out_channels.value]) + intrin_tensordot = make_intrin_tensordot( + output_slice.op, + (data_buf, kernel_buf, output_buf), + tensordot_params + ) - return conv + tensordot_code = tensordot_impl(tensordot_params) + return (intrin_tensordot, tensordot_code) def conv2d_nhwc_dsp_schedule(cfg, outs): """Schedule function for v7e-m DSP instructions of conv2d.""" - sched = te.create_schedule([x.op for x in outs]) + schedule = te.create_schedule([x.op for x in outs]) - def _callback(op): - if "conv2d_nhwc" not in op.tag: + def _callback(operator): + if "conv2d_nhwc" not in operator.tag: return # extract tensors - output = op.output(0) - conv = op - data_vec = conv.input_tensors[0] - kernel = conv.input_tensors[1] # pylint: disable=unused-variable - last = outs[0] # pylint: disable=unused-variable - - source_index_w = output.op.body[0].source[0].a.value.indices[2].a - stride_w = source_index_w.b.value if isinstance(source_index_w, Mul) else 1 - - # tile reduction axes - n, oh, ow, co = sched[conv].op.axis - kh, kw, ci = sched[conv].op.reduce_axis - - M = cfg["tile_ow"].size[-1] - K = cfg["tile_ci"].size[-1] - N = cfg["tile_co"].size[-1] - - owo, owi = cfg["tile_ow"].apply(sched, conv, ow) - cio, cii = cfg["tile_ci"].apply(sched, conv, ci) - coo, coi = cfg["tile_co"].apply(sched, conv, co) - - cfg["reorder_0_simd"].apply(sched, conv, [n, oh, owo, owi, coo, coi, kh, kw, cio, cii]) - - gemm, uniq_id = intrin_gemm_MxKxN(M, K, N, data_vec.dtype, output.dtype, stride_w) - sched[output].tensorize(owi, gemm) - sched[output].pragma(n, "import_c", gemm_MxKxN_impl(M, K, N, uniq_id)) + output = operator.output(0) + padded_data = output.op.input_tensors[0] + kernel = output.op.input_tensors[1] - # this is the scope to attach global config inside this kernel - kernel_scope = n + b_ax, y_ax, x_ax, co_ax = schedule[output].op.axis + kh_ax, kw_ax, ci_ax = schedule[output].op.reduce_axis + schedule[output].reorder(b_ax, y_ax, x_ax, co_ax, kh_ax, kw_ax, ci_ax) - # tune unroll - sched[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) - sched[output].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val) + intrin, code = _make_tensorization(padded_data, kernel, suffix) + schedule[output].tensorize(kh_ax, intrin) + schedule[output].pragma(n, "import_c", code) - traverse_inline(sched, outs[-1].op, _callback) - return sched + traverse_inline(schedule, outs[-1].op, _callback) + return schedule diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py new file mode 100644 index 000000000000..d58d3e536de5 --- /dev/null +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py @@ -0,0 +1,163 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Computes a "jumpy tensordot" operator, which can/should be used to +tensorize ANY aritrarily grouped conv2D, provided the data and kernel +layouts are the optimal ones. When groups=1, the optimal data layout +is NHWC and kernel layout is OHWI. When this is a depthwise convolution, +the optimal data layout is NCHW and kernel layout is OIHW.""" + +import textwrap + +from tvm import te, tir + + +def _get_func_name(in_dtype, tensor_h, jump, tensor_w, suffix): + """Gets the C function name of the tensorized function.""" + return f"tensordot_{in_dtype}_h{tensor_h}_j{jump}_w{tensor_w}_{suffix}" + + +def make_intrin_tensordot(operator, buffers, tensordot_params): + data, kernel, output = buffers + #tensor_h, jump, tensor_w, suffix = tensordot_params + + def intrin_func(ins, outs): + builder = tir.ir_builder.create() + builder.emit( + tir.call_extern( + "int32", + _get_func_name(data.dtype, *tensordot_params), + outs[0].access_ptr("w"), + ins[0].access_ptr("r"), + ins[1].access_ptr("r"), + ) + ) + return builder.get() + + return te.decl_tensor_intrin( + operator, + intrin_func, + binds={data_slice: data, kernel_slice: kernel, output_slice: output}, + ) + + +def intrin_depthwise_conv2d_tensordot(in_dtype, tensor_w, kernel_h, kernel_w, suffix): + data_slice = te.placeholder((kernel_h, kernel_w), name="a", dtype=in_dtype) + kernel_slice = te.placeholder((kernel_h, kernel_w), name="b", dtype=in_dtype) + + kh_i = te.reduce_axis((0, kernel_h), name="kh_i") + kw_i = te.reduce_axis((0, kernel_w), name="kw_i") + + output_slice = te.compute((1,), + lambda k: te.sum( + data_slice[kh_i, kw_i].astype("int32") + * kernel_slice[kh_i, kw_i].astype("int32"), + axis=[kh_i, kw_i], + ), + name="c", + ) + + data_buf = tir.decl_buffer( + data_slice.shape, + data_slice.dtype, + name="data", + offset_factor=1, + strides=[tensor_w, 1], + ) + kernel_buf = tir.decl_buffer( + kernel_slice.shape, kernel_slice.dtype, name="kernel", offset_factor=1, strides=[kernel_w, 1] + ) + output_buf = tir.decl_buffer( + output_slice.shape, output_slice.dtype, name="output", offset_factor=1, strides=[1] + ) + + def intrin_func(ins, outs): + builder = tir.ir_builder.create() + builder.emit( + tir.call_extern( + "int32", + _get_func_name(in_dtype, tensor_w, channels, kernel_h, kernel_w, suffix), + outs[0].access_ptr("w"), + ins[0].access_ptr("r"), + ins[1].access_ptr("r"), + ) + ) + return builder.get() + + return te.decl_tensor_intrin( + output_slice.op, + intrin_func, + binds={data_slice: data_buf, kernel_slice: kernel_buf, output_slice: output_buf}, + ) + + +def tensordot_impl(in_dtype, tensor_h, jump, tensor_w, suffix): + assert in_dtype in ["int8", "int16", "int32"] + simd_lanes = num_simd_lanes_per_word(in_dtype) + assert tensor_w % simd_lanes == 0 + assert jump % simd_lanes == 0 + + if in_dtype == "int8": + inner_loop = """ + uint32_t tensor_c20 = __SXTB16(tensor_batch); + uint32_t kernel_c20 = __SXTB16(kernel_batch); + sum = __SMLAD(tensor_c20, kernel_c20, sum); + + uint32_t tensor_c31 = __SXTB16(__ROR(tensor_batch, 8)); + uint32_t kernel_c31 = __SXTB16(__ROR(kernel_batch, 8)); + sum = __SMLAD(tensor_c31, kernel_c31, sum);""" + + elif in_dtype == "int16": + inner_loop = """ + sum = __SMLAD(tensor_batch, kernel_batch, sum);""" + + elif in_dtype == "int32": + inner_loop = """ + sum = __MLA(tensor_batch, kernel_batch, sum);""" + + function_name = _get_func_name(in_dtype, tensor_h, jump, tensor_w, suffix) + return textwrap.dedent( + ( + f""" + #include + #include + + #ifdef __cplusplus + extern "C" + #endif + __STATIC_FORCEINLINE int32_t {function_name}( + uint32_t *out, + uint32_t *tensor, + uint32_t *kernel) {{ + + uint32_t sum = 0; + + #pragma GCC unroll {tensor_h} + for (int i = 0; i < {tensor_h}; i++) {{ + #pragma GCC unroll {length // simd_width} + for (int j = 0; j < {length // simd_width}; i++) {{ + uint32_t tensor_batch = *tensor++; + uint32_t kernel_batch = *kernel++; + {inner_loop.trim()} + }} + tensor += {jump // simd_width}; + }} + out[0] = sum; + return 0; + }} + """ + ) + ) From c645428599f742e9a22b3b11dcb1c9c4ad746153 Mon Sep 17 00:00:00 2001 From: Gavin Uberti Date: Mon, 3 Oct 2022 04:34:27 -0700 Subject: [PATCH 02/18] Functional conv2D tensordot implementation --- .../tvm/topi/arm_cpu/mprofile/dsp/conv2d.py | 25 +++++++++------- .../mprofile/dsp/micro_kernel/tensordot.py | 20 ++++++------- python/tvm/topi/utils.py | 29 +++++++++++++++++-- 3 files changed, 52 insertions(+), 22 deletions(-) diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/conv2d.py b/python/tvm/topi/arm_cpu/mprofile/dsp/conv2d.py index 2664c52faff5..17c6f2aaccba 100644 --- a/python/tvm/topi/arm_cpu/mprofile/dsp/conv2d.py +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/conv2d.py @@ -17,9 +17,11 @@ # pylint: disable=invalid-name, no-value-for-parameter """Direct implementation of conv2d.""" -from tvm import autotvm +import random +import string + +from tvm import autotvm, te, tir from tvm.autotvm.task import deserialize_args -from tvm import te from tvm.topi.utils import simplify, traverse_inline from tvm.topi.nn.pad import pad from tvm.topi.nn.utils import get_pad_tuple @@ -84,9 +86,12 @@ def conv2d_nhwc_dsp_compute(cfg, data, kernel, strides, padding, dilation, out_d return conv -def _make_tensorization(padded_data, kernel, suffix): +def _make_tensorization(padded_data, kernel): _, padded_h, padded_w, in_channels = padded_data.shape _, kernel_h, kernel_w, _ = kernel.shape + in_dtype = padded_data.dtype + suffix = "".join(random.choices(string.ascii_uppercase, k=8)) + assert in_dtype == kernel.dtype data_slice = te.placeholder((kernel_h, kernel_w, in_channels), name="a", dtype=in_dtype) kernel_slice = te.placeholder((kernel_h, kernel_w, in_channels), name="b", dtype=in_dtype) @@ -106,7 +111,7 @@ def _make_tensorization(padded_data, kernel, suffix): data_buf = tir.decl_buffer( data_slice.shape, data_slice.dtype, name="data", offset_factor=1, - strides=[tensor_w * in_channels, in_channels, 1], + strides=[padded_w * in_channels, in_channels, 1], ) kernel_buf = tir.decl_buffer( kernel_slice.shape, kernel_slice.dtype, name="kernel", offset_factor=1, @@ -116,16 +121,16 @@ def _make_tensorization(padded_data, kernel, suffix): output_slice.shape, output_slice.dtype, name="output", offset_factor=1, strides=[1] ) - jump = (tensor_w - kernel_w) * in_channels - tensordot_params = (kernel_h, jump, kernel_w * in_channels, suffix) + jump = (padded_w - kernel_w) * in_channels + tensordot_params = (in_dtype, kernel_h, jump, kernel_w * in_channels, suffix) intrin_tensordot = make_intrin_tensordot( output_slice.op, - (data_buf, kernel_buf, output_buf), + {data_slice: data_buf, kernel_slice: kernel_buf, output_slice: output_buf}, tensordot_params ) - tensordot_code = tensordot_impl(tensordot_params) + tensordot_code = tensordot_impl(*tensordot_params) return (intrin_tensordot, tensordot_code) @@ -146,9 +151,9 @@ def _callback(operator): kh_ax, kw_ax, ci_ax = schedule[output].op.reduce_axis schedule[output].reorder(b_ax, y_ax, x_ax, co_ax, kh_ax, kw_ax, ci_ax) - intrin, code = _make_tensorization(padded_data, kernel, suffix) + intrin, code = _make_tensorization(padded_data, kernel) schedule[output].tensorize(kh_ax, intrin) - schedule[output].pragma(n, "import_c", code) + schedule[output].pragma(b_ax, "import_c", code) traverse_inline(schedule, outs[-1].op, _callback) return schedule diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py index d58d3e536de5..72c8b432d7b5 100644 --- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py @@ -24,22 +24,22 @@ from tvm import te, tir +from .common import num_simd_lanes_per_word def _get_func_name(in_dtype, tensor_h, jump, tensor_w, suffix): """Gets the C function name of the tensorized function.""" return f"tensordot_{in_dtype}_h{tensor_h}_j{jump}_w{tensor_w}_{suffix}" -def make_intrin_tensordot(operator, buffers, tensordot_params): - data, kernel, output = buffers - #tensor_h, jump, tensor_w, suffix = tensordot_params +def make_intrin_tensordot(operator, binds, tensordot_params): + #in_dtype, tensor_h, jump, tensor_w, suffix = tensordot_params def intrin_func(ins, outs): builder = tir.ir_builder.create() builder.emit( tir.call_extern( "int32", - _get_func_name(data.dtype, *tensordot_params), + _get_func_name(*tensordot_params), outs[0].access_ptr("w"), ins[0].access_ptr("r"), ins[1].access_ptr("r"), @@ -50,7 +50,7 @@ def intrin_func(ins, outs): return te.decl_tensor_intrin( operator, intrin_func, - binds={data_slice: data, kernel_slice: kernel, output_slice: output}, + binds=binds, ) @@ -100,7 +100,7 @@ def intrin_func(ins, outs): return te.decl_tensor_intrin( output_slice.op, intrin_func, - binds={data_slice: data_buf, kernel_slice: kernel_buf, output_slice: output_buf}, + binds=binings, ) @@ -147,13 +147,13 @@ def tensordot_impl(in_dtype, tensor_h, jump, tensor_w, suffix): #pragma GCC unroll {tensor_h} for (int i = 0; i < {tensor_h}; i++) {{ - #pragma GCC unroll {length // simd_width} - for (int j = 0; j < {length // simd_width}; i++) {{ + #pragma GCC unroll {tensor_w // simd_lanes} + for (int j = 0; j < {tensor_w // simd_lanes}; j++) {{ uint32_t tensor_batch = *tensor++; uint32_t kernel_batch = *kernel++; - {inner_loop.trim()} + {inner_loop.strip()} }} - tensor += {jump // simd_width}; + tensor += {jump // simd_lanes}; }} out[0] = sum; return 0; diff --git a/python/tvm/topi/utils.py b/python/tvm/topi/utils.py index f1c6fb5aa4f4..a8f215dce8bf 100644 --- a/python/tvm/topi/utils.py +++ b/python/tvm/topi/utils.py @@ -22,9 +22,8 @@ import numpy as np import tvm -from tvm import te +from tvm import relay, te from tvm.tir import bijective_layout, layout - from . import cpp, tag @@ -432,6 +431,32 @@ def get_shape(src_shape, src_layout, dst_layout): return get_const_tuple(tuple([src_shape[i.value] for i in dst_indices])) +def change_constant_shape(src, src_layout, dst_layout): + """Makes a copy of a Relay constant, reshaping it to a new data layout. + + Parameter + --------- + src : relay.Constant + The Constant to be reformatted. + + src_layout : str + The current layout of the Relay constant. Must be alphabetic (e.g. NHWC + or OIHW, but not NCHW2c). + + dst_layout : str + The desired layout of new the Relay constant. Must be alphabetic (e.g. NHWC + or OIHW, but not NCHW2c). + + Returns + ------- + dst_shape : relay.Constant + A copy of the Constant with the new layout. + """ + assert src_layout.isalpha() and dst_layout.isalpha() + axis_order = [src_layout.index(c) for c in dst_layout] + reshaped = np.transpose(src.data.numpy(), axis_order) + return relay.Constant(tvm.nd.array(reshaped)) + def within_index(b, e, s, i): """Return a boolean value that indicates if i is within the given index. From e3ca2566be19fe4108e872e6f2d5a57539443f47 Mon Sep 17 00:00:00 2001 From: Gavin Uberti Date: Mon, 3 Oct 2022 12:04:08 -0700 Subject: [PATCH 03/18] Add stupid hack to work around TVM bug --- python/tvm/topi/arm_cpu/mprofile/dsp/conv2d.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/conv2d.py b/python/tvm/topi/arm_cpu/mprofile/dsp/conv2d.py index 17c6f2aaccba..7124228cf463 100644 --- a/python/tvm/topi/arm_cpu/mprofile/dsp/conv2d.py +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/conv2d.py @@ -70,7 +70,7 @@ def conv2d_nhwc_dsp_compute(cfg, data, kernel, strides, padding, dilation, out_d rx = te.reduce_axis((0, kernel_w), name="rx") rc = te.reduce_axis((0, in_channels), name="rc") - conv = te.compute( + return te.compute( (batch_size, out_height, out_width, out_channels), lambda nn, yy, xx, ff: te.sum( padded_data[ @@ -83,8 +83,6 @@ def conv2d_nhwc_dsp_compute(cfg, data, kernel, strides, padding, dilation, out_d tag="conv2d_nhwc", ) - return conv - def _make_tensorization(padded_data, kernel): _, padded_h, padded_w, in_channels = padded_data.shape @@ -109,9 +107,12 @@ def _make_tensorization(padded_data, kernel): name="c", ) + # TVM has a really strange bug where the outer reduction axis (kh_i) having length 1 causes the + # decl_buffer strides check to fail. height_stride is a dark magic workaround for this. + height_stride = in_channels * padded_w if kernel_h > 1 else in_channels data_buf = tir.decl_buffer( data_slice.shape, data_slice.dtype, name="data", offset_factor=1, - strides=[padded_w * in_channels, in_channels, 1], + strides=[height_stride, in_channels, 1], ) kernel_buf = tir.decl_buffer( kernel_slice.shape, kernel_slice.dtype, name="kernel", offset_factor=1, From 7783acc3b3858040ab618fc8897bd52809bae8d7 Mon Sep 17 00:00:00 2001 From: Gavin Uberti Date: Tue, 4 Oct 2022 02:44:01 -0700 Subject: [PATCH 04/18] Unit testing for conv2d schedule --- .../tvm/topi/arm_cpu/mprofile/dsp/conv2d.py | 48 ++++++++---- .../mprofile/dsp/micro_kernel/tensordot.py | 3 +- .../strategy/arm_cpu/test_conv2d_nhwc.py | 78 +++++++------------ 3 files changed, 62 insertions(+), 67 deletions(-) diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/conv2d.py b/python/tvm/topi/arm_cpu/mprofile/dsp/conv2d.py index 7124228cf463..89d98be884e5 100644 --- a/python/tvm/topi/arm_cpu/mprofile/dsp/conv2d.py +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/conv2d.py @@ -34,7 +34,6 @@ def conv2d_nhwc_dsp_compute(cfg, data, kernel, strides, padding, dilation, out_dtype): - print("Activating Conv2D NHWC schedule") """Compute function for v7e-m DSP instructions of conv2d.""" assert isinstance(strides, int) or len(strides) == 2 assert isinstance(dilation, int) or len(dilation) == 2 @@ -44,37 +43,53 @@ def conv2d_nhwc_dsp_compute(cfg, data, kernel, strides, padding, dilation, out_d else: stride_h, stride_w = strides + # Dilation prevents us from using DSP instructions, so this schedule can't work (aside from the + # niche case where dilation_h == stride_h and dilation_w == stride_w, which is rare enough we + # probably don't need to support it). if isinstance(dilation, int): dilation_h = dilation_w = dilation else: dilation_h, dilation_w = dilation + assert dilation_h == dilation_w == 1 batch_size, in_height, in_width, in_channels = data.shape out_channels, kernel_h, kernel_w, _ = kernel.shape assert kernel.shape[3] == in_channels - # compute the output shape - dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 - dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 - pad_top, pad_left, pad_down, pad_right = get_pad_tuple( - padding, (dilated_kernel_h, dilated_kernel_w) - ) - out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) - out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) + # Compute and apply padding + assert isinstance(padding, tuple) + if len(padding) == 2: + (pad_up, pad_down), (pad_left, pad_right) = padding + else: + pad_up, pad_left, pad_down, pad_right = padding + + if pad_up or pad_left or pad_down or pad_right: + padded_data = pad( + data, + [0, pad_up, pad_left, 0], + [0, pad_down, pad_right, 0], + name="padded_data" + ) + else: + padded_data = data - pad_before = [0, pad_top, pad_left, 0] - pad_after = [0, pad_down, pad_right, 0] - padded_data = pad(data, pad_before, pad_after, name="padded_data") + # Compute output dimensions + output_h = (in_height - kernel_h + pad_up + pad_down) // stride_h + 1 + output_w = (in_width - kernel_w + pad_left + pad_right) // stride_w + 1 + + # Offsets to "prefer" the bottom right corner. This is done to match Tensorflow's convention, + # but does NOT match the other TVM schedules. + y_offset = (in_height + pad_up + pad_down - kernel_h) % stride_h + x_offset = (in_width + pad_left + pad_right - kernel_w) % stride_w ry = te.reduce_axis((0, kernel_h), name="ry") rx = te.reduce_axis((0, kernel_w), name="rx") rc = te.reduce_axis((0, in_channels), name="rc") - return te.compute( - (batch_size, out_height, out_width, out_channels), + (batch_size, output_h, output_w, out_channels), lambda nn, yy, xx, ff: te.sum( padded_data[ - nn, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rc + nn, y_offset + yy * stride_h + ry * dilation_h, x_offset + xx * stride_w + rx * dilation_w, rc ].astype(out_dtype) * kernel[ff, ry, rx, rc].astype(out_dtype), axis=[ry, rx, rc], @@ -110,8 +125,9 @@ def _make_tensorization(padded_data, kernel): # TVM has a really strange bug where the outer reduction axis (kh_i) having length 1 causes the # decl_buffer strides check to fail. height_stride is a dark magic workaround for this. height_stride = in_channels * padded_w if kernel_h > 1 else in_channels + print(f"Using strides {[height_stride, in_channels, 1]}") data_buf = tir.decl_buffer( - data_slice.shape, data_slice.dtype, name="data", offset_factor=1, + data_slice.shape, data_slice.dtype, name="foofoomcbar", offset_factor=1, strides=[height_stride, in_channels, 1], ) kernel_buf = tir.decl_buffer( diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py index 72c8b432d7b5..8e65b6091903 100644 --- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py @@ -126,7 +126,8 @@ def tensordot_impl(in_dtype, tensor_h, jump, tensor_w, suffix): elif in_dtype == "int32": inner_loop = """ - sum = __MLA(tensor_batch, kernel_batch, sum);""" + // Compiles to a single MAC instruction + sum += tensor_batch * kernel_batch;""" function_name = _get_func_name(in_dtype, tensor_h, jump, tensor_w, suffix) return textwrap.dedent( diff --git a/tests/python/relay/strategy/arm_cpu/test_conv2d_nhwc.py b/tests/python/relay/strategy/arm_cpu/test_conv2d_nhwc.py index f5ae6f51dbd7..72200bbdffae 100644 --- a/tests/python/relay/strategy/arm_cpu/test_conv2d_nhwc.py +++ b/tests/python/relay/strategy/arm_cpu/test_conv2d_nhwc.py @@ -22,7 +22,7 @@ from tvm import relay from tvm.testing.aot import AOTTestModel, compile_and_run, generate_ref_data from tvm.micro.testing.aot_test_utils import AOT_CORSTONE300_RUNNER - +from tvm.topi.utils import change_constant_shape class BasicConv2dTests: @tvm.testing.requires_corstone300 @@ -61,11 +61,7 @@ def test_conv2d( ref_mod = tvm.IRModule.from_expr(relay.Function([input0], out0)) input1 = relay.var("input", relay.TensorType(ishape, dtype)) - - if kernel_layout == "HWOI": - weight1 = relay.const(np.moveaxis(weight_data, 2, -1)) - elif kernel_layout == "HWIO": - weight1 = relay.const(weight_data) + weight1 = change_constant_shape(weight0, "HWIO", kernel_layout) out1 = relay.op.nn.conv2d( input1, @@ -97,57 +93,39 @@ def test_conv2d( ) -class TestConv2d_DSP_HWOI(BasicConv2dTests): - """This test is for conv2d_nhwc_dsp.arm_cpu schedule.""" +class TestConv2d_NHWC_OHWI_DSP(BasicConv2dTests): data_shape, kernel_size, num_filter, strides, padding, dilation = tvm.testing.parameters( - # TODO(mehrdadh): Fails due to https://github.com/apache/tvm/issues/11216 + # Disabled because these kernels are not an integral number of words # ((1, 32, 32, 1), (3, 3), 12, 1, 0, 1), # ((1, 32, 10, 3), (3, 3), 16, 1, 0, 1), - # ((1, 49, 10, 1), (10, 4), 64, (2, 1), (4, 1, 5, 1), 1), - ((1, 32, 32, 16), (3, 3), 16, 1, (0, 2, 2, 0), 1), - ((1, 32, 32, 16), (3, 3), 16, 1, 0, 1), - ((1, 32, 32, 16), (3, 3), 16, 1, 0, 1), - ((1, 32, 32, 16), (3, 3), 16, 1, (0, 2, 2, 0), 2), - ((1, 32, 32, 16), (3, 3), 16, 1, (1, 1, 2, 2), 2), - # from Keyword Spotting model from MLPerfTiny models - # TODO(mehrdad): Fails due to https://github.com/apache/tvm/issues/11216 - # ((1, 49, 10, 1), (10, 4), 64, (2, 2), (4, 1, 5, 1), 1), - # from Visual Wake Word model from MLPerfTiny models - # TODO(mehrdadh): fails due to https://github.com/apache/tvm/issues/11216 # ((1, 96, 96, 3), (3, 3), 8, (2, 2), (0, 0, 1, 1), 1), - # from Image Classification model from MLPerfTiny models - ((1, 16, 16, 32), (1, 1), 64, (2, 2), 0, 1), - ((4, 16, 16, 8), (5, 5), 8, 2, (0, 4, 4, 0), 1), - ((4, 16, 16, 8), (5, 5), 16, 2, (0, 4, 4, 0), 1), - ((4, 16, 16, 8), (5, 5), 8, 2, 0, 1), - ((4, 16, 16, 8), (5, 5), 16, 2, 0, 1), - ((1, 16, 16, 8), (3, 3), 16, 2, (0, 0, 1, 1), 1), - ((1, 16, 16, 8), (3, 3), 16, 2, (1, 1, 2, 2), 1), - ((1, 16, 16, 8), (5, 5), 16, 2, (3, 3, 2, 2), 1), - ((1, 16, 16, 8), (3, 3), 16, 2, (0, 1, 2, 3), 1), - ) - dtype = tvm.testing.parameter("int8", "int16") - kernel_layout = tvm.testing.parameter("HWOI") - schedule_name = tvm.testing.parameter("conv2d_nhwc_dsp.arm_cpu") - -class TestConv2d_HWIO(BasicConv2dTests): - """This test is for conv2d_nhwc_spatial_pack.arm_cpu schedule.""" - - data_shape, kernel_size, num_filter, strides, padding, dilation = tvm.testing.parameters( - ((1, 32, 32, 1), (3, 3), 12, 1, 0, 1), - ((1, 32, 10, 3), (3, 3), 16, 1, 0, 1), - ((1, 49, 10, 1), (10, 4), 64, (2, 1), (4, 1, 5, 1), 1), - ((1, 32, 32, 16), (3, 3), 16, 1, (0, 2, 2, 0), 1), - ((1, 32, 32, 16), (3, 3), 16, 1, 0, 1), - ((1, 32, 32, 16), (3, 3), 16, 1, 0, 1), - ((1, 32, 32, 16), (3, 3), 16, 1, (0, 2, 2, 0), 2), - ((1, 32, 32, 16), (3, 3), 16, 1, (1, 1, 2, 2), 2), + # Disabled because while our schedule matches TensorFlow's behavior, it does NOT + # match the x86 schedule behavior (which is different). These schedules have either: + # (in_height + pad_up + pad_down - kernel_h) % stride_h > 0 OR + # (in_width + pad_left + pad_right - kernel_w) % stride_w > 0 + # ((4, 16, 16, 8), (5, 5), 8, 2, (0, 4, 3, 0), 1), + # ((4, 16, 16, 8), (5, 5), 16, 2, (0, 4, 4, 0), 1), + # ((4, 16, 16, 8), (5, 5), 8, 2, 0, 1), + # ((4, 16, 16, 8), (5, 5), 16, 2, 0, 1), + # ((1, 16, 16, 32), (1, 1), 64, (2, 2), 0, 1), + # ((1, 16, 16, 32), (1, 1), 64, (2, 2), 0, 1) + # ((1, 49, 10, 1), (10, 4), 64, (2, 1), (4, 1, 5, 1), 1), + + ((1, 32, 32, 16), (3, 3), 16, 1, (0, 2, 2, 0), 1), + ((1, 32, 32, 16), (3, 3), 16, 1, 0, 1), + ((1, 32, 32, 16), (3, 3), 16, 1, 0, 1), + ((1, 49, 10, 1), (10, 4), 64, (2, 2), (4, 1, 5, 1), 1), + ((1, 16, 16, 8), (3, 3), 16, 2, (0, 0, 1, 1), 1), + ((1, 16, 16, 8), (3, 3), 16, 2, (1, 1, 2, 2), 1), + ((1, 16, 16, 8), (5, 5), 16, 2, (3, 3, 2, 2), 1), + ((1, 32, 32, 16), (3, 3), 16, 1, 0, 1), + ((1, 16, 16, 32), (1, 1), 64, 1, 0, 1), ) - dtype = tvm.testing.parameter("int8", "int16") - kernel_layout = tvm.testing.parameter("HWIO") - schedule_name = tvm.testing.parameter("conv2d_nhwc_spatial_pack.arm_cpu") + dtype = tvm.testing.parameter("int8", "int16", "int32") + kernel_layout = tvm.testing.parameter("OHWI") + schedule_name = tvm.testing.parameter("conv2d_nhwc_dsp.arm_cpu") if __name__ == "__main__": From 4b0e4c7d3476602b4288cdfe2e52d0fbb059a392 Mon Sep 17 00:00:00 2001 From: Gavin Uberti Date: Tue, 4 Oct 2022 06:29:26 -0700 Subject: [PATCH 05/18] Connect new implementations to Arm strategy --- python/tvm/relay/op/strategy/arm_cpu.py | 30 ++++++++++++++++----- python/tvm/topi/arm_cpu/conv2d.py | 13 +++++++++ python/tvm/topi/arm_cpu/depthwise_conv2d.py | 12 +++++++++ 3 files changed, 48 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index 947beb396ae2..c9a5ac80df1f 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -159,7 +159,17 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): name="conv2d_hwcn.generic", ) elif layout == "NHWC": - if target.features.has_dsp and kernel_layout == "HWOI": + if ( + target.features.has_dsp + and dilation_w == dilation_h == 1 + and kernel_layout == "OHWI" + ): + strategy.add_implementation( + wrap_compute_conv2d(topi.arm_cpu.conv2d_nhwc_ohwi_dsp), + wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nhwc_ohwi_dsp), + name="conv2d_nhwc_ohwi_dsp.arm_cpu", + ) + elif target.features.has_dsp and kernel_layout == "HWOI": strategy.add_implementation( wrap_compute_conv2d(topi.arm_cpu.conv2d_nhwc_dsp), wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nhwc_dsp), @@ -199,13 +209,19 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups): if layout == "NCHW": assert kernel_layout == "OIHW" or re.match(r"OIHW\d*o", kernel_layout) - # ARM conv2d depthwise schedule if kernel_layout == "OIHW": - strategy.add_implementation( - wrap_compute_conv2d(topi.arm_cpu.depthwise_conv2d_nchw), - wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nchw), - name="depthwise_conv2d_nchw.arm_cpu", - ) + if (target.features.has_dsp and dilation_w == dilation_h == 1): + strategy.add_implementation( + wrap_compute_conv2d(topi.arm_cpu.depthwise_conv2d_nchw_oihw_dsp), + wrap_topi_schedule(topi.arm_cpu.depthwise_schedule_conv2d_nchw_oihw_dsp), + name="depthwise_conv2d_nchw_oihw_dsp.arm_cpu", + ) + else: + strategy.add_implementation( + wrap_compute_conv2d(topi.arm_cpu.depthwise_conv2d_nchw), + wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nchw), + name="depthwise_conv2d_nchw.arm_cpu", + ) # TODO: # This schedule has incorrect result on some hardware platforms (like NV Jetson TX2) diff --git a/python/tvm/topi/arm_cpu/conv2d.py b/python/tvm/topi/arm_cpu/conv2d.py index ab489161a8fa..7ec82ed75328 100644 --- a/python/tvm/topi/arm_cpu/conv2d.py +++ b/python/tvm/topi/arm_cpu/conv2d.py @@ -518,3 +518,16 @@ def conv2d_nhwc_dsp(cfg, data, kernel, strides, padding, dilation, out_dtype): def schedule_conv2d_nhwc_dsp(cfg, outs): """Create schedule for conv2d_nhwc_dsp""" return conv2d_nhwc_dsp_schedule(cfg, outs) + + +@autotvm.register_topi_compute("conv2d_nhwc_ohwi_dsp.arm_cpu") +def conv2d_nhwc_ohwi_dsp(cfg, data, kernel, strides, padding, dilation, out_dtype): + return conv2d_nhwc_ohwi_dsp( + cfg, data, kernel, strides, padding, dilation, out_dtype + ) + + +@autotvm.register_topi_schedule("conv2d_nhwc_ohwi_dsp.arm_cpu") +def schedule_conv2d_nhwc_ohwi_dsp(cfg, outs): + return tensordot_conv2ds_schedule(cfg, outs) + diff --git a/python/tvm/topi/arm_cpu/depthwise_conv2d.py b/python/tvm/topi/arm_cpu/depthwise_conv2d.py index 333db3d5e014..ace5ffdabafe 100644 --- a/python/tvm/topi/arm_cpu/depthwise_conv2d.py +++ b/python/tvm/topi/arm_cpu/depthwise_conv2d.py @@ -718,3 +718,15 @@ def depthwise_conv2d_nhwc_dsp(cfg, data, kernel, strides, padding, dilation, out def schedule_depthwise_conv2d_nhwc_dsp(cfg, outs): """Create schedule for conv2d_nhwc_dsp""" return depthwise_conv2d_nhwc_dsp_schedule(cfg, outs) + + +@autotvm.register_topi_compute("depthwise_conv2d_nchw_oihw_dsp.arm_cpu") +def depthwise_conv2d_nchw_oihw_dsp(cfg, data, kernel, strides, padding, dilation, out_dtype): + return depthwise_conv2d_nchw_oihw_dsp( + cfg, data, kernel, strides, padding, dilation, out_dtype + ) + + +@autotvm.register_topi_schedule("depthwise_conv2d_nchw_oihw_dsp.arm_cpu") +def schedule_depthwise_conv2d_nchw_oihw_dsp(cfg, outs): + return tensordot_conv2ds_schedule(cfg, outs) From 058cb3485ec39987e08b4dc04aeeab2e74aef496 Mon Sep 17 00:00:00 2001 From: Gavin Uberti Date: Tue, 4 Oct 2022 06:30:34 -0700 Subject: [PATCH 06/18] Separate into new tensordot conv2d schedule --- .../tvm/topi/arm_cpu/mprofile/dsp/conv2d.py | 218 +++++++++-------- .../arm_cpu/mprofile/dsp/tensordot_conv2ds.py | 228 ++++++++++++++++++ 2 files changed, 346 insertions(+), 100 deletions(-) create mode 100644 python/tvm/topi/arm_cpu/mprofile/dsp/tensordot_conv2ds.py diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/conv2d.py b/python/tvm/topi/arm_cpu/mprofile/dsp/conv2d.py index 89d98be884e5..470d46b92a7a 100644 --- a/python/tvm/topi/arm_cpu/mprofile/dsp/conv2d.py +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/conv2d.py @@ -17,22 +17,39 @@ # pylint: disable=invalid-name, no-value-for-parameter """Direct implementation of conv2d.""" -import random -import string - -from tvm import autotvm, te, tir +from tvm import autotvm from tvm.autotvm.task import deserialize_args +from tvm import te from tvm.topi.utils import simplify, traverse_inline from tvm.topi.nn.pad import pad from tvm.topi.nn.utils import get_pad_tuple from tvm.tir.expr import Mul -from .micro_kernel.tensordot import ( - make_intrin_tensordot, - tensordot_impl, +from .micro_kernel.gemm import ( + intrin_gemm_MxKxN, + gemm_MxKxN_impl, ) +def conv2d_nhwc_dsp(*args, **kwargs): + """Defines the v7e-m DSP instructions of conv2d.""" + assert not kwargs, "Do not support kwargs in template function call" + args = deserialize_args(args) + data, kernel = args[:2] + layout = args[-2] + cfg = autotvm.get_config() + args = [cfg] + args + assert layout == "NHWC" + conv = conv2d_nhwc_dsp_compute(*args) + sched = conv2d_nhwc_dsp_schedule(cfg, [data, kernel, conv]) + return sched, [data, kernel, conv] + + +conv2d_nhwc_dsp.template_key = "dsp" +conv2d_nhwc_dsp.default_data_layout = "NHWC" +conv2d_nhwc_dsp.default_kernel_layout = "HWOI" + + def conv2d_nhwc_dsp_compute(cfg, data, kernel, strides, padding, dilation, out_dtype): """Compute function for v7e-m DSP instructions of conv2d.""" assert isinstance(strides, int) or len(strides) == 2 @@ -43,134 +60,135 @@ def conv2d_nhwc_dsp_compute(cfg, data, kernel, strides, padding, dilation, out_d else: stride_h, stride_w = strides - # Dilation prevents us from using DSP instructions, so this schedule can't work (aside from the - # niche case where dilation_h == stride_h and dilation_w == stride_w, which is rare enough we - # probably don't need to support it). if isinstance(dilation, int): dilation_h = dilation_w = dilation else: dilation_h, dilation_w = dilation - assert dilation_h == dilation_w == 1 batch_size, in_height, in_width, in_channels = data.shape - out_channels, kernel_h, kernel_w, _ = kernel.shape - assert kernel.shape[3] == in_channels + kernel_h, kernel_w, out_channels, _ = kernel.shape - # Compute and apply padding - assert isinstance(padding, tuple) - if len(padding) == 2: - (pad_up, pad_down), (pad_left, pad_right) = padding - else: - pad_up, pad_left, pad_down, pad_right = padding - - if pad_up or pad_left or pad_down or pad_right: - padded_data = pad( - data, - [0, pad_up, pad_left, 0], - [0, pad_down, pad_right, 0], - name="padded_data" - ) - else: - padded_data = data - - # Compute output dimensions - output_h = (in_height - kernel_h + pad_up + pad_down) // stride_h + 1 - output_w = (in_width - kernel_w + pad_left + pad_right) // stride_w + 1 + # compute the output shape + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + pad_top, pad_left, pad_down, pad_right = get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w) + ) + out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) + out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) - # Offsets to "prefer" the bottom right corner. This is done to match Tensorflow's convention, - # but does NOT match the other TVM schedules. - y_offset = (in_height + pad_up + pad_down - kernel_h) % stride_h - x_offset = (in_width + pad_left + pad_right - kernel_w) % stride_w + pad_before = [0, pad_top, pad_left, 0] + pad_after = [0, pad_down, pad_right, 0] + padded_data = pad(data, pad_before, pad_after, name="padded_data") + rc = te.reduce_axis((0, in_channels), name="rc") ry = te.reduce_axis((0, kernel_h), name="ry") rx = te.reduce_axis((0, kernel_w), name="rx") - rc = te.reduce_axis((0, in_channels), name="rc") - return te.compute( - (batch_size, output_h, output_w, out_channels), + + conv = te.compute( + (batch_size, out_height, out_width, out_channels), lambda nn, yy, xx, ff: te.sum( padded_data[ - nn, y_offset + yy * stride_h + ry * dilation_h, x_offset + xx * stride_w + rx * dilation_w, rc + nn, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rc ].astype(out_dtype) - * kernel[ff, ry, rx, rc].astype(out_dtype), + * kernel[ry, rx, ff, rc].astype(out_dtype), axis=[ry, rx, rc], ), name="conv2d", tag="conv2d_nhwc", ) - -def _make_tensorization(padded_data, kernel): - _, padded_h, padded_w, in_channels = padded_data.shape - _, kernel_h, kernel_w, _ = kernel.shape - in_dtype = padded_data.dtype - suffix = "".join(random.choices(string.ascii_uppercase, k=8)) - assert in_dtype == kernel.dtype - - data_slice = te.placeholder((kernel_h, kernel_w, in_channels), name="a", dtype=in_dtype) - kernel_slice = te.placeholder((kernel_h, kernel_w, in_channels), name="b", dtype=in_dtype) - - kh_i = te.reduce_axis((0, kernel_h), name="kh_i") - kw_i = te.reduce_axis((0, kernel_w), name="kw_i") - kc_i = te.reduce_axis((0, in_channels), name="kc_i") - - output_slice = te.compute((1,), - lambda k: te.sum( - data_slice[kh_i, kw_i, kc_i].astype("int32") - * kernel_slice[kh_i, kw_i, kc_i].astype("int32"), - axis=[kh_i, kw_i, kc_i], - ), - name="c", + ########################### + # Config Space Definition # + ########################### + n, oh, ow, co = ( + cfg.axis(batch_size.value), + cfg.axis(out_height.value), + cfg.axis(out_width.value), + cfg.axis(out_channels.value), ) - - # TVM has a really strange bug where the outer reduction axis (kh_i) having length 1 causes the - # decl_buffer strides check to fail. height_stride is a dark magic workaround for this. - height_stride = in_channels * padded_w if kernel_h > 1 else in_channels - print(f"Using strides {[height_stride, in_channels, 1]}") - data_buf = tir.decl_buffer( - data_slice.shape, data_slice.dtype, name="foofoomcbar", offset_factor=1, - strides=[height_stride, in_channels, 1], + kh, kw, ci = ( + cfg.reduce_axis(kernel_h.value), + cfg.reduce_axis(kernel_w.value), + cfg.reduce_axis(in_channels.value), ) - kernel_buf = tir.decl_buffer( - kernel_slice.shape, kernel_slice.dtype, name="kernel", offset_factor=1, - strides=[kernel_w * in_channels, in_channels, 1] + + owo, owi = cfg.define_split("tile_ow", ow, policy="factors", num_outputs=2) + cio, cii = cfg.define_split( + "tile_ci", + ci, + policy="factors", + num_outputs=2, + # TODO: check case with in_channels.value % 4 != 0 with AutoTVM + filter=None if cfg.is_fallback else lambda x: x.size[-1] % 4 == 0, ) - output_buf = tir.decl_buffer( - output_slice.shape, output_slice.dtype, name="output", offset_factor=1, strides=[1] + coo, coi = cfg.define_split("tile_co", co, policy="factors", num_outputs=2) + + cfg.define_reorder( + "reorder_0_simd", + [n, oh, owo, owi, coo, coi, kh, kw, cio, cii], + policy="candidate", + candidate=[ + [n, oh, kh, kw, owo, coo, cio, owi, coi, cii], + [n, oh, kh, kw, coo, owo, cio, owi, coi, cii], + [n, kh, kw, oh, owo, coo, cio, owi, coi, cii], + [n, kh, kw, oh, coo, owo, cio, owi, coi, cii], + ], ) - jump = (padded_w - kernel_w) * in_channels - tensordot_params = (in_dtype, kernel_h, jump, kernel_w * in_channels, suffix) + cfg.define_knob("auto_unroll_max_step", [0, 2, 4, 8, 16, 32]) + cfg.define_knob("unroll_explicit", [0, 1]) - intrin_tensordot = make_intrin_tensordot( - output_slice.op, - {data_slice: data_buf, kernel_slice: kernel_buf, output_slice: output_buf}, - tensordot_params - ) + if cfg.is_fallback: + cfg.fallback_split("tile_ow", [-1, out_width.value]) + cfg.fallback_split("tile_ci", [-1, in_channels.value]) + cfg.fallback_split("tile_co", [-1, out_channels.value]) - tensordot_code = tensordot_impl(*tensordot_params) - return (intrin_tensordot, tensordot_code) + return conv def conv2d_nhwc_dsp_schedule(cfg, outs): """Schedule function for v7e-m DSP instructions of conv2d.""" - schedule = te.create_schedule([x.op for x in outs]) + sched = te.create_schedule([x.op for x in outs]) - def _callback(operator): - if "conv2d_nhwc" not in operator.tag: + def _callback(op): + if "conv2d_nhwc" not in op.tag: return # extract tensors - output = operator.output(0) - padded_data = output.op.input_tensors[0] - kernel = output.op.input_tensors[1] + output = op.output(0) + conv = op + data_vec = conv.input_tensors[0] + kernel = conv.input_tensors[1] # pylint: disable=unused-variable + last = outs[0] # pylint: disable=unused-variable + + source_index_w = output.op.body[0].source[0].a.value.indices[2].a + stride_w = source_index_w.b.value if isinstance(source_index_w, Mul) else 1 + + # tile reduction axes + n, oh, ow, co = sched[conv].op.axis + kh, kw, ci = sched[conv].op.reduce_axis + + M = cfg["tile_ow"].size[-1] + K = cfg["tile_ci"].size[-1] + N = cfg["tile_co"].size[-1] + + owo, owi = cfg["tile_ow"].apply(sched, conv, ow) + cio, cii = cfg["tile_ci"].apply(sched, conv, ci) + coo, coi = cfg["tile_co"].apply(sched, conv, co) + + cfg["reorder_0_simd"].apply(sched, conv, [n, oh, owo, owi, coo, coi, kh, kw, cio, cii]) + + gemm, uniq_id = intrin_gemm_MxKxN(M, K, N, data_vec.dtype, output.dtype, stride_w) + sched[output].tensorize(owi, gemm) + sched[output].pragma(n, "import_c", gemm_MxKxN_impl(M, K, N, uniq_id)) - b_ax, y_ax, x_ax, co_ax = schedule[output].op.axis - kh_ax, kw_ax, ci_ax = schedule[output].op.reduce_axis - schedule[output].reorder(b_ax, y_ax, x_ax, co_ax, kh_ax, kw_ax, ci_ax) + # this is the scope to attach global config inside this kernel + kernel_scope = n - intrin, code = _make_tensorization(padded_data, kernel) - schedule[output].tensorize(kh_ax, intrin) - schedule[output].pragma(b_ax, "import_c", code) + # tune unroll + sched[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) + sched[output].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val) - traverse_inline(schedule, outs[-1].op, _callback) - return schedule + traverse_inline(sched, outs[-1].op, _callback) + return sched diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/tensordot_conv2ds.py b/python/tvm/topi/arm_cpu/mprofile/dsp/tensordot_conv2ds.py new file mode 100644 index 000000000000..ae6ff030a3b2 --- /dev/null +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/tensordot_conv2ds.py @@ -0,0 +1,228 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, no-value-for-parameter +"""Direct implementation of conv2d.""" + +import random +import string + +from tvm import te, tir +from tvm.tir import indexdiv, indexmod +from tvm.topi.utils import traverse_inline +from tvm.topi.nn.pad import pad + +from .micro_kernel.tensordot import ( + make_intrin_tensordot, + tensordot_impl, +) + + +# Dilation prevents us from using DSP instructions, so this schedule can't work (aside from the +# niche case where dilation_h == stride_h and dilation_w == stride_w, which is rare enough we +# probably don't need to support it). +def _check_no_dilation(dilation): + assert isinstance(dilation, int) or len(dilation) == 2 + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + assert dilation_h == dilation_w == 1 + + +def _unpack_strides(strides): + assert isinstance(strides, int) or len(strides) == 2 + if isinstance(strides, int): + return (strides, strides) + else: + return strides + + +def _unpack_padding(padding): + assert isinstance(padding, tuple) + if len(padding) == 2: + (pad_up, pad_down), (pad_left, pad_right) = padding + else: + pad_up, pad_left, pad_down, pad_right = padding + return pad_up, pad_left, pad_down, pad_right + + +# We only care about tuples here - "VALID" and "SAME" padding will be converted by the importer. +def _pad_if_needed(data, unpacked_padding): + pad_up, pad_left, pad_down, pad_right = unpacked_padding + if pad_up or pad_left or pad_down or pad_right: + return pad( + data, + [0, pad_up, pad_left, 0], + [0, pad_down, pad_right, 0], + name="padded_data" + ) + else: + return data + + +def _compute_output_dim(data_dim, kernel_dim, pad_before, pad_after, stride): + return (data_dim - kernel_dim + pad_before + pad_after) // stride + 1 + + +# Offsets to "prefer" the bottom right corner. This is done to match TensorFlow's convention, but +# does NOT match the other TVM schedules. We violate this convention because it improves accuracy on +# models imported from TensorFlow. +def _compute_offset(data_dim, kernel_dim, pad_before, pad_after, stride): + return (data_dim - kernel_dim + pad_before + pad_after) % stride + + +def conv2d_nhwc_ohwi_dsp_compute(cfg, data, kernel, strides, padding, dilation, out_dtype): + stride_h, stride_w = _unpack_strides(strides) + pad_up, pad_left, pad_down, pad_right = _unpack_padding(padding) + _check_no_dilation(dilation) + + batch_size, data_h, data_w, in_channels = data.shape + out_channels, kernel_h, kernel_w, _ = kernel.shape + assert kernel.shape[3] == in_channels + + output_h = _compute_output_dim(data_h, kernel_h, pad_up, pad_down, stride_h) + output_w = _compute_output_dim(data_w, kernel_w, pad_left, pad_right, stride_w) + y_offset = _compute_offset(data_h, kernel_h, pad_up, pad_down, stride_h) + x_offset = _compute_offset(data_w, kernel_w, pad_left, pad_right, stride_w) + + ry = te.reduce_axis((0, kernel_h), name="ry") + rx = te.reduce_axis((0, kernel_w), name="rx") + rc = te.reduce_axis((0, in_channels), name="rc") + + padded_data = _pad_if_needed(data, unpacked_padding) + return te.compute( + (batch_size, output_h, output_w, out_channels), + lambda n, y, x, c: te.sum( + padded_data[ + n, y_offset + y * stride_h + ry, x_offset + x * stride_w + rxh, rc + ].astype(out_dtype) + * kernel[c, ry, rx, rc].astype(out_dtype), + axis=(ry, rx, rc), + ), + name="conv2d", + tag="conv2d_nhwc_ohwi_dsp", + ) + + +def depthwise_conv2d_nchw_oihw_dsp_compute(cfg, data, kernel, strides, padding, dilation, out_dtype): + stride_h, stride_w = _unpack_strides(strides) + pad_up, pad_left, pad_down, pad_right = _unpack_padding(padding) + _check_no_dilation(dilation) + + batch_size, in_channels, data_h, data_w = data.shape + channel_multiplier, _, kernel_h, kernel_w = kernel.shape + assert kernel.shape[1] == in_channels + + output_h = _compute_output_dim(data_h, kernel_h, pad_up, pad_down, stride_h) + output_w = _compute_output_dim(data_w, kernel_w, pad_left, pad_right, stride_w) + y_offset = _compute_offset(data_h, kernel_h, pad_up, pad_down, stride_h) + x_offset = _compute_offset(data_w, kernel_w, pad_left, pad_right, stride_w) + + ry = te.reduce_axis((0, kernel_h), name="ry") + rx = te.reduce_axis((0, kernel_w), name="rx") + + padded_data = _pad_if_needed(data, unpacked_padding) + return te.compute( + (batch_size, output_h, output_w, out_channels), + lambda n, y, x, c: te.sum( + padded_data[ + n, idxdiv(c, c_mul), y_offset + y * stride_h + ry, x_offset + x * stride_w + rx, + ].astype(out_dtype) + * kernel[idxmod(c, c_mul), idxdiv(c, c_mul), ry, rx].astype(out_dtype), + axis=(ry, rx), + ), + name="depthwise_conv2d", + tag="depthwise_conv2d_nchw_oihw_dsp", + ) + +def _make_tensorization(padded_data, kernel): + _, padded_h, padded_w, in_channels = padded_data.shape + _, kernel_h, kernel_w, _ = kernel.shape + in_dtype = padded_data.dtype + suffix = "".join(random.choices(string.ascii_uppercase, k=8)) + assert in_dtype == kernel.dtype + + data_slice = te.placeholder((kernel_h, kernel_w, in_channels), name="a", dtype=in_dtype) + kernel_slice = te.placeholder((kernel_h, kernel_w, in_channels), name="b", dtype=in_dtype) + + kh_i = te.reduce_axis((0, kernel_h), name="kh_i") + kw_i = te.reduce_axis((0, kernel_w), name="kw_i") + kc_i = te.reduce_axis((0, in_channels), name="kc_i") + + output_slice = te.compute((1,), + lambda k: te.sum( + data_slice[kh_i, kw_i, kc_i].astype("int32") + * kernel_slice[kh_i, kw_i, kc_i].astype("int32"), + axis=[kh_i, kw_i, kc_i], + ), + name="c", + ) + + # TVM has a really strange bug where the outer reduction axis (kh_i) having length 1 causes the + # decl_buffer strides check to fail. height_stride is a dark magic workaround for this. + height_stride = in_channels * padded_w if kernel_h > 1 else in_channels + data_buf = tir.decl_buffer( + data_slice.shape, data_slice.dtype, name="foofoomcbar", offset_factor=1, + strides=[height_stride, in_channels, 1], + ) + kernel_buf = tir.decl_buffer( + kernel_slice.shape, kernel_slice.dtype, name="kernel", offset_factor=1, + strides=[kernel_w * in_channels, in_channels, 1] + ) + output_buf = tir.decl_buffer( + output_slice.shape, output_slice.dtype, name="output", offset_factor=1, strides=[1] + ) + + jump = (padded_w - kernel_w) * in_channels + tensordot_params = (in_dtype, kernel_h, jump, kernel_w * in_channels, suffix) + + intrin_tensordot = make_intrin_tensordot( + output_slice.op, + {data_slice: data_buf, kernel_slice: kernel_buf, output_slice: output_buf}, + tensordot_params + ) + + tensordot_code = tensordot_impl(*tensordot_params) + return (intrin_tensordot, tensordot_code) + + +def tensordot_conv2ds_schedule(cfg, outs): + schedule = te.create_schedule([x.op for x in outs]) + + def _callback(operator): + if "conv2d" in operator.tag: + output = operator.output(0) + padded_data = output.op.input_tensors[0] + kernel = output.op.input_tensors[1] + + if tag == "conv2d_nhwc_ohwi_dsp": + b_ax, y_ax, x_ax, co_ax = schedule[output].op.axis + kh_ax, kw_ax, ci_ax = schedule[output].op.reduce_axis + schedule[output].reorder(b_ax, y_ax, x_ax, co_ax, kh_ax, kw_ax, ci_ax) + intrin, code = _make_tensorization(padded_data, kernel) + + elif tag == "depthwise_conv2d_nchw_oihw_dsp" + b_ax, y_ax, x_ax, co_ax = schedule[output].op.axis + kh_ax, kw_ax = schedule[output].op.reduce_axis + schedule[output].reorder(b_ax, co_ax, y_ax, x_ax, kh_ax, kw_ax) + intrin, code = _make_tensorization(padded_data, kernel) + + schedule[output].tensorize(kh_ax, intrin) + schedule[output].pragma(b_ax, "import_c", code) + + traverse_inline(schedule, outs[-1].op, _callback) + return schedule From 90b7657071ef7e5164ba71331c5b55c9e3345ef3 Mon Sep 17 00:00:00 2001 From: Gavin Uberti Date: Tue, 4 Oct 2022 06:48:51 -0700 Subject: [PATCH 07/18] Separate testing infrastructure --- python/tvm/topi/arm_cpu/conv2d.py | 5 +- python/tvm/topi/arm_cpu/depthwise_conv2d.py | 6 +- .../strategy/arm_cpu/test_conv2d_nhwc.py | 93 +++++++++++++++---- .../strategy/arm_cpu/test_depthwise_conv2d.py | 10 ++ 4 files changed, 92 insertions(+), 22 deletions(-) diff --git a/python/tvm/topi/arm_cpu/conv2d.py b/python/tvm/topi/arm_cpu/conv2d.py index 7ec82ed75328..ed1fbf4e83eb 100644 --- a/python/tvm/topi/arm_cpu/conv2d.py +++ b/python/tvm/topi/arm_cpu/conv2d.py @@ -37,7 +37,10 @@ conv2d_nhwc_dsp_compute, conv2d_nhwc_dsp_schedule, ) - +from .mprofile.dsp.tensordot_conv2ds import ( + conv2d_nhwc_ohwi_dsp, + tensordot_conv2ds_schedule, +) @autotvm.register_topi_compute("conv2d_nchw_spatial_pack.arm_cpu") def conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype): diff --git a/python/tvm/topi/arm_cpu/depthwise_conv2d.py b/python/tvm/topi/arm_cpu/depthwise_conv2d.py index ace5ffdabafe..d6efbedb8eab 100644 --- a/python/tvm/topi/arm_cpu/depthwise_conv2d.py +++ b/python/tvm/topi/arm_cpu/depthwise_conv2d.py @@ -27,11 +27,15 @@ from ..nn.utils import get_pad_tuple from .tensor_intrin import smlal_int16_int32 from .arm_utils import is_aarch64_arm - from .mprofile.dsp.depthwise_conv2d import ( depthwise_conv2d_nhwc_dsp_compute, depthwise_conv2d_nhwc_dsp_schedule, ) +from .mprofile.dsp.tensordot_conv2ds import ( + depthwise_conv2d_nchw_oihw_dsp, + tensordot_conv2ds_schedule, +) + @autotvm.register_topi_compute("depthwise_conv2d_nchw.arm_cpu") diff --git a/tests/python/relay/strategy/arm_cpu/test_conv2d_nhwc.py b/tests/python/relay/strategy/arm_cpu/test_conv2d_nhwc.py index 72200bbdffae..53f9481c41b2 100644 --- a/tests/python/relay/strategy/arm_cpu/test_conv2d_nhwc.py +++ b/tests/python/relay/strategy/arm_cpu/test_conv2d_nhwc.py @@ -93,39 +93,92 @@ def test_conv2d( ) -class TestConv2d_NHWC_OHWI_DSP(BasicConv2dTests): +class TestConv2d_DSP_HWOI(BasicConv2dTests): + """This test is for conv2d_nhwc_dsp.arm_cpu schedule.""" data_shape, kernel_size, num_filter, strides, padding, dilation = tvm.testing.parameters( - # Disabled because these kernels are not an integral number of words + # TODO(mehrdadh): Fails due to https://github.com/apache/tvm/issues/11216 # ((1, 32, 32, 1), (3, 3), 12, 1, 0, 1), # ((1, 32, 10, 3), (3, 3), 16, 1, 0, 1), + # ((1, 49, 10, 1), (10, 4), 64, (2, 1), (4, 1, 5, 1), 1), + ((1, 32, 32, 16), (3, 3), 16, 1, (0, 2, 2, 0), 1), + ((1, 32, 32, 16), (3, 3), 16, 1, 0, 1), + ((1, 32, 32, 16), (3, 3), 16, 1, 0, 1), + ((1, 32, 32, 16), (3, 3), 16, 1, (0, 2, 2, 0), 2), + ((1, 32, 32, 16), (3, 3), 16, 1, (1, 1, 2, 2), 2), + # from Keyword Spotting model from MLPerfTiny models + # TODO(mehrdad): Fails due to https://github.com/apache/tvm/issues/11216 + # ((1, 49, 10, 1), (10, 4), 64, (2, 2), (4, 1, 5, 1), 1), + # from Visual Wake Word model from MLPerfTiny models + # TODO(mehrdadh): fails due to https://github.com/apache/tvm/issues/11216 # ((1, 96, 96, 3), (3, 3), 8, (2, 2), (0, 0, 1, 1), 1), + # from Image Classification model from MLPerfTiny models + ((1, 16, 16, 32), (1, 1), 64, (2, 2), 0, 1), + ((4, 16, 16, 8), (5, 5), 8, 2, (0, 4, 4, 0), 1), + ((4, 16, 16, 8), (5, 5), 16, 2, (0, 4, 4, 0), 1), + ((4, 16, 16, 8), (5, 5), 8, 2, 0, 1), + ((4, 16, 16, 8), (5, 5), 16, 2, 0, 1), + ((1, 16, 16, 8), (3, 3), 16, 2, (0, 0, 1, 1), 1), + ((1, 16, 16, 8), (3, 3), 16, 2, (1, 1, 2, 2), 1), + ((1, 16, 16, 8), (5, 5), 16, 2, (3, 3, 2, 2), 1), + ((1, 16, 16, 8), (3, 3), 16, 2, (0, 1, 2, 3), 1), + ) + dtype = tvm.testing.parameter("int8", "int16") + kernel_layout = tvm.testing.parameter("HWOI") + schedule_name = tvm.testing.parameter("conv2d_nhwc_dsp.arm_cpu") + + +class TestConv2d_HWIO(BasicConv2dTests): + """This test is for conv2d_nhwc_spatial_pack.arm_cpu schedule.""" + + data_shape, kernel_size, num_filter, strides, padding, dilation = tvm.testing.parameters( + ((1, 32, 32, 1), (3, 3), 12, 1, 0, 1), + ((1, 32, 10, 3), (3, 3), 16, 1, 0, 1), + ((1, 49, 10, 1), (10, 4), 64, (2, 1), (4, 1, 5, 1), 1), + ((1, 32, 32, 16), (3, 3), 16, 1, (0, 2, 2, 0), 1), + ((1, 32, 32, 16), (3, 3), 16, 1, 0, 1), + ((1, 32, 32, 16), (3, 3), 16, 1, 0, 1), + ((1, 32, 32, 16), (3, 3), 16, 1, (0, 2, 2, 0), 2), + ((1, 32, 32, 16), (3, 3), 16, 1, (1, 1, 2, 2), 2), + ) + dtype = tvm.testing.parameter("int8", "int16") + kernel_layout = tvm.testing.parameter("HWIO") + schedule_name = tvm.testing.parameter("conv2d_nhwc_spatial_pack.arm_cpu") + + +class TestConv2d_Tensordot(BasicConv2dTests): + data_shape, kernel_size, num_filter, strides, padding = tvm.testing.parameters( + # Disabled because these kernels are not an integral number of words + # ((1, 32, 32, 1), (3, 3), 12, 1, 0), + # ((1, 32, 10, 3), (3, 3), 16, 1, 0), + # ((1, 96, 96, 3), (3, 3), 8, (2, 2), (0, 0, 1, 1)), # Disabled because while our schedule matches TensorFlow's behavior, it does NOT # match the x86 schedule behavior (which is different). These schedules have either: # (in_height + pad_up + pad_down - kernel_h) % stride_h > 0 OR # (in_width + pad_left + pad_right - kernel_w) % stride_w > 0 - # ((4, 16, 16, 8), (5, 5), 8, 2, (0, 4, 3, 0), 1), - # ((4, 16, 16, 8), (5, 5), 16, 2, (0, 4, 4, 0), 1), - # ((4, 16, 16, 8), (5, 5), 8, 2, 0, 1), - # ((4, 16, 16, 8), (5, 5), 16, 2, 0, 1), - # ((1, 16, 16, 32), (1, 1), 64, (2, 2), 0, 1), - # ((1, 16, 16, 32), (1, 1), 64, (2, 2), 0, 1) - # ((1, 49, 10, 1), (10, 4), 64, (2, 1), (4, 1, 5, 1), 1), - - ((1, 32, 32, 16), (3, 3), 16, 1, (0, 2, 2, 0), 1), - ((1, 32, 32, 16), (3, 3), 16, 1, 0, 1), - ((1, 32, 32, 16), (3, 3), 16, 1, 0, 1), - ((1, 49, 10, 1), (10, 4), 64, (2, 2), (4, 1, 5, 1), 1), - ((1, 16, 16, 8), (3, 3), 16, 2, (0, 0, 1, 1), 1), - ((1, 16, 16, 8), (3, 3), 16, 2, (1, 1, 2, 2), 1), - ((1, 16, 16, 8), (5, 5), 16, 2, (3, 3, 2, 2), 1), - ((1, 32, 32, 16), (3, 3), 16, 1, 0, 1), - ((1, 16, 16, 32), (1, 1), 64, 1, 0, 1), + # ((4, 16, 16, 8), (5, 5), 8, 2, (0, 4, 3, 0)), + # ((4, 16, 16, 8), (5, 5), 16, 2, (0, 4, 4, 0)), + # ((4, 16, 16, 8), (5, 5), 8, 2, 0, ), + # ((4, 16, 16, 8), (5, 5), 16, 2, 0, ), + # ((1, 16, 16, 32), (1, 1), 64, (2, 2), 0, ), + # ((1, 16, 16, 32), (1, 1), 64, (2, 2), 0, ), + # ((1, 49, 10, 1), (10, 4), 64, (2, 1), (4, 1, 5, 1)), + + ((1, 32, 32, 16), (3, 3), 16, 1, (0, 2, 2, 0)), + ((1, 32, 32, 16), (3, 3), 16, 1, 0, ), + ((1, 32, 32, 16), (3, 3), 16, 1, 0, ), + ((1, 49, 10, 1), (10, 4), 64, (2, 2), (4, 1, 5, 1)), + ((1, 16, 16, 8), (3, 3), 16, 2, (0, 0, 1, 1)), + ((1, 16, 16, 8), (3, 3), 16, 2, (1, 1, 2, 2)), + ((1, 16, 16, 8), (5, 5), 16, 2, (3, 3, 2, 2)), + ((1, 32, 32, 16), (3, 3), 16, 1, 0, ), + ((1, 16, 16, 32), (1, 1), 64, 1, 0, ), ) + dilation = tvm.testing.parameter(1) dtype = tvm.testing.parameter("int8", "int16", "int32") kernel_layout = tvm.testing.parameter("OHWI") - schedule_name = tvm.testing.parameter("conv2d_nhwc_dsp.arm_cpu") + schedule_name = tvm.testing.parameter("conv2d_nhwc_ohwi_dsp.arm_cpu") if __name__ == "__main__": diff --git a/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py b/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py index 15ea2a31d864..9d8723ab216a 100644 --- a/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py +++ b/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py @@ -185,5 +185,15 @@ class TestDepthwiseConv2d_NHWC_HWOI_DSP(BasicDepthwiseConv2dTests): schedule_name = tvm.testing.parameter("depthwise_conv2d_nhwc_dsp.arm_cpu") +class TestDepthwiseConv2d_Tensordot_DSP(BasicDepthwiseConv2dTests): + data_shape, kernel_size, num_filter, strides, padding, dtype = tvm.testing.parameters( + ((1, 48, 48, 8), (3, 3), 8, (1, 1), 1, "int16"), + ) + dilation = tvm.testing.parameter(1) + data_layout = tvm.testing.parameter("NCHW") + kernel_layout = tvm.testing.parameter("OIHW") + schedule_name = tvm.testing.parameter("conv2d_nchw_oihw_dsp.arm_cpu") + + if __name__ == "__main__": tvm.testing.main() From 814bc6c78aca8e2f34a3775d654a05ffa3d6d6ab Mon Sep 17 00:00:00 2001 From: Gavin Uberti Date: Tue, 4 Oct 2022 13:01:46 -0700 Subject: [PATCH 08/18] Prototype depthwise implementation --- python/tvm/relay/op/strategy/arm_cpu.py | 2 +- python/tvm/topi/arm_cpu/conv2d.py | 4 +- python/tvm/topi/arm_cpu/depthwise_conv2d.py | 4 +- .../arm_cpu/mprofile/dsp/tensordot_conv2ds.py | 124 +++++++++++++----- .../strategy/arm_cpu/test_depthwise_conv2d.py | 4 +- 5 files changed, 97 insertions(+), 41 deletions(-) diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index c9a5ac80df1f..47e0921c45a7 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -213,7 +213,7 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): if (target.features.has_dsp and dilation_w == dilation_h == 1): strategy.add_implementation( wrap_compute_conv2d(topi.arm_cpu.depthwise_conv2d_nchw_oihw_dsp), - wrap_topi_schedule(topi.arm_cpu.depthwise_schedule_conv2d_nchw_oihw_dsp), + wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nchw_oihw_dsp), name="depthwise_conv2d_nchw_oihw_dsp.arm_cpu", ) else: diff --git a/python/tvm/topi/arm_cpu/conv2d.py b/python/tvm/topi/arm_cpu/conv2d.py index ed1fbf4e83eb..fcb71d0513d2 100644 --- a/python/tvm/topi/arm_cpu/conv2d.py +++ b/python/tvm/topi/arm_cpu/conv2d.py @@ -38,7 +38,7 @@ conv2d_nhwc_dsp_schedule, ) from .mprofile.dsp.tensordot_conv2ds import ( - conv2d_nhwc_ohwi_dsp, + conv2d_nhwc_ohwi_dsp_compute, tensordot_conv2ds_schedule, ) @@ -525,7 +525,7 @@ def schedule_conv2d_nhwc_dsp(cfg, outs): @autotvm.register_topi_compute("conv2d_nhwc_ohwi_dsp.arm_cpu") def conv2d_nhwc_ohwi_dsp(cfg, data, kernel, strides, padding, dilation, out_dtype): - return conv2d_nhwc_ohwi_dsp( + return conv2d_nhwc_ohwi_dsp_compute( cfg, data, kernel, strides, padding, dilation, out_dtype ) diff --git a/python/tvm/topi/arm_cpu/depthwise_conv2d.py b/python/tvm/topi/arm_cpu/depthwise_conv2d.py index d6efbedb8eab..d6cc11e34cd6 100644 --- a/python/tvm/topi/arm_cpu/depthwise_conv2d.py +++ b/python/tvm/topi/arm_cpu/depthwise_conv2d.py @@ -32,7 +32,7 @@ depthwise_conv2d_nhwc_dsp_schedule, ) from .mprofile.dsp.tensordot_conv2ds import ( - depthwise_conv2d_nchw_oihw_dsp, + depthwise_conv2d_nchw_oihw_dsp_compute, tensordot_conv2ds_schedule, ) @@ -726,7 +726,7 @@ def schedule_depthwise_conv2d_nhwc_dsp(cfg, outs): @autotvm.register_topi_compute("depthwise_conv2d_nchw_oihw_dsp.arm_cpu") def depthwise_conv2d_nchw_oihw_dsp(cfg, data, kernel, strides, padding, dilation, out_dtype): - return depthwise_conv2d_nchw_oihw_dsp( + return depthwise_conv2d_nchw_oihw_dsp_compute( cfg, data, kernel, strides, padding, dilation, out_dtype ) diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/tensordot_conv2ds.py b/python/tvm/topi/arm_cpu/mprofile/dsp/tensordot_conv2ds.py index ae6ff030a3b2..9b4b6de9193d 100644 --- a/python/tvm/topi/arm_cpu/mprofile/dsp/tensordot_conv2ds.py +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/tensordot_conv2ds.py @@ -61,8 +61,7 @@ def _unpack_padding(padding): # We only care about tuples here - "VALID" and "SAME" padding will be converted by the importer. -def _pad_if_needed(data, unpacked_padding): - pad_up, pad_left, pad_down, pad_right = unpacked_padding +def _pad_if_needed(data, pad_up, pad_left, pad_down, pad_right): if pad_up or pad_left or pad_down or pad_right: return pad( data, @@ -85,13 +84,18 @@ def _compute_offset(data_dim, kernel_dim, pad_before, pad_after, stride): return (data_dim - kernel_dim + pad_before + pad_after) % stride +# Prevents re-definition of C functions +def _get_suffix(): + return "".join(random.choices(string.ascii_uppercase, k=8)) + + def conv2d_nhwc_ohwi_dsp_compute(cfg, data, kernel, strides, padding, dilation, out_dtype): stride_h, stride_w = _unpack_strides(strides) pad_up, pad_left, pad_down, pad_right = _unpack_padding(padding) _check_no_dilation(dilation) batch_size, data_h, data_w, in_channels = data.shape - out_channels, kernel_h, kernel_w, _ = kernel.shape + output_channels, kernel_h, kernel_w, _ = kernel.shape assert kernel.shape[3] == in_channels output_h = _compute_output_dim(data_h, kernel_h, pad_up, pad_down, stride_h) @@ -103,12 +107,12 @@ def conv2d_nhwc_ohwi_dsp_compute(cfg, data, kernel, strides, padding, dilation, rx = te.reduce_axis((0, kernel_w), name="rx") rc = te.reduce_axis((0, in_channels), name="rc") - padded_data = _pad_if_needed(data, unpacked_padding) + padded_data = _pad_if_needed(data, pad_up, pad_left, pad_down, pad_right) return te.compute( - (batch_size, output_h, output_w, out_channels), + (batch_size, output_h, output_w, output_channels), lambda n, y, x, c: te.sum( padded_data[ - n, y_offset + y * stride_h + ry, x_offset + x * stride_w + rxh, rc + n, y_offset + y * stride_h + ry, x_offset + x * stride_w + rx, rc ].astype(out_dtype) * kernel[c, ry, rx, rc].astype(out_dtype), axis=(ry, rx, rc), @@ -118,14 +122,68 @@ def conv2d_nhwc_ohwi_dsp_compute(cfg, data, kernel, strides, padding, dilation, ) +def _make_conv2d_tensorization(padded_data, kernel): + _, padded_h, padded_w, in_channels = padded_data.shape + _, kernel_h, kernel_w, _ = kernel.shape + in_dtype = padded_data.dtype + suffix = _get_suffix() + assert in_dtype == kernel.dtype + + data_slice = te.placeholder((kernel_h, kernel_w, in_channels), name="a", dtype=in_dtype) + kernel_slice = te.placeholder((kernel_h, kernel_w, in_channels), name="b", dtype=in_dtype) + + kh_i = te.reduce_axis((0, kernel_h), name="kh_i") + kw_i = te.reduce_axis((0, kernel_w), name="kw_i") + kc_i = te.reduce_axis((0, in_channels), name="kc_i") + + output_slice = te.compute((1,), + lambda k: te.sum( + data_slice[kh_i, kw_i, kc_i].astype("int32") + * kernel_slice[kh_i, kw_i, kc_i].astype("int32"), + axis=[kh_i, kw_i, kc_i], + ), + name="c", + ) + + # TVM has a really strange bug where the outer reduction axis (kh_i) having length 1 causes the + # decl_buffer strides check to fail. height_stride is a dark magic workaround for this. + height_stride = in_channels * padded_w if kernel_h > 1 else in_channels + data_buf = tir.decl_buffer( + data_slice.shape, data_slice.dtype, name="data", offset_factor=1, + strides=[height_stride, in_channels, 1], + ) + kernel_buf = tir.decl_buffer( + kernel_slice.shape, kernel_slice.dtype, name="kernel", offset_factor=1, + strides=[kernel_w * in_channels, in_channels, 1] + ) + output_buf = tir.decl_buffer( + output_slice.shape, output_slice.dtype, name="output", offset_factor=1, strides=[1] + ) + + jump = (padded_w - kernel_w) * in_channels + tensordot_params = (in_dtype, kernel_h, jump, kernel_w * in_channels, suffix) + + intrin_tensordot = make_intrin_tensordot( + output_slice.op, + {data_slice: data_buf, kernel_slice: kernel_buf, output_slice: output_buf}, + tensordot_params + ) + + tensordot_code = tensordot_impl(*tensordot_params) + return (intrin_tensordot, tensordot_code) + + def depthwise_conv2d_nchw_oihw_dsp_compute(cfg, data, kernel, strides, padding, dilation, out_dtype): stride_h, stride_w = _unpack_strides(strides) pad_up, pad_left, pad_down, pad_right = _unpack_padding(padding) _check_no_dilation(dilation) batch_size, in_channels, data_h, data_w = data.shape - channel_multiplier, _, kernel_h, kernel_w = kernel.shape - assert kernel.shape[1] == in_channels + _, c_mul, kernel_h, kernel_w = kernel.shape + output_channels = in_channels * c_mul + print(data.shape) + print(kernel.shape) + assert kernel.shape[0] == in_channels output_h = _compute_output_dim(data_h, kernel_h, pad_up, pad_down, stride_h) output_w = _compute_output_dim(data_w, kernel_w, pad_left, pad_right, stride_w) @@ -135,60 +193,57 @@ def depthwise_conv2d_nchw_oihw_dsp_compute(cfg, data, kernel, strides, padding, ry = te.reduce_axis((0, kernel_h), name="ry") rx = te.reduce_axis((0, kernel_w), name="rx") - padded_data = _pad_if_needed(data, unpacked_padding) + padded_data = _pad_if_needed(data, pad_up, pad_left, pad_down, pad_right) return te.compute( - (batch_size, output_h, output_w, out_channels), + (batch_size, output_h, output_w, output_channels), lambda n, y, x, c: te.sum( padded_data[ - n, idxdiv(c, c_mul), y_offset + y * stride_h + ry, x_offset + x * stride_w + rx, + n, indexdiv(c, c_mul), y_offset + y * stride_h + ry, x_offset + x * stride_w + rx, ].astype(out_dtype) - * kernel[idxmod(c, c_mul), idxdiv(c, c_mul), ry, rx].astype(out_dtype), + * kernel[indexdiv(c, c_mul), indexmod(c, c_mul), ry, rx].astype(out_dtype), axis=(ry, rx), ), name="depthwise_conv2d", tag="depthwise_conv2d_nchw_oihw_dsp", ) -def _make_tensorization(padded_data, kernel): - _, padded_h, padded_w, in_channels = padded_data.shape - _, kernel_h, kernel_w, _ = kernel.shape +def _make_depthwise_conv2d_tensorization(padded_data, kernel): + _, _, padded_h, padded_w = padded_data.shape + _, _, kernel_h, kernel_w = kernel.shape + in_dtype = padded_data.dtype - suffix = "".join(random.choices(string.ascii_uppercase, k=8)) + suffix = _get_suffix() assert in_dtype == kernel.dtype - data_slice = te.placeholder((kernel_h, kernel_w, in_channels), name="a", dtype=in_dtype) - kernel_slice = te.placeholder((kernel_h, kernel_w, in_channels), name="b", dtype=in_dtype) + data_slice = te.placeholder((kernel_h, kernel_w), name="a", dtype=in_dtype) + kernel_slice = te.placeholder((kernel_h, kernel_w), name="b", dtype=in_dtype) kh_i = te.reduce_axis((0, kernel_h), name="kh_i") kw_i = te.reduce_axis((0, kernel_w), name="kw_i") - kc_i = te.reduce_axis((0, in_channels), name="kc_i") output_slice = te.compute((1,), lambda k: te.sum( - data_slice[kh_i, kw_i, kc_i].astype("int32") - * kernel_slice[kh_i, kw_i, kc_i].astype("int32"), - axis=[kh_i, kw_i, kc_i], + data_slice[kh_i, kw_i].astype("int32") + * kernel_slice[kh_i, kw_i].astype("int32"), + axis=[kh_i, kw_i], ), name="c", ) - # TVM has a really strange bug where the outer reduction axis (kh_i) having length 1 causes the - # decl_buffer strides check to fail. height_stride is a dark magic workaround for this. - height_stride = in_channels * padded_w if kernel_h > 1 else in_channels data_buf = tir.decl_buffer( - data_slice.shape, data_slice.dtype, name="foofoomcbar", offset_factor=1, - strides=[height_stride, in_channels, 1], + data_slice.shape, data_slice.dtype, name="data", offset_factor=1, + strides=[padded_w, 1], ) kernel_buf = tir.decl_buffer( kernel_slice.shape, kernel_slice.dtype, name="kernel", offset_factor=1, - strides=[kernel_w * in_channels, in_channels, 1] + strides=[kernel_w, 1] ) output_buf = tir.decl_buffer( output_slice.shape, output_slice.dtype, name="output", offset_factor=1, strides=[1] ) - jump = (padded_w - kernel_w) * in_channels - tensordot_params = (in_dtype, kernel_h, jump, kernel_w * in_channels, suffix) + jump = padded_w - kernel_w + tensordot_params = (in_dtype, kernel_h, jump, kernel_w, suffix) intrin_tensordot = make_intrin_tensordot( output_slice.op, @@ -209,17 +264,18 @@ def _callback(operator): padded_data = output.op.input_tensors[0] kernel = output.op.input_tensors[1] - if tag == "conv2d_nhwc_ohwi_dsp": + if operator.tag == "conv2d_nhwc_ohwi_dsp": b_ax, y_ax, x_ax, co_ax = schedule[output].op.axis kh_ax, kw_ax, ci_ax = schedule[output].op.reduce_axis schedule[output].reorder(b_ax, y_ax, x_ax, co_ax, kh_ax, kw_ax, ci_ax) - intrin, code = _make_tensorization(padded_data, kernel) + intrin, code = _make_conv2d_tensorization(padded_data, kernel) - elif tag == "depthwise_conv2d_nchw_oihw_dsp" + elif operator.tag == "depthwise_conv2d_nchw_oihw_dsp": b_ax, y_ax, x_ax, co_ax = schedule[output].op.axis kh_ax, kw_ax = schedule[output].op.reduce_axis schedule[output].reorder(b_ax, co_ax, y_ax, x_ax, kh_ax, kw_ax) - intrin, code = _make_tensorization(padded_data, kernel) + intrin, code = _make_depthwise_conv2d_tensorization(padded_data, kernel) + print(code) schedule[output].tensorize(kh_ax, intrin) schedule[output].pragma(b_ax, "import_c", code) diff --git a/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py b/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py index 9d8723ab216a..7999c354ee46 100644 --- a/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py +++ b/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py @@ -185,9 +185,9 @@ class TestDepthwiseConv2d_NHWC_HWOI_DSP(BasicDepthwiseConv2dTests): schedule_name = tvm.testing.parameter("depthwise_conv2d_nhwc_dsp.arm_cpu") -class TestDepthwiseConv2d_Tensordot_DSP(BasicDepthwiseConv2dTests): +class TestDepthwiseConv2d_Tensordot(BasicDepthwiseConv2dTests): data_shape, kernel_size, num_filter, strides, padding, dtype = tvm.testing.parameters( - ((1, 48, 48, 8), (3, 3), 8, (1, 1), 1, "int16"), + ((1, 8, 48, 48), (3, 3), 8, (1, 1), 1, "int32"), ) dilation = tvm.testing.parameter(1) data_layout = tvm.testing.parameter("NCHW") From 5def2e199196d5459454531415d27781fc9217f0 Mon Sep 17 00:00:00 2001 From: Gavin Uberti Date: Wed, 5 Oct 2022 01:35:43 -0700 Subject: [PATCH 09/18] Unit testing for depthwise_conv2d --- .../arm_cpu/mprofile/dsp/tensordot_conv2ds.py | 29 ++++++++++--------- .../strategy/arm_cpu/test_depthwise_conv2d.py | 19 +++++++++++- 2 files changed, 33 insertions(+), 15 deletions(-) diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/tensordot_conv2ds.py b/python/tvm/topi/arm_cpu/mprofile/dsp/tensordot_conv2ds.py index 9b4b6de9193d..826fa117efe6 100644 --- a/python/tvm/topi/arm_cpu/mprofile/dsp/tensordot_conv2ds.py +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/tensordot_conv2ds.py @@ -61,14 +61,18 @@ def _unpack_padding(padding): # We only care about tuples here - "VALID" and "SAME" padding will be converted by the importer. -def _pad_if_needed(data, pad_up, pad_left, pad_down, pad_right): +def _pad_if_needed(data, layout, pad_up, pad_left, pad_down, pad_right): if pad_up or pad_left or pad_down or pad_right: - return pad( - data, - [0, pad_up, pad_left, 0], - [0, pad_down, pad_right, 0], - name="padded_data" - ) + assert len(layout) == 4 + + # We want to pad the "H" and "W" columns, and their position depends on the layout + pad_before, pad_after = [0, 0, 0, 0], [0, 0, 0, 0] + pad_before[layout.index("H")] = pad_up + pad_before[layout.index("W")] = pad_left + pad_after[layout.index("H")] = pad_down + pad_after[layout.index("W")] = pad_right + return pad(data, pad_before, pad_after, name="padded_data") + else: return data @@ -107,7 +111,7 @@ def conv2d_nhwc_ohwi_dsp_compute(cfg, data, kernel, strides, padding, dilation, rx = te.reduce_axis((0, kernel_w), name="rx") rc = te.reduce_axis((0, in_channels), name="rc") - padded_data = _pad_if_needed(data, pad_up, pad_left, pad_down, pad_right) + padded_data = _pad_if_needed(data, "NHWC", pad_up, pad_left, pad_down, pad_right) return te.compute( (batch_size, output_h, output_w, output_channels), lambda n, y, x, c: te.sum( @@ -181,8 +185,6 @@ def depthwise_conv2d_nchw_oihw_dsp_compute(cfg, data, kernel, strides, padding, batch_size, in_channels, data_h, data_w = data.shape _, c_mul, kernel_h, kernel_w = kernel.shape output_channels = in_channels * c_mul - print(data.shape) - print(kernel.shape) assert kernel.shape[0] == in_channels output_h = _compute_output_dim(data_h, kernel_h, pad_up, pad_down, stride_h) @@ -193,10 +195,10 @@ def depthwise_conv2d_nchw_oihw_dsp_compute(cfg, data, kernel, strides, padding, ry = te.reduce_axis((0, kernel_h), name="ry") rx = te.reduce_axis((0, kernel_w), name="rx") - padded_data = _pad_if_needed(data, pad_up, pad_left, pad_down, pad_right) + padded_data = _pad_if_needed(data, "NCHW", pad_up, pad_left, pad_down, pad_right) return te.compute( - (batch_size, output_h, output_w, output_channels), - lambda n, y, x, c: te.sum( + (batch_size, output_channels, output_h, output_w), + lambda n, c, y, x: te.sum( padded_data[ n, indexdiv(c, c_mul), y_offset + y * stride_h + ry, x_offset + x * stride_w + rx, ].astype(out_dtype) @@ -275,7 +277,6 @@ def _callback(operator): kh_ax, kw_ax = schedule[output].op.reduce_axis schedule[output].reorder(b_ax, co_ax, y_ax, x_ax, kh_ax, kw_ax) intrin, code = _make_depthwise_conv2d_tensorization(padded_data, kernel) - print(code) schedule[output].tensorize(kh_ax, intrin) schedule[output].pragma(b_ax, "import_c", code) diff --git a/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py b/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py index 7999c354ee46..888eb591645e 100644 --- a/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py +++ b/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py @@ -187,12 +187,29 @@ class TestDepthwiseConv2d_NHWC_HWOI_DSP(BasicDepthwiseConv2dTests): class TestDepthwiseConv2d_Tensordot(BasicDepthwiseConv2dTests): data_shape, kernel_size, num_filter, strides, padding, dtype = tvm.testing.parameters( + # Currently, our schedule requires kernel_w be divisible by the number of simd lanes given + # its dtype. This means 3x3 and 5x5 kernels do not work on int16 or int8 for now. If you had + # to, you could hack around this by padding the data and kernel. ((1, 8, 48, 48), (3, 3), 8, (1, 1), 1, "int32"), + ((1, 16, 48, 48), (3, 3), 16, (2, 2), (1, 1, 0, 0), "int32"), + ((1, 32, 24, 24), (3, 3), 32, (1, 1), 1, "int32"), + ((1, 32, 24, 24), (3, 3), 32, (2, 2), (1, 1, 0, 0), "int32"), + ((1, 64, 12, 12), (3, 3), 64, (1, 1), 1, "int32"), + ((1, 64, 12, 12), (3, 3), 64, (2, 2), (1, 1, 0, 0), "int32"), + ((1, 128, 6, 6), (3, 3), 128, (1, 1), 1, "int32"), + ((1, 128, 6, 6), (3, 3), 128, (2, 2), (1, 1, 0, 0), "int32"), + ((1, 256, 3, 3), (3, 3), 256, (1, 1), 1, "int32"), + ((1, 64, 25, 5), (3, 3), 64, (1, 1), 1, "int32"), + ((1, 8, 24, 24), (5, 5), 8, (1, 1), 1, "int32"), + ((1, 8, 24, 24), (3, 5), 8, (1, 1), 1, "int32"), + + ((1, 8, 48, 48), (3, 2), 8, 1, 0, "int16"), + ((1, 8, 48, 48), (4, 4), 8, 1, 0, "int8"), ) dilation = tvm.testing.parameter(1) data_layout = tvm.testing.parameter("NCHW") kernel_layout = tvm.testing.parameter("OIHW") - schedule_name = tvm.testing.parameter("conv2d_nchw_oihw_dsp.arm_cpu") + schedule_name = tvm.testing.parameter("depthwise_conv2d_nchw_oihw_dsp.arm_cpu") if __name__ == "__main__": From 12898740d16cea973983954ce3ba92399ea7b0e9 Mon Sep 17 00:00:00 2001 From: Gavin Uberti Date: Wed, 5 Oct 2022 02:53:10 -0700 Subject: [PATCH 10/18] Linting and documentation --- python/tvm/relay/op/strategy/arm_cpu.py | 10 +- python/tvm/topi/arm_cpu/conv2d.py | 8 +- python/tvm/topi/arm_cpu/depthwise_conv2d.py | 3 +- .../mprofile/dsp/micro_kernel/tensordot.py | 78 +++----- .../arm_cpu/mprofile/dsp/tensordot_conv2ds.py | 183 +++++++++--------- python/tvm/topi/utils.py | 1 + .../strategy/arm_cpu/test_conv2d_nhwc.py | 45 +++-- .../strategy/arm_cpu/test_depthwise_conv2d.py | 2 +- 8 files changed, 159 insertions(+), 171 deletions(-) diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index 47e0921c45a7..0b036c0abd4c 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -160,10 +160,10 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): ) elif layout == "NHWC": if ( - target.features.has_dsp - and dilation_w == dilation_h == 1 - and kernel_layout == "OHWI" - ): + target.features.has_dsp + and dilation_w == dilation_h == 1 + and kernel_layout == "OHWI" + ): strategy.add_implementation( wrap_compute_conv2d(topi.arm_cpu.conv2d_nhwc_ohwi_dsp), wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nhwc_ohwi_dsp), @@ -210,7 +210,7 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): if layout == "NCHW": assert kernel_layout == "OIHW" or re.match(r"OIHW\d*o", kernel_layout) if kernel_layout == "OIHW": - if (target.features.has_dsp and dilation_w == dilation_h == 1): + if target.features.has_dsp and dilation_w == dilation_h == 1: strategy.add_implementation( wrap_compute_conv2d(topi.arm_cpu.depthwise_conv2d_nchw_oihw_dsp), wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nchw_oihw_dsp), diff --git a/python/tvm/topi/arm_cpu/conv2d.py b/python/tvm/topi/arm_cpu/conv2d.py index fcb71d0513d2..bb29de8fa27b 100644 --- a/python/tvm/topi/arm_cpu/conv2d.py +++ b/python/tvm/topi/arm_cpu/conv2d.py @@ -42,6 +42,7 @@ tensordot_conv2ds_schedule, ) + @autotvm.register_topi_compute("conv2d_nchw_spatial_pack.arm_cpu") def conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype): """Compute conv2d with NCHW layout""" @@ -525,12 +526,11 @@ def schedule_conv2d_nhwc_dsp(cfg, outs): @autotvm.register_topi_compute("conv2d_nhwc_ohwi_dsp.arm_cpu") def conv2d_nhwc_ohwi_dsp(cfg, data, kernel, strides, padding, dilation, out_dtype): - return conv2d_nhwc_ohwi_dsp_compute( - cfg, data, kernel, strides, padding, dilation, out_dtype - ) + """Compute conv2d_nhwc_ohwi with v7e-m DSP instructions and the tensordot kernel.""" + return conv2d_nhwc_ohwi_dsp_compute(cfg, data, kernel, strides, padding, dilation, out_dtype) @autotvm.register_topi_schedule("conv2d_nhwc_ohwi_dsp.arm_cpu") def schedule_conv2d_nhwc_ohwi_dsp(cfg, outs): + """Create schedule for conv2d_nhwc_ohwi.""" return tensordot_conv2ds_schedule(cfg, outs) - diff --git a/python/tvm/topi/arm_cpu/depthwise_conv2d.py b/python/tvm/topi/arm_cpu/depthwise_conv2d.py index d6cc11e34cd6..58cd11e8cc09 100644 --- a/python/tvm/topi/arm_cpu/depthwise_conv2d.py +++ b/python/tvm/topi/arm_cpu/depthwise_conv2d.py @@ -37,7 +37,6 @@ ) - @autotvm.register_topi_compute("depthwise_conv2d_nchw.arm_cpu") def depthwise_conv2d_nchw(_, data, kernel, strides, padding, dilation, out_dtype): """Compute depthwise_conv2d with NCHW layout""" @@ -726,6 +725,7 @@ def schedule_depthwise_conv2d_nhwc_dsp(cfg, outs): @autotvm.register_topi_compute("depthwise_conv2d_nchw_oihw_dsp.arm_cpu") def depthwise_conv2d_nchw_oihw_dsp(cfg, data, kernel, strides, padding, dilation, out_dtype): + """Compute depthwise_conv2d_nchw_oihw with v7e-m DSP instructions and the tensordot kernel.""" return depthwise_conv2d_nchw_oihw_dsp_compute( cfg, data, kernel, strides, padding, dilation, out_dtype ) @@ -733,4 +733,5 @@ def depthwise_conv2d_nchw_oihw_dsp(cfg, data, kernel, strides, padding, dilation @autotvm.register_topi_schedule("depthwise_conv2d_nchw_oihw_dsp.arm_cpu") def schedule_depthwise_conv2d_nchw_oihw_dsp(cfg, outs): + """Create schedule for depthwise_conv2d_nchw_oihw.""" return tensordot_conv2ds_schedule(cfg, outs) diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py index 8e65b6091903..5305b90e6323 100644 --- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py @@ -14,11 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Computes a "jumpy tensordot" operator, which can/should be used to -tensorize ANY aritrarily grouped conv2D, provided the data and kernel -layouts are the optimal ones. When groups=1, the optimal data layout -is NHWC and kernel layout is OHWI. When this is a depthwise convolution, -the optimal data layout is NCHW and kernel layout is OIHW.""" +"""Computes a "jumpy tensordot" operator, which can be used to tensorize many common operators +including regular conv2d, depthwise conv2d, and grouped conv2d provided the data and kernel layouts +are the optimal ones. When groups=1, the optimal data layout is NHWC and kernel layout is OHWI. When +this is a depthwise convolution, the optimal data layout is NCHW and kernel layout is OIHW.""" import textwrap @@ -26,62 +25,33 @@ from .common import num_simd_lanes_per_word + def _get_func_name(in_dtype, tensor_h, jump, tensor_w, suffix): - """Gets the C function name of the tensorized function.""" + """Gets the C function name of the tensordot function.""" return f"tensordot_{in_dtype}_h{tensor_h}_j{jump}_w{tensor_w}_{suffix}" -def make_intrin_tensordot(operator, binds, tensordot_params): - #in_dtype, tensor_h, jump, tensor_w, suffix = tensordot_params - - def intrin_func(ins, outs): - builder = tir.ir_builder.create() - builder.emit( - tir.call_extern( - "int32", - _get_func_name(*tensordot_params), - outs[0].access_ptr("w"), - ins[0].access_ptr("r"), - ins[1].access_ptr("r"), - ) - ) - return builder.get() - - return te.decl_tensor_intrin( - operator, - intrin_func, - binds=binds, - ) - - -def intrin_depthwise_conv2d_tensordot(in_dtype, tensor_w, kernel_h, kernel_w, suffix): - data_slice = te.placeholder((kernel_h, kernel_w), name="a", dtype=in_dtype) - kernel_slice = te.placeholder((kernel_h, kernel_w), name="b", dtype=in_dtype) +def make_intrin_tensordot(slices, strides, tensordot_params): + """Helper function for constructing tensordot intrinsic. We can't construct the whole thing here + (as multiple schedules use tensordot and each must build the intrinstic differently) but we can + build part here to simplify the code.""" - kh_i = te.reduce_axis((0, kernel_h), name="kh_i") - kw_i = te.reduce_axis((0, kernel_w), name="kw_i") - - output_slice = te.compute((1,), - lambda k: te.sum( - data_slice[kh_i, kw_i].astype("int32") - * kernel_slice[kh_i, kw_i].astype("int32"), - axis=[kh_i, kw_i], - ), - name="c", - ) + # in_dtype, tensor_h, jump, tensor_w, suffix = tensordot_params + data, kernel, output = slices + data_strides, kernel_strides = strides data_buf = tir.decl_buffer( - data_slice.shape, - data_slice.dtype, - name="data", - offset_factor=1, - strides=[tensor_w, 1], + data.shape, data.dtype, name="data", offset_factor=1, strides=data_strides ) kernel_buf = tir.decl_buffer( - kernel_slice.shape, kernel_slice.dtype, name="kernel", offset_factor=1, strides=[kernel_w, 1] + kernel.shape, + kernel.dtype, + name="kernel", + offset_factor=1, + strides=kernel_strides, ) output_buf = tir.decl_buffer( - output_slice.shape, output_slice.dtype, name="output", offset_factor=1, strides=[1] + output.shape, output.dtype, name="output", offset_factor=1, strides=[1] ) def intrin_func(ins, outs): @@ -89,7 +59,7 @@ def intrin_func(ins, outs): builder.emit( tir.call_extern( "int32", - _get_func_name(in_dtype, tensor_w, channels, kernel_h, kernel_w, suffix), + _get_func_name(*tensordot_params), outs[0].access_ptr("w"), ins[0].access_ptr("r"), ins[1].access_ptr("r"), @@ -98,13 +68,15 @@ def intrin_func(ins, outs): return builder.get() return te.decl_tensor_intrin( - output_slice.op, + output.op, intrin_func, - binds=binings, + binds={data: data_buf, kernel: kernel_buf, output: output_buf}, ) def tensordot_impl(in_dtype, tensor_h, jump, tensor_w, suffix): + """Generates C code for tensordot. The int8 and int16 versions have Arm v7e-m DSP assembly.""" + assert in_dtype in ["int8", "int16", "int32"] simd_lanes = num_simd_lanes_per_word(in_dtype) assert tensor_w % simd_lanes == 0 diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/tensordot_conv2ds.py b/python/tvm/topi/arm_cpu/mprofile/dsp/tensordot_conv2ds.py index 826fa117efe6..a57c306b2f7c 100644 --- a/python/tvm/topi/arm_cpu/mprofile/dsp/tensordot_conv2ds.py +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/tensordot_conv2ds.py @@ -14,13 +14,17 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name, no-value-for-parameter -"""Direct implementation of conv2d.""" +"""Implementations of several conv2d variations, all tensorized using tensordot and optimized for +Cortex-M DSP. Currently contains a standard conv2d and depthwise conv2d implementation, but could be +extended to add a grouped conv2d operator. Due to the way we tensorize, this schedule ONLY works +when the data and kernel layouts are NCHWxc and OIHWxi respectively, where x is the number of +input channels divided by the number of groups.""" import random import string +from typing import Union, Tuple -from tvm import te, tir +from tvm import te from tvm.tir import indexdiv, indexmod from tvm.topi.utils import traverse_inline from tvm.topi.nn.pad import pad @@ -31,27 +35,24 @@ ) -# Dilation prevents us from using DSP instructions, so this schedule can't work (aside from the -# niche case where dilation_h == stride_h and dilation_w == stride_w, which is rare enough we -# probably don't need to support it). -def _check_no_dilation(dilation): - assert isinstance(dilation, int) or len(dilation) == 2 - if isinstance(dilation, int): - dilation_h = dilation_w = dilation - else: - dilation_h, dilation_w = dilation - assert dilation_h == dilation_w == 1 +def _unpack_2d_argument(argument: Union[int, Tuple]) -> Tuple: + if isinstance(argument, int): + return (argument, argument) + assert len(argument) == 2 + return argument -def _unpack_strides(strides): - assert isinstance(strides, int) or len(strides) == 2 - if isinstance(strides, int): - return (strides, strides) - else: - return strides +def _check_no_dilation(dilation: Union[int, Tuple]) -> None: + """Takes a dilation argument as an integer or tuple, and makes sure both dimensions are 1. + Dilation prevents us from using DSP instructions, so this schedule can't work (aside from the + niche case where dilation_h == stride_h and dilation_w == stride_w, which is rare enough we + probably don't need to support it).""" + dilation_h, dilation_w = _unpack_2d_argument(dilation) + assert dilation_h == dilation_w == 1 -def _unpack_padding(padding): + +def _unpack_padding(padding: Tuple) -> Tuple: assert isinstance(padding, tuple) if len(padding) == 2: (pad_up, pad_down), (pad_left, pad_right) = padding @@ -60,41 +61,45 @@ def _unpack_padding(padding): return pad_up, pad_left, pad_down, pad_right -# We only care about tuples here - "VALID" and "SAME" padding will be converted by the importer. -def _pad_if_needed(data, layout, pad_up, pad_left, pad_down, pad_right): - if pad_up or pad_left or pad_down or pad_right: - assert len(layout) == 4 - - # We want to pad the "H" and "W" columns, and their position depends on the layout - pad_before, pad_after = [0, 0, 0, 0], [0, 0, 0, 0] - pad_before[layout.index("H")] = pad_up - pad_before[layout.index("W")] = pad_left - pad_after[layout.index("H")] = pad_down - pad_after[layout.index("W")] = pad_right - return pad(data, pad_before, pad_after, name="padded_data") +def _pad_if_needed(data: te.tensor.Tensor, layout: str, padding: Tuple) -> te.tensor.Tensor: + """Performs padding on a te.tensor.Tensor object if necessary. If padding = (0, 0, 0, 0), the + input tensor is returned unmodified. We only care about tuples here - "VALID" and "SAME" padding + will be converted by the importer TFLite importer if present.""" - else: + pad_up, pad_left, pad_down, pad_right = padding + if not any(padding): return data + # We want to pad the "H" and "W" columns, and their position depends on the layout + pad_before, pad_after = [0, 0, 0, 0], [0, 0, 0, 0] + pad_before[layout.index("H")] = pad_up + pad_before[layout.index("W")] = pad_left + pad_after[layout.index("H")] = pad_down + pad_after[layout.index("W")] = pad_right + return pad(data, pad_before, pad_after, name="padded_data") -def _compute_output_dim(data_dim, kernel_dim, pad_before, pad_after, stride): + +def _compute_output_dim(data_dim, kernel_dim, pad_before, pad_after, stride) -> int: return (data_dim - kernel_dim + pad_before + pad_after) // stride + 1 -# Offsets to "prefer" the bottom right corner. This is done to match TensorFlow's convention, but -# does NOT match the other TVM schedules. We violate this convention because it improves accuracy on -# models imported from TensorFlow. -def _compute_offset(data_dim, kernel_dim, pad_before, pad_after, stride): +def _compute_offset(data_dim, kernel_dim, pad_before, pad_after, stride) -> int: + """Computes offsets to "prefer" the bottom right corner. This is done to match TensorFlow's + convention, but it does NOT match the other TVM schedules. We violate this convention because it + improves accuracy on models imported from TensorFlow.""" return (data_dim - kernel_dim + pad_before + pad_after) % stride -# Prevents re-definition of C functions -def _get_suffix(): +def _get_suffix() -> str: + """Returns a random eight-character string to append to C function names. Prevents accidental + re-definition of functions if the same operator appears twice in a Relay graph.""" return "".join(random.choices(string.ascii_uppercase, k=8)) -def conv2d_nhwc_ohwi_dsp_compute(cfg, data, kernel, strides, padding, dilation, out_dtype): - stride_h, stride_w = _unpack_strides(strides) +def conv2d_nhwc_ohwi_dsp_compute(_cfg, data, kernel, strides, padding, dilation, out_dtype): + """Standard conv2d schedule that can be tensorized using tensordot.""" + + stride_h, stride_w = _unpack_2d_argument(strides) pad_up, pad_left, pad_down, pad_right = _unpack_padding(padding) _check_no_dilation(dilation) @@ -107,19 +112,19 @@ def conv2d_nhwc_ohwi_dsp_compute(cfg, data, kernel, strides, padding, dilation, y_offset = _compute_offset(data_h, kernel_h, pad_up, pad_down, stride_h) x_offset = _compute_offset(data_w, kernel_w, pad_left, pad_right, stride_w) - ry = te.reduce_axis((0, kernel_h), name="ry") - rx = te.reduce_axis((0, kernel_w), name="rx") - rc = te.reduce_axis((0, in_channels), name="rc") + kh_i = te.reduce_axis((0, kernel_h), name="kh_i") + kw_i = te.reduce_axis((0, kernel_w), name="kw_i") + kc_i = te.reduce_axis((0, in_channels), name="rc") - padded_data = _pad_if_needed(data, "NHWC", pad_up, pad_left, pad_down, pad_right) + padded_data = _pad_if_needed(data, "NHWC", (pad_up, pad_left, pad_down, pad_right)) return te.compute( (batch_size, output_h, output_w, output_channels), lambda n, y, x, c: te.sum( padded_data[ - n, y_offset + y * stride_h + ry, x_offset + x * stride_w + rx, rc + n, y_offset + y * stride_h + kh_i, x_offset + x * stride_w + kw_i, kc_i ].astype(out_dtype) - * kernel[c, ry, rx, rc].astype(out_dtype), - axis=(ry, rx, rc), + * kernel[c, kh_i, kw_i, kc_i].astype(out_dtype), + axis=(kh_i, kw_i, kc_i), ), name="conv2d", tag="conv2d_nhwc_ohwi_dsp", @@ -127,7 +132,7 @@ def conv2d_nhwc_ohwi_dsp_compute(cfg, data, kernel, strides, padding, dilation, def _make_conv2d_tensorization(padded_data, kernel): - _, padded_h, padded_w, in_channels = padded_data.shape + _, _, padded_w, in_channels = padded_data.shape _, kernel_h, kernel_w, _ = kernel.shape in_dtype = padded_data.dtype suffix = _get_suffix() @@ -140,7 +145,8 @@ def _make_conv2d_tensorization(padded_data, kernel): kw_i = te.reduce_axis((0, kernel_w), name="kw_i") kc_i = te.reduce_axis((0, in_channels), name="kc_i") - output_slice = te.compute((1,), + output_slice = te.compute( + (1,), lambda k: te.sum( data_slice[kh_i, kw_i, kc_i].astype("int32") * kernel_slice[kh_i, kw_i, kc_i].astype("int32"), @@ -152,33 +158,24 @@ def _make_conv2d_tensorization(padded_data, kernel): # TVM has a really strange bug where the outer reduction axis (kh_i) having length 1 causes the # decl_buffer strides check to fail. height_stride is a dark magic workaround for this. height_stride = in_channels * padded_w if kernel_h > 1 else in_channels - data_buf = tir.decl_buffer( - data_slice.shape, data_slice.dtype, name="data", offset_factor=1, - strides=[height_stride, in_channels, 1], - ) - kernel_buf = tir.decl_buffer( - kernel_slice.shape, kernel_slice.dtype, name="kernel", offset_factor=1, - strides=[kernel_w * in_channels, in_channels, 1] - ) - output_buf = tir.decl_buffer( - output_slice.shape, output_slice.dtype, name="output", offset_factor=1, strides=[1] - ) - jump = (padded_w - kernel_w) * in_channels tensordot_params = (in_dtype, kernel_h, jump, kernel_w * in_channels, suffix) - intrin_tensordot = make_intrin_tensordot( - output_slice.op, - {data_slice: data_buf, kernel_slice: kernel_buf, output_slice: output_buf}, - tensordot_params + (data_slice, kernel_slice, output_slice), + ([height_stride, in_channels, 1], [kernel_w * in_channels, in_channels, 1]), + tensordot_params, ) tensordot_code = tensordot_impl(*tensordot_params) return (intrin_tensordot, tensordot_code) -def depthwise_conv2d_nchw_oihw_dsp_compute(cfg, data, kernel, strides, padding, dilation, out_dtype): - stride_h, stride_w = _unpack_strides(strides) +def depthwise_conv2d_nchw_oihw_dsp_compute( + _cfg, data, kernel, strides, padding, dilation, out_dtype +): + """Depthwise conv2d schedule that can be tensorized using tensordot.""" + + stride_h, stride_w = _unpack_2d_argument(strides) pad_up, pad_left, pad_down, pad_right = _unpack_padding(padding) _check_no_dilation(dilation) @@ -192,25 +189,29 @@ def depthwise_conv2d_nchw_oihw_dsp_compute(cfg, data, kernel, strides, padding, y_offset = _compute_offset(data_h, kernel_h, pad_up, pad_down, stride_h) x_offset = _compute_offset(data_w, kernel_w, pad_left, pad_right, stride_w) - ry = te.reduce_axis((0, kernel_h), name="ry") - rx = te.reduce_axis((0, kernel_w), name="rx") + kh_i = te.reduce_axis((0, kernel_h), name="kh_i") + kw_i = te.reduce_axis((0, kernel_w), name="kw_i") - padded_data = _pad_if_needed(data, "NCHW", pad_up, pad_left, pad_down, pad_right) + padded_data = _pad_if_needed(data, "NCHW", (pad_up, pad_left, pad_down, pad_right)) return te.compute( (batch_size, output_channels, output_h, output_w), lambda n, c, y, x: te.sum( padded_data[ - n, indexdiv(c, c_mul), y_offset + y * stride_h + ry, x_offset + x * stride_w + rx, + n, + indexdiv(c, c_mul), + y_offset + y * stride_h + kh_i, + x_offset + x * stride_w + kw_i, ].astype(out_dtype) - * kernel[indexdiv(c, c_mul), indexmod(c, c_mul), ry, rx].astype(out_dtype), - axis=(ry, rx), + * kernel[indexdiv(c, c_mul), indexmod(c, c_mul), kh_i, kw_i].astype(out_dtype), + axis=(kh_i, kw_i), ), name="depthwise_conv2d", tag="depthwise_conv2d_nchw_oihw_dsp", ) + def _make_depthwise_conv2d_tensorization(padded_data, kernel): - _, _, padded_h, padded_w = padded_data.shape + _, _, _, padded_w = padded_data.shape _, _, kernel_h, kernel_w = kernel.shape in_dtype = padded_data.dtype @@ -223,41 +224,31 @@ def _make_depthwise_conv2d_tensorization(padded_data, kernel): kh_i = te.reduce_axis((0, kernel_h), name="kh_i") kw_i = te.reduce_axis((0, kernel_w), name="kw_i") - output_slice = te.compute((1,), + output_slice = te.compute( + (1,), lambda k: te.sum( - data_slice[kh_i, kw_i].astype("int32") - * kernel_slice[kh_i, kw_i].astype("int32"), + data_slice[kh_i, kw_i].astype("int32") * kernel_slice[kh_i, kw_i].astype("int32"), axis=[kh_i, kw_i], ), name="c", ) - data_buf = tir.decl_buffer( - data_slice.shape, data_slice.dtype, name="data", offset_factor=1, - strides=[padded_w, 1], - ) - kernel_buf = tir.decl_buffer( - kernel_slice.shape, kernel_slice.dtype, name="kernel", offset_factor=1, - strides=[kernel_w, 1] - ) - output_buf = tir.decl_buffer( - output_slice.shape, output_slice.dtype, name="output", offset_factor=1, strides=[1] - ) - jump = padded_w - kernel_w tensordot_params = (in_dtype, kernel_h, jump, kernel_w, suffix) - intrin_tensordot = make_intrin_tensordot( - output_slice.op, - {data_slice: data_buf, kernel_slice: kernel_buf, output_slice: output_buf}, - tensordot_params + (data_slice, kernel_slice, output_slice), + ([padded_w, 1], [kernel_w, 1]), + tensordot_params, ) tensordot_code = tensordot_impl(*tensordot_params) return (intrin_tensordot, tensordot_code) -def tensordot_conv2ds_schedule(cfg, outs): +def tensordot_conv2ds_schedule(_cfg, outs): + """Schedule function using v7e-m DSP instructions for all the conv2d operators in this file. We + use one schedule function for them all, because they are tensorized with the same kernel.""" + schedule = te.create_schedule([x.op for x in outs]) def _callback(operator): diff --git a/python/tvm/topi/utils.py b/python/tvm/topi/utils.py index a8f215dce8bf..f6ca03d32742 100644 --- a/python/tvm/topi/utils.py +++ b/python/tvm/topi/utils.py @@ -457,6 +457,7 @@ def change_constant_shape(src, src_layout, dst_layout): reshaped = np.transpose(src.data.numpy(), axis_order) return relay.Constant(tvm.nd.array(reshaped)) + def within_index(b, e, s, i): """Return a boolean value that indicates if i is within the given index. diff --git a/tests/python/relay/strategy/arm_cpu/test_conv2d_nhwc.py b/tests/python/relay/strategy/arm_cpu/test_conv2d_nhwc.py index 53f9481c41b2..d829425bc6f7 100644 --- a/tests/python/relay/strategy/arm_cpu/test_conv2d_nhwc.py +++ b/tests/python/relay/strategy/arm_cpu/test_conv2d_nhwc.py @@ -24,6 +24,7 @@ from tvm.micro.testing.aot_test_utils import AOT_CORSTONE300_RUNNER from tvm.topi.utils import change_constant_shape + class BasicConv2dTests: @tvm.testing.requires_corstone300 def test_conv2d( @@ -152,7 +153,6 @@ class TestConv2d_Tensordot(BasicConv2dTests): # ((1, 32, 32, 1), (3, 3), 12, 1, 0), # ((1, 32, 10, 3), (3, 3), 16, 1, 0), # ((1, 96, 96, 3), (3, 3), 8, (2, 2), (0, 0, 1, 1)), - # Disabled because while our schedule matches TensorFlow's behavior, it does NOT # match the x86 schedule behavior (which is different). These schedules have either: # (in_height + pad_up + pad_down - kernel_h) % stride_h > 0 OR @@ -164,16 +164,39 @@ class TestConv2d_Tensordot(BasicConv2dTests): # ((1, 16, 16, 32), (1, 1), 64, (2, 2), 0, ), # ((1, 16, 16, 32), (1, 1), 64, (2, 2), 0, ), # ((1, 49, 10, 1), (10, 4), 64, (2, 1), (4, 1, 5, 1)), - - ((1, 32, 32, 16), (3, 3), 16, 1, (0, 2, 2, 0)), - ((1, 32, 32, 16), (3, 3), 16, 1, 0, ), - ((1, 32, 32, 16), (3, 3), 16, 1, 0, ), - ((1, 49, 10, 1), (10, 4), 64, (2, 2), (4, 1, 5, 1)), - ((1, 16, 16, 8), (3, 3), 16, 2, (0, 0, 1, 1)), - ((1, 16, 16, 8), (3, 3), 16, 2, (1, 1, 2, 2)), - ((1, 16, 16, 8), (5, 5), 16, 2, (3, 3, 2, 2)), - ((1, 32, 32, 16), (3, 3), 16, 1, 0, ), - ((1, 16, 16, 32), (1, 1), 64, 1, 0, ), + ((1, 32, 32, 16), (3, 3), 16, 1, (0, 2, 2, 0)), + ( + (1, 32, 32, 16), + (3, 3), + 16, + 1, + 0, + ), + ( + (1, 32, 32, 16), + (3, 3), + 16, + 1, + 0, + ), + ((1, 49, 10, 1), (10, 4), 64, (2, 2), (4, 1, 5, 1)), + ((1, 16, 16, 8), (3, 3), 16, 2, (0, 0, 1, 1)), + ((1, 16, 16, 8), (3, 3), 16, 2, (1, 1, 2, 2)), + ((1, 16, 16, 8), (5, 5), 16, 2, (3, 3, 2, 2)), + ( + (1, 32, 32, 16), + (3, 3), + 16, + 1, + 0, + ), + ( + (1, 16, 16, 32), + (1, 1), + 64, + 1, + 0, + ), ) dilation = tvm.testing.parameter(1) dtype = tvm.testing.parameter("int8", "int16", "int32") diff --git a/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py b/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py index 888eb591645e..36059c798cbb 100644 --- a/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py +++ b/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py @@ -202,7 +202,7 @@ class TestDepthwiseConv2d_Tensordot(BasicDepthwiseConv2dTests): ((1, 64, 25, 5), (3, 3), 64, (1, 1), 1, "int32"), ((1, 8, 24, 24), (5, 5), 8, (1, 1), 1, "int32"), ((1, 8, 24, 24), (3, 5), 8, (1, 1), 1, "int32"), - + # These "evenly divisible" kernels work on smaller dtypes. ((1, 8, 48, 48), (3, 2), 8, 1, 0, "int16"), ((1, 8, 48, 48), (4, 4), 8, 1, 0, "int8"), ) From deaebc72d0732696fb2688dd33dd85da2ddc6353 Mon Sep 17 00:00:00 2001 From: Gavin Uberti Date: Wed, 5 Oct 2022 05:29:31 -0700 Subject: [PATCH 11/18] Enforce SIMD alignment in strategy --- python/tvm/relay/op/strategy/arm_cpu.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index 0b036c0abd4c..92928edd90a6 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -16,6 +16,7 @@ # under the License. """Definition of ARM CPU operator strategy.""" import logging +import math # pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import import re @@ -71,6 +72,15 @@ def schedule_pool_arm_cpu(attrs, outs, target): return topi.generic.schedule_pool(outs, layout) +def _is_simd_aligned(dtype, dimensions): + size = math.prod(dimensions) + return ( + (dtype == "int8" and size % 4 == 0) + or (dtype == "int16" and size % 2 == 0) + or (dtype == "int32") + ) + + @conv2d_strategy.register("arm_cpu") def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): """conv2d arm cpu strategy""" @@ -163,6 +173,9 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): target.features.has_dsp and dilation_w == dilation_h == 1 and kernel_layout == "OHWI" + # Check SIMD alignment + and _is_simd_aligned(data.dtype, data.shape[2:]) + and _is_simd_aligned(kernel.dtype, kernel.shape[1:]) ): strategy.add_implementation( wrap_compute_conv2d(topi.arm_cpu.conv2d_nhwc_ohwi_dsp), @@ -210,7 +223,12 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): if layout == "NCHW": assert kernel_layout == "OIHW" or re.match(r"OIHW\d*o", kernel_layout) if kernel_layout == "OIHW": - if target.features.has_dsp and dilation_w == dilation_h == 1: + if ( + target.features.has_dsp + and dilation_w == dilation_h == 1 + and _is_simd_aligned(data.dtype, data.shape[3:]) + and _is_simd_aligned(kernel.dtype, kernel.shape[2:]) + ): strategy.add_implementation( wrap_compute_conv2d(topi.arm_cpu.depthwise_conv2d_nchw_oihw_dsp), wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nchw_oihw_dsp), From 981b1bdb4780b27cc673ba9cfcd29bb56322ba54 Mon Sep 17 00:00:00 2001 From: Gavin Uberti Date: Wed, 5 Oct 2022 05:37:39 -0700 Subject: [PATCH 12/18] Prevent black from butchering our formatting --- .../strategy/arm_cpu/test_conv2d_nhwc.py | 32 +++---------------- 1 file changed, 4 insertions(+), 28 deletions(-) diff --git a/tests/python/relay/strategy/arm_cpu/test_conv2d_nhwc.py b/tests/python/relay/strategy/arm_cpu/test_conv2d_nhwc.py index d829425bc6f7..e2867edf87dc 100644 --- a/tests/python/relay/strategy/arm_cpu/test_conv2d_nhwc.py +++ b/tests/python/relay/strategy/arm_cpu/test_conv2d_nhwc.py @@ -165,38 +165,14 @@ class TestConv2d_Tensordot(BasicConv2dTests): # ((1, 16, 16, 32), (1, 1), 64, (2, 2), 0, ), # ((1, 49, 10, 1), (10, 4), 64, (2, 1), (4, 1, 5, 1)), ((1, 32, 32, 16), (3, 3), 16, 1, (0, 2, 2, 0)), - ( - (1, 32, 32, 16), - (3, 3), - 16, - 1, - 0, - ), - ( - (1, 32, 32, 16), - (3, 3), - 16, - 1, - 0, - ), + ((1, 32, 32, 16), (3, 3), 16, 1, 0), + ((1, 32, 32, 16), (3, 3), 16, 1, 0), ((1, 49, 10, 1), (10, 4), 64, (2, 2), (4, 1, 5, 1)), ((1, 16, 16, 8), (3, 3), 16, 2, (0, 0, 1, 1)), ((1, 16, 16, 8), (3, 3), 16, 2, (1, 1, 2, 2)), ((1, 16, 16, 8), (5, 5), 16, 2, (3, 3, 2, 2)), - ( - (1, 32, 32, 16), - (3, 3), - 16, - 1, - 0, - ), - ( - (1, 16, 16, 32), - (1, 1), - 64, - 1, - 0, - ), + ((1, 32, 32, 16), (3, 3), 16, 1, 0), + ((1, 16, 16, 32), (1, 1), 64, 1, 0), ) dilation = tvm.testing.parameter(1) dtype = tvm.testing.parameter("int8", "int16", "int32") From 1a21f6cbd4bba16bb8d9af87418ed92a4c934448 Mon Sep 17 00:00:00 2001 From: Gavin Uberti Date: Wed, 5 Oct 2022 13:45:35 -0700 Subject: [PATCH 13/18] Address code review comments --- .../mprofile/dsp/micro_kernel/tensordot.py | 25 ++++++++++++++++--- .../arm_cpu/mprofile/dsp/tensordot_conv2ds.py | 20 +++++++++++---- 2 files changed, 37 insertions(+), 8 deletions(-) diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py index 5305b90e6323..0fdffc06cf4f 100644 --- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py @@ -74,10 +74,26 @@ def intrin_func(ins, outs): ) -def tensordot_impl(in_dtype, tensor_h, jump, tensor_w, suffix): - """Generates C code for tensordot. The int8 and int16 versions have Arm v7e-m DSP assembly.""" +def tensordot_impl(in_dtype: str, tensor_h: int, jump: int, tensor_w: int, suffix: str) -> str: + """Generates C code for taking the dot products of two `tensor_h` * `tensor_w` tensors. Also has + a `jump` argument that advances the pointer of one tensor by that many words after each row. The + `jump` and `tensor_w` values must be word-aligned for the input data type, as non-word-aligned + memory access is slow on the Cortex-M series. Depending on the input datatype, the code may + contain DSP instructions for Arm v7e-m. C code contains DSP instructions for Arm v7e-m. See + the below pseudocode for reference: + + tensordot(out_ptr, dat_ptr, ker_ptr) { + sum = 0; + for (i = 0; i < tensor_h; i++) { + for (j = 0; j < tensor_w; j++) { + sum += (*dat_ptr++) * (*ker_ptr++); + } + dat_ptr += jump; + } + *out_ptr = sum; + } + """ - assert in_dtype in ["int8", "int16", "int32"] simd_lanes = num_simd_lanes_per_word(in_dtype) assert tensor_w % simd_lanes == 0 assert jump % simd_lanes == 0 @@ -101,6 +117,9 @@ def tensordot_impl(in_dtype, tensor_h, jump, tensor_w, suffix): // Compiles to a single MAC instruction sum += tensor_batch * kernel_batch;""" + else: + raise ValueError(f"No tensordot implementation exists for dtype '{in_dtype}'!") + function_name = _get_func_name(in_dtype, tensor_h, jump, tensor_w, suffix) return textwrap.dedent( ( diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/tensordot_conv2ds.py b/python/tvm/topi/arm_cpu/mprofile/dsp/tensordot_conv2ds.py index a57c306b2f7c..3579daf029cb 100644 --- a/python/tvm/topi/arm_cpu/mprofile/dsp/tensordot_conv2ds.py +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/tensordot_conv2ds.py @@ -79,11 +79,18 @@ def _pad_if_needed(data: te.tensor.Tensor, layout: str, padding: Tuple) -> te.te return pad(data, pad_before, pad_after, name="padded_data") -def _compute_output_dim(data_dim, kernel_dim, pad_before, pad_after, stride) -> int: - return (data_dim - kernel_dim + pad_before + pad_after) // stride + 1 - - -def _compute_offset(data_dim, kernel_dim, pad_before, pad_after, stride) -> int: +def _compute_output_dim( + data_dim: int, kernel_dim: int, pad_before: int, pad_after: int, stride: int +) -> int: + """Computes an output dimension of a convolution, given the data dimension, kernel dimension, + padding, and stride along that axis. Note that when stride > 1, this division will often not + be perfectly even.""" + return (data_dim + pad_before + pad_after - kernel_dim) // stride + 1 + + +def _compute_offset( + data_dim: int, kernel_dim: int, pad_before: int, pad_after: int, stride: int +) -> int: """Computes offsets to "prefer" the bottom right corner. This is done to match TensorFlow's convention, but it does NOT match the other TVM schedules. We violate this convention because it improves accuracy on models imported from TensorFlow.""" @@ -269,6 +276,9 @@ def _callback(operator): schedule[output].reorder(b_ax, co_ax, y_ax, x_ax, kh_ax, kw_ax) intrin, code = _make_depthwise_conv2d_tensorization(padded_data, kernel) + else: + raise ValueError(f"Cannot tensorize {operator.tag} with tensordot!") + schedule[output].tensorize(kh_ax, intrin) schedule[output].pragma(b_ax, "import_c", code) From 5cebb6605da534aa600acd376cf3bc1d30b60b10 Mon Sep 17 00:00:00 2001 From: Gavin Uberti Date: Thu, 6 Oct 2022 11:57:32 -0700 Subject: [PATCH 14/18] Fix alignment strategy bug --- python/tvm/relay/op/strategy/arm_cpu.py | 29 ++++++++++++++++++++----- 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index 92928edd90a6..2abc29a744d8 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -72,8 +72,23 @@ def schedule_pool_arm_cpu(attrs, outs, target): return topi.generic.schedule_pool(outs, layout) -def _is_simd_aligned(dtype, dimensions): - size = math.prod(dimensions) +def _get_padding_width(padding): + assert isinstance(padding, tuple) + if len(padding) == 2: + _, (pad_left, pad_right) = padding + else: + _pad_up, pad_left, _pad_down, pad_right = padding + return pad_left + pad_right + + +def _is_simd_aligned(dtype, dimensions, padding=None): + if padding: + assert len(dimensions) == len(padding) + padded_dims = (sum(x) for x in zip(dimensions, padding)) + else: + padded_dims = dimensions + + size = math.prod(padded_dims) return ( (dtype == "int8" and size % 4 == 0) or (dtype == "int16" and size % 2 == 0) @@ -169,13 +184,14 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): name="conv2d_hwcn.generic", ) elif layout == "NHWC": + data_width_padding = _get_padding_width(padding) if ( target.features.has_dsp and dilation_w == dilation_h == 1 and kernel_layout == "OHWI" # Check SIMD alignment - and _is_simd_aligned(data.dtype, data.shape[2:]) - and _is_simd_aligned(kernel.dtype, kernel.shape[1:]) + and _is_simd_aligned(data.dtype, data.shape[2:], padding=(data_width_padding, 0)) + and _is_simd_aligned(kernel.dtype, kernel.shape[2:]) ): strategy.add_implementation( wrap_compute_conv2d(topi.arm_cpu.conv2d_nhwc_ohwi_dsp), @@ -223,11 +239,12 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): if layout == "NCHW": assert kernel_layout == "OIHW" or re.match(r"OIHW\d*o", kernel_layout) if kernel_layout == "OIHW": + data_width_padding = _get_padding_width(padding) if ( target.features.has_dsp and dilation_w == dilation_h == 1 - and _is_simd_aligned(data.dtype, data.shape[3:]) - and _is_simd_aligned(kernel.dtype, kernel.shape[2:]) + and _is_simd_aligned(data.dtype, data.shape[3:], padding=(data_width_padding,)) + and _is_simd_aligned(kernel.dtype, kernel.shape[3:]) ): strategy.add_implementation( wrap_compute_conv2d(topi.arm_cpu.depthwise_conv2d_nchw_oihw_dsp), From cf18071bd2b24b281dc6188f9c8816a6a2910ae8 Mon Sep 17 00:00:00 2001 From: Gavin Uberti Date: Thu, 6 Oct 2022 12:12:12 -0700 Subject: [PATCH 15/18] Fix linting --- python/tvm/relay/op/strategy/arm_cpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index 2abc29a744d8..23826a718cea 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -77,7 +77,7 @@ def _get_padding_width(padding): if len(padding) == 2: _, (pad_left, pad_right) = padding else: - _pad_up, pad_left, _pad_down, pad_right = padding + _, pad_left, _, pad_right = padding return pad_left + pad_right From b7d5b96ae07ed1cbfd4f0b81972bb4b393836e94 Mon Sep 17 00:00:00 2001 From: Gavin Uberti Date: Fri, 7 Oct 2022 04:52:44 -0700 Subject: [PATCH 16/18] Remove unconventional offset behavior --- .../arm_cpu/mprofile/dsp/tensordot_conv2ds.py | 15 +-------------- .../relay/strategy/arm_cpu/test_conv2d_nhwc.py | 18 +++++++----------- 2 files changed, 8 insertions(+), 25 deletions(-) diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/tensordot_conv2ds.py b/python/tvm/topi/arm_cpu/mprofile/dsp/tensordot_conv2ds.py index 3579daf029cb..5b493ba07030 100644 --- a/python/tvm/topi/arm_cpu/mprofile/dsp/tensordot_conv2ds.py +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/tensordot_conv2ds.py @@ -88,15 +88,6 @@ def _compute_output_dim( return (data_dim + pad_before + pad_after - kernel_dim) // stride + 1 -def _compute_offset( - data_dim: int, kernel_dim: int, pad_before: int, pad_after: int, stride: int -) -> int: - """Computes offsets to "prefer" the bottom right corner. This is done to match TensorFlow's - convention, but it does NOT match the other TVM schedules. We violate this convention because it - improves accuracy on models imported from TensorFlow.""" - return (data_dim - kernel_dim + pad_before + pad_after) % stride - - def _get_suffix() -> str: """Returns a random eight-character string to append to C function names. Prevents accidental re-definition of functions if the same operator appears twice in a Relay graph.""" @@ -116,8 +107,6 @@ def conv2d_nhwc_ohwi_dsp_compute(_cfg, data, kernel, strides, padding, dilation, output_h = _compute_output_dim(data_h, kernel_h, pad_up, pad_down, stride_h) output_w = _compute_output_dim(data_w, kernel_w, pad_left, pad_right, stride_w) - y_offset = _compute_offset(data_h, kernel_h, pad_up, pad_down, stride_h) - x_offset = _compute_offset(data_w, kernel_w, pad_left, pad_right, stride_w) kh_i = te.reduce_axis((0, kernel_h), name="kh_i") kw_i = te.reduce_axis((0, kernel_w), name="kw_i") @@ -127,9 +116,7 @@ def conv2d_nhwc_ohwi_dsp_compute(_cfg, data, kernel, strides, padding, dilation, return te.compute( (batch_size, output_h, output_w, output_channels), lambda n, y, x, c: te.sum( - padded_data[ - n, y_offset + y * stride_h + kh_i, x_offset + x * stride_w + kw_i, kc_i - ].astype(out_dtype) + padded_data[n, y * stride_h + kh_i, x * stride_w + kw_i, kc_i].astype(out_dtype) * kernel[c, kh_i, kw_i, kc_i].astype(out_dtype), axis=(kh_i, kw_i, kc_i), ), diff --git a/tests/python/relay/strategy/arm_cpu/test_conv2d_nhwc.py b/tests/python/relay/strategy/arm_cpu/test_conv2d_nhwc.py index e2867edf87dc..f5de3b51b67d 100644 --- a/tests/python/relay/strategy/arm_cpu/test_conv2d_nhwc.py +++ b/tests/python/relay/strategy/arm_cpu/test_conv2d_nhwc.py @@ -153,17 +153,13 @@ class TestConv2d_Tensordot(BasicConv2dTests): # ((1, 32, 32, 1), (3, 3), 12, 1, 0), # ((1, 32, 10, 3), (3, 3), 16, 1, 0), # ((1, 96, 96, 3), (3, 3), 8, (2, 2), (0, 0, 1, 1)), - # Disabled because while our schedule matches TensorFlow's behavior, it does NOT - # match the x86 schedule behavior (which is different). These schedules have either: - # (in_height + pad_up + pad_down - kernel_h) % stride_h > 0 OR - # (in_width + pad_left + pad_right - kernel_w) % stride_w > 0 - # ((4, 16, 16, 8), (5, 5), 8, 2, (0, 4, 3, 0)), - # ((4, 16, 16, 8), (5, 5), 16, 2, (0, 4, 4, 0)), - # ((4, 16, 16, 8), (5, 5), 8, 2, 0, ), - # ((4, 16, 16, 8), (5, 5), 16, 2, 0, ), - # ((1, 16, 16, 32), (1, 1), 64, (2, 2), 0, ), - # ((1, 16, 16, 32), (1, 1), 64, (2, 2), 0, ), - # ((1, 49, 10, 1), (10, 4), 64, (2, 1), (4, 1, 5, 1)), + ((4, 16, 16, 8), (5, 5), 8, 2, (0, 3, 3, 0)), + ((4, 16, 16, 8), (5, 5), 16, 2, (0, 3, 3, 0)), + ((4, 16, 16, 8), (5, 5), 8, 2, 0), + ((4, 16, 16, 8), (5, 5), 16, 2, 0), + ((1, 16, 16, 32), (1, 1), 64, (2, 2), 0), + ((1, 16, 16, 32), (1, 1), 64, (2, 2), 0), + ((1, 49, 10, 1), (10, 4), 64, (2, 1), (4, 1, 5, 1)), ((1, 32, 32, 16), (3, 3), 16, 1, (0, 2, 2, 0)), ((1, 32, 32, 16), (3, 3), 16, 1, 0), ((1, 32, 32, 16), (3, 3), 16, 1, 0), From 771c919d9f2f8084d786cfc744056238387a5520 Mon Sep 17 00:00:00 2001 From: Gavin Uberti Date: Fri, 7 Oct 2022 07:28:54 -0700 Subject: [PATCH 17/18] Replace math.prod function to support Python 3.7 --- python/tvm/relay/op/strategy/arm_cpu.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index 23826a718cea..e56e7ba12e94 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. """Definition of ARM CPU operator strategy.""" +from functools import reduce import logging -import math # pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import import re @@ -88,7 +88,9 @@ def _is_simd_aligned(dtype, dimensions, padding=None): else: padded_dims = dimensions - size = math.prod(padded_dims) + # Multiply all elements of padded_dims together. We can't use math.prod, as it + # does not exist in Python 3.7. + size = reduce(lambda x, y: x * y, padded_dims) return ( (dtype == "int8" and size % 4 == 0) or (dtype == "int16" and size % 2 == 0) From 1966533449919b356e2178ecf1c4bc86e3bb396a Mon Sep 17 00:00:00 2001 From: Gavin Uberti Date: Fri, 7 Oct 2022 09:11:53 -0700 Subject: [PATCH 18/18] Fix CI tests --- python/tvm/topi/arm_cpu/mprofile/dsp/tensordot_conv2ds.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/tensordot_conv2ds.py b/python/tvm/topi/arm_cpu/mprofile/dsp/tensordot_conv2ds.py index 5b493ba07030..ccd0c8e3ef32 100644 --- a/python/tvm/topi/arm_cpu/mprofile/dsp/tensordot_conv2ds.py +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/tensordot_conv2ds.py @@ -180,8 +180,6 @@ def depthwise_conv2d_nchw_oihw_dsp_compute( output_h = _compute_output_dim(data_h, kernel_h, pad_up, pad_down, stride_h) output_w = _compute_output_dim(data_w, kernel_w, pad_left, pad_right, stride_w) - y_offset = _compute_offset(data_h, kernel_h, pad_up, pad_down, stride_h) - x_offset = _compute_offset(data_w, kernel_w, pad_left, pad_right, stride_w) kh_i = te.reduce_axis((0, kernel_h), name="kh_i") kw_i = te.reduce_axis((0, kernel_w), name="kw_i") @@ -193,8 +191,8 @@ def depthwise_conv2d_nchw_oihw_dsp_compute( padded_data[ n, indexdiv(c, c_mul), - y_offset + y * stride_h + kh_i, - x_offset + x * stride_w + kw_i, + y * stride_h + kh_i, + x * stride_w + kw_i, ].astype(out_dtype) * kernel[indexdiv(c, c_mul), indexmod(c, c_mul), kh_i, kw_i].astype(out_dtype), axis=(kh_i, kw_i),