diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 5ce568e96..f963dc9e2 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -8,9 +8,15 @@ add_subdirectory(minicpm4) add_subdirectory(qwen3) add_subdirectory(qwen3_service) add_subdirectory(deepseek_ocr) + if(MLLM_BUILD_QNN_BACKEND) add_subdirectory(qwen_npu) endif() + if(MLLM_TRACY_ENABLE) add_subdirectory(tracy_example) endif() + +if(MLLM_QUALCOMM_QNN_AOT_ON_X86_ENABLE) + add_subdirectory(qwen3_qnn_aot) +endif() diff --git a/examples/qwen3_qnn_aot/CMakeLists.txt b/examples/qwen3_qnn_aot/CMakeLists.txt new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/cpu/CMakeLists.txt b/mllm/backends/cpu/CMakeLists.txt index 989878ca2..779d636b5 100644 --- a/mllm/backends/cpu/CMakeLists.txt +++ b/mllm/backends/cpu/CMakeLists.txt @@ -168,6 +168,10 @@ if(MLLM_BUILD_ARM_BACKEND) PATTERN "*.h" PATTERN "*.hpp") else() - # X86 highway - # TODO + install( + TARGETS hwy + EXPORT MllmTargets + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib + RUNTIME DESTINATION bin) endif() diff --git a/mllm/backends/qnn/aot/QnnTargetMachine.hpp b/mllm/backends/qnn/aot/QnnTargetMachine.hpp new file mode 100644 index 000000000..6d3823130 --- /dev/null +++ b/mllm/backends/qnn/aot/QnnTargetMachine.hpp @@ -0,0 +1,67 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include + +namespace mllm::qnn::aot { + +enum class QcomHTPArch : uint32_t { + NONE = 0, + V68 = 68, + V69 = 69, + V73 = 73, + V75 = 75, + V79 = 79, + V81 = 81, +}; + +enum QcomChipset : uint32_t { + UNKNOWN_SM = 0, + SA8295 = 39, + SM8350 = 35, + SM8450 = 36, + SM8475 = 42, + SM8550 = 43, + SM8650 = 57, + SM8750 = 69, + SM8850 = 87, + SSG2115P = 46, + SSG2125P = 58, + SXR1230P = 45, + SXR2230P = 53, + SXR2330P = 75, + QCS9100 = 77, + SAR2230P = 95, + SA8255 = 52, + SW6100 = 96, +}; + +enum QcomTryBestPerformance : uint32_t { + kHtpDefault = 0, + kHtpSustainedHighPerformance, + kHtpBurst, + kHtpHighPerformance, + kHtpPowerSaver, + kHtpLowPowerSaver, + kHtpHighPowerSaver, + kHtpLowBalanced, + kHtpBalanced, +}; + +// Protection Domain Session +enum QcomSecurityPDSession : uint32_t { + kHtpUnsignedPd = 0, + kHtpSignedPd, +}; + +struct QcomTargetMachine { + QcomChipset soc_htp_chipset; + QcomHTPArch soc_htp_arch; + QcomTryBestPerformance soc_htp_performance; + QcomSecurityPDSession soc_htp_security_pd_session; + uint32_t soc_htp_vtcm_total_memory_size; +}; + +} // namespace mllm::qnn::aot diff --git a/mllm/backends/qnn/aot/QnnWrappersAPI.cpp b/mllm/backends/qnn/aot/QnnWrappersAPI.cpp index 8144e32d2..b2f72e732 100644 --- a/mllm/backends/qnn/aot/QnnWrappersAPI.cpp +++ b/mllm/backends/qnn/aot/QnnWrappersAPI.cpp @@ -1,6 +1,13 @@ // Copyright (c) MLLM Team. // Licensed under the MIT License. +#include +#include +#include + +#include "QnnContext.h" +#include "mllm/utils/Common.hpp" #include "mllm/backends/qnn/aot/QnnWrappersAPI.hpp" +#include "mllm/backends/qnn/aot/QnnTargetMachine.hpp" namespace mllm::qnn::aot { @@ -67,9 +74,11 @@ bool QnnDynSymbolLoader::loadQnnDynLibAtPath(const std::string& path, const std: return false; } -QnnAOTEnv::QnnAOTEnv() { _setup(); } +QnnAOTEnv::QnnAOTEnv(QcomTargetMachine& target_machine) : target_machine_(target_machine) { _setup(); } -QnnAOTEnv::QnnAOTEnv(const std::string& lib_path) { _setup(lib_path); } +QnnAOTEnv::QnnAOTEnv(const std::string& lib_path, QcomTargetMachine& target_machine) : target_machine_(target_machine) { + _setup(lib_path); +} void QnnAOTEnv::_setup(const std::string& path) { auto& loader = QnnDynSymbolLoader::instance(); @@ -121,9 +130,50 @@ void QnnAOTEnv::_setup(const std::string& path) { MLLM_RT_ASSERT(status != QNN_PROPERTY_ERROR_UNKNOWN_KEY); } + + // Try to config this target machine + { + auto device_custom_config = createDecideCustomConfigInfo(); + QnnHtpDevice_CustomConfig_t* p_custom_config = nullptr; + + switch (target_machine_.soc_htp_security_pd_session) { + case QcomSecurityPDSession::kHtpSignedPd: { + p_custom_config = (QnnHtpDevice_CustomConfig_t*)malloc(sizeof(QnnHtpDevice_CustomConfig_t)); + unreachable_handle_.push_back(p_custom_config); + p_custom_config->option = QNN_HTP_DEVICE_CONFIG_OPTION_SIGNEDPD; + p_custom_config->useSignedProcessDomain.useSignedProcessDomain = true; + p_custom_config->useSignedProcessDomain.deviceId = 0; + device_custom_config.push_back(static_cast(p_custom_config)); + break; + } + case QcomSecurityPDSession::kHtpUnsignedPd: + default: break; + } + + const std::vector device_platform_info = createDevicePlatformInfo(); + uint32_t num_custom_configs = device_platform_info.size() + device_custom_config.size(); + target_machine_qnn_config_.resize(num_custom_configs); + + for (std::size_t i = 0; i < device_custom_config.size(); ++i) { + target_machine_qnn_config_[i].option = QNN_DEVICE_CONFIG_OPTION_CUSTOM; + target_machine_qnn_config_[i].customConfig = device_custom_config[i]; + target_machine_qnn_config_ptrs_.push_back(&target_machine_qnn_config_[i]); + } + + if (!device_platform_info.empty()) { + // The length of platform info can only be 1. + MLLM_RT_ASSERT_EQ(device_platform_info.size(), 1u); + target_machine_qnn_config_[device_custom_config.size()].option = QNN_DEVICE_CONFIG_OPTION_PLATFORM_INFO; + target_machine_qnn_config_[device_custom_config.size()].hardwareInfo = device_platform_info.back(); + target_machine_qnn_config_ptrs_.push_back(&target_machine_qnn_config_[device_custom_config.size()]); + } + + // null terminated + target_machine_qnn_config_ptrs_.push_back(nullptr); + } } -std::shared_ptr QnnAOTEnv::createContext(const std::string& name) { +std::shared_ptr QnnAOTEnv::createContext(const std::string& name, bool weights_sharing) { std::shared_ptr context = std::make_shared(); context->name_ = name; @@ -134,10 +184,9 @@ std::shared_ptr QnnAOTEnv::createContext(const std::string& // clang-format on // 2. Create HTP Device - // FIXME(wch): we need to model each Hexagon machine with its special device info. // clang-format off if (nullptr != qnn_htp_func_symbols_.qnn_interface_.deviceCreate) { - auto status = qnn_htp_func_symbols_.qnn_interface_.deviceCreate(context->log_, nullptr, &context->device_handle_); + auto status = qnn_htp_func_symbols_.qnn_interface_.deviceCreate(context->log_, target_machine_qnn_config_ptrs_.data(), &context->device_handle_); MLLM_RT_ASSERT_EQ(status, QNN_SUCCESS); } // clang-format on @@ -151,6 +200,9 @@ std::shared_ptr QnnAOTEnv::createContext(const std::string& // 4. Create Context { + auto cfgs = createContextCustomConfig(weights_sharing); + // Current not support + MLLM_RT_ASSERT_EQ(cfgs.size(), 0); auto status = qnn_htp_func_symbols_.qnn_interface_.contextCreate(context->bk_handle_, context->device_handle_, (const QnnContext_Config_t**)&context->qnn_context_config_, &context->qnn_ctx_handle_); @@ -196,4 +248,76 @@ void QnnAOTEnv::destroyContext(const std::string& name) { // TODO } +std::vector QnnAOTEnv::createDevicePlatformInfo() { + std::vector ret; + QnnDevice_PlatformInfo_t* p_platform_info = nullptr; + QnnDevice_HardwareDeviceInfo_t* p_hw_device_info = nullptr; + QnnHtpDevice_DeviceInfoExtension_t* p_device_info_extension = nullptr; + QnnDevice_CoreInfo_t* p_core_info = nullptr; + + p_platform_info = (QnnDevice_PlatformInfo_t*)malloc(sizeof(QnnDevice_PlatformInfo_t)); + unreachable_handle_.push_back(p_platform_info); + p_platform_info->version = QNN_DEVICE_PLATFORM_INFO_VERSION_1; + p_platform_info->v1.numHwDevices = 1; + + p_hw_device_info = (QnnDevice_HardwareDeviceInfo_t*)malloc(sizeof(QnnDevice_HardwareDeviceInfo_t)); + unreachable_handle_.push_back(p_hw_device_info); + p_hw_device_info->version = QNN_DEVICE_HARDWARE_DEVICE_INFO_VERSION_1; + p_hw_device_info->v1.deviceId = 0; + p_hw_device_info->v1.deviceType = 0; + p_hw_device_info->v1.numCores = 1; + + p_device_info_extension = (QnnHtpDevice_DeviceInfoExtension_t*)malloc(sizeof(QnnHtpDevice_DeviceInfoExtension_t)); + unreachable_handle_.push_back(p_device_info_extension); + // clang-format off + p_device_info_extension->devType = QNN_HTP_DEVICE_TYPE_ON_CHIP; + p_device_info_extension->onChipDevice.vtcmSize = target_machine_.soc_htp_vtcm_total_memory_size; // in MB + p_device_info_extension->onChipDevice.signedPdSupport = target_machine_.soc_htp_security_pd_session == QcomSecurityPDSession::kHtpSignedPd; + p_device_info_extension->onChipDevice.socModel = static_cast(target_machine_.soc_htp_chipset); + p_device_info_extension->onChipDevice.arch = static_cast(target_machine_.soc_htp_arch); + p_device_info_extension->onChipDevice.dlbcSupport = true; + p_hw_device_info->v1.deviceInfoExtension = p_device_info_extension; + // clang-format on + + p_core_info = (QnnDevice_CoreInfo_t*)malloc(sizeof(QnnDevice_CoreInfo_t)); + unreachable_handle_.push_back(p_core_info); + p_core_info->version = QNN_DEVICE_CORE_INFO_VERSION_1; + p_core_info->v1.coreId = 0; + p_core_info->v1.coreType = 0; + p_core_info->v1.coreInfoExtension = nullptr; + p_hw_device_info->v1.cores = p_core_info; + + p_platform_info->v1.hwDevices = p_hw_device_info; + ret.push_back(p_platform_info); + + return ret; +} + +std::vector QnnAOTEnv::createDecideCustomConfigInfo() { + std::vector ret; + + QnnHtpDevice_CustomConfig_t* p_custom_config = (QnnHtpDevice_CustomConfig_t*)malloc(sizeof(QnnHtpDevice_CustomConfig_t)); + unreachable_handle_.push_back(p_custom_config); + p_custom_config->option = QNN_HTP_DEVICE_CONFIG_OPTION_SOC; + p_custom_config->socModel = static_cast(target_machine_.soc_htp_chipset); + ret.push_back(static_cast(p_custom_config)); + + return ret; +} + +std::vector QnnAOTEnv::createContextCustomConfig(bool weights_sharing) { + std::vector ret; + QnnHtpContext_CustomConfig_t* p_custom_config = nullptr; + + if (weights_sharing) { + p_custom_config = (QnnHtpContext_CustomConfig_t*)malloc(sizeof(QnnHtpContext_CustomConfig_t)); + unreachable_handle_.push_back(p_custom_config); + p_custom_config->option = QNN_HTP_CONTEXT_CONFIG_OPTION_WEIGHT_SHARING_ENABLED; + p_custom_config->weightSharingEnabled = true; + ret.push_back(static_cast(p_custom_config)); + } + + return ret; +} + } // namespace mllm::qnn::aot diff --git a/mllm/backends/qnn/aot/QnnWrappersAPI.hpp b/mllm/backends/qnn/aot/QnnWrappersAPI.hpp index aeaa32785..ad351d34a 100644 --- a/mllm/backends/qnn/aot/QnnWrappersAPI.hpp +++ b/mllm/backends/qnn/aot/QnnWrappersAPI.hpp @@ -14,11 +14,13 @@ #include #include +#include #include #include #include #include +#include "mllm/backends/qnn/aot/QnnTargetMachine.hpp" #include "mllm/utils/Common.hpp" namespace mllm::qnn::aot { @@ -97,21 +99,37 @@ class QnnAOTEnv { public: using ptr_t = std::shared_ptr; - QnnAOTEnv(); + explicit QnnAOTEnv(QcomTargetMachine& target_machine); - explicit QnnAOTEnv(const std::string& lib_path); + QnnAOTEnv(const std::string& lib_path, QcomTargetMachine& target_machine); - std::shared_ptr createContext(const std::string& name); + std::shared_ptr createContext(const std::string& name, bool weights_sharing = false); void saveContext(const std::string& name, const std::string& path); void destroyContext(const std::string& name); + // This is for All PUs, such as CPU, GPU, NPU + std::vector createDevicePlatformInfo(); + + // This function is for NPU only. + std::vector createDecideCustomConfigInfo(); + + std::vector createContextCustomConfig(bool weights_sharing); + private: void _setup(const std::string& path = ""); + QcomTargetMachine target_machine_; QnnFuncSymbols qnn_htp_func_symbols_; std::unordered_map> contexts_; + + // device config for all to use + std::vector target_machine_qnn_config_; + std::vector target_machine_qnn_config_ptrs_; + + // void* handle that should be freed when QnnAOTEnv end + std::vector unreachable_handle_; }; } // namespace mllm::qnn::aot diff --git a/mllm/backends/qnn/aot/passes/AOTPass.cpp b/mllm/backends/qnn/aot/passes/AOTPass.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/passes/AOTPass.hpp b/mllm/backends/qnn/aot/passes/AOTPass.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/passes/AOTPipeline.cpp b/mllm/backends/qnn/aot/passes/AOTPipeline.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/passes/AOTPipeline.hpp b/mllm/backends/qnn/aot/passes/AOTPipeline.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/passes/MarkQnnGraphPass.cpp b/mllm/backends/qnn/aot/passes/MarkQnnGraphPass.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/passes/MarkQnnGraphPass.hpp b/mllm/backends/qnn/aot/passes/MarkQnnGraphPass.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/passes/PTQPass.cpp b/mllm/backends/qnn/aot/passes/PTQPass.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/passes/PTQPass.hpp b/mllm/backends/qnn/aot/passes/PTQPass.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/passes/QuantPrecisionsCfgPass.cpp b/mllm/backends/qnn/aot/passes/QuantPrecisionsCfgPass.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/passes/QuantPrecisionsCfgPass.hpp b/mllm/backends/qnn/aot/passes/QuantPrecisionsCfgPass.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/passes/SplitGraphPass.cpp b/mllm/backends/qnn/aot/passes/SplitGraphPass.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/passes/SplitGraphPass.hpp b/mllm/backends/qnn/aot/passes/SplitGraphPass.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/visitor/Elewise.cpp b/mllm/backends/qnn/aot/visitor/Elewise.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/visitor/Elewise.hpp b/mllm/backends/qnn/aot/visitor/Elewise.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/visitor/Linear.cpp b/mllm/backends/qnn/aot/visitor/Linear.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/visitor/Linear.hpp b/mllm/backends/qnn/aot/visitor/Linear.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/visitor/Matmul.cpp b/mllm/backends/qnn/aot/visitor/Matmul.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/visitor/Matmul.hpp b/mllm/backends/qnn/aot/visitor/Matmul.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/visitor/RMSNorm.cpp b/mllm/backends/qnn/aot/visitor/RMSNorm.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/visitor/RMSNorm.hpp b/mllm/backends/qnn/aot/visitor/RMSNorm.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/visitor/RoPE.cpp b/mllm/backends/qnn/aot/visitor/RoPE.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/visitor/RoPE.hpp b/mllm/backends/qnn/aot/visitor/RoPE.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/visitor/SiLU.cpp b/mllm/backends/qnn/aot/visitor/SiLU.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/visitor/SiLU.hpp b/mllm/backends/qnn/aot/visitor/SiLU.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/visitor/Softmax.cpp b/mllm/backends/qnn/aot/visitor/Softmax.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/visitor/Softmax.hpp b/mllm/backends/qnn/aot/visitor/Softmax.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp b/mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot_rt/QnnAOTRuntime.hpp b/mllm/backends/qnn/aot_rt/QnnAOTRuntime.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot_rt/utils/MaskGen.cpp b/mllm/backends/qnn/aot_rt/utils/MaskGen.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot_rt/utils/MaskGen.hpp b/mllm/backends/qnn/aot_rt/utils/MaskGen.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot_rt/utils/PositionIdGen.cpp b/mllm/backends/qnn/aot_rt/utils/PositionIdGen.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot_rt/utils/PositionIdGen.hpp b/mllm/backends/qnn/aot_rt/utils/PositionIdGen.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot_rt/utils/RoPEGen.cpp b/mllm/backends/qnn/aot_rt/utils/RoPEGen.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot_rt/utils/RoPEGen.hpp b/mllm/backends/qnn/aot_rt/utils/RoPEGen.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/compile/ir/GeneratedRTTIKind.hpp b/mllm/compile/ir/GeneratedRTTIKind.hpp index 1f84298ca..5e8583ff2 100644 --- a/mllm/compile/ir/GeneratedRTTIKind.hpp +++ b/mllm/compile/ir/GeneratedRTTIKind.hpp @@ -1,4 +1,4 @@ -// Auto generated: 2025-11-26 11:54:51 +// Auto generated: 2025-12-19 07:36:12 // do not modify this file #pragma once @@ -133,6 +133,8 @@ enum NodeKind : uint32_t { RK_Val_Last, RK_Attr, RK_Attr_LinalgIRAttr, + RK_Attr_LinalgIRAttr_QuantizationAnnotation, + RK_Attr_LinalgIRAttr_Last, RK_Attr_GraphIRAttr, RK_Attr_TensorIRAttr, RK_Attr_BuiltinIRAttr, diff --git a/mllm/compile/ir/NodeRTTIClassOfImpl.hpp b/mllm/compile/ir/NodeRTTIClassOfImpl.hpp index a45878534..9a36f283d 100644 --- a/mllm/compile/ir/NodeRTTIClassOfImpl.hpp +++ b/mllm/compile/ir/NodeRTTIClassOfImpl.hpp @@ -1,4 +1,4 @@ -// Auto generated: 2025-11-26 11:54:51 +// Auto generated: 2025-12-19 07:36:12 // do not modify this file #pragma once namespace mllm::ir { @@ -325,7 +325,11 @@ struct NodeRTTIClassOfImpl { #define RTTI_RK_ATTR_IMPL(v) return (v)->getKind() >= RK_Attr && (v)->getKind() <= RK_Attr_Last #define RTTI_RK_ATTR_LINALGIRATTR_IMPL(v) \ - return (v)->getKind() >= RK_Attr_LinalgIRAttr && (v)->getKind() <= RK_Attr_LinalgIRAttr + return (v)->getKind() >= RK_Attr_LinalgIRAttr && (v)->getKind() <= RK_Attr_LinalgIRAttr_Last + +#define RTTI_RK_ATTR_LINALGIRATTR_QUANTIZATIONANNOTATION_IMPL(v) \ + return (v)->getKind() >= RK_Attr_LinalgIRAttr_QuantizationAnnotation \ + && (v)->getKind() <= RK_Attr_LinalgIRAttr_QuantizationAnnotation #define RTTI_RK_ATTR_GRAPHIRATTR_IMPL(v) return (v)->getKind() >= RK_Attr_GraphIRAttr && (v)->getKind() <= RK_Attr_GraphIRAttr diff --git a/mllm/compile/ir/linalg/Attribute.cpp b/mllm/compile/ir/linalg/Attribute.cpp new file mode 100644 index 000000000..5d381f170 --- /dev/null +++ b/mllm/compile/ir/linalg/Attribute.cpp @@ -0,0 +1,21 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/compile/ir/linalg/Attribute.hpp" + +namespace mllm::ir::linalg { + +LinalgIRAttr::~LinalgIRAttr() = default; + +LinalgIRAttr::LinalgIRAttr() : Attr(RK_Attr_LinalgIRAttr) {} + +LinalgIRAttr::LinalgIRAttr(const NodeKind& kind) : Attr(kind) {} + +LinalgIRQuantizatonAnnotationAttr::~LinalgIRQuantizatonAnnotationAttr() = default; + +LinalgIRQuantizatonAnnotationAttr::LinalgIRQuantizatonAnnotationAttr() + : LinalgIRAttr(RK_Attr_LinalgIRAttr_QuantizationAnnotation) {} + +LinalgIRQuantizatonAnnotationAttr::LinalgIRQuantizatonAnnotationAttr(const NodeKind& kind) : LinalgIRAttr(kind) {} + +} // namespace mllm::ir::linalg diff --git a/mllm/compile/ir/linalg/Attribute.hpp b/mllm/compile/ir/linalg/Attribute.hpp new file mode 100644 index 000000000..ea632afdb --- /dev/null +++ b/mllm/compile/ir/linalg/Attribute.hpp @@ -0,0 +1,278 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +#include "mllm/compile/ir/GeneratedRTTIKind.hpp" +#include "mllm/compile/ir/Node.hpp" +#include "mllm/compile/ir/NodeRTTIClassOfImpl.hpp" +#include "mllm/core/DataTypes.hpp" +#include "mllm/core/Tensor.hpp" + +namespace mllm::ir::linalg { + +class LinalgIRAttr : public Attr { + public: + DEFINE_SPECIFIC_IR_CLASS(LinalgIRAttr); + + ~LinalgIRAttr() override; + + LinalgIRAttr(); + + explicit LinalgIRAttr(const NodeKind& kind); + + static inline bool classof(const Node* node) { RTTI_RK_ATTR_LINALGIRATTR_IMPL(node); } +}; + +enum class QuantizationSpecType : uint32_t { + kNone = 0, + kSymPerTensor, + kSymPerChannel, + kSymPerBlock, + kAsymPerTensor, + kAsymPerChannel, + kAsymPerBlock, + kLPBQ, +}; + +struct QuantizationSpec { + using ptr_t = std::shared_ptr; + QuantizationSpecType type; +}; + +struct QuantizationSpecSymPerTensor : public QuantizationSpec { + int32_t quant_min = -1; + int32_t quant_max = -1; + DataTypes quant_to_type = kUInt8; + DataTypes scale_type = kFloat32; + Tensor scale = Tensor::nil(); + + static inline ptr_t create(int32_t quant_min, int32_t quant_max, DataTypes quant_to_type, DataTypes scale_type, + Tensor scale) { + auto spec = std::make_shared(); + spec->type = QuantizationSpecType::kSymPerTensor; + spec->quant_min = quant_min; + spec->quant_max = quant_max; + spec->quant_to_type = quant_to_type; + spec->scale_type = scale_type; + spec->scale = std::move(scale); + return spec; + } + + static inline ptr_t create() { + auto spec = std::make_shared(); + spec->type = QuantizationSpecType::kSymPerTensor; + return spec; + } +}; + +struct QuantizationSpecSymPerChannel : public QuantizationSpec { + int32_t quant_min = -1; + int32_t quant_max = -1; + int32_t ch_axis = -1; + DataTypes quant_to_type = kUInt8; + DataTypes scale_type = kFloat32; + Tensor scale = Tensor::nil(); + + static inline ptr_t create(int32_t quant_min, int32_t quant_max, int32_t ch_axis, DataTypes quant_to_type, + DataTypes scale_type, Tensor scale) { + auto spec = std::make_shared(); + spec->type = QuantizationSpecType::kSymPerChannel; + spec->quant_min = quant_min; + spec->quant_max = quant_max; + spec->ch_axis = ch_axis; + spec->quant_to_type = quant_to_type; + spec->scale_type = scale_type; + spec->scale = std::move(scale); + return spec; + } + + static inline ptr_t create() { + auto spec = std::make_shared(); + spec->type = QuantizationSpecType::kSymPerChannel; + return spec; + } +}; + +struct QuantizationSpecSymPerBlock : public QuantizationSpec { + int32_t quant_min = -1; + int32_t quant_max = -1; + int32_t block_size = -1; + DataTypes quant_to_type = kUInt8; + DataTypes scale_type = kFloat32; + Tensor scale = Tensor::nil(); ///< Flattened scale, blocks num + + static inline ptr_t create(int32_t quant_min, int32_t quant_max, int32_t block_size, DataTypes quant_to_type, + DataTypes scale_type, Tensor scale) { + auto spec = std::make_shared(); + spec->type = QuantizationSpecType::kSymPerBlock; + spec->quant_min = quant_min; + spec->quant_max = quant_max; + spec->block_size = block_size; + spec->quant_to_type = quant_to_type; + spec->scale_type = scale_type; + spec->scale = std::move(scale); + return spec; + } + + static inline ptr_t create() { + auto spec = std::make_shared(); + spec->type = QuantizationSpecType::kSymPerBlock; + return spec; + } +}; + +struct QuantizationSpecAsymPerTensor : public QuantizationSpec { + int32_t quant_min = -1; + int32_t quant_max = -1; + DataTypes quant_to_type = kUInt8; + DataTypes scale_type = kFloat32; + DataTypes zero_point_type = kInt32; + Tensor scale = Tensor::nil(); + Tensor zero_point = Tensor::nil(); + + static inline ptr_t create(int32_t quant_min, int32_t quant_max, DataTypes quant_to_type, DataTypes scale_type, + DataTypes zero_point_type, Tensor scale, Tensor zero_point) { + auto spec = std::make_shared(); + spec->type = QuantizationSpecType::kAsymPerTensor; + spec->quant_min = quant_min; + spec->quant_max = quant_max; + spec->quant_to_type = quant_to_type; + spec->scale_type = scale_type; + spec->zero_point_type = zero_point_type; + spec->scale = std::move(scale); + spec->zero_point = std::move(zero_point); + return spec; + } + + static inline ptr_t create() { + auto spec = std::make_shared(); + spec->type = QuantizationSpecType::kAsymPerTensor; + return spec; + } +}; + +struct QuantizationSpecAsymPerChannel : public QuantizationSpec { + int32_t quant_min = -1; + int32_t quant_max = -1; + int32_t ch_axis = -1; + DataTypes quant_to_type = kUInt8; + DataTypes scale_type = kFloat32; + DataTypes zero_point_type = kInt32; + Tensor scale = Tensor::nil(); + Tensor zero_point = Tensor::nil(); + + static inline ptr_t create(int32_t quant_min, int32_t quant_max, int32_t ch_axis, DataTypes quant_to_type, + DataTypes scale_type, DataTypes zero_point_type, Tensor scale, Tensor zero_point) { + auto spec = std::make_shared(); + spec->type = QuantizationSpecType::kAsymPerChannel; + spec->quant_min = quant_min; + spec->quant_max = quant_max; + spec->ch_axis = ch_axis; + spec->quant_to_type = quant_to_type; + spec->scale_type = scale_type; + spec->zero_point_type = zero_point_type; + spec->scale = std::move(scale); + spec->zero_point = std::move(zero_point); + return spec; + } + + static inline ptr_t create() { + auto spec = std::make_shared(); + spec->type = QuantizationSpecType::kAsymPerChannel; + return spec; + } +}; + +struct QuantizationSpecAsymPerBlock : public QuantizationSpec { + int32_t quant_min = -1; + int32_t quant_max = -1; + int32_t block_size = -1; + DataTypes quant_to_type = kUInt8; + DataTypes scale_type = kFloat32; + DataTypes zero_point_type = kInt32; + Tensor scale = Tensor::nil(); ///< Flattened scale, blocks num + Tensor zero_point = Tensor::nil(); ///< Flattened zero_point, blocks num + + static inline ptr_t create(int32_t quant_min, int32_t quant_max, int32_t block_size, DataTypes quant_to_type, + DataTypes scale_type, DataTypes zero_point_type, Tensor scale, Tensor zero_point) { + auto spec = std::make_shared(); + spec->type = QuantizationSpecType::kAsymPerBlock; + spec->quant_min = quant_min; + spec->quant_max = quant_max; + spec->block_size = block_size; + spec->quant_to_type = quant_to_type; + spec->scale_type = scale_type; + spec->zero_point_type = zero_point_type; + spec->scale = std::move(scale); + spec->zero_point = std::move(zero_point); + return spec; + } + + static inline ptr_t create() { + auto spec = std::make_shared(); + spec->type = QuantizationSpecType::kAsymPerBlock; + return spec; + } +}; + +struct QuantizationSpecLPBQ : public QuantizationSpec { + int32_t quant_min = -1; + int32_t quant_max = -1; + int32_t block_size = -1; + int32_t ch_axis = -1; + int32_t scale_level_0_bitwidth = 4; + DataTypes quant_to_type = kUInt8; + DataTypes scale_1_type = kFloat32; + Tensor scale_level_0_int = Tensor::nil(); ///< Flattened scale, blocks num + Tensor scale_level_1_fp = Tensor::nil(); ///< Flattened scale, channel num + + static inline ptr_t create(int32_t quant_min, int32_t quant_max, int32_t block_size, int32_t ch_axis, + int32_t scale_level_0_bitwidth, DataTypes quant_to_type, DataTypes scale_1_type, + Tensor scale_level_0_int, Tensor scale_level_1_fp) { + auto spec = std::make_shared(); + spec->type = QuantizationSpecType::kLPBQ; + spec->quant_min = quant_min; + spec->quant_max = quant_max; + spec->block_size = block_size; + spec->ch_axis = ch_axis; + spec->scale_level_0_bitwidth = scale_level_0_bitwidth; + spec->quant_to_type = quant_to_type; + spec->scale_1_type = scale_1_type; + spec->scale_level_0_int = std::move(scale_level_0_int); + spec->scale_level_1_fp = std::move(scale_level_1_fp); + return spec; + } + + static inline ptr_t create() { + auto spec = std::make_shared(); + spec->type = QuantizationSpecType::kLPBQ; + return spec; + } +}; + +struct QuantizationAnnotation { + std::vector inputs; + std::vector outputs; + std::unordered_map weights; +}; + +class LinalgIRQuantizatonAnnotationAttr final : public LinalgIRAttr { + public: + QuantizationAnnotation annotation_; + + DEFINE_SPECIFIC_IR_CLASS(LinalgIRQuantizatonAnnotationAttr); + + ~LinalgIRQuantizatonAnnotationAttr() override; + + LinalgIRQuantizatonAnnotationAttr(); + + explicit LinalgIRQuantizatonAnnotationAttr(const NodeKind& kind); + + static inline bool classof(const Node* node) { RTTI_RK_ATTR_LINALGIRATTR_QUANTIZATIONANNOTATION_IMPL(node); } +}; + +} // namespace mllm::ir::linalg diff --git a/mllm/compile/ir/rtti_kind_gen.py b/mllm/compile/ir/rtti_kind_gen.py index 4d31ad861..03759db01 100644 --- a/mllm/compile/ir/rtti_kind_gen.py +++ b/mllm/compile/ir/rtti_kind_gen.py @@ -215,6 +215,8 @@ def define_lianlg_ir(ir: dict): val: Cls = ir["Value"] attr: Cls = ir["Attribute"] + attr.derive(Cls("QuantizationAnnotation")) + # op op.derive(Cls("RegisterOp")) op.derive(Cls("CustomKernelOp")) diff --git a/mllm/core/DataTypes.hpp b/mllm/core/DataTypes.hpp index f49b38fa2..a30dafc95 100644 --- a/mllm/core/DataTypes.hpp +++ b/mllm/core/DataTypes.hpp @@ -339,6 +339,12 @@ enum DataTypes : int32_t { kByte = 134, kMXFP4 = 135, + // Int4 and low bits + kInt4 = 136, + KUInt4 = 137, + kInt2 = 138, + kUInt2 = 139, + // complex dtypes for STFT and other ops kComplexFloat32 = 201, kComplexFloat64 = 202, diff --git a/mllm/ffi/Extension.cc b/mllm/ffi/Extension.cc index 4744e172d..22449f883 100644 --- a/mllm/ffi/Extension.cc +++ b/mllm/ffi/Extension.cc @@ -45,8 +45,11 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("mllm.echo", mllm::ffi::echo); refl::GlobalDef().def("mllm.initialize_context", mllm::initializeContext); refl::GlobalDef().def("mllm.shutdown_context", mllm::shutdownContext); + refl::GlobalDef().def("mllm.is_qnn_aot_on_x86_enabled", mllm::isQnnAOTOnX86Enabled); // Primitives + refl::ObjectDef<::mllm::ffi::DeviceObj>(); + refl::ObjectDef<::mllm::ffi::DTypeObj>(); refl::GlobalDef().def("mllm.cpu_", []() -> mllm::ffi::Device { return mllm::ffi::Device(::mllm::DeviceTypes::kCPU); }); refl::GlobalDef().def("mllm.cuda_", []() -> mllm::ffi::Device { return mllm::ffi::Device(::mllm::DeviceTypes::kCUDA); }); refl::GlobalDef().def("mllm.qnn_", []() -> mllm::ffi::Device { return mllm::ffi::Device(::mllm::DeviceTypes::kQNN); }); @@ -325,6 +328,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; + refl::ObjectDef<::mllm::ffi::BaseOpObj>(); refl::GlobalDef().def("mllm.BaseOp.load", [](const mllm::ffi::BaseOp& self, const mllm::ffi::ParameterFile& obj) -> void { self.get()->op_ptr_->load(obj.get()->pf_ptr_); }); @@ -336,6 +340,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; + refl::ObjectDef<::mllm::ffi::SessionObj>(); refl::GlobalDef().def("mllm.service.startService", [](int work_threads = 1) -> void { ::mllm::service::startService(work_threads); }); refl::GlobalDef().def("mllm.service.stopService", []() -> void { ::mllm::service::stopService(); }); diff --git a/mllm/ffi/Nn.cc b/mllm/ffi/Nn.cc index e69de29bb..71c85aea3 100644 --- a/mllm/ffi/Nn.cc +++ b/mllm/ffi/Nn.cc @@ -0,0 +1,38 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#include +#include + +#include "mllm/core/aops/SoftmaxOp.hpp" +#include "mllm/engine/Context.hpp" +#include "mllm/ffi/Object.hh" + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + + refl::ObjectDef<::mllm::ffi::SoftmaxOpOptionsObj>().def_static("__create__", [](int dim) -> mllm::ffi::SoftmaxOpOptions { + auto v = ::mllm::aops::SoftmaxOpOptions{.axis = dim}; + return mllm::ffi::SoftmaxOpOptions(v); + }); + + refl::ObjectDef<::mllm::ffi::SoftmaxOpObj>(); + refl::GlobalDef().def( + "mllm.aops.__ctx_create_softmax_op", [](const mllm::ffi::Device& d, const mllm::ffi::SoftmaxOpOptions& o) { + auto v = mllm::Context::instance().getBackend(d.get()->device)->createOp(mllm::OpTypes::kSoftmax, o.get()->options_); + return mllm::ffi::BaseOp(v); + }); + // =============================================================== + // Dispatcher things + // =============================================================== + refl::GlobalDef().def("mllm.engine.dispatch", [](const mllm::ffi::Device& d, const mllm::ffi::BaseOp& op, + const tvm::ffi::Array& input_ffi) { + mllm::DispatcherManager::dispatcher_id_t id = (int32_t)(d.get()->device); + std::vector inputs; + for (auto& t : input_ffi) { inputs.push_back(t.get()->mllm_tensor_); } + auto task = mllm::Task::createExecuteOpTask(op.get()->op_ptr_, inputs, {}); + mllm::Context::instance().dispatcherManager()->submit(id, task); + tvm::ffi::Array ret; + for (auto& o : task->outputs) { ret.push_back(mllm::ffi::Tensor(o)); } + return ret; + }); +} diff --git a/mllm/ffi/Object.hh b/mllm/ffi/Object.hh index 13021c871..7cfd1ae21 100644 --- a/mllm/ffi/Object.hh +++ b/mllm/ffi/Object.hh @@ -100,7 +100,7 @@ class BaseOpObj : public tvm::ffi::Object { explicit BaseOpObj(const ::mllm::BaseOp::ptr_t& op_ptr) : op_ptr_(op_ptr) { MLLM_EMPTY_SCOPE; } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("mllm.BaseOp", BaseOpObj, tvm::ffi::Object); + TVM_FFI_DECLARE_OBJECT_INFO("mllm.BaseOp", BaseOpObj, tvm::ffi::Object); }; class BaseOp : public tvm::ffi::ObjectRef { @@ -129,4 +129,40 @@ class ParameterFile : public tvm::ffi::ObjectRef { TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ParameterFile, tvm::ffi::ObjectRef, ParameterFileObj); // NOLINT }; +//===----------------------------------------------------------------------===// +// MLLM Ops +//===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// MLLM Softmax Op +//===----------------------------------------------------------------------===// +class SoftmaxOpOptionsObj : public tvm::ffi::Object { + public: + ::mllm::aops::SoftmaxOpOptions options_; + + explicit SoftmaxOpOptionsObj(const ::mllm::aops::SoftmaxOpOptions& opt) : options_(opt) { MLLM_EMPTY_SCOPE; } + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("mllm.aops.SoftmaxOpOptions", SoftmaxOpOptionsObj, tvm::ffi::Object); +}; + +class SoftmaxOpOptions : public tvm::ffi::ObjectRef { + public: + explicit SoftmaxOpOptions(::mllm::aops::SoftmaxOpOptions& opt) { data_ = tvm::ffi::make_object(opt); } + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SoftmaxOpOptions, tvm::ffi::ObjectRef, SoftmaxOpOptionsObj); // NOLINT +}; + +class SoftmaxOpObj : public BaseOpObj { + public: + explicit SoftmaxOpObj(const ::mllm::BaseOp::ptr_t& opt) : BaseOpObj(opt) { MLLM_EMPTY_SCOPE; } + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("mllm.aops.SoftmaxOp", SoftmaxOpObj, BaseOpObj); +}; + +class SoftmaxOp : public BaseOp { + public: + explicit SoftmaxOp(::mllm::aops::SoftmaxOp::ptr_t& opt) { data_ = tvm::ffi::make_object(opt); } + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SoftmaxOp, BaseOp, SoftmaxOpObj); // NOLINT +}; + } // namespace mllm::ffi diff --git a/mllm/ffi/qualcomm/QnnAOT.cc b/mllm/ffi/qualcomm/QnnAOT.cc index dd1fac055..e36cad641 100644 --- a/mllm/ffi/qualcomm/QnnAOT.cc +++ b/mllm/ffi/qualcomm/QnnAOT.cc @@ -8,6 +8,7 @@ #include #include +#include "mllm/backends/qnn/aot/QnnTargetMachine.hpp" #include "mllm/ffi/qualcomm/QnnAOT.hh" #ifdef MLLM_QUALCOMM_QNN_AOT_ON_X86_ENABLE @@ -15,20 +16,196 @@ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_static("__create__", [](const std::string& path) -> mllm::ffi::QnnAOTEnv { - if (path.empty()) { - auto s = std::make_shared<::mllm::qnn::aot::QnnAOTEnv>(); - return ::mllm::ffi::QnnAOTEnv(s); - } else { - auto s = std::make_shared<::mllm::qnn::aot::QnnAOTEnv>(path); - return ::mllm::ffi::QnnAOTEnv(s); - } + refl::ObjectDef<::mllm::ffi::QcomHTPArchObj>(); + + refl::GlobalDef().def("mllm.qualcomm.QcomHTPArch.NONE", []() { + auto ret = mllm::qnn::aot::QcomHTPArch::NONE; + return mllm::ffi::QcomHTPArch(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomHTPArch.V68", []() { + auto ret = mllm::qnn::aot::QcomHTPArch::V68; + return mllm::ffi::QcomHTPArch(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomHTPArch.V69", []() { + auto ret = mllm::qnn::aot::QcomHTPArch::V69; + return mllm::ffi::QcomHTPArch(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomHTPArch.V73", []() { + auto ret = mllm::qnn::aot::QcomHTPArch::V73; + return mllm::ffi::QcomHTPArch(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomHTPArch.V75", []() { + auto ret = mllm::qnn::aot::QcomHTPArch::V75; + return mllm::ffi::QcomHTPArch(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomHTPArch.V79", []() { + auto ret = mllm::qnn::aot::QcomHTPArch::V79; + return mllm::ffi::QcomHTPArch(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomHTPArch.V81", []() { + auto ret = mllm::qnn::aot::QcomHTPArch::V81; + return mllm::ffi::QcomHTPArch(ret); + }); + + refl::ObjectDef<::mllm::ffi::QcomChipsetObj>(); + + refl::GlobalDef().def("mllm.qualcomm.QcomChipset.UNKNOWN_SM", []() { + auto ret = mllm::qnn::aot::QcomChipset::UNKNOWN_SM; + return mllm::ffi::QcomChipset(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomChipset.SA8295", []() { + auto ret = mllm::qnn::aot::QcomChipset::SA8295; + return mllm::ffi::QcomChipset(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomChipset.SM8350", []() { + auto ret = mllm::qnn::aot::QcomChipset::SM8350; + return mllm::ffi::QcomChipset(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomChipset.SM8450", []() { + auto ret = mllm::qnn::aot::QcomChipset::SM8450; + return mllm::ffi::QcomChipset(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomChipset.SM8475", []() { + auto ret = mllm::qnn::aot::QcomChipset::SM8475; + return mllm::ffi::QcomChipset(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomChipset.SM8550", []() { + auto ret = mllm::qnn::aot::QcomChipset::SM8550; + return mllm::ffi::QcomChipset(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomChipset.SM8650", []() { + auto ret = mllm::qnn::aot::QcomChipset::SM8650; + return mllm::ffi::QcomChipset(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomChipset.SM8750", []() { + auto ret = mllm::qnn::aot::QcomChipset::SM8750; + return mllm::ffi::QcomChipset(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomChipset.SM8850", []() { + auto ret = mllm::qnn::aot::QcomChipset::SM8850; + return mllm::ffi::QcomChipset(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomChipset.SSG2115P", []() { + auto ret = mllm::qnn::aot::QcomChipset::SSG2115P; + return mllm::ffi::QcomChipset(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomChipset.SSG2125P", []() { + auto ret = mllm::qnn::aot::QcomChipset::SSG2125P; + return mllm::ffi::QcomChipset(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomChipset.SXR1230P", []() { + auto ret = mllm::qnn::aot::QcomChipset::SXR1230P; + return mllm::ffi::QcomChipset(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomChipset.SXR2230P", []() { + auto ret = mllm::qnn::aot::QcomChipset::SXR2230P; + return mllm::ffi::QcomChipset(ret); }); + refl::GlobalDef().def("mllm.qualcomm.QcomChipset.SXR2330P", []() { + auto ret = mllm::qnn::aot::QcomChipset::SXR2330P; + return mllm::ffi::QcomChipset(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomChipset.QCS9100", []() { + auto ret = mllm::qnn::aot::QcomChipset::QCS9100; + return mllm::ffi::QcomChipset(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomChipset.SAR2230P", []() { + auto ret = mllm::qnn::aot::QcomChipset::SAR2230P; + return mllm::ffi::QcomChipset(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomChipset.SA8255", []() { + auto ret = mllm::qnn::aot::QcomChipset::SA8255; + return mllm::ffi::QcomChipset(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomChipset.SW6100", []() { + auto ret = mllm::qnn::aot::QcomChipset::SW6100; + return mllm::ffi::QcomChipset(ret); + }); + + refl::ObjectDef<::mllm::ffi::QcomTryBestPerformanceObj>(); + + refl::GlobalDef().def("mllm.qualcomm.QcomTryBestPerformance.HtpDefault", []() { + auto ret = mllm::qnn::aot::QcomTryBestPerformance::kHtpDefault; + return mllm::ffi::QcomTryBestPerformance(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomTryBestPerformance.HtpSustainedHighPerformance", []() { + auto ret = mllm::qnn::aot::QcomTryBestPerformance::kHtpSustainedHighPerformance; + return mllm::ffi::QcomTryBestPerformance(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomTryBestPerformance.HtpBurst", []() { + auto ret = mllm::qnn::aot::QcomTryBestPerformance::kHtpBurst; + return mllm::ffi::QcomTryBestPerformance(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomTryBestPerformance.HtpHighPerformance", []() { + auto ret = mllm::qnn::aot::QcomTryBestPerformance::kHtpHighPerformance; + return mllm::ffi::QcomTryBestPerformance(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomTryBestPerformance.HtpPowerSaver", []() { + auto ret = mllm::qnn::aot::QcomTryBestPerformance::kHtpPowerSaver; + return mllm::ffi::QcomTryBestPerformance(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomTryBestPerformance.HtpLowPowerSaver", []() { + auto ret = mllm::qnn::aot::QcomTryBestPerformance::kHtpLowPowerSaver; + return mllm::ffi::QcomTryBestPerformance(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomTryBestPerformance.HtpHighPowerSaver", []() { + auto ret = mllm::qnn::aot::QcomTryBestPerformance::kHtpHighPowerSaver; + return mllm::ffi::QcomTryBestPerformance(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomTryBestPerformance.HtpLowBalanced", []() { + auto ret = mllm::qnn::aot::QcomTryBestPerformance::kHtpLowBalanced; + return mllm::ffi::QcomTryBestPerformance(ret); + }); + refl::GlobalDef().def("mllm.qualcomm.QcomTryBestPerformance.HtpBalanced", []() { + auto ret = mllm::qnn::aot::QcomTryBestPerformance::kHtpBalanced; + return mllm::ffi::QcomTryBestPerformance(ret); + }); + + refl::ObjectDef<::mllm::ffi::QcomSecurityPDSessionObj>(); - refl::GlobalDef().def("mllm.qualcomm.QnnAOTEnv.createContext", [](const mllm::ffi::QnnAOTEnv& self, const std::string& name) { - auto s = self.get()->qnn_aot_env_ptr_->createContext(name); - return mllm::ffi::QnnDeviceAndContext(s); + refl::GlobalDef().def("mllm.qualcomm.QcomSecurityPDSession.HtpUnsignedPd", []() { + auto ret = mllm::qnn::aot::QcomSecurityPDSession::kHtpUnsignedPd; + return mllm::ffi::QcomSecurityPDSession(ret); }); + refl::GlobalDef().def("mllm.qualcomm.QcomSecurityPDSession.HtpSignedPd", []() { + auto ret = mllm::qnn::aot::QcomSecurityPDSession::kHtpSignedPd; + return mllm::ffi::QcomSecurityPDSession(ret); + }); + + refl::ObjectDef().def_static( + "__create__", + [](const mllm::ffi::QcomChipset& chipset, const mllm::ffi::QcomHTPArch& arch, + const mllm::ffi::QcomTryBestPerformance& perf, const mllm::ffi::QcomSecurityPDSession& pd_session, uint32_t htp_vtcm) { + auto tm = mllm::qnn::aot::QcomTargetMachine{ + .soc_htp_chipset = chipset.get()->chipset_, + .soc_htp_arch = arch.get()->htp_arch_, + .soc_htp_performance = perf.get()->perf_, + .soc_htp_security_pd_session = pd_session.get()->pd_, + .soc_htp_vtcm_total_memory_size = htp_vtcm, + }; + return ::mllm::ffi::QcomTargetMachine(tm); + }); + + refl::ObjectDef().def_static( + "__create__", [](const mllm::ffi::QcomTargetMachine& machine, const std::string& path) -> mllm::ffi::QnnAOTEnv { + if (path.empty()) { + auto tm = machine.get()->target_machine_; + auto s = std::make_shared<::mllm::qnn::aot::QnnAOTEnv>(tm); + return ::mllm::ffi::QnnAOTEnv(s); + } else { + auto tm = machine.get()->target_machine_; + auto s = std::make_shared<::mllm::qnn::aot::QnnAOTEnv>(path, tm); + return ::mllm::ffi::QnnAOTEnv(s); + } + }); + + refl::ObjectDef<::mllm::ffi::QnnDeviceAndContextObj>(); + + refl::GlobalDef().def("mllm.qualcomm.QnnAOTEnv.createContext", + [](const mllm::ffi::QnnAOTEnv& self, const std::string& name, bool weights_sharing) { + auto s = self.get()->qnn_aot_env_ptr_->createContext(name, weights_sharing); + return mllm::ffi::QnnDeviceAndContext(s); + }); } #endif diff --git a/mllm/ffi/qualcomm/QnnAOT.hh b/mllm/ffi/qualcomm/QnnAOT.hh index 321acc6ea..f0feb46f3 100644 --- a/mllm/ffi/qualcomm/QnnAOT.hh +++ b/mllm/ffi/qualcomm/QnnAOT.hh @@ -59,6 +59,107 @@ class QnnDeviceAndContext : public tvm::ffi::ObjectRef { TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(QnnDeviceAndContext, tvm::ffi::ObjectRef, QnnDeviceAndContextObj); // NOLINT }; +//===----------------------------------------------------------------------===// +// MLLM QcomHTPArch Define +//===----------------------------------------------------------------------===// +class QcomHTPArchObj : public tvm::ffi::Object { + public: + mllm::qnn::aot::QcomHTPArch htp_arch_; + + explicit QcomHTPArchObj(const mllm::qnn::aot::QcomHTPArch& obj) : htp_arch_(obj) { MLLM_EMPTY_SCOPE; } + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("mllm.qualcomm.QcomHTPArch", QcomHTPArchObj, tvm::ffi::Object); +}; + +class QcomHTPArch : public tvm::ffi::ObjectRef { + public: + explicit QcomHTPArch(mllm::qnn::aot::QcomHTPArch& ptr) { data_ = tvm::ffi::make_object(ptr); } + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(QcomHTPArch, tvm::ffi::ObjectRef, QcomHTPArchObj); // NOLINT +}; + +//===----------------------------------------------------------------------===// +// MLLM QcomChipset Define +//===----------------------------------------------------------------------===// +class QcomChipsetObj : public tvm::ffi::Object { + public: + mllm::qnn::aot::QcomChipset chipset_; + + explicit QcomChipsetObj(const mllm::qnn::aot::QcomChipset& obj) : chipset_(obj) { MLLM_EMPTY_SCOPE; } + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("mllm.qualcomm.QcomChipset", QcomChipsetObj, tvm::ffi::Object); +}; + +class QcomChipset : public tvm::ffi::ObjectRef { + public: + explicit QcomChipset(mllm::qnn::aot::QcomChipset& ptr) { data_ = tvm::ffi::make_object(ptr); } + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(QcomChipset, tvm::ffi::ObjectRef, QcomChipsetObj); // NOLINT +}; + +//===----------------------------------------------------------------------===// +// MLLM QcomTryBestPerformance Define +//===----------------------------------------------------------------------===// +class QcomTryBestPerformanceObj : public tvm::ffi::Object { + public: + mllm::qnn::aot::QcomTryBestPerformance perf_; + + explicit QcomTryBestPerformanceObj(const mllm::qnn::aot::QcomTryBestPerformance& obj) : perf_(obj) { MLLM_EMPTY_SCOPE; } + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("mllm.qualcomm.QcomTryBestPerformance", QcomTryBestPerformanceObj, tvm::ffi::Object); +}; + +class QcomTryBestPerformance : public tvm::ffi::ObjectRef { + public: + explicit QcomTryBestPerformance(mllm::qnn::aot::QcomTryBestPerformance& ptr) { + data_ = tvm::ffi::make_object(ptr); + } + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(QcomTryBestPerformance, tvm::ffi::ObjectRef, QcomTryBestPerformanceObj); // NOLINT +}; + +//===----------------------------------------------------------------------===// +// MLLM QcomSecurityPDSession Define +//===----------------------------------------------------------------------===// +class QcomSecurityPDSessionObj : public tvm::ffi::Object { + public: + mllm::qnn::aot::QcomSecurityPDSession pd_; + + explicit QcomSecurityPDSessionObj(const mllm::qnn::aot::QcomSecurityPDSession& obj) : pd_(obj) { MLLM_EMPTY_SCOPE; } + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("mllm.qualcomm.QcomSecurityPDSession", QcomSecurityPDSessionObj, tvm::ffi::Object); +}; + +class QcomSecurityPDSession : public tvm::ffi::ObjectRef { + public: + explicit QcomSecurityPDSession(mllm::qnn::aot::QcomSecurityPDSession& ptr) { + data_ = tvm::ffi::make_object(ptr); + } + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(QcomSecurityPDSession, tvm::ffi::ObjectRef, QcomSecurityPDSessionObj); // NOLINT +}; + +//===----------------------------------------------------------------------===// +// MLLM QcomTargetMachine Define +//===----------------------------------------------------------------------===// +class QcomTargetMachineObj : public tvm::ffi::Object { + public: + mllm::qnn::aot::QcomTargetMachine target_machine_; + + explicit QcomTargetMachineObj(const mllm::qnn::aot::QcomTargetMachine& obj) : target_machine_(obj) { MLLM_EMPTY_SCOPE; } + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("mllm.qualcomm.QcomTargetMachine", QcomTargetMachineObj, tvm::ffi::Object); +}; + +class QcomTargetMachine : public tvm::ffi::ObjectRef { + public: + explicit QcomTargetMachine(mllm::qnn::aot::QcomTargetMachine& ptr) { + data_ = tvm::ffi::make_object(ptr); + } + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(QcomTargetMachine, tvm::ffi::ObjectRef, QcomTargetMachineObj); // NOLINT +}; + #endif } // namespace mllm::ffi diff --git a/mllm/mllm.cpp b/mllm/mllm.cpp index 81aa44309..08a45aefd 100644 --- a/mllm/mllm.cpp +++ b/mllm/mllm.cpp @@ -82,6 +82,14 @@ bool isOpenCLAvailable() { return false; } +bool isQnnAOTOnX86Enabled() { +#ifdef MLLM_QUALCOMM_QNN_AOT_ON_X86_ENABLE + return true; +#else + return false; +#endif +} + bool isQnnAvailable() { #ifdef MLLM_QNN_BACKEND return true; diff --git a/mllm/mllm.hpp b/mllm/mllm.hpp index 9b5ec3581..4a07f0ee7 100644 --- a/mllm/mllm.hpp +++ b/mllm/mllm.hpp @@ -187,6 +187,8 @@ void memoryReport(); bool isOpenCLAvailable(); +bool isQnnAOTOnX86Enabled(); + extern void initOpenCLBackend(); extern void initCudaBackend(); diff --git a/mllm/models/qwen3/modeling_qwen3_qnn_aot.hpp b/mllm/models/qwen3/modeling_qwen3_qnn_aot.hpp new file mode 100644 index 000000000..bea991a55 --- /dev/null +++ b/mllm/models/qwen3/modeling_qwen3_qnn_aot.hpp @@ -0,0 +1,355 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/mllm.hpp" +#include "mllm/nn/Module.hpp" +#include "mllm/nn/Nn.hpp" +#include "mllm/nn/Functional.hpp" +#include "mllm/nn/lmcache/StaticCache.hpp" +#include "mllm/models/qwen3/configuration_qwen3.hpp" +#include "mllm/utils/Enumerate.hpp" +#include "mllm/models/ARGeneration.hpp" + +namespace mllm::models::qwen3 { + +inline auto makeRoPEInvFreq(int output_dim, float rope_theta) -> Tensor { + auto inv_freq = Tensor::empty({output_dim / 2}, kFloat32, kCPU).alloc(); + auto inv_freq_ptr = inv_freq.ptr(); + for (int i = 0; i < output_dim / 2; i++) { inv_freq_ptr[i] = 1.0 / std::pow(rope_theta, 2.0 * i / output_dim); } + return inv_freq; +} + +inline auto makeRotaryPosEmbedding(Tensor& position_ids, const Tensor& inv_freq, float attention_scaling = 1.0f) + -> std::pair { + auto batch_size = position_ids.shape()[0]; + auto seq_len = position_ids.shape()[1]; + auto inv_freq_len = inv_freq.shape()[0]; + auto dim = inv_freq_len * 2; + + // Create freqs tensor: position_ids @ inv_freq + auto freqs = Tensor::empty({batch_size, seq_len, inv_freq_len}, kFloat32, kCPU).alloc(); + auto freqs_ptr = freqs.ptr(); + auto position_ids_ptr = position_ids.ptr(); + auto inv_freq_ptr = inv_freq.ptr(); + + // Compute freqs = position_ids[:, :, None] @ inv_freq[None, :] + for (int b = 0; b < batch_size; ++b) { + for (int s = 0; s < seq_len; ++s) { + auto pos = position_ids_ptr[b * seq_len + s]; + for (int d = 0; d < inv_freq_len; ++d) { + freqs_ptr[b * seq_len * inv_freq_len + s * inv_freq_len + d] = static_cast(pos) * inv_freq_ptr[d]; + } + } + } + + // Create sin and cos tensors with shape [batch_size, seq_len, dim] + auto sin_emb = Tensor::empty({batch_size, seq_len, dim}, kFloat32, kCPU).alloc(); + auto cos_emb = Tensor::empty({batch_size, seq_len, dim}, kFloat32, kCPU).alloc(); + auto sin_ptr = sin_emb.ptr(); + auto cos_ptr = cos_emb.ptr(); + + // Compute sin and cos embeddings: emb = [freqs, freqs] + for (int b = 0; b < batch_size; ++b) { + for (int s = 0; s < seq_len; ++s) { + for (int d = 0; d < inv_freq_len; ++d) { + auto freq = freqs_ptr[b * seq_len * inv_freq_len + s * inv_freq_len + d]; + auto sin_val = std::sin(freq) * attention_scaling; + auto cos_val = std::cos(freq) * attention_scaling; + + // Store the same values in both halves: [freqs, freqs] + sin_ptr[b * seq_len * dim + s * dim + d] = sin_val; + sin_ptr[b * seq_len * dim + s * dim + d + inv_freq_len] = sin_val; + cos_ptr[b * seq_len * dim + s * dim + d] = cos_val; + cos_ptr[b * seq_len * dim + s * dim + d + inv_freq_len] = cos_val; + } + } + } + + return {sin_emb, cos_emb}; +} + +class Qwen3MLP final : public nn::Module { + nn::Linear gate_proj_; + nn::Linear up_proj_; + nn::Linear down_proj_; + nn::SiLU silu_; + + public: + Qwen3MLP() = default; + Qwen3MLP(const std::string& name, const Qwen3Config& cfg) : nn::Module(name) { + gate_proj_ = reg("gate_proj", cfg.hidden_size, cfg.intermediate_size, false, cfg.linear_impl_type); + silu_ = reg("act"); + up_proj_ = reg("up_proj", cfg.hidden_size, cfg.intermediate_size, false, cfg.linear_impl_type); + down_proj_ = reg("down_proj", cfg.intermediate_size, cfg.hidden_size, false, cfg.linear_impl_type); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = gate_proj_(inputs[0]); + x = silu_(x); + auto y = up_proj_(inputs[0]); + x = x * y; + x = down_proj_(x); + return {x}; + } +}; + +class Qwen3Attention final : public nn::Module { + nn::Linear q_proj_; + nn::Linear k_proj_; + nn::Linear v_proj_; + nn::Linear o_proj_; + nn::RMSNorm rms_norm_q_; + nn::RMSNorm rms_norm_k_; + nn::RoPE q_rope_; + nn::RoPE k_rope_; + nn::CausalMask mask_; + nn::Softmax softmax_; + + int hidden_size_; + int head_dim_; + int num_attention_heads_; + int num_key_value_heads_; + int num_key_value_groups_; + + public: + Qwen3Attention() = default; + + Qwen3Attention(const std::string& name, const Qwen3Config& cfg) : nn::Module(name) { + hidden_size_ = cfg.hidden_size; + num_attention_heads_ = cfg.num_attention_heads; + num_key_value_heads_ = cfg.num_key_value_heads; + head_dim_ = cfg.head_dim; + num_key_value_groups_ = num_attention_heads_ / num_key_value_heads_; + + q_proj_ = + reg("q_proj", hidden_size_, head_dim_ * num_attention_heads_, cfg.attention_bias, cfg.linear_impl_type); + k_proj_ = + reg("k_proj", hidden_size_, head_dim_ * num_key_value_heads_, cfg.attention_bias, cfg.linear_impl_type); + v_proj_ = + reg("v_proj", hidden_size_, head_dim_ * num_key_value_heads_, cfg.attention_bias, cfg.linear_impl_type); + o_proj_ = + reg("o_proj", head_dim_ * num_attention_heads_, hidden_size_, cfg.attention_bias, cfg.linear_impl_type); + + rms_norm_q_ = reg("q_norm", cfg.rms_norm_eps); + rms_norm_k_ = reg("k_norm", cfg.rms_norm_eps); + + q_rope_ = reg("q_rope", cfg.rope_theta, cfg.max_position_embeddings); + k_rope_ = reg("k_rope", cfg.rope_theta, cfg.max_position_embeddings); + + mask_ = reg("mask"); + softmax_ = reg("softmax", -1); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = inputs[0]; + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto past_kv_cache = args[0].get(); + + // [B, S, H * D] + auto query_states = q_proj_(x); + auto key_states = k_proj_(x); + auto value_states = v_proj_(x); + + int B = inputs[0].shape()[0]; + int S = inputs[0].shape()[1]; + + // [B, S, H, D] + query_states = query_states.view({B, S, num_attention_heads_, head_dim_}); + key_states = key_states.view({B, S, num_key_value_heads_, head_dim_}); + value_states = value_states.view({B, S, num_key_value_heads_, head_dim_}); + + // [B, S, H, D] + query_states = rms_norm_q_(query_states); + key_states = rms_norm_k_(key_states); + + // [B, H, S, D] + query_states = query_states.transpose(1, 2); + key_states = key_states.transpose(1, 2); + value_states = value_states.transpose(1, 2); + + // [B, H, S, D] + query_states = q_rope_(query_states, llm_embedding_sin, llm_embedding_cos); + key_states = k_rope_(key_states, llm_embedding_sin, llm_embedding_cos); + + // [B, H, S, D] + auto [key_states_new, value_states_new] = past_kv_cache->updateKVCache(layer_idx_, key_states, value_states); + key_states = key_states_new; + value_states = value_states_new; + + Tensor attn; + if (key_states.dtype() == kFloat32) { + // attention weight + // [B, H, S, S] + attn = nn::functional::matmul(query_states, key_states, false, true) * (1.f / sqrtf(head_dim_)); + attn = mask_(attn); + attn = softmax_(attn); + } else if (key_states.dtype() == kFloat16) { + attn = nn::functional::matmul(query_states.to(kFloat32), key_states.to(kFloat32), false, true) * (1.f / sqrtf(head_dim_)); + attn = mask_(attn); + attn = softmax_(attn); + attn = attn.to(kFloat16); + } + + // attn output + // [B, H, S, S] @ [B, H, S, D] -> [B, H, S, D] + auto output = nn::functional::matmul(attn, value_states); + // [B, H, S, D] -> [B, S, H, D] -> [B, S, H * D] + output = output.transpose(1, 2).view({B, S, num_attention_heads_ * head_dim_}); + output = o_proj_(output); + + return {output}; + } + + int layer_idx_; +}; + +class Qwen3Decoder final : public nn::Module { + public: + Qwen3Attention self_attn_; + Qwen3MLP mlp_; + nn::RMSNorm input_layer_norm_; + nn::RMSNorm post_attention_layer_norm_; + + Qwen3Decoder() = default; + + Qwen3Decoder(const std::string& name, const Qwen3Config& cfg) : nn::Module(name) { + self_attn_ = reg("self_attn", cfg); + mlp_ = reg("mlp", cfg); + input_layer_norm_ = reg("input_layernorm", cfg.rms_norm_eps); + post_attention_layer_norm_ = reg("post_attention_layernorm", cfg.rms_norm_eps); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto& kv_cache = args[0]; + + auto x = input_layer_norm_(inputs[0]); + x = self_attn_(x, llm_embedding_sin, llm_embedding_cos, kv_cache)[0]; + auto tmp = x + inputs[0]; + x = post_attention_layer_norm_(tmp); + x = mlp_(x)[0]; + x = x + tmp; + return {x}; + } +}; + +class Qwen3Text final : public nn::Module { + nn::ModuleList decode_blocks_; + nn::RMSNorm norm_; + nn::Embedding embedding_; + + public: + Qwen3Text() = default; + + Qwen3Text(const std::string& name, const Qwen3Config& cfg) : nn::Module(name) { + decode_blocks_ = reg>("layers", cfg.num_hidden_layers, cfg); + for (auto [idx, b] : enumerate(decode_blocks_.list())) { b.self_attn_.layer_idx_ = idx; } + norm_ = reg("norm", cfg.rms_norm_eps); + embedding_ = reg("embed_tokens", cfg.vocab_size, cfg.hidden_size); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto& blocks = decode_blocks_.list(); + + // X is already embedded + auto x = embedding_(inputs[0]); + + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto& kv_cache = args[0]; + + for (auto& block : blocks) { x = block(x, llm_embedding_sin, llm_embedding_cos, kv_cache)[0]; } + + x = norm_(x); + + return {x}; + } +}; + +class Qwen3ForCausalLM : public ARGeneration, public nn::Module { + public: + explicit Qwen3ForCausalLM(const Qwen3Config& cfg) : cfg(cfg) { + kv_cache_ = nn::StaticCache(cfg.max_cache_length, cfg.num_hidden_layers, + cfg.num_attention_heads, // q_heads + cfg.num_key_value_heads, // kv_heads + cfg.head_dim, // kv_dim + kFloat32, // k_dtype + kFloat32, // v_dtype + kCPU, // device_type + false // use_fa2 + ); + eos_token_id_ = cfg.end_of_text_token_id; + max_length_ = cfg.max_cache_length; + tie_word_embeddings_ = cfg.tie_word_embeddings; + + llm = reg("model", cfg); + + if (cfg.tie_word_embeddings) { + // NOTE: + // model.lm_head.weight is quantization weights of model.embed_tokens.weight + lm_head_ = reg("lm_head_out", cfg.hidden_size, cfg.vocab_size, false, cfg.linear_impl_type); + } + + // Init inv freq + auto inv = makeRoPEInvFreq(cfg.head_dim, cfg.rope_theta); + registerBuffer("inv_freq", inv); + } + + ARGenerationOutputPast forward(const ARGenerationOutputPast& input, const ARGenerationArgs& args) override { + auto sequence = input.at("sequence"); + + // Generate position_ids for the current sequence + auto batch_size = sequence.shape()[0]; + auto seq_len = sequence.shape()[1]; + + Tensor position_ids = Tensor::nil(); + if (input.count("position_ids")) { + // Use existing position_ids for decode phase + position_ids = input.at("position_ids"); + + // For decode phase, increment the last position + if (seq_len == 1) { + auto last_pos = *position_ids.offsettedPtr({0, position_ids.shape()[1] - 1}); + position_ids = Tensor::empty({batch_size, 1}, kInt64, kCPU).alloc(); + *position_ids.offsettedPtr({0, 0}) = last_pos + 1; + } + } else { + // Generate position_ids for prefill phase + position_ids = Tensor::empty({batch_size, seq_len}, kInt64, kCPU).alloc(); + auto position_ids_ptr = position_ids.ptr(); + for (int b = 0; b < batch_size; ++b) { + for (int s = 0; s < seq_len; ++s) { position_ids_ptr[b * seq_len + s] = s; } + } + } + + // Generate RoPE embeddings using the inv_freq buffer + auto [llm_embedding_sin, llm_embedding_cos] = makeRotaryPosEmbedding(position_ids, getBuffer("inv_freq"), 1.0f); + + sequence = llm(sequence, llm_embedding_sin, llm_embedding_cos, AnyValue(&kv_cache_))[0]; + + // clip x to one seq length + { + auto S = sequence.shape()[1]; + sequence = sequence[{kAll, {S - 1}, kAll}]; + } + if (tie_word_embeddings_) { sequence = lm_head_(sequence); } + + return { + {"sequence", sequence}, + {"position_ids", position_ids}, + }; + } + + inline nn::StaticCache& kvCache() { return kv_cache_; } + + private: + const Qwen3Config& cfg; + Qwen3Text llm; + nn::Linear lm_head_; + bool tie_word_embeddings_; + nn::StaticCache kv_cache_; +}; + +} // namespace mllm::models::qwen3 diff --git a/pymllm/backends/qualcomm/qnn_aot_env.py b/pymllm/backends/qualcomm/qnn_aot_env.py index 7737a5c02..af4bf2c1a 100644 --- a/pymllm/backends/qualcomm/qnn_aot_env.py +++ b/pymllm/backends/qualcomm/qnn_aot_env.py @@ -1 +1,9 @@ -from pymllm.ffi import QnnDeviceAndContext, QnnAOTEnv +from pymllm.ffi import ( + QnnDeviceAndContext, + QnnAOTEnv, + QcomChipset, + QcomHTPArch, + QcomSecurityPDSession, + QcomTargetMachine, + QcomTryBestPerformance, +) diff --git a/pymllm/ffi/__init__.py b/pymllm/ffi/__init__.py index d49a6fc93..17bd04c19 100644 --- a/pymllm/ffi/__init__.py +++ b/pymllm/ffi/__init__.py @@ -18,6 +18,10 @@ def echo(rec: str) -> None: return _ffi_api.echo(rec) +def is_qnn_aot_on_x86_enabled() -> bool: + return _ffi_api.is_qnn_aot_on_x86_enabled() + + def initialize_context() -> None: return _ffi_api.initialize_context() @@ -320,23 +324,261 @@ def load(self, pf: ParameterFile): return tvm_ffi.get_global_func("mllm.BaseOp.load")(self, pf) -@tvm_ffi.register_object("mllm.qualcomm.QnnDeviceAndContext") -class QnnDeviceAndContext(tvm_ffi.Object): - def __init__(self): - pass +if is_qnn_aot_on_x86_enabled(): + print("pymllm: is_qnn_aot_on_x86_enabled is true") + @tvm_ffi.register_object("mllm.qualcomm.QnnDeviceAndContext") + class QnnDeviceAndContext(tvm_ffi.Object): + def __init__(self): + pass -@tvm_ffi.register_object("mllm.qualcomm.QnnAOTEnv") -class QnnAOTEnv(tvm_ffi.Object): - def __init__(self, path=None): - if path is None or path == "": - self.__init_handle_by_constructor__(QnnAOTEnv.__create__, "") - else: - self.__init_handle_by_constructor__(QnnAOTEnv.__create__, path) + @tvm_ffi.register_object("mllm.qualcomm.QcomHTPArch") + class QcomHTPArch(tvm_ffi.Object): + def __init__(self): + pass + + @staticmethod + def NONE() -> QcomHTPArch: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomHTPArch.NONE")() + + @staticmethod + def V68() -> QcomHTPArch: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomHTPArch.V68")() + + @staticmethod + def V69() -> QcomHTPArch: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomHTPArch.V69")() + + @staticmethod + def V73() -> QcomHTPArch: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomHTPArch.V73")() + + @staticmethod + def V75() -> QcomHTPArch: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomHTPArch.V75")() + + @staticmethod + def V79() -> QcomHTPArch: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomHTPArch.V79")() + + @staticmethod + def V81() -> QcomHTPArch: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomHTPArch.V81")() + + @tvm_ffi.register_object("mllm.qualcomm.QcomChipset") + class QcomChipset(tvm_ffi.Object): + def __init__(self): + pass + + @staticmethod + def UNKNOWN_SM() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.UNKNOWN_SM")() + + @staticmethod + def SA8295() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SA8295")() + + @staticmethod + def SM8350() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SM8350")() + + @staticmethod + def SM8450() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SM8450")() + + @staticmethod + def SM8475() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SM8475")() + + @staticmethod + def SM8550() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SM8550")() + + @staticmethod + def SM8650() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SM8650")() + + @staticmethod + def SM8750() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SM8750")() + + @staticmethod + def SM8850() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SM8850")() + + @staticmethod + def SSG2115P() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SSG2115P")() + + @staticmethod + def SSG2125P() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SSG2125P")() + + @staticmethod + def SXR1230P() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SXR1230P")() + + @staticmethod + def SXR2230P() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SXR2230P")() + + @staticmethod + def SXR2330P() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SXR2330P")() + + @staticmethod + def QCS9100() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.QCS9100")() + + @staticmethod + def SAR2230P() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SAR2230P")() + + @staticmethod + def SA8255() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SA8255")() + + @staticmethod + def SW6100() -> QcomChipset: + return tvm_ffi.get_global_func("mllm.qualcomm.QcomChipset.SW6100")() + + @tvm_ffi.register_object("mllm.qualcomm.QcomTryBestPerformance") + class QcomTryBestPerformance(tvm_ffi.Object): + def __init__(self): + pass + + @staticmethod + def HtpDefault() -> QcomTryBestPerformance: + return tvm_ffi.get_global_func( + "mllm.qualcomm.QcomTryBestPerformance.HtpDefault" + )() + + @staticmethod + def HtpSustainedHighPerformance() -> QcomTryBestPerformance: + return tvm_ffi.get_global_func( + "mllm.qualcomm.QcomTryBestPerformance.HtpSustainedHighPerformance" + )() + + @staticmethod + def HtpBurst() -> QcomTryBestPerformance: + return tvm_ffi.get_global_func( + "mllm.qualcomm.QcomTryBestPerformance.HtpBurst" + )() + + @staticmethod + def HtpHighPerformance() -> QcomTryBestPerformance: + return tvm_ffi.get_global_func( + "mllm.qualcomm.QcomTryBestPerformance.HtpHighPerformance" + )() + + @staticmethod + def HtpPowerSaver() -> QcomTryBestPerformance: + return tvm_ffi.get_global_func( + "mllm.qualcomm.QcomTryBestPerformance.HtpPowerSaver" + )() + + @staticmethod + def HtpLowPowerSaver() -> QcomTryBestPerformance: + return tvm_ffi.get_global_func( + "mllm.qualcomm.QcomTryBestPerformance.HtpLowPowerSaver" + )() + + @staticmethod + def HtpHighPowerSaver() -> QcomTryBestPerformance: + return tvm_ffi.get_global_func( + "mllm.qualcomm.QcomTryBestPerformance.HtpHighPowerSaver" + )() + + @staticmethod + def HtpLowBalanced() -> QcomTryBestPerformance: + return tvm_ffi.get_global_func( + "mllm.qualcomm.QcomTryBestPerformance.HtpLowBalanced" + )() + + @staticmethod + def HtpBalanced() -> QcomTryBestPerformance: + return tvm_ffi.get_global_func( + "mllm.qualcomm.QcomTryBestPerformance.HtpBalanced" + )() + + @tvm_ffi.register_object("mllm.qualcomm.QcomSecurityPDSession") + class QcomSecurityPDSession(tvm_ffi.Object): + def __init__(self): + pass + + @staticmethod + def HtpUnsignedPd() -> QcomSecurityPDSession: + return tvm_ffi.get_global_func( + "mllm.qualcomm.QcomSecurityPDSession.HtpUnsignedPd" + )() + + @staticmethod + def HtpSignedPd() -> QcomSecurityPDSession: + return tvm_ffi.get_global_func( + "mllm.qualcomm.QcomSecurityPDSession.HtpSignedPd" + )() + + @tvm_ffi.register_object("mllm.qualcomm.QcomTargetMachine") + class QcomTargetMachine(tvm_ffi.Object): + def __init__( + self, + soc_htp_chipset: QcomChipset, + soc_htp_arch: QcomHTPArch, + soc_htp_performance: QcomTryBestPerformance, + soc_htp_security_pd_session: QcomSecurityPDSession, + soc_htp_vtcm: int, + ): + self.__init_handle_by_constructor__( + QcomTargetMachine.__create__, + soc_htp_chipset, + soc_htp_arch, + soc_htp_performance, + soc_htp_security_pd_session, + soc_htp_vtcm, + ) + + @tvm_ffi.register_object("mllm.qualcomm.QnnAOTEnv") + class QnnAOTEnv(tvm_ffi.Object): + def __init__( + self, + machine: QcomTargetMachine | None = None, + path: str = None, + ): + if machine is None: + raise ValueError("QnnAOTEnv requires a non-None QcomTargetMachine") + if path is None or path == "": + self.__init_handle_by_constructor__(QnnAOTEnv.__create__, machine, "") + else: + self.__init_handle_by_constructor__(QnnAOTEnv.__create__, machine, path) + + def create_context( + self, name: str, weights_sharing: bool = False + ) -> QnnDeviceAndContext: + return tvm_ffi.get_global_func("mllm.qualcomm.QnnAOTEnv.createContext")( + self, name, weights_sharing + ) + + +# ============================================================================= +# Mllm Ops Binding +# +# ============================================================================= +@tvm_ffi.register_object("mllm.aops.SoftmaxOpOptions") +class SoftmaxOpOptions(tvm_ffi.Object): + def __init__(self, dim=-1): + super().__init__() + self.__init_handle_by_constructor__(SoftmaxOpOptions.__create__, dim) + + +@tvm_ffi.register_object("mllm.aops.SoftmaxOp") +class SoftmaxOp(BaseOp): + def __init__(self): + super().__init__() - def create_context(self, name: str) -> QnnDeviceAndContext: - return tvm_ffi.get_global_func("mllm.qualcomm.QnnAOTEnv.createContext")( - self, name + @staticmethod + def create(device: Device, options: SoftmaxOpOptions): + return tvm_ffi.get_global_func("mllm.aops.__ctx_create_softmax_op")( + device, options ) diff --git a/pymllm/nn/__init__.py b/pymllm/nn/__init__.py index baf5de492..f6be0719b 100644 --- a/pymllm/nn/__init__.py +++ b/pymllm/nn/__init__.py @@ -3,4 +3,4 @@ from . import functional from ._module import Module -from ._layers import Linear +from ._layers import Linear, Softmax diff --git a/pymllm/nn/_layers.py b/pymllm/nn/_layers.py index 4dc6c0d66..5adc79cbf 100644 --- a/pymllm/nn/_layers.py +++ b/pymllm/nn/_layers.py @@ -1,11 +1,13 @@ # Copyright (c) MLLM Team. # Licensed under the MIT License. +import tvm_ffi from .. import ffi class _Layer: def __init__(self): + self.device: ffi.Device = ffi.cpu_() self.this_layer_name: str = None self.absolute_name: str = None self._mllm_c_op_ptr: ffi.BaseOp = None @@ -18,22 +20,50 @@ def load(self, pf: ffi.ParameterFile): def trace(self): pass - def forward(self): - # TODO dispatch op - pass + def forward(self, *args): + inputs = [] + for arg in args: + if isinstance(arg, (ffi.Tensor)): + inputs.append(arg) + else: + print( + f"The layer's forward function received a none Tensor type of {type(arg)}. Which is not supported." + ) + ret = tvm_ffi.get_global_func("mllm.engine.dispatch")( + self.device, self._mllm_c_op_ptr, inputs + ) + if len(ret) == 1: + return ret[0] + return ret def __call__(self, *args, **kwds): - pass + return self.forward(*args) + + def __repr__(self): + return "_Layer" class Linear(_Layer): def __init__(self): super().__init__() + def __repr__(self): + return "nn.Linear" + class Softmax(_Layer): - def __init__(self): + def __init__( + self, + dim=-1, + ): super().__init__() + self.dim = dim + self._mllm_c_op_ptr = ffi.SoftmaxOp.create( + self.device, ffi.SoftmaxOpOptions(dim) + ) + + def __repr__(self): + return f"mllm.aops.Softmax(dim={self.dim})" class RoPE(_Layer): diff --git a/pymllm/nn/_module.py b/pymllm/nn/_module.py index 309721ec2..a4f28e17c 100644 --- a/pymllm/nn/_module.py +++ b/pymllm/nn/_module.py @@ -6,13 +6,44 @@ class Module: - def __init__(self): - self.this_module_name: str = None - self.absolute_name: str = None - self.module_layer_list: list = [] + def __init__(self, name: str = "model"): + super().__setattr__("this_module_name", name) + super().__setattr__("absolute_name", name) + super().__setattr__("module_layer_list", {}) + super().__setattr__("_is_initializing", True) + self.device: ffi.Device = ffi.cpu_() + + def __setattr__(self, name: str, value): + super().__setattr__(name, value) + if ( + getattr(self, "_is_initializing", False) + and not name.startswith("_") + and isinstance(value, (Module, _Layer)) + ): + value.this_module_name = name + value.absolute_name = f"{self.absolute_name}.{name}" + self.module_layer_list[name] = value + if isinstance(value, Module): + value._is_initializing = True + + def to(self, x): + if isinstance(x, str): + if x == "qnn": + pass + self.device = ffi.device(x) + elif isinstance(x, ffi.Device): + self.device = x + elif isinstance(x, ffi.DType): + raise NotImplementedError("Module.to(DType) is not supported") + else: + raise TypeError("device must be str or Device, but got {}".format(type(x))) + return self + + def re_naming_finish_initialization(self): + self._is_initializing = False def load(self, pf: ffi.ParameterFile): - for module_layer in self.module_layer_list: + for module_layer in self.module_layer_list.values(): if isinstance(module_layer, Module, _Layer): module_layer.load(pf) else: @@ -22,11 +53,10 @@ def load(self, pf: ffi.ParameterFile): ) ) - def trace(self): + def trace(self, *args): pass def forward(self, *args): - # TODO send to engine's dispatcher pass def __call__(self, *args, **kwds): @@ -35,3 +65,31 @@ def __call__(self, *args, **kwds): return self.trace(*args, **kwds) return self.forward(*args, **kwds) # __send_graph_end() + + def __str__(self): + return self._repr_helper() + + def __repr__(self): + return self._repr_helper() + + def _repr_helper(self, indent_level: int = 0) -> str: + indent = " " * indent_level + next_indent = " " * (indent_level + 1) + module_str = f"{self.__class__.__name__}(" + if self.module_layer_list: + child_lines = [] + for name, child in self.module_layer_list.items(): + if isinstance(child, Module): + child_repr = child._repr_helper(indent_level + 1) + child_lines.append(f"{next_indent}({name}): {child_repr}") + elif isinstance(child, _Layer): + child_lines.append(f"{next_indent}({name}): {repr(child)}") + else: + child_lines.append( + f"{next_indent}({name}): {type(child).__name__}()" + ) + module_str += "\n" + "\n".join(child_lines) + "\n" + indent + ")" + else: + module_str += ")" + + return module_str diff --git a/pymllm/tests/qualcomm/test_context_create.py b/pymllm/tests/qualcomm/test_context_create.py index f34ef2393..18983daa7 100644 --- a/pymllm/tests/qualcomm/test_context_create.py +++ b/pymllm/tests/qualcomm/test_context_create.py @@ -1,8 +1,28 @@ import pymllm as mllm -from pymllm.backends.qualcomm.qnn_aot_env import QnnAOTEnv, QnnDeviceAndContext +from pymllm.backends.qualcomm.qnn_aot_env import ( + QnnAOTEnv, + QnnDeviceAndContext, + QcomTryBestPerformance, + QcomSecurityPDSession, + QcomTargetMachine, + QcomChipset, + QcomHTPArch, +) -qnn_aot_env: QnnAOTEnv = QnnAOTEnv() + +qnn_aot_env: QnnAOTEnv = QnnAOTEnv( + machine=QcomTargetMachine( + soc_htp_chipset=QcomChipset.SM8850(), + soc_htp_arch=QcomHTPArch.V81(), + soc_htp_performance=QcomTryBestPerformance.HtpBurst(), + soc_htp_security_pd_session=QcomSecurityPDSession.HtpUnsignedPd(), + soc_htp_vtcm=8, # in MB + ), + path="/opt/qcom/aistack/qairt/2.41.0.251128/lib/x86_64-linux-clang/", +) if __name__ == "__main__": - mllm.echo("Testing mllm's tvm-ffi abi compatibility") - qnn_context: QnnDeviceAndContext = qnn_aot_env.create_context("model.layer.0") + mllm.echo("Testing tvm-ffi compatibility") + qnn_context: QnnDeviceAndContext = qnn_aot_env.create_context( + "context.0", weights_sharing=False + ) diff --git a/pymllm/tests/test_nn.py b/pymllm/tests/test_nn.py new file mode 100644 index 000000000..d9a3db2d8 --- /dev/null +++ b/pymllm/tests/test_nn.py @@ -0,0 +1,19 @@ +import pymllm as mllm +from pymllm import nn + + +class FooModule(nn.Module): + def __init__(self): + super().__init__() + self.sf = nn.Softmax(dim=-1) + + def forward(self, x): + x = self.sf(x) + return x + + +if __name__ == "__main__": + x = mllm.ones([6, 10]) + foo = FooModule() + print(foo) + print(foo(x)) diff --git a/tasks/build_sdk_x86_qnn_aot.yaml b/tasks/build_sdk_x86_qnn_aot.yaml new file mode 100644 index 000000000..f33281616 --- /dev/null +++ b/tasks/build_sdk_x86_qnn_aot.yaml @@ -0,0 +1,21 @@ +Tasks: + - CMakeConfigTask: + cmake_cfg_path: "build-qnn-aot" + cmake_build_type: "Release" + cmake_extra_args: + # Optional, If use Highway + - "-DHWY_ENABLE_TESTS=OFF" + - "-DHWY_ENABLE_EXAMPLES=OFF" + - "-DHWY_ENABLE_CONTRIB=OFF" + # Optional + - '-DMLLM_CPU_BACKEND_COMPILE_OPTIONS="-march=native"' + - "-DMLLM_KERNEL_USE_THREADS=ON" + - "-DMLLM_KERNEL_THREADS_VENDOR_OPENMP=ON" + - "-DMLLM_KERNEL_USE_THREADS_VENDOR_MLLM=OFF" + - "-DMLLM_QUALCOMM_QNN_AOT_ON_X86_ENABLE=ON" + - "-DCMAKE_INSTALL_PREFIX=./mllm-sdk-x86-qnn-aot" + + - CMakeBuildTask: + cmake_cfg_path: "build-qnn-aot" + - CMakeInstallTask: + cmake_cfg_path: "build-qnn-aot"