-
Notifications
You must be signed in to change notification settings - Fork 190
feat: add LLM2QnnLoweringPass and update graph splitting logic #577
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9d499b3
cdb7387
1f345cd
d47693c
65c7818
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,18 +4,19 @@ | |
|
|
||
| #include <QNN/QnnTypes.h> | ||
|
|
||
| #include <QNN/QnnGraph.h> | ||
| #include <QNN/QnnContext.h> | ||
| #include <QNN/HTP/QnnHtpDevice.h> | ||
| #include <QNN/HTP/QnnHtpCommon.h> | ||
| #include <QNN/HTP/QnnHtpContext.h> | ||
|
|
||
| #include "mllm/backends/qnn/aot/passes/AOTCompileContext.hpp" | ||
| #include "mllm/core/DataTypes.hpp" | ||
| #include "mllm/utils/Common.hpp" | ||
| #include "mllm/core/DataTypes.hpp" | ||
| #include "mllm/backends/qnn/QNNTypeMacros.hpp" | ||
| #include "mllm/compile/ir/linalg/Attribute.hpp" | ||
| #include "mllm/backends/qnn/aot/QnnWrappersAPI.hpp" | ||
| #include "mllm/backends/qnn/aot/QnnTargetMachine.hpp" | ||
| #include "mllm/backends/qnn/aot/passes/AOTCompileContext.hpp" | ||
|
|
||
| namespace mllm::qnn::aot { | ||
|
|
||
|
|
@@ -139,13 +140,6 @@ Qnn_Param_t* QnnAOTParamTensor::getQnnParam() { return &qnn_param_; } | |
| Qnn_Tensor_t* QnnAOTParamTensor::getQnnTensor() { return &qnn_param_.tensorParam; } | ||
|
|
||
| QnnAOTNodeTensor::QnnAOTNodeTensor(const ir::tensor::TensorValue::ptr_t& v, bool force_static_weight) { | ||
| // TODO Constant value should also use Static!!! And they can be pruned | ||
| // TODO Constant value should also use Static!!! And they can be pruned | ||
| // TODO Constant value should also use Static!!! And they can be pruned | ||
| // TODO Constant value should also use Static!!! And they can be pruned | ||
| // TODO Constant value should also use Static!!! And they can be pruned | ||
| // TODO Constant value should also use Static!!! And they can be pruned | ||
|
|
||
| name_ = v->name(); | ||
| mllm_tensor_ = v->tensor_; | ||
| quant_spec_ = v->getAttr("quant_recipe")->cast_<ir::linalg::LinalgIRQuantizatonSpecAttr>()->spec_; | ||
|
|
@@ -232,6 +226,7 @@ Qnn_TensorType_t QnnAOTNodeTensor::parseQnnTensorTypeFromIR(const ir::tensor::Te | |
| // Check Attribute. The Attribute priority is higher than tensor type | ||
| if (v->getAttr("qnn_graph_outputs")) { ret_qnn_tensor_type = QNN_TENSOR_TYPE_APP_READ; } | ||
| if (v->getAttr("qnn_graph_inputs")) { ret_qnn_tensor_type = QNN_TENSOR_TYPE_APP_READWRITE; } | ||
| if (v->getAttr("constant")) { ret_qnn_tensor_type = QNN_TENSOR_TYPE_STATIC; } | ||
|
|
||
| return ret_qnn_tensor_type; | ||
| } | ||
|
|
@@ -470,6 +465,17 @@ QnnAOTNodeOperation::ptr_t QnnAOTNodeOperation::setPackageName(const std::string | |
| return shared_from_this(); | ||
| } | ||
|
|
||
| QnnAOTGraph::QnnAOTGraph(const std::string& g_name, const std::shared_ptr<QnnDeviceAndContext>& context) | ||
| : graph_name_(g_name), qnn_context_(context) { | ||
| belongs_context_name_ = context->name_; | ||
|
|
||
| auto env = AOTCompileContext::getInstance().getEnv(); | ||
| auto qnn_interface = env->getFuncSymbol().qnn_interface_; | ||
|
|
||
| auto ok = qnn_interface.graphCreate(context->qnn_ctx_handle_, g_name.c_str(), nullptr /*graph_config*/, &qnn_graph_handle_); | ||
| MLLM_RT_ASSERT_EQ(ok, QNN_SUCCESS); | ||
| } | ||
|
|
||
| void QnnAOTGraph::addOperation(const QnnAOTNodeOperation::ptr_t& qnn_op) { | ||
| auto env = AOTCompileContext::getInstance().getEnv(); | ||
| auto qnn_interface = env->getFuncSymbol().qnn_interface_; | ||
|
|
@@ -481,20 +487,52 @@ void QnnAOTGraph::addOperation(const QnnAOTNodeOperation::ptr_t& qnn_op) { | |
| qnn_op_config.v1.packageName = qnn_op->package_name_.c_str(); | ||
| qnn_op_config.v1.typeName = qnn_op->op_name_.c_str(); | ||
|
|
||
| // TODO PARAMs | ||
| // TODO Inputs | ||
| // TODO Outputs | ||
| // Params | ||
| uint32_t param_counter = 0; | ||
| size_t total_param_size = qnn_op->param_scalar.size() + qnn_op->param_tensor.size(); | ||
| Qnn_Param_t* qnn_param_array = (Qnn_Param_t*)malloc(total_param_size * sizeof(Qnn_Param_t)); | ||
| qnn_op->unreachable_handle_.emplace_back(qnn_param_array); | ||
| { | ||
| // Tensor Param | ||
| for (const auto& p : qnn_op->param_tensor) { | ||
| auto ok = qnn_interface.tensorCreateGraphTensor(qnn_graph_handle_, p->getQnnTensor()); | ||
| MLLM_RT_ASSERT_EQ(ok, QNN_SUCCESS); | ||
| qnn_param_array[param_counter++] = *p->getQnnParam(); | ||
| } | ||
| for (const auto& p : qnn_op->param_scalar) { qnn_param_array[param_counter++] = *p->getQnnParam(); } | ||
| } | ||
|
|
||
| // Inputs | ||
| Qnn_Tensor_t* qnn_inputs_array = (Qnn_Tensor_t*)malloc(qnn_op->inputs.size() * sizeof(Qnn_Tensor_t)); | ||
| qnn_op->unreachable_handle_.emplace_back(qnn_inputs_array); | ||
| for (int i = 0; i < qnn_op->inputs.size(); ++i) { qnn_inputs_array[i] = *qnn_op->inputs[i]->getQnnTensor(); } | ||
|
|
||
| // TODO node validations | ||
| // Outputs | ||
| Qnn_Tensor_t* qnn_outputs_array = (Qnn_Tensor_t*)malloc(qnn_op->outputs.size() * sizeof(Qnn_Tensor_t)); | ||
| qnn_op->unreachable_handle_.emplace_back(qnn_outputs_array); | ||
| for (int i = 0; i < qnn_op->outputs.size(); ++i) { qnn_outputs_array[i] = *qnn_op->outputs[i]->getQnnTensor(); } | ||
|
|
||
| // TODO add node to graph. | ||
| qnn_op_config.v1.params = qnn_param_array; | ||
| qnn_op_config.v1.numOfParams = total_param_size; | ||
| qnn_op_config.v1.inputTensors = qnn_inputs_array; | ||
| qnn_op_config.v1.numOfInputs = qnn_op->inputs.size(); | ||
| qnn_op_config.v1.outputTensors = qnn_outputs_array; | ||
| qnn_op_config.v1.numOfOutputs = qnn_op->outputs.size(); | ||
|
|
||
| auto ok = qnn_interface.backendValidateOpConfig(env->getContext(belongs_context_name_)->bk_handle_, qnn_op_config); | ||
| MLLM_RT_ASSERT_EQ(ok, QNN_SUCCESS); | ||
| ok = qnn_interface.graphAddNode(qnn_graph_handle_, qnn_op_config); | ||
| MLLM_RT_ASSERT_EQ(ok, QNN_SUCCESS); | ||
|
|
||
| op_node_.insert({qnn_op->getName(), qnn_op}); | ||
| } | ||
|
|
||
| bool QnnAOTGraph::compile() { | ||
| if (is_compiled_) { return true; } | ||
| // TODO | ||
|
|
||
| auto env = AOTCompileContext::getInstance().getEnv(); | ||
| auto qnn_interface = env->getFuncSymbol().qnn_interface_; | ||
| qnn_interface.graphFinalize(qnn_graph_handle_, env->getContext(belongs_context_name_)->profile_bk_handle_, nullptr); | ||
|
|
||
| is_compiled_ = true; | ||
| return true; | ||
|
|
@@ -692,25 +730,6 @@ std::shared_ptr<QnnDeviceAndContext> QnnAOTEnv::createContext(const std::string& | |
| // clang-format off | ||
| { | ||
| // FIXME(wch): we need to register our own opset of qnn. | ||
| // struct OpPackageInfo { | ||
| // std::string path; | ||
| // std::string interface_provider; | ||
| // std::string target; | ||
| // }; | ||
|
|
||
| // std::vector<OpPackageInfo> op_packages = { | ||
| // {.path = "libQnnMllmPackageCPU.so", .interface_provider = "MllmPackageInterfaceProvider", .target = "CPU"}, | ||
| // {.path = "libQnnMllmPackageHTP.so", .interface_provider = "MllmPackageInterfaceProvider", .target = "HTP"}, | ||
| // }; | ||
|
|
||
| // for (const auto& pkg : op_packages) { | ||
| // if (!qnn_htp_func_symbols_.qnn_interface_.backendRegisterOpPackage) { | ||
| // MLLM_ERROR_EXIT(ExitCode::kCoreError, "qnn_htp_func_symbols_.qnn_interface_.backendRegisterOpPackage is nullptr."); | ||
| // } | ||
| // auto status = qnn_htp_func_symbols_.qnn_interface_.backendRegisterOpPackage(context->bk_handle_, pkg.path.c_str(), pkg.interface_provider.c_str(), pkg.target.c_str()); | ||
| // MLLM_RT_ASSERT_EQ(status, QNN_BACKEND_NO_ERROR); | ||
| // MLLM_INFO("QNN Registered op package: {}, interface provider: {}, target: {}", pkg.path, pkg.interface_provider, pkg.target); | ||
| // } | ||
| } | ||
| // clang-format on | ||
|
|
||
|
|
@@ -800,8 +819,11 @@ std::vector<QnnContext_CustomConfig_t> QnnAOTEnv::createContextCustomConfig(bool | |
| } | ||
|
|
||
| QnnAOTGraph::ptr_t QnnAOTEnv::captureAOTGraph(const std::string& qnn_context_name, const std::string& g_name) { | ||
| // TODO | ||
| return nullptr; | ||
| MLLM_RT_ASSERT(contexts_.count(qnn_context_name) == 1); | ||
| auto ret = QnnAOTGraph::create(g_name, contexts_[qnn_context_name]); | ||
| ret->belongs_context_name_ = qnn_context_name; | ||
| contexts_[qnn_context_name]->graphs_.insert({g_name, ret}); | ||
| return ret; | ||
| } | ||
|
|
||
| void QnnAOTEnv::captureAOTNodeOp(const std::string& qnn_context_name, const std::string& graph_name, | ||
|
|
@@ -813,18 +835,13 @@ void QnnAOTEnv::captureAOTNodeOp(const std::string& qnn_context_name, const std: | |
|
|
||
| QnnAOTNodeTensor::ptr_t QnnAOTEnv::captureQnnAOTNodeTensor(const std::string& qnn_context_name, const std::string& graph_name, | ||
| const ir::tensor::TensorValue::ptr_t& v, bool force_static_weight) { | ||
| // TODO Constant value should also use Static!!! And they can be pruned | ||
| // TODO Constant value should also use Static!!! And they can be pruned | ||
| // TODO Constant value should also use Static!!! And they can be pruned | ||
| // TODO Constant value should also use Static!!! And they can be pruned | ||
| // TODO Constant value should also use Static!!! And they can be pruned | ||
| // TODO Constant value should also use Static!!! And they can be pruned | ||
| auto __qnn_tensor_name = v->name(); | ||
|
|
||
| bool __qnn_enable_static_weight = force_static_weight; | ||
|
|
||
| // Check if this value want static qnn weight. The static qnn weight will be shared through one context in diff graphs! | ||
| if (v->tensor_.memType() == kGlobal || (v->tensor_.memType() <= kParams_End && v->tensor_.memType() >= kParams_Start)) { | ||
| if (v->tensor_.memType() == kGlobal || (v->tensor_.memType() <= kParams_End && v->tensor_.memType() >= kParams_Start) | ||
| || v->getAttr("constant")) { | ||
| __qnn_enable_static_weight = true; | ||
| } | ||
|
|
||
|
|
@@ -848,11 +865,17 @@ QnnAOTNodeTensor::ptr_t QnnAOTEnv::captureQnnAOTNodeTensor(const std::string& qn | |
| auto ret = QnnAOTNodeTensor::create(v, __qnn_enable_static_weight); | ||
| if (__qnn_enable_static_weight) { | ||
| contexts_[qnn_context_name]->static_tensor_.insert({__qnn_tensor_name, ret}); | ||
| qnn_htp_func_symbols_.qnn_interface_.tensorCreateContextTensor(contexts_[qnn_context_name]->qnn_ctx_handle_, | ||
| ret->getQnnTensor()); | ||
| } else { | ||
| contexts_[qnn_context_name]->graphs_[graph_name]->all_tensors_.insert({__qnn_tensor_name, ret}); | ||
| qnn_htp_func_symbols_.qnn_interface_.tensorCreateGraphTensor( | ||
| contexts_[qnn_context_name]->graphs_[graph_name]->qnn_graph_handle_, ret->getQnnTensor()); | ||
| } | ||
|
|
||
| return ret; | ||
| } | ||
|
|
||
| std::shared_ptr<QnnDeviceAndContext> QnnAOTEnv::getContext(const std::string& name) { return contexts_[name]; } | ||
|
|
||
|
Comment on lines
+879
to
+880
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Using 🔎 Proposed fix using at() for fail-fast behavior-std::shared_ptr<QnnDeviceAndContext> QnnAOTEnv::getContext(const std::string& name) { return contexts_[name]; }
+std::shared_ptr<QnnDeviceAndContext> QnnAOTEnv::getContext(const std::string& name) {
+ auto it = contexts_.find(name);
+ if (it == contexts_.end()) {
+ MLLM_ERROR("QnnAOTEnv::getContext: context '{}' not found", name);
+ return nullptr;
+ }
+ return it->second;
+}🤖 Prompt for AI Agents |
||
| } // namespace mllm::qnn::aot | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing error handling for tensor creation API calls.
The
tensorCreateContextTensorandtensorCreateGraphTensorreturn values are not checked. Other QNN API calls in this file consistently check forQNN_SUCCESS. If tensor creation fails silently, it could cause hard-to-debug issues downstream.🔎 Proposed fix
if (__qnn_enable_static_weight) { contexts_[qnn_context_name]->static_tensor_.insert({__qnn_tensor_name, ret}); - qnn_htp_func_symbols_.qnn_interface_.tensorCreateContextTensor(contexts_[qnn_context_name]->qnn_ctx_handle_, - ret->getQnnTensor()); + auto status = qnn_htp_func_symbols_.qnn_interface_.tensorCreateContextTensor( + contexts_[qnn_context_name]->qnn_ctx_handle_, ret->getQnnTensor()); + MLLM_RT_ASSERT_EQ(status, QNN_SUCCESS); } else { contexts_[qnn_context_name]->graphs_[graph_name]->all_tensors_.insert({__qnn_tensor_name, ret}); - qnn_htp_func_symbols_.qnn_interface_.tensorCreateGraphTensor( + auto status = qnn_htp_func_symbols_.qnn_interface_.tensorCreateGraphTensor( contexts_[qnn_context_name]->graphs_[graph_name]->qnn_graph_handle_, ret->getQnnTensor()); + MLLM_RT_ASSERT_EQ(status, QNN_SUCCESS); }🤖 Prompt for AI Agents