From 08e901f62d9680b3971e2d5727f723e89b134e7f Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Thu, 1 Sep 2022 14:12:46 +0000 Subject: [PATCH 1/2] [ETHOSN] Add support for transpose convolution Adds support for offloading transpose convolution with an optional bias to the NPU. Co-authored-by: Samuel Panijel Co-authored-by: Leo Blonk Change-Id: I2534fc45a1498679b4701f41b8f7a1f79dda4e79 --- python/tvm/relay/op/contrib/ethosn.py | 18 ++ src/relay/backend/contrib/ethosn/codegen.cc | 39 +++ .../backend/contrib/ethosn/codegen_ethosn.h | 1 + .../contrib/ethosn/convert_equivalent.cc | 15 +- .../backend/contrib/ethosn/ethosn_api.cc | 126 ++++++++++ src/relay/backend/contrib/ethosn/ethosn_api.h | 23 ++ .../contrib/test_ethosn/infrastructure.py | 43 ++++ .../python/contrib/test_ethosn/test_conv2d.py | 21 +- .../test_ethosn/test_conv2d_transpose.py | 234 ++++++++++++++++++ 9 files changed, 487 insertions(+), 33 deletions(-) create mode 100644 tests/python/contrib/test_ethosn/test_conv2d_transpose.py diff --git a/python/tvm/relay/op/contrib/ethosn.py b/python/tvm/relay/op/contrib/ethosn.py index a4e9d9647c95..b475e3af2a95 100644 --- a/python/tvm/relay/op/contrib/ethosn.py +++ b/python/tvm/relay/op/contrib/ethosn.py @@ -233,6 +233,16 @@ def qnn_add_pattern(): return input_is_left | input_is_right | two_inputs + def qnn_conv2d_transpose_pattern(): + pattern = is_op("qnn.conv2d_transpose")( + wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant() + ).has_attr({"data_layout": "NHWC"}) + pattern = pattern.optional(lambda x: is_op("nn.bias_add")(x, is_constant())) + pattern = is_op("qnn.requantize")( + pattern, is_constant(), is_constant(), is_constant(), is_constant() + ) + return pattern + def check_conv2d(extract): """Check if a conv2d is supported by Ethos-N.""" if not ethosn_available(): @@ -261,6 +271,13 @@ def check_mean(extract): return _ethosn.mean(extract) + def check_conv2d_transpose(extract): + """Check if mean is supported by Ethos-N.""" + if not ethosn_available(): + return False + + return _ethosn.conv2d_transpose(extract) + def check_sigmoid(extract): """Check if a sigmoid is supported by Ethos-N.""" if not ethosn_available(): @@ -326,6 +343,7 @@ def check_add(extract): ("ethos-n.qnn_mul", qnn_mul_pattern(), check_mul), ("ethos-n.qnn_add", qnn_add_pattern(), check_add), ("ethos-n.qnn_conv2d", qnn_conv_pattern(), check_conv2d), + ("ethos-n.qnn_conv2d_transpose", qnn_conv2d_transpose_pattern(), check_conv2d_transpose), ("ethos-n.qnn_avg_pool2d", qnn_avg_pool2d_pattern(), check_avg_pool2d), ("ethos-n.qnn_sigmoid", qnn_sigmoid_pattern(), check_sigmoid), ("ethos-n.qnn_fc", qnn_fc_pattern(), check_fc), diff --git a/src/relay/backend/contrib/ethosn/codegen.cc b/src/relay/backend/contrib/ethosn/codegen.cc index 69672a143585..c7109b754d2b 100644 --- a/src/relay/backend/contrib/ethosn/codegen.cc +++ b/src/relay/backend/contrib/ethosn/codegen.cc @@ -125,6 +125,10 @@ void InferTensorsVisitor::InferCall(const CallNode* cn) { LeakyReLUParams params; err += EthosnAPI::LeakyReLU(cn->op.as()->body, ¶ms); tensor_table_[cn->args[0]] = {params.input_info}; + } else if (IsEthosnFunc(call, "ethos-n.qnn_conv2d_transpose")) { + QnnConv2dTransposeParams params; + err += EthosnAPI::QnnConv2dTranspose(cn->op.as()->body, ¶ms); + tensor_table_[cn->args[0]] = {params.input_info}; } else if (IsEthosnOp(call, "qnn.concatenate")) { ConcatenateParams params; err = EthosnAPI::Concatenate(call, ¶ms); @@ -311,6 +315,9 @@ sl::TensorsAndId ConstructNetworkVisitor::HandleCall(const CallNode* cn) { } else if (IsEthosnFunc(call, "ethos-n.qnn_leaky_relu")) { if ((err = MakeLeakyReLULayer(call, &tensor))) ReportFatalError(call, err); return MakeOps(tensor); + } else if (IsEthosnFunc(call, "ethos-n.qnn_conv2d_transpose")) { + if ((err = MakeConv2DTransposeLayer(call, &tensor))) ReportFatalError(call, err); + return MakeOps(tensor); } else if (IsEthosnOp(call, "qnn.concatenate")) { if ((err = MakeConcatenateLayer(call, &tensor))) ReportFatalError(call, err); return MakeOps(tensor); @@ -537,6 +544,24 @@ EthosnError ConstructNetworkVisitor::MakeLeakyReLULayer(const Call& call, return EthosnError(); } +EthosnError ConstructNetworkVisitor::MakeConv2DTransposeLayer(const Call& call, + sl::TensorAndId* out) { + QnnConv2dTransposeParams params; + if (auto err = EthosnAPI::QnnConv2dTranspose(call->op.as()->body, ¶ms)) { + return err; + } + + auto activation = operand_table_[call->args[0]][0]; + auto weights = AddConstant(network_, params.weights_info, params.raw_weights->data).tensor; + auto bias = AddConstant(network_, params.bias_info, params.raw_bias->data).tensor; + try { + *out = AddTransposeConvolution(network_, *activation, *bias, *weights, params.conv_info); + } catch (const sl::NotSupportedException& e) { + return EthosnError(e.what()); + } + return EthosnError(); +} + EthosnError ConstructNetworkVisitor::MakeConcatenateLayer(const Call& call, sl::TensorAndId* out) { ConcatenateParams params; @@ -913,6 +938,20 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.leaky_relu") err += EthosnError(reason); }); +TVM_REGISTER_GLOBAL("relay.ethos-n.support.conv2d_transpose") + .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { + Call call = args[0]; + QnnConv2dTransposeParams params; + auto err = EthosnAPI::QnnConv2dTranspose(call, ¶ms); + err += EthosnCompiler::SupportedSetup(); + char reason[kReasonMaxLength]; + reason[0] = '\0'; + *rv = !err && EthosnCompiler::GetSupported()->IsTransposeConvolutionSupported( + params.bias_info, params.weights_info, params.conv_info, params.input_info, + ¶ms.output_info, reason, sizeof(reason)); + err += EthosnError(reason); + }); + TVM_REGISTER_GLOBAL("relay.ethos-n.support.concatenate") .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { Call call = args[0]; diff --git a/src/relay/backend/contrib/ethosn/codegen_ethosn.h b/src/relay/backend/contrib/ethosn/codegen_ethosn.h index 863a032cafba..a653b0b8dc97 100644 --- a/src/relay/backend/contrib/ethosn/codegen_ethosn.h +++ b/src/relay/backend/contrib/ethosn/codegen_ethosn.h @@ -206,6 +206,7 @@ class ConstructNetworkVisitor : public MixedModeVisitor, private ErrorReportingP EthosnError MakeSigmoidLayer(const Call& call, sl::TensorAndId* out); EthosnError MakeMeanLayer(const Call& call, sl::TensorAndId* out); EthosnError MakeTanhLayer(const Call& call, sl::TensorAndId* out); + EthosnError MakeConv2DTransposeLayer(const Call& call, sl::TensorAndId* out); EthosnError MakeConcatenateLayer(const Call& call, sl::TensorAndId* out); EthosnError MakeSplitLayer(const Call& call, sl::TensorsAndId* outs); EthosnError MakeDepthToSpaceLayer(const Call& call, sl::TensorAndId* out); diff --git a/src/relay/backend/contrib/ethosn/convert_equivalent.cc b/src/relay/backend/contrib/ethosn/convert_equivalent.cc index 12b5a12afb35..91c924b1b04f 100644 --- a/src/relay/backend/contrib/ethosn/convert_equivalent.cc +++ b/src/relay/backend/contrib/ethosn/convert_equivalent.cc @@ -32,26 +32,13 @@ #include "../../../qnn/utils.h" #include "../../../transforms/pattern_utils.h" #include "../../../transforms/simplify_expr.h" +#include "ethosn_api.h" namespace tvm { namespace relay { namespace contrib { namespace ethosn { -/*! - * \brief Apply constant folding on an expression. - * - * \param expr The expression to fold. - * \param fold_qnn Whether to fold constants for QNN operations. - * \returns The new folded expression. - */ -Expr FoldConstantExpr(const Expr& expr, bool fold_qnn = true) { - auto mod = IRModule::FromExpr(expr); - mod = transform::FoldConstant(fold_qnn)(mod); - auto entry_func = Downcast(mod->Lookup("main")); - return expr.as() == nullptr ? entry_func->body : entry_func; -} - /*! * \brief Converts qnn.mul to mathematically equivalent * qnn.conv2d depthwise operation. diff --git a/src/relay/backend/contrib/ethosn/ethosn_api.cc b/src/relay/backend/contrib/ethosn/ethosn_api.cc index 4f01c924cf6e..ce57cc23419a 100644 --- a/src/relay/backend/contrib/ethosn/ethosn_api.cc +++ b/src/relay/backend/contrib/ethosn/ethosn_api.cc @@ -23,6 +23,7 @@ #include "ethosn_api.h" +#include #include #include #include @@ -37,6 +38,9 @@ #include #include +#include "../../../op/make_op.h" +#include "../../../transforms/pattern_utils.h" +#include "../../../transforms/simplify_expr.h" #include "ethosn_support_library/Support.hpp" #include "ethosn_support_library/SupportQueries.hpp" #include "tvm/relay/qnn/attrs.h" @@ -445,6 +449,121 @@ EthosnError EthosnAPI::Mean(const Expr& expr, MeanParams* params) { return err; } +Constant TransposeWeights(const Constant& data, const std::string& input_layout) { + int pos_h = input_layout.find("H"); + int pos_w = input_layout.find("W"); + int pos_i = input_layout.find("I"); + int pos_o = input_layout.find("O"); + + // Currently the expected target layout is HWIO only. + Array target_shape = {pos_h, pos_w, pos_i, pos_o}; + + Expr transpose = MakeTranspose(data, target_shape); + transpose = InferType(FoldConstantExpr(transpose)); + Constant transposed_data = Downcast(transpose); + return transposed_data; +} + +EthosnError EthosnAPI::QnnConv2dTranspose(const Expr& expr, QnnConv2dTransposeParams* params) { + Call requantize = Downcast(expr); + Call bias; + Call conv2d_transpose; + if (requantize->args[0]->IsInstance() && + Downcast(requantize->args[0])->op == Op::Get("nn.bias_add")) { + bias = Downcast(requantize->args[0]); + conv2d_transpose = Downcast(bias->args[0]); + } else { + conv2d_transpose = Downcast(requantize->args[0]); + } + const auto& conv_attr = conv2d_transpose->attrs.as(); + ICHECK(conv_attr) << "Expected type Conv2DTransposeAttrs but was " + << conv2d_transpose->attrs->GetTypeKey(); + + int input_zero_point; + int kernel_zero_point; + int output_zero_point; + std::valarray input_scale; + std::valarray kernel_scale; + float output_scale; + unsigned int qaxis = conv_attr->kernel_layout.find("O"); + + EthosnError err = AsConstant(conv2d_transpose->args[2], &input_zero_point); + err += AsConstant(conv2d_transpose->args[3], &kernel_zero_point); + err += AsConstant(requantize->args[4], &output_zero_point); + err += AsConstant(conv2d_transpose->args[4], &input_scale); + err += AsConstant(conv2d_transpose->args[5], &kernel_scale); + err += AsConstant(requantize->args[3], &output_scale); + + // Convert quantization params + sl::QuantizationInfo input_q_info; + sl::QuantizationInfo weights_q_info; + sl::QuantizationInfo bias_q_info; + sl::QuantizationInfo output_q_info; + err += Tvm2Npu(input_zero_point, input_scale, qaxis, &input_q_info); + err += Tvm2Npu(kernel_zero_point, kernel_scale, qaxis, &weights_q_info); + std::valarray bias_scales = input_q_info.GetScales() * weights_q_info.GetScales(); + err += Tvm2Npu(0, bias_scales, 3, &bias_q_info); + err += Tvm2Npu(output_zero_point, output_scale, &output_q_info); + + // Convert convolution attributes + sl::Padding padding; + err += Tvm2Npu(conv_attr->padding, &padding); + sl::Stride stride; + err += Tvm2Npu(conv_attr->strides, &stride); + // Dilation is not supported + std::array dilation = {1, 1}; + AsArray(conv_attr->dilation, &dilation); + if (conv_attr->dilation.size() != 2 || dilation[0] != 1 || dilation[1] != 1) { + err += + EthosnError(ErrStrm() << "dilation=" << conv_attr->dilation << ", dilation must = [1, 1]"); + } + + // Create convolution info + params->conv_info = sl::ConvolutionInfo(padding, stride, output_q_info); + + // Create input info + sl::TensorInfo input_tensor_info; + err += Tvm2Npu(conv2d_transpose->args[0]->checked_type(), &input_tensor_info); + input_tensor_info.m_QuantizationInfo = input_q_info; + params->input_info = input_tensor_info; + + // Create weights info + Constant weights_data = Downcast(conv2d_transpose->args[1]); + if (conv_attr->kernel_layout != "HWIO") { + weights_data = TransposeWeights(weights_data, conv_attr->kernel_layout); + } + const auto* weights_ttype = weights_data->checked_type().as(); + sl::TensorShape weights_tensor_shape; + sl::DataType weights_data_type; + sl::DataFormat weights_data_format; + // Ignore the error here because weights don't have a batch axis + Tvm2Npu(weights_ttype->shape, &weights_tensor_shape); + err += Tvm2Npu(weights_ttype->dtype, &weights_data_type); + err += Tvm2Npu("HWIO", &weights_data_format); + params->weights_info = + sl::TensorInfo(weights_tensor_shape, weights_data_type, weights_data_format, weights_q_info); + + params->raw_weights = weights_data->data; + + // Create bias info + unsigned int out_channels = Downcast(conv_attr->channels)->value; + params->bias_info = sl::TensorInfo({1, 1, 1, out_channels}, sl::DataType::INT32_QUANTIZED, + sl::DataFormat::NHWC, bias_q_info); + if (bias.defined()) { + params->raw_bias = Downcast(bias->args[1])->data; + } else { + params->raw_bias = MakeConstantZeros(tvm::DataType::Int(32), {1, 1, 1, out_channels})->data; + } + + // Create output info + sl::TensorInfo output_tensor_info; + err += Tvm2Npu(requantize->checked_type(), &output_tensor_info); + output_tensor_info.m_QuantizationInfo = output_q_info; + params->output_info = output_tensor_info; + + return err; +} + EthosnError EthosnAPI::Tanh(const Expr& expr, TanhParams* params) { Call quantize = Downcast(expr); Call tanh = Downcast(quantize->args[0]); @@ -925,6 +1044,13 @@ EthosnError EthosnAPI::AsConstant(const Expr& expr, T* out) { return EthosnError(); } +Expr FoldConstantExpr(const Expr& expr, bool fold_qnn) { + auto mod = IRModule::FromExpr(expr); + mod = transform::FoldConstant(fold_qnn)(mod); + auto entry_func = Downcast(mod->Lookup("main")); + return expr.as() == nullptr ? entry_func->body : entry_func; +} + } // namespace ethosn } // namespace contrib } // namespace relay diff --git a/src/relay/backend/contrib/ethosn/ethosn_api.h b/src/relay/backend/contrib/ethosn/ethosn_api.h index afe4736bfc40..167106c3d06d 100644 --- a/src/relay/backend/contrib/ethosn/ethosn_api.h +++ b/src/relay/backend/contrib/ethosn/ethosn_api.h @@ -24,6 +24,7 @@ #ifndef TVM_RELAY_BACKEND_CONTRIB_ETHOSN_ETHOSN_API_H_ #define TVM_RELAY_BACKEND_CONTRIB_ETHOSN_ETHOSN_API_H_ +#include #include #include #include @@ -115,6 +116,16 @@ struct LeakyReLUParams { sl::TensorInfo output_info; }; +struct QnnConv2dTransposeParams { + sl::ConvolutionInfo conv_info; + sl::TensorInfo input_info; + sl::TensorInfo weights_info; + sl::TensorInfo bias_info; + sl::TensorInfo output_info; + runtime::NDArray raw_weights; + runtime::NDArray raw_bias; +}; + struct ConcatenateParams { sl::QuantizationInfo qInfo; sl::ConcatenationInfo concat_info = sl::ConcatenationInfo(1, qInfo); @@ -237,6 +248,9 @@ class EthosnAPI { static EthosnError Tanh(const Expr& expr, TanhParams* params); /*! \brief Extract the Support Library leaky relu params from an ethos-n leaky relu Relu call. */ static EthosnError LeakyReLU(const Expr& expr, LeakyReLUParams* params); + /*! \brief Extract the Support Library transpose params from a Relay + * ethos-n.qnn_conv2d_transpose func */ + static EthosnError QnnConv2dTranspose(const Expr& expr, QnnConv2dTransposeParams* params); /*! \brief Extract the Support Library concatenate params from a Relay qnn.concatenate call */ static EthosnError Concatenate(const Expr& expr, ConcatenateParams* params); /*! \brief Extract the Support Library split params from a Relay split call */ @@ -294,6 +308,15 @@ class EthosnAPI { static EthosnError AsConstant(const Expr& expr, std::valarray* out); }; +/*! + * \brief Apply constant folding on an expression. + * + * \param expr The expression to fold. + * \param fold_qnn Whether to fold constants for QNN operations. + * \returns The new folded expression. + */ +Expr FoldConstantExpr(const Expr& expr, bool fold_qnn = true); + } // namespace ethosn } // namespace contrib } // namespace relay diff --git a/tests/python/contrib/test_ethosn/infrastructure.py b/tests/python/contrib/test_ethosn/infrastructure.py index c658b33747c3..6b019686968e 100644 --- a/tests/python/contrib/test_ethosn/infrastructure.py +++ b/tests/python/contrib/test_ethosn/infrastructure.py @@ -21,6 +21,9 @@ from hashlib import md5 from itertools import zip_longest, combinations import os +from typing import Tuple +import math + import numpy as np from PIL import Image @@ -28,6 +31,7 @@ from tvm import relay from tvm.contrib import utils, graph_executor, download from tvm.relay.op.contrib import partition_for_ethosn + from . import _infrastructure @@ -340,5 +344,44 @@ def get_conv2d_qnn_params( return output_zp, output_sc +def get_same_padding( + data: Tuple[int, int], + kernel: Tuple[int, int], + dilation: Tuple[int, int], + stride: Tuple[int, int], +) -> Tuple[int, int, int, int]: + """ + Get the padding values required for 'SAME' padding. + + Parameters + ---------- + data : Tuple[int, int] + The height and width of the data respectively. + kernel : Tuple[int, int] + The height and width of the kernel respectively. + dilation : Tuple[int, int] + The dilation of the kernel. + stride : Tuple[int, int] + The stride of the kernel. + + Returns + ------- + Tuple[int, int, int, int] + The padding values for top, left, bottom and right respectively. + """ + dilated_kernel_h = dilation[0] * (kernel[0] - 1) + 1 + dilated_kernel_w = dilation[1] * (kernel[1] - 1) + 1 + out = int(math.ceil(float(data[0]) / float(stride[0]))) + pad = max(0, (out - 1) * stride[0] + dilated_kernel_h - data[0]) + pad_top = pad // 2 + pad_bottom = pad - pad_top + + out = int(math.ceil(float(data[1]) / float(stride[1]))) + pad = max(0, (out - 1) * stride[1] + dilated_kernel_w - data[1]) + pad_left = pad // 2 + pad_right = pad - pad_left + return (pad_top, pad_left, pad_bottom, pad_right) + + def get_ethosn_variant(): return os.getenv("ETHOSN_VARIANT_CONFIG", default="Ethos-N78_1TOPS_2PLE_RATIO") diff --git a/tests/python/contrib/test_ethosn/test_conv2d.py b/tests/python/contrib/test_ethosn/test_conv2d.py index 4026f8267d72..a6ce73656bfc 100644 --- a/tests/python/contrib/test_ethosn/test_conv2d.py +++ b/tests/python/contrib/test_ethosn/test_conv2d.py @@ -17,8 +17,6 @@ """Arm(R) Ethos(TM)-N integration conv2d tests""" -import math - import numpy as np import pytest @@ -29,21 +27,6 @@ from . import infrastructure as tei -def _get_same_padding(data, kernel, dilation, stride): - dilated_kernel_h = dilation[0] * (kernel[0] - 1) + 1 - dilated_kernel_w = dilation[1] * (kernel[1] - 1) + 1 - out = int(math.ceil(float(data[0]) / float(stride[0]))) - pad = max(0, (out - 1) * stride[0] + dilated_kernel_h - data[0]) - pad_top = pad // 2 - pad_bottom = pad - pad_top - - out = int(math.ceil(float(data[1]) / float(stride[1]))) - pad = max(0, (out - 1) * stride[1] + dilated_kernel_w - data[1]) - pad_left = pad // 2 - pad_right = pad - pad_left - return [pad_top, pad_left, pad_bottom, pad_right] - - def _get_model( shape, kernel_h, @@ -65,7 +48,7 @@ def _get_model( """Return a model and any parameters it may have""" a = relay.var("a", shape=shape, dtype=dtype) if pad in ("op", "both"): - p = _get_same_padding((shape[1], shape[2]), (kernel_h, kernel_w), dilation, strides) + p = tei.get_same_padding((shape[1], shape[2]), (kernel_h, kernel_w), dilation, strides) a = relay.nn.pad( a, pad_width=[(0, 0), (p[0], p[2]), (p[1], p[3]), (0, 0)], @@ -74,7 +57,7 @@ def _get_model( ) shape = (shape[0], shape[1] + p[0] + p[2], shape[2] + p[1] + p[3], shape[3]) - p = _get_same_padding((shape[1], shape[2]), (kernel_h, kernel_w), dilation, strides) + p = tei.get_same_padding((shape[1], shape[2]), (kernel_h, kernel_w), dilation, strides) if weight_format == "HWIO": weight_shape = (kernel_h, kernel_w, shape[3] // groups, out_channels) else: diff --git a/tests/python/contrib/test_ethosn/test_conv2d_transpose.py b/tests/python/contrib/test_ethosn/test_conv2d_transpose.py new file mode 100644 index 000000000000..84aa7e969b30 --- /dev/null +++ b/tests/python/contrib/test_ethosn/test_conv2d_transpose.py @@ -0,0 +1,234 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Arm(R) Ethos(TM)-N integration conv2d tests""" + +import pytest +import numpy as np + +import tvm +from tvm import relay +from tvm.testing import requires_ethosn +from . import infrastructure as tei + + +def _get_model( + shape, + kernel_h, + kernel_w, + input_zp, + input_sc, + kernel_zp, + kernel_sc, + output_zp, + output_sc, + stride, + dilation, + groups, + kernel_layout, + dtype, + out_channels, + bias, +): + """Return a model and any parameters it may have""" + a = relay.var("a", shape=shape, dtype=dtype) + p = tei.get_same_padding((shape[1], shape[2]), (kernel_h, kernel_w), dilation, stride) + weight_shape = (shape[3], out_channels // groups, kernel_h, kernel_w) + + weight_data = tvm.nd.array( + np.random.randint( + np.iinfo(dtype).min, + high=(np.iinfo(dtype).max + 1), + size=weight_shape, + dtype=dtype, + ) + ) + weights = relay.const(weight_data, dtype) + op = relay.qnn.op.conv2d_transpose( + a, + weights, + input_zero_point=relay.const(input_zp, "int32"), + input_scale=relay.const(input_sc, "float32"), + kernel_zero_point=relay.const(kernel_zp, "int32"), + kernel_scale=relay.const(kernel_sc, "float32"), + kernel_size=(kernel_h, kernel_w), + padding=p, + strides=stride, + dilation=dilation, + data_layout="NHWC", + kernel_layout=kernel_layout, + out_dtype="int32", + channels=out_channels, + groups=groups, + ) + if bias: + bias_data = tvm.nd.array( + np.random.randint( + np.iinfo(dtype).min, + high=np.iinfo(dtype).max + 1, + size=(out_channels,), + dtype="int32", + ) + ) + biasc = relay.const(bias_data, "int32") + op = relay.nn.bias_add(op, biasc, axis=3) + + if isinstance(kernel_sc, tvm.runtime.ndarray.NDArray): + req_input_sc = [sc * input_sc for sc in kernel_sc.numpy()] + else: + req_input_sc = input_sc * kernel_sc + + op = relay.qnn.op.requantize( + op, + input_zero_point=relay.const(input_zp, "int32"), + input_scale=relay.const(req_input_sc, "float32"), + output_zero_point=relay.const(output_zp, "int32"), + output_scale=relay.const(output_sc, "float32"), + axis=3, + rounding="UPWARD", + out_dtype=dtype, + ) + params = {"w": weight_data} + if bias: + params["b"] = bias_data + return op, params + + +@requires_ethosn +@pytest.mark.parametrize("dtype", ["uint8", "int8"]) +@pytest.mark.parametrize( + "ifm_shape,strides,kernel_size,out_channels,bias", + [ + ((1, 2, 2, 1), (2, 2), (1, 1), 1, False), + ((1, 2, 2, 5), (2, 2), (3, 5), 4, False), + ((1, 7, 7, 4), (2, 2), (7, 9), 8, True), + ], +) +def test_conv2d_transpose(ifm_shape, strides, kernel_size, out_channels, dtype, bias): + """Check transpose convolution output with TVM.""" + np.random.seed(0) + + kernel_layout = "IOHW" + dilation = (1, 1) + groups = 1 + + iinfo = np.iinfo(dtype) + data_min = iinfo.min + data_max = iinfo.max + + input_zp = np.random.randint(data_min, data_max) + input_sc = np.random.random() * 2 + kernel_zp = np.random.randint(data_min, data_max) + kernel_sc = np.random.random() * 4 + output_zp, output_sc = tei.get_conv2d_qnn_params( + dtype, input_zp, input_sc, kernel_zp, kernel_sc, ifm_shape[1], ifm_shape[2], ifm_shape[3] + ) + + model, params = _get_model( + shape=ifm_shape, + kernel_h=kernel_size[0], + kernel_w=kernel_size[1], + input_zp=input_zp, + input_sc=input_sc, + kernel_zp=kernel_zp, + kernel_sc=kernel_sc, + output_zp=output_zp, + output_sc=output_sc, + stride=strides, + dilation=dilation, + groups=groups, + kernel_layout=kernel_layout, + dtype=dtype, + out_channels=out_channels, + bias=bias, + ) + + outputs = [] + inputs = { + "a": tvm.nd.array(np.random.randint(data_min, data_max + 1, size=ifm_shape, dtype=dtype)) + } + + for npu in [False, True]: + mod = tei.make_module(model, params) + outputs.append(tei.build_and_run(mod, inputs, 1, params, npu=npu)) + + tei.verify(outputs, dtype, 1) + + +@requires_ethosn +@pytest.mark.parametrize("dtype", ["uint8", "int8"]) +@pytest.mark.parametrize( + "shape, stride, dilation, groups, err_msg", + [ + ( + (1, 4, 4, 4), + (1, 1, 1), + (1, 1), + 1, + "stride size=3, stride size must = 2", + ), + ( + (1, 4, 4, 4), + (2, 2), + (2, 2), + 2, + "dilation=[2, 2], dilation must = [1, 1]", + ), + ( + (2, 4, 4, 4), + (1, 1), + (1, 1), + 1, + "batch size=2, batch size must = 1", + ), + ], +) +def test_conv2d_transpose_failure( + shape, + stride, + dilation, + groups, + err_msg, + dtype, +): + """ + Test transpose_conv2d error messages. + """ + np.random.seed(0) + out_channels = 8 + + model, _ = _get_model( + shape=shape, + kernel_h=1, + kernel_w=1, + input_zp=0, + input_sc=1, + kernel_zp=0, + kernel_sc=1, + output_zp=0, + output_sc=1, + stride=stride, + dilation=dilation, + groups=groups, + kernel_layout="IOHW", + dtype=dtype, + out_channels=out_channels, + bias=False, + ) + model = tei.make_ethosn_composite(model, "ethos-n.qnn_conv2d_transpose") + mod = tei.make_ethosn_partition(model) + tei.test_error(mod, {}, err_msg) From c198125ac3168d164fb62b061c6cc2b84468bfa1 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Mon, 5 Sep 2022 08:38:40 +0000 Subject: [PATCH 2/2] Update docstring Change-Id: I6e8d1d9d05e5870dad4f9735f3c0ed81458449f3 --- python/tvm/relay/op/contrib/ethosn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/op/contrib/ethosn.py b/python/tvm/relay/op/contrib/ethosn.py index b475e3af2a95..5129ed9ffaef 100644 --- a/python/tvm/relay/op/contrib/ethosn.py +++ b/python/tvm/relay/op/contrib/ethosn.py @@ -272,7 +272,7 @@ def check_mean(extract): return _ethosn.mean(extract) def check_conv2d_transpose(extract): - """Check if mean is supported by Ethos-N.""" + """Check if conv2d_transpose is supported by Ethos-N.""" if not ethosn_available(): return False