From 0316e879786f4869c80be36b01ece56d13b41718 Mon Sep 17 00:00:00 2001 From: Chris Sidebottom Date: Wed, 27 Oct 2021 18:09:36 +0100 Subject: [PATCH 1/3] [CMSIS-NN] Convert CMSIS-NN to use Target Hooks This migrates CMSIS-NN to use Target Hooks instead of fully BYOC, which means it will now go through any central passes the Driver API. Found a few things whilst doing this: * Forgot to mutate PrimFunc arguments in LowerTE which meant functions weren't getting lowered passed the first function in test_networks * Target `cmsis-nn` needs to match external code generator `cmsis-nn` to connect the Target with the external code generator * Partition Graph needed to sanitise compiler names to generate them properly in C --- python/tvm/relay/op/contrib/cmsisnn.py | 15 ++- src/driver/driver_api.cc | 5 +- .../backend/contrib/cmsisnn/relay_to_tir.cc | 127 ++++++++++-------- .../cmsisnn/{codegen_cmsisnn.cc => target.cc} | 29 ++-- .../backend/contrib/cmsisnn/tir_to_runtime.cc | 104 +++----------- src/relay/backend/te_compiler.cc | 17 +-- src/relay/transforms/partition_graph.cc | 6 +- .../contrib/test_cmsisnn/test_binary_ops.py | 2 +- .../contrib/test_cmsisnn/test_softmax.py | 2 +- .../python/relay/test_pass_partition_graph.py | 43 ++++++ 10 files changed, 172 insertions(+), 178 deletions(-) rename src/relay/backend/contrib/cmsisnn/{codegen_cmsisnn.cc => target.cc} (57%) diff --git a/python/tvm/relay/op/contrib/cmsisnn.py b/python/tvm/relay/op/contrib/cmsisnn.py index cf0e9156e65f..824343e0066b 100644 --- a/python/tvm/relay/op/contrib/cmsisnn.py +++ b/python/tvm/relay/op/contrib/cmsisnn.py @@ -17,6 +17,7 @@ # pylint: disable=invalid-name, unused-argument """Arm(R) CMSIS-NN supported operators for Cortex-M.""" import tvm.ir +from tvm.target import Target from tvm.relay import transform from tvm.relay.build_module import bind_params_by_name @@ -25,7 +26,7 @@ def enabled(): - return bool(tvm.get_global_func("relay.ext.cmsisnn", True)) + return "cmsis-nn" in Target.list_kinds() def partition_for_cmsisnn(mod, params=None, **opts): @@ -51,7 +52,7 @@ def partition_for_cmsisnn(mod, params=None, **opts): [ transform.InferType(), transform.MergeComposite(pattern_table()), - transform.AnnotateTarget("cmsisnn"), + transform.AnnotateTarget("cmsis-nn"), transform.MergeCompilerRegions(), transform.PartitionGraph(), ] @@ -60,9 +61,9 @@ def partition_for_cmsisnn(mod, params=None, **opts): return seq(mod) -@register_pattern_table("cmsisnn") +@register_pattern_table("cmsis-nn") def pattern_table(): - """Get the cmsisnn compiler pattern table.""" + """Get the CMSIS-NN compiler pattern table.""" def softmax_pattern(): pattern = is_op("qnn.dequantize")(wildcard(), is_constant(), is_constant()) @@ -104,14 +105,14 @@ def check_quantized_binary_op(extract): ) return [ - ("cmsisnn.quantized_softmax", softmax_pattern(), check_quantized_softmax), + ("cmsis-nn.quantized_softmax", softmax_pattern(), check_quantized_softmax), ( - "cmsisnn.quantized_mul", + "cmsis-nn.quantized_mul", binary_op_pattern("mul"), check_quantized_binary_op, ), ( - "cmsisnn.quantized_add", + "cmsis-nn.quantized_add", binary_op_pattern("add"), check_quantized_binary_op, ), diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 34661f81c847..b1e56bdb2078 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -607,7 +607,10 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) mixed_pass_list.push_back(tir::transform::InferFragment()); mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce()); - if (target->GetAttr("unpacked-api").value_or(Bool(false))) { + // The host Target contains these parameters at the moment rather than + // the specific Target + // TODO(Mousius) - Move these to the Executor object rather than Target + if (target->GetHost().value()->GetAttr("unpacked-api").value_or(Bool(false))) { mixed_pass_list.push_back(tir::transform::MakeUnpackedAPI()); } else { mixed_pass_list.push_back(tir::transform::MakePackedAPI(-1)); diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc index 3c3346340f04..bd0ac52330d5 100644 --- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc +++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc @@ -17,6 +17,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include #include @@ -33,29 +34,46 @@ namespace relay { namespace contrib { namespace cmsisnn { -class RelayToTIRVisitor : public MixedModeVisitor { +class RelayToTIRVisitor : public MixedModeMutator { public: - explicit RelayToTIRVisitor(String func_name) : func_name_(func_name) {} + explicit RelayToTIRVisitor(IRModule ir_module, Target target) + : ir_module_(ir_module), target_(target) {} - tir::PrimFunc GetReplacementPrimFunc() { return primfunc_; } + IRModule Mutate() { + GlobalVar main_global_var = ir_module_->GetGlobalVar("main"); + BaseFunc main = ir_module_->Lookup(main_global_var); + Function main_func = GetRef(main.as()); + + // Copy everything across and mutate the body + Function mutated_main = + Function(main_func->params, VisitExpr(main_func->body), main_func->ret_type, + main_func->type_params, main_func->attrs, main_func->span); + + ir_module_->Update(main_global_var, mutated_main); + + return ir_module_; + } private: inline IntImm ToArg(int32_t value) { return IntImm(DataType::Int(32), value); } - void CreatePrimFuncForExtern(Array func_signature, + void CreatePrimFuncForExtern(const GlobalVar& global_var, Array func_signature, tvm::Array call_extern_args) { Map dict_attrs; - dict_attrs.Set("global_symbol", func_name_); + dict_attrs.Set(tvm::attr::kGlobalSymbol, global_var->name_hint); + dict_attrs.Set(tvm::attr::kTarget, target_); dict_attrs.Set("tir.noalias", Bool(true)); tir::Stmt body = tir::Evaluate( tvm::tir::Call(DataType::Int(8), tir::builtin::call_extern(), call_extern_args)); - primfunc_ = tir::PrimFunc(func_signature, body, VoidType(), Map(), - DictAttrs(dict_attrs)); + tir::PrimFunc replacement_func(func_signature, body, VoidType(), Map(), + DictAttrs(dict_attrs)); + + ir_module_->Add(global_var, replacement_func); } - void EmitSoftMax(const Expr& expr) { + void EmitSoftMax(const GlobalVar& global_var, const Expr& expr) { auto* quantize_call = expr.as(); auto* softmax_call = quantize_call->args[0].as(); auto* dequant_call = softmax_call->args[0].as(); @@ -102,10 +120,10 @@ class RelayToTIRVisitor : public MixedModeVisitor { out_var, }; - CreatePrimFuncForExtern(func_signature, args); + CreatePrimFuncForExtern(global_var, func_signature, args); } - void EmitMul(const Expr& expr) { + void EmitMul(const GlobalVar& global_var, const Expr& expr) { auto* mul_call = expr.as(); const float input_0_scale = GetScalarFromConstant(mul_call->args[2]); @@ -145,10 +163,10 @@ class RelayToTIRVisitor : public MixedModeVisitor { tensor_size, }; - CreatePrimFuncForExtern(func_signature, args); + CreatePrimFuncForExtern(global_var, func_signature, args); } - void EmitAdd(const Expr& expr) { + void EmitAdd(const GlobalVar& global_var, const Expr& expr) { auto* add_call = expr.as(); const float input_0_scale = GetScalarFromConstant(add_call->args[2]); @@ -212,58 +230,59 @@ class RelayToTIRVisitor : public MixedModeVisitor { tensor_size, }; - CreatePrimFuncForExtern(func_signature, args); + CreatePrimFuncForExtern(global_var, func_signature, args); } - void VisitExpr_(const CallNode* call) final { - auto* func = call->op.as(); - if (func == nullptr) { - return; - } - - auto comp_name = func->GetAttr(attr::kComposite); - if (comp_name.defined()) { - if (comp_name == "cmsisnn.quantized_softmax") { - EmitSoftMax(func->body); + Expr Rewrite_(const CallNode* pre, const Expr& post) override { + if (const CallNode* call = post.as()) { + auto* func = call->op.as(); + if (func == nullptr) { + return post; } - if (comp_name == "cmsisnn.quantized_mul") { - EmitMul(func->body); - } - if (comp_name == "cmsisnn.quantized_add") { - EmitAdd(func->body); + + auto codegen_name = func->GetAttr(attr::kCompiler); + if (codegen_name.defined() && codegen_name == "cmsis-nn") { + const CallNode* inner_call = func->body.as(); + const FunctionNode* composite_func = inner_call->op.as(); + auto comp_name = composite_func->GetAttr(attr::kComposite); + auto func_name = func->GetAttr(::tvm::attr::kGlobalSymbol); + + GlobalVar new_global_var(func_name.value()); + new_global_var->checked_type_ = composite_func->checked_type(); + + if (comp_name == "cmsis-nn.quantized_softmax") { + EmitSoftMax(new_global_var, composite_func->body); + } + if (comp_name == "cmsis-nn.quantized_mul") { + EmitMul(new_global_var, composite_func->body); + } + if (comp_name == "cmsis-nn.quantized_add") { + EmitAdd(new_global_var, composite_func->body); + } + + Array args; + for (const auto& arg : call->args) { + args.push_back(VisitExpr(arg)); + } + + return Call(new_global_var, args, call->attrs, call->type_args, call->span); } } - } - - public: - String func_name_; - tir::PrimFunc primfunc_; -}; - -IRModule GenerateTIR(IRModule mod) { - String func_name; - Function func; - // Obtain external Relay Function that needs to be translated into TIR - ICHECK(mod->functions.size() == 1) << "Supports modules with single external Relay function."; - for (auto kv : mod->functions) { - func = Downcast(kv.second); - func_name = func->GetAttr(tvm::attr::kGlobalSymbol).value(); + return post; } - // Prepare PrimFunc from Relay Function - auto relay_to_tir = RelayToTIRVisitor(func_name); - relay_to_tir.VisitExpr(func->body); - - // Build the TIR IRModule from the generated PrimFunc - Map var_func_map; - var_func_map.Set(GlobalVar(func_name), relay_to_tir.GetReplacementPrimFunc()); - return IRModule(var_func_map); -} + private: + IRModule ir_module_; + Target target_; +}; -transform::Pass RelayToTIR() { +tvm::transform::Pass RelayToTIR() { runtime::TypedPackedFunc pass_func = - [=](IRModule m, transform::PassContext pc) { return GenerateTIR(m); }; + [=](IRModule ir_module, transform::PassContext pass_context) { + auto relay_to_tir = RelayToTIRVisitor(ir_module, Target("cmsis-nn")); + return relay_to_tir.Mutate(); + }; return tvm::transform::CreateModulePass(pass_func, 0, "RelayToTIR", {}); } diff --git a/src/relay/backend/contrib/cmsisnn/codegen_cmsisnn.cc b/src/relay/backend/contrib/cmsisnn/target.cc similarity index 57% rename from src/relay/backend/contrib/cmsisnn/codegen_cmsisnn.cc rename to src/relay/backend/contrib/cmsisnn/target.cc index c8094109771b..99bc0bc7cb20 100644 --- a/src/relay/backend/contrib/cmsisnn/codegen_cmsisnn.cc +++ b/src/relay/backend/contrib/cmsisnn/target.cc @@ -1,3 +1,4 @@ + /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file @@ -16,34 +17,22 @@ * specific language governing permissions and limitations * under the License. */ + #include -#include -#include +#include namespace tvm { + namespace relay { namespace contrib { namespace cmsisnn { -transform::Pass RelayToTIR(); - -runtime::Module CompileCMSISNN(const ObjectRef& ref) { - IRModule relay_mod; - Function relay_func = Downcast(ref); - auto func_name = relay_func->GetAttr(tvm::attr::kGlobalSymbol); - GlobalVar var = GlobalVar(func_name.value()); - relay_mod->Add(var, relay_func); - relay_mod = transform::InferType()(relay_mod); - - Array pass_seqs{transform::InferType(), RelayToTIR()}; - transform::Sequential seq(pass_seqs); - IRModule tir_mod = seq(relay_mod); - - const auto* pf = runtime::Registry::Get("runtime.CMSISNNModuleNodeCreate"); - return (*pf)(tir_mod); -} +tvm::transform::Pass RelayToTIR(); +runtime::Module TIRToRuntime(IRModule mod, Target target); -TVM_REGISTER_GLOBAL("relay.ext.cmsisnn").set_body_typed(CompileCMSISNN); +TVM_REGISTER_TARGET_KIND("cmsis-nn", kDLCPU) + .set_attr("RelayToTIR", RelayToTIR()) + .set_attr("TIRToRuntime", TIRToRuntime); } // namespace cmsisnn } // namespace contrib diff --git a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc index fb612e70311b..7350107d186c 100644 --- a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc +++ b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc @@ -25,21 +25,23 @@ #include "../../../../runtime/file_utils.h" #include "../../../../target/source/codegen_c.h" +#include "../../../../target/source/codegen_c_host.h" namespace tvm { -namespace codegen { - using namespace tir; +namespace relay { +namespace contrib { +namespace cmsisnn { -class CodeGenCMSISNN : public CodeGenC { +class CodeGenCMSISNN : public codegen::CodeGenCHost { public: - void Init(bool output_ssa) { + void Init(bool output_ssa, bool emit_asserts, std::string target_str) { decl_stream << "#include \n"; decl_stream << "#include \n"; decl_stream << "#include \n"; decl_stream << "#include \n"; decl_stream << "#include \n"; - CodeGenC::Init(output_ssa); + CodeGenCHost::Init(output_ssa, emit_asserts, target_str); } /*! @@ -47,92 +49,26 @@ class CodeGenCMSISNN : public CodeGenC { * * \return string of code that offloads a subgraph to the Cortex-M */ - void AddFunction(const PrimFunc& prim_func) { - PrintExternCPrefix(stream); - CodeGenC::AddFunction(prim_func); - PrintExternCPostfix(stream); - } - - private: - /*! * \brief Creates a cplusplus guard prefix for extern "C" printing */ - void PrintExternCPrefix(std::ostringstream& ss) { - PrintIndent(); - ss << "#ifdef __cplusplus\n"; - ss << "extern \"C\" {\n"; - ss << "#endif\n"; - } - - /*! * \brief Creates a cplusplus guard postfix for extern "C" printing */ - void PrintExternCPostfix(std::ostringstream& ss) { - PrintIndent(); - ss << "#ifdef __cplusplus\n"; - ss << "}\n"; - ss << "#endif\n"; - } + void AddFunction(const PrimFunc& prim_func) { CodeGenC::AddFunction(prim_func); } }; -class CMSISNNModuleNode : public runtime::ModuleNode { - public: - CMSISNNModuleNode(const std::string& code, const std::string& fmt, - const Array& func_names) - : code_(code), fmt_(fmt), func_names_(func_names) {} - - std::string GetSource(const std::string& format) final { return code_; } - - const char* type_key() const { return "c"; } - - PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { - if (name == "get_symbol") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->func_names_[0]; }); - } else if (name == "get_func_names") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->func_names_; }); - } else { - return PackedFunc(nullptr); - } - } - - void SaveToFile(const std::string& file_name, const std::string& format) final { - std::string fmt = runtime::GetFileFormat(file_name, format); - std::string meta_file = runtime::GetMetaFilePath(file_name); - if (fmt == "c") { - ICHECK_NE(code_.length(), 0); - runtime::SaveBinaryToFile(file_name, code_); - } else { - ICHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; - } - } - - protected: - std::string code_; - std::string fmt_; - Array func_names_; -}; - -static runtime::Module CMSISNNModuleNodeCreate(IRModule mod) { +runtime::Module TIRToRuntime(IRModule mod, Target target) { bool output_ssa = false; - CodeGenCMSISNN cg; + bool emit_asserts = false; + CodeGenCMSISNN codegen; Array function_names; - cg.Init(output_ssa); - ICHECK(mod->functions.size() == 1) << "Supports modules with single PrimFunc."; + codegen.Init(output_ssa, emit_asserts, target->str()); for (auto kv : mod->functions) { - ICHECK(kv.second->IsInstance()) << "CodegenCMSISNN: Can only take PrimFunc"; - auto f = Downcast(kv.second); - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.defined()) - << "CodeGenCHost: Expect PrimFunc to have the global_symbol attribute"; + auto prim_func = Downcast(kv.second); + auto global_symbol = prim_func->GetAttr(tvm::attr::kGlobalSymbol); function_names.push_back(global_symbol.value()); - cg.AddFunction(f); + codegen.AddFunction(prim_func); } - std::string code = cg.Finish(); - auto n = make_object(code, "c", function_names); - return runtime::Module(n); + std::string code = codegen.Finish(); + return codegen::CSourceModuleCreate(code, "c", function_names); } -TVM_REGISTER_GLOBAL("runtime.CMSISNNModuleNodeCreate").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = CMSISNNModuleNodeCreate(args[0]); -}); - -} // namespace codegen +} // namespace cmsisnn +} // namespace contrib +} // namespace relay } // namespace tvm diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index a8c27a126032..02a8e611f827 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -580,6 +580,7 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { Expr DeviceAwareVisitExpr_(const CallNode* call_node) override { Call call = GetRef(call_node); + // Look for (indirect) calls to primitives. BaseFunc prim_func = ResolveToPrimitive(call_node->op); if (!prim_func.defined()) { @@ -590,10 +591,16 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { return ExprMutator::VisitExpr_(call_node); } + // Similarly transform arguments. + Array args; + for (const auto& arg : call_node->args) { + args.push_back(VisitExpr(arg)); + } + // Already lowered by other means so we don't need to mutate - // the call + // the call but we do need to mutate the arguments if (prim_func->IsInstance()) { - return std::move(call); + return Call(call_node->op, args, call_node->attrs); } // Find the desired target device. @@ -612,12 +619,6 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { Function func = Downcast(prim_func); std::pair pair = LowerFunction(func, target); - // Similarly transform arguments. - Array args; - for (const auto& arg : call_node->args) { - args.push_back(VisitExpr(arg)); - } - // Replace with direct call to lowered primitive, and attach annotations to record calling // convention. return Call(pair.first, args, pair.second); diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index f74cf983ccae..6e52cbfbe55a 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -43,6 +43,7 @@ #include #include "../analysis/annotated_region_set.h" +#include "../backend/name_transforms.h" #include "../backend/utils.h" #include "pass_utils.h" @@ -501,7 +502,7 @@ class NameMangleExtFuncs : public MixedModeMutator { if (auto* fn = pair.second.as()) { auto func = GetRef(fn); if (func->GetAttr(attr::kCompiler).defined()) { - auto fn_name_mangled = mangle_fn_(pair.first->name_hint); + auto fn_name_mangled = relay::backend::SanitizeName(mangle_fn_(pair.first->name_hint)); GlobalVar gvar = GlobalVar(fn_name_mangled); mangled_gvars_[pair.first->name_hint] = gvar; } @@ -519,7 +520,8 @@ class NameMangleExtFuncs : public MixedModeMutator { if (func->GetAttr(attr::kCompiler).defined()) { auto new_dict = func->attrs->dict; - new_dict.Set(tvm::attr::kGlobalSymbol, String(mangle_fn_(pair.first->name_hint))); + new_dict.Set(tvm::attr::kGlobalSymbol, + String(relay::backend::SanitizeName(mangle_fn_(pair.first->name_hint)))); func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, DictAttrs(new_dict)); new_module->Add(mangled_gvars_[pair.first->name_hint], func); diff --git a/tests/python/contrib/test_cmsisnn/test_binary_ops.py b/tests/python/contrib/test_cmsisnn/test_binary_ops.py index d785cfa199ae..42eb31a3532c 100644 --- a/tests/python/contrib/test_cmsisnn/test_binary_ops.py +++ b/tests/python/contrib/test_cmsisnn/test_binary_ops.py @@ -103,7 +103,7 @@ def test_op_int8(op, input_0_scale, input_0_zero_point, input_1_scale, input_1_z assert any(attrs), "At least one function with external attributes was expected." compilers = [ - key == "Compiler" and value == "cmsisnn" for attr in attrs for key, value in attr.items() + key == "Compiler" and value == "cmsis-nn" for attr in attrs for key, value in attr.items() ] assert any(compilers), "Module does not contain function for cmsisnn target." diff --git a/tests/python/contrib/test_cmsisnn/test_softmax.py b/tests/python/contrib/test_cmsisnn/test_softmax.py index b030437252dc..40e12fc962b2 100644 --- a/tests/python/contrib/test_cmsisnn/test_softmax.py +++ b/tests/python/contrib/test_cmsisnn/test_softmax.py @@ -85,7 +85,7 @@ def test_softmax_int8(zero_point, scale): assert any(attrs), "At least one function with external attributes was expected." compilers = [ - key == "Compiler" and value == "cmsisnn" for attr in attrs for key, value in attr.items() + key == "Compiler" and value == "cmsis-nn" for attr in attrs for key, value in attr.items() ] assert any(compilers), "Module does not contain function for cmsisnn target." diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 5aba6229c5e2..90d88169225c 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -326,6 +326,49 @@ def expected(): check_result(mod, {"x": x_data, "y": y_data}, (16, 8), res) +def test_extern_compiler_sanitized_ops(): + def expected(): + mod = tvm.IRModule() + x = relay.var("x", shape=(8, 8)) + y = relay.var("y", shape=(8, 8)) + x0 = relay.var("x0", shape=(8, 8)) + y0 = relay.var("y0", shape=(8, 8)) + add = x0 + y0 + # Function that uses C compiler + func = relay.Function([x0, y0], add) + func = set_func_attr(func, "unsanitary-name++", "tvmgen_default_unsanitary_name___main_0") + glb_0 = relay.GlobalVar("tvmgen_default_unsanitary_name___main_0") + mod[glb_0] = func + add_call = relay.Call(glb_0, [x, y]) + # Function that uses default compiler. Ops are fused in this function. + p0 = relay.var("p0", shape=(8, 8)) + log = relay.log(p0) + exp = relay.exp(p0) + concat = relay.concatenate([log, exp], axis=0) + fused_func = relay.Function([p0], concat) + fused_func = fused_func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + fused_call = relay.Call(fused_func, [add_call]) + main = relay.Function([x, y], fused_call) + mod["main"] = main + mod = transform.InferType()(mod) + return mod + + x = relay.var("x", shape=(8, 8)) + y = relay.var("y", shape=(8, 8)) + add = x + y + log = relay.log(add) + exp = relay.exp(add) + concat = relay.concatenate([log, exp], axis=0) + f = relay.Function([x, y], concat) + mod = tvm.IRModule() + mod["main"] = f + mod = WhiteListAnnotator(["add", "subtract", "multiply"], "unsanitary-name++")(mod) + mod = transform.PartitionGraph()(mod) + fused_mod = transform.FuseOps(2)(mod) + expected_mod = expected() + assert tvm.ir.structural_equal(fused_mod, expected_mod, map_free_vars=True) + + def test_extern_ccompiler_multiple_functions(): def expected(): mod = tvm.IRModule() From b911e5c462290539cb8538cfdab14de92312160b Mon Sep 17 00:00:00 2001 From: Chris Sidebottom Date: Fri, 29 Oct 2021 10:24:33 +0000 Subject: [PATCH 2/3] Port tvmc fixes for hybrid targets --- python/tvm/driver/tvmc/common.py | 4 +++- python/tvm/driver/tvmc/target.py | 14 +++++++++----- tests/python/driver/tvmc/test_target.py | 13 +++++++++++++ tests/python/driver/tvmc/test_target_options.py | 10 ++++++++++ 4 files changed, 35 insertions(+), 6 deletions(-) diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py index 1ee24cf69d44..65b0c3dbc0aa 100644 --- a/python/tvm/driver/tvmc/common.py +++ b/python/tvm/driver/tvmc/common.py @@ -27,6 +27,7 @@ import tvm +from tvm.driver import tvmc from tvm import relay from tvm import transform from tvm._ffi import registry @@ -206,6 +207,7 @@ def parse_target(target): a key-value for all options passed via CLI; 'raw', containing the plain string for this codegen """ + codegen_names = tvmc.composite_target.get_codegen_names() codegens = [] tvm_target_kinds = tvm.target.Target.list_kinds() @@ -232,7 +234,7 @@ def parse_target(target): for codegen_def in split_codegens: # the first is expected to be the name name = codegen_def[0] - is_tvm_target = name in tvm_target_kinds + is_tvm_target = name in tvm_target_kinds and name not in codegen_names raw_target = " ".join(codegen_def) all_opts = codegen_def[1:] if len(codegen_def) > 1 else [] opts = {} diff --git a/python/tvm/driver/tvmc/target.py b/python/tvm/driver/tvmc/target.py index 7a078b8be087..a551293b26a5 100644 --- a/python/tvm/driver/tvmc/target.py +++ b/python/tvm/driver/tvmc/target.py @@ -18,6 +18,7 @@ This file contains functions for processing target inputs for the TVMC CLI """ +from tvm.driver import tvmc from tvm.target import Target # We can't tell the type inside an Array but all current options are strings so @@ -27,6 +28,11 @@ INTERNAL_TO_HELP = {"runtime.String": " string", "IntImm": "", "Array": " options"} +def _valid_target_kinds(): + codegen_names = tvmc.composite_target.get_codegen_names() + return filter(lambda target: target not in codegen_names, Target.list_kinds()) + + def _generate_target_kind_args(parser, kind): target_group = parser.add_argument_group(f"target {kind.name}") for target_option, target_type in kind.options.items(): @@ -45,8 +51,7 @@ def generate_target_args(parser): help="compilation target as plain string, inline JSON or path to a JSON file", required=True, ) - target_kinds = Target.list_kinds() - for target_kind in target_kinds: + for target_kind in _valid_target_kinds(): target = Target(target_kind) _generate_target_kind_args(parser, target.kind) @@ -55,7 +60,7 @@ def _reconstruct_target_kind_args(args, kind): kind_options = {} for target_option, target_type in kind.options.items(): if target_type in INTERNAL_TO_NATIVE_TYPE: - var_name = f"target_{kind.name}_{target_option.replace('-', '_')}" + var_name = f"target_{kind.name.replace('-', '_')}_{target_option.replace('-', '_')}" option_value = getattr(args, var_name) if option_value is not None: kind_options[target_option] = getattr(args, var_name) @@ -64,9 +69,8 @@ def _reconstruct_target_kind_args(args, kind): def reconstruct_target_args(args): """Reconstructs the target options from the arguments""" - target_kinds = Target.list_kinds() reconstructed = {} - for target_kind in target_kinds: + for target_kind in _valid_target_kinds(): target = Target(target_kind) kind_options = _reconstruct_target_kind_args(args, target.kind) if kind_options: diff --git a/tests/python/driver/tvmc/test_target.py b/tests/python/driver/tvmc/test_target.py index 001ac18ca6d9..865542ee25c1 100644 --- a/tests/python/driver/tvmc/test_target.py +++ b/tests/python/driver/tvmc/test_target.py @@ -118,6 +118,19 @@ def test_parse_multiple_target(): assert "llvm" == targets[1]["name"] +def test_parse_hybrid_target(): + """Hybrid Target and external codegen""" + targets = tvmc.common.parse_target( + "cmsis-nn -accelerator_config=ethos-u55-256, llvm -device=arm_cpu --system-lib" + ) + + assert len(targets) == 2 + assert "cmsis-nn" == targets[0]["name"] + assert not targets[0]["is_tvm_target"] + assert "llvm" == targets[1]["name"] + assert targets[1]["is_tvm_target"] + + def test_parse_quotes_and_separators_on_options(): targets_no_quote = tvmc.common.parse_target("foo -option1=+v1.0x,+value,+bar") targets_single_quote = tvmc.common.parse_target("foo -option1='+v1.0x,+value'") diff --git a/tests/python/driver/tvmc/test_target_options.py b/tests/python/driver/tvmc/test_target_options.py index f6942299b751..b592d504fe7f 100644 --- a/tests/python/driver/tvmc/test_target_options.py +++ b/tests/python/driver/tvmc/test_target_options.py @@ -42,6 +42,16 @@ def test_mapping_target_args(): assert reconstruct_target_args(parsed) == {"llvm": {"mcpu": "cortex-m3"}} +def test_skip_target_from_codegen(): + parser = argparse.ArgumentParser() + generate_target_args(parser) + parsed, left = parser.parse_known_args( + ["--target=cmsis-nn, c", "--target-cmsis-nn-from_device=1", "--target-c-mcpu=cortex-m55"] + ) + assert left == ["--target-cmsis-nn-from_device=1"] + assert reconstruct_target_args(parsed) == {"c": {"mcpu": "cortex-m55"}} + + def test_target_recombobulation_single(): tvm_target, _ = tvmc.common.target_from_cli("llvm", {"llvm": {"mcpu": "cortex-m3"}}) From 5e22a878ddee4e1e5457c40f542d0cd28a4de704 Mon Sep 17 00:00:00 2001 From: Chris Sidebottom Date: Sat, 30 Oct 2021 15:35:39 +0100 Subject: [PATCH 3/3] Update NPU tests with new sanitisation --- .../python/contrib/test_ethosn/infrastructure.py | 16 ++++++++++++---- .../python/contrib/test_ethosn/test_networks.py | 16 ++++++++-------- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/tests/python/contrib/test_ethosn/infrastructure.py b/tests/python/contrib/test_ethosn/infrastructure.py index c284f6488bfb..f16c37fe19af 100644 --- a/tests/python/contrib/test_ethosn/infrastructure.py +++ b/tests/python/contrib/test_ethosn/infrastructure.py @@ -170,11 +170,19 @@ def build(mod, params, npu=True, expected_host_ops=0, npu_partitions=1): assert ( host_op_count == expected_host_ops ), "Got {} host operators, expected {}".format(host_op_count, expected_host_ops) - partition_count = 0 - for global_var in mod.get_global_vars(): - if "ethos-n" in global_var.name_hint: - partition_count += 1 + attrs = [ + mod[var.name_hint].attrs + for var in mod.get_global_vars() + if mod[var.name_hint].attrs + ] + partition_count = sum( + [ + key == "Compiler" and value == "ethos-n" + for attr in attrs + for key, value in attr.items() + ] + ) assert ( npu_partitions == partition_count ), "Got {} ethos-n partitions, expected {}".format(partition_count, npu_partitions) diff --git a/tests/python/contrib/test_ethosn/test_networks.py b/tests/python/contrib/test_ethosn/test_networks.py index f720c55c567a..3a8b95496fde 100644 --- a/tests/python/contrib/test_ethosn/test_networks.py +++ b/tests/python/contrib/test_ethosn/test_networks.py @@ -123,9 +123,9 @@ def test_mobilenet_v1(): # codegen, which could come about from either a change in Support Library # version or a change in the Ethos-N codegen. To update this requires running # on hardware that isn't available in CI. - _compile_hash = {"1fd4ef29a1ea9f3a015cab87c0b8014a"} + _compile_hash = {"0433d3c3947a067b36f0228bdb5f1838"} if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": - _compile_hash = {"b879dfbff1f907eaf6129dfd41b44ece"} + _compile_hash = {"e4ed29dceb1187505948ab17fc3cc6d6"} if tei.get_ethosn_api_version() == 2011: _compile_hash = {"9c9f63b30824f5b223cdb27d2f22c857"} if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": @@ -150,9 +150,9 @@ def test_inception_v3(): # codegen, which could come about from either a change in Support Library # version or a change in the Ethos-N codegen. To update this requires running # on hardware that isn't available in CI. - _compile_hash = {"b90ed315639c6a0e97584c2dbc42a55c"} + _compile_hash = {"43dc2097127eb224c0191b1a15f8acca"} if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": - _compile_hash = {"5693569055695e581a8739194d0301aa"} + _compile_hash = {"7db23387bdc5af6eaa1ae3f7d456caf0"} if tei.get_ethosn_api_version() == 2011: _compile_hash = {"46ccafc840633633aca441645e41b444"} if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": @@ -176,9 +176,9 @@ def test_inception_v4(): # codegen, which could come about from either a change in Support Library # version or a change in the Ethos-N codegen. To update this requires running # on hardware that isn't available in CI. - _compile_hash = {"b36877d2386d9f9c37a11772e3c4072c"} + _compile_hash = {"fab6c2297502f95d33079c6ce1a737f9"} if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": - _compile_hash = {"b5046a6f56d78af0b4f51960bf2deeda"} + _compile_hash = {"8da68849b75613ac3dffd3fff2dd87da"} if tei.get_ethosn_api_version() == 2011: _compile_hash = {"4a1a56393078367dd27915a188d6a6af"} if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": @@ -202,9 +202,9 @@ def test_ssd_mobilenet_v1(): # codegen, which could come about from either a change in Support Library # version or a change in the Ethos-N codegen. To update this requires running # on hardware that isn't available in CI. - _compile_hash = {"956caf9e7fe5cfd5c042bd17857f7407", "4313033d14328e2aa022b1bd71b27b1c"} + _compile_hash = {"2345cf5d6c0013bad7c76dcccee9d862", "7795b6c67178da9d1f9b98063bad75b1"} if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": - _compile_hash = {"dc60cc687d892cd2877873094e9dfc0b", "6b3deeec16c24c0dcef23df0db5fb162"} + _compile_hash = {"928dc6ae5ce49a4ad63ca87f7575970f", "b092f9820f7e9341fc53daa781b98772"} if tei.get_ethosn_api_version() == 2011: _compile_hash = {"10826406ae724e52f360a06c35ced09d", "9a484d5ecec7acb18c9d6bc6058be031"} if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO":