From 23619e7149582345e343f7e4a93cf68bccf8ba6b Mon Sep 17 00:00:00 2001 From: Ashutosh Parkhi Date: Wed, 30 Nov 2022 14:20:23 +0000 Subject: [PATCH] [CodegenC] Explicit forward function declarations armclang 6.19 does not support implicit function declarations. This commit adds support for generating forward function declarations in the C file generated for __tvm_main__. All the non-pure extern functions called from __tvm_main__ will be declared explicitly in this file. Change-Id: I03b12e6c844911bd7edb6e42ddd2b17f066bd0fa --- python/tvm/relay/testing/tflite.py | 10 + .../backend/contrib/cmsisnn/tir_to_runtime.cc | 7 +- .../example_target_hooks/tir_to_runtime.cc | 3 +- .../backend/contrib/uma/tir_to_runtime.cc | 7 +- src/target/source/codegen_c.cc | 5 +- src/target/source/codegen_c.h | 13 +- src/target/source/codegen_c_host.cc | 57 +++- src/target/source/codegen_c_host.h | 11 +- src/target/source/codegen_cuda.cc | 2 +- src/target/source/codegen_cuda.h | 2 +- src/target/source/codegen_opencl.cc | 2 +- src/target/source/codegen_opencl.h | 2 +- src/target/source/codegen_source_base.h | 2 + src/target/source/codegen_vhls.cc | 2 +- src/target/source/codegen_vhls.h | 2 +- tests/python/relay/aot/corstone300.mk | 2 +- .../aot/test_crt_forward_declarations.py | 275 ++++++++++++++++++ 17 files changed, 372 insertions(+), 32 deletions(-) create mode 100644 tests/python/relay/aot/test_crt_forward_declarations.py diff --git a/python/tvm/relay/testing/tflite.py b/python/tvm/relay/testing/tflite.py index df40130cebaf..c45b76c77369 100644 --- a/python/tvm/relay/testing/tflite.py +++ b/python/tvm/relay/testing/tflite.py @@ -61,6 +61,16 @@ def conv2d_single_function(ifm_tensor): return conv2d_single_function + def load_from_file(self, model_file, shapes): + """Load tflite model from a tflite file""" + for i, shape in enumerate(shapes): + input_name = "input_" + str(i) + self.shape_dict.update({input_name: shape}) + self.dtype_dict.update({input_name: self.dtype}) + + with open(model_file, "rb") as f: + self.serial_model = f.read() + def create_tflite_model(self, tfl_function, shapes, ranges=None): """Creates TFLite serial graph""" tensor_specs = [] diff --git a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc index 420e8618a4f9..1d53373ba833 100644 --- a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc +++ b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc @@ -35,10 +35,10 @@ namespace cmsisnn { class CodeGenCMSISNN : public codegen::CodeGenCHost { public: - void Init(bool output_ssa, bool emit_asserts, std::string target_str) { + void Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_decl, std::string target_str) { std::unordered_set devices; devices.insert("cmsis-nn"); - CodeGenCHost::Init(output_ssa, emit_asserts, target_str, devices); + CodeGenCHost::Init(output_ssa, emit_asserts, emit_fwd_func_decl, target_str, devices); } /*! @@ -491,9 +491,10 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost { runtime::Module TIRToRuntime(IRModule mod, Target target) { bool output_ssa = false; bool emit_asserts = false; + bool emit_fwd_func_decl = false; CodeGenCMSISNN codegen; Array function_names; - codegen.Init(output_ssa, emit_asserts, target->str()); + codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str()); std::vector> funcs; for (auto kv : mod->functions) { diff --git a/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc b/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc index 9ad434b88c60..0db8d06c3143 100644 --- a/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc +++ b/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc @@ -47,10 +47,11 @@ class CodeGenExampleTargetHook : public codegen::CodeGenCHost { runtime::Module TIRToRuntime(IRModule mod, Target target) { bool output_ssa = false; bool emit_asserts = false; + bool emit_fwd_func_decl = false; CodeGenExampleTargetHook codegen; Array function_names; std::unordered_set devices; - codegen.Init(output_ssa, emit_asserts, target->str(), devices); + codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), devices); for (auto kv : mod->functions) { auto prim_func = Downcast(kv.second); auto global_symbol = prim_func->GetAttr(tvm::attr::kGlobalSymbol); diff --git a/src/relay/backend/contrib/uma/tir_to_runtime.cc b/src/relay/backend/contrib/uma/tir_to_runtime.cc index 4b5cd4332476..3b58fda54b52 100644 --- a/src/relay/backend/contrib/uma/tir_to_runtime.cc +++ b/src/relay/backend/contrib/uma/tir_to_runtime.cc @@ -37,7 +37,7 @@ class UMACodegen : public codegen::CodeGenCHost { public: explicit UMACodegen(String target_str) : target_str_(target_str) {} - void Init(bool output_ssa, bool emit_asserts) { + void Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_decl) { auto includes_pf = tvm::runtime::Registry::Get("relay.ext.uma.codegen_c_includes_" + target_str_); if (includes_pf) { @@ -46,7 +46,7 @@ class UMACodegen : public codegen::CodeGenCHost { } std::unordered_set devices; devices.insert(target_str_); - CodeGenCHost::Init(output_ssa, emit_asserts, target_str_, devices); + CodeGenCHost::Init(output_ssa, emit_asserts, emit_fwd_func_decl, target_str_, devices); } /*! @@ -63,9 +63,10 @@ class UMACodegen : public codegen::CodeGenCHost { runtime::Module TIRToRuntime(IRModule mod, Target target) { bool output_ssa = false; bool emit_asserts = false; + bool emit_fwd_func_decl = false; UMACodegen codegen(target->kind->name); Array function_names; - codegen.Init(output_ssa, emit_asserts); + codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl); for (auto kv : mod->functions) { auto prim_func = Downcast(kv.second); auto global_symbol = prim_func->GetAttr(tvm::attr::kGlobalSymbol); diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 66c92181c126..6bf14424bf38 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -85,7 +85,7 @@ void CodeGenC::AddFunction(const PrimFunc& f) { << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias); - this->PrintFuncPrefix(); + this->PrintFuncPrefix(stream); this->PrintExtraAttrs(f); this->stream << " " << static_cast(global_symbol.value()) << "("; @@ -127,7 +127,7 @@ void CodeGenC::AddFunction(const PrimFunc& f) { this->stream << "}\n\n"; } -void CodeGenC::PrintFuncPrefix() { stream << "void"; } +void CodeGenC::PrintFuncPrefix(std::ostream& os) { os << "void"; } void CodeGenC::PrintExtraAttrs(const PrimFunc& f) {} @@ -540,6 +540,7 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) ICHECK_GE(op->args.size(), 1U); auto func = Downcast(op->args[0]); this->PrintCallExtern(GetType(GetRef(op)), func->value, op->args, true, os); + this->GenerateForwardFunctionDeclarations(func->value, op->args); } else if (op_attr_global_symbol_.count(call_op)) { // call extern if the op itself have a global symbol. this->PrintCallExtern(GetType(GetRef(op)), op_attr_global_symbol_[call_op], diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index 0af24dfdc066..be715ad3a049 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -75,7 +75,7 @@ class CodeGenC : public ExprFunctor, * \brief Finalize the compilation and return the code. * \return The code. */ - std::string Finish(); + virtual std::string Finish(); /*! * \brief Print the Stmt n to CodeGenC->stream * \param n The statement to be printed. @@ -99,10 +99,11 @@ class CodeGenC : public ExprFunctor, // The following parts are overloadable print operations. /*! * \brief Print the function header before the argument list + * \param os The output stream * * Example: stream << "void"; */ - virtual void PrintFuncPrefix(); // NOLINT(*) + virtual void PrintFuncPrefix(std::ostream& os); // NOLINT(*) /*! * \brief Print extra function attributes * @@ -230,6 +231,14 @@ class CodeGenC : public ExprFunctor, */ virtual bool IsScopePartOfType() const { return true; } + /*! + * \brief Generate forward function declarations. + * \param global_symbol The symbolc of the target function. + * \param args The arguments to the function. + * \param os The output stream. + */ + virtual void GenerateForwardFunctionDeclarations(String global_symbol, + const Array& args) {} /*! * \brief Print external function call. * \param ret_type The return type. diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index a47158d37883..0a89198c985a 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -44,9 +44,10 @@ namespace codegen { CodeGenCHost::CodeGenCHost() { module_name_ = name_supply_->FreshName("__tvm_module_ctx"); } -void CodeGenCHost::Init(bool output_ssa, bool emit_asserts, std::string target_str, - const std::unordered_set& devices) { +void CodeGenCHost::Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_decl, + std::string target_str, const std::unordered_set& devices) { emit_asserts_ = emit_asserts; + emit_fwd_func_decl_ = emit_fwd_func_decl; declared_globals_.clear(); decl_stream << "// tvm target: " << target_str << "\n"; decl_stream << "#define TVM_EXPORTS\n"; @@ -73,17 +74,18 @@ void CodeGenCHost::InitGlobalContext() { void CodeGenCHost::DefineModuleName() { decl_stream << "void* " << module_name_ << " = NULL;\n"; } -void CodeGenCHost::AddFunction(const PrimFunc& f) { +void CodeGenCHost::AddFunction(const PrimFunc& f, bool emit_fwd_func_decl) { auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.defined()) << "CodeGenCHost: Expect PrimFunc to have the global_symbol attribute"; function_names_.push_back(global_symbol.value()); + emit_fwd_func_decl_ = emit_fwd_func_decl; CodeGenC::AddFunction(f); if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { function_names_.push_back(runtime::symbol::tvm_module_main); stream << "// CodegenC: NOTE: Auto-generated entry function\n"; - PrintFuncPrefix(); + PrintFuncPrefix(stream); stream << " " << tvm::runtime::symbol::tvm_module_main << "(void* args, int* arg_type_ids, int num_args, void* out_ret_value, " << "int* out_ret_tcode, void* resource_handle) {\n"; @@ -93,11 +95,33 @@ void CodeGenCHost::AddFunction(const PrimFunc& f) { } } -void CodeGenCHost::PrintFuncPrefix() { // NOLINT(*) - stream << "#ifdef __cplusplus\n" - << "extern \"C\"\n" - << "#endif\n" - << "TVM_DLL int32_t"; +void CodeGenCHost::GenerateForwardFunctionDeclarations(String global_symbol, + const Array& args) { + if (!emit_fwd_func_decl_) { + return; + } + for (auto& func_already_defined : GetFunctionNames()) { + if (global_symbol == func_already_defined) { + return; + } + } + this->PrintFuncPrefix(fwd_decl_stream); + fwd_decl_stream << " " << global_symbol << "("; + for (size_t i = 1; i < args.size(); ++i) { + CodeGenSourceBase::PrintType(GetType(args[i]), fwd_decl_stream); + fwd_decl_stream << " ", this->PrintExpr(args[i], fwd_decl_stream); + if (i < args.size() - 1) { + fwd_decl_stream << ", "; + } + } + fwd_decl_stream << ");\n"; +} + +void CodeGenCHost::PrintFuncPrefix(std::ostream& os) { // NOLINT(*) + os << "#ifdef __cplusplus\n" + << "extern \"C\"\n" + << "#endif\n" + << "TVM_DLL int32_t"; } void CodeGenCHost::PrintFinalReturn() { // NOLINT(*) @@ -105,6 +129,15 @@ void CodeGenCHost::PrintFinalReturn() { // NOLINT(*) stream << "return 0;\n"; } +std::string CodeGenCHost::Finish() { // NOLINT(*) + std::string ret = decl_stream.str(); + if (emit_fwd_func_decl_) { + ret += fwd_decl_stream.str(); + } + ret += stream.str(); + return ret; +} + void CodeGenCHost::PrintType(DataType t, std::ostream& os) { // NOLINT(*) int lanes = t.lanes(); if (t.is_handle()) { @@ -391,6 +424,7 @@ runtime::Module BuildCHost(IRModule mod, Target target) { using tvm::runtime::Registry; bool output_ssa = false; bool emit_asserts = false; + bool emit_fwd_func_decl = true; std::unordered_set devices; if (mod->GetAttr>("device_contexts") != nullptr) { @@ -402,7 +436,7 @@ runtime::Module BuildCHost(IRModule mod, Target target) { } CodeGenCHost cg; - cg.Init(output_ssa, emit_asserts, target->str(), devices); + cg.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), devices); cg.SetConstantsByteAlignment(target->GetAttr("constants-byte-alignment").value_or(16)); PrimFunc aot_executor_fn; @@ -438,7 +472,8 @@ runtime::Module BuildCHost(IRModule mod, Target target) { // Add __tvm_main__ if (aot_executor_fn.defined()) { - cg.AddFunction(aot_executor_fn); + emit_fwd_func_decl = true; + cg.AddFunction(aot_executor_fn, emit_fwd_func_decl); } // NOTE: it's possible that kRuntime attr is not attached when the mod was built with tvm.build(). diff --git a/src/target/source/codegen_c_host.h b/src/target/source/codegen_c_host.h index 84c27b91bac3..6bae574627d5 100644 --- a/src/target/source/codegen_c_host.h +++ b/src/target/source/codegen_c_host.h @@ -40,11 +40,12 @@ namespace codegen { class CodeGenCHost : public CodeGenC { public: CodeGenCHost(); - void Init(bool output_ssa, bool emit_asserts, std::string target_str, + void Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_decl, std::string target_str, const std::unordered_set& devices); void InitGlobalContext(); - void AddFunction(const PrimFunc& f); + void AddFunction(const PrimFunc& f, bool emit_fwd_func_decl = false); + std::string Finish() final; /*! * \brief Add functions from the (unordered) range to the current module in a deterministic * order. This helps with debugging. @@ -55,7 +56,7 @@ class CodeGenCHost : public CodeGenC { void DefineModuleName(); void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) - void PrintFuncPrefix() final; // NOLINT(*) + void PrintFuncPrefix(std::ostream& os) final; // NOLINT(*) void PrintFinalReturn() final; // NOLINT(*) // overload visitor functions @@ -68,6 +69,8 @@ class CodeGenCHost : public CodeGenC { void VisitStmt_(const AssertStmtNode* op) final; // NOLINT(*) + virtual void GenerateForwardFunctionDeclarations(String global_symbol, + const Array& args); // NOLINT(*) Array GetFunctionNames() { return function_names_; } private: @@ -87,6 +90,8 @@ class CodeGenCHost : public CodeGenC { Array function_names_; /*! \brief whether to emit asserts in the resulting C code */ bool emit_asserts_; + /*! \brief whether to emit forwared function declarations in the resulting C code */ + bool emit_fwd_func_decl_; FunctionInfo GetFunctionInfo(const CallNode* op, bool has_resource_handle); std::string GetPackedName(const CallNode* op); diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 3ae74cc16da4..436e85247ffe 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -48,7 +48,7 @@ void CodeGenCUDA::Init(bool output_ssa) { ICHECK_EQ(vid_global_barrier_state_, runtime::symbol::tvm_global_barrier_state); } -void CodeGenCUDA::PrintFuncPrefix() { stream << "extern \"C\" __global__ void"; } +void CodeGenCUDA::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" __global__ void"; } class ThreadIdxExtractor : public tir::StmtVisitor { private: diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h index 673753c470ae..0fef15c7a7f3 100644 --- a/src/target/source/codegen_cuda.h +++ b/src/target/source/codegen_cuda.h @@ -45,7 +45,7 @@ class CodeGenCUDA final : public CodeGenC { return (enable_fp16_ || enable_bf16_ || enable_int8_ || need_math_constants_h_ || need_mma_h_); } // override behavior - void PrintFuncPrefix() final; + void PrintFuncPrefix(std::ostream& os) final; void PrintExtraAttrs(const PrimFunc& f) final; void VisitStmt_(const ForNode* op) final; void PrintStorageSync(const CallNode* op) final; diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index cd898043eeb5..6e5a9db4d37c 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -87,7 +87,7 @@ void CodeGenOpenCL::InitFuncState(const PrimFunc& f) { } } -void CodeGenOpenCL::PrintFuncPrefix() { stream << "__kernel void"; } +void CodeGenOpenCL::PrintFuncPrefix(std::ostream& os) { os << "__kernel void"; } void CodeGenOpenCL::PreFunctionBody(const PrimFunc& f) { for (Var arg : f->params) { diff --git a/src/target/source/codegen_opencl.h b/src/target/source/codegen_opencl.h index af6de1531017..bf3046f0d8df 100644 --- a/src/target/source/codegen_opencl.h +++ b/src/target/source/codegen_opencl.h @@ -41,7 +41,7 @@ class CodeGenOpenCL final : public CodeGenC { // override print thread tag. void InitFuncState(const PrimFunc& f) final; - void PrintFuncPrefix() final; // NOLINT(*) + void PrintFuncPrefix(std::ostream& os) final; // NOLINT(*) void PreFunctionBody(const PrimFunc& f) final; // NOLINT(*) void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) diff --git a/src/target/source/codegen_source_base.h b/src/target/source/codegen_source_base.h index 2fd0abcd68a6..8191ad43aa99 100644 --- a/src/target/source/codegen_source_base.h +++ b/src/target/source/codegen_source_base.h @@ -120,6 +120,8 @@ class CodeGenSourceBase { std::ostringstream decl_stream; /*! \brief the stream to be printed */ std::ostringstream stream; + /*! \brief the forward declaration stream */ + std::ostringstream fwd_decl_stream; /*! \brief name of each variable */ std::unordered_map var_idmap_; /*! \brief NameSupply for allocation */ diff --git a/src/target/source/codegen_vhls.cc b/src/target/source/codegen_vhls.cc index 4091b64f4524..3ae3fb773d7f 100644 --- a/src/target/source/codegen_vhls.cc +++ b/src/target/source/codegen_vhls.cc @@ -80,7 +80,7 @@ void CodeGenVivadoHLS::PrintType(DataType t, std::ostream& os) { } } -void CodeGenVivadoHLS::PrintFuncPrefix() { stream << "extern \"C\" void"; } +void CodeGenVivadoHLS::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" void"; } void CodeGenVivadoHLS::PreFunctionBody(const PrimFunc& f) { for (size_t i = 0; i < f->params.size(); ++i) { diff --git a/src/target/source/codegen_vhls.h b/src/target/source/codegen_vhls.h index b9bec516bae9..32ddce1b3a30 100644 --- a/src/target/source/codegen_vhls.h +++ b/src/target/source/codegen_vhls.h @@ -40,7 +40,7 @@ class CodeGenVivadoHLS final : public CodeGenC { void Init(bool output_ssa); void PrintType(DataType t, std::ostream& os); - void PrintFuncPrefix() final; + void PrintFuncPrefix(std::ostream& os) final; void PreFunctionBody(const PrimFunc& f) final; void VisitExpr_(const MinNode* op, std::ostream& os) final; void VisitExpr_(const MaxNode* op, std::ostream& os) final; diff --git a/tests/python/relay/aot/corstone300.mk b/tests/python/relay/aot/corstone300.mk index 1361dbbc1946..cb1db5ea9995 100644 --- a/tests/python/relay/aot/corstone300.mk +++ b/tests/python/relay/aot/corstone300.mk @@ -42,7 +42,7 @@ ETHOSU_PATH=/opt/arm/ethosu DRIVER_PATH=${ETHOSU_PATH}/core_driver CMSIS_PATH=${ETHOSU_PATH}/cmsis PLATFORM_PATH=${ETHOSU_PATH}/core_platform/targets/corstone-300 -PKG_COMPILE_OPTS = -g -Wall -O2 -Wno-incompatible-pointer-types -Wno-format -mcpu=${MCPU}${MCPU_FLAGS} -mthumb -mfloat-abi=${MFLOAT_ABI} -std=gnu99 +PKG_COMPILE_OPTS = -g -Wall -O2 -Wno-incompatible-pointer-types -Wno-format -Werror-implicit-function-declaration -mcpu=${MCPU}${MCPU_FLAGS} -mthumb -mfloat-abi=${MFLOAT_ABI} -std=gnu99 CMAKE = /opt/arm/cmake/bin/cmake CC = arm-none-eabi-gcc AR = arm-none-eabi-ar diff --git a/tests/python/relay/aot/test_crt_forward_declarations.py b/tests/python/relay/aot/test_crt_forward_declarations.py new file mode 100644 index 000000000000..7454f85ed153 --- /dev/null +++ b/tests/python/relay/aot/test_crt_forward_declarations.py @@ -0,0 +1,275 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""test forward function declarations codegen by CodegenCHost.""" + +from collections import OrderedDict +import pytest +import numpy as np + +import tvm.testing +from tvm import relay +from tvm.contrib.download import download_testdata +from tvm.relay.op.contrib import cmsisnn +from tvm.testing.aot import AOTTestModel, compile_models, generate_ref_data +from tvm.micro.testing.aot_test_utils import ( + AOT_CORSTONE300_RUNNER, + AOT_USMP_CORSTONE300_RUNNER, + parametrize_aot_options, + AOTTestRunner, +) + + +def get_range_for_dtype_str(dtype): + """ + Produces the min,max for a give data type. + + Parameters + ---------- + dtype : str + a type string (e.g., int8) + + Returns + ------- + type_info.min : int + the minimum of the range + type_info.max : int + the maximum of the range + """ + + try: + type_info = np.iinfo(dtype) + except ValueError: + type_info = np.finfo(dtype) + return type_info.min, type_info.max + + +def _change_ndarray_layout(arr, src_layout, dst_layout): + """Makes a copy of an ndarray, reshaping it to a new data layout. + + Parameter + --------- + arr : numpy.ndarray + The ndarray to be reformatted. + + src_layout : str + The current layout of the Relay constant. Must be alphabetic (e.g. NHWC + or OIHW, but not NCHW2c). + + dst_layout : str + The desired layout of new the Relay constant. Must be alphabetic (e.g. NHWC + or OIHW, but not NCHW2c). + + Returns + ------- + dst_shape : numpy.ndarray + A copy of the ndarray with the new layout. + """ + assert src_layout.isalpha() and dst_layout.isalpha() + axis_order = [src_layout.index(c) for c in dst_layout] + return np.transpose(arr, axis_order) + + +@tvm.testing.requires_package("tflite") +@tvm.testing.requires_cmsisnn +@pytest.mark.parametrize("test_runner", [AOT_CORSTONE300_RUNNER, AOT_USMP_CORSTONE300_RUNNER]) +def test_external_calls(test_runner): + """Download a small network and partition for CMSIS-NN to test forward declarations for external + calls outside of __tvm_main__.""" + # download the model + base_url = ( + "https://github.com/ARM-software/ML-zoo/raw/" + "48a22ee22325d15d2371a6df24eb7d67e21dcc97" + "/models/keyword_spotting/cnn_small/tflite_int8" + ) + file_to_download = "cnn_s_quantized.tflite" + file_saved = "cnn_s_quantized_15Dec2021.tflite" + model_file = download_testdata("{}/{}".format(base_url, file_to_download), file_saved) + + # convert the tflite network into relay model + # pylint: disable=import-outside-toplevel + from tvm.relay.testing.tflite import TFLiteModel + + input_shape = (1, 490) + dtype = "int8" + tfl_model = TFLiteModel(dtype) + tfl_model.load_from_file(model_file, [input_shape]) + relay_mod, relay_params = tfl_model.convert_to_relay() + cmsisnn_mod = cmsisnn.partition_for_cmsisnn(relay_mod, relay_params) + + # obtain the executor factory post relay compilation. + input_map, output_map, output_tolerance = tfl_model.generate_reference_data() + interface_api = "c" + use_unpacked_api = True + compiled_models = compile_models( + AOTTestModel( + module=cmsisnn_mod, + inputs=input_map, + outputs=output_map, + params=None, + output_tolerance=output_tolerance, + ), + interface_api, + use_unpacked_api, + pass_config=test_runner.pass_config, + ) + + # Validate frquency of function appearances in the Host C file after forward declarations. + lib_mod = compiled_models[0].executor_factory.lib.imported_modules[0] + main_source = lib_mod.get_source() + assert ( + main_source.count("TVMBackendAllocWorkspace") == 3 + or main_source.count("TVMBackendAllocWorkspace") == 0 + ) + assert main_source.count("tvmgen_default_fused_reshape") == 2 + assert main_source.count("tvmgen_default_cmsis_nn_main") == 12 + cmsisnn_source = lib_mod.imported_modules[0].get_source() + assert cmsisnn_source.count("arm_convolve_wrapper") == 1 + assert cmsisnn_source.count("arm_fully_connected") == 3 + assert cmsisnn_source.count("arm_softmax") == 1 + + +@parametrize_aot_options +def test_internal_calls(interface_api, use_unpacked_api, test_runner): + """Test for all internal function calls. No forward declarations are expected here.""" + dtype = "float32" + groups = 32 + weight_shape = 1 + ishape = (1, 32, 14, 14) + wshape = (32, weight_shape, 3, 3) + pass_config = {"tir.usmp.enable": True} + test_runner = AOTTestRunner( + makefile=test_runner.makefile, + prologue=test_runner.prologue, + epilogue=test_runner.epilogue, + includes=test_runner.includes, + parameters=test_runner.parameters, + pass_config=pass_config, + ) + + data0 = relay.var("data", shape=ishape, dtype=dtype) + weight0 = relay.var("weight", shape=wshape, dtype=dtype) + out = relay.nn.conv2d(data0, weight0, kernel_size=(3, 3), padding=(1, 1), groups=groups) + main_f = relay.Function([data0, weight0], out) + mod = tvm.IRModule() + mod["main"] = main_f + mod = tvm.relay.transform.InferType()(mod) + + i_data = np.random.uniform(0, 1, ishape).astype(dtype) + w1_data = np.random.uniform(0, 1, wshape).astype(dtype) + + inputs = OrderedDict([("data", i_data), ("weight", w1_data)]) + + output_list = generate_ref_data(mod, inputs) + compiled_models = compile_models( + models=AOTTestModel(module=mod, inputs=inputs, outputs=output_list), + interface_api=interface_api, + use_unpacked_api=use_unpacked_api, + pass_config=test_runner.pass_config, + ) + + lib_mod = compiled_models[0].executor_factory.lib.imported_modules[0] + main_source = lib_mod.get_source() + assert main_source.count("tvmgen_default_fused_nn_contrib_depthwise_conv2d_NCHWc") == 2 + assert main_source.count("tvmgen_default_fused_layout_transform") == 6 + + +@tvm.testing.requires_corstone300 +def test_tensorized_calls(): + """Test a subgraph with a mix of internal and tensorized calls.""" + data_shape, kernel_size, num_filter, groups, strides, padding, dilation = ( + (1, 32, 32, 16), + (3, 3), + 16, + 1, + 1, + (0, 2, 2, 0), + 1, + ) + in_dtype = "int8" + data_layout = "NHWC" + kernel_layout = "HWOI" + ref_kernel_layout = "HWIO" + out_layout = "NHWC" + schedule_name = "conv2d_nhwc_dsp.arm_cpu" + + ref_input_data = np.random.randint(low=-128, high=127, size=data_shape, dtype=in_dtype) + ref_input_var = relay.var("input", relay.TensorType(data_shape, in_dtype)) # NHWC layout + kernel_shape = (*kernel_size, data_shape[-1] // groups, num_filter) # HWIO layout + ref_kernel_data = np.random.randint(low=-10, high=10, size=kernel_shape, dtype=in_dtype) + + ref_relay_op = relay.op.nn.conv2d( + ref_input_var, + relay.const(_change_ndarray_layout(ref_kernel_data, "HWIO", ref_kernel_layout)), + kernel_size=kernel_size, + strides=strides, + padding=padding, + groups=groups, + dilation=(dilation, dilation), + data_layout="NHWC", + kernel_layout=ref_kernel_layout, + out_dtype="int32", + out_layout="NHWC", + ) + ref_module = tvm.IRModule.from_expr(relay.Function([ref_input_var], ref_relay_op)) + ref_outputs = generate_ref_data(ref_module, {"input": ref_input_data}) + + # Reshape output dictionary to match out_layout + assert len(ref_outputs) == 1 + output_tensor_name, output_tensor = next(iter(ref_outputs.items())) + ref_outputs[output_tensor_name] = _change_ndarray_layout(output_tensor, "NHWC", out_layout) + + test_input_data = _change_ndarray_layout(ref_input_data, "NHWC", data_layout) + test_input_var = relay.var("input", relay.TensorType(test_input_data.shape, in_dtype)) + test_kernel_data = _change_ndarray_layout(ref_kernel_data, "HWIO", kernel_layout) + + test_relay_op = relay.op.nn.conv2d( + test_input_var, + relay.const(test_kernel_data), + kernel_size=kernel_size, + strides=strides, + padding=padding, + groups=groups, + dilation=(dilation, dilation), + data_layout=data_layout, + kernel_layout=kernel_layout, + out_dtype="int32", + out_layout=out_layout, + ) + test_function = relay.Function([test_input_var], test_relay_op) + test_model = AOTTestModel( + module=tvm.IRModule.from_expr(test_function), + inputs={"input": test_input_data}, + outputs=ref_outputs, + ) + compiled_models = compile_models( + test_model, + interface_api="c", + use_unpacked_api=True, + pass_config=AOT_CORSTONE300_RUNNER.pass_config, + target=f"c -keys=arm_cpu -mcpu=cortex-m7", + schedule_name=schedule_name, + ) + + lib_mod = compiled_models[0].executor_factory.lib.imported_modules[0] + main_source = lib_mod.get_source() + assert main_source.count("tvmgen_default_fused_nn_conv2d") == 2 + assert main_source.count("gemm_") == 13 + + +if __name__ == "__main__": + tvm.testing.main()