diff --git a/mllm/CMakeLists.txt b/mllm/CMakeLists.txt index 05e670275..f28afdce2 100644 --- a/mllm/CMakeLists.txt +++ b/mllm/CMakeLists.txt @@ -15,6 +15,7 @@ add_library( mllm.cpp backends/base/Backend.cpp backends/base/Allocator.cpp + backends/base/PluginSystem.cpp ${MLLM_RT_CORE_SRC} ${MLLM_RT_UTILS_SRC} ${MLLM_RT_ENGINE_SRC} diff --git a/mllm/backends/base/PluginInterface.cpp b/mllm/backends/base/PluginInterface.cpp deleted file mode 100644 index 8a948dbc3..000000000 --- a/mllm/backends/base/PluginInterface.cpp +++ /dev/null @@ -1,2 +0,0 @@ -// Copyright (c) MLLM Team. -// Licensed under the MIT License. diff --git a/mllm/backends/base/PluginInterface.hpp b/mllm/backends/base/PluginInterface.hpp index 73fbb6aa8..867d34d33 100644 --- a/mllm/backends/base/PluginInterface.hpp +++ b/mllm/backends/base/PluginInterface.hpp @@ -4,9 +4,21 @@ #include "mllm/core/BaseOp.hpp" +#define MLLM_PLUGIN_OP_PACKAGE_DESCRIPTOR_VERSION 1 +#define MLLM_PLUGIN_OP_PACKAGE_NAME_LEN 256 +#define MLLM_PLUGIN_OP_PACKAGE_DESCRIPTOR_LEN 256 + +#define MLLM_PLUGIN_OP_INTERFACE_DEFINE_BEGIN \ + extern "C" { \ + void* opPackageDescriptor(); +#define MLLM_PLUGIN_OP_INTERFACE_DEFINE_END } + namespace mllm::plugin::interface { -class CustomizedOp : BaseOp {}; +class CustomizedOp : public BaseOp { + public: + explicit CustomizedOp(const std::string& name) : BaseOp(OpTypes::kDynamicOp_Start) { setName(name); } +}; template class CustomizedOpFactory : protected TypedOpFactory { @@ -17,3 +29,20 @@ class CustomizedOpFactory : protected TypedOpFactory +#else +#include +#endif + +#include +#include + #include "mllm/engine/Context.hpp" #include "mllm/backends/base/PluginSystem.hpp" namespace mllm::plugin { -int32_t OpPluginSystem::registerCustomizedOp(DeviceTypes device_type, const std::shared_ptr& factory) { +OpPluginSystem::~OpPluginSystem() { +#ifdef _WIN32 + for (auto handle : loaded_libraries_) { + if (handle) { FreeLibrary(handle); } + } +#else + for (auto handle : loaded_libraries_) { + if (handle) { dlclose(handle); } + } +#endif + for (auto descriptor : op_packages_) { + if (descriptor) { delete descriptor; } + } +} + +void OpPluginSystem::loadOpPackage(const std::string& path) { + // 1. DLOPEN +#ifdef _WIN32 + HMODULE handle = LoadLibraryA(path.c_str()); + if (!handle) { throw std::runtime_error("Failed to load plugin library on Windows: " + path); } +#else // Linux and other Unix-like systems + void* handle = dlopen(path.c_str(), RTLD_LAZY); + if (!handle) { + throw std::runtime_error("Failed to load plugin library on Linux and other Unix-like systems: " + std::string(dlerror())); + } +#endif + + // 2. Call function opPackageDescriptor() in the library + typedef void* (*DescriptorFunc)(); // NOLINT +#ifdef _WIN32 + DescriptorFunc descriptor = (DescriptorFunc)GetProcAddress(handle, "opPackageDescriptor"); + if (!descriptor) { + FreeLibrary(handle); + throw std::runtime_error("Failed to find opPackageDescriptor function in plugin: " + path); + } +#else + DescriptorFunc descriptor = (DescriptorFunc)dlsym(handle, "opPackageDescriptor"); + if (!descriptor) { + dlclose(handle); + throw std::runtime_error("Failed to find opPackageDescriptor function in plugin: " + std::string(dlerror())); + } +#endif + + // 3. Cast void* return to PluginOpPackageDescriptor* + void* result = descriptor(); + PluginOpPackageDescriptor* package_descriptor = static_cast(result); + MLLM_RT_ASSERT(package_descriptor != nullptr); + op_packages_.push_back(package_descriptor); + loaded_libraries_.push_back(handle); + + // 4. Load all factory to the backend + MLLM_INFO("Load customized op package: {}, find op nums: {}", package_descriptor->name, + package_descriptor->op_factories_count); + for (int i = 0; i < package_descriptor->op_factories_count; ++i) { + std::shared_ptr factory((BaseOpFactory*)(package_descriptor->op_factory_create_funcs[i]()), + [&package_descriptor, &i](BaseOpFactory* ptr) { + // FIXME: will raise error + // package_descriptor->op_factory_free_funcs[i](ptr); + }); + registerCustomizedOp(static_cast(package_descriptor->device_type), + std::string(package_descriptor->op_factories_names[i]), factory); + } +} + +int32_t OpPluginSystem::registerCustomizedOp(DeviceTypes device_type, const std::string& name, + const std::shared_ptr& factory) { auto op_type_ret = ++dynamic_op_type_counter_; - Context::instance().getBackend(device_type)->regOpFactory(factory); factory->__forceSetType(op_type_ret); + Context::instance().getBackend(device_type)->regOpFactory(factory); + if (!op_name_table_.has(device_type)) { op_name_table_.reg(device_type, SymbolTable{}); } + op_name_table_[device_type].reg(name, op_type_ret); + MLLM_INFO("Register customized op: {}:{} -> {}", name, op_type_ret, deviceTypes2Str(device_type)); return op_type_ret; } +int32_t OpPluginSystem::lookupCustomizedOp(DeviceTypes device_type, const std::string& name) { + return op_name_table_[device_type][name]; +} + } // namespace mllm::plugin diff --git a/mllm/backends/base/PluginSystem.hpp b/mllm/backends/base/PluginSystem.hpp index 67a2c5a2b..2e5533214 100644 --- a/mllm/backends/base/PluginSystem.hpp +++ b/mllm/backends/base/PluginSystem.hpp @@ -12,6 +12,7 @@ #include "mllm/core/OpTypes.hpp" #include "mllm/core/DeviceTypes.hpp" #include "mllm/utils/SymbolTable.hpp" +#include "mllm/backends/base/PluginInterface.hpp" namespace mllm::plugin { @@ -20,11 +21,25 @@ class OpPluginSystem { using op_type_t = int32_t; using op_name_t = std::string; - int32_t registerCustomizedOp(DeviceTypes device_type, const std::shared_ptr& factory); + OpPluginSystem() = default; + + ~OpPluginSystem(); + + void loadOpPackage(const std::string& path); + + int32_t registerCustomizedOp(DeviceTypes device_type, const std::string& name, const std::shared_ptr& factory); + + int32_t lookupCustomizedOp(DeviceTypes device_type, const std::string& name); private: int32_t dynamic_op_type_counter_ = (int32_t)OpTypes::kDynamicOp_Start; - SymbolTable op_name_table_; + SymbolTable> op_name_table_; + std::vector op_packages_; +#ifdef _WIN32 + std::vector loaded_libraries_; +#else + std::vector loaded_libraries_; +#endif }; } // namespace mllm::plugin diff --git a/mllm/core/BaseOp.cpp b/mllm/core/BaseOp.cpp index 39109d7e5..b86000269 100644 --- a/mllm/core/BaseOp.cpp +++ b/mllm/core/BaseOp.cpp @@ -23,4 +23,6 @@ void BaseOp::setDeviceType(DeviceTypes device_type) { device_type_ = device_type OpTypes BaseOp::getOpType() const { return op_type_; } +void BaseOp::setOpType(OpTypes op_type) { op_type_ = op_type; } + } // namespace mllm diff --git a/mllm/core/BaseOp.hpp b/mllm/core/BaseOp.hpp index b2629f9d5..0ca8e6d9f 100644 --- a/mllm/core/BaseOp.hpp +++ b/mllm/core/BaseOp.hpp @@ -173,6 +173,8 @@ class BaseOp : public std::enable_shared_from_this { OpTypes getOpType() const; + void setOpType(OpTypes op_type); + private: DeviceTypes device_type_; std::string name_; diff --git a/mllm/engine/Context.cpp b/mllm/engine/Context.cpp index 5c9f0ffb6..e7de7c79c 100644 --- a/mllm/engine/Context.cpp +++ b/mllm/engine/Context.cpp @@ -124,4 +124,10 @@ void Context::setPrintMaxElementsPerDim(int max_elements) { print_max_elements_p int Context::getPrintMaxElementsPerDim() const { return print_max_elements_per_dim_; } +void Context::loadOpPackage(const std::string& path) { op_plugin_system_.loadOpPackage(path); } + +int32_t Context::lookupCustomizedOpId(DeviceTypes device_type, const std::string& name) { + return op_plugin_system_.lookupCustomizedOp(device_type, name); +} + } // namespace mllm diff --git a/mllm/engine/Context.hpp b/mllm/engine/Context.hpp index fe0b0bc84..19eb63158 100644 --- a/mllm/engine/Context.hpp +++ b/mllm/engine/Context.hpp @@ -14,6 +14,7 @@ #include "mllm/utils/SymbolTable.hpp" #include "mllm/engine/MemoryManager.hpp" #include "mllm/backends/base/Backend.hpp" +#include "mllm/backends/base/PluginSystem.hpp" namespace mllm { @@ -61,10 +62,17 @@ class Context { int getPrintMaxElementsPerDim() const; + void loadOpPackage(const std::string& path); + + int32_t lookupCustomizedOpId(DeviceTypes device_type, const std::string& name); + private: // NOTE: Context should be made private in singleton design pattern. Context(); + // Plugin system + plugin::OpPluginSystem op_plugin_system_; + uint64_t random_seed_ = 42; uint64_t random_state_ = 42; SessionTCB::ptr_t main_thread_; diff --git a/mllm/mllm.cpp b/mllm/mllm.cpp index 580c59a77..808f0fdb4 100644 --- a/mllm/mllm.cpp +++ b/mllm/mllm.cpp @@ -58,6 +58,8 @@ void cleanThisThread() { SessionTCB::ptr_t thisThread() { return Context::instance().thisThread(); } +void loadOpPackage(const std::string& path) { Context::instance().loadOpPackage(path); } + ParameterFile::ptr_t load(const std::string& file_name, ModelFileVersion v, DeviceTypes map_2_device) { if (v == ModelFileVersion::kV1 && map_2_device == kCPU) { return ParameterFileIOImpl::read(file_name); diff --git a/mllm/mllm.hpp b/mllm/mllm.hpp index 152435fbb..0c70a7446 100644 --- a/mllm/mllm.hpp +++ b/mllm/mllm.hpp @@ -178,6 +178,8 @@ void cleanThisThread(); SessionTCB::ptr_t thisThread(); +void loadOpPackage(const std::string& path); + ParameterFile::ptr_t load(const std::string& file_name, ModelFileVersion version = ModelFileVersion::kV1, DeviceTypes map_2_device = kCPU); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 4769286f9..ff5442f27 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -1,6 +1,7 @@ add_subdirectory(nn) add_subdirectory(core) add_subdirectory(utils) +add_subdirectory(plugin) add_subdirectory(engine) add_subdirectory(compile) diff --git a/tests/plugin/CMakeLists.txt b/tests/plugin/CMakeLists.txt new file mode 100644 index 000000000..c3a4a8fe2 --- /dev/null +++ b/tests/plugin/CMakeLists.txt @@ -0,0 +1,7 @@ +add_executable(Mllm-Test-Plugin-HostTest HostTest.cpp) +target_link_libraries(Mllm-Test-Plugin-HostTest PRIVATE MllmRT MllmCPUBackend) +target_include_directories(Mllm-Test-Plugin-HostTest PRIVATE ${MLLM_INCLUDE_DIR}) + +add_library(CustomPackageForHostTest SHARED CustomPackageForHostTest.cpp) +target_link_libraries(CustomPackageForHostTest PRIVATE MllmRT MllmCPUBackend) +target_include_directories(CustomPackageForHostTest PRIVATE ${MLLM_INCLUDE_DIR}) diff --git a/tests/plugin/CustomPackageForHostTest.cpp b/tests/plugin/CustomPackageForHostTest.cpp new file mode 100644 index 000000000..8354b7b4a --- /dev/null +++ b/tests/plugin/CustomPackageForHostTest.cpp @@ -0,0 +1,36 @@ +#include "CustomPackageForHostTest.hpp" + +MLLM_PLUGIN_OP_INTERFACE_DEFINE_BEGIN +void* createCustomOp1Factory() { return new CustomOp1Factory(); }; + +void freeCustomOp1Factory(void* factory) { delete static_cast(factory); }; + +void* createCustomOp2Factory() { return new CustomOp2Factory(); }; + +void freeCustomOp2Factory(void* factory) { delete static_cast(factory); }; + +void* opPackageDescriptor() { + auto package = new PluginOpPackageDescriptor{ + .version = MLLM_PLUGIN_OP_PACKAGE_DESCRIPTOR_VERSION, + .name = "CustomPackageForHostTest", + .device_type = 1, + .op_factories_count = 2, + .op_factories_names = + { + "custom_op1", + "custom_op2", + }, + .op_factory_create_funcs = + { + createCustomOp1Factory, + createCustomOp2Factory, + }, + .op_factory_free_funcs = + { + freeCustomOp1Factory, + freeCustomOp2Factory, + }, + }; + return package; +} +MLLM_PLUGIN_OP_INTERFACE_DEFINE_END diff --git a/tests/plugin/CustomPackageForHostTest.hpp b/tests/plugin/CustomPackageForHostTest.hpp new file mode 100644 index 000000000..5591532e7 --- /dev/null +++ b/tests/plugin/CustomPackageForHostTest.hpp @@ -0,0 +1,70 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#include "mllm/mllm.hpp" +#include "mllm/backends/base/PluginInterface.hpp" + +struct CustomOp1Options : public mllm::BaseOpOptions { + int32_t data = 0; +}; + +class CustomOp1 final : public mllm::plugin::interface::CustomizedOp { + public: + explicit CustomOp1(const CustomOp1Options& options) : CustomizedOp("custom_op1"), options_(options) {} + + void load(const mllm::ParameterFile::ptr_t& ploader) override {}; + + void trace(void* trace_context, const std::vector& inputs, std::vector& outputs) override {}; + + void forward(const std::vector& inputs, std::vector& outputs) override { + MLLM_INFO("Hello from custom op1, data: {}", options_.data); + } + + void reshape(const std::vector& inputs, std::vector& outputs) override {} + + void setup(const std::vector& inputs, std::vector& outputs) override {} + + protected: + CustomOp1Options options_; +}; + +class CustomOp1Factory final : public mllm::plugin::interface::CustomizedOpFactory { + public: + inline std::shared_ptr createOpImpl(const CustomOp1Options& cargo) override { + auto p = std::make_shared(cargo); + p->setOpType(opType()); + return p; + } +}; + +struct CustomOp2Options : public mllm::BaseOpOptions { + int32_t data = 0; +}; + +class CustomOp2 final : public mllm::plugin::interface::CustomizedOp { + public: + explicit CustomOp2(const CustomOp2Options& options) : CustomizedOp("custom_op2"), options_(options) {} + + void load(const mllm::ParameterFile::ptr_t& ploader) override {}; + + void trace(void* trace_context, const std::vector& inputs, std::vector& outputs) override {}; + + void forward(const std::vector& inputs, std::vector& outputs) override { + MLLM_INFO("Hello from custom op2, data: {}", options_.data); + } + + void reshape(const std::vector& inputs, std::vector& outputs) override {} + + void setup(const std::vector& inputs, std::vector& outputs) override {} + + protected: + CustomOp2Options options_; +}; + +class CustomOp2Factory final : public mllm::plugin::interface::CustomizedOpFactory { + public: + inline std::shared_ptr createOpImpl(const CustomOp2Options& cargo) override { + auto p = std::make_shared(cargo); + p->setOpType(opType()); + return p; + } +}; diff --git a/tests/plugin/HostTest.cpp b/tests/plugin/HostTest.cpp new file mode 100644 index 000000000..a9459f885 --- /dev/null +++ b/tests/plugin/HostTest.cpp @@ -0,0 +1,12 @@ +#include "mllm/mllm.hpp" +#include "CustomPackageForHostTest.hpp" + +MLLM_MAIN({ + mllm::loadOpPackage("./libCustomPackageForHostTest.so"); + std::vector inputs, outputs; + auto op = mllm::Context::instance() + .getBackend(mllm::kCPU) + ->createOp((mllm::OpTypes)mllm::Context::instance().lookupCustomizedOpId(mllm::kCPU, "custom_op1"), + CustomOp1Options{.data = 42}); + op->forward(inputs, outputs); +})