diff --git a/examples/qwen3_qnn_aot/qnn_aot_cfg.json b/examples/qwen3_qnn_aot/qnn_aot_cfg.json index 349554271..7511686b7 100644 --- a/examples/qwen3_qnn_aot/qnn_aot_cfg.json +++ b/examples/qwen3_qnn_aot/qnn_aot_cfg.json @@ -15,6 +15,7 @@ "split_graph": 1, "quant_recipe": { "llm_recipe": true, + "layers": 28, "builtin_llm_pass": { "model": "qwen3", "lm_head": { diff --git a/mllm/backends/qnn/aot/QnnWrappersAPI.cpp b/mllm/backends/qnn/aot/QnnWrappersAPI.cpp index 5644821fe..0fd354de3 100644 --- a/mllm/backends/qnn/aot/QnnWrappersAPI.cpp +++ b/mllm/backends/qnn/aot/QnnWrappersAPI.cpp @@ -4,18 +4,19 @@ #include +#include #include #include #include #include -#include "mllm/backends/qnn/aot/passes/AOTCompileContext.hpp" -#include "mllm/core/DataTypes.hpp" #include "mllm/utils/Common.hpp" +#include "mllm/core/DataTypes.hpp" #include "mllm/backends/qnn/QNNTypeMacros.hpp" #include "mllm/compile/ir/linalg/Attribute.hpp" #include "mllm/backends/qnn/aot/QnnWrappersAPI.hpp" #include "mllm/backends/qnn/aot/QnnTargetMachine.hpp" +#include "mllm/backends/qnn/aot/passes/AOTCompileContext.hpp" namespace mllm::qnn::aot { @@ -139,13 +140,6 @@ Qnn_Param_t* QnnAOTParamTensor::getQnnParam() { return &qnn_param_; } Qnn_Tensor_t* QnnAOTParamTensor::getQnnTensor() { return &qnn_param_.tensorParam; } QnnAOTNodeTensor::QnnAOTNodeTensor(const ir::tensor::TensorValue::ptr_t& v, bool force_static_weight) { - // TODO Constant value should also use Static!!! And they can be pruned - // TODO Constant value should also use Static!!! And they can be pruned - // TODO Constant value should also use Static!!! And they can be pruned - // TODO Constant value should also use Static!!! And they can be pruned - // TODO Constant value should also use Static!!! And they can be pruned - // TODO Constant value should also use Static!!! And they can be pruned - name_ = v->name(); mllm_tensor_ = v->tensor_; quant_spec_ = v->getAttr("quant_recipe")->cast_()->spec_; @@ -232,6 +226,7 @@ Qnn_TensorType_t QnnAOTNodeTensor::parseQnnTensorTypeFromIR(const ir::tensor::Te // Check Attribute. The Attribute priority is higher than tensor type if (v->getAttr("qnn_graph_outputs")) { ret_qnn_tensor_type = QNN_TENSOR_TYPE_APP_READ; } if (v->getAttr("qnn_graph_inputs")) { ret_qnn_tensor_type = QNN_TENSOR_TYPE_APP_READWRITE; } + if (v->getAttr("constant")) { ret_qnn_tensor_type = QNN_TENSOR_TYPE_STATIC; } return ret_qnn_tensor_type; } @@ -470,6 +465,17 @@ QnnAOTNodeOperation::ptr_t QnnAOTNodeOperation::setPackageName(const std::string return shared_from_this(); } +QnnAOTGraph::QnnAOTGraph(const std::string& g_name, const std::shared_ptr& context) + : graph_name_(g_name), qnn_context_(context) { + belongs_context_name_ = context->name_; + + auto env = AOTCompileContext::getInstance().getEnv(); + auto qnn_interface = env->getFuncSymbol().qnn_interface_; + + auto ok = qnn_interface.graphCreate(context->qnn_ctx_handle_, g_name.c_str(), nullptr /*graph_config*/, &qnn_graph_handle_); + MLLM_RT_ASSERT_EQ(ok, QNN_SUCCESS); +} + void QnnAOTGraph::addOperation(const QnnAOTNodeOperation::ptr_t& qnn_op) { auto env = AOTCompileContext::getInstance().getEnv(); auto qnn_interface = env->getFuncSymbol().qnn_interface_; @@ -481,20 +487,52 @@ void QnnAOTGraph::addOperation(const QnnAOTNodeOperation::ptr_t& qnn_op) { qnn_op_config.v1.packageName = qnn_op->package_name_.c_str(); qnn_op_config.v1.typeName = qnn_op->op_name_.c_str(); - // TODO PARAMs - // TODO Inputs - // TODO Outputs + // Params + uint32_t param_counter = 0; + size_t total_param_size = qnn_op->param_scalar.size() + qnn_op->param_tensor.size(); + Qnn_Param_t* qnn_param_array = (Qnn_Param_t*)malloc(total_param_size * sizeof(Qnn_Param_t)); + qnn_op->unreachable_handle_.emplace_back(qnn_param_array); + { + // Tensor Param + for (const auto& p : qnn_op->param_tensor) { + auto ok = qnn_interface.tensorCreateGraphTensor(qnn_graph_handle_, p->getQnnTensor()); + MLLM_RT_ASSERT_EQ(ok, QNN_SUCCESS); + qnn_param_array[param_counter++] = *p->getQnnParam(); + } + for (const auto& p : qnn_op->param_scalar) { qnn_param_array[param_counter++] = *p->getQnnParam(); } + } + + // Inputs + Qnn_Tensor_t* qnn_inputs_array = (Qnn_Tensor_t*)malloc(qnn_op->inputs.size() * sizeof(Qnn_Tensor_t)); + qnn_op->unreachable_handle_.emplace_back(qnn_inputs_array); + for (int i = 0; i < qnn_op->inputs.size(); ++i) { qnn_inputs_array[i] = *qnn_op->inputs[i]->getQnnTensor(); } - // TODO node validations + // Outputs + Qnn_Tensor_t* qnn_outputs_array = (Qnn_Tensor_t*)malloc(qnn_op->outputs.size() * sizeof(Qnn_Tensor_t)); + qnn_op->unreachable_handle_.emplace_back(qnn_outputs_array); + for (int i = 0; i < qnn_op->outputs.size(); ++i) { qnn_outputs_array[i] = *qnn_op->outputs[i]->getQnnTensor(); } - // TODO add node to graph. + qnn_op_config.v1.params = qnn_param_array; + qnn_op_config.v1.numOfParams = total_param_size; + qnn_op_config.v1.inputTensors = qnn_inputs_array; + qnn_op_config.v1.numOfInputs = qnn_op->inputs.size(); + qnn_op_config.v1.outputTensors = qnn_outputs_array; + qnn_op_config.v1.numOfOutputs = qnn_op->outputs.size(); + + auto ok = qnn_interface.backendValidateOpConfig(env->getContext(belongs_context_name_)->bk_handle_, qnn_op_config); + MLLM_RT_ASSERT_EQ(ok, QNN_SUCCESS); + ok = qnn_interface.graphAddNode(qnn_graph_handle_, qnn_op_config); + MLLM_RT_ASSERT_EQ(ok, QNN_SUCCESS); op_node_.insert({qnn_op->getName(), qnn_op}); } bool QnnAOTGraph::compile() { if (is_compiled_) { return true; } - // TODO + + auto env = AOTCompileContext::getInstance().getEnv(); + auto qnn_interface = env->getFuncSymbol().qnn_interface_; + qnn_interface.graphFinalize(qnn_graph_handle_, env->getContext(belongs_context_name_)->profile_bk_handle_, nullptr); is_compiled_ = true; return true; @@ -692,25 +730,6 @@ std::shared_ptr QnnAOTEnv::createContext(const std::string& // clang-format off { // FIXME(wch): we need to register our own opset of qnn. - // struct OpPackageInfo { - // std::string path; - // std::string interface_provider; - // std::string target; - // }; - - // std::vector op_packages = { - // {.path = "libQnnMllmPackageCPU.so", .interface_provider = "MllmPackageInterfaceProvider", .target = "CPU"}, - // {.path = "libQnnMllmPackageHTP.so", .interface_provider = "MllmPackageInterfaceProvider", .target = "HTP"}, - // }; - - // for (const auto& pkg : op_packages) { - // if (!qnn_htp_func_symbols_.qnn_interface_.backendRegisterOpPackage) { - // MLLM_ERROR_EXIT(ExitCode::kCoreError, "qnn_htp_func_symbols_.qnn_interface_.backendRegisterOpPackage is nullptr."); - // } - // auto status = qnn_htp_func_symbols_.qnn_interface_.backendRegisterOpPackage(context->bk_handle_, pkg.path.c_str(), pkg.interface_provider.c_str(), pkg.target.c_str()); - // MLLM_RT_ASSERT_EQ(status, QNN_BACKEND_NO_ERROR); - // MLLM_INFO("QNN Registered op package: {}, interface provider: {}, target: {}", pkg.path, pkg.interface_provider, pkg.target); - // } } // clang-format on @@ -800,8 +819,11 @@ std::vector QnnAOTEnv::createContextCustomConfig(bool } QnnAOTGraph::ptr_t QnnAOTEnv::captureAOTGraph(const std::string& qnn_context_name, const std::string& g_name) { - // TODO - return nullptr; + MLLM_RT_ASSERT(contexts_.count(qnn_context_name) == 1); + auto ret = QnnAOTGraph::create(g_name, contexts_[qnn_context_name]); + ret->belongs_context_name_ = qnn_context_name; + contexts_[qnn_context_name]->graphs_.insert({g_name, ret}); + return ret; } void QnnAOTEnv::captureAOTNodeOp(const std::string& qnn_context_name, const std::string& graph_name, @@ -813,18 +835,13 @@ void QnnAOTEnv::captureAOTNodeOp(const std::string& qnn_context_name, const std: QnnAOTNodeTensor::ptr_t QnnAOTEnv::captureQnnAOTNodeTensor(const std::string& qnn_context_name, const std::string& graph_name, const ir::tensor::TensorValue::ptr_t& v, bool force_static_weight) { - // TODO Constant value should also use Static!!! And they can be pruned - // TODO Constant value should also use Static!!! And they can be pruned - // TODO Constant value should also use Static!!! And they can be pruned - // TODO Constant value should also use Static!!! And they can be pruned - // TODO Constant value should also use Static!!! And they can be pruned - // TODO Constant value should also use Static!!! And they can be pruned auto __qnn_tensor_name = v->name(); bool __qnn_enable_static_weight = force_static_weight; // Check if this value want static qnn weight. The static qnn weight will be shared through one context in diff graphs! - if (v->tensor_.memType() == kGlobal || (v->tensor_.memType() <= kParams_End && v->tensor_.memType() >= kParams_Start)) { + if (v->tensor_.memType() == kGlobal || (v->tensor_.memType() <= kParams_End && v->tensor_.memType() >= kParams_Start) + || v->getAttr("constant")) { __qnn_enable_static_weight = true; } @@ -848,11 +865,17 @@ QnnAOTNodeTensor::ptr_t QnnAOTEnv::captureQnnAOTNodeTensor(const std::string& qn auto ret = QnnAOTNodeTensor::create(v, __qnn_enable_static_weight); if (__qnn_enable_static_weight) { contexts_[qnn_context_name]->static_tensor_.insert({__qnn_tensor_name, ret}); + qnn_htp_func_symbols_.qnn_interface_.tensorCreateContextTensor(contexts_[qnn_context_name]->qnn_ctx_handle_, + ret->getQnnTensor()); } else { contexts_[qnn_context_name]->graphs_[graph_name]->all_tensors_.insert({__qnn_tensor_name, ret}); + qnn_htp_func_symbols_.qnn_interface_.tensorCreateGraphTensor( + contexts_[qnn_context_name]->graphs_[graph_name]->qnn_graph_handle_, ret->getQnnTensor()); } return ret; } +std::shared_ptr QnnAOTEnv::getContext(const std::string& name) { return contexts_[name]; } + } // namespace mllm::qnn::aot diff --git a/mllm/backends/qnn/aot/QnnWrappersAPI.hpp b/mllm/backends/qnn/aot/QnnWrappersAPI.hpp index d06e84f33..718c0219b 100644 --- a/mllm/backends/qnn/aot/QnnWrappersAPI.hpp +++ b/mllm/backends/qnn/aot/QnnWrappersAPI.hpp @@ -106,6 +106,8 @@ class QnnAOTNodeTensor : public std::enable_shared_from_this { explicit QnnAOTNodeTensor(const ir::tensor::TensorValue::ptr_t& v, bool force_static_weight = false); + inline Qnn_Tensor_t* getQnnTensor() { return &qnn_tensor_; } + private: Qnn_TensorType_t parseQnnTensorTypeFromIR(const ir::tensor::TensorValue::ptr_t& v); @@ -171,10 +173,18 @@ class QnnAOTNodeOperation : public std::enable_shared_from_this unreachable_handle_; }; +struct QnnDeviceAndContext; class QnnAOTGraph : public std::enable_shared_from_this { public: using ptr_t = std::shared_ptr; + QnnAOTGraph(const std::string& g_name, const std::shared_ptr& context); + + static inline ptr_t create(const std::string& g_name, const std::shared_ptr& context) { + auto ret = std::make_shared(g_name, context); + return ret; + } + void addOperation(const QnnAOTNodeOperation::ptr_t& qnn_op); bool compile(); @@ -183,13 +193,15 @@ class QnnAOTGraph : public std::enable_shared_from_this { std::unordered_map op_node_; std::unordered_map all_tensors_; - private: std::string graph_name_; std::string belongs_context_name_; Qnn_GraphHandle_t qnn_graph_handle_ = nullptr; + std::shared_ptr qnn_context_ = nullptr; }; struct QnnDeviceAndContext { + using ptr_t = std::shared_ptr; + std::string name_; Qnn_LogHandle_t log_ = nullptr; Qnn_BackendHandle_t bk_handle_ = nullptr; @@ -283,6 +295,8 @@ class QnnAOTEnv { inline QnnFuncSymbols& getFuncSymbol() { return qnn_htp_func_symbols_; } + std::shared_ptr getContext(const std::string& name); + private: void _setup(const std::string& path = ""); diff --git a/mllm/backends/qnn/aot/passes/AOTPipeline.cpp b/mllm/backends/qnn/aot/passes/AOTPipeline.cpp index d89e90bdb..b1caa2d13 100644 --- a/mllm/backends/qnn/aot/passes/AOTPipeline.cpp +++ b/mllm/backends/qnn/aot/passes/AOTPipeline.cpp @@ -1,10 +1,12 @@ #include "mllm/backends/qnn/aot/passes/AOTPipeline.hpp" #include "mllm/backends/qnn/aot/passes/AOTCompileContext.hpp" +#include "mllm/backends/qnn/aot/passes/LLM2QnnLoweringPass.hpp" #include "mllm/backends/qnn/aot/passes/LLMQuantRecipePass.hpp" #include "mllm/backends/qnn/aot/passes/MarkQnnGraphPass.hpp" #include "mllm/backends/qnn/aot/passes/MarkTensorIO.hpp" #include "mllm/backends/qnn/aot/passes/MergeLLMHeadIntoMainGraphPass.hpp" #include "mllm/backends/qnn/aot/passes/OpNamingPass.hpp" +#include "mllm/backends/qnn/aot/passes/SplitLLMGraphPass.hpp" namespace mllm::qnn::aot { std::vector> createQnnAOTLoweringPipeline(QnnAOTEnv* env, const std::string& config_path) { @@ -20,6 +22,9 @@ std::vector> createQnnAOTLoweringPipeline(QnnAOTEnv* e ret.emplace_back(createOpNamingPass()); ret.emplace_back(createMergeLLMHeadIntoMainGraphPass()); ret.emplace_back(createLLMQuantRecipePass()); + ret.emplace_back(createSplitLLMGraphPass()); + ret.emplace_back(createMarkTensorIOPass()); + ret.emplace_back(createLLM2QnnLoweringPass()); } else { MLLM_WARN("This pass currently only supports LLM applications. Please ensure your config contains 'quant_recipe.llm_recipe " "= true'."); diff --git a/mllm/backends/qnn/aot/passes/LLM2QnnLoweringPass.cpp b/mllm/backends/qnn/aot/passes/LLM2QnnLoweringPass.cpp new file mode 100644 index 000000000..78fa9d2f3 --- /dev/null +++ b/mllm/backends/qnn/aot/passes/LLM2QnnLoweringPass.cpp @@ -0,0 +1,168 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include + +#include "mllm/backends/qnn/aot/passes/LLM2QnnLoweringPass.hpp" +#include "mllm/backends/qnn/aot/passes/AOTCompileContext.hpp" +#include "mllm/compile/ir/builtin/Op.hpp" +#include "mllm/compile/ir/graph/Op.hpp" +#include "mllm/compile/ir/linalg/Op.hpp" +#include "mllm/compile/ir/Node.hpp" +#include "mllm/compile/passes/Pass.hpp" +#include "mllm/utils/Common.hpp" +#include "mllm/backends/qnn/aot/visitor/Elewise.hpp" + +namespace mllm::qnn::aot { + +LLM2QnnLoweringPass::LLM2QnnLoweringPass() { + named_pattern_.insert(QnnAOTAddPattern::create()); + // TODO reg other patterns here. + // TODO reg other patterns here. + // TODO reg other patterns here. + // TODO reg other patterns here. + // TODO reg other patterns here. + // TODO reg other patterns here. +} + +uint8_t LLM2QnnLoweringPass::run(const ir::node_ptr_t& op) { + // The top op should be modelOp + MLLM_RT_ASSERT(op->isa_()); + + auto model_op = op->cast_(); + auto writer = ir::IRWriter(getCtx(), model_op->getTopRegion()); + + // Check only has 1 call graph op in model_op + ir::graph::CallGraphOp::ptr_t call_graph_op = nullptr; + writer.walk( + [&](ir::IRWriter& /*writer*/, const ir::graph::CallGraphOp::ptr_t& call_op) -> ir::IRWriter::WalkResult { + MLLM_RT_ASSERT(call_graph_op == nullptr); // Should only have one CallGraphOp + call_graph_op = call_op; + return ir::IRWriter::WalkResult::WALK_CONTINUE; + }); + + if (call_graph_op == nullptr) { + MLLM_ERROR("LLM2QnnLoweringPass: No CallGraphOp found in ModuleOp"); + return ir::PASS_RET_FAILURE; + } + + // Check call graph op point to a subgraph named "model" + auto symbol_attr = call_graph_op->getSymbolAttr(); + if (symbol_attr == nullptr || symbol_attr->str() != "model") { + MLLM_ERROR("LLM2QnnLoweringPass: CallGraphOp should point to a subgraph named 'model'"); + return ir::PASS_RET_FAILURE; + } + + // Get the "model" subgraph + auto model_subgraph = getCtx()->lookupSymbolTable("model")->cast_(); + if (model_subgraph == nullptr) { + MLLM_ERROR("LLM2QnnLoweringPass: Cannot find 'model' subgraph in symbol table"); + return ir::PASS_RET_FAILURE; + } + + // Collect all subgraphs from the modelOp's top region + std::unordered_map subgraphs; + + for (auto& region_op : model_op->getTopRegion()->ops()) { + if (auto sub_graph_op = std::dynamic_pointer_cast(region_op)) { + auto symbol_attr = sub_graph_op->getSymbolAttr(); + if (symbol_attr) { subgraphs[symbol_attr->str()] = sub_graph_op; } + } + } + + // Validate that we only have the expected subgraphs: model, model.0.s32, model.1.s16, etc. + // Pattern: model.x.sN where x is a number and N can be 16, 32, 64, 96, etc. + std::regex model_pattern(R"(^model(\.\d+\.s\d+)?$)"); + for (const auto& [name, _] : subgraphs) { + if (!std::regex_match(name, model_pattern)) { + MLLM_ERROR("LLM2QnnLoweringPass: Unexpected subgraph name {}, expected pattern: model or model.x.sx", name); + return ir::PASS_RET_FAILURE; + } + } + + // Store subgraphs in the member variable + subgraph_map_.clear(); + for (const auto& [name, subgraph] : subgraphs) { + if (name != "model") { subgraph_map_[name] = subgraph; } + } + + // Validate that at least one model.x.sN subgraph exists (required for the lowering) + // We don't require specifically model.0.s32, but any model.x.sN pattern + bool has_valid_subgraph = false; + for (const auto& [name, _] : subgraph_map_) { + if (std::regex_match(name, std::regex(R"(^model\.\d+\.s\d+$)"))) { + has_valid_subgraph = true; + break; + } + } + + if (!has_valid_subgraph) { + MLLM_ERROR("LLM2QnnLoweringPass: No valid subgraph found (expected model.x.sN pattern)"); + return ir::PASS_RET_FAILURE; + } + + // Sort subgraphs by name to ensure deterministic processing order + std::vector sorted_names; + sorted_names.reserve(subgraph_map_.size()); + for (const auto& [name, _] : subgraph_map_) { sorted_names.push_back(name); } + std::sort(sorted_names.begin(), sorted_names.end()); + + // Get AOT Compile Context + auto aot_cfg = AOTCompileContext::getInstance().getConfig(); + auto aot_env = AOTCompileContext::getInstance().getEnv(); + + // FIXME: Only support one context right now. + { + int split_graph = aot_cfg["split_graph"]; + MLLM_RT_ASSERT_EQ(split_graph, 1); + aot_env->createContext("context.0", true); + } + + // Process each subgraph in order + for (const auto& subgraph_name : sorted_names) { + auto subgraph = subgraph_map_[subgraph_name]; + auto region = subgraph->getTopRegion(); + if (!region) continue; + + // Create IRWriter for this subgraph + auto subgraph_writer = ir::IRWriter(getCtx(), region); + + auto aot_graph = aot_env->captureAOTGraph("context.0", subgraph_name); + + // Walk through all linalg operations in the subgraph + subgraph_writer.walk( + [&](ir::IRWriter& this_tough_writer, const ir::linalg::LinalgIROp::ptr_t& linalg_op) -> ir::IRWriter::WalkResult { + if (!linalg_op->belongsTo()->getAttr("use_qnn")) { + MLLM_WARN("Found none qnn op: {} in graph: {}", linalg_op->getAOp()->getName(), subgraph_name); + return ir::IRWriter::WalkResult::WALK_BREAK; + } + bool processed = false; + for (auto& [op_type, pass] : named_pattern_) { + if (pass->isMatch(linalg_op)) { + if (!pass->rewrite(this_tough_writer, linalg_op)) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Failed when processing op {} with pass {}", + linalg_op->getAOp()->getName(), optype2Str(op_type)); + } else { + processed = true; + break; + } + } + } + + if (!processed) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Failed processing op {} on all passes", linalg_op->getAOp()->getName()); + } + + return ir::IRWriter::WalkResult::WALK_CONTINUE; + }); + + // Compile + MLLM_RT_ASSERT(aot_graph->compile()); + } + + return ir::PASS_RET_SUCCESS; +} + +ir::Pass::ptr_t createLLM2QnnLoweringPass() { return std::make_shared(); } + +} // namespace mllm::qnn::aot diff --git a/mllm/backends/qnn/aot/passes/LLM2QnnLoweringPass.hpp b/mllm/backends/qnn/aot/passes/LLM2QnnLoweringPass.hpp new file mode 100644 index 000000000..a4ab0335e --- /dev/null +++ b/mllm/backends/qnn/aot/passes/LLM2QnnLoweringPass.hpp @@ -0,0 +1,31 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "mllm/core/OpTypes.hpp" +#include "mllm/compile/ir/Node.hpp" +#include "mllm/compile/ir/graph/Op.hpp" +#include "mllm/compile/passes/Pass.hpp" +#include "mllm/backends/qnn/aot/visitor/Base.hpp" + +namespace mllm::qnn::aot { + +class LLM2QnnLoweringPass final : public ir::Pass { + public: + LLM2QnnLoweringPass(); + + ~LLM2QnnLoweringPass() override = default; + + uint8_t run(const ir::node_ptr_t& op) override; + + private: + std::unordered_map> named_pattern_; + std::unordered_map subgraph_map_; +}; + +ir::Pass::ptr_t createLLM2QnnLoweringPass(); + +} // namespace mllm::qnn::aot diff --git a/mllm/backends/qnn/aot/passes/SplitLLMGraphPass.cpp b/mllm/backends/qnn/aot/passes/SplitLLMGraphPass.cpp index 1d6e10e05..5c2ff94fd 100644 --- a/mllm/backends/qnn/aot/passes/SplitLLMGraphPass.cpp +++ b/mllm/backends/qnn/aot/passes/SplitLLMGraphPass.cpp @@ -2,6 +2,8 @@ // Licensed under the MIT License. #include "mllm/backends/qnn/aot/passes/SplitLLMGraphPass.hpp" +#include "mllm/backends/qnn/aot/passes/AOTCompileContext.hpp" +#include "mllm/compile/ir/builtin/Attribute.hpp" #include "mllm/compile/ir/builtin/Op.hpp" #include "mllm/compile/ir/graph/Op.hpp" #include "mllm/compile/ir/linalg/Op.hpp" @@ -13,14 +15,217 @@ namespace mllm::qnn::aot { +namespace { + +void recursiveAttachGraphNameAndContextName(const ir::IRContext::ptr_t& ctx, const std::string& qnn_context_name, + const std::string& qnn_graph_name, ir::graph::SubGraphOp::ptr_t& g) { + auto _ = ir::IRWriter(ctx, g->getTopRegion()); + _.walk([&](ir::IRWriter& w /*writer*/, const ir::Op::ptr_t& owo) -> ir::IRWriter::WalkResult { + if (owo->isa_()) { + auto p_owo_g = + ctx->lookupSymbolTable(owo->cast_()->getSymbolAttr()->str())->cast_(); + recursiveAttachGraphNameAndContextName(ctx, qnn_context_name, qnn_graph_name, p_owo_g); + } + if (owo->isa_()) { + owo->setAttr("qnn_context_name", ctx->create(qnn_context_name)); + owo->setAttr("qnn_graph_name", ctx->create(qnn_graph_name)); + } + return ir::IRWriter::WalkResult::WALK_CONTINUE; + }); +} + +void recursiveRemoveOpsIntoNewGraph(const ir::IRContext::ptr_t& ctx, ir::graph::SubGraphOp::ptr_t& g) { + auto _ = ir::IRWriter(ctx, g->getTopRegion()); + _.walk([&](ir::IRWriter& w /*writer*/, const ir::Op::ptr_t& owo) -> ir::IRWriter::WalkResult { + if (owo->isa_()) { + auto p_owo_g = + ctx->lookupSymbolTable(owo->cast_()->getSymbolAttr()->str())->cast_(); + recursiveRemoveOpsIntoNewGraph(ctx, p_owo_g); + } + if (owo->isa_()) { + auto _g_name = owo->getAttr("qnn_graph_name")->cast_()->data(); + auto p_owo_g = ctx->lookupSymbolTable(_g_name)->cast_(); + auto temp_w = ir::IRWriter(ctx, p_owo_g->getTopRegion()); + w.removeOp(owo); + temp_w.insertOpAtLast(owo); + } + return ir::IRWriter::WalkResult::WALK_CONTINUE; + }); +} + +} // namespace + uint8_t SplitLLMGraphPass::run(const ir::node_ptr_t& op) { - // The top op should be ModuleOp + // The top op should be modelOp MLLM_RT_ASSERT(op->isa_()); - auto module_op = op->cast_(); - auto writer = ir::IRWriter(getCtx(), module_op->getTopRegion()); + auto model_op = op->cast_(); + auto top_model_writer = ir::IRWriter(getCtx(), model_op->getTopRegion()); + + // Check only has 1 call graph op in model_op + ir::graph::CallGraphOp::ptr_t call_graph_op = nullptr; + top_model_writer.walk( + [&](ir::IRWriter& /*writer*/, const ir::graph::CallGraphOp::ptr_t& call_op) -> ir::IRWriter::WalkResult { + MLLM_RT_ASSERT(call_graph_op == nullptr); // Should only have one CallGraphOp + call_graph_op = call_op; + return ir::IRWriter::WalkResult::WALK_CONTINUE; + }); + + if (call_graph_op == nullptr) { + MLLM_ERROR("LLM2QnnLoweringPass: No CallGraphOp found in ModuleOp"); + return ir::PASS_RET_FAILURE; + } + + // Check call graph op point to a subgraph named "model" + auto symbol_attr = call_graph_op->getSymbolAttr(); + if (symbol_attr == nullptr || symbol_attr->str() != "model") { + MLLM_ERROR("LLM2QnnLoweringPass: CallGraphOp should point to a subgraph named 'model'"); + return ir::PASS_RET_FAILURE; + } + + // Get the "model" subgraph + auto model_subgraph = getCtx()->lookupSymbolTable("model")->cast_(); + if (model_subgraph == nullptr) { + MLLM_ERROR("LLM2QnnLoweringPass: Cannot find 'model' subgraph in symbol table"); + return ir::PASS_RET_FAILURE; + } + + // Split all layers, and fuse some op into one layer. + // e.g. + // op0 + // op1 + // op2 + // layer.0 + // layer.1 + // op3 + // When we split graphs. op0, op1 and op2 will be merged into layer.0. And op3 will be merged into op3 + + // Count how many layers we have. + auto cfg = AOTCompileContext::getInstance().getConfig(); + int32_t __global_total_layers = cfg["quant_recipe"]["layers"]; + int32_t __global_split_graphs = cfg["split_graph"]; + + // Check seq length + // FIXME: We suppose the first input to LLM is tokend_ids! Whose shape is [Batch, Sequence] + int32_t __global_seq_len = + model_subgraph->getTopRegion()->inputs().front()->cast_()->tensor_.size(1); + + // Create merged graph first! + for (int i = 0; i < __global_split_graphs; ++i) { + auto op = top_model_writer.create( + top_model_writer.create("model." + std::to_string(i) + ".s" + std::to_string(__global_seq_len))); + op->setAttr("use_qnn", top_model_writer.create(true)); + } + + // Solve op's scopes. Attach qnn_context_name and qnn_graph_name on them. + // Suppose all layer's name is xxx.xxx.number + { + int graph_counter = 0; + auto model_graph_writer = ir::IRWriter(getCtx(), model_subgraph->getTopRegion()); + model_graph_writer.walk([&](ir::IRWriter& w /*writer*/, const ir::Op::ptr_t& one_op) -> ir::IRWriter::WalkResult { + if (one_op->isa_()) { + auto g_w_g = getCtx() + ->lookupSymbolTable(one_op->cast_()->getSymbolAttr()->str()) + ->cast_(); + + auto g_w_g_name = g_w_g->getSymbolAttr()->str(); + + // Extract layer number from g_w_g_name (format: xxx.xxx.number) + // The layer number is the last part after the final dot + size_t last_dot_pos = g_w_g_name.find_last_of('.'); + int layer_num = 0; + if (last_dot_pos != std::string::npos && last_dot_pos + 1 < g_w_g_name.length()) { + layer_num = std::stoi(g_w_g_name.substr(last_dot_pos + 1)); + } + + // Calculate which graph counter to use based on layer number + // Each graph will contain approximately __global_total_layers / __global_split_graphs layers + if (__global_split_graphs > 0) { + int layers_per_graph = __global_total_layers / __global_split_graphs; + graph_counter = layer_num / layers_per_graph; + // Ensure graph_counter doesn't exceed the number of split graphs + if (graph_counter >= __global_split_graphs) { graph_counter = __global_split_graphs - 1; } + } + + recursiveAttachGraphNameAndContextName( + getCtx(), "context." + std::to_string(graph_counter), + "model." + std::to_string(graph_counter) + ".s" + std::to_string(__global_seq_len), g_w_g); + } + + if (one_op->isa_()) { + one_op->setAttr("qnn_context_name", getCtx()->create("context." + std::to_string(graph_counter))); + one_op->setAttr("qnn_graph_name", getCtx()->create("model." + std::to_string(graph_counter) + ".s" + + std::to_string(__global_seq_len))); + } + return ir::IRWriter::WalkResult::WALK_CONTINUE; + }); + } + + // Loop all ops in model_subgraph recursively. Using qnn_graph_name to merge those ops in on exists graph. + recursiveRemoveOpsIntoNewGraph(getCtx(), model_subgraph); + + // Solve the inputs and output of splitted graph + { + std::vector original_inputs; + std::vector original_outputs; + for (auto item : model_subgraph->inputs()) { original_inputs.emplace_back(item->cast_()); } + for (auto item : model_subgraph->outputs()) { original_outputs.emplace_back(item->cast_()); } + + // FIXME: currently only support one graph! + MLLM_RT_ASSERT_EQ(__global_split_graphs, 1); + auto one_graph = + getCtx()->lookupSymbolTable("model.0.s" + std::to_string(__global_seq_len))->cast_(); + for (auto item : original_inputs) { + one_graph->inputs().emplace_back(item); + one_graph->getTopRegion()->inputs().emplace_back(item); + } + for (auto item : original_outputs) { + one_graph->outputs().emplace_back(item); + one_graph->getTopRegion()->outputs().emplace_back(item); + } + auto wwww = ir::IRWriter(getCtx(), one_graph->getTopRegion()); + auto return_op = wwww.create(original_outputs); + } + + // Remove old graphs + { + top_model_writer.walk( + [&](ir::IRWriter& wvw, const ir::graph::SubGraphOp::ptr_t& sub_g_op) -> ir::IRWriter::WalkResult { + auto name = sub_g_op->getSymbolAttr()->str(); + bool matched = false; + for (int i = 0; i < __global_split_graphs; ++i) { + if (name == "model." + std::to_string(i) + ".s" + std::to_string(__global_seq_len)) { matched = true; } + if (name == "model") { matched = true; } + } + if (!matched) { wvw.removeOp(sub_g_op); } + return ir::IRWriter::WalkResult::WALK_CONTINUE; + }); + } + + // Insert call graph ops into top model subgraph. + { + // 1. remove all call graph ops in model_subgraph + auto model_graph_writer = ir::IRWriter(getCtx(), model_subgraph->getTopRegion()); + model_graph_writer.walk([&](ir::IRWriter& wvw /*writer*/, const ir::Op::ptr_t& one_op) -> ir::IRWriter::WalkResult { + if (one_op->isa_()) { wvw.removeOp(one_op); } + if (one_op->isa_()) { wvw.removeOp(one_op); } + return ir::IRWriter::WalkResult::WALK_CONTINUE; + }); - // TODO: Implement graph splitting logic here + // 2. Insert new call ops. + // FIXME: currently only support one graph! + MLLM_RT_ASSERT_EQ(__global_split_graphs, 1); + auto call_op = model_graph_writer.create( + model_graph_writer.create("model.0.s" + std::to_string(__global_seq_len))); + std::vector original_inputs; + std::vector original_outputs; + for (auto item : model_subgraph->inputs()) { (*item)-- > call_op; } + for (auto item : model_subgraph->outputs()) { + (*call_op)-- > item->cast_(); + original_outputs.emplace_back(item->cast_()); + } + auto return_op = model_graph_writer.create(original_outputs); + } return ir::PASS_RET_SUCCESS; } diff --git a/mllm/backends/qnn/aot/visitor/Embedding.cpp b/mllm/backends/qnn/aot/visitor/Embedding.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/backends/qnn/aot/visitor/Embedding.hpp b/mllm/backends/qnn/aot/visitor/Embedding.hpp new file mode 100644 index 000000000..e69de29bb diff --git a/mllm/compile/ir/Node.hpp b/mllm/compile/ir/Node.hpp index 1bf95c540..5f3031237 100644 --- a/mllm/compile/ir/Node.hpp +++ b/mllm/compile/ir/Node.hpp @@ -191,6 +191,8 @@ class Attr : public Node { class Val : public Node, public DeviceInterface { public: + using ptr_t = val_ptr_t; + ~Val() override; Val(); explicit Val(const NodeKind& kind); diff --git a/pymllm/backends/qualcomm/transformers/static_qwen3_quantizer.py b/pymllm/backends/qualcomm/transformers/static_qwen3_quantizer.py new file mode 100644 index 000000000..e69de29bb