Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/qwen3_qnn_aot/qnn_aot_cfg.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"split_graph": 1,
"quant_recipe": {
"llm_recipe": true,
"layers": 28,
"builtin_llm_pass": {
"model": "qwen3",
"lm_head": {
Expand Down
109 changes: 66 additions & 43 deletions mllm/backends/qnn/aot/QnnWrappersAPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,19 @@

#include <QNN/QnnTypes.h>

#include <QNN/QnnGraph.h>
#include <QNN/QnnContext.h>
#include <QNN/HTP/QnnHtpDevice.h>
#include <QNN/HTP/QnnHtpCommon.h>
#include <QNN/HTP/QnnHtpContext.h>

#include "mllm/backends/qnn/aot/passes/AOTCompileContext.hpp"
#include "mllm/core/DataTypes.hpp"
#include "mllm/utils/Common.hpp"
#include "mllm/core/DataTypes.hpp"
#include "mllm/backends/qnn/QNNTypeMacros.hpp"
#include "mllm/compile/ir/linalg/Attribute.hpp"
#include "mllm/backends/qnn/aot/QnnWrappersAPI.hpp"
#include "mllm/backends/qnn/aot/QnnTargetMachine.hpp"
#include "mllm/backends/qnn/aot/passes/AOTCompileContext.hpp"

namespace mllm::qnn::aot {

Expand Down Expand Up @@ -139,13 +140,6 @@ Qnn_Param_t* QnnAOTParamTensor::getQnnParam() { return &qnn_param_; }
Qnn_Tensor_t* QnnAOTParamTensor::getQnnTensor() { return &qnn_param_.tensorParam; }

QnnAOTNodeTensor::QnnAOTNodeTensor(const ir::tensor::TensorValue::ptr_t& v, bool force_static_weight) {
// TODO Constant value should also use Static!!! And they can be pruned
// TODO Constant value should also use Static!!! And they can be pruned
// TODO Constant value should also use Static!!! And they can be pruned
// TODO Constant value should also use Static!!! And they can be pruned
// TODO Constant value should also use Static!!! And they can be pruned
// TODO Constant value should also use Static!!! And they can be pruned

name_ = v->name();
mllm_tensor_ = v->tensor_;
quant_spec_ = v->getAttr("quant_recipe")->cast_<ir::linalg::LinalgIRQuantizatonSpecAttr>()->spec_;
Expand Down Expand Up @@ -232,6 +226,7 @@ Qnn_TensorType_t QnnAOTNodeTensor::parseQnnTensorTypeFromIR(const ir::tensor::Te
// Check Attribute. The Attribute priority is higher than tensor type
if (v->getAttr("qnn_graph_outputs")) { ret_qnn_tensor_type = QNN_TENSOR_TYPE_APP_READ; }
if (v->getAttr("qnn_graph_inputs")) { ret_qnn_tensor_type = QNN_TENSOR_TYPE_APP_READWRITE; }
if (v->getAttr("constant")) { ret_qnn_tensor_type = QNN_TENSOR_TYPE_STATIC; }

return ret_qnn_tensor_type;
}
Expand Down Expand Up @@ -470,6 +465,17 @@ QnnAOTNodeOperation::ptr_t QnnAOTNodeOperation::setPackageName(const std::string
return shared_from_this();
}

QnnAOTGraph::QnnAOTGraph(const std::string& g_name, const std::shared_ptr<QnnDeviceAndContext>& context)
: graph_name_(g_name), qnn_context_(context) {
belongs_context_name_ = context->name_;

auto env = AOTCompileContext::getInstance().getEnv();
auto qnn_interface = env->getFuncSymbol().qnn_interface_;

auto ok = qnn_interface.graphCreate(context->qnn_ctx_handle_, g_name.c_str(), nullptr /*graph_config*/, &qnn_graph_handle_);
MLLM_RT_ASSERT_EQ(ok, QNN_SUCCESS);
}

void QnnAOTGraph::addOperation(const QnnAOTNodeOperation::ptr_t& qnn_op) {
auto env = AOTCompileContext::getInstance().getEnv();
auto qnn_interface = env->getFuncSymbol().qnn_interface_;
Expand All @@ -481,20 +487,52 @@ void QnnAOTGraph::addOperation(const QnnAOTNodeOperation::ptr_t& qnn_op) {
qnn_op_config.v1.packageName = qnn_op->package_name_.c_str();
qnn_op_config.v1.typeName = qnn_op->op_name_.c_str();

// TODO PARAMs
// TODO Inputs
// TODO Outputs
// Params
uint32_t param_counter = 0;
size_t total_param_size = qnn_op->param_scalar.size() + qnn_op->param_tensor.size();
Qnn_Param_t* qnn_param_array = (Qnn_Param_t*)malloc(total_param_size * sizeof(Qnn_Param_t));
qnn_op->unreachable_handle_.emplace_back(qnn_param_array);
{
// Tensor Param
for (const auto& p : qnn_op->param_tensor) {
auto ok = qnn_interface.tensorCreateGraphTensor(qnn_graph_handle_, p->getQnnTensor());
MLLM_RT_ASSERT_EQ(ok, QNN_SUCCESS);
qnn_param_array[param_counter++] = *p->getQnnParam();
}
for (const auto& p : qnn_op->param_scalar) { qnn_param_array[param_counter++] = *p->getQnnParam(); }
}

// Inputs
Qnn_Tensor_t* qnn_inputs_array = (Qnn_Tensor_t*)malloc(qnn_op->inputs.size() * sizeof(Qnn_Tensor_t));
qnn_op->unreachable_handle_.emplace_back(qnn_inputs_array);
for (int i = 0; i < qnn_op->inputs.size(); ++i) { qnn_inputs_array[i] = *qnn_op->inputs[i]->getQnnTensor(); }

// TODO node validations
// Outputs
Qnn_Tensor_t* qnn_outputs_array = (Qnn_Tensor_t*)malloc(qnn_op->outputs.size() * sizeof(Qnn_Tensor_t));
qnn_op->unreachable_handle_.emplace_back(qnn_outputs_array);
for (int i = 0; i < qnn_op->outputs.size(); ++i) { qnn_outputs_array[i] = *qnn_op->outputs[i]->getQnnTensor(); }

// TODO add node to graph.
qnn_op_config.v1.params = qnn_param_array;
qnn_op_config.v1.numOfParams = total_param_size;
qnn_op_config.v1.inputTensors = qnn_inputs_array;
qnn_op_config.v1.numOfInputs = qnn_op->inputs.size();
qnn_op_config.v1.outputTensors = qnn_outputs_array;
qnn_op_config.v1.numOfOutputs = qnn_op->outputs.size();

auto ok = qnn_interface.backendValidateOpConfig(env->getContext(belongs_context_name_)->bk_handle_, qnn_op_config);
MLLM_RT_ASSERT_EQ(ok, QNN_SUCCESS);
ok = qnn_interface.graphAddNode(qnn_graph_handle_, qnn_op_config);
MLLM_RT_ASSERT_EQ(ok, QNN_SUCCESS);

op_node_.insert({qnn_op->getName(), qnn_op});
}

bool QnnAOTGraph::compile() {
if (is_compiled_) { return true; }
// TODO

auto env = AOTCompileContext::getInstance().getEnv();
auto qnn_interface = env->getFuncSymbol().qnn_interface_;
qnn_interface.graphFinalize(qnn_graph_handle_, env->getContext(belongs_context_name_)->profile_bk_handle_, nullptr);

is_compiled_ = true;
return true;
Expand Down Expand Up @@ -692,25 +730,6 @@ std::shared_ptr<QnnDeviceAndContext> QnnAOTEnv::createContext(const std::string&
// clang-format off
{
// FIXME(wch): we need to register our own opset of qnn.
// struct OpPackageInfo {
// std::string path;
// std::string interface_provider;
// std::string target;
// };

// std::vector<OpPackageInfo> op_packages = {
// {.path = "libQnnMllmPackageCPU.so", .interface_provider = "MllmPackageInterfaceProvider", .target = "CPU"},
// {.path = "libQnnMllmPackageHTP.so", .interface_provider = "MllmPackageInterfaceProvider", .target = "HTP"},
// };

// for (const auto& pkg : op_packages) {
// if (!qnn_htp_func_symbols_.qnn_interface_.backendRegisterOpPackage) {
// MLLM_ERROR_EXIT(ExitCode::kCoreError, "qnn_htp_func_symbols_.qnn_interface_.backendRegisterOpPackage is nullptr.");
// }
// auto status = qnn_htp_func_symbols_.qnn_interface_.backendRegisterOpPackage(context->bk_handle_, pkg.path.c_str(), pkg.interface_provider.c_str(), pkg.target.c_str());
// MLLM_RT_ASSERT_EQ(status, QNN_BACKEND_NO_ERROR);
// MLLM_INFO("QNN Registered op package: {}, interface provider: {}, target: {}", pkg.path, pkg.interface_provider, pkg.target);
// }
}
// clang-format on

Expand Down Expand Up @@ -800,8 +819,11 @@ std::vector<QnnContext_CustomConfig_t> QnnAOTEnv::createContextCustomConfig(bool
}

QnnAOTGraph::ptr_t QnnAOTEnv::captureAOTGraph(const std::string& qnn_context_name, const std::string& g_name) {
// TODO
return nullptr;
MLLM_RT_ASSERT(contexts_.count(qnn_context_name) == 1);
auto ret = QnnAOTGraph::create(g_name, contexts_[qnn_context_name]);
ret->belongs_context_name_ = qnn_context_name;
contexts_[qnn_context_name]->graphs_.insert({g_name, ret});
return ret;
}

void QnnAOTEnv::captureAOTNodeOp(const std::string& qnn_context_name, const std::string& graph_name,
Expand All @@ -813,18 +835,13 @@ void QnnAOTEnv::captureAOTNodeOp(const std::string& qnn_context_name, const std:

QnnAOTNodeTensor::ptr_t QnnAOTEnv::captureQnnAOTNodeTensor(const std::string& qnn_context_name, const std::string& graph_name,
const ir::tensor::TensorValue::ptr_t& v, bool force_static_weight) {
// TODO Constant value should also use Static!!! And they can be pruned
// TODO Constant value should also use Static!!! And they can be pruned
// TODO Constant value should also use Static!!! And they can be pruned
// TODO Constant value should also use Static!!! And they can be pruned
// TODO Constant value should also use Static!!! And they can be pruned
// TODO Constant value should also use Static!!! And they can be pruned
auto __qnn_tensor_name = v->name();

bool __qnn_enable_static_weight = force_static_weight;

// Check if this value want static qnn weight. The static qnn weight will be shared through one context in diff graphs!
if (v->tensor_.memType() == kGlobal || (v->tensor_.memType() <= kParams_End && v->tensor_.memType() >= kParams_Start)) {
if (v->tensor_.memType() == kGlobal || (v->tensor_.memType() <= kParams_End && v->tensor_.memType() >= kParams_Start)
|| v->getAttr("constant")) {
__qnn_enable_static_weight = true;
}

Expand All @@ -848,11 +865,17 @@ QnnAOTNodeTensor::ptr_t QnnAOTEnv::captureQnnAOTNodeTensor(const std::string& qn
auto ret = QnnAOTNodeTensor::create(v, __qnn_enable_static_weight);
if (__qnn_enable_static_weight) {
contexts_[qnn_context_name]->static_tensor_.insert({__qnn_tensor_name, ret});
qnn_htp_func_symbols_.qnn_interface_.tensorCreateContextTensor(contexts_[qnn_context_name]->qnn_ctx_handle_,
ret->getQnnTensor());
} else {
contexts_[qnn_context_name]->graphs_[graph_name]->all_tensors_.insert({__qnn_tensor_name, ret});
qnn_htp_func_symbols_.qnn_interface_.tensorCreateGraphTensor(
contexts_[qnn_context_name]->graphs_[graph_name]->qnn_graph_handle_, ret->getQnnTensor());
Comment on lines +868 to +873
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Missing error handling for tensor creation API calls.

The tensorCreateContextTensor and tensorCreateGraphTensor return values are not checked. Other QNN API calls in this file consistently check for QNN_SUCCESS. If tensor creation fails silently, it could cause hard-to-debug issues downstream.

🔎 Proposed fix
   if (__qnn_enable_static_weight) {
     contexts_[qnn_context_name]->static_tensor_.insert({__qnn_tensor_name, ret});
-    qnn_htp_func_symbols_.qnn_interface_.tensorCreateContextTensor(contexts_[qnn_context_name]->qnn_ctx_handle_,
-                                                                   ret->getQnnTensor());
+    auto status = qnn_htp_func_symbols_.qnn_interface_.tensorCreateContextTensor(
+        contexts_[qnn_context_name]->qnn_ctx_handle_, ret->getQnnTensor());
+    MLLM_RT_ASSERT_EQ(status, QNN_SUCCESS);
   } else {
     contexts_[qnn_context_name]->graphs_[graph_name]->all_tensors_.insert({__qnn_tensor_name, ret});
-    qnn_htp_func_symbols_.qnn_interface_.tensorCreateGraphTensor(
+    auto status = qnn_htp_func_symbols_.qnn_interface_.tensorCreateGraphTensor(
         contexts_[qnn_context_name]->graphs_[graph_name]->qnn_graph_handle_, ret->getQnnTensor());
+    MLLM_RT_ASSERT_EQ(status, QNN_SUCCESS);
   }
🤖 Prompt for AI Agents
In mllm/backends/qnn/aot/QnnWrappersAPI.cpp around lines 868 to 873, the calls
to qnn_interface_.tensorCreateContextTensor and tensorCreateGraphTensor are not
checking return values; capture each call's return value, compare it to
QNN_SUCCESS, and on failure log an error including the tensor name and the
context/graph name, perform any necessary cleanup (remove inserted map entry or
release allocated objects), and propagate the failure (return an error/nullptr
or throw) consistent with other QNN API error handling in this file.

}

return ret;
}

std::shared_ptr<QnnDeviceAndContext> QnnAOTEnv::getContext(const std::string& name) { return contexts_[name]; }

Comment on lines +879 to +880
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

getContext returns null shared_ptr for missing keys without indication.

Using operator[] on an unordered_map inserts a default-constructed value (null shared_ptr) if the key doesn't exist. This could lead to null pointer dereferences if callers don't check. Consider using at() (throws on missing key) or find() with explicit null handling.

🔎 Proposed fix using at() for fail-fast behavior
-std::shared_ptr<QnnDeviceAndContext> QnnAOTEnv::getContext(const std::string& name) { return contexts_[name]; }
+std::shared_ptr<QnnDeviceAndContext> QnnAOTEnv::getContext(const std::string& name) {
+  auto it = contexts_.find(name);
+  if (it == contexts_.end()) {
+    MLLM_ERROR("QnnAOTEnv::getContext: context '{}' not found", name);
+    return nullptr;
+  }
+  return it->second;
+}
🤖 Prompt for AI Agents
In mllm/backends/qnn/aot/QnnWrappersAPI.cpp around lines 879-880, replace the
use of operator[] which inserts a default null shared_ptr for missing keys with
a fail-fast access: use contexts_.at(name) so a missing key throws
std::out_of_range (or, if you prefer explicit handling, use contexts_.find(name)
and return nullptr or throw a descriptive exception). Ensure the function either
allows the exception to propagate or throws a clear std::runtime_error with
context about the missing device name so callers are not left with a null
shared_ptr.

} // namespace mllm::qnn::aot
16 changes: 15 additions & 1 deletion mllm/backends/qnn/aot/QnnWrappersAPI.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ class QnnAOTNodeTensor : public std::enable_shared_from_this<QnnAOTNodeTensor> {

explicit QnnAOTNodeTensor(const ir::tensor::TensorValue::ptr_t& v, bool force_static_weight = false);

inline Qnn_Tensor_t* getQnnTensor() { return &qnn_tensor_; }

private:
Qnn_TensorType_t parseQnnTensorTypeFromIR(const ir::tensor::TensorValue::ptr_t& v);

Expand Down Expand Up @@ -171,10 +173,18 @@ class QnnAOTNodeOperation : public std::enable_shared_from_this<QnnAOTNodeOperat
std::vector<void*> unreachable_handle_;
};

struct QnnDeviceAndContext;
class QnnAOTGraph : public std::enable_shared_from_this<QnnAOTGraph> {
public:
using ptr_t = std::shared_ptr<QnnAOTGraph>;

QnnAOTGraph(const std::string& g_name, const std::shared_ptr<QnnDeviceAndContext>& context);

static inline ptr_t create(const std::string& g_name, const std::shared_ptr<QnnDeviceAndContext>& context) {
auto ret = std::make_shared<QnnAOTGraph>(g_name, context);
return ret;
}

void addOperation(const QnnAOTNodeOperation::ptr_t& qnn_op);

bool compile();
Expand All @@ -183,13 +193,15 @@ class QnnAOTGraph : public std::enable_shared_from_this<QnnAOTGraph> {
std::unordered_map<std::string, QnnAOTNodeOperation::ptr_t> op_node_;
std::unordered_map<std::string, QnnAOTNodeTensor::ptr_t> all_tensors_;

private:
std::string graph_name_;
std::string belongs_context_name_;
Qnn_GraphHandle_t qnn_graph_handle_ = nullptr;
std::shared_ptr<QnnDeviceAndContext> qnn_context_ = nullptr;
};

struct QnnDeviceAndContext {
using ptr_t = std::shared_ptr<QnnDeviceAndContext>;

std::string name_;
Qnn_LogHandle_t log_ = nullptr;
Qnn_BackendHandle_t bk_handle_ = nullptr;
Expand Down Expand Up @@ -283,6 +295,8 @@ class QnnAOTEnv {

inline QnnFuncSymbols& getFuncSymbol() { return qnn_htp_func_symbols_; }

std::shared_ptr<QnnDeviceAndContext> getContext(const std::string& name);

private:
void _setup(const std::string& path = "");

Expand Down
5 changes: 5 additions & 0 deletions mllm/backends/qnn/aot/passes/AOTPipeline.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
#include "mllm/backends/qnn/aot/passes/AOTPipeline.hpp"
#include "mllm/backends/qnn/aot/passes/AOTCompileContext.hpp"
#include "mllm/backends/qnn/aot/passes/LLM2QnnLoweringPass.hpp"
#include "mllm/backends/qnn/aot/passes/LLMQuantRecipePass.hpp"
#include "mllm/backends/qnn/aot/passes/MarkQnnGraphPass.hpp"
#include "mllm/backends/qnn/aot/passes/MarkTensorIO.hpp"
#include "mllm/backends/qnn/aot/passes/MergeLLMHeadIntoMainGraphPass.hpp"
#include "mllm/backends/qnn/aot/passes/OpNamingPass.hpp"
#include "mllm/backends/qnn/aot/passes/SplitLLMGraphPass.hpp"

namespace mllm::qnn::aot {
std::vector<std::shared_ptr<ir::Pass>> createQnnAOTLoweringPipeline(QnnAOTEnv* env, const std::string& config_path) {
Expand All @@ -20,6 +22,9 @@ std::vector<std::shared_ptr<ir::Pass>> createQnnAOTLoweringPipeline(QnnAOTEnv* e
ret.emplace_back(createOpNamingPass());
ret.emplace_back(createMergeLLMHeadIntoMainGraphPass());
ret.emplace_back(createLLMQuantRecipePass());
ret.emplace_back(createSplitLLMGraphPass());
ret.emplace_back(createMarkTensorIOPass());
ret.emplace_back(createLLM2QnnLoweringPass());
} else {
MLLM_WARN("This pass currently only supports LLM applications. Please ensure your config contains 'quant_recipe.llm_recipe "
"= true'.");
Expand Down
Loading