Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mllm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ add_library(
mllm.cpp
backends/base/Backend.cpp
backends/base/Allocator.cpp
backends/base/PluginSystem.cpp
${MLLM_RT_CORE_SRC}
${MLLM_RT_UTILS_SRC}
${MLLM_RT_ENGINE_SRC}
Expand Down
2 changes: 0 additions & 2 deletions mllm/backends/base/PluginInterface.cpp

This file was deleted.

31 changes: 30 additions & 1 deletion mllm/backends/base/PluginInterface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,21 @@

#include "mllm/core/BaseOp.hpp"

#define MLLM_PLUGIN_OP_PACKAGE_DESCRIPTOR_VERSION 1
#define MLLM_PLUGIN_OP_PACKAGE_NAME_LEN 256
#define MLLM_PLUGIN_OP_PACKAGE_DESCRIPTOR_LEN 256

#define MLLM_PLUGIN_OP_INTERFACE_DEFINE_BEGIN \
extern "C" { \
void* opPackageDescriptor();
#define MLLM_PLUGIN_OP_INTERFACE_DEFINE_END }

namespace mllm::plugin::interface {

class CustomizedOp : BaseOp {};
class CustomizedOp : public BaseOp {
public:
explicit CustomizedOp(const std::string& name) : BaseOp(OpTypes::kDynamicOp_Start) { setName(name); }
};

template<typename CargoT>
class CustomizedOpFactory : protected TypedOpFactory<OpTypes::kDynamicOp_Start, CargoT> {
Expand All @@ -17,3 +29,20 @@ class CustomizedOpFactory : protected TypedOpFactory<OpTypes::kDynamicOp_Start,
};

} // namespace mllm::plugin::interface

extern "C" {

typedef void* (*OpFactoryCreateFunc)(); // NOLINT
typedef void (*OpFactoryFreeFunc)(void*); // NOLINT

struct PluginOpPackageDescriptor {
int32_t version = MLLM_PLUGIN_OP_PACKAGE_DESCRIPTOR_VERSION;
char name[MLLM_PLUGIN_OP_PACKAGE_NAME_LEN];

int32_t device_type;
int32_t op_factories_count = 0;
char op_factories_names[MLLM_PLUGIN_OP_PACKAGE_DESCRIPTOR_LEN][MLLM_PLUGIN_OP_PACKAGE_NAME_LEN];
OpFactoryCreateFunc op_factory_create_funcs[MLLM_PLUGIN_OP_PACKAGE_DESCRIPTOR_LEN];
OpFactoryFreeFunc op_factory_free_funcs[MLLM_PLUGIN_OP_PACKAGE_DESCRIPTOR_LEN];
};
}
85 changes: 83 additions & 2 deletions mllm/backends/base/PluginSystem.cpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,97 @@
// Copyright (c) MLLM Team.
// Licensed under the MIT License.

#ifdef _WIN32
#include <windows.h>
#else
#include <dlfcn.h>
#endif

#include <string>
#include <stdexcept>

#include "mllm/engine/Context.hpp"
#include "mllm/backends/base/PluginSystem.hpp"

namespace mllm::plugin {

int32_t OpPluginSystem::registerCustomizedOp(DeviceTypes device_type, const std::shared_ptr<BaseOpFactory>& factory) {
OpPluginSystem::~OpPluginSystem() {
#ifdef _WIN32
for (auto handle : loaded_libraries_) {
if (handle) { FreeLibrary(handle); }
}
#else
for (auto handle : loaded_libraries_) {
if (handle) { dlclose(handle); }
}
#endif
for (auto descriptor : op_packages_) {
if (descriptor) { delete descriptor; }
}
}

void OpPluginSystem::loadOpPackage(const std::string& path) {
// 1. DLOPEN
#ifdef _WIN32
HMODULE handle = LoadLibraryA(path.c_str());
if (!handle) { throw std::runtime_error("Failed to load plugin library on Windows: " + path); }
#else // Linux and other Unix-like systems
void* handle = dlopen(path.c_str(), RTLD_LAZY);
if (!handle) {
throw std::runtime_error("Failed to load plugin library on Linux and other Unix-like systems: " + std::string(dlerror()));
}
#endif

// 2. Call function opPackageDescriptor() in the library
typedef void* (*DescriptorFunc)(); // NOLINT
#ifdef _WIN32
DescriptorFunc descriptor = (DescriptorFunc)GetProcAddress(handle, "opPackageDescriptor");
if (!descriptor) {
FreeLibrary(handle);
throw std::runtime_error("Failed to find opPackageDescriptor function in plugin: " + path);
}
#else
DescriptorFunc descriptor = (DescriptorFunc)dlsym(handle, "opPackageDescriptor");
if (!descriptor) {
dlclose(handle);
throw std::runtime_error("Failed to find opPackageDescriptor function in plugin: " + std::string(dlerror()));
}
#endif

// 3. Cast void* return to PluginOpPackageDescriptor*
void* result = descriptor();
PluginOpPackageDescriptor* package_descriptor = static_cast<PluginOpPackageDescriptor*>(result);
MLLM_RT_ASSERT(package_descriptor != nullptr);
op_packages_.push_back(package_descriptor);
loaded_libraries_.push_back(handle);

// 4. Load all factory to the backend
MLLM_INFO("Load customized op package: {}, find op nums: {}", package_descriptor->name,
package_descriptor->op_factories_count);
for (int i = 0; i < package_descriptor->op_factories_count; ++i) {
std::shared_ptr<BaseOpFactory> factory((BaseOpFactory*)(package_descriptor->op_factory_create_funcs[i]()),
[&package_descriptor, &i](BaseOpFactory* ptr) {
// FIXME: will raise error
// package_descriptor->op_factory_free_funcs[i](ptr);
});
registerCustomizedOp(static_cast<DeviceTypes>(package_descriptor->device_type),
std::string(package_descriptor->op_factories_names[i]), factory);
}
}

int32_t OpPluginSystem::registerCustomizedOp(DeviceTypes device_type, const std::string& name,
const std::shared_ptr<BaseOpFactory>& factory) {
auto op_type_ret = ++dynamic_op_type_counter_;
Context::instance().getBackend(device_type)->regOpFactory(factory);
factory->__forceSetType(op_type_ret);
Context::instance().getBackend(device_type)->regOpFactory(factory);
if (!op_name_table_.has(device_type)) { op_name_table_.reg(device_type, SymbolTable<op_name_t, op_type_t>{}); }
op_name_table_[device_type].reg(name, op_type_ret);
MLLM_INFO("Register customized op: {}:{} -> {}", name, op_type_ret, deviceTypes2Str(device_type));
return op_type_ret;
}

int32_t OpPluginSystem::lookupCustomizedOp(DeviceTypes device_type, const std::string& name) {
return op_name_table_[device_type][name];
}

} // namespace mllm::plugin
19 changes: 17 additions & 2 deletions mllm/backends/base/PluginSystem.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mllm/core/OpTypes.hpp"
#include "mllm/core/DeviceTypes.hpp"
#include "mllm/utils/SymbolTable.hpp"
#include "mllm/backends/base/PluginInterface.hpp"

namespace mllm::plugin {

Expand All @@ -20,11 +21,25 @@ class OpPluginSystem {
using op_type_t = int32_t;
using op_name_t = std::string;

int32_t registerCustomizedOp(DeviceTypes device_type, const std::shared_ptr<BaseOpFactory>& factory);
OpPluginSystem() = default;

~OpPluginSystem();

void loadOpPackage(const std::string& path);

int32_t registerCustomizedOp(DeviceTypes device_type, const std::string& name, const std::shared_ptr<BaseOpFactory>& factory);

int32_t lookupCustomizedOp(DeviceTypes device_type, const std::string& name);

private:
int32_t dynamic_op_type_counter_ = (int32_t)OpTypes::kDynamicOp_Start;
SymbolTable<op_type_t, op_name_t> op_name_table_;
SymbolTable<DeviceTypes, SymbolTable<op_name_t, op_type_t>> op_name_table_;
std::vector<PluginOpPackageDescriptor*> op_packages_;
#ifdef _WIN32
std::vector<HMODULE> loaded_libraries_;
#else
std::vector<void*> loaded_libraries_;
#endif
};

} // namespace mllm::plugin
2 changes: 2 additions & 0 deletions mllm/core/BaseOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,6 @@ void BaseOp::setDeviceType(DeviceTypes device_type) { device_type_ = device_type

OpTypes BaseOp::getOpType() const { return op_type_; }

void BaseOp::setOpType(OpTypes op_type) { op_type_ = op_type; }

} // namespace mllm
2 changes: 2 additions & 0 deletions mllm/core/BaseOp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ class BaseOp : public std::enable_shared_from_this<BaseOp> {

OpTypes getOpType() const;

void setOpType(OpTypes op_type);

private:
DeviceTypes device_type_;
std::string name_;
Expand Down
6 changes: 6 additions & 0 deletions mllm/engine/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,4 +124,10 @@ void Context::setPrintMaxElementsPerDim(int max_elements) { print_max_elements_p

int Context::getPrintMaxElementsPerDim() const { return print_max_elements_per_dim_; }

void Context::loadOpPackage(const std::string& path) { op_plugin_system_.loadOpPackage(path); }

int32_t Context::lookupCustomizedOpId(DeviceTypes device_type, const std::string& name) {
return op_plugin_system_.lookupCustomizedOp(device_type, name);
}

} // namespace mllm
8 changes: 8 additions & 0 deletions mllm/engine/Context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mllm/utils/SymbolTable.hpp"
#include "mllm/engine/MemoryManager.hpp"
#include "mllm/backends/base/Backend.hpp"
#include "mllm/backends/base/PluginSystem.hpp"

namespace mllm {

Expand Down Expand Up @@ -61,10 +62,17 @@ class Context {

int getPrintMaxElementsPerDim() const;

void loadOpPackage(const std::string& path);

int32_t lookupCustomizedOpId(DeviceTypes device_type, const std::string& name);

private:
// NOTE: Context should be made private in singleton design pattern.
Context();

// Plugin system
plugin::OpPluginSystem op_plugin_system_;

uint64_t random_seed_ = 42;
uint64_t random_state_ = 42;
SessionTCB::ptr_t main_thread_;
Expand Down
2 changes: 2 additions & 0 deletions mllm/mllm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ void cleanThisThread() {

SessionTCB::ptr_t thisThread() { return Context::instance().thisThread(); }

void loadOpPackage(const std::string& path) { Context::instance().loadOpPackage(path); }

ParameterFile::ptr_t load(const std::string& file_name, ModelFileVersion v, DeviceTypes map_2_device) {
if (v == ModelFileVersion::kV1 && map_2_device == kCPU) {
return ParameterFileIOImpl<kCPU, ModelFileVersion::kV1>::read(file_name);
Expand Down
2 changes: 2 additions & 0 deletions mllm/mllm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ void cleanThisThread();

SessionTCB::ptr_t thisThread();

void loadOpPackage(const std::string& path);

ParameterFile::ptr_t load(const std::string& file_name, ModelFileVersion version = ModelFileVersion::kV1,
DeviceTypes map_2_device = kCPU);

Expand Down
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_subdirectory(nn)
add_subdirectory(core)
add_subdirectory(utils)
add_subdirectory(plugin)
add_subdirectory(engine)
add_subdirectory(compile)

Expand Down
7 changes: 7 additions & 0 deletions tests/plugin/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
add_executable(Mllm-Test-Plugin-HostTest HostTest.cpp)
target_link_libraries(Mllm-Test-Plugin-HostTest PRIVATE MllmRT MllmCPUBackend)
target_include_directories(Mllm-Test-Plugin-HostTest PRIVATE ${MLLM_INCLUDE_DIR})

add_library(CustomPackageForHostTest SHARED CustomPackageForHostTest.cpp)
target_link_libraries(CustomPackageForHostTest PRIVATE MllmRT MllmCPUBackend)
target_include_directories(CustomPackageForHostTest PRIVATE ${MLLM_INCLUDE_DIR})
36 changes: 36 additions & 0 deletions tests/plugin/CustomPackageForHostTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#include "CustomPackageForHostTest.hpp"

MLLM_PLUGIN_OP_INTERFACE_DEFINE_BEGIN
void* createCustomOp1Factory() { return new CustomOp1Factory(); };

void freeCustomOp1Factory(void* factory) { delete static_cast<CustomOp1Factory*>(factory); };

void* createCustomOp2Factory() { return new CustomOp2Factory(); };

void freeCustomOp2Factory(void* factory) { delete static_cast<CustomOp2Factory*>(factory); };

void* opPackageDescriptor() {
auto package = new PluginOpPackageDescriptor{
.version = MLLM_PLUGIN_OP_PACKAGE_DESCRIPTOR_VERSION,
.name = "CustomPackageForHostTest",
.device_type = 1,
.op_factories_count = 2,
.op_factories_names =
{
"custom_op1",
"custom_op2",
},
.op_factory_create_funcs =
{
createCustomOp1Factory,
createCustomOp2Factory,
},
.op_factory_free_funcs =
{
freeCustomOp1Factory,
freeCustomOp2Factory,
},
};
return package;
}
MLLM_PLUGIN_OP_INTERFACE_DEFINE_END
70 changes: 70 additions & 0 deletions tests/plugin/CustomPackageForHostTest.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// Copyright (c) MLLM Team.
// Licensed under the MIT License.
#include "mllm/mllm.hpp"
#include "mllm/backends/base/PluginInterface.hpp"

struct CustomOp1Options : public mllm::BaseOpOptions<CustomOp1Options> {
int32_t data = 0;
};

class CustomOp1 final : public mllm::plugin::interface::CustomizedOp {
public:
explicit CustomOp1(const CustomOp1Options& options) : CustomizedOp("custom_op1"), options_(options) {}

void load(const mllm::ParameterFile::ptr_t& ploader) override {};

void trace(void* trace_context, const std::vector<mllm::Tensor>& inputs, std::vector<mllm::Tensor>& outputs) override {};

void forward(const std::vector<mllm::Tensor>& inputs, std::vector<mllm::Tensor>& outputs) override {
MLLM_INFO("Hello from custom op1, data: {}", options_.data);
}

void reshape(const std::vector<mllm::Tensor>& inputs, std::vector<mllm::Tensor>& outputs) override {}

void setup(const std::vector<mllm::Tensor>& inputs, std::vector<mllm::Tensor>& outputs) override {}

protected:
CustomOp1Options options_;
};

class CustomOp1Factory final : public mllm::plugin::interface::CustomizedOpFactory<CustomOp1Options> {
public:
inline std::shared_ptr<mllm::BaseOp> createOpImpl(const CustomOp1Options& cargo) override {
auto p = std::make_shared<CustomOp1>(cargo);
p->setOpType(opType());
return p;
}
};

struct CustomOp2Options : public mllm::BaseOpOptions<CustomOp2Options> {
int32_t data = 0;
};

class CustomOp2 final : public mllm::plugin::interface::CustomizedOp {
public:
explicit CustomOp2(const CustomOp2Options& options) : CustomizedOp("custom_op2"), options_(options) {}

void load(const mllm::ParameterFile::ptr_t& ploader) override {};

void trace(void* trace_context, const std::vector<mllm::Tensor>& inputs, std::vector<mllm::Tensor>& outputs) override {};

void forward(const std::vector<mllm::Tensor>& inputs, std::vector<mllm::Tensor>& outputs) override {
MLLM_INFO("Hello from custom op2, data: {}", options_.data);
}

void reshape(const std::vector<mllm::Tensor>& inputs, std::vector<mllm::Tensor>& outputs) override {}

void setup(const std::vector<mllm::Tensor>& inputs, std::vector<mllm::Tensor>& outputs) override {}

protected:
CustomOp2Options options_;
};

class CustomOp2Factory final : public mllm::plugin::interface::CustomizedOpFactory<CustomOp2Options> {
public:
inline std::shared_ptr<mllm::BaseOp> createOpImpl(const CustomOp2Options& cargo) override {
auto p = std::make_shared<CustomOp2>(cargo);
p->setOpType(opType());
return p;
}
};
12 changes: 12 additions & 0 deletions tests/plugin/HostTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#include "mllm/mllm.hpp"
#include "CustomPackageForHostTest.hpp"

MLLM_MAIN({
mllm::loadOpPackage("./libCustomPackageForHostTest.so");
std::vector<mllm::Tensor> inputs, outputs;
auto op = mllm::Context::instance()
.getBackend(mllm::kCPU)
->createOp((mllm::OpTypes)mllm::Context::instance().lookupCustomizedOpId(mllm::kCPU, "custom_op1"),
CustomOp1Options{.data = 42});
op->forward(inputs, outputs);
})