From a37d02276dc188656121fc12225539c6607bc483 Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Wed, 28 Jan 2026 09:59:40 +0000 Subject: [PATCH 1/8] feat(qnn): Add Qwen2 AOT support and enhance performance logging in QNN runtime. Update default log level and improve token generation timing metrics for better performance analysis. --- examples/CMakeLists.txt | 1 + examples/qwen2_qnn_aot/CMakeLists.txt | 10 + examples/qwen2_qnn_aot/aot_run.cpp | 58 ++ examples/qwen2_qnn_aot/compile.cpp | 159 ++++ examples/qwen2_qnn_aot/config_1.5B.json | 32 + examples/qwen2_qnn_aot/config_3B.json | 31 + examples/qwen2_qnn_aot/config_7B.json | 21 + .../qwen2_qnn_aot/modeling_qwen2_qnn_aot.hpp | 508 +++++++++++ examples/qwen2_qnn_aot/qnn_aot_cfg_1.5B.json | 51 ++ examples/qwen2_qnn_aot/qnn_aot_cfg_3B.json | 51 ++ examples/qwen2_qnn_aot/qnn_aot_cfg_7B.json | 51 ++ mllm/backends/qnn/QNNBackend.cpp | 2 +- mllm/backends/qnn/aot_rt/PromptProcessor.cpp | 2 + mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp | 56 +- mllm/backends/qnn/aot_rt/QnnAOTRuntime.hpp | 3 +- mllm/backends/qnn/aot_rt/TokenGenerator.cpp | 4 +- mllm/backends/qnn/aot_rt/TokenGenerator.hpp | 2 +- .../transformers/qwen2/modeling_qwen2.py | 804 ++++++++++++++++++ .../qualcomm/transformers/qwen2/runner.py | 352 ++++++++ .../qualcomm/transformers/qwen2/train.py | 56 ++ 20 files changed, 2246 insertions(+), 8 deletions(-) create mode 100644 examples/qwen2_qnn_aot/CMakeLists.txt create mode 100644 examples/qwen2_qnn_aot/aot_run.cpp create mode 100644 examples/qwen2_qnn_aot/compile.cpp create mode 100644 examples/qwen2_qnn_aot/config_1.5B.json create mode 100644 examples/qwen2_qnn_aot/config_3B.json create mode 100644 examples/qwen2_qnn_aot/config_7B.json create mode 100644 examples/qwen2_qnn_aot/modeling_qwen2_qnn_aot.hpp create mode 100644 examples/qwen2_qnn_aot/qnn_aot_cfg_1.5B.json create mode 100644 examples/qwen2_qnn_aot/qnn_aot_cfg_3B.json create mode 100644 examples/qwen2_qnn_aot/qnn_aot_cfg_7B.json create mode 100644 pymllm/backends/qualcomm/transformers/qwen2/modeling_qwen2.py create mode 100644 pymllm/backends/qualcomm/transformers/qwen2/runner.py create mode 100644 pymllm/backends/qualcomm/transformers/qwen2/train.py diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 180c3cbe6..ea02dd3c6 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -19,4 +19,5 @@ endif() if(MLLM_QUALCOMM_QNN_AOT_ON_X86_ENABLE OR MLLM_BUILD_QNN_BACKEND) add_subdirectory(qwen3_qnn_aot) + add_subdirectory(qwen2_qnn_aot) endif() diff --git a/examples/qwen2_qnn_aot/CMakeLists.txt b/examples/qwen2_qnn_aot/CMakeLists.txt new file mode 100644 index 000000000..eafbe1952 --- /dev/null +++ b/examples/qwen2_qnn_aot/CMakeLists.txt @@ -0,0 +1,10 @@ +# AOT targets run on x86 +if(MLLM_QUALCOMM_QNN_AOT_ON_X86_ENABLE) + add_executable(mllm-qwen2-aot-c compile.cpp) + target_link_libraries(mllm-qwen2-aot-c PRIVATE MllmRT MllmCPUBackend MllmQNNBackend) + target_include_directories(mllm-qwen2-aot-c PRIVATE ${MLLM_INCLUDE_DIR}) +endif() + +add_executable(mllm-qwen2-aot-runner aot_run.cpp) +target_link_libraries(mllm-qwen2-aot-runner PRIVATE MllmRT MllmCPUBackend MllmQNNBackend) +target_include_directories(mllm-qwen2-aot-runner PRIVATE ${MLLM_INCLUDE_DIR}) diff --git a/examples/qwen2_qnn_aot/aot_run.cpp b/examples/qwen2_qnn_aot/aot_run.cpp new file mode 100644 index 000000000..7c0eccc0a --- /dev/null +++ b/examples/qwen2_qnn_aot/aot_run.cpp @@ -0,0 +1,58 @@ +#include +#include +#include +#include +#include "mllm/backends/qnn/aot_rt/QnnAOTRuntime.hpp" +#include "mllm/models/qwen3/configuration_qwen3.hpp" +#include "mllm/models/qwen3/tokenization_qwen3.hpp" + +using mllm::Argparse; +using namespace mllm::qnn::aot; // NOLINT + +MLLM_MAIN({ + auto& help = Argparse::add("-h|--help").help("Show help message"); + auto& model_path = Argparse::add("-m|--model").help("Model path").def("qwen2_qnn.mllm"); + auto& tokenizer_path = Argparse::add("-t|--tokenizer").help("Tokenizer path").def("tokenizer.json"); + auto& config_path = Argparse::add("-c|--config").help("Config path").required(true); + auto& ar_len = Argparse::add("--ar_len").help("Autoregressive length (chunk size)").def(128); + + Argparse::parse(argc, argv); + + if (help.isSet()) { + Argparse::printHelp(); + return 0; + } + + mllm::initQnnBackend(model_path.get()); + + auto qwen2_cfg = mllm::models::qwen3::Qwen3Config(config_path.get()); + + RunnerConfig config; + config.num_layers = qwen2_cfg.num_hidden_layers; + config.num_heads = qwen2_cfg.num_attention_heads; + config.head_dim = qwen2_cfg.head_dim; + config.vocab_size = qwen2_cfg.vocab_size; + config.context_len = 1024; + config.ar_len = ar_len.get(); + + auto tokenizer = mllm::models::qwen3::Qwen3Tokenizer(tokenizer_path.get()); + + auto input_tensor = tokenizer.convertMessage({.prompt = "hello"}); + + input_tensor["sequence"] = mllm::Tensor::arange(0, 256, 1, mllm::kInt64, mllm::kCPU).view({1, -1}); + + // DBG: + mllm::print(input_tensor["sequence"].shape()); + mllm::print(input_tensor["sequence"]); + + Runner runner(config, &tokenizer); + if (!runner.load()) { + std::cerr << "Failed to load model\n"; + return 1; + } + + runner.generate(input_tensor["sequence"], 128, [](const std::string& token) { std::cout << token << std::flush; }, true); + std::cout << "\n"; + + return 0; +}); diff --git a/examples/qwen2_qnn_aot/compile.cpp b/examples/qwen2_qnn_aot/compile.cpp new file mode 100644 index 000000000..17857ac48 --- /dev/null +++ b/examples/qwen2_qnn_aot/compile.cpp @@ -0,0 +1,159 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include + +#include "modeling_qwen2_qnn_aot.hpp" + +using mllm::Argparse; + +MLLM_MAIN({ + auto& help = Argparse::add("-h|--help").help("Show help message"); + auto& model_path = Argparse::add("-m|--model_path").help("Model file path."); + auto& model_cfg_path = Argparse::add("-c|--config").help("Model config file path."); + auto& qnn_aot_cfg_files = Argparse::add("-aot_cfg|--aot_config").help("AOT Config file path."); + + Argparse::parse(argc, argv); + + int N = 32; + int CL = 1024; + + if (help.isSet()) { + Argparse::printHelp(); + return 0; + } + + if (!qnn_aot_cfg_files.isSet()) { + MLLM_ERROR_EXIT(mllm::ExitCode::kCoreError, "No input aot config file path provided"); + Argparse::printHelp(); + return -1; + } + + auto model_cfg = mllm::models::qwen3::Qwen3Config(model_cfg_path.get()); + auto model = mllm::models::qwen2::Qwen2ForCausalLM(model_cfg); + auto params = mllm::load(model_path.get(), mllm::ModelFileVersion::kV2); + // Add params for causal mask + { + params->push("causal_mask.scale", mllm::Tensor::constant(0.001 / 65536.f, mllm::kFloat32)); + params->push("causal_mask.zero_point", mllm::Tensor::constant(65536, mllm::kInt8)); + params->push("constant_zero.scale", mllm::Tensor::constant(0.001 / 65536.f, mllm::kFloat32)); + params->push("constant_zero.zero_point", mllm::Tensor::constant(65536, mllm::kInt8)); + } + model.load(params); + + // Create Qnn AOT Model + auto qnn_aot_env = mllm::qnn::aot::QnnAOTEnv("/opt/qcom/aistack/qairt/2.41.0.251128/lib/x86_64-linux-clang/", + mllm::qnn::aot::parseQcomTargetMachineFromJSONFile(qnn_aot_cfg_files.get())); + + // Model length 32. + + { + // Sequence: [B, N] + // past_key_i: [B, H, D, CL-N] for each layer i + // past_value_i: [B, H, CL-N, D] for each layer i + // causal_mask: [B, 1, N, CL] + auto sequence = mllm::Tensor::zeros({1, N}, mllm::kInt32); + auto causal_mask = mllm::Tensor::zeros({1, 1, N, CL}, mllm::kUInt16); + + // NOTE: force set causal mask to UInt16Asy + // NOTE: Attach scale and zero point to causal mask + { + causal_mask = causal_mask.__unsafeSetDType(mllm::kUInt16PerTensorAsy); + causal_mask.attach("scale", params->pull("causal_mask.scale").impl(), true); + causal_mask.attach("zero_point", params->pull("causal_mask.zero_point").impl(), true); + } + + // Create KV cache inputs for all layers + std::unordered_map trace_inputs; + trace_inputs["sequence"] = sequence; + trace_inputs["causal_mask"] = causal_mask; + for (int i = 0; i < model_cfg.num_hidden_layers; ++i) { + auto past_key_name = "past_key_" + std::to_string(i); + auto past_value_name = "past_value_" + std::to_string(i); + + // clang-format off + trace_inputs[past_key_name] = mllm::Tensor::empty({ + 1, + model_cfg.num_key_value_heads, + model_cfg.head_dim, + CL - N, + }, mllm::kUInt8PerTensorSym); + trace_inputs[past_value_name] = mllm::Tensor::empty({1, model_cfg.num_key_value_heads, CL - N, model_cfg.head_dim}, mllm::kUInt8PerTensorSym); + + trace_inputs[past_key_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_key_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + + trace_inputs[past_value_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_value_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + // clang-format on + } + + auto ir = model.trace(trace_inputs, {}); + + mllm::ir::PassManager pm(ir["model"]); + pm.reg(mllm::qnn::aot::createQnnAOTLoweringPipeline(&qnn_aot_env, qnn_aot_cfg_files.get(), params)); + pm.run(); + + mllm::redirect("qwen2_qnn_aot_32.mir", [&]() { mllm::print(ir["model"]); }); + } + + // Model length 1. + { + N = 1; + + // Sequence: [B, N] + // past_key_i: [B, H, D, CL-N] for each layer i + // past_value_i: [B, H, CL-N, D] for each layer i + // causal_mask: [B, 1, N, CL] + auto sequence = mllm::Tensor::zeros({1, N}, mllm::kInt32); + auto causal_mask = mllm::Tensor::zeros({1, 1, N, CL}, mllm::kUInt16); + + // NOTE: force set causal mask to UInt16Asy + // NOTE: Attach scale and zero point to causal mask + { + causal_mask = causal_mask.__unsafeSetDType(mllm::kUInt16PerTensorAsy); + causal_mask.attach("scale", params->pull("causal_mask.scale").impl(), true); + causal_mask.attach("zero_point", params->pull("causal_mask.zero_point").impl(), true); + } + + // Create KV cache inputs for all layers + std::unordered_map trace_inputs; + trace_inputs["sequence"] = sequence; + trace_inputs["causal_mask"] = causal_mask; + for (int i = 0; i < model_cfg.num_hidden_layers; ++i) { + auto past_key_name = "past_key_" + std::to_string(i); + auto past_value_name = "past_value_" + std::to_string(i); + + // clang-format off + trace_inputs[past_key_name] = mllm::Tensor::empty({ + 1, + model_cfg.num_key_value_heads, + model_cfg.head_dim, + CL - N, + }, mllm::kUInt8PerTensorSym); + trace_inputs[past_value_name] = mllm::Tensor::empty({1, model_cfg.num_key_value_heads, CL - N, model_cfg.head_dim}, mllm::kUInt8PerTensorSym); + + trace_inputs[past_key_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_key_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + + trace_inputs[past_value_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_value_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + // clang-format on + } + + auto ir = model.trace(trace_inputs, {}); + + mllm::ir::PassManager pm(ir["model"]); + pm.reg(mllm::qnn::aot::createQnnAOTLoweringPipeline(&qnn_aot_env, qnn_aot_cfg_files.get(), params)); + pm.run(); + + mllm::redirect("qwen2_qnn_aot_1.mir", [&]() { mllm::print(ir["model"]); }); + } + + qnn_aot_env.saveContext("context.0", "qwen2-lpbq.bin"); +}); diff --git a/examples/qwen2_qnn_aot/config_1.5B.json b/examples/qwen2_qnn_aot/config_1.5B.json new file mode 100644 index 000000000..e04d581bf --- /dev/null +++ b/examples/qwen2_qnn_aot/config_1.5B.json @@ -0,0 +1,32 @@ +{ + "architectures": [ + "Qwen2ForCausalLM" + ], + "bos_token_id": 151643, + "eos_token_id": 151645, + "attention_bias": true, + "attention_dropout": 0.0, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 1536, + "initializer_range": 0.02, + "intermediate_size": 4096, + "max_position_embeddings": 32768, + "max_window_layers": 0, + "model_type": "qwen2", + "num_attention_heads": 12, + "num_hidden_layers": 28, + "num_key_value_heads": 2, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 1000000, + "sliding_window": null, + "tie_word_embeddings": true, + "torch_dtype": "bfloat16", + "transformers_version": "4.51.0", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 151936, + "max_cache_length": 2048, + "linear_impl_type": "QNN_LPBQ_w4a16o16_G32" +} diff --git a/examples/qwen2_qnn_aot/config_3B.json b/examples/qwen2_qnn_aot/config_3B.json new file mode 100644 index 000000000..e532f0d23 --- /dev/null +++ b/examples/qwen2_qnn_aot/config_3B.json @@ -0,0 +1,31 @@ +{ + "architectures": [ + "Qwen2ForCausalLM" + ], + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "attention_bias": true, + "hidden_act": "silu", + "hidden_size": 2048, + "head_dim": 128, + "initializer_range": 0.02, + "intermediate_size": 11008, + "max_position_embeddings": 32768, + "max_window_layers": 70, + "model_type": "qwen2", + "num_attention_heads": 16, + "num_hidden_layers": 36, + "num_key_value_heads": 2, + "rms_norm_eps": 1e-06, + "rope_theta": 1000000.0, + "sliding_window": 32768, + "tie_word_embeddings": true, + "torch_dtype": "bfloat16", + "transformers_version": "4.43.1", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 151936, + "max_cache_length": 2048, + "linear_impl_type": "QNN_LPBQ_w4a16o16_G32" +} diff --git a/examples/qwen2_qnn_aot/config_7B.json b/examples/qwen2_qnn_aot/config_7B.json new file mode 100644 index 000000000..8673b3102 --- /dev/null +++ b/examples/qwen2_qnn_aot/config_7B.json @@ -0,0 +1,21 @@ +{ + "architectures": [ + "Qwen2ForCausalLM" + ], + "bos_token_id": 151643, + "eos_token_id": 151645, + "attention_bias": true, + "hidden_size": 3584, + "head_dim": 128, + "intermediate_size": 9728, + "num_attention_heads": 28, + "num_key_value_heads": 4, + "num_hidden_layers": 32, + "max_position_embeddings": 32768, + "rms_norm_eps": 1e-06, + "vocab_size": 151936, + "max_cache_length": 2048, + "rope_theta": 1000000.0, + "tie_word_embeddings": true, + "linear_impl_type": "QNN_LPBQ_w4a16o16_G32" +} diff --git a/examples/qwen2_qnn_aot/modeling_qwen2_qnn_aot.hpp b/examples/qwen2_qnn_aot/modeling_qwen2_qnn_aot.hpp new file mode 100644 index 000000000..26d57e676 --- /dev/null +++ b/examples/qwen2_qnn_aot/modeling_qwen2_qnn_aot.hpp @@ -0,0 +1,508 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/mllm.hpp" +#include "mllm/nn/Nn.hpp" +#include "mllm/nn/Module.hpp" +#include "mllm/nn/Functional.hpp" +#include "mllm/core/DataTypes.hpp" +#include "mllm/utils/Enumerate.hpp" +#include "mllm/compile/ir/Trace.hpp" +#include "mllm/models/ARGeneration.hpp" +#include "mllm/models/qwen3/configuration_qwen3.hpp" + +namespace mllm::models::qwen2 { + +namespace ptq { + +Tensor QDQ_CONSTANT(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { + std::string scale_name = qdq_name_in_pytorch + ".scale"; + std::string zp_name = qdq_name_in_pytorch + ".zero_point"; + switch (in.dtype()) { + case kFloat32: + case kUInt16PerTensorAsy: { + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", zp.impl(), true); + break; + } + default: { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't Process dtype={}", nameOfType(in.dtype())); + } + } + return in; +} + +Tensor QDQ(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { + std::string scale_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.scale"; + std::string zp_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.zero_point"; + + if (m->getModuleName().empty()) { + scale_name = qdq_name_in_pytorch + ".fake_quant.scale"; + zp_name = qdq_name_in_pytorch + ".fake_quant.zero_point"; + } else { + scale_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.scale"; + zp_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.zero_point"; + } + + switch (in.dtype()) { + case kUInt16PerTensorAsy: { + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", zp.impl(), true); + break; + } + // For Constant! + case kFloat32: { + MLLM_RT_ASSERT_EQ(in.rank(), 1); + MLLM_RT_ASSERT_EQ(in.size(-1), 1); + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", zp.impl(), true); + break; + } + default: { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't Process dtype={}", nameOfType(in.dtype())); + } + } + + return in; +} + +Tensor QDQ_KV(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { + auto scale_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.scale"; + auto zp_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.zero_point"; + + // The inputs is int8 sym. which means zero_point should be changed. + switch (in.dtype()) { + case kUInt8PerTensorSym: { + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + MLLM_RT_ASSERT_EQ(zp.item(), 0); + + // Is 128! not 127! + auto new_zp = Tensor::constant(128, kInt32).setName(zp_name).setMemType(kParamsNormal); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", new_zp.impl(), true); + break; + } + default: { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't Process dtype={}", nameOfType(in.dtype())); + } + } + + return in; +} + +Tensor QDQ_ROPE(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { + auto scale_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.scale"; + auto zp_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.zero_point"; + + (void)in.__unsafeSetDType(kUInt16PerTensorAsy); + + switch (in.dtype()) { + case kUInt16PerTensorAsy: { + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", zp.impl(), true); + break; + } + default: { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't Process dtype={}", nameOfType(in.dtype())); + } + } + + return in; +} + +} // namespace ptq + +Tensor rotateHalf(Tensor x, nn::Module* m, const std::string& qdq_name_in_pytorch) { // NOLINT + // X is [x, x, x, D] + auto D = x.size(-1); + auto x1 = x.slice({kAll, kAll, kAll, {kAll, D / 2}}, /*ssa=*/true); + auto x2 = x.slice({kAll, kAll, kAll, {D / 2, kAll}}, /*ssa=*/true); + return nn::functional::concat({ptq::QDQ(m, -x2, qdq_name_in_pytorch), x1}, -1); +} + +using vi32 = std::vector; +#define CONV2D_PROPERTY vi32{1, 1}, vi32{1, 1}, vi32{0, 0}, vi32{1, 1}, false, aops::Conv2DOpImplType::kQNN_LPBQ_w4a16o16_G32 + +// Using Conv2D to replace Linear. +// Conv2D Filter Weight is [1, 1, In, Out] +// Conv2D Activation is [N, H=1, W=Seq, In] + +class Qwen2MLP final : public nn::Module { + nn::Conv2D gate_proj_; + nn::Conv2D up_proj_; + nn::Conv2D down_proj_; + nn::SiLU silu_; + int hidden_size_; + int intermediate_size_; + + public: + Qwen2MLP() = default; + Qwen2MLP(const std::string& name, const qwen3::Qwen3Config& cfg) : nn::Module(name) { + gate_proj_ = reg("gate_proj", cfg.hidden_size, cfg.intermediate_size, CONV2D_PROPERTY); + silu_ = reg("act"); + up_proj_ = reg("up_proj", cfg.hidden_size, cfg.intermediate_size, CONV2D_PROPERTY); + down_proj_ = reg("down_proj", cfg.intermediate_size, cfg.hidden_size, CONV2D_PROPERTY); + hidden_size_ = cfg.hidden_size; + intermediate_size_ = cfg.intermediate_size; + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = inputs[0]; + x = ptq::QDQ(this, x, "up_proj_input_qdq"); + x = x.view({1, 1, -1, hidden_size_}, true); + + auto up_result = ptq::QDQ(this, up_proj_(x), "up_proj_output_qdq").view({1, -1, intermediate_size_}, true); + auto gate_result = ptq::QDQ(this, gate_proj_(x), "gate_proj_output_qdq").view({1, -1, intermediate_size_}, true); + + // SiLU + gate_result = ptq::QDQ(this, (gate_result * ptq::QDQ(this, nn::functional::sigmoid(gate_result), "sigmoid_output_qdq")), + "act_output_qdq"); + + auto o = ptq::QDQ(this, gate_result * up_result, "down_proj_input_qdq"); + o = o.view({1, 1, -1, intermediate_size_}, true); + o = down_proj_(o).view({1, -1, hidden_size_}, true); + + return {o}; + } +}; + +class Qwen2Attention final : public nn::Module { + nn::Conv2D q_proj_; + nn::Conv2D k_proj_; + nn::Conv2D v_proj_; + nn::Conv2D o_proj_; + // Qwen2 does NOT have RMSNorm after QK projection (unlike Qwen3) + nn::CausalMask mask_; + nn::Softmax softmax_; + + int hidden_size_; + int head_dim_; + int num_attention_heads_; + int num_key_value_heads_; + int num_key_value_groups_; + float scale_; + + public: + Qwen2Attention() = default; + + Qwen2Attention(const std::string& name, const qwen3::Qwen3Config& cfg) : nn::Module(name) { + hidden_size_ = cfg.hidden_size; + num_attention_heads_ = cfg.num_attention_heads; + num_key_value_heads_ = cfg.num_key_value_heads; + head_dim_ = cfg.head_dim; + num_key_value_groups_ = num_attention_heads_ / num_key_value_heads_; + scale_ = (1.f / sqrtf((float)head_dim_)); + + // clang-format off + q_proj_ = reg("q_proj", hidden_size_, head_dim_ * num_attention_heads_, CONV2D_PROPERTY); + k_proj_ = reg("k_proj", hidden_size_, head_dim_ * num_key_value_heads_, CONV2D_PROPERTY); + v_proj_ = reg("v_proj", hidden_size_, head_dim_ * num_key_value_heads_, CONV2D_PROPERTY); + o_proj_ = reg("o_proj", head_dim_ * num_attention_heads_, hidden_size_, CONV2D_PROPERTY); + // clang-format on + + mask_ = reg("mask"); + softmax_ = reg("softmax", -1); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto causal_mask = inputs[3]; + auto past_key = inputs[4]; + auto past_value = inputs[5]; + + // [B, S, D] + hidden_states = ptq::QDQ(this, hidden_states, "q_proj_input_qdq"); + hidden_states = hidden_states.view({1, 1, -1, hidden_size_}, true); + + // [B, S, H * D] + auto query_states = q_proj_(hidden_states); + auto key_states = k_proj_(hidden_states); + auto value_states = v_proj_(hidden_states); + + query_states = ptq::QDQ(this, query_states, "q_proj_output_qdq"); + key_states = ptq::QDQ(this, key_states, "k_proj_output_qdq"); + + // [B, H, S, D] + query_states = query_states.view({1, -1, num_attention_heads_, head_dim_}, /*ssa=*/true).transpose(1, 2); + key_states = key_states.view({1, -1, num_key_value_heads_, head_dim_}, /*ssa=*/true).transpose(1, 2); + value_states = value_states.view({1, -1, num_key_value_heads_, head_dim_}, /*ssa=*/true).transpose(1, 2); + + // Qwen2 does NOT have RMSNorm here (unlike Qwen3) + // Directly apply RoPE + + // [B, H, S, D] + auto cos = llm_embedding_cos.unsqueeze(1, true); + auto sin = llm_embedding_sin.unsqueeze(1, true); + query_states = + ptq::QDQ(this, + ptq::QDQ(this, query_states * cos, "q_rope_mul_0_output_qdq") + + ptq::QDQ(this, rotateHalf(query_states, this, "q_rope_neg_half_qdq") * sin, "q_rope_mul_1_output_qdq"), + "q_rope_add_0_output_qdq"); + key_states = + ptq::QDQ(this, + ptq::QDQ(this, key_states * cos, "k_rope_mul_0_output_qdq") + + ptq::QDQ(this, rotateHalf(key_states, this, "k_rope_neg_half_qdq") * sin, "k_rope_mul_1_output_qdq"), + "k_rope_add_0_output_qdq"); + + // De-quantization and quantization again + key_states = key_states.to(kFloat32); + key_states = key_states.to(kUInt8PerTensorSym); + key_states = ptq::QDQ_KV(this, key_states, "k_cast_to_int8_qdq"); + + // [B, H, D, S] + key_states = key_states.transpose(2, 3); + + // Handle KV Cache + value_states = ptq::QDQ(this, value_states, "v_cast_to_int16_qdq"); + value_states = value_states.to(kFloat32); + value_states = value_states.to(kUInt8PerTensorSym); + value_states = ptq::QDQ_KV(this, value_states, "v_cast_to_int8_qdq"); + + auto kh = nn::functional::concat({past_key, key_states}, -1); // [B, H, D, S] + auto vh = nn::functional::concat({past_value, value_states}, 2); // [B, H, S, D] + + // Repeat + kh = kh.repeat(num_key_value_groups_, 1); + vh = vh.repeat(num_key_value_groups_, 1); + + // Attn + auto attn = ptq::QDQ(this, nn::functional::matmul(query_states, kh), "qk_matmul_output_qdq"); + auto scale = Tensor::constant(scale_, kFloat32); + scale = ptq::QDQ(this, scale, "scaling_qdq"); + attn = ptq::QDQ(this, attn.mulConstant(scale), "mul_0_output_qdq"); + + // Masked Softmax + auto attn_min = ptq::QDQ(this, attn.min(-1, true), "reduce_min_output_qdq"); + auto minus_value = Tensor::constant(-20, kFloat32); + minus_value = ptq::QDQ(this, minus_value, "neg_20_qdq"); + auto attn_vv = ptq::QDQ(this, attn_min.addConstant(minus_value), "minus_0_output_qdq"); + auto zero_constant = Tensor::constant(0.f, kFloat32); + zero_constant = ptq::QDQ_CONSTANT(this, zero_constant, "constant_zero"); + attn = nn::functional::where(causal_mask.equalConstant(zero_constant), attn, attn_vv); + attn = ptq::QDQ(this, attn, "where_attn_qdq"); + attn = ptq::QDQ(this, nn::functional::softmax(attn, -1), "softmax_output_qdq"); + auto y = ptq::QDQ(this, nn::functional::matmul(attn, vh), "attn_value_matmul_output_qdq"); + y = y.transpose(1, 2).view({1, 1, -1, num_attention_heads_ * head_dim_}, /*ssa=*/true); + y = o_proj_(y).view({1, -1, hidden_size_}, true); + + return {y, key_states, value_states}; + } + + int layer_idx_; +}; + +class Qwen2Decoder final : public nn::Module { + public: + int layer_idx_; + Qwen2Attention self_attn_; + Qwen2MLP mlp_; + nn::RMSNorm input_layer_norm_; + nn::RMSNorm post_attention_layer_norm_; + + Qwen2Decoder() = default; + + Qwen2Decoder(const std::string& name, const qwen3::Qwen3Config& cfg, int layer_idx) : nn::Module(name) { + layer_idx_ = layer_idx; + self_attn_ = reg("self_attn", cfg); + mlp_ = reg("mlp", cfg); + input_layer_norm_ = reg("input_layernorm", cfg.rms_norm_eps); + post_attention_layer_norm_ = reg("post_attention_layernorm", cfg.rms_norm_eps); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto causal_mask = inputs[3]; + auto past_key = inputs[4]; + auto past_value = inputs[5]; + + auto hidden_states = inputs[0]; + if (layer_idx_ != 0) { hidden_states = ptq::QDQ(this, hidden_states, "input_layernorm_input_qdq"); } + auto residual = hidden_states; + hidden_states = input_layer_norm_(hidden_states); + auto _ = self_attn_(hidden_states, llm_embedding_sin, llm_embedding_cos, causal_mask, past_key, past_value); + hidden_states = _[0]; + hidden_states = ptq::QDQ(this, residual + ptq::QDQ(this, hidden_states, "add_0_lhs_input_qdq"), "add_0_output_qdq"); + residual = hidden_states; + hidden_states = post_attention_layer_norm_(hidden_states); + hidden_states = mlp_(hidden_states)[0]; + hidden_states = residual + ptq::QDQ(this, hidden_states, "add_1_lhs_input_qdq"); + return {hidden_states, _[1], _[2]}; + } +}; + +class Qwen2Text final : public nn::Module { + nn::ModuleListWithIdx decode_blocks_; + nn::RMSNorm norm_; + nn::Embedding embedding_; + nn::Param rope_sin_; + nn::Param rope_cos_; + int32_t num_hidden_layers_; + int32_t hidden_size_; + + public: + Qwen2Text() = default; + + Qwen2Text(const std::string& name, const qwen3::Qwen3Config& cfg) : nn::Module(name) { + num_hidden_layers_ = cfg.num_hidden_layers; + hidden_size_ = cfg.hidden_size; + decode_blocks_ = reg>("layers", cfg.num_hidden_layers, cfg); + for (auto [idx, b] : enumerate(decode_blocks_.list())) { b.self_attn_.layer_idx_ = idx; } + norm_ = reg("norm", cfg.rms_norm_eps); + embedding_ = reg("embed_tokens", cfg.vocab_size, cfg.hidden_size); + rope_sin_ = reg("mllm_max_sin_embedding", "model.mllm_max_sin_embedding"); + rope_cos_ = reg("mllm_max_cos_embedding", "model.mllm_max_cos_embedding"); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto& blocks = decode_blocks_.list(); + + // X is already embedded + auto x = embedding_(inputs[0]); + + const auto& position_ids = inputs[1]; + auto causal_mask = inputs[2]; + + // clang-format off + auto llm_embedding_sin = nn::functional::gather(ptq::QDQ_ROPE(this, rope_sin_(), "sin_embedding_input_qdq"), 1, position_ids); + auto llm_embedding_cos = nn::functional::gather(ptq::QDQ_ROPE(this, rope_cos_(), "cos_embedding_input_qdq"), 1, position_ids); + // clang-format on + + std::vector keys; + std::vector values; + for (auto [index, block] : enumerate(blocks)) { + auto pk = inputs[3 + index]; + auto pv = inputs[3 + index + num_hidden_layers_]; + auto _ = block(x, llm_embedding_sin, llm_embedding_cos, causal_mask, pk, pv); + x = _[0]; + keys.push_back(_[1]); + values.push_back(_[2]); + } + + x = norm_(ptq::QDQ(this, x, "norm_input_qdq")); + x = x.view({1, 1, -1, hidden_size_}, true); + + auto ret = std::vector{x}; + for (const auto& item : keys) { ret.push_back(item); } + for (const auto& item : values) { ret.push_back(item); } + + return ret; + } +}; + +class Qwen2ForCausalLM : public ARGeneration, public nn::Module { + public: + explicit Qwen2ForCausalLM(const qwen3::Qwen3Config& cfg) : cfg(cfg) { + eos_token_id_ = cfg.end_of_text_token_id; + max_length_ = cfg.max_cache_length; + tie_word_embeddings_ = cfg.tie_word_embeddings; + + llm = reg("model", cfg); + + if (cfg.tie_word_embeddings) { + // NOTE: + // model.lm_head.weight is quantization weights of model.embed_tokens.weight + lm_head_ = reg("lm_head", cfg.hidden_size, cfg.vocab_size, CONV2D_PROPERTY); + } + } + + IROutput trace(const ARGenerationOutputPast& input, const ARGenerationArgs& args) override { + // Things we need to return + ir::IRContext::ptr_t llm_ir = nullptr; + + auto sequence = input.at("sequence"); + auto causal_mask = input.at("causal_mask"); + + std::vector kv_caches; + + // Append Key + for (int i = 0; i < cfg.num_hidden_layers; ++i) { + auto past_key_name = "past_key_" + std::to_string(i); + if (input.count(past_key_name)) { + kv_caches.push_back(input.at(past_key_name)); + } else { + // If KV cache doesn't exist, we need to handle this case + // For now, we'll create empty tensors or handle it appropriately + // This might need adjustment based on your initialization logic + throw std::runtime_error("Missing KV cache for layer " + std::to_string(i)); + } + } + + // Append Value + for (int i = 0; i < cfg.num_hidden_layers; ++i) { + auto past_value_name = "past_value_" + std::to_string(i); + if (input.count(past_value_name)) { + kv_caches.push_back(input.at(past_value_name)); + } else { + // If KV cache doesn't exist, we need to handle this case + // For now, we'll create empty tensors or handle it appropriately + // This might need adjustment based on your initialization logic + throw std::runtime_error("Missing KV cache for layer " + std::to_string(i)); + } + } + + // Generate position_ids for the current sequence + auto batch_size = sequence.shape()[0]; + auto seq_len = sequence.shape()[1]; + + Tensor position_ids = Tensor::nil(); + if (input.count("position_ids")) { + // Use existing position_ids for decode phase + position_ids = input.at("position_ids"); + + // For decode phase, increment the last position + if (seq_len == 1) { + auto last_pos = *position_ids.offsettedPtr({position_ids.shape()[1] - 1}); + position_ids = Tensor::empty({1}, kInt32, kCPU).alloc(); + *position_ids.offsettedPtr({0}) = last_pos + 1; + } + } else { + // Generate position_ids for prefill phase + position_ids = Tensor::empty({seq_len}, kInt32, kCPU).alloc(); + auto position_ids_ptr = position_ids.ptr(); + for (int s = 0; s < seq_len; ++s) { position_ids_ptr[s] = s; } + } + + ir::lowlevel::traceStart(); + + // Build inputs for llm: sequence, llm_embedding_sin, llm_embedding_cos, causal_mask, then all KV caches + std::vector llm_inputs = {sequence, position_ids, causal_mask}; + llm_inputs.insert(llm_inputs.end(), kv_caches.begin(), kv_caches.end()); + + sequence = llm(llm_inputs)[0]; + sequence = lm_head_(ptq::QDQ(this, sequence, "lm_head_input_qdq")); + sequence = ptq::QDQ(this, sequence, "lm_head_output_qdq"); + ir::lowlevel::traceComment(" ╔═════╗ "); + ir::lowlevel::traceComment(" ║ o o ║ "); + ir::lowlevel::traceComment(" ║ ▽ ║ "); + ir::lowlevel::traceComment(" ╚═════╝ "); + ir::lowlevel::traceComment(" ║ ║ "); + ir::lowlevel::traceComment(" ╱╩╦╦╩╲ "); + llm_ir = ir::lowlevel::traceStop(); + + return {{"model", llm_ir}}; + } + + ARGenerationOutputPast forward(const ARGenerationOutputPast& input, const ARGenerationArgs& args) override { return {}; } + + private: + const qwen3::Qwen3Config& cfg; + Qwen2Text llm; + nn::Conv2D lm_head_; + bool tie_word_embeddings_; +}; + +} // namespace mllm::models::qwen2 diff --git a/examples/qwen2_qnn_aot/qnn_aot_cfg_1.5B.json b/examples/qwen2_qnn_aot/qnn_aot_cfg_1.5B.json new file mode 100644 index 000000000..3ddf11c9e --- /dev/null +++ b/examples/qwen2_qnn_aot/qnn_aot_cfg_1.5B.json @@ -0,0 +1,51 @@ +{ + "target_machine": { + "htp_arch": "V75", + "htp_chipset": "SM8650", + "htp_try_best_performance": "HtpBurst", + "htp_security_pd_session": "HtpSignedPd", + "htp_vtcm_capability_in_mb": 8 + }, + "graph_on_qnn": [ + "model" + ], + "op_on_qnn": [ + "lm_head" + ], + "split_graph": 1, + "quant_recipe": { + "llm_recipe": true, + "layers": 28, + "builtin_llm_pass": { + "model": "qwen2", + "lm_head": { + "fallback": { + "method": "LPBQ", + "sym": true, + "precision": "w4a16", + "block_size": 32 + } + }, + "linear": { + "fallback": { + "method": "LPBQ", + "sym": true, + "precision": "w4a16", + "block_size": 32 + } + }, + "kv_cache": { + "key": { + "method": "per-tensor", + "sym": true, + "precision": "w8a8" + }, + "value": { + "method": "per-tensor", + "sym": true, + "precision": "w8a8" + } + } + } + } +} diff --git a/examples/qwen2_qnn_aot/qnn_aot_cfg_3B.json b/examples/qwen2_qnn_aot/qnn_aot_cfg_3B.json new file mode 100644 index 000000000..d765567f3 --- /dev/null +++ b/examples/qwen2_qnn_aot/qnn_aot_cfg_3B.json @@ -0,0 +1,51 @@ +{ + "target_machine": { + "htp_arch": "V75", + "htp_chipset": "SM8650", + "htp_try_best_performance": "HtpBurst", + "htp_security_pd_session": "HtpSignedPd", + "htp_vtcm_capability_in_mb": 8 + }, + "graph_on_qnn": [ + "model" + ], + "op_on_qnn": [ + "lm_head" + ], + "split_graph": 1, + "quant_recipe": { + "llm_recipe": true, + "layers": 36, + "builtin_llm_pass": { + "model": "qwen2", + "lm_head": { + "fallback": { + "method": "LPBQ", + "sym": true, + "precision": "w4a16", + "block_size": 32 + } + }, + "linear": { + "fallback": { + "method": "LPBQ", + "sym": true, + "precision": "w4a16", + "block_size": 32 + } + }, + "kv_cache": { + "key": { + "method": "per-tensor", + "sym": true, + "precision": "w8a8" + }, + "value": { + "method": "per-tensor", + "sym": true, + "precision": "w8a8" + } + } + } + } +} diff --git a/examples/qwen2_qnn_aot/qnn_aot_cfg_7B.json b/examples/qwen2_qnn_aot/qnn_aot_cfg_7B.json new file mode 100644 index 000000000..85f66abd8 --- /dev/null +++ b/examples/qwen2_qnn_aot/qnn_aot_cfg_7B.json @@ -0,0 +1,51 @@ +{ + "target_machine": { + "htp_arch": "V79", + "htp_chipset": "SM8750", + "htp_try_best_performance": "HtpBurst", + "htp_security_pd_session": "HtpSignedPd", + "htp_vtcm_capability_in_mb": 8 + }, + "graph_on_qnn": [ + "model" + ], + "op_on_qnn": [ + "lm_head" + ], + "split_graph": 1, + "quant_recipe": { + "llm_recipe": true, + "layers": 32, + "builtin_llm_pass": { + "model": "qwen2", + "lm_head": { + "fallback": { + "method": "LPBQ", + "sym": true, + "precision": "w4a16", + "block_size": 32 + } + }, + "linear": { + "fallback": { + "method": "LPBQ", + "sym": true, + "precision": "w4a16", + "block_size": 32 + } + }, + "kv_cache": { + "key": { + "method": "per-tensor", + "sym": true, + "precision": "w8a8" + }, + "value": { + "method": "per-tensor", + "sym": true, + "precision": "w8a8" + } + } + } + } +} diff --git a/mllm/backends/qnn/QNNBackend.cpp b/mllm/backends/qnn/QNNBackend.cpp index 3fdd5dbd0..f0cc0a952 100644 --- a/mllm/backends/qnn/QNNBackend.cpp +++ b/mllm/backends/qnn/QNNBackend.cpp @@ -29,7 +29,7 @@ QNNBackend::QNNBackend() : Backend(kQNN, createQNNAllocator()) { QNNViewOpFactory, QNNRMSNormOpFactory, QNNTransposeOpFactory, QNNX2XOpFactory, QNNCastTypeOpFactory, QNNParamOpFactory, QNNSiLUOpFactory, QNNEmbeddingOpFactory>(); - QnnLog_Level_t qnnLogLevel = QNN_LOG_LEVEL_VERBOSE; // default QNN log level + QnnLog_Level_t qnnLogLevel = QNN_LOG_LEVEL_ERROR; // default QNN log level profilingLevel_ = ProfilingLevel::OFF; debug_ = false; // when set true, NATIVE tensor will be regared as APP_READ tensor diff --git a/mllm/backends/qnn/aot_rt/PromptProcessor.cpp b/mllm/backends/qnn/aot_rt/PromptProcessor.cpp index bee78b1bc..f9eae7157 100644 --- a/mllm/backends/qnn/aot_rt/PromptProcessor.cpp +++ b/mllm/backends/qnn/aot_rt/PromptProcessor.cpp @@ -126,6 +126,8 @@ int64_t PromptProcessor::prefill(const std::vector& prompt_tokens, i module_->setOutputTensors(output_tensors_); + MLLM_INFO("num_tokens: {}", num_tokens); + while (processed_tokens < num_tokens) { int64_t chunk_size = std::min((int64_t)config_.ar_len, num_tokens - processed_tokens); diff --git a/mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp b/mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp index 3bbd077dd..ae1fafa29 100644 --- a/mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp +++ b/mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp @@ -3,6 +3,7 @@ #include #include +#include #include "mllm/backends/qnn/aot_rt/QnnAOTRuntime.hpp" #include "mllm/core/DataTypes.hpp" @@ -10,6 +11,8 @@ #include "mllm/preprocessor/tokenizers/Unicode.hpp" #include "mllm/utils/Common.hpp" #include "mllm/utils/Log.hpp" +#include +#include namespace mllm::qnn::aot { Runner::Runner(const RunnerConfig& config, mllm::preprocessor::AutoTokenizer* tokenizer) @@ -56,16 +59,25 @@ bool Runner::load() { } void Runner::generate(const Tensor& prompt_tokens, int32_t seq_len, - const std::function& token_callback) { + const std::function& token_callback, bool perf) { MLLM_RT_ASSERT(prompt_tokens.rank() == 2 && prompt_tokens.dtype() == kInt64); int64_t start_pos = 0; std::vector prompt_tokens_i64; prompt_tokens_i64.reserve(prompt_tokens.shape()[1]); + for (int i = 0; i < prompt_tokens.shape()[1]; i++) { prompt_tokens_i64.push_back(prompt_tokens.ptr()[i]); } + // Measure prefill time + std::chrono::high_resolution_clock::time_point prefill_start, prefill_end; + if (perf) { prefill_start = std::chrono::high_resolution_clock::now(); } + + int64_t prefill_token_count = prompt_tokens_i64.size(); int64_t next_token = prompt_processor_->prefill(prompt_tokens_i64, start_pos); + prompt_tokens_i64.push_back(next_token); + + if (perf) { prefill_end = std::chrono::high_resolution_clock::now(); } if (token_callback) { std::wstring wstr = tokenizer_->detokenize(next_token); @@ -73,9 +85,47 @@ void Runner::generate(const Tensor& prompt_tokens, int32_t seq_len, token_callback(str); } - // int64_t cur_pos = prompt_tokens.size(-1); + int64_t cur_pos = prompt_tokens.size(-1); + + // Measure decode time + std::chrono::high_resolution_clock::time_point decode_start, decode_end; + if (perf) { decode_start = std::chrono::high_resolution_clock::now(); } + + int64_t generated_count = token_generator_->generate(prompt_tokens_i64, cur_pos, seq_len, token_callback, false); + + if (perf) { + decode_end = std::chrono::high_resolution_clock::now(); + + // Calculate durations in microseconds + auto prefill_duration = std::chrono::duration_cast(prefill_end - prefill_start).count(); + auto decode_duration = std::chrono::duration_cast(decode_end - decode_start).count(); - // token_generator_->generate(prompt_tokens, cur_pos, seq_len, token_callback, false); + // Calculate TPS + double prefill_tps = 0.0; + double decode_tps = 0.0; + + if (prefill_duration > 0) { prefill_tps = (double)prefill_token_count / (prefill_duration / 1000000.0); } + + if (decode_duration > 0 && generated_count > 0) { decode_tps = (double)generated_count / (decode_duration / 1000000.0); } + + // Print performance summary + fmt::print(fg(fmt::color::cyan), "\n{:=^50}\n", " Performance Summary "); + fmt::print(fg(fmt::color::white), "{:<20}: ", "Prefill time"); + fmt::print(fg(fmt::color::yellow), "{:>10.2f} μs", (double)prefill_duration); + if (prefill_tps > 0) { fmt::print(fg(fmt::color::white), " ({:>6.2f} tokens/s)", prefill_tps); } + fmt::print("\n"); + + fmt::print(fg(fmt::color::white), "{:<20}: ", "Decode time"); + fmt::print(fg(fmt::color::yellow), "{:>10.2f} μs", (double)decode_duration); + if (decode_tps > 0) { fmt::print(fg(fmt::color::white), " ({:>6.2f} tokens/s)", decode_tps); } + fmt::print("\n"); + + fmt::print(fg(fmt::color::white), "{:<20}: ", "Prefill tokens"); + fmt::print(fg(fmt::color::green), "{:>10}\n", prefill_token_count); + + fmt::print(fg(fmt::color::white), "{:<20}: ", "Decode tokens"); + fmt::print(fg(fmt::color::green), "{:>10}\n", generated_count); + } } } // namespace mllm::qnn::aot diff --git a/mllm/backends/qnn/aot_rt/QnnAOTRuntime.hpp b/mllm/backends/qnn/aot_rt/QnnAOTRuntime.hpp index 12a13e675..51ce86c70 100644 --- a/mllm/backends/qnn/aot_rt/QnnAOTRuntime.hpp +++ b/mllm/backends/qnn/aot_rt/QnnAOTRuntime.hpp @@ -22,7 +22,8 @@ class Runner { ~Runner() = default; bool load(); - void generate(const Tensor& prompt_tokens, int32_t seq_len, const std::function& token_callback); + void generate(const Tensor& prompt_tokens, int32_t seq_len, const std::function& token_callback, + bool perf = false); private: RunnerConfig config_; diff --git a/mllm/backends/qnn/aot_rt/TokenGenerator.cpp b/mllm/backends/qnn/aot_rt/TokenGenerator.cpp index d2cbbf43f..4e0884358 100644 --- a/mllm/backends/qnn/aot_rt/TokenGenerator.cpp +++ b/mllm/backends/qnn/aot_rt/TokenGenerator.cpp @@ -98,10 +98,10 @@ void TokenGenerator::prepare_io(uint64_t cur_token, int64_t start_pos) { } template -int64_t TokenGenerator::generate(std::vector& tokens, int64_t start_pos, int32_t seq_len, +int64_t TokenGenerator::generate(std::vector& tokens, int64_t start_pos, int32_t seq_len, const std::function& token_callback, bool dump_logits) { int64_t current_pos = start_pos; - uint64_t next_token = tokens.back(); + int64_t next_token = tokens.back(); int64_t generated_count = 0; // Ensure KV cache is arranged for decode (1 token) diff --git a/mllm/backends/qnn/aot_rt/TokenGenerator.hpp b/mllm/backends/qnn/aot_rt/TokenGenerator.hpp index da50836d2..b40b9725a 100644 --- a/mllm/backends/qnn/aot_rt/TokenGenerator.hpp +++ b/mllm/backends/qnn/aot_rt/TokenGenerator.hpp @@ -25,7 +25,7 @@ class TokenGenerator { virtual const std::vector& get_all_logits(); - virtual int64_t generate(std::vector& tokens, int64_t start_pos, int32_t seq_len, + virtual int64_t generate(std::vector& tokens, int64_t start_pos, int32_t seq_len, const std::function& token_callback, bool dump_logits); protected: diff --git a/pymllm/backends/qualcomm/transformers/qwen2/modeling_qwen2.py b/pymllm/backends/qualcomm/transformers/qwen2/modeling_qwen2.py new file mode 100644 index 000000000..1e32266a5 --- /dev/null +++ b/pymllm/backends/qualcomm/transformers/qwen2/modeling_qwen2.py @@ -0,0 +1,804 @@ +from typing import Callable, Optional, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.generation import GenerationMixin +from transformers.masking_utils import ( + create_causal_mask, + create_sliding_window_causal_mask, +) +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_layers import ( + GenericForQuestionAnswering, + GenericForSequenceClassification, + GenericForTokenClassification, + GradientCheckpointingLayer, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.processing_utils import Unpack +from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple +from transformers.utils.deprecation import deprecate_kwarg +from transformers.utils.generic import check_model_inputs +from transformers.models.qwen2.configuration_qwen2 import Qwen2Config + +# Replace linear, rms_norm with: +from pymllm.backends.qualcomm.transformers.core.rms_norm import QRMSNorm +from pymllm.backends.qualcomm.transformers.core.qlinear import ( + QLinearLPBQ, +) +from pymllm.backends.qualcomm.transformers.core.qdq import ( + ActivationQDQ, + FixedActivationQDQ, +) +from pymllm.backends.qualcomm.transformers.core.embedding import QEmbedding +from pymllm.backends.qualcomm.transformers.core.observer import ConcatObserver + + +class Qwen2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = QLinearLPBQ( + self.hidden_size, self.intermediate_size, bias=False, block_size=16 + ) + self.up_proj = QLinearLPBQ( + self.hidden_size, self.intermediate_size, bias=False, block_size=16 + ) + self.down_proj = QLinearLPBQ( + self.intermediate_size, self.hidden_size, bias=False, block_size=16 + ) + + # QDQ + self.up_proj_input_qdq = ActivationQDQ(bits=16) + self.up_proj_output_qdq = ActivationQDQ(bits=16) + self.gate_proj_output_qdq = ActivationQDQ(bits=16) + self.act_output_qdq = ActivationQDQ(bits=16) + self.down_proj_input_qdq = ActivationQDQ(bits=16) + # For sigmoid output: scale = 1 / (q_max - q_min + 1), zp = 0 + # For 16-bit: q_min = 0, q_max = 65535 + sigmoid_scale = 1.0 / (65535 - 0 + 1) # 1 / 65536 + self.sigmoid_output_qdq = FixedActivationQDQ( + scale=sigmoid_scale, zero_point=0, bits=16 + ) + + def forward(self, x): + x = self.up_proj_input_qdq(x) + up_result = self.up_proj_output_qdq(self.up_proj(x)) + gate_result = self.gate_proj_output_qdq(self.gate_proj(x)) + + # SiLU + gate_result = self.act_output_qdq( + gate_result * self.sigmoid_output_qdq(F.sigmoid(gate_result)) + ) + + o = self.down_proj_input_qdq(gate_result * up_result) + o = self.down_proj(o) + return o + + +def rotate_half( + x, x_observer, x2_neg_fake_quant: ActivationQDQ, concat_observer: ConcatObserver +): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return concat_observer(torch.cat((x2_neg_fake_quant(-x2), x1), dim=-1)) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Qwen2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Qwen2Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + attention_bias = getattr(config, "attention_bias", True) + self.q_proj = QLinearLPBQ( + config.hidden_size, + config.num_attention_heads * self.head_dim, + bias=attention_bias, + block_size=16, + ) + self.k_proj = QLinearLPBQ( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=attention_bias, + block_size=16, + ) + self.v_proj = QLinearLPBQ( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=attention_bias, + block_size=16, + ) + self.o_proj = QLinearLPBQ( + config.num_attention_heads * self.head_dim, + config.hidden_size, + bias=False, + block_size=16, + ) + self.sliding_window = ( + config.sliding_window + if config.layer_types[layer_idx] == "sliding_attention" + else None + ) + + # QDQ + self.q_proj_input_qdq = ActivationQDQ(bits=16) + self.k_proj_input_qdq = ActivationQDQ(bits=16) + + self.q_proj_output_qdq = ActivationQDQ(bits=16) + self.k_proj_output_qdq = ActivationQDQ(bits=16) + + self.v_proj_input_qdq = ActivationQDQ(bits=16) + self.q_rope_mul_0_output_qdq = ActivationQDQ(bits=16) + self.q_rope_mul_1_output_qdq = ActivationQDQ(bits=16) + self.q_rope_add_0_output_qdq = ActivationQDQ(bits=16) + self.k_rope_mul_0_output_qdq = ActivationQDQ(bits=16) + self.k_rope_mul_1_output_qdq = ActivationQDQ(bits=16) + self.k_rope_add_0_output_qdq = ActivationQDQ(bits=16) + + self.q_rope_concat_observer = ConcatObserver( + dtype=torch.int32, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=0, + quant_max=2**16 - 1, + eps=0.0001 / 65535, + is_dynamic=False, + ) + self.q_rope_neg_half_qdq = ActivationQDQ(bits=16) + self.k_rope_concat_observer = ConcatObserver( + dtype=torch.int32, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=0, + quant_max=2**16 - 1, + eps=0.0001 / 65535, + is_dynamic=False, + ) + self.k_rope_neg_half_qdq = ActivationQDQ(bits=16) + self.k_rope_concat_observer.add_observer( + self.k_proj_input_qdq.fake_quant.activation_post_process + ) + self.k_rope_concat_observer.add_observer( + self.k_rope_neg_half_qdq.fake_quant.activation_post_process + ) + self.q_rope_concat_observer.add_observer( + self.q_proj_input_qdq.fake_quant.activation_post_process + ) + self.q_rope_concat_observer.add_observer( + self.q_rope_neg_half_qdq.fake_quant.activation_post_process + ) + + # In qnn, is uint8 sym. + self.k_cast_to_int8_qdq = ActivationQDQ( + bits=8, qscheme=torch.per_tensor_symmetric + ) + self.v_cast_to_int8_qdq = ActivationQDQ( + bits=8, qscheme=torch.per_tensor_symmetric + ) + + self.v_cast_to_int16_qdq = ActivationQDQ(bits=16) + self.qk_matmul_output_qdq = ActivationQDQ(bits=16) + self.scaling_qdq = ActivationQDQ(bits=16) + self.neg_20_qdq = ActivationQDQ(bits=16) + self.reduce_min_output_qdq = ActivationQDQ(bits=16) + self.mul_0_output_qdq = ActivationQDQ(bits=16) + self.minus_0_output_qdq = ActivationQDQ(bits=16) + self.softmax_output_qdq = ActivationQDQ(bits=16) + self.attn_value_matmul_output_qdq = ActivationQDQ(bits=16) + self.where_attn_qdq = ActivationQDQ(bits=16) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + hidden_states = self.q_proj_input_qdq(hidden_states) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + query_states = self.q_proj_output_qdq(query_states) + + hidden_states_k = self.k_proj_input_qdq(hidden_states) + key_states = self.k_proj(hidden_states_k).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj_output_qdq(key_states) + + hidden_states_v = self.v_proj_input_qdq(hidden_states) + value_states = self.v_proj(hidden_states_v).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + cos = cos.unsqueeze(1) + sin = sin.unsqueeze(1) + query_states = self.q_rope_add_0_output_qdq( + self.q_rope_mul_0_output_qdq(query_states * cos) + + self.q_rope_mul_1_output_qdq( + rotate_half( + query_states, + self.q_proj_input_qdq.fake_quant.activation_post_process, + self.q_rope_neg_half_qdq, + self.q_rope_concat_observer, + ) + * sin + ) + ) + key_states = self.k_rope_add_0_output_qdq( + self.k_rope_mul_0_output_qdq(key_states * cos) + + self.k_rope_mul_1_output_qdq( + rotate_half( + key_states, + self.k_proj_input_qdq.fake_quant.activation_post_process, + self.k_rope_neg_half_qdq, + self.k_rope_concat_observer, + ) + * sin + ) + ) + + key_states = self.k_cast_to_int8_qdq(key_states) + value_states = self.v_cast_to_int8_qdq(self.v_cast_to_int16_qdq(value_states)) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = self.mul_0_output_qdq( + self.qk_matmul_output_qdq( + torch.matmul(query_states, key_states.transpose(2, 3)) + ) + * self.scaling_qdq( + torch.ones(1, dtype=value_states.dtype, device=value_states.device) + * self.scaling + ) + ) + + attn_min = self.reduce_min_output_qdq( + torch.amin(attn_weights, dim=-1, keepdim=True) + ) + attn_vv = self.minus_0_output_qdq( + attn_min + + self.neg_20_qdq( + torch.ones(1, dtype=value_states.dtype, device=value_states.device) + * (-20) + ) + ) + attn_weights = self.where_attn_qdq( + torch.where(attention_mask == 0, attn_weights, attn_vv) + ) + + attn_weights = self.softmax_output_qdq( + nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query_states.dtype + ) + ) + attn_output = self.attn_value_matmul_output_qdq( + torch.matmul(attn_weights, value_states) + ) + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Qwen2RMSNorm(QRMSNorm): + def __init__(self, hidden_size, eps: float = 1e-6, quant_bits=16) -> None: + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__(hidden_size, eps=eps, quant_bits=quant_bits) + + +class Qwen2DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Qwen2Config, layer_idx: int): + super().__init__() + self.layer_dix = layer_idx + self.hidden_size = config.hidden_size + + self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx) + + self.mlp = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm( + config.hidden_size, eps=config.rms_norm_eps, quant_bits=16 + ) + self.post_attention_layernorm = Qwen2RMSNorm( + config.hidden_size, eps=config.rms_norm_eps, quant_bits=16 + ) + self.attention_type = config.layer_types[layer_idx] + + # QDQ + if self.layer_dix != 0: + self.input_layernorm_input_qdq = ActivationQDQ(bits=16) + self.add_0_lhs_input_qdq = ActivationQDQ(bits=16) + self.add_0_output_qdq = ActivationQDQ(bits=16) + self.add_1_lhs_input_qdq = ActivationQDQ(bits=16) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + tuple[torch.Tensor, torch.Tensor] + ] = None, # necessary, but kept here for BC + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + if self.layer_dix != 0: + hidden_states = self.input_layernorm_input_qdq(hidden_states) + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.add_0_output_qdq( + residual + self.add_0_lhs_input_qdq(hidden_states) + ) + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.add_1_lhs_input_qdq(hidden_states) + return hidden_states + + +@auto_docstring +class Qwen2PreTrainedModel(PreTrainedModel): + config: Qwen2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen2DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": Qwen2DecoderLayer, + "attentions": Qwen2Attention, + } + + +class Qwen2RotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: Qwen2Config, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get( + "rope_type", config.rope_scaling.get("type") + ) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = ( + self.inv_freq[None, :, None] + .float() + .expand(position_ids.shape[0], -1, 1) + .to(x.device) + ) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = ( + x.device.type + if isinstance(x.device.type, str) and x.device.type != "mps" + else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = ( + inv_freq_expanded.float() @ position_ids_expanded.float() + ).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +@auto_docstring +class Qwen2Model(Qwen2PreTrainedModel): + def __init__(self, config: Qwen2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = QEmbedding( + config.vocab_size, config.hidden_size, self.padding_idx, quant_bits=16 + ) + self.layers = nn.ModuleList( + [ + Qwen2DecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = Qwen2RMSNorm( + config.hidden_size, eps=config.rms_norm_eps, quant_bits=16 + ) + self.rotary_emb = Qwen2RotaryEmbedding(config=config) + self.gradient_checkpointing = False + self.has_sliding_layers = "sliding_attention" in self.config.layer_types + + # Register sin and cos as buffers + self.register_buffer("mllm_max_sin_embedding", None) + self.register_buffer("mllm_max_cos_embedding", None) + self.sin_embedding_input_qdq = ActivationQDQ(bits=16) + self.cos_embedding_input_qdq = ActivationQDQ(bits=16) + self.norm_input_qdq = ActivationQDQ(bits=16) + + # Initialize weights and apply final processing + self.post_init() + + @torch.no_grad() + def convert_rope_for_deploy(self): + sin_scale = self.sin_embedding_input_qdq.fake_quant.scale + sin_zero_point = self.sin_embedding_input_qdq.fake_quant.zero_point + sin_quant_min = self.sin_embedding_input_qdq.fake_quant.quant_min + sin_quant_max = self.sin_embedding_input_qdq.fake_quant.quant_max + + cos_scale = self.cos_embedding_input_qdq.fake_quant.scale + cos_zero_point = self.cos_embedding_input_qdq.fake_quant.zero_point + cos_quant_min = self.cos_embedding_input_qdq.fake_quant.quant_min + cos_quant_max = self.cos_embedding_input_qdq.fake_quant.quant_max + + sin_int = torch.round( + self.mllm_max_sin_embedding / sin_scale + sin_zero_point + ).clamp(sin_quant_min, sin_quant_max) + self.mllm_max_sin_embedding = sin_int.to(torch.uint16) + + cos_int = torch.round( + self.mllm_max_cos_embedding / cos_scale + cos_zero_point + ).clamp(cos_quant_min, cos_quant_max) + self.mllm_max_cos_embedding = cos_int.to(torch.uint16) + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + } + # The sliding window alternating layers are not always activated depending on the config + if self.has_sliding_layers: + causal_mask_mapping["sliding_attention"] = ( + create_sliding_window_causal_mask(**mask_kwargs) + ) + + hidden_states = inputs_embeds + + if self.mllm_max_sin_embedding is None and self.mllm_max_cos_embedding is None: + mllm_qualcomm_max_length = kwargs.get("mllm_qualcomm_max_length", None) + assert mllm_qualcomm_max_length is not None + max_position_ids = torch.arange( + 0, + mllm_qualcomm_max_length, + dtype=position_ids.dtype, + device=position_ids.device, + ).unsqueeze(0) + self.mllm_max_cos_embedding, self.mllm_max_sin_embedding = self.rotary_emb( + hidden_states, max_position_ids + ) + self.mllm_max_cos_embedding = self.mllm_max_cos_embedding.to( + inputs_embeds.dtype + ) + self.mllm_max_sin_embedding = self.mllm_max_sin_embedding.to( + inputs_embeds.dtype + ) + self.mllm_max_cos_embedding = self.cos_embedding_input_qdq( + self.mllm_max_cos_embedding + ) + self.mllm_max_sin_embedding = self.sin_embedding_input_qdq( + self.mllm_max_sin_embedding + ) + + # create position embeddings to be shared across the decoder layers + position_embeddings = ( + self.mllm_max_cos_embedding[:, position_ids.squeeze(0), :], + self.mllm_max_sin_embedding[:, position_ids.squeeze(0), :], + ) + + # Generate causal mask based on position_ids length + # For prefill, we need a lower triangular mask + _, seq_len = input_ids.shape + if seq_len != 1: + causal_mask = 1 - torch.tril( + torch.ones(seq_len, seq_len, dtype=torch.int8, device=input_ids.device) + ) + # [1, 1, seq_len, seq_len] + causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) + else: + # [1, 1, seq_len, seq_len] + causal_mask = torch.zeros( + (1, 1, 1, seq_len), dtype=torch.int8, device=input_ids.device + ) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.norm(self.norm_input_qdq(hidden_states)) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + ) + + +@auto_docstring +class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.config = config + self.model = Qwen2Model(config) + self.vocab_size = config.vocab_size + self.lm_head = QLinearLPBQ( + config.hidden_size, config.vocab_size, bias=False, block_size=16 + ) + self.mllm_qualcomm_max_length = None + + self.lm_head_input_qdq = ActivationQDQ(bits=16) + self.lm_head_output_qdq = ActivationQDQ(bits=16) + + # Initialize weights and apply final processing + self.post_init() + + @torch.no_grad() + def copy_lm_head_weight_from_embed_tokens(self): + if self.config.tie_word_embeddings: + self.lm_head.weight.copy_(self.model.embed_tokens.weight) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen2ForCausalLM + + >>> model = Qwen2ForCausalLM.from_pretrained("meta-qwen2/Qwen2-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-qwen2/Qwen2-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + kwargs.update({"mllm_qualcomm_max_length": self.mllm_qualcomm_max_length}) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + logits = self.lm_head( + self.lm_head_input_qdq(hidden_states[:, slice_indices, :]) + ) + logits = self.lm_head_output_qdq(logits) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class Qwen2ForSequenceClassification( + GenericForSequenceClassification, Qwen2PreTrainedModel +): + pass + + +class Qwen2ForTokenClassification(GenericForTokenClassification, Qwen2PreTrainedModel): + pass + + +class Qwen2ForQuestionAnswering(GenericForQuestionAnswering, Qwen2PreTrainedModel): + base_model_prefix = ( + "transformer" # For BC, where `transformer` was used instead of `model` + ) + + +__all__ = [ + "Qwen2PreTrainedModel", + "Qwen2Model", + "Qwen2ForCausalLM", + "Qwen2RMSNorm", + "Qwen2ForSequenceClassification", + "Qwen2ForTokenClassification", + "Qwen2ForQuestionAnswering", +] diff --git a/pymllm/backends/qualcomm/transformers/qwen2/runner.py b/pymllm/backends/qualcomm/transformers/qwen2/runner.py new file mode 100644 index 000000000..d2f5be05b --- /dev/null +++ b/pymllm/backends/qualcomm/transformers/qwen2/runner.py @@ -0,0 +1,352 @@ +import torch +from tqdm import tqdm +from modelscope.msdatasets import MsDataset +from transformers import AutoTokenizer +from pymllm.backends.qualcomm.transformers.core.qdq import ( + ActivationQDQ, + FixedActivationQDQ, +) +from pymllm.backends.qualcomm.transformers.core.rms_norm import QRMSNorm +from pymllm.backends.qualcomm.transformers.core.qlinear import ( + QLinearLPBQ, + QLinearW8A16_PerChannelSym, +) +from pymllm.backends.qualcomm.transformers.core.embedding import QEmbedding +from pymllm.backends.qualcomm.transformers.qwen2.modeling_qwen2 import Qwen2ForCausalLM +from pymllm.backends.qualcomm.transformers.core.observer import ConcatObserver + + +def recompute_scale_zp(module): + """ + Callback function: Used to forcefully refresh scale and zero_point of all FakeQuantize modules after calibration. + + Problem solved: + When using ConcatObserver, min/max may be updated during forward pass, + but at the end of forward, the scale/zp stored in FakeQuantize's internal buffer are still computed from old min/max. + This function forces a calculate_qparams call to sync the latest parameters to the buffer. + + Usage: + model.apply(recompute_scale_zp) + """ + + # We mainly focus on FakeQuantize modules since they store the scale/zero_point buffers + # Note: model.apply recursively traverses all submodules, so self.fake_quant inside ActivationQDQ will also be visited + if isinstance(module, ActivationQDQ): + observer = module.fake_quant.activation_post_process + + # 2. Check if observer is valid and contains statistics + # We only care about MinMaxObserver or MovingAverageMinMaxObserver that have min_val/max_val + if hasattr(observer, "min_val") and hasattr(observer, "max_val"): + # 3. Check if data is initialized + # If min_val is still the initial inf, this layer hasn't processed data, skip to avoid errors + if observer.min_val.numel() == 0 or observer.max_val.numel() == 0: + return + if ( + torch.isinf(observer.min_val).any() + or torch.isinf(observer.max_val).any() + ): + return + + # 4. Recompute Scale and Zero Point + # calculate_qparams reads the current min_val/max_val from observer (may have been modified by ConcatObserver) + try: + scale, zero_point = observer.calculate_qparams() + except Exception as e: + # Some special Observers (e.g., FixedQParams) may not support recomputation or behave differently, safely skip + print(e) + return + + # 5. Force overwrite the computed results to FakeQuantize's Buffer + # Use copy_ to keep reference unchanged, ensuring the new values are used during export + if ( + hasattr(module.fake_quant, "scale") + and module.fake_quant.scale is not None + ): + # Ensure dimension match (handle per-channel vs per-tensor) + if module.fake_quant.scale.shape != scale.shape: + module.fake_quant.scale.resize_(scale.shape) + module.fake_quant.scale.copy_(scale) + # Try to get the registered name of module scale from _parameters or _buffers + for key, value in module.fake_quant.named_parameters(): + if value is module.fake_quant.scale: + print(f"{module._get_name()}.{key}: {module.scale}") + break + + if ( + hasattr(module.fake_quant, "zero_point") + and module.fake_quant.zero_point is not None + ): + if module.fake_quant.zero_point.shape != zero_point.shape: + module.fake_quant.zero_point.resize_(zero_point.shape) + module.fake_quant.zero_point.copy_(zero_point) + + +def validate_concat_observer_fn(module, results: list, name: str = ""): + """ + Callback function: Validate that all input_observers in ConcatObserver have consistent scale and zero_point. + + Usage: + results = [] + for name, m in model.named_modules(): + validate_concat_observer_fn(m, results, name) + """ + if not isinstance(module, ConcatObserver): + return + + input_observers = module.input_observers + if len(input_observers) == 0: + return + + # Collect scale and zero_point from all observers + scales_zps = [] + for i, observer in enumerate(input_observers): + try: + scale, zp = observer.calculate_qparams() + scales_zps.append(f"[{i}] s={scale.item():.8f} zp={zp.item()}") + except Exception: + scales_zps.append(f"[{i}] failed") + + # Print one line: scale and zp of all inputs for each concat observer + print(f"ConcatObserver [{name}]: {' | '.join(scales_zps)}") + + # Original validation logic + if len(input_observers) <= 1: + return + + # Get scale and zero_point from the first observer as reference + first_observer = input_observers[0] + try: + ref_scale, ref_zp = first_observer.calculate_qparams() + except Exception: + return + + # Check if all other observers have the same scale and zero_point + for i, observer in enumerate(input_observers[1:], start=1): + try: + scale, zp = observer.calculate_qparams() + except Exception: + results.append(f"Failed to calculate qparams for observer[{i}]") + continue + + scale_match = torch.allclose(ref_scale, scale, rtol=1e-5, atol=1e-8) + zp_match = torch.equal(ref_zp, zp) + + if not scale_match or not zp_match: + results.append( + f"observer[{i}] mismatch: ref_scale={ref_scale.item():.8f}, " + f"scale={scale.item():.8f}, ref_zp={ref_zp.item()}, zp={zp.item()}" + ) + + +def freeze_qwen2_rmsnorm_weight(m): + if isinstance(m, QRMSNorm): + m.freeze_weight() + + +def freeze_qwen2_linear_weight(m): + if isinstance(m, QLinearLPBQ) or isinstance(m, QLinearW8A16_PerChannelSym): + m.freeze_weight() + + +def freeze_qwen2_embed_tokens_weight(m): + if isinstance(m, QEmbedding): + m.freeze_weight() + + +def disable_qdq_observer(m): + if isinstance(m, ActivationQDQ): + m.disable_observer() + + +def enable_qdq_observer(m): + if isinstance(m, ActivationQDQ): + m.enable_observer() + + +def enable_fake_quant(m): + if isinstance(m, ActivationQDQ) or isinstance(m, FixedActivationQDQ): + m.enable_fakequant() + if isinstance(m, QLinearLPBQ): + m.enable_fakequant() + if isinstance(m, QRMSNorm): + m.enable_fakequant() + if isinstance(m, QEmbedding): + m.enable_fakequant() + + +def disable_fake_quant(m): + if isinstance(m, ActivationQDQ) or isinstance(m, FixedActivationQDQ): + m.disable_fakequant() + if isinstance(m, QLinearLPBQ): + m.disable_fakequant() + if isinstance(m, QRMSNorm): + m.disable_fakequant() + if isinstance(m, QEmbedding): + m.disable_fakequant() + + +def convert_weight(m): + if isinstance(m, QLinearLPBQ) or isinstance(m, QLinearW8A16_PerChannelSym): + m.convert_to_conv2d_deploy_hwio() + if isinstance(m, QRMSNorm): + m.convert_to_deploy() + if isinstance(m, QEmbedding): + m.convert_to_deploy() + + +class Qwen2Quantizer: + def __init__(self, model_path: str, mllm_qualcomm_max_length=2048): + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + self.model = Qwen2ForCausalLM.from_pretrained( + model_path, + attn_implementation="eager", + dtype=torch.float32, + ) + self.model.cuda() + self.mllm_qualcomm_max_length = mllm_qualcomm_max_length + self.model.mllm_qualcomm_max_length = mllm_qualcomm_max_length + + if self.model.config.tie_word_embeddings: + self.model.copy_lm_head_weight_from_embed_tokens() + + # PTQ All Weights. + self.model.apply(freeze_qwen2_rmsnorm_weight) + self.model.apply(freeze_qwen2_linear_weight) + self.model.apply(freeze_qwen2_embed_tokens_weight) + print("All PTQ weights preparation done.") + + def freeze_activation(self): + self.model.apply(disable_qdq_observer) + + def enable_activation_update(self): + self.model.apply(enable_qdq_observer) + + def enable_fake_quant(self): + self.model.apply(enable_fake_quant) + + def disable_fake_quant(self): + self.model.apply(disable_fake_quant) + + def compile(self): + print("Compile Start.") + self.model = torch.compile( + self.model, mode="reduce-overhead", fullgraph=False, backend="inductor" + ) + print("Compile done.") + + def infer(self, prompt: str): + messages = [{"role": "user", "content": prompt}] + text = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device) + + # conduct text completion + generated_ids = self.model.generate( + **model_inputs, + max_new_tokens=self.mllm_qualcomm_max_length + - len(model_inputs.input_ids[0]) + - 1, + do_sample=False, + temperature=None, + top_p=None, + top_k=None, + ) + output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist() + content = self.tokenizer.decode( + output_ids, skip_special_tokens=True + ).strip("\n") + + print("content:", content) + + def calibrate(self, num_samples=64, max_seq_length=512): + """ + Perform calibration using Wikipedia dataset (PTQ) + :param num_samples: Number of samples for calibration + :param max_seq_length: Maximum length for each sample (not exceeding mllm_qualcomm_max_length) + """ + print( + f"Starting calibration, samples: {num_samples}, max length: {max_seq_length}" + ) + + # 1. Enable QDQ Observer for activation values + self.enable_activation_update() + self.model.eval() + + # 2. Load Wikipedia dataset (English version example) + # Use streaming=True to download and process on the fly, without downloading the full几十G dataset + dataset = MsDataset.load( + "modelscope/wikitext", + subset_name="wikitext-103-v1", + split="train", + trust_remote_code=True, + ) + + # 3. Execute forward pass (Prefill stage) + samples_processed = 0 + + # Ensure no gradient calculation during inference + with torch.no_grad(): + pbar = tqdm(total=num_samples, desc="Calibrating") + for entry in dataset: + if samples_processed >= num_samples: + break + + if len(entry["text"].strip()) < 1024: + continue + + messages = [{"role": "user", "content": entry["text"]}] + text = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + model_inputs = self.tokenizer( + [text], + return_tensors="pt", + max_length=max_seq_length, + truncation=True, + padding=False, + ).to(self.model.device) + + # Only need Prefill stage: directly call forward + # This will trigger observer update statistics in ActivationQDQ + self.model.generate( + **model_inputs, + max_new_tokens=1, + do_sample=False, + temperature=None, + top_p=None, + top_k=None, + ) + + samples_processed += 1 + pbar.update(1) + + # 4. Close Observer, freeze calibrated quantization parameters + self.freeze_activation() + print("\nCalibration completed, activation quantization parameters frozen.") + + def convert(self): + self.model.apply(convert_weight) + self.model.model.convert_rope_for_deploy() + + def recompute_scale_zp(self): + self.model.apply(recompute_scale_zp) + + def validate_concat_observer(self): + results = [] + for name, module in self.model.named_modules(): + validate_concat_observer_fn(module, results, name) + if results: + print("ConcatObserver validation FAILED:") + for msg in results: + print(f" {msg}") + raise ValueError("ConcatObserver validation FAILED") + else: + print( + "ConcatObserver validation PASSED: all observers have matching scale and zp" + ) + print("ConcatObserver validation done.", flush=True) diff --git a/pymllm/backends/qualcomm/transformers/qwen2/train.py b/pymllm/backends/qualcomm/transformers/qwen2/train.py new file mode 100644 index 000000000..fec5fdfca --- /dev/null +++ b/pymllm/backends/qualcomm/transformers/qwen2/train.py @@ -0,0 +1,56 @@ +import os +import torch +import argparse +from safetensors.torch import save_model +from pymllm.backends.qualcomm.transformers.qwen2.runner import Qwen2Quantizer + + +def main(): + parser = argparse.ArgumentParser(description="Qwen2 Quantizer for Qualcomm backend") + parser.add_argument( + "--model_path", + type=str, + default="Qwen/Qwen2-1.5B", + help="Path to the Qwen2 model directory", + ) + parser.add_argument( + "--max_length", + type=int, + default=2048, + help="Maximum sequence length for quantization", + ) + parser.add_argument( + "--num_samples", type=int, default=128, help="Number of samples for calibration" + ) + parser.add_argument( + "--infer_text", + type=str, + default="为什么伟大不能被计划", + help="Text to run inference on", + ) + parser.add_argument( + "--output_dir", + type=str, + help="Directory to save the quantized model", + ) + + args = parser.parse_args() + + m = Qwen2Quantizer(args.model_path, mllm_qualcomm_max_length=args.max_length) + + # FIXME: Should disable or not. + m.disable_fake_quant() + m.calibrate(num_samples=args.num_samples, max_seq_length=args.max_length) + m.enable_fake_quant() + m.recompute_scale_zp() + m.validate_concat_observer() + m.infer(args.infer_text) + m.convert() + + os.makedirs(args.output_dir, exist_ok=True) + model_save_path = os.path.join(args.output_dir, "model.safetensors") + save_model(m.model, model_save_path) + + +if __name__ == "__main__": + main() From 534431807505d94bacc9a9e9bedc9d4206ffa63d Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Wed, 28 Jan 2026 12:07:32 +0000 Subject: [PATCH 2/8] feat(qwen3): Introduce Single Head Attention (SHA) optimization for QNN AOT compilation. Add new executable for SHA model and implement weight slicing utilities to enhance performance and reduce compilation time. Update CMake configuration and include necessary headers for SHA implementation. --- examples/qwen3_qnn_aot/CMakeLists.txt | 8 +- examples/qwen3_qnn_aot/compile_sha.cpp | 140 +++ .../modeling_qwen_qnn_aot_sha.hpp | 885 ++++++++++++++++++ 3 files changed, 1032 insertions(+), 1 deletion(-) create mode 100644 examples/qwen3_qnn_aot/compile_sha.cpp create mode 100644 examples/qwen3_qnn_aot/modeling_qwen_qnn_aot_sha.hpp diff --git a/examples/qwen3_qnn_aot/CMakeLists.txt b/examples/qwen3_qnn_aot/CMakeLists.txt index 18041bdcb..3c97427e1 100644 --- a/examples/qwen3_qnn_aot/CMakeLists.txt +++ b/examples/qwen3_qnn_aot/CMakeLists.txt @@ -3,8 +3,14 @@ if(MLLM_QUALCOMM_QNN_AOT_ON_X86_ENABLE) add_executable(mllm-qwen3-aot-c compile.cpp) target_link_libraries(mllm-qwen3-aot-c PRIVATE MllmRT MllmCPUBackend MllmQNNBackend) target_include_directories(mllm-qwen3-aot-c PRIVATE ${MLLM_INCLUDE_DIR}) + + # SHA (Single Head Attention) version - MHA to SHA optimization + # Similar to ExecuTorch's ConvertMhaToSha pass for better QNN AOT performance + add_executable(mllm-qwen3-aot-sha-c compile_sha.cpp) + target_link_libraries(mllm-qwen3-aot-sha-c PRIVATE MllmRT MllmCPUBackend MllmQNNBackend) + target_include_directories(mllm-qwen3-aot-sha-c PRIVATE ${MLLM_INCLUDE_DIR}) endif() add_executable(mllm-qwen3-aot-runner aot_run.cpp) target_link_libraries(mllm-qwen3-aot-runner PRIVATE MllmRT MllmCPUBackend MllmQNNBackend) -target_include_directories(mllm-qwen3-aot-runner PRIVATE ${MLLM_INCLUDE_DIR}) \ No newline at end of file +target_include_directories(mllm-qwen3-aot-runner PRIVATE ${MLLM_INCLUDE_DIR}) diff --git a/examples/qwen3_qnn_aot/compile_sha.cpp b/examples/qwen3_qnn_aot/compile_sha.cpp new file mode 100644 index 000000000..71b3ceb03 --- /dev/null +++ b/examples/qwen3_qnn_aot/compile_sha.cpp @@ -0,0 +1,140 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +// +// This file demonstrates the MHA -> SHA (Multi-Head Attention to Single-Head Attention) +// optimization for QNN AOT compilation. Similar to ExecuTorch's ConvertMhaToSha pass, +// this approach splits large Q/K/V projections into per-head projections. +// +// Benefits: +// 1. Reduces QNN AOT compilation time +// 2. Improves HTP runtime performance +// 3. Enables better memory locality per head +// +// Usage: +// ./compile_sha -m /path/to/model.mllm -c /path/to/config.json -aot_cfg /path/to/qnn_aot_cfg.json + +#include +#include +#include +#include +#include +#include + +#include "modeling_qwen_qnn_aot_sha.hpp" + +using mllm::Argparse; + +MLLM_MAIN({ + auto& help = Argparse::add("-h|--help").help("Show help message"); + auto& model_path = Argparse::add("-m|--model_path").help("Model file path."); + auto& model_cfg_path = Argparse::add("-c|--config").help("Model config file path."); + auto& qnn_aot_cfg_files = Argparse::add("-aot_cfg|--aot_config").help("AOT Config file path."); + + Argparse::parse(argc, argv); + + constexpr int N = 32; + constexpr int CL = 1024; + + if (help.isSet()) { + Argparse::printHelp(); + return 0; + } + + if (!qnn_aot_cfg_files.isSet()) { + MLLM_ERROR_EXIT(mllm::ExitCode::kCoreError, "No input aot config file path provided"); + Argparse::printHelp(); + return -1; + } + + auto model_cfg = mllm::models::qwen3::Qwen3Config(model_cfg_path.get()); + + // Load original parameters + auto params = mllm::load(model_path.get(), mllm::ModelFileVersion::kV2); + + // ============================================================================ + // Key Step: Prepare SHA parameters by slicing MHA weights + // ============================================================================ + // This is the critical step that transforms MHA weights into SHA weights. + // For each Q/K/V projection, we slice the weight matrix into per-head pieces. + // + // Original: q_proj.weight [num_heads * head_dim, hidden_size, 1, 1] + // SHA: q_proj.{h}.weight [head_dim, hidden_size, 1, 1] for each head h + // + mllm::print("Preparing SHA parameters (slicing MHA weights)..."); + mllm::models::qwen3::sha::prepareParametersForSHA(params, model_cfg); + mllm::print("SHA parameters prepared."); + + // Create SHA model + auto model = mllm::models::qwen3::sha::Qwen3ForCausalLM_SHA(model_cfg); + + // Add params for causal mask + { + params->push("causal_mask.scale", mllm::Tensor::constant(0.001 / 65536.f, mllm::kFloat32)); + params->push("causal_mask.zero_point", mllm::Tensor::constant(65536, mllm::kInt8)); + params->push("constant_zero.scale", mllm::Tensor::constant(0.001 / 65536.f, mllm::kFloat32)); + params->push("constant_zero.zero_point", mllm::Tensor::constant(65536, mllm::kInt8)); + } + model.load(params); + + // Sequence: [B, N] + // past_key_i: [B, H, D, CL-N] for each layer i + // past_value_i: [B, H, CL-N, D] for each layer i + // causal_mask: [B, 1, N, CL] + auto sequence = mllm::Tensor::zeros({1, N}, mllm::kInt32); + auto causal_mask = mllm::Tensor::zeros({1, 1, N, CL}, mllm::kUInt16); + + // NOTE: force set causal mask to UInt16Asy + // NOTE: Attach scale and zero point to causal mask + { + causal_mask = causal_mask.__unsafeSetDType(mllm::kUInt16PerTensorAsy); + causal_mask.attach("scale", params->pull("causal_mask.scale").impl(), true); + causal_mask.attach("zero_point", params->pull("causal_mask.zero_point").impl(), true); + } + + // Create KV cache inputs for all layers + std::unordered_map trace_inputs; + trace_inputs["sequence"] = sequence; + trace_inputs["causal_mask"] = causal_mask; + + for (int i = 0; i < model_cfg.num_hidden_layers; ++i) { + auto past_key_name = "past_key_" + std::to_string(i); + auto past_value_name = "past_value_" + std::to_string(i); + + // clang-format off + trace_inputs[past_key_name] = mllm::Tensor::empty({ + 1, + model_cfg.num_key_value_heads, + model_cfg.head_dim, + CL - N, + }, mllm::kUInt8PerTensorSym); + trace_inputs[past_value_name] = mllm::Tensor::empty({1, model_cfg.num_key_value_heads, CL - N, model_cfg.head_dim}, mllm::kUInt8PerTensorSym); + + trace_inputs[past_key_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_key_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + + trace_inputs[past_value_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_value_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + // clang-format on + } + + mllm::print("Tracing SHA model..."); + auto ir = model.trace(trace_inputs, {}); + mllm::print("SHA model traced successfully."); + + // Create Qnn AOT Model + auto qnn_aot_env = mllm::qnn::aot::QnnAOTEnv("/opt/qcom/aistack/qairt/2.41.0.251128/lib/x86_64-linux-clang/", + mllm::qnn::aot::parseQcomTargetMachineFromJSONFile(qnn_aot_cfg_files.get())); + + mllm::ir::PassManager pm(ir["model"]); + pm.reg(mllm::qnn::aot::createQnnAOTLoweringPipeline(&qnn_aot_env, qnn_aot_cfg_files.get(), params)); + pm.run(); + + mllm::redirect("qwen3_qnn_aot_sha.mir", [&]() { mllm::print(ir["model"]); }); + + qnn_aot_env.saveContext("context.0", "qwen3-1.7B-lpbq-sha.bin"); + + mllm::print("SHA compilation completed successfully!"); + mllm::print("Output files:"); + mllm::print(" - qwen3_qnn_aot_sha.mir (IR dump)"); + mllm::print(" - qwen3-1.7B-lpbq-sha.bin (QNN context)"); +}); diff --git a/examples/qwen3_qnn_aot/modeling_qwen_qnn_aot_sha.hpp b/examples/qwen3_qnn_aot/modeling_qwen_qnn_aot_sha.hpp new file mode 100644 index 000000000..e7f0b4df7 --- /dev/null +++ b/examples/qwen3_qnn_aot/modeling_qwen_qnn_aot_sha.hpp @@ -0,0 +1,885 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +// +// This file implements the MHA -> SHA (Multi-Head Attention to Single-Head Attention) +// transformation for QNN AOT compilation, similar to ExecuTorch's ConvertMhaToSha pass. +// +// The optimization splits large Q/K/V projections into per-head projections, +// allowing QNN to optimize each head separately, reducing AOT compilation time +// and improving HTP performance. + +#pragma once + +#include "mllm/mllm.hpp" +#include "mllm/nn/Nn.hpp" +#include "mllm/nn/Module.hpp" +#include "mllm/nn/Functional.hpp" +#include "mllm/core/DataTypes.hpp" +#include "mllm/utils/Enumerate.hpp" +#include "mllm/compile/ir/Trace.hpp" +#include "mllm/models/ARGeneration.hpp" +#include "mllm/models/qwen3/configuration_qwen3.hpp" + +namespace mllm::models::qwen3::sha { + +namespace ptq { + +Tensor QDQ_CONSTANT(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { + std::string scale_name = qdq_name_in_pytorch + ".scale"; + std::string zp_name = qdq_name_in_pytorch + ".zero_point"; + switch (in.dtype()) { + case kFloat32: + case kUInt16PerTensorAsy: { + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", zp.impl(), true); + break; + } + default: { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't Process dtype={}", nameOfType(in.dtype())); + } + } + return in; +} + +Tensor QDQ(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { + std::string scale_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.scale"; + std::string zp_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.zero_point"; + + if (m->getModuleName().empty()) { + scale_name = qdq_name_in_pytorch + ".fake_quant.scale"; + zp_name = qdq_name_in_pytorch + ".fake_quant.zero_point"; + } else { + scale_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.scale"; + zp_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.zero_point"; + } + + switch (in.dtype()) { + case kUInt16PerTensorAsy: { + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", zp.impl(), true); + break; + } + // For Constant! + case kFloat32: { + MLLM_RT_ASSERT_EQ(in.rank(), 1); + MLLM_RT_ASSERT_EQ(in.size(-1), 1); + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", zp.impl(), true); + break; + } + default: { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't Process dtype={}", nameOfType(in.dtype())); + } + } + + return in; +} + +Tensor QDQ_KV(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { + auto scale_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.scale"; + auto zp_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.zero_point"; + + // The inputs is int8 sym. which means zero_point should be changed. + switch (in.dtype()) { + case kUInt8PerTensorSym: { + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + MLLM_RT_ASSERT_EQ(zp.item(), 0); + + // Is 128! not 127! + auto new_zp = Tensor::constant(128, kInt32).setName(zp_name).setMemType(kParamsNormal); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", new_zp.impl(), true); + break; + } + default: { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't Process dtype={}", nameOfType(in.dtype())); + } + } + + return in; +} + +Tensor QDQ_ROPE(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { + auto scale_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.scale"; + auto zp_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.zero_point"; + + (void)in.__unsafeSetDType(kUInt16PerTensorAsy); + + switch (in.dtype()) { + case kUInt16PerTensorAsy: { + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", zp.impl(), true); + break; + } + default: { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't Process dtype={}", nameOfType(in.dtype())); + } + } + + return in; +} + +} // namespace ptq + +Tensor rotateHalf(Tensor x, nn::Module* m, const std::string& qdq_name_in_pytorch) { // NOLINT + // X is [x, x, x, D] + auto D = x.size(-1); + auto x1 = x.slice({kAll, kAll, kAll, {kAll, D / 2}}, /*ssa=*/true); + auto x2 = x.slice({kAll, kAll, kAll, {D / 2, kAll}}, /*ssa=*/true); + return nn::functional::concat({ptq::QDQ(m, -x2, qdq_name_in_pytorch), x1}, -1); +} + +using vi32 = std::vector; +#define CONV2D_PROPERTY vi32{1, 1}, vi32{1, 1}, vi32{0, 0}, vi32{1, 1}, false, aops::Conv2DOpImplType::kQNN_LPBQ_w4a16o16_G16 + +// Using Conv2D to replace Linear. +// Conv2D Filter Weight is [1, 1, In, Out] +// Conv2D Activation is [N, H=1, W=Seq, In] + +class Qwen3MLP final : public nn::Module { + nn::Conv2D gate_proj_; + nn::Conv2D up_proj_; + nn::Conv2D down_proj_; + nn::SiLU silu_; + int hidden_size_; + int intermediate_size_; + + public: + Qwen3MLP() = default; + Qwen3MLP(const std::string& name, const Qwen3Config& cfg) : nn::Module(name) { + gate_proj_ = reg("gate_proj", cfg.hidden_size, cfg.intermediate_size, CONV2D_PROPERTY); + silu_ = reg("act"); + up_proj_ = reg("up_proj", cfg.hidden_size, cfg.intermediate_size, CONV2D_PROPERTY); + down_proj_ = reg("down_proj", cfg.intermediate_size, cfg.hidden_size, CONV2D_PROPERTY); + hidden_size_ = cfg.hidden_size; + intermediate_size_ = cfg.intermediate_size; + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = inputs[0]; + x = ptq::QDQ(this, x, "up_proj_input_qdq"); + x = x.view({1, 1, -1, hidden_size_}, true); + + auto up_result = ptq::QDQ(this, up_proj_(x), "up_proj_output_qdq").view({1, -1, intermediate_size_}, true); + auto gate_result = ptq::QDQ(this, gate_proj_(x), "gate_proj_output_qdq").view({1, -1, intermediate_size_}, true); + + // SiLU + gate_result = ptq::QDQ(this, (gate_result * ptq::QDQ(this, nn::functional::sigmoid(gate_result), "sigmoid_output_qdq")), + "act_output_qdq"); + + auto o = ptq::QDQ(this, gate_result * up_result, "down_proj_input_qdq"); + o = o.view({1, 1, -1, intermediate_size_}, true); + o = down_proj_(o).view({1, -1, hidden_size_}, true); + + return {o}; + } +}; + +// ============================================================================ +// Single Head Attention (SHA) Implementation +// ============================================================================ +// +// This class implements SHA where each attention head has its own separate +// Conv2D projection, instead of one large MHA projection that processes all +// heads at once. +// +// Benefits: +// 1. Reduces QNN AOT compilation time +// 2. Improves HTP runtime performance +// 3. Enables better memory locality per head +// + +class Qwen3AttentionSHA final : public nn::Module { + // Per-head Q projections: num_attention_heads Conv2D(hidden_size, head_dim) + std::vector q_projs_; + // Per-head K projections: num_key_value_heads Conv2D(hidden_size, head_dim) + std::vector k_projs_; + // Per-head V projections: num_key_value_heads Conv2D(hidden_size, head_dim) + std::vector v_projs_; + // Single O projection remains unchanged (concatenated heads -> hidden_size) + nn::Conv2D o_proj_; + + // Per-head RMSNorm for Q + std::vector rms_norm_q_; + // Per-head RMSNorm for K (shared across GQA groups) + std::vector rms_norm_k_; + + nn::CausalMask mask_; + nn::Softmax softmax_; + + int hidden_size_; + int head_dim_; + int num_attention_heads_; + int num_key_value_heads_; + int num_key_value_groups_; + float scale_; + + public: + Qwen3AttentionSHA() = default; + + Qwen3AttentionSHA(const std::string& name, const Qwen3Config& cfg) : nn::Module(name) { + hidden_size_ = cfg.hidden_size; + num_attention_heads_ = cfg.num_attention_heads; + num_key_value_heads_ = cfg.num_key_value_heads; + head_dim_ = cfg.head_dim; + num_key_value_groups_ = num_attention_heads_ / num_key_value_heads_; + scale_ = (1.f / sqrtf((float)head_dim_)); + + // Register per-head Q projections + for (int h = 0; h < num_attention_heads_; ++h) { + q_projs_.emplace_back(reg("q_proj." + std::to_string(h), hidden_size_, head_dim_, CONV2D_PROPERTY)); + } + + // Register per-head K projections + for (int h = 0; h < num_key_value_heads_; ++h) { + k_projs_.emplace_back(reg("k_proj." + std::to_string(h), hidden_size_, head_dim_, CONV2D_PROPERTY)); + } + + // Register per-head V projections + for (int h = 0; h < num_key_value_heads_; ++h) { + v_projs_.emplace_back(reg("v_proj." + std::to_string(h), hidden_size_, head_dim_, CONV2D_PROPERTY)); + } + + // O projection remains the same (combines all heads) + o_proj_ = reg("o_proj", head_dim_ * num_attention_heads_, hidden_size_, CONV2D_PROPERTY); + + // Per-head Q RMSNorm + for (int h = 0; h < num_attention_heads_; ++h) { + rms_norm_q_.emplace_back(reg("q_norm." + std::to_string(h), cfg.rms_norm_eps)); + } + + // Per-head K RMSNorm (for KV heads) + for (int h = 0; h < num_key_value_heads_; ++h) { + rms_norm_k_.emplace_back(reg("k_norm." + std::to_string(h), cfg.rms_norm_eps)); + } + + mask_ = reg("mask"); + softmax_ = reg("softmax", -1); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto causal_mask = inputs[3]; + const auto& past_key = inputs[4]; // [B, num_kv_heads, D, S] + const auto& past_value = inputs[5]; // [B, num_kv_heads, S, D] + + // [B, S, D] - shared QDQ for input to all Q/K/V projections + hidden_states = ptq::QDQ(this, hidden_states, "q_proj_input_qdq"); + hidden_states = hidden_states.view({1, 1, -1, hidden_size_}, true); + + // ======================================================================== + // Per-head Q/K/V Projections + // ======================================================================== + // This is the key SHA optimization: instead of one large projection for all + // heads, we have separate smaller projections per head. + + // Compute per-head Q projections: each outputs [1, 1, S, head_dim] + std::vector query_states_per_head; + for (int h = 0; h < num_attention_heads_; ++h) { + auto q_h = q_projs_[h](hidden_states); + q_h = q_h.view({1, -1, 1, head_dim_}, /*ssa=*/true).transpose(1, 2); + query_states_per_head.push_back(q_h); + } + + // Compute per-head K projections + std::vector key_states_per_head; + for (int h = 0; h < num_key_value_heads_; ++h) { + auto k_h = k_projs_[h](hidden_states); + k_h = k_h.view({1, -1, 1, head_dim_}, /*ssa=*/true).transpose(1, 2); + key_states_per_head.push_back(k_h); + } + + // Compute per-head V projections + std::vector value_states_per_head; + for (int h = 0; h < num_key_value_heads_; ++h) { + auto v_h = v_projs_[h](hidden_states); + v_h = v_h.view({1, -1, 1, head_dim_}, /*ssa=*/true).transpose(1, 2); + value_states_per_head.push_back(v_h); + } + + // ======================================================================== + // Per-head RMSNorm and RoPE + // ======================================================================== + + auto cos = llm_embedding_cos.unsqueeze(1, true); + auto sin = llm_embedding_sin.unsqueeze(1, true); + + // Apply RMSNorm and RoPE per Q head + for (int h = 0; h < num_attention_heads_; ++h) { + std::string h_str = std::to_string(h); + query_states_per_head[h] = rms_norm_q_[h](ptq::QDQ(this, query_states_per_head[h], "q_norm_input_qdq_h" + h_str)); + query_states_per_head[h] = ptq::QDQ(this, query_states_per_head[h], "q_norm_output_qdq_h" + h_str); + + // Apply RoPE + query_states_per_head[h] = + ptq::QDQ(this, + ptq::QDQ(this, query_states_per_head[h] * cos, "q_rope_mul_0_output_qdq_h" + h_str) + + ptq::QDQ(this, rotateHalf(query_states_per_head[h], this, "q_rope_neg_half_qdq_h" + h_str) * sin, + "q_rope_mul_1_output_qdq_h" + h_str), + "q_rope_add_0_output_qdq_h" + h_str); + } + + // Apply RMSNorm and RoPE per K head + for (int h = 0; h < num_key_value_heads_; ++h) { + std::string h_str = std::to_string(h); + key_states_per_head[h] = rms_norm_k_[h](ptq::QDQ(this, key_states_per_head[h], "k_norm_input_qdq_h" + h_str)); + key_states_per_head[h] = ptq::QDQ(this, key_states_per_head[h], "k_norm_output_qdq_h" + h_str); + + // Apply RoPE + key_states_per_head[h] = + ptq::QDQ(this, + ptq::QDQ(this, key_states_per_head[h] * cos, "k_rope_mul_0_output_qdq_h" + h_str) + + ptq::QDQ(this, rotateHalf(key_states_per_head[h], this, "k_rope_neg_half_qdq_h" + h_str) * sin, + "k_rope_mul_1_output_qdq_h" + h_str), + "k_rope_add_0_output_qdq_h" + h_str); + } + + // ======================================================================== + // KV Cache Processing per head + // ======================================================================== + + std::vector new_key_per_head; + std::vector new_value_per_head; + std::vector key_cache_per_head; + std::vector value_cache_per_head; + + for (int h = 0; h < num_key_value_heads_; ++h) { + std::string h_str = std::to_string(h); + + // K: De-quantize and re-quantize to int8 + auto k_h = key_states_per_head[h].to(kFloat32); + k_h = k_h.to(kUInt8PerTensorSym); + k_h = ptq::QDQ_KV(this, k_h, "k_cast_to_int8_qdq_h" + h_str); + k_h = k_h.transpose(2, 3); // [B, 1, D, S] + + // V: Quantize to int16 then int8 + auto v_h = ptq::QDQ(this, value_states_per_head[h], "v_cast_to_int16_qdq_h" + h_str); + v_h = v_h.to(kFloat32); + v_h = v_h.to(kUInt8PerTensorSym); + v_h = ptq::QDQ_KV(this, v_h, "v_cast_to_int8_qdq_h" + h_str); + + new_key_per_head.push_back(k_h); + new_value_per_head.push_back(v_h); + + // Slice past cache for this head + auto past_k_h = past_key.slice({kAll, {h, h + 1}, kAll, kAll}, true); + auto past_v_h = past_value.slice({kAll, {h, h + 1}, kAll, kAll}, true); + + // Concat current with past + key_cache_per_head.push_back(nn::functional::concat({past_k_h, k_h}, -1)); + value_cache_per_head.push_back(nn::functional::concat({past_v_h, v_h}, 2)); + } + + // ======================================================================== + // Per-head Attention Computation + // ======================================================================== + // Each Q head computes attention with its corresponding KV head (GQA support) + + std::vector attn_outputs; + + for (int h = 0; h < num_attention_heads_; ++h) { + std::string h_str = std::to_string(h); + int kv_head_idx = h / num_key_value_groups_; + + const auto& q_h = query_states_per_head[h]; + const auto& kh = key_cache_per_head[kv_head_idx]; + const auto& vh = value_cache_per_head[kv_head_idx]; + + // QK^T + auto attn = ptq::QDQ(this, nn::functional::matmul(q_h, kh), "qk_matmul_output_qdq_h" + h_str); + + // Scale + auto scale = Tensor::constant(scale_, kFloat32); + scale = ptq::QDQ(this, scale, "scaling_qdq_h" + h_str); + attn = ptq::QDQ(this, attn.mulConstant(scale), "mul_0_output_qdq_h" + h_str); + + // Masked Softmax + auto attn_min = ptq::QDQ(this, attn.min(-1, true), "reduce_min_output_qdq_h" + h_str); + auto minus_value = Tensor::constant(-20, kFloat32); + minus_value = ptq::QDQ(this, minus_value, "neg_20_qdq_h" + h_str); + auto attn_vv = ptq::QDQ(this, attn_min.addConstant(minus_value), "minus_0_output_qdq_h" + h_str); + auto zero_constant = Tensor::constant(0.f, kFloat32); + zero_constant = ptq::QDQ_CONSTANT(this, zero_constant, "constant_zero"); + attn = nn::functional::where(causal_mask.equalConstant(zero_constant), attn, attn_vv); + attn = ptq::QDQ(this, attn, "where_attn_qdq_h" + h_str); + attn = ptq::QDQ(this, nn::functional::softmax(attn, -1), "softmax_output_qdq_h" + h_str); + + // Output: attn @ V + auto y_h = ptq::QDQ(this, nn::functional::matmul(attn, vh), "attn_value_matmul_output_qdq_h" + h_str); + attn_outputs.push_back(y_h); + } + + // ======================================================================== + // Concatenate and Output Projection + // ======================================================================== + + // Concat all head outputs: [B, num_heads, S, D] + auto y = nn::functional::concat(attn_outputs, 1); + + // Reshape and apply O projection + y = y.transpose(1, 2).view({1, 1, -1, num_attention_heads_ * head_dim_}, /*ssa=*/true); + y = o_proj_(y).view({1, -1, hidden_size_}, true); + + // Concat new keys and values back to original format + auto new_key = nn::functional::concat(new_key_per_head, 1); + auto new_value = nn::functional::concat(new_value_per_head, 1); + + return {y, new_key, new_value}; + } + + int layer_idx_; +}; + +class Qwen3DecoderSHA final : public nn::Module { + public: + int layer_idx_; + Qwen3AttentionSHA self_attn_; + Qwen3MLP mlp_; + nn::RMSNorm input_layer_norm_; + nn::RMSNorm post_attention_layer_norm_; + + Qwen3DecoderSHA() = default; + + Qwen3DecoderSHA(const std::string& name, const Qwen3Config& cfg, int layer_idx) : nn::Module(name) { + layer_idx_ = layer_idx; + self_attn_ = reg("self_attn", cfg); + mlp_ = reg("mlp", cfg); + input_layer_norm_ = reg("input_layernorm", cfg.rms_norm_eps); + post_attention_layer_norm_ = reg("post_attention_layernorm", cfg.rms_norm_eps); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto causal_mask = inputs[3]; + auto past_key = inputs[4]; + auto past_value = inputs[5]; + + auto hidden_states = inputs[0]; + if (layer_idx_ != 0) { hidden_states = ptq::QDQ(this, hidden_states, "input_layernorm_input_qdq"); } + auto residual = hidden_states; + hidden_states = input_layer_norm_(hidden_states); + auto _ = self_attn_(hidden_states, llm_embedding_sin, llm_embedding_cos, causal_mask, past_key, past_value); + hidden_states = _[0]; + hidden_states = ptq::QDQ(this, residual + ptq::QDQ(this, hidden_states, "add_0_lhs_input_qdq"), "add_0_output_qdq"); + residual = hidden_states; + hidden_states = post_attention_layer_norm_(hidden_states); + hidden_states = mlp_(hidden_states)[0]; + hidden_states = residual + ptq::QDQ(this, hidden_states, "add_1_lhs_input_qdq"); + return {hidden_states, _[1], _[2]}; + } +}; + +class Qwen3TextSHA final : public nn::Module { + nn::ModuleListWithIdx decode_blocks_; + nn::RMSNorm norm_; + nn::Embedding embedding_; + nn::Param rope_sin_; + nn::Param rope_cos_; + int32_t num_hidden_layers_; + int32_t hidden_size_; + + public: + Qwen3TextSHA() = default; + + Qwen3TextSHA(const std::string& name, const Qwen3Config& cfg) : nn::Module(name) { + num_hidden_layers_ = cfg.num_hidden_layers; + hidden_size_ = cfg.hidden_size; + decode_blocks_ = reg>("layers", cfg.num_hidden_layers, cfg); + for (auto [idx, b] : enumerate(decode_blocks_.list())) { b.self_attn_.layer_idx_ = idx; } + norm_ = reg("norm", cfg.rms_norm_eps); + embedding_ = reg("embed_tokens", cfg.vocab_size, cfg.hidden_size); + rope_sin_ = reg("mllm_max_sin_embedding", "model.mllm_max_sin_embedding"); + rope_cos_ = reg("mllm_max_cos_embedding", "model.mllm_max_cos_embedding"); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto& blocks = decode_blocks_.list(); + + // X is already embedded + auto x = embedding_(inputs[0]); + + const auto& position_ids = inputs[1]; + auto causal_mask = inputs[2]; + + // clang-format off + auto llm_embedding_sin = nn::functional::gather(ptq::QDQ_ROPE(this, rope_sin_(), "sin_embedding_input_qdq"), 1, position_ids); + auto llm_embedding_cos = nn::functional::gather(ptq::QDQ_ROPE(this, rope_cos_(), "cos_embedding_input_qdq"), 1, position_ids); + // clang-format on + + std::vector keys; + std::vector values; + for (auto [index, block] : enumerate(blocks)) { + auto pk = inputs[3 + index]; + auto pv = inputs[3 + index + num_hidden_layers_]; + auto _ = block(x, llm_embedding_sin, llm_embedding_cos, causal_mask, pk, pv); + x = _[0]; + keys.push_back(_[1]); + values.push_back(_[2]); + } + + x = norm_(ptq::QDQ(this, x, "norm_input_qdq")); + x = x.view({1, 1, -1, hidden_size_}, true); + + auto ret = std::vector{x}; + for (const auto& item : keys) { ret.push_back(item); } + for (const auto& item : values) { ret.push_back(item); } + + return ret; + } +}; + +class Qwen3ForCausalLM_SHA : public ARGeneration, public nn::Module { + public: + explicit Qwen3ForCausalLM_SHA(const Qwen3Config& cfg) : cfg(cfg) { + eos_token_id_ = cfg.end_of_text_token_id; + max_length_ = cfg.max_cache_length; + tie_word_embeddings_ = cfg.tie_word_embeddings; + + llm = reg("model", cfg); + + if (cfg.tie_word_embeddings) { + // NOTE: + // model.lm_head.weight is quantization weights of model.embed_tokens.weight + lm_head_ = reg("lm_head", cfg.hidden_size, cfg.vocab_size, CONV2D_PROPERTY); + } + } + + IROutput trace(const ARGenerationOutputPast& input, const ARGenerationArgs& args) override { + // Things we need to return + ir::IRContext::ptr_t llm_ir = nullptr; + + auto sequence = input.at("sequence"); + auto causal_mask = input.at("causal_mask"); + + std::vector kv_caches; + + // Append Key + for (int i = 0; i < cfg.num_hidden_layers; ++i) { + auto past_key_name = "past_key_" + std::to_string(i); + if (input.count(past_key_name)) { + kv_caches.push_back(input.at(past_key_name)); + } else { + // If KV cache doesn't exist, we need to handle this case + // For now, we'll create empty tensors or handle it appropriately + // This might need adjustment based on your initialization logic + throw std::runtime_error("Missing KV cache for layer " + std::to_string(i)); + } + } + + // Append Value + for (int i = 0; i < cfg.num_hidden_layers; ++i) { + auto past_value_name = "past_value_" + std::to_string(i); + if (input.count(past_value_name)) { + kv_caches.push_back(input.at(past_value_name)); + } else { + // If KV cache doesn't exist, we need to handle this case + // For now, we'll create empty tensors or handle it appropriately + // This might need adjustment based on your initialization logic + throw std::runtime_error("Missing KV cache for layer " + std::to_string(i)); + } + } + + // Generate position_ids for the current sequence + auto batch_size = sequence.shape()[0]; + auto seq_len = sequence.shape()[1]; + + Tensor position_ids = Tensor::nil(); + if (input.count("position_ids")) { + // Use existing position_ids for decode phase + position_ids = input.at("position_ids"); + + // For decode phase, increment the last position + if (seq_len == 1) { + auto last_pos = *position_ids.offsettedPtr({position_ids.shape()[1] - 1}); + position_ids = Tensor::empty({1}, kInt32, kCPU).alloc(); + *position_ids.offsettedPtr({0}) = last_pos + 1; + } + } else { + // Generate position_ids for prefill phase + position_ids = Tensor::empty({seq_len}, kInt32, kCPU).alloc(); + auto position_ids_ptr = position_ids.ptr(); + for (int s = 0; s < seq_len; ++s) { position_ids_ptr[s] = s; } + } + + ir::lowlevel::traceStart(); + + // Build inputs for llm: sequence, llm_embedding_sin, llm_embedding_cos, causal_mask, then all KV caches + std::vector llm_inputs = {sequence, position_ids, causal_mask}; + llm_inputs.insert(llm_inputs.end(), kv_caches.begin(), kv_caches.end()); + + sequence = llm(llm_inputs)[0]; + sequence = lm_head_(ptq::QDQ(this, sequence, "lm_head_input_qdq")); + sequence = ptq::QDQ(this, sequence, "lm_head_output_qdq"); + ir::lowlevel::traceComment(" ╔═════╗ "); + ir::lowlevel::traceComment(" ║ o o ║ "); + ir::lowlevel::traceComment(" ║ ▽ ║ "); + ir::lowlevel::traceComment(" ╚═════╝ "); + ir::lowlevel::traceComment(" ║ ║ "); + ir::lowlevel::traceComment(" ╱╩╦╦╩╲ "); + llm_ir = ir::lowlevel::traceStop(); + + return {{"model", llm_ir}}; + } + + ARGenerationOutputPast forward(const ARGenerationOutputPast& input, const ARGenerationArgs& args) override { return {}; } + + private: + const Qwen3Config& cfg; + Qwen3TextSHA llm; + nn::Conv2D lm_head_; + bool tie_word_embeddings_; +}; + +// ============================================================================ +// Weight Slicing Utilities for SHA +// ============================================================================ +// +// These functions are used during the compile phase to slice the original +// MHA weights into per-head SHA weights. +// + +/** + * @brief Prepares the parameter file by slicing MHA weights into SHA weights. + * + * This function takes the original parameter file with MHA weights and creates + * new per-head weights for the SHA model. + * + * Original weight layout for Conv2D: [out_channels, in_channels, 1, 1] + * - q_proj.weight: [num_heads * head_dim, hidden_size, 1, 1] + * - k_proj.weight: [num_kv_heads * head_dim, hidden_size, 1, 1] + * - v_proj.weight: [num_kv_heads * head_dim, hidden_size, 1, 1] + * + * For LPBQ quantization, also need to slice: + * - scale1: flattened scale for block quantization + * - scale2: flattened scale for block quantization + * + * SHA weight layout: + * - q_proj.{h}.weight: [head_dim, hidden_size, 1, 1] for h in [0, num_heads) + * - q_proj.{h}.scale1: sliced scale for head h + * - q_proj.{h}.scale2: sliced scale for head h + * - Similar for k_proj and v_proj + */ +inline void prepareParametersForSHA(const ParameterFile::ptr_t& params, const Qwen3Config& cfg) { + int num_heads = cfg.num_attention_heads; + int num_kv_heads = cfg.num_key_value_heads; + int head_dim = cfg.head_dim; + int num_layers = cfg.num_hidden_layers; + + // Helper lambda to slice and push Conv2D params (weight, scale1, scale2) + // For LPBQ, scale1 and scale2 are flattened along the output channel dimension + // Scale size per head = total_scale_size / num_heads_for_this_proj + auto sliceAndPushConv2DParams = [&](const std::string& orig_name_prefix, const std::string& new_name_prefix, + int total_out_channels, int out_channels_per_head, int num_splits) { + // Process weight: HWIO format [H=1, W=1, In_channels, Out_channels] + // For q_proj: [1, 1, hidden_size, num_heads * head_dim] + // Slice on the last dimension (Out_channels) + std::string orig_weight_name = orig_name_prefix + ".weight"; + if (params->has(orig_weight_name)) { + auto orig_weight = params->pull(orig_weight_name); + + for (int h = 0; h < num_splits; ++h) { + std::string new_weight_name = new_name_prefix + "." + std::to_string(h) + ".weight"; + int start_idx = h * out_channels_per_head; + int end_idx = (h + 1) * out_channels_per_head; + // HWIO format: slice on dim 3 (Out_channels) + auto sliced = orig_weight.slice({kAll, kAll, kAll, {start_idx, end_idx}}, false); + params->push(new_weight_name, sliced.contiguous()); + } + } + + // Process scale1: flattened, size = total_out_channels / block_size (or similar) + // Slice index: (total_scale_size / num_splits) * h + std::string orig_scale1_name = orig_name_prefix + ".scale1"; + if (params->has(orig_scale1_name)) { + auto orig_scale1 = params->pull(orig_scale1_name); + int total_scale_size = orig_scale1.numel(); + int scale_per_head = total_scale_size / num_splits; + + for (int h = 0; h < num_splits; ++h) { + std::string new_scale1_name = new_name_prefix + "." + std::to_string(h) + ".scale1"; + int start_idx = h * scale_per_head; + int end_idx = (h + 1) * scale_per_head; + auto sliced = orig_scale1.slice({{start_idx, end_idx}}, false); + params->push(new_scale1_name, sliced.contiguous()); + } + } + + // Process scale2: flattened, same logic as scale1 + std::string orig_scale2_name = orig_name_prefix + ".scale2"; + if (params->has(orig_scale2_name)) { + auto orig_scale2 = params->pull(orig_scale2_name); + int total_scale_size = orig_scale2.numel(); + int scale_per_head = total_scale_size / num_splits; + + for (int h = 0; h < num_splits; ++h) { + std::string new_scale2_name = new_name_prefix + "." + std::to_string(h) + ".scale2"; + int start_idx = h * scale_per_head; + int end_idx = (h + 1) * scale_per_head; + auto sliced = orig_scale2.slice({{start_idx, end_idx}}, false); + params->push(new_scale2_name, sliced.contiguous()); + } + } + }; + + for (int layer = 0; layer < num_layers; ++layer) { + std::string layer_prefix = "model.layers." + std::to_string(layer) + ".self_attn."; + + // Process Q projection: split into num_heads parts + sliceAndPushConv2DParams(layer_prefix + "q_proj", layer_prefix + "q_proj", num_heads * head_dim, head_dim, num_heads); + + // Process K projection: split into num_kv_heads parts + sliceAndPushConv2DParams(layer_prefix + "k_proj", layer_prefix + "k_proj", num_kv_heads * head_dim, head_dim, num_kv_heads); + + // Process V projection: split into num_kv_heads parts + sliceAndPushConv2DParams(layer_prefix + "v_proj", layer_prefix + "v_proj", num_kv_heads * head_dim, head_dim, num_kv_heads); + + Dbg(); + + // Process Q norm params (per head) + // RMSNorm has: weight (needs slicing), scale (scalar, copy), zero_point (scalar, copy) + { + std::string orig_weight_name = layer_prefix + "q_norm.weight"; + std::string orig_scale_name = layer_prefix + "q_norm.scale"; + std::string orig_zp_name = layer_prefix + "q_norm.zero_point"; + + if (params->has(orig_weight_name)) { + auto orig_weight = params->pull(orig_weight_name); // [num_heads * head_dim] + + // scale and zp are scalars, just need to copy + Tensor orig_scale, orig_zp; + bool has_scale = params->has(orig_scale_name); + bool has_zp = params->has(orig_zp_name); + if (has_scale) orig_scale = params->pull(orig_scale_name); + if (has_zp) orig_zp = params->pull(orig_zp_name); + + for (int h = 0; h < num_heads; ++h) { + std::string h_str = std::to_string(h); + // Weight: slice per head + std::string new_weight_name = layer_prefix + "q_norm." + h_str + ".weight"; // NOLINT + auto sliced = orig_weight.slice({{h * head_dim, (h + 1) * head_dim}}, false); + params->push(new_weight_name, sliced.contiguous()); + + // Scale: copy (scalar) + if (has_scale) { + std::string new_scale_name = layer_prefix + "q_norm." + h_str + ".scale"; // NOLINT + params->push(new_scale_name, orig_scale.contiguous()); + } + + // Zero point: copy (scalar) + if (has_zp) { + std::string new_zp_name = layer_prefix + "q_norm." + h_str + ".zero_point"; // NOLINT + params->push(new_zp_name, orig_zp.contiguous()); + } + } + } + } + + // Process K norm params (per KV head) + { + std::string orig_weight_name = layer_prefix + "k_norm.weight"; + std::string orig_scale_name = layer_prefix + "k_norm.scale"; + std::string orig_zp_name = layer_prefix + "k_norm.zero_point"; + + if (params->has(orig_weight_name)) { + auto orig_weight = params->pull(orig_weight_name); + + Tensor orig_scale, orig_zp; + bool has_scale = params->has(orig_scale_name); + bool has_zp = params->has(orig_zp_name); + if (has_scale) orig_scale = params->pull(orig_scale_name); + if (has_zp) orig_zp = params->pull(orig_zp_name); + + for (int h = 0; h < num_kv_heads; ++h) { + std::string h_str = std::to_string(h); + // Weight: slice per head + std::string new_weight_name = layer_prefix + "k_norm." + h_str + ".weight"; // NOLINT + auto sliced = orig_weight.slice({{h * head_dim, (h + 1) * head_dim}}, false); + params->push(new_weight_name, sliced.contiguous()); + + // Scale: copy (scalar) + if (has_scale) { + std::string new_scale_name = layer_prefix + "k_norm." + h_str + ".scale"; // NOLINT + params->push(new_scale_name, orig_scale.contiguous()); + } + + // Zero point: copy (scalar) + if (has_zp) { + std::string new_zp_name = layer_prefix + "k_norm." + h_str + ".zero_point"; // NOLINT + params->push(new_zp_name, orig_zp.contiguous()); + } + } + } + } + + // ======================================================================== + // Duplicate QDQ parameters for each head + // ======================================================================== + // The original MHA uses shared QDQ params for all heads. For SHA, we + // duplicate these to per-head versions using "_h{N}" suffix naming. + // This allows each head to have its own quantization parameters. + + auto copyQDQParams = [&](const std::string& base_name, const std::string& new_base_name, int count) { + std::string scale_name = layer_prefix + base_name + ".fake_quant.scale"; + std::string zp_name = layer_prefix + base_name + ".fake_quant.zero_point"; + + if (params->has(scale_name)) { + auto scale = params->pull(scale_name); + auto zp = params->pull(zp_name); + + for (int h = 0; h < count; ++h) { + std::string new_scale_name = layer_prefix + new_base_name + std::to_string(h) + ".fake_quant.scale"; + std::string new_zp_name = layer_prefix + new_base_name + std::to_string(h) + ".fake_quant.zero_point"; + // QDQ scale/zp are typically scalar or small tensors, clone to ensure contiguous + params->push(new_scale_name, scale.contiguous()); + params->push(new_zp_name, zp.contiguous()); + } + } + }; + + // Copy QDQ params for Q-related nodes (per Q head) + copyQDQParams("q_norm_input_qdq", "q_norm_input_qdq_h", num_heads); + copyQDQParams("q_norm_output_qdq", "q_norm_output_qdq_h", num_heads); + copyQDQParams("q_rope_mul_0_output_qdq", "q_rope_mul_0_output_qdq_h", num_heads); + copyQDQParams("q_rope_mul_1_output_qdq", "q_rope_mul_1_output_qdq_h", num_heads); + copyQDQParams("q_rope_neg_half_qdq", "q_rope_neg_half_qdq_h", num_heads); + copyQDQParams("q_rope_add_0_output_qdq", "q_rope_add_0_output_qdq_h", num_heads); + + // Copy QDQ params for K-related nodes (per KV head) + copyQDQParams("k_norm_input_qdq", "k_norm_input_qdq_h", num_kv_heads); + copyQDQParams("k_norm_output_qdq", "k_norm_output_qdq_h", num_kv_heads); + copyQDQParams("k_rope_mul_0_output_qdq", "k_rope_mul_0_output_qdq_h", num_kv_heads); + copyQDQParams("k_rope_mul_1_output_qdq", "k_rope_mul_1_output_qdq_h", num_kv_heads); + copyQDQParams("k_rope_neg_half_qdq", "k_rope_neg_half_qdq_h", num_kv_heads); + copyQDQParams("k_rope_add_0_output_qdq", "k_rope_add_0_output_qdq_h", num_kv_heads); + copyQDQParams("k_cast_to_int8_qdq", "k_cast_to_int8_qdq_h", num_kv_heads); + + // Copy QDQ params for V-related nodes (per KV head) + copyQDQParams("v_cast_to_int16_qdq", "v_cast_to_int16_qdq_h", num_kv_heads); + copyQDQParams("v_cast_to_int8_qdq", "v_cast_to_int8_qdq_h", num_kv_heads); + + // Copy QDQ params for attention computation (per Q head) + copyQDQParams("qk_matmul_output_qdq", "qk_matmul_output_qdq_h", num_heads); + copyQDQParams("scaling_qdq", "scaling_qdq_h", num_heads); + copyQDQParams("mul_0_output_qdq", "mul_0_output_qdq_h", num_heads); + copyQDQParams("reduce_min_output_qdq", "reduce_min_output_qdq_h", num_heads); + copyQDQParams("neg_20_qdq", "neg_20_qdq_h", num_heads); + copyQDQParams("minus_0_output_qdq", "minus_0_output_qdq_h", num_heads); + copyQDQParams("where_attn_qdq", "where_attn_qdq_h", num_heads); + copyQDQParams("softmax_output_qdq", "softmax_output_qdq_h", num_heads); + copyQDQParams("attn_value_matmul_output_qdq", "attn_value_matmul_output_qdq_h", num_heads); + } +} + +} // namespace mllm::models::qwen3::sha From 981423733523f6ad981ea685d368413cad6ab7e5 Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Wed, 28 Jan 2026 16:16:46 +0000 Subject: [PATCH 3/8] feat(qwen2, qwen3): Enhance AOT compilation with SHA optimization and new executable. Update input tensor dimensions and block sizes for improved performance. Introduce compile_sha utility for weight slicing and quantization support. Adjust CMake configurations and add new model files for better integration. --- examples/qwen2_qnn_aot/CMakeLists.txt | 4 + examples/qwen2_qnn_aot/aot_run.cpp | 4 +- examples/qwen2_qnn_aot/compile_sha.cpp | 196 +++++ .../modeling_qwen2_qnn_aot_sha.hpp | 788 +++++++++++++++++ examples/qwen3_qnn_aot/CMakeLists.txt | 1 - examples/qwen3_qnn_aot/compile.cpp | 123 ++- examples/qwen3_qnn_aot/compile_sha.cpp | 136 ++- .../modeling_qwen_qnn_aot_sha.hpp | 34 +- mllm/backends/qnn/QNNBackend.cpp | 4 +- mllm/backends/qnn/aot/QnnWrappersAPI.cpp | 9 +- .../qnn/aot/passes/LLMQuantRecipePass.cpp | 39 +- .../transformers/llama/modeling_llama.py | 831 ++++++++++++++++++ .../qualcomm/transformers/llama/runner.py | 345 ++++++++ .../qualcomm/transformers/llama/train.py | 56 ++ .../transformers/qwen2/modeling_qwen2.py | 16 +- 15 files changed, 2459 insertions(+), 127 deletions(-) create mode 100644 examples/qwen2_qnn_aot/compile_sha.cpp create mode 100644 examples/qwen2_qnn_aot/modeling_qwen2_qnn_aot_sha.hpp create mode 100644 pymllm/backends/qualcomm/transformers/llama/modeling_llama.py create mode 100644 pymllm/backends/qualcomm/transformers/llama/runner.py create mode 100644 pymllm/backends/qualcomm/transformers/llama/train.py diff --git a/examples/qwen2_qnn_aot/CMakeLists.txt b/examples/qwen2_qnn_aot/CMakeLists.txt index eafbe1952..4db6131c0 100644 --- a/examples/qwen2_qnn_aot/CMakeLists.txt +++ b/examples/qwen2_qnn_aot/CMakeLists.txt @@ -3,6 +3,10 @@ if(MLLM_QUALCOMM_QNN_AOT_ON_X86_ENABLE) add_executable(mllm-qwen2-aot-c compile.cpp) target_link_libraries(mllm-qwen2-aot-c PRIVATE MllmRT MllmCPUBackend MllmQNNBackend) target_include_directories(mllm-qwen2-aot-c PRIVATE ${MLLM_INCLUDE_DIR}) + + add_executable(mllm-qwen2-aot-c-sha compile_sha.cpp) + target_link_libraries(mllm-qwen2-aot-c-sha PRIVATE MllmRT MllmCPUBackend MllmQNNBackend) + target_include_directories(mllm-qwen2-aot-c-sha PRIVATE ${MLLM_INCLUDE_DIR}) endif() add_executable(mllm-qwen2-aot-runner aot_run.cpp) diff --git a/examples/qwen2_qnn_aot/aot_run.cpp b/examples/qwen2_qnn_aot/aot_run.cpp index 7c0eccc0a..41437a85e 100644 --- a/examples/qwen2_qnn_aot/aot_run.cpp +++ b/examples/qwen2_qnn_aot/aot_run.cpp @@ -39,7 +39,7 @@ MLLM_MAIN({ auto input_tensor = tokenizer.convertMessage({.prompt = "hello"}); - input_tensor["sequence"] = mllm::Tensor::arange(0, 256, 1, mllm::kInt64, mllm::kCPU).view({1, -1}); + input_tensor["sequence"] = mllm::Tensor::arange(0, 800, 1, mllm::kInt64, mllm::kCPU).view({1, -1}); // DBG: mllm::print(input_tensor["sequence"].shape()); @@ -51,7 +51,7 @@ MLLM_MAIN({ return 1; } - runner.generate(input_tensor["sequence"], 128, [](const std::string& token) { std::cout << token << std::flush; }, true); + runner.generate(input_tensor["sequence"], 32, [](const std::string& token) { std::cout << token << std::flush; }, true); std::cout << "\n"; return 0; diff --git a/examples/qwen2_qnn_aot/compile_sha.cpp b/examples/qwen2_qnn_aot/compile_sha.cpp new file mode 100644 index 000000000..2ae85f3ae --- /dev/null +++ b/examples/qwen2_qnn_aot/compile_sha.cpp @@ -0,0 +1,196 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +// +// Benefits: +// 1. Reduces QNN AOT compilation time +// 2. Improves HTP runtime performance +// 3. Enables better memory locality per head +// +// Usage: +// ./compile_sha -m /path/to/model.mllm -c /path/to/config.json -aot_cfg /path/to/qnn_aot_cfg.json + +#include +#include +#include +#include +#include +#include + +#include "modeling_qwen2_qnn_aot_sha.hpp" + +using mllm::Argparse; + +MLLM_MAIN({ + auto& help = Argparse::add("-h|--help").help("Show help message"); + auto& model_path = Argparse::add("-m|--model_path").help("Model file path."); + auto& model_cfg_path = Argparse::add("-c|--config").help("Model config file path."); + auto& qnn_aot_cfg_files = Argparse::add("-aot_cfg|--aot_config").help("AOT Config file path."); + + Argparse::parse(argc, argv); + + int N = 32; + int CL = 1024; + + if (help.isSet()) { + Argparse::printHelp(); + return 0; + } + + if (!qnn_aot_cfg_files.isSet()) { + MLLM_ERROR_EXIT(mllm::ExitCode::kCoreError, "No input aot config file path provided"); + Argparse::printHelp(); + return -1; + } + + auto model_cfg = mllm::models::qwen3::Qwen3Config(model_cfg_path.get()); + + // Load original parameters + auto params = mllm::load(model_path.get(), mllm::ModelFileVersion::kV2); + + // ============================================================================ + // Key Step: Prepare SHA parameters by slicing MHA weights + // ============================================================================ + // This is the critical step that transforms MHA weights into SHA weights. + // For each Q/K/V projection, we slice the weight matrix into per-head pieces. + // + // Original: q_proj.weight [num_heads * head_dim, hidden_size, 1, 1] + // SHA: q_proj.{h}.weight [head_dim, hidden_size, 1, 1] for each head h + // + mllm::print("Preparing SHA parameters (slicing MHA weights)..."); + mllm::models::qwen2::sha::prepareParametersForSHA(params, model_cfg); + mllm::print("SHA parameters prepared."); + + // Create SHA model + auto model = mllm::models::qwen2::sha::Qwen2ForCausalLM_SHA(model_cfg); + + // Add params for causal mask + { + params->push("causal_mask.scale", mllm::Tensor::constant(0.001 / 65536.f, mllm::kFloat32)); + params->push("causal_mask.zero_point", mllm::Tensor::constant(65536, mllm::kInt8)); + params->push("constant_zero.scale", mllm::Tensor::constant(0.001 / 65536.f, mllm::kFloat32)); + params->push("constant_zero.zero_point", mllm::Tensor::constant(65536, mllm::kInt8)); + } + model.load(params); + + // Create Qnn AOT Model + auto qnn_aot_env = mllm::qnn::aot::QnnAOTEnv("/opt/qcom/aistack/qairt/2.41.0.251128/lib/x86_64-linux-clang/", + mllm::qnn::aot::parseQcomTargetMachineFromJSONFile(qnn_aot_cfg_files.get())); + + // Model length 32. + + { + // Sequence: [B, N] + // past_key_i: [B, H, D, CL-N] for each layer i + // past_value_i: [B, H, CL-N, D] for each layer i + // causal_mask: [B, 1, N, CL] + auto sequence = mllm::Tensor::zeros({1, N}, mllm::kInt32); + auto causal_mask = mllm::Tensor::zeros({1, 1, N, CL}, mllm::kUInt16); + + // NOTE: force set causal mask to UInt16Asy + // NOTE: Attach scale and zero point to causal mask + { + causal_mask = causal_mask.__unsafeSetDType(mllm::kUInt16PerTensorAsy); + causal_mask.attach("scale", params->pull("causal_mask.scale").impl(), true); + causal_mask.attach("zero_point", params->pull("causal_mask.zero_point").impl(), true); + } + + // Create KV cache inputs for all layers + std::unordered_map trace_inputs; + trace_inputs["sequence"] = sequence; + trace_inputs["causal_mask"] = causal_mask; + + for (int i = 0; i < model_cfg.num_hidden_layers; ++i) { + auto past_key_name = "past_key_" + std::to_string(i); + auto past_value_name = "past_value_" + std::to_string(i); + + // clang-format off + trace_inputs[past_key_name] = mllm::Tensor::empty({ + 1, + model_cfg.num_key_value_heads, + model_cfg.head_dim, + CL - N, + }, mllm::kUInt8PerTensorSym); + trace_inputs[past_value_name] = mllm::Tensor::empty({1, model_cfg.num_key_value_heads, CL - N, model_cfg.head_dim}, mllm::kUInt8PerTensorSym); + + trace_inputs[past_key_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_key_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + + trace_inputs[past_value_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_value_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + // clang-format on + } + + mllm::print("Tracing SHA model (seq=32)..."); + auto ir = model.trace(trace_inputs, {}); + mllm::print("SHA model traced successfully."); + + mllm::ir::PassManager pm(ir["model"]); + pm.reg(mllm::qnn::aot::createQnnAOTLoweringPipeline(&qnn_aot_env, qnn_aot_cfg_files.get(), params)); + pm.run(); + + mllm::redirect("qwen2_qnn_aot_sha_32.mir", [&]() { mllm::print(ir["model"]); }); + } + + // Model length 1. + { + N = 1; + + // Sequence: [B, N] + // past_key_i: [B, H, D, CL-N] for each layer i + // past_value_i: [B, H, CL-N, D] for each layer i + // causal_mask: [B, 1, N, CL] + auto sequence = mllm::Tensor::zeros({1, N}, mllm::kInt32); + auto causal_mask = mllm::Tensor::zeros({1, 1, N, CL}, mllm::kUInt16); + + // NOTE: force set causal mask to UInt16Asy + // NOTE: Attach scale and zero point to causal mask + { + causal_mask = causal_mask.__unsafeSetDType(mllm::kUInt16PerTensorAsy); + causal_mask.attach("scale", params->pull("causal_mask.scale").impl(), true); + causal_mask.attach("zero_point", params->pull("causal_mask.zero_point").impl(), true); + } + + // Create KV cache inputs for all layers + std::unordered_map trace_inputs; + trace_inputs["sequence"] = sequence; + trace_inputs["causal_mask"] = causal_mask; + for (int i = 0; i < model_cfg.num_hidden_layers; ++i) { + auto past_key_name = "past_key_" + std::to_string(i); + auto past_value_name = "past_value_" + std::to_string(i); + + // clang-format off + trace_inputs[past_key_name] = mllm::Tensor::empty({ + 1, + model_cfg.num_key_value_heads, + model_cfg.head_dim, + CL - N, + }, mllm::kUInt8PerTensorSym); + trace_inputs[past_value_name] = mllm::Tensor::empty({1, model_cfg.num_key_value_heads, CL - N, model_cfg.head_dim}, mllm::kUInt8PerTensorSym); + + trace_inputs[past_key_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_key_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + + trace_inputs[past_value_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_value_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + // clang-format on + } + + mllm::print("Tracing SHA model (seq=1)..."); + auto ir = model.trace(trace_inputs, {}); + mllm::print("SHA model traced successfully."); + + mllm::ir::PassManager pm(ir["model"]); + pm.reg(mllm::qnn::aot::createQnnAOTLoweringPipeline(&qnn_aot_env, qnn_aot_cfg_files.get(), params)); + pm.run(); + + mllm::redirect("qwen2_qnn_aot_sha_1.mir", [&]() { mllm::print(ir["model"]); }); + } + + qnn_aot_env.saveContext("context.0", "qwen2-lpbq-sha.bin"); + + mllm::print("SHA compilation completed successfully!"); + mllm::print("Output files:"); + mllm::print(" - qwen2_qnn_aot_sha_32.mir (IR dump for seq=32)"); + mllm::print(" - qwen2_qnn_aot_sha_1.mir (IR dump for seq=1)"); + mllm::print(" - qwen2-lpbq-sha.bin (QNN context)"); +}); diff --git a/examples/qwen2_qnn_aot/modeling_qwen2_qnn_aot_sha.hpp b/examples/qwen2_qnn_aot/modeling_qwen2_qnn_aot_sha.hpp new file mode 100644 index 000000000..db69c6017 --- /dev/null +++ b/examples/qwen2_qnn_aot/modeling_qwen2_qnn_aot_sha.hpp @@ -0,0 +1,788 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +// +// The optimization splits large Q/K/V projections into per-head projections, +// allowing QNN to optimize each head separately, reducing AOT compilation time +// and improving HTP performance. + +#pragma once + +#include "mllm/core/TensorStorage.hpp" +#include "mllm/mllm.hpp" +#include "mllm/nn/Nn.hpp" +#include "mllm/nn/Module.hpp" +#include "mllm/nn/Functional.hpp" +#include "mllm/core/DataTypes.hpp" +#include "mllm/utils/Enumerate.hpp" +#include "mllm/compile/ir/Trace.hpp" +#include "mllm/models/ARGeneration.hpp" +#include "mllm/models/qwen3/configuration_qwen3.hpp" + +namespace mllm::models::qwen2::sha { + +namespace ptq { + +Tensor QDQ_CONSTANT(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { + std::string scale_name = qdq_name_in_pytorch + ".scale"; + std::string zp_name = qdq_name_in_pytorch + ".zero_point"; + switch (in.dtype()) { + case kFloat32: + case kUInt16PerTensorAsy: { + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", zp.impl(), true); + break; + } + default: { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't Process dtype={}", nameOfType(in.dtype())); + } + } + return in; +} + +Tensor QDQ(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { + std::string scale_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.scale"; + std::string zp_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.zero_point"; + + if (m->getModuleName().empty()) { + scale_name = qdq_name_in_pytorch + ".fake_quant.scale"; + zp_name = qdq_name_in_pytorch + ".fake_quant.zero_point"; + } else { + scale_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.scale"; + zp_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.zero_point"; + } + + switch (in.dtype()) { + case kUInt16PerTensorAsy: { + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", zp.impl(), true); + break; + } + // For Constant! + case kFloat32: { + MLLM_RT_ASSERT_EQ(in.rank(), 1); + MLLM_RT_ASSERT_EQ(in.size(-1), 1); + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", zp.impl(), true); + break; + } + default: { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't Process dtype={}", nameOfType(in.dtype())); + } + } + + return in; +} + +Tensor QDQ_KV(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { + auto scale_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.scale"; + auto zp_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.zero_point"; + + // The inputs is int8 sym. which means zero_point should be changed. + switch (in.dtype()) { + case kUInt8PerTensorSym: { + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + MLLM_RT_ASSERT_EQ(zp.item(), 0); + + // Is 128! not 127! + auto new_zp = Tensor::constant(128, kInt32).setName(zp_name).setMemType(kParamsNormal); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", new_zp.impl(), true); + break; + } + default: { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't Process dtype={}", nameOfType(in.dtype())); + } + } + + return in; +} + +Tensor QDQ_ROPE(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { + auto scale_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.scale"; + auto zp_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.zero_point"; + + (void)in.__unsafeSetDType(kUInt16PerTensorAsy); + + switch (in.dtype()) { + case kUInt16PerTensorAsy: { + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", zp.impl(), true); + break; + } + default: { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't Process dtype={}", nameOfType(in.dtype())); + } + } + + return in; +} + +} // namespace ptq + +Tensor rotateHalf(Tensor x, nn::Module* m, const std::string& qdq_name_in_pytorch) { // NOLINT + // X is [x, x, x, D] + auto D = x.size(-1); + auto x1 = x.slice({kAll, kAll, kAll, {kAll, D / 2}}, /*ssa=*/true); + auto x2 = x.slice({kAll, kAll, kAll, {D / 2, kAll}}, /*ssa=*/true); + return nn::functional::concat({ptq::QDQ(m, -x2, qdq_name_in_pytorch), x1}, -1); +} + +using vi32 = std::vector; +#define CONV2D_PROPERTY vi32{1, 1}, vi32{1, 1}, vi32{0, 0}, vi32{1, 1}, false, aops::Conv2DOpImplType::kQNN_LPBQ_w4a16o16_G32 + +// Using Conv2D to replace Linear. +// Conv2D Filter Weight is [1, 1, In, Out] +// Conv2D Activation is [N, H=1, W=Seq, In] + +class Qwen2MLP final : public nn::Module { + nn::Conv2D gate_proj_; + nn::Conv2D up_proj_; + nn::Conv2D down_proj_; + nn::SiLU silu_; + int hidden_size_; + int intermediate_size_; + + public: + Qwen2MLP() = default; + Qwen2MLP(const std::string& name, const qwen3::Qwen3Config& cfg) : nn::Module(name) { + gate_proj_ = reg("gate_proj", cfg.hidden_size, cfg.intermediate_size, CONV2D_PROPERTY); + silu_ = reg("act"); + up_proj_ = reg("up_proj", cfg.hidden_size, cfg.intermediate_size, CONV2D_PROPERTY); + down_proj_ = reg("down_proj", cfg.intermediate_size, cfg.hidden_size, CONV2D_PROPERTY); + hidden_size_ = cfg.hidden_size; + intermediate_size_ = cfg.intermediate_size; + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = inputs[0]; + x = ptq::QDQ(this, x, "up_proj_input_qdq"); + x = x.view({1, 1, -1, hidden_size_}, true); + + auto up_result = ptq::QDQ(this, up_proj_(x), "up_proj_output_qdq").view({1, -1, intermediate_size_}, true); + auto gate_result = ptq::QDQ(this, gate_proj_(x), "gate_proj_output_qdq").view({1, -1, intermediate_size_}, true); + + // SiLU + gate_result = ptq::QDQ(this, (gate_result * ptq::QDQ(this, nn::functional::sigmoid(gate_result), "sigmoid_output_qdq")), + "act_output_qdq"); + + auto o = ptq::QDQ(this, gate_result * up_result, "down_proj_input_qdq"); + o = o.view({1, 1, -1, intermediate_size_}, true); + o = down_proj_(o).view({1, -1, hidden_size_}, true); + + return {o}; + } +}; + +// ============================================================================ +// Single Head Attention (SHA) Implementation +// ============================================================================ +// +// This class implements SHA where each attention head has its own separate +// Conv2D projection, instead of one large MHA projection that processes all +// heads at once. +// +// Benefits: +// 1. Reduces QNN AOT compilation time +// 2. Improves HTP runtime performance +// 3. Enables better memory locality per head +// +// Note: Qwen2 does NOT have RMSNorm after Q/K projection (unlike Qwen3) + +class Qwen2AttentionSHA final : public nn::Module { + // Per-head Q projections: num_attention_heads Conv2D(hidden_size, head_dim) + std::vector q_projs_; + // Per-head K projections: num_key_value_heads Conv2D(hidden_size, head_dim) + std::vector k_projs_; + // Per-head V projections: num_key_value_heads Conv2D(hidden_size, head_dim) + std::vector v_projs_; + // Single O projection remains unchanged (concatenated heads -> hidden_size) + nn::Conv2D o_proj_; + + nn::CausalMask mask_; + nn::Softmax softmax_; + + int hidden_size_; + int head_dim_; + int num_attention_heads_; + int num_key_value_heads_; + int num_key_value_groups_; + float scale_; + + public: + Qwen2AttentionSHA() = default; + + Qwen2AttentionSHA(const std::string& name, const qwen3::Qwen3Config& cfg) : nn::Module(name) { + hidden_size_ = cfg.hidden_size; + num_attention_heads_ = cfg.num_attention_heads; + num_key_value_heads_ = cfg.num_key_value_heads; + head_dim_ = cfg.head_dim; + num_key_value_groups_ = num_attention_heads_ / num_key_value_heads_; + scale_ = (1.f / sqrtf((float)head_dim_)); + + // Register per-head Q projections + for (int h = 0; h < num_attention_heads_; ++h) { + q_projs_.emplace_back(reg("q_proj." + std::to_string(h), hidden_size_, head_dim_, CONV2D_PROPERTY)); + } + + // Register per-head K projections + for (int h = 0; h < num_key_value_heads_; ++h) { + k_projs_.emplace_back(reg("k_proj." + std::to_string(h), hidden_size_, head_dim_, CONV2D_PROPERTY)); + } + + // Register per-head V projections + for (int h = 0; h < num_key_value_heads_; ++h) { + v_projs_.emplace_back(reg("v_proj." + std::to_string(h), hidden_size_, head_dim_, CONV2D_PROPERTY)); + } + + // O projection remains the same (combines all heads) + o_proj_ = reg("o_proj", head_dim_ * num_attention_heads_, hidden_size_, CONV2D_PROPERTY); + + mask_ = reg("mask"); + softmax_ = reg("softmax", -1); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto causal_mask = inputs[3]; + const auto& past_key = inputs[4]; // [B, num_kv_heads, D, S] + const auto& past_value = inputs[5]; // [B, num_kv_heads, S, D] + + // [B, S, D] - shared QDQ for input to all Q/K/V projections + hidden_states = ptq::QDQ(this, hidden_states, "q_proj_input_qdq"); + hidden_states = hidden_states.view({1, 1, -1, hidden_size_}, true); + + // ======================================================================== + // Per-head Q/K/V Projections + // ======================================================================== + // This is the key SHA optimization: instead of one large projection for all + // heads, we have separate smaller projections per head. + + // Compute per-head Q projections: each outputs [1, 1, S, head_dim] + std::vector query_states_per_head; + for (int h = 0; h < num_attention_heads_; ++h) { + auto q_h = q_projs_[h](hidden_states); + q_h = q_h.view({1, 1, -1, head_dim_}, /*ssa=*/true); + query_states_per_head.push_back(q_h); + } + + // Compute per-head K projections + std::vector key_states_per_head; + for (int h = 0; h < num_key_value_heads_; ++h) { + auto k_h = k_projs_[h](hidden_states); + k_h = k_h.view({1, 1, -1, head_dim_}, /*ssa=*/true); + key_states_per_head.push_back(k_h); + } + + // Compute per-head V projections + std::vector value_states_per_head; + for (int h = 0; h < num_key_value_heads_; ++h) { + auto v_h = v_projs_[h](hidden_states); + v_h = v_h.view({1, 1, -1, head_dim_}, /*ssa=*/true); + value_states_per_head.push_back(v_h); + } + + // ======================================================================== + // Reshape and Transpose for RoPE + // ======================================================================== + // Qwen2 does NOT have RMSNorm here (unlike Qwen3) + // Directly apply RoPE after reshaping to [B, H, S, D] format + // Each head tensor is [1, 1, S, head_dim], need to reshape to [1, 1, S, head_dim] for RoPE + // (The shape is already correct, but we need to ensure QDQ is applied) + + auto cos = llm_embedding_cos.unsqueeze(1, true); + auto sin = llm_embedding_sin.unsqueeze(1, true); + + // Apply QDQ and RoPE per Q head + // Each query_states_per_head[h] is [1, 1, S, head_dim] + for (int h = 0; h < num_attention_heads_; ++h) { + std::string h_str = std::to_string(h); + query_states_per_head[h] = ptq::QDQ(this, query_states_per_head[h], "q_proj_output_qdq_h" + h_str); + // Reshape to [1, 1, S, head_dim] for RoPE (already correct shape) + query_states_per_head[h] = + ptq::QDQ(this, + ptq::QDQ(this, query_states_per_head[h] * cos, "q_rope_mul_0_output_qdq_h" + h_str) + + ptq::QDQ(this, rotateHalf(query_states_per_head[h], this, "q_rope_neg_half_qdq_h" + h_str) * sin, + "q_rope_mul_1_output_qdq_h" + h_str), + "q_rope_add_0_output_qdq_h" + h_str); + } + + // Apply QDQ and RoPE per K head + // Each key_states_per_head[h] is [1, 1, S, head_dim] + for (int h = 0; h < num_key_value_heads_; ++h) { + std::string h_str = std::to_string(h); + key_states_per_head[h] = ptq::QDQ(this, key_states_per_head[h], "k_proj_output_qdq_h" + h_str); + // Reshape to [1, 1, S, head_dim] for RoPE (already correct shape) + key_states_per_head[h] = + ptq::QDQ(this, + ptq::QDQ(this, key_states_per_head[h] * cos, "k_rope_mul_0_output_qdq_h" + h_str) + + ptq::QDQ(this, rotateHalf(key_states_per_head[h], this, "k_rope_neg_half_qdq_h" + h_str) * sin, + "k_rope_mul_1_output_qdq_h" + h_str), + "k_rope_add_0_output_qdq_h" + h_str); + } + + // ======================================================================== + // KV Cache Processing per head + // ======================================================================== + + std::vector new_key_per_head; + std::vector new_value_per_head; + std::vector key_cache_per_head; + std::vector value_cache_per_head; + + for (int h = 0; h < num_key_value_heads_; ++h) { + std::string h_str = std::to_string(h); + + // K: De-quantize and re-quantize to int8 + auto k_h = key_states_per_head[h].to(kFloat32); + k_h = k_h.to(kUInt8PerTensorSym); + k_h = ptq::QDQ_KV(this, k_h, "k_cast_to_int8_qdq_h" + h_str); + k_h = k_h.transpose(2, 3); // [B, 1, D, S] + + // V: Quantize to int16 then int8 + auto v_h = ptq::QDQ(this, value_states_per_head[h], "v_cast_to_int16_qdq_h" + h_str); + v_h = v_h.to(kFloat32); + v_h = v_h.to(kUInt8PerTensorSym); + v_h = ptq::QDQ_KV(this, v_h, "v_cast_to_int8_qdq_h" + h_str); + + new_key_per_head.push_back(k_h); + new_value_per_head.push_back(v_h); + + // Slice past cache for this head + auto past_k_h = past_key.slice({kAll, {h, h + 1}, kAll, kAll}, true); + auto past_v_h = past_value.slice({kAll, {h, h + 1}, kAll, kAll}, true); + + // Concat current with past + key_cache_per_head.push_back(nn::functional::concat({past_k_h, k_h}, -1)); + value_cache_per_head.push_back(nn::functional::concat({past_v_h, v_h}, 2)); + } + + // ======================================================================== + // Per-head Attention Computation + // ======================================================================== + // Each Q head computes attention with its corresponding KV head (GQA support) + // For GQA, multiple Q heads share the same KV head + + std::vector attn_outputs; + + for (int h = 0; h < num_attention_heads_; ++h) { + std::string h_str = std::to_string(h); + int kv_head_idx = h / num_key_value_groups_; + + const auto& q_h = query_states_per_head[h]; + const auto& kh = key_cache_per_head[kv_head_idx]; + const auto& vh = value_cache_per_head[kv_head_idx]; + + // QK^T + auto attn = ptq::QDQ(this, nn::functional::matmul(q_h, kh), "qk_matmul_output_qdq_h" + h_str); + + // Scale + auto scale = Tensor::constant(scale_, kFloat32); + scale = ptq::QDQ(this, scale, "scaling_qdq_h" + h_str); + attn = ptq::QDQ(this, attn.mulConstant(scale), "mul_0_output_qdq_h" + h_str); + + // Masked Softmax + auto attn_min = ptq::QDQ(this, attn.min(-1, true), "reduce_min_output_qdq_h" + h_str); + auto minus_value = Tensor::constant(-20, kFloat32); + minus_value = ptq::QDQ(this, minus_value, "neg_20_qdq_h" + h_str); + auto attn_vv = ptq::QDQ(this, attn_min.addConstant(minus_value), "minus_0_output_qdq_h" + h_str); + auto zero_constant = Tensor::constant(0.f, kFloat32); + zero_constant = ptq::QDQ_CONSTANT(this, zero_constant, "constant_zero"); + attn = nn::functional::where(causal_mask.equalConstant(zero_constant), attn, attn_vv); + attn = ptq::QDQ(this, attn, "where_attn_qdq_h" + h_str); + attn = ptq::QDQ(this, nn::functional::softmax(attn, -1), "softmax_output_qdq_h" + h_str); + + // Output: attn @ V + auto y_h = ptq::QDQ(this, nn::functional::matmul(attn, vh), "attn_value_matmul_output_qdq_h" + h_str); + attn_outputs.push_back(y_h); + } + + // ======================================================================== + // Concatenate and Output Projection + // ======================================================================== + + // Concat all head outputs: [B, num_heads, S, D] + auto y = nn::functional::concat(attn_outputs, 1); + + // Reshape and apply O projection + y = y.transpose(1, 2).view({1, 1, -1, num_attention_heads_ * head_dim_}, /*ssa=*/true); + y = o_proj_(y).view({1, -1, hidden_size_}, true); + + // Concat new keys and values back to original format + auto new_key = nn::functional::concat(new_key_per_head, 1); + auto new_value = nn::functional::concat(new_value_per_head, 1); + + return {y, new_key, new_value}; + } + + int layer_idx_; +}; + +class Qwen2DecoderSHA final : public nn::Module { + public: + int layer_idx_; + Qwen2AttentionSHA self_attn_; + Qwen2MLP mlp_; + nn::RMSNorm input_layer_norm_; + nn::RMSNorm post_attention_layer_norm_; + + Qwen2DecoderSHA() = default; + + Qwen2DecoderSHA(const std::string& name, const qwen3::Qwen3Config& cfg, int layer_idx) : nn::Module(name) { + layer_idx_ = layer_idx; + self_attn_ = reg("self_attn", cfg); + mlp_ = reg("mlp", cfg); + input_layer_norm_ = reg("input_layernorm", cfg.rms_norm_eps); + post_attention_layer_norm_ = reg("post_attention_layernorm", cfg.rms_norm_eps); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto causal_mask = inputs[3]; + auto past_key = inputs[4]; + auto past_value = inputs[5]; + + auto hidden_states = inputs[0]; + if (layer_idx_ != 0) { hidden_states = ptq::QDQ(this, hidden_states, "input_layernorm_input_qdq"); } + auto residual = hidden_states; + hidden_states = input_layer_norm_(hidden_states); + auto _ = self_attn_(hidden_states, llm_embedding_sin, llm_embedding_cos, causal_mask, past_key, past_value); + hidden_states = _[0]; + hidden_states = ptq::QDQ(this, residual + ptq::QDQ(this, hidden_states, "add_0_lhs_input_qdq"), "add_0_output_qdq"); + residual = hidden_states; + hidden_states = post_attention_layer_norm_(hidden_states); + hidden_states = mlp_(hidden_states)[0]; + hidden_states = residual + ptq::QDQ(this, hidden_states, "add_1_lhs_input_qdq"); + return {hidden_states, _[1], _[2]}; + } +}; + +class Qwen2TextSHA final : public nn::Module { + nn::ModuleListWithIdx decode_blocks_; + nn::RMSNorm norm_; + nn::Embedding embedding_; + nn::Param rope_sin_; + nn::Param rope_cos_; + int32_t num_hidden_layers_; + int32_t hidden_size_; + + public: + Qwen2TextSHA() = default; + + Qwen2TextSHA(const std::string& name, const qwen3::Qwen3Config& cfg) : nn::Module(name) { + num_hidden_layers_ = cfg.num_hidden_layers; + hidden_size_ = cfg.hidden_size; + decode_blocks_ = reg>("layers", cfg.num_hidden_layers, cfg); + for (auto [idx, b] : enumerate(decode_blocks_.list())) { b.self_attn_.layer_idx_ = idx; } + norm_ = reg("norm", cfg.rms_norm_eps); + embedding_ = reg("embed_tokens", cfg.vocab_size, cfg.hidden_size); + rope_sin_ = reg("mllm_max_sin_embedding", "model.mllm_max_sin_embedding"); + rope_cos_ = reg("mllm_max_cos_embedding", "model.mllm_max_cos_embedding"); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto& blocks = decode_blocks_.list(); + + // X is already embedded + auto x = embedding_(inputs[0]); + + const auto& position_ids = inputs[1]; + auto causal_mask = inputs[2]; + + // clang-format off + auto llm_embedding_sin = nn::functional::gather(ptq::QDQ_ROPE(this, rope_sin_(), "sin_embedding_input_qdq"), 1, position_ids); + auto llm_embedding_cos = nn::functional::gather(ptq::QDQ_ROPE(this, rope_cos_(), "cos_embedding_input_qdq"), 1, position_ids); + // clang-format on + + std::vector keys; + std::vector values; + for (auto [index, block] : enumerate(blocks)) { + auto pk = inputs[3 + index]; + auto pv = inputs[3 + index + num_hidden_layers_]; + auto _ = block(x, llm_embedding_sin, llm_embedding_cos, causal_mask, pk, pv); + x = _[0]; + keys.push_back(_[1]); + values.push_back(_[2]); + } + + x = norm_(ptq::QDQ(this, x, "norm_input_qdq")); + x = x.view({1, 1, -1, hidden_size_}, true); + + auto ret = std::vector{x}; + for (const auto& item : keys) { ret.push_back(item); } + for (const auto& item : values) { ret.push_back(item); } + + return ret; + } +}; + +class Qwen2ForCausalLM_SHA : public ARGeneration, public nn::Module { + public: + explicit Qwen2ForCausalLM_SHA(const qwen3::Qwen3Config& cfg) : cfg(cfg) { + eos_token_id_ = cfg.end_of_text_token_id; + max_length_ = cfg.max_cache_length; + tie_word_embeddings_ = cfg.tie_word_embeddings; + + llm = reg("model", cfg); + + if (cfg.tie_word_embeddings) { + // NOTE: + // model.lm_head.weight is quantization weights of model.embed_tokens.weight + lm_head_ = reg("lm_head", cfg.hidden_size, cfg.vocab_size, CONV2D_PROPERTY); + } + } + + IROutput trace(const ARGenerationOutputPast& input, const ARGenerationArgs& args) override { + // Things we need to return + ir::IRContext::ptr_t llm_ir = nullptr; + + auto sequence = input.at("sequence"); + auto causal_mask = input.at("causal_mask"); + + std::vector kv_caches; + + // Append Key + for (int i = 0; i < cfg.num_hidden_layers; ++i) { + auto past_key_name = "past_key_" + std::to_string(i); + if (input.count(past_key_name)) { + kv_caches.push_back(input.at(past_key_name)); + } else { + throw std::runtime_error("Missing KV cache for layer " + std::to_string(i)); + } + } + + // Append Value + for (int i = 0; i < cfg.num_hidden_layers; ++i) { + auto past_value_name = "past_value_" + std::to_string(i); + if (input.count(past_value_name)) { + kv_caches.push_back(input.at(past_value_name)); + } else { + throw std::runtime_error("Missing KV cache for layer " + std::to_string(i)); + } + } + + // Generate position_ids for the current sequence + auto batch_size = sequence.shape()[0]; + auto seq_len = sequence.shape()[1]; + + Tensor position_ids = Tensor::nil(); + if (input.count("position_ids")) { + // Use existing position_ids for decode phase + position_ids = input.at("position_ids"); + + // For decode phase, increment the last position + if (seq_len == 1) { + auto last_pos = *position_ids.offsettedPtr({position_ids.shape()[1] - 1}); + position_ids = Tensor::empty({1}, kInt32, kCPU).alloc(); + *position_ids.offsettedPtr({0}) = last_pos + 1; + } + } else { + // Generate position_ids for prefill phase + position_ids = Tensor::empty({seq_len}, kInt32, kCPU).alloc(); + auto position_ids_ptr = position_ids.ptr(); + for (int s = 0; s < seq_len; ++s) { position_ids_ptr[s] = s; } + } + + ir::lowlevel::traceStart(); + + // Build inputs for llm: sequence, llm_embedding_sin, llm_embedding_cos, causal_mask, then all KV caches + std::vector llm_inputs = {sequence, position_ids, causal_mask}; + llm_inputs.insert(llm_inputs.end(), kv_caches.begin(), kv_caches.end()); + + sequence = llm(llm_inputs)[0]; + sequence = lm_head_(ptq::QDQ(this, sequence, "lm_head_input_qdq")); + sequence = ptq::QDQ(this, sequence, "lm_head_output_qdq"); + ir::lowlevel::traceComment(" ╔═════╗ "); + ir::lowlevel::traceComment(" ║ o o ║ "); + ir::lowlevel::traceComment(" ║ ▽ ║ "); + ir::lowlevel::traceComment(" ╚═════╝ "); + ir::lowlevel::traceComment(" ║ ║ "); + ir::lowlevel::traceComment(" ╱╩╦╦╩╲ "); + llm_ir = ir::lowlevel::traceStop(); + + return {{"model", llm_ir}}; + } + + ARGenerationOutputPast forward(const ARGenerationOutputPast& input, const ARGenerationArgs& args) override { return {}; } + + private: + const qwen3::Qwen3Config& cfg; + Qwen2TextSHA llm; + nn::Conv2D lm_head_; + bool tie_word_embeddings_; +}; + +// ============================================================================ +// Weight Slicing Utilities for SHA +// ============================================================================ +// +// These functions are used during the compile phase to slice the original +// MHA weights into per-head SHA weights. +// +// Note: Qwen2 does NOT have q_norm and k_norm (RMSNorm), so we don't need +// to slice those parameters. + +/** + * @brief Prepares the parameter file by slicing MHA weights into SHA weights. + * + * This function takes the original parameter file with MHA weights and creates + * new per-head weights for the SHA model. + * + * Original weight layout for Conv2D: [out_channels, in_channels, 1, 1] + * - q_proj.weight: [num_heads * head_dim, hidden_size, 1, 1] + * - k_proj.weight: [num_kv_heads * head_dim, hidden_size, 1, 1] + * - v_proj.weight: [num_kv_heads * head_dim, hidden_size, 1, 1] + * + * For LPBQ quantization, also need to slice: + * - scale1: flattened scale for block quantization + * - scale2: flattened scale for block quantization + * + * SHA weight layout: + * - q_proj.{h}.weight: [head_dim, hidden_size, 1, 1] for h in [0, num_heads) + * - q_proj.{h}.scale1: sliced scale for head h + * - q_proj.{h}.scale2: sliced scale for head h + * - Similar for k_proj and v_proj + */ +inline void prepareParametersForSHA(const ParameterFile::ptr_t& params, const qwen3::Qwen3Config& cfg) { + int num_heads = cfg.num_attention_heads; + int num_kv_heads = cfg.num_key_value_heads; + int head_dim = cfg.head_dim; + int num_layers = cfg.num_hidden_layers; + + // Helper lambda to slice and push Conv2D params (weight, scale1, scale2) + // For LPBQ, scale1 and scale2 are flattened along the output channel dimension + // Scale size per head = total_scale_size / num_heads_for_this_proj + auto sliceAndPushConv2DParams = [&](const std::string& orig_name_prefix, const std::string& new_name_prefix, + int total_out_channels, int out_channels_per_head, int num_splits) { + // Process weight: HWIO format [H=1, W=1, In_channels, Out_channels] + // For q_proj: [1, 1, hidden_size, num_heads * head_dim] + // Slice on the last dimension (Out_channels) + std::string orig_weight_name = orig_name_prefix + ".weight"; + if (params->has(orig_weight_name)) { + auto orig_weight = params->pull(orig_weight_name); + + for (int h = 0; h < num_splits; ++h) { + std::string new_weight_name = new_name_prefix + "." + std::to_string(h) + ".weight"; + int start_idx = h * out_channels_per_head; + int end_idx = (h + 1) * out_channels_per_head; + // HWIO format: slice on dim 3 (Out_channels) + auto sliced = orig_weight.slice({kAll, kAll, kAll, {start_idx, end_idx}}, false); + params->push(new_weight_name, sliced.contiguous().setMemType(kParamsNormal).setName(new_weight_name)); + } + } + + // Process scale1: flattened, size = total_out_channels / block_size (or similar) + // Slice index: (total_scale_size / num_splits) * h + std::string orig_scale1_name = orig_name_prefix + ".scale1"; + if (params->has(orig_scale1_name)) { + auto orig_scale1 = params->pull(orig_scale1_name); + int total_scale_size = orig_scale1.numel(); + int scale_per_head = total_scale_size / num_splits; + + for (int h = 0; h < num_splits; ++h) { + std::string new_scale1_name = new_name_prefix + "." + std::to_string(h) + ".scale1"; + int start_idx = h * scale_per_head; + int end_idx = (h + 1) * scale_per_head; + auto sliced = orig_scale1.slice({{start_idx, end_idx}}, false); + params->push(new_scale1_name, sliced.contiguous().setMemType(kParamsNormal).setName(new_scale1_name)); + } + } + + // Process scale2: flattened, same logic as scale1 + std::string orig_scale2_name = orig_name_prefix + ".scale2"; + if (params->has(orig_scale2_name)) { + auto orig_scale2 = params->pull(orig_scale2_name); + int total_scale_size = orig_scale2.numel(); + int scale_per_head = total_scale_size / num_splits; + + for (int h = 0; h < num_splits; ++h) { + std::string new_scale2_name = new_name_prefix + "." + std::to_string(h) + ".scale2"; + int start_idx = h * scale_per_head; + int end_idx = (h + 1) * scale_per_head; + auto sliced = orig_scale2.slice({{start_idx, end_idx}}, false); + params->push(new_scale2_name, sliced.contiguous().setMemType(kParamsNormal).setName(new_scale2_name)); + } + } + }; + + for (int layer = 0; layer < num_layers; ++layer) { + std::string layer_prefix = "model.layers." + std::to_string(layer) + ".self_attn."; + + // Process Q projection: split into num_heads parts + sliceAndPushConv2DParams(layer_prefix + "q_proj", layer_prefix + "q_proj", num_heads * head_dim, head_dim, num_heads); + + // Process K projection: split into num_kv_heads parts + sliceAndPushConv2DParams(layer_prefix + "k_proj", layer_prefix + "k_proj", num_kv_heads * head_dim, head_dim, num_kv_heads); + + // Process V projection: split into num_kv_heads parts + sliceAndPushConv2DParams(layer_prefix + "v_proj", layer_prefix + "v_proj", num_kv_heads * head_dim, head_dim, num_kv_heads); + + // ======================================================================== + // Duplicate QDQ parameters for each head + // ======================================================================== + // The original MHA uses shared QDQ params for all heads. For SHA, we + // duplicate these to per-head versions using "_h{N}" suffix naming. + // This allows each head to have its own quantization parameters. + + auto copyQDQParams = [&](const std::string& base_name, const std::string& new_base_name, int count) { + std::string scale_name = layer_prefix + base_name + ".fake_quant.scale"; + std::string zp_name = layer_prefix + base_name + ".fake_quant.zero_point"; + + if (params->has(scale_name)) { + auto scale = params->pull(scale_name); + auto zp = params->pull(zp_name); + + for (int h = 0; h < count; ++h) { + std::string new_scale_name = layer_prefix + new_base_name + std::to_string(h) + ".fake_quant.scale"; + std::string new_zp_name = layer_prefix + new_base_name + std::to_string(h) + ".fake_quant.zero_point"; + // QDQ scale/zp are typically scalar or small tensors, clone to ensure contiguous + params->push(new_scale_name, scale.contiguous().setMemType(kParamsNormal).setName(new_scale_name)); + params->push(new_zp_name, zp.contiguous().setMemType(kParamsNormal).setName(new_zp_name)); + } + } + }; + + // Copy QDQ params for Q-related nodes (per Q head) + copyQDQParams("q_proj_output_qdq", "q_proj_output_qdq_h", num_heads); + copyQDQParams("q_rope_mul_0_output_qdq", "q_rope_mul_0_output_qdq_h", num_heads); + copyQDQParams("q_rope_mul_1_output_qdq", "q_rope_mul_1_output_qdq_h", num_heads); + copyQDQParams("q_rope_neg_half_qdq", "q_rope_neg_half_qdq_h", num_heads); + copyQDQParams("q_rope_add_0_output_qdq", "q_rope_add_0_output_qdq_h", num_heads); + + // Copy QDQ params for K-related nodes (per KV head) + copyQDQParams("k_proj_output_qdq", "k_proj_output_qdq_h", num_kv_heads); + copyQDQParams("k_rope_mul_0_output_qdq", "k_rope_mul_0_output_qdq_h", num_kv_heads); + copyQDQParams("k_rope_mul_1_output_qdq", "k_rope_mul_1_output_qdq_h", num_kv_heads); + copyQDQParams("k_rope_neg_half_qdq", "k_rope_neg_half_qdq_h", num_kv_heads); + copyQDQParams("k_rope_add_0_output_qdq", "k_rope_add_0_output_qdq_h", num_kv_heads); + copyQDQParams("k_cast_to_int8_qdq", "k_cast_to_int8_qdq_h", num_kv_heads); + + // Copy QDQ params for V-related nodes (per KV head) + copyQDQParams("v_cast_to_int16_qdq", "v_cast_to_int16_qdq_h", num_kv_heads); + copyQDQParams("v_cast_to_int8_qdq", "v_cast_to_int8_qdq_h", num_kv_heads); + + // Copy QDQ params for attention computation (per Q head) + copyQDQParams("qk_matmul_output_qdq", "qk_matmul_output_qdq_h", num_heads); + copyQDQParams("scaling_qdq", "scaling_qdq_h", num_heads); + copyQDQParams("mul_0_output_qdq", "mul_0_output_qdq_h", num_heads); + copyQDQParams("reduce_min_output_qdq", "reduce_min_output_qdq_h", num_heads); + copyQDQParams("neg_20_qdq", "neg_20_qdq_h", num_heads); + copyQDQParams("minus_0_output_qdq", "minus_0_output_qdq_h", num_heads); + copyQDQParams("where_attn_qdq", "where_attn_qdq_h", num_heads); + copyQDQParams("softmax_output_qdq", "softmax_output_qdq_h", num_heads); + copyQDQParams("attn_value_matmul_output_qdq", "attn_value_matmul_output_qdq_h", num_heads); + } +} + +} // namespace mllm::models::qwen2::sha diff --git a/examples/qwen3_qnn_aot/CMakeLists.txt b/examples/qwen3_qnn_aot/CMakeLists.txt index 3c97427e1..6556d0557 100644 --- a/examples/qwen3_qnn_aot/CMakeLists.txt +++ b/examples/qwen3_qnn_aot/CMakeLists.txt @@ -5,7 +5,6 @@ if(MLLM_QUALCOMM_QNN_AOT_ON_X86_ENABLE) target_include_directories(mllm-qwen3-aot-c PRIVATE ${MLLM_INCLUDE_DIR}) # SHA (Single Head Attention) version - MHA to SHA optimization - # Similar to ExecuTorch's ConvertMhaToSha pass for better QNN AOT performance add_executable(mllm-qwen3-aot-sha-c compile_sha.cpp) target_link_libraries(mllm-qwen3-aot-sha-c PRIVATE MllmRT MllmCPUBackend MllmQNNBackend) target_include_directories(mllm-qwen3-aot-sha-c PRIVATE ${MLLM_INCLUDE_DIR}) diff --git a/examples/qwen3_qnn_aot/compile.cpp b/examples/qwen3_qnn_aot/compile.cpp index f47b0dee4..cdf38583c 100644 --- a/examples/qwen3_qnn_aot/compile.cpp +++ b/examples/qwen3_qnn_aot/compile.cpp @@ -20,8 +20,8 @@ MLLM_MAIN({ Argparse::parse(argc, argv); - constexpr int N = 32; - constexpr int CL = 1024; + int N = 32; + int CL = 1024; if (help.isSet()) { Argparse::printHelp(); @@ -46,31 +46,91 @@ MLLM_MAIN({ } model.load(params); - // Sequence: [B, N] - // past_key_i: [B, H, D, CL-N] for each layer i - // past_value_i: [B, H, CL-N, D] for each layer i - // causal_mask: [B, 1, N, CL] - auto sequence = mllm::Tensor::zeros({1, N}, mllm::kInt32); - auto causal_mask = mllm::Tensor::zeros({1, 1, N, CL}, mllm::kUInt16); + // Create Qnn AOT Model + auto qnn_aot_env = mllm::qnn::aot::QnnAOTEnv("/opt/qcom/aistack/qairt/2.41.0.251128/lib/x86_64-linux-clang/", + mllm::qnn::aot::parseQcomTargetMachineFromJSONFile(qnn_aot_cfg_files.get())); + + // Model length 32. - // NOTE: force set causal mask to UInt16Asy - // NOTE: Attach scale and zero point to causal mask { - causal_mask = causal_mask.__unsafeSetDType(mllm::kUInt16PerTensorAsy); - causal_mask.attach("scale", params->pull("causal_mask.scale").impl(), true); - causal_mask.attach("zero_point", params->pull("causal_mask.zero_point").impl(), true); - } + // Sequence: [B, N] + // past_key_i: [B, H, D, CL-N] for each layer i + // past_value_i: [B, H, CL-N, D] for each layer i + // causal_mask: [B, 1, N, CL] + auto sequence = mllm::Tensor::zeros({1, N}, mllm::kInt32); + auto causal_mask = mllm::Tensor::zeros({1, 1, N, CL}, mllm::kUInt16); + + // NOTE: force set causal mask to UInt16Asy + // NOTE: Attach scale and zero point to causal mask + { + causal_mask = causal_mask.__unsafeSetDType(mllm::kUInt16PerTensorAsy); + causal_mask.attach("scale", params->pull("causal_mask.scale").impl(), true); + causal_mask.attach("zero_point", params->pull("causal_mask.zero_point").impl(), true); + } + + // Create KV cache inputs for all layers + std::unordered_map trace_inputs; + trace_inputs["sequence"] = sequence; + trace_inputs["causal_mask"] = causal_mask; + + for (int i = 0; i < model_cfg.num_hidden_layers; ++i) { + auto past_key_name = "past_key_" + std::to_string(i); + auto past_value_name = "past_value_" + std::to_string(i); + + // clang-format off + trace_inputs[past_key_name] = mllm::Tensor::empty({ + 1, + model_cfg.num_key_value_heads, + model_cfg.head_dim, + CL - N, + }, mllm::kUInt8PerTensorSym); + trace_inputs[past_value_name] = mllm::Tensor::empty({1, model_cfg.num_key_value_heads, CL - N, model_cfg.head_dim}, mllm::kUInt8PerTensorSym); + + trace_inputs[past_key_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_key_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + + trace_inputs[past_value_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_value_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + // clang-format on + } - // Create KV cache inputs for all layers - std::unordered_map trace_inputs; - trace_inputs["sequence"] = sequence; - trace_inputs["causal_mask"] = causal_mask; + auto ir = model.trace(trace_inputs, {}); - for (int i = 0; i < model_cfg.num_hidden_layers; ++i) { - auto past_key_name = "past_key_" + std::to_string(i); - auto past_value_name = "past_value_" + std::to_string(i); + mllm::ir::PassManager pm(ir["model"]); + pm.reg(mllm::qnn::aot::createQnnAOTLoweringPipeline(&qnn_aot_env, qnn_aot_cfg_files.get(), params)); + pm.run(); - // clang-format off + mllm::redirect("qwen3_qnn_aot_32.mir", [&]() { mllm::print(ir["model"]); }); + } + + // Model length 1. + { + N = 1; + + // Sequence: [B, N] + // past_key_i: [B, H, D, CL-N] for each layer i + // past_value_i: [B, H, CL-N, D] for each layer i + // causal_mask: [B, 1, N, CL] + auto sequence = mllm::Tensor::zeros({1, N}, mllm::kInt32); + auto causal_mask = mllm::Tensor::zeros({1, 1, N, CL}, mllm::kUInt16); + + // NOTE: force set causal mask to UInt16Asy + // NOTE: Attach scale and zero point to causal mask + { + causal_mask = causal_mask.__unsafeSetDType(mllm::kUInt16PerTensorAsy); + causal_mask.attach("scale", params->pull("causal_mask.scale").impl(), true); + causal_mask.attach("zero_point", params->pull("causal_mask.zero_point").impl(), true); + } + + // Create KV cache inputs for all layers + std::unordered_map trace_inputs; + trace_inputs["sequence"] = sequence; + trace_inputs["causal_mask"] = causal_mask; + for (int i = 0; i < model_cfg.num_hidden_layers; ++i) { + auto past_key_name = "past_key_" + std::to_string(i); + auto past_value_name = "past_value_" + std::to_string(i); + + // clang-format off trace_inputs[past_key_name] = mllm::Tensor::empty({ 1, model_cfg.num_key_value_heads, @@ -84,20 +144,17 @@ MLLM_MAIN({ trace_inputs[past_value_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.scale").impl(), true); trace_inputs[past_value_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); - // clang-format on - } - - auto ir = model.trace(trace_inputs, {}); + // clang-format on + } - // Create Qnn AOT Model - auto qnn_aot_env = mllm::qnn::aot::QnnAOTEnv("/opt/qcom/aistack/qairt/2.41.0.251128/lib/x86_64-linux-clang/", - mllm::qnn::aot::parseQcomTargetMachineFromJSONFile(qnn_aot_cfg_files.get())); + auto ir = model.trace(trace_inputs, {}); - mllm::ir::PassManager pm(ir["model"]); - pm.reg(mllm::qnn::aot::createQnnAOTLoweringPipeline(&qnn_aot_env, qnn_aot_cfg_files.get(), params)); - pm.run(); + mllm::ir::PassManager pm(ir["model"]); + pm.reg(mllm::qnn::aot::createQnnAOTLoweringPipeline(&qnn_aot_env, qnn_aot_cfg_files.get(), params)); + pm.run(); - mllm::redirect("qwen3_qnn_aot.mir", [&]() { mllm::print(ir["model"]); }); + mllm::redirect("qwen3_qnn_aot_1.mir", [&]() { mllm::print(ir["model"]); }); + } qnn_aot_env.saveContext("context.0", "qwen3-1.7B-lpbq.bin"); }); diff --git a/examples/qwen3_qnn_aot/compile_sha.cpp b/examples/qwen3_qnn_aot/compile_sha.cpp index 71b3ceb03..840a5e5ef 100644 --- a/examples/qwen3_qnn_aot/compile_sha.cpp +++ b/examples/qwen3_qnn_aot/compile_sha.cpp @@ -1,10 +1,6 @@ // Copyright (c) MLLM Team. // Licensed under the MIT License. // -// This file demonstrates the MHA -> SHA (Multi-Head Attention to Single-Head Attention) -// optimization for QNN AOT compilation. Similar to ExecuTorch's ConvertMhaToSha pass, -// this approach splits large Q/K/V projections into per-head projections. -// // Benefits: // 1. Reduces QNN AOT compilation time // 2. Improves HTP runtime performance @@ -32,8 +28,8 @@ MLLM_MAIN({ Argparse::parse(argc, argv); - constexpr int N = 32; - constexpr int CL = 1024; + int N = 32; + int CL = 1024; if (help.isSet()) { Argparse::printHelp(); @@ -76,31 +72,93 @@ MLLM_MAIN({ } model.load(params); - // Sequence: [B, N] - // past_key_i: [B, H, D, CL-N] for each layer i - // past_value_i: [B, H, CL-N, D] for each layer i - // causal_mask: [B, 1, N, CL] - auto sequence = mllm::Tensor::zeros({1, N}, mllm::kInt32); - auto causal_mask = mllm::Tensor::zeros({1, 1, N, CL}, mllm::kUInt16); + // Create Qnn AOT Model + auto qnn_aot_env = mllm::qnn::aot::QnnAOTEnv("/opt/qcom/aistack/qairt/2.41.0.251128/lib/x86_64-linux-clang/", + mllm::qnn::aot::parseQcomTargetMachineFromJSONFile(qnn_aot_cfg_files.get())); + + // Model length 32. - // NOTE: force set causal mask to UInt16Asy - // NOTE: Attach scale and zero point to causal mask { - causal_mask = causal_mask.__unsafeSetDType(mllm::kUInt16PerTensorAsy); - causal_mask.attach("scale", params->pull("causal_mask.scale").impl(), true); - causal_mask.attach("zero_point", params->pull("causal_mask.zero_point").impl(), true); - } + // Sequence: [B, N] + // past_key_i: [B, H, D, CL-N] for each layer i + // past_value_i: [B, H, CL-N, D] for each layer i + // causal_mask: [B, 1, N, CL] + auto sequence = mllm::Tensor::zeros({1, N}, mllm::kInt32); + auto causal_mask = mllm::Tensor::zeros({1, 1, N, CL}, mllm::kUInt16); + + // NOTE: force set causal mask to UInt16Asy + // NOTE: Attach scale and zero point to causal mask + { + causal_mask = causal_mask.__unsafeSetDType(mllm::kUInt16PerTensorAsy); + causal_mask.attach("scale", params->pull("causal_mask.scale").impl(), true); + causal_mask.attach("zero_point", params->pull("causal_mask.zero_point").impl(), true); + } + + // Create KV cache inputs for all layers + std::unordered_map trace_inputs; + trace_inputs["sequence"] = sequence; + trace_inputs["causal_mask"] = causal_mask; + + for (int i = 0; i < model_cfg.num_hidden_layers; ++i) { + auto past_key_name = "past_key_" + std::to_string(i); + auto past_value_name = "past_value_" + std::to_string(i); + + // clang-format off + trace_inputs[past_key_name] = mllm::Tensor::empty({ + 1, + model_cfg.num_key_value_heads, + model_cfg.head_dim, + CL - N, + }, mllm::kUInt8PerTensorSym); + trace_inputs[past_value_name] = mllm::Tensor::empty({1, model_cfg.num_key_value_heads, CL - N, model_cfg.head_dim}, mllm::kUInt8PerTensorSym); + + trace_inputs[past_key_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_key_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); - // Create KV cache inputs for all layers - std::unordered_map trace_inputs; - trace_inputs["sequence"] = sequence; - trace_inputs["causal_mask"] = causal_mask; + trace_inputs[past_value_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_value_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + // clang-format on + } - for (int i = 0; i < model_cfg.num_hidden_layers; ++i) { - auto past_key_name = "past_key_" + std::to_string(i); - auto past_value_name = "past_value_" + std::to_string(i); + mllm::print("Tracing SHA model (seq=32)..."); + auto ir = model.trace(trace_inputs, {}); + mllm::print("SHA model traced successfully."); + + mllm::ir::PassManager pm(ir["model"]); + pm.reg(mllm::qnn::aot::createQnnAOTLoweringPipeline(&qnn_aot_env, qnn_aot_cfg_files.get(), params)); + pm.run(); + + mllm::redirect("qwen3_qnn_aot_sha_32.mir", [&]() { mllm::print(ir["model"]); }); + } - // clang-format off + // Model length 1. + { + N = 1; + + // Sequence: [B, N] + // past_key_i: [B, H, D, CL-N] for each layer i + // past_value_i: [B, H, CL-N, D] for each layer i + // causal_mask: [B, 1, N, CL] + auto sequence = mllm::Tensor::zeros({1, N}, mllm::kInt32); + auto causal_mask = mllm::Tensor::zeros({1, 1, N, CL}, mllm::kUInt16); + + // NOTE: force set causal mask to UInt16Asy + // NOTE: Attach scale and zero point to causal mask + { + causal_mask = causal_mask.__unsafeSetDType(mllm::kUInt16PerTensorAsy); + causal_mask.attach("scale", params->pull("causal_mask.scale").impl(), true); + causal_mask.attach("zero_point", params->pull("causal_mask.zero_point").impl(), true); + } + + // Create KV cache inputs for all layers + std::unordered_map trace_inputs; + trace_inputs["sequence"] = sequence; + trace_inputs["causal_mask"] = causal_mask; + for (int i = 0; i < model_cfg.num_hidden_layers; ++i) { + auto past_key_name = "past_key_" + std::to_string(i); + auto past_value_name = "past_value_" + std::to_string(i); + + // clang-format off trace_inputs[past_key_name] = mllm::Tensor::empty({ 1, model_cfg.num_key_value_heads, @@ -114,27 +172,25 @@ MLLM_MAIN({ trace_inputs[past_value_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.scale").impl(), true); trace_inputs[past_value_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); - // clang-format on - } + // clang-format on + } - mllm::print("Tracing SHA model..."); - auto ir = model.trace(trace_inputs, {}); - mllm::print("SHA model traced successfully."); + mllm::print("Tracing SHA model (seq=1)..."); + auto ir = model.trace(trace_inputs, {}); + mllm::print("SHA model traced successfully."); - // Create Qnn AOT Model - auto qnn_aot_env = mllm::qnn::aot::QnnAOTEnv("/opt/qcom/aistack/qairt/2.41.0.251128/lib/x86_64-linux-clang/", - mllm::qnn::aot::parseQcomTargetMachineFromJSONFile(qnn_aot_cfg_files.get())); - - mllm::ir::PassManager pm(ir["model"]); - pm.reg(mllm::qnn::aot::createQnnAOTLoweringPipeline(&qnn_aot_env, qnn_aot_cfg_files.get(), params)); - pm.run(); + mllm::ir::PassManager pm(ir["model"]); + pm.reg(mllm::qnn::aot::createQnnAOTLoweringPipeline(&qnn_aot_env, qnn_aot_cfg_files.get(), params)); + pm.run(); - mllm::redirect("qwen3_qnn_aot_sha.mir", [&]() { mllm::print(ir["model"]); }); + mllm::redirect("qwen3_qnn_aot_sha_1.mir", [&]() { mllm::print(ir["model"]); }); + } qnn_aot_env.saveContext("context.0", "qwen3-1.7B-lpbq-sha.bin"); mllm::print("SHA compilation completed successfully!"); mllm::print("Output files:"); - mllm::print(" - qwen3_qnn_aot_sha.mir (IR dump)"); + mllm::print(" - qwen3_qnn_aot_sha_32.mir (IR dump for seq=32)"); + mllm::print(" - qwen3_qnn_aot_sha_1.mir (IR dump for seq=1)"); mllm::print(" - qwen3-1.7B-lpbq-sha.bin (QNN context)"); }); diff --git a/examples/qwen3_qnn_aot/modeling_qwen_qnn_aot_sha.hpp b/examples/qwen3_qnn_aot/modeling_qwen_qnn_aot_sha.hpp index e7f0b4df7..272535d52 100644 --- a/examples/qwen3_qnn_aot/modeling_qwen_qnn_aot_sha.hpp +++ b/examples/qwen3_qnn_aot/modeling_qwen_qnn_aot_sha.hpp @@ -1,15 +1,13 @@ // Copyright (c) MLLM Team. // Licensed under the MIT License. // -// This file implements the MHA -> SHA (Multi-Head Attention to Single-Head Attention) -// transformation for QNN AOT compilation, similar to ExecuTorch's ConvertMhaToSha pass. -// // The optimization splits large Q/K/V projections into per-head projections, // allowing QNN to optimize each head separately, reducing AOT compilation time // and improving HTP performance. #pragma once +#include "mllm/core/TensorStorage.hpp" #include "mllm/mllm.hpp" #include "mllm/nn/Nn.hpp" #include "mllm/nn/Module.hpp" @@ -288,7 +286,7 @@ class Qwen3AttentionSHA final : public nn::Module { std::vector query_states_per_head; for (int h = 0; h < num_attention_heads_; ++h) { auto q_h = q_projs_[h](hidden_states); - q_h = q_h.view({1, -1, 1, head_dim_}, /*ssa=*/true).transpose(1, 2); + q_h = q_h.view({1, 1, -1, head_dim_}, /*ssa=*/true); query_states_per_head.push_back(q_h); } @@ -296,7 +294,7 @@ class Qwen3AttentionSHA final : public nn::Module { std::vector key_states_per_head; for (int h = 0; h < num_key_value_heads_; ++h) { auto k_h = k_projs_[h](hidden_states); - k_h = k_h.view({1, -1, 1, head_dim_}, /*ssa=*/true).transpose(1, 2); + k_h = k_h.view({1, 1, -1, head_dim_}, /*ssa=*/true); key_states_per_head.push_back(k_h); } @@ -304,7 +302,7 @@ class Qwen3AttentionSHA final : public nn::Module { std::vector value_states_per_head; for (int h = 0; h < num_key_value_heads_; ++h) { auto v_h = v_projs_[h](hidden_states); - v_h = v_h.view({1, -1, 1, head_dim_}, /*ssa=*/true).transpose(1, 2); + v_h = v_h.view({1, 1, -1, head_dim_}, /*ssa=*/true); value_states_per_head.push_back(v_h); } @@ -695,7 +693,7 @@ inline void prepareParametersForSHA(const ParameterFile::ptr_t& params, const Qw int end_idx = (h + 1) * out_channels_per_head; // HWIO format: slice on dim 3 (Out_channels) auto sliced = orig_weight.slice({kAll, kAll, kAll, {start_idx, end_idx}}, false); - params->push(new_weight_name, sliced.contiguous()); + params->push(new_weight_name, sliced.contiguous().setMemType(kParamsNormal).setName(new_weight_name)); } } @@ -712,7 +710,7 @@ inline void prepareParametersForSHA(const ParameterFile::ptr_t& params, const Qw int start_idx = h * scale_per_head; int end_idx = (h + 1) * scale_per_head; auto sliced = orig_scale1.slice({{start_idx, end_idx}}, false); - params->push(new_scale1_name, sliced.contiguous()); + params->push(new_scale1_name, sliced.contiguous().setMemType(kParamsNormal).setName(new_scale1_name)); } } @@ -728,7 +726,7 @@ inline void prepareParametersForSHA(const ParameterFile::ptr_t& params, const Qw int start_idx = h * scale_per_head; int end_idx = (h + 1) * scale_per_head; auto sliced = orig_scale2.slice({{start_idx, end_idx}}, false); - params->push(new_scale2_name, sliced.contiguous()); + params->push(new_scale2_name, sliced.contiguous().setMemType(kParamsNormal).setName(new_scale2_name)); } } }; @@ -745,8 +743,6 @@ inline void prepareParametersForSHA(const ParameterFile::ptr_t& params, const Qw // Process V projection: split into num_kv_heads parts sliceAndPushConv2DParams(layer_prefix + "v_proj", layer_prefix + "v_proj", num_kv_heads * head_dim, head_dim, num_kv_heads); - Dbg(); - // Process Q norm params (per head) // RMSNorm has: weight (needs slicing), scale (scalar, copy), zero_point (scalar, copy) { @@ -769,18 +765,18 @@ inline void prepareParametersForSHA(const ParameterFile::ptr_t& params, const Qw // Weight: slice per head std::string new_weight_name = layer_prefix + "q_norm." + h_str + ".weight"; // NOLINT auto sliced = orig_weight.slice({{h * head_dim, (h + 1) * head_dim}}, false); - params->push(new_weight_name, sliced.contiguous()); + params->push(new_weight_name, sliced.contiguous().setMemType(kParamsNormal).setName(new_weight_name)); // Scale: copy (scalar) if (has_scale) { std::string new_scale_name = layer_prefix + "q_norm." + h_str + ".scale"; // NOLINT - params->push(new_scale_name, orig_scale.contiguous()); + params->push(new_scale_name, orig_scale.contiguous().setMemType(kParamsNormal).setName(new_scale_name)); } // Zero point: copy (scalar) if (has_zp) { std::string new_zp_name = layer_prefix + "q_norm." + h_str + ".zero_point"; // NOLINT - params->push(new_zp_name, orig_zp.contiguous()); + params->push(new_zp_name, orig_zp.contiguous().setMemType(kParamsNormal).setName(new_zp_name)); } } } @@ -806,18 +802,18 @@ inline void prepareParametersForSHA(const ParameterFile::ptr_t& params, const Qw // Weight: slice per head std::string new_weight_name = layer_prefix + "k_norm." + h_str + ".weight"; // NOLINT auto sliced = orig_weight.slice({{h * head_dim, (h + 1) * head_dim}}, false); - params->push(new_weight_name, sliced.contiguous()); + params->push(new_weight_name, sliced.contiguous().setMemType(kParamsNormal).setName(new_weight_name)); // Scale: copy (scalar) if (has_scale) { std::string new_scale_name = layer_prefix + "k_norm." + h_str + ".scale"; // NOLINT - params->push(new_scale_name, orig_scale.contiguous()); + params->push(new_scale_name, orig_scale.contiguous().setMemType(kParamsNormal).setName(new_scale_name)); } // Zero point: copy (scalar) if (has_zp) { std::string new_zp_name = layer_prefix + "k_norm." + h_str + ".zero_point"; // NOLINT - params->push(new_zp_name, orig_zp.contiguous()); + params->push(new_zp_name, orig_zp.contiguous().setMemType(kParamsNormal).setName(new_zp_name)); } } } @@ -842,8 +838,8 @@ inline void prepareParametersForSHA(const ParameterFile::ptr_t& params, const Qw std::string new_scale_name = layer_prefix + new_base_name + std::to_string(h) + ".fake_quant.scale"; std::string new_zp_name = layer_prefix + new_base_name + std::to_string(h) + ".fake_quant.zero_point"; // QDQ scale/zp are typically scalar or small tensors, clone to ensure contiguous - params->push(new_scale_name, scale.contiguous()); - params->push(new_zp_name, zp.contiguous()); + params->push(new_scale_name, scale.contiguous().setMemType(kParamsNormal).setName(new_scale_name)); + params->push(new_zp_name, zp.contiguous().setMemType(kParamsNormal).setName(new_zp_name)); } } }; diff --git a/mllm/backends/qnn/QNNBackend.cpp b/mllm/backends/qnn/QNNBackend.cpp index f0cc0a952..1a891cb6b 100644 --- a/mllm/backends/qnn/QNNBackend.cpp +++ b/mllm/backends/qnn/QNNBackend.cpp @@ -98,8 +98,8 @@ QNNPerf::QNNPerf(const QNN_INTERFACE_VER_TYPE* qnnInterface) { .powerMode = QNN_HTP_PERF_INFRASTRUCTURE_POWERMODE_PERFORMANCE_MODE, .setSleepLatency = 1, // True to consider Latency parameter otherwise False .sleepLatency = 40, // set dsp sleep latency ranges 10-65535 micro sec, refer hexagon sdk - .setSleepDisable = 1, // True to consider sleep disable/enable parameter otherwise False - .sleepDisable = 1, // True to disable sleep, False to re-enable sleep + .setSleepDisable = 0, // True to consider sleep disable/enable parameter otherwise False + .sleepDisable = 0, // True to disable sleep, False to re-enable sleep .setBusParams = 1, // True to consider Bus parameter otherwise False .busVoltageCornerMin = DCVS_VOLTAGE_VCORNER_MAX_VOLTAGE_CORNER, .busVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_MAX_VOLTAGE_CORNER, diff --git a/mllm/backends/qnn/aot/QnnWrappersAPI.cpp b/mllm/backends/qnn/aot/QnnWrappersAPI.cpp index e74781052..a585c17f4 100644 --- a/mllm/backends/qnn/aot/QnnWrappersAPI.cpp +++ b/mllm/backends/qnn/aot/QnnWrappersAPI.cpp @@ -270,10 +270,11 @@ QnnAOTGraph::QnnAOTGraph(QNN_INTERFACE_VER_TYPE& qnnInterface, Qnn_BackendHandle // Short Depth Conv On HMX Off QnnHtpGraph_CustomConfig_t* p_custom_config = nullptr; - p_custom_config = (QnnHtpGraph_CustomConfig_t*)malloc(sizeof(QnnHtpGraph_CustomConfig_t)); - p_custom_config->option = QNN_HTP_GRAPH_CONFIG_OPTION_SHORT_DEPTH_CONV_ON_HMX_OFF; - p_custom_config->shortDepthConvOnHmxOff = true; - htp_graph_configs.push_back(static_cast(p_custom_config)); + // FIXME: @chenghuaWang The code below will make llm inference slow!!! + // p_custom_config = (QnnHtpGraph_CustomConfig_t*)malloc(sizeof(QnnHtpGraph_CustomConfig_t)); + // p_custom_config->option = QNN_HTP_GRAPH_CONFIG_OPTION_SHORT_DEPTH_CONV_ON_HMX_OFF; + // p_custom_config->shortDepthConvOnHmxOff = true; + // htp_graph_configs.push_back(static_cast(p_custom_config)); // Fold Relu Activation Into Conv Off p_custom_config = (QnnHtpGraph_CustomConfig_t*)malloc(sizeof(QnnHtpGraph_CustomConfig_t)); diff --git a/mllm/backends/qnn/aot/passes/LLMQuantRecipePass.cpp b/mllm/backends/qnn/aot/passes/LLMQuantRecipePass.cpp index 444abf575..7e2a63220 100644 --- a/mllm/backends/qnn/aot/passes/LLMQuantRecipePass.cpp +++ b/mllm/backends/qnn/aot/passes/LLMQuantRecipePass.cpp @@ -645,34 +645,37 @@ bool LLMQuantRecipeConcatPattern::isMatch(const mllm::ir::op_ptr_t& op) { } bool LLMQuantRecipeConcatPattern::rewrite(ir::IRWriter& writer, const ir::op_ptr_t& node) { - // Current support concat two Tensor. Inherent first tensor's Quant Spec. + // Support concat with multiple inputs. Inherit first tensor's Quant Spec. auto concat_ir = node->cast_(); - auto i_0 = *(node->inputs().begin()); // t1 - auto i_1 = *(std::next(node->inputs().begin())); // t2 - auto o_0 = *(node->outputs().begin()); // to1 + auto o_0 = *(node->outputs().begin()); // to1 - if (concat_ir->inputs().size() != 2) { - MLLM_WARN("Current support concat two Tensor. Inherent first tensor's setting."); + if (concat_ir->inputs().empty()) { + MLLM_WARN("Concat op has no inputs."); return false; } - // Create quant_recipe if not present - if (!i_0->getAttr("quant_recipe")) { - auto i_0_spec = genSimpleQuantizationSpecAttr(writer.getContext(), i_0->cast_()); - i_0->setAttr("quant_recipe", i_0_spec); - } - if (!i_1->getAttr("quant_recipe")) { - auto i_1_spec = genSimpleQuantizationSpecAttr(writer.getContext(), i_1->cast_()); - i_1->setAttr("quant_recipe", i_1_spec); + auto i_0 = *(node->inputs().begin()); // First input + + // Create quant_recipe for all inputs if not present + for (auto input : node->inputs()) { + if (!input->getAttr("quant_recipe")) { + auto input_spec = genSimpleQuantizationSpecAttr(writer.getContext(), input->cast_()); + input->setAttr("quant_recipe", input_spec); + } } + // Output inherits first tensor's quant_recipe o_0->setAttr("quant_recipe", i_0->getAttr("quant_recipe")); auto annotation_attr = writer.create(); - annotation_attr->annotation_.inputs.emplace_back( - i_0->getAttr("quant_recipe")->cast_()->spec_); - annotation_attr->annotation_.inputs.emplace_back( - i_1->getAttr("quant_recipe")->cast_()->spec_); + + // Add quant_recipe for all inputs + for (auto input : node->inputs()) { + annotation_attr->annotation_.inputs.emplace_back( + input->getAttr("quant_recipe")->cast_()->spec_); + } + + // Add quant_recipe for output annotation_attr->annotation_.outputs.emplace_back( o_0->getAttr("quant_recipe")->cast_()->spec_); diff --git a/pymllm/backends/qualcomm/transformers/llama/modeling_llama.py b/pymllm/backends/qualcomm/transformers/llama/modeling_llama.py new file mode 100644 index 000000000..119ec04bc --- /dev/null +++ b/pymllm/backends/qualcomm/transformers/llama/modeling_llama.py @@ -0,0 +1,831 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Optional, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.generation import GenerationMixin +from transformers.integrations import use_kernel_forward_from_hub +from transformers.masking_utils import create_causal_mask +from transformers.modeling_layers import ( + GenericForQuestionAnswering, + GenericForSequenceClassification, + GenericForTokenClassification, + GradientCheckpointingLayer, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.processing_utils import Unpack +from transformers.utils import ( + TransformersKwargs, + auto_docstring, + can_return_tuple, + logging, +) +from transformers.utils.deprecation import deprecate_kwarg +from transformers.utils.generic import check_model_inputs +from transformers.models.llama.configuration_llama import LlamaConfig + +# Replace linear, rms_norm with: +from pymllm.backends.qualcomm.transformers.core.rms_norm import QRMSNorm +from pymllm.backends.qualcomm.transformers.core.qlinear import ( + QLinearLPBQ, +) +from pymllm.backends.qualcomm.transformers.core.qdq import ( + ActivationQDQ, + FixedActivationQDQ, +) +from pymllm.backends.qualcomm.transformers.core.embedding import QEmbedding +from pymllm.backends.qualcomm.transformers.core.observer import ConcatObserver + + +logger = logging.get_logger(__name__) + + +class LlamaRMSNorm(QRMSNorm): + def __init__(self, hidden_size, eps=1e-6, quant_bits=16): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__(hidden_size, eps=eps, quant_bits=quant_bits) + + +class LlamaRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: LlamaConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get( + "rope_type", config.rope_scaling.get("type") + ) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = ( + self.inv_freq[None, :, None] + .float() + .expand(position_ids.shape[0], -1, 1) + .to(x.device) + ) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = ( + x.device.type + if isinstance(x.device.type, str) and x.device.type != "mps" + else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = ( + inv_freq_expanded.float() @ position_ids_expanded.float() + ).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half( + x, x_observer, x2_neg_fake_quant: ActivationQDQ, concat_observer: ConcatObserver +): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return concat_observer(torch.cat((x2_neg_fake_quant(-x2), x1), dim=-1)) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class LlamaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = QLinearLPBQ( + self.hidden_size, + self.intermediate_size, + bias=config.mlp_bias, + block_size=32, + ) + self.up_proj = QLinearLPBQ( + self.hidden_size, + self.intermediate_size, + bias=config.mlp_bias, + block_size=32, + ) + self.down_proj = QLinearLPBQ( + self.intermediate_size, + self.hidden_size, + bias=config.mlp_bias, + block_size=32, + ) + self.act_fn = ACT2FN[config.hidden_act] + + # QDQ + self.up_proj_input_qdq = ActivationQDQ(bits=16) + self.up_proj_output_qdq = ActivationQDQ(bits=16) + self.gate_proj_output_qdq = ActivationQDQ(bits=16) + self.act_output_qdq = ActivationQDQ(bits=16) + self.down_proj_input_qdq = ActivationQDQ(bits=16) + # For sigmoid output: scale = 1 / (q_max - q_min + 1), zp = 0 + # For 16-bit: q_min = 0, q_max = 65535 + sigmoid_scale = 1.0 / (65535 - 0 + 1) # 1 / 65536 + self.sigmoid_output_qdq = FixedActivationQDQ( + scale=sigmoid_scale, zero_point=0, bits=16 + ) + + def forward(self, x): + x = self.up_proj_input_qdq(x) + up_result = self.up_proj_output_qdq(self.up_proj(x)) + gate_result = self.gate_proj_output_qdq(self.gate_proj(x)) + + # SiLU or other activation + gate_result = self.act_output_qdq( + gate_result * self.sigmoid_output_qdq(F.sigmoid(gate_result)) + ) + + o = self.down_proj_input_qdq(gate_result * up_result) + o = self.down_proj(o) + return o + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query.dtype + ) + attn_weights = nn.functional.dropout( + attn_weights, p=dropout, training=module.training + ) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = QLinearLPBQ( + config.hidden_size, + config.num_attention_heads * self.head_dim, + bias=config.attention_bias, + block_size=32, + ) + self.k_proj = QLinearLPBQ( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + block_size=32, + ) + self.v_proj = QLinearLPBQ( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + block_size=32, + ) + self.o_proj = QLinearLPBQ( + config.num_attention_heads * self.head_dim, + config.hidden_size, + bias=config.attention_bias, + block_size=32, + ) + + # QDQ + self.q_proj_input_qdq = ActivationQDQ(bits=16) + self.k_proj_input_qdq = ActivationQDQ(bits=16) + self.v_proj_input_qdq = ActivationQDQ(bits=16) + + self.q_proj_output_qdq = ActivationQDQ(bits=16) + self.k_proj_output_qdq = ActivationQDQ(bits=16) + + self.q_rope_mul_0_output_qdq = ActivationQDQ(bits=16) + self.q_rope_mul_1_output_qdq = ActivationQDQ(bits=16) + self.q_rope_add_0_output_qdq = ActivationQDQ(bits=16) + self.k_rope_mul_0_output_qdq = ActivationQDQ(bits=16) + self.k_rope_mul_1_output_qdq = ActivationQDQ(bits=16) + self.k_rope_add_0_output_qdq = ActivationQDQ(bits=16) + + self.q_rope_concat_observer = ConcatObserver( + dtype=torch.int32, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=0, + quant_max=2**16 - 1, + eps=0.0001 / 65535, + is_dynamic=False, + ) + self.q_rope_neg_half_qdq = ActivationQDQ(bits=16) + self.k_rope_concat_observer = ConcatObserver( + dtype=torch.int32, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=0, + quant_max=2**16 - 1, + eps=0.0001 / 65535, + is_dynamic=False, + ) + self.k_rope_neg_half_qdq = ActivationQDQ(bits=16) + self.k_rope_concat_observer.add_observer( + self.k_proj_input_qdq.fake_quant.activation_post_process + ) + self.k_rope_concat_observer.add_observer( + self.k_rope_neg_half_qdq.fake_quant.activation_post_process + ) + self.q_rope_concat_observer.add_observer( + self.q_proj_input_qdq.fake_quant.activation_post_process + ) + self.q_rope_concat_observer.add_observer( + self.q_rope_neg_half_qdq.fake_quant.activation_post_process + ) + + # In qnn, is uint8 sym. + self.k_cast_to_int8_qdq = ActivationQDQ( + bits=8, qscheme=torch.per_tensor_symmetric + ) + self.v_cast_to_int8_qdq = ActivationQDQ( + bits=8, qscheme=torch.per_tensor_symmetric + ) + + self.v_cast_to_int16_qdq = ActivationQDQ(bits=16) + self.qk_matmul_output_qdq = ActivationQDQ(bits=16) + self.scaling_qdq = ActivationQDQ(bits=16) + self.neg_20_qdq = ActivationQDQ(bits=16) + self.reduce_min_output_qdq = ActivationQDQ(bits=16) + self.mul_0_output_qdq = ActivationQDQ(bits=16) + self.minus_0_output_qdq = ActivationQDQ(bits=16) + self.softmax_output_qdq = ActivationQDQ(bits=16) + self.attn_value_matmul_output_qdq = ActivationQDQ(bits=16) + self.where_attn_qdq = ActivationQDQ(bits=16) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + hidden_states = self.q_proj_input_qdq(hidden_states) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + query_states = self.q_proj_output_qdq(query_states) + + hidden_states_k = self.k_proj_input_qdq(hidden_states) + key_states = self.k_proj(hidden_states_k).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj_output_qdq(key_states) + + hidden_states_v = self.v_proj_input_qdq(hidden_states) + value_states = self.v_proj(hidden_states_v).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + cos = cos.unsqueeze(1) + sin = sin.unsqueeze(1) + query_states = self.q_rope_add_0_output_qdq( + self.q_rope_mul_0_output_qdq(query_states * cos) + + self.q_rope_mul_1_output_qdq( + rotate_half( + query_states, + self.q_proj_input_qdq.fake_quant.activation_post_process, + self.q_rope_neg_half_qdq, + self.q_rope_concat_observer, + ) + * sin + ) + ) + key_states = self.k_rope_add_0_output_qdq( + self.k_rope_mul_0_output_qdq(key_states * cos) + + self.k_rope_mul_1_output_qdq( + rotate_half( + key_states, + self.k_proj_input_qdq.fake_quant.activation_post_process, + self.k_rope_neg_half_qdq, + self.k_rope_concat_observer, + ) + * sin + ) + ) + + key_states = self.k_cast_to_int8_qdq(key_states) + value_states = self.v_cast_to_int8_qdq(self.v_cast_to_int16_qdq(value_states)) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = self.mul_0_output_qdq( + self.qk_matmul_output_qdq( + torch.matmul(query_states, key_states.transpose(2, 3)) + ) + * self.scaling_qdq( + torch.ones(1, dtype=value_states.dtype, device=value_states.device) + * self.scaling + ) + ) + + attn_min = self.reduce_min_output_qdq( + torch.amin(attn_weights, dim=-1, keepdim=True) + ) + attn_vv = self.minus_0_output_qdq( + attn_min + + self.neg_20_qdq( + torch.ones(1, dtype=value_states.dtype, device=value_states.device) + * (-20) + ) + ) + attn_weights = self.where_attn_qdq( + torch.where(attention_mask == 0, attn_weights, attn_vv) + ) + + attn_weights = self.softmax_output_qdq( + nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query_states.dtype + ) + ) + attn_output = self.attn_value_matmul_output_qdq( + torch.matmul(attn_weights, value_states) + ) + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class LlamaDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__() + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + + self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx) + + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps, quant_bits=16 + ) + self.post_attention_layernorm = LlamaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps, quant_bits=16 + ) + + # QDQ + if self.layer_idx != 0: + self.input_layernorm_input_qdq = ActivationQDQ(bits=16) + self.add_0_lhs_input_qdq = ActivationQDQ(bits=16) + self.add_0_output_qdq = ActivationQDQ(bits=16) + self.add_1_lhs_input_qdq = ActivationQDQ(bits=16) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + tuple[torch.Tensor, torch.Tensor] + ] = None, # necessary, but kept here for BC + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + if self.layer_idx != 0: + hidden_states = self.input_layernorm_input_qdq(hidden_states) + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = self.add_0_output_qdq( + residual + self.add_0_lhs_input_qdq(hidden_states) + ) + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.add_1_lhs_input_qdq(hidden_states) + return hidden_states + + +@auto_docstring +class LlamaPreTrainedModel(PreTrainedModel): + config: LlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": LlamaDecoderLayer, + "attentions": LlamaAttention, + } + + +@auto_docstring +class LlamaModel(LlamaPreTrainedModel): + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = QEmbedding( + config.vocab_size, config.hidden_size, self.padding_idx, quant_bits=16 + ) + self.layers = nn.ModuleList( + [ + LlamaDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = LlamaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps, quant_bits=16 + ) + self.rotary_emb = LlamaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Register sin and cos as buffers + self.register_buffer("mllm_max_sin_embedding", None) + self.register_buffer("mllm_max_cos_embedding", None) + self.sin_embedding_input_qdq = ActivationQDQ(bits=16) + self.cos_embedding_input_qdq = ActivationQDQ(bits=16) + self.norm_input_qdq = ActivationQDQ(bits=16) + + # Initialize weights and apply final processing + self.post_init() + + @torch.no_grad() + def convert_rope_for_deploy(self): + sin_scale = self.sin_embedding_input_qdq.fake_quant.scale + sin_zero_point = self.sin_embedding_input_qdq.fake_quant.zero_point + sin_quant_min = self.sin_embedding_input_qdq.fake_quant.quant_min + sin_quant_max = self.sin_embedding_input_qdq.fake_quant.quant_max + + cos_scale = self.cos_embedding_input_qdq.fake_quant.scale + cos_zero_point = self.cos_embedding_input_qdq.fake_quant.zero_point + cos_quant_min = self.cos_embedding_input_qdq.fake_quant.quant_min + cos_quant_max = self.cos_embedding_input_qdq.fake_quant.quant_max + + sin_int = torch.round( + self.mllm_max_sin_embedding / sin_scale + sin_zero_point + ).clamp(sin_quant_min, sin_quant_max) + self.mllm_max_sin_embedding = sin_int.to(torch.uint16) + + cos_int = torch.round( + self.mllm_max_cos_embedding / cos_scale + cos_zero_point + ).clamp(cos_quant_min, cos_quant_max) + self.mllm_max_cos_embedding = cos_int.to(torch.uint16) + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + if inputs_embeds is None: + inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + cache_position: torch.Tensor = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + + if self.mllm_max_sin_embedding is None and self.mllm_max_cos_embedding is None: + mllm_qualcomm_max_length = kwargs.get("mllm_qualcomm_max_length", None) + assert mllm_qualcomm_max_length is not None + max_position_ids = torch.arange( + 0, + mllm_qualcomm_max_length, + dtype=position_ids.dtype, + device=position_ids.device, + ).unsqueeze(0) + self.mllm_max_cos_embedding, self.mllm_max_sin_embedding = self.rotary_emb( + hidden_states, max_position_ids + ) + self.mllm_max_cos_embedding = self.mllm_max_cos_embedding.to( + inputs_embeds.dtype + ) + self.mllm_max_sin_embedding = self.mllm_max_sin_embedding.to( + inputs_embeds.dtype + ) + self.mllm_max_cos_embedding = self.cos_embedding_input_qdq( + self.mllm_max_cos_embedding + ) + self.mllm_max_sin_embedding = self.sin_embedding_input_qdq( + self.mllm_max_sin_embedding + ) + + # create position embeddings to be shared across the decoder layers + position_embeddings = ( + self.mllm_max_cos_embedding[:, position_ids.squeeze(0), :], + self.mllm_max_sin_embedding[:, position_ids.squeeze(0), :], + ) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.norm(self.norm_input_qdq(hidden_states)) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +@auto_docstring +class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = LlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = QLinearLPBQ( + config.hidden_size, config.vocab_size, bias=False, block_size=32 + ) + self.mllm_qualcomm_max_length = None + + self.lm_head_input_qdq = ActivationQDQ(bits=16) + self.lm_head_output_qdq = ActivationQDQ(bits=16) + + # Initialize weights and apply final processing + self.post_init() + + @torch.no_grad() + def copy_lm_head_weight_from_embed_tokens(self): + if self.config.tie_word_embeddings: + self.lm_head.weight.copy_(self.model.embed_tokens.weight) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + kwargs.update({"mllm_qualcomm_max_length": self.mllm_qualcomm_max_length}) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + logits = self.lm_head( + self.lm_head_input_qdq(hidden_states[:, slice_indices, :]) + ) + logits = self.lm_head_output_qdq(logits) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class LlamaForSequenceClassification( + GenericForSequenceClassification, LlamaPreTrainedModel +): ... + + +class LlamaForQuestionAnswering(GenericForQuestionAnswering, LlamaPreTrainedModel): + base_model_prefix = ( + "transformer" # For BC, where `transformer` was used instead of `model` + ) + + +class LlamaForTokenClassification( + GenericForTokenClassification, LlamaPreTrainedModel +): ... + + +__all__ = [ + "LlamaForCausalLM", + "LlamaModel", + "LlamaPreTrainedModel", + "LlamaForSequenceClassification", + "LlamaForQuestionAnswering", + "LlamaForTokenClassification", +] diff --git a/pymllm/backends/qualcomm/transformers/llama/runner.py b/pymllm/backends/qualcomm/transformers/llama/runner.py new file mode 100644 index 000000000..8aa4627bf --- /dev/null +++ b/pymllm/backends/qualcomm/transformers/llama/runner.py @@ -0,0 +1,345 @@ +import torch +from tqdm import tqdm +from modelscope.msdatasets import MsDataset +from transformers import AutoTokenizer +from pymllm.backends.qualcomm.transformers.core.qdq import ( + ActivationQDQ, + FixedActivationQDQ, +) +from pymllm.backends.qualcomm.transformers.core.rms_norm import QRMSNorm +from pymllm.backends.qualcomm.transformers.core.qlinear import ( + QLinearLPBQ, + QLinearW8A16_PerChannelSym, +) +from pymllm.backends.qualcomm.transformers.core.embedding import QEmbedding +from pymllm.backends.qualcomm.transformers.llama.modeling_llama import LlamaForCausalLM +from pymllm.backends.qualcomm.transformers.core.observer import ConcatObserver + + +def recompute_scale_zp(module): + """ + Callback function: Used to forcefully refresh scale and zero_point of all FakeQuantize modules after calibration. + + Problem solved: + When using ConcatObserver, min/max may be updated during forward pass, + but at the end of forward, the scale/zp stored in FakeQuantize's internal buffer are still computed from old min/max. + This function forces a calculate_qparams call to sync the latest parameters to the buffer. + + Usage: + model.apply(recompute_scale_zp) + """ + + # We mainly focus on FakeQuantize modules since they store the scale/zero_point buffers + # Note: model.apply recursively traverses all submodules, so self.fake_quant inside ActivationQDQ will also be visited + if isinstance(module, ActivationQDQ): + observer = module.fake_quant.activation_post_process + + # 2. Check if observer is valid and contains statistics + # We only care about MinMaxObserver or MovingAverageMinMaxObserver that have min_val/max_val + if hasattr(observer, "min_val") and hasattr(observer, "max_val"): + # 3. Check if data is initialized + # If min_val is still the initial inf, this layer hasn't processed data, skip to avoid errors + if observer.min_val.numel() == 0 or observer.max_val.numel() == 0: + return + if ( + torch.isinf(observer.min_val).any() + or torch.isinf(observer.max_val).any() + ): + return + + # 4. Recompute Scale and Zero Point + # calculate_qparams reads the current min_val/max_val from observer (may have been modified by ConcatObserver) + try: + scale, zero_point = observer.calculate_qparams() + except Exception as e: + # Some special Observers (e.g., FixedQParams) may not support recomputation or behave differently, safely skip + print(e) + return + + # 5. Force overwrite the computed results to FakeQuantize's Buffer + # Use copy_ to keep reference unchanged, ensuring the new values are used during export + if ( + hasattr(module.fake_quant, "scale") + and module.fake_quant.scale is not None + ): + # Ensure dimension match (handle per-channel vs per-tensor) + if module.fake_quant.scale.shape != scale.shape: + module.fake_quant.scale.resize_(scale.shape) + module.fake_quant.scale.copy_(scale) + # Try to get the registered name of module scale from _parameters or _buffers + for key, value in module.fake_quant.named_parameters(): + if value is module.fake_quant.scale: + print(f"{module._get_name()}.{key}: {module.scale}") + break + + if ( + hasattr(module.fake_quant, "zero_point") + and module.fake_quant.zero_point is not None + ): + if module.fake_quant.zero_point.shape != zero_point.shape: + module.fake_quant.zero_point.resize_(zero_point.shape) + module.fake_quant.zero_point.copy_(zero_point) + + +def validate_concat_observer_fn(module, results: list, name: str = ""): + """ + Callback function: Validate that all input_observers in ConcatObserver have consistent scale and zero_point. + + Usage: + results = [] + for name, m in model.named_modules(): + validate_concat_observer_fn(m, results, name) + """ + if not isinstance(module, ConcatObserver): + return + + input_observers = module.input_observers + if len(input_observers) == 0: + return + + # Collect scale and zero_point from all observers + scales_zps = [] + for i, observer in enumerate(input_observers): + try: + scale, zp = observer.calculate_qparams() + scales_zps.append(f"[{i}] s={scale.item():.8f} zp={zp.item()}") + except Exception: + scales_zps.append(f"[{i}] failed") + + # Print one line: scale and zp of all inputs for each concat observer + print(f"ConcatObserver [{name}]: {' | '.join(scales_zps)}") + + # Original validation logic + if len(input_observers) <= 1: + return + + # Get scale and zero_point from the first observer as reference + first_observer = input_observers[0] + try: + ref_scale, ref_zp = first_observer.calculate_qparams() + except Exception: + return + + # Check if all other observers have the same scale and zero_point + for i, observer in enumerate(input_observers[1:], start=1): + try: + scale, zp = observer.calculate_qparams() + except Exception: + results.append(f"Failed to calculate qparams for observer[{i}]") + continue + + scale_match = torch.allclose(ref_scale, scale, rtol=1e-5, atol=1e-8) + zp_match = torch.equal(ref_zp, zp) + + if not scale_match or not zp_match: + results.append( + f"observer[{i}] mismatch: ref_scale={ref_scale.item():.8f}, " + f"scale={scale.item():.8f}, ref_zp={ref_zp.item()}, zp={zp.item()}" + ) + + +def freeze_llama_rmsnorm_weight(m): + if isinstance(m, QRMSNorm): + m.freeze_weight() + + +def freeze_llama_linear_weight(m): + if isinstance(m, QLinearLPBQ) or isinstance(m, QLinearW8A16_PerChannelSym): + m.freeze_weight() + + +def freeze_llama_embed_tokens_weight(m): + if isinstance(m, QEmbedding): + m.freeze_weight() + + +def disable_qdq_observer(m): + if isinstance(m, ActivationQDQ): + m.disable_observer() + + +def enable_qdq_observer(m): + if isinstance(m, ActivationQDQ): + m.enable_observer() + + +def enable_fake_quant(m): + if isinstance(m, ActivationQDQ) or isinstance(m, FixedActivationQDQ): + m.enable_fakequant() + if isinstance(m, QLinearLPBQ): + m.enable_fakequant() + if isinstance(m, QRMSNorm): + m.enable_fakequant() + if isinstance(m, QEmbedding): + m.enable_fakequant() + + +def disable_fake_quant(m): + if isinstance(m, ActivationQDQ) or isinstance(m, FixedActivationQDQ): + m.disable_fakequant() + if isinstance(m, QLinearLPBQ): + m.disable_fakequant() + if isinstance(m, QRMSNorm): + m.disable_fakequant() + if isinstance(m, QEmbedding): + m.disable_fakequant() + + +def convert_weight(m): + if isinstance(m, QLinearLPBQ) or isinstance(m, QLinearW8A16_PerChannelSym): + m.convert_to_conv2d_deploy_hwio() + if isinstance(m, QRMSNorm): + m.convert_to_deploy() + if isinstance(m, QEmbedding): + m.convert_to_deploy() + + +class LlamaQuantizer: + def __init__(self, model_path: str, mllm_qualcomm_max_length=2048): + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + self.model = LlamaForCausalLM.from_pretrained( + model_path, + attn_implementation="eager", + dtype=torch.float32, + ) + self.model.cuda() + self.mllm_qualcomm_max_length = mllm_qualcomm_max_length + self.model.mllm_qualcomm_max_length = mllm_qualcomm_max_length + + if self.model.config.tie_word_embeddings: + self.model.copy_lm_head_weight_from_embed_tokens() + + # PTQ All Weights. + self.model.apply(freeze_llama_rmsnorm_weight) + self.model.apply(freeze_llama_linear_weight) + self.model.apply(freeze_llama_embed_tokens_weight) + print("All PTQ weights preparation done.") + + def freeze_activation(self): + self.model.apply(disable_qdq_observer) + + def enable_activation_update(self): + self.model.apply(enable_qdq_observer) + + def enable_fake_quant(self): + self.model.apply(enable_fake_quant) + + def disable_fake_quant(self): + self.model.apply(disable_fake_quant) + + def compile(self): + print("Compile Start.") + self.model = torch.compile( + self.model, mode="reduce-overhead", fullgraph=False, backend="inductor" + ) + print("Compile done.") + + def infer(self, prompt: str): + # Llama models typically don't use chat templates, so we tokenize directly + model_inputs = self.tokenizer([prompt], return_tensors="pt").to( + self.model.device + ) + + # conduct text completion + generated_ids = self.model.generate( + **model_inputs, + max_new_tokens=self.mllm_qualcomm_max_length + - len(model_inputs.input_ids[0]) + - 1, + do_sample=False, + temperature=None, + top_p=None, + top_k=None, + ) + output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist() + content = self.tokenizer.decode(output_ids, skip_special_tokens=True).strip( + "\n" + ) + + print("content:", content) + + def calibrate(self, num_samples=64, max_seq_length=512): + """ + Perform calibration using Wikipedia dataset (PTQ) + :param num_samples: Number of samples for calibration + :param max_seq_length: Maximum length for each sample (not exceeding mllm_qualcomm_max_length) + """ + print( + f"Starting calibration, samples: {num_samples}, max length: {max_seq_length}" + ) + + # 1. Enable QDQ Observer for activation values + self.enable_activation_update() + self.model.eval() + + # 2. Load Wikipedia dataset (English version example) + # Use streaming=True to download and process on the fly, without downloading the full几十G dataset + dataset = MsDataset.load( + "modelscope/wikitext", + subset_name="wikitext-103-v1", + split="train", + trust_remote_code=True, + ) + + # 3. Execute forward pass (Prefill stage) + samples_processed = 0 + + # Ensure no gradient calculation during inference + with torch.no_grad(): + pbar = tqdm(total=num_samples, desc="Calibrating") + for entry in dataset: + if samples_processed >= num_samples: + break + + if len(entry["text"].strip()) < 1024: + continue + + # Llama models typically don't use chat templates + text = entry["text"] + model_inputs = self.tokenizer( + [text], + return_tensors="pt", + max_length=max_seq_length, + truncation=True, + padding=False, + ).to(self.model.device) + + # Only need Prefill stage: directly call forward + # This will trigger observer update statistics in ActivationQDQ + self.model.generate( + **model_inputs, + max_new_tokens=1, + do_sample=False, + temperature=None, + top_p=None, + top_k=None, + ) + + samples_processed += 1 + pbar.update(1) + + # 4. Close Observer, freeze calibrated quantization parameters + self.freeze_activation() + print("\nCalibration completed, activation quantization parameters frozen.") + + def convert(self): + self.model.apply(convert_weight) + self.model.model.convert_rope_for_deploy() + + def recompute_scale_zp(self): + self.model.apply(recompute_scale_zp) + + def validate_concat_observer(self): + results = [] + for name, module in self.model.named_modules(): + validate_concat_observer_fn(module, results, name) + if results: + print("ConcatObserver validation FAILED:") + for msg in results: + print(f" {msg}") + raise ValueError("ConcatObserver validation FAILED") + else: + print( + "ConcatObserver validation PASSED: all observers have matching scale and zp" + ) + print("ConcatObserver validation done.", flush=True) diff --git a/pymllm/backends/qualcomm/transformers/llama/train.py b/pymllm/backends/qualcomm/transformers/llama/train.py new file mode 100644 index 000000000..cd10befba --- /dev/null +++ b/pymllm/backends/qualcomm/transformers/llama/train.py @@ -0,0 +1,56 @@ +import os +import torch +import argparse +from safetensors.torch import save_model +from pymllm.backends.qualcomm.transformers.llama.runner import LlamaQuantizer + + +def main(): + parser = argparse.ArgumentParser(description="Llama Quantizer for Qualcomm backend") + parser.add_argument( + "--model_path", + type=str, + default="", + help="Path to the Llama model directory", + ) + parser.add_argument( + "--max_length", + type=int, + default=2048, + help="Maximum sequence length for quantization", + ) + parser.add_argument( + "--num_samples", type=int, default=128, help="Number of samples for calibration" + ) + parser.add_argument( + "--infer_text", + type=str, + default="为什么伟大不能被计划", + help="Text to run inference on", + ) + parser.add_argument( + "--output_dir", + type=str, + help="Directory to save the quantized model", + ) + + args = parser.parse_args() + + m = LlamaQuantizer(args.model_path, mllm_qualcomm_max_length=args.max_length) + + # FIXME: Should disable or not. + m.disable_fake_quant() + m.calibrate(num_samples=args.num_samples, max_seq_length=args.max_length) + m.enable_fake_quant() + m.recompute_scale_zp() + m.validate_concat_observer() + m.infer(args.infer_text) + m.convert() + + os.makedirs(args.output_dir, exist_ok=True) + model_save_path = os.path.join(args.output_dir, "model.safetensors") + save_model(m.model, model_save_path) + + +if __name__ == "__main__": + main() diff --git a/pymllm/backends/qualcomm/transformers/qwen2/modeling_qwen2.py b/pymllm/backends/qualcomm/transformers/qwen2/modeling_qwen2.py index 1e32266a5..56b19c421 100644 --- a/pymllm/backends/qualcomm/transformers/qwen2/modeling_qwen2.py +++ b/pymllm/backends/qualcomm/transformers/qwen2/modeling_qwen2.py @@ -50,13 +50,13 @@ def __init__(self, config): self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = QLinearLPBQ( - self.hidden_size, self.intermediate_size, bias=False, block_size=16 + self.hidden_size, self.intermediate_size, bias=False, block_size=32 ) self.up_proj = QLinearLPBQ( - self.hidden_size, self.intermediate_size, bias=False, block_size=16 + self.hidden_size, self.intermediate_size, bias=False, block_size=32 ) self.down_proj = QLinearLPBQ( - self.intermediate_size, self.hidden_size, bias=False, block_size=16 + self.intermediate_size, self.hidden_size, bias=False, block_size=32 ) # QDQ @@ -158,25 +158,25 @@ def __init__(self, config: Qwen2Config, layer_idx: int): config.hidden_size, config.num_attention_heads * self.head_dim, bias=attention_bias, - block_size=16, + block_size=32, ) self.k_proj = QLinearLPBQ( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=attention_bias, - block_size=16, + block_size=32, ) self.v_proj = QLinearLPBQ( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=attention_bias, - block_size=16, + block_size=32, ) self.o_proj = QLinearLPBQ( config.num_attention_heads * self.head_dim, config.hidden_size, bias=False, - block_size=16, + block_size=32, ) self.sliding_window = ( config.sliding_window @@ -688,7 +688,7 @@ def __init__(self, config): self.model = Qwen2Model(config) self.vocab_size = config.vocab_size self.lm_head = QLinearLPBQ( - config.hidden_size, config.vocab_size, bias=False, block_size=16 + config.hidden_size, config.vocab_size, bias=False, block_size=32 ) self.mllm_qualcomm_max_length = None From dae8eb60a723759030bcaebbdb7ca0ca6fc32876 Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Thu, 29 Jan 2026 02:28:49 +0000 Subject: [PATCH 4/8] feat(llama_qnn_aot): Add Llama AOT support with SHA optimization. Introduce new executables for model compilation and runtime, enhance CMake configurations, and implement weight slicing utilities for improved performance. Include configuration files and update input handling for tensor dimensions in AOT run. --- examples/CMakeLists.txt | 1 + examples/llama_qnn_aot/CMakeLists.txt | 14 + examples/llama_qnn_aot/aot_run.cpp | 64 ++ examples/llama_qnn_aot/compile.cpp | 159 ++++ examples/llama_qnn_aot/compile_sha.cpp | 196 +++++ examples/llama_qnn_aot/config_3B.json | 28 + .../llama_qnn_aot/configuration_llama3.hpp | 97 +++ .../llama_qnn_aot/modeling_llama_qnn_aot.hpp | 508 +++++++++++ .../modeling_llama_qnn_aot_sha.hpp | 788 ++++++++++++++++++ examples/llama_qnn_aot/qnn_aot_cfg_3B.json | 6 + examples/qwen2_qnn_aot/aot_run.cpp | 7 +- 11 files changed, 1866 insertions(+), 2 deletions(-) create mode 100644 examples/llama_qnn_aot/CMakeLists.txt create mode 100644 examples/llama_qnn_aot/aot_run.cpp create mode 100644 examples/llama_qnn_aot/compile.cpp create mode 100644 examples/llama_qnn_aot/compile_sha.cpp create mode 100644 examples/llama_qnn_aot/config_3B.json create mode 100644 examples/llama_qnn_aot/configuration_llama3.hpp create mode 100644 examples/llama_qnn_aot/modeling_llama_qnn_aot.hpp create mode 100644 examples/llama_qnn_aot/modeling_llama_qnn_aot_sha.hpp create mode 100644 examples/llama_qnn_aot/qnn_aot_cfg_3B.json diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index ea02dd3c6..31bd8e1b1 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -20,4 +20,5 @@ endif() if(MLLM_QUALCOMM_QNN_AOT_ON_X86_ENABLE OR MLLM_BUILD_QNN_BACKEND) add_subdirectory(qwen3_qnn_aot) add_subdirectory(qwen2_qnn_aot) + add_subdirectory(llama_qnn_aot) endif() diff --git a/examples/llama_qnn_aot/CMakeLists.txt b/examples/llama_qnn_aot/CMakeLists.txt new file mode 100644 index 000000000..029d8d1e7 --- /dev/null +++ b/examples/llama_qnn_aot/CMakeLists.txt @@ -0,0 +1,14 @@ +# AOT targets run on x86 +if(MLLM_QUALCOMM_QNN_AOT_ON_X86_ENABLE) + add_executable(mllm-llama-aot-c compile.cpp) + target_link_libraries(mllm-llama-aot-c PRIVATE MllmRT MllmCPUBackend MllmQNNBackend) + target_include_directories(mllm-llama-aot-c PRIVATE ${MLLM_INCLUDE_DIR}) + + add_executable(mllm-llama-aot-c-sha compile_sha.cpp) + target_link_libraries(mllm-llama-aot-c-sha PRIVATE MllmRT MllmCPUBackend MllmQNNBackend) + target_include_directories(mllm-llama-aot-c-sha PRIVATE ${MLLM_INCLUDE_DIR}) +endif() + +add_executable(mllm-llama-aot-runner aot_run.cpp) +target_link_libraries(mllm-llama-aot-runner PRIVATE MllmRT MllmCPUBackend MllmQNNBackend) +target_include_directories(mllm-llama-aot-runner PRIVATE ${MLLM_INCLUDE_DIR}) diff --git a/examples/llama_qnn_aot/aot_run.cpp b/examples/llama_qnn_aot/aot_run.cpp new file mode 100644 index 000000000..b54b7ba48 --- /dev/null +++ b/examples/llama_qnn_aot/aot_run.cpp @@ -0,0 +1,64 @@ +#include +#include +#include +#include +#include "mllm/backends/qnn/aot_rt/QnnAOTRuntime.hpp" +#include "configuration_llama3.hpp" +#include "mllm/models/qwen3/tokenization_qwen3.hpp" + +using mllm::Argparse; +using namespace mllm::qnn::aot; // NOLINT + +MLLM_MAIN({ + auto& help = Argparse::add("-h|--help").help("Show help message"); + auto& model_path = Argparse::add("-m|--model").help("Model path").def("llama_qnn.mllm"); + auto& tokenizer_path = Argparse::add("-t|--tokenizer").help("Tokenizer path").def("tokenizer.json"); + auto& config_path = Argparse::add("-c|--config").help("Config path").required(true); + auto& ar_len = Argparse::add("--ar_len").help("Autoregressive length (chunk size)").def(128); + auto& seq_len = Argparse::add("--seq_len").help("Input sequence length").def(800); + auto& gen_len = Argparse::add("--gen_len").help("Generate token length").def(32); + + Argparse::parse(argc, argv); + + if (help.isSet()) { + Argparse::printHelp(); + return 0; + } + + mllm::initQnnBackend(model_path.get()); + + auto llama_cfg = mllm::models::llama3::Llama3Config(config_path.get()); + + RunnerConfig config; + config.num_layers = llama_cfg.num_hidden_layers; + config.num_heads = llama_cfg.num_attention_heads; + config.head_dim = llama_cfg.head_dim; + config.vocab_size = llama_cfg.vocab_size; + config.context_len = 1024; + config.ar_len = ar_len.get(); + + // Note: Using Qwen3 tokenizer as a placeholder. + // For production use, you should implement a Llama3Tokenizer or use + // the appropriate tokenizer for your model. + auto tokenizer = mllm::models::qwen3::Qwen3Tokenizer(tokenizer_path.get()); + + auto input_tensor = tokenizer.convertMessage({.prompt = "hello"}); + + input_tensor["sequence"] = mllm::Tensor::arange(0, seq_len.get(), 1, mllm::kInt64, mllm::kCPU).view({1, -1}); + + // DBG: + mllm::print(input_tensor["sequence"].shape()); + mllm::print(input_tensor["sequence"]); + + Runner runner(config, &tokenizer); + if (!runner.load()) { + std::cerr << "Failed to load model\n"; + return 1; + } + + runner.generate( + input_tensor["sequence"], gen_len.get(), [](const std::string& token) { std::cout << token << std::flush; }, true); + std::cout << "\n"; + + return 0; +}); diff --git a/examples/llama_qnn_aot/compile.cpp b/examples/llama_qnn_aot/compile.cpp new file mode 100644 index 000000000..e8260484b --- /dev/null +++ b/examples/llama_qnn_aot/compile.cpp @@ -0,0 +1,159 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include + +#include "modeling_llama_qnn_aot.hpp" + +using mllm::Argparse; + +MLLM_MAIN({ + auto& help = Argparse::add("-h|--help").help("Show help message"); + auto& model_path = Argparse::add("-m|--model_path").help("Model file path."); + auto& model_cfg_path = Argparse::add("-c|--config").help("Model config file path."); + auto& qnn_aot_cfg_files = Argparse::add("-aot_cfg|--aot_config").help("AOT Config file path."); + + Argparse::parse(argc, argv); + + int N = 32; + int CL = 1024; + + if (help.isSet()) { + Argparse::printHelp(); + return 0; + } + + if (!qnn_aot_cfg_files.isSet()) { + MLLM_ERROR_EXIT(mllm::ExitCode::kCoreError, "No input aot config file path provided"); + Argparse::printHelp(); + return -1; + } + + auto model_cfg = mllm::models::llama3::Llama3Config(model_cfg_path.get()); + auto model = mllm::models::llama3::LlamaForCausalLM(model_cfg); + auto params = mllm::load(model_path.get(), mllm::ModelFileVersion::kV2); + // Add params for causal mask + { + params->push("causal_mask.scale", mllm::Tensor::constant(0.001 / 65536.f, mllm::kFloat32)); + params->push("causal_mask.zero_point", mllm::Tensor::constant(65536, mllm::kInt8)); + params->push("constant_zero.scale", mllm::Tensor::constant(0.001 / 65536.f, mllm::kFloat32)); + params->push("constant_zero.zero_point", mllm::Tensor::constant(65536, mllm::kInt8)); + } + model.load(params); + + // Create Qnn AOT Model + auto qnn_aot_env = mllm::qnn::aot::QnnAOTEnv("/opt/qcom/aistack/qairt/2.41.0.251128/lib/x86_64-linux-clang/", + mllm::qnn::aot::parseQcomTargetMachineFromJSONFile(qnn_aot_cfg_files.get())); + + // Model length 32. + + { + // Sequence: [B, N] + // past_key_i: [B, H, D, CL-N] for each layer i + // past_value_i: [B, H, CL-N, D] for each layer i + // causal_mask: [B, 1, N, CL] + auto sequence = mllm::Tensor::zeros({1, N}, mllm::kInt32); + auto causal_mask = mllm::Tensor::zeros({1, 1, N, CL}, mllm::kUInt16); + + // NOTE: force set causal mask to UInt16Asy + // NOTE: Attach scale and zero point to causal mask + { + causal_mask = causal_mask.__unsafeSetDType(mllm::kUInt16PerTensorAsy); + causal_mask.attach("scale", params->pull("causal_mask.scale").impl(), true); + causal_mask.attach("zero_point", params->pull("causal_mask.zero_point").impl(), true); + } + + // Create KV cache inputs for all layers + std::unordered_map trace_inputs; + trace_inputs["sequence"] = sequence; + trace_inputs["causal_mask"] = causal_mask; + for (int i = 0; i < model_cfg.num_hidden_layers; ++i) { + auto past_key_name = "past_key_" + std::to_string(i); + auto past_value_name = "past_value_" + std::to_string(i); + + // clang-format off + trace_inputs[past_key_name] = mllm::Tensor::empty({ + 1, + model_cfg.num_key_value_heads, + model_cfg.head_dim, + CL - N, + }, mllm::kUInt8PerTensorSym); + trace_inputs[past_value_name] = mllm::Tensor::empty({1, model_cfg.num_key_value_heads, CL - N, model_cfg.head_dim}, mllm::kUInt8PerTensorSym); + + trace_inputs[past_key_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_key_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + + trace_inputs[past_value_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_value_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + // clang-format on + } + + auto ir = model.trace(trace_inputs, {}); + + mllm::ir::PassManager pm(ir["model"]); + pm.reg(mllm::qnn::aot::createQnnAOTLoweringPipeline(&qnn_aot_env, qnn_aot_cfg_files.get(), params)); + pm.run(); + + mllm::redirect("llama_qnn_aot_32.mir", [&]() { mllm::print(ir["model"]); }); + } + + // Model length 1. + { + N = 1; + + // Sequence: [B, N] + // past_key_i: [B, H, D, CL-N] for each layer i + // past_value_i: [B, H, CL-N, D] for each layer i + // causal_mask: [B, 1, N, CL] + auto sequence = mllm::Tensor::zeros({1, N}, mllm::kInt32); + auto causal_mask = mllm::Tensor::zeros({1, 1, N, CL}, mllm::kUInt16); + + // NOTE: force set causal mask to UInt16Asy + // NOTE: Attach scale and zero point to causal mask + { + causal_mask = causal_mask.__unsafeSetDType(mllm::kUInt16PerTensorAsy); + causal_mask.attach("scale", params->pull("causal_mask.scale").impl(), true); + causal_mask.attach("zero_point", params->pull("causal_mask.zero_point").impl(), true); + } + + // Create KV cache inputs for all layers + std::unordered_map trace_inputs; + trace_inputs["sequence"] = sequence; + trace_inputs["causal_mask"] = causal_mask; + for (int i = 0; i < model_cfg.num_hidden_layers; ++i) { + auto past_key_name = "past_key_" + std::to_string(i); + auto past_value_name = "past_value_" + std::to_string(i); + + // clang-format off + trace_inputs[past_key_name] = mllm::Tensor::empty({ + 1, + model_cfg.num_key_value_heads, + model_cfg.head_dim, + CL - N, + }, mllm::kUInt8PerTensorSym); + trace_inputs[past_value_name] = mllm::Tensor::empty({1, model_cfg.num_key_value_heads, CL - N, model_cfg.head_dim}, mllm::kUInt8PerTensorSym); + + trace_inputs[past_key_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_key_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + + trace_inputs[past_value_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_value_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + // clang-format on + } + + auto ir = model.trace(trace_inputs, {}); + + mllm::ir::PassManager pm(ir["model"]); + pm.reg(mllm::qnn::aot::createQnnAOTLoweringPipeline(&qnn_aot_env, qnn_aot_cfg_files.get(), params)); + pm.run(); + + mllm::redirect("llama_qnn_aot_1.mir", [&]() { mllm::print(ir["model"]); }); + } + + qnn_aot_env.saveContext("context.0", "llama-lpbq.bin"); +}); diff --git a/examples/llama_qnn_aot/compile_sha.cpp b/examples/llama_qnn_aot/compile_sha.cpp new file mode 100644 index 000000000..96acd8076 --- /dev/null +++ b/examples/llama_qnn_aot/compile_sha.cpp @@ -0,0 +1,196 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +// +// Benefits: +// 1. Reduces QNN AOT compilation time +// 2. Improves HTP runtime performance +// 3. Enables better memory locality per head +// +// Usage: +// ./compile_sha -m /path/to/model.mllm -c /path/to/config.json -aot_cfg /path/to/qnn_aot_cfg.json + +#include +#include +#include +#include +#include +#include + +#include "modeling_llama_qnn_aot_sha.hpp" + +using mllm::Argparse; + +MLLM_MAIN({ + auto& help = Argparse::add("-h|--help").help("Show help message"); + auto& model_path = Argparse::add("-m|--model_path").help("Model file path."); + auto& model_cfg_path = Argparse::add("-c|--config").help("Model config file path."); + auto& qnn_aot_cfg_files = Argparse::add("-aot_cfg|--aot_config").help("AOT Config file path."); + + Argparse::parse(argc, argv); + + int N = 32; + int CL = 1024; + + if (help.isSet()) { + Argparse::printHelp(); + return 0; + } + + if (!qnn_aot_cfg_files.isSet()) { + MLLM_ERROR_EXIT(mllm::ExitCode::kCoreError, "No input aot config file path provided"); + Argparse::printHelp(); + return -1; + } + + auto model_cfg = mllm::models::llama3::Llama3Config(model_cfg_path.get()); + + // Load original parameters + auto params = mllm::load(model_path.get(), mllm::ModelFileVersion::kV2); + + // ============================================================================ + // Key Step: Prepare SHA parameters by slicing MHA weights + // ============================================================================ + // This is the critical step that transforms MHA weights into SHA weights. + // For each Q/K/V projection, we slice the weight matrix into per-head pieces. + // + // Original: q_proj.weight [num_heads * head_dim, hidden_size, 1, 1] + // SHA: q_proj.{h}.weight [head_dim, hidden_size, 1, 1] for each head h + // + mllm::print("Preparing SHA parameters (slicing MHA weights)..."); + mllm::models::llama3::sha::prepareParametersForSHA(params, model_cfg); + mllm::print("SHA parameters prepared."); + + // Create SHA model + auto model = mllm::models::llama3::sha::LlamaForCausalLM_SHA(model_cfg); + + // Add params for causal mask + { + params->push("causal_mask.scale", mllm::Tensor::constant(0.001 / 65536.f, mllm::kFloat32)); + params->push("causal_mask.zero_point", mllm::Tensor::constant(65536, mllm::kInt8)); + params->push("constant_zero.scale", mllm::Tensor::constant(0.001 / 65536.f, mllm::kFloat32)); + params->push("constant_zero.zero_point", mllm::Tensor::constant(65536, mllm::kInt8)); + } + model.load(params); + + // Create Qnn AOT Model + auto qnn_aot_env = mllm::qnn::aot::QnnAOTEnv("/opt/qcom/aistack/qairt/2.41.0.251128/lib/x86_64-linux-clang/", + mllm::qnn::aot::parseQcomTargetMachineFromJSONFile(qnn_aot_cfg_files.get())); + + // Model length 32. + + { + // Sequence: [B, N] + // past_key_i: [B, H, D, CL-N] for each layer i + // past_value_i: [B, H, CL-N, D] for each layer i + // causal_mask: [B, 1, N, CL] + auto sequence = mllm::Tensor::zeros({1, N}, mllm::kInt32); + auto causal_mask = mllm::Tensor::zeros({1, 1, N, CL}, mllm::kUInt16); + + // NOTE: force set causal mask to UInt16Asy + // NOTE: Attach scale and zero point to causal mask + { + causal_mask = causal_mask.__unsafeSetDType(mllm::kUInt16PerTensorAsy); + causal_mask.attach("scale", params->pull("causal_mask.scale").impl(), true); + causal_mask.attach("zero_point", params->pull("causal_mask.zero_point").impl(), true); + } + + // Create KV cache inputs for all layers + std::unordered_map trace_inputs; + trace_inputs["sequence"] = sequence; + trace_inputs["causal_mask"] = causal_mask; + + for (int i = 0; i < model_cfg.num_hidden_layers; ++i) { + auto past_key_name = "past_key_" + std::to_string(i); + auto past_value_name = "past_value_" + std::to_string(i); + + // clang-format off + trace_inputs[past_key_name] = mllm::Tensor::empty({ + 1, + model_cfg.num_key_value_heads, + model_cfg.head_dim, + CL - N, + }, mllm::kUInt8PerTensorSym); + trace_inputs[past_value_name] = mllm::Tensor::empty({1, model_cfg.num_key_value_heads, CL - N, model_cfg.head_dim}, mllm::kUInt8PerTensorSym); + + trace_inputs[past_key_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_key_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + + trace_inputs[past_value_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_value_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + // clang-format on + } + + mllm::print("Tracing SHA model (seq=32)..."); + auto ir = model.trace(trace_inputs, {}); + mllm::print("SHA model traced successfully."); + + mllm::ir::PassManager pm(ir["model"]); + pm.reg(mllm::qnn::aot::createQnnAOTLoweringPipeline(&qnn_aot_env, qnn_aot_cfg_files.get(), params)); + pm.run(); + + mllm::redirect("llama_qnn_aot_sha_32.mir", [&]() { mllm::print(ir["model"]); }); + } + + // Model length 1. + { + N = 1; + + // Sequence: [B, N] + // past_key_i: [B, H, D, CL-N] for each layer i + // past_value_i: [B, H, CL-N, D] for each layer i + // causal_mask: [B, 1, N, CL] + auto sequence = mllm::Tensor::zeros({1, N}, mllm::kInt32); + auto causal_mask = mllm::Tensor::zeros({1, 1, N, CL}, mllm::kUInt16); + + // NOTE: force set causal mask to UInt16Asy + // NOTE: Attach scale and zero point to causal mask + { + causal_mask = causal_mask.__unsafeSetDType(mllm::kUInt16PerTensorAsy); + causal_mask.attach("scale", params->pull("causal_mask.scale").impl(), true); + causal_mask.attach("zero_point", params->pull("causal_mask.zero_point").impl(), true); + } + + // Create KV cache inputs for all layers + std::unordered_map trace_inputs; + trace_inputs["sequence"] = sequence; + trace_inputs["causal_mask"] = causal_mask; + for (int i = 0; i < model_cfg.num_hidden_layers; ++i) { + auto past_key_name = "past_key_" + std::to_string(i); + auto past_value_name = "past_value_" + std::to_string(i); + + // clang-format off + trace_inputs[past_key_name] = mllm::Tensor::empty({ + 1, + model_cfg.num_key_value_heads, + model_cfg.head_dim, + CL - N, + }, mllm::kUInt8PerTensorSym); + trace_inputs[past_value_name] = mllm::Tensor::empty({1, model_cfg.num_key_value_heads, CL - N, model_cfg.head_dim}, mllm::kUInt8PerTensorSym); + + trace_inputs[past_key_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_key_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.k_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + + trace_inputs[past_value_name].attach("scale", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.scale").impl(), true); + trace_inputs[past_value_name].attach("zero_point", params->pull("model.layers." + std::to_string(i) + ".self_attn.v_cast_to_int8_qdq.fake_quant.zero_point").impl(), true); + // clang-format on + } + + mllm::print("Tracing SHA model (seq=1)..."); + auto ir = model.trace(trace_inputs, {}); + mllm::print("SHA model traced successfully."); + + mllm::ir::PassManager pm(ir["model"]); + pm.reg(mllm::qnn::aot::createQnnAOTLoweringPipeline(&qnn_aot_env, qnn_aot_cfg_files.get(), params)); + pm.run(); + + mllm::redirect("llama_qnn_aot_sha_1.mir", [&]() { mllm::print(ir["model"]); }); + } + + qnn_aot_env.saveContext("context.0", "llama-lpbq-sha.bin"); + + mllm::print("SHA compilation completed successfully!"); + mllm::print("Output files:"); + mllm::print(" - llama_qnn_aot_sha_32.mir (IR dump for seq=32)"); + mllm::print(" - llama_qnn_aot_sha_1.mir (IR dump for seq=1)"); + mllm::print(" - llama-lpbq-sha.bin (QNN context)"); +}); diff --git a/examples/llama_qnn_aot/config_3B.json b/examples/llama_qnn_aot/config_3B.json new file mode 100644 index 000000000..ef7a3e94e --- /dev/null +++ b/examples/llama_qnn_aot/config_3B.json @@ -0,0 +1,28 @@ +{ + "architectures": [ + "LlamaForCausalLM" + ], + "attention_dropout": 0.0, + "bos_token_id": 128000, + "eos_token_id": 128009, + "attention_bias": false, + "hidden_act": "silu", + "hidden_size": 3072, + "head_dim": 128, + "initializer_range": 0.02, + "intermediate_size": 8192, + "max_position_embeddings": 131072, + "model_type": "llama", + "num_attention_heads": 24, + "num_hidden_layers": 28, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-05, + "rope_theta": 500000.0, + "tie_word_embeddings": true, + "torch_dtype": "bfloat16", + "transformers_version": "4.45.0", + "use_cache": true, + "vocab_size": 128256, + "max_cache_length": 2048, + "linear_impl_type": "QNN_LPBQ_w4a16o16_G32" +} diff --git a/examples/llama_qnn_aot/configuration_llama3.hpp b/examples/llama_qnn_aot/configuration_llama3.hpp new file mode 100644 index 000000000..16375ff6c --- /dev/null +++ b/examples/llama_qnn_aot/configuration_llama3.hpp @@ -0,0 +1,97 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include "mllm/core/aops/LinearOp.hpp" +#include "mllm/engine/ConfigFile.hpp" + +namespace mllm::models::llama3 { + +/** + * @brief Configuration for Llama 3.x models used in QNN AOT compilation. + * + * This configuration is designed to support Llama 3.2 3B Instruct and similar models. + * It includes all necessary fields for QNN AOT compilation, including: + * - head_dim: Dimension of each attention head + * - max_cache_length: Maximum KV cache length + * - end_of_text_token_id: Token ID for end of text + */ +struct Llama3Config : protected ConfigFile { + Llama3Config() = default; + + explicit Llama3Config(const std::string& file_path) : ConfigFile(file_path) { + // Init all + vocab_size = data()["vocab_size"]; + hidden_size = data()["hidden_size"]; + intermediate_size = data()["intermediate_size"]; + num_hidden_layers = data()["num_hidden_layers"]; + num_attention_heads = data()["num_attention_heads"]; + num_key_value_heads = data()["num_key_value_heads"]; + hidden_act = data()["hidden_act"]; + max_position_embeddings = data()["max_position_embeddings"]; + rms_norm_eps = data()["rms_norm_eps"]; + rope_theta = data()["rope_theta"]; + attention_bias = data()["attention_bias"]; + + // Handle head_dim - compute from hidden_size/num_attention_heads if not provided + if (data().contains("head_dim")) { + head_dim = data()["head_dim"]; + } else { + head_dim = hidden_size / num_attention_heads; + } + + // Handle default values for optional parameters + if (num_key_value_heads == 0) { num_key_value_heads = num_attention_heads; } + + // Token IDs + bos_token_id = data()["bos_token_id"]; + eos_token_id = data()["eos_token_id"]; + + // End of text token - use eos_token_id as default + if (data().contains("end_of_text_token_id")) { + end_of_text_token_id = data()["end_of_text_token_id"]; + } else { + end_of_text_token_id = eos_token_id; + } + + tie_word_embeddings = data()["tie_word_embeddings"]; + max_cache_length = data()["max_cache_length"]; + + linear_impl_type = aops::str2LinearImplTypes(data()["linear_impl_type"]); + } + + // Model architecture parameters + int32_t vocab_size = 128256; // Llama 3.2 vocabulary size + int32_t hidden_size = 3072; // Llama 3.2 3B hidden size + int32_t head_dim = 128; // Head dimension + int32_t intermediate_size = 8192; // FFN intermediate size + int32_t num_hidden_layers = 28; // Number of transformer layers + int32_t num_attention_heads = 24; // Number of attention heads + int32_t num_key_value_heads = 8; // Number of KV heads (GQA) + std::string hidden_act = "silu"; // Activation function + int32_t max_position_embeddings = 131072; // Max sequence length + + // Normalization and RoPE + float rms_norm_eps = 1e-5; // RMSNorm epsilon + float rope_theta = 500000.0; // RoPE base frequency + + // Attention + bool attention_bias = false; // Whether to use bias in attention + + // Token IDs + int64_t bos_token_id = 128000; // Begin of sequence token + int64_t eos_token_id = 128009; // End of sequence token + int32_t end_of_text_token_id = 128009; // End of text token for generation + + // Word embedding + bool tie_word_embeddings = true; // Tie input/output embeddings + + // Cache configuration + int32_t max_cache_length = 2048; // Maximum KV cache length + + // Linear implementation type for quantization + aops::LinearImplTypes linear_impl_type = aops::LinearImplTypes::kDefault; +}; + +} // namespace mllm::models::llama3 diff --git a/examples/llama_qnn_aot/modeling_llama_qnn_aot.hpp b/examples/llama_qnn_aot/modeling_llama_qnn_aot.hpp new file mode 100644 index 000000000..a129cd3bd --- /dev/null +++ b/examples/llama_qnn_aot/modeling_llama_qnn_aot.hpp @@ -0,0 +1,508 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/mllm.hpp" +#include "mllm/nn/Nn.hpp" +#include "mllm/nn/Module.hpp" +#include "mllm/nn/Functional.hpp" +#include "mllm/core/DataTypes.hpp" +#include "mllm/utils/Enumerate.hpp" +#include "mllm/compile/ir/Trace.hpp" +#include "mllm/models/ARGeneration.hpp" +#include "configuration_llama3.hpp" + +namespace mllm::models::llama3 { + +namespace ptq { + +Tensor QDQ_CONSTANT(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { + std::string scale_name = qdq_name_in_pytorch + ".scale"; + std::string zp_name = qdq_name_in_pytorch + ".zero_point"; + switch (in.dtype()) { + case kFloat32: + case kUInt16PerTensorAsy: { + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", zp.impl(), true); + break; + } + default: { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't Process dtype={}", nameOfType(in.dtype())); + } + } + return in; +} + +Tensor QDQ(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { + std::string scale_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.scale"; + std::string zp_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.zero_point"; + + if (m->getModuleName().empty()) { + scale_name = qdq_name_in_pytorch + ".fake_quant.scale"; + zp_name = qdq_name_in_pytorch + ".fake_quant.zero_point"; + } else { + scale_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.scale"; + zp_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.zero_point"; + } + + switch (in.dtype()) { + case kUInt16PerTensorAsy: { + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", zp.impl(), true); + break; + } + // For Constant! + case kFloat32: { + MLLM_RT_ASSERT_EQ(in.rank(), 1); + MLLM_RT_ASSERT_EQ(in.size(-1), 1); + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", zp.impl(), true); + break; + } + default: { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't Process dtype={}", nameOfType(in.dtype())); + } + } + + return in; +} + +Tensor QDQ_KV(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { + auto scale_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.scale"; + auto zp_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.zero_point"; + + // The inputs is int8 sym. which means zero_point should be changed. + switch (in.dtype()) { + case kUInt8PerTensorSym: { + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + MLLM_RT_ASSERT_EQ(zp.item(), 0); + + // Is 128! not 127! + auto new_zp = Tensor::constant(128, kInt32).setName(zp_name).setMemType(kParamsNormal); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", new_zp.impl(), true); + break; + } + default: { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't Process dtype={}", nameOfType(in.dtype())); + } + } + + return in; +} + +Tensor QDQ_ROPE(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { + auto scale_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.scale"; + auto zp_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.zero_point"; + + (void)in.__unsafeSetDType(kUInt16PerTensorAsy); + + switch (in.dtype()) { + case kUInt16PerTensorAsy: { + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", zp.impl(), true); + break; + } + default: { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't Process dtype={}", nameOfType(in.dtype())); + } + } + + return in; +} + +} // namespace ptq + +Tensor rotateHalf(Tensor x, nn::Module* m, const std::string& qdq_name_in_pytorch) { // NOLINT + // X is [x, x, x, D] + auto D = x.size(-1); + auto x1 = x.slice({kAll, kAll, kAll, {kAll, D / 2}}, /*ssa=*/true); + auto x2 = x.slice({kAll, kAll, kAll, {D / 2, kAll}}, /*ssa=*/true); + return nn::functional::concat({ptq::QDQ(m, -x2, qdq_name_in_pytorch), x1}, -1); +} + +using vi32 = std::vector; +#define CONV2D_PROPERTY vi32{1, 1}, vi32{1, 1}, vi32{0, 0}, vi32{1, 1}, false, aops::Conv2DOpImplType::kQNN_LPBQ_w4a16o16_G32 + +// Using Conv2D to replace Linear. +// Conv2D Filter Weight is [1, 1, In, Out] +// Conv2D Activation is [N, H=1, W=Seq, In] + +class LlamaMLP final : public nn::Module { + nn::Conv2D gate_proj_; + nn::Conv2D up_proj_; + nn::Conv2D down_proj_; + nn::SiLU silu_; + int hidden_size_; + int intermediate_size_; + + public: + LlamaMLP() = default; + LlamaMLP(const std::string& name, const Llama3Config& cfg) : nn::Module(name) { + gate_proj_ = reg("gate_proj", cfg.hidden_size, cfg.intermediate_size, CONV2D_PROPERTY); + silu_ = reg("act"); + up_proj_ = reg("up_proj", cfg.hidden_size, cfg.intermediate_size, CONV2D_PROPERTY); + down_proj_ = reg("down_proj", cfg.intermediate_size, cfg.hidden_size, CONV2D_PROPERTY); + hidden_size_ = cfg.hidden_size; + intermediate_size_ = cfg.intermediate_size; + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = inputs[0]; + x = ptq::QDQ(this, x, "up_proj_input_qdq"); + x = x.view({1, 1, -1, hidden_size_}, true); + + auto up_result = ptq::QDQ(this, up_proj_(x), "up_proj_output_qdq").view({1, -1, intermediate_size_}, true); + auto gate_result = ptq::QDQ(this, gate_proj_(x), "gate_proj_output_qdq").view({1, -1, intermediate_size_}, true); + + // SiLU + gate_result = ptq::QDQ(this, (gate_result * ptq::QDQ(this, nn::functional::sigmoid(gate_result), "sigmoid_output_qdq")), + "act_output_qdq"); + + auto o = ptq::QDQ(this, gate_result * up_result, "down_proj_input_qdq"); + o = o.view({1, 1, -1, intermediate_size_}, true); + o = down_proj_(o).view({1, -1, hidden_size_}, true); + + return {o}; + } +}; + +class LlamaAttention final : public nn::Module { + nn::Conv2D q_proj_; + nn::Conv2D k_proj_; + nn::Conv2D v_proj_; + nn::Conv2D o_proj_; + // Llama does NOT have RMSNorm after QK projection (unlike Qwen3) + nn::CausalMask mask_; + nn::Softmax softmax_; + + int hidden_size_; + int head_dim_; + int num_attention_heads_; + int num_key_value_heads_; + int num_key_value_groups_; + float scale_; + + public: + LlamaAttention() = default; + + LlamaAttention(const std::string& name, const Llama3Config& cfg) : nn::Module(name) { + hidden_size_ = cfg.hidden_size; + num_attention_heads_ = cfg.num_attention_heads; + num_key_value_heads_ = cfg.num_key_value_heads; + head_dim_ = cfg.head_dim; + num_key_value_groups_ = num_attention_heads_ / num_key_value_heads_; + scale_ = (1.f / sqrtf((float)head_dim_)); + + // clang-format off + q_proj_ = reg("q_proj", hidden_size_, head_dim_ * num_attention_heads_, CONV2D_PROPERTY); + k_proj_ = reg("k_proj", hidden_size_, head_dim_ * num_key_value_heads_, CONV2D_PROPERTY); + v_proj_ = reg("v_proj", hidden_size_, head_dim_ * num_key_value_heads_, CONV2D_PROPERTY); + o_proj_ = reg("o_proj", head_dim_ * num_attention_heads_, hidden_size_, CONV2D_PROPERTY); + // clang-format on + + mask_ = reg("mask"); + softmax_ = reg("softmax", -1); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto causal_mask = inputs[3]; + auto past_key = inputs[4]; + auto past_value = inputs[5]; + + // [B, S, D] + hidden_states = ptq::QDQ(this, hidden_states, "q_proj_input_qdq"); + hidden_states = hidden_states.view({1, 1, -1, hidden_size_}, true); + + // [B, S, H * D] + auto query_states = q_proj_(hidden_states); + auto key_states = k_proj_(hidden_states); + auto value_states = v_proj_(hidden_states); + + query_states = ptq::QDQ(this, query_states, "q_proj_output_qdq"); + key_states = ptq::QDQ(this, key_states, "k_proj_output_qdq"); + + // [B, H, S, D] + query_states = query_states.view({1, -1, num_attention_heads_, head_dim_}, /*ssa=*/true).transpose(1, 2); + key_states = key_states.view({1, -1, num_key_value_heads_, head_dim_}, /*ssa=*/true).transpose(1, 2); + value_states = value_states.view({1, -1, num_key_value_heads_, head_dim_}, /*ssa=*/true).transpose(1, 2); + + // Llama does NOT have RMSNorm here (unlike Qwen3) + // Directly apply RoPE + + // [B, H, S, D] + auto cos = llm_embedding_cos.unsqueeze(1, true); + auto sin = llm_embedding_sin.unsqueeze(1, true); + query_states = + ptq::QDQ(this, + ptq::QDQ(this, query_states * cos, "q_rope_mul_0_output_qdq") + + ptq::QDQ(this, rotateHalf(query_states, this, "q_rope_neg_half_qdq") * sin, "q_rope_mul_1_output_qdq"), + "q_rope_add_0_output_qdq"); + key_states = + ptq::QDQ(this, + ptq::QDQ(this, key_states * cos, "k_rope_mul_0_output_qdq") + + ptq::QDQ(this, rotateHalf(key_states, this, "k_rope_neg_half_qdq") * sin, "k_rope_mul_1_output_qdq"), + "k_rope_add_0_output_qdq"); + + // De-quantization and quantization again + key_states = key_states.to(kFloat32); + key_states = key_states.to(kUInt8PerTensorSym); + key_states = ptq::QDQ_KV(this, key_states, "k_cast_to_int8_qdq"); + + // [B, H, D, S] + key_states = key_states.transpose(2, 3); + + // Handle KV Cache + value_states = ptq::QDQ(this, value_states, "v_cast_to_int16_qdq"); + value_states = value_states.to(kFloat32); + value_states = value_states.to(kUInt8PerTensorSym); + value_states = ptq::QDQ_KV(this, value_states, "v_cast_to_int8_qdq"); + + auto kh = nn::functional::concat({past_key, key_states}, -1); // [B, H, D, S] + auto vh = nn::functional::concat({past_value, value_states}, 2); // [B, H, S, D] + + // Repeat + kh = kh.repeat(num_key_value_groups_, 1); + vh = vh.repeat(num_key_value_groups_, 1); + + // Attn + auto attn = ptq::QDQ(this, nn::functional::matmul(query_states, kh), "qk_matmul_output_qdq"); + auto scale = Tensor::constant(scale_, kFloat32); + scale = ptq::QDQ(this, scale, "scaling_qdq"); + attn = ptq::QDQ(this, attn.mulConstant(scale), "mul_0_output_qdq"); + + // Masked Softmax + auto attn_min = ptq::QDQ(this, attn.min(-1, true), "reduce_min_output_qdq"); + auto minus_value = Tensor::constant(-20, kFloat32); + minus_value = ptq::QDQ(this, minus_value, "neg_20_qdq"); + auto attn_vv = ptq::QDQ(this, attn_min.addConstant(minus_value), "minus_0_output_qdq"); + auto zero_constant = Tensor::constant(0.f, kFloat32); + zero_constant = ptq::QDQ_CONSTANT(this, zero_constant, "constant_zero"); + attn = nn::functional::where(causal_mask.equalConstant(zero_constant), attn, attn_vv); + attn = ptq::QDQ(this, attn, "where_attn_qdq"); + attn = ptq::QDQ(this, nn::functional::softmax(attn, -1), "softmax_output_qdq"); + auto y = ptq::QDQ(this, nn::functional::matmul(attn, vh), "attn_value_matmul_output_qdq"); + y = y.transpose(1, 2).view({1, 1, -1, num_attention_heads_ * head_dim_}, /*ssa=*/true); + y = o_proj_(y).view({1, -1, hidden_size_}, true); + + return {y, key_states, value_states}; + } + + int layer_idx_; +}; + +class LlamaDecoder final : public nn::Module { + public: + int layer_idx_; + LlamaAttention self_attn_; + LlamaMLP mlp_; + nn::RMSNorm input_layer_norm_; + nn::RMSNorm post_attention_layer_norm_; + + LlamaDecoder() = default; + + LlamaDecoder(const std::string& name, const Llama3Config& cfg, int layer_idx) : nn::Module(name) { + layer_idx_ = layer_idx; + self_attn_ = reg("self_attn", cfg); + mlp_ = reg("mlp", cfg); + input_layer_norm_ = reg("input_layernorm", cfg.rms_norm_eps); + post_attention_layer_norm_ = reg("post_attention_layernorm", cfg.rms_norm_eps); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto causal_mask = inputs[3]; + auto past_key = inputs[4]; + auto past_value = inputs[5]; + + auto hidden_states = inputs[0]; + if (layer_idx_ != 0) { hidden_states = ptq::QDQ(this, hidden_states, "input_layernorm_input_qdq"); } + auto residual = hidden_states; + hidden_states = input_layer_norm_(hidden_states); + auto _ = self_attn_(hidden_states, llm_embedding_sin, llm_embedding_cos, causal_mask, past_key, past_value); + hidden_states = _[0]; + hidden_states = ptq::QDQ(this, residual + ptq::QDQ(this, hidden_states, "add_0_lhs_input_qdq"), "add_0_output_qdq"); + residual = hidden_states; + hidden_states = post_attention_layer_norm_(hidden_states); + hidden_states = mlp_(hidden_states)[0]; + hidden_states = residual + ptq::QDQ(this, hidden_states, "add_1_lhs_input_qdq"); + return {hidden_states, _[1], _[2]}; + } +}; + +class LlamaText final : public nn::Module { + nn::ModuleListWithIdx decode_blocks_; + nn::RMSNorm norm_; + nn::Embedding embedding_; + nn::Param rope_sin_; + nn::Param rope_cos_; + int32_t num_hidden_layers_; + int32_t hidden_size_; + + public: + LlamaText() = default; + + LlamaText(const std::string& name, const Llama3Config& cfg) : nn::Module(name) { + num_hidden_layers_ = cfg.num_hidden_layers; + hidden_size_ = cfg.hidden_size; + decode_blocks_ = reg>("layers", cfg.num_hidden_layers, cfg); + for (auto [idx, b] : enumerate(decode_blocks_.list())) { b.self_attn_.layer_idx_ = idx; } + norm_ = reg("norm", cfg.rms_norm_eps); + embedding_ = reg("embed_tokens", cfg.vocab_size, cfg.hidden_size); + rope_sin_ = reg("mllm_max_sin_embedding", "model.mllm_max_sin_embedding"); + rope_cos_ = reg("mllm_max_cos_embedding", "model.mllm_max_cos_embedding"); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto& blocks = decode_blocks_.list(); + + // X is already embedded + auto x = embedding_(inputs[0]); + + const auto& position_ids = inputs[1]; + auto causal_mask = inputs[2]; + + // clang-format off + auto llm_embedding_sin = nn::functional::gather(ptq::QDQ_ROPE(this, rope_sin_(), "sin_embedding_input_qdq"), 1, position_ids); + auto llm_embedding_cos = nn::functional::gather(ptq::QDQ_ROPE(this, rope_cos_(), "cos_embedding_input_qdq"), 1, position_ids); + // clang-format on + + std::vector keys; + std::vector values; + for (auto [index, block] : enumerate(blocks)) { + auto pk = inputs[3 + index]; + auto pv = inputs[3 + index + num_hidden_layers_]; + auto _ = block(x, llm_embedding_sin, llm_embedding_cos, causal_mask, pk, pv); + x = _[0]; + keys.push_back(_[1]); + values.push_back(_[2]); + } + + x = norm_(ptq::QDQ(this, x, "norm_input_qdq")); + x = x.view({1, 1, -1, hidden_size_}, true); + + auto ret = std::vector{x}; + for (const auto& item : keys) { ret.push_back(item); } + for (const auto& item : values) { ret.push_back(item); } + + return ret; + } +}; + +class LlamaForCausalLM : public ARGeneration, public nn::Module { + public: + explicit LlamaForCausalLM(const Llama3Config& cfg) : cfg(cfg) { + eos_token_id_ = cfg.end_of_text_token_id; + max_length_ = cfg.max_cache_length; + tie_word_embeddings_ = cfg.tie_word_embeddings; + + llm = reg("model", cfg); + + if (cfg.tie_word_embeddings) { + // NOTE: + // model.lm_head.weight is quantization weights of model.embed_tokens.weight + lm_head_ = reg("lm_head", cfg.hidden_size, cfg.vocab_size, CONV2D_PROPERTY); + } + } + + IROutput trace(const ARGenerationOutputPast& input, const ARGenerationArgs& args) override { + // Things we need to return + ir::IRContext::ptr_t llm_ir = nullptr; + + auto sequence = input.at("sequence"); + auto causal_mask = input.at("causal_mask"); + + std::vector kv_caches; + + // Append Key + for (int i = 0; i < cfg.num_hidden_layers; ++i) { + auto past_key_name = "past_key_" + std::to_string(i); + if (input.count(past_key_name)) { + kv_caches.push_back(input.at(past_key_name)); + } else { + // If KV cache doesn't exist, we need to handle this case + // For now, we'll create empty tensors or handle it appropriately + // This might need adjustment based on your initialization logic + throw std::runtime_error("Missing KV cache for layer " + std::to_string(i)); + } + } + + // Append Value + for (int i = 0; i < cfg.num_hidden_layers; ++i) { + auto past_value_name = "past_value_" + std::to_string(i); + if (input.count(past_value_name)) { + kv_caches.push_back(input.at(past_value_name)); + } else { + // If KV cache doesn't exist, we need to handle this case + // For now, we'll create empty tensors or handle it appropriately + // This might need adjustment based on your initialization logic + throw std::runtime_error("Missing KV cache for layer " + std::to_string(i)); + } + } + + // Generate position_ids for the current sequence + auto batch_size = sequence.shape()[0]; + auto seq_len = sequence.shape()[1]; + + Tensor position_ids = Tensor::nil(); + if (input.count("position_ids")) { + // Use existing position_ids for decode phase + position_ids = input.at("position_ids"); + + // For decode phase, increment the last position + if (seq_len == 1) { + auto last_pos = *position_ids.offsettedPtr({position_ids.shape()[1] - 1}); + position_ids = Tensor::empty({1}, kInt32, kCPU).alloc(); + *position_ids.offsettedPtr({0}) = last_pos + 1; + } + } else { + // Generate position_ids for prefill phase + position_ids = Tensor::empty({seq_len}, kInt32, kCPU).alloc(); + auto position_ids_ptr = position_ids.ptr(); + for (int s = 0; s < seq_len; ++s) { position_ids_ptr[s] = s; } + } + + ir::lowlevel::traceStart(); + + // Build inputs for llm: sequence, llm_embedding_sin, llm_embedding_cos, causal_mask, then all KV caches + std::vector llm_inputs = {sequence, position_ids, causal_mask}; + llm_inputs.insert(llm_inputs.end(), kv_caches.begin(), kv_caches.end()); + + sequence = llm(llm_inputs)[0]; + sequence = lm_head_(ptq::QDQ(this, sequence, "lm_head_input_qdq")); + sequence = ptq::QDQ(this, sequence, "lm_head_output_qdq"); + ir::lowlevel::traceComment(" ╔═════╗ "); + ir::lowlevel::traceComment(" ║ o o ║ "); + ir::lowlevel::traceComment(" ║ ▽ ║ "); + ir::lowlevel::traceComment(" ╚═════╝ "); + ir::lowlevel::traceComment(" ║ ║ "); + ir::lowlevel::traceComment(" ╱╩╦╦╩╲ "); + llm_ir = ir::lowlevel::traceStop(); + + return {{"model", llm_ir}}; + } + + ARGenerationOutputPast forward(const ARGenerationOutputPast& input, const ARGenerationArgs& args) override { return {}; } + + private: + const Llama3Config& cfg; + LlamaText llm; + nn::Conv2D lm_head_; + bool tie_word_embeddings_; +}; + +} // namespace mllm::models::llama3 diff --git a/examples/llama_qnn_aot/modeling_llama_qnn_aot_sha.hpp b/examples/llama_qnn_aot/modeling_llama_qnn_aot_sha.hpp new file mode 100644 index 000000000..a26ebef1e --- /dev/null +++ b/examples/llama_qnn_aot/modeling_llama_qnn_aot_sha.hpp @@ -0,0 +1,788 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +// +// The optimization splits large Q/K/V projections into per-head projections, +// allowing QNN to optimize each head separately, reducing AOT compilation time +// and improving HTP performance. + +#pragma once + +#include "mllm/core/TensorStorage.hpp" +#include "mllm/mllm.hpp" +#include "mllm/nn/Nn.hpp" +#include "mllm/nn/Module.hpp" +#include "mllm/nn/Functional.hpp" +#include "mllm/core/DataTypes.hpp" +#include "mllm/utils/Enumerate.hpp" +#include "mllm/compile/ir/Trace.hpp" +#include "mllm/models/ARGeneration.hpp" +#include "configuration_llama3.hpp" + +namespace mllm::models::llama3::sha { + +namespace ptq { + +Tensor QDQ_CONSTANT(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { + std::string scale_name = qdq_name_in_pytorch + ".scale"; + std::string zp_name = qdq_name_in_pytorch + ".zero_point"; + switch (in.dtype()) { + case kFloat32: + case kUInt16PerTensorAsy: { + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", zp.impl(), true); + break; + } + default: { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't Process dtype={}", nameOfType(in.dtype())); + } + } + return in; +} + +Tensor QDQ(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { + std::string scale_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.scale"; + std::string zp_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.zero_point"; + + if (m->getModuleName().empty()) { + scale_name = qdq_name_in_pytorch + ".fake_quant.scale"; + zp_name = qdq_name_in_pytorch + ".fake_quant.zero_point"; + } else { + scale_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.scale"; + zp_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.zero_point"; + } + + switch (in.dtype()) { + case kUInt16PerTensorAsy: { + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", zp.impl(), true); + break; + } + // For Constant! + case kFloat32: { + MLLM_RT_ASSERT_EQ(in.rank(), 1); + MLLM_RT_ASSERT_EQ(in.size(-1), 1); + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", zp.impl(), true); + break; + } + default: { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't Process dtype={}", nameOfType(in.dtype())); + } + } + + return in; +} + +Tensor QDQ_KV(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { + auto scale_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.scale"; + auto zp_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.zero_point"; + + // The inputs is int8 sym. which means zero_point should be changed. + switch (in.dtype()) { + case kUInt8PerTensorSym: { + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + MLLM_RT_ASSERT_EQ(zp.item(), 0); + + // Is 128! not 127! + auto new_zp = Tensor::constant(128, kInt32).setName(zp_name).setMemType(kParamsNormal); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", new_zp.impl(), true); + break; + } + default: { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't Process dtype={}", nameOfType(in.dtype())); + } + } + + return in; +} + +Tensor QDQ_ROPE(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { + auto scale_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.scale"; + auto zp_name = m->getModuleName() + "." + qdq_name_in_pytorch + ".fake_quant.zero_point"; + + (void)in.__unsafeSetDType(kUInt16PerTensorAsy); + + switch (in.dtype()) { + case kUInt16PerTensorAsy: { + auto scale = m->getTopParameterFile()->pull(scale_name); + auto zp = m->getTopParameterFile()->pull(zp_name); + in.attach("scale", scale.impl(), true); + in.attach("zero_point", zp.impl(), true); + break; + } + default: { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Can't Process dtype={}", nameOfType(in.dtype())); + } + } + + return in; +} + +} // namespace ptq + +Tensor rotateHalf(Tensor x, nn::Module* m, const std::string& qdq_name_in_pytorch) { // NOLINT + // X is [x, x, x, D] + auto D = x.size(-1); + auto x1 = x.slice({kAll, kAll, kAll, {kAll, D / 2}}, /*ssa=*/true); + auto x2 = x.slice({kAll, kAll, kAll, {D / 2, kAll}}, /*ssa=*/true); + return nn::functional::concat({ptq::QDQ(m, -x2, qdq_name_in_pytorch), x1}, -1); +} + +using vi32 = std::vector; +#define CONV2D_PROPERTY vi32{1, 1}, vi32{1, 1}, vi32{0, 0}, vi32{1, 1}, false, aops::Conv2DOpImplType::kQNN_LPBQ_w4a16o16_G32 + +// Using Conv2D to replace Linear. +// Conv2D Filter Weight is [1, 1, In, Out] +// Conv2D Activation is [N, H=1, W=Seq, In] + +class LlamaMLP final : public nn::Module { + nn::Conv2D gate_proj_; + nn::Conv2D up_proj_; + nn::Conv2D down_proj_; + nn::SiLU silu_; + int hidden_size_; + int intermediate_size_; + + public: + LlamaMLP() = default; + LlamaMLP(const std::string& name, const Llama3Config& cfg) : nn::Module(name) { + gate_proj_ = reg("gate_proj", cfg.hidden_size, cfg.intermediate_size, CONV2D_PROPERTY); + silu_ = reg("act"); + up_proj_ = reg("up_proj", cfg.hidden_size, cfg.intermediate_size, CONV2D_PROPERTY); + down_proj_ = reg("down_proj", cfg.intermediate_size, cfg.hidden_size, CONV2D_PROPERTY); + hidden_size_ = cfg.hidden_size; + intermediate_size_ = cfg.intermediate_size; + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = inputs[0]; + x = ptq::QDQ(this, x, "up_proj_input_qdq"); + x = x.view({1, 1, -1, hidden_size_}, true); + + auto up_result = ptq::QDQ(this, up_proj_(x), "up_proj_output_qdq").view({1, -1, intermediate_size_}, true); + auto gate_result = ptq::QDQ(this, gate_proj_(x), "gate_proj_output_qdq").view({1, -1, intermediate_size_}, true); + + // SiLU + gate_result = ptq::QDQ(this, (gate_result * ptq::QDQ(this, nn::functional::sigmoid(gate_result), "sigmoid_output_qdq")), + "act_output_qdq"); + + auto o = ptq::QDQ(this, gate_result * up_result, "down_proj_input_qdq"); + o = o.view({1, 1, -1, intermediate_size_}, true); + o = down_proj_(o).view({1, -1, hidden_size_}, true); + + return {o}; + } +}; + +// ============================================================================ +// Single Head Attention (SHA) Implementation +// ============================================================================ +// +// This class implements SHA where each attention head has its own separate +// Conv2D projection, instead of one large MHA projection that processes all +// heads at once. +// +// Benefits: +// 1. Reduces QNN AOT compilation time +// 2. Improves HTP runtime performance +// 3. Enables better memory locality per head +// +// Note: Llama does NOT have RMSNorm after Q/K projection (unlike Qwen3) + +class LlamaAttentionSHA final : public nn::Module { + // Per-head Q projections: num_attention_heads Conv2D(hidden_size, head_dim) + std::vector q_projs_; + // Per-head K projections: num_key_value_heads Conv2D(hidden_size, head_dim) + std::vector k_projs_; + // Per-head V projections: num_key_value_heads Conv2D(hidden_size, head_dim) + std::vector v_projs_; + // Single O projection remains unchanged (concatenated heads -> hidden_size) + nn::Conv2D o_proj_; + + nn::CausalMask mask_; + nn::Softmax softmax_; + + int hidden_size_; + int head_dim_; + int num_attention_heads_; + int num_key_value_heads_; + int num_key_value_groups_; + float scale_; + + public: + LlamaAttentionSHA() = default; + + LlamaAttentionSHA(const std::string& name, const Llama3Config& cfg) : nn::Module(name) { + hidden_size_ = cfg.hidden_size; + num_attention_heads_ = cfg.num_attention_heads; + num_key_value_heads_ = cfg.num_key_value_heads; + head_dim_ = cfg.head_dim; + num_key_value_groups_ = num_attention_heads_ / num_key_value_heads_; + scale_ = (1.f / sqrtf((float)head_dim_)); + + // Register per-head Q projections + for (int h = 0; h < num_attention_heads_; ++h) { + q_projs_.emplace_back(reg("q_proj." + std::to_string(h), hidden_size_, head_dim_, CONV2D_PROPERTY)); + } + + // Register per-head K projections + for (int h = 0; h < num_key_value_heads_; ++h) { + k_projs_.emplace_back(reg("k_proj." + std::to_string(h), hidden_size_, head_dim_, CONV2D_PROPERTY)); + } + + // Register per-head V projections + for (int h = 0; h < num_key_value_heads_; ++h) { + v_projs_.emplace_back(reg("v_proj." + std::to_string(h), hidden_size_, head_dim_, CONV2D_PROPERTY)); + } + + // O projection remains the same (combines all heads) + o_proj_ = reg("o_proj", head_dim_ * num_attention_heads_, hidden_size_, CONV2D_PROPERTY); + + mask_ = reg("mask"); + softmax_ = reg("softmax", -1); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto causal_mask = inputs[3]; + const auto& past_key = inputs[4]; // [B, num_kv_heads, D, S] + const auto& past_value = inputs[5]; // [B, num_kv_heads, S, D] + + // [B, S, D] - shared QDQ for input to all Q/K/V projections + hidden_states = ptq::QDQ(this, hidden_states, "q_proj_input_qdq"); + hidden_states = hidden_states.view({1, 1, -1, hidden_size_}, true); + + // ======================================================================== + // Per-head Q/K/V Projections + // ======================================================================== + // This is the key SHA optimization: instead of one large projection for all + // heads, we have separate smaller projections per head. + + // Compute per-head Q projections: each outputs [1, 1, S, head_dim] + std::vector query_states_per_head; + for (int h = 0; h < num_attention_heads_; ++h) { + auto q_h = q_projs_[h](hidden_states); + q_h = q_h.view({1, 1, -1, head_dim_}, /*ssa=*/true); + query_states_per_head.push_back(q_h); + } + + // Compute per-head K projections + std::vector key_states_per_head; + for (int h = 0; h < num_key_value_heads_; ++h) { + auto k_h = k_projs_[h](hidden_states); + k_h = k_h.view({1, 1, -1, head_dim_}, /*ssa=*/true); + key_states_per_head.push_back(k_h); + } + + // Compute per-head V projections + std::vector value_states_per_head; + for (int h = 0; h < num_key_value_heads_; ++h) { + auto v_h = v_projs_[h](hidden_states); + v_h = v_h.view({1, 1, -1, head_dim_}, /*ssa=*/true); + value_states_per_head.push_back(v_h); + } + + // ======================================================================== + // Reshape and Transpose for RoPE + // ======================================================================== + // Llama does NOT have RMSNorm here (unlike Qwen3) + // Directly apply RoPE after reshaping to [B, H, S, D] format + // Each head tensor is [1, 1, S, head_dim], need to reshape to [1, 1, S, head_dim] for RoPE + // (The shape is already correct, but we need to ensure QDQ is applied) + + auto cos = llm_embedding_cos.unsqueeze(1, true); + auto sin = llm_embedding_sin.unsqueeze(1, true); + + // Apply QDQ and RoPE per Q head + // Each query_states_per_head[h] is [1, 1, S, head_dim] + for (int h = 0; h < num_attention_heads_; ++h) { + std::string h_str = std::to_string(h); + query_states_per_head[h] = ptq::QDQ(this, query_states_per_head[h], "q_proj_output_qdq_h" + h_str); + // Reshape to [1, 1, S, head_dim] for RoPE (already correct shape) + query_states_per_head[h] = + ptq::QDQ(this, + ptq::QDQ(this, query_states_per_head[h] * cos, "q_rope_mul_0_output_qdq_h" + h_str) + + ptq::QDQ(this, rotateHalf(query_states_per_head[h], this, "q_rope_neg_half_qdq_h" + h_str) * sin, + "q_rope_mul_1_output_qdq_h" + h_str), + "q_rope_add_0_output_qdq_h" + h_str); + } + + // Apply QDQ and RoPE per K head + // Each key_states_per_head[h] is [1, 1, S, head_dim] + for (int h = 0; h < num_key_value_heads_; ++h) { + std::string h_str = std::to_string(h); + key_states_per_head[h] = ptq::QDQ(this, key_states_per_head[h], "k_proj_output_qdq_h" + h_str); + // Reshape to [1, 1, S, head_dim] for RoPE (already correct shape) + key_states_per_head[h] = + ptq::QDQ(this, + ptq::QDQ(this, key_states_per_head[h] * cos, "k_rope_mul_0_output_qdq_h" + h_str) + + ptq::QDQ(this, rotateHalf(key_states_per_head[h], this, "k_rope_neg_half_qdq_h" + h_str) * sin, + "k_rope_mul_1_output_qdq_h" + h_str), + "k_rope_add_0_output_qdq_h" + h_str); + } + + // ======================================================================== + // KV Cache Processing per head + // ======================================================================== + + std::vector new_key_per_head; + std::vector new_value_per_head; + std::vector key_cache_per_head; + std::vector value_cache_per_head; + + for (int h = 0; h < num_key_value_heads_; ++h) { + std::string h_str = std::to_string(h); + + // K: De-quantize and re-quantize to int8 + auto k_h = key_states_per_head[h].to(kFloat32); + k_h = k_h.to(kUInt8PerTensorSym); + k_h = ptq::QDQ_KV(this, k_h, "k_cast_to_int8_qdq_h" + h_str); + k_h = k_h.transpose(2, 3); // [B, 1, D, S] + + // V: Quantize to int16 then int8 + auto v_h = ptq::QDQ(this, value_states_per_head[h], "v_cast_to_int16_qdq_h" + h_str); + v_h = v_h.to(kFloat32); + v_h = v_h.to(kUInt8PerTensorSym); + v_h = ptq::QDQ_KV(this, v_h, "v_cast_to_int8_qdq_h" + h_str); + + new_key_per_head.push_back(k_h); + new_value_per_head.push_back(v_h); + + // Slice past cache for this head + auto past_k_h = past_key.slice({kAll, {h, h + 1}, kAll, kAll}, true); + auto past_v_h = past_value.slice({kAll, {h, h + 1}, kAll, kAll}, true); + + // Concat current with past + key_cache_per_head.push_back(nn::functional::concat({past_k_h, k_h}, -1)); + value_cache_per_head.push_back(nn::functional::concat({past_v_h, v_h}, 2)); + } + + // ======================================================================== + // Per-head Attention Computation + // ======================================================================== + // Each Q head computes attention with its corresponding KV head (GQA support) + // For GQA, multiple Q heads share the same KV head + + std::vector attn_outputs; + + for (int h = 0; h < num_attention_heads_; ++h) { + std::string h_str = std::to_string(h); + int kv_head_idx = h / num_key_value_groups_; + + const auto& q_h = query_states_per_head[h]; + const auto& kh = key_cache_per_head[kv_head_idx]; + const auto& vh = value_cache_per_head[kv_head_idx]; + + // QK^T + auto attn = ptq::QDQ(this, nn::functional::matmul(q_h, kh), "qk_matmul_output_qdq_h" + h_str); + + // Scale + auto scale = Tensor::constant(scale_, kFloat32); + scale = ptq::QDQ(this, scale, "scaling_qdq_h" + h_str); + attn = ptq::QDQ(this, attn.mulConstant(scale), "mul_0_output_qdq_h" + h_str); + + // Masked Softmax + auto attn_min = ptq::QDQ(this, attn.min(-1, true), "reduce_min_output_qdq_h" + h_str); + auto minus_value = Tensor::constant(-20, kFloat32); + minus_value = ptq::QDQ(this, minus_value, "neg_20_qdq_h" + h_str); + auto attn_vv = ptq::QDQ(this, attn_min.addConstant(minus_value), "minus_0_output_qdq_h" + h_str); + auto zero_constant = Tensor::constant(0.f, kFloat32); + zero_constant = ptq::QDQ_CONSTANT(this, zero_constant, "constant_zero"); + attn = nn::functional::where(causal_mask.equalConstant(zero_constant), attn, attn_vv); + attn = ptq::QDQ(this, attn, "where_attn_qdq_h" + h_str); + attn = ptq::QDQ(this, nn::functional::softmax(attn, -1), "softmax_output_qdq_h" + h_str); + + // Output: attn @ V + auto y_h = ptq::QDQ(this, nn::functional::matmul(attn, vh), "attn_value_matmul_output_qdq_h" + h_str); + attn_outputs.push_back(y_h); + } + + // ======================================================================== + // Concatenate and Output Projection + // ======================================================================== + + // Concat all head outputs: [B, num_heads, S, D] + auto y = nn::functional::concat(attn_outputs, 1); + + // Reshape and apply O projection + y = y.transpose(1, 2).view({1, 1, -1, num_attention_heads_ * head_dim_}, /*ssa=*/true); + y = o_proj_(y).view({1, -1, hidden_size_}, true); + + // Concat new keys and values back to original format + auto new_key = nn::functional::concat(new_key_per_head, 1); + auto new_value = nn::functional::concat(new_value_per_head, 1); + + return {y, new_key, new_value}; + } + + int layer_idx_; +}; + +class LlamaDecoderSHA final : public nn::Module { + public: + int layer_idx_; + LlamaAttentionSHA self_attn_; + LlamaMLP mlp_; + nn::RMSNorm input_layer_norm_; + nn::RMSNorm post_attention_layer_norm_; + + LlamaDecoderSHA() = default; + + LlamaDecoderSHA(const std::string& name, const Llama3Config& cfg, int layer_idx) : nn::Module(name) { + layer_idx_ = layer_idx; + self_attn_ = reg("self_attn", cfg); + mlp_ = reg("mlp", cfg); + input_layer_norm_ = reg("input_layernorm", cfg.rms_norm_eps); + post_attention_layer_norm_ = reg("post_attention_layernorm", cfg.rms_norm_eps); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto causal_mask = inputs[3]; + auto past_key = inputs[4]; + auto past_value = inputs[5]; + + auto hidden_states = inputs[0]; + if (layer_idx_ != 0) { hidden_states = ptq::QDQ(this, hidden_states, "input_layernorm_input_qdq"); } + auto residual = hidden_states; + hidden_states = input_layer_norm_(hidden_states); + auto _ = self_attn_(hidden_states, llm_embedding_sin, llm_embedding_cos, causal_mask, past_key, past_value); + hidden_states = _[0]; + hidden_states = ptq::QDQ(this, residual + ptq::QDQ(this, hidden_states, "add_0_lhs_input_qdq"), "add_0_output_qdq"); + residual = hidden_states; + hidden_states = post_attention_layer_norm_(hidden_states); + hidden_states = mlp_(hidden_states)[0]; + hidden_states = residual + ptq::QDQ(this, hidden_states, "add_1_lhs_input_qdq"); + return {hidden_states, _[1], _[2]}; + } +}; + +class LlamaTextSHA final : public nn::Module { + nn::ModuleListWithIdx decode_blocks_; + nn::RMSNorm norm_; + nn::Embedding embedding_; + nn::Param rope_sin_; + nn::Param rope_cos_; + int32_t num_hidden_layers_; + int32_t hidden_size_; + + public: + LlamaTextSHA() = default; + + LlamaTextSHA(const std::string& name, const Llama3Config& cfg) : nn::Module(name) { + num_hidden_layers_ = cfg.num_hidden_layers; + hidden_size_ = cfg.hidden_size; + decode_blocks_ = reg>("layers", cfg.num_hidden_layers, cfg); + for (auto [idx, b] : enumerate(decode_blocks_.list())) { b.self_attn_.layer_idx_ = idx; } + norm_ = reg("norm", cfg.rms_norm_eps); + embedding_ = reg("embed_tokens", cfg.vocab_size, cfg.hidden_size); + rope_sin_ = reg("mllm_max_sin_embedding", "model.mllm_max_sin_embedding"); + rope_cos_ = reg("mllm_max_cos_embedding", "model.mllm_max_cos_embedding"); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto& blocks = decode_blocks_.list(); + + // X is already embedded + auto x = embedding_(inputs[0]); + + const auto& position_ids = inputs[1]; + auto causal_mask = inputs[2]; + + // clang-format off + auto llm_embedding_sin = nn::functional::gather(ptq::QDQ_ROPE(this, rope_sin_(), "sin_embedding_input_qdq"), 1, position_ids); + auto llm_embedding_cos = nn::functional::gather(ptq::QDQ_ROPE(this, rope_cos_(), "cos_embedding_input_qdq"), 1, position_ids); + // clang-format on + + std::vector keys; + std::vector values; + for (auto [index, block] : enumerate(blocks)) { + auto pk = inputs[3 + index]; + auto pv = inputs[3 + index + num_hidden_layers_]; + auto _ = block(x, llm_embedding_sin, llm_embedding_cos, causal_mask, pk, pv); + x = _[0]; + keys.push_back(_[1]); + values.push_back(_[2]); + } + + x = norm_(ptq::QDQ(this, x, "norm_input_qdq")); + x = x.view({1, 1, -1, hidden_size_}, true); + + auto ret = std::vector{x}; + for (const auto& item : keys) { ret.push_back(item); } + for (const auto& item : values) { ret.push_back(item); } + + return ret; + } +}; + +class LlamaForCausalLM_SHA : public ARGeneration, public nn::Module { + public: + explicit LlamaForCausalLM_SHA(const Llama3Config& cfg) : cfg(cfg) { + eos_token_id_ = cfg.end_of_text_token_id; + max_length_ = cfg.max_cache_length; + tie_word_embeddings_ = cfg.tie_word_embeddings; + + llm = reg("model", cfg); + + if (cfg.tie_word_embeddings) { + // NOTE: + // model.lm_head.weight is quantization weights of model.embed_tokens.weight + lm_head_ = reg("lm_head", cfg.hidden_size, cfg.vocab_size, CONV2D_PROPERTY); + } + } + + IROutput trace(const ARGenerationOutputPast& input, const ARGenerationArgs& args) override { + // Things we need to return + ir::IRContext::ptr_t llm_ir = nullptr; + + auto sequence = input.at("sequence"); + auto causal_mask = input.at("causal_mask"); + + std::vector kv_caches; + + // Append Key + for (int i = 0; i < cfg.num_hidden_layers; ++i) { + auto past_key_name = "past_key_" + std::to_string(i); + if (input.count(past_key_name)) { + kv_caches.push_back(input.at(past_key_name)); + } else { + throw std::runtime_error("Missing KV cache for layer " + std::to_string(i)); + } + } + + // Append Value + for (int i = 0; i < cfg.num_hidden_layers; ++i) { + auto past_value_name = "past_value_" + std::to_string(i); + if (input.count(past_value_name)) { + kv_caches.push_back(input.at(past_value_name)); + } else { + throw std::runtime_error("Missing KV cache for layer " + std::to_string(i)); + } + } + + // Generate position_ids for the current sequence + auto batch_size = sequence.shape()[0]; + auto seq_len = sequence.shape()[1]; + + Tensor position_ids = Tensor::nil(); + if (input.count("position_ids")) { + // Use existing position_ids for decode phase + position_ids = input.at("position_ids"); + + // For decode phase, increment the last position + if (seq_len == 1) { + auto last_pos = *position_ids.offsettedPtr({position_ids.shape()[1] - 1}); + position_ids = Tensor::empty({1}, kInt32, kCPU).alloc(); + *position_ids.offsettedPtr({0}) = last_pos + 1; + } + } else { + // Generate position_ids for prefill phase + position_ids = Tensor::empty({seq_len}, kInt32, kCPU).alloc(); + auto position_ids_ptr = position_ids.ptr(); + for (int s = 0; s < seq_len; ++s) { position_ids_ptr[s] = s; } + } + + ir::lowlevel::traceStart(); + + // Build inputs for llm: sequence, llm_embedding_sin, llm_embedding_cos, causal_mask, then all KV caches + std::vector llm_inputs = {sequence, position_ids, causal_mask}; + llm_inputs.insert(llm_inputs.end(), kv_caches.begin(), kv_caches.end()); + + sequence = llm(llm_inputs)[0]; + sequence = lm_head_(ptq::QDQ(this, sequence, "lm_head_input_qdq")); + sequence = ptq::QDQ(this, sequence, "lm_head_output_qdq"); + ir::lowlevel::traceComment(" ╔═════╗ "); + ir::lowlevel::traceComment(" ║ o o ║ "); + ir::lowlevel::traceComment(" ║ ▽ ║ "); + ir::lowlevel::traceComment(" ╚═════╝ "); + ir::lowlevel::traceComment(" ║ ║ "); + ir::lowlevel::traceComment(" ╱╩╦╦╩╲ "); + llm_ir = ir::lowlevel::traceStop(); + + return {{"model", llm_ir}}; + } + + ARGenerationOutputPast forward(const ARGenerationOutputPast& input, const ARGenerationArgs& args) override { return {}; } + + private: + const Llama3Config& cfg; + LlamaTextSHA llm; + nn::Conv2D lm_head_; + bool tie_word_embeddings_; +}; + +// ============================================================================ +// Weight Slicing Utilities for SHA +// ============================================================================ +// +// These functions are used during the compile phase to slice the original +// MHA weights into per-head SHA weights. +// +// Note: Llama does NOT have q_norm and k_norm (RMSNorm), so we don't need +// to slice those parameters. + +/** + * @brief Prepares the parameter file by slicing MHA weights into SHA weights. + * + * This function takes the original parameter file with MHA weights and creates + * new per-head weights for the SHA model. + * + * Original weight layout for Conv2D: [out_channels, in_channels, 1, 1] + * - q_proj.weight: [num_heads * head_dim, hidden_size, 1, 1] + * - k_proj.weight: [num_kv_heads * head_dim, hidden_size, 1, 1] + * - v_proj.weight: [num_kv_heads * head_dim, hidden_size, 1, 1] + * + * For LPBQ quantization, also need to slice: + * - scale1: flattened scale for block quantization + * - scale2: flattened scale for block quantization + * + * SHA weight layout: + * - q_proj.{h}.weight: [head_dim, hidden_size, 1, 1] for h in [0, num_heads) + * - q_proj.{h}.scale1: sliced scale for head h + * - q_proj.{h}.scale2: sliced scale for head h + * - Similar for k_proj and v_proj + */ +inline void prepareParametersForSHA(const ParameterFile::ptr_t& params, const Llama3Config& cfg) { + int num_heads = cfg.num_attention_heads; + int num_kv_heads = cfg.num_key_value_heads; + int head_dim = cfg.head_dim; + int num_layers = cfg.num_hidden_layers; + + // Helper lambda to slice and push Conv2D params (weight, scale1, scale2) + // For LPBQ, scale1 and scale2 are flattened along the output channel dimension + // Scale size per head = total_scale_size / num_heads_for_this_proj + auto sliceAndPushConv2DParams = [&](const std::string& orig_name_prefix, const std::string& new_name_prefix, + int total_out_channels, int out_channels_per_head, int num_splits) { + // Process weight: HWIO format [H=1, W=1, In_channels, Out_channels] + // For q_proj: [1, 1, hidden_size, num_heads * head_dim] + // Slice on the last dimension (Out_channels) + std::string orig_weight_name = orig_name_prefix + ".weight"; + if (params->has(orig_weight_name)) { + auto orig_weight = params->pull(orig_weight_name); + + for (int h = 0; h < num_splits; ++h) { + std::string new_weight_name = new_name_prefix + "." + std::to_string(h) + ".weight"; + int start_idx = h * out_channels_per_head; + int end_idx = (h + 1) * out_channels_per_head; + // HWIO format: slice on dim 3 (Out_channels) + auto sliced = orig_weight.slice({kAll, kAll, kAll, {start_idx, end_idx}}, false); + params->push(new_weight_name, sliced.contiguous().setMemType(kParamsNormal).setName(new_weight_name)); + } + } + + // Process scale1: flattened, size = total_out_channels / block_size (or similar) + // Slice index: (total_scale_size / num_splits) * h + std::string orig_scale1_name = orig_name_prefix + ".scale1"; + if (params->has(orig_scale1_name)) { + auto orig_scale1 = params->pull(orig_scale1_name); + int total_scale_size = orig_scale1.numel(); + int scale_per_head = total_scale_size / num_splits; + + for (int h = 0; h < num_splits; ++h) { + std::string new_scale1_name = new_name_prefix + "." + std::to_string(h) + ".scale1"; + int start_idx = h * scale_per_head; + int end_idx = (h + 1) * scale_per_head; + auto sliced = orig_scale1.slice({{start_idx, end_idx}}, false); + params->push(new_scale1_name, sliced.contiguous().setMemType(kParamsNormal).setName(new_scale1_name)); + } + } + + // Process scale2: flattened, same logic as scale1 + std::string orig_scale2_name = orig_name_prefix + ".scale2"; + if (params->has(orig_scale2_name)) { + auto orig_scale2 = params->pull(orig_scale2_name); + int total_scale_size = orig_scale2.numel(); + int scale_per_head = total_scale_size / num_splits; + + for (int h = 0; h < num_splits; ++h) { + std::string new_scale2_name = new_name_prefix + "." + std::to_string(h) + ".scale2"; + int start_idx = h * scale_per_head; + int end_idx = (h + 1) * scale_per_head; + auto sliced = orig_scale2.slice({{start_idx, end_idx}}, false); + params->push(new_scale2_name, sliced.contiguous().setMemType(kParamsNormal).setName(new_scale2_name)); + } + } + }; + + for (int layer = 0; layer < num_layers; ++layer) { + std::string layer_prefix = "model.layers." + std::to_string(layer) + ".self_attn."; + + // Process Q projection: split into num_heads parts + sliceAndPushConv2DParams(layer_prefix + "q_proj", layer_prefix + "q_proj", num_heads * head_dim, head_dim, num_heads); + + // Process K projection: split into num_kv_heads parts + sliceAndPushConv2DParams(layer_prefix + "k_proj", layer_prefix + "k_proj", num_kv_heads * head_dim, head_dim, num_kv_heads); + + // Process V projection: split into num_kv_heads parts + sliceAndPushConv2DParams(layer_prefix + "v_proj", layer_prefix + "v_proj", num_kv_heads * head_dim, head_dim, num_kv_heads); + + // ======================================================================== + // Duplicate QDQ parameters for each head + // ======================================================================== + // The original MHA uses shared QDQ params for all heads. For SHA, we + // duplicate these to per-head versions using "_h{N}" suffix naming. + // This allows each head to have its own quantization parameters. + + auto copyQDQParams = [&](const std::string& base_name, const std::string& new_base_name, int count) { + std::string scale_name = layer_prefix + base_name + ".fake_quant.scale"; + std::string zp_name = layer_prefix + base_name + ".fake_quant.zero_point"; + + if (params->has(scale_name)) { + auto scale = params->pull(scale_name); + auto zp = params->pull(zp_name); + + for (int h = 0; h < count; ++h) { + std::string new_scale_name = layer_prefix + new_base_name + std::to_string(h) + ".fake_quant.scale"; + std::string new_zp_name = layer_prefix + new_base_name + std::to_string(h) + ".fake_quant.zero_point"; + // QDQ scale/zp are typically scalar or small tensors, clone to ensure contiguous + params->push(new_scale_name, scale.contiguous().setMemType(kParamsNormal).setName(new_scale_name)); + params->push(new_zp_name, zp.contiguous().setMemType(kParamsNormal).setName(new_zp_name)); + } + } + }; + + // Copy QDQ params for Q-related nodes (per Q head) + copyQDQParams("q_proj_output_qdq", "q_proj_output_qdq_h", num_heads); + copyQDQParams("q_rope_mul_0_output_qdq", "q_rope_mul_0_output_qdq_h", num_heads); + copyQDQParams("q_rope_mul_1_output_qdq", "q_rope_mul_1_output_qdq_h", num_heads); + copyQDQParams("q_rope_neg_half_qdq", "q_rope_neg_half_qdq_h", num_heads); + copyQDQParams("q_rope_add_0_output_qdq", "q_rope_add_0_output_qdq_h", num_heads); + + // Copy QDQ params for K-related nodes (per KV head) + copyQDQParams("k_proj_output_qdq", "k_proj_output_qdq_h", num_kv_heads); + copyQDQParams("k_rope_mul_0_output_qdq", "k_rope_mul_0_output_qdq_h", num_kv_heads); + copyQDQParams("k_rope_mul_1_output_qdq", "k_rope_mul_1_output_qdq_h", num_kv_heads); + copyQDQParams("k_rope_neg_half_qdq", "k_rope_neg_half_qdq_h", num_kv_heads); + copyQDQParams("k_rope_add_0_output_qdq", "k_rope_add_0_output_qdq_h", num_kv_heads); + copyQDQParams("k_cast_to_int8_qdq", "k_cast_to_int8_qdq_h", num_kv_heads); + + // Copy QDQ params for V-related nodes (per KV head) + copyQDQParams("v_cast_to_int16_qdq", "v_cast_to_int16_qdq_h", num_kv_heads); + copyQDQParams("v_cast_to_int8_qdq", "v_cast_to_int8_qdq_h", num_kv_heads); + + // Copy QDQ params for attention computation (per Q head) + copyQDQParams("qk_matmul_output_qdq", "qk_matmul_output_qdq_h", num_heads); + copyQDQParams("scaling_qdq", "scaling_qdq_h", num_heads); + copyQDQParams("mul_0_output_qdq", "mul_0_output_qdq_h", num_heads); + copyQDQParams("reduce_min_output_qdq", "reduce_min_output_qdq_h", num_heads); + copyQDQParams("neg_20_qdq", "neg_20_qdq_h", num_heads); + copyQDQParams("minus_0_output_qdq", "minus_0_output_qdq_h", num_heads); + copyQDQParams("where_attn_qdq", "where_attn_qdq_h", num_heads); + copyQDQParams("softmax_output_qdq", "softmax_output_qdq_h", num_heads); + copyQDQParams("attn_value_matmul_output_qdq", "attn_value_matmul_output_qdq_h", num_heads); + } +} + +} // namespace mllm::models::llama3::sha diff --git a/examples/llama_qnn_aot/qnn_aot_cfg_3B.json b/examples/llama_qnn_aot/qnn_aot_cfg_3B.json new file mode 100644 index 000000000..e4691eccb --- /dev/null +++ b/examples/llama_qnn_aot/qnn_aot_cfg_3B.json @@ -0,0 +1,6 @@ +{ + "target_device": "SM8750", + "soc_model": "59", + "htp_arch": "79", + "vtcm_mb": 8 +} diff --git a/examples/qwen2_qnn_aot/aot_run.cpp b/examples/qwen2_qnn_aot/aot_run.cpp index 41437a85e..14d2dadfc 100644 --- a/examples/qwen2_qnn_aot/aot_run.cpp +++ b/examples/qwen2_qnn_aot/aot_run.cpp @@ -15,6 +15,8 @@ MLLM_MAIN({ auto& tokenizer_path = Argparse::add("-t|--tokenizer").help("Tokenizer path").def("tokenizer.json"); auto& config_path = Argparse::add("-c|--config").help("Config path").required(true); auto& ar_len = Argparse::add("--ar_len").help("Autoregressive length (chunk size)").def(128); + auto& seq_len = Argparse::add("--seq_len").help("Input sequence length").def(800); + auto& gen_len = Argparse::add("--gen_len").help("Generate token length").def(32); Argparse::parse(argc, argv); @@ -39,7 +41,7 @@ MLLM_MAIN({ auto input_tensor = tokenizer.convertMessage({.prompt = "hello"}); - input_tensor["sequence"] = mllm::Tensor::arange(0, 800, 1, mllm::kInt64, mllm::kCPU).view({1, -1}); + input_tensor["sequence"] = mllm::Tensor::arange(0, seq_len.get(), 1, mllm::kInt64, mllm::kCPU).view({1, -1}); // DBG: mllm::print(input_tensor["sequence"].shape()); @@ -51,7 +53,8 @@ MLLM_MAIN({ return 1; } - runner.generate(input_tensor["sequence"], 32, [](const std::string& token) { std::cout << token << std::flush; }, true); + runner.generate( + input_tensor["sequence"], gen_len.get(), [](const std::string& token) { std::cout << token << std::flush; }, true); std::cout << "\n"; return 0; From c848097f209062db46bf9b7bc1dea630ad6f6438 Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Thu, 29 Jan 2026 04:25:25 +0000 Subject: [PATCH 5/8] feat(llama_qnn_aot): Update tokenizer to TinyLlama and enhance input handling in AOT run. Modify configuration for improved performance with new target machine settings and quantization parameters. --- examples/llama_qnn_aot/aot_run.cpp | 8 +++- examples/llama_qnn_aot/qnn_aot_cfg_3B.json | 53 ++++++++++++++++++++-- 2 files changed, 55 insertions(+), 6 deletions(-) diff --git a/examples/llama_qnn_aot/aot_run.cpp b/examples/llama_qnn_aot/aot_run.cpp index b54b7ba48..c19183533 100644 --- a/examples/llama_qnn_aot/aot_run.cpp +++ b/examples/llama_qnn_aot/aot_run.cpp @@ -4,6 +4,7 @@ #include #include "mllm/backends/qnn/aot_rt/QnnAOTRuntime.hpp" #include "configuration_llama3.hpp" +#include "mllm/models/llama/tokenization_tiny_llama.hpp" #include "mllm/models/qwen3/tokenization_qwen3.hpp" using mllm::Argparse; @@ -40,9 +41,12 @@ MLLM_MAIN({ // Note: Using Qwen3 tokenizer as a placeholder. // For production use, you should implement a Llama3Tokenizer or use // the appropriate tokenizer for your model. - auto tokenizer = mllm::models::qwen3::Qwen3Tokenizer(tokenizer_path.get()); + auto tokenizer = mllm::models::llama::TinyLlamaTokenizer(tokenizer_path.get()); - auto input_tensor = tokenizer.convertMessage({.prompt = "hello"}); + auto input_tensor = tokenizer.convertMessage({{ + .role = "user", + .content = "hello", + }}); input_tensor["sequence"] = mllm::Tensor::arange(0, seq_len.get(), 1, mllm::kInt64, mllm::kCPU).view({1, -1}); diff --git a/examples/llama_qnn_aot/qnn_aot_cfg_3B.json b/examples/llama_qnn_aot/qnn_aot_cfg_3B.json index e4691eccb..97240f5a2 100644 --- a/examples/llama_qnn_aot/qnn_aot_cfg_3B.json +++ b/examples/llama_qnn_aot/qnn_aot_cfg_3B.json @@ -1,6 +1,51 @@ { - "target_device": "SM8750", - "soc_model": "59", - "htp_arch": "79", - "vtcm_mb": 8 + "target_machine": { + "htp_arch": "V75", + "htp_chipset": "SM8650", + "htp_try_best_performance": "HtpBurst", + "htp_security_pd_session": "HtpSignedPd", + "htp_vtcm_capability_in_mb": 8 + }, + "graph_on_qnn": [ + "model" + ], + "op_on_qnn": [ + "lm_head" + ], + "split_graph": 1, + "quant_recipe": { + "llm_recipe": true, + "layers": 28, + "builtin_llm_pass": { + "model": "llama", + "lm_head": { + "fallback": { + "method": "LPBQ", + "sym": true, + "precision": "w4a16", + "block_size": 32 + } + }, + "linear": { + "fallback": { + "method": "LPBQ", + "sym": true, + "precision": "w4a16", + "block_size": 32 + } + }, + "kv_cache": { + "key": { + "method": "per-tensor", + "sym": true, + "precision": "w8a8" + }, + "value": { + "method": "per-tensor", + "sym": true, + "precision": "w8a8" + } + } + } + } } From 30d9a052b43d792c9dbb494939b460445255ba84 Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Thu, 29 Jan 2026 08:37:52 +0000 Subject: [PATCH 6/8] feat(qnn): Enhance quantization checks and tensor handling in QnnWrappersAPI and PTQPass. Update weight type limits and add masking for Qnn's packing requirements in QLinearLPBQ for improved compatibility and performance. --- mllm/backends/qnn/aot/QnnWrappersAPI.cpp | 3 +++ mllm/backends/qnn/aot/passes/PTQPass.cpp | 2 +- pymllm/backends/qualcomm/transformers/core/qlinear.py | 7 +++++++ 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/mllm/backends/qnn/aot/QnnWrappersAPI.cpp b/mllm/backends/qnn/aot/QnnWrappersAPI.cpp index a585c17f4..0f67bab56 100644 --- a/mllm/backends/qnn/aot/QnnWrappersAPI.cpp +++ b/mllm/backends/qnn/aot/QnnWrappersAPI.cpp @@ -182,6 +182,9 @@ void QnnAOTNodeTensor::setupComplexTensorQuantization(const ir::tensor::TensorVa std::vector scale_offsets(num_scale_offsets); MLLM_RT_ASSERT_EQ(num_scale_offsets, cfg->scale_level_1_fp.size(-1)); MLLM_RT_ASSERT_EQ(cfg->scale_level_0_int.dtype(), kUInt8); + MLLM_RT_ASSERT_EQ(cfg->scale_level_1_fp.dtype(), kFloat32); + MLLM_RT_ASSERT_EQ(cfg->scale_level_0_int.rank(), 1); + MLLM_RT_ASSERT_EQ(cfg->scale_level_1_fp.rank(), 1); for (int i = 0; i < num_scale_offsets; ++i) { scale_offsets[i].scale = cfg->scale_level_1_fp.at({i}); scale_offsets[i].offset = 0; diff --git a/mllm/backends/qnn/aot/passes/PTQPass.cpp b/mllm/backends/qnn/aot/passes/PTQPass.cpp index b1d46f908..43e591fec 100644 --- a/mllm/backends/qnn/aot/passes/PTQPass.cpp +++ b/mllm/backends/qnn/aot/passes/PTQPass.cpp @@ -47,7 +47,7 @@ void solveLinearWeight(const ir::IRContext::ptr_t& ctx, const ParameterFile::ptr auto weight = pf->pull(mllm_op->getName() + ".weight"); // FIXME weight maybe error, Check qnn eats int8 or uint8. Here weight using int8 to store int4. - checkTypeLimits(weight, -8, 7); // Int4 + checkTypeLimits(weight, 0, 15); // Int4 checkTypeLimits(scale1, 0, 16); // UInt4 this_spec->scale_level_0_int = scale1; diff --git a/pymllm/backends/qualcomm/transformers/core/qlinear.py b/pymllm/backends/qualcomm/transformers/core/qlinear.py index d3bc8150d..9e90ba8a5 100644 --- a/pymllm/backends/qualcomm/transformers/core/qlinear.py +++ b/pymllm/backends/qualcomm/transformers/core/qlinear.py @@ -242,6 +242,13 @@ def convert_to_conv2d_deploy_hwio(self): .contiguous() ) + # Packing for Qnn to use + # Qnn's packing will not do x & 0x0F, so we need to do it here. + mask = torch.full( + weight_int4.size(), 0x0F, dtype=torch.int8, device=weight_int4.device + ) + weight_int4 = torch.bitwise_and(mask, weight_int4) + del self.weight self.register_buffer("weight", weight_int4) self.register_buffer("scale1", quantized_scales.flatten()) From 7afb5f0dc42c8e5b05a3b68b3db0e63f65e453cd Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Thu, 29 Jan 2026 09:47:45 +0000 Subject: [PATCH 7/8] fix(qnn): Update causal mask and constant zero parameters for improved quantization accuracy in AOT compilation across multiple models. --- examples/llama_qnn_aot/compile.cpp | 8 ++++---- examples/llama_qnn_aot/compile_sha.cpp | 8 ++++---- examples/qwen2_qnn_aot/compile.cpp | 8 ++++---- examples/qwen2_qnn_aot/compile_sha.cpp | 8 ++++---- examples/qwen3_qnn_aot/compile.cpp | 8 ++++---- examples/qwen3_qnn_aot/compile_sha.cpp | 8 ++++---- 6 files changed, 24 insertions(+), 24 deletions(-) diff --git a/examples/llama_qnn_aot/compile.cpp b/examples/llama_qnn_aot/compile.cpp index e8260484b..16e819d48 100644 --- a/examples/llama_qnn_aot/compile.cpp +++ b/examples/llama_qnn_aot/compile.cpp @@ -39,10 +39,10 @@ MLLM_MAIN({ auto params = mllm::load(model_path.get(), mllm::ModelFileVersion::kV2); // Add params for causal mask { - params->push("causal_mask.scale", mllm::Tensor::constant(0.001 / 65536.f, mllm::kFloat32)); - params->push("causal_mask.zero_point", mllm::Tensor::constant(65536, mllm::kInt8)); - params->push("constant_zero.scale", mllm::Tensor::constant(0.001 / 65536.f, mllm::kFloat32)); - params->push("constant_zero.zero_point", mllm::Tensor::constant(65536, mllm::kInt8)); + params->push("causal_mask.scale", mllm::Tensor::constant(0.001 / 65535.f, mllm::kFloat32)); + params->push("causal_mask.zero_point", mllm::Tensor::constant(65535, mllm::kInt8)); + params->push("constant_zero.scale", mllm::Tensor::constant(0.001 / 65535.f, mllm::kFloat32)); + params->push("constant_zero.zero_point", mllm::Tensor::constant(65535, mllm::kInt8)); } model.load(params); diff --git a/examples/llama_qnn_aot/compile_sha.cpp b/examples/llama_qnn_aot/compile_sha.cpp index 96acd8076..a2a111568 100644 --- a/examples/llama_qnn_aot/compile_sha.cpp +++ b/examples/llama_qnn_aot/compile_sha.cpp @@ -65,10 +65,10 @@ MLLM_MAIN({ // Add params for causal mask { - params->push("causal_mask.scale", mllm::Tensor::constant(0.001 / 65536.f, mllm::kFloat32)); - params->push("causal_mask.zero_point", mllm::Tensor::constant(65536, mllm::kInt8)); - params->push("constant_zero.scale", mllm::Tensor::constant(0.001 / 65536.f, mllm::kFloat32)); - params->push("constant_zero.zero_point", mllm::Tensor::constant(65536, mllm::kInt8)); + params->push("causal_mask.scale", mllm::Tensor::constant(0.001 / 65535.f, mllm::kFloat32)); + params->push("causal_mask.zero_point", mllm::Tensor::constant(65535, mllm::kInt8)); + params->push("constant_zero.scale", mllm::Tensor::constant(0.001 / 65535.f, mllm::kFloat32)); + params->push("constant_zero.zero_point", mllm::Tensor::constant(65535, mllm::kInt8)); } model.load(params); diff --git a/examples/qwen2_qnn_aot/compile.cpp b/examples/qwen2_qnn_aot/compile.cpp index 17857ac48..fe491a76f 100644 --- a/examples/qwen2_qnn_aot/compile.cpp +++ b/examples/qwen2_qnn_aot/compile.cpp @@ -39,10 +39,10 @@ MLLM_MAIN({ auto params = mllm::load(model_path.get(), mllm::ModelFileVersion::kV2); // Add params for causal mask { - params->push("causal_mask.scale", mllm::Tensor::constant(0.001 / 65536.f, mllm::kFloat32)); - params->push("causal_mask.zero_point", mllm::Tensor::constant(65536, mllm::kInt8)); - params->push("constant_zero.scale", mllm::Tensor::constant(0.001 / 65536.f, mllm::kFloat32)); - params->push("constant_zero.zero_point", mllm::Tensor::constant(65536, mllm::kInt8)); + params->push("causal_mask.scale", mllm::Tensor::constant(0.001 / 65535.f, mllm::kFloat32)); + params->push("causal_mask.zero_point", mllm::Tensor::constant(65535, mllm::kInt8)); + params->push("constant_zero.scale", mllm::Tensor::constant(0.001 / 65535.f, mllm::kFloat32)); + params->push("constant_zero.zero_point", mllm::Tensor::constant(65535, mllm::kInt8)); } model.load(params); diff --git a/examples/qwen2_qnn_aot/compile_sha.cpp b/examples/qwen2_qnn_aot/compile_sha.cpp index 2ae85f3ae..fd3748dd7 100644 --- a/examples/qwen2_qnn_aot/compile_sha.cpp +++ b/examples/qwen2_qnn_aot/compile_sha.cpp @@ -65,10 +65,10 @@ MLLM_MAIN({ // Add params for causal mask { - params->push("causal_mask.scale", mllm::Tensor::constant(0.001 / 65536.f, mllm::kFloat32)); - params->push("causal_mask.zero_point", mllm::Tensor::constant(65536, mllm::kInt8)); - params->push("constant_zero.scale", mllm::Tensor::constant(0.001 / 65536.f, mllm::kFloat32)); - params->push("constant_zero.zero_point", mllm::Tensor::constant(65536, mllm::kInt8)); + params->push("causal_mask.scale", mllm::Tensor::constant(0.001 / 65535.f, mllm::kFloat32)); + params->push("causal_mask.zero_point", mllm::Tensor::constant(65535, mllm::kInt8)); + params->push("constant_zero.scale", mllm::Tensor::constant(0.001 / 65535.f, mllm::kFloat32)); + params->push("constant_zero.zero_point", mllm::Tensor::constant(65535, mllm::kInt8)); } model.load(params); diff --git a/examples/qwen3_qnn_aot/compile.cpp b/examples/qwen3_qnn_aot/compile.cpp index cdf38583c..4b56bdb3b 100644 --- a/examples/qwen3_qnn_aot/compile.cpp +++ b/examples/qwen3_qnn_aot/compile.cpp @@ -39,10 +39,10 @@ MLLM_MAIN({ auto params = mllm::load(model_path.get(), mllm::ModelFileVersion::kV2); // Add params for causal mask { - params->push("causal_mask.scale", mllm::Tensor::constant(0.001 / 65536.f, mllm::kFloat32)); - params->push("causal_mask.zero_point", mllm::Tensor::constant(65536, mllm::kInt8)); - params->push("constant_zero.scale", mllm::Tensor::constant(0.001 / 65536.f, mllm::kFloat32)); - params->push("constant_zero.zero_point", mllm::Tensor::constant(65536, mllm::kInt8)); + params->push("causal_mask.scale", mllm::Tensor::constant(0.001 / 65535.f, mllm::kFloat32)); + params->push("causal_mask.zero_point", mllm::Tensor::constant(65535, mllm::kInt8)); + params->push("constant_zero.scale", mllm::Tensor::constant(0.001 / 65535.f, mllm::kFloat32)); + params->push("constant_zero.zero_point", mllm::Tensor::constant(65535, mllm::kInt8)); } model.load(params); diff --git a/examples/qwen3_qnn_aot/compile_sha.cpp b/examples/qwen3_qnn_aot/compile_sha.cpp index 840a5e5ef..e20072580 100644 --- a/examples/qwen3_qnn_aot/compile_sha.cpp +++ b/examples/qwen3_qnn_aot/compile_sha.cpp @@ -65,10 +65,10 @@ MLLM_MAIN({ // Add params for causal mask { - params->push("causal_mask.scale", mllm::Tensor::constant(0.001 / 65536.f, mllm::kFloat32)); - params->push("causal_mask.zero_point", mllm::Tensor::constant(65536, mllm::kInt8)); - params->push("constant_zero.scale", mllm::Tensor::constant(0.001 / 65536.f, mllm::kFloat32)); - params->push("constant_zero.zero_point", mllm::Tensor::constant(65536, mllm::kInt8)); + params->push("causal_mask.scale", mllm::Tensor::constant(0.001 / 65535.f, mllm::kFloat32)); + params->push("causal_mask.zero_point", mllm::Tensor::constant(65535, mllm::kInt8)); + params->push("constant_zero.scale", mllm::Tensor::constant(0.001 / 65535.f, mllm::kFloat32)); + params->push("constant_zero.zero_point", mllm::Tensor::constant(65535, mllm::kInt8)); } model.load(params); From af2e4c61f077b0487e2481d9d9d4cea8c2e19d92 Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Thu, 29 Jan 2026 10:22:31 +0000 Subject: [PATCH 8/8] fix(qnn): Change zero point data type from Int8 to Int32 for causal mask and constant zero parameters in AOT compilation to enhance quantization accuracy across multiple models. --- examples/llama_qnn_aot/compile.cpp | 4 ++-- examples/llama_qnn_aot/compile_sha.cpp | 4 ++-- examples/qwen2_qnn_aot/compile.cpp | 4 ++-- examples/qwen2_qnn_aot/compile_sha.cpp | 4 ++-- examples/qwen3_qnn_aot/compile.cpp | 4 ++-- examples/qwen3_qnn_aot/compile_sha.cpp | 4 ++-- 6 files changed, 12 insertions(+), 12 deletions(-) diff --git a/examples/llama_qnn_aot/compile.cpp b/examples/llama_qnn_aot/compile.cpp index 16e819d48..3568a2f44 100644 --- a/examples/llama_qnn_aot/compile.cpp +++ b/examples/llama_qnn_aot/compile.cpp @@ -40,9 +40,9 @@ MLLM_MAIN({ // Add params for causal mask { params->push("causal_mask.scale", mllm::Tensor::constant(0.001 / 65535.f, mllm::kFloat32)); - params->push("causal_mask.zero_point", mllm::Tensor::constant(65535, mllm::kInt8)); + params->push("causal_mask.zero_point", mllm::Tensor::constant(65535, mllm::kInt32)); params->push("constant_zero.scale", mllm::Tensor::constant(0.001 / 65535.f, mllm::kFloat32)); - params->push("constant_zero.zero_point", mllm::Tensor::constant(65535, mllm::kInt8)); + params->push("constant_zero.zero_point", mllm::Tensor::constant(65535, mllm::kInt32)); } model.load(params); diff --git a/examples/llama_qnn_aot/compile_sha.cpp b/examples/llama_qnn_aot/compile_sha.cpp index a2a111568..bd938b7a9 100644 --- a/examples/llama_qnn_aot/compile_sha.cpp +++ b/examples/llama_qnn_aot/compile_sha.cpp @@ -66,9 +66,9 @@ MLLM_MAIN({ // Add params for causal mask { params->push("causal_mask.scale", mllm::Tensor::constant(0.001 / 65535.f, mllm::kFloat32)); - params->push("causal_mask.zero_point", mllm::Tensor::constant(65535, mllm::kInt8)); + params->push("causal_mask.zero_point", mllm::Tensor::constant(65535, mllm::kInt32)); params->push("constant_zero.scale", mllm::Tensor::constant(0.001 / 65535.f, mllm::kFloat32)); - params->push("constant_zero.zero_point", mllm::Tensor::constant(65535, mllm::kInt8)); + params->push("constant_zero.zero_point", mllm::Tensor::constant(65535, mllm::kInt32)); } model.load(params); diff --git a/examples/qwen2_qnn_aot/compile.cpp b/examples/qwen2_qnn_aot/compile.cpp index fe491a76f..288501966 100644 --- a/examples/qwen2_qnn_aot/compile.cpp +++ b/examples/qwen2_qnn_aot/compile.cpp @@ -40,9 +40,9 @@ MLLM_MAIN({ // Add params for causal mask { params->push("causal_mask.scale", mllm::Tensor::constant(0.001 / 65535.f, mllm::kFloat32)); - params->push("causal_mask.zero_point", mllm::Tensor::constant(65535, mllm::kInt8)); + params->push("causal_mask.zero_point", mllm::Tensor::constant(65535, mllm::kInt32)); params->push("constant_zero.scale", mllm::Tensor::constant(0.001 / 65535.f, mllm::kFloat32)); - params->push("constant_zero.zero_point", mllm::Tensor::constant(65535, mllm::kInt8)); + params->push("constant_zero.zero_point", mllm::Tensor::constant(65535, mllm::kInt32)); } model.load(params); diff --git a/examples/qwen2_qnn_aot/compile_sha.cpp b/examples/qwen2_qnn_aot/compile_sha.cpp index fd3748dd7..50aa9b5e5 100644 --- a/examples/qwen2_qnn_aot/compile_sha.cpp +++ b/examples/qwen2_qnn_aot/compile_sha.cpp @@ -66,9 +66,9 @@ MLLM_MAIN({ // Add params for causal mask { params->push("causal_mask.scale", mllm::Tensor::constant(0.001 / 65535.f, mllm::kFloat32)); - params->push("causal_mask.zero_point", mllm::Tensor::constant(65535, mllm::kInt8)); + params->push("causal_mask.zero_point", mllm::Tensor::constant(65535, mllm::kInt32)); params->push("constant_zero.scale", mllm::Tensor::constant(0.001 / 65535.f, mllm::kFloat32)); - params->push("constant_zero.zero_point", mllm::Tensor::constant(65535, mllm::kInt8)); + params->push("constant_zero.zero_point", mllm::Tensor::constant(65535, mllm::kInt32)); } model.load(params); diff --git a/examples/qwen3_qnn_aot/compile.cpp b/examples/qwen3_qnn_aot/compile.cpp index 4b56bdb3b..cc813fe32 100644 --- a/examples/qwen3_qnn_aot/compile.cpp +++ b/examples/qwen3_qnn_aot/compile.cpp @@ -40,9 +40,9 @@ MLLM_MAIN({ // Add params for causal mask { params->push("causal_mask.scale", mllm::Tensor::constant(0.001 / 65535.f, mllm::kFloat32)); - params->push("causal_mask.zero_point", mllm::Tensor::constant(65535, mllm::kInt8)); + params->push("causal_mask.zero_point", mllm::Tensor::constant(65535, mllm::kInt32)); params->push("constant_zero.scale", mllm::Tensor::constant(0.001 / 65535.f, mllm::kFloat32)); - params->push("constant_zero.zero_point", mllm::Tensor::constant(65535, mllm::kInt8)); + params->push("constant_zero.zero_point", mllm::Tensor::constant(65535, mllm::kInt32)); } model.load(params); diff --git a/examples/qwen3_qnn_aot/compile_sha.cpp b/examples/qwen3_qnn_aot/compile_sha.cpp index e20072580..8e6dd2323 100644 --- a/examples/qwen3_qnn_aot/compile_sha.cpp +++ b/examples/qwen3_qnn_aot/compile_sha.cpp @@ -66,9 +66,9 @@ MLLM_MAIN({ // Add params for causal mask { params->push("causal_mask.scale", mllm::Tensor::constant(0.001 / 65535.f, mllm::kFloat32)); - params->push("causal_mask.zero_point", mllm::Tensor::constant(65535, mllm::kInt8)); + params->push("causal_mask.zero_point", mllm::Tensor::constant(65535, mllm::kInt32)); params->push("constant_zero.scale", mllm::Tensor::constant(0.001 / 65535.f, mllm::kFloat32)); - params->push("constant_zero.zero_point", mllm::Tensor::constant(65535, mllm::kInt8)); + params->push("constant_zero.zero_point", mllm::Tensor::constant(65535, mllm::kInt32)); } model.load(params);