From 31e290f2750144cca4a2986caeb2d07e245f466e Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Thu, 4 Jun 2020 18:44:16 +0000 Subject: [PATCH 01/13] Separate code and metadata --- python/tvm/target/__init__.py | 1 + python/tvm/target/packaging_module.py | 52 ++++++++++ src/relay/backend/contrib/dnnl/codegen.cc | 28 ++++-- src/relay/backend/vm/compiler.cc | 18 ++-- src/target/source/source_module.cc | 106 +++++++++++++++++++- tests/python/relay/test_external_codegen.py | 10 +- 6 files changed, 196 insertions(+), 19 deletions(-) create mode 100644 python/tvm/target/packaging_module.py diff --git a/python/tvm/target/__init__.py b/python/tvm/target/__init__.py index 2553fedb9869..dd470dd28975 100644 --- a/python/tvm/target/__init__.py +++ b/python/tvm/target/__init__.py @@ -61,3 +61,4 @@ from . import datatype from . import codegen from .intrin import register_intrin_rule +from .packaging_module import PackagingModule, CSourceModule diff --git a/python/tvm/target/packaging_module.py b/python/tvm/target/packaging_module.py new file mode 100644 index 000000000000..cb077d1d76e5 --- /dev/null +++ b/python/tvm/target/packaging_module.py @@ -0,0 +1,52 @@ +# License .to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, invalid-name, redefined-builtin +""" +APIs for a packaging module +""" +from tvm.runtime import _ffi_api + +class PackagingModule: + """The Packaging module""" + def __init__(self, mod): + self.mod = mod + self._get_source = self.mod["get_source"] + self._get_metadata = self.mod["get_metadata"] + self._is_c_source = self.mod["is_c_source"] + + @property + def source(self): + """Get the source""" + return self._get_source() + + @property + def metadata(self): + """Get the metadata""" + return self._get_metadata() + + def is_c_source(self): + """Check if the source code is C/C++""" + return self._is_c_source() + + +def CSourceModule(code, fmt="c"): + """Create a C source module""" + return _ffi_api.CSourceModuleCreate(code, fmt) + +def ModuleInitWrapper(metadata, code=""): + """Create a module initialization wrapper""" + return _ffi_api.ModuleInitWrapper(metadata, code) diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index 3f9ad7cdc69f..61924cb62d30 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -169,6 +169,8 @@ class CodegenDNNL : public MemoizedExprTranslator>, public C output.dtype = "float"; runtime::NDArray array = cn->data; + CHECK_EQ(metadata_.count(output.name), 0U) << "variable must be unique: " << output.name; + metadata_.Set(output.name, array); // Get the number of elements. int64_t num_elems = 1; @@ -212,6 +214,8 @@ class CodegenDNNL : public MemoizedExprTranslator>, public C return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body, out); } + Map GetMetadata() const { return metadata_; } + private: struct GenerateBodyOutput { std::string decl; @@ -345,6 +349,8 @@ class CodegenDNNL : public MemoizedExprTranslator>, public C std::vector ext_func_body; /*! \brief The declaration of intermeidate buffers. */ std::vector buf_decl_; + /*! \brief The variable name to constant mapping. */ + Map metadata_; }; /*! @@ -355,7 +361,7 @@ class CodegenDNNL : public MemoizedExprTranslator>, public C class DNNLModuleCodegen : public CSourceModuleCodegenBase { public: // Create a corresponding DNNL function for the given relay Function. - void GenDNNLFunc(const Function& func) { + std::pair> GenDNNLFunc(const Function& func) { CHECK(func.defined()) << "Input error: expect a Relay function."; // Record the external symbol for runtime lookup. @@ -364,6 +370,8 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase { CodegenDNNL builder(sid); auto out = builder.VisitExpr(func->body); code_stream_ << builder.JIT(out); + + return std::make_pair(sid, builder.GetMetadata()); } /*! @@ -392,22 +400,30 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase { code_stream_ << "using namespace tvm::runtime::contrib;\n"; code_stream_ << "\n"; + Map code; + Map> metadata; if (ref->IsInstance()) { - GenDNNLFunc(Downcast(ref)); + auto ret = GenDNNLFunc(Downcast(ref)); + String sym = std::get<0>(ret); + code.Set(sym, code_stream_.str()); + metadata.Set(sym, std::get<1>(ret)); } else if (ref->IsInstance()) { IRModule mod = Downcast(ref); for (const auto& it : mod->functions) { - GenDNNLFunc(Downcast(it.second)); + auto ret = GenDNNLFunc(Downcast(it.second)); + String sym = std::get<0>(ret); + code.Set(sym, code_stream_.str()); + metadata.Set(sym, std::get<1>(ret)); } } else { LOG(FATAL) << "The input ref is expected to be a Relay function or module" << "\n"; } - // Create a CSourceModule - const auto* pf = runtime::Registry::Get("runtime.CSourceModuleCreate"); + // Create a PackagingModule + const auto* pf = runtime::Registry::Get("runtime.PackagingModuleCreate"); CHECK(pf != nullptr) << "Cannot find csource module to create the external runtime module"; - return (*pf)(code_stream_.str(), "cc"); + return (*pf)(code, "c", metadata); } private: diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 81db34125bd7..4d90c47ddf7f 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -41,6 +41,7 @@ #include #include +#include "../../../target/source/codegen_source_base.h" #include "../../backend/compile_engine.h" #include "../../op/op_common.h" #include "../../transforms/pass_util.h" @@ -1015,18 +1016,13 @@ void VMCompiler::Codegen() { mod = tvm::build(build_funcs, target_host_); CHECK(mod.operator->()); } else { - CHECK_EQ(ext_mods.size(), 1U) - << "Expect to have a TVM DSOModule when multiple runtime modules exist"; + // There is no function handled by TVM. We create a virtual master module + // to make sure a DSO module will be also available. + mod = codegen::CSourceModuleCreate(";", ""); } - if (!ext_mods.empty()) { - if (funcs.size() == 0) { - mod = ext_mods[0]; - } else { - // Import all external runtime modules. - for (auto it : ext_mods) { - mod.Import(it); - } - } + // Import all external runtime modules. + for (auto it : ext_mods) { + mod.Import(it); } exec_->lib = mod; } diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index ba7f075d0045..9b585cb5822b 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -21,6 +21,7 @@ * \file source_module.cc * \brief Source code module, only for viewing */ +#include #include #include @@ -152,8 +153,111 @@ runtime::Module DeviceSourceModuleCreate( return runtime::Module(n); } +// Pack the source code and metadata, where source code could be any +// user-defined code, i.e. c source code, json graph representation, etc. +class PackagingModule final : public runtime::ModuleNode { + public: + PackagingModule(Map code, const std::string& fmt, + Map> metadata) + : code_(code), fmt_(fmt), metadata_(metadata) {} + + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { + if (name == "get_source") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetSource(); }); + } else if (name == "get_metadata") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetMetadata(); }); + } else if (name == "is_c_source") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->IsCSourceCode(); }); + } else { + LOG(FATAL) << "Unknown packed function: " << name; + return PackedFunc(nullptr); + } + } + + Map GetSource() const { return code_; } + + Map> GetMetadata() const { return metadata_; } + + bool IsCSourceCode() { return fmt_ == "c" || fmt_ == "cc"; } + + const char* type_key() const { return "c"; } + + void SaveToFile(const std::string& file_name, const std::string& format) final { + std::string fmt = GetFileFormat(file_name, format); + CHECK_EQ(fmt, "cc") << "file_name: " << file_name << " must be a .cc file."; + SaveBinaryToFile(file_name, ";"); + } + + private: + /*! \brief Symbol to source (e.g. c source/json) mapping. */ + Map code_; + std::string fmt_; + /*! \brief Symbol to {var_name : NDArray} pair mapping. */ + Map> metadata_; +}; + +runtime::Module PackagingModuleCreate(Map code, std::string fmt, + Map> metadata) { + auto n = make_object(code, fmt, metadata); + return runtime::Module(n); +} + +class ModuleInitWrapper : public runtime::ModuleNode { + public: + ModuleInitWrapper(Map> metadata, Map code) + : metadata_(metadata), code_(code) {} + + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { + if (initialized_.count(name) == 0) { + this->InitSubModule(name); + initialized_[name] = true; + } else if (name != "__InitModule" && name != "__DestroyModule") { + CHECK(!this->imports().empty()); + runtime::Module submodule = this->imports().at(0); + return submodule->GetFunction(name); + } + + return PackedFunc(); + } + + const char* type_key() const { return "module_init"; } + + void InitSubModule(const std::string& symbol) {} + + private: + std::unordered_map initialized_; + /*! \brief A symbol to {var_name : NDArray} pair mapping. */ + Map> metadata_; + /*! + * \brief For JSON runtime we need the json code to build up an engine. For + * c source module, code has already been compiled into a DSO module, only + * metadata is needed to feed the correct data. + */ + Map code_; +}; + +runtime::Module ModuleInitWrapperCreate(Map> metadata, + Map code) { + auto n = make_object(metadata, code); + return runtime::Module(n); +} + +TVM_REGISTER_GLOBAL("runtime.PackagingModuleCreate").set_body_typed(PackagingModuleCreate); + TVM_REGISTER_GLOBAL("runtime.SourceModuleCreate").set_body_typed(SourceModuleCreate); -TVM_REGISTER_GLOBAL("runtime.CSourceModuleCreate").set_body_typed(CSourceModuleCreate); +TVM_REGISTER_GLOBAL("runtime.CSourceModuleCreate").set_body_typed([](String code, String fmt) { + return CSourceModuleCreate(code.operator std::string(), fmt.operator std::string()); +}); + +TVM_REGISTER_GLOBAL("runtime.ModuleInitWrapper") + .set_body_typed([](Map> metadata, + Map code) { + return ModuleInitWrapperCreate(metadata, code); + }); + } // namespace codegen } // namespace tvm diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py index c449ce39ff01..7c829d91c571 100644 --- a/tests/python/relay/test_external_codegen.py +++ b/tests/python/relay/test_external_codegen.py @@ -34,6 +34,14 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", return def update_lib(lib): + ext_mod = tvm.target.PackagingModule(lib.imported_modules[0]) + code = ext_mod.source + metadata = ext_mod.metadata + + new_lib = lib + for sym, src in code.items(): + new_lib.import_module(tvm.target.CSourceModule(src)) + test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) source_dir = os.path.join(test_dir, "..", "..", "..") contrib_path = os.path.join(source_dir, "src", "runtime", "contrib") @@ -43,7 +51,7 @@ def update_lib(lib): tmp_path = util.tempdir() lib_name = 'lib.so' lib_path = tmp_path.relpath(lib_name) - lib.export_library(lib_path, fcompile=False, **kwargs) + new_lib.export_library(lib_path, fcompile=False, **kwargs) lib = tvm.runtime.load_module(lib_path) return lib From 1f5a7f1edf1b606f08433956ba5489f4ef0fcc61 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Fri, 5 Jun 2020 16:55:39 +0000 Subject: [PATCH 02/13] init wrapper WIP --- python/tvm/target/packaging_module.py | 11 ++- src/relay/backend/contrib/dnnl/codegen.cc | 2 +- src/target/source/source_module.cc | 96 +++++++++++++++------ tests/python/relay/test_external_codegen.py | 3 + 4 files changed, 83 insertions(+), 29 deletions(-) diff --git a/python/tvm/target/packaging_module.py b/python/tvm/target/packaging_module.py index cb077d1d76e5..a7c49bf29b24 100644 --- a/python/tvm/target/packaging_module.py +++ b/python/tvm/target/packaging_module.py @@ -25,6 +25,7 @@ class PackagingModule: def __init__(self, mod): self.mod = mod self._get_source = self.mod["get_source"] + self._get_source_type = self.mod["get_source_type"] self._get_metadata = self.mod["get_metadata"] self._is_c_source = self.mod["is_c_source"] @@ -33,6 +34,11 @@ def source(self): """Get the source""" return self._get_source() + @property + def source_type(self): + """Get the source type""" + return self._get_source_type() + @property def metadata(self): """Get the metadata""" @@ -47,6 +53,7 @@ def CSourceModule(code, fmt="c"): """Create a C source module""" return _ffi_api.CSourceModuleCreate(code, fmt) -def ModuleInitWrapper(metadata, code=""): + +def ModuleInitWrapper(metadata, code="", source_type="c"): """Create a module initialization wrapper""" - return _ffi_api.ModuleInitWrapper(metadata, code) + return _ffi_api.ModuleInitWrapper(metadata, code, source_type) diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index 61924cb62d30..bcc8513fd9e9 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -165,7 +165,7 @@ class CodegenDNNL : public MemoizedExprTranslator>, public C std::vector VisitExpr_(const ConstantNode* cn) final { Output output; - output.name = "const_" + std::to_string(const_idx_++); + output.name = ext_func_id_ + "_const_" + std::to_string(const_idx_++); output.dtype = "float"; runtime::NDArray array = cn->data; diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index 9b585cb5822b..5fe428247a0f 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -157,17 +157,19 @@ runtime::Module DeviceSourceModuleCreate( // user-defined code, i.e. c source code, json graph representation, etc. class PackagingModule final : public runtime::ModuleNode { public: - PackagingModule(Map code, const std::string& fmt, + PackagingModule(Map code, const std::string& source_type, Map> metadata) - : code_(code), fmt_(fmt), metadata_(metadata) {} + : code_(code), source_type_(source_type), metadata_(metadata) {} PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { if (name == "get_source") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->code_; }); + } else if (name == "get_source_type") { return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetSource(); }); + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->source_type_; }); } else if (name == "get_metadata") { return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetMetadata(); }); + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->metadata_; }); } else if (name == "is_c_source") { return PackedFunc( [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->IsCSourceCode(); }); @@ -177,44 +179,66 @@ class PackagingModule final : public runtime::ModuleNode { } } - Map GetSource() const { return code_; } - - Map> GetMetadata() const { return metadata_; } - - bool IsCSourceCode() { return fmt_ == "c" || fmt_ == "cc"; } + bool IsCSourceCode() { return source_type_ == "c" || source_type_ == "cc"; } const char* type_key() const { return "c"; } void SaveToFile(const std::string& file_name, const std::string& format) final { - std::string fmt = GetFileFormat(file_name, format); - CHECK_EQ(fmt, "cc") << "file_name: " << file_name << " must be a .cc file."; + std::string source_type = GetFileFormat(file_name, format); + CHECK_EQ(source_type, "cc") << "file_name: " << file_name << " must be a .cc file."; SaveBinaryToFile(file_name, ";"); } private: /*! \brief Symbol to source (e.g. c source/json) mapping. */ Map code_; - std::string fmt_; + /*! \brief The type of the source code, e.g. c or any customized json type. */ + std::string source_type_; /*! \brief Symbol to {var_name : NDArray} pair mapping. */ Map> metadata_; }; -runtime::Module PackagingModuleCreate(Map code, std::string fmt, +runtime::Module PackagingModuleCreate(Map code, std::string source_type, Map> metadata) { - auto n = make_object(code, fmt, metadata); + auto n = make_object(code, source_type, metadata); + return runtime::Module(n); +} + +class CSourceModuleInitializer : public runtime::ModuleNode { + public: + explicit CSourceModuleInitializer(Map> metadata) + : metadata_(metadata) {} + + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { + LOG(FATAL) << "CSourceModuleInitializer cannot be executed"; + return PackedFunc(); + } + + void Init() {} + + const char* type_key() const { return "csourcemodule_initializer"; } + + private: + Map> metadata_; +}; + +runtime::Module CSourceModuleInitializerCreate( + Map> metadata) { + auto n = make_object(metadata); return runtime::Module(n); } class ModuleInitWrapper : public runtime::ModuleNode { public: - ModuleInitWrapper(Map> metadata, Map code) - : metadata_(metadata), code_(code) {} + ModuleInitWrapper(Map> metadata, Map code, + String source_type) + : metadata_(metadata), code_(code), source_type_(source_type) {} PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { if (initialized_.count(name) == 0) { this->InitSubModule(name); initialized_[name] = true; - } else if (name != "__InitModule" && name != "__DestroyModule") { + } else if (name != "init_module" && name != "destroy_module") { CHECK(!this->imports().empty()); runtime::Module submodule = this->imports().at(0); return submodule->GetFunction(name); @@ -222,10 +246,25 @@ class ModuleInitWrapper : public runtime::ModuleNode { return PackedFunc(); } - + const char* type_key() const { return "module_init"; } - void InitSubModule(const std::string& symbol) {} + void InitSubModule(const std::string& symbol) { + // Dispatch initializer according to the source type + std::string initializer = "runtime.init." + source_type_; + auto pf = tvm::runtime::Registry::Get(initializer); + + CHECK(pf) << "Failed to find the initializer for " << initializer; + if (source_type_ == "c") { + // Initialize the s source module. + runtime::Module c_mod = (*pf)(metadata_); + CHECK(c_mod->IsInstance()); + auto* c_mod_init = static_cast(c_mod.operator->()); + c_mod_init->Init(); + } else { + LOG(FATAL) << "Implement the initialization of json style runtime here"; + } + } private: std::unordered_map initialized_; @@ -237,11 +276,13 @@ class ModuleInitWrapper : public runtime::ModuleNode { * metadata is needed to feed the correct data. */ Map code_; + /*! \brief The type of the source, i.e. c, or any customized json */ + String source_type_; }; runtime::Module ModuleInitWrapperCreate(Map> metadata, - Map code) { - auto n = make_object(metadata, code); + Map code, String source_type) { + auto n = make_object(metadata, code, source_type); return runtime::Module(n); } @@ -249,15 +290,18 @@ TVM_REGISTER_GLOBAL("runtime.PackagingModuleCreate").set_body_typed(PackagingMod TVM_REGISTER_GLOBAL("runtime.SourceModuleCreate").set_body_typed(SourceModuleCreate); -TVM_REGISTER_GLOBAL("runtime.CSourceModuleCreate").set_body_typed([](String code, String fmt) { - return CSourceModuleCreate(code.operator std::string(), fmt.operator std::string()); -}); +TVM_REGISTER_GLOBAL("runtime.CSourceModuleCreate") + .set_body_typed([](String code, String source_type) { + return CSourceModuleCreate(code.operator std::string(), source_type.operator std::string()); + }); TVM_REGISTER_GLOBAL("runtime.ModuleInitWrapper") .set_body_typed([](Map> metadata, - Map code) { - return ModuleInitWrapperCreate(metadata, code); + Map code, String source_type) { + return ModuleInitWrapperCreate(metadata, code, source_type); }); +TVM_REGISTER_GLOBAL("runtime.init.c").set_body_typed(CSourceModuleInitializerCreate); + } // namespace codegen } // namespace tvm diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py index 7c829d91c571..21449cde9334 100644 --- a/tests/python/relay/test_external_codegen.py +++ b/tests/python/relay/test_external_codegen.py @@ -37,9 +37,12 @@ def update_lib(lib): ext_mod = tvm.target.PackagingModule(lib.imported_modules[0]) code = ext_mod.source metadata = ext_mod.metadata + src_ty = ext_mod.source_type new_lib = lib for sym, src in code.items(): + # init_mod = tvm.target.ModuleInitWrapper(metadata, code, src_ty) + # init_mod.import_module(tvm.target.CSourceModule(src)) new_lib.import_module(tvm.target.CSourceModule(src)) test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) From eaa592ab08db29411398da249bdef1cac9c867e7 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Sat, 6 Jun 2020 19:00:15 +0000 Subject: [PATCH 03/13] use moduleinitwrapper --- python/tvm/runtime/module.py | 27 +++-- python/tvm/target/__init__.py | 2 +- python/tvm/target/packaging_module.py | 4 +- src/relay/backend/contrib/dnnl/codegen.cc | 21 +++- src/target/source/source_module.cc | 110 +++++++++++--------- tests/python/relay/test_external_codegen.py | 8 +- 6 files changed, 105 insertions(+), 67 deletions(-) diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index 3cdb28f8c496..e5f201bfd803 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -222,9 +222,12 @@ def evaluator(*args): except NameError: raise NameError("time_evaluate is only supported when RPC is enabled") - def _collect_dso_modules(self): - """Helper function to collect dso modules, then return it.""" - visited, stack, dso_modules = set(), [], [] + def _collect_dso_metadata_modules(self): + """ + Helper function to collect dso modules and metadata init module. There + is at most one medata init module if it exists. + """ + visited, stack, dso_modules, medata_init = set(), [], [], [] # append root module visited.add(self) stack.append(self) @@ -232,11 +235,15 @@ def _collect_dso_modules(self): module = stack.pop() if module._dso_exportable(): dso_modules.append(module) + elif module.type_key == "module_init": + medata_init.append(module) for m in module.imported_modules: if m not in visited: visited.add(m) stack.append(m) - return dso_modules + + assert len(medata_init) <= 1, "At most one metadata init module is allowed." + return dso_modules, medata_init def _dso_exportable(self): return self.type_key == "llvm" or self.type_key == "c" @@ -282,13 +289,13 @@ def export_library(self, self.save(file_name) return - modules = self._collect_dso_modules() + dso_modules, metadata_init = self._collect_dso_metadata_modules() temp = _util.tempdir() files = addons if addons else [] is_system_lib = False has_c_module = False llvm_target_triple = None - for index, module in enumerate(modules): + for index, module in enumerate(dso_modules): if fcompile is not None and hasattr(fcompile, "object_format"): object_format = fcompile.object_format else: @@ -305,6 +312,14 @@ def export_library(self, module.get_function("__tvm_is_system_module")()) llvm_target_triple = (module.type_key == "llvm" and module.get_function("_get_target_triple")()) + + metadata_import = None if not metadata_init else metadata_init[0].imported_modules + if metadata_import and metadata_import[0].type_key == "c": + module = metadata_init[0] + header = temp.relpath("metadata.h") + module.save(header) + files.append(header) + if not fcompile: if file_name.endswith(".tar"): fcompile = _tar.tar diff --git a/python/tvm/target/__init__.py b/python/tvm/target/__init__.py index dd470dd28975..cfc19363eb1a 100644 --- a/python/tvm/target/__init__.py +++ b/python/tvm/target/__init__.py @@ -61,4 +61,4 @@ from . import datatype from . import codegen from .intrin import register_intrin_rule -from .packaging_module import PackagingModule, CSourceModule +from .packaging_module import PackagingModule, CSourceModule, ModuleInitWrapper diff --git a/python/tvm/target/packaging_module.py b/python/tvm/target/packaging_module.py index a7c49bf29b24..20b71a2faed3 100644 --- a/python/tvm/target/packaging_module.py +++ b/python/tvm/target/packaging_module.py @@ -54,6 +54,6 @@ def CSourceModule(code, fmt="c"): return _ffi_api.CSourceModuleCreate(code, fmt) -def ModuleInitWrapper(metadata, code="", source_type="c"): +def ModuleInitWrapper(metadata, source_type="c"): """Create a module initialization wrapper""" - return _ffi_api.ModuleInitWrapper(metadata, code, source_type) + return _ffi_api.ModuleInitWrapper(metadata, source_type) diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index bcc8513fd9e9..422a58b69efc 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -405,16 +405,27 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase { if (ref->IsInstance()) { auto ret = GenDNNLFunc(Downcast(ref)); String sym = std::get<0>(ret); - code.Set(sym, code_stream_.str()); - metadata.Set(sym, std::get<1>(ret)); + Map consts = std::get<1>(ret); + std::string code_str = code_stream_.str(); + if (!consts.empty()) { + code_str = "#include \"metadata.h\"\n" + code_str; + metadata.Set(sym, consts); + } + code.Set(sym, code_str); } else if (ref->IsInstance()) { IRModule mod = Downcast(ref); for (const auto& it : mod->functions) { auto ret = GenDNNLFunc(Downcast(it.second)); - String sym = std::get<0>(ret); - code.Set(sym, code_stream_.str()); - metadata.Set(sym, std::get<1>(ret)); + Map consts = std::get<1>(ret); + if (!consts.empty()) { + metadata.Set(std::get<0>(ret), consts); + } + } + std::string code_str = code_stream_.str(); + if (!metadata.empty()) { + code_str = "#include \"metadata.h\"\n" + code_str; } + code.Set("all", code_str); } else { LOG(FATAL) << "The input ref is expected to be a Relay function or module" << "\n"; diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index 5fe428247a0f..c38c11587f8a 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -204,35 +204,10 @@ runtime::Module PackagingModuleCreate(Map code, std::string sour return runtime::Module(n); } -class CSourceModuleInitializer : public runtime::ModuleNode { - public: - explicit CSourceModuleInitializer(Map> metadata) - : metadata_(metadata) {} - - PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { - LOG(FATAL) << "CSourceModuleInitializer cannot be executed"; - return PackedFunc(); - } - - void Init() {} - - const char* type_key() const { return "csourcemodule_initializer"; } - - private: - Map> metadata_; -}; - -runtime::Module CSourceModuleInitializerCreate( - Map> metadata) { - auto n = make_object(metadata); - return runtime::Module(n); -} - class ModuleInitWrapper : public runtime::ModuleNode { public: - ModuleInitWrapper(Map> metadata, Map code, - String source_type) - : metadata_(metadata), code_(code), source_type_(source_type) {} + ModuleInitWrapper(Map> metadata, String source_type) + : metadata_(metadata), source_type_(source_type) {} PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { if (initialized_.count(name) == 0) { @@ -249,40 +224,78 @@ class ModuleInitWrapper : public runtime::ModuleNode { const char* type_key() const { return "module_init"; } + std::string InitCSourceMetadata() { + std::ostringstream os; + os.precision(16); + for (const auto& it : metadata_) { + auto vars = it.second; + if (!vars.defined()) continue; + String var_name = (*vars.begin()).first; + runtime::NDArray data = (*vars.begin()).second; + // Get the number of elements. + int64_t num_elems = 1; + for (auto i : data.Shape()) num_elems *= i; + // TODO(zhiics) Handle different data types. + os << "static float " << var_name.c_str() << "[" << num_elems << "] = {"; + const float* ptr = static_cast(data->data); + for (int64_t i = 0; i < num_elems - 1; i++) { + os << ptr[i] << ","; + } + if (num_elems > 0) os << ptr[num_elems - 1]; + os << "};\n"; + } + return os.str(); + } + void InitSubModule(const std::string& symbol) { // Dispatch initializer according to the source type - std::string initializer = "runtime.init." + source_type_; - auto pf = tvm::runtime::Registry::Get(initializer); + // std::string initializer = "runtime.init." + source_type_; + // auto pf = tvm::runtime::Registry::Get(initializer); + + // CHECK(pf) << "Failed to find the initializer for " << initializer; + if (source_type_ != "c") { + LOG(FATAL) << "Implement the initialization of json style runtime here"; + } + } - CHECK(pf) << "Failed to find the initializer for " << initializer; + void SaveToFile(const std::string& file_name, const std::string& format) final { + std::string fmt = GetFileFormat(file_name, format); + CHECK_EQ(fmt, "h") << "Can only save to .h file"; if (source_type_ == "c") { - // Initialize the s source module. - runtime::Module c_mod = (*pf)(metadata_); - CHECK(c_mod->IsInstance()); - auto* c_mod_init = static_cast(c_mod.operator->()); - c_mod_init->Init(); + SaveBinaryToFile(file_name, InitCSourceMetadata()); } else { - LOG(FATAL) << "Implement the initialization of json style runtime here"; + SaveBinaryToFile(file_name, ";"); } } + void SaveToBinary(dmlc::Stream* stream) final { + for (const auto& it : metadata_) { + // Save metadata using the NDArray serialization utility + } + stream->Write(source_type_.operator std::string()); + } + + static runtime::Module LoadFromBinary(void* strm) { + dmlc::Stream* stream = static_cast(strm); + std::string source_type; + stream->Read(&source_type); + Map> metadata; + // Deserializing metadata using the NDArray serialization utility. + auto n = runtime::make_object(metadata, source_type); + return runtime::Module(n); + } + private: std::unordered_map initialized_; /*! \brief A symbol to {var_name : NDArray} pair mapping. */ Map> metadata_; - /*! - * \brief For JSON runtime we need the json code to build up an engine. For - * c source module, code has already been compiled into a DSO module, only - * metadata is needed to feed the correct data. - */ - Map code_; /*! \brief The type of the source, i.e. c, or any customized json */ String source_type_; }; runtime::Module ModuleInitWrapperCreate(Map> metadata, - Map code, String source_type) { - auto n = make_object(metadata, code, source_type); + String source_type) { + auto n = make_object(metadata, source_type); return runtime::Module(n); } @@ -296,12 +309,11 @@ TVM_REGISTER_GLOBAL("runtime.CSourceModuleCreate") }); TVM_REGISTER_GLOBAL("runtime.ModuleInitWrapper") - .set_body_typed([](Map> metadata, - Map code, String source_type) { - return ModuleInitWrapperCreate(metadata, code, source_type); + .set_body_typed([](Map> metadata, String source_type) { + return ModuleInitWrapperCreate(metadata, source_type); }); -TVM_REGISTER_GLOBAL("runtime.init.c").set_body_typed(CSourceModuleInitializerCreate); - +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_module_init") + .set_body_typed(ModuleInitWrapper::LoadFromBinary); } // namespace codegen } // namespace tvm diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py index 21449cde9334..5476fbe965fb 100644 --- a/tests/python/relay/test_external_codegen.py +++ b/tests/python/relay/test_external_codegen.py @@ -40,10 +40,10 @@ def update_lib(lib): src_ty = ext_mod.source_type new_lib = lib - for sym, src in code.items(): - # init_mod = tvm.target.ModuleInitWrapper(metadata, code, src_ty) - # init_mod.import_module(tvm.target.CSourceModule(src)) - new_lib.import_module(tvm.target.CSourceModule(src)) + init_mod = tvm.target.ModuleInitWrapper(metadata, src_ty) + for _, src in code.items(): + init_mod.import_module(tvm.target.CSourceModule(src)) + new_lib.import_module(init_mod) test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) source_dir = os.path.join(test_dir, "..", "..", "..") From 82ac95208d7be44ea96711bc9b86126aea19ce89 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Sun, 7 Jun 2020 00:37:28 +0000 Subject: [PATCH 04/13] serialize metadata, csourcemodule works --- src/relay/backend/contrib/dnnl/codegen.cc | 21 +---- src/target/source/source_module.cc | 96 ++++++++++++++++----- tests/python/relay/test_external_codegen.py | 43 +++++++++ 3 files changed, 117 insertions(+), 43 deletions(-) diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index 422a58b69efc..fa0f5ea36569 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -168,32 +168,13 @@ class CodegenDNNL : public MemoizedExprTranslator>, public C output.name = ext_func_id_ + "_const_" + std::to_string(const_idx_++); output.dtype = "float"; - runtime::NDArray array = cn->data; CHECK_EQ(metadata_.count(output.name), 0U) << "variable must be unique: " << output.name; - metadata_.Set(output.name, array); - - // Get the number of elements. - int64_t num_elems = 1; - for (auto i : array.Shape()) num_elems *= i; + metadata_.Set(output.name, cn->data); const auto* type_node = cn->checked_type().as(); CHECK(type_node); CHECK_EQ(GetDtypeString(type_node), "float") << "Only float is supported for now."; - std::ostringstream buf_stream; - const float* ptr = static_cast(array->data); - - // Allocate large arrays on the static section to avoid stakc overflow. - // Note that this would probably increase compilation time as the source - // file could be really large. - buf_stream << "static float " << output.name << "[" << num_elems << "] = {"; - for (int64_t i = 0; i < num_elems - 1; i++) { - buf_stream << ptr[i] << ","; - } - if (num_elems > 0) buf_stream << ptr[num_elems - 1]; - buf_stream << "};\n"; - - ext_func_body.insert(ext_func_body.begin(), buf_stream.str()); return {output}; } diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index c38c11587f8a..fef16ea969d3 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include "../../runtime/file_util.h" #include "../../runtime/meta_data.h" @@ -213,7 +214,9 @@ class ModuleInitWrapper : public runtime::ModuleNode { if (initialized_.count(name) == 0) { this->InitSubModule(name); initialized_[name] = true; - } else if (name != "init_module" && name != "destroy_module") { + } + + if (name != "init_module" && name != "destroy_module") { CHECK(!this->imports().empty()); runtime::Module submodule = this->imports().at(0); return submodule->GetFunction(name); @@ -228,21 +231,21 @@ class ModuleInitWrapper : public runtime::ModuleNode { std::ostringstream os; os.precision(16); for (const auto& it : metadata_) { - auto vars = it.second; - if (!vars.defined()) continue; - String var_name = (*vars.begin()).first; - runtime::NDArray data = (*vars.begin()).second; - // Get the number of elements. - int64_t num_elems = 1; - for (auto i : data.Shape()) num_elems *= i; - // TODO(zhiics) Handle different data types. - os << "static float " << var_name.c_str() << "[" << num_elems << "] = {"; - const float* ptr = static_cast(data->data); - for (int64_t i = 0; i < num_elems - 1; i++) { - os << ptr[i] << ","; + for (const auto& vars : it.second) { + String var_name = vars.first; + runtime::NDArray data = vars.second; + // Get the number of elements. + int64_t num_elems = 1; + for (auto i : data.Shape()) num_elems *= i; + // TODO(zhiics) Handle different data types. + os << "static float " << var_name.c_str() << "[" << num_elems << "] = {"; + const float* ptr = static_cast(data->data); + for (int64_t i = 0; i < num_elems - 1; i++) { + os << ptr[i] << ","; + } + if (num_elems > 0) os << ptr[num_elems - 1]; + os << "};\n"; } - if (num_elems > 0) os << ptr[num_elems - 1]; - os << "};\n"; } return os.str(); } @@ -259,28 +262,75 @@ class ModuleInitWrapper : public runtime::ModuleNode { } void SaveToFile(const std::string& file_name, const std::string& format) final { - std::string fmt = GetFileFormat(file_name, format); - CHECK_EQ(fmt, "h") << "Can only save to .h file"; if (source_type_ == "c") { + std::string fmt = GetFileFormat(file_name, format); + CHECK_EQ(fmt, "h") << "Can only save to .h file"; SaveBinaryToFile(file_name, InitCSourceMetadata()); - } else { - SaveBinaryToFile(file_name, ";"); } } void SaveToBinary(dmlc::Stream* stream) final { + stream->Write(source_type_.operator std::string()); + + // Save the total number of symbols + uint64_t sym_cnt = static_cast(metadata_.size()); + stream->Write(sym_cnt); + for (const auto& it : metadata_) { - // Save metadata using the NDArray serialization utility + // Save the symbol/function name + stream->Write(it.first.operator std::string()); + + std::vector variables; + std::vector metadata; + for (const auto& vit : it.second) { + String var_name = vit.first; + variables.push_back(var_name.operator std::string()); + metadata.push_back(vit.second); + } + + // Save all variables in the function. + stream->Write(variables); + // Save all constant data + uint64_t sz = static_cast(metadata.size()); + stream->Write(sz); + for (uint64_t i = 0; i < sz; i++) { + metadata[i].Save(stream); + } } - stream->Write(source_type_.operator std::string()); } static runtime::Module LoadFromBinary(void* strm) { dmlc::Stream* stream = static_cast(strm); std::string source_type; - stream->Read(&source_type); + CHECK(stream->Read(&source_type)) << "Loading source type failed"; + Map> metadata; - // Deserializing metadata using the NDArray serialization utility. + + uint64_t sym_cnt; + CHECK(stream->Read(&sym_cnt, sizeof(sym_cnt))) << "Loading the number of symbols failed"; + + for (uint64_t i = 0; i < sym_cnt; i++) { + std::string sym; + CHECK(stream->Read(&sym)) << "Loading symbol name failed"; + // Load variable and ndarray pairs + std::vector variables; + std::vector arrays; + CHECK(stream->Read(&variables)) << "Loading variables failed"; + uint64_t sz; + CHECK(stream->Read(&sz, sizeof(sz))) << "Loading medata size failed"; + CHECK_EQ(static_cast(sz), variables.size()) + << "The number of variables and ndarray counts must match"; + for (uint64_t i = 0; i < sz; i++) { + tvm::runtime::NDArray temp; + temp.Load(stream); + arrays.push_back(temp); + } + Map var_const; + for (size_t i = 0; i < variables.size(); i++) { + var_const.Set(variables[i], arrays[i]); + } + metadata.Set(sym, var_const); + } auto n = runtime::make_object(metadata, source_type); return runtime::Module(n); } diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py index 5476fbe965fb..8502d227a966 100644 --- a/tests/python/relay/test_external_codegen.py +++ b/tests/python/relay/test_external_codegen.py @@ -260,6 +260,7 @@ def test_extern_dnnl(): f = set_external_func_attr(f, "dnnl", "dnnl_0") call = relay.Call(f, [data0, weight0, weight0]) mod = tvm.IRModule.from_expr(call) + print(mod) i_data = np.random.uniform(0, 1, ishape).astype(dtype) w_data = np.random.uniform(0, 1, w1shape).astype(dtype) @@ -270,6 +271,48 @@ def test_extern_dnnl(): (1, 32, 14, 14), ref_res.asnumpy(), tol=1e-5) +def test_extern_dnnl_const(): + if not tvm.get_global_func("relay.ext.dnnl", True): + print("skip because DNNL codegen is not available") + return + + dtype = 'float32' + ishape = (1, 32, 14, 14) + w1shape = (32, 1, 3, 3) + data0 = relay.var('data0', shape=(ishape), dtype=dtype) + w_data = np.random.uniform(0, 1, w1shape).astype(dtype) + + data1 = relay.var('data0', shape=(ishape), dtype=dtype) + weight1 = relay.const(w_data, dtype=dtype) + weight2 = relay.const(w_data, dtype=dtype) + depthwise_conv2d_1 = relay.nn.conv2d(data1, + weight1, + kernel_size=(3, 3), + padding=(1, 1), + groups=32) + depthwise_conv2d_2 = relay.nn.conv2d(depthwise_conv2d_1, + weight2, + kernel_size=(3, 3), + padding=(1, 1), + groups=32) + out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2) + + f = relay.Function([data1], out) + ref_mod = tvm.IRModule() + ref_mod['main'] = f + + f = set_external_func_attr(f, "dnnl", "dnnl_0") + call = relay.Call(f, [data0]) + mod = tvm.IRModule.from_expr(call) + + i_data = np.random.uniform(0, 1, ishape).astype(dtype) + + ref_ex = relay.create_executor("graph", mod=ref_mod, ctx=tvm.cpu()) + ref_res = ref_ex.evaluate()(i_data) + check_result(mod, {"data0": i_data}, + (1, 32, 14, 14), ref_res.asnumpy(), tol=1e-5) + + if __name__ == "__main__": test_multi_node_subgraph() test_extern_gcc_single_op() From 5ac7c60767154929b16ce7ca1076e9771945fec3 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Sun, 7 Jun 2020 05:43:47 +0000 Subject: [PATCH 05/13] handle more types, all tests pass --- .../backend/contrib/codegen_c/codegen.cc | 73 ++++++++++--------- src/target/source/source_module.cc | 55 ++++++++++---- tests/python/relay/test_external_codegen.py | 20 ++--- .../python/relay/test_pass_annotate_target.py | 14 +++- .../python/relay/test_pass_partition_graph.py | 14 +++- 5 files changed, 115 insertions(+), 61 deletions(-) diff --git a/src/relay/backend/contrib/codegen_c/codegen.cc b/src/relay/backend/contrib/codegen_c/codegen.cc index 2968966e8039..1308972a39c8 100644 --- a/src/relay/backend/contrib/codegen_c/codegen.cc +++ b/src/relay/backend/contrib/codegen_c/codegen.cc @@ -76,43 +76,19 @@ class CodegenC : public MemoizedExprTranslator>, public Code } std::vector VisitExpr_(const ConstantNode* cn) final { - // Note this is for demonstration purpose. ConstantNode doesn't necessarily - // belong to calls. We need to revisit this when tuples come into play. - std::ostringstream decl_stream; std::ostringstream buf_stream; Output output; - output.name = "const_" + std::to_string(const_idx_++); - - runtime::NDArray array = cn->data; - const auto& shape = array.Shape(); - - // Get the number of elements. - int64_t num_elems = 1; - for (auto i : shape) num_elems *= i; - + output.name = ext_func_id_ + "const_" + std::to_string(const_idx_++); const auto* type_node = cn->checked_type().as(); CHECK(type_node); const auto& dtype = GetDtypeString(type_node); - // Define a const buffer: float const_0[64] = {1.0, 2.0, ...}; - // - // Technically, you may need: static float* const_0 = (float*)malloc(4 * 64) - // to avoid possible stack overflow. - buf_stream << dtype << " " << output.name << "[" << num_elems << "] = {"; - if (dtype == "float") { - float* p_flt = static_cast(array->data); - for (int64_t i = 0; i < num_elems - 1; i++) buf_stream << p_flt[i] << ", "; - if (num_elems) buf_stream << p_flt[num_elems - 1]; - } else if (dtype == "int") { - int* p_flt = static_cast(array->data); - for (int64_t i = 0; i < num_elems - 1; i++) buf_stream << p_flt[i] << ", "; - if (num_elems) buf_stream << p_flt[num_elems - 1]; - } else { - LOG(FATAL) << "Only float and int are supported for now."; - } - buf_stream << "};"; - ext_func_body.insert(ext_func_body.begin(), buf_stream.str()); + CHECK(dtype == "float" || dtype == "int") << "Only float and int are supported for now."; + output.dtype = dtype; + + CHECK_EQ(metadata_.count(output.name), 0U) << "variable must be unique: " << output.name; + metadata_.Set(output.name, cn->data); return {output}; } @@ -201,6 +177,8 @@ class CodegenC : public MemoizedExprTranslator>, public Code return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body, out); } + Map GetMetadata() const { return metadata_; } + private: /*! \brief The function id that represents a C source function. */ std::string ext_func_id_ = ""; @@ -218,11 +196,13 @@ class CodegenC : public MemoizedExprTranslator>, public Code std::vector func_decl_; /*! \brief The declaration statements of buffers. */ std::vector buf_decl_; + /*! \brief The variable name to constant mapping. */ + Map metadata_; }; class CSourceCodegen : public CSourceModuleCodegenBase { public: - void GenCFunc(const Function& func) { + std::pair> GenCFunc(const Function& func) { CHECK(func.defined()) << "Input error: expect a Relay function."; // Record the external symbol for runtime lookup. @@ -231,6 +211,8 @@ class CSourceCodegen : public CSourceModuleCodegenBase { CodegenC builder(sid); auto out = builder.VisitExpr(func->body); code_stream_ << builder.JIT(out); + + return std::make_pair(sid, builder.GetMetadata()); } runtime::Module CreateCSourceModule(const ObjectRef& ref) override { @@ -262,22 +244,41 @@ class CSourceCodegen : public CSourceModuleCodegenBase { code_stream_ << operator_macro << "\n\n"; + Map code; + Map> metadata; if (ref->IsInstance()) { - GenCFunc(Downcast(ref)); + auto ret = GenCFunc(Downcast(ref)); + String sym = std::get<0>(ret); + Map consts = std::get<1>(ret); + std::string code_str = code_stream_.str(); + if (!consts.empty()) { + code_str = "#include \"metadata.h\"\n" + code_str; + metadata.Set(sym, consts); + } + code.Set(sym, code_str); } else if (ref->IsInstance()) { IRModule mod = Downcast(ref); for (const auto& it : mod->functions) { - GenCFunc(Downcast(it.second)); + auto ret = GenCFunc(Downcast(it.second)); + Map consts = std::get<1>(ret); + if (!consts.empty()) { + metadata.Set(std::get<0>(ret), consts); + } + } + std::string code_str = code_stream_.str(); + if (!metadata.empty()) { + code_str = "#include \"metadata.h\"\n" + code_str; } + code.Set("all", code_str); } else { LOG(FATAL) << "The input ref is expected to be a Relay function or module" << "\n"; } - // Create a CSourceModule - const auto* pf = runtime::Registry::Get("runtime.CSourceModuleCreate"); + // Create a PackagingModule + const auto* pf = runtime::Registry::Get("runtime.PackagingModuleCreate"); CHECK(pf != nullptr) << "Cannot find csource module to create the external runtime module"; - return (*pf)(code_stream_.str(), "cc"); + return (*pf)(code, "c", metadata); } private: diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index fef16ea969d3..6e4b2a398712 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -227,27 +227,56 @@ class ModuleInitWrapper : public runtime::ModuleNode { const char* type_key() const { return "module_init"; } - std::string InitCSourceMetadata() { + template + std::string GetElements(const std::string& var_name, const std::string& type_name, + const runtime::NDArray& arr) { std::ostringstream os; os.precision(16); + // Get the number of elements. + int64_t num_elems = 1; + for (auto i : arr.Shape()) num_elems *= i; + os << "static " << type_name << " " << var_name << "[" << num_elems << "] = {"; + T* ptr = static_cast(arr->data); + for (int64_t i = 0; i < num_elems - 1; i++) { + os << ptr[i] << ","; + } + if (num_elems > 0) os << ptr[num_elems - 1]; + os << "};\n"; + return os.str(); + } + + std::string InitCSourceMetadata() { + std::string ret = ""; for (const auto& it : metadata_) { for (const auto& vars : it.second) { - String var_name = vars.first; + std::string var_name = vars.first.operator std::string(); runtime::NDArray data = vars.second; - // Get the number of elements. - int64_t num_elems = 1; - for (auto i : data.Shape()) num_elems *= i; - // TODO(zhiics) Handle different data types. - os << "static float " << var_name.c_str() << "[" << num_elems << "] = {"; - const float* ptr = static_cast(data->data); - for (int64_t i = 0; i < num_elems - 1; i++) { - os << ptr[i] << ","; + CHECK(data->dtype.lanes == 1); + if (data->dtype.code == kDLFloat) { + if (data->dtype.bits == 32) { + ret += GetElements(var_name, "float", data); + } else { + CHECK_EQ(data->dtype.bits, 64); + ret += GetElements(var_name, "double", data); + } + } else if (data->dtype.code == kDLUInt) { + if (data->dtype.bits == 8) { + ret += GetElements(var_name, "uint8_t", data); + } else { + CHECK_EQ(data->dtype.bits, 32); + ret += GetElements(var_name, "uint32_t", data); + } + } else { + if (data->dtype.bits == 8) { + ret += GetElements(var_name, "int8_t", data); + } else { + CHECK_EQ(data->dtype.bits, 32); + ret += GetElements(var_name, "int32_t", data); + } } - if (num_elems > 0) os << ptr[num_elems - 1]; - os << "};\n"; } } - return os.str(); + return ret; } void InitSubModule(const std::string& symbol) { diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py index 8502d227a966..197de2aed972 100644 --- a/tests/python/relay/test_external_codegen.py +++ b/tests/python/relay/test_external_codegen.py @@ -34,16 +34,17 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", return def update_lib(lib): - ext_mod = tvm.target.PackagingModule(lib.imported_modules[0]) - code = ext_mod.source - metadata = ext_mod.metadata - src_ty = ext_mod.source_type - new_lib = lib - init_mod = tvm.target.ModuleInitWrapper(metadata, src_ty) - for _, src in code.items(): - init_mod.import_module(tvm.target.CSourceModule(src)) - new_lib.import_module(init_mod) + if lib.imported_modules: + ext_mod = tvm.target.PackagingModule(lib.imported_modules[0]) + code = ext_mod.source + metadata = ext_mod.metadata + src_ty = ext_mod.source_type + + init_mod = tvm.target.ModuleInitWrapper(metadata, src_ty) + for _, src in code.items(): + init_mod.import_module(tvm.target.CSourceModule(src)) + new_lib.import_module(init_mod) test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) source_dir = os.path.join(test_dir, "..", "..", "..") @@ -260,7 +261,6 @@ def test_extern_dnnl(): f = set_external_func_attr(f, "dnnl", "dnnl_0") call = relay.Call(f, [data0, weight0, weight0]) mod = tvm.IRModule.from_expr(call) - print(mod) i_data = np.random.uniform(0, 1, ishape).astype(dtype) w_data = np.random.uniform(0, 1, w1shape).astype(dtype) diff --git a/tests/python/relay/test_pass_annotate_target.py b/tests/python/relay/test_pass_annotate_target.py index 273c27b0d05f..49ed5fa0bb2b 100644 --- a/tests/python/relay/test_pass_annotate_target.py +++ b/tests/python/relay/test_pass_annotate_target.py @@ -35,6 +35,18 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", return def update_lib(lib): + new_lib = lib + if lib.imported_modules: + ext_mod = tvm.target.PackagingModule(lib.imported_modules[0]) + code = ext_mod.source + metadata = ext_mod.metadata + src_ty = ext_mod.source_type + + init_mod = tvm.target.ModuleInitWrapper(metadata, src_ty) + for _, src in code.items(): + init_mod.import_module(tvm.target.CSourceModule(src)) + new_lib.import_module(init_mod) + test_dir = os.path.dirname( os.path.realpath(os.path.expanduser(__file__))) source_dir = os.path.join(test_dir, "..", "..", "..") @@ -45,7 +57,7 @@ def update_lib(lib): tmp_path = util.tempdir() lib_name = 'lib.so' lib_path = tmp_path.relpath(lib_name) - lib.export_library(lib_path, fcompile=False, **kwargs) + new_lib.export_library(lib_path, fcompile=False, **kwargs) lib = runtime.load_module(lib_path) return lib diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 473ca9d66106..ddf753ff03df 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -179,6 +179,18 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", return def update_lib(lib): + new_lib = lib + if lib.imported_modules: + ext_mod = tvm.target.PackagingModule(lib.imported_modules[0]) + code = ext_mod.source + metadata = ext_mod.metadata + src_ty = ext_mod.source_type + + init_mod = tvm.target.ModuleInitWrapper(metadata, src_ty) + for _, src in code.items(): + init_mod.import_module(tvm.target.CSourceModule(src)) + new_lib.import_module(init_mod) + test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) source_dir = os.path.join(test_dir, "..", "..", "..") contrib_path = os.path.join(source_dir, "src", "runtime", "contrib") @@ -188,7 +200,7 @@ def update_lib(lib): tmp_path = util.tempdir() lib_name = 'lib.so' lib_path = tmp_path.relpath(lib_name) - lib.export_library(lib_path, fcompile=False, **kwargs) + new_lib.export_library(lib_path, fcompile=False, **kwargs) lib = runtime.load_module(lib_path) return lib From ce99649ce4e5d2591a103063a6d48d7a81a214f5 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Sun, 7 Jun 2020 23:42:35 +0000 Subject: [PATCH 06/13] create module_init_wrapper file --- 3rdparty/dlpack | 2 +- python/tvm/runtime/__init__.py | 2 +- python/tvm/runtime/module.py | 5 + python/tvm/target/__init__.py | 2 +- .../{packaging_module.py => source_module.py} | 12 +- .../backend/contrib/codegen_c/codegen.cc | 4 +- src/relay/backend/contrib/dnnl/codegen.cc | 4 +- src/runtime/module_init_wrapper.cc | 236 ++++++++++++++++++ src/target/source/source_module.cc | 201 +-------------- tests/python/relay/test_external_codegen.py | 4 +- .../python/relay/test_pass_annotate_target.py | 4 +- .../python/relay/test_pass_partition_graph.py | 4 +- 12 files changed, 265 insertions(+), 215 deletions(-) rename python/tvm/target/{packaging_module.py => source_module.py} (85%) create mode 100644 src/runtime/module_init_wrapper.cc diff --git a/3rdparty/dlpack b/3rdparty/dlpack index 3ec04430e89a..0acb731e0e43 160000 --- a/3rdparty/dlpack +++ b/3rdparty/dlpack @@ -1 +1 @@ -Subproject commit 3ec04430e89a6834e5a1b99471f415fa939bf642 +Subproject commit 0acb731e0e43d15deee27b66f10e4c5b4e667913 diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py index 21c06c517bd7..8bc81935695e 100644 --- a/python/tvm/runtime/__init__.py +++ b/python/tvm/runtime/__init__.py @@ -21,7 +21,7 @@ from .object import Object from .object_generic import ObjectGeneric, ObjectTypes from .ndarray import NDArray, DataType, DataTypeCode, TVMContext -from .module import Module +from .module import Module, ModuleInitWrapper # function exposures from .object_generic import convert_to_object, convert, const diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index e5f201bfd803..201b31f8e3a3 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -33,6 +33,11 @@ ProfileResult = namedtuple("ProfileResult", ["mean", "results"]) +def ModuleInitWrapper(metadata, source_type="c"): + """Create a module initialization wrapper""""" + return _ffi_api.ModuleInitWrapper(metadata, source_type) + + class Module(object): """Runtime Module.""" __slots__ = ["handle", "_entry", "entry_name"] diff --git a/python/tvm/target/__init__.py b/python/tvm/target/__init__.py index cfc19363eb1a..56171ef3cd8e 100644 --- a/python/tvm/target/__init__.py +++ b/python/tvm/target/__init__.py @@ -61,4 +61,4 @@ from . import datatype from . import codegen from .intrin import register_intrin_rule -from .packaging_module import PackagingModule, CSourceModule, ModuleInitWrapper +from .source_module import SourceMetadataModule, CSourceModule diff --git a/python/tvm/target/packaging_module.py b/python/tvm/target/source_module.py similarity index 85% rename from python/tvm/target/packaging_module.py rename to python/tvm/target/source_module.py index 20b71a2faed3..9f73ee207394 100644 --- a/python/tvm/target/packaging_module.py +++ b/python/tvm/target/source_module.py @@ -16,18 +16,17 @@ # under the License. # pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, invalid-name, redefined-builtin """ -APIs for a packaging module +Helper functions and classes for hanlding source and metdata. """ from tvm.runtime import _ffi_api -class PackagingModule: +class SourceMetadataModule: """The Packaging module""" def __init__(self, mod): self.mod = mod self._get_source = self.mod["get_source"] self._get_source_type = self.mod["get_source_type"] self._get_metadata = self.mod["get_metadata"] - self._is_c_source = self.mod["is_c_source"] @property def source(self): @@ -46,14 +45,9 @@ def metadata(self): def is_c_source(self): """Check if the source code is C/C++""" - return self._is_c_source() + return self.source_type == "c" or self.source_type == "cc" def CSourceModule(code, fmt="c"): """Create a C source module""" return _ffi_api.CSourceModuleCreate(code, fmt) - - -def ModuleInitWrapper(metadata, source_type="c"): - """Create a module initialization wrapper""" - return _ffi_api.ModuleInitWrapper(metadata, source_type) diff --git a/src/relay/backend/contrib/codegen_c/codegen.cc b/src/relay/backend/contrib/codegen_c/codegen.cc index 1308972a39c8..963382cca8f6 100644 --- a/src/relay/backend/contrib/codegen_c/codegen.cc +++ b/src/relay/backend/contrib/codegen_c/codegen.cc @@ -275,8 +275,8 @@ class CSourceCodegen : public CSourceModuleCodegenBase { << "\n"; } - // Create a PackagingModule - const auto* pf = runtime::Registry::Get("runtime.PackagingModuleCreate"); + // Create a SourceMetadataModuleNode + const auto* pf = runtime::Registry::Get("runtime.SourceMetadataModuleCreate"); CHECK(pf != nullptr) << "Cannot find csource module to create the external runtime module"; return (*pf)(code, "c", metadata); } diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index fa0f5ea36569..ba403db1dd62 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -412,8 +412,8 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase { << "\n"; } - // Create a PackagingModule - const auto* pf = runtime::Registry::Get("runtime.PackagingModuleCreate"); + // Create a SourceMetadataModuleNode + const auto* pf = runtime::Registry::Get("runtime.SourceMetadataModuleCreate"); CHECK(pf != nullptr) << "Cannot find csource module to create the external runtime module"; return (*pf)(code, "c", metadata); } diff --git a/src/runtime/module_init_wrapper.cc b/src/runtime/module_init_wrapper.cc new file mode 100644 index 000000000000..9d71ff01214c --- /dev/null +++ b/src/runtime/module_init_wrapper.cc @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/module_init_wrapper.cc + * \brief A wrapper for initializing modules using metadata + */ +#include +#include +#include +#include +#include +#include + +#include "file_util.h" + +namespace tvm { +namespace runtime { + +class CSourceMetadataInitializer { + public: + explicit CSourceMetadataInitializer(Map> metadata) + : metadata_(metadata) {} + + template + void GetElements(const std::string& var_name, const std::string& type_name, + const runtime::NDArray& arr) { + // Get the number of elements. + int64_t num_elems = 1; + for (auto i : arr.Shape()) num_elems *= i; + stream_ << "static " << type_name << " " << var_name << "[" << num_elems << "] = {"; + T* ptr = static_cast(arr->data); + for (int64_t i = 0; i < num_elems - 1; i++) { + stream_ << ptr[i] << ","; + } + if (num_elems > 0) stream_ << ptr[num_elems - 1]; + stream_ << "};\n"; + } + + std::string Init() { + for (const auto& it : metadata_) { + for (const auto& vars : it.second) { + std::string var_name = vars.first.operator std::string(); + runtime::NDArray data = vars.second; + CHECK(data->dtype.lanes == 1); + if (data->dtype.code == kDLFloat) { + if (data->dtype.bits == 32) { + stream_.precision(std::numeric_limits::digits10 + 1); + GetElements(var_name, "float", data); + } else { + CHECK_EQ(data->dtype.bits, 64); + stream_.precision(std::numeric_limits::digits10 + 1); + GetElements(var_name, "double", data); + } + } else if (data->dtype.code == kDLUInt) { + if (data->dtype.bits == 8) { + GetElements(var_name, "uint8_t", data); + } else { + CHECK_EQ(data->dtype.bits, 32); + GetElements(var_name, "uint32_t", data); + } + } else { + if (data->dtype.bits == 8) { + GetElements(var_name, "int8_t", data); + } else { + CHECK_EQ(data->dtype.bits, 32); + GetElements(var_name, "int32_t", data); + } + } + } + } + return stream_.str(); + } + + private: + /*! \brief The stream to print constant data. */ + std::ostringstream stream_; + /*! \brief A symbol to {var_name : NDArray} pair mapping. */ + Map> metadata_; +}; + +class ModuleInitWrapper : public runtime::ModuleNode { + public: + ModuleInitWrapper(Map> metadata, String source_type) + : metadata_(metadata), source_type_(source_type) {} + + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { + if (initialized_.count(name) == 0) { + this->InitSubModule(name); + initialized_[name] = true; + } + + if (name != "init_module" && name != "destroy_module") { + CHECK(!this->imports().empty()); + runtime::Module submodule = this->imports().at(0); + return submodule->GetFunction(name); + } + + return PackedFunc(); + } + + const char* type_key() const { return "module_init"; } + + void InitSubModule(const std::string& symbol) { + // Dispatch initializer according to the source type + if (source_type_ != "c") { + LOG(FATAL) << "Implement the initialization of json style runtime here"; + } else { + // TODO(zhiics) Handle json runtime. + // std::string initializer = "runtime.init." + source_type_; + // auto pf = tvm::runtime::Registry::Get(initializer); + // CHECK(pf) << "Failed to find the initializer for " << initializer; + } + } + + void SaveToFile(const std::string& file_name, const std::string& format) final { + // C source module relies on AOT compilation. The source code has already + // been generated. The used metadata is saved a separate file for + // compilation. + if (source_type_ == "c") { + std::string fmt = GetFileFormat(file_name, format); + CHECK_EQ(fmt, "h") << "Can only save to .h file"; + CSourceMetadataInitializer c_init(metadata_); + SaveBinaryToFile(file_name, c_init.Init()); + } + } + + void SaveToBinary(dmlc::Stream* stream) final { + stream->Write(source_type_.operator std::string()); + + // Save the total number of symbols + uint64_t sym_cnt = static_cast(metadata_.size()); + stream->Write(sym_cnt); + + for (const auto& it : metadata_) { + // Save the symbol/function name + stream->Write(it.first.operator std::string()); + + std::vector variables; + std::vector metadata; + for (const auto& vit : it.second) { + String var_name = vit.first; + variables.push_back(var_name.operator std::string()); + metadata.push_back(vit.second); + } + + // Save all variables in the function. + stream->Write(variables); + // Save all constant data + uint64_t sz = static_cast(metadata.size()); + stream->Write(sz); + for (uint64_t i = 0; i < sz; i++) { + metadata[i].Save(stream); + } + } + } + + static runtime::Module LoadFromBinary(void* strm) { + dmlc::Stream* stream = static_cast(strm); + std::string source_type; + CHECK(stream->Read(&source_type)) << "Loading source type failed"; + + Map> metadata; + + uint64_t sym_cnt; + CHECK(stream->Read(&sym_cnt, sizeof(sym_cnt))) << "Loading the number of symbols failed"; + + for (uint64_t i = 0; i < sym_cnt; i++) { + std::string sym; + CHECK(stream->Read(&sym)) << "Loading symbol name failed"; + // Load variable and ndarray pairs + std::vector variables; + std::vector arrays; + CHECK(stream->Read(&variables)) << "Loading variables failed"; + uint64_t sz; + CHECK(stream->Read(&sz, sizeof(sz))) << "Loading medata size failed"; + CHECK_EQ(static_cast(sz), variables.size()) + << "The number of variables and ndarray counts must match"; + for (uint64_t i = 0; i < sz; i++) { + tvm::runtime::NDArray temp; + temp.Load(stream); + arrays.push_back(temp); + } + Map var_const; + for (size_t i = 0; i < variables.size(); i++) { + var_const.Set(variables[i], arrays[i]); + } + metadata.Set(sym, var_const); + } + auto n = runtime::make_object(metadata, source_type); + return runtime::Module(n); + } + + private: + /*! + * \brief Record if a module is initialized. It is needed by imported + * modules using execution engine. + */ + std::unordered_map initialized_; + /*! \brief A symbol to {var_name : NDArray} pair mapping. */ + Map> metadata_; + /*! \brief The type of the source, i.e. c, or any customized json */ + String source_type_; +}; + +runtime::Module ModuleInitWrapperCreate(Map> metadata, + String source_type) { + auto n = make_object(metadata, source_type); + return runtime::Module(n); +} + +TVM_REGISTER_GLOBAL("runtime.ModuleInitWrapper") + .set_body_typed([](Map> metadata, String source_type) { + return ModuleInitWrapperCreate(metadata, source_type); + }); + +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_module_init") + .set_body_typed(ModuleInitWrapper::LoadFromBinary); +} // namespace runtime +} // namespace tvm diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index 6e4b2a398712..445d3744d538 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -24,7 +24,6 @@ #include #include #include -#include #include "../../runtime/file_util.h" #include "../../runtime/meta_data.h" @@ -156,10 +155,10 @@ runtime::Module DeviceSourceModuleCreate( // Pack the source code and metadata, where source code could be any // user-defined code, i.e. c source code, json graph representation, etc. -class PackagingModule final : public runtime::ModuleNode { +class SourceMetadataModuleNode final : public runtime::ModuleNode { public: - PackagingModule(Map code, const std::string& source_type, - Map> metadata) + SourceMetadataModuleNode(Map code, const std::string& source_type, + Map> metadata) : code_(code), source_type_(source_type), metadata_(metadata) {} PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { @@ -171,17 +170,12 @@ class PackagingModule final : public runtime::ModuleNode { } else if (name == "get_metadata") { return PackedFunc( [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->metadata_; }); - } else if (name == "is_c_source") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->IsCSourceCode(); }); } else { LOG(FATAL) << "Unknown packed function: " << name; return PackedFunc(nullptr); } } - bool IsCSourceCode() { return source_type_ == "c" || source_type_ == "cc"; } - const char* type_key() const { return "c"; } void SaveToFile(const std::string& file_name, const std::string& format) final { @@ -199,186 +193,14 @@ class PackagingModule final : public runtime::ModuleNode { Map> metadata_; }; -runtime::Module PackagingModuleCreate(Map code, std::string source_type, - Map> metadata) { - auto n = make_object(code, source_type, metadata); - return runtime::Module(n); -} - -class ModuleInitWrapper : public runtime::ModuleNode { - public: - ModuleInitWrapper(Map> metadata, String source_type) - : metadata_(metadata), source_type_(source_type) {} - - PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { - if (initialized_.count(name) == 0) { - this->InitSubModule(name); - initialized_[name] = true; - } - - if (name != "init_module" && name != "destroy_module") { - CHECK(!this->imports().empty()); - runtime::Module submodule = this->imports().at(0); - return submodule->GetFunction(name); - } - - return PackedFunc(); - } - - const char* type_key() const { return "module_init"; } - - template - std::string GetElements(const std::string& var_name, const std::string& type_name, - const runtime::NDArray& arr) { - std::ostringstream os; - os.precision(16); - // Get the number of elements. - int64_t num_elems = 1; - for (auto i : arr.Shape()) num_elems *= i; - os << "static " << type_name << " " << var_name << "[" << num_elems << "] = {"; - T* ptr = static_cast(arr->data); - for (int64_t i = 0; i < num_elems - 1; i++) { - os << ptr[i] << ","; - } - if (num_elems > 0) os << ptr[num_elems - 1]; - os << "};\n"; - return os.str(); - } - - std::string InitCSourceMetadata() { - std::string ret = ""; - for (const auto& it : metadata_) { - for (const auto& vars : it.second) { - std::string var_name = vars.first.operator std::string(); - runtime::NDArray data = vars.second; - CHECK(data->dtype.lanes == 1); - if (data->dtype.code == kDLFloat) { - if (data->dtype.bits == 32) { - ret += GetElements(var_name, "float", data); - } else { - CHECK_EQ(data->dtype.bits, 64); - ret += GetElements(var_name, "double", data); - } - } else if (data->dtype.code == kDLUInt) { - if (data->dtype.bits == 8) { - ret += GetElements(var_name, "uint8_t", data); - } else { - CHECK_EQ(data->dtype.bits, 32); - ret += GetElements(var_name, "uint32_t", data); - } - } else { - if (data->dtype.bits == 8) { - ret += GetElements(var_name, "int8_t", data); - } else { - CHECK_EQ(data->dtype.bits, 32); - ret += GetElements(var_name, "int32_t", data); - } - } - } - } - return ret; - } - - void InitSubModule(const std::string& symbol) { - // Dispatch initializer according to the source type - // std::string initializer = "runtime.init." + source_type_; - // auto pf = tvm::runtime::Registry::Get(initializer); - - // CHECK(pf) << "Failed to find the initializer for " << initializer; - if (source_type_ != "c") { - LOG(FATAL) << "Implement the initialization of json style runtime here"; - } - } - - void SaveToFile(const std::string& file_name, const std::string& format) final { - if (source_type_ == "c") { - std::string fmt = GetFileFormat(file_name, format); - CHECK_EQ(fmt, "h") << "Can only save to .h file"; - SaveBinaryToFile(file_name, InitCSourceMetadata()); - } - } - - void SaveToBinary(dmlc::Stream* stream) final { - stream->Write(source_type_.operator std::string()); - - // Save the total number of symbols - uint64_t sym_cnt = static_cast(metadata_.size()); - stream->Write(sym_cnt); - - for (const auto& it : metadata_) { - // Save the symbol/function name - stream->Write(it.first.operator std::string()); - - std::vector variables; - std::vector metadata; - for (const auto& vit : it.second) { - String var_name = vit.first; - variables.push_back(var_name.operator std::string()); - metadata.push_back(vit.second); - } - - // Save all variables in the function. - stream->Write(variables); - // Save all constant data - uint64_t sz = static_cast(metadata.size()); - stream->Write(sz); - for (uint64_t i = 0; i < sz; i++) { - metadata[i].Save(stream); - } - } - } - - static runtime::Module LoadFromBinary(void* strm) { - dmlc::Stream* stream = static_cast(strm); - std::string source_type; - CHECK(stream->Read(&source_type)) << "Loading source type failed"; - - Map> metadata; - - uint64_t sym_cnt; - CHECK(stream->Read(&sym_cnt, sizeof(sym_cnt))) << "Loading the number of symbols failed"; - - for (uint64_t i = 0; i < sym_cnt; i++) { - std::string sym; - CHECK(stream->Read(&sym)) << "Loading symbol name failed"; - // Load variable and ndarray pairs - std::vector variables; - std::vector arrays; - CHECK(stream->Read(&variables)) << "Loading variables failed"; - uint64_t sz; - CHECK(stream->Read(&sz, sizeof(sz))) << "Loading medata size failed"; - CHECK_EQ(static_cast(sz), variables.size()) - << "The number of variables and ndarray counts must match"; - for (uint64_t i = 0; i < sz; i++) { - tvm::runtime::NDArray temp; - temp.Load(stream); - arrays.push_back(temp); - } - Map var_const; - for (size_t i = 0; i < variables.size(); i++) { - var_const.Set(variables[i], arrays[i]); - } - metadata.Set(sym, var_const); - } - auto n = runtime::make_object(metadata, source_type); - return runtime::Module(n); - } - - private: - std::unordered_map initialized_; - /*! \brief A symbol to {var_name : NDArray} pair mapping. */ - Map> metadata_; - /*! \brief The type of the source, i.e. c, or any customized json */ - String source_type_; -}; - -runtime::Module ModuleInitWrapperCreate(Map> metadata, - String source_type) { - auto n = make_object(metadata, source_type); +runtime::Module SourceMetadataModuleCreate(Map code, std::string source_type, + Map> metadata) { + auto n = make_object(code, source_type, metadata); return runtime::Module(n); } -TVM_REGISTER_GLOBAL("runtime.PackagingModuleCreate").set_body_typed(PackagingModuleCreate); +TVM_REGISTER_GLOBAL("runtime.SourceMetadataModuleCreate") + .set_body_typed(SourceMetadataModuleCreate); TVM_REGISTER_GLOBAL("runtime.SourceModuleCreate").set_body_typed(SourceModuleCreate); @@ -387,12 +209,5 @@ TVM_REGISTER_GLOBAL("runtime.CSourceModuleCreate") return CSourceModuleCreate(code.operator std::string(), source_type.operator std::string()); }); -TVM_REGISTER_GLOBAL("runtime.ModuleInitWrapper") - .set_body_typed([](Map> metadata, String source_type) { - return ModuleInitWrapperCreate(metadata, source_type); - }); - -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_module_init") - .set_body_typed(ModuleInitWrapper::LoadFromBinary); } // namespace codegen } // namespace tvm diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py index 197de2aed972..c88372e04cb5 100644 --- a/tests/python/relay/test_external_codegen.py +++ b/tests/python/relay/test_external_codegen.py @@ -36,12 +36,12 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", def update_lib(lib): new_lib = lib if lib.imported_modules: - ext_mod = tvm.target.PackagingModule(lib.imported_modules[0]) + ext_mod = tvm.target.SourceMetadataModule(lib.imported_modules[0]) code = ext_mod.source metadata = ext_mod.metadata src_ty = ext_mod.source_type - init_mod = tvm.target.ModuleInitWrapper(metadata, src_ty) + init_mod = runtime.ModuleInitWrapper(metadata, src_ty) for _, src in code.items(): init_mod.import_module(tvm.target.CSourceModule(src)) new_lib.import_module(init_mod) diff --git a/tests/python/relay/test_pass_annotate_target.py b/tests/python/relay/test_pass_annotate_target.py index 49ed5fa0bb2b..376c78b7cce6 100644 --- a/tests/python/relay/test_pass_annotate_target.py +++ b/tests/python/relay/test_pass_annotate_target.py @@ -37,12 +37,12 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", def update_lib(lib): new_lib = lib if lib.imported_modules: - ext_mod = tvm.target.PackagingModule(lib.imported_modules[0]) + ext_mod = tvm.target.SourceMetadataModule(lib.imported_modules[0]) code = ext_mod.source metadata = ext_mod.metadata src_ty = ext_mod.source_type - init_mod = tvm.target.ModuleInitWrapper(metadata, src_ty) + init_mod = runtime.ModuleInitWrapper(metadata, src_ty) for _, src in code.items(): init_mod.import_module(tvm.target.CSourceModule(src)) new_lib.import_module(init_mod) diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index ddf753ff03df..2ae7d8943863 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -181,12 +181,12 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", def update_lib(lib): new_lib = lib if lib.imported_modules: - ext_mod = tvm.target.PackagingModule(lib.imported_modules[0]) + ext_mod = tvm.target.SourceMetadataModule(lib.imported_modules[0]) code = ext_mod.source metadata = ext_mod.metadata src_ty = ext_mod.source_type - init_mod = tvm.target.ModuleInitWrapper(metadata, src_ty) + init_mod = runtime.ModuleInitWrapper(metadata, src_ty) for _, src in code.items(): init_mod.import_module(tvm.target.CSourceModule(src)) new_lib.import_module(init_mod) From 0131bf1cc46a40090bcae53f25234847dd0bfac7 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Tue, 9 Jun 2020 16:43:30 +0000 Subject: [PATCH 07/13] module wrapper --- 3rdparty/dlpack | 2 +- python/tvm/runtime/module.py | 38 ++++++++++++++----- python/tvm/target/source_module.py | 4 +- src/relay/backend/build_module.cc | 18 ++++++--- src/relay/backend/vm/compiler.cc | 13 +++++-- src/target/source/codegen_source_base.h | 6 +++ src/target/source/source_module.cc | 29 ++++++++++---- tests/python/relay/test_external_codegen.py | 7 ++-- .../python/relay/test_pass_annotate_target.py | 6 +-- .../python/relay/test_pass_partition_graph.py | 6 +-- 10 files changed, 91 insertions(+), 38 deletions(-) diff --git a/3rdparty/dlpack b/3rdparty/dlpack index 0acb731e0e43..3ec04430e89a 160000 --- a/3rdparty/dlpack +++ b/3rdparty/dlpack @@ -1 +1 @@ -Subproject commit 0acb731e0e43d15deee27b66f10e4c5b4e667913 +Subproject commit 3ec04430e89a6834e5a1b99471f415fa939bf642 diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index 201b31f8e3a3..c2cbb7b5a5f0 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -232,7 +232,7 @@ def _collect_dso_metadata_modules(self): Helper function to collect dso modules and metadata init module. There is at most one medata init module if it exists. """ - visited, stack, dso_modules, medata_init = set(), [], [], [] + visited, stack, dso_modules, metadata_init = set(), [], [], [] # append root module visited.add(self) stack.append(self) @@ -241,18 +241,34 @@ def _collect_dso_metadata_modules(self): if module._dso_exportable(): dso_modules.append(module) elif module.type_key == "module_init": - medata_init.append(module) + metadata_init.append(module) for m in module.imported_modules: if m not in visited: visited.add(m) stack.append(m) - assert len(medata_init) <= 1, "At most one metadata init module is allowed." - return dso_modules, medata_init + return dso_modules, metadata_init def _dso_exportable(self): return self.type_key == "llvm" or self.type_key == "c" + def unwrap_modules(self): + """Unwrap the host and source metadata modules. + + Returns + ------- + ret : Tuple(runtime.Module, List[runtime.Module]) + The host module and a list of source metadata module pair. + """ + if not self.type_key == "module_class_wrapper": + return (self, None) + + assert len(self.imported_modules) > 1, \ + "Expect both host and source metadata module" + host_mod = self.imported_modules[0] + source_metadata_mods = self.imported_modules[1:] + return (host_mod, source_metadata_mods) + def export_library(self, file_name, fcompile=None, @@ -318,12 +334,14 @@ def export_library(self, llvm_target_triple = (module.type_key == "llvm" and module.get_function("_get_target_triple")()) - metadata_import = None if not metadata_init else metadata_init[0].imported_modules - if metadata_import and metadata_import[0].type_key == "c": - module = metadata_init[0] - header = temp.relpath("metadata.h") - module.save(header) - files.append(header) + for m in metadata_init: + metadata_import = m.imported_modules + assert len(metadata_import) == 1, \ + "A module should be wrapped in the initialization module." + if metadata_import[0].type_key == "c": + header = temp.relpath("metadata.h") + m.save(header) + files.append(header) if not fcompile: if file_name.endswith(".tar"): diff --git a/python/tvm/target/source_module.py b/python/tvm/target/source_module.py index 9f73ee207394..9d4c68715416 100644 --- a/python/tvm/target/source_module.py +++ b/python/tvm/target/source_module.py @@ -16,12 +16,12 @@ # under the License. # pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, invalid-name, redefined-builtin """ -Helper functions and classes for hanlding source and metdata. +Helper functions and classes for handling source and metdata. """ from tvm.runtime import _ffi_api class SourceMetadataModule: - """The Packaging module""" + """The module used to wrap both source and metadata.""" def __init__(self, mod): self.mod = mod self._get_source = self.mod["get_source"] diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index f9ce24d410b7..55e438f928e6 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -429,6 +429,8 @@ class RelayBuildModule : public runtime::ModuleNode { auto lowered_funcs = graph_codegen_->GetIRModule(); + runtime::Module tvm_lib; + // When there is no lowered_funcs due to reasons such as optimization. if (lowered_funcs.size() == 0) { Target target_host = GetTargetHost(); @@ -441,20 +443,26 @@ class RelayBuildModule : public runtime::ModuleNode { if (target_host.defined() && target_host->target_name == "llvm") { // If we can decide the target is LLVM, we then create an empty LLVM module. - ret_.mod = (*pf)(target_host->str(), "empty_module"); + tvm_lib = (*pf)(target_host->str(), "empty_module"); } else { // If we cannot decide the target is LLVM, we create an empty CSourceModule. // The code content is initialized with ";" to prevent complaining // from CSourceModuleNode::SaveToFile. - ret_.mod = tvm::codegen::CSourceModuleCreate(";", ""); + tvm_lib = tvm::codegen::CSourceModuleCreate(";", ""); } } else { - ret_.mod = tvm::build(lowered_funcs, target_host_); + tvm_lib = tvm::build(lowered_funcs, target_host_); } Array ext_mods = graph_codegen_->GetExternalModules(); - // Import all external runtime modules. - for (const auto& it : ext_mods) ret_.mod.Import(it); + if (!ext_mods.empty()) { + ret_.mod = tvm::codegen::ModuleClassWrapperCreate(); + ret_.mod.Import(tvm_lib); + // Import all external runtime modules. + for (const auto& it : ext_mods) ret_.mod.Import(it); + } else { + ret_.mod = std::move(tvm_lib); + } } private: diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 4d90c47ddf7f..b371a873ecb2 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -1020,11 +1020,16 @@ void VMCompiler::Codegen() { // to make sure a DSO module will be also available. mod = codegen::CSourceModuleCreate(";", ""); } - // Import all external runtime modules. - for (auto it : ext_mods) { - mod.Import(it); + if (!ext_mods.empty()) { + exec_->lib = codegen::ModuleClassWrapperCreate(); + exec_->lib.Import(mod); + // Import all external runtime modules. + for (auto it : ext_mods) { + exec_->lib.Import(it); + } + } else { + exec_->lib = mod; } - exec_->lib = mod; } runtime::Module CreateVMCompiler() { diff --git a/src/target/source/codegen_source_base.h b/src/target/source/codegen_source_base.h index 39016590abdc..316a7e5bd2f0 100644 --- a/src/target/source/codegen_source_base.h +++ b/src/target/source/codegen_source_base.h @@ -139,6 +139,12 @@ runtime::Module SourceModuleCreate(std::string code, std::string fmt); */ runtime::Module CSourceModuleCreate(std::string code, std::string fmt); +/*! + * \brief Create a helper module to wrap different modules. + * \return The created module. + */ +runtime::Module ModuleClassWrapperCreate(); + /*! * \brief Create a source module for viewing and limited saving for device. * \param data The code data to be viewed. diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index 445d3744d538..60b4905548e5 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -153,6 +153,23 @@ runtime::Module DeviceSourceModuleCreate( return runtime::Module(n); } +// A helper used to wrap different types of modules and pass through packedfunc. +// This module will never be used for compilation and execution. +class ModuleClassWrapperNode : public runtime::ModuleNode { + public: + ModuleClassWrapperNode() = default; + const char* type_key() const { return "module_class_wrapper"; } + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { + LOG(FATAL) << "Cannot execute module wrapper"; + return PackedFunc(); + } +}; + +runtime::Module ModuleClassWrapperCreate() { + auto n = make_object(); + return runtime::Module(n); +} + // Pack the source code and metadata, where source code could be any // user-defined code, i.e. c source code, json graph representation, etc. class SourceMetadataModuleNode final : public runtime::ModuleNode { @@ -176,13 +193,7 @@ class SourceMetadataModuleNode final : public runtime::ModuleNode { } } - const char* type_key() const { return "c"; } - - void SaveToFile(const std::string& file_name, const std::string& format) final { - std::string source_type = GetFileFormat(file_name, format); - CHECK_EQ(source_type, "cc") << "file_name: " << file_name << " must be a .cc file."; - SaveBinaryToFile(file_name, ";"); - } + const char* type_key() const { return "source_metadata"; } private: /*! \brief Symbol to source (e.g. c source/json) mapping. */ @@ -204,6 +215,10 @@ TVM_REGISTER_GLOBAL("runtime.SourceMetadataModuleCreate") TVM_REGISTER_GLOBAL("runtime.SourceModuleCreate").set_body_typed(SourceModuleCreate); +TVM_REGISTER_GLOBAL("runtime.ModuleClassWrapperCreate").set_body_typed([]() { + return ModuleClassWrapperCreate(); +}); + TVM_REGISTER_GLOBAL("runtime.CSourceModuleCreate") .set_body_typed([](String code, String source_type) { return CSourceModuleCreate(code.operator std::string(), source_type.operator std::string()); diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py index c88372e04cb5..6441ced18c3c 100644 --- a/tests/python/relay/test_external_codegen.py +++ b/tests/python/relay/test_external_codegen.py @@ -34,9 +34,9 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", return def update_lib(lib): - new_lib = lib - if lib.imported_modules: - ext_mod = tvm.target.SourceMetadataModule(lib.imported_modules[0]) + new_lib, source_metadata_mods = lib.unwrap_modules() + if source_metadata_mods: + ext_mod = tvm.target.SourceMetadataModule(source_metadata_mods[0]) code = ext_mod.source metadata = ext_mod.metadata src_ty = ext_mod.source_type @@ -319,3 +319,4 @@ def test_extern_dnnl_const(): test_extern_gcc_single_op_int() test_extern_gcc() test_extern_dnnl() + test_extern_dnnl_const() diff --git a/tests/python/relay/test_pass_annotate_target.py b/tests/python/relay/test_pass_annotate_target.py index 376c78b7cce6..7898ae14b226 100644 --- a/tests/python/relay/test_pass_annotate_target.py +++ b/tests/python/relay/test_pass_annotate_target.py @@ -35,9 +35,9 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", return def update_lib(lib): - new_lib = lib - if lib.imported_modules: - ext_mod = tvm.target.SourceMetadataModule(lib.imported_modules[0]) + new_lib, source_metadata_mods = lib.unwrap_modules() + if source_metadata_mods: + ext_mod = tvm.target.SourceMetadataModule(source_metadata_mods[0]) code = ext_mod.source metadata = ext_mod.metadata src_ty = ext_mod.source_type diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 2ae7d8943863..bb33d1042a6e 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -179,9 +179,9 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", return def update_lib(lib): - new_lib = lib - if lib.imported_modules: - ext_mod = tvm.target.SourceMetadataModule(lib.imported_modules[0]) + new_lib, source_metadata_mods = lib.unwrap_modules() + if source_metadata_mods: + ext_mod = tvm.target.SourceMetadataModule(source_metadata_mods[0]) code = ext_mod.source metadata = ext_mod.metadata src_ty = ext_mod.source_type From 4bc791e293615264e272dcf0e17b34cb5080fd73 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Thu, 11 Jun 2020 15:17:01 +0000 Subject: [PATCH 08/13] map->array --- python/tvm/runtime/module.py | 41 ++-- python/tvm/target/source_module.py | 13 ++ .../backend/contrib/codegen_c/codegen.cc | 33 +-- src/relay/backend/contrib/dnnl/codegen.cc | 33 +-- src/runtime/module_init_wrapper.cc | 216 +++++++++--------- src/target/source/source_module.cc | 38 ++- tests/python/relay/test_external_codegen.py | 7 +- .../python/relay/test_pass_annotate_target.py | 7 +- .../python/relay/test_pass_partition_graph.py | 7 +- 9 files changed, 222 insertions(+), 173 deletions(-) diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index c2cbb7b5a5f0..b3ee7b0cf938 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -33,9 +33,23 @@ ProfileResult = namedtuple("ProfileResult", ["mean", "results"]) -def ModuleInitWrapper(metadata, source_type="c"): - """Create a module initialization wrapper""""" - return _ffi_api.ModuleInitWrapper(metadata, source_type) +def ModuleInitWrapper(variables, metadata): + """Create a module initialization wrapper. + + Parameters + ---------- + variables : List[Str] + The list of variables. + + metadata : List[runtime.NDArray] + The list of used NDArray. + + Returns + ------- + ret : runtime.Module + The created module wrapper for initialization + """"" + return _ffi_api.ModuleInitWrapper(variables, metadata) class Module(object): @@ -232,7 +246,7 @@ def _collect_dso_metadata_modules(self): Helper function to collect dso modules and metadata init module. There is at most one medata init module if it exists. """ - visited, stack, dso_modules, metadata_init = set(), [], [], [] + visited, stack, dso_modules, metadata_init = set(), [], [], None # append root module visited.add(self) stack.append(self) @@ -241,7 +255,9 @@ def _collect_dso_metadata_modules(self): if module._dso_exportable(): dso_modules.append(module) elif module.type_key == "module_init": - metadata_init.append(module) + assert not metadata_init, \ + "At most one module initializer is allowed" + metadata_init = module for m in module.imported_modules: if m not in visited: visited.add(m) @@ -334,14 +350,13 @@ def export_library(self, llvm_target_triple = (module.type_key == "llvm" and module.get_function("_get_target_triple")()) - for m in metadata_init: - metadata_import = m.imported_modules - assert len(metadata_import) == 1, \ - "A module should be wrapped in the initialization module." - if metadata_import[0].type_key == "c": - header = temp.relpath("metadata.h") - m.save(header) - files.append(header) + if metadata_init: + for m in metadata_init.imported_modules: + if m.type_key == "c": + header = temp.relpath("metadata.h") + metadata_init.save(header) + files.append(header) + break if not fcompile: if file_name.endswith(".tar"): diff --git a/python/tvm/target/source_module.py b/python/tvm/target/source_module.py index 9d4c68715416..a4acb74138b6 100644 --- a/python/tvm/target/source_module.py +++ b/python/tvm/target/source_module.py @@ -25,9 +25,16 @@ class SourceMetadataModule: def __init__(self, mod): self.mod = mod self._get_source = self.mod["get_source"] + self._get_symbol = self.mod["get_symbol"] self._get_source_type = self.mod["get_source_type"] + self._get_variables = self.mod["get_vars"] self._get_metadata = self.mod["get_metadata"] + @property + def symbol(self): + """Get the source""" + return self._get_symbol() + @property def source(self): """Get the source""" @@ -43,6 +50,12 @@ def metadata(self): """Get the metadata""" return self._get_metadata() + @property + def variables(self): + """Get the metadata""" + return self._get_variables() + + def is_c_source(self): """Check if the source code is C/C++""" return self.source_type == "c" or self.source_type == "cc" diff --git a/src/relay/backend/contrib/codegen_c/codegen.cc b/src/relay/backend/contrib/codegen_c/codegen.cc index 963382cca8f6..7650ff636069 100644 --- a/src/relay/backend/contrib/codegen_c/codegen.cc +++ b/src/relay/backend/contrib/codegen_c/codegen.cc @@ -202,7 +202,7 @@ class CodegenC : public MemoizedExprTranslator>, public Code class CSourceCodegen : public CSourceModuleCodegenBase { public: - std::pair> GenCFunc(const Function& func) { + Map GenCFunc(const Function& func) { CHECK(func.defined()) << "Input error: expect a Relay function."; // Record the external symbol for runtime lookup. @@ -212,7 +212,7 @@ class CSourceCodegen : public CSourceModuleCodegenBase { auto out = builder.VisitExpr(func->body); code_stream_ << builder.JIT(out); - return std::make_pair(sid, builder.GetMetadata()); + return builder.GetMetadata(); } runtime::Module CreateCSourceModule(const ObjectRef& ref) override { @@ -244,32 +244,37 @@ class CSourceCodegen : public CSourceModuleCodegenBase { code_stream_ << operator_macro << "\n\n"; - Map code; - Map> metadata; + String func_symbol("all"); + String code; + Array variables; + Array metadata; if (ref->IsInstance()) { - auto ret = GenCFunc(Downcast(ref)); - String sym = std::get<0>(ret); - Map consts = std::get<1>(ret); + Map consts = GenCFunc(Downcast(ref)); std::string code_str = code_stream_.str(); if (!consts.empty()) { code_str = "#include \"metadata.h\"\n" + code_str; - metadata.Set(sym, consts); + for (const auto& it : consts) { + variables.push_back(it.first); + metadata.push_back(it.second); + } } - code.Set(sym, code_str); + code = code_str; } else if (ref->IsInstance()) { IRModule mod = Downcast(ref); for (const auto& it : mod->functions) { - auto ret = GenCFunc(Downcast(it.second)); - Map consts = std::get<1>(ret); + Map consts = GenCFunc(Downcast(it.second)); if (!consts.empty()) { - metadata.Set(std::get<0>(ret), consts); + for (const auto& it : consts) { + variables.push_back(it.first); + metadata.push_back(it.second); + } } } std::string code_str = code_stream_.str(); if (!metadata.empty()) { code_str = "#include \"metadata.h\"\n" + code_str; } - code.Set("all", code_str); + code = code_str; } else { LOG(FATAL) << "The input ref is expected to be a Relay function or module" << "\n"; @@ -278,7 +283,7 @@ class CSourceCodegen : public CSourceModuleCodegenBase { // Create a SourceMetadataModuleNode const auto* pf = runtime::Registry::Get("runtime.SourceMetadataModuleCreate"); CHECK(pf != nullptr) << "Cannot find csource module to create the external runtime module"; - return (*pf)(code, "c", metadata); + return (*pf)(func_symbol, code, "c", variables, metadata); } private: diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index ba403db1dd62..13206ce42868 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -342,7 +342,7 @@ class CodegenDNNL : public MemoizedExprTranslator>, public C class DNNLModuleCodegen : public CSourceModuleCodegenBase { public: // Create a corresponding DNNL function for the given relay Function. - std::pair> GenDNNLFunc(const Function& func) { + Map GenDNNLFunc(const Function& func) { CHECK(func.defined()) << "Input error: expect a Relay function."; // Record the external symbol for runtime lookup. @@ -352,7 +352,7 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase { auto out = builder.VisitExpr(func->body); code_stream_ << builder.JIT(out); - return std::make_pair(sid, builder.GetMetadata()); + return builder.GetMetadata(); } /*! @@ -381,32 +381,37 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase { code_stream_ << "using namespace tvm::runtime::contrib;\n"; code_stream_ << "\n"; - Map code; - Map> metadata; + String func_symbol("all"); + String code; + Array variables; + Array metadata; if (ref->IsInstance()) { - auto ret = GenDNNLFunc(Downcast(ref)); - String sym = std::get<0>(ret); - Map consts = std::get<1>(ret); + Map consts = GenDNNLFunc(Downcast(ref)); std::string code_str = code_stream_.str(); if (!consts.empty()) { code_str = "#include \"metadata.h\"\n" + code_str; - metadata.Set(sym, consts); + for (const auto& it : consts) { + variables.push_back(it.first); + metadata.push_back(it.second); + } } - code.Set(sym, code_str); + code = code_str; } else if (ref->IsInstance()) { IRModule mod = Downcast(ref); for (const auto& it : mod->functions) { - auto ret = GenDNNLFunc(Downcast(it.second)); - Map consts = std::get<1>(ret); + Map consts = GenDNNLFunc(Downcast(it.second)); if (!consts.empty()) { - metadata.Set(std::get<0>(ret), consts); + for (const auto& it : consts) { + variables.push_back(it.first); + metadata.push_back(it.second); + } } } std::string code_str = code_stream_.str(); if (!metadata.empty()) { code_str = "#include \"metadata.h\"\n" + code_str; } - code.Set("all", code_str); + code = code_str; } else { LOG(FATAL) << "The input ref is expected to be a Relay function or module" << "\n"; @@ -415,7 +420,7 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase { // Create a SourceMetadataModuleNode const auto* pf = runtime::Registry::Get("runtime.SourceMetadataModuleCreate"); CHECK(pf != nullptr) << "Cannot find csource module to create the external runtime module"; - return (*pf)(code, "c", metadata); + return (*pf)(func_symbol, code, "c", variables, metadata); } private: diff --git a/src/runtime/module_init_wrapper.cc b/src/runtime/module_init_wrapper.cc index 9d71ff01214c..a3ab635c22b6 100644 --- a/src/runtime/module_init_wrapper.cc +++ b/src/runtime/module_init_wrapper.cc @@ -25,6 +25,7 @@ #include #include #include + #include #include @@ -33,10 +34,11 @@ namespace tvm { namespace runtime { +using StringNDArrayMap = std::unordered_map; + class CSourceMetadataInitializer { public: - explicit CSourceMetadataInitializer(Map> metadata) - : metadata_(metadata) {} + explicit CSourceMetadataInitializer(const StringNDArrayMap& metadata) : metadata_(metadata) {} template void GetElements(const std::string& var_name, const std::string& type_name, @@ -55,33 +57,31 @@ class CSourceMetadataInitializer { std::string Init() { for (const auto& it : metadata_) { - for (const auto& vars : it.second) { - std::string var_name = vars.first.operator std::string(); - runtime::NDArray data = vars.second; - CHECK(data->dtype.lanes == 1); - if (data->dtype.code == kDLFloat) { - if (data->dtype.bits == 32) { - stream_.precision(std::numeric_limits::digits10 + 1); - GetElements(var_name, "float", data); - } else { - CHECK_EQ(data->dtype.bits, 64); - stream_.precision(std::numeric_limits::digits10 + 1); - GetElements(var_name, "double", data); - } - } else if (data->dtype.code == kDLUInt) { - if (data->dtype.bits == 8) { - GetElements(var_name, "uint8_t", data); - } else { - CHECK_EQ(data->dtype.bits, 32); - GetElements(var_name, "uint32_t", data); - } + std::string var_name = it.first.operator std::string(); + runtime::NDArray data = it.second; + CHECK_EQ(data->dtype.lanes, 1U); + if (data->dtype.code == kDLFloat) { + if (data->dtype.bits == 32) { + stream_.precision(std::numeric_limits::digits10 + 1); + GetElements(var_name, "float", data); + } else { + CHECK_EQ(data->dtype.bits, 64); + stream_.precision(std::numeric_limits::digits10 + 1); + GetElements(var_name, "double", data); + } + } else if (data->dtype.code == kDLUInt) { + if (data->dtype.bits == 8) { + GetElements(var_name, "uint8_t", data); + } else { + CHECK_EQ(data->dtype.bits, 32); + GetElements(var_name, "uint32_t", data); + } + } else { + if (data->dtype.bits == 8) { + GetElements(var_name, "int8_t", data); } else { - if (data->dtype.bits == 8) { - GetElements(var_name, "int8_t", data); - } else { - CHECK_EQ(data->dtype.bits, 32); - GetElements(var_name, "int32_t", data); - } + CHECK_EQ(data->dtype.bits, 32); + GetElements(var_name, "int32_t", data); } } } @@ -91,14 +91,18 @@ class CSourceMetadataInitializer { private: /*! \brief The stream to print constant data. */ std::ostringstream stream_; - /*! \brief A symbol to {var_name : NDArray} pair mapping. */ - Map> metadata_; + /*! \brief variable name to NDArray mapping. */ + StringNDArrayMap metadata_; }; class ModuleInitWrapper : public runtime::ModuleNode { public: - ModuleInitWrapper(Map> metadata, String source_type) - : metadata_(metadata), source_type_(source_type) {} + ModuleInitWrapper(const Array& variables, const Array& metadata) { + CHECK_EQ(variables.size(), metadata.size()); + for (size_t i = 0; i < variables.size(); i++) { + metadata_[variables[i]] = metadata[i]; + } + } PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { if (initialized_.count(name) == 0) { @@ -106,7 +110,7 @@ class ModuleInitWrapper : public runtime::ModuleNode { initialized_[name] = true; } - if (name != "init_module" && name != "destroy_module") { + if (name != "__InitModule") { CHECK(!this->imports().empty()); runtime::Module submodule = this->imports().at(0); return submodule->GetFunction(name); @@ -117,93 +121,93 @@ class ModuleInitWrapper : public runtime::ModuleNode { const char* type_key() const { return "module_init"; } + /*! + * \brief Initialize each imported module. + * \param symobl The symbol used for initializing a module. It is also used + * for runtime lookup. + * + * \note A module could be like the following: + * ModuleInitWrapper (contains all the metadata) + * - CSourceModule + * - JSON runtime module + * + * The initializer iterates through the wrapped module and intilizes them + * accordingly by passing the needed metadata into it. + */ void InitSubModule(const std::string& symbol) { // Dispatch initializer according to the source type - if (source_type_ != "c") { - LOG(FATAL) << "Implement the initialization of json style runtime here"; - } else { - // TODO(zhiics) Handle json runtime. - // std::string initializer = "runtime.init." + source_type_; - // auto pf = tvm::runtime::Registry::Get(initializer); - // CHECK(pf) << "Failed to find the initializer for " << initializer; - } + // TODO(zhiics) iterate through the imported modules to initialize + // for (const auto& it : this->imports()) { + // } } void SaveToFile(const std::string& file_name, const std::string& format) final { // C source module relies on AOT compilation. The source code has already // been generated. The used metadata is saved a separate file for // compilation. - if (source_type_ == "c") { + std::string consts = ""; + for (auto& it : this->imports()) { + if (!std::strcmp(it->type_key(), "c")) { + // TODO(zhiics) Maybe we need to store the list of required + // variales in the CSourceModule so that we can validate the + // existence of the variable and feed it only with the required + // ones. + CSourceMetadataInitializer c_init(metadata_); + consts += c_init.Init(); + consts += "\n"; + } + } + if (consts != "") { std::string fmt = GetFileFormat(file_name, format); CHECK_EQ(fmt, "h") << "Can only save to .h file"; - CSourceMetadataInitializer c_init(metadata_); - SaveBinaryToFile(file_name, c_init.Init()); + SaveBinaryToFile(file_name, consts); } } void SaveToBinary(dmlc::Stream* stream) final { - stream->Write(source_type_.operator std::string()); - - // Save the total number of symbols - uint64_t sym_cnt = static_cast(metadata_.size()); - stream->Write(sym_cnt); - + std::vector variables; + std::vector metadata; for (const auto& it : metadata_) { - // Save the symbol/function name - stream->Write(it.first.operator std::string()); - - std::vector variables; - std::vector metadata; - for (const auto& vit : it.second) { - String var_name = vit.first; - variables.push_back(var_name.operator std::string()); - metadata.push_back(vit.second); - } + String var_name = it.first; + variables.push_back(var_name.operator std::string()); + metadata.push_back(it.second); + } - // Save all variables in the function. - stream->Write(variables); - // Save all constant data - uint64_t sz = static_cast(metadata.size()); - stream->Write(sz); - for (uint64_t i = 0; i < sz; i++) { - metadata[i].Save(stream); - } + // Save all variables in the function. + stream->Write(variables); + // Save all constant data. + uint64_t sz = static_cast(metadata.size()); + stream->Write(sz); + for (uint64_t i = 0; i < sz; i++) { + metadata[i].Save(stream); } } static runtime::Module LoadFromBinary(void* strm) { dmlc::Stream* stream = static_cast(strm); - std::string source_type; - CHECK(stream->Read(&source_type)) << "Loading source type failed"; - - Map> metadata; - - uint64_t sym_cnt; - CHECK(stream->Read(&sym_cnt, sizeof(sym_cnt))) << "Loading the number of symbols failed"; - - for (uint64_t i = 0; i < sym_cnt; i++) { - std::string sym; - CHECK(stream->Read(&sym)) << "Loading symbol name failed"; - // Load variable and ndarray pairs - std::vector variables; - std::vector arrays; - CHECK(stream->Read(&variables)) << "Loading variables failed"; - uint64_t sz; - CHECK(stream->Read(&sz, sizeof(sz))) << "Loading medata size failed"; - CHECK_EQ(static_cast(sz), variables.size()) - << "The number of variables and ndarray counts must match"; - for (uint64_t i = 0; i < sz; i++) { - tvm::runtime::NDArray temp; - temp.Load(stream); - arrays.push_back(temp); - } - Map var_const; - for (size_t i = 0; i < variables.size(); i++) { - var_const.Set(variables[i], arrays[i]); - } - metadata.Set(sym, var_const); + + // Load the variables. + std::vector variables; + CHECK(stream->Read(&variables)) << "Loading variables failed"; + uint64_t sz; + CHECK(stream->Read(&sz, sizeof(sz))) << "Loading medata size failed"; + CHECK_EQ(static_cast(sz), variables.size()) + << "The number of variables and ndarray counts must match"; + // Load the list of ndarray. + std::vector metadata; + for (uint64_t i = 0; i < sz; i++) { + tvm::runtime::NDArray temp; + temp.Load(stream); + metadata.push_back(temp); + } + + Array vars; + Array consts; + for (size_t i = 0; i < variables.size(); i++) { + vars.push_back(variables[i]); + consts.push_back(metadata[i]); } - auto n = runtime::make_object(metadata, source_type); + auto n = runtime::make_object(vars, consts); return runtime::Module(n); } @@ -213,22 +217,16 @@ class ModuleInitWrapper : public runtime::ModuleNode { * modules using execution engine. */ std::unordered_map initialized_; - /*! \brief A symbol to {var_name : NDArray} pair mapping. */ - Map> metadata_; - /*! \brief The type of the source, i.e. c, or any customized json */ - String source_type_; + /*! \brief Variable name to NDArray mapping. */ + StringNDArrayMap metadata_; }; -runtime::Module ModuleInitWrapperCreate(Map> metadata, - String source_type) { - auto n = make_object(metadata, source_type); +runtime::Module ModuleInitWrapperCreate(Array variables, Array metadata) { + auto n = make_object(variables, metadata); return runtime::Module(n); } -TVM_REGISTER_GLOBAL("runtime.ModuleInitWrapper") - .set_body_typed([](Map> metadata, String source_type) { - return ModuleInitWrapperCreate(metadata, source_type); - }); +TVM_REGISTER_GLOBAL("runtime.ModuleInitWrapper").set_body_typed(ModuleInitWrapperCreate); TVM_REGISTER_GLOBAL("runtime.module.loadbinary_module_init") .set_body_typed(ModuleInitWrapper::LoadFromBinary); diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index 60b4905548e5..00bbe22bea14 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -174,9 +174,13 @@ runtime::Module ModuleClassWrapperCreate() { // user-defined code, i.e. c source code, json graph representation, etc. class SourceMetadataModuleNode final : public runtime::ModuleNode { public: - SourceMetadataModuleNode(Map code, const std::string& source_type, - Map> metadata) - : code_(code), source_type_(source_type), metadata_(metadata) {} + SourceMetadataModuleNode(const String& func_symbol, const String& code, const String& source_type, + const Array& variables, const Array& metadata) + : func_symbol_(func_symbol), + code_(code), + source_type_(source_type), + variables_(variables), + metadata_(metadata) {} PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { if (name == "get_source") { @@ -184,6 +188,12 @@ class SourceMetadataModuleNode final : public runtime::ModuleNode { } else if (name == "get_source_type") { return PackedFunc( [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->source_type_; }); + } else if (name == "get_symbol") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->func_symbol_; }); + } else if (name == "get_vars") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->variables_; }); } else if (name == "get_metadata") { return PackedFunc( [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->metadata_; }); @@ -196,17 +206,23 @@ class SourceMetadataModuleNode final : public runtime::ModuleNode { const char* type_key() const { return "source_metadata"; } private: - /*! \brief Symbol to source (e.g. c source/json) mapping. */ - Map code_; + /*! \brief The function symbols. */ + String func_symbol_; + /*! \brief The source code. */ + String code_; /*! \brief The type of the source code, e.g. c or any customized json type. */ - std::string source_type_; - /*! \brief Symbol to {var_name : NDArray} pair mapping. */ - Map> metadata_; + String source_type_; + /*! \brief The list of constant variables. */ + Array variables_; + /*! \brief The list of constant values that are corresponding to the variables. */ + Array metadata_; }; -runtime::Module SourceMetadataModuleCreate(Map code, std::string source_type, - Map> metadata) { - auto n = make_object(code, source_type, metadata); +runtime::Module SourceMetadataModuleCreate(String func_symbol, String code, String source_type, + Array variables, + Array metadata) { + auto n = + make_object(func_symbol, code, source_type, variables, metadata); return runtime::Module(n); } diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py index 6441ced18c3c..78f451a380ac 100644 --- a/tests/python/relay/test_external_codegen.py +++ b/tests/python/relay/test_external_codegen.py @@ -39,11 +39,10 @@ def update_lib(lib): ext_mod = tvm.target.SourceMetadataModule(source_metadata_mods[0]) code = ext_mod.source metadata = ext_mod.metadata - src_ty = ext_mod.source_type + variables = ext_mod.variables - init_mod = runtime.ModuleInitWrapper(metadata, src_ty) - for _, src in code.items(): - init_mod.import_module(tvm.target.CSourceModule(src)) + init_mod = runtime.ModuleInitWrapper(variables, metadata) + init_mod.import_module(tvm.target.CSourceModule(code)) new_lib.import_module(init_mod) test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) diff --git a/tests/python/relay/test_pass_annotate_target.py b/tests/python/relay/test_pass_annotate_target.py index 7898ae14b226..d88118060281 100644 --- a/tests/python/relay/test_pass_annotate_target.py +++ b/tests/python/relay/test_pass_annotate_target.py @@ -40,11 +40,10 @@ def update_lib(lib): ext_mod = tvm.target.SourceMetadataModule(source_metadata_mods[0]) code = ext_mod.source metadata = ext_mod.metadata - src_ty = ext_mod.source_type + variables = ext_mod.variables - init_mod = runtime.ModuleInitWrapper(metadata, src_ty) - for _, src in code.items(): - init_mod.import_module(tvm.target.CSourceModule(src)) + init_mod = runtime.ModuleInitWrapper(variables, metadata) + init_mod.import_module(tvm.target.CSourceModule(code)) new_lib.import_module(init_mod) test_dir = os.path.dirname( diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index bb33d1042a6e..894ebddaed7c 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -184,11 +184,10 @@ def update_lib(lib): ext_mod = tvm.target.SourceMetadataModule(source_metadata_mods[0]) code = ext_mod.source metadata = ext_mod.metadata - src_ty = ext_mod.source_type + variables = ext_mod.variables - init_mod = runtime.ModuleInitWrapper(metadata, src_ty) - for _, src in code.items(): - init_mod.import_module(tvm.target.CSourceModule(src)) + init_mod = runtime.ModuleInitWrapper(variables, metadata) + init_mod.import_module(tvm.target.CSourceModule(code)) new_lib.import_module(init_mod) test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) From a2a1d9f91bb7e4d6ddb5279b02db7825ac2bcbf2 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Fri, 12 Jun 2020 18:37:40 +0000 Subject: [PATCH 09/13] hide module wrappers --- python/tvm/runtime/__init__.py | 2 +- python/tvm/runtime/module.py | 36 ---------- python/tvm/target/__init__.py | 1 - python/tvm/target/source_module.py | 66 ------------------- src/relay/backend/build_module.cc | 16 ++--- src/relay/backend/vm/compiler.cc | 16 ++--- src/target/source/codegen_source_base.h | 7 ++ src/target/source/source_module.cc | 58 ++++++++++------ tests/python/relay/test_external_codegen.py | 13 +--- .../python/relay/test_pass_annotate_target.py | 13 +--- .../python/relay/test_pass_partition_graph.py | 13 +--- 11 files changed, 57 insertions(+), 184 deletions(-) delete mode 100644 python/tvm/target/source_module.py diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py index 8bc81935695e..21c06c517bd7 100644 --- a/python/tvm/runtime/__init__.py +++ b/python/tvm/runtime/__init__.py @@ -21,7 +21,7 @@ from .object import Object from .object_generic import ObjectGeneric, ObjectTypes from .ndarray import NDArray, DataType, DataTypeCode, TVMContext -from .module import Module, ModuleInitWrapper +from .module import Module # function exposures from .object_generic import convert_to_object, convert, const diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index b3ee7b0cf938..b2a681337fbf 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -33,25 +33,6 @@ ProfileResult = namedtuple("ProfileResult", ["mean", "results"]) -def ModuleInitWrapper(variables, metadata): - """Create a module initialization wrapper. - - Parameters - ---------- - variables : List[Str] - The list of variables. - - metadata : List[runtime.NDArray] - The list of used NDArray. - - Returns - ------- - ret : runtime.Module - The created module wrapper for initialization - """"" - return _ffi_api.ModuleInitWrapper(variables, metadata) - - class Module(object): """Runtime Module.""" __slots__ = ["handle", "_entry", "entry_name"] @@ -268,23 +249,6 @@ def _collect_dso_metadata_modules(self): def _dso_exportable(self): return self.type_key == "llvm" or self.type_key == "c" - def unwrap_modules(self): - """Unwrap the host and source metadata modules. - - Returns - ------- - ret : Tuple(runtime.Module, List[runtime.Module]) - The host module and a list of source metadata module pair. - """ - if not self.type_key == "module_class_wrapper": - return (self, None) - - assert len(self.imported_modules) > 1, \ - "Expect both host and source metadata module" - host_mod = self.imported_modules[0] - source_metadata_mods = self.imported_modules[1:] - return (host_mod, source_metadata_mods) - def export_library(self, file_name, fcompile=None, diff --git a/python/tvm/target/__init__.py b/python/tvm/target/__init__.py index 56171ef3cd8e..2553fedb9869 100644 --- a/python/tvm/target/__init__.py +++ b/python/tvm/target/__init__.py @@ -61,4 +61,3 @@ from . import datatype from . import codegen from .intrin import register_intrin_rule -from .source_module import SourceMetadataModule, CSourceModule diff --git a/python/tvm/target/source_module.py b/python/tvm/target/source_module.py deleted file mode 100644 index a4acb74138b6..000000000000 --- a/python/tvm/target/source_module.py +++ /dev/null @@ -1,66 +0,0 @@ -# License .to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, invalid-name, redefined-builtin -""" -Helper functions and classes for handling source and metdata. -""" -from tvm.runtime import _ffi_api - -class SourceMetadataModule: - """The module used to wrap both source and metadata.""" - def __init__(self, mod): - self.mod = mod - self._get_source = self.mod["get_source"] - self._get_symbol = self.mod["get_symbol"] - self._get_source_type = self.mod["get_source_type"] - self._get_variables = self.mod["get_vars"] - self._get_metadata = self.mod["get_metadata"] - - @property - def symbol(self): - """Get the source""" - return self._get_symbol() - - @property - def source(self): - """Get the source""" - return self._get_source() - - @property - def source_type(self): - """Get the source type""" - return self._get_source_type() - - @property - def metadata(self): - """Get the metadata""" - return self._get_metadata() - - @property - def variables(self): - """Get the metadata""" - return self._get_variables() - - - def is_c_source(self): - """Check if the source code is C/C++""" - return self.source_type == "c" or self.source_type == "cc" - - -def CSourceModule(code, fmt="c"): - """Create a C source module""" - return _ffi_api.CSourceModuleCreate(code, fmt) diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 55e438f928e6..25459e3686fa 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -429,8 +429,6 @@ class RelayBuildModule : public runtime::ModuleNode { auto lowered_funcs = graph_codegen_->GetIRModule(); - runtime::Module tvm_lib; - // When there is no lowered_funcs due to reasons such as optimization. if (lowered_funcs.size() == 0) { Target target_host = GetTargetHost(); @@ -443,25 +441,21 @@ class RelayBuildModule : public runtime::ModuleNode { if (target_host.defined() && target_host->target_name == "llvm") { // If we can decide the target is LLVM, we then create an empty LLVM module. - tvm_lib = (*pf)(target_host->str(), "empty_module"); + ret_.mod = (*pf)(target_host->str(), "empty_module"); } else { // If we cannot decide the target is LLVM, we create an empty CSourceModule. // The code content is initialized with ";" to prevent complaining // from CSourceModuleNode::SaveToFile. - tvm_lib = tvm::codegen::CSourceModuleCreate(";", ""); + ret_.mod = tvm::codegen::CSourceModuleCreate(";", ""); } } else { - tvm_lib = tvm::build(lowered_funcs, target_host_); + ret_.mod = tvm::build(lowered_funcs, target_host_); } Array ext_mods = graph_codegen_->GetExternalModules(); if (!ext_mods.empty()) { - ret_.mod = tvm::codegen::ModuleClassWrapperCreate(); - ret_.mod.Import(tvm_lib); - // Import all external runtime modules. - for (const auto& it : ext_mods) ret_.mod.Import(it); - } else { - ret_.mod = std::move(tvm_lib); + auto init_mod = tvm::codegen::WrapMetadataModule(ext_mods); + ret_.mod.Import(init_mod); } } diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index b371a873ecb2..912a5f65fd36 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -1007,28 +1007,20 @@ void VMCompiler::Codegen() { auto compile_engine = CompileEngine::Global(); auto ext_mods = compile_engine->LowerExternalFunctions(); - runtime::Module mod; if (funcs.size() > 0) { Map build_funcs; for (const auto& i : funcs) { build_funcs.Set(i.first, i.second); } - mod = tvm::build(build_funcs, target_host_); - CHECK(mod.operator->()); + exec_->lib = tvm::build(build_funcs, target_host_); } else { // There is no function handled by TVM. We create a virtual master module // to make sure a DSO module will be also available. - mod = codegen::CSourceModuleCreate(";", ""); + exec_->lib = codegen::CSourceModuleCreate(";", ""); } if (!ext_mods.empty()) { - exec_->lib = codegen::ModuleClassWrapperCreate(); - exec_->lib.Import(mod); - // Import all external runtime modules. - for (auto it : ext_mods) { - exec_->lib.Import(it); - } - } else { - exec_->lib = mod; + auto init_mod = codegen::WrapMetadataModule(ext_mods); + exec_->lib.Import(init_mod); } } diff --git a/src/target/source/codegen_source_base.h b/src/target/source/codegen_source_base.h index 316a7e5bd2f0..e677f6cd550d 100644 --- a/src/target/source/codegen_source_base.h +++ b/src/target/source/codegen_source_base.h @@ -145,6 +145,13 @@ runtime::Module CSourceModuleCreate(std::string code, std::string fmt); */ runtime::Module ModuleClassWrapperCreate(); +/*! + * \brief Wrap the submodules in a metadata module. + * \param modules The modules to be wrapped. + * \return The wrapped module. + */ +runtime::Module WrapMetadataModule(const Array& modules); + /*! * \brief Create a source module for viewing and limited saving for device. * \param data The code data to be viewed. diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index 00bbe22bea14..fba7d5cb04ce 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -41,6 +41,43 @@ using runtime::GetFileFormat; using runtime::GetMetaFilePath; using runtime::SaveBinaryToFile; +runtime::Module WrapMetadataModule(const Array& modules) { + // Wrap all submodules in the initialization wrapper. + Map var_md; + Array source_modules; + for (runtime::Module it : modules) { + String code = it.GetFunction("get_source")(); + Array variables = it.GetFunction("get_vars")(); + Array metadata = it.GetFunction("get_metadata")(); + CHECK_EQ(variables.size(), metadata.size()) + << "Found mismatch in the number of variables and ndarray"; + for (size_t i = 0; i < variables.size(); i++) { + var_md.Set(variables[i], metadata[i]); + } + + // TODO(zhiics) Invoke the corresponding module create function using the + // type key when json runtime comes. + source_modules.push_back(tvm::codegen::CSourceModuleCreate(code, "c")); + } + Array vars; + Array arrs; + for (const auto& it : var_md) { + vars.push_back(it.first); + arrs.push_back(it.second); + } + + // Wrap the modules. + const auto* pf = runtime::Registry::Get("runtime.ModuleInitWrapper"); + CHECK(pf != nullptr) << "Cannot find the registry for runtime.ModuleInitWrapper"; + runtime::Module init_m = (*pf)(vars, arrs); + + for (const auto& it : source_modules) { + init_m.Import(it); + } + + return init_m; +} + // Simulator function class SourceModuleNode : public runtime::ModuleNode { public: @@ -153,23 +190,6 @@ runtime::Module DeviceSourceModuleCreate( return runtime::Module(n); } -// A helper used to wrap different types of modules and pass through packedfunc. -// This module will never be used for compilation and execution. -class ModuleClassWrapperNode : public runtime::ModuleNode { - public: - ModuleClassWrapperNode() = default; - const char* type_key() const { return "module_class_wrapper"; } - PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { - LOG(FATAL) << "Cannot execute module wrapper"; - return PackedFunc(); - } -}; - -runtime::Module ModuleClassWrapperCreate() { - auto n = make_object(); - return runtime::Module(n); -} - // Pack the source code and metadata, where source code could be any // user-defined code, i.e. c source code, json graph representation, etc. class SourceMetadataModuleNode final : public runtime::ModuleNode { @@ -231,10 +251,6 @@ TVM_REGISTER_GLOBAL("runtime.SourceMetadataModuleCreate") TVM_REGISTER_GLOBAL("runtime.SourceModuleCreate").set_body_typed(SourceModuleCreate); -TVM_REGISTER_GLOBAL("runtime.ModuleClassWrapperCreate").set_body_typed([]() { - return ModuleClassWrapperCreate(); -}); - TVM_REGISTER_GLOBAL("runtime.CSourceModuleCreate") .set_body_typed([](String code, String source_type) { return CSourceModuleCreate(code.operator std::string(), source_type.operator std::string()); diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py index 78f451a380ac..6771bd10c2d6 100644 --- a/tests/python/relay/test_external_codegen.py +++ b/tests/python/relay/test_external_codegen.py @@ -34,17 +34,6 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", return def update_lib(lib): - new_lib, source_metadata_mods = lib.unwrap_modules() - if source_metadata_mods: - ext_mod = tvm.target.SourceMetadataModule(source_metadata_mods[0]) - code = ext_mod.source - metadata = ext_mod.metadata - variables = ext_mod.variables - - init_mod = runtime.ModuleInitWrapper(variables, metadata) - init_mod.import_module(tvm.target.CSourceModule(code)) - new_lib.import_module(init_mod) - test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) source_dir = os.path.join(test_dir, "..", "..", "..") contrib_path = os.path.join(source_dir, "src", "runtime", "contrib") @@ -54,7 +43,7 @@ def update_lib(lib): tmp_path = util.tempdir() lib_name = 'lib.so' lib_path = tmp_path.relpath(lib_name) - new_lib.export_library(lib_path, fcompile=False, **kwargs) + lib.export_library(lib_path, fcompile=False, **kwargs) lib = tvm.runtime.load_module(lib_path) return lib diff --git a/tests/python/relay/test_pass_annotate_target.py b/tests/python/relay/test_pass_annotate_target.py index d88118060281..273c27b0d05f 100644 --- a/tests/python/relay/test_pass_annotate_target.py +++ b/tests/python/relay/test_pass_annotate_target.py @@ -35,17 +35,6 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", return def update_lib(lib): - new_lib, source_metadata_mods = lib.unwrap_modules() - if source_metadata_mods: - ext_mod = tvm.target.SourceMetadataModule(source_metadata_mods[0]) - code = ext_mod.source - metadata = ext_mod.metadata - variables = ext_mod.variables - - init_mod = runtime.ModuleInitWrapper(variables, metadata) - init_mod.import_module(tvm.target.CSourceModule(code)) - new_lib.import_module(init_mod) - test_dir = os.path.dirname( os.path.realpath(os.path.expanduser(__file__))) source_dir = os.path.join(test_dir, "..", "..", "..") @@ -56,7 +45,7 @@ def update_lib(lib): tmp_path = util.tempdir() lib_name = 'lib.so' lib_path = tmp_path.relpath(lib_name) - new_lib.export_library(lib_path, fcompile=False, **kwargs) + lib.export_library(lib_path, fcompile=False, **kwargs) lib = runtime.load_module(lib_path) return lib diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 894ebddaed7c..473ca9d66106 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -179,17 +179,6 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", return def update_lib(lib): - new_lib, source_metadata_mods = lib.unwrap_modules() - if source_metadata_mods: - ext_mod = tvm.target.SourceMetadataModule(source_metadata_mods[0]) - code = ext_mod.source - metadata = ext_mod.metadata - variables = ext_mod.variables - - init_mod = runtime.ModuleInitWrapper(variables, metadata) - init_mod.import_module(tvm.target.CSourceModule(code)) - new_lib.import_module(init_mod) - test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) source_dir = os.path.join(test_dir, "..", "..", "..") contrib_path = os.path.join(source_dir, "src", "runtime", "contrib") @@ -199,7 +188,7 @@ def update_lib(lib): tmp_path = util.tempdir() lib_name = 'lib.so' lib_path = tmp_path.relpath(lib_name) - new_lib.export_library(lib_path, fcompile=False, **kwargs) + lib.export_library(lib_path, fcompile=False, **kwargs) lib = runtime.load_module(lib_path) return lib From acf2e51b907887bb2e5c180470d49fc53e9df00b Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Sat, 13 Jun 2020 01:09:57 +0000 Subject: [PATCH 10/13] init each function --- python/tvm/contrib/graph_runtime.py | 6 +- python/tvm/runtime/module.py | 29 +-- python/tvm/runtime/vm.py | 13 +- src/relay/backend/build_module.cc | 3 +- src/relay/backend/compile_engine.cc | 39 +-- .../backend/contrib/codegen_c/codegen.cc | 87 +++---- .../backend/contrib/codegen_c/codegen_c.h | 30 ++- src/relay/backend/contrib/dnnl/codegen.cc | 88 +++---- src/relay/backend/graph_runtime_codegen.cc | 8 + src/relay/backend/utils.h | 20 ++ src/relay/backend/vm/compiler.cc | 7 +- src/runtime/meta_data.h | 16 ++ src/runtime/metadata_module.cc | 220 ++++++++++++++++ src/runtime/module_init_wrapper.cc | 234 ------------------ src/target/source/codegen_source_base.h | 20 +- src/target/source/source_module.cc | 126 +++------- tests/python/relay/test_external_runtime.py | 6 +- .../unittest/test_runtime_module_export.py | 3 +- 18 files changed, 468 insertions(+), 487 deletions(-) create mode 100644 src/runtime/metadata_module.cc delete mode 100644 src/runtime/module_init_wrapper.cc diff --git a/python/tvm/contrib/graph_runtime.py b/python/tvm/contrib/graph_runtime.py index 740d1c3f19f3..46077278e473 100644 --- a/python/tvm/contrib/graph_runtime.py +++ b/python/tvm/contrib/graph_runtime.py @@ -162,7 +162,11 @@ def set_input(self, key=None, value=None, **params): keys = list(params.keys()) keys.sort(key=lambda x: -np.prod(params[x].shape)) for k in keys: - self._get_input(k).copyfrom(params[k]) + # TODO(zhiics) Skip the weights for submodule in a better way. + # We could get all inputs required by graphruntime first, + # we should use MetadataModule for initialization. + if "_const_" not in k: + self._get_input(k).copyfrom(params[k]) def run(self, **input_dict): """Run forward execution of the graph diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index b2a681337fbf..3cdb28f8c496 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -222,12 +222,9 @@ def evaluator(*args): except NameError: raise NameError("time_evaluate is only supported when RPC is enabled") - def _collect_dso_metadata_modules(self): - """ - Helper function to collect dso modules and metadata init module. There - is at most one medata init module if it exists. - """ - visited, stack, dso_modules, metadata_init = set(), [], [], None + def _collect_dso_modules(self): + """Helper function to collect dso modules, then return it.""" + visited, stack, dso_modules = set(), [], [] # append root module visited.add(self) stack.append(self) @@ -235,16 +232,11 @@ def _collect_dso_metadata_modules(self): module = stack.pop() if module._dso_exportable(): dso_modules.append(module) - elif module.type_key == "module_init": - assert not metadata_init, \ - "At most one module initializer is allowed" - metadata_init = module for m in module.imported_modules: if m not in visited: visited.add(m) stack.append(m) - - return dso_modules, metadata_init + return dso_modules def _dso_exportable(self): return self.type_key == "llvm" or self.type_key == "c" @@ -290,13 +282,13 @@ def export_library(self, self.save(file_name) return - dso_modules, metadata_init = self._collect_dso_metadata_modules() + modules = self._collect_dso_modules() temp = _util.tempdir() files = addons if addons else [] is_system_lib = False has_c_module = False llvm_target_triple = None - for index, module in enumerate(dso_modules): + for index, module in enumerate(modules): if fcompile is not None and hasattr(fcompile, "object_format"): object_format = fcompile.object_format else: @@ -313,15 +305,6 @@ def export_library(self, module.get_function("__tvm_is_system_module")()) llvm_target_triple = (module.type_key == "llvm" and module.get_function("_get_target_triple")()) - - if metadata_init: - for m in metadata_init.imported_modules: - if m.type_key == "c": - header = temp.relpath("metadata.h") - metadata_init.save(header) - files.append(header) - break - if not fcompile: if file_name.endswith(".tar"): fcompile = _tar.tar diff --git a/python/tvm/runtime/vm.py b/python/tvm/runtime/vm.py index 2643ff131ba0..8a85051d6bbd 100644 --- a/python/tvm/runtime/vm.py +++ b/python/tvm/runtime/vm.py @@ -106,7 +106,7 @@ def save(self): import numpy as np import tvm -from tvm import te + from tvm import te from tvm import relay # define a simple network. x = relay.var('x', shape=(10, 10)) @@ -309,12 +309,17 @@ def set_input(self, func_name, *args, **kwargs): Named arguments to the function. """ if kwargs: + # kwargs is a super set of the required function parameters. We + # only find the ones that are needed. func_params = self._exec.get_function_params(func_name) new_args = [None] * len(func_params) - assert len(args) + len(kwargs) == len(func_params) + cnt = 0 for k in kwargs: - idx = func_params.index(k) - new_args[idx] = kwargs[k] + if k in func_params: + idx = func_params.index(k) + new_args[idx] = kwargs[k] + cnt += 1 + assert len(args) + cnt == len(func_params) idx = 0 for i, arg in enumerate(new_args): if arg is None: diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 25459e3686fa..01cf6057d474 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -454,8 +454,7 @@ class RelayBuildModule : public runtime::ModuleNode { Array ext_mods = graph_codegen_->GetExternalModules(); if (!ext_mods.empty()) { - auto init_mod = tvm::codegen::WrapMetadataModule(ext_mods); - ret_.mod.Import(init_mod); + ret_.mod = tvm::codegen::WrapMetadataModule(ret_.params, ret_.mod, ext_mods); } } diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index be749fdd3a97..ecd7fb9e8f62 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -38,6 +38,7 @@ #include #include +#include #include #include #include @@ -572,7 +573,8 @@ class CompileEngineImpl : public CompileEngineNode { } Array LowerExternalFunctions() { - std::unordered_map ext_mods; + Array ret; + std::unordered_map cached_symbol; std::vector cached_ext_funcs; for (const auto& it : cache_) { auto src_func = it.first->source_func; @@ -581,29 +583,31 @@ class CompileEngineImpl : public CompileEngineNode { auto code_gen = src_func->GetAttr(attr::kCompiler); CHECK(code_gen.defined()) << "No external codegen is set"; std::string code_gen_name = code_gen.value(); - if (ext_mods.find(code_gen_name) == ext_mods.end()) { - ext_mods[code_gen_name] = IRModule({}, {}); - } + cached_ext_funcs.push_back(it.first); + auto symbol_name = src_func->GetAttr(tvm::attr::kGlobalSymbol); CHECK(symbol_name.defined()) << "No external symbol is set for:\n" << AsText(src_func, false); - auto gv = GlobalVar(symbol_name.value()); + + std::string sn = symbol_name.value(); + if (cached_symbol.count(sn)) { + cached_symbol[sn] = code_gen_name; + } else { + CHECK_NE(sn, code_gen_name) + << "Found duplicated symbol: " << sn << " for: " << code_gen_name; + } + + std::string ext_name = "relay.ext." + code_gen_name; + auto pf = tvm::runtime::Registry::Get(ext_name); + CHECK(pf) << "Failed to find the codegen tool for " << ext_name << "\n"; // No need to keep compiler attribute at this point, functions have been // extracted for specific codegen. src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue()); - ext_mods[code_gen_name]->Add(gv, src_func); - cached_ext_funcs.push_back(it.first); - } - } + runtime::Module ext_mod = (*pf)(src_func); - Array ret; - for (const auto& it : ext_mods) { - std::string ext_name = "relay.ext." + it.first; - auto pf = tvm::runtime::Registry::Get(ext_name); - CHECK(pf) << "Failed to find the codegen tool for " << ext_name << "\n"; - runtime::Module ext_mod = (*pf)(it.second); - CHECK(ext_mod.defined()) << "No external runtime is generated."; - ret.push_back(ext_mod); + CHECK(ext_mod.defined()) << "No external runtime is generated."; + ret.push_back(ext_mod); + } } // No need to cache external functions as we collected them all to create @@ -659,6 +663,7 @@ class CompileEngineImpl : public CompileEngineNode { CHECK(name_node.defined()) << "External function has not been attached a name yet."; cache_node->func_name = std::string(name_node.value()); cache_node->target = tvm::target::ext_dev(); + cache_node->funcs->Add(GlobalVar(cache_node->func_name), key->source_func); value->cached_func = CachedFunc(cache_node); return value; } diff --git a/src/relay/backend/contrib/codegen_c/codegen.cc b/src/relay/backend/contrib/codegen_c/codegen.cc index 7650ff636069..a0118e3da544 100644 --- a/src/relay/backend/contrib/codegen_c/codegen.cc +++ b/src/relay/backend/contrib/codegen_c/codegen.cc @@ -25,6 +25,7 @@ #include #include +#include #include "../../utils.h" #include "codegen_c.h" @@ -80,15 +81,27 @@ class CodegenC : public MemoizedExprTranslator>, public Code std::ostringstream buf_stream; Output output; - output.name = ext_func_id_ + "const_" + std::to_string(const_idx_++); + // Get const: static_cast(dnnl_0_consts[0]->data) + output.name = "static_cast(" + ext_func_id_ + "_consts[" + std::to_string(const_idx_) + + "]->data)"; const auto* type_node = cn->checked_type().as(); CHECK(type_node); const auto& dtype = GetDtypeString(type_node); + + // Generate the global variable for needed ndarrays + if (const_array_.empty()) { + const_array_ = "Array " + ext_func_id_ + "_consts;"; + std::ostringstream buf_stream; + buf_stream << "CHECK(!" << ext_func_id_ + << "_consts.empty()) << \"C source module hasn't been initialized.\";\n"; + ext_func_body_.insert(ext_func_body_.begin(), buf_stream.str()); + } + CHECK(dtype == "float" || dtype == "int") << "Only float and int are supported for now."; output.dtype = dtype; - CHECK_EQ(metadata_.count(output.name), 0U) << "variable must be unique: " << output.name; - metadata_.Set(output.name, cn->data); + std::string const_var_name = ext_func_id_ + "_const_" + std::to_string(const_idx_++); + const_vars_.push_back(const_var_name); return {output}; } @@ -151,7 +164,7 @@ class CodegenC : public MemoizedExprTranslator>, public Code buf_decl_.push_back(buf_stream.str()); decl_stream << ", " << out << ");"; - ext_func_body.push_back(decl_stream.str()); + ext_func_body_.push_back(decl_stream.str()); // Update output buffer // Note C codegen only handles TensorType. Therefore, we don't flatten @@ -174,11 +187,9 @@ class CodegenC : public MemoizedExprTranslator>, public Code for (auto decl : func_decl_) { code_stream_ << decl << "\n"; } - return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body, out); + return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body_, const_array_, out); } - Map GetMetadata() const { return metadata_; } - private: /*! \brief The function id that represents a C source function. */ std::string ext_func_id_ = ""; @@ -191,18 +202,22 @@ class CodegenC : public MemoizedExprTranslator>, public Code /*! \brief The arguments of a C compiler compatible function. */ Array ext_func_args_; /*! \brief The statements of a C compiler compatible function. */ - std::vector ext_func_body; + std::vector ext_func_body_; + /*! \brief The array declared to store the constant values. */ + std::string const_array_; /*! \brief The declaration statements of a C compiler compatible function. */ std::vector func_decl_; /*! \brief The declaration statements of buffers. */ std::vector buf_decl_; /*! \brief The variable name to constant mapping. */ - Map metadata_; + Array const_vars_; + + friend class CSourceCodegen; }; class CSourceCodegen : public CSourceModuleCodegenBase { public: - Map GenCFunc(const Function& func) { + std::pair> GenCFunc(const Function& func) { CHECK(func.defined()) << "Input error: expect a Relay function."; // Record the external symbol for runtime lookup. @@ -212,15 +227,18 @@ class CSourceCodegen : public CSourceModuleCodegenBase { auto out = builder.VisitExpr(func->body); code_stream_ << builder.JIT(out); - return builder.GetMetadata(); + return {sid, builder.const_vars_}; } runtime::Module CreateCSourceModule(const ObjectRef& ref) override { // Create headers code_stream_ << "#include \n"; + code_stream_ << "#include \n"; code_stream_ << "#include \n"; + code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; + code_stream_ << "using namespace tvm::runtime;\n"; // Append some common macro for operator definition. const char* operator_macro = R"op_macro( @@ -244,46 +262,17 @@ class CSourceCodegen : public CSourceModuleCodegenBase { code_stream_ << operator_macro << "\n\n"; - String func_symbol("all"); - String code; - Array variables; - Array metadata; - if (ref->IsInstance()) { - Map consts = GenCFunc(Downcast(ref)); - std::string code_str = code_stream_.str(); - if (!consts.empty()) { - code_str = "#include \"metadata.h\"\n" + code_str; - for (const auto& it : consts) { - variables.push_back(it.first); - metadata.push_back(it.second); - } - } - code = code_str; - } else if (ref->IsInstance()) { - IRModule mod = Downcast(ref); - for (const auto& it : mod->functions) { - Map consts = GenCFunc(Downcast(it.second)); - if (!consts.empty()) { - for (const auto& it : consts) { - variables.push_back(it.first); - metadata.push_back(it.second); - } - } - } - std::string code_str = code_stream_.str(); - if (!metadata.empty()) { - code_str = "#include \"metadata.h\"\n" + code_str; - } - code = code_str; - } else { - LOG(FATAL) << "The input ref is expected to be a Relay function or module" - << "\n"; - } + CHECK(ref->IsInstance()); + auto res = GenCFunc(Downcast(ref)); + std::string code = code_stream_.str(); + + String sym = std::get<0>(res); + Array variables = std::get<1>(res); - // Create a SourceMetadataModuleNode - const auto* pf = runtime::Registry::Get("runtime.SourceMetadataModuleCreate"); + // Create a CSource module + const auto* pf = runtime::Registry::Get("runtime.CSourceModuleCreate"); CHECK(pf != nullptr) << "Cannot find csource module to create the external runtime module"; - return (*pf)(func_symbol, code, "c", variables, metadata); + return (*pf)(code, "c", sym, variables); } private: diff --git a/src/relay/backend/contrib/codegen_c/codegen_c.h b/src/relay/backend/contrib/codegen_c/codegen_c.h index 3a3c486bb035..f4512f38e858 100644 --- a/src/relay/backend/contrib/codegen_c/codegen_c.h +++ b/src/relay/backend/contrib/codegen_c/codegen_c.h @@ -110,6 +110,8 @@ class CodegenCBase { * * \code * + * Array foo_consts; + * * // An example code for the generated C function. * extern "C" void foo_wrapper_(DLTensor* arg0, * DLTensor* arg1, @@ -122,10 +124,16 @@ class CodegenCBase { * * TVM_DLL_EXPORT_TYPED_FUNC(foo, foo_wrapper_); * + * void foo_init_wrapper_(Array arr) { + * foo_consts = arr; + * } + * + * TVM_DLL_EXPORT_TYPED_FUNC(__init_foo, foo_init_wrapper_); + * * \endcode */ void GenerateBackendCFunc(const std::string& func_name, const Array& args, - const std::vector& outs) { + const std::string& const_arr, const std::vector& outs) { // Print signature code_stream_ << "\n"; code_stream_ << "extern \"C\" int " << func_name << "_wrapper_("; @@ -163,6 +171,17 @@ class CodegenCBase { // Generate the macro code_stream_ << "TVM_DLL_EXPORT_TYPED_FUNC(" << func_name << ", " << func_name << "_wrapper_);\n\n"; + + if (!const_arr.empty()) { + code_stream_ << "void " << func_name << "_init_wrapper_(Array arr) {\n"; + EnterScope(); + PrintIndents(); + code_stream_ << func_name << "_consts = arr;\n"; + code_stream_ << "}\n\n"; + ExitScope(); + code_stream_ << "TVM_DLL_EXPORT_TYPED_FUNC(__init_" << func_name << ", " << func_name + << "_init_wrapper_);\n\n"; + } } /*! @@ -190,7 +209,12 @@ class CodegenCBase { */ std::string JitImpl(const std::string& ext_func_id, const Array& args, const std::vector& buf_decl, - const std::vector& body, const std::vector& outs) { + const std::vector& body, const std::string& const_arr, + const std::vector& outs) { + // Create a declaration for global ndarrays that contain constant data. + if (!const_arr.empty()) { + code_stream_ << const_arr << "\n\n"; + } // Create the signature. For example, it could be: // extern "C" void dnnl_0_(float* in0, float* in1, float* out0, float* out1) {} code_stream_ << "extern \"C\" void " << ext_func_id << "_("; @@ -236,7 +260,7 @@ class CodegenCBase { code_stream_ << "}\n"; // Create the wrapper to call the ext_func - this->GenerateBackendCFunc(ext_func_id, args, outs); + this->GenerateBackendCFunc(ext_func_id, args, const_arr, outs); return code_stream_.str(); } diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index 13206ce42868..4d9477ad9f7d 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -165,11 +165,24 @@ class CodegenDNNL : public MemoizedExprTranslator>, public C std::vector VisitExpr_(const ConstantNode* cn) final { Output output; - output.name = ext_func_id_ + "_const_" + std::to_string(const_idx_++); + // Get const: static_cast(dnnl_0_consts[0]->data) + output.name = "static_cast(" + ext_func_id_ + "_consts[" + std::to_string(const_idx_) + + "]->data)"; output.dtype = "float"; - CHECK_EQ(metadata_.count(output.name), 0U) << "variable must be unique: " << output.name; - metadata_.Set(output.name, cn->data); + // Generate the global variable for needed ndarrays + if (const_array_.empty()) { + const_array_ = "Array " + ext_func_id_ + "_consts;"; + std::ostringstream buf_stream; + buf_stream << "CHECK(!" << ext_func_id_ + << "_consts.empty()) << \"DNNL source module hasn't been initialized.\";\n"; + ext_func_body_.insert(ext_func_body_.begin(), buf_stream.str()); + } + + // Give the ndarray a unique name to ease the initialization of it at + // runtime. + std::string const_var_name = ext_func_id_ + "_const_" + std::to_string(const_idx_++); + const_vars_.push_back(const_var_name); const auto* type_node = cn->checked_type().as(); CHECK(type_node); @@ -187,16 +200,14 @@ class CodegenDNNL : public MemoizedExprTranslator>, public C } buf_decl_.insert(buf_decl_.end(), ret.buffers.begin(), ret.buffers.end()); - ext_func_body.push_back(ret.decl); + ext_func_body_.push_back(ret.decl); return ret.outputs; } std::string JIT(const std::vector& out) { - return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body, out); + return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body_, const_array_, out); } - Map GetMetadata() const { return metadata_; } - private: struct GenerateBodyOutput { std::string decl; @@ -326,12 +337,16 @@ class CodegenDNNL : public MemoizedExprTranslator>, public C int const_idx_{0}; /*! \brief The arguments used by a wrapped function that calls DNNL kernels. */ Array ext_func_args_; - /*! \brief statement of the function that will be compiled using DNNL kernels. */ - std::vector ext_func_body; + /*! \brief Statement of the function that will be compiled using DNNL kernels. */ + std::vector ext_func_body_; + /*! \brief The array declared to store the constant values. */ + std::string const_array_; /*! \brief The declaration of intermeidate buffers. */ std::vector buf_decl_; /*! \brief The variable name to constant mapping. */ - Map metadata_; + Array const_vars_; + + friend class DNNLModuleCodegen; }; /*! @@ -342,7 +357,7 @@ class CodegenDNNL : public MemoizedExprTranslator>, public C class DNNLModuleCodegen : public CSourceModuleCodegenBase { public: // Create a corresponding DNNL function for the given relay Function. - Map GenDNNLFunc(const Function& func) { + std::pair> GenDNNLFunc(const Function& func) { CHECK(func.defined()) << "Input error: expect a Relay function."; // Record the external symbol for runtime lookup. @@ -352,7 +367,7 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase { auto out = builder.VisitExpr(func->body); code_stream_ << builder.JIT(out); - return builder.GetMetadata(); + return {sid, builder.const_vars_}; } /*! @@ -371,56 +386,29 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase { code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; + code_stream_ << "#include \n"; code_stream_ << "#include \n"; + code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; // dnnl_kernel file is saved under src/runtime/contrib/dnnl so that we don't // expose it to ordinary users. To make export_library use it, users need to // pass -I${PATH_TO_TVM}/src/runtime/contrib code_stream_ << "#include \n"; + code_stream_ << "using namespace tvm::runtime;\n"; code_stream_ << "using namespace tvm::runtime::contrib;\n"; code_stream_ << "\n"; - String func_symbol("all"); - String code; - Array variables; - Array metadata; - if (ref->IsInstance()) { - Map consts = GenDNNLFunc(Downcast(ref)); - std::string code_str = code_stream_.str(); - if (!consts.empty()) { - code_str = "#include \"metadata.h\"\n" + code_str; - for (const auto& it : consts) { - variables.push_back(it.first); - metadata.push_back(it.second); - } - } - code = code_str; - } else if (ref->IsInstance()) { - IRModule mod = Downcast(ref); - for (const auto& it : mod->functions) { - Map consts = GenDNNLFunc(Downcast(it.second)); - if (!consts.empty()) { - for (const auto& it : consts) { - variables.push_back(it.first); - metadata.push_back(it.second); - } - } - } - std::string code_str = code_stream_.str(); - if (!metadata.empty()) { - code_str = "#include \"metadata.h\"\n" + code_str; - } - code = code_str; - } else { - LOG(FATAL) << "The input ref is expected to be a Relay function or module" - << "\n"; - } + CHECK(ref->IsInstance()); + auto res = GenDNNLFunc(Downcast(ref)); + std::string code = code_stream_.str(); + String sym = std::get<0>(res); + Array variables = std::get<1>(res); - // Create a SourceMetadataModuleNode - const auto* pf = runtime::Registry::Get("runtime.SourceMetadataModuleCreate"); + // Create a CSource module + const auto* pf = runtime::Registry::Get("runtime.CSourceModuleCreate"); CHECK(pf != nullptr) << "Cannot find csource module to create the external runtime module"; - return (*pf)(func_symbol, code, "c", variables, metadata); + return (*pf)(code, "c", sym, variables); } private: diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index 4226cc872589..bc8b390716ee 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -368,6 +368,14 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslatorGetAttr(tvm::attr::kGlobalSymbol); + std::string symobl = std::string(name_node.value()); + ConstantUpdater const_visit(symobl, ¶ms_); + const_visit(func); + return GraphAddCallNode(op, ext_func->func_name, ext_func->func_name); } diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 4475d43f2898..cac6f55329c8 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -43,6 +43,26 @@ namespace tvm { namespace relay { namespace backend { +/*! + * \brief A helper to expand the params by adding the ones used in a given expression. + */ +struct ConstantUpdater : public ExprVisitor { + public: + ConstantUpdater(const std::string& symbol, + std::unordered_map* params) + : symbol_(symbol), params_(params) {} + + void VisitExpr_(const ConstantNode* cn) final { + std::string name = symbol_ + "_const_" + std::to_string(const_idx_++); + (*params_)[name] = cn->data; + } + + private: + int const_idx_{0}; + std::string symbol_; + std::unordered_map* params_; +}; + /*! * \brief A simple wrapper around ExprFunctor for a single argument case. * The result of visit is memoized. diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 912a5f65fd36..75e2cd5a40e6 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -997,6 +997,10 @@ void VMCompiler::Codegen() { mod.CopyOnWrite(); if (target_str == "ext_dev") { + // Collect metadata in functions that are handled by external codegen. + CHECK(mod->ContainGlobalVar(cfunc->func_name)); + backend::ConstantUpdater const_visit(cfunc->func_name, ¶ms_); + const_visit(Downcast(mod->Lookup(cfunc->func_name))); continue; } else if (funcs.count(target_str) == 0) { funcs.emplace(target_str, mod); @@ -1019,8 +1023,7 @@ void VMCompiler::Codegen() { exec_->lib = codegen::CSourceModuleCreate(";", ""); } if (!ext_mods.empty()) { - auto init_mod = codegen::WrapMetadataModule(ext_mods); - exec_->lib.Import(init_mod); + exec_->lib = codegen::WrapMetadataModule(params_, exec_->lib, ext_mods); } } diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h index 451c0e88fcb0..03dba399fcb4 100644 --- a/src/runtime/meta_data.h +++ b/src/runtime/meta_data.h @@ -26,9 +26,12 @@ #include #include +#include +#include #include #include +#include #include #include "runtime_base.h" @@ -36,6 +39,19 @@ namespace tvm { namespace runtime { +/*! + * \brief Create a metadata module object. + * + * \param metadata The variable name to ndarray mapping. + * \param sym_vars The symbol to the list of required constant variables + * mapping. + * + * \return The created metadata module. + */ +Module MetadataModuleCreate( + const std::unordered_map& metadata, + const std::unordered_map>& sym_vars); + /*! \brief function information needed by device */ struct FunctionInfo { std::string name; diff --git a/src/runtime/metadata_module.cc b/src/runtime/metadata_module.cc new file mode 100644 index 000000000000..fa03e02a66db --- /dev/null +++ b/src/runtime/metadata_module.cc @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/metadata_module.cc + * \brief A wrapper for initializing imported modules using metadata. This + * module is intended to be used by various runtime in the TVM stack, i.e. + * graph runtime, relay VM, AOT runtime, and various user defined runtimes. It + * paves the way to separate the code and metedata, which makes compilation + * and/or interpretation more convenient. In addition, the clear separation of + * code and metadata significantly reduces the efforts for handling external + * codegen and runtimes. + */ +#include +#include +#include +#include + +#include +#include + +#include "meta_data.h" + +namespace tvm { +namespace runtime { + +/*! + * \brief The metadata module is designed to manage initialization of the + * imported submodules. + */ +class MetadataModuleNode : public ModuleNode { + public: + MetadataModuleNode(const std::unordered_map& metadata, + const std::unordered_map>& sym_vars) + : metadata_(metadata), sym_vars_(sym_vars) {} + + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { + // Initialize and memoize the module. + // Usually, we have some warmup runs. The module initialization should be + // done at this stage. Therefore, runtime overhead is not a concern. + if (initialized_.count(name) == 0) { + this->InitSubModule(name); + initialized_[name] = true; + } + + // Run the module. + // Normally we would only have a limited number of submodules. The runtime + // symobl lookup overhead should be minimal. + CHECK(!this->imports().empty()); + for (Module it : this->imports()) { + PackedFunc pf = it.GetFunction(name); + if (pf != nullptr) return pf; + } + return PackedFunc(nullptr); + } + + const char* type_key() const { return "module_init"; } + + /*! + * \brief Get the list of metadata that is required by the given module. + * \param symbol The symbol that is being queried. + * \return The list of needed NDArray. + */ + Array GetRequiredMetadata(const std::string& symbol) { + Array ret; + CHECK_GT(sym_vars_.count(symbol), 0U) << "Not symbol is recorded for " << symbol; + std::vector vars = sym_vars_[symbol]; + for (const auto& it : vars) { + CHECK_GT(metadata_.count(it), 0U) << "Found not recorded constant variable: " << it; + ret.push_back(metadata_[it]); + } + return ret; + } + + /*! + * \brief Initialize each imported module. + * \param symobl The symbol used for initializing a module. It is also used + * for runtime lookup. + * + * \note A module could be like the following: + * MetadataModuleNode (contains all the metadata) + * - CSourceModule + * - JSON runtime module + * + * The initializer iterates through the imported modules and intilizes the + * found module accordingly by passing the needed metadata into it. + */ + void InitSubModule(const std::string& symbol) { + PackedFunc init(nullptr); + for (Module it : this->imports()) { + // Get the initialization function from the imported modules. + std::string init_name = "__init_" + symbol; + init = it.GetFunction(init_name, false); + if (init != nullptr) { + auto md = GetRequiredMetadata(symbol); + // Initialize the module with metadata. + init(md); + break; + } + } + } + + void SaveToBinary(dmlc::Stream* stream) final { + std::vector variables; + std::vector metadata; + for (const auto& it : metadata_) { + String var_name = it.first; + variables.push_back(var_name); + metadata.push_back(it.second); + } + + // Save all variables in the function. + stream->Write(variables); + // Save all constant data. + uint64_t sz = static_cast(metadata.size()); + stream->Write(sz); + for (uint64_t i = 0; i < sz; i++) { + metadata[i].Save(stream); + } + + // Save the symbol to list of required constant variables mapping + std::vector symbols; + std::vector> const_vars; + for (const auto& it : sym_vars_) { + symbols.push_back(it.first); + const_vars.push_back(it.second); + } + + stream->Write(symbols); + sz = static_cast(sym_vars_.size()); + stream->Write(sz); + for (uint64_t i = 0; i < sz; i++) { + stream->Write(const_vars[i]); + } + } + + static Module LoadFromBinary(void* strm) { + dmlc::Stream* stream = static_cast(strm); + + // Load the variables. + std::vector variables; + CHECK(stream->Read(&variables)) << "Loading variables failed"; + uint64_t sz; + CHECK(stream->Read(&sz, sizeof(sz))) << "Loading medata size failed"; + CHECK_EQ(static_cast(sz), variables.size()) + << "The number of variables and ndarray counts must match"; + // Load the list of ndarray. + std::vector arrays; + for (uint64_t i = 0; i < sz; i++) { + NDArray temp; + temp.Load(stream); + arrays.push_back(temp); + } + + std::unordered_map metadata; + for (size_t i = 0; i < variables.size(); i++) { + CHECK_EQ(metadata.count(variables[i]), 0U); + metadata[variables[i]] = arrays[i]; + } + + // Load the symbol to list of required constant variables mapping + std::vector symbols; + CHECK(stream->Read(&symbols)) << "Loading symbols failed"; + CHECK(stream->Read(&sz, sizeof(sz))) << "Loading number of symbols failed"; + CHECK_EQ(static_cast(sz), symbols.size()); + std::vector> const_vars; + for (uint64_t i = 0; i < sz; i++) { + std::vector vars; + CHECK(stream->Read(&vars)) << "Loading const variables failed"; + const_vars.push_back(vars); + } + + std::unordered_map> sym_vars; + for (uint64_t i = 0; i < sz; i++) { + sym_vars[symbols[i]] = const_vars[i]; + } + + auto n = make_object(metadata, sym_vars); + return Module(n); + } + + private: + /*! + * \brief Record if a module is initialized. It is needed by imported + * modules using execution engine. + */ + std::unordered_map initialized_; + /*! \brief Variable name to NDArray mapping. */ + std::unordered_map metadata_; + /*! \brief Symbol name to required constant variables mapping. */ + std::unordered_map> sym_vars_; +}; + +Module MetadataModuleCreate( + const std::unordered_map& metadata, + const std::unordered_map>& sym_vars) { + auto n = make_object(metadata, sym_vars); + return Module(n); +} + +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_module_init") + .set_body_typed(MetadataModuleNode::LoadFromBinary); +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/module_init_wrapper.cc b/src/runtime/module_init_wrapper.cc deleted file mode 100644 index a3ab635c22b6..000000000000 --- a/src/runtime/module_init_wrapper.cc +++ /dev/null @@ -1,234 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/runtime/module_init_wrapper.cc - * \brief A wrapper for initializing modules using metadata - */ -#include -#include -#include -#include - -#include -#include - -#include "file_util.h" - -namespace tvm { -namespace runtime { - -using StringNDArrayMap = std::unordered_map; - -class CSourceMetadataInitializer { - public: - explicit CSourceMetadataInitializer(const StringNDArrayMap& metadata) : metadata_(metadata) {} - - template - void GetElements(const std::string& var_name, const std::string& type_name, - const runtime::NDArray& arr) { - // Get the number of elements. - int64_t num_elems = 1; - for (auto i : arr.Shape()) num_elems *= i; - stream_ << "static " << type_name << " " << var_name << "[" << num_elems << "] = {"; - T* ptr = static_cast(arr->data); - for (int64_t i = 0; i < num_elems - 1; i++) { - stream_ << ptr[i] << ","; - } - if (num_elems > 0) stream_ << ptr[num_elems - 1]; - stream_ << "};\n"; - } - - std::string Init() { - for (const auto& it : metadata_) { - std::string var_name = it.first.operator std::string(); - runtime::NDArray data = it.second; - CHECK_EQ(data->dtype.lanes, 1U); - if (data->dtype.code == kDLFloat) { - if (data->dtype.bits == 32) { - stream_.precision(std::numeric_limits::digits10 + 1); - GetElements(var_name, "float", data); - } else { - CHECK_EQ(data->dtype.bits, 64); - stream_.precision(std::numeric_limits::digits10 + 1); - GetElements(var_name, "double", data); - } - } else if (data->dtype.code == kDLUInt) { - if (data->dtype.bits == 8) { - GetElements(var_name, "uint8_t", data); - } else { - CHECK_EQ(data->dtype.bits, 32); - GetElements(var_name, "uint32_t", data); - } - } else { - if (data->dtype.bits == 8) { - GetElements(var_name, "int8_t", data); - } else { - CHECK_EQ(data->dtype.bits, 32); - GetElements(var_name, "int32_t", data); - } - } - } - return stream_.str(); - } - - private: - /*! \brief The stream to print constant data. */ - std::ostringstream stream_; - /*! \brief variable name to NDArray mapping. */ - StringNDArrayMap metadata_; -}; - -class ModuleInitWrapper : public runtime::ModuleNode { - public: - ModuleInitWrapper(const Array& variables, const Array& metadata) { - CHECK_EQ(variables.size(), metadata.size()); - for (size_t i = 0; i < variables.size(); i++) { - metadata_[variables[i]] = metadata[i]; - } - } - - PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { - if (initialized_.count(name) == 0) { - this->InitSubModule(name); - initialized_[name] = true; - } - - if (name != "__InitModule") { - CHECK(!this->imports().empty()); - runtime::Module submodule = this->imports().at(0); - return submodule->GetFunction(name); - } - - return PackedFunc(); - } - - const char* type_key() const { return "module_init"; } - - /*! - * \brief Initialize each imported module. - * \param symobl The symbol used for initializing a module. It is also used - * for runtime lookup. - * - * \note A module could be like the following: - * ModuleInitWrapper (contains all the metadata) - * - CSourceModule - * - JSON runtime module - * - * The initializer iterates through the wrapped module and intilizes them - * accordingly by passing the needed metadata into it. - */ - void InitSubModule(const std::string& symbol) { - // Dispatch initializer according to the source type - // TODO(zhiics) iterate through the imported modules to initialize - // for (const auto& it : this->imports()) { - // } - } - - void SaveToFile(const std::string& file_name, const std::string& format) final { - // C source module relies on AOT compilation. The source code has already - // been generated. The used metadata is saved a separate file for - // compilation. - std::string consts = ""; - for (auto& it : this->imports()) { - if (!std::strcmp(it->type_key(), "c")) { - // TODO(zhiics) Maybe we need to store the list of required - // variales in the CSourceModule so that we can validate the - // existence of the variable and feed it only with the required - // ones. - CSourceMetadataInitializer c_init(metadata_); - consts += c_init.Init(); - consts += "\n"; - } - } - if (consts != "") { - std::string fmt = GetFileFormat(file_name, format); - CHECK_EQ(fmt, "h") << "Can only save to .h file"; - SaveBinaryToFile(file_name, consts); - } - } - - void SaveToBinary(dmlc::Stream* stream) final { - std::vector variables; - std::vector metadata; - for (const auto& it : metadata_) { - String var_name = it.first; - variables.push_back(var_name.operator std::string()); - metadata.push_back(it.second); - } - - // Save all variables in the function. - stream->Write(variables); - // Save all constant data. - uint64_t sz = static_cast(metadata.size()); - stream->Write(sz); - for (uint64_t i = 0; i < sz; i++) { - metadata[i].Save(stream); - } - } - - static runtime::Module LoadFromBinary(void* strm) { - dmlc::Stream* stream = static_cast(strm); - - // Load the variables. - std::vector variables; - CHECK(stream->Read(&variables)) << "Loading variables failed"; - uint64_t sz; - CHECK(stream->Read(&sz, sizeof(sz))) << "Loading medata size failed"; - CHECK_EQ(static_cast(sz), variables.size()) - << "The number of variables and ndarray counts must match"; - // Load the list of ndarray. - std::vector metadata; - for (uint64_t i = 0; i < sz; i++) { - tvm::runtime::NDArray temp; - temp.Load(stream); - metadata.push_back(temp); - } - - Array vars; - Array consts; - for (size_t i = 0; i < variables.size(); i++) { - vars.push_back(variables[i]); - consts.push_back(metadata[i]); - } - auto n = runtime::make_object(vars, consts); - return runtime::Module(n); - } - - private: - /*! - * \brief Record if a module is initialized. It is needed by imported - * modules using execution engine. - */ - std::unordered_map initialized_; - /*! \brief Variable name to NDArray mapping. */ - StringNDArrayMap metadata_; -}; - -runtime::Module ModuleInitWrapperCreate(Array variables, Array metadata) { - auto n = make_object(variables, metadata); - return runtime::Module(n); -} - -TVM_REGISTER_GLOBAL("runtime.ModuleInitWrapper").set_body_typed(ModuleInitWrapperCreate); - -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_module_init") - .set_body_typed(ModuleInitWrapper::LoadFromBinary); -} // namespace runtime -} // namespace tvm diff --git a/src/target/source/codegen_source_base.h b/src/target/source/codegen_source_base.h index e677f6cd550d..f59d4dfefd97 100644 --- a/src/target/source/codegen_source_base.h +++ b/src/target/source/codegen_source_base.h @@ -135,22 +135,26 @@ runtime::Module SourceModuleCreate(std::string code, std::string fmt); /*! * \brief Create a C source module for viewing and compiling GCC code. * \param code The code to be viewed. - * \param fmt The code. format. - */ -runtime::Module CSourceModuleCreate(std::string code, std::string fmt); - -/*! - * \brief Create a helper module to wrap different modules. + * \param fmt The code format. + * \param symbol The symbol that the c source module represents. + * \param const_vars. The constant variables that the c source module needs. * \return The created module. */ -runtime::Module ModuleClassWrapperCreate(); +runtime::Module CSourceModuleCreate(const String& code, const String& fmt, + const String& symbol = "", + const Array& const_vars = {}); /*! * \brief Wrap the submodules in a metadata module. + * \param params The variable to constant mapping that is collected by the host + * module. + * \param dso_module The host module to be wrapped. * \param modules The modules to be wrapped. * \return The wrapped module. */ -runtime::Module WrapMetadataModule(const Array& modules); +runtime::Module WrapMetadataModule(const std::unordered_map& params, + const runtime::Module& dso_module, + const Array& modules); /*! * \brief Create a source module for viewing and limited saving for device. diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index fba7d5cb04ce..3d195ec3651c 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -41,37 +41,27 @@ using runtime::GetFileFormat; using runtime::GetMetaFilePath; using runtime::SaveBinaryToFile; -runtime::Module WrapMetadataModule(const Array& modules) { +runtime::Module WrapMetadataModule(const std::unordered_map& params, + const runtime::Module& dso_module, + const Array& modules) { // Wrap all submodules in the initialization wrapper. - Map var_md; - Array source_modules; + std::unordered_map> sym_metadata; for (runtime::Module it : modules) { - String code = it.GetFunction("get_source")(); - Array variables = it.GetFunction("get_vars")(); - Array metadata = it.GetFunction("get_metadata")(); - CHECK_EQ(variables.size(), metadata.size()) - << "Found mismatch in the number of variables and ndarray"; + CHECK_EQ(it->type_key(), "c") << "Only csource submodule is handled for now"; + String symbol = it.GetFunction("get_symbol")(); + Array variables = it.GetFunction("get_const_vars")(); + std::vector arrays; for (size_t i = 0; i < variables.size(); i++) { - var_md.Set(variables[i], metadata[i]); + arrays.push_back(variables[i].operator std::string()); } - - // TODO(zhiics) Invoke the corresponding module create function using the - // type key when json runtime comes. - source_modules.push_back(tvm::codegen::CSourceModuleCreate(code, "c")); - } - Array vars; - Array arrs; - for (const auto& it : var_md) { - vars.push_back(it.first); - arrs.push_back(it.second); + CHECK_EQ(sym_metadata.count(symbol), 0U) << "Found duplicated symbol: " << symbol; + sym_metadata[symbol] = arrays; } // Wrap the modules. - const auto* pf = runtime::Registry::Get("runtime.ModuleInitWrapper"); - CHECK(pf != nullptr) << "Cannot find the registry for runtime.ModuleInitWrapper"; - runtime::Module init_m = (*pf)(vars, arrs); - - for (const auto& it : source_modules) { + runtime::Module init_m = runtime::MetadataModuleCreate(params, sym_metadata); + init_m.Import(dso_module); + for (const auto& it : modules) { init_m.Import(it); } @@ -105,13 +95,22 @@ runtime::Module SourceModuleCreate(std::string code, std::string fmt) { // Simulator function class CSourceModuleNode : public runtime::ModuleNode { public: - CSourceModuleNode(std::string code, std::string fmt) : code_(code), fmt_(fmt) {} + CSourceModuleNode(const std::string& code, const std::string& fmt, const std::string& symbol, + const Array& const_vars) + : code_(code), fmt_(fmt), symbol_(symbol), const_vars_(const_vars) {} const char* type_key() const { return "c"; } PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { - LOG(FATAL) << "C Source module cannot execute, to get executable module" - << " build TVM with \'" << fmt_ << "\' runtime support"; - return PackedFunc(); + if (name == "get_symbol") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->symbol_; }); + } else if (name == "get_const_vars") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->const_vars_; }); + } else { + LOG(FATAL) << "Unknown packed function: " << name; + return PackedFunc(nullptr); + } } std::string GetSource(const std::string& format) final { return code_; } @@ -130,10 +129,14 @@ class CSourceModuleNode : public runtime::ModuleNode { protected: std::string code_; std::string fmt_; + std::string symbol_; + Array const_vars_; }; -runtime::Module CSourceModuleCreate(std::string code, std::string fmt) { - auto n = make_object(code, fmt); +runtime::Module CSourceModuleCreate(const String& code, const String& fmt, const String& symbol, + const Array& const_vars) { + auto n = make_object(code.operator std::string(), fmt.operator std::string(), + symbol.operator std::string(), const_vars); return runtime::Module(n); } @@ -190,70 +193,11 @@ runtime::Module DeviceSourceModuleCreate( return runtime::Module(n); } -// Pack the source code and metadata, where source code could be any -// user-defined code, i.e. c source code, json graph representation, etc. -class SourceMetadataModuleNode final : public runtime::ModuleNode { - public: - SourceMetadataModuleNode(const String& func_symbol, const String& code, const String& source_type, - const Array& variables, const Array& metadata) - : func_symbol_(func_symbol), - code_(code), - source_type_(source_type), - variables_(variables), - metadata_(metadata) {} - - PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { - if (name == "get_source") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->code_; }); - } else if (name == "get_source_type") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->source_type_; }); - } else if (name == "get_symbol") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->func_symbol_; }); - } else if (name == "get_vars") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->variables_; }); - } else if (name == "get_metadata") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->metadata_; }); - } else { - LOG(FATAL) << "Unknown packed function: " << name; - return PackedFunc(nullptr); - } - } - - const char* type_key() const { return "source_metadata"; } - - private: - /*! \brief The function symbols. */ - String func_symbol_; - /*! \brief The source code. */ - String code_; - /*! \brief The type of the source code, e.g. c or any customized json type. */ - String source_type_; - /*! \brief The list of constant variables. */ - Array variables_; - /*! \brief The list of constant values that are corresponding to the variables. */ - Array metadata_; -}; - -runtime::Module SourceMetadataModuleCreate(String func_symbol, String code, String source_type, - Array variables, - Array metadata) { - auto n = - make_object(func_symbol, code, source_type, variables, metadata); - return runtime::Module(n); -} - -TVM_REGISTER_GLOBAL("runtime.SourceMetadataModuleCreate") - .set_body_typed(SourceMetadataModuleCreate); - TVM_REGISTER_GLOBAL("runtime.SourceModuleCreate").set_body_typed(SourceModuleCreate); TVM_REGISTER_GLOBAL("runtime.CSourceModuleCreate") - .set_body_typed([](String code, String source_type) { - return CSourceModuleCreate(code.operator std::string(), source_type.operator std::string()); + .set_body_typed([](String code, String fmt, String symbol, Array const_vars) { + return CSourceModuleCreate(code, fmt, symbol, const_vars); }); } // namespace codegen diff --git a/tests/python/relay/test_external_runtime.py b/tests/python/relay/test_external_runtime.py index 39209232f3d0..7928e4d61b37 100644 --- a/tests/python/relay/test_external_runtime.py +++ b/tests/python/relay/test_external_runtime.py @@ -109,7 +109,8 @@ def generate_csource_module(): TVM_DLL_EXPORT_TYPED_FUNC(json_rt_0, ccompiler_wrapper_0_); ''' - csource_module = tvm.runtime._ffi_api.CSourceModuleCreate(code, "cc") + csource_module = tvm.runtime._ffi_api.CSourceModuleCreate(code, "cc", "", + None) return csource_module @@ -175,7 +176,8 @@ def generate_engine_module(): ''' gen_json_engine() - csource_module = tvm.runtime._ffi_api.CSourceModuleCreate(code, "cc") + csource_module = tvm.runtime._ffi_api.CSourceModuleCreate(code, "cc", "", + None) return csource_module diff --git a/tests/python/unittest/test_runtime_module_export.py b/tests/python/unittest/test_runtime_module_export.py index 8473a67e6e41..8ee197d643ac 100644 --- a/tests/python/unittest/test_runtime_module_export.py +++ b/tests/python/unittest/test_runtime_module_export.py @@ -54,7 +54,8 @@ def generate_engine_module(): ''' import tvm.runtime._ffi_api gen_engine_header() - csource_module = tvm.runtime._ffi_api.CSourceModuleCreate(code, "cc") + csource_module = tvm.runtime._ffi_api.CSourceModuleCreate(code, "cc", "", + None) return csource_module From b1b97948146373c35417bab20fd1dbafaff39da6 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Tue, 16 Jun 2020 15:49:15 +0000 Subject: [PATCH 11/13] add comments --- src/relay/backend/build_module.cc | 4 +++- src/relay/backend/vm/compiler.cc | 2 +- src/runtime/metadata_module.cc | 4 ++-- src/target/source/codegen_source_base.h | 6 +++--- src/target/source/source_module.cc | 17 ++++++++++++++--- 5 files changed, 23 insertions(+), 10 deletions(-) diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 01cf6057d474..b3cbde4a0b48 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -453,8 +453,10 @@ class RelayBuildModule : public runtime::ModuleNode { } Array ext_mods = graph_codegen_->GetExternalModules(); + // TODO(zhiics) We should be able to completely switch to MetadataModule no + // matter whether there are external modules or not. if (!ext_mods.empty()) { - ret_.mod = tvm::codegen::WrapMetadataModule(ret_.params, ret_.mod, ext_mods); + ret_.mod = tvm::codegen::CreateMetadataModule(ret_.params, ret_.mod, ext_mods); } } diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 75e2cd5a40e6..0af19491b298 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -1023,7 +1023,7 @@ void VMCompiler::Codegen() { exec_->lib = codegen::CSourceModuleCreate(";", ""); } if (!ext_mods.empty()) { - exec_->lib = codegen::WrapMetadataModule(params_, exec_->lib, ext_mods); + exec_->lib = codegen::CreateMetadataModule(params_, exec_->lib, ext_mods); } } diff --git a/src/runtime/metadata_module.cc b/src/runtime/metadata_module.cc index fa03e02a66db..1cb118d877f3 100644 --- a/src/runtime/metadata_module.cc +++ b/src/runtime/metadata_module.cc @@ -70,7 +70,7 @@ class MetadataModuleNode : public ModuleNode { return PackedFunc(nullptr); } - const char* type_key() const { return "module_init"; } + const char* type_key() const { return "metadata"; } /*! * \brief Get the list of metadata that is required by the given module. @@ -214,7 +214,7 @@ Module MetadataModuleCreate( return Module(n); } -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_module_init") +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_metadata") .set_body_typed(MetadataModuleNode::LoadFromBinary); } // namespace runtime } // namespace tvm diff --git a/src/target/source/codegen_source_base.h b/src/target/source/codegen_source_base.h index f59d4dfefd97..7e5e40324c47 100644 --- a/src/target/source/codegen_source_base.h +++ b/src/target/source/codegen_source_base.h @@ -152,9 +152,9 @@ runtime::Module CSourceModuleCreate(const String& code, const String& fmt, * \param modules The modules to be wrapped. * \return The wrapped module. */ -runtime::Module WrapMetadataModule(const std::unordered_map& params, - const runtime::Module& dso_module, - const Array& modules); +runtime::Module CreateMetadataModule( + const std::unordered_map& params, + const runtime::Module& dso_module, const Array& modules); /*! * \brief Create a source module for viewing and limited saving for device. diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index 3d195ec3651c..1e201e50f0ea 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -41,9 +41,20 @@ using runtime::GetFileFormat; using runtime::GetMetaFilePath; using runtime::SaveBinaryToFile; -runtime::Module WrapMetadataModule(const std::unordered_map& params, - const runtime::Module& dso_module, - const Array& modules) { +/*! + * \brief Create a metadata module wrapper. The helper is used by different + * codegens, such as graph runtime codegen and the vm compiler. + * + * \param params The metadata for initialization of all modules. + * \param dso_module The DSO module that contains TVM primitives. + * \param modules The submodules that will be wrapped, e.g. CSource modules that + * contain vendor library calls or customized runtime modules. + * + * \return The created metadata module that manages initialization of metadata. + */ +runtime::Module CreateMetadataModule( + const std::unordered_map& params, + const runtime::Module& dso_module, const Array& modules) { // Wrap all submodules in the initialization wrapper. std::unordered_map> sym_metadata; for (runtime::Module it : modules) { From 6be3fe8d29aeb15c1ade40682f7527495ac6ab92 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Wed, 17 Jun 2020 18:27:29 +0000 Subject: [PATCH 12/13] fix comments --- src/relay/backend/compile_engine.cc | 1 - .../backend/contrib/codegen_c/codegen.cc | 22 +++--- .../backend/contrib/codegen_c/codegen_c.h | 71 ++++++++++++++++--- src/relay/backend/contrib/dnnl/codegen.cc | 20 +++--- src/runtime/metadata_module.cc | 14 ++-- 5 files changed, 88 insertions(+), 40 deletions(-) diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index ecd7fb9e8f62..11fe61e4e697 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -38,7 +38,6 @@ #include #include -#include #include #include #include diff --git a/src/relay/backend/contrib/codegen_c/codegen.cc b/src/relay/backend/contrib/codegen_c/codegen.cc index a0118e3da544..c7b5a8da1fed 100644 --- a/src/relay/backend/contrib/codegen_c/codegen.cc +++ b/src/relay/backend/contrib/codegen_c/codegen.cc @@ -81,27 +81,25 @@ class CodegenC : public MemoizedExprTranslator>, public Code std::ostringstream buf_stream; Output output; - // Get const: static_cast(dnnl_0_consts[0]->data) - output.name = "static_cast(" + ext_func_id_ + "_consts[" + std::to_string(const_idx_) + - "]->data)"; + // Get const: static_cast(gcc_0_consts[0]->data) + output.name = CreateDataReference(ext_func_id_, const_idx_); const auto* type_node = cn->checked_type().as(); CHECK(type_node); const auto& dtype = GetDtypeString(type_node); // Generate the global variable for needed ndarrays - if (const_array_.empty()) { - const_array_ = "Array " + ext_func_id_ + "_consts;"; - std::ostringstream buf_stream; - buf_stream << "CHECK(!" << ext_func_id_ - << "_consts.empty()) << \"C source module hasn't been initialized.\";\n"; - ext_func_body_.insert(ext_func_body_.begin(), buf_stream.str()); + if (const_array_name_.empty()) { + const_array_name_ = CreateNDArrayPool(ext_func_id_); + std::string checker = CreateInitChecker(ext_func_id_); + ext_func_body_.insert(ext_func_body_.begin(), checker); } CHECK(dtype == "float" || dtype == "int") << "Only float and int are supported for now."; output.dtype = dtype; - std::string const_var_name = ext_func_id_ + "_const_" + std::to_string(const_idx_++); + std::string const_var_name = CreateConstVar(ext_func_id_, const_idx_); const_vars_.push_back(const_var_name); + const_idx_++; return {output}; } @@ -187,7 +185,7 @@ class CodegenC : public MemoizedExprTranslator>, public Code for (auto decl : func_decl_) { code_stream_ << decl << "\n"; } - return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body_, const_array_, out); + return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body_, const_array_name_, out); } private: @@ -204,7 +202,7 @@ class CodegenC : public MemoizedExprTranslator>, public Code /*! \brief The statements of a C compiler compatible function. */ std::vector ext_func_body_; /*! \brief The array declared to store the constant values. */ - std::string const_array_; + std::string const_array_name_; /*! \brief The declaration statements of a C compiler compatible function. */ std::vector func_decl_; /*! \brief The declaration statements of buffers. */ diff --git a/src/relay/backend/contrib/codegen_c/codegen_c.h b/src/relay/backend/contrib/codegen_c/codegen_c.h index f4512f38e858..32ab15058989 100644 --- a/src/relay/backend/contrib/codegen_c/codegen_c.h +++ b/src/relay/backend/contrib/codegen_c/codegen_c.h @@ -113,7 +113,7 @@ class CodegenCBase { * Array foo_consts; * * // An example code for the generated C function. - * extern "C" void foo_wrapper_(DLTensor* arg0, + * extern "C" int foo_wrapper_(DLTensor* arg0, * DLTensor* arg1, * DLTensor* out) { * foo_(static_cast(arg0->data), @@ -124,8 +124,9 @@ class CodegenCBase { * * TVM_DLL_EXPORT_TYPED_FUNC(foo, foo_wrapper_); * - * void foo_init_wrapper_(Array arr) { + * int foo_init_wrapper_(Array arr) { * foo_consts = arr; + * return 0; * } * * TVM_DLL_EXPORT_TYPED_FUNC(__init_foo, foo_init_wrapper_); @@ -133,7 +134,7 @@ class CodegenCBase { * \endcode */ void GenerateBackendCFunc(const std::string& func_name, const Array& args, - const std::string& const_arr, const std::vector& outs) { + const std::string& const_arr_name, const std::vector& outs) { // Print signature code_stream_ << "\n"; code_stream_ << "extern \"C\" int " << func_name << "_wrapper_("; @@ -172,13 +173,14 @@ class CodegenCBase { code_stream_ << "TVM_DLL_EXPORT_TYPED_FUNC(" << func_name << ", " << func_name << "_wrapper_);\n\n"; - if (!const_arr.empty()) { - code_stream_ << "void " << func_name << "_init_wrapper_(Array arr) {\n"; + if (!const_arr_name.empty()) { + code_stream_ << "int " << func_name << "_init_wrapper_(Array arr) {\n"; EnterScope(); PrintIndents(); code_stream_ << func_name << "_consts = arr;\n"; - code_stream_ << "}\n\n"; + code_stream_ << "return 0;\n"; ExitScope(); + code_stream_ << "}\n\n"; code_stream_ << "TVM_DLL_EXPORT_TYPED_FUNC(__init_" << func_name << ", " << func_name << "_init_wrapper_);\n\n"; } @@ -209,11 +211,11 @@ class CodegenCBase { */ std::string JitImpl(const std::string& ext_func_id, const Array& args, const std::vector& buf_decl, - const std::vector& body, const std::string& const_arr, + const std::vector& body, const std::string& const_arr_name, const std::vector& outs) { // Create a declaration for global ndarrays that contain constant data. - if (!const_arr.empty()) { - code_stream_ << const_arr << "\n\n"; + if (!const_arr_name.empty()) { + code_stream_ << const_arr_name << "\n\n"; } // Create the signature. For example, it could be: // extern "C" void dnnl_0_(float* in0, float* in1, float* out0, float* out1) {} @@ -260,7 +262,7 @@ class CodegenCBase { code_stream_ << "}\n"; // Create the wrapper to call the ext_func - this->GenerateBackendCFunc(ext_func_id, args, const_arr, outs); + this->GenerateBackendCFunc(ext_func_id, args, const_arr_name, outs); return code_stream_.str(); } @@ -299,6 +301,55 @@ class CodegenCBase { return dtype; } + /*! + * \brief Creates a checker to check if the NDArray pool is initialized + * + * \param symobl The Symbol of the current function + * + * \return The created checker + */ + std::string CreateInitChecker(const std::string& symbol) const { + std::ostringstream oss; + oss << "CHECK(!" << symbol + << "_consts.empty()) << \"C source module hasn't been initialized.\";\n"; + return oss.str(); + } + + /*! + * \brief Generates the global ndarray pool declaration + * + * \param symobl The Symbol of the current function + * + * \return The created declaration + */ + std::string CreateNDArrayPool(const std::string& symbol) const { + return "Array " + symbol + "_consts;"; + } + + /*! + * \brief Generates the reference to the data of a constant ndarray + * + * \param symobl The Symbol of the current function + * \param symobl const_id The index of the constant + * + * \return The created reference + */ + std::string CreateDataReference(const std::string& symbol, int const_id) const { + return "static_cast(" + symbol + "_consts[" + std::to_string(const_id) + "]->data)"; + } + + /*! + * \brief Returns the variable name for a constant variable + * + * \param symobl The Symbol of the current function + * \param symobl const_id The index of the constant + * + * \return The created variable name + */ + std::string CreateConstVar(const std::string& symbol, int const_id) const { + return symbol + "_const_" + std::to_string(const_id++); + } + /*! \brief The external function source code stream. */ std::ostringstream code_stream_; diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index 4d9477ad9f7d..60138ae99b3e 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -166,23 +166,21 @@ class CodegenDNNL : public MemoizedExprTranslator>, public C std::vector VisitExpr_(const ConstantNode* cn) final { Output output; // Get const: static_cast(dnnl_0_consts[0]->data) - output.name = "static_cast(" + ext_func_id_ + "_consts[" + std::to_string(const_idx_) + - "]->data)"; + output.name = CreateDataReference(ext_func_id_, const_idx_); output.dtype = "float"; // Generate the global variable for needed ndarrays - if (const_array_.empty()) { - const_array_ = "Array " + ext_func_id_ + "_consts;"; - std::ostringstream buf_stream; - buf_stream << "CHECK(!" << ext_func_id_ - << "_consts.empty()) << \"DNNL source module hasn't been initialized.\";\n"; - ext_func_body_.insert(ext_func_body_.begin(), buf_stream.str()); + if (const_array_name_.empty()) { + const_array_name_ = CreateNDArrayPool(ext_func_id_); + std::string checker = CreateInitChecker(ext_func_id_); + ext_func_body_.insert(ext_func_body_.begin(), checker); } // Give the ndarray a unique name to ease the initialization of it at // runtime. - std::string const_var_name = ext_func_id_ + "_const_" + std::to_string(const_idx_++); + std::string const_var_name = CreateConstVar(ext_func_id_, const_idx_); const_vars_.push_back(const_var_name); + const_idx_++; const auto* type_node = cn->checked_type().as(); CHECK(type_node); @@ -205,7 +203,7 @@ class CodegenDNNL : public MemoizedExprTranslator>, public C } std::string JIT(const std::vector& out) { - return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body_, const_array_, out); + return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body_, const_array_name_, out); } private: @@ -340,7 +338,7 @@ class CodegenDNNL : public MemoizedExprTranslator>, public C /*! \brief Statement of the function that will be compiled using DNNL kernels. */ std::vector ext_func_body_; /*! \brief The array declared to store the constant values. */ - std::string const_array_; + std::string const_array_name_; /*! \brief The declaration of intermeidate buffers. */ std::vector buf_decl_; /*! \brief The variable name to constant mapping. */ diff --git a/src/runtime/metadata_module.cc b/src/runtime/metadata_module.cc index 1cb118d877f3..cf3d5476d56f 100644 --- a/src/runtime/metadata_module.cc +++ b/src/runtime/metadata_module.cc @@ -56,7 +56,7 @@ class MetadataModuleNode : public ModuleNode { // done at this stage. Therefore, runtime overhead is not a concern. if (initialized_.count(name) == 0) { this->InitSubModule(name); - initialized_[name] = true; + initialized_.emplace(name); } // Run the module. @@ -79,7 +79,7 @@ class MetadataModuleNode : public ModuleNode { */ Array GetRequiredMetadata(const std::string& symbol) { Array ret; - CHECK_GT(sym_vars_.count(symbol), 0U) << "Not symbol is recorded for " << symbol; + CHECK_GT(sym_vars_.count(symbol), 0U) << "No symbol is recorded for " << symbol; std::vector vars = sym_vars_[symbol]; for (const auto& it : vars) { CHECK_GT(metadata_.count(it), 0U) << "Found not recorded constant variable: " << it; @@ -110,7 +110,9 @@ class MetadataModuleNode : public ModuleNode { if (init != nullptr) { auto md = GetRequiredMetadata(symbol); // Initialize the module with metadata. - init(md); + int ret = init(md); + // Report the error if initialization is failed. + CHECK_EQ(ret, 0) << TVMGetLastError(); break; } } @@ -157,7 +159,7 @@ class MetadataModuleNode : public ModuleNode { std::vector variables; CHECK(stream->Read(&variables)) << "Loading variables failed"; uint64_t sz; - CHECK(stream->Read(&sz, sizeof(sz))) << "Loading medata size failed"; + CHECK(stream->Read(&sz, sizeof(sz))) << "Loading metadata size failed"; CHECK_EQ(static_cast(sz), variables.size()) << "The number of variables and ndarray counts must match"; // Load the list of ndarray. @@ -169,7 +171,7 @@ class MetadataModuleNode : public ModuleNode { } std::unordered_map metadata; - for (size_t i = 0; i < variables.size(); i++) { + for (uint64_t i = 0; i < sz; i++) { CHECK_EQ(metadata.count(variables[i]), 0U); metadata[variables[i]] = arrays[i]; } @@ -200,7 +202,7 @@ class MetadataModuleNode : public ModuleNode { * \brief Record if a module is initialized. It is needed by imported * modules using execution engine. */ - std::unordered_map initialized_; + std::unordered_set initialized_; /*! \brief Variable name to NDArray mapping. */ std::unordered_map metadata_; /*! \brief Symbol name to required constant variables mapping. */ From 7397296759197f898f1b92f6c94aca5710a1d419 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Wed, 17 Jun 2020 22:29:52 +0000 Subject: [PATCH 13/13] don't trap on missing index --- python/tvm/contrib/graph_runtime.py | 7 ++++--- src/runtime/graph/graph_runtime.cc | 9 +++++---- tests/python/frontend/onnx/test_forward.py | 5 ++--- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/python/tvm/contrib/graph_runtime.py b/python/tvm/contrib/graph_runtime.py index 46077278e473..9b714a84b541 100644 --- a/python/tvm/contrib/graph_runtime.py +++ b/python/tvm/contrib/graph_runtime.py @@ -163,9 +163,10 @@ def set_input(self, key=None, value=None, **params): keys.sort(key=lambda x: -np.prod(params[x].shape)) for k in keys: # TODO(zhiics) Skip the weights for submodule in a better way. - # We could get all inputs required by graphruntime first, - # we should use MetadataModule for initialization. - if "_const_" not in k: + # We should use MetadataModule for initialization and remove + # params from set_input + val = self._get_input(k) + if val: self._get_input(k).copyfrom(params[k]) def run(self, **input_dict): diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc index 59bfb68f039b..146c0975cb5a 100644 --- a/src/runtime/graph/graph_runtime.cc +++ b/src/runtime/graph/graph_runtime.cc @@ -198,7 +198,7 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) { CHECK(size == names.size()) << "Invalid parameters file format"; for (size_t i = 0; i < size; ++i) { int in_idx = GetInputIndex(names[i]); - CHECK_GE(in_idx, 0) << "Found param for non-existent input: " << names[i]; + if (in_idx < 0) continue; uint32_t eid = this->entry_id(input_nodes_[in_idx], 0); CHECK_LT(eid, data_entry_.size()); @@ -222,7 +222,7 @@ void GraphRuntime::ShareParams(const GraphRuntime& other, dmlc::Stream* strm) { CHECK(size == names.size()) << "Invalid parameters file format"; for (size_t i = 0; i < size; ++i) { int in_idx = GetInputIndex(names[i]); - CHECK_GE(in_idx, 0) << "Found param for non-existent input: " << names[i]; + if (in_idx < 0) continue; uint32_t eid = this->entry_id(input_nodes_[in_idx], 0); CHECK_LT(eid, data_entry_.size()); CHECK_EQ(data_entry_[eid].use_count(), 1); @@ -422,8 +422,9 @@ PackedFunc GraphRuntime::GetFunction(const std::string& name, } else { in_idx = args[0]; } - CHECK_GE(in_idx, 0); - *rv = this->GetInput(in_idx); + if (in_idx >= 0) { + *rv = this->GetInput(in_idx); + } }); } else if (name == "get_num_outputs") { return PackedFunc( diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 178f059e2635..f033a4b7e6ff 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -78,11 +78,10 @@ def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output # Its possible for some onnx inputs to not be needed in the tvm # module, confirm its present before setting. try: - m.get_input(input_names[i]) + m.set_input(input_names[i], tvm.nd.array( + input_data[i].astype(input_data[i].dtype))) except: continue - m.set_input(input_names[i], tvm.nd.array( - input_data[i].astype(input_data[i].dtype))) else: m.set_input(input_names, tvm.nd.array( input_data.astype(input_data.dtype)))