From 6dc3595023a3b9714b20c565379ed9aa784a8a4c Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Fri, 26 Nov 2021 20:05:35 +0000 Subject: [PATCH 1/7] [microNPU] Move the compilation to use Target Hooks. This commits moves the current compilation flow to use target hooks, so that the generated TIR is provided to unified module to for unified optimizations. Change-Id: Ib3239a04ab201748e7f1b1ffa503cfe2aa7ccb7b --- include/tvm/tir/transform.h | 5 + .../relay/backend/contrib/ethosu/codegen.py | 98 +++++++----- .../backend/contrib/ethosu/tir/compiler.py | 3 +- .../backend/contrib/ethosu/tir/passes.py | 31 ++++ .../contrib/ethosu/tir_to_cs_translator.py | 22 +-- .../tvm/relay/backend/contrib/ethosu/util.py | 30 ++++ python/tvm/runtime/object_generic.py | 8 +- src/relay/backend/contrib/ethosu/codegen.cc | 136 +++++++++++++++++ .../backend/contrib/ethosu/source_module.cc | 123 +++++++-------- src/relay/backend/contrib/ethosu/utils.cc | 75 ++++++++++ src/relay/backend/contrib/ethosu/utils.h | 96 ++++++++++++ src/relay/backend/te_compiler.cc | 7 + src/relay/backend/te_compiler.h | 1 + src/target/target_kind.cc | 1 - src/tir/transforms/lower_tvm_builtin.cc | 5 + src/tir/transforms/make_unpacked_api.cc | 20 +-- .../contrib/test_ethosu/test_codegen.py | 141 +++++++----------- 17 files changed, 583 insertions(+), 219 deletions(-) create mode 100644 src/relay/backend/contrib/ethosu/codegen.cc create mode 100644 src/relay/backend/contrib/ethosu/utils.cc create mode 100644 src/relay/backend/contrib/ethosu/utils.h diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 7922e978c381..02543886a982 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -287,6 +287,11 @@ TVM_DLL Pass LowerThreadAllreduce(); */ TVM_DLL Pass InferFragment(); +/*! + * \brief This annotation for nodes to be disabled for builtin lowering + */ +static constexpr const char* kDisableLowerTVMBuiltin = "disable_lower_builtin"; + /*! * \brief Lower builtin intrinsics. * \return The pass. diff --git a/python/tvm/relay/backend/contrib/ethosu/codegen.py b/python/tvm/relay/backend/contrib/ethosu/codegen.py index 5fe51b4cbda0..9fabba3f19f7 100644 --- a/python/tvm/relay/backend/contrib/ethosu/codegen.py +++ b/python/tvm/relay/backend/contrib/ethosu/codegen.py @@ -14,7 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Codegen for Arm(R) Ethos(TM)-U""" +"""Codegen for Arm(R) Ethos(TM)-U NPU""" + import tvm from tvm import relay from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir @@ -24,25 +25,7 @@ from tvm.relay.backend.contrib.ethosu import util -@tvm._ffi.register_func("relay.ext.ethos-u") -def ethosu_compiler(external_function): - """The entry-point to a compile a external relay function of - NPU compatible operators to generated command stream. - Such generated command stream would be used to create c-source r - runtime module that interfaces with NPU driver. - """ - assert isinstance(external_function, tvm.ir.function.BaseFunc) - func_name = external_function.attrs["global_symbol"] - # There should only be a single input - assert len(external_function.params) == 1 - input_size = util.calculate_size_bytes(external_function.params[0]) - output_size = util.calculate_size_bytes(external_function.body) - cmms, encoded_constants, scratch_size = _compile(external_function) - ethosu_runtime = tvm._ffi.get_global_func("runtime.module.ethos-u.create") - return ethosu_runtime(func_name, cmms, encoded_constants, scratch_size, input_size, output_size) - - -@tvm._ffi.register_func("relay.ext.ethos-u.constant_updater") +@tvm._ffi.register_func("relay.ext.ethosu.constant_updater") def constant_updater(expr, symbol): # pylint: disable=unused-argument """ The constant updater process happen after lowering in the core compiler. @@ -52,25 +35,25 @@ def constant_updater(expr, symbol): # pylint: disable=unused-argument return dict() -def _compile(ext_func): +@tvm._ffi.register_func("relay.ext.ethosu.relay_to_tir_func") +def relay_to_tir_func(ext_func: relay.Function) -> tvm.tir.PrimFunc: """ - This is the main wrapper that accepts an external - relay function and runs all the passes to lower it down - to command stream + This is hook for python-based lowering of relay function + that gets offloaded to Ethos-U. + Parameters ---------- - ext_func : tvm.relay.function.Function - The partitioned relay function + ext_func : relay.Function + This is the partitioned relay function + Returns ------- - cs : str - An hex string of the bytes of command stream - encoded_constants : str - An hex string of the bytes that includes concat'd - encoded weights, encoded biases and scales. - scratch_size : int - The size of the scratch buffer needed. + primfunc : tir.PrimFunc + This returns the scheduled PrimFunc """ + assert len(ext_func.params) == 1 + input_size = util.calculate_size_bytes(ext_func.params[0]) + output_size = util.calculate_size_bytes(ext_func.body) mod = tvm.IRModule() mod["main"] = ext_func mod = LegalizeEthosU()(mod) @@ -80,5 +63,50 @@ def _compile(ext_func): # that can perform scheduling based on user inputs such as # scratch memory size. tir_mod, params = lower_to_tir(mod["main"], copy_constants()) - cmms, encoded_constants, scratch_size = tir_to_cs_translator.translate(tir_mod, params) - return cmms, encoded_constants, scratch_size + + for idx in params.keys(): + params[idx] = tvm.nd.array(params[idx]) + + primfunc = tir_mod["main"] + primfunc = primfunc.with_attr("global_symbol", ext_func.attrs["global_symbol"]) + primfunc = primfunc.with_attr("ethos-u.constants", params) + primfunc = primfunc.with_attr("ethos-u.input_size", input_size) + primfunc = primfunc.with_attr("ethos-u.output_size", output_size) + return primfunc + + +@tvm._ffi.register_func("relay.ext.ethosu.primfunc_to_artifact") +def primfunc_to_artifact(primfunc: tvm.tir.PrimFunc) -> util.CompilationArtifact: + """ + This is hook for python-based lowering of TIR PrimFunc + that has undergone unified optimization to Compilation + Artifact destined for the microNPU. + + Parameters + ---------- + primfunc : tir.PrimFunc + TIR PrimFuncthat has undergone unified optimization + + Returns + ------- + CompilationArtifact + This is a structure that holds the binary artifacts + for the microNPU + """ + symbol = str(primfunc.attrs["global_symbol"]) + params = primfunc.attrs["ethos-u.constants"] + input_size = primfunc.attrs["ethos-u.input_size"] + output_size = primfunc.attrs["ethos-u.output_size"] + tir_mod = tvm.IRModule() + tir_mod[symbol] = primfunc + + params_with_int_keys = dict() + for idx in params.keys(): + params_with_int_keys[int(idx)] = params[idx].numpy() + + cmms, encoded_constants, scratch_size = tir_to_cs_translator.translate( + tir_mod, params_with_int_keys + ) + return util.CompilationArtifact( + cmms, encoded_constants, scratch_size, input_size, output_size, symbol + ) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py index b68a5ad14a6f..b3ffecb2ec22 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py @@ -21,7 +21,7 @@ from tvm.relay.expr_functor import ExprMutator from tvm.driver.build_module import schedule_to_module -from .passes import ReplaceOperators, RemoveZeroStores, EncodeConstants +from .passes import ReplaceOperators, RemoveZeroStores, EncodeConstants, AnnotateAllocates from .scheduler import schedule @@ -88,6 +88,7 @@ def lower_ethosu(sch, args, const_dict, name="main"): mod, const_dict = EncodeConstants(const_dict)(mod) mod = tvm.tir.transform.StorageRewrite()(mod) mod = tvm.tir.transform.RemoveNoOp()(mod) + mod = AnnotateAllocates()(mod) return mod, const_dict diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index cb46ba319edd..41a6832c5953 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -488,3 +488,34 @@ def _encode_constants(mod): return new_func, new_const_dict return _encode_constants + + +# This need to be kept in sync with kDisableLowerTVMBuiltin in include/tvm/tir/transform.h +DISABLE_LOWER_BUILTIN = "disable_lower_builtin" + + +def AnnotateAllocates(): + """ + This is pass to annotate all allocate + nodes of the PrimFuncs of the microNPU + to be not lowered to built-ins. + """ + + def _post_transform(allocate): + return tvm.tir.Allocate( + buffer_var=allocate.buffer_var, + dtype=allocate.dtype, + extents=allocate.extents, + condition=allocate.condition, + body=allocate.body, + annotations={DISABLE_LOWER_BUILTIN: True}, + ) + + def _ftransform(f, mod, ctx): + return f.with_body( + tvm.tir.stmt_functor.ir_transform(f.body, None, _post_transform, ["tir.Allocate"]) + ) + + return tvm.tir.transform.prim_func_pass( + _ftransform, opt_level=0, name="tir.ethosu.annotate_allocates" + ) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py index 4e84febe5e48..2912f46bf697 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py @@ -152,16 +152,16 @@ def extract_buffer_info( primfunc = mod.functions.items()[0][1] for idx, const_data in param_dict.items(): param = primfunc.params[idx] - buffer_info[primfunc.buffer_map[param].data] = BufferInfo( + buffer_info[param] = BufferInfo( const_data, const_data.shape, const_data.dtype, BufferType.constant ) for param in primfunc.params: - if primfunc.buffer_map[param].data not in buffer_info.keys(): - buffer_info[primfunc.buffer_map[param].data] = BufferInfo( + if param not in buffer_info.keys(): + buffer_info[param] = BufferInfo( + None, + None, None, - primfunc.buffer_map[param].shape, - primfunc.buffer_map[param].dtype, BufferType.input_or_output, ) @@ -223,7 +223,7 @@ def replace_npu_fm_with_address(npu_fm): def replace_npu_address_range_with_address(npu_addr_range): assert isinstance(npu_addr_range.address, tvm.tir.Load) buffer = npu_addr_range.address.buffer_var - assert buffer in buffer_addresses.keys() + assert buffer in buffer_addresses.keys(), f"searching for buffer : {buffer}, but not found" address, buffer_type = buffer_addresses[buffer] return vapi.NpuAddressRange(_REGION_MAP[buffer_type], address, npu_addr_range.length) @@ -269,17 +269,17 @@ def classify_io(buffer): size_in_bytes = util.round_up(size_in_bytes, 16) constant_tensor = np.append(constant_tensor, np.resize(info.values, size_in_bytes)) else: - size_in_bytes = int( - (np.iinfo(np.dtype(info.dtype)).bits // 8) * np.prod(list(info.shape)) - ) - # Every memory address the NPU access have to be 16 byte aligned - size_in_bytes = util.round_up(size_in_bytes, 16) if info.btype == BufferType.input_or_output: buffer_type = classify_io(_buffer) assert buffer_type in (BufferType.input, BufferType.output) address = 0 buffer_addresses[_buffer] = (address, buffer_type) else: + size_in_bytes = int( + (np.iinfo(np.dtype(info.dtype)).bits // 8) * np.prod(list(info.shape)) + ) + # Every memory address the NPU access have to be 16 byte aligned + size_in_bytes = util.round_up(size_in_bytes, 16) assert info.btype == BufferType.scratch address = scratch_size scratch_size += size_in_bytes diff --git a/python/tvm/relay/backend/contrib/ethosu/util.py b/python/tvm/relay/backend/contrib/ethosu/util.py index 589ab21b3998..45a82d5932d6 100644 --- a/python/tvm/relay/backend/contrib/ethosu/util.py +++ b/python/tvm/relay/backend/contrib/ethosu/util.py @@ -28,6 +28,9 @@ import tvm # type: ignore from tvm import relay +from tvm._ffi import register_object +from tvm.runtime import Object +from . import _ffi_api class QConv2DArgs(Enum): @@ -209,3 +212,30 @@ def calculate_size_bytes(expr): element_size = type_info.bits // 8 elements = np.prod(list(expr.checked_type.shape)) return element_size * elements + + +@register_object("relay.ext.ethos-u.CompilationArtifact") +class CompilationArtifact(Object): + """ + This is a structure to hold binary artifacts + for the microNPU. + """ + + def __init__( + self, + command_stream: str, + encoded_constants: str, + scratch_size: int, + input_size: int, + output_size: int, + function_name: str, + ): + self.__init_handle_by_constructor__( + _ffi_api.CompilationArtifact, # type: ignore # pylint: disable=no-member + command_stream, + encoded_constants, + scratch_size, + input_size, + output_size, + function_name, + ) diff --git a/python/tvm/runtime/object_generic.py b/python/tvm/runtime/object_generic.py index 974523d1eb1a..7a55d3ef244e 100644 --- a/python/tvm/runtime/object_generic.py +++ b/python/tvm/runtime/object_generic.py @@ -68,9 +68,13 @@ def convert_to_object(value, span=None): if isinstance(value, dict): vlist = [] for item in value.items(): - if not isinstance(item[0], ObjectTypes) and not isinstance(item[0], string_types): + if ( + not isinstance(item[0], ObjectTypes) + and not isinstance(item[0], string_types) + and not isinstance(item[0], Number) + ): raise ValueError("key of map must already been a container type") - vlist.append(item[0]) + vlist.append(convert_to_object(item[0])) vlist.append(convert_to_object(item[1])) return _ffi_api.Map(*vlist) if isinstance(value, ObjectGeneric): diff --git a/src/relay/backend/contrib/ethosu/codegen.cc b/src/relay/backend/contrib/ethosu/codegen.cc new file mode 100644 index 000000000000..d7f9510fc463 --- /dev/null +++ b/src/relay/backend/contrib/ethosu/codegen.cc @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relay/backend/contrib/ethosu/codegen.cc + * + * \brief This file contains the target hooks for Arm(R) Ethos(TM)-U NPU + * Codegen. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "../../../op/make_op.h" +#include "utils.h" + +namespace tvm { +namespace relay { +namespace contrib { +namespace ethosu { + +/*! + * \brief This mutator lowers each external + * relay function to a TIR PrimFunc + */ +class RelayToTIRMutator : public MixedModeMutator { + public: + explicit RelayToTIRMutator(IRModule ir_module) : ir_module_(ir_module) {} + + IRModule operator()() { + GlobalVar main_global_var = ir_module_->GetGlobalVar("main"); + Function main_func = Downcast(ir_module_->Lookup(main_global_var)); + + // Copy everything across and mutate the body + Function mutated_main = + Function(main_func->params, VisitExpr(main_func->body), main_func->ret_type, + main_func->type_params, main_func->attrs, main_func->span); + + ir_module_->Update(main_global_var, mutated_main); + ir_module_ = WithAttr(ir_module_, "device_contexts", device_contexts_); + return ir_module_; + } + + Expr Rewrite_(const CallNode* pre, const Expr& post) override { + Call call = Downcast(post); + if (call->op->IsInstance()) { + Function func = Downcast(call->op); + auto codegen_name = func->GetAttr(attr::kCompiler); + if (codegen_name.defined() && codegen_name == "ethos-u") { + auto relay_to_tir_func_pf = + tvm::runtime::Registry::Get("relay.ext.ethosu.relay_to_tir_func"); + ICHECK(relay_to_tir_func_pf); + tir::PrimFunc prim_func = (*relay_to_tir_func_pf)(func); + prim_func = WithAttr(prim_func, tvm::attr::kTarget, Target("ethos-u")); + String symbol_name = prim_func->GetAttr(tvm::attr::kGlobalSymbol).value(); + GlobalVar gv(symbol_name); + Array args = call->args; + gv->checked_type_ = func->checked_type(); + ir_module_->Update(gv, prim_func); + device_contexts_.Set(gv, codegen_name.value()); + return Call(gv, args, call->attrs, call->type_args); + } + } + return post; + } + + private: + IRModule ir_module_; + Map device_contexts_; +}; + +tvm::transform::Pass RelayToTIR() { + runtime::TypedPackedFunc pass_func = + [=](IRModule ir_module, transform::PassContext pass_context) { + return RelayToTIRMutator(ir_module)(); + }; + return tvm::transform::CreateModulePass(pass_func, 0, "relay.contrib.ethosu.RelayToTIR", {}); +} + +/*! + * \brief This function lowers the IRModule with PrimFunc + * with the target of the microNPU to a C-source runtime module + */ +runtime::Module TIRToRuntime(IRModule mod, Target target) { + Array compile_artifacts; + for (const auto& kv : mod->functions) { + const tir::PrimFunc& prim_func = Downcast(kv.second); + Optional> params = + prim_func->GetAttr>("ethos-u.constants"); + ICHECK(params) << "microNPU params should be present"; + auto primfunc_to_artifact_pf = + tvm::runtime::Registry::Get("relay.ext.ethosu.primfunc_to_artifact"); + ICHECK(primfunc_to_artifact_pf); + CompilationArtifact ca = (*primfunc_to_artifact_pf)(prim_func); + compile_artifacts.push_back(ca); + } + auto ca_to_runtime = tvm::runtime::Registry::Get("runtime.module.ethos-u.create"); + return (*ca_to_runtime)(compile_artifacts); +} + +TVM_REGISTER_TARGET_KIND("ethos-u", kDLCPU) + .set_attr("use_device_api", Bool(true)) + .set_attr("RelayToTIR", RelayToTIR()) + .set_attr("TIRToRuntime", TIRToRuntime); + +} // namespace ethosu +} // namespace contrib +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/contrib/ethosu/source_module.cc b/src/relay/backend/contrib/ethosu/source_module.cc index b7b359ab4735..6515df6245ae 100644 --- a/src/relay/backend/contrib/ethosu/source_module.cc +++ b/src/relay/backend/contrib/ethosu/source_module.cc @@ -41,34 +41,29 @@ #include #include "../../../../runtime/file_utils.h" +#include "utils.h" namespace tvm { namespace runtime { +using CompilationArtifact = relay::contrib::ethosu::CompilationArtifact; + // The runtime.Module that contains the host-side c code // required for invoking the NPU with the command stream class EthosUModuleNode : public ModuleNode { public: - /*! - * \brief The ethos runtime module. - * - * \param func_name_ name of the should be codegen'd function - * \param cmms_hex_ command stream for the NPU in hex - * \param weights_bias_hex_ the encoded biases and weights for the NPU in hex - * \param scratch_size_ the size of the scratch memory required for command stream - * \param input_size_ the size (in bytes) for the input tensor - * \param output_size_ the size (in bytes) for the output tensor - */ - explicit EthosUModuleNode(const String& func_name_, const String& cmms_hex_, - const String& weights_bias_hex_, const Integer& scratch_size_, - const Integer& input_size_, const Integer& output_size_) { - func_name = func_name_; - cmms_hex = std::move(cmms_hex_); - weights_bias_hex = std::move(weights_bias_hex_); - scratch_size = scratch_size_->value; - input_size = input_size_->value; - output_size = output_size_->value; - c_source = GenerateSource(); + explicit EthosUModuleNode(Array compilation_artifacts) + : compilation_artifacts_(compilation_artifacts) { + c_source += "#include \n"; + c_source += "#include \n"; + c_source += "#include \n"; + c_source += "#include \n"; + c_source += "\n"; + for (const CompilationArtifact& compilation_artifact : compilation_artifacts) { + c_source += GenerateSource(compilation_artifact); + c_source += "\n"; + c_source += "\n"; + } } /*! @@ -79,7 +74,6 @@ class EthosUModuleNode : public ModuleNode { */ void SaveToFile(const std::string& file_name, const std::string& format) final { std::string fmt = GetFileFormat(file_name, format); - LOG(INFO) << "format=" << fmt << ";;\n"; ICHECK_EQ(fmt, "c") << "Can only save to format=" << "c"; std::ofstream out(file_name); @@ -89,7 +83,7 @@ class EthosUModuleNode : public ModuleNode { std::string GetSource(const std::string& format) final { return c_source; } - std::string GetCS() { return cmms_hex; } + Array GetArtifacts() { return compilation_artifacts_; } /*! * \brief Get a PackedFunc from the module. @@ -102,7 +96,11 @@ class EthosUModuleNode : public ModuleNode { PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { if (name == "get_func_names") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = Array{this->func_name}; + Array func_names; + for (const CompilationArtifact& ca : compilation_artifacts_) { + func_names.push_back(ca->function_name); + } + *rv = func_names; }); } return PackedFunc(); @@ -110,21 +108,14 @@ class EthosUModuleNode : public ModuleNode { const char* type_key() const override { return "c"; } - static Module Create(String func_name, String cmms_hex, String weights_bias_hex, - Integer scratch_size, Integer input_size, Integer output_size) { - auto n = make_object(func_name, cmms_hex, weights_bias_hex, scratch_size, - input_size, output_size); + static Module Create(Array compilation_artifacts) { + auto n = make_object(compilation_artifacts); return Module(n); } private: - String c_source; - String func_name; - String cmms_hex; - String weights_bias_hex; - size_t scratch_size; - size_t input_size; - size_t output_size; + std::string c_source; + Array compilation_artifacts_; int indent_{0}; /*! @@ -151,10 +142,10 @@ class EthosUModuleNode : public ModuleNode { * \return string of code that updates the base_addrs array with the base address of the given * array */ - std::string SetBaseAddress(int index, std::string name) { + std::string SetBaseAddress(int index, std::string name, std::string size) { std::stringstream ss; ss << " base_addrs[" << index << "] = (uintptr_t)(" << name << ");\n"; - ss << " base_addrs_size[" << index << "] = " << name << "_size;\n"; + ss << " base_addrs_size[" << index << "] = " << size << ";\n"; return ss.str(); } @@ -211,43 +202,39 @@ class EthosUModuleNode : public ModuleNode { * * \return string of code that offloads a subgraph to the NPU */ - std::string GenerateSource() { - std::string func_no_dashes = func_name; + std::string GenerateSource(relay::contrib::ethosu::CompilationArtifact compilation_artifact) { + std::string func_no_dashes = compilation_artifact->function_name; std::replace(func_no_dashes.begin(), func_no_dashes.end(), '-', '_'); std::stringstream ss; - ss << "#include \n"; - ss << "#include \n"; - ss << "#include \n"; - ss << "#include \n"; - ss << "\n"; - size_t weights_size = (weights_bias_hex.size() / 2); - ss << "static const size_t weights_size = " << std::to_string(weights_size) << ";\n"; - ss << "static const size_t scratch_size = " << std::to_string(scratch_size) << ";\n"; + size_t weights_size = (compilation_artifact->encoded_constants.size() / 2); + size_t scratch_size = compilation_artifact->scratch_size; ss << "// Update linker script to place .rodata.tvm in memory that can be accessed by the " "NPU\n"; if (weights_size > 0) { - ss << "__attribute__((section(\".rodata.tvm\"), aligned(16))) static int8_t weights[" - << weights_size << "] = \""; - ss << GetHexString(weights_bias_hex); + ss << "__attribute__((section(\".rodata.tvm\"), aligned(16))) static int8_t " + << func_no_dashes << "_weights[" << weights_size << "] = \""; + ss << GetHexString(compilation_artifact->encoded_constants); ss << "\";\n"; } else { - ss << "static int8_t* weights = NULL;\n"; + ss << "static int8_t* " << func_no_dashes << "_weights = NULL;\n"; } - ss << "__attribute__((section(\".rodata.tvm\"), aligned(16))) static int8_t cms_data_data[" - << cmms_hex.size() / 2 << "] = \""; - ss << GetHexString(cmms_hex); + ss << "__attribute__((section(\".rodata.tvm\"), aligned(16))) static int8_t " << func_no_dashes + << "_cms_data_data[" << compilation_artifact->command_stream.size() / 2 << "] = \""; + ss << GetHexString(compilation_artifact->command_stream); ss << "\";\n"; - ss << "static const size_t cms_data_size = sizeof(cms_data_data);\n"; ss << "\n"; PrintExternCPrefix(ss); ss << "static int32_t " << func_no_dashes + "_(int8_t* in0, " << "size_t in0_size, int8_t* out0, size_t out0_size, void* resource_handle) {\n"; ss << " int num_tensors = 5;\n"; - ss << " void* cms_data = (void*)(cms_data_data);\n"; + ss << " void* cms_data = (void*)(" << func_no_dashes << "_cms_data_data);\n"; ss << " int64_t device_type = kDLCPU;\n"; ss << " int64_t device_id = 0;\n"; + ss << " const size_t weights_size = " << std::to_string(weights_size) << ";\n"; + ss << " const size_t scratch_size = " << std::to_string(scratch_size) << ";\n"; + ss << " const size_t cms_data_size = sizeof(" << func_no_dashes << "_cms_data_data);\n"; if (scratch_size > 0) { ss << " int8_t* scratch = (int8_t*) TVMBackendAllocWorkspace(device_type, device_id, " "(uint64_t)scratch_size, 0, 16);\n"; @@ -257,11 +244,11 @@ class EthosUModuleNode : public ModuleNode { ss << " size_t base_addrs_size[num_tensors];\n"; ss << " uint64_t base_addrs[num_tensors];\n"; ss << "\n"; - ss << SetBaseAddress(0, "weights"); - ss << SetBaseAddress(1, "scratch"); - ss << SetBaseAddress(2, "scratch"); - ss << SetBaseAddress(3, "in0"); - ss << SetBaseAddress(4, "out0"); + ss << SetBaseAddress(0, func_no_dashes + "_weights", "weights_size"); + ss << SetBaseAddress(1, "scratch", "scratch_size"); + ss << SetBaseAddress(2, "scratch", "scratch_size"); + ss << SetBaseAddress(3, "in0", "in0_size"); + ss << SetBaseAddress(4, "out0", "out0_size"); ss << "\n"; ss << " int32_t result = TVMEthosULaunch(resource_handle, cms_data, cms_data_size, " "base_addrs, base_addrs_size, num_tensors);\n"; @@ -277,8 +264,8 @@ class EthosUModuleNode : public ModuleNode { ss << "// Wrapper function is provided to allow for easier debugging\n"; ss << "inline static int32_t " + func_no_dashes + "_wrapper_(void* input, void* output, void* resource_handle) {\n"; - ss << " size_t input_data_size = " << input_size << ";\n"; - ss << " size_t output_data_size = " << output_size << ";\n"; + ss << " size_t input_data_size = " << compilation_artifact->input_size << ";\n"; + ss << " size_t output_data_size = " << compilation_artifact->output_size << ";\n"; ss << " return " + func_no_dashes + "_((int8_t*)input, input_data_size, (int8_t*)output, output_data_size, " + "resource_handle);\n"; @@ -286,7 +273,7 @@ class EthosUModuleNode : public ModuleNode { PrintExternCPostfix(ss); ss << "\n"; PrintExternCPrefix(ss); - PrintRuntimeFunctionHeader(ss, func_name); + PrintRuntimeFunctionHeader(ss, func_no_dashes); EnterScope(); PrintIndents(ss); ss << "return " << func_no_dashes << "_wrapper_(input, output, resource_handle);\n"; @@ -313,14 +300,12 @@ inline EthosUModuleNode* EthosUModule::operator->() { } TVM_REGISTER_GLOBAL("runtime.module.ethos-u.create") - .set_body_typed([](String func_name, String cmms_hex, String weights_bias_hex, - Integer scratch_size, Integer input_size, Integer output_size) { - return EthosUModuleNode::Create(func_name, cmms_hex, weights_bias_hex, scratch_size, - input_size, output_size); + .set_body_typed([](Array compilation_artifacts) { + return EthosUModuleNode::Create(compilation_artifacts); }); -TVM_REGISTER_GLOBAL("runtime.module.ethos-u.getcs").set_body_typed([](EthosUModule mod) { - return mod->GetCS(); +TVM_REGISTER_GLOBAL("runtime.module.ethos-u.get_artifacts").set_body_typed([](EthosUModule mod) { + return mod->GetArtifacts(); }); } // namespace runtime diff --git a/src/relay/backend/contrib/ethosu/utils.cc b/src/relay/backend/contrib/ethosu/utils.cc new file mode 100644 index 000000000000..7e6c1c2ac840 --- /dev/null +++ b/src/relay/backend/contrib/ethosu/utils.cc @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relay/backend/contrib/ethosu/utils.cc + * \brief Utilities for microNPU codegen + */ + +#include "utils.h" + +#include +#include +#include +#include + +#include + +namespace tvm { +namespace relay { +namespace contrib { +namespace ethosu { + +CompilationArtifact::CompilationArtifact(String command_stream, String encoded_constants, + Integer scratch_size, Integer input_size, + Integer output_size, String function_name) { + auto compilation_artifact_node = make_object(); + compilation_artifact_node->command_stream = command_stream; + compilation_artifact_node->encoded_constants = encoded_constants; + compilation_artifact_node->scratch_size = scratch_size; + compilation_artifact_node->input_size = input_size; + compilation_artifact_node->output_size = output_size; + compilation_artifact_node->function_name = function_name; + data_ = std::move(compilation_artifact_node); +} + +TVM_REGISTER_NODE_TYPE(CompilationArtifactNode); +TVM_REGISTER_GLOBAL("relay.ext.ethos-u.CompilationArtifact") + .set_body_typed([](String command_stream, String encoded_constants, Integer scratch_size, + Integer input_size, Integer output_size, String function_name) { + return CompilationArtifact(command_stream, encoded_constants, scratch_size, input_size, + output_size, function_name); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "CompilationArtifactNode(\n" + << "command_stream=" << node->command_stream + << ",\n encoded_constants=" << node->encoded_constants + << ",\n scratch_size=" << node->scratch_size + << ",\n input_size=" << node->input_size + << ",\n output_size=" << node->output_size + << ",\n function_name=" << node->function_name << ")"; + }); + +} // namespace ethosu +} // namespace contrib +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/contrib/ethosu/utils.h b/src/relay/backend/contrib/ethosu/utils.h new file mode 100644 index 000000000000..5e9e337c3f17 --- /dev/null +++ b/src/relay/backend/contrib/ethosu/utils.h @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relay/backend/contrib/ethosu/utils.h + * \brief Utilities for microNPU codegen + */ + +#ifndef TVM_RELAY_BACKEND_CONTRIB_ETHOSU_UTILS_H_ +#define TVM_RELAY_BACKEND_CONTRIB_ETHOSU_UTILS_H_ + +#include +#include +#include + +namespace tvm { +namespace relay { +namespace contrib { +namespace ethosu { + +/*! + * \brief Captures all the binary artifactes required to create + * the C-source runtime module + */ +struct CompilationArtifactNode : public Object { + /*! \brief The binary command stream (CS) in hex format */ + String command_stream; + /*! \brief The encoded biases and weights in hex format */ + String encoded_constants; + /*! \brief The intermediary scratch area required for the execution of the CS */ + Integer scratch_size; + /*! \brief The size of the input tensor in bytes */ + Integer input_size; + /*! \brief The size of the output tensor in bytes */ + Integer output_size; + /*! \brief The name of the function */ + String function_name; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("command_stream", &command_stream); + v->Visit("encoded_constants", &encoded_constants); + v->Visit("scratch_size", &scratch_size); + v->Visit("input_size", &input_size); + v->Visit("output_size", &output_size); + v->Visit("function_name", &function_name); + } + + bool SEqualReduce(const CompilationArtifactNode* other, SEqualReducer equal) const { + return equal(command_stream, other->command_stream) && + equal(encoded_constants, other->encoded_constants) && + equal(scratch_size, other->scratch_size) && equal(input_size, other->input_size) && + equal(output_size, other->output_size) && equal(function_name, other->function_name); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(command_stream); + hash_reduce(encoded_constants); + hash_reduce(scratch_size); + hash_reduce(input_size); + hash_reduce(output_size); + hash_reduce(function_name); + } + + static constexpr const char* _type_key = "relay.ext.ethos-u.CompilationArtifact"; + TVM_DECLARE_FINAL_OBJECT_INFO(CompilationArtifactNode, Object); +}; + +class CompilationArtifact : public ObjectRef { + public: + TVM_DLL CompilationArtifact(String command_stream, String encoded_constants, Integer scratch_size, + Integer input_size, Integer output_size, String function_name); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(CompilationArtifact, ObjectRef, CompilationArtifactNode); +}; + +} // namespace ethosu +} // namespace contrib +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_BACKEND_CONTRIB_ETHOSU_UTILS_H_ diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index b339828b0cd4..662655f5a033 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -181,6 +181,9 @@ class TECompilerImpl : public TECompilerNode { } Map GetDeviceContexts() { return device_contexts_; } + void SetDeviceContexts(const Map& device_contexts) { + device_contexts_ = device_contexts; + } void Clear() final { cache_.clear(); } @@ -953,6 +956,10 @@ void UpdateFunctionMetadata(BaseFunc func, IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn process_fn) { TECompiler compiler; + auto device_contexts = module->GetAttr>("device_contexts"); + if (device_contexts) { + compiler->SetDeviceContexts(device_contexts.value()); + } auto updated_module = LowerTensorExpr(module_name, compiler, process_fn)(module); diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index cb36718df120..ddcb80db3406 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -115,6 +115,7 @@ class TECompilerNode : public Object { * annotated) */ virtual Map GetDeviceContexts() = 0; + virtual void SetDeviceContexts(const Map& device_contexts) = 0; virtual std::unordered_map GetOpWeights() = 0; diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index b44ea251204d..b92f57ad7974 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -401,7 +401,6 @@ TVM_REGISTER_TARGET_KIND("hybrid", kDLCPU) // line break TVM_REGISTER_TARGET_KIND("composite", kDLCPU).add_attr_option>("devices"); -TVM_REGISTER_TARGET_KIND("ethos-u", kDLCPU).set_attr("use_device_api", Bool(true)); /********** Registry **********/ diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 3343e1062e57..a5ecf4ba8296 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -117,6 +117,11 @@ class BuiltinLower : public StmtExprMutator { // and less than runtime::kMaxStackAlloca heuristic // they are not serviced with TVMBackendWorkspaceAlloc calls // to be placed on stack. + if (op->annotations.count(transform::kDisableLowerTVMBuiltin)) { + if (Downcast(op->annotations[transform::kDisableLowerTVMBuiltin])) { + return stmt; + } + } if (device_type_.defined()) { if (const auto* dev_type = device_type_.as()) { auto storage_scope = Downcast(op->buffer_var->type_annotation)->storage_scope; diff --git a/src/tir/transforms/make_unpacked_api.cc b/src/tir/transforms/make_unpacked_api.cc index 6e8793fbd367..169983a525df 100644 --- a/src/tir/transforms/make_unpacked_api.cc +++ b/src/tir/transforms/make_unpacked_api.cc @@ -64,31 +64,21 @@ PrimFunc MakeUnpackedAPI(PrimFunc&& func) { // Collect variables and buffers to map between Array args; std::vector> var_def; - std::vector> buffer_def; + bool buffer_map_found = false; for (int i = 0; i < static_cast(func_ptr->params.size()); ++i) { Var param = func_ptr->params[i]; - Var v_arg = Var("arg" + std::to_string(i), param->dtype); auto it = func_ptr->buffer_map.find(param); if (it != func_ptr->buffer_map.end()) { - buffer_def.emplace_back(v_arg, (*it).second); + args.push_back((*it).second->data); + buffer_map_found = true; } else { - var_def.emplace_back(v_arg, param); + args.push_back(param); } - - args.push_back(v_arg); - } - - // Bind variables then bind buffers to them to ensure correct ordering - for (const auto& kv : var_def) { - binder.Bind(kv.second, kv.first, kv.first->name_hint, true); - } - for (const auto& kv : buffer_def) { - binder.Bind(kv.second->data, kv.first, kv.first->name_hint, true); } - if (buffer_def.size()) { + if (buffer_map_found) { device_init.push_back(AttrStmt(node, attr::device_id, device_id, nop)); device_init.push_back(AttrStmt(node, attr::device_type, device_type, nop)); } diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index b6cf873cb6f3..dd4108728d0a 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -154,14 +154,14 @@ def create_graph_activation(input_tensor_name, input_tensor_shape, input_tensor_ ) # Assumes only two runtime.Modules are created -- i.e. single offload module - imported_modules = compiled_models[0].executor_factory.lib.imported_modules - assert len(imported_modules) == 2 - ethosu_module = imported_modules[0] + ethosu_module = ( + compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0] + ) # Verify generated C source - get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") - cmms = get_cs(ethosu_module) - cmms = bytes.fromhex(cmms) + get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts") + compilation_artifacts = get_artifacts(ethosu_module) + cmms = bytes.fromhex(compilation_artifacts[0].command_stream) infra.print_payload(cmms) infra.verify_source(compiled_models, accel_type) @@ -241,15 +241,12 @@ def representative_dataset(): ) # Assumes only two runtime.Modules are created -- i.e. single offload module - imported_modules = compiled_models[0].executor_factory.lib.imported_modules - assert len(imported_modules) == 2 - ethosu_module = imported_modules[0] + ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0] # Verify generated C source - get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") - cmms = get_cs(ethosu_module) - cmms = bytes.fromhex(cmms) - + get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts") + compilation_artifacts = get_artifacts(ethosu_module) + cmms = bytes.fromhex(compilation_artifacts[0].command_stream) infra.print_payload(cmms) infra.verify_source(compiled_models, accel_type) @@ -328,15 +325,12 @@ def representative_dataset(): ) # Assumes only two runtime.Modules are created -- i.e. single offload module - imported_modules = compiled_models[0].executor_factory.lib.imported_modules - assert len(imported_modules) == 2 - ethosu_module = imported_modules[0] + ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0] # Verify generated C source - get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") - cmms = get_cs(ethosu_module) - cmms = bytes.fromhex(cmms) - + get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts") + compilation_artifacts = get_artifacts(ethosu_module) + cmms = bytes.fromhex(compilation_artifacts[0].command_stream) infra.print_payload(cmms) infra.verify_source(compiled_models, accel_type) @@ -423,15 +417,12 @@ def representative_dataset(): ) # Assumes only two runtime.Modules are created -- i.e. single offload module - imported_modules = compiled_models[0].executor_factory.lib.imported_modules - assert len(imported_modules) == 2 - ethosu_module = imported_modules[0] + ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0] # Verify generated C source - get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") - cmms = get_cs(ethosu_module) - cmms = bytes.fromhex(cmms) - + get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts") + compilation_artifacts = get_artifacts(ethosu_module) + cmms = bytes.fromhex(compilation_artifacts[0].command_stream) infra.print_payload(cmms) infra.verify_source(compiled_models, accel_type) @@ -501,15 +492,12 @@ def representative_dataset(): ) # Assumes only two runtime.Modules are created -- i.e. single offload module - imported_modules = compiled_models[0].executor_factory.lib.imported_modules - assert len(imported_modules) == 2 - ethosu_module = imported_modules[0] + ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0] # Verify generated C source - get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") - cmms = get_cs(ethosu_module) - cmms = bytes.fromhex(cmms) - + get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts") + compilation_artifacts = get_artifacts(ethosu_module) + cmms = bytes.fromhex(compilation_artifacts[0].command_stream) infra.print_payload(cmms) infra.verify_source(compiled_models, accel_type) @@ -551,15 +539,12 @@ def create_relay_graph(): ) # Assumes only two runtime.Modules are created -- i.e. single offload module - imported_modules = compiled_models[0].executor_factory.lib.imported_modules - assert len(imported_modules) == 2 - ethosu_module = imported_modules[0] + ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0] # Verify generated C source - get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") - cmms = get_cs(ethosu_module) - cmms = bytes.fromhex(cmms) - + get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts") + compilation_artifacts = get_artifacts(ethosu_module) + cmms = bytes.fromhex(compilation_artifacts[0].command_stream) infra.print_payload(cmms) infra.verify_source(compiled_models, accel_type) @@ -608,15 +593,12 @@ def create_model(): ) # Assumes only two runtime.Modules are created -- i.e. single offload module - imported_modules = compiled_models[0].executor_factory.lib.imported_modules - assert len(imported_modules) == 2 - ethosu_module = imported_modules[0] + ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0] # Verify generated C source - get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") - cmms = get_cs(ethosu_module) - cmms = bytes.fromhex(cmms) - + get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts") + compilation_artifacts = get_artifacts(ethosu_module) + cmms = bytes.fromhex(compilation_artifacts[0].command_stream) infra.print_payload(cmms) infra.verify_source(compiled_models, accel_type) @@ -705,18 +687,16 @@ def rounding_right_shift(lhs, rhs): [rounding_right_shift(x[0], x[1]) for x in zip(lhs.flat, rhs.flat)] ).astype(ofm_dtype) - compiled_model = infra.build_source(mod, input_data, [output_data], accel_type) - imported_modules = compiled_model[0].executor_factory.lib.imported_modules - assert len(imported_modules) == 2 - ethosu_module = imported_modules[0] + compiled_models = infra.build_source(mod, input_data, [output_data], accel_type) + # Assumes only two runtime.Modules are created -- i.e. single offload module + ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0] # Verify generated C source - get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") - cmms = get_cs(ethosu_module) - cmms = bytes.fromhex(cmms) - + get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts") + compilation_artifacts = get_artifacts(ethosu_module) + cmms = bytes.fromhex(compilation_artifacts[0].command_stream) infra.print_payload(cmms) - infra.verify_source(compiled_model, accel_type) + infra.verify_source(compiled_models, accel_type) @pytest.mark.parametrize("accel_type", ACCEL_TYPES) @@ -738,15 +718,13 @@ def test_ethosu_identity_codegen(ifm_shape, ifm_scale, ifm_zp, ofm_scale, ofm_zp mod, {"ifm": in_data}, [out_data], accel_type, output_tolerance=1 ) - imported_modules = compiled_model[0].executor_factory.lib.imported_modules - assert len(imported_modules) == 2 - ethosu_module = imported_modules[0] + # Assumes only two runtime.Modules are created -- i.e. single offload module + ethosu_module = compiled_model[0].executor_factory.lib.imported_modules[0].imported_modules[0] # Verify generated C source - get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") - cmms = get_cs(ethosu_module) - cmms = bytes.fromhex(cmms) - + get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts") + compilation_artifacts = get_artifacts(ethosu_module) + cmms = bytes.fromhex(compilation_artifacts[0].command_stream) infra.print_payload(cmms) infra.verify_source(compiled_model, accel_type) @@ -786,15 +764,13 @@ def test_relay_reshape_codegen(ifm_shape, new_shape, accel_type): accel_type, ) - imported_modules = compiled_model[0].executor_factory.lib.imported_modules - assert len(imported_modules) == 2 - ethosu_module = imported_modules[0] + # Assumes only two runtime.Modules are created -- i.e. single offload module + ethosu_module = compiled_model[0].executor_factory.lib.imported_modules[0].imported_modules[0] # Verify generated C source - get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") - cmms = get_cs(ethosu_module) - cmms = bytes.fromhex(cmms) - + get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts") + compilation_artifacts = get_artifacts(ethosu_module) + cmms = bytes.fromhex(compilation_artifacts[0].command_stream) infra.print_payload(cmms) infra.verify_source(compiled_model, accel_type) @@ -831,15 +807,13 @@ def test_relay_strided_slice_codegen(ifm_shape, begin, end, accel_type): accel_type, ) - imported_modules = compiled_model[0].executor_factory.lib.imported_modules - assert len(imported_modules) == 2 - ethosu_module = imported_modules[0] + # Assumes only two runtime.Modules are created -- i.e. single offload module + ethosu_module = compiled_model[0].executor_factory.lib.imported_modules[0].imported_modules[0] # Verify generated C source - get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") - cmms = get_cs(ethosu_module) - cmms = bytes.fromhex(cmms) - + get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts") + compilation_artifacts = get_artifacts(ethosu_module) + cmms = bytes.fromhex(compilation_artifacts[0].command_stream) infra.print_payload(cmms) infra.verify_source(compiled_model, accel_type) @@ -907,15 +881,12 @@ def representative_dataset(): ) # Assumes only two runtime.Modules are created -- i.e. single offload module - imported_modules = compiled_models[0].executor_factory.lib.imported_modules - assert len(imported_modules) == 2 - ethosu_module = imported_modules[0] + ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0] # Verify generated C source - get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") - cmms = get_cs(ethosu_module) - cmms = bytes.fromhex(cmms) - + get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts") + compilation_artifacts = get_artifacts(ethosu_module) + cmms = bytes.fromhex(compilation_artifacts[0].command_stream) infra.print_payload(cmms) infra.verify_source(compiled_models, accel_type) From 3469a2ec97ace8599a3ca829e57a5fa62b3aafeb Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Sat, 27 Nov 2021 01:10:34 +0000 Subject: [PATCH 2/7] [microNPU] Move the compilation to use Target Hooks. *Fixing unpacked API tests *Adding use_device_api target attr to example target hooks Change-Id: I72c51caa57e9a0c2a538f40eb73939e28d4f112f --- .../contrib/example_target_hooks/target.cc | 1 + .../test_tir_transform_make_unpacked_api.py | 37 +++++-------------- 2 files changed, 11 insertions(+), 27 deletions(-) diff --git a/src/relay/backend/contrib/example_target_hooks/target.cc b/src/relay/backend/contrib/example_target_hooks/target.cc index 75b161ad4499..6f1914eac4c3 100644 --- a/src/relay/backend/contrib/example_target_hooks/target.cc +++ b/src/relay/backend/contrib/example_target_hooks/target.cc @@ -33,6 +33,7 @@ runtime::Module TIRToRuntime(IRModule mod, Target target); } // namespace relay TVM_REGISTER_TARGET_KIND("example_target_hook", kDLCPU) + .set_attr("use_device_api", Bool(true)) .set_attr("RelayToTIR", relay::contrib::example_target_hooks::RelayToTIR()) .set_attr("TIRToRuntime", relay::contrib::example_target_hooks::TIRToRuntime); diff --git a/tests/python/unittest/test_tir_transform_make_unpacked_api.py b/tests/python/unittest/test_tir_transform_make_unpacked_api.py index 9d917466758b..e5f41e7b520f 100644 --- a/tests/python/unittest/test_tir_transform_make_unpacked_api.py +++ b/tests/python/unittest/test_tir_transform_make_unpacked_api.py @@ -58,7 +58,7 @@ def test_device_setup(mod, target, dev): mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target(target)))(mod) f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"] assert len(f.params) == 1 - assert f.params[0].name == "arg0" + assert f.params[0].name == "A" assert f.body.node == "default" assert f.body.attr_key == "device_id" assert f.body.value == 0 @@ -77,16 +77,13 @@ def test_no_buffers_no_device_setup(): f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"] assert len(f.params) == 1 - assert f.body.var.name == "A" - assert f.body.value.name == "arg0" + assert f.params[0].name == "A" def test_argument_mapping(mod): f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"] assert len(f.params) == 1 - assert f.params[0].name == "arg0" - assert f.body.body.body.var.name == "A" - assert f.body.body.body.value.name == "arg0" + assert f.params[0].name == "A" def test_argument_mapping_multiple(): @@ -101,12 +98,8 @@ def test_argument_mapping_multiple(): f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"] assert len(f.params) == 2 - assert f.params[0].name == "arg0" - assert f.params[1].name == "arg1" - assert f.body.body.body.var.name == "A" - assert f.body.body.body.value.name == "arg0" - assert f.body.body.body.body.var.name == "B" - assert f.body.body.body.body.value.name == "arg1" + assert f.params[0].name == "A" + assert f.params[1].name == "B" def test_argument_mapping_multiple_matching(): @@ -120,12 +113,8 @@ def test_argument_mapping_multiple_matching(): f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"] assert len(f.params) == 2 - assert f.params[0].name == "arg0" - assert f.params[1].name == "arg1" - assert f.body.body.body.var.name == "A" - assert f.body.body.body.value.name == "arg0" - assert f.body.body.body.body.condition.a.name == "A" - assert f.body.body.body.body.condition.b.name == "arg1" + assert f.params[0].name == "A" + assert f.params[1].name == "A" def test_body(): @@ -140,15 +129,9 @@ def test_body(): mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod) f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"] assert len(f.params) == 3 - assert f.params[0].name == "arg0" - assert f.params[1].name == "arg1" - assert f.params[2].name == "arg2" - assert f.body.body.body.var.name == "A" - assert f.body.body.body.value.name == "arg2" - assert f.body.body.body.body.var.name == "B" - assert f.body.body.body.body.value.name == "arg1" - assert f.body.body.body.body.body.condition.a.name == "A" - assert f.body.body.body.body.body.condition.b.name == "arg0" + assert f.params[0].name == "A" + assert f.params[1].name == "B" + assert f.params[2].name == "A" if __name__ == "__main__": From 29ae4fd718cddf5db8f378282dc771e319dce0e7 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Sat, 27 Nov 2021 08:32:29 +0000 Subject: [PATCH 3/7] [microNPU] Move the compilation to use Target Hooks. * Modifed CLZ test case to support target hooks * Modifed reference TIR for test to include allocate annotation * TIR to CS translation tests are modified to run MakeUnpackedAPI Change-Id: I3a3d28777a6995e7f2b8789e14c5cb0f280dc763 --- .../contrib/test_ethosu/test_codegen.py | 24 +++++++++---------- .../test_ethosu/test_encode_constants.py | 16 ++++++++----- .../test_ethosu/test_replace_conv2d.py | 8 +++---- .../contrib/test_ethosu/test_replace_copy.py | 8 +++---- .../test_ethosu/test_tir_to_cs_translator.py | 20 +++++++++------- 5 files changed, 41 insertions(+), 35 deletions(-) diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index dd4108728d0a..5008f51881c5 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -928,16 +928,18 @@ def create_graph_single(input_tensor_name, input_tensor_shape, input_tensor_dtyp ) # Assumes only two runtime.Modules are created -- i.e. single offload module - imported_modules = compiled_models[0].executor_factory.lib.imported_modules - assert len(imported_modules) == 2 - ethosu_module = imported_modules[0] + ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0] # Verify generated C source source = ethosu_module.get_source() assert ( - '__attribute__((section(".rodata.tvm"), aligned(16))) static int8_t cms_data_data' in source + '__attribute__((section(".rodata.tvm"), aligned(16))) static int8_t tvmgen_default_ethos_u_main_0_cms_data_data' + in source + ) + assert ( + '__attribute__((section(".rodata.tvm"), aligned(16))) static int8_t tvmgen_default_ethos_u_main_0_weights' + in source ) - assert '__attribute__((section(".rodata.tvm"), aligned(16))) static int8_t weights' in source @pytest.mark.parametrize("accel_type", ACCEL_TYPES) @@ -961,15 +963,13 @@ def clz_comp(n): compiled_model = infra.build_source(mod, {"ifm": in_data}, [out_data], accel_type) - imported_modules = compiled_model[0].executor_factory.lib.imported_modules - assert len(imported_modules) == 2 - ethosu_module = imported_modules[0] + # Assumes only two runtime.Modules are created -- i.e. single offload module + ethosu_module = compiled_model[0].executor_factory.lib.imported_modules[0].imported_modules[0] # Verify generated C source - get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") - cmms = get_cs(ethosu_module) - cmms = bytes.fromhex(cmms) - + get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts") + compilation_artifacts = get_artifacts(ethosu_module) + cmms = bytes.fromhex(compilation_artifacts[0].command_stream) infra.print_payload(cmms) infra.verify_source(compiled_model, accel_type) diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py index 91cee81a1565..de8a7f922390 100644 --- a/tests/python/contrib/test_ethosu/test_encode_constants.py +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -48,8 +48,8 @@ def main(placeholder: T.handle, ethosu_write: T.handle, placeholder_1: T.handle, ethosu_write_1 = T.match_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) buffer_7 = T.match_buffer(placeholder_6, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body - placeholder_global = T.allocate([128], "uint8", "global") - placeholder_d_global = T.allocate([32], "uint8", "global") + placeholder_global = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin": True}) + placeholder_d_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin": True}) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_6.data, 0), 128, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2.data, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_9.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 128, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) @@ -122,7 +122,7 @@ def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle buffer_2 = T.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) buffer_3 = T.match_buffer(placeholder_4, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body - ethosu_write_2 = T.allocate([4096], "int8", "global") + ethosu_write_2 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin": True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_1.data, 0), 592, 12, T.load("uint8", buffer_2.data, 0), 160, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 8, 16, 0, 16, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer.data, 0), 160, 12, T.load("uint8", buffer_3.data, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None @@ -190,9 +190,9 @@ def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle buffer_8 = T.match_buffer(placeholder_8, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) buffer_9 = T.match_buffer(placeholder_10, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body - ethosu_write_2 = T.allocate([4096], "int8", "global") - placeholder_global = T.allocate([80], "uint8", "global") - placeholder_d_global = T.allocate([32], "uint8", "global") + ethosu_write_2 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin": True}) + placeholder_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin": True}) + placeholder_d_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin": True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_11.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_6.data, 0), 592, 12, T.load("uint8", buffer_7.data, 0), 160, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2.data, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_3.data, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) @@ -312,6 +312,10 @@ def get_graph(): # More generally, check compiles successfully to make sure # nothing else was overrwritten. + # With Target Hooks the TIR module needs a target attached + # and lowered via make unpacked API. + tir_mod["main"] = tir_mod["main"].with_attr("target", tvm.target.Target("ethos-u")) + tir_mod = tvm.tir.transform.MakeUnpackedAPI()(tir_mod) tir_to_cs_translator.translate(tir_mod, params) diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_conv2d.py index 7992f421a5bd..e1f9beff66a7 100644 --- a/tests/python/contrib/test_ethosu/test_replace_conv2d.py +++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py @@ -257,7 +257,7 @@ def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle ethosu_write_1 = T.match_buffer(ethosu_write, [1, 8, 8, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) buffer_3 = T.match_buffer(placeholder_1, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body - ethosu_write_2 = T.allocate([1024], "int8", "global") + ethosu_write_2 = T.allocate([1024], "int8", "global", annotations={"disable_lower_builtin": True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_3.data, 0), 160, 12, T.load("uint8", buffer_2.data, 0), 320, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 128, 32, 1, "int8", 8, 4, 8, 8, 0, 4, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer.data, 0), 304, 12, T.load("uint8", buffer_1.data, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, T.load("int8", placeholder_5.data, 12), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_3.data, 0), 160, 12, T.load("uint8", buffer_2.data, 0), 320, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) @@ -278,7 +278,7 @@ def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle placeholder_5 = T.match_buffer(placeholder, [1, 8, 8, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) ethosu_write_1 = T.match_buffer(ethosu_write, [1, 8, 8, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body - ethosu_write_2 = T.allocate([1536], "int8", "global") + ethosu_write_2 = T.allocate([1536], "int8", "global", annotations={"disable_lower_builtin": True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 256), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_2.data, 0), 1312, 12, T.load("uint8", buffer_1.data, 0), 320, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 256), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 8, 8, 4, 0, 8, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_3.data, 0), 2608, 12, T.load("uint8", buffer.data, 0), 80, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 48), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_2.data, 0), 1312, 12, T.load("uint8", buffer_1.data, 0), 320, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) @@ -299,7 +299,7 @@ def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle buffer_3 = T.match_buffer(placeholder_1, [880], dtype="uint8", elem_offset=0, align=128, offset_factor=1) placeholder_5 = T.match_buffer(placeholder, [1, 16, 16, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body - ethosu_write_2 = T.allocate([2560], "int8", "global") + ethosu_write_2 = T.allocate([2560], "int8", "global", annotations={"disable_lower_builtin": True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 3, 8, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 8, 8, 32, 8, 0, 8, T.load("int8", ethosu_write_2, 512), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer_3.data, 0), 880, 12, T.load("uint8", buffer_2.data, 0), 320, 2, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 32, 8, 0, 8, T.load("int8", ethosu_write_2, 512), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 8, 4, 8, 8, 0, 4, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer.data, 0), 1744, 12, T.load("uint8", buffer_1.data, 0), 80, 2, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 12, 16, 3, 12, 0, 16, T.load("int8", placeholder_5.data, 192), 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 10, 8, 32, 10, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer_3.data, 0), 880, 12, T.load("uint8", buffer_2.data, 0), 320, 0, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) @@ -322,7 +322,7 @@ def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle buffer_2 = T.match_buffer(placeholder_4, [272], dtype="uint8", elem_offset=0, align=128, offset_factor=1) buffer_3 = T.match_buffer(placeholder_3, [11040], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body - ethosu_write_2 = T.allocate([2304], "int8", "global") + ethosu_write_2 = T.allocate([2304], "int8", "global", annotations={"disable_lower_builtin": True}) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 384), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer.data, 0), 1456, 12, T.load("uint8", buffer_1.data, 0), 352, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 384), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 384, 16, 128, "int8", 4, 8, 26, 4, 0, 8, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 256, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_3.data, 0), 11040, 12, T.load("uint8", buffer_2.data, 0), 272, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 256), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer.data, 0), 1456, 12, T.load("uint8", buffer_1.data, 0), 352, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) diff --git a/tests/python/contrib/test_ethosu/test_replace_copy.py b/tests/python/contrib/test_ethosu/test_replace_copy.py index b1f923de4646..cce414c4c8f7 100644 --- a/tests/python/contrib/test_ethosu/test_replace_copy.py +++ b/tests/python/contrib/test_ethosu/test_replace_copy.py @@ -39,8 +39,8 @@ def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle buffer_1 = T.match_buffer(placeholder_1, [304], dtype="uint8", elem_offset=0, align=128, offset_factor=1) ethosu_write_1 = T.match_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body - placeholder_global = T.allocate([304], "uint8", "global") - placeholder_d_global = T.allocate([80], "uint8", "global") + placeholder_global = T.allocate([304], "uint8", "global", annotations={"disable_lower_builtin": True}) + placeholder_d_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin": True}) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1.data, 0), 304, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer.data, 0), 80, T.load("uint8", placeholder_d_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 8, 16, 0, 16, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 304, 12, T.load("uint8", placeholder_d_global, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) @@ -87,8 +87,8 @@ def main(placeholder: T.handle, ethosu_write: T.handle, placeholder_1: T.handle, buffer_2 = T.match_buffer(placeholder_3, [272], dtype="uint8") buffer_3 = T.match_buffer(placeholder_4, [64], dtype="uint8") # body - placeholder_global = T.allocate([416], "uint8", "global") - placeholder_d_global = T.allocate([112], "uint8", "global") + placeholder_global = T.allocate([416], "uint8", "global", annotations={"disable_lower_builtin": True}) + placeholder_d_global = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin": True}) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer.data, 0), 416, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1.data, 0), 112, T.load("uint8", placeholder_d_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 10, 16, 0, 16, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 416, 12, T.load("uint8", placeholder_d_global, 0), 112, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) diff --git a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py index 94c8f0ddc04e..59b7b2c21723 100644 --- a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py +++ b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py @@ -233,9 +233,12 @@ def test_buffer_info_extraction(): }, ] for test_case in test_cases: - buffer_info = tir_to_cs_translator.extract_buffer_info( - test_case["tir_module"], test_case["param_dict"] - ) + # With Target Hooks the TIR module needs a target attached + # and lowered via make unpacked API. + tir_mod = test_case["tir_module"] + tir_mod["main"] = tir_mod["main"].with_attr("target", tvm.target.Target("ethos-u")) + tir_mod = tvm.tir.transform.MakeUnpackedAPI()(tir_mod) + buffer_info = tir_to_cs_translator.extract_buffer_info(tir_mod, test_case["param_dict"]) for buffer_var, info in buffer_info.items(): buffer_name = buffer_var.name if buffer_name in test_case["constants"].keys(): @@ -247,8 +250,6 @@ def test_buffer_info_extraction(): ) info.btype == tir_to_cs_translator.BufferType.constant else: - assert list(info.shape) == test_case["data_buffers"][buffer_name][0] - assert info.dtype == test_case["data_buffers"][buffer_name][1] assert info.btype == test_case["data_buffers"][buffer_name][2] @@ -831,10 +832,11 @@ def check_buffer(address, region, length, buffer_var): ) for test_case in test_cases: - buffer_info = tir_to_cs_translator.extract_buffer_info( - test_case["tir_module"], test_case["param_dict"] - ) - extern_calls = extract_call_extern_list(test_case["tir_module"]) + tir_mod = test_case["tir_module"] + tir_mod["main"] = tir_mod["main"].with_attr("target", tvm.target.Target("ethos-u")) + tir_mod = tvm.tir.transform.MakeUnpackedAPI()(tir_mod) + buffer_info = tir_to_cs_translator.extract_buffer_info(tir_mod, test_case["param_dict"]) + extern_calls = extract_call_extern_list(tir_mod) _npu_ops = list() for extern_call in extern_calls: _npu_ops.append(tir_to_cs_translator.translate_ethosu_tir_call_extern(extern_call)) From 1447f4f379844ee74fd6f65e6fce9a08036100e3 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Sat, 27 Nov 2021 13:21:24 +0000 Subject: [PATCH 4/7] [microNPU] Move the compilation to use Target Hooks. * Added a missed documentation to changes in source module * Skipping device api test for packed API as microNPU does not support it. Change-Id: I6da1adcf8fdd3f972ec9b37ff530ff673e93058c --- src/relay/backend/contrib/ethosu/source_module.cc | 8 ++++++++ tests/python/relay/aot/test_c_device_api.py | 5 ++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/relay/backend/contrib/ethosu/source_module.cc b/src/relay/backend/contrib/ethosu/source_module.cc index 6515df6245ae..d6af5c10b786 100644 --- a/src/relay/backend/contrib/ethosu/source_module.cc +++ b/src/relay/backend/contrib/ethosu/source_module.cc @@ -52,6 +52,14 @@ using CompilationArtifact = relay::contrib::ethosu::CompilationArtifact; // required for invoking the NPU with the command stream class EthosUModuleNode : public ModuleNode { public: + /*! + * \brief The microNPU runtime module. + * + * \param compilation_artifacts + * This is an array of CompilationArtifacts that is produced via + * lowering each PrimFunc to command stream. Here, those artifacts + * will be used to create the c-source. + */ explicit EthosUModuleNode(Array compilation_artifacts) : compilation_artifacts_(compilation_artifacts) { c_source += "#include \n"; diff --git a/tests/python/relay/aot/test_c_device_api.py b/tests/python/relay/aot/test_c_device_api.py index 3de4fecf5544..473b8d5ee300 100644 --- a/tests/python/relay/aot/test_c_device_api.py +++ b/tests/python/relay/aot/test_c_device_api.py @@ -92,7 +92,7 @@ def compile_to_main_func(interface_api="c", use_unpacked_api=True): workspace_byte_alignment=16, pass_config=test_runner.pass_config, ) - main_ir_module = list(compiled_models[0].executor_factory.lowered_ir_mods.values())[0] + main_ir_module = compiled_models[0].executor_factory.lowered_ir_mods.items()[1][1] main_func = main_ir_module["run_model"] return main_func @@ -177,6 +177,9 @@ def test_device_api_hooks_unpacked_api(device_api_main_func): ) +@pytest.mark.skip( + "Skipping this test as this is incorrectly using Arm(R) Ethos(TM)-U NPU with packed calling convention which is not supported by the NPU codegen's TIR to Runtime Hook. We need to use a different target to test this feature" +) def test_device_api_hooks_packed_api(device_api_main_func): """Check for Device API hooks with packed internal calls""" main_func = device_api_main_func(interface_api="packed", use_unpacked_api=False) From 0d0485442972a479462c45f48b6a95a6eae84672 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Sat, 27 Nov 2021 20:31:00 +0000 Subject: [PATCH 5/7] [microNPU] Move the compilation to use Target Hooks. * fixed tvmc test use unpacked-api for microNPU compilation Change-Id: Ib722d91ca3b3e4c6d13075ee0873acb86f487247 --- tests/python/driver/tvmc/test_compiler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index 1bb854c1cf0a..4918e641adae 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -483,8 +483,8 @@ def test_compile_tflite_module_with_external_codegen_ethosu( tvmc.compiler.compile_model( tvmc_model, target=f"ethos-u -accelerator_config={accel_type}, c -mcpu=cortex-m55", - runtime=Runtime("crt", {"system-lib": True}), - executor=Executor("aot"), + runtime=Runtime("crt"), + executor=Executor("aot", {"unpacked-api": True}), output_format="mlf", package_path=output_file_name, pass_context_configs=["tir.disable_vectorize=true"], From 114709cef87dc2248978b36f4dcaac73d15e9a25 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Sun, 28 Nov 2021 08:10:10 +0000 Subject: [PATCH 6/7] [microNPU] Move the compilation to use Target Hooks. * adjust target name. Change-Id: I862957324440705fb6093939b97b1a00fa1d4b46 --- python/tvm/relay/backend/contrib/ethosu/codegen.py | 8 ++++---- src/relay/backend/contrib/ethosu/codegen.cc | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/codegen.py b/python/tvm/relay/backend/contrib/ethosu/codegen.py index 9fabba3f19f7..df0868641808 100644 --- a/python/tvm/relay/backend/contrib/ethosu/codegen.py +++ b/python/tvm/relay/backend/contrib/ethosu/codegen.py @@ -25,7 +25,7 @@ from tvm.relay.backend.contrib.ethosu import util -@tvm._ffi.register_func("relay.ext.ethosu.constant_updater") +@tvm._ffi.register_func("relay.ext.ethos-u.constant_updater") def constant_updater(expr, symbol): # pylint: disable=unused-argument """ The constant updater process happen after lowering in the core compiler. @@ -35,11 +35,11 @@ def constant_updater(expr, symbol): # pylint: disable=unused-argument return dict() -@tvm._ffi.register_func("relay.ext.ethosu.relay_to_tir_func") +@tvm._ffi.register_func("relay.ext.ethos-u.relay_to_tir_func") def relay_to_tir_func(ext_func: relay.Function) -> tvm.tir.PrimFunc: """ This is hook for python-based lowering of relay function - that gets offloaded to Ethos-U. + that gets offloaded to the microNPU. Parameters ---------- @@ -75,7 +75,7 @@ def relay_to_tir_func(ext_func: relay.Function) -> tvm.tir.PrimFunc: return primfunc -@tvm._ffi.register_func("relay.ext.ethosu.primfunc_to_artifact") +@tvm._ffi.register_func("relay.ext.ethos-u.primfunc_to_artifact") def primfunc_to_artifact(primfunc: tvm.tir.PrimFunc) -> util.CompilationArtifact: """ This is hook for python-based lowering of TIR PrimFunc diff --git a/src/relay/backend/contrib/ethosu/codegen.cc b/src/relay/backend/contrib/ethosu/codegen.cc index d7f9510fc463..d618a4971189 100644 --- a/src/relay/backend/contrib/ethosu/codegen.cc +++ b/src/relay/backend/contrib/ethosu/codegen.cc @@ -75,7 +75,7 @@ class RelayToTIRMutator : public MixedModeMutator { auto codegen_name = func->GetAttr(attr::kCompiler); if (codegen_name.defined() && codegen_name == "ethos-u") { auto relay_to_tir_func_pf = - tvm::runtime::Registry::Get("relay.ext.ethosu.relay_to_tir_func"); + tvm::runtime::Registry::Get("relay.ext.ethos-u.relay_to_tir_func"); ICHECK(relay_to_tir_func_pf); tir::PrimFunc prim_func = (*relay_to_tir_func_pf)(func); prim_func = WithAttr(prim_func, tvm::attr::kTarget, Target("ethos-u")); @@ -101,7 +101,7 @@ tvm::transform::Pass RelayToTIR() { [=](IRModule ir_module, transform::PassContext pass_context) { return RelayToTIRMutator(ir_module)(); }; - return tvm::transform::CreateModulePass(pass_func, 0, "relay.contrib.ethosu.RelayToTIR", {}); + return tvm::transform::CreateModulePass(pass_func, 0, "relay.contrib.ethos-u.RelayToTIR", {}); } /*! @@ -116,7 +116,7 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) { prim_func->GetAttr>("ethos-u.constants"); ICHECK(params) << "microNPU params should be present"; auto primfunc_to_artifact_pf = - tvm::runtime::Registry::Get("relay.ext.ethosu.primfunc_to_artifact"); + tvm::runtime::Registry::Get("relay.ext.ethos-u.primfunc_to_artifact"); ICHECK(primfunc_to_artifact_pf); CompilationArtifact ca = (*primfunc_to_artifact_pf)(prim_func); compile_artifacts.push_back(ca); From 2e0c58176e0f1c38ba089f1b308c3b62813ead46 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Mon, 29 Nov 2021 14:57:25 +0000 Subject: [PATCH 7/7] [microNPU] follow up on using target hooks * Fixed few typos and cleaned up as per suggestions Change-Id: I2a744a4bc4015e1884dbef4165252aa13aa30b31 --- python/tvm/relay/backend/contrib/ethosu/codegen.py | 6 +++--- src/relay/backend/contrib/ethosu/source_module.cc | 6 ++---- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/codegen.py b/python/tvm/relay/backend/contrib/ethosu/codegen.py index df0868641808..b78086260635 100644 --- a/python/tvm/relay/backend/contrib/ethosu/codegen.py +++ b/python/tvm/relay/backend/contrib/ethosu/codegen.py @@ -38,7 +38,7 @@ def constant_updater(expr, symbol): # pylint: disable=unused-argument @tvm._ffi.register_func("relay.ext.ethos-u.relay_to_tir_func") def relay_to_tir_func(ext_func: relay.Function) -> tvm.tir.PrimFunc: """ - This is hook for python-based lowering of relay function + This is the hook for python-based lowering of relay function that gets offloaded to the microNPU. Parameters @@ -78,14 +78,14 @@ def relay_to_tir_func(ext_func: relay.Function) -> tvm.tir.PrimFunc: @tvm._ffi.register_func("relay.ext.ethos-u.primfunc_to_artifact") def primfunc_to_artifact(primfunc: tvm.tir.PrimFunc) -> util.CompilationArtifact: """ - This is hook for python-based lowering of TIR PrimFunc + This is the hook for python-based lowering of TIR PrimFunc that has undergone unified optimization to Compilation Artifact destined for the microNPU. Parameters ---------- primfunc : tir.PrimFunc - TIR PrimFuncthat has undergone unified optimization + TIR PrimFunc that has undergone unified optimizations Returns ------- diff --git a/src/relay/backend/contrib/ethosu/source_module.cc b/src/relay/backend/contrib/ethosu/source_module.cc index d6af5c10b786..f56544aee99a 100644 --- a/src/relay/backend/contrib/ethosu/source_module.cc +++ b/src/relay/backend/contrib/ethosu/source_module.cc @@ -65,12 +65,10 @@ class EthosUModuleNode : public ModuleNode { c_source += "#include \n"; c_source += "#include \n"; c_source += "#include \n"; - c_source += "#include \n"; - c_source += "\n"; + c_source += "#include \n\n"; for (const CompilationArtifact& compilation_artifact : compilation_artifacts) { c_source += GenerateSource(compilation_artifact); - c_source += "\n"; - c_source += "\n"; + c_source += "\n\n"; } }