diff --git a/src/runtime/dso_library.cc b/src/runtime/dso_library.cc index c439bde82497..81eb30ee12d2 100644 --- a/src/runtime/dso_library.cc +++ b/src/runtime/dso_library.cc @@ -37,62 +37,102 @@ namespace tvm { namespace runtime { -// Dynamic shared libary. -// This is the default module TVM used for host-side AOT +/*! + * \brief Dynamic shared library object used to load + * and retrieve symbols by name. This is the default + * module TVM uses for host-side AOT compilation. + */ class DSOLibrary final : public Library { public: - ~DSOLibrary() { - if (lib_handle_) Unload(); - } - void Init(const std::string& name) { Load(name); } - - void* GetSymbol(const char* name) final { return GetSymbol_(name); } + ~DSOLibrary(); + /*! + * \brief Initialize by loading and storing + * a handle to the underlying shared library. + * \param name The string name/path to the + * shared library over which to initialize. + */ + void Init(const std::string& name); + /*! + * \brief Returns the symbol address within + * the shared library for a given symbol name. + * \param name The name of the symbol. + * \return The symbol. + */ + void* GetSymbol(const char* name) final; private: - // Platform dependent handling. + /*! \brief Private implementation of symbol lookup. + * Implementation is operating system dependent. + * \param The name of the symbol. + * \return The symbol. + */ + void* GetSymbol_(const char* name); + /*! \brief Implementation of shared library load. + * Implementation is operating system dependent. + * \param The name/path of the shared library. + */ + void Load(const std::string& name); + /*! \brief Implementation of shared library unload. + * Implementation is operating system dependent. + */ + void Unload(); + #if defined(_WIN32) - // library handle + //! \brief Windows library handle HMODULE lib_handle_{nullptr}; - - void* GetSymbol_(const char* name) { - return reinterpret_cast(GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*) - } - - // Load the library - void Load(const std::string& name) { - // use wstring version that is needed by LLVM. - std::wstring wname(name.begin(), name.end()); - lib_handle_ = LoadLibraryW(wname.c_str()); - ICHECK(lib_handle_ != nullptr) << "Failed to load dynamic shared library " << name; - } - - void Unload() { - FreeLibrary(lib_handle_); - lib_handle_ = nullptr; - } #else - // Library handle + // \brief Linux library handle void* lib_handle_{nullptr}; - // load the library - void Load(const std::string& name) { - lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL); - ICHECK(lib_handle_ != nullptr) - << "Failed to load dynamic shared library " << name << " " << dlerror(); - } - - void* GetSymbol_(const char* name) { return dlsym(lib_handle_, name); } - - void Unload() { - dlclose(lib_handle_); - lib_handle_ = nullptr; - } #endif }; -TVM_REGISTER_GLOBAL("runtime.module.loadfile_so").set_body([](TVMArgs args, TVMRetValue* rv) { +DSOLibrary::~DSOLibrary() { + if (lib_handle_) Unload(); +} + +void DSOLibrary::Init(const std::string& name) { Load(name); } + +void* DSOLibrary::GetSymbol(const char* name) { return GetSymbol_(name); } + +#if defined(_WIN32) + +void* DSOLibrary::GetSymbol_(const char* name) { + return reinterpret_cast(GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*) +} + +void DSOLibrary::Load(const std::string& name) { + // use wstring version that is needed by LLVM. + std::wstring wname(name.begin(), name.end()); + lib_handle_ = LoadLibraryW(wname.c_str()); + ICHECK(lib_handle_ != nullptr) << "Failed to load dynamic shared library " << name; +} + +void DSOLibrary::Unload() { + FreeLibrary(lib_handle_); + lib_handle_ = nullptr; +} + +#else + +void DSOLibrary::Load(const std::string& name) { + lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL); + ICHECK(lib_handle_ != nullptr) << "Failed to load dynamic shared library " << name << " " + << dlerror(); +} + +void* DSOLibrary::GetSymbol_(const char* name) { return dlsym(lib_handle_, name); } + +void DSOLibrary::Unload() { + dlclose(lib_handle_); + lib_handle_ = nullptr; +} + +#endif + +ObjectPtr CreateDSOLibraryObject(std::string library_path) { auto n = make_object(); - n->Init(args[0]); - *rv = CreateModuleFromLibrary(n); -}); + n->Init(library_path); + return n; +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc index 5dfd5e8ad7d5..7efa91d912eb 100644 --- a/src/runtime/library_module.cc +++ b/src/runtime/library_module.cc @@ -37,7 +37,8 @@ namespace runtime { // Library module that exposes symbols from a library. class LibraryModuleNode final : public ModuleNode { public: - explicit LibraryModuleNode(ObjectPtr lib) : lib_(lib) {} + explicit LibraryModuleNode(ObjectPtr lib, PackedFuncWrapper wrapper) + : lib_(lib), packed_func_wrapper_(wrapper) {} const char* type_key() const final { return "library"; } @@ -53,11 +54,12 @@ class LibraryModuleNode final : public ModuleNode { faddr = reinterpret_cast(lib_->GetSymbol(name.c_str())); } if (faddr == nullptr) return PackedFunc(); - return WrapPackedFunc(faddr, sptr_to_self); + return packed_func_wrapper_(faddr, sptr_to_self); } private: ObjectPtr lib_; + PackedFuncWrapper packed_func_wrapper_; }; /*! @@ -128,7 +130,8 @@ Module LoadModuleFromBinary(const std::string& type_key, dmlc::Stream* stream) { * \param root_module the output root module * \param dso_ctx_addr the output dso module */ -void ProcessModuleBlob(const char* mblob, ObjectPtr lib, runtime::Module* root_module, +void ProcessModuleBlob(const char* mblob, ObjectPtr lib, + PackedFuncWrapper packed_func_wrapper, runtime::Module* root_module, runtime::ModuleNode** dso_ctx_addr = nullptr) { ICHECK(mblob != nullptr); uint64_t nbytes = 0; @@ -152,7 +155,7 @@ void ProcessModuleBlob(const char* mblob, ObjectPtr lib, runtime::Modul // "_lib" serves as a placeholder in the module import tree to indicate where // to place the DSOModule if (tkey == "_lib") { - auto dso_module = Module(make_object(lib)); + auto dso_module = Module(make_object(lib, packed_func_wrapper)); *dso_ctx_addr = dso_module.operator->(); ++num_dso_module; modules.emplace_back(dso_module); @@ -170,7 +173,7 @@ void ProcessModuleBlob(const char* mblob, ObjectPtr lib, runtime::Modul // if we are using old dll, we don't have import tree // so that we can't reconstruct module relationship using import tree if (import_tree_row_ptr.empty()) { - auto n = make_object(lib); + auto n = make_object(lib, packed_func_wrapper); auto module_import_addr = ModuleInternal::GetImportsAddr(n.operator->()); for (const auto& m : modules) { module_import_addr->emplace_back(m); @@ -194,9 +197,9 @@ void ProcessModuleBlob(const char* mblob, ObjectPtr lib, runtime::Modul } } -Module CreateModuleFromLibrary(ObjectPtr lib) { +Module CreateModuleFromLibrary(ObjectPtr lib, PackedFuncWrapper packed_func_wrapper) { InitContextFunctions([lib](const char* fname) { return lib->GetSymbol(fname); }); - auto n = make_object(lib); + auto n = make_object(lib, packed_func_wrapper); // Load the imported modules const char* dev_mblob = reinterpret_cast(lib->GetSymbol(runtime::symbol::tvm_dev_mblob)); @@ -204,7 +207,7 @@ Module CreateModuleFromLibrary(ObjectPtr lib) { Module root_mod; runtime::ModuleNode* dso_ctx_addr = nullptr; if (dev_mblob != nullptr) { - ProcessModuleBlob(dev_mblob, lib, &root_mod, &dso_ctx_addr); + ProcessModuleBlob(dev_mblob, lib, packed_func_wrapper, &root_mod, &dso_ctx_addr); } else { // Only have one single DSO Module root_mod = Module(n); @@ -218,5 +221,10 @@ Module CreateModuleFromLibrary(ObjectPtr lib) { return root_mod; } + +TVM_REGISTER_GLOBAL("runtime.module.loadfile_so").set_body([](TVMArgs args, TVMRetValue* rv) { + ObjectPtr n = CreateDSOLibraryObject(args[0]); + *rv = CreateModuleFromLibrary(n); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/library_module.h b/src/runtime/library_module.h index 00c79e8248f4..b5780975f43a 100644 --- a/src/runtime/library_module.h +++ b/src/runtime/library_module.h @@ -78,16 +78,35 @@ PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr& */ void InitContextFunctions(std::function fgetsymbol); +/*! + * \brief Type alias for funcion to wrap a TVMBackendPackedCFunc. + * \param The function address imported from a module. + * \param mptr The module pointer node. + * \return Packed function that wraps the invocation of the function at faddr. + */ +using PackedFuncWrapper = + std::function& mptr)>; + +/*! \brief Return a library object interface over dynamic shared + * libraries in Windows and Linux providing support for + * loading/unloading and symbol lookup. + * \param Full path to shared library. + * \return Returns pointer to the Library providing symbol lookup. + */ +ObjectPtr CreateDSOLibraryObject(std::string library_path); + /*! * \brief Create a module from a library. * * \param lib The library. + * \param wrapper Optional function used to wrap a TVMBackendPackedCFunc, + * by default WrapPackedFunc is used. * \return The corresponding loaded module. * * \note This function can create multiple linked modules * by parsing the binary blob section of the library. */ -Module CreateModuleFromLibrary(ObjectPtr lib); +Module CreateModuleFromLibrary(ObjectPtr lib, PackedFuncWrapper wrapper = WrapPackedFunc); } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_LIBRARY_MODULE_H_