From 36b2e64c389c6ca2b7fd3aad86e660fb3048eb37 Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Thu, 11 Dec 2025 10:00:51 +0000 Subject: [PATCH 1/7] feat(qnn): add Qualcomm QNN AOT support for x86 platforms - Introduce CMake options and build configurations to enable Qualcomm NPU AOT on x86 - Add QNN AOT wrapper APIs for dynamic library loading and QNN context management - Implement FFI bindings and Python interfaces for QNN AOT environment and device context - Include new source files and update build scripts to integrate QNN AOT functionality - Add test script to verify QNN context creation in Python bindings --- CMakeLists.txt | 3 + mllm/CMakeLists.txt | 15 ++ mllm/backends/qnn/aot/QnnWrappersAPI.cpp | 199 +++++++++++++++++++ mllm/backends/qnn/aot/QnnWrappersAPI.hpp | 117 +++++++++++ mllm/backends/qnn/aot/README.md | 3 + mllm/backends/qnn/aot_rt/README.md | 1 + mllm/ffi/CMakeLists.txt | 13 ++ mllm/ffi/Extension.cc | 15 ++ mllm/ffi/Object.hh | 38 ++++ mllm/ffi/qualcomm/QnnAOT.cc | 30 +++ mllm/ffi/qualcomm/QnnAOT.hh | 64 ++++++ pymllm/backends/__init__.py | 4 + pymllm/backends/cuda/__init__.py | 0 pymllm/backends/qualcomm/README.md | 1 + pymllm/backends/qualcomm/__init__.py | 4 + pymllm/backends/qualcomm/nn.py | 11 + pymllm/backends/qualcomm/qnn_aot_env.py | 1 + pymllm/compile/mllm_ir/trace.py | 0 pymllm/ffi/__init__.py | 35 ++++ pymllm/nn/_layers.py | 34 +++- pymllm/nn/_module.py | 28 +++ pymllm/tests/qualcomm/test_context_create.py | 8 + pyproject.toml | 4 +- tasks/build_x86.yaml | 1 + tasks/build_x86_qnn_aot.yaml | 18 ++ 25 files changed, 645 insertions(+), 2 deletions(-) create mode 100644 mllm/backends/qnn/aot/QnnWrappersAPI.cpp create mode 100644 mllm/backends/qnn/aot/QnnWrappersAPI.hpp create mode 100644 mllm/backends/qnn/aot/README.md create mode 100644 mllm/backends/qnn/aot_rt/README.md create mode 100644 mllm/ffi/qualcomm/QnnAOT.cc create mode 100644 mllm/ffi/qualcomm/QnnAOT.hh create mode 100644 pymllm/backends/__init__.py create mode 100644 pymllm/backends/cuda/__init__.py create mode 100644 pymllm/backends/qualcomm/README.md create mode 100644 pymllm/backends/qualcomm/__init__.py create mode 100644 pymllm/backends/qualcomm/nn.py create mode 100644 pymllm/backends/qualcomm/qnn_aot_env.py create mode 100644 pymllm/compile/mllm_ir/trace.py create mode 100644 pymllm/tests/qualcomm/test_context_create.py create mode 100644 tasks/build_x86_qnn_aot.yaml diff --git a/CMakeLists.txt b/CMakeLists.txt index 92d70fd5c..221e956d5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -55,6 +55,9 @@ option(MLLM_KERNEL_THREADS_VENDOR_APPLE_GCD "Enable Apple GCD Threads" OFF) option(MLLM_PERFETTO_ENABLE "Enable perfetto" OFF) option(MLLM_TRACY_ENABLE "Enable Tracy. A more advanced profiler" OFF) +# NPU AOT things +option(MLLM_QUALCOMM_QNN_AOT_ON_X86_ENABLE "Enable Qualcomm NPU AOT on X86 devices" OFF) + # Platform Hints option(MLLM_ANDROID_BURST_PERFORMANCE_HINTS "If MLLM need use APerformanceHintManager to tell android we need best performance" OFF) diff --git a/mllm/CMakeLists.txt b/mllm/CMakeLists.txt index 1671ddc53..ea2e6c358 100644 --- a/mllm/CMakeLists.txt +++ b/mllm/CMakeLists.txt @@ -8,6 +8,9 @@ file(GLOB_RECURSE MLLM_RT_PREPROCESSOR_SRC ${CMAKE_CURRENT_LIST_DIR}/preprocesso if(MLLM_BUILD_EXPERIMENTS) file(GLOB_RECURSE MLLM_RT_AUTO_TUNE_SRC ${CMAKE_CURRENT_LIST_DIR}/experiments/auto_tune/*.cpp) endif() +if(MLLM_QUALCOMM_QNN_AOT_ON_X86_ENABLE) + file(GLOB_RECURSE MLLM_QUALCOMM_AOT_SRC ${CMAKE_CURRENT_LIST_DIR}/backends/qnn/aot/*.cpp) +endif() file(GLOB WENET_AUDIO_SOURCES ${PROJECT_SOURCE_DIR}/third_party/wenet_audio/*) add_library( @@ -24,6 +27,7 @@ add_library( ${MLLM_RT_MODELS_SRC} ${MLLM_RT_COMPILE_SRC} ${MLLM_RT_AUTO_TUNE_SRC} + ${MLLM_QUALCOMM_AOT_SRC} ${WENET_AUDIO_SOURCES} ) @@ -113,6 +117,17 @@ if(MLLM_BUILD_OPENCL_BACKEND) ) endif() +if(MLLM_QUALCOMM_QNN_AOT_ON_X86_ENABLE) + # Build + target_include_directories(MllmRT PRIVATE + $ENV{QAIRT_SDK_ROOT}/include # QNN SDK include + $ENV{QAIRT_SDK_ROOT}/include/QNN # QNN SDK include + ) + add_compile_definitions( + MLLM_QUALCOMM_QNN_AOT_ON_X86_ENABLE + ) +endif() + if(MLLM_BUILD_QNN_BACKEND) add_subdirectory(backends/qnn) add_compile_definitions( diff --git a/mllm/backends/qnn/aot/QnnWrappersAPI.cpp b/mllm/backends/qnn/aot/QnnWrappersAPI.cpp new file mode 100644 index 000000000..8144e32d2 --- /dev/null +++ b/mllm/backends/qnn/aot/QnnWrappersAPI.cpp @@ -0,0 +1,199 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#include "mllm/backends/qnn/aot/QnnWrappersAPI.hpp" + +namespace mllm::qnn::aot { + +void __mllmLoggerCallback4QnnLogger(const char* fmt, QnnLog_Level_t level, uint64_t times_tamp, va_list argp) { + const char* level_str = ""; + switch (level) { + case QNN_LOG_LEVEL_ERROR: level_str = "[ERROR] "; break; + case QNN_LOG_LEVEL_WARN: level_str = "[WARN] "; break; + case QNN_LOG_LEVEL_INFO: level_str = "[INFO] "; break; + case QNN_LOG_LEVEL_DEBUG: level_str = "[DEBUG] "; break; + case QNN_LOG_LEVEL_VERBOSE: level_str = "[VERBOSE]"; break; + case QNN_LOG_LEVEL_MAX: level_str = "[UNKNOWN]"; break; + } + + double ms = (double)times_tamp / 1000000.0; + + { + fprintf(stdout, "QnnLogger(%8.1fms, %ld) %s: ", ms, times_tamp, level_str); + vfprintf(stdout, fmt, argp); + } +} + +const std::vector QnnDynSymbolLoader::possible_qnn_dyn_lib_paths_ = { + "/opt/qcom/aistack/qairt/2.41.0.251128/lib/x86_64-linux-clang/", +}; + +QnnDynSymbolLoader::~QnnDynSymbolLoader() { + for (auto& item : libs_) { + if (item.second.handle_) { dlclose(item.second.handle_); } + } +} + +bool QnnDynSymbolLoader::loadQnnDynLib(const std::string& lib_name, int flag) { + for (auto const& path : possible_qnn_dyn_lib_paths_) { + auto real_path = path + lib_name; + auto handle = dlopen(real_path.c_str(), flag); + if (handle) { + auto descriptor = QnnDynLibDescriptor{.lib_name_ = lib_name, .lib_path_ = path, .handle_ = handle}; + libs_.insert({lib_name, descriptor}); + MLLM_INFO("QnnDynSymbolLoader::loadQnnDynLib {} success.", real_path); + return true; + } else { + char* error = dlerror(); + MLLM_ERROR("QnnDynSymbolLoader::loadQnnDynLib try for {} failed: {}", real_path, error ? error : "Unknown error"); + } + } + MLLM_ERROR("QnnDynSymbolLoader::loadQnnDynLib {} failed.", lib_name); + return false; +} + +bool QnnDynSymbolLoader::loadQnnDynLibAtPath(const std::string& path, const std::string& lib_name, int flag) { + auto real_path = path + lib_name; + auto handle = dlopen(real_path.c_str(), flag); + if (handle) { + auto descriptor = QnnDynLibDescriptor{.lib_name_ = lib_name, .lib_path_ = path, .handle_ = handle}; + libs_.insert({lib_name, descriptor}); + MLLM_INFO("QnnDynSymbolLoader::loadQnnDynLib {} success.", real_path); + return true; + } else { + char* error = dlerror(); + MLLM_ERROR("QnnDynSymbolLoader::loadQnnDynLib try for {} failed: {}", real_path, error ? error : "Unknown error"); + } + MLLM_ERROR("QnnDynSymbolLoader::loadQnnDynLib {} failed.", lib_name); + return false; +} + +QnnAOTEnv::QnnAOTEnv() { _setup(); } + +QnnAOTEnv::QnnAOTEnv(const std::string& lib_path) { _setup(lib_path); } + +void QnnAOTEnv::_setup(const std::string& path) { + auto& loader = QnnDynSymbolLoader::instance(); + std::string htp_backend_lib_name = "libQnnHtp.so"; + // GLOBAL Load + if (path.empty()) { + if (!loader.loadQnnDynLib(htp_backend_lib_name, + QnnDynSymbolLoader::DynFlag::kRTLD_NOW | QnnDynSymbolLoader::DynFlag::kRTLD_GLOBAL)) { + MLLM_ERROR("QnnAOTEnv::QnnAOTEnv {} failed.", htp_backend_lib_name); + exit(1); + } + } else { + if (!loader.loadQnnDynLibAtPath(path, htp_backend_lib_name, + QnnDynSymbolLoader::DynFlag::kRTLD_NOW | QnnDynSymbolLoader::DynFlag::kRTLD_GLOBAL)) { + MLLM_ERROR("QnnAOTEnv::QnnAOTEnv {} failed.", htp_backend_lib_name); + exit(1); + } + } + + auto qnn_interface_get_providers_func = + loader(htp_backend_lib_name).func("QnnInterface_getProviders"); + + QnnInterface_t** interface_providers = nullptr; + uint32_t num_providers = 0; + + MLLM_RT_ASSERT_EQ(qnn_interface_get_providers_func((const QnnInterface_t***)&interface_providers, &num_providers), + QNN_SUCCESS); + MLLM_RT_ASSERT(interface_providers != nullptr); + MLLM_RT_ASSERT(num_providers != 0); + + MLLM_INFO("QnnAOTEnv::QnnAOTEnv get HTP num_providers: {}", num_providers); + + bool found_valid_interface = false; + // Get correct provider + for (size_t provider_id = 0; provider_id < num_providers; provider_id++) { + if (QNN_API_VERSION_MAJOR == interface_providers[provider_id]->apiVersion.coreApiVersion.major + && QNN_API_VERSION_MINOR <= interface_providers[provider_id]->apiVersion.coreApiVersion.minor) { + found_valid_interface = true; + qnn_htp_func_symbols_.qnn_interface_ = interface_providers[provider_id]->QNN_INTERFACE_VER_NAME; + break; + } + } + MLLM_RT_ASSERT_EQ(found_valid_interface, true); + + // Check if this HTP Backend has specific property + if (nullptr != qnn_htp_func_symbols_.qnn_interface_.propertyHasCapability) { + auto status = qnn_htp_func_symbols_.qnn_interface_.propertyHasCapability(QNN_PROPERTY_GROUP_DEVICE); + if (status == QNN_PROPERTY_NOT_SUPPORTED) { MLLM_WARN("Device property is not supported"); } + + MLLM_RT_ASSERT(status != QNN_PROPERTY_ERROR_UNKNOWN_KEY); + } +} + +std::shared_ptr QnnAOTEnv::createContext(const std::string& name) { + std::shared_ptr context = std::make_shared(); + context->name_ = name; + + // 1. create logger and register callback. + // clang-format off + MLLM_RT_ASSERT_EQ(qnn_htp_func_symbols_.qnn_interface_.logCreate(__mllmLoggerCallback4QnnLogger,QNN_LOG_LEVEL_VERBOSE, &context->log_), QNN_SUCCESS) + MLLM_RT_ASSERT_EQ(QNN_BACKEND_NO_ERROR, qnn_htp_func_symbols_.qnn_interface_.backendCreate(context->log_, (const QnnBackend_Config_t**)context->bk_cfg_, &context->bk_handle_)) + // clang-format on + + // 2. Create HTP Device + // FIXME(wch): we need to model each Hexagon machine with its special device info. + // clang-format off + if (nullptr != qnn_htp_func_symbols_.qnn_interface_.deviceCreate) { + auto status = qnn_htp_func_symbols_.qnn_interface_.deviceCreate(context->log_, nullptr, &context->device_handle_); + MLLM_RT_ASSERT_EQ(status, QNN_SUCCESS); + } + // clang-format on + + // 3. Create Profile + { + auto status = qnn_htp_func_symbols_.qnn_interface_.profileCreate(context->bk_handle_, QNN_PROFILE_LEVEL_DETAILED, + &context->profile_bk_handle_); + MLLM_RT_ASSERT_EQ(status, QNN_SUCCESS); + } + + // 4. Create Context + { + auto status = qnn_htp_func_symbols_.qnn_interface_.contextCreate(context->bk_handle_, context->device_handle_, + (const QnnContext_Config_t**)&context->qnn_context_config_, + &context->qnn_ctx_handle_); + MLLM_RT_ASSERT_EQ(QNN_CONTEXT_NO_ERROR, status); + } + + // 5. Register MLLM's Qnn Opset + // clang-format off + { + // FIXME(wch): we need to register our own opset of qnn. + // struct OpPackageInfo { + // std::string path; + // std::string interface_provider; + // std::string target; + // }; + + // std::vector op_packages = { + // {.path = "libQnnMllmPackageCPU.so", .interface_provider = "MllmPackageInterfaceProvider", .target = "CPU"}, + // {.path = "libQnnMllmPackageHTP.so", .interface_provider = "MllmPackageInterfaceProvider", .target = "HTP"}, + // }; + + // for (const auto& pkg : op_packages) { + // if (!qnn_htp_func_symbols_.qnn_interface_.backendRegisterOpPackage) { + // MLLM_ERROR_EXIT(ExitCode::kCoreError, "qnn_htp_func_symbols_.qnn_interface_.backendRegisterOpPackage is nullptr."); + // } + // auto status = qnn_htp_func_symbols_.qnn_interface_.backendRegisterOpPackage(context->bk_handle_, pkg.path.c_str(), pkg.interface_provider.c_str(), pkg.target.c_str()); + // MLLM_RT_ASSERT_EQ(status, QNN_BACKEND_NO_ERROR); + // MLLM_INFO("QNN Registered op package: {}, interface provider: {}, target: {}", pkg.path, pkg.interface_provider, pkg.target); + // } + } + // clang-format on + + MLLM_RT_ASSERT_EQ(contexts_.count(name), 0); + contexts_[name] = context; + return context; +} + +void QnnAOTEnv::saveContext(const std::string& name, const std::string& path) { + // TODO +} + +void QnnAOTEnv::destroyContext(const std::string& name) { + // TODO +} + +} // namespace mllm::qnn::aot diff --git a/mllm/backends/qnn/aot/QnnWrappersAPI.hpp b/mllm/backends/qnn/aot/QnnWrappersAPI.hpp new file mode 100644 index 000000000..aeaa32785 --- /dev/null +++ b/mllm/backends/qnn/aot/QnnWrappersAPI.hpp @@ -0,0 +1,117 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "mllm/utils/Common.hpp" + +namespace mllm::qnn::aot { + +void __mllmLoggerCallback4QnnLogger(const char* fmt, QnnLog_Level_t level, uint64_t times_tamp, va_list argp); + +// Collection of symbols that we need to load from qnn dyn lib. +struct QnnFuncSymbols { + using QnnInterfaceGetProvidersFuncType = Qnn_ErrorHandle_t(const QnnInterface_t*** providerList, uint32_t* numProviders); + using QnnSystemInterfaceGetProvidersFuncType = Qnn_ErrorHandle_t(const QnnSystemInterface_t*** providerList, + uint32_t* numProviders); + + QNN_INTERFACE_VER_TYPE qnn_interface_; + QNN_SYSTEM_INTERFACE_VER_TYPE qnn_system_interface_; +}; + +struct QnnDeviceAndContext { + std::string name_; + Qnn_LogHandle_t log_ = nullptr; + Qnn_BackendHandle_t bk_handle_ = nullptr; + Qnn_DeviceHandle_t device_handle_ = nullptr; + QnnBackend_Config_t** bk_cfg_ = nullptr; + QnnContext_Config_t** qnn_context_config_ = nullptr; + Qnn_ProfileHandle_t profile_bk_handle_ = nullptr; + Qnn_ContextHandle_t qnn_ctx_handle_; +}; + +struct QnnDynLibDescriptor { + std::string lib_name_; + std::string lib_path_; + void* handle_ = nullptr; + + template + std::function func(const std::string& symbol_name) { + if (handle_ == nullptr) { MLLM_ERROR_EXIT(ExitCode::kCoreError, "QnnDynSymbolLoader: handle is nullptr."); } + auto func_ptr = dlsym(handle_, symbol_name.c_str()); + MLLM_RT_ASSERT(func_ptr != nullptr); + return (FuncType*)(func_ptr); + }; +}; + +class QnnDynSymbolLoader { + public: + enum DynFlag : int { // NOLINT performance-enum-size + kRTLD_NOW = RTLD_NOW, + kRTLD_LOCAL = RTLD_LOCAL, + kRTLD_GLOBAL = RTLD_GLOBAL, + }; + + static QnnDynSymbolLoader& instance() { + static QnnDynSymbolLoader instance; + return instance; + } + + ~QnnDynSymbolLoader(); + + QnnDynSymbolLoader() = default; + + QnnDynSymbolLoader(const QnnDynSymbolLoader&) = delete; + + QnnDynSymbolLoader& operator=(const QnnDynSymbolLoader&) = delete; + + bool loadQnnDynLib(const std::string& lib_name, int flag); + + bool loadQnnDynLibAtPath(const std::string& path, const std::string& lib_name, int flag); + + inline QnnDynLibDescriptor& operator()(const std::string& lib_name) { return libs_.at(lib_name); } + + private: + std::unordered_map libs_; + static const std::vector possible_qnn_dyn_lib_paths_; +}; + +// Device and Dynamic Lib included +class QnnAOTEnv { + public: + using ptr_t = std::shared_ptr; + + QnnAOTEnv(); + + explicit QnnAOTEnv(const std::string& lib_path); + + std::shared_ptr createContext(const std::string& name); + + void saveContext(const std::string& name, const std::string& path); + + void destroyContext(const std::string& name); + + private: + void _setup(const std::string& path = ""); + + QnnFuncSymbols qnn_htp_func_symbols_; + std::unordered_map> contexts_; +}; + +} // namespace mllm::qnn::aot diff --git a/mllm/backends/qnn/aot/README.md b/mllm/backends/qnn/aot/README.md new file mode 100644 index 000000000..d2d28d1d4 --- /dev/null +++ b/mllm/backends/qnn/aot/README.md @@ -0,0 +1,3 @@ +# Qnn AOT + +This is the Qnn AOT API for X86 platform to build executable qnn model. This is not depends on QNNBackend target. diff --git a/mllm/backends/qnn/aot_rt/README.md b/mllm/backends/qnn/aot_rt/README.md new file mode 100644 index 000000000..f7930caee --- /dev/null +++ b/mllm/backends/qnn/aot_rt/README.md @@ -0,0 +1 @@ +# Runtime of AOT Models diff --git a/mllm/ffi/CMakeLists.txt b/mllm/ffi/CMakeLists.txt index c46d15af5..549d0a68f 100644 --- a/mllm/ffi/CMakeLists.txt +++ b/mllm/ffi/CMakeLists.txt @@ -13,11 +13,24 @@ add_library(MllmFFIExtension SHARED ${CMAKE_CURRENT_LIST_DIR}/ModelService.cc ${CMAKE_CURRENT_LIST_DIR}/Nn.cc ${CMAKE_CURRENT_LIST_DIR}/Compile.cc + ${CMAKE_CURRENT_LIST_DIR}/qualcomm/QnnAOT.cc ) target_link_libraries(MllmFFIExtension PUBLIC tvm_ffi_header) target_link_libraries(MllmFFIExtension PUBLIC tvm_ffi_shared MllmRT MllmCPUBackend) set_target_properties(MllmFFIExtension PROPERTIES PREFIX "") +if(MLLM_QUALCOMM_QNN_AOT_ON_X86_ENABLE) + # Build + target_include_directories(MllmFFIExtension PRIVATE + $ENV{QAIRT_SDK_ROOT}/include # QNN SDK include + $ENV{QAIRT_SDK_ROOT}/include/QNN # QNN SDK include + ) + add_compile_definitions( + MLLM_QUALCOMM_QNN_AOT_ON_X86_ENABLE + ) +endif() + + # Set the depend search path. Windows do not need this, it will search dlls in the same directory first. if(APPLE) set_target_properties(MllmFFIExtension PROPERTIES diff --git a/mllm/ffi/Extension.cc b/mllm/ffi/Extension.cc index 71d4ce120..4744e172d 100644 --- a/mllm/ffi/Extension.cc +++ b/mllm/ffi/Extension.cc @@ -319,6 +319,17 @@ TVM_FFI_STATIC_INIT_BLOCK() { }); } +//===----------------------------------------------------------------------===// +// REGISTER: BaseOp Functions. +//===----------------------------------------------------------------------===// +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + + refl::GlobalDef().def("mllm.BaseOp.load", [](const mllm::ffi::BaseOp& self, const mllm::ffi::ParameterFile& obj) -> void { + self.get()->op_ptr_->load(obj.get()->pf_ptr_); + }); +} + //===----------------------------------------------------------------------===// // REGISTER: Service Functions. //===----------------------------------------------------------------------===// @@ -338,6 +349,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { }); } +//===----------------------------------------------------------------------===// +// REGISTER: _Context Functions. +//===----------------------------------------------------------------------===// + //===----------------------------------------------------------------------===// // REGISTER: Quantize && Packing Functions. //===----------------------------------------------------------------------===// diff --git a/mllm/ffi/Object.hh b/mllm/ffi/Object.hh index 31329db08..13021c871 100644 --- a/mllm/ffi/Object.hh +++ b/mllm/ffi/Object.hh @@ -91,4 +91,42 @@ class Session : public tvm::ffi::ObjectRef { TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Session, tvm::ffi::ObjectRef, SessionObj); // NOLINT }; +//===----------------------------------------------------------------------===// +// MLLM BaseOp Define +//===----------------------------------------------------------------------===// +class BaseOpObj : public tvm::ffi::Object { + public: + ::mllm::BaseOp::ptr_t op_ptr_ = nullptr; + + explicit BaseOpObj(const ::mllm::BaseOp::ptr_t& op_ptr) : op_ptr_(op_ptr) { MLLM_EMPTY_SCOPE; } + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("mllm.BaseOp", BaseOpObj, tvm::ffi::Object); +}; + +class BaseOp : public tvm::ffi::ObjectRef { + public: + explicit BaseOp(::mllm::BaseOp::ptr_t& base_op_ptr) { data_ = tvm::ffi::make_object(base_op_ptr); } + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BaseOp, tvm::ffi::ObjectRef, BaseOpObj); // NOLINT +}; + +//===----------------------------------------------------------------------===// +// MLLM Parameter File Define +//===----------------------------------------------------------------------===// +class ParameterFileObj : public tvm::ffi::Object { + public: + ::mllm::ParameterFile::ptr_t pf_ptr_ = nullptr; + + explicit ParameterFileObj(const ::mllm::ParameterFile::ptr_t& pf_ptr) : pf_ptr_(pf_ptr) { MLLM_EMPTY_SCOPE; } + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("mllm.ParameterFile", ParameterFileObj, tvm::ffi::Object); +}; + +class ParameterFile : public tvm::ffi::ObjectRef { + public: + explicit ParameterFile(::mllm::ParameterFile::ptr_t& pf_ptr) { data_ = tvm::ffi::make_object(pf_ptr); } + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ParameterFile, tvm::ffi::ObjectRef, ParameterFileObj); // NOLINT +}; + } // namespace mllm::ffi diff --git a/mllm/ffi/qualcomm/QnnAOT.cc b/mllm/ffi/qualcomm/QnnAOT.cc new file mode 100644 index 000000000..d77121959 --- /dev/null +++ b/mllm/ffi/qualcomm/QnnAOT.cc @@ -0,0 +1,30 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include + +#include "mllm/ffi/qualcomm/QnnAOT.hh" + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + + refl::ObjectDef().def_static("__create__", [](const std::string& path) -> mllm::ffi::QnnAOTEnv { + if (path.empty()) { + auto s = std::make_shared<::mllm::qnn::aot::QnnAOTEnv>(); + return ::mllm::ffi::QnnAOTEnv(s); + } else { + auto s = std::make_shared<::mllm::qnn::aot::QnnAOTEnv>(path); + return ::mllm::ffi::QnnAOTEnv(s); + } + }); + + refl::GlobalDef().def("mllm.qualcomm.QnnAOTEnv.createContext", [](const mllm::ffi::QnnAOTEnv& self, const std::string& name) { + auto s = self.get()->qnn_aot_env_ptr_->createContext(name); + return mllm::ffi::QnnDeviceAndContext(s); + }); +} diff --git a/mllm/ffi/qualcomm/QnnAOT.hh b/mllm/ffi/qualcomm/QnnAOT.hh new file mode 100644 index 000000000..321acc6ea --- /dev/null +++ b/mllm/ffi/qualcomm/QnnAOT.hh @@ -0,0 +1,64 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include + +#ifdef MLLM_QUALCOMM_QNN_AOT_ON_X86_ENABLE +#include "mllm/backends/qnn/aot/QnnWrappersAPI.hpp" +#endif + +namespace mllm::ffi { + +#ifdef MLLM_QUALCOMM_QNN_AOT_ON_X86_ENABLE + +//===----------------------------------------------------------------------===// +// MLLM Parameter File Define +//===----------------------------------------------------------------------===// +class QnnAOTEnvObj : public tvm::ffi::Object { + public: + ::mllm::qnn::aot::QnnAOTEnv::ptr_t qnn_aot_env_ptr_ = nullptr; + + explicit QnnAOTEnvObj(const ::mllm::qnn::aot::QnnAOTEnv::ptr_t& ptr) : qnn_aot_env_ptr_(ptr) { MLLM_EMPTY_SCOPE; } + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("mllm.qualcomm.QnnAOTEnv", QnnAOTEnvObj, tvm::ffi::Object); +}; + +class QnnAOTEnv : public tvm::ffi::ObjectRef { + public: + explicit QnnAOTEnv(::mllm::qnn::aot::QnnAOTEnv::ptr_t& ptr) { data_ = tvm::ffi::make_object(ptr); } + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(QnnAOTEnv, tvm::ffi::ObjectRef, QnnAOTEnvObj); // NOLINT +}; + +//===----------------------------------------------------------------------===// +// MLLM QnnDeviceAndContext Define +//===----------------------------------------------------------------------===// +class QnnDeviceAndContextObj : public tvm::ffi::Object { + public: + std::shared_ptr<::mllm::qnn::aot::QnnDeviceAndContext> qnn_device_and_context_ptr_ = nullptr; + + explicit QnnDeviceAndContextObj(const std::shared_ptr<::mllm::qnn::aot::QnnDeviceAndContext>& ptr) + : qnn_device_and_context_ptr_(ptr) { + MLLM_EMPTY_SCOPE; + } + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("mllm.qualcomm.QnnDeviceAndContext", QnnDeviceAndContextObj, tvm::ffi::Object); +}; + +class QnnDeviceAndContext : public tvm::ffi::ObjectRef { + public: + explicit QnnDeviceAndContext(std::shared_ptr<::mllm::qnn::aot::QnnDeviceAndContext>& ptr) { + data_ = tvm::ffi::make_object(ptr); + } + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(QnnDeviceAndContext, tvm::ffi::ObjectRef, QnnDeviceAndContextObj); // NOLINT +}; + +#endif + +} // namespace mllm::ffi diff --git a/pymllm/backends/__init__.py b/pymllm/backends/__init__.py new file mode 100644 index 000000000..5e926d580 --- /dev/null +++ b/pymllm/backends/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) MLLM Team. +# Licensed under the MIT License. + +from . import cuda, qualcomm diff --git a/pymllm/backends/cuda/__init__.py b/pymllm/backends/cuda/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pymllm/backends/qualcomm/README.md b/pymllm/backends/qualcomm/README.md new file mode 100644 index 000000000..27122dbc2 --- /dev/null +++ b/pymllm/backends/qualcomm/README.md @@ -0,0 +1 @@ +# Qualcomm Qnn AOT API diff --git a/pymllm/backends/qualcomm/__init__.py b/pymllm/backends/qualcomm/__init__.py new file mode 100644 index 000000000..3dc11529f --- /dev/null +++ b/pymllm/backends/qualcomm/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) MLLM Team. +# Licensed under the MIT License. + +from . import qnn_aot_env diff --git a/pymllm/backends/qualcomm/nn.py b/pymllm/backends/qualcomm/nn.py new file mode 100644 index 000000000..0ba9aef55 --- /dev/null +++ b/pymllm/backends/qualcomm/nn.py @@ -0,0 +1,11 @@ +from pymllm.nn._layers import Softmax, RoPE + + +class QnnSoftmax(Softmax): + def __init__(self): + super().__init__() + + +class QnnRoPE(RoPE): + def __init__(self): + super().__init__() diff --git a/pymllm/backends/qualcomm/qnn_aot_env.py b/pymllm/backends/qualcomm/qnn_aot_env.py new file mode 100644 index 000000000..7737a5c02 --- /dev/null +++ b/pymllm/backends/qualcomm/qnn_aot_env.py @@ -0,0 +1 @@ +from pymllm.ffi import QnnDeviceAndContext, QnnAOTEnv diff --git a/pymllm/compile/mllm_ir/trace.py b/pymllm/compile/mllm_ir/trace.py new file mode 100644 index 000000000..e69de29bb diff --git a/pymllm/ffi/__init__.py b/pymllm/ffi/__init__.py index e85176d0e..d49a6fc93 100644 --- a/pymllm/ffi/__init__.py +++ b/pymllm/ffi/__init__.py @@ -305,6 +305,41 @@ def __init__(self): pass +@tvm_ffi.register_object("mllm.ParameterFile") +class ParameterFile(tvm_ffi.Object): + def __init__(self): + pass + + +@tvm_ffi.register_object("mllm.BaseOp") +class BaseOp(tvm_ffi.Object): + def __init__(self): + pass + + def load(self, pf: ParameterFile): + return tvm_ffi.get_global_func("mllm.BaseOp.load")(self, pf) + + +@tvm_ffi.register_object("mllm.qualcomm.QnnDeviceAndContext") +class QnnDeviceAndContext(tvm_ffi.Object): + def __init__(self): + pass + + +@tvm_ffi.register_object("mllm.qualcomm.QnnAOTEnv") +class QnnAOTEnv(tvm_ffi.Object): + def __init__(self, path=None): + if path is None or path == "": + self.__init_handle_by_constructor__(QnnAOTEnv.__create__, "") + else: + self.__init_handle_by_constructor__(QnnAOTEnv.__create__, path) + + def create_context(self, name: str) -> QnnDeviceAndContext: + return tvm_ffi.get_global_func("mllm.qualcomm.QnnAOTEnv.createContext")( + self, name + ) + + # Initialize context initialize_context() diff --git a/pymllm/nn/_layers.py b/pymllm/nn/_layers.py index b4fdafb43..4dc6c0d66 100644 --- a/pymllm/nn/_layers.py +++ b/pymllm/nn/_layers.py @@ -4,6 +4,38 @@ from .. import ffi -class Linear: +class _Layer: def __init__(self): + self.this_layer_name: str = None + self.absolute_name: str = None + self._mllm_c_op_ptr: ffi.BaseOp = None + self._params_file_ptr: ffi.ParameterFile = None + + def load(self, pf: ffi.ParameterFile): + self._mllm_c_op_ptr.load(pf) + self._params_file_ptr = pf + + def trace(self): + pass + + def forward(self): + # TODO dispatch op pass + + def __call__(self, *args, **kwds): + pass + + +class Linear(_Layer): + def __init__(self): + super().__init__() + + +class Softmax(_Layer): + def __init__(self): + super().__init__() + + +class RoPE(_Layer): + def __init__(self): + super().__init__() diff --git a/pymllm/nn/_module.py b/pymllm/nn/_module.py index fa8434da0..309721ec2 100644 --- a/pymllm/nn/_module.py +++ b/pymllm/nn/_module.py @@ -2,8 +2,36 @@ # Licensed under the MIT License. from .. import ffi +from ._layers import _Layer class Module: def __init__(self): + self.this_module_name: str = None + self.absolute_name: str = None + self.module_layer_list: list = [] + + def load(self, pf: ffi.ParameterFile): + for module_layer in self.module_layer_list: + if isinstance(module_layer, Module, _Layer): + module_layer.load(pf) + else: + raise TypeError( + "Module layer must be Module or _Layer, but got {}".format( + type(module_layer) + ) + ) + + def trace(self): pass + + def forward(self, *args): + # TODO send to engine's dispatcher + pass + + def __call__(self, *args, **kwds): + # __send_graph_begin() + if kwds.get("__mllm_trace_mode_enabled", False): + return self.trace(*args, **kwds) + return self.forward(*args, **kwds) + # __send_graph_end() diff --git a/pymllm/tests/qualcomm/test_context_create.py b/pymllm/tests/qualcomm/test_context_create.py new file mode 100644 index 000000000..f34ef2393 --- /dev/null +++ b/pymllm/tests/qualcomm/test_context_create.py @@ -0,0 +1,8 @@ +import pymllm as mllm +from pymllm.backends.qualcomm.qnn_aot_env import QnnAOTEnv, QnnDeviceAndContext + +qnn_aot_env: QnnAOTEnv = QnnAOTEnv() + +if __name__ == "__main__": + mllm.echo("Testing mllm's tvm-ffi abi compatibility") + qnn_context: QnnDeviceAndContext = qnn_aot_env.create_context("model.layer.0") diff --git a/pyproject.toml b/pyproject.toml index 014234dd2..0eae93363 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [build-system] requires = [ - "scikit-build-core==0.10.0", "apache-tvm-ffi" + "scikit-build-core>=0.11.0", "apache-tvm-ffi" ] build-backend = "scikit_build_core.build" @@ -54,6 +54,8 @@ cmake.args = [ "-DCMAKE_BUILD_TYPE=Release", "-DMLLM_ENABLE_PY_MLLM=on" ] +sdist.exclude = [".*", ".*/*"] +wheel.exclude = [".*", ".*/*"] minimum-version = "build-system.requires" # Build configuration diff --git a/tasks/build_x86.yaml b/tasks/build_x86.yaml index a2b60952d..617f05f9c 100644 --- a/tasks/build_x86.yaml +++ b/tasks/build_x86.yaml @@ -11,6 +11,7 @@ Tasks: - '-DMLLM_CPU_BACKEND_COMPILE_OPTIONS="-march=native"' - "-DMLLM_KERNEL_USE_THREADS=ON" - "-DMLLM_KERNEL_THREADS_VENDOR_OPENMP=ON" + - "-DMLLM_KERNEL_USE_THREADS_VENDOR_MLLM=OFF" - CMakeBuildTask: cmake_cfg_path: "build" diff --git a/tasks/build_x86_qnn_aot.yaml b/tasks/build_x86_qnn_aot.yaml new file mode 100644 index 000000000..98870aa1f --- /dev/null +++ b/tasks/build_x86_qnn_aot.yaml @@ -0,0 +1,18 @@ +Tasks: + - CMakeConfigTask: + cmake_cfg_path: "build-qnn-aot" + cmake_build_type: "Release" + cmake_extra_args: + # Optional, If use Highway + - "-DHWY_ENABLE_TESTS=OFF" + - "-DHWY_ENABLE_EXAMPLES=OFF" + - "-DHWY_ENABLE_CONTRIB=OFF" + # Optional + - '-DMLLM_CPU_BACKEND_COMPILE_OPTIONS="-march=native"' + - "-DMLLM_KERNEL_USE_THREADS=ON" + - "-DMLLM_KERNEL_THREADS_VENDOR_OPENMP=ON" + - "-DMLLM_KERNEL_USE_THREADS_VENDOR_MLLM=OFF" + - "-DMLLM_QUALCOMM_QNN_AOT_ON_X86_ENABLE=ON" + + - CMakeBuildTask: + cmake_cfg_path: "build-qnn-aot" From 2211cff4bfd6e5e99fa12eea89554fc039d060c7 Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Thu, 11 Dec 2025 10:41:04 +0000 Subject: [PATCH 2/7] feat(qualcomm): add conditional compilation for QnnAOT on x86 Add preprocessor directives to enable conditional compilation for QnnAOT on x86 platforms. This allows the code to be compiled only when the MLLM_QUALCOMM_QNN_AOT_ON_X86_ENABLE macro is defined. --- mllm/ffi/qualcomm/QnnAOT.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mllm/ffi/qualcomm/QnnAOT.cc b/mllm/ffi/qualcomm/QnnAOT.cc index d77121959..dd1fac055 100644 --- a/mllm/ffi/qualcomm/QnnAOT.cc +++ b/mllm/ffi/qualcomm/QnnAOT.cc @@ -10,6 +10,8 @@ #include "mllm/ffi/qualcomm/QnnAOT.hh" +#ifdef MLLM_QUALCOMM_QNN_AOT_ON_X86_ENABLE + TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; @@ -28,3 +30,5 @@ TVM_FFI_STATIC_INIT_BLOCK() { return mllm::ffi::QnnDeviceAndContext(s); }); } + +#endif From ca67e5814faff3911e9664955246c34e20ab6ffa Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Thu, 11 Dec 2025 14:12:35 +0000 Subject: [PATCH 3/7] feat(qnn): add QcomTargetMachine and related enums for AOT environment - Introduce `QcomTargetMachine` struct to encapsulate target machine configurations including chipset, HTP architecture, performance mode, and security session. - Define enums for Qualcomm chipsets, HTP architectures, performance levels, and protection domain sessions. - Update `QnnAOTEnv` constructor to accept `QcomTargetMachine` for configuration. - Expose new FFI objects and Python bindings for target machine configuration. - Add test case demonstrating usage of `QcomTargetMachine` in QNN AOT environment. --- mllm/backends/qnn/aot/QnnTargetMachine.hpp | 66 ++++++ mllm/backends/qnn/aot/QnnWrappersAPI.cpp | 7 +- mllm/backends/qnn/aot/QnnWrappersAPI.hpp | 6 +- mllm/ffi/Extension.cc | 4 + mllm/ffi/qualcomm/QnnAOT.cc | 190 +++++++++++++++- mllm/ffi/qualcomm/QnnAOT.hh | 101 +++++++++ pymllm/backends/qualcomm/qnn_aot_env.py | 10 +- pymllm/ffi/__init__.py | 220 ++++++++++++++++++- pymllm/tests/qualcomm/test_context_create.py | 25 ++- 9 files changed, 609 insertions(+), 20 deletions(-) create mode 100644 mllm/backends/qnn/aot/QnnTargetMachine.hpp diff --git a/mllm/backends/qnn/aot/QnnTargetMachine.hpp b/mllm/backends/qnn/aot/QnnTargetMachine.hpp new file mode 100644 index 000000000..106a1bab9 --- /dev/null +++ b/mllm/backends/qnn/aot/QnnTargetMachine.hpp @@ -0,0 +1,66 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include + +namespace mllm::qnn::aot { + +enum class QcomHTPArch : uint32_t { + NONE = 0, + V68 = 68, + V69 = 69, + V73 = 73, + V75 = 75, + V79 = 79, + V81 = 81, +}; + +enum QcomChipset : int { + UNKNOWN_SM = 0, + SA8295 = 39, + SM8350 = 35, + SM8450 = 36, + SM8475 = 42, + SM8550 = 43, + SM8650 = 57, + SM8750 = 69, + SM8850 = 87, + SSG2115P = 46, + SSG2125P = 58, + SXR1230P = 45, + SXR2230P = 53, + SXR2330P = 75, + QCS9100 = 77, + SAR2230P = 95, + SA8255 = 52, + SW6100 = 96, +}; + +enum QcomTryBestPerformance : int { + kHtpDefault = 0, + kHtpSustainedHighPerformance, + kHtpBurst, + kHtpHighPerformance, + kHtpPowerSaver, + kHtpLowPowerSaver, + kHtpHighPowerSaver, + kHtpLowBalanced, + kHtpBalanced, +}; + +// Protection Domain Session +enum QcomSecurityPDSession : int { + kHtpUnsignedPd = 0, + kHtpSignedPd, +}; + +struct QcomTargetMachine { + QcomChipset soc_htp_chipset; + QcomHTPArch soc_htp_arch; + QcomTryBestPerformance soc_htp_performance; + QcomSecurityPDSession soc_htp_security_pd_session; +}; + +} // namespace mllm::qnn::aot diff --git a/mllm/backends/qnn/aot/QnnWrappersAPI.cpp b/mllm/backends/qnn/aot/QnnWrappersAPI.cpp index 8144e32d2..81c221500 100644 --- a/mllm/backends/qnn/aot/QnnWrappersAPI.cpp +++ b/mllm/backends/qnn/aot/QnnWrappersAPI.cpp @@ -1,6 +1,7 @@ // Copyright (c) MLLM Team. // Licensed under the MIT License. #include "mllm/backends/qnn/aot/QnnWrappersAPI.hpp" +#include "mllm/backends/qnn/aot/QnnTargetMachine.hpp" namespace mllm::qnn::aot { @@ -67,9 +68,11 @@ bool QnnDynSymbolLoader::loadQnnDynLibAtPath(const std::string& path, const std: return false; } -QnnAOTEnv::QnnAOTEnv() { _setup(); } +QnnAOTEnv::QnnAOTEnv(QcomTargetMachine& target_machine) : target_machine_(target_machine) { _setup(); } -QnnAOTEnv::QnnAOTEnv(const std::string& lib_path) { _setup(lib_path); } +QnnAOTEnv::QnnAOTEnv(const std::string& lib_path, QcomTargetMachine& target_machine) : target_machine_(target_machine) { + _setup(lib_path); +} void QnnAOTEnv::_setup(const std::string& path) { auto& loader = QnnDynSymbolLoader::instance(); diff --git a/mllm/backends/qnn/aot/QnnWrappersAPI.hpp b/mllm/backends/qnn/aot/QnnWrappersAPI.hpp index aeaa32785..d54f92fad 100644 --- a/mllm/backends/qnn/aot/QnnWrappersAPI.hpp +++ b/mllm/backends/qnn/aot/QnnWrappersAPI.hpp @@ -19,6 +19,7 @@ #include #include +#include "mllm/backends/qnn/aot/QnnTargetMachine.hpp" #include "mllm/utils/Common.hpp" namespace mllm::qnn::aot { @@ -97,9 +98,9 @@ class QnnAOTEnv { public: using ptr_t = std::shared_ptr; - QnnAOTEnv(); + explicit QnnAOTEnv(QcomTargetMachine& target_machine); - explicit QnnAOTEnv(const std::string& lib_path); + QnnAOTEnv(const std::string& lib_path, QcomTargetMachine& target_machine); std::shared_ptr createContext(const std::string& name); @@ -110,6 +111,7 @@ class QnnAOTEnv { private: void _setup(const std::string& path = ""); + QcomTargetMachine target_machine_; QnnFuncSymbols qnn_htp_func_symbols_; std::unordered_map> contexts_; }; diff --git a/mllm/ffi/Extension.cc b/mllm/ffi/Extension.cc index 4744e172d..1bcb0a1e6 100644 --- a/mllm/ffi/Extension.cc +++ b/mllm/ffi/Extension.cc @@ -47,6 +47,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("mllm.shutdown_context", mllm::shutdownContext); // Primitives + refl::ObjectDef<::mllm::ffi::DeviceObj>(); + refl::ObjectDef<::mllm::ffi::DTypeObj>(); refl::GlobalDef().def("mllm.cpu_", []() -> mllm::ffi::Device { return mllm::ffi::Device(::mllm::DeviceTypes::kCPU); }); refl::GlobalDef().def("mllm.cuda_", []() -> mllm::ffi::Device { return mllm::ffi::Device(::mllm::DeviceTypes::kCUDA); }); refl::GlobalDef().def("mllm.qnn_", []() -> mllm::ffi::Device { return mllm::ffi::Device(::mllm::DeviceTypes::kQNN); }); @@ -325,6 +327,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; + refl::ObjectDef<::mllm::ffi::BaseOpObj>(); refl::GlobalDef().def("mllm.BaseOp.load", [](const mllm::ffi::BaseOp& self, const mllm::ffi::ParameterFile& obj) -> void { self.get()->op_ptr_->load(obj.get()->pf_ptr_); }); @@ -336,6 +339,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; + refl::ObjectDef<::mllm::ffi::SessionObj>(); refl::GlobalDef().def("mllm.service.startService", [](int work_threads = 1) -> void { ::mllm::service::startService(work_threads); }); refl::GlobalDef().def("mllm.service.stopService", []() -> void { ::mllm::service::stopService(); }); diff --git a/mllm/ffi/qualcomm/QnnAOT.cc b/mllm/ffi/qualcomm/QnnAOT.cc index dd1fac055..c4ddc6f97 100644 --- a/mllm/ffi/qualcomm/QnnAOT.cc +++ b/mllm/ffi/qualcomm/QnnAOT.cc @@ -8,6 +8,7 @@ #include #include +#include "mllm/backends/qnn/aot/QnnTargetMachine.hpp" #include "mllm/ffi/qualcomm/QnnAOT.hh" #ifdef MLLM_QUALCOMM_QNN_AOT_ON_X86_ENABLE @@ -15,15 +16,188 @@ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_static("__create__", [](const std::string& path) -> mllm::ffi::QnnAOTEnv { - if (path.empty()) { - auto s = std::make_shared<::mllm::qnn::aot::QnnAOTEnv>(); - return ::mllm::ffi::QnnAOTEnv(s); - } else { - auto s = std::make_shared<::mllm::qnn::aot::QnnAOTEnv>(path); - return ::mllm::ffi::QnnAOTEnv(s); - } + refl::ObjectDef<::mllm::ffi::QcomHTPArchObj>(); + + refl::GlobalDef().def("mllm.qualcomm.QcomHTPArch.NONE", []() { + auto ret = mllm::qnn::aot::QcomHTPArch::NONE; + return mllm::ffi::QcomHTPArch(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomHTPArch.V68", []() { + auto ret = mllm::qnn::aot::QcomHTPArch::V68; + return mllm::ffi::QcomHTPArch(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomHTPArch.V69", []() { + auto ret = mllm::qnn::aot::QcomHTPArch::V69; + return mllm::ffi::QcomHTPArch(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomHTPArch.V73", []() { + auto ret = mllm::qnn::aot::QcomHTPArch::V73; + return mllm::ffi::QcomHTPArch(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomHTPArch.V75", []() { + auto ret = mllm::qnn::aot::QcomHTPArch::V79; + return mllm::ffi::QcomHTPArch(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomHTPArch.V79", []() { + auto ret = mllm::qnn::aot::QcomHTPArch::V79; + return mllm::ffi::QcomHTPArch(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomHTPArch.V81", []() { + auto ret = mllm::qnn::aot::QcomHTPArch::V81; + return mllm::ffi::QcomHTPArch(ret); + }); + + refl::ObjectDef<::mllm::ffi::QcomChipsetObj>(); + + refl::GlobalDef().def("mllm.qualcomm.QcomChipset.UNKNOWN_SM", []() { + auto ret = mllm::qnn::aot::QcomChipset::UNKNOWN_SM; + return mllm::ffi::QcomChipset(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomChipset.SA8295", []() { + auto ret = mllm::qnn::aot::QcomChipset::SA8295; + return mllm::ffi::QcomChipset(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomChipset.SM8350", []() { + auto ret = mllm::qnn::aot::QcomChipset::SM8350; + return mllm::ffi::QcomChipset(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomChipset.SM8450", []() { + auto ret = mllm::qnn::aot::QcomChipset::SM8450; + return mllm::ffi::QcomChipset(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomChipset.SM8475", []() { + auto ret = mllm::qnn::aot::QcomChipset::SM8475; + return mllm::ffi::QcomChipset(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomChipset.SM8550", []() { + auto ret = mllm::qnn::aot::QcomChipset::SM8550; + return mllm::ffi::QcomChipset(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomChipset.SM8650", []() { + auto ret = mllm::qnn::aot::QcomChipset::SM8650; + return mllm::ffi::QcomChipset(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomChipset.SM8750", []() { + auto ret = mllm::qnn::aot::QcomChipset::SM8750; + return mllm::ffi::QcomChipset(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomChipset.SM8850", []() { + auto ret = mllm::qnn::aot::QcomChipset::SM8850; + return mllm::ffi::QcomChipset(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomChipset.SSG2115P", []() { + auto ret = mllm::qnn::aot::QcomChipset::SSG2115P; + return mllm::ffi::QcomChipset(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomChipset.SSG2125P", []() { + auto ret = mllm::qnn::aot::QcomChipset::SSG2125P; + return mllm::ffi::QcomChipset(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomChipset.SXR1230P", []() { + auto ret = mllm::qnn::aot::QcomChipset::SXR1230P; + return mllm::ffi::QcomChipset(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomChipset.SXR2230P", []() { + auto ret = mllm::qnn::aot::QcomChipset::SXR2230P; + return mllm::ffi::QcomChipset(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomChipset.SXR2330P", []() { + auto ret = mllm::qnn::aot::QcomChipset::SXR2330P; + return mllm::ffi::QcomChipset(ret); }); + refl::GlobalDef().def("mllm.qualcomm.QcomChipset.QCS9100", []() { + auto ret = mllm::qnn::aot::QcomChipset::QCS9100; + return mllm::ffi::QcomChipset(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomChipset.SAR2230P", []() { + auto ret = mllm::qnn::aot::QcomChipset::SAR2230P; + return mllm::ffi::QcomChipset(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomChipset.SA8255", []() { + auto ret = mllm::qnn::aot::QcomChipset::SA8255; + return mllm::ffi::QcomChipset(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomChipset.SW6100", []() { + auto ret = mllm::qnn::aot::QcomChipset::SW6100; + return mllm::ffi::QcomChipset(ret); + }); + + refl::ObjectDef<::mllm::ffi::QcomTryBestPerformanceObj>(); + + refl::GlobalDef().def("mllm.qualcomm.QcomTryBestPerformance.HtpDefault", []() { + auto ret = mllm::qnn::aot::QcomTryBestPerformance::kHtpDefault; + return mllm::ffi::QcomTryBestPerformance(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomTryBestPerformance.HtpSustainedHighPerformance", []() { + auto ret = mllm::qnn::aot::QcomTryBestPerformance::kHtpSustainedHighPerformance; + return mllm::ffi::QcomTryBestPerformance(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomTryBestPerformance.HtpBurst", []() { + auto ret = mllm::qnn::aot::QcomTryBestPerformance::kHtpBurst; + return mllm::ffi::QcomTryBestPerformance(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomTryBestPerformance.HtpHighPerformance", []() { + auto ret = mllm::qnn::aot::QcomTryBestPerformance::kHtpHighPerformance; + return mllm::ffi::QcomTryBestPerformance(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomTryBestPerformance.HtpPowerSaver", []() { + auto ret = mllm::qnn::aot::QcomTryBestPerformance::kHtpPowerSaver; + return mllm::ffi::QcomTryBestPerformance(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomTryBestPerformance.HtpLowPowerSaver", []() { + auto ret = mllm::qnn::aot::QcomTryBestPerformance::kHtpLowPowerSaver; + return mllm::ffi::QcomTryBestPerformance(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomTryBestPerformance.HtpHighPowerSaver", []() { + auto ret = mllm::qnn::aot::QcomTryBestPerformance::kHtpHighPowerSaver; + return mllm::ffi::QcomTryBestPerformance(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomTryBestPerformance.HtpLowBalanced", []() { + auto ret = mllm::qnn::aot::QcomTryBestPerformance::kHtpLowBalanced; + return mllm::ffi::QcomTryBestPerformance(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomTryBestPerformance.HtpBalanced", []() { + auto ret = mllm::qnn::aot::QcomTryBestPerformance::kHtpBalanced; + return mllm::ffi::QcomTryBestPerformance(ret); + }); + + refl::ObjectDef<::mllm::ffi::QcomSecurityPDSessionObj>(); + + refl::GlobalDef().def("mllm.qualcomm.QcomSecurityPDSession.HtpUnsignedPd", []() { + auto ret = mllm::qnn::aot::QcomSecurityPDSession::kHtpUnsignedPd; + return mllm::ffi::QcomSecurityPDSession(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomSecurityPDSession.HtpSignedPd", []() { + auto ret = mllm::qnn::aot::QcomSecurityPDSession::kHtpSignedPd; + return mllm::ffi::QcomSecurityPDSession(ret); + }); + + refl::ObjectDef().def_static( + "__create__", [](const mllm::ffi::QcomChipset& chipset, const mllm::ffi::QcomHTPArch& arch, + const mllm::ffi::QcomTryBestPerformance& perf, const mllm::ffi::QcomSecurityPDSession& pd_session) { + auto tm = mllm::qnn::aot::QcomTargetMachine{ + .soc_htp_chipset = chipset.get()->chipset_, + .soc_htp_arch = arch.get()->htp_arch_, + .soc_htp_performance = perf.get()->perf_, + .soc_htp_security_pd_session = pd_session.get()->pd_, + }; + return ::mllm::ffi::QcomTargetMachine(tm); + }); + + refl::ObjectDef().def_static( + "__create__", [](const mllm::ffi::QcomTargetMachine& machine, const std::string& path) -> mllm::ffi::QnnAOTEnv { + if (path.empty()) { + auto tm = machine.get()->target_machine_; + auto s = std::make_shared<::mllm::qnn::aot::QnnAOTEnv>(tm); + return ::mllm::ffi::QnnAOTEnv(s); + } else { + auto tm = machine.get()->target_machine_; + auto s = std::make_shared<::mllm::qnn::aot::QnnAOTEnv>(path, tm); + return ::mllm::ffi::QnnAOTEnv(s); + } + }); + + refl::ObjectDef<::mllm::ffi::QnnDeviceAndContextObj>(); refl::GlobalDef().def("mllm.qualcomm.QnnAOTEnv.createContext", [](const mllm::ffi::QnnAOTEnv& self, const std::string& name) { auto s = self.get()->qnn_aot_env_ptr_->createContext(name); diff --git a/mllm/ffi/qualcomm/QnnAOT.hh b/mllm/ffi/qualcomm/QnnAOT.hh index 321acc6ea..f0feb46f3 100644 --- a/mllm/ffi/qualcomm/QnnAOT.hh +++ b/mllm/ffi/qualcomm/QnnAOT.hh @@ -59,6 +59,107 @@ class QnnDeviceAndContext : public tvm::ffi::ObjectRef { TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(QnnDeviceAndContext, tvm::ffi::ObjectRef, QnnDeviceAndContextObj); // NOLINT }; +//===----------------------------------------------------------------------===// +// MLLM QcomHTPArch Define +//===----------------------------------------------------------------------===// +class QcomHTPArchObj : public tvm::ffi::Object { + public: + mllm::qnn::aot::QcomHTPArch htp_arch_; + + explicit QcomHTPArchObj(const mllm::qnn::aot::QcomHTPArch& obj) : htp_arch_(obj) { MLLM_EMPTY_SCOPE; } + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("mllm.qualcomm.QcomHTPArch", QcomHTPArchObj, tvm::ffi::Object); +}; + +class QcomHTPArch : public tvm::ffi::ObjectRef { + public: + explicit QcomHTPArch(mllm::qnn::aot::QcomHTPArch& ptr) { data_ = tvm::ffi::make_object(ptr); } + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(QcomHTPArch, tvm::ffi::ObjectRef, QcomHTPArchObj); // NOLINT +}; + +//===----------------------------------------------------------------------===// +// MLLM QcomChipset Define +//===----------------------------------------------------------------------===// +class QcomChipsetObj : public tvm::ffi::Object { + public: + mllm::qnn::aot::QcomChipset chipset_; + + explicit QcomChipsetObj(const mllm::qnn::aot::QcomChipset& obj) : chipset_(obj) { MLLM_EMPTY_SCOPE; } + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("mllm.qualcomm.QcomChipset", QcomChipsetObj, tvm::ffi::Object); +}; + +class QcomChipset : public tvm::ffi::ObjectRef { + public: + explicit QcomChipset(mllm::qnn::aot::QcomChipset& ptr) { data_ = tvm::ffi::make_object(ptr); } + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(QcomChipset, tvm::ffi::ObjectRef, QcomChipsetObj); // NOLINT +}; + +//===----------------------------------------------------------------------===// +// MLLM QcomTryBestPerformance Define +//===----------------------------------------------------------------------===// +class QcomTryBestPerformanceObj : public tvm::ffi::Object { + public: + mllm::qnn::aot::QcomTryBestPerformance perf_; + + explicit QcomTryBestPerformanceObj(const mllm::qnn::aot::QcomTryBestPerformance& obj) : perf_(obj) { MLLM_EMPTY_SCOPE; } + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("mllm.qualcomm.QcomTryBestPerformance", QcomTryBestPerformanceObj, tvm::ffi::Object); +}; + +class QcomTryBestPerformance : public tvm::ffi::ObjectRef { + public: + explicit QcomTryBestPerformance(mllm::qnn::aot::QcomTryBestPerformance& ptr) { + data_ = tvm::ffi::make_object(ptr); + } + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(QcomTryBestPerformance, tvm::ffi::ObjectRef, QcomTryBestPerformanceObj); // NOLINT +}; + +//===----------------------------------------------------------------------===// +// MLLM QcomSecurityPDSession Define +//===----------------------------------------------------------------------===// +class QcomSecurityPDSessionObj : public tvm::ffi::Object { + public: + mllm::qnn::aot::QcomSecurityPDSession pd_; + + explicit QcomSecurityPDSessionObj(const mllm::qnn::aot::QcomSecurityPDSession& obj) : pd_(obj) { MLLM_EMPTY_SCOPE; } + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("mllm.qualcomm.QcomSecurityPDSession", QcomSecurityPDSessionObj, tvm::ffi::Object); +}; + +class QcomSecurityPDSession : public tvm::ffi::ObjectRef { + public: + explicit QcomSecurityPDSession(mllm::qnn::aot::QcomSecurityPDSession& ptr) { + data_ = tvm::ffi::make_object(ptr); + } + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(QcomSecurityPDSession, tvm::ffi::ObjectRef, QcomSecurityPDSessionObj); // NOLINT +}; + +//===----------------------------------------------------------------------===// +// MLLM QcomTargetMachine Define +//===----------------------------------------------------------------------===// +class QcomTargetMachineObj : public tvm::ffi::Object { + public: + mllm::qnn::aot::QcomTargetMachine target_machine_; + + explicit QcomTargetMachineObj(const mllm::qnn::aot::QcomTargetMachine& obj) : target_machine_(obj) { MLLM_EMPTY_SCOPE; } + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("mllm.qualcomm.QcomTargetMachine", QcomTargetMachineObj, tvm::ffi::Object); +}; + +class QcomTargetMachine : public tvm::ffi::ObjectRef { + public: + explicit QcomTargetMachine(mllm::qnn::aot::QcomTargetMachine& ptr) { + data_ = tvm::ffi::make_object(ptr); + } + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(QcomTargetMachine, tvm::ffi::ObjectRef, QcomTargetMachineObj); // NOLINT +}; + #endif } // namespace mllm::ffi diff --git a/pymllm/backends/qualcomm/qnn_aot_env.py b/pymllm/backends/qualcomm/qnn_aot_env.py index 7737a5c02..af4bf2c1a 100644 --- a/pymllm/backends/qualcomm/qnn_aot_env.py +++ b/pymllm/backends/qualcomm/qnn_aot_env.py @@ -1 +1,9 @@ -from pymllm.ffi import QnnDeviceAndContext, QnnAOTEnv +from pymllm.ffi import ( + QnnDeviceAndContext, + QnnAOTEnv, + QcomChipset, + QcomHTPArch, + QcomSecurityPDSession, + QcomTargetMachine, + QcomTryBestPerformance, +) diff --git a/pymllm/ffi/__init__.py b/pymllm/ffi/__init__.py index d49a6fc93..08667b8f2 100644 --- a/pymllm/ffi/__init__.py +++ b/pymllm/ffi/__init__.py @@ -326,13 +326,227 @@ def __init__(self): pass +@tvm_ffi.register_object("mllm.qualcomm.QcomHTPArch") +class QcomHTPArch(tvm_ffi.Object): + def __init__(self): + pass + + @staticmethod + def NONE() -> QcomHTPArch: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomHTPArch.NONE")() + + @staticmethod + def V68() -> QcomHTPArch: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomHTPArch.V68")() + + @staticmethod + def V69() -> QcomHTPArch: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomHTPArch.V69")() + + @staticmethod + def V73() -> QcomHTPArch: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomHTPArch.V73")() + + @staticmethod + def V75() -> QcomHTPArch: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomHTPArch.V75")() + + @staticmethod + def V79() -> QcomHTPArch: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomHTPArch.V79")() + + @staticmethod + def V81() -> QcomHTPArch: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomHTPArch.V81")() + + +@tvm_ffi.register_object("mllm.qualcomm.QcomChipset") +class QcomChipset(tvm_ffi.Object): + def __init__(self): + pass + + @staticmethod + def UNKNOWN_SM() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.UNKNOWN_SM")() + + @staticmethod + def SA8295() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SA8295")() + + @staticmethod + def SM8350() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SM8350")() + + @staticmethod + def SM8450() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SM8450")() + + @staticmethod + def SM8475() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SM8475")() + + @staticmethod + def SM8550() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SM8550")() + + @staticmethod + def SM8650() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SM8650")() + + @staticmethod + def SM8750() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SM8750")() + + @staticmethod + def SM8850() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SM8850")() + + @staticmethod + def SSG2115P() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SSG2115P")() + + @staticmethod + def SSG2125P() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SSG2125P")() + + @staticmethod + def SXR1230P() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SXR1230P")() + + @staticmethod + def SXR2230P() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SXR2230P")() + + @staticmethod + def SXR2330P() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SXR2330P")() + + @staticmethod + def QCS9100() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.QCS9100")() + + @staticmethod + def SAR2230P() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SAR2230P")() + + @staticmethod + def SA8255() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SA8255")() + + @staticmethod + def SW6100() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SW6100")() + + +@tvm_ffi.register_object("mllm.qualcomm.QcomTryBestPerformance") +class QcomTryBestPerformance(tvm_ffi.Object): + def __init__(self): + pass + + @staticmethod + def HtpDefault() -> QcomTryBestPerformance: + return tvm_ffi.get_global_func( + "mllm.qualcomm.QcomTryBestPerformance.HtpDefault" + )() + + @staticmethod + def HtpSustainedHighPerformance() -> QcomTryBestPerformance: + return tvm_ffi.get_global_func( + "mllm.qualcomm.QcomTryBestPerformance.HtpSustainedHighPerformance" + )() + + @staticmethod + def HtpBurst() -> QcomTryBestPerformance: + return tvm_ffi.get_global_func( + "mllm.qualcomm.QcomTryBestPerformance.HtpBurst" + )() + + @staticmethod + def HtpHighPerformance() -> QcomTryBestPerformance: + return tvm_ffi.get_global_func( + "mllm.qualcomm.QcomTryBestPerformance.HtpHighPerformance" + )() + + @staticmethod + def HtpPowerSaver() -> QcomTryBestPerformance: + return tvm_ffi.get_global_func( + "mllm.qualcomm.QcomTryBestPerformance.HtpPowerSaver" + )() + + @staticmethod + def HtpLowPowerSaver() -> QcomTryBestPerformance: + return tvm_ffi.get_global_func( + "mllm.qualcomm.QcomTryBestPerformance.HtpLowPowerSaver" + )() + + @staticmethod + def HtpHighPowerSaver() -> QcomTryBestPerformance: + return tvm_ffi.get_global_func( + "mllm.qualcomm.QcomTryBestPerformance.HtpHighPowerSaver" + )() + + @staticmethod + def HtpLowBalanced() -> QcomTryBestPerformance: + return tvm_ffi.get_global_func( + "mllm.qualcomm.QcomTryBestPerformance.HtpLowBalanced" + )() + + @staticmethod + def HtpBalanced() -> QcomTryBestPerformance: + return tvm_ffi.get_global_func( + "mllm.qualcomm.QcomTryBestPerformance.HtpBalanced" + )() + + +@tvm_ffi.register_object("mllm.qualcomm.QcomSecurityPDSession") +class QcomSecurityPDSession(tvm_ffi.Object): + def __init__(self): + pass + + @staticmethod + def HtpUnsignedPd() -> QcomSecurityPDSession: + return tvm_ffi.get_global_func( + "mllm.qualcomm.QcomSecurityPDSession.HtpUnsignedPd" + )() + + @staticmethod + def HtpSignedPd() -> QcomSecurityPDSession: + return tvm_ffi.get_global_func( + "mllm.qualcomm.QcomSecurityPDSession.HtpSignedPd" + )() + + +@tvm_ffi.register_object("mllm.qualcomm.QcomTargetMachine") +class QcomTargetMachine(tvm_ffi.Object): + def __init__( + self, + soc_htp_chipset: QcomChipset, + soc_htp_arch: QcomHTPArch, + soc_htp_performance: QcomTryBestPerformance, + soc_htp_security_pd_session: QcomSecurityPDSession, + ): + self.__init_handle_by_constructor__( + QcomTargetMachine.__create__, + soc_htp_chipset, + soc_htp_arch, + soc_htp_performance, + soc_htp_security_pd_session, + ) + + @tvm_ffi.register_object("mllm.qualcomm.QnnAOTEnv") class QnnAOTEnv(tvm_ffi.Object): - def __init__(self, path=None): + def __init__( + self, + machine: QcomTargetMachine = None, + path: str = None, + ): + if machine is None: + raise RuntimeError("machine target is none!") if path is None or path == "": - self.__init_handle_by_constructor__(QnnAOTEnv.__create__, "") + self.__init_handle_by_constructor__(QnnAOTEnv.__create__, machine, "") else: - self.__init_handle_by_constructor__(QnnAOTEnv.__create__, path) + self.__init_handle_by_constructor__(QnnAOTEnv.__create__, machine, path) def create_context(self, name: str) -> QnnDeviceAndContext: return tvm_ffi.get_global_func("mllm.qualcomm.QnnAOTEnv.createContext")( diff --git a/pymllm/tests/qualcomm/test_context_create.py b/pymllm/tests/qualcomm/test_context_create.py index f34ef2393..b80d6df26 100644 --- a/pymllm/tests/qualcomm/test_context_create.py +++ b/pymllm/tests/qualcomm/test_context_create.py @@ -1,8 +1,25 @@ import pymllm as mllm -from pymllm.backends.qualcomm.qnn_aot_env import QnnAOTEnv, QnnDeviceAndContext +from pymllm.backends.qualcomm.qnn_aot_env import ( + QnnAOTEnv, + QnnDeviceAndContext, + QcomTryBestPerformance, + QcomSecurityPDSession, + QcomTargetMachine, + QcomChipset, + QcomHTPArch, +) -qnn_aot_env: QnnAOTEnv = QnnAOTEnv() + +qnn_aot_env: QnnAOTEnv = QnnAOTEnv( + machine=QcomTargetMachine( + soc_htp_chipset=QcomChipset.SM8850(), + soc_htp_arch=QcomHTPArch.V81(), + soc_htp_performance=QcomTryBestPerformance.HtpBurst(), + soc_htp_security_pd_session=QcomSecurityPDSession.HtpUnsignedPd(), + ), + path="/opt/qcom/aistack/qairt/2.41.0.251128/lib/x86_64-linux-clang/", +) if __name__ == "__main__": - mllm.echo("Testing mllm's tvm-ffi abi compatibility") - qnn_context: QnnDeviceAndContext = qnn_aot_env.create_context("model.layer.0") + mllm.echo("Testing tvm-ffi compatibility") + qnn_context: QnnDeviceAndContext = qnn_aot_env.create_context("context.0") From a8211a15e1d4deb4e034d5f6e41e012cbd956e4a Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Fri, 12 Dec 2025 10:16:34 +0000 Subject: [PATCH 4/7] feat(qnn): qualcomm aot target machine. --- mllm/backends/qnn/aot/QnnTargetMachine.hpp | 7 +- mllm/backends/qnn/aot/QnnWrappersAPI.cpp | 127 ++++++++++++++++++- mllm/backends/qnn/aot/QnnWrappersAPI.hpp | 18 ++- mllm/ffi/qualcomm/QnnAOT.cc | 17 ++- pymllm/ffi/__init__.py | 8 +- pymllm/tests/qualcomm/test_context_create.py | 5 +- 6 files changed, 165 insertions(+), 17 deletions(-) diff --git a/mllm/backends/qnn/aot/QnnTargetMachine.hpp b/mllm/backends/qnn/aot/QnnTargetMachine.hpp index 106a1bab9..6d3823130 100644 --- a/mllm/backends/qnn/aot/QnnTargetMachine.hpp +++ b/mllm/backends/qnn/aot/QnnTargetMachine.hpp @@ -17,7 +17,7 @@ enum class QcomHTPArch : uint32_t { V81 = 81, }; -enum QcomChipset : int { +enum QcomChipset : uint32_t { UNKNOWN_SM = 0, SA8295 = 39, SM8350 = 35, @@ -38,7 +38,7 @@ enum QcomChipset : int { SW6100 = 96, }; -enum QcomTryBestPerformance : int { +enum QcomTryBestPerformance : uint32_t { kHtpDefault = 0, kHtpSustainedHighPerformance, kHtpBurst, @@ -51,7 +51,7 @@ enum QcomTryBestPerformance : int { }; // Protection Domain Session -enum QcomSecurityPDSession : int { +enum QcomSecurityPDSession : uint32_t { kHtpUnsignedPd = 0, kHtpSignedPd, }; @@ -61,6 +61,7 @@ struct QcomTargetMachine { QcomHTPArch soc_htp_arch; QcomTryBestPerformance soc_htp_performance; QcomSecurityPDSession soc_htp_security_pd_session; + uint32_t soc_htp_vtcm_total_memory_size; }; } // namespace mllm::qnn::aot diff --git a/mllm/backends/qnn/aot/QnnWrappersAPI.cpp b/mllm/backends/qnn/aot/QnnWrappersAPI.cpp index 81c221500..86cf75a7c 100644 --- a/mllm/backends/qnn/aot/QnnWrappersAPI.cpp +++ b/mllm/backends/qnn/aot/QnnWrappersAPI.cpp @@ -1,5 +1,11 @@ // Copyright (c) MLLM Team. // Licensed under the MIT License. +#include +#include +#include + +#include "QnnContext.h" +#include "mllm/utils/Common.hpp" #include "mllm/backends/qnn/aot/QnnWrappersAPI.hpp" #include "mllm/backends/qnn/aot/QnnTargetMachine.hpp" @@ -124,9 +130,50 @@ void QnnAOTEnv::_setup(const std::string& path) { MLLM_RT_ASSERT(status != QNN_PROPERTY_ERROR_UNKNOWN_KEY); } + + // Try to config this target machine + { + auto device_custom_config = createDecideCustomConfigInfo(); + QnnHtpDevice_CustomConfig_t* p_custom_config = nullptr; + + switch (target_machine_.soc_htp_security_pd_session) { + case QcomSecurityPDSession::kHtpSignedPd: { + p_custom_config = (QnnHtpDevice_CustomConfig_t*)malloc(sizeof(QnnHtpDevice_CustomConfig_t)); + unreachable_handel_.push_back(p_custom_config); + p_custom_config->option = QNN_HTP_DEVICE_CONFIG_OPTION_SIGNEDPD; + p_custom_config->useSignedProcessDomain.useSignedProcessDomain = true; + p_custom_config->useSignedProcessDomain.deviceId = 0; + device_custom_config.push_back(static_cast(p_custom_config)); + break; + } + case QcomSecurityPDSession::kHtpUnsignedPd: + default: break; + } + + const std::vector device_platform_info = createDevicePlatformInfo(); + uint32_t num_custom_configs = device_platform_info.size() + device_custom_config.size(); + target_machine_qnn_config_.resize(num_custom_configs); + + for (std::size_t i = 0; i < device_custom_config.size(); ++i) { + target_machine_qnn_config_[i].option = QNN_DEVICE_CONFIG_OPTION_CUSTOM; + target_machine_qnn_config_[i].customConfig = device_custom_config[i]; + target_machine_qnn_config_ptrs_.push_back(&target_machine_qnn_config_[i]); + } + + if (!device_platform_info.empty()) { + // The length of platform info can only be 1. + MLLM_RT_ASSERT_EQ(device_platform_info.size(), 1u); + target_machine_qnn_config_[device_custom_config.size()].option = QNN_DEVICE_CONFIG_OPTION_PLATFORM_INFO; + target_machine_qnn_config_[device_custom_config.size()].hardwareInfo = device_platform_info.back(); + target_machine_qnn_config_ptrs_.push_back(&target_machine_qnn_config_[device_custom_config.size()]); + } + + // null terminated + target_machine_qnn_config_ptrs_.push_back(nullptr); + } } -std::shared_ptr QnnAOTEnv::createContext(const std::string& name) { +std::shared_ptr QnnAOTEnv::createContext(const std::string& name, bool weights_sharing) { std::shared_ptr context = std::make_shared(); context->name_ = name; @@ -137,10 +184,9 @@ std::shared_ptr QnnAOTEnv::createContext(const std::string& // clang-format on // 2. Create HTP Device - // FIXME(wch): we need to model each Hexagon machine with its special device info. // clang-format off if (nullptr != qnn_htp_func_symbols_.qnn_interface_.deviceCreate) { - auto status = qnn_htp_func_symbols_.qnn_interface_.deviceCreate(context->log_, nullptr, &context->device_handle_); + auto status = qnn_htp_func_symbols_.qnn_interface_.deviceCreate(context->log_, target_machine_qnn_config_ptrs_.data(), &context->device_handle_); MLLM_RT_ASSERT_EQ(status, QNN_SUCCESS); } // clang-format on @@ -154,6 +200,9 @@ std::shared_ptr QnnAOTEnv::createContext(const std::string& // 4. Create Context { + auto cfgs = createContextCustomConfig(weights_sharing); + // Current not support + MLLM_RT_ASSERT_EQ(cfgs.size(), 0); auto status = qnn_htp_func_symbols_.qnn_interface_.contextCreate(context->bk_handle_, context->device_handle_, (const QnnContext_Config_t**)&context->qnn_context_config_, &context->qnn_ctx_handle_); @@ -199,4 +248,76 @@ void QnnAOTEnv::destroyContext(const std::string& name) { // TODO } +std::vector QnnAOTEnv::createDevicePlatformInfo() { + std::vector ret; + QnnDevice_PlatformInfo_t* p_platform_info = nullptr; + QnnDevice_HardwareDeviceInfo_t* p_hw_device_info = nullptr; + QnnHtpDevice_DeviceInfoExtension_t* p_device_info_extension = nullptr; + QnnDevice_CoreInfo_t* p_core_info = nullptr; + + p_platform_info = (QnnDevice_PlatformInfo_t*)malloc(sizeof(QnnDevice_PlatformInfo_t)); + unreachable_handel_.push_back(p_platform_info); + p_platform_info->version = QNN_DEVICE_PLATFORM_INFO_VERSION_1; + p_platform_info->v1.numHwDevices = 1; + + p_hw_device_info = (QnnDevice_HardwareDeviceInfo_t*)malloc(sizeof(QnnDevice_HardwareDeviceInfo_t)); + unreachable_handel_.push_back(p_hw_device_info); + p_hw_device_info->version = QNN_DEVICE_HARDWARE_DEVICE_INFO_VERSION_1; + p_hw_device_info->v1.deviceId = 0; + p_hw_device_info->v1.deviceType = 0; + p_hw_device_info->v1.numCores = 1; + + p_device_info_extension = (QnnHtpDevice_DeviceInfoExtension_t*)malloc(sizeof(QnnHtpDevice_DeviceInfoExtension_t)); + unreachable_handel_.push_back(p_device_info_extension); + // clang-format off + p_device_info_extension->devType = QNN_HTP_DEVICE_TYPE_ON_CHIP; + p_device_info_extension->onChipDevice.vtcmSize = target_machine_.soc_htp_vtcm_total_memory_size; // in MB + p_device_info_extension->onChipDevice.signedPdSupport = target_machine_.soc_htp_security_pd_session == QcomSecurityPDSession::kHtpSignedPd; + p_device_info_extension->onChipDevice.socModel = static_cast(target_machine_.soc_htp_chipset); + p_device_info_extension->onChipDevice.arch = static_cast(target_machine_.soc_htp_arch); + p_device_info_extension->onChipDevice.dlbcSupport = true; + p_hw_device_info->v1.deviceInfoExtension = p_device_info_extension; + // clang-format on + + p_core_info = (QnnDevice_CoreInfo_t*)malloc(sizeof(QnnDevice_CoreInfo_t)); + unreachable_handel_.push_back(p_core_info); + p_core_info->version = QNN_DEVICE_CORE_INFO_VERSION_1; + p_core_info->v1.coreId = 0; + p_core_info->v1.coreType = 0; + p_core_info->v1.coreInfoExtension = nullptr; + p_hw_device_info->v1.cores = p_core_info; + + p_platform_info->v1.hwDevices = p_hw_device_info; + ret.push_back(p_platform_info); + + return ret; +} + +std::vector QnnAOTEnv::createDecideCustomConfigInfo() { + std::vector ret; + + QnnHtpDevice_CustomConfig_t* p_custom_config = (QnnHtpDevice_CustomConfig_t*)malloc(sizeof(QnnHtpDevice_CustomConfig_t)); + unreachable_handel_.push_back(p_custom_config); + p_custom_config->option = QNN_HTP_DEVICE_CONFIG_OPTION_SOC; + p_custom_config->socModel = static_cast(target_machine_.soc_htp_chipset); + ret.push_back(static_cast(p_custom_config)); + + return ret; +} + +std::vector QnnAOTEnv::createContextCustomConfig(bool weights_sharing) { + std::vector ret; + QnnHtpContext_CustomConfig_t* p_custom_config = nullptr; + + if (weights_sharing) { + p_custom_config = (QnnHtpContext_CustomConfig_t*)malloc(sizeof(QnnHtpContext_CustomConfig_t)); + unreachable_handel_.push_back(p_custom_config); + p_custom_config->option = QNN_HTP_CONTEXT_CONFIG_OPTION_WEIGHT_SHARING_ENABLED; + p_custom_config->weightSharingEnabled = true; + ret.push_back(static_cast(p_custom_config)); + } + + return ret; +} + } // namespace mllm::qnn::aot diff --git a/mllm/backends/qnn/aot/QnnWrappersAPI.hpp b/mllm/backends/qnn/aot/QnnWrappersAPI.hpp index d54f92fad..5c2f9a0be 100644 --- a/mllm/backends/qnn/aot/QnnWrappersAPI.hpp +++ b/mllm/backends/qnn/aot/QnnWrappersAPI.hpp @@ -14,6 +14,7 @@ #include #include +#include #include #include #include @@ -102,18 +103,33 @@ class QnnAOTEnv { QnnAOTEnv(const std::string& lib_path, QcomTargetMachine& target_machine); - std::shared_ptr createContext(const std::string& name); + std::shared_ptr createContext(const std::string& name, bool weights_sharing = false); void saveContext(const std::string& name, const std::string& path); void destroyContext(const std::string& name); + // This is for All PUs, such as CPU, GPU, NPU + std::vector createDevicePlatformInfo(); + + // This function is for NPU only. + std::vector createDecideCustomConfigInfo(); + + std::vector createContextCustomConfig(bool weights_sharing); + private: void _setup(const std::string& path = ""); QcomTargetMachine target_machine_; QnnFuncSymbols qnn_htp_func_symbols_; std::unordered_map> contexts_; + + // device config for all to use + std::vector target_machine_qnn_config_; + std::vector target_machine_qnn_config_ptrs_; + + // void* handle that should be freed when QnnAOTEnv end + std::vector unreachable_handel_; }; } // namespace mllm::qnn::aot diff --git a/mllm/ffi/qualcomm/QnnAOT.cc b/mllm/ffi/qualcomm/QnnAOT.cc index c4ddc6f97..e36cad641 100644 --- a/mllm/ffi/qualcomm/QnnAOT.cc +++ b/mllm/ffi/qualcomm/QnnAOT.cc @@ -35,7 +35,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { return mllm::ffi::QcomHTPArch(ret); }); refl::GlobalDef().def("mllm.qualcomm.QcomHTPArch.V75", []() { - auto ret = mllm::qnn::aot::QcomHTPArch::V79; + auto ret = mllm::qnn::aot::QcomHTPArch::V75; return mllm::ffi::QcomHTPArch(ret); }); refl::GlobalDef().def("mllm.qualcomm.QcomHTPArch.V79", []() { @@ -173,13 +173,15 @@ TVM_FFI_STATIC_INIT_BLOCK() { }); refl::ObjectDef().def_static( - "__create__", [](const mllm::ffi::QcomChipset& chipset, const mllm::ffi::QcomHTPArch& arch, - const mllm::ffi::QcomTryBestPerformance& perf, const mllm::ffi::QcomSecurityPDSession& pd_session) { + "__create__", + [](const mllm::ffi::QcomChipset& chipset, const mllm::ffi::QcomHTPArch& arch, + const mllm::ffi::QcomTryBestPerformance& perf, const mllm::ffi::QcomSecurityPDSession& pd_session, uint32_t htp_vtcm) { auto tm = mllm::qnn::aot::QcomTargetMachine{ .soc_htp_chipset = chipset.get()->chipset_, .soc_htp_arch = arch.get()->htp_arch_, .soc_htp_performance = perf.get()->perf_, .soc_htp_security_pd_session = pd_session.get()->pd_, + .soc_htp_vtcm_total_memory_size = htp_vtcm, }; return ::mllm::ffi::QcomTargetMachine(tm); }); @@ -199,10 +201,11 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::ObjectDef<::mllm::ffi::QnnDeviceAndContextObj>(); - refl::GlobalDef().def("mllm.qualcomm.QnnAOTEnv.createContext", [](const mllm::ffi::QnnAOTEnv& self, const std::string& name) { - auto s = self.get()->qnn_aot_env_ptr_->createContext(name); - return mllm::ffi::QnnDeviceAndContext(s); - }); + refl::GlobalDef().def("mllm.qualcomm.QnnAOTEnv.createContext", + [](const mllm::ffi::QnnAOTEnv& self, const std::string& name, bool weights_sharing) { + auto s = self.get()->qnn_aot_env_ptr_->createContext(name, weights_sharing); + return mllm::ffi::QnnDeviceAndContext(s); + }); } #endif diff --git a/pymllm/ffi/__init__.py b/pymllm/ffi/__init__.py index 08667b8f2..c22971f7c 100644 --- a/pymllm/ffi/__init__.py +++ b/pymllm/ffi/__init__.py @@ -524,6 +524,7 @@ def __init__( soc_htp_arch: QcomHTPArch, soc_htp_performance: QcomTryBestPerformance, soc_htp_security_pd_session: QcomSecurityPDSession, + soc_htp_vtcm: int, ): self.__init_handle_by_constructor__( QcomTargetMachine.__create__, @@ -531,6 +532,7 @@ def __init__( soc_htp_arch, soc_htp_performance, soc_htp_security_pd_session, + soc_htp_vtcm, ) @@ -548,9 +550,11 @@ def __init__( else: self.__init_handle_by_constructor__(QnnAOTEnv.__create__, machine, path) - def create_context(self, name: str) -> QnnDeviceAndContext: + def create_context( + self, name: str, weights_sharing: bool = False + ) -> QnnDeviceAndContext: return tvm_ffi.get_global_func("mllm.qualcomm.QnnAOTEnv.createContext")( - self, name + self, name, weights_sharing ) diff --git a/pymllm/tests/qualcomm/test_context_create.py b/pymllm/tests/qualcomm/test_context_create.py index b80d6df26..18983daa7 100644 --- a/pymllm/tests/qualcomm/test_context_create.py +++ b/pymllm/tests/qualcomm/test_context_create.py @@ -16,10 +16,13 @@ soc_htp_arch=QcomHTPArch.V81(), soc_htp_performance=QcomTryBestPerformance.HtpBurst(), soc_htp_security_pd_session=QcomSecurityPDSession.HtpUnsignedPd(), + soc_htp_vtcm=8, # in MB ), path="/opt/qcom/aistack/qairt/2.41.0.251128/lib/x86_64-linux-clang/", ) if __name__ == "__main__": mllm.echo("Testing tvm-ffi compatibility") - qnn_context: QnnDeviceAndContext = qnn_aot_env.create_context("context.0") + qnn_context: QnnDeviceAndContext = qnn_aot_env.create_context( + "context.0", weights_sharing=False + ) From eaab294cefc6b0536032427e6337fe1711b6c380 Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Fri, 12 Dec 2025 13:22:44 +0000 Subject: [PATCH 5/7] feat(ffi): add SoftmaxOp FFI bindings and implementation - Introduce SoftmaxOpOptions and SoftmaxOp classes in Object.hh for FFI integration - Implement SoftmaxOp creation and dispatch logic in Nn.cc - Add Python bindings for SoftmaxOp and SoftmaxOpOptions in pymllm/ffi/__init__.py - Expose Softmax layer in pymllm.nn module and update Layer base class to support dispatching operations - Enhance Module class with better attribute handling, device management, and string representation - Add test case for Softmax layer usage within a module --- mllm/ffi/Nn.cc | 38 ++++++++++++++++++++++ mllm/ffi/Object.hh | 36 +++++++++++++++++++++ pymllm/ffi/__init__.py | 23 +++++++++++++ pymllm/nn/__init__.py | 2 +- pymllm/nn/_layers.py | 40 ++++++++++++++++++++--- pymllm/nn/_module.py | 72 +++++++++++++++++++++++++++++++++++++---- pymllm/tests/test_nn.py | 19 +++++++++++ 7 files changed, 217 insertions(+), 13 deletions(-) create mode 100644 pymllm/tests/test_nn.py diff --git a/mllm/ffi/Nn.cc b/mllm/ffi/Nn.cc index e69de29bb..71c85aea3 100644 --- a/mllm/ffi/Nn.cc +++ b/mllm/ffi/Nn.cc @@ -0,0 +1,38 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#include +#include + +#include "mllm/core/aops/SoftmaxOp.hpp" +#include "mllm/engine/Context.hpp" +#include "mllm/ffi/Object.hh" + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + + refl::ObjectDef<::mllm::ffi::SoftmaxOpOptionsObj>().def_static("__create__", [](int dim) -> mllm::ffi::SoftmaxOpOptions { + auto v = ::mllm::aops::SoftmaxOpOptions{.axis = dim}; + return mllm::ffi::SoftmaxOpOptions(v); + }); + + refl::ObjectDef<::mllm::ffi::SoftmaxOpObj>(); + refl::GlobalDef().def( + "mllm.aops.__ctx_create_softmax_op", [](const mllm::ffi::Device& d, const mllm::ffi::SoftmaxOpOptions& o) { + auto v = mllm::Context::instance().getBackend(d.get()->device)->createOp(mllm::OpTypes::kSoftmax, o.get()->options_); + return mllm::ffi::BaseOp(v); + }); + // =============================================================== + // Dispatcher things + // =============================================================== + refl::GlobalDef().def("mllm.engine.dispatch", [](const mllm::ffi::Device& d, const mllm::ffi::BaseOp& op, + const tvm::ffi::Array& input_ffi) { + mllm::DispatcherManager::dispatcher_id_t id = (int32_t)(d.get()->device); + std::vector inputs; + for (auto& t : input_ffi) { inputs.push_back(t.get()->mllm_tensor_); } + auto task = mllm::Task::createExecuteOpTask(op.get()->op_ptr_, inputs, {}); + mllm::Context::instance().dispatcherManager()->submit(id, task); + tvm::ffi::Array ret; + for (auto& o : task->outputs) { ret.push_back(mllm::ffi::Tensor(o)); } + return ret; + }); +} diff --git a/mllm/ffi/Object.hh b/mllm/ffi/Object.hh index 13021c871..b23164bcc 100644 --- a/mllm/ffi/Object.hh +++ b/mllm/ffi/Object.hh @@ -129,4 +129,40 @@ class ParameterFile : public tvm::ffi::ObjectRef { TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ParameterFile, tvm::ffi::ObjectRef, ParameterFileObj); // NOLINT }; +//===----------------------------------------------------------------------===// +// MLLM Ops +//===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// MLLM Softmax Op +//===----------------------------------------------------------------------===// +class SoftmaxOpOptionsObj : public tvm::ffi::Object { + public: + ::mllm::aops::SoftmaxOpOptions options_; + + explicit SoftmaxOpOptionsObj(const ::mllm::aops::SoftmaxOpOptions& opt) : options_(opt) { MLLM_EMPTY_SCOPE; } + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("mllm.aops.SoftmaxOpOptions", SoftmaxOpOptionsObj, tvm::ffi::Object); +}; + +class SoftmaxOpOptions : public tvm::ffi::ObjectRef { + public: + explicit SoftmaxOpOptions(::mllm::aops::SoftmaxOpOptions& opt) { data_ = tvm::ffi::make_object(opt); } + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SoftmaxOpOptions, tvm::ffi::ObjectRef, SoftmaxOpOptionsObj); // NOLINT +}; + +class SoftmaxOpObj : public BaseOpObj { + public: + explicit SoftmaxOpObj(const ::mllm::BaseOp::ptr_t& opt) : BaseOpObj(opt) { MLLM_EMPTY_SCOPE; } + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("mllm.aops.SoftmaxOp", SoftmaxOpObj, tvm::ffi::Object); +}; + +class SoftmaxOp : public BaseOp { + public: + explicit SoftmaxOp(::mllm::aops::SoftmaxOp::ptr_t& opt) { data_ = tvm::ffi::make_object(opt); } + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SoftmaxOp, BaseOp, SoftmaxOpObj); // NOLINT +}; + } // namespace mllm::ffi diff --git a/pymllm/ffi/__init__.py b/pymllm/ffi/__init__.py index c22971f7c..b1f9799d6 100644 --- a/pymllm/ffi/__init__.py +++ b/pymllm/ffi/__init__.py @@ -558,6 +558,29 @@ def create_context( ) +# ============================================================================= +# Mllm Ops Binding +# +# ============================================================================= +@tvm_ffi.register_object("mllm.aops.SoftmaxOpOptions") +class SoftmaxOpOptions(tvm_ffi.Object): + def __init__(self, dim=-1): + super().__init__() + self.__init_handle_by_constructor__(SoftmaxOpOptions.__create__, dim) + + +@tvm_ffi.register_object("mllm.aops.SoftmaxOp") +class SoftmaxOp(BaseOp): + def __init__(self): + super().__init__() + + @staticmethod + def create(device: Device, options: SoftmaxOpOptions): + return tvm_ffi.get_global_func("mllm.aops.__ctx_create_softmax_op")( + device, options + ) + + # Initialize context initialize_context() diff --git a/pymllm/nn/__init__.py b/pymllm/nn/__init__.py index baf5de492..f6be0719b 100644 --- a/pymllm/nn/__init__.py +++ b/pymllm/nn/__init__.py @@ -3,4 +3,4 @@ from . import functional from ._module import Module -from ._layers import Linear +from ._layers import Linear, Softmax diff --git a/pymllm/nn/_layers.py b/pymllm/nn/_layers.py index 4dc6c0d66..5adc79cbf 100644 --- a/pymllm/nn/_layers.py +++ b/pymllm/nn/_layers.py @@ -1,11 +1,13 @@ # Copyright (c) MLLM Team. # Licensed under the MIT License. +import tvm_ffi from .. import ffi class _Layer: def __init__(self): + self.device: ffi.Device = ffi.cpu_() self.this_layer_name: str = None self.absolute_name: str = None self._mllm_c_op_ptr: ffi.BaseOp = None @@ -18,22 +20,50 @@ def load(self, pf: ffi.ParameterFile): def trace(self): pass - def forward(self): - # TODO dispatch op - pass + def forward(self, *args): + inputs = [] + for arg in args: + if isinstance(arg, (ffi.Tensor)): + inputs.append(arg) + else: + print( + f"The layer's forward function received a none Tensor type of {type(arg)}. Which is not supported." + ) + ret = tvm_ffi.get_global_func("mllm.engine.dispatch")( + self.device, self._mllm_c_op_ptr, inputs + ) + if len(ret) == 1: + return ret[0] + return ret def __call__(self, *args, **kwds): - pass + return self.forward(*args) + + def __repr__(self): + return "_Layer" class Linear(_Layer): def __init__(self): super().__init__() + def __repr__(self): + return "nn.Linear" + class Softmax(_Layer): - def __init__(self): + def __init__( + self, + dim=-1, + ): super().__init__() + self.dim = dim + self._mllm_c_op_ptr = ffi.SoftmaxOp.create( + self.device, ffi.SoftmaxOpOptions(dim) + ) + + def __repr__(self): + return f"mllm.aops.Softmax(dim={self.dim})" class RoPE(_Layer): diff --git a/pymllm/nn/_module.py b/pymllm/nn/_module.py index 309721ec2..a4f28e17c 100644 --- a/pymllm/nn/_module.py +++ b/pymllm/nn/_module.py @@ -6,13 +6,44 @@ class Module: - def __init__(self): - self.this_module_name: str = None - self.absolute_name: str = None - self.module_layer_list: list = [] + def __init__(self, name: str = "model"): + super().__setattr__("this_module_name", name) + super().__setattr__("absolute_name", name) + super().__setattr__("module_layer_list", {}) + super().__setattr__("_is_initializing", True) + self.device: ffi.Device = ffi.cpu_() + + def __setattr__(self, name: str, value): + super().__setattr__(name, value) + if ( + getattr(self, "_is_initializing", False) + and not name.startswith("_") + and isinstance(value, (Module, _Layer)) + ): + value.this_module_name = name + value.absolute_name = f"{self.absolute_name}.{name}" + self.module_layer_list[name] = value + if isinstance(value, Module): + value._is_initializing = True + + def to(self, x): + if isinstance(x, str): + if x == "qnn": + pass + self.device = ffi.device(x) + elif isinstance(x, ffi.Device): + self.device = x + elif isinstance(x, ffi.DType): + raise NotImplementedError("Module.to(DType) is not supported") + else: + raise TypeError("device must be str or Device, but got {}".format(type(x))) + return self + + def re_naming_finish_initialization(self): + self._is_initializing = False def load(self, pf: ffi.ParameterFile): - for module_layer in self.module_layer_list: + for module_layer in self.module_layer_list.values(): if isinstance(module_layer, Module, _Layer): module_layer.load(pf) else: @@ -22,11 +53,10 @@ def load(self, pf: ffi.ParameterFile): ) ) - def trace(self): + def trace(self, *args): pass def forward(self, *args): - # TODO send to engine's dispatcher pass def __call__(self, *args, **kwds): @@ -35,3 +65,31 @@ def __call__(self, *args, **kwds): return self.trace(*args, **kwds) return self.forward(*args, **kwds) # __send_graph_end() + + def __str__(self): + return self._repr_helper() + + def __repr__(self): + return self._repr_helper() + + def _repr_helper(self, indent_level: int = 0) -> str: + indent = " " * indent_level + next_indent = " " * (indent_level + 1) + module_str = f"{self.__class__.__name__}(" + if self.module_layer_list: + child_lines = [] + for name, child in self.module_layer_list.items(): + if isinstance(child, Module): + child_repr = child._repr_helper(indent_level + 1) + child_lines.append(f"{next_indent}({name}): {child_repr}") + elif isinstance(child, _Layer): + child_lines.append(f"{next_indent}({name}): {repr(child)}") + else: + child_lines.append( + f"{next_indent}({name}): {type(child).__name__}()" + ) + module_str += "\n" + "\n".join(child_lines) + "\n" + indent + ")" + else: + module_str += ")" + + return module_str diff --git a/pymllm/tests/test_nn.py b/pymllm/tests/test_nn.py new file mode 100644 index 000000000..d9a3db2d8 --- /dev/null +++ b/pymllm/tests/test_nn.py @@ -0,0 +1,19 @@ +import pymllm as mllm +from pymllm import nn + + +class FooModule(nn.Module): + def __init__(self): + super().__init__() + self.sf = nn.Softmax(dim=-1) + + def forward(self, x): + x = self.sf(x) + return x + + +if __name__ == "__main__": + x = mllm.ones([6, 10]) + foo = FooModule() + print(foo) + print(foo(x)) From cce49c54084d7972d47c6b77f39359dfe21f5900 Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Tue, 16 Dec 2025 08:08:10 +0000 Subject: [PATCH 6/7] fix(qnn): correct typo in unreachable handle variable name The variable `unreachable_handel_` was misspelled and has been corrected to `unreachable_handle_` across multiple files in the QNN AOT wrapper API. Also updated BaseOpObj and SoftmaxOpObj to use proper TVM FFI object info macros, and improved type hint and error message in Python FFI code. Added new build configuration file for x86 QNN AOT SDK setup with Highway support and OpenMP threading enabled. --- mllm/backends/cpu/CMakeLists.txt | 8 +++++-- mllm/backends/qnn/aot/QnnWrappersAPI.cpp | 14 ++++++------- mllm/backends/qnn/aot/QnnWrappersAPI.hpp | 2 +- mllm/backends/qnn/aot/passes/AOTPass.cpp | 0 mllm/backends/qnn/aot/passes/AOTPass.hpp | 0 mllm/backends/qnn/aot/passes/AOTPipeline.cpp | 0 mllm/backends/qnn/aot/passes/AOTPipeline.hpp | 0 .../qnn/aot/passes/MarkQnnGraphPass.cpp | 0 .../qnn/aot/passes/MarkQnnGraphPass.hpp | 0 .../qnn/aot/passes/QuantPrecisionsCfgPass.cpp | 0 .../qnn/aot/passes/QuantPrecisionsCfgPass.hpp | 0 .../qnn/aot/passes/SplitGraphPass.cpp | 0 .../qnn/aot/passes/SplitGraphPass.hpp | 0 mllm/backends/qnn/aot/visitor/Elewise.cpp | 0 mllm/backends/qnn/aot/visitor/Elewise.hpp | 0 mllm/backends/qnn/aot/visitor/Linear.cpp | 0 mllm/backends/qnn/aot/visitor/Linear.hpp | 0 mllm/backends/qnn/aot/visitor/Matmul.cpp | 0 mllm/backends/qnn/aot/visitor/Matmul.hpp | 0 mllm/backends/qnn/aot/visitor/RMSNorm.cpp | 0 mllm/backends/qnn/aot/visitor/RMSNorm.hpp | 0 mllm/backends/qnn/aot/visitor/RoPE.cpp | 0 mllm/backends/qnn/aot/visitor/RoPE.hpp | 0 mllm/backends/qnn/aot/visitor/SiLU.cpp | 0 mllm/backends/qnn/aot/visitor/SiLU.hpp | 0 mllm/backends/qnn/aot/visitor/Softmax.cpp | 0 mllm/backends/qnn/aot/visitor/Softmax.hpp | 0 mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp | 0 mllm/backends/qnn/aot_rt/QnnAOTRuntime.hpp | 0 mllm/backends/qnn/aot_rt/utils/MaskGen.cpp | 0 mllm/backends/qnn/aot_rt/utils/MaskGen.hpp | 0 .../qnn/aot_rt/utils/PositionIdGen.cpp | 0 .../qnn/aot_rt/utils/PositionIdGen.hpp | 0 mllm/backends/qnn/aot_rt/utils/RoPEGen.cpp | 0 mllm/backends/qnn/aot_rt/utils/RoPEGen.hpp | 0 mllm/ffi/Object.hh | 4 ++-- pymllm/ffi/__init__.py | 4 ++-- tasks/build_sdk_x86_qnn_aot.yaml | 21 +++++++++++++++++++ 38 files changed, 39 insertions(+), 14 deletions(-) create mode 100644 mllm/backends/qnn/aot/passes/AOTPass.cpp create mode 100644 mllm/backends/qnn/aot/passes/AOTPass.hpp create mode 100644 mllm/backends/qnn/aot/passes/AOTPipeline.cpp create mode 100644 mllm/backends/qnn/aot/passes/AOTPipeline.hpp create mode 100644 mllm/backends/qnn/aot/passes/MarkQnnGraphPass.cpp create mode 100644 mllm/backends/qnn/aot/passes/MarkQnnGraphPass.hpp create mode 100644 mllm/backends/qnn/aot/passes/QuantPrecisionsCfgPass.cpp create mode 100644 mllm/backends/qnn/aot/passes/QuantPrecisionsCfgPass.hpp create mode 100644 mllm/backends/qnn/aot/passes/SplitGraphPass.cpp create mode 100644 mllm/backends/qnn/aot/passes/SplitGraphPass.hpp create mode 100644 mllm/backends/qnn/aot/visitor/Elewise.cpp create mode 100644 mllm/backends/qnn/aot/visitor/Elewise.hpp create mode 100644 mllm/backends/qnn/aot/visitor/Linear.cpp create mode 100644 mllm/backends/qnn/aot/visitor/Linear.hpp create mode 100644 mllm/backends/qnn/aot/visitor/Matmul.cpp create mode 100644 mllm/backends/qnn/aot/visitor/Matmul.hpp create mode 100644 mllm/backends/qnn/aot/visitor/RMSNorm.cpp create mode 100644 mllm/backends/qnn/aot/visitor/RMSNorm.hpp create mode 100644 mllm/backends/qnn/aot/visitor/RoPE.cpp create mode 100644 mllm/backends/qnn/aot/visitor/RoPE.hpp create mode 100644 mllm/backends/qnn/aot/visitor/SiLU.cpp create mode 100644 mllm/backends/qnn/aot/visitor/SiLU.hpp create mode 100644 mllm/backends/qnn/aot/visitor/Softmax.cpp create mode 100644 mllm/backends/qnn/aot/visitor/Softmax.hpp create mode 100644 mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp create mode 100644 mllm/backends/qnn/aot_rt/QnnAOTRuntime.hpp create mode 100644 mllm/backends/qnn/aot_rt/utils/MaskGen.cpp create mode 100644 mllm/backends/qnn/aot_rt/utils/MaskGen.hpp create mode 100644 mllm/backends/qnn/aot_rt/utils/PositionIdGen.cpp create mode 100644 mllm/backends/qnn/aot_rt/utils/PositionIdGen.hpp create mode 100644 mllm/backends/qnn/aot_rt/utils/RoPEGen.cpp create mode 100644 mllm/backends/qnn/aot_rt/utils/RoPEGen.hpp create mode 100644 tasks/build_sdk_x86_qnn_aot.yaml diff --git a/mllm/backends/cpu/CMakeLists.txt b/mllm/backends/cpu/CMakeLists.txt index 989878ca2..779d636b5 100644 --- a/mllm/backends/cpu/CMakeLists.txt +++ b/mllm/backends/cpu/CMakeLists.txt @@ -168,6 +168,10 @@ if(MLLM_BUILD_ARM_BACKEND) PATTERN "*.h" PATTERN "*.hpp") else() - # X86 highway - # TODO + install( + TARGETS hwy + EXPORT MllmTargets + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib + RUNTIME DESTINATION bin) endif() diff --git a/mllm/backends/qnn/aot/QnnWrappersAPI.cpp b/mllm/backends/qnn/aot/QnnWrappersAPI.cpp index 86cf75a7c..b2f72e732 100644 --- a/mllm/backends/qnn/aot/QnnWrappersAPI.cpp +++ b/mllm/backends/qnn/aot/QnnWrappersAPI.cpp @@ -139,7 +139,7 @@ void QnnAOTEnv::_setup(const std::string& path) { switch (target_machine_.soc_htp_security_pd_session) { case QcomSecurityPDSession::kHtpSignedPd: { p_custom_config = (QnnHtpDevice_CustomConfig_t*)malloc(sizeof(QnnHtpDevice_CustomConfig_t)); - unreachable_handel_.push_back(p_custom_config); + unreachable_handle_.push_back(p_custom_config); p_custom_config->option = QNN_HTP_DEVICE_CONFIG_OPTION_SIGNEDPD; p_custom_config->useSignedProcessDomain.useSignedProcessDomain = true; p_custom_config->useSignedProcessDomain.deviceId = 0; @@ -256,19 +256,19 @@ std::vector QnnAOTEnv::createDevicePlatformInfo() { QnnDevice_CoreInfo_t* p_core_info = nullptr; p_platform_info = (QnnDevice_PlatformInfo_t*)malloc(sizeof(QnnDevice_PlatformInfo_t)); - unreachable_handel_.push_back(p_platform_info); + unreachable_handle_.push_back(p_platform_info); p_platform_info->version = QNN_DEVICE_PLATFORM_INFO_VERSION_1; p_platform_info->v1.numHwDevices = 1; p_hw_device_info = (QnnDevice_HardwareDeviceInfo_t*)malloc(sizeof(QnnDevice_HardwareDeviceInfo_t)); - unreachable_handel_.push_back(p_hw_device_info); + unreachable_handle_.push_back(p_hw_device_info); p_hw_device_info->version = QNN_DEVICE_HARDWARE_DEVICE_INFO_VERSION_1; p_hw_device_info->v1.deviceId = 0; p_hw_device_info->v1.deviceType = 0; p_hw_device_info->v1.numCores = 1; p_device_info_extension = (QnnHtpDevice_DeviceInfoExtension_t*)malloc(sizeof(QnnHtpDevice_DeviceInfoExtension_t)); - unreachable_handel_.push_back(p_device_info_extension); + unreachable_handle_.push_back(p_device_info_extension); // clang-format off p_device_info_extension->devType = QNN_HTP_DEVICE_TYPE_ON_CHIP; p_device_info_extension->onChipDevice.vtcmSize = target_machine_.soc_htp_vtcm_total_memory_size; // in MB @@ -280,7 +280,7 @@ std::vector QnnAOTEnv::createDevicePlatformInfo() { // clang-format on p_core_info = (QnnDevice_CoreInfo_t*)malloc(sizeof(QnnDevice_CoreInfo_t)); - unreachable_handel_.push_back(p_core_info); + unreachable_handle_.push_back(p_core_info); p_core_info->version = QNN_DEVICE_CORE_INFO_VERSION_1; p_core_info->v1.coreId = 0; p_core_info->v1.coreType = 0; @@ -297,7 +297,7 @@ std::vector QnnAOTEnv::createDecideCustomConfigInfo() std::vector ret; QnnHtpDevice_CustomConfig_t* p_custom_config = (QnnHtpDevice_CustomConfig_t*)malloc(sizeof(QnnHtpDevice_CustomConfig_t)); - unreachable_handel_.push_back(p_custom_config); + unreachable_handle_.push_back(p_custom_config); p_custom_config->option = QNN_HTP_DEVICE_CONFIG_OPTION_SOC; p_custom_config->socModel = static_cast(target_machine_.soc_htp_chipset); ret.push_back(static_cast(p_custom_config)); @@ -311,7 +311,7 @@ std::vector QnnAOTEnv::createContextCustomConfig(bool if (weights_sharing) { p_custom_config = (QnnHtpContext_CustomConfig_t*)malloc(sizeof(QnnHtpContext_CustomConfig_t)); - unreachable_handel_.push_back(p_custom_config); + unreachable_handle_.push_back(p_custom_config); p_custom_config->option = QNN_HTP_CONTEXT_CONFIG_OPTION_WEIGHT_SHARING_ENABLED; p_custom_config->weightSharingEnabled = true; ret.push_back(static_cast(p_custom_config)); diff --git a/mllm/backends/qnn/aot/QnnWrappersAPI.hpp b/mllm/backends/qnn/aot/QnnWrappersAPI.hpp index 5c2f9a0be..ad351d34a 100644 --- a/mllm/backends/qnn/aot/QnnWrappersAPI.hpp +++ b/mllm/backends/qnn/aot/QnnWrappersAPI.hpp @@ -129,7 +129,7 @@ class QnnAOTEnv { std::vector target_machine_qnn_config_ptrs_; // void* handle that should be freed when QnnAOTEnv end - std::vector unreachable_handel_; + std::vector unreachable_handle_; }; } // namespace mllm::qnn::aot diff --git a/mllm/backends/qnn/aot/passes/AOTPass.cpp b/mllm/backends/qnn/aot/passes/AOTPass.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/passes/AOTPass.hpp b/mllm/backends/qnn/aot/passes/AOTPass.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/passes/AOTPipeline.cpp b/mllm/backends/qnn/aot/passes/AOTPipeline.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/passes/AOTPipeline.hpp b/mllm/backends/qnn/aot/passes/AOTPipeline.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/passes/MarkQnnGraphPass.cpp b/mllm/backends/qnn/aot/passes/MarkQnnGraphPass.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/passes/MarkQnnGraphPass.hpp b/mllm/backends/qnn/aot/passes/MarkQnnGraphPass.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/passes/QuantPrecisionsCfgPass.cpp b/mllm/backends/qnn/aot/passes/QuantPrecisionsCfgPass.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/passes/QuantPrecisionsCfgPass.hpp b/mllm/backends/qnn/aot/passes/QuantPrecisionsCfgPass.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/passes/SplitGraphPass.cpp b/mllm/backends/qnn/aot/passes/SplitGraphPass.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/passes/SplitGraphPass.hpp b/mllm/backends/qnn/aot/passes/SplitGraphPass.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/visitor/Elewise.cpp b/mllm/backends/qnn/aot/visitor/Elewise.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/visitor/Elewise.hpp b/mllm/backends/qnn/aot/visitor/Elewise.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/visitor/Linear.cpp b/mllm/backends/qnn/aot/visitor/Linear.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/visitor/Linear.hpp b/mllm/backends/qnn/aot/visitor/Linear.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/visitor/Matmul.cpp b/mllm/backends/qnn/aot/visitor/Matmul.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/visitor/Matmul.hpp b/mllm/backends/qnn/aot/visitor/Matmul.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/visitor/RMSNorm.cpp b/mllm/backends/qnn/aot/visitor/RMSNorm.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/visitor/RMSNorm.hpp b/mllm/backends/qnn/aot/visitor/RMSNorm.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/visitor/RoPE.cpp b/mllm/backends/qnn/aot/visitor/RoPE.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/visitor/RoPE.hpp b/mllm/backends/qnn/aot/visitor/RoPE.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/visitor/SiLU.cpp b/mllm/backends/qnn/aot/visitor/SiLU.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/visitor/SiLU.hpp b/mllm/backends/qnn/aot/visitor/SiLU.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/visitor/Softmax.cpp b/mllm/backends/qnn/aot/visitor/Softmax.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/visitor/Softmax.hpp b/mllm/backends/qnn/aot/visitor/Softmax.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp b/mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot_rt/QnnAOTRuntime.hpp b/mllm/backends/qnn/aot_rt/QnnAOTRuntime.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot_rt/utils/MaskGen.cpp b/mllm/backends/qnn/aot_rt/utils/MaskGen.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot_rt/utils/MaskGen.hpp b/mllm/backends/qnn/aot_rt/utils/MaskGen.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot_rt/utils/PositionIdGen.cpp b/mllm/backends/qnn/aot_rt/utils/PositionIdGen.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot_rt/utils/PositionIdGen.hpp b/mllm/backends/qnn/aot_rt/utils/PositionIdGen.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot_rt/utils/RoPEGen.cpp b/mllm/backends/qnn/aot_rt/utils/RoPEGen.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot_rt/utils/RoPEGen.hpp b/mllm/backends/qnn/aot_rt/utils/RoPEGen.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/ffi/Object.hh b/mllm/ffi/Object.hh index b23164bcc..7cfd1ae21 100644 --- a/mllm/ffi/Object.hh +++ b/mllm/ffi/Object.hh @@ -100,7 +100,7 @@ class BaseOpObj : public tvm::ffi::Object { explicit BaseOpObj(const ::mllm::BaseOp::ptr_t& op_ptr) : op_ptr_(op_ptr) { MLLM_EMPTY_SCOPE; } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("mllm.BaseOp", BaseOpObj, tvm::ffi::Object); + TVM_FFI_DECLARE_OBJECT_INFO("mllm.BaseOp", BaseOpObj, tvm::ffi::Object); }; class BaseOp : public tvm::ffi::ObjectRef { @@ -155,7 +155,7 @@ class SoftmaxOpObj : public BaseOpObj { public: explicit SoftmaxOpObj(const ::mllm::BaseOp::ptr_t& opt) : BaseOpObj(opt) { MLLM_EMPTY_SCOPE; } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("mllm.aops.SoftmaxOp", SoftmaxOpObj, tvm::ffi::Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("mllm.aops.SoftmaxOp", SoftmaxOpObj, BaseOpObj); }; class SoftmaxOp : public BaseOp { diff --git a/pymllm/ffi/__init__.py b/pymllm/ffi/__init__.py index b1f9799d6..f363f0d77 100644 --- a/pymllm/ffi/__init__.py +++ b/pymllm/ffi/__init__.py @@ -540,11 +540,11 @@ def __init__( class QnnAOTEnv(tvm_ffi.Object): def __init__( self, - machine: QcomTargetMachine = None, + machine: QcomTargetMachine | None = None, path: str = None, ): if machine is None: - raise RuntimeError("machine target is none!") + raise ValueError("QnnAOTEnv requires a non-None QcomTargetMachine") if path is None or path == "": self.__init_handle_by_constructor__(QnnAOTEnv.__create__, machine, "") else: diff --git a/tasks/build_sdk_x86_qnn_aot.yaml b/tasks/build_sdk_x86_qnn_aot.yaml new file mode 100644 index 000000000..f33281616 --- /dev/null +++ b/tasks/build_sdk_x86_qnn_aot.yaml @@ -0,0 +1,21 @@ +Tasks: + - CMakeConfigTask: + cmake_cfg_path: "build-qnn-aot" + cmake_build_type: "Release" + cmake_extra_args: + # Optional, If use Highway + - "-DHWY_ENABLE_TESTS=OFF" + - "-DHWY_ENABLE_EXAMPLES=OFF" + - "-DHWY_ENABLE_CONTRIB=OFF" + # Optional + - '-DMLLM_CPU_BACKEND_COMPILE_OPTIONS="-march=native"' + - "-DMLLM_KERNEL_USE_THREADS=ON" + - "-DMLLM_KERNEL_THREADS_VENDOR_OPENMP=ON" + - "-DMLLM_KERNEL_USE_THREADS_VENDOR_MLLM=OFF" + - "-DMLLM_QUALCOMM_QNN_AOT_ON_X86_ENABLE=ON" + - "-DCMAKE_INSTALL_PREFIX=./mllm-sdk-x86-qnn-aot" + + - CMakeBuildTask: + cmake_cfg_path: "build-qnn-aot" + - CMakeInstallTask: + cmake_cfg_path: "build-qnn-aot" From 5b9d9dc4dde68371f4592161b17196c85829c92d Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Fri, 19 Dec 2025 09:32:18 +0000 Subject: [PATCH 7/7] Add QNN AOT support for x86 and enhance Qwen3 model - Introduced `QuantizationAnnotation` in RTTI kind generation. - Added new data types for Int4 and Int2 in DataTypes.hpp. - Implemented `isQnnAOTOnX86Enabled` function in mllm.cpp and declared it in mllm.hpp. - Created a new header file for Qwen3 model with rotary position embedding and attention mechanisms. - Updated FFI to expose `is_qnn_aot_on_x86_enabled` function to Python. - Refactored QNN-related classes in pymllm/ffi/__init__.py to conditionally register based on QNN AOT support. --- examples/CMakeLists.txt | 6 + examples/qwen3_qnn_aot/CMakeLists.txt | 0 mllm/backends/qnn/aot/passes/PTQPass.cpp | 0 mllm/backends/qnn/aot/passes/PTQPass.hpp | 0 mllm/compile/ir/GeneratedRTTIKind.hpp | 4 +- mllm/compile/ir/NodeRTTIClassOfImpl.hpp | 8 +- mllm/compile/ir/linalg/Attribute.cpp | 21 + mllm/compile/ir/linalg/Attribute.hpp | 278 ++++++++++++ mllm/compile/ir/rtti_kind_gen.py | 2 + mllm/core/DataTypes.hpp | 6 + mllm/ffi/Extension.cc | 1 + mllm/mllm.cpp | 8 + mllm/mllm.hpp | 2 + mllm/models/qwen3/modeling_qwen3_qnn_aot.hpp | 355 +++++++++++++++ pymllm/ffi/__init__.py | 451 ++++++++++--------- 15 files changed, 914 insertions(+), 228 deletions(-) create mode 100644 examples/qwen3_qnn_aot/CMakeLists.txt create mode 100644 mllm/backends/qnn/aot/passes/PTQPass.cpp create mode 100644 mllm/backends/qnn/aot/passes/PTQPass.hpp create mode 100644 mllm/compile/ir/linalg/Attribute.cpp create mode 100644 mllm/compile/ir/linalg/Attribute.hpp create mode 100644 mllm/models/qwen3/modeling_qwen3_qnn_aot.hpp diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 5ce568e96..f963dc9e2 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -8,9 +8,15 @@ add_subdirectory(minicpm4) add_subdirectory(qwen3) add_subdirectory(qwen3_service) add_subdirectory(deepseek_ocr) + if(MLLM_BUILD_QNN_BACKEND) add_subdirectory(qwen_npu) endif() + if(MLLM_TRACY_ENABLE) add_subdirectory(tracy_example) endif() + +if(MLLM_QUALCOMM_QNN_AOT_ON_X86_ENABLE) + add_subdirectory(qwen3_qnn_aot) +endif() diff --git a/examples/qwen3_qnn_aot/CMakeLists.txt b/examples/qwen3_qnn_aot/CMakeLists.txt new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/passes/PTQPass.cpp b/mllm/backends/qnn/aot/passes/PTQPass.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/passes/PTQPass.hpp b/mllm/backends/qnn/aot/passes/PTQPass.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/compile/ir/GeneratedRTTIKind.hpp b/mllm/compile/ir/GeneratedRTTIKind.hpp index 1f84298ca..5e8583ff2 100644 --- a/mllm/compile/ir/GeneratedRTTIKind.hpp +++ b/mllm/compile/ir/GeneratedRTTIKind.hpp @@ -1,4 +1,4 @@ -// Auto generated: 2025-11-26 11:54:51 +// Auto generated: 2025-12-19 07:36:12 // do not modify this file #pragma once @@ -133,6 +133,8 @@ enum NodeKind : uint32_t { RK_Val_Last, RK_Attr, RK_Attr_LinalgIRAttr, + RK_Attr_LinalgIRAttr_QuantizationAnnotation, + RK_Attr_LinalgIRAttr_Last, RK_Attr_GraphIRAttr, RK_Attr_TensorIRAttr, RK_Attr_BuiltinIRAttr, diff --git a/mllm/compile/ir/NodeRTTIClassOfImpl.hpp b/mllm/compile/ir/NodeRTTIClassOfImpl.hpp index a45878534..9a36f283d 100644 --- a/mllm/compile/ir/NodeRTTIClassOfImpl.hpp +++ b/mllm/compile/ir/NodeRTTIClassOfImpl.hpp @@ -1,4 +1,4 @@ -// Auto generated: 2025-11-26 11:54:51 +// Auto generated: 2025-12-19 07:36:12 // do not modify this file #pragma once namespace mllm::ir { @@ -325,7 +325,11 @@ struct NodeRTTIClassOfImpl { #define RTTI_RK_ATTR_IMPL(v) return (v)->getKind() >= RK_Attr && (v)->getKind() <= RK_Attr_Last #define RTTI_RK_ATTR_LINALGIRATTR_IMPL(v) \ - return (v)->getKind() >= RK_Attr_LinalgIRAttr && (v)->getKind() <= RK_Attr_LinalgIRAttr + return (v)->getKind() >= RK_Attr_LinalgIRAttr && (v)->getKind() <= RK_Attr_LinalgIRAttr_Last + +#define RTTI_RK_ATTR_LINALGIRATTR_QUANTIZATIONANNOTATION_IMPL(v) \ + return (v)->getKind() >= RK_Attr_LinalgIRAttr_QuantizationAnnotation \ + && (v)->getKind() <= RK_Attr_LinalgIRAttr_QuantizationAnnotation #define RTTI_RK_ATTR_GRAPHIRATTR_IMPL(v) return (v)->getKind() >= RK_Attr_GraphIRAttr && (v)->getKind() <= RK_Attr_GraphIRAttr diff --git a/mllm/compile/ir/linalg/Attribute.cpp b/mllm/compile/ir/linalg/Attribute.cpp new file mode 100644 index 000000000..5d381f170 --- /dev/null +++ b/mllm/compile/ir/linalg/Attribute.cpp @@ -0,0 +1,21 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/compile/ir/linalg/Attribute.hpp" + +namespace mllm::ir::linalg { + +LinalgIRAttr::~LinalgIRAttr() = default; + +LinalgIRAttr::LinalgIRAttr() : Attr(RK_Attr_LinalgIRAttr) {} + +LinalgIRAttr::LinalgIRAttr(const NodeKind& kind) : Attr(kind) {} + +LinalgIRQuantizatonAnnotationAttr::~LinalgIRQuantizatonAnnotationAttr() = default; + +LinalgIRQuantizatonAnnotationAttr::LinalgIRQuantizatonAnnotationAttr() + : LinalgIRAttr(RK_Attr_LinalgIRAttr_QuantizationAnnotation) {} + +LinalgIRQuantizatonAnnotationAttr::LinalgIRQuantizatonAnnotationAttr(const NodeKind& kind) : LinalgIRAttr(kind) {} + +} // namespace mllm::ir::linalg diff --git a/mllm/compile/ir/linalg/Attribute.hpp b/mllm/compile/ir/linalg/Attribute.hpp new file mode 100644 index 000000000..ea632afdb --- /dev/null +++ b/mllm/compile/ir/linalg/Attribute.hpp @@ -0,0 +1,278 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +#include "mllm/compile/ir/GeneratedRTTIKind.hpp" +#include "mllm/compile/ir/Node.hpp" +#include "mllm/compile/ir/NodeRTTIClassOfImpl.hpp" +#include "mllm/core/DataTypes.hpp" +#include "mllm/core/Tensor.hpp" + +namespace mllm::ir::linalg { + +class LinalgIRAttr : public Attr { + public: + DEFINE_SPECIFIC_IR_CLASS(LinalgIRAttr); + + ~LinalgIRAttr() override; + + LinalgIRAttr(); + + explicit LinalgIRAttr(const NodeKind& kind); + + static inline bool classof(const Node* node) { RTTI_RK_ATTR_LINALGIRATTR_IMPL(node); } +}; + +enum class QuantizationSpecType : uint32_t { + kNone = 0, + kSymPerTensor, + kSymPerChannel, + kSymPerBlock, + kAsymPerTensor, + kAsymPerChannel, + kAsymPerBlock, + kLPBQ, +}; + +struct QuantizationSpec { + using ptr_t = std::shared_ptr; + QuantizationSpecType type; +}; + +struct QuantizationSpecSymPerTensor : public QuantizationSpec { + int32_t quant_min = -1; + int32_t quant_max = -1; + DataTypes quant_to_type = kUInt8; + DataTypes scale_type = kFloat32; + Tensor scale = Tensor::nil(); + + static inline ptr_t create(int32_t quant_min, int32_t quant_max, DataTypes quant_to_type, DataTypes scale_type, + Tensor scale) { + auto spec = std::make_shared(); + spec->type = QuantizationSpecType::kSymPerTensor; + spec->quant_min = quant_min; + spec->quant_max = quant_max; + spec->quant_to_type = quant_to_type; + spec->scale_type = scale_type; + spec->scale = std::move(scale); + return spec; + } + + static inline ptr_t create() { + auto spec = std::make_shared(); + spec->type = QuantizationSpecType::kSymPerTensor; + return spec; + } +}; + +struct QuantizationSpecSymPerChannel : public QuantizationSpec { + int32_t quant_min = -1; + int32_t quant_max = -1; + int32_t ch_axis = -1; + DataTypes quant_to_type = kUInt8; + DataTypes scale_type = kFloat32; + Tensor scale = Tensor::nil(); + + static inline ptr_t create(int32_t quant_min, int32_t quant_max, int32_t ch_axis, DataTypes quant_to_type, + DataTypes scale_type, Tensor scale) { + auto spec = std::make_shared(); + spec->type = QuantizationSpecType::kSymPerChannel; + spec->quant_min = quant_min; + spec->quant_max = quant_max; + spec->ch_axis = ch_axis; + spec->quant_to_type = quant_to_type; + spec->scale_type = scale_type; + spec->scale = std::move(scale); + return spec; + } + + static inline ptr_t create() { + auto spec = std::make_shared(); + spec->type = QuantizationSpecType::kSymPerChannel; + return spec; + } +}; + +struct QuantizationSpecSymPerBlock : public QuantizationSpec { + int32_t quant_min = -1; + int32_t quant_max = -1; + int32_t block_size = -1; + DataTypes quant_to_type = kUInt8; + DataTypes scale_type = kFloat32; + Tensor scale = Tensor::nil(); ///< Flattened scale, blocks num + + static inline ptr_t create(int32_t quant_min, int32_t quant_max, int32_t block_size, DataTypes quant_to_type, + DataTypes scale_type, Tensor scale) { + auto spec = std::make_shared(); + spec->type = QuantizationSpecType::kSymPerBlock; + spec->quant_min = quant_min; + spec->quant_max = quant_max; + spec->block_size = block_size; + spec->quant_to_type = quant_to_type; + spec->scale_type = scale_type; + spec->scale = std::move(scale); + return spec; + } + + static inline ptr_t create() { + auto spec = std::make_shared(); + spec->type = QuantizationSpecType::kSymPerBlock; + return spec; + } +}; + +struct QuantizationSpecAsymPerTensor : public QuantizationSpec { + int32_t quant_min = -1; + int32_t quant_max = -1; + DataTypes quant_to_type = kUInt8; + DataTypes scale_type = kFloat32; + DataTypes zero_point_type = kInt32; + Tensor scale = Tensor::nil(); + Tensor zero_point = Tensor::nil(); + + static inline ptr_t create(int32_t quant_min, int32_t quant_max, DataTypes quant_to_type, DataTypes scale_type, + DataTypes zero_point_type, Tensor scale, Tensor zero_point) { + auto spec = std::make_shared(); + spec->type = QuantizationSpecType::kAsymPerTensor; + spec->quant_min = quant_min; + spec->quant_max = quant_max; + spec->quant_to_type = quant_to_type; + spec->scale_type = scale_type; + spec->zero_point_type = zero_point_type; + spec->scale = std::move(scale); + spec->zero_point = std::move(zero_point); + return spec; + } + + static inline ptr_t create() { + auto spec = std::make_shared(); + spec->type = QuantizationSpecType::kAsymPerTensor; + return spec; + } +}; + +struct QuantizationSpecAsymPerChannel : public QuantizationSpec { + int32_t quant_min = -1; + int32_t quant_max = -1; + int32_t ch_axis = -1; + DataTypes quant_to_type = kUInt8; + DataTypes scale_type = kFloat32; + DataTypes zero_point_type = kInt32; + Tensor scale = Tensor::nil(); + Tensor zero_point = Tensor::nil(); + + static inline ptr_t create(int32_t quant_min, int32_t quant_max, int32_t ch_axis, DataTypes quant_to_type, + DataTypes scale_type, DataTypes zero_point_type, Tensor scale, Tensor zero_point) { + auto spec = std::make_shared(); + spec->type = QuantizationSpecType::kAsymPerChannel; + spec->quant_min = quant_min; + spec->quant_max = quant_max; + spec->ch_axis = ch_axis; + spec->quant_to_type = quant_to_type; + spec->scale_type = scale_type; + spec->zero_point_type = zero_point_type; + spec->scale = std::move(scale); + spec->zero_point = std::move(zero_point); + return spec; + } + + static inline ptr_t create() { + auto spec = std::make_shared(); + spec->type = QuantizationSpecType::kAsymPerChannel; + return spec; + } +}; + +struct QuantizationSpecAsymPerBlock : public QuantizationSpec { + int32_t quant_min = -1; + int32_t quant_max = -1; + int32_t block_size = -1; + DataTypes quant_to_type = kUInt8; + DataTypes scale_type = kFloat32; + DataTypes zero_point_type = kInt32; + Tensor scale = Tensor::nil(); ///< Flattened scale, blocks num + Tensor zero_point = Tensor::nil(); ///< Flattened zero_point, blocks num + + static inline ptr_t create(int32_t quant_min, int32_t quant_max, int32_t block_size, DataTypes quant_to_type, + DataTypes scale_type, DataTypes zero_point_type, Tensor scale, Tensor zero_point) { + auto spec = std::make_shared(); + spec->type = QuantizationSpecType::kAsymPerBlock; + spec->quant_min = quant_min; + spec->quant_max = quant_max; + spec->block_size = block_size; + spec->quant_to_type = quant_to_type; + spec->scale_type = scale_type; + spec->zero_point_type = zero_point_type; + spec->scale = std::move(scale); + spec->zero_point = std::move(zero_point); + return spec; + } + + static inline ptr_t create() { + auto spec = std::make_shared(); + spec->type = QuantizationSpecType::kAsymPerBlock; + return spec; + } +}; + +struct QuantizationSpecLPBQ : public QuantizationSpec { + int32_t quant_min = -1; + int32_t quant_max = -1; + int32_t block_size = -1; + int32_t ch_axis = -1; + int32_t scale_level_0_bitwidth = 4; + DataTypes quant_to_type = kUInt8; + DataTypes scale_1_type = kFloat32; + Tensor scale_level_0_int = Tensor::nil(); ///< Flattened scale, blocks num + Tensor scale_level_1_fp = Tensor::nil(); ///< Flattened scale, channel num + + static inline ptr_t create(int32_t quant_min, int32_t quant_max, int32_t block_size, int32_t ch_axis, + int32_t scale_level_0_bitwidth, DataTypes quant_to_type, DataTypes scale_1_type, + Tensor scale_level_0_int, Tensor scale_level_1_fp) { + auto spec = std::make_shared(); + spec->type = QuantizationSpecType::kLPBQ; + spec->quant_min = quant_min; + spec->quant_max = quant_max; + spec->block_size = block_size; + spec->ch_axis = ch_axis; + spec->scale_level_0_bitwidth = scale_level_0_bitwidth; + spec->quant_to_type = quant_to_type; + spec->scale_1_type = scale_1_type; + spec->scale_level_0_int = std::move(scale_level_0_int); + spec->scale_level_1_fp = std::move(scale_level_1_fp); + return spec; + } + + static inline ptr_t create() { + auto spec = std::make_shared(); + spec->type = QuantizationSpecType::kLPBQ; + return spec; + } +}; + +struct QuantizationAnnotation { + std::vector inputs; + std::vector outputs; + std::unordered_map weights; +}; + +class LinalgIRQuantizatonAnnotationAttr final : public LinalgIRAttr { + public: + QuantizationAnnotation annotation_; + + DEFINE_SPECIFIC_IR_CLASS(LinalgIRQuantizatonAnnotationAttr); + + ~LinalgIRQuantizatonAnnotationAttr() override; + + LinalgIRQuantizatonAnnotationAttr(); + + explicit LinalgIRQuantizatonAnnotationAttr(const NodeKind& kind); + + static inline bool classof(const Node* node) { RTTI_RK_ATTR_LINALGIRATTR_QUANTIZATIONANNOTATION_IMPL(node); } +}; + +} // namespace mllm::ir::linalg diff --git a/mllm/compile/ir/rtti_kind_gen.py b/mllm/compile/ir/rtti_kind_gen.py index 4d31ad861..03759db01 100644 --- a/mllm/compile/ir/rtti_kind_gen.py +++ b/mllm/compile/ir/rtti_kind_gen.py @@ -215,6 +215,8 @@ def define_lianlg_ir(ir: dict): val: Cls = ir["Value"] attr: Cls = ir["Attribute"] + attr.derive(Cls("QuantizationAnnotation")) + # op op.derive(Cls("RegisterOp")) op.derive(Cls("CustomKernelOp")) diff --git a/mllm/core/DataTypes.hpp b/mllm/core/DataTypes.hpp index f49b38fa2..a30dafc95 100644 --- a/mllm/core/DataTypes.hpp +++ b/mllm/core/DataTypes.hpp @@ -339,6 +339,12 @@ enum DataTypes : int32_t { kByte = 134, kMXFP4 = 135, + // Int4 and low bits + kInt4 = 136, + KUInt4 = 137, + kInt2 = 138, + kUInt2 = 139, + // complex dtypes for STFT and other ops kComplexFloat32 = 201, kComplexFloat64 = 202, diff --git a/mllm/ffi/Extension.cc b/mllm/ffi/Extension.cc index 1bcb0a1e6..22449f883 100644 --- a/mllm/ffi/Extension.cc +++ b/mllm/ffi/Extension.cc @@ -45,6 +45,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("mllm.echo", mllm::ffi::echo); refl::GlobalDef().def("mllm.initialize_context", mllm::initializeContext); refl::GlobalDef().def("mllm.shutdown_context", mllm::shutdownContext); + refl::GlobalDef().def("mllm.is_qnn_aot_on_x86_enabled", mllm::isQnnAOTOnX86Enabled); // Primitives refl::ObjectDef<::mllm::ffi::DeviceObj>(); diff --git a/mllm/mllm.cpp b/mllm/mllm.cpp index 81aa44309..08a45aefd 100644 --- a/mllm/mllm.cpp +++ b/mllm/mllm.cpp @@ -82,6 +82,14 @@ bool isOpenCLAvailable() { return false; } +bool isQnnAOTOnX86Enabled() { +#ifdef MLLM_QUALCOMM_QNN_AOT_ON_X86_ENABLE + return true; +#else + return false; +#endif +} + bool isQnnAvailable() { #ifdef MLLM_QNN_BACKEND return true; diff --git a/mllm/mllm.hpp b/mllm/mllm.hpp index 9b5ec3581..4a07f0ee7 100644 --- a/mllm/mllm.hpp +++ b/mllm/mllm.hpp @@ -187,6 +187,8 @@ void memoryReport(); bool isOpenCLAvailable(); +bool isQnnAOTOnX86Enabled(); + extern void initOpenCLBackend(); extern void initCudaBackend(); diff --git a/mllm/models/qwen3/modeling_qwen3_qnn_aot.hpp b/mllm/models/qwen3/modeling_qwen3_qnn_aot.hpp new file mode 100644 index 000000000..bea991a55 --- /dev/null +++ b/mllm/models/qwen3/modeling_qwen3_qnn_aot.hpp @@ -0,0 +1,355 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/mllm.hpp" +#include "mllm/nn/Module.hpp" +#include "mllm/nn/Nn.hpp" +#include "mllm/nn/Functional.hpp" +#include "mllm/nn/lmcache/StaticCache.hpp" +#include "mllm/models/qwen3/configuration_qwen3.hpp" +#include "mllm/utils/Enumerate.hpp" +#include "mllm/models/ARGeneration.hpp" + +namespace mllm::models::qwen3 { + +inline auto makeRoPEInvFreq(int output_dim, float rope_theta) -> Tensor { + auto inv_freq = Tensor::empty({output_dim / 2}, kFloat32, kCPU).alloc(); + auto inv_freq_ptr = inv_freq.ptr(); + for (int i = 0; i < output_dim / 2; i++) { inv_freq_ptr[i] = 1.0 / std::pow(rope_theta, 2.0 * i / output_dim); } + return inv_freq; +} + +inline auto makeRotaryPosEmbedding(Tensor& position_ids, const Tensor& inv_freq, float attention_scaling = 1.0f) + -> std::pair { + auto batch_size = position_ids.shape()[0]; + auto seq_len = position_ids.shape()[1]; + auto inv_freq_len = inv_freq.shape()[0]; + auto dim = inv_freq_len * 2; + + // Create freqs tensor: position_ids @ inv_freq + auto freqs = Tensor::empty({batch_size, seq_len, inv_freq_len}, kFloat32, kCPU).alloc(); + auto freqs_ptr = freqs.ptr(); + auto position_ids_ptr = position_ids.ptr(); + auto inv_freq_ptr = inv_freq.ptr(); + + // Compute freqs = position_ids[:, :, None] @ inv_freq[None, :] + for (int b = 0; b < batch_size; ++b) { + for (int s = 0; s < seq_len; ++s) { + auto pos = position_ids_ptr[b * seq_len + s]; + for (int d = 0; d < inv_freq_len; ++d) { + freqs_ptr[b * seq_len * inv_freq_len + s * inv_freq_len + d] = static_cast(pos) * inv_freq_ptr[d]; + } + } + } + + // Create sin and cos tensors with shape [batch_size, seq_len, dim] + auto sin_emb = Tensor::empty({batch_size, seq_len, dim}, kFloat32, kCPU).alloc(); + auto cos_emb = Tensor::empty({batch_size, seq_len, dim}, kFloat32, kCPU).alloc(); + auto sin_ptr = sin_emb.ptr(); + auto cos_ptr = cos_emb.ptr(); + + // Compute sin and cos embeddings: emb = [freqs, freqs] + for (int b = 0; b < batch_size; ++b) { + for (int s = 0; s < seq_len; ++s) { + for (int d = 0; d < inv_freq_len; ++d) { + auto freq = freqs_ptr[b * seq_len * inv_freq_len + s * inv_freq_len + d]; + auto sin_val = std::sin(freq) * attention_scaling; + auto cos_val = std::cos(freq) * attention_scaling; + + // Store the same values in both halves: [freqs, freqs] + sin_ptr[b * seq_len * dim + s * dim + d] = sin_val; + sin_ptr[b * seq_len * dim + s * dim + d + inv_freq_len] = sin_val; + cos_ptr[b * seq_len * dim + s * dim + d] = cos_val; + cos_ptr[b * seq_len * dim + s * dim + d + inv_freq_len] = cos_val; + } + } + } + + return {sin_emb, cos_emb}; +} + +class Qwen3MLP final : public nn::Module { + nn::Linear gate_proj_; + nn::Linear up_proj_; + nn::Linear down_proj_; + nn::SiLU silu_; + + public: + Qwen3MLP() = default; + Qwen3MLP(const std::string& name, const Qwen3Config& cfg) : nn::Module(name) { + gate_proj_ = reg("gate_proj", cfg.hidden_size, cfg.intermediate_size, false, cfg.linear_impl_type); + silu_ = reg("act"); + up_proj_ = reg("up_proj", cfg.hidden_size, cfg.intermediate_size, false, cfg.linear_impl_type); + down_proj_ = reg("down_proj", cfg.intermediate_size, cfg.hidden_size, false, cfg.linear_impl_type); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = gate_proj_(inputs[0]); + x = silu_(x); + auto y = up_proj_(inputs[0]); + x = x * y; + x = down_proj_(x); + return {x}; + } +}; + +class Qwen3Attention final : public nn::Module { + nn::Linear q_proj_; + nn::Linear k_proj_; + nn::Linear v_proj_; + nn::Linear o_proj_; + nn::RMSNorm rms_norm_q_; + nn::RMSNorm rms_norm_k_; + nn::RoPE q_rope_; + nn::RoPE k_rope_; + nn::CausalMask mask_; + nn::Softmax softmax_; + + int hidden_size_; + int head_dim_; + int num_attention_heads_; + int num_key_value_heads_; + int num_key_value_groups_; + + public: + Qwen3Attention() = default; + + Qwen3Attention(const std::string& name, const Qwen3Config& cfg) : nn::Module(name) { + hidden_size_ = cfg.hidden_size; + num_attention_heads_ = cfg.num_attention_heads; + num_key_value_heads_ = cfg.num_key_value_heads; + head_dim_ = cfg.head_dim; + num_key_value_groups_ = num_attention_heads_ / num_key_value_heads_; + + q_proj_ = + reg("q_proj", hidden_size_, head_dim_ * num_attention_heads_, cfg.attention_bias, cfg.linear_impl_type); + k_proj_ = + reg("k_proj", hidden_size_, head_dim_ * num_key_value_heads_, cfg.attention_bias, cfg.linear_impl_type); + v_proj_ = + reg("v_proj", hidden_size_, head_dim_ * num_key_value_heads_, cfg.attention_bias, cfg.linear_impl_type); + o_proj_ = + reg("o_proj", head_dim_ * num_attention_heads_, hidden_size_, cfg.attention_bias, cfg.linear_impl_type); + + rms_norm_q_ = reg("q_norm", cfg.rms_norm_eps); + rms_norm_k_ = reg("k_norm", cfg.rms_norm_eps); + + q_rope_ = reg("q_rope", cfg.rope_theta, cfg.max_position_embeddings); + k_rope_ = reg("k_rope", cfg.rope_theta, cfg.max_position_embeddings); + + mask_ = reg("mask"); + softmax_ = reg("softmax", -1); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = inputs[0]; + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto past_kv_cache = args[0].get(); + + // [B, S, H * D] + auto query_states = q_proj_(x); + auto key_states = k_proj_(x); + auto value_states = v_proj_(x); + + int B = inputs[0].shape()[0]; + int S = inputs[0].shape()[1]; + + // [B, S, H, D] + query_states = query_states.view({B, S, num_attention_heads_, head_dim_}); + key_states = key_states.view({B, S, num_key_value_heads_, head_dim_}); + value_states = value_states.view({B, S, num_key_value_heads_, head_dim_}); + + // [B, S, H, D] + query_states = rms_norm_q_(query_states); + key_states = rms_norm_k_(key_states); + + // [B, H, S, D] + query_states = query_states.transpose(1, 2); + key_states = key_states.transpose(1, 2); + value_states = value_states.transpose(1, 2); + + // [B, H, S, D] + query_states = q_rope_(query_states, llm_embedding_sin, llm_embedding_cos); + key_states = k_rope_(key_states, llm_embedding_sin, llm_embedding_cos); + + // [B, H, S, D] + auto [key_states_new, value_states_new] = past_kv_cache->updateKVCache(layer_idx_, key_states, value_states); + key_states = key_states_new; + value_states = value_states_new; + + Tensor attn; + if (key_states.dtype() == kFloat32) { + // attention weight + // [B, H, S, S] + attn = nn::functional::matmul(query_states, key_states, false, true) * (1.f / sqrtf(head_dim_)); + attn = mask_(attn); + attn = softmax_(attn); + } else if (key_states.dtype() == kFloat16) { + attn = nn::functional::matmul(query_states.to(kFloat32), key_states.to(kFloat32), false, true) * (1.f / sqrtf(head_dim_)); + attn = mask_(attn); + attn = softmax_(attn); + attn = attn.to(kFloat16); + } + + // attn output + // [B, H, S, S] @ [B, H, S, D] -> [B, H, S, D] + auto output = nn::functional::matmul(attn, value_states); + // [B, H, S, D] -> [B, S, H, D] -> [B, S, H * D] + output = output.transpose(1, 2).view({B, S, num_attention_heads_ * head_dim_}); + output = o_proj_(output); + + return {output}; + } + + int layer_idx_; +}; + +class Qwen3Decoder final : public nn::Module { + public: + Qwen3Attention self_attn_; + Qwen3MLP mlp_; + nn::RMSNorm input_layer_norm_; + nn::RMSNorm post_attention_layer_norm_; + + Qwen3Decoder() = default; + + Qwen3Decoder(const std::string& name, const Qwen3Config& cfg) : nn::Module(name) { + self_attn_ = reg("self_attn", cfg); + mlp_ = reg("mlp", cfg); + input_layer_norm_ = reg("input_layernorm", cfg.rms_norm_eps); + post_attention_layer_norm_ = reg("post_attention_layernorm", cfg.rms_norm_eps); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto& kv_cache = args[0]; + + auto x = input_layer_norm_(inputs[0]); + x = self_attn_(x, llm_embedding_sin, llm_embedding_cos, kv_cache)[0]; + auto tmp = x + inputs[0]; + x = post_attention_layer_norm_(tmp); + x = mlp_(x)[0]; + x = x + tmp; + return {x}; + } +}; + +class Qwen3Text final : public nn::Module { + nn::ModuleList decode_blocks_; + nn::RMSNorm norm_; + nn::Embedding embedding_; + + public: + Qwen3Text() = default; + + Qwen3Text(const std::string& name, const Qwen3Config& cfg) : nn::Module(name) { + decode_blocks_ = reg>("layers", cfg.num_hidden_layers, cfg); + for (auto [idx, b] : enumerate(decode_blocks_.list())) { b.self_attn_.layer_idx_ = idx; } + norm_ = reg("norm", cfg.rms_norm_eps); + embedding_ = reg("embed_tokens", cfg.vocab_size, cfg.hidden_size); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto& blocks = decode_blocks_.list(); + + // X is already embedded + auto x = embedding_(inputs[0]); + + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto& kv_cache = args[0]; + + for (auto& block : blocks) { x = block(x, llm_embedding_sin, llm_embedding_cos, kv_cache)[0]; } + + x = norm_(x); + + return {x}; + } +}; + +class Qwen3ForCausalLM : public ARGeneration, public nn::Module { + public: + explicit Qwen3ForCausalLM(const Qwen3Config& cfg) : cfg(cfg) { + kv_cache_ = nn::StaticCache(cfg.max_cache_length, cfg.num_hidden_layers, + cfg.num_attention_heads, // q_heads + cfg.num_key_value_heads, // kv_heads + cfg.head_dim, // kv_dim + kFloat32, // k_dtype + kFloat32, // v_dtype + kCPU, // device_type + false // use_fa2 + ); + eos_token_id_ = cfg.end_of_text_token_id; + max_length_ = cfg.max_cache_length; + tie_word_embeddings_ = cfg.tie_word_embeddings; + + llm = reg("model", cfg); + + if (cfg.tie_word_embeddings) { + // NOTE: + // model.lm_head.weight is quantization weights of model.embed_tokens.weight + lm_head_ = reg("lm_head_out", cfg.hidden_size, cfg.vocab_size, false, cfg.linear_impl_type); + } + + // Init inv freq + auto inv = makeRoPEInvFreq(cfg.head_dim, cfg.rope_theta); + registerBuffer("inv_freq", inv); + } + + ARGenerationOutputPast forward(const ARGenerationOutputPast& input, const ARGenerationArgs& args) override { + auto sequence = input.at("sequence"); + + // Generate position_ids for the current sequence + auto batch_size = sequence.shape()[0]; + auto seq_len = sequence.shape()[1]; + + Tensor position_ids = Tensor::nil(); + if (input.count("position_ids")) { + // Use existing position_ids for decode phase + position_ids = input.at("position_ids"); + + // For decode phase, increment the last position + if (seq_len == 1) { + auto last_pos = *position_ids.offsettedPtr({0, position_ids.shape()[1] - 1}); + position_ids = Tensor::empty({batch_size, 1}, kInt64, kCPU).alloc(); + *position_ids.offsettedPtr({0, 0}) = last_pos + 1; + } + } else { + // Generate position_ids for prefill phase + position_ids = Tensor::empty({batch_size, seq_len}, kInt64, kCPU).alloc(); + auto position_ids_ptr = position_ids.ptr(); + for (int b = 0; b < batch_size; ++b) { + for (int s = 0; s < seq_len; ++s) { position_ids_ptr[b * seq_len + s] = s; } + } + } + + // Generate RoPE embeddings using the inv_freq buffer + auto [llm_embedding_sin, llm_embedding_cos] = makeRotaryPosEmbedding(position_ids, getBuffer("inv_freq"), 1.0f); + + sequence = llm(sequence, llm_embedding_sin, llm_embedding_cos, AnyValue(&kv_cache_))[0]; + + // clip x to one seq length + { + auto S = sequence.shape()[1]; + sequence = sequence[{kAll, {S - 1}, kAll}]; + } + if (tie_word_embeddings_) { sequence = lm_head_(sequence); } + + return { + {"sequence", sequence}, + {"position_ids", position_ids}, + }; + } + + inline nn::StaticCache& kvCache() { return kv_cache_; } + + private: + const Qwen3Config& cfg; + Qwen3Text llm; + nn::Linear lm_head_; + bool tie_word_embeddings_; + nn::StaticCache kv_cache_; +}; + +} // namespace mllm::models::qwen3 diff --git a/pymllm/ffi/__init__.py b/pymllm/ffi/__init__.py index f363f0d77..17bd04c19 100644 --- a/pymllm/ffi/__init__.py +++ b/pymllm/ffi/__init__.py @@ -18,6 +18,10 @@ def echo(rec: str) -> None: return _ffi_api.echo(rec) +def is_qnn_aot_on_x86_enabled() -> bool: + return _ffi_api.is_qnn_aot_on_x86_enabled() + + def initialize_context() -> None: return _ffi_api.initialize_context() @@ -320,242 +324,239 @@ def load(self, pf: ParameterFile): return tvm_ffi.get_global_func("mllm.BaseOp.load")(self, pf) -@tvm_ffi.register_object("mllm.qualcomm.QnnDeviceAndContext") -class QnnDeviceAndContext(tvm_ffi.Object): - def __init__(self): - pass +if is_qnn_aot_on_x86_enabled(): + print("pymllm: is_qnn_aot_on_x86_enabled is true") + @tvm_ffi.register_object("mllm.qualcomm.QnnDeviceAndContext") + class QnnDeviceAndContext(tvm_ffi.Object): + def __init__(self): + pass -@tvm_ffi.register_object("mllm.qualcomm.QcomHTPArch") -class QcomHTPArch(tvm_ffi.Object): - def __init__(self): - pass + @tvm_ffi.register_object("mllm.qualcomm.QcomHTPArch") + class QcomHTPArch(tvm_ffi.Object): + def __init__(self): + pass - @staticmethod - def NONE() -> QcomHTPArch: - return tvm_ffi.get_global_func("mllm.qualcomm.QcomHTPArch.NONE")() + @staticmethod + def NONE() -> QcomHTPArch: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomHTPArch.NONE")() - @staticmethod - def V68() -> QcomHTPArch: - return tvm_ffi.get_global_func("mllm.qualcomm.QcomHTPArch.V68")() + @staticmethod + def V68() -> QcomHTPArch: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomHTPArch.V68")() - @staticmethod - def V69() -> QcomHTPArch: - return tvm_ffi.get_global_func("mllm.qualcomm.QcomHTPArch.V69")() + @staticmethod + def V69() -> QcomHTPArch: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomHTPArch.V69")() - @staticmethod - def V73() -> QcomHTPArch: - return tvm_ffi.get_global_func("mllm.qualcomm.QcomHTPArch.V73")() + @staticmethod + def V73() -> QcomHTPArch: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomHTPArch.V73")() - @staticmethod - def V75() -> QcomHTPArch: - return tvm_ffi.get_global_func("mllm.qualcomm.QcomHTPArch.V75")() + @staticmethod + def V75() -> QcomHTPArch: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomHTPArch.V75")() - @staticmethod - def V79() -> QcomHTPArch: - return tvm_ffi.get_global_func("mllm.qualcomm.QcomHTPArch.V79")() + @staticmethod + def V79() -> QcomHTPArch: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomHTPArch.V79")() - @staticmethod - def V81() -> QcomHTPArch: - return tvm_ffi.get_global_func("mllm.qualcomm.QcomHTPArch.V81")() - - -@tvm_ffi.register_object("mllm.qualcomm.QcomChipset") -class QcomChipset(tvm_ffi.Object): - def __init__(self): - pass - - @staticmethod - def UNKNOWN_SM() -> QcomChipset: - return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.UNKNOWN_SM")() - - @staticmethod - def SA8295() -> QcomChipset: - return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SA8295")() - - @staticmethod - def SM8350() -> QcomChipset: - return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SM8350")() - - @staticmethod - def SM8450() -> QcomChipset: - return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SM8450")() - - @staticmethod - def SM8475() -> QcomChipset: - return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SM8475")() - - @staticmethod - def SM8550() -> QcomChipset: - return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SM8550")() - - @staticmethod - def SM8650() -> QcomChipset: - return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SM8650")() - - @staticmethod - def SM8750() -> QcomChipset: - return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SM8750")() - - @staticmethod - def SM8850() -> QcomChipset: - return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SM8850")() - - @staticmethod - def SSG2115P() -> QcomChipset: - return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SSG2115P")() - - @staticmethod - def SSG2125P() -> QcomChipset: - return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SSG2125P")() - - @staticmethod - def SXR1230P() -> QcomChipset: - return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SXR1230P")() - - @staticmethod - def SXR2230P() -> QcomChipset: - return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SXR2230P")() - - @staticmethod - def SXR2330P() -> QcomChipset: - return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SXR2330P")() + @staticmethod + def V81() -> QcomHTPArch: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomHTPArch.V81")() - @staticmethod - def QCS9100() -> QcomChipset: - return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.QCS9100")() - - @staticmethod - def SAR2230P() -> QcomChipset: - return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SAR2230P")() - - @staticmethod - def SA8255() -> QcomChipset: - return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SA8255")() - - @staticmethod - def SW6100() -> QcomChipset: - return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SW6100")() - - -@tvm_ffi.register_object("mllm.qualcomm.QcomTryBestPerformance") -class QcomTryBestPerformance(tvm_ffi.Object): - def __init__(self): - pass - - @staticmethod - def HtpDefault() -> QcomTryBestPerformance: - return tvm_ffi.get_global_func( - "mllm.qualcomm.QcomTryBestPerformance.HtpDefault" - )() - - @staticmethod - def HtpSustainedHighPerformance() -> QcomTryBestPerformance: - return tvm_ffi.get_global_func( - "mllm.qualcomm.QcomTryBestPerformance.HtpSustainedHighPerformance" - )() - - @staticmethod - def HtpBurst() -> QcomTryBestPerformance: - return tvm_ffi.get_global_func( - "mllm.qualcomm.QcomTryBestPerformance.HtpBurst" - )() - - @staticmethod - def HtpHighPerformance() -> QcomTryBestPerformance: - return tvm_ffi.get_global_func( - "mllm.qualcomm.QcomTryBestPerformance.HtpHighPerformance" - )() - - @staticmethod - def HtpPowerSaver() -> QcomTryBestPerformance: - return tvm_ffi.get_global_func( - "mllm.qualcomm.QcomTryBestPerformance.HtpPowerSaver" - )() - - @staticmethod - def HtpLowPowerSaver() -> QcomTryBestPerformance: - return tvm_ffi.get_global_func( - "mllm.qualcomm.QcomTryBestPerformance.HtpLowPowerSaver" - )() - - @staticmethod - def HtpHighPowerSaver() -> QcomTryBestPerformance: - return tvm_ffi.get_global_func( - "mllm.qualcomm.QcomTryBestPerformance.HtpHighPowerSaver" - )() - - @staticmethod - def HtpLowBalanced() -> QcomTryBestPerformance: - return tvm_ffi.get_global_func( - "mllm.qualcomm.QcomTryBestPerformance.HtpLowBalanced" - )() - - @staticmethod - def HtpBalanced() -> QcomTryBestPerformance: - return tvm_ffi.get_global_func( - "mllm.qualcomm.QcomTryBestPerformance.HtpBalanced" - )() - - -@tvm_ffi.register_object("mllm.qualcomm.QcomSecurityPDSession") -class QcomSecurityPDSession(tvm_ffi.Object): - def __init__(self): - pass - - @staticmethod - def HtpUnsignedPd() -> QcomSecurityPDSession: - return tvm_ffi.get_global_func( - "mllm.qualcomm.QcomSecurityPDSession.HtpUnsignedPd" - )() - - @staticmethod - def HtpSignedPd() -> QcomSecurityPDSession: - return tvm_ffi.get_global_func( - "mllm.qualcomm.QcomSecurityPDSession.HtpSignedPd" - )() - - -@tvm_ffi.register_object("mllm.qualcomm.QcomTargetMachine") -class QcomTargetMachine(tvm_ffi.Object): - def __init__( - self, - soc_htp_chipset: QcomChipset, - soc_htp_arch: QcomHTPArch, - soc_htp_performance: QcomTryBestPerformance, - soc_htp_security_pd_session: QcomSecurityPDSession, - soc_htp_vtcm: int, - ): - self.__init_handle_by_constructor__( - QcomTargetMachine.__create__, - soc_htp_chipset, - soc_htp_arch, - soc_htp_performance, - soc_htp_security_pd_session, - soc_htp_vtcm, - ) - - -@tvm_ffi.register_object("mllm.qualcomm.QnnAOTEnv") -class QnnAOTEnv(tvm_ffi.Object): - def __init__( - self, - machine: QcomTargetMachine | None = None, - path: str = None, - ): - if machine is None: - raise ValueError("QnnAOTEnv requires a non-None QcomTargetMachine") - if path is None or path == "": - self.__init_handle_by_constructor__(QnnAOTEnv.__create__, machine, "") - else: - self.__init_handle_by_constructor__(QnnAOTEnv.__create__, machine, path) + @tvm_ffi.register_object("mllm.qualcomm.QcomChipset") + class QcomChipset(tvm_ffi.Object): + def __init__(self): + pass + + @staticmethod + def UNKNOWN_SM() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.UNKNOWN_SM")() + + @staticmethod + def SA8295() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SA8295")() + + @staticmethod + def SM8350() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SM8350")() + + @staticmethod + def SM8450() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SM8450")() + + @staticmethod + def SM8475() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SM8475")() + + @staticmethod + def SM8550() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SM8550")() + + @staticmethod + def SM8650() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SM8650")() + + @staticmethod + def SM8750() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SM8750")() + + @staticmethod + def SM8850() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SM8850")() + + @staticmethod + def SSG2115P() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SSG2115P")() + + @staticmethod + def SSG2125P() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SSG2125P")() + + @staticmethod + def SXR1230P() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SXR1230P")() + + @staticmethod + def SXR2230P() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SXR2230P")() + + @staticmethod + def SXR2330P() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SXR2330P")() + + @staticmethod + def QCS9100() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.QCS9100")() + + @staticmethod + def SAR2230P() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SAR2230P")() + + @staticmethod + def SA8255() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SA8255")() + + @staticmethod + def SW6100() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SW6100")() + + @tvm_ffi.register_object("mllm.qualcomm.QcomTryBestPerformance") + class QcomTryBestPerformance(tvm_ffi.Object): + def __init__(self): + pass + + @staticmethod + def HtpDefault() -> QcomTryBestPerformance: + return tvm_ffi.get_global_func( + "mllm.qualcomm.QcomTryBestPerformance.HtpDefault" + )() + + @staticmethod + def HtpSustainedHighPerformance() -> QcomTryBestPerformance: + return tvm_ffi.get_global_func( + "mllm.qualcomm.QcomTryBestPerformance.HtpSustainedHighPerformance" + )() + + @staticmethod + def HtpBurst() -> QcomTryBestPerformance: + return tvm_ffi.get_global_func( + "mllm.qualcomm.QcomTryBestPerformance.HtpBurst" + )() + + @staticmethod + def HtpHighPerformance() -> QcomTryBestPerformance: + return tvm_ffi.get_global_func( + "mllm.qualcomm.QcomTryBestPerformance.HtpHighPerformance" + )() + + @staticmethod + def HtpPowerSaver() -> QcomTryBestPerformance: + return tvm_ffi.get_global_func( + "mllm.qualcomm.QcomTryBestPerformance.HtpPowerSaver" + )() + + @staticmethod + def HtpLowPowerSaver() -> QcomTryBestPerformance: + return tvm_ffi.get_global_func( + "mllm.qualcomm.QcomTryBestPerformance.HtpLowPowerSaver" + )() + + @staticmethod + def HtpHighPowerSaver() -> QcomTryBestPerformance: + return tvm_ffi.get_global_func( + "mllm.qualcomm.QcomTryBestPerformance.HtpHighPowerSaver" + )() + + @staticmethod + def HtpLowBalanced() -> QcomTryBestPerformance: + return tvm_ffi.get_global_func( + "mllm.qualcomm.QcomTryBestPerformance.HtpLowBalanced" + )() + + @staticmethod + def HtpBalanced() -> QcomTryBestPerformance: + return tvm_ffi.get_global_func( + "mllm.qualcomm.QcomTryBestPerformance.HtpBalanced" + )() + + @tvm_ffi.register_object("mllm.qualcomm.QcomSecurityPDSession") + class QcomSecurityPDSession(tvm_ffi.Object): + def __init__(self): + pass + + @staticmethod + def HtpUnsignedPd() -> QcomSecurityPDSession: + return tvm_ffi.get_global_func( + "mllm.qualcomm.QcomSecurityPDSession.HtpUnsignedPd" + )() + + @staticmethod + def HtpSignedPd() -> QcomSecurityPDSession: + return tvm_ffi.get_global_func( + "mllm.qualcomm.QcomSecurityPDSession.HtpSignedPd" + )() + + @tvm_ffi.register_object("mllm.qualcomm.QcomTargetMachine") + class QcomTargetMachine(tvm_ffi.Object): + def __init__( + self, + soc_htp_chipset: QcomChipset, + soc_htp_arch: QcomHTPArch, + soc_htp_performance: QcomTryBestPerformance, + soc_htp_security_pd_session: QcomSecurityPDSession, + soc_htp_vtcm: int, + ): + self.__init_handle_by_constructor__( + QcomTargetMachine.__create__, + soc_htp_chipset, + soc_htp_arch, + soc_htp_performance, + soc_htp_security_pd_session, + soc_htp_vtcm, + ) - def create_context( - self, name: str, weights_sharing: bool = False - ) -> QnnDeviceAndContext: - return tvm_ffi.get_global_func("mllm.qualcomm.QnnAOTEnv.createContext")( - self, name, weights_sharing - ) + @tvm_ffi.register_object("mllm.qualcomm.QnnAOTEnv") + class QnnAOTEnv(tvm_ffi.Object): + def __init__( + self, + machine: QcomTargetMachine | None = None, + path: str = None, + ): + if machine is None: + raise ValueError("QnnAOTEnv requires a non-None QcomTargetMachine") + if path is None or path == "": + self.__init_handle_by_constructor__(QnnAOTEnv.__create__, machine, "") + else: + self.__init_handle_by_constructor__(QnnAOTEnv.__create__, machine, path) + + def create_context( + self, name: str, weights_sharing: bool = False + ) -> QnnDeviceAndContext: + return tvm_ffi.get_global_func("mllm.qualcomm.QnnAOTEnv.createContext")( + self, name, weights_sharing + ) # =============================================================================