From 7b9b05efd683ac3c1ecbd38d28e81e1d99e88940 Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Wed, 13 Oct 2021 11:52:59 -0700 Subject: [PATCH 1/6] Expose DSOLibrary and add documentation. --- src/runtime/dso_library.cc | 91 ++++++++++++++++--------------------- src/runtime/dso_library.h | 93 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 133 insertions(+), 51 deletions(-) create mode 100644 src/runtime/dso_library.h diff --git a/src/runtime/dso_library.cc b/src/runtime/dso_library.cc index c439bde82497..289e77df27e2 100644 --- a/src/runtime/dso_library.cc +++ b/src/runtime/dso_library.cc @@ -21,73 +21,62 @@ * \file dso_libary.cc * \brief Create library module to load from dynamic shared library. */ +#include "dso_library.h" + #include #include #include #include -#include "library_module.h" - -#if defined(_WIN32) -#include -#else +#if !defined(_WIN32) #include #endif namespace tvm { namespace runtime { -// Dynamic shared libary. -// This is the default module TVM used for host-side AOT -class DSOLibrary final : public Library { - public: - ~DSOLibrary() { - if (lib_handle_) Unload(); - } - void Init(const std::string& name) { Load(name); } +DSOLibrary::~DSOLibrary() { + if (lib_handle_) Unload(); +} - void* GetSymbol(const char* name) final { return GetSymbol_(name); } +void DSOLibrary::Init(const std::string& name) { Load(name); } + +void* DSOLibrary::GetSymbol(const char* name) { return GetSymbol_(name); } - private: - // Platform dependent handling. #if defined(_WIN32) - // library handle - HMODULE lib_handle_{nullptr}; - - void* GetSymbol_(const char* name) { - return reinterpret_cast(GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*) - } - - // Load the library - void Load(const std::string& name) { - // use wstring version that is needed by LLVM. - std::wstring wname(name.begin(), name.end()); - lib_handle_ = LoadLibraryW(wname.c_str()); - ICHECK(lib_handle_ != nullptr) << "Failed to load dynamic shared library " << name; - } - - void Unload() { - FreeLibrary(lib_handle_); - lib_handle_ = nullptr; - } + +void* DSOLibrary::GetSymbol_(const char* name) { + return reinterpret_cast(GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*) +} + +void DSOLibrary::Load(const std::string& name) { + // use wstring version that is needed by LLVM. + std::wstring wname(name.begin(), name.end()); + lib_handle_ = LoadLibraryW(wname.c_str()); + ICHECK(lib_handle_ != nullptr) << "Failed to load dynamic shared library " << name; +} + +void DSOLibrary::Unload() { + FreeLibrary(lib_handle_); + lib_handle_ = nullptr; +} + #else - // Library handle - void* lib_handle_{nullptr}; - // load the library - void Load(const std::string& name) { - lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL); - ICHECK(lib_handle_ != nullptr) - << "Failed to load dynamic shared library " << name << " " << dlerror(); - } - - void* GetSymbol_(const char* name) { return dlsym(lib_handle_, name); } - - void Unload() { - dlclose(lib_handle_); - lib_handle_ = nullptr; - } + +void DSOLibrary::Load(const std::string& name) { + lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL); + ICHECK(lib_handle_ != nullptr) + << "Failed to load dynamic shared library " << name << " " << dlerror(); +} + +void* DSOLibrary::GetSymbol_(const char* name) { return dlsym(lib_handle_, name); } + +void DSOLibrary::Unload() { + dlclose(lib_handle_); + lib_handle_ = nullptr; +} + #endif -}; TVM_REGISTER_GLOBAL("runtime.module.loadfile_so").set_body([](TVMArgs args, TVMRetValue* rv) { auto n = make_object(); diff --git a/src/runtime/dso_library.h b/src/runtime/dso_library.h new file mode 100644 index 000000000000..f8266ea0e38b --- /dev/null +++ b/src/runtime/dso_library.h @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file dso_module.h + * \brief Abstraction over dynamic shared librariess in + * Windows and Linux providing support for loading/unloading + * and symbol lookup. + */ +#ifndef TVM_RUNTIME_DSO_LIBRARY_H_ +#define TVM_RUNTIME_DSO_LIBRARY_H_ + +#include + +#include + +#if defined(_WIN32) +#include +#endif + +#include "library_module.h" + +namespace tvm { +namespace runtime { + +/*! + * \brief Dynamic shared library object used to load + * and retrieve symbols by name. This is the default + * module TVM uses for host-side AOT compilation. + */ +class DSOLibrary final : public Library { + public: + ~DSOLibrary(); + /*! + * \brief Initialize by loading and storing + * a handle to the underlying shared library. + * \param name The string name/path to the + * shared library over which to initialize. + */ + void Init(const std::string& name); + /*! + * \brief Returns the symbol address within + * the shared library for a given symbol name. + * \param name The name of the symbol. + * \return The symbol. + */ + void* GetSymbol(const char* name) final; + + private: + /*! \brief Private implementation of symbol lookup. + * Implementation is operating system dependent. + * \param The name of the symbol. + * \return The symbol. + */ + void* GetSymbol_(const char* name); + /*! \brief Implementation of shared library load. + * Implementation is operating system dependent. + * \param The name/path of the shared library. + */ + void Load(const std::string& name); + /*! \brief Implementation of shared library unload. + * Implementation is operating system dependent. + */ + void Unload(); + +#if defined(_WIN32) + //! \brief Windows library handle + HMODULE lib_handle_{nullptr}; +#else + // \brief Linux library handle + void* lib_handle_{nullptr}; +#endif +}; +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_DSO_LIBRARY_H_ From b637d25776e960d47de70bc9a007a164c7801173 Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Wed, 13 Oct 2021 11:54:35 -0700 Subject: [PATCH 2/6] Allow runtimes to specialize library module function wrapping by providing a PackedFunctionWrapper object at construction. --- src/runtime/library_module.cc | 27 +++++++++++++++++++-------- src/runtime/library_module.h | 20 +++++++++++++++++++- 2 files changed, 38 insertions(+), 9 deletions(-) diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc index 5dfd5e8ad7d5..08aea2085c7a 100644 --- a/src/runtime/library_module.cc +++ b/src/runtime/library_module.cc @@ -37,7 +37,8 @@ namespace runtime { // Library module that exposes symbols from a library. class LibraryModuleNode final : public ModuleNode { public: - explicit LibraryModuleNode(ObjectPtr lib) : lib_(lib) {} + explicit LibraryModuleNode(ObjectPtr lib, ObjectPtr wrapper) + : lib_(lib), packed_func_wrapper_(wrapper) {} const char* type_key() const final { return "library"; } @@ -53,11 +54,12 @@ class LibraryModuleNode final : public ModuleNode { faddr = reinterpret_cast(lib_->GetSymbol(name.c_str())); } if (faddr == nullptr) return PackedFunc(); - return WrapPackedFunc(faddr, sptr_to_self); + return packed_func_wrapper_->operator()(faddr, sptr_to_self); } private: ObjectPtr lib_; + ObjectPtr packed_func_wrapper_; }; /*! @@ -69,6 +71,10 @@ class ModuleInternal { static std::vector* GetImportsAddr(ModuleNode* node) { return &(node->imports_); } }; +PackedFunc PackedFuncWrapper::operator()(TVMBackendPackedCFunc faddr, const ObjectPtr& mptr) { + return WrapPackedFunc(faddr, mptr); +} + PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr& sptr_to_self) { return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) { TVMValue ret_value; @@ -128,7 +134,9 @@ Module LoadModuleFromBinary(const std::string& type_key, dmlc::Stream* stream) { * \param root_module the output root module * \param dso_ctx_addr the output dso module */ -void ProcessModuleBlob(const char* mblob, ObjectPtr lib, runtime::Module* root_module, +void ProcessModuleBlob(const char* mblob, ObjectPtr lib, + ObjectPtr packed_func_wrapper, + runtime::Module* root_module, runtime::ModuleNode** dso_ctx_addr = nullptr) { ICHECK(mblob != nullptr); uint64_t nbytes = 0; @@ -152,7 +160,7 @@ void ProcessModuleBlob(const char* mblob, ObjectPtr lib, runtime::Modul // "_lib" serves as a placeholder in the module import tree to indicate where // to place the DSOModule if (tkey == "_lib") { - auto dso_module = Module(make_object(lib)); + auto dso_module = Module(make_object(lib, packed_func_wrapper)); *dso_ctx_addr = dso_module.operator->(); ++num_dso_module; modules.emplace_back(dso_module); @@ -170,7 +178,7 @@ void ProcessModuleBlob(const char* mblob, ObjectPtr lib, runtime::Modul // if we are using old dll, we don't have import tree // so that we can't reconstruct module relationship using import tree if (import_tree_row_ptr.empty()) { - auto n = make_object(lib); + auto n = make_object(lib, packed_func_wrapper); auto module_import_addr = ModuleInternal::GetImportsAddr(n.operator->()); for (const auto& m : modules) { module_import_addr->emplace_back(m); @@ -194,9 +202,12 @@ void ProcessModuleBlob(const char* mblob, ObjectPtr lib, runtime::Modul } } -Module CreateModuleFromLibrary(ObjectPtr lib) { +Module CreateModuleFromLibrary(ObjectPtr lib, ObjectPtr packed_func_wrapper) { InitContextFunctions([lib](const char* fname) { return lib->GetSymbol(fname); }); - auto n = make_object(lib); + if (packed_func_wrapper == nullptr) { + packed_func_wrapper = make_object(); + } + auto n = make_object(lib, packed_func_wrapper); // Load the imported modules const char* dev_mblob = reinterpret_cast(lib->GetSymbol(runtime::symbol::tvm_dev_mblob)); @@ -204,7 +215,7 @@ Module CreateModuleFromLibrary(ObjectPtr lib) { Module root_mod; runtime::ModuleNode* dso_ctx_addr = nullptr; if (dev_mblob != nullptr) { - ProcessModuleBlob(dev_mblob, lib, &root_mod, &dso_ctx_addr); + ProcessModuleBlob(dev_mblob, lib, packed_func_wrapper, & root_mod, &dso_ctx_addr); } else { // Only have one single DSO Module root_mod = Module(n); diff --git a/src/runtime/library_module.h b/src/runtime/library_module.h index 00c79e8248f4..8a42cb283148 100644 --- a/src/runtime/library_module.h +++ b/src/runtime/library_module.h @@ -65,6 +65,23 @@ class Library : public Object { // This is because we do not need dynamic type downcasting. }; +/*! +* \brief Default virtual functor that provides an interface to +* wrap a TVMBackendPackedCFunc. Virtual interface allows derivative +* runtime's that utilize a library module to to provide custom +* function wrapping. By default WrapPackedFunc is used. +*/ +class PackedFuncWrapper : public Object { + public: + /* + * \brief Virtual interface for wrapping a library function + * \param faddr The function address. + * \param mptr The module pointer node. + * \return A packed function wrapping the requested function. + */ + virtual PackedFunc operator()(TVMBackendPackedCFunc faddr, const ObjectPtr& mptr); +}; + /*! * \brief Wrap a TVMBackendPackedCFunc to packed function. * \param faddr The function address @@ -87,7 +104,8 @@ void InitContextFunctions(std::function fgetsymbol); * \note This function can create multiple linked modules * by parsing the binary blob section of the library. */ -Module CreateModuleFromLibrary(ObjectPtr lib); +Module CreateModuleFromLibrary(ObjectPtr lib, + ObjectPtr wrapper = nullptr); } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_LIBRARY_MODULE_H_ From 5fa09294de8887a8244b91e244c33f83297e3dea Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Thu, 21 Oct 2021 08:42:34 -0700 Subject: [PATCH 3/6] Apply clang formatting. --- src/runtime/dso_library.cc | 4 ++-- src/runtime/library_module.cc | 13 +++++++------ src/runtime/library_module.h | 10 +++++----- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/src/runtime/dso_library.cc b/src/runtime/dso_library.cc index 289e77df27e2..3225c1e03e8c 100644 --- a/src/runtime/dso_library.cc +++ b/src/runtime/dso_library.cc @@ -65,8 +65,8 @@ void DSOLibrary::Unload() { void DSOLibrary::Load(const std::string& name) { lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL); - ICHECK(lib_handle_ != nullptr) - << "Failed to load dynamic shared library " << name << " " << dlerror(); + ICHECK(lib_handle_ != nullptr) << "Failed to load dynamic shared library " << name << " " + << dlerror(); } void* DSOLibrary::GetSymbol_(const char* name) { return dlsym(lib_handle_, name); } diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc index 08aea2085c7a..e16ca12ffc2a 100644 --- a/src/runtime/library_module.cc +++ b/src/runtime/library_module.cc @@ -38,7 +38,7 @@ namespace runtime { class LibraryModuleNode final : public ModuleNode { public: explicit LibraryModuleNode(ObjectPtr lib, ObjectPtr wrapper) - : lib_(lib), packed_func_wrapper_(wrapper) {} + : lib_(lib), packed_func_wrapper_(wrapper) {} const char* type_key() const final { return "library"; } @@ -71,7 +71,8 @@ class ModuleInternal { static std::vector* GetImportsAddr(ModuleNode* node) { return &(node->imports_); } }; -PackedFunc PackedFuncWrapper::operator()(TVMBackendPackedCFunc faddr, const ObjectPtr& mptr) { +PackedFunc PackedFuncWrapper::operator()(TVMBackendPackedCFunc faddr, + const ObjectPtr& mptr) { return WrapPackedFunc(faddr, mptr); } @@ -136,8 +137,7 @@ Module LoadModuleFromBinary(const std::string& type_key, dmlc::Stream* stream) { */ void ProcessModuleBlob(const char* mblob, ObjectPtr lib, ObjectPtr packed_func_wrapper, - runtime::Module* root_module, - runtime::ModuleNode** dso_ctx_addr = nullptr) { + runtime::Module* root_module, runtime::ModuleNode** dso_ctx_addr = nullptr) { ICHECK(mblob != nullptr); uint64_t nbytes = 0; for (size_t i = 0; i < sizeof(nbytes); ++i) { @@ -202,7 +202,8 @@ void ProcessModuleBlob(const char* mblob, ObjectPtr lib, } } -Module CreateModuleFromLibrary(ObjectPtr lib, ObjectPtr packed_func_wrapper) { +Module CreateModuleFromLibrary(ObjectPtr lib, + ObjectPtr packed_func_wrapper) { InitContextFunctions([lib](const char* fname) { return lib->GetSymbol(fname); }); if (packed_func_wrapper == nullptr) { packed_func_wrapper = make_object(); @@ -215,7 +216,7 @@ Module CreateModuleFromLibrary(ObjectPtr lib, ObjectPtr Date: Thu, 21 Oct 2021 12:05:06 -0700 Subject: [PATCH 4/6] Use std::function. --- src/runtime/library_module.cc | 21 ++++++--------------- src/runtime/library_module.h | 28 +++++++++------------------- 2 files changed, 15 insertions(+), 34 deletions(-) diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc index e16ca12ffc2a..7e185e444c6d 100644 --- a/src/runtime/library_module.cc +++ b/src/runtime/library_module.cc @@ -37,7 +37,7 @@ namespace runtime { // Library module that exposes symbols from a library. class LibraryModuleNode final : public ModuleNode { public: - explicit LibraryModuleNode(ObjectPtr lib, ObjectPtr wrapper) + explicit LibraryModuleNode(ObjectPtr lib, PackedFuncWrapper wrapper) : lib_(lib), packed_func_wrapper_(wrapper) {} const char* type_key() const final { return "library"; } @@ -54,12 +54,12 @@ class LibraryModuleNode final : public ModuleNode { faddr = reinterpret_cast(lib_->GetSymbol(name.c_str())); } if (faddr == nullptr) return PackedFunc(); - return packed_func_wrapper_->operator()(faddr, sptr_to_self); + return packed_func_wrapper_(faddr, sptr_to_self); } private: ObjectPtr lib_; - ObjectPtr packed_func_wrapper_; + PackedFuncWrapper packed_func_wrapper_; }; /*! @@ -71,11 +71,6 @@ class ModuleInternal { static std::vector* GetImportsAddr(ModuleNode* node) { return &(node->imports_); } }; -PackedFunc PackedFuncWrapper::operator()(TVMBackendPackedCFunc faddr, - const ObjectPtr& mptr) { - return WrapPackedFunc(faddr, mptr); -} - PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr& sptr_to_self) { return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) { TVMValue ret_value; @@ -136,8 +131,8 @@ Module LoadModuleFromBinary(const std::string& type_key, dmlc::Stream* stream) { * \param dso_ctx_addr the output dso module */ void ProcessModuleBlob(const char* mblob, ObjectPtr lib, - ObjectPtr packed_func_wrapper, - runtime::Module* root_module, runtime::ModuleNode** dso_ctx_addr = nullptr) { + PackedFuncWrapper packed_func_wrapper, runtime::Module* root_module, + runtime::ModuleNode** dso_ctx_addr = nullptr) { ICHECK(mblob != nullptr); uint64_t nbytes = 0; for (size_t i = 0; i < sizeof(nbytes); ++i) { @@ -202,12 +197,8 @@ void ProcessModuleBlob(const char* mblob, ObjectPtr lib, } } -Module CreateModuleFromLibrary(ObjectPtr lib, - ObjectPtr packed_func_wrapper) { +Module CreateModuleFromLibrary(ObjectPtr lib, PackedFuncWrapper packed_func_wrapper) { InitContextFunctions([lib](const char* fname) { return lib->GetSymbol(fname); }); - if (packed_func_wrapper == nullptr) { - packed_func_wrapper = make_object(); - } auto n = make_object(lib, packed_func_wrapper); // Load the imported modules const char* dev_mblob = diff --git a/src/runtime/library_module.h b/src/runtime/library_module.h index 486b052f96d8..23b9830b3a5f 100644 --- a/src/runtime/library_module.h +++ b/src/runtime/library_module.h @@ -65,23 +65,6 @@ class Library : public Object { // This is because we do not need dynamic type downcasting. }; -/*! - * \brief Default virtual functor that provides an interface to - * wrap a TVMBackendPackedCFunc. Virtual interface allows derivative - * runtime's that utilize a library module to to provide custom - * function wrapping. By default WrapPackedFunc is used. - */ -class PackedFuncWrapper : public Object { - public: - /* - * \brief Virtual interface for wrapping a library function - * \param faddr The function address. - * \param mptr The module pointer node. - * \return A packed function wrapping the requested function. - */ - virtual PackedFunc operator()(TVMBackendPackedCFunc faddr, const ObjectPtr& mptr); -}; - /*! * \brief Wrap a TVMBackendPackedCFunc to packed function. * \param faddr The function address @@ -95,17 +78,24 @@ PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr& */ void InitContextFunctions(std::function fgetsymbol); +/*! + * \brief Type alias for funcion to wrap a TVMBackendPackedCFunc. + */ +using PackedFuncWrapper = + std::function&)>; + /*! * \brief Create a module from a library. * * \param lib The library. + * \param wrapper Optional function used to wrap a TVMBackendPackedCFunc, + * by default WrapPackedFunc is used. * \return The corresponding loaded module. * * \note This function can create multiple linked modules * by parsing the binary blob section of the library. */ -Module CreateModuleFromLibrary(ObjectPtr lib, - ObjectPtr wrapper = nullptr); +Module CreateModuleFromLibrary(ObjectPtr lib, PackedFuncWrapper wrapper = WrapPackedFunc); } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_LIBRARY_MODULE_H_ From 2b897187cf5ba64a030fd58960067f493716169d Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Mon, 25 Oct 2021 16:05:41 -0700 Subject: [PATCH 5/6] Minimize DSOLibrary interface. --- src/runtime/dso_library.cc | 65 +++++++++++++++++++++--- src/runtime/dso_library.h | 93 ----------------------------------- src/runtime/library_module.cc | 5 ++ src/runtime/library_module.h | 8 +++ 4 files changed, 71 insertions(+), 100 deletions(-) delete mode 100644 src/runtime/dso_library.h diff --git a/src/runtime/dso_library.cc b/src/runtime/dso_library.cc index 3225c1e03e8c..81eb30ee12d2 100644 --- a/src/runtime/dso_library.cc +++ b/src/runtime/dso_library.cc @@ -21,20 +21,71 @@ * \file dso_libary.cc * \brief Create library module to load from dynamic shared library. */ -#include "dso_library.h" - #include #include #include #include -#if !defined(_WIN32) +#include "library_module.h" + +#if defined(_WIN32) +#include +#else #include #endif namespace tvm { namespace runtime { +/*! + * \brief Dynamic shared library object used to load + * and retrieve symbols by name. This is the default + * module TVM uses for host-side AOT compilation. + */ +class DSOLibrary final : public Library { + public: + ~DSOLibrary(); + /*! + * \brief Initialize by loading and storing + * a handle to the underlying shared library. + * \param name The string name/path to the + * shared library over which to initialize. + */ + void Init(const std::string& name); + /*! + * \brief Returns the symbol address within + * the shared library for a given symbol name. + * \param name The name of the symbol. + * \return The symbol. + */ + void* GetSymbol(const char* name) final; + + private: + /*! \brief Private implementation of symbol lookup. + * Implementation is operating system dependent. + * \param The name of the symbol. + * \return The symbol. + */ + void* GetSymbol_(const char* name); + /*! \brief Implementation of shared library load. + * Implementation is operating system dependent. + * \param The name/path of the shared library. + */ + void Load(const std::string& name); + /*! \brief Implementation of shared library unload. + * Implementation is operating system dependent. + */ + void Unload(); + +#if defined(_WIN32) + //! \brief Windows library handle + HMODULE lib_handle_{nullptr}; +#else + // \brief Linux library handle + void* lib_handle_{nullptr}; +#endif +}; + DSOLibrary::~DSOLibrary() { if (lib_handle_) Unload(); } @@ -78,10 +129,10 @@ void DSOLibrary::Unload() { #endif -TVM_REGISTER_GLOBAL("runtime.module.loadfile_so").set_body([](TVMArgs args, TVMRetValue* rv) { +ObjectPtr CreateDSOLibraryObject(std::string library_path) { auto n = make_object(); - n->Init(args[0]); - *rv = CreateModuleFromLibrary(n); -}); + n->Init(library_path); + return n; +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/dso_library.h b/src/runtime/dso_library.h deleted file mode 100644 index f8266ea0e38b..000000000000 --- a/src/runtime/dso_library.h +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file dso_module.h - * \brief Abstraction over dynamic shared librariess in - * Windows and Linux providing support for loading/unloading - * and symbol lookup. - */ -#ifndef TVM_RUNTIME_DSO_LIBRARY_H_ -#define TVM_RUNTIME_DSO_LIBRARY_H_ - -#include - -#include - -#if defined(_WIN32) -#include -#endif - -#include "library_module.h" - -namespace tvm { -namespace runtime { - -/*! - * \brief Dynamic shared library object used to load - * and retrieve symbols by name. This is the default - * module TVM uses for host-side AOT compilation. - */ -class DSOLibrary final : public Library { - public: - ~DSOLibrary(); - /*! - * \brief Initialize by loading and storing - * a handle to the underlying shared library. - * \param name The string name/path to the - * shared library over which to initialize. - */ - void Init(const std::string& name); - /*! - * \brief Returns the symbol address within - * the shared library for a given symbol name. - * \param name The name of the symbol. - * \return The symbol. - */ - void* GetSymbol(const char* name) final; - - private: - /*! \brief Private implementation of symbol lookup. - * Implementation is operating system dependent. - * \param The name of the symbol. - * \return The symbol. - */ - void* GetSymbol_(const char* name); - /*! \brief Implementation of shared library load. - * Implementation is operating system dependent. - * \param The name/path of the shared library. - */ - void Load(const std::string& name); - /*! \brief Implementation of shared library unload. - * Implementation is operating system dependent. - */ - void Unload(); - -#if defined(_WIN32) - //! \brief Windows library handle - HMODULE lib_handle_{nullptr}; -#else - // \brief Linux library handle - void* lib_handle_{nullptr}; -#endif -}; -} // namespace runtime -} // namespace tvm - -#endif // TVM_RUNTIME_DSO_LIBRARY_H_ diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc index 7e185e444c6d..7efa91d912eb 100644 --- a/src/runtime/library_module.cc +++ b/src/runtime/library_module.cc @@ -221,5 +221,10 @@ Module CreateModuleFromLibrary(ObjectPtr lib, PackedFuncWrapper packed_ return root_mod; } + +TVM_REGISTER_GLOBAL("runtime.module.loadfile_so").set_body([](TVMArgs args, TVMRetValue* rv) { + ObjectPtr n = CreateDSOLibraryObject(args[0]); + *rv = CreateModuleFromLibrary(n); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/library_module.h b/src/runtime/library_module.h index 23b9830b3a5f..81f222027d9c 100644 --- a/src/runtime/library_module.h +++ b/src/runtime/library_module.h @@ -84,6 +84,14 @@ void InitContextFunctions(std::function fgetsymbol); using PackedFuncWrapper = std::function&)>; +/*! \brief Return a library object interface over dynamic shared + * libraries in Windows and Linux providing support for + * loading/unloading and symbol lookup. + * \param Full path to shared library. + * \return Returns pointer to the Library providing symbol lookup. + */ +ObjectPtr CreateDSOLibraryObject(std::string library_path); + /*! * \brief Create a module from a library. * From 51ff91738f2b4694b2d8285eeeefb139e4cf6c6a Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Mon, 25 Oct 2021 17:07:32 -0700 Subject: [PATCH 6/6] Add param and return documentation to PackedFuncWrapper type alias. --- src/runtime/library_module.h | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/runtime/library_module.h b/src/runtime/library_module.h index 81f222027d9c..b5780975f43a 100644 --- a/src/runtime/library_module.h +++ b/src/runtime/library_module.h @@ -80,9 +80,12 @@ void InitContextFunctions(std::function fgetsymbol); /*! * \brief Type alias for funcion to wrap a TVMBackendPackedCFunc. + * \param The function address imported from a module. + * \param mptr The module pointer node. + * \return Packed function that wraps the invocation of the function at faddr. */ using PackedFuncWrapper = - std::function&)>; + std::function& mptr)>; /*! \brief Return a library object interface over dynamic shared * libraries in Windows and Linux providing support for