diff --git a/cmake/modules/contrib/CMSISNN.cmake b/cmake/modules/contrib/CMSISNN.cmake index 73ecd5916df3..eef12fdd778e 100644 --- a/cmake/modules/contrib/CMSISNN.cmake +++ b/cmake/modules/contrib/CMSISNN.cmake @@ -18,6 +18,8 @@ if(USE_CMSISNN) add_definitions(-DTVM_USE_CMSISNN) message(STATUS "Build with CMSIS-NN support") - tvm_file_glob(GLOB RELAY_CONTRIB_CMSISNN_SRCS src/relay/backend/contrib/cmsisnn/*.cc) + tvm_file_glob(GLOB RELAY_CONTRIB_CMSISNN_SRCS + src/relay/backend/contrib/cmsisnn/*.cc + src/relay/backend/contrib/constant_transforms.cc) list(APPEND COMPILER_SRCS ${RELAY_CONTRIB_CMSISNN_SRCS}) endif(USE_CMSISNN) diff --git a/cmake/modules/contrib/EthosN.cmake b/cmake/modules/contrib/EthosN.cmake index dbf5549180aa..b230acfc380d 100644 --- a/cmake/modules/contrib/EthosN.cmake +++ b/cmake/modules/contrib/EthosN.cmake @@ -35,7 +35,8 @@ if(NOT USE_ETHOSN STREQUAL "OFF") list(APPEND RUNTIME_SRCS ${ETHOSN_RUNTIME_CONTRIB_SRC}) tvm_file_glob(GLOB COMPILER_ETHOSN_SRCS - src/relay/backend/contrib/ethosn/*) + src/relay/backend/contrib/ethosn/* + src/relay/backend/contrib/constant_transforms.cc) list(APPEND COMPILER_SRCS ${COMPILER_ETHOSN_SRCS}) list(APPEND TVM_LINKER_LIBS ${ETHOSN_COMPILER_LIBRARY} diff --git a/src/relay/backend/contrib/cmsisnn/generate_constants.cc b/src/relay/backend/contrib/cmsisnn/generate_constants.cc index 297e6b7acea3..e08b61c457f9 100644 --- a/src/relay/backend/contrib/cmsisnn/generate_constants.cc +++ b/src/relay/backend/contrib/cmsisnn/generate_constants.cc @@ -31,6 +31,7 @@ #include "../../../op/make_op.h" #include "../../../qnn/utils.h" #include "../../../transforms/pattern_utils.h" +#include "../constant_transforms.h" #include "convolutions.h" namespace tvm { @@ -64,22 +65,9 @@ class GenerateConstantsMutator : public MixedModeMutator { attrs->out_dtype = std::move(conv2d_attrs->out_dtype); *new_attrs = tvm::Attrs{attrs}; - std::string kernel_layout = conv2d_attrs->kernel_layout.c_str(); - int pos_o = kernel_layout.find("O"); - int pos_h = kernel_layout.find("H"); - int pos_w = kernel_layout.find("W"); - int pos_i = kernel_layout.find("I"); - - IRModule kernel_module; - auto func_body = MakeTranspose( - kernel_expr, {Integer(pos_o), Integer(pos_h), Integer(pos_w), Integer(pos_i)}); - auto kernel_func = - Function(FreeVars(func_body), func_body, Type(), FreeTypeVars(func_body, kernel_module)); - GlobalVar kernel_var("main"); - kernel_module->Add(kernel_var, kernel_func); - kernel_module = relay::transform::FoldConstant()(kernel_module); - kernel_func = Downcast(kernel_module->Lookup("main")); - return kernel_func->body; + Constant conv2d_kernel = Downcast(kernel_expr); + conv2d_kernel = TransposeWeights(conv2d_kernel, conv2d_attrs->kernel_layout, "OHWI"); + return conv2d_kernel; } /*! * \brief Performs weight transpose and substitutes existing constants in the composite diff --git a/src/relay/backend/contrib/constant_transforms.cc b/src/relay/backend/contrib/constant_transforms.cc new file mode 100644 index 000000000000..6041d37451aa --- /dev/null +++ b/src/relay/backend/contrib/constant_transforms.cc @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "constant_transforms.h" + +#include + +#include "../../transforms/pattern_utils.h" +#include "../../transforms/simplify_expr.h" + +/*! + * \file src/relay/backend/contrib/constant_transforms.cc + * \brief Transforms applied to constant operations during codegen for BYOC backends. + */ + +namespace tvm { +namespace relay { +namespace contrib { + +Expr FoldConstantExpr(const Expr& expr, bool fold_qnn) { + auto mod = IRModule::FromExpr(expr); + mod = transform::FoldConstant(fold_qnn)(mod); + auto entry_func = Downcast(mod->Lookup("main")); + return expr.as() == nullptr ? entry_func->body : entry_func; +} + +Constant TransposeWeights(const Constant& data, const std::string& source_layout, + const std::string& target_layout) { + Array transpose_matrix; + for (const char& c : target_layout) { + int pos = source_layout.find(c); + transpose_matrix.push_back(pos); + } + Expr transpose = MakeTranspose(data, transpose_matrix); + transpose = InferType(FoldConstantExpr(transpose)); + Constant transposed_data = Downcast(transpose); + return transposed_data; +} + +} // namespace contrib +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/contrib/constant_transforms.h b/src/relay/backend/contrib/constant_transforms.h new file mode 100644 index 000000000000..39a9dc1d53d4 --- /dev/null +++ b/src/relay/backend/contrib/constant_transforms.h @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/contrib/constant_transforms.h + * \brief Transforms applied to constant operations during codegen for BYOC backends. + */ + +#ifndef TVM_RELAY_BACKEND_CONTRIB_CONSTANT_TRANSFORMS_H_ +#define TVM_RELAY_BACKEND_CONTRIB_CONSTANT_TRANSFORMS_H_ + +#include + +#include + +namespace tvm { +namespace relay { +namespace contrib { + +/*! + * \brief Apply constant folding on an expression. + * + * \param expr The expression to fold. + * \param fold_qnn Whether to fold constants for QNN operations. + * \returns The new folded expression. + */ +Expr FoldConstantExpr(const Expr& expr, bool fold_qnn = true); + +/*! + *\brief Transpose weights from `source_layout` to `target_layout` + * + * \param data The constant expression to transpose. + * \param source_layout The current layout of the constant e.g. "OHWI". + * \param target_layout The target layout of the constant e.g. "HWIO". + */ +Constant TransposeWeights(const Constant& data, const std::string& source_layout, + const std::string& target_layout); + +} // namespace contrib +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_BACKEND_CONTRIB_CONSTANT_TRANSFORMS_H_ diff --git a/src/relay/backend/contrib/ethosn/codegen.cc b/src/relay/backend/contrib/ethosn/codegen.cc index 46420775ae5b..d2281f782615 100644 --- a/src/relay/backend/contrib/ethosn/codegen.cc +++ b/src/relay/backend/contrib/ethosn/codegen.cc @@ -412,8 +412,8 @@ EthosnError ConstructNetworkVisitor::MakeFullyConnectedLayer(const Call& call, return err; } - auto weights = AddConstant(network_, params.weights_info, params.raw_weights).tensor; - auto bias = AddConstant(network_, params.bias_info, params.raw_bias).tensor; + auto weights = AddConstant(network_, params.weights_info, params.raw_weights->data).tensor; + auto bias = AddConstant(network_, params.bias_info, params.raw_bias->data).tensor; try { auto input = AddReshape(network_, *operand_table_[call->args[0]][0], params.input_info.m_Dimensions) diff --git a/src/relay/backend/contrib/ethosn/convert_equivalent.cc b/src/relay/backend/contrib/ethosn/convert_equivalent.cc index 7f4e1a3c5045..14d94192c84e 100644 --- a/src/relay/backend/contrib/ethosn/convert_equivalent.cc +++ b/src/relay/backend/contrib/ethosn/convert_equivalent.cc @@ -32,6 +32,7 @@ #include "../../../qnn/utils.h" #include "../../../transforms/pattern_utils.h" #include "../../../transforms/simplify_expr.h" +#include "../constant_transforms.h" #include "ethosn_api.h" namespace tvm { diff --git a/src/relay/backend/contrib/ethosn/ethosn_api.cc b/src/relay/backend/contrib/ethosn/ethosn_api.cc index dbcdecd8f382..c0f8767a8c65 100644 --- a/src/relay/backend/contrib/ethosn/ethosn_api.cc +++ b/src/relay/backend/contrib/ethosn/ethosn_api.cc @@ -41,6 +41,7 @@ #include "../../../op/make_op.h" #include "../../../transforms/pattern_utils.h" #include "../../../transforms/simplify_expr.h" +#include "../constant_transforms.h" #include "ethosn_support_library/Support.hpp" #include "ethosn_support_library/SupportQueries.hpp" #include "tvm/relay/qnn/attrs.h" @@ -197,7 +198,10 @@ EthosnError EthosnAPI::QnnFullyConnected(const Expr& expr, FullyConnectedParams* sl::QuantizationInfo output_q_info; err += Tvm2Npu(input_zero_point, input_scale, &data_q_info); err += Tvm2Npu(kernel_zero_point, kernel_scale, &weights_q_info); - err += Tvm2Npu(0, data_q_info.GetScale() * weights_q_info.GetScale(), &bias_q_info); + std::valarray bias_scales = data_q_info.GetScale() * weights_q_info.GetScales(); + const int bias_zero_point = 0; + const unsigned int bias_axis = 3; + err += Tvm2Npu(bias_zero_point, bias_scales, bias_axis, &bias_q_info); err += Tvm2Npu(output_zero_point, output_scale, &output_q_info); // Create fc info @@ -213,27 +217,30 @@ EthosnError EthosnAPI::QnnFullyConnected(const Expr& expr, FullyConnectedParams* data_data_type, sl::DataFormat::NHWC, data_q_info); // Create weights info - const auto* weights_dtype = dense->args[1]->checked_type().as(); + Constant weights_data = Downcast(dense->args[1]); + weights_data = TransposeWeights(weights_data, "OI", "IO"); + const auto* weights_ttype = weights_data->checked_type().as(); sl::TensorShape weights_tensor_shape; sl::DataType weights_data_type; sl::DataFormat weights_data_format; // Ignore the error here because weights don't have a batch axis - Tvm2Npu(weights_dtype->shape, &weights_tensor_shape); - err += Tvm2Npu(weights_dtype->dtype, &weights_data_type); + Tvm2Npu(weights_ttype->shape, &weights_tensor_shape); + err += Tvm2Npu(weights_ttype->dtype, &weights_data_type); err += Tvm2Npu("HWIO", &weights_data_format); - params->weights_info = sl::TensorInfo({1, 1, weights_tensor_shape[1], weights_tensor_shape[0]}, + // Weights tensor shape is 1, 1, I, O + params->weights_info = sl::TensorInfo({1, 1, weights_tensor_shape[0], weights_tensor_shape[1]}, weights_data_type, weights_data_format, weights_q_info); - params->raw_weights = dense->args[1].as()->data->data; + params->raw_weights = weights_data->data; // Create bias info params->bias_info = - sl::TensorInfo({1, 1, 1, weights_tensor_shape[0]}, sl::DataType::INT32_QUANTIZED, + sl::TensorInfo({1, 1, 1, weights_tensor_shape[1]}, sl::DataType::INT32_QUANTIZED, sl::DataFormat::NHWC, bias_q_info); - params->raw_bias = bias_add->args[1].as()->data->data; + params->raw_bias = bias_add->args[1].as()->data; sl::TensorInfo output_tensor_info; err += Tvm2Npu(requantize->checked_type(), &output_tensor_info); - output_tensor_info.m_Dimensions = {data_tensor_shape[0], 1, 1, weights_tensor_shape[0]}; + output_tensor_info.m_Dimensions = {data_tensor_shape[0], 1, 1, weights_tensor_shape[1]}; output_tensor_info.m_QuantizationInfo = output_q_info; params->output_info = output_tensor_info; @@ -449,21 +456,6 @@ EthosnError EthosnAPI::Mean(const Expr& expr, MeanParams* params) { return err; } -Constant TransposeWeights(const Constant& data, const std::string& input_layout) { - int pos_h = input_layout.find("H"); - int pos_w = input_layout.find("W"); - int pos_i = input_layout.find("I"); - int pos_o = input_layout.find("O"); - - // Currently the expected target layout is HWIO only. - Array target_shape = {pos_h, pos_w, pos_i, pos_o}; - - Expr transpose = MakeTranspose(data, target_shape); - transpose = InferType(FoldConstantExpr(transpose)); - Constant transposed_data = Downcast(transpose); - return transposed_data; -} - EthosnError EthosnAPI::QnnConv2dTranspose(const Expr& expr, QnnConv2dTransposeParams* params) { Call requantize = Downcast(expr); Call bias; @@ -530,7 +522,7 @@ EthosnError EthosnAPI::QnnConv2dTranspose(const Expr& expr, QnnConv2dTransposePa // Create weights info Constant weights_data = Downcast(conv2d_transpose->args[1]); if (conv_attr->kernel_layout != "HWIO") { - weights_data = TransposeWeights(weights_data, conv_attr->kernel_layout); + weights_data = TransposeWeights(weights_data, conv_attr->kernel_layout, "HWIO"); } const auto* weights_ttype = weights_data->checked_type().as(); sl::TensorShape weights_tensor_shape; @@ -1080,13 +1072,6 @@ EthosnError EthosnAPI::AsConstant(const Expr& expr, T* out) { return EthosnError(); } -Expr FoldConstantExpr(const Expr& expr, bool fold_qnn) { - auto mod = IRModule::FromExpr(expr); - mod = transform::FoldConstant(fold_qnn)(mod); - auto entry_func = Downcast(mod->Lookup("main")); - return expr.as() == nullptr ? entry_func->body : entry_func; -} - } // namespace ethosn } // namespace contrib } // namespace relay diff --git a/src/relay/backend/contrib/ethosn/ethosn_api.h b/src/relay/backend/contrib/ethosn/ethosn_api.h index 3d704f2757c6..d640a02312ec 100644 --- a/src/relay/backend/contrib/ethosn/ethosn_api.h +++ b/src/relay/backend/contrib/ethosn/ethosn_api.h @@ -66,8 +66,8 @@ struct FullyConnectedParams { sl::TensorInfo weights_info; sl::TensorInfo bias_info; sl::TensorInfo output_info; - void* raw_weights = nullptr; - void* raw_bias = nullptr; + runtime::NDArray raw_weights; + runtime::NDArray raw_bias; }; struct MaxPool2DParams { @@ -324,15 +324,6 @@ class EthosnAPI { static EthosnError AsConstant(const Expr& expr, std::valarray* out); }; -/*! - * \brief Apply constant folding on an expression. - * - * \param expr The expression to fold. - * \param fold_qnn Whether to fold constants for QNN operations. - * \returns The new folded expression. - */ -Expr FoldConstantExpr(const Expr& expr, bool fold_qnn = true); - } // namespace ethosn } // namespace contrib } // namespace relay diff --git a/tests/python/contrib/test_ethosn/test_fullyconnected.py b/tests/python/contrib/test_ethosn/test_fullyconnected.py index d38b2528c7bb..e84464f90217 100644 --- a/tests/python/contrib/test_ethosn/test_fullyconnected.py +++ b/tests/python/contrib/test_ethosn/test_fullyconnected.py @@ -19,9 +19,11 @@ import numpy as np import pytest + import tvm from tvm import relay from tvm.testing import requires_ethosn + from . import infrastructure as tei @@ -30,7 +32,11 @@ def _get_model( ): """Return a model an any parameters it may have""" a = relay.var("a", shape=shape, dtype=dtype) - weights_array = tvm.nd.array(np.ones(weight_shape, dtype)) + weights_array = tvm.nd.array( + np.random.randint( + np.iinfo(dtype).min, high=np.iinfo(dtype).max, size=weight_shape, dtype=dtype + ) + ) weights = relay.const(weights_array, dtype) dense = relay.qnn.op.dense( a, @@ -66,26 +72,24 @@ def _get_model( ((1, 1280), 1000), ], ) -@pytest.mark.parametrize( - "dtype,input_zp,input_sc,kernel_zp,kernel_sc", - [ - ("uint8", 71, 0.580, 176, 1.498), - ("uint8", 166, 1.724, 138, 0.180), - ("int8", 71, 0.580, 0, 1.498), - ("int8", 120, 1.724, 0, 0.180), - ], -) -def test_fullyconnected(shape, out_channels, dtype, input_zp, input_sc, kernel_zp, kernel_sc): +@pytest.mark.parametrize("dtype", ["uint8", "int8"]) +def test_fullyconnected(shape, out_channels, dtype): """Compare Fully Connected output with TVM.""" np.random.seed(0) + iinfo = np.iinfo(dtype) + data_min = iinfo.min + data_max = iinfo.max + inputs = { - "a": tvm.nd.array( - np.random.randint(np.iinfo(dtype).min, np.iinfo(dtype).max + 1, size=shape, dtype=dtype) - ), + "a": tvm.nd.array(np.random.randint(data_min, data_max + 1, size=shape, dtype=dtype)), } - outputs = [] + + input_zp = np.random.randint(data_min, data_max) + input_sc = np.random.random() * 2 + kernel_zp = np.random.randint(data_min, data_max) + kernel_sc = np.random.random() * 2 output_zp, output_sc = tei.get_conv2d_qnn_params( dtype, input_zp, @@ -96,18 +100,18 @@ def test_fullyconnected(shape, out_channels, dtype, input_zp, input_sc, kernel_z shape[1], 1, ) + model, params = _get_model( + shape, + (out_channels, shape[1]), + input_zp, + input_sc, + kernel_zp, + kernel_sc, + output_zp, + output_sc, + dtype, + ) for npu in [False, True]: - model, params = _get_model( - shape, - (out_channels, shape[1]), - input_zp, - input_sc, - kernel_zp, - kernel_sc, - output_zp, - output_sc, - dtype, - ) mod = tei.make_module(model, params) outputs.append(tei.build_and_run(mod, inputs, 1, params, npu=npu)) tei.verify(outputs, dtype, 1) diff --git a/tests/python/contrib/test_ethosn/test_networks.py b/tests/python/contrib/test_ethosn/test_networks.py index 54ca44805171..5bd133ba20bb 100644 --- a/tests/python/contrib/test_ethosn/test_networks.py +++ b/tests/python/contrib/test_ethosn/test_networks.py @@ -145,7 +145,7 @@ def test_resnet_50_int8(): # version or a change in the Ethos-N codegen. To update this requires running # on hardware that isn't available in CI. _compile_hash = { - "6b130a99397715156d5fb833809a92d2", + "f16dc9caa8e696bc5da8a5c6a644eb72", "6e5fcbab831607b9da1039aff4e56871", "41acecca37b2735bd580f6ec38d8c2e0", } @@ -190,7 +190,7 @@ def test_inception_v4(): # codegen, which could come about from either a change in Support Library # version or a change in the Ethos-N codegen. To update this requires running # on hardware that isn't available in CI. - _compile_hash = {"2eeae331898f8e94c74868e190077837"} + _compile_hash = {"c00c119506b34c8e87f81aa009b42431"} _test_image_network( model_url="https://storage.googleapis.com/download.tensorflow.org/" "models/inception_v4_299_quant_20181026.tgz",